Example-60: Chromatic matrix (element)

[1]:
# Can be used to model 4x4 linear elements with leading order chromatic effects
# Transport matrix is constructed given elements of the symmetric matrix
# M = exp(S A) exp(dp S B)
[2]:
from scipy.linalg import logm

import torch
from model.library.drift import Drift
from model.library.matrix import Matrix
[3]:
# Drift element

D = Drift('D', length=1.5)
[4]:
# Parametric drift matrix

def matrix(state, dp):
    return torch.func.jacrev(lambda state: D(state, data={**D.data(), **{'dp': dp}}))(state)

state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)

print(matrix(state, 0.0))
print()

print(matrix(state, 0.001))
print()
tensor([[1.0000, 1.5000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.0000, 1.5000],
        [0.0000, 0.0000, 0.0000, 1.0000]], dtype=torch.float64)

tensor([[1.0000, 1.4985, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.0000, 1.4985],
        [0.0000, 0.0000, 0.0000, 1.0000]], dtype=torch.float64)

[5]:
# Construct A and B matrices

S = torch.tensor([[0.,  1.,  0.,  0.], [-1., 0.,  0.,  0.], [0.,  0.,  0.,  1.], [0.,  0., -1.,  0.]], dtype=torch.float64)

state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
dp = torch.tensor(0.0, dtype=torch.float64)

m = matrix(state, dp)
dmdp = torch.func.jacrev(matrix, 1)(state, dp)

A = (torch.linalg.inv(S) @ torch.tensor(logm(m), dtype=torch.float64))
B = - S @ m @ dmdp

print(matrix(state, 0.001))
print()

print(torch.matrix_exp(S @ A) @ torch.matrix_exp(0.001 * S @ B))
print()
tensor([[1.0000, 1.4985, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.0000, 1.4985],
        [0.0000, 0.0000, 0.0000, 1.0000]], dtype=torch.float64)

tensor([[1.0000, 1.4985, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.0000, 1.4985],
        [0.0000, 0.0000, 0.0000, 1.0000]], dtype=torch.float64)

[6]:
# Define matrix element

M = Matrix('M',
           length=1.5,
           A=A[torch.triu(torch.ones_like(A, dtype=torch.bool))].tolist(),
           B=B[torch.triu(torch.ones_like(B, dtype=torch.bool))].tolist())
[7]:
# Compare with drift

def matrix(state, dp):
    return torch.func.jacrev(lambda state: M(state, data={**M.data(), **{'dp': dp}}))(state)

state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)

print(matrix(state, 0.0))
print()

print(matrix(state, 0.001))
print()
tensor([[1.0000, 1.5000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.0000, 1.5000],
        [0.0000, 0.0000, 0.0000, 1.0000]], dtype=torch.float64)

tensor([[1.0000, 1.4985, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.0000, 1.4985],
        [0.0000, 0.0000, 0.0000, 1.0000]], dtype=torch.float64)

[ ]: