Example-60: Chromatic matrix (element)
[1]:
# Can be used to model 4x4 linear elements with leading order chromatic effects
# Transport matrix is constructed given elements of the symmetric matrix
# M = exp(S A) exp(dp S B)
[2]:
from scipy.linalg import logm
import torch
from model.library.drift import Drift
from model.library.matrix import Matrix
[3]:
# Drift element
D = Drift('D', length=1.5)
[4]:
# Parametric drift matrix
def matrix(state, dp):
return torch.func.jacrev(lambda state: D(state, data={**D.data(), **{'dp': dp}}))(state)
state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
print(matrix(state, 0.0))
print()
print(matrix(state, 0.001))
print()
tensor([[1.0000, 1.5000, 0.0000, 0.0000],
[0.0000, 1.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 1.0000, 1.5000],
[0.0000, 0.0000, 0.0000, 1.0000]], dtype=torch.float64)
tensor([[1.0000, 1.4985, 0.0000, 0.0000],
[0.0000, 1.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 1.0000, 1.4985],
[0.0000, 0.0000, 0.0000, 1.0000]], dtype=torch.float64)
[5]:
# Construct A and B matrices
S = torch.tensor([[0., 1., 0., 0.], [-1., 0., 0., 0.], [0., 0., 0., 1.], [0., 0., -1., 0.]], dtype=torch.float64)
state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
dp = torch.tensor(0.0, dtype=torch.float64)
m = matrix(state, dp)
dmdp = torch.func.jacrev(matrix, 1)(state, dp)
A = (torch.linalg.inv(S) @ torch.tensor(logm(m), dtype=torch.float64))
B = - S @ m @ dmdp
print(matrix(state, 0.001))
print()
print(torch.matrix_exp(S @ A) @ torch.matrix_exp(0.001 * S @ B))
print()
tensor([[1.0000, 1.4985, 0.0000, 0.0000],
[0.0000, 1.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 1.0000, 1.4985],
[0.0000, 0.0000, 0.0000, 1.0000]], dtype=torch.float64)
tensor([[1.0000, 1.4985, 0.0000, 0.0000],
[0.0000, 1.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 1.0000, 1.4985],
[0.0000, 0.0000, 0.0000, 1.0000]], dtype=torch.float64)
[6]:
# Define matrix element
M = Matrix('M',
length=1.5,
A=A[torch.triu(torch.ones_like(A, dtype=torch.bool))].tolist(),
B=B[torch.triu(torch.ones_like(B, dtype=torch.bool))].tolist())
[7]:
# Compare with drift
def matrix(state, dp):
return torch.func.jacrev(lambda state: M(state, data={**M.data(), **{'dp': dp}}))(state)
state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
print(matrix(state, 0.0))
print()
print(matrix(state, 0.001))
print()
tensor([[1.0000, 1.5000, 0.0000, 0.0000],
[0.0000, 1.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 1.0000, 1.5000],
[0.0000, 0.0000, 0.0000, 1.0000]], dtype=torch.float64)
tensor([[1.0000, 1.4985, 0.0000, 0.0000],
[0.0000, 1.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 1.0000, 1.4985],
[0.0000, 0.0000, 0.0000, 1.0000]], dtype=torch.float64)
[ ]: