"""
Wolski
-----
Compute coupled Wolski twiss matrices and normalization matrix in standard gauge
Input matrix is assumed stable and can have arbitrary even dimension
Main function can be mapped over a batch of input matrices and is differentiable
"""
from math import pi
import torch
from torch import Tensor
from twiss.util import mod
from twiss.matrix import rotation
[docs]
def twiss(m:Tensor, *,
epsilon:float=1.0E-12) -> tuple[Tensor, Tensor, Tensor]:
"""
Compute coupled Wolski twiss parameters for a given one-turn input matrix
Returns fractional tunes, normalization matrix (standard gauge) and Wolski twiss matrices
Input matrix can have arbitrary even dimension
Input matrix stability is not checked
Symplectic block is [[0, 1], [-1, 0]]
Rotation block is [[cos(alpha), sin(alpha)], [-sin(alpha), cos(alpha)]]
Complex block is 1/sqrt(2)*[[1, 1j], [1, -1j]]
Parameters
----------
m: Tensor, even-dimension, symplectic
input one-turn matrix
epsilon: float, default=1.0E-12
tolerance epsilon (ordering of planes)
Returns
-------
tuple[Tensor, Tensor, Tensor]
fractional tunes [..., T_I, ...]
normalization matrix (standard gauge) N
twiss matrices W = [..., W_I, ...]
Note
----
M = N R N^-1 = ... + W_I S sin(2*pi*T_I) - (W_I S)**2 cos(2*pi*T_I) + ...
Examples
--------
>>> from math import pi
>>> import torch
>>> from twiss.matrix import rotation
>>> m = rotation(2*pi*torch.tensor(0.88, dtype=torch.float64))
>>> t, n, w = twiss(m)
>>> t
tensor([0.8800], dtype=torch.float64)
>>> n
tensor([[1.0000, 0.0000],
[0.0000, 1.0000]], dtype=torch.float64)
>>> w
tensor([[[1.0000, 0.0000],
[0.0000, 1.0000]]], dtype=torch.float64)
>>> from math import pi
>>> import torch
>>> from twiss.matrix import rotation
>>> m = rotation(*(2*pi*torch.linspace(0.1, 0.9, 9, dtype=torch.float64)))
>>> t, n, w = twiss(m)
>>> t
tensor([0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000, 0.8000, 0.9000],
dtype=torch.float64)
>>> from math import pi
>>> import torch
>>> from twiss.matrix import rotation
>>> m = torch.func.vmap(rotation)(2*pi*torch.linspace(0.1, 0.9, 9, dtype=torch.float64))
>>> t, n, w = torch.func.vmap(twiss)(m)
>>> t
tensor([[0.1000],
[0.2000],
[0.3000],
[0.4000],
[0.5000],
[0.6000],
[0.7000],
[0.8000],
[0.9000]], dtype=torch.float64)
>>> from math import pi
>>> import torch
>>> from twiss.matrix import rotation
>>> def fn(k):
... m = rotation(2*pi*torch.tensor(0.88, dtype=torch.float64))
... i = torch.ones_like(k)
... o = torch.zeros_like(k)
... m = m @ torch.stack([i, k, o, i]).reshape(m.shape)
... t, *_ = twiss(m)
... return t
>>> k = torch.tensor(0.0, dtype=torch.float64)
>>> fn(k)
tensor([0.8800], dtype=torch.float64)
>>> torch.func.jacfwd(fn)(k)
tensor([0.0796], dtype=torch.float64)
"""
with torch.no_grad():
dtype = m.dtype
device = m.device
rdtype = torch.tensor(1, dtype=dtype).abs().dtype
cdtype = (1j*torch.tensor(1, dtype=dtype)).dtype
d = len(m) // 2
b_p = torch.tensor([[1, 0], [0, 1]], dtype=rdtype, device=device)
b_s = torch.tensor([[0, 1], [-1, 0]], dtype=rdtype, device=device)
b_c = 0.5**0.5*torch.tensor([[1, +1j], [1, -1j]], dtype=cdtype, device=device)
m_p = torch.stack([torch.block_diag(*[b_p*(i == j) for i in range(d)]) for j in range(d)])
m_s = torch.block_diag(*[b_s for _ in range(d)])
m_c = torch.block_diag(*[b_c for _ in range(d)])
e, v = torch.linalg.eig(m)
e, v = e.reshape(d, -1), v.T.reshape(d, -1, 2*d)
u = torch.zeros_like(v)
for i, (v1, v2) in enumerate(v):
u[i] = v[i]/(-1j*(v1 @ m_s.to(cdtype) @ v2)).abs().sqrt()
k = torch.zeros_like(e)
v = torch.zeros_like(u)
for i in range(d):
o = torch.clone(e[i].log()).imag.argsort()
k[i], v[i] = e[i, o], u[i, o]
t = 1.0 - k.log().abs().mean(-1)/(2.0*pi)
n = torch.cat([*v]).H
n = (n @ m_c).real
w = n @ m_p @ n.T
o = torch.stack([w[i].diag().argmax() for i in range(d)]).argsort()
t, v = t[o], v[o]
n = torch.cat([*v]).H
n = (n @ m_c).real
f = (torch.stack(torch.hsplit(n.T @ m_s @ n - m_s, d)).abs().sum((1, -1)) <= epsilon)
g = f.logical_not()
for i in range(d):
t[i] = f[i]*t[i] + g[i]*(1.0 - t[i]).abs()
v[i] = f[i]*v[i] + g[i]*v[i].conj()
n = torch.cat([*v]).H
n = (n @ m_c).real
s = torch.arange(d, dtype=torch.int64, device=device)
a = (n[2*s, 2*s + 1] + 1j*n[2*s, 2*s]).angle() - 0.5*pi
n = n @ rotation(*a)
w = n @ m_p @ n.T
return t, n, w
[docs]
def is_stable(m:Tensor, *,
epsilon:float=1.0E-12) -> bool:
"""
Check one-turn matrix stability
Parameters
----------
m: Tensor, even-dimension
input one-turn matrix
epsilon: float, default=1.0E-12
tolerance epsilon
Returns
-------
bool
Note
----
Input matrix is stable if eigenvalues are on the unit circle
Examples
--------
>>> from math import pi
>>> import torch
>>> from twiss.matrix import rotation
>>> m = rotation(2*pi*torch.tensor(0.1234, dtype=torch.float64))
>>> is_stable(m)
True
>>> is_stable(m + 1.0E-3)
False
"""
return all((torch.linalg.eigvals(m).abs() - 1.0).abs() <= epsilon)
[docs]
def propagate(w:Tensor,
m:Tensor) -> Tensor:
"""
Propagate Wolski twiss matrices throught a given transport matrix
Parameters
----------
w: Tensor, even-dimension
Wolski twiss matrices
m: Tensor, even-dimension, symplectic
transport matrix
Returns
-------
Tensor
Note
----
W_I = M W_I M.T
Examples
--------
>>> from math import pi
>>> import torch
>>> from twiss.matrix import rotation
>>> m = rotation(2*pi*torch.tensor(0.88, dtype=torch.float64))
>>> *_, w = twiss(m)
>>> propagate(w, torch.tensor([[1.0, 0.1], [0.0, 1.0]], dtype=torch.float64))
tensor([[[1.0100, 0.1000],
[0.1000, 1.0000]]], dtype=torch.float64)
"""
return m @ w @ m.T
[docs]
def advance(n:Tensor,
m:Tensor) -> tuple[Tensor, Tensor]:
"""
Compute advance and final normalization matrix given input transport and normalization matrices
Parameters
----------
n: Tensor, even-dimension, symplectic
normalization matrix
m: Tensor, even-dimension, symplectic
transport matrix
Returns
-------
tuple[Tensor, Tensor]
phase advance
final normalization matrix
Note
----
Output phase advance is mod 2 pi
Examples
--------
>>> from math import pi
>>> import torch
>>> from twiss.matrix import rotation
>>> m = rotation(2*pi*torch.tensor(0.88, dtype=torch.float64))
>>> _, n1, _ = twiss(m)
>>> t = torch.tensor([[1.0, 0.1], [0.0, 1.0]], dtype=torch.float64)
>>> mu, n2 = advance(n1, t)
>>> torch.allclose(t, n2 @ rotation(*mu) @ n1.inverse())
True
"""
d = len(n) // 2
i = torch.arange(d, dtype=torch.int64, device=n.device)
k = m @ n
f = mod(torch.arctan2(k[2*i, 2*i + 1], k[2*i, 2*i]), 2.0*pi)
return f, k @ rotation(*(-f))