ELETTRA-37: ID linear optics distortion (one-turn matrix factorization)

[1]:
# In this example one-turn matrix factorization is illustrated
# The one-turn matrix can be expressed as M exp(S A) where Aij = Aji is (a block matrix if there is no coupling that) describes cumulative effect of perturbation
# Here the matrix A is constructed perturbatively using Magnus expansion
[2]:
# Import

import torch
from torch import Tensor

from pathlib import Path

import matplotlib
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
matplotlib.rcParams['text.usetex'] = True

from model.library.element import Element
from model.library.line import Line
from model.library.quadrupole import Quadrupole
from model.library.marker import Marker
from model.library.matrix import Matrix

from model.command.external import load_lattice
from model.command.build import build
[3]:
# Set data type and device

Element.dtype = dtype = torch.float64
Element.device = device = torch.device('cpu')
[4]:
# Load lattice (ELEGANT table)
# Note, lattice is allowed to have repeated elements

path = Path('elettra.lte')
data = load_lattice(path)
[5]:
# Build and setup lattice

ring:Line = build('RING', 'ELEGANT', data)

# Flatten sublines

ring.flatten()

# Remove all marker elements but the ones starting with MLL (long straight section centers)

ring.remove_group(pattern=r'^(?!MLL_).*', kinds=['Marker'])

# Replace all sextupoles with quadrupoles

def factory(element:Element) -> None:
    table = element.serialize
    table.pop('ms', None)
    return Quadrupole(**table)

ring.replace_group(pattern=r'', factory=factory, kinds=['Sextupole'])

# Set linear dipoles

def apply(element:Element) -> None:
    element.linear = True

ring.apply(apply, kinds=['Dipole'])

# Merge drifts

ring.merge()

# Change lattice start

ring.start = "BPM_S01_01"

# Describe

ring.describe
[5]:
{'BPM': 168, 'Drift': 708, 'Dipole': 156, 'Quadrupole': 360, 'Marker': 12}
[6]:
# Unperturbed one-turn matrix

state = torch.tensor(4*[0.0], dtype=dtype)
M0 = torch.func.jacrev(ring)(state)
print(M0)
tensor([[-0.5038,  3.9513,  0.0000,  0.0000],
        [-0.2394, -0.1073,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.1784,  1.9752],
        [ 0.0000,  0.0000, -0.4264,  0.8845]], dtype=torch.float64)
[7]:
# Define IDs

# The first ID is 'fake' and acts as an identity
# It is also located at the lattice start

ca, cb, cc, cd = -0.034441907232402175, -0.04458009513208418, 0.056279356423643276, 0.08037110220505986
A = torch.tensor([[ca, 0.0, 0.0, 0.0], [0.0, cb, 0.0, 0.0], [0.0, 0.0, cc, 0.0], [0.0, 0.0, 0.0, cd]], dtype=dtype)
mask = torch.triu(torch.ones_like(A, dtype=torch.bool))

ID1 = Matrix('ID1', length=0.0, A=(0.0*A[mask]).tolist())
ID2 = Matrix('ID2', length=0.0, A=(1.0*A[mask]).tolist())
ID3 = Matrix('ID3', length=0.0, A=(1.1*A[mask]).tolist())
ID4 = Matrix('ID4', length=0.0, A=(0.8*A[mask]).tolist())
ID5 = Matrix('ID5', length=0.0, A=(0.5*A[mask]).tolist())
ID6 = Matrix('ID6', length=0.0, A=(1.2*A[mask]).tolist())
[8]:
# Insert IDs

# Each ID (or other perturbation) is a thin insertion matrix
# The matrices are inserted after given markers

elements = [ID1, ID2, ID3, ID4, ID5, ID6]
markers = ['BPM_S01_01', 'MLL_S01', 'MLL_S02', 'MLL_S03', 'MLL_S05', 'MLL_S08']

error = ring.clone()
for element, marker in zip(elements, markers):
    error.insert(element, marker)

# Describe

error.describe
[8]:
{'BPM': 168,
 'Matrix': 6,
 'Drift': 708,
 'Dipole': 156,
 'Quadrupole': 360,
 'Marker': 12}
[9]:
# Perturbed one-turn matrix

state = torch.tensor(4*[0.0], dtype=dtype)
M1 = torch.func.jacrev(error)(state)

print(M0)
print(M1)
print()
tensor([[-0.5038,  3.9513,  0.0000,  0.0000],
        [-0.2394, -0.1073,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.1784,  1.9752],
        [ 0.0000,  0.0000, -0.4264,  0.8845]], dtype=torch.float64)
tensor([[ 0.4728,  4.4862,  0.0000,  0.0000],
        [-0.1703,  0.4987,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)

$:nbsphinx-math:hat `M = :nbsphinx-math:prod`i T{i, i + 1} K_i $

[10]:
# Skew identity matrix

S = torch.tensor([[0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1], [0, 0, -1, 0]], dtype=dtype)

# Compute perturbations

Ais = []
Kis = []
for element in elements:
    Ai = torch.zeros((4, 4), dtype=dtype)
    Ai[mask] = element.A
    Ki = torch.linalg.matrix_exp(S @ Ai)
    Ais.append(Ai)
    Kis.append(Ki)

Ais = torch.stack(Ais)
Kis = torch.stack(Kis)

print([torch.allclose(Ki, torch.func.jacrev(element)(state)) for Ki, element in zip(Kis, elements)])
print()

# Compute matrices between IDs (using markers in the unperturbed lattice)

Tis = []
for i in range(len(markers)):
    _, *sequence, _ = ring[markers[i] : markers[(i + 1) % len(markers)]]
    line = Line('', sequence=sequence)
    Tis.append(torch.func.jacrev(line)(state))
Tis = torch.stack(Tis)

# Construct one-turn matrix

hM = torch.eye(4, dtype=dtype)
for Ti in Tis:
    hM = Ti @ hM

print(M0)
print(hM)
print(torch.allclose(M0, hM))
print()

# Construct one-turn matrix

hM = torch.eye(4, dtype=dtype)
for Ti, Ki in zip(Tis, Kis):
    hM = Ti @ Ki @ hM

print(M1)
print(hM)
print(torch.allclose(M1, hM))
print()
[True, True, True, True, True, True]

tensor([[-0.5038,  3.9513,  0.0000,  0.0000],
        [-0.2394, -0.1073,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.1784,  1.9752],
        [ 0.0000,  0.0000, -0.4264,  0.8845]], dtype=torch.float64)
tensor([[-0.5038,  3.9513,  0.0000,  0.0000],
        [-0.2394, -0.1073,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.1784,  1.9752],
        [ 0.0000,  0.0000, -0.4264,  0.8845]], dtype=torch.float64)
True

tensor([[ 0.4728,  4.4862,  0.0000,  0.0000],
        [-0.1703,  0.4987,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)
tensor([[ 0.4728,  4.4862,  0.0000,  0.0000],
        [-0.1703,  0.4987,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)
True

\(\hat M = \prod_i T_{i, i + 1} K_i = M \prod_i \hat K_i\)

$:nbsphinx-math:`hat `K_i = T_i^{-1} K_i T_i $

[11]:
# Compute unperturbed transport matrices from lattice start to each ID

Tis = []
for marker in markers:
    *sequence, _ = ring[ring.start : marker]
    line = Line('', sequence=sequence)
    Tis.append(torch.func.jacrev(line)(state))
Tis = torch.stack(Tis)

# Construct one-turn matrix

hM = torch.eye(4, dtype=dtype)
for Ti, Ki in zip(Tis, Kis):
    hKi = Ti.inverse() @ Ki @ Ti
    hM =  hKi @ hM
hM = M0 @ hM

print(M1)
print(hM)
print(torch.allclose(M1, hM))
print()
tensor([[ 0.4728,  4.4862,  0.0000,  0.0000],
        [-0.1703,  0.4987,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)
tensor([[ 0.4728,  4.4862,  0.0000,  0.0000],
        [-0.1703,  0.4987,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)
True

\(\hat M = \prod_i T_{i, i + 1} K_i = M \prod_i \hat K_i\)

\(\hat K_i = T_i^{-1} K_i T_i = \exp ( T_i^{-1} S A_i T_i ) = \exp ( S \hat A_i )\)

\(\hat A_i = T_i^{T} A_i T_i\)

[12]:
# Construct shifted exponents

hAis = []
for Ti, Ai in zip(Tis, Ais):
    hAi = Ti.T @ Ai @ Ti
    hAis.append(hAi)
hAis = torch.stack(hAis)

# Construct one-turn matrix

hM = torch.eye(4, dtype=dtype)
for hAi in hAis:
    hM =  torch.linalg.matrix_exp(S @ hAi) @ hM
hM = M0 @ hM

print(M1)
print(hM)
print(torch.allclose(M1, hM))
print()
tensor([[ 0.4728,  4.4862,  0.0000,  0.0000],
        [-0.1703,  0.4987,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)
tensor([[ 0.4728,  4.4862,  0.0000,  0.0000],
        [-0.1703,  0.4987,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)
True

\(\hat M = M \prod_i \exp(S \hat A_i) = M \exp(\varepsilon \hat A^{(1)} + \varepsilon^2 \hat A^{(2)} + \dots)\)

\(\hat A^{(1)} = \sum_i \hat A_i\)

[13]:
# Construct 1st order product approximation

hA1 = torch.zeros((4, 4), dtype=dtype)
for hAi in hAis:
    hA1 += hAi

# Construct one-turn matrix with perturbation

hM = M0 @ torch.linalg.matrix_exp(S @ hA1)

print(M1)
print(M0)
print((M1 - M0).norm())
print()

print(M1)
print(hM)
print((M1 - hM).norm())
print()
tensor([[ 0.4728,  4.4862,  0.0000,  0.0000],
        [-0.1703,  0.4987,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)
tensor([[-0.5038,  3.9513,  0.0000,  0.0000],
        [-0.2394, -0.1073,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.1784,  1.9752],
        [ 0.0000,  0.0000, -0.4264,  0.8845]], dtype=torch.float64)
tensor(1.3791, dtype=torch.float64)

tensor([[ 0.4728,  4.4862,  0.0000,  0.0000],
        [-0.1703,  0.4987,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)
tensor([[ 0.2006,  4.7250,  0.0000,  0.0000],
        [-0.1836,  0.6597,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1595,  2.2923],
        [ 0.0000,  0.0000, -0.4798,  0.6261]], dtype=torch.float64)
tensor(0.3967, dtype=torch.float64)

\(\hat M = M \prod_i \exp(S \hat A_i) = M \exp(\varepsilon \hat A^{(1)} + \varepsilon^2 \hat A^{(2)} + \dots)\)

\(S \hat A^{(1)} = S \sum_i \hat A_i\)

\(S \hat A^{(2)} = \sum_{i \ge j} \frac{1}{1 + \delta_{i, j}} \frac{1}{2} \{ S \hat A_i, S \hat A_j \} = S \sum_{i \ge j} \frac{1}{1 + \delta_{i, j}} \frac{1}{2} (\hat A_i S \hat A_j - \hat A_j S \hat A_i)\)

[14]:
# Define bracket (commutator)

def bracket(X, Y):
    return X @ Y - Y @ X
[15]:
# Construct 2nd order product approximation

hA2 = torch.zeros((4, 4), dtype=dtype)

for i in range(len(elements)):
    for j in range(i + 1):
        dij = (i == j)
        factor = 1/(1 + dij)
        Xi = S @ hAis[i]
        Xj = S @ hAis[j]
        hA2 += factor*1/2*bracket(Xi, Xj)

hA2 = - S @ hA2

# Construct one-turn matrix

print(M1)
print(M0)
print((M1 - M0).norm())
print()

hM = M0 @ torch.linalg.matrix_exp(S @ hA1)

print(M1)
print(hM)
print((M1 - hM).norm())
print()

hM = M0 @ torch.linalg.matrix_exp(S @ (hA1 + hA2))

print(M1)
print(hM)
print((M1 - hM).norm())
print()
tensor([[ 0.4728,  4.4862,  0.0000,  0.0000],
        [-0.1703,  0.4987,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)
tensor([[-0.5038,  3.9513,  0.0000,  0.0000],
        [-0.2394, -0.1073,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.1784,  1.9752],
        [ 0.0000,  0.0000, -0.4264,  0.8845]], dtype=torch.float64)
tensor(1.3791, dtype=torch.float64)

tensor([[ 0.4728,  4.4862,  0.0000,  0.0000],
        [-0.1703,  0.4987,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)
tensor([[ 0.2006,  4.7250,  0.0000,  0.0000],
        [-0.1836,  0.6597,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1595,  2.2923],
        [ 0.0000,  0.0000, -0.4798,  0.6261]], dtype=torch.float64)
tensor(0.3967, dtype=torch.float64)

tensor([[ 0.4728,  4.4862,  0.0000,  0.0000],
        [-0.1703,  0.4987,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)
tensor([[ 0.3914,  4.5444,  0.0000,  0.0000],
        [-0.1772,  0.4978,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1571,  2.3014],
        [ 0.0000,  0.0000, -0.4770,  0.6229]], dtype=torch.float64)
tensor(0.1004, dtype=torch.float64)

\(\hat M = M \prod_i \exp(S \hat A_i) = M \exp(\varepsilon \hat A^{(1)} + \varepsilon^2 \hat A^{(2)} + \dots)\)

\(S \hat A^{(1)} = S \sum_i \hat A_i\)

\(S \hat A^{(2)} = \sum_{i \ge j} \frac{1}{1 + \delta_{i, j}} \frac{1}{2} \{ S \hat A_i, S \hat A_j \} = S \sum_{i \ge j} \frac{1}{1 + \delta_{i, j}} \frac{1}{2} (\hat A_i S \hat A_j - \hat A_j S \hat A_i)\)

\(S\,\hat A^{(3)}=\sum_{i \ge j \ge k} \frac{1}{1+\delta_{ij}} \frac{1}{1+\delta_{ik}+\delta_{jk}} (\frac{1}{4} \{\{S\hat A_i,S\hat A_j\},S\hat A_k\}-\frac{1}{12}\{\{S\hat A_i,S\hat A_k\},S\hat A_j\}-\frac{1}{12}\{\{S\hat A_j,S\hat A_k\},S\hat A_i\})\)

[16]:
# Construct 3rd order product approximation

hA3 = torch.zeros((4, 4), dtype=dtype)

for i in range(len(elements)):
    for j in range(i + 1):
        for k in range(j + 1):
            dij = (i == j)
            dik = (i == k)
            djk = (j == k)
            factor = (1.0/(1.0 + dij))*(1.0/(1.0 + dik + djk))
            Xi = S @ hAis[i]
            Xj = S @ hAis[j]
            Xk = S @ hAis[k]
            hA3 += factor*(
                + 1/4*bracket(bracket(Xi, Xj), Xk)
                - (1/12)*bracket(bracket(Xi, Xk), Xj)
                - (1/12)*bracket(bracket(Xj, Xk), Xi)
            )

hA3 = - S @ hA3

# Construct one-turn matrix

print(M1)
print(M0)
print((M1 - M0).norm())
print()

hM = M0 @ torch.linalg.matrix_exp(S @ hA1)

print(M1)
print(hM)
print((M1 - hM).norm())
print()

hM = M0 @ torch.linalg.matrix_exp(S @ (hA1 + hA2))

print(M1)
print(hM)
print((M1 - hM).norm())
print()

hM = M0 @ torch.linalg.matrix_exp(S @ (hA1 + hA2 + hA3))

print(M1)
print(hM)
print((M1 - hM).norm())
print()
tensor([[ 0.4728,  4.4862,  0.0000,  0.0000],
        [-0.1703,  0.4987,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)
tensor([[-0.5038,  3.9513,  0.0000,  0.0000],
        [-0.2394, -0.1073,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.1784,  1.9752],
        [ 0.0000,  0.0000, -0.4264,  0.8845]], dtype=torch.float64)
tensor(1.3791, dtype=torch.float64)

tensor([[ 0.4728,  4.4862,  0.0000,  0.0000],
        [-0.1703,  0.4987,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)
tensor([[ 0.2006,  4.7250,  0.0000,  0.0000],
        [-0.1836,  0.6597,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1595,  2.2923],
        [ 0.0000,  0.0000, -0.4798,  0.6261]], dtype=torch.float64)
tensor(0.3967, dtype=torch.float64)

tensor([[ 0.4728,  4.4862,  0.0000,  0.0000],
        [-0.1703,  0.4987,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)
tensor([[ 0.3914,  4.5444,  0.0000,  0.0000],
        [-0.1772,  0.4978,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1571,  2.3014],
        [ 0.0000,  0.0000, -0.4770,  0.6229]], dtype=torch.float64)
tensor(0.1004, dtype=torch.float64)

tensor([[ 0.4728,  4.4862,  0.0000,  0.0000],
        [-0.1703,  0.4987,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)
tensor([[ 0.4452,  4.5011,  0.0000,  0.0000],
        [-0.1722,  0.5048,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1575,  2.3013],
        [ 0.0000,  0.0000, -0.4772,  0.6232]], dtype=torch.float64)
tensor(0.0320, dtype=torch.float64)

\(\hat M = M \prod_i \exp(S \hat A_i) = M \exp(\varepsilon \hat A^{(1)} + \varepsilon^2 \hat A^{(2)} + \dots)\)

\(S \hat A^{(1)} = S \sum_i \hat A_i\)

\(S \hat A^{(2)} = \sum_{i \ge j} \frac{1}{1 + \delta_{i, j}} \frac{1}{2} \{ S \hat A_i, S \hat A_j \} = S \sum_{i \ge j} \frac{1}{1 + \delta_{i, j}} \frac{1}{2} (\hat A_i S \hat A_j - \hat A_j S \hat A_i)\)

\(S \hat A^{(3)}=\sum_{i \ge j \ge k} \frac{1}{1+\delta_{ij}} \frac{1}{1+\delta_{ik}+\delta_{jk}} (\frac{1}{4} \{\{S\hat A_i,S\hat A_j\},S\hat A_k\}-\frac{1}{12}\{\{S\hat A_i,S\hat A_k\},S\hat A_j\}-\frac{1}{12}\{\{S\hat A_j,S\hat A_k\},S\hat A_i\})\)

$ S:nbsphinx-math:hat `A^{(4)} = :nbsphinx-math:sum`_{i \ge `j :nbsphinx-math:ge k :nbsphinx-math:ge l} :nbsphinx-math:frac{1}{1+delta_{ij}}`; \frac{1}{1+\delta_{ik}+\delta_{jk}}; \frac{1}{1+\delta_{il}+\delta_{jl}+\delta_{kl}} \frac{1}{12} ( {{{S \hat `A_i,S :nbsphinx-math:hat A_j},S :nbsphinx-math:hat A_k},S :nbsphinx-math:hat A_l} +{S :nbsphinx-math:hat A_i,{{S :nbsphinx-math:hat A_j,S :nbsphinx-math:hat A_k},S :nbsphinx-math:hat A_l}} +{S :nbsphinx-math:hat A_i,{S :nbsphinx-math:hat A_j,{S :nbsphinx-math:hat A_k,S :nbsphinx-math:hat A_l}}} +{S :nbsphinx-math:hat A_j,{S :nbsphinx-math:hat A_k,{S :nbsphinx-math:hat A_l,S :nbsphinx-math:hat `A_i}}} ) $

[17]:
# Construct 4th order product approximation

hA4 = torch.zeros((4, 4), dtype=dtype)

for i in range(len(elements)):
    for j in range(i + 1):
        for k in range(j + 1):
            for l in range(k + 1):
                dij = (i == j)
                dik = (i == k)
                djk = (j == k)
                dil = (i == l)
                djl = (j == l)
                dkl = (k == l)
                factor = (1.0 / (1.0 + dij)) * (1.0 / (1.0 + dik + djk)) * (1.0 / (1.0 + dil + djl + dkl))
                Xi = S @ hAis[i]
                Xj = S @ hAis[j]
                Xk = S @ hAis[k]
                Xl = S @ hAis[l]
                hA4 += factor*(
                    1/12*bracket(bracket(bracket(Xi, Xj), Xk), Xl) +
                    1/12*bracket(Xi, bracket(bracket(Xj, Xk), Xl)) +
                    1/12*bracket(Xi, bracket(Xj, bracket(Xk, Xl))) +
                    1/12*bracket(Xj, bracket(Xk, bracket(Xl, Xi)))
                )

hA4 = - S @ hA4

# Construct one-turn matrix

print(M1)
print(M0)
print((M1 - M0).norm())
print()

hM = M0 @ torch.linalg.matrix_exp(S @ hA1)

print(M1)
print(hM)
print((M1 - hM).norm())
print()

hM = M0 @ torch.linalg.matrix_exp(S @ (hA1 + hA2))

print(M1)
print(hM)
print((M1 - hM).norm())
print()

hM = M0 @ torch.linalg.matrix_exp(S @ (hA1 + hA2 + hA3))

print(M1)
print(hM)
print((M1 - hM).norm())
print()

hM = M0 @ torch.linalg.matrix_exp(S @ (hA1 + hA2 + hA3 + hA4))

print(M1)
print(hM)
print((M1 - hM).norm())
print()
tensor([[ 0.4728,  4.4862,  0.0000,  0.0000],
        [-0.1703,  0.4987,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)
tensor([[-0.5038,  3.9513,  0.0000,  0.0000],
        [-0.2394, -0.1073,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.1784,  1.9752],
        [ 0.0000,  0.0000, -0.4264,  0.8845]], dtype=torch.float64)
tensor(1.3791, dtype=torch.float64)

tensor([[ 0.4728,  4.4862,  0.0000,  0.0000],
        [-0.1703,  0.4987,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)
tensor([[ 0.2006,  4.7250,  0.0000,  0.0000],
        [-0.1836,  0.6597,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1595,  2.2923],
        [ 0.0000,  0.0000, -0.4798,  0.6261]], dtype=torch.float64)
tensor(0.3967, dtype=torch.float64)

tensor([[ 0.4728,  4.4862,  0.0000,  0.0000],
        [-0.1703,  0.4987,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)
tensor([[ 0.3914,  4.5444,  0.0000,  0.0000],
        [-0.1772,  0.4978,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1571,  2.3014],
        [ 0.0000,  0.0000, -0.4770,  0.6229]], dtype=torch.float64)
tensor(0.1004, dtype=torch.float64)

tensor([[ 0.4728,  4.4862,  0.0000,  0.0000],
        [-0.1703,  0.4987,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)
tensor([[ 0.4452,  4.5011,  0.0000,  0.0000],
        [-0.1722,  0.5048,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1575,  2.3013],
        [ 0.0000,  0.0000, -0.4772,  0.6232]], dtype=torch.float64)
tensor(0.0320, dtype=torch.float64)

tensor([[ 0.4728,  4.4862,  0.0000,  0.0000],
        [-0.1703,  0.4987,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)
tensor([[ 0.4635,  4.4927,  0.0000,  0.0000],
        [-0.1711,  0.4991,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1576,  2.3014],
        [ 0.0000,  0.0000, -0.4772,  0.6233]], dtype=torch.float64)
tensor(0.0114, dtype=torch.float64)