Example-28: Transformation

[1]:
# In this example another wrappers are used to construct parametric transformations between elements
# Given two element a transformation can be constructed from the first element enterence frame to the second element exit frame
# If the first element appears after the second one in the line, inverse transformation is constructed
# Note, these transformations are given around initial reference orbit
[2]:
# Import

import torch

import matplotlib
from matplotlib import pyplot as plt
matplotlib.rcParams['text.usetex'] = True

from twiss import twiss

from ndmap.signature import chop
from ndmap.evaluate import evaluate
from ndmap.pfp import parametric_fixed_point

from model.library.drift import Drift
from model.library.quadrupole import Quadrupole
from model.library.sextupole import Sextupole
from model.library.dipole import Dipole
from model.library.line import Line

from model.command.wrapper import group
[3]:
# Define simple FODO based lattice using nested lines
# Note, all elements have unique names

QF_A = Quadrupole('QF_A', 1.0, +0.20)
QD_A = Quadrupole('QD_A', 1.0, -0.19)
QF_B = Quadrupole('QF_B', 1.0, +0.20)
QD_B = Quadrupole('QD_B', 1.0, -0.19)
QF_C = Quadrupole('QF_C', 1.0, +0.20)
QD_C = Quadrupole('QD_C', 1.0, -0.19)
QF_D = Quadrupole('QF_D', 1.0, +0.20)
QD_D = Quadrupole('QD_D', 1.0, -0.19)

SF1_A = Sextupole('SF1_A', 0.25, 0.00)
SD1_A = Sextupole('SD1_A', 0.25, 0.00)
SF2_A = Sextupole('SF2_A', 0.25, 0.00)
SD2_A = Sextupole('SD2_A', 0.25, 0.00)
SF1_B = Sextupole('SF1_B', 0.25, 0.00)
SD1_B = Sextupole('SD1_B', 0.25, 0.00)
SF2_B = Sextupole('SF2_B', 0.25, 0.00)
SD2_B = Sextupole('SD2_B', 0.25, 0.00)
SF1_C = Sextupole('SF1_C', 0.25, 0.00)
SD1_C = Sextupole('SD1_C', 0.25, 0.00)
SF2_C = Sextupole('SF2_C', 0.25, 0.00)
SD2_C = Sextupole('SD2_C', 0.25, 0.00)
SF1_D = Sextupole('SF1_D', 0.25, 0.00)
SD1_D = Sextupole('SD1_D', 0.25, 0.00)
SF2_D = Sextupole('SF2_D', 0.25, 0.00)
SD2_D = Sextupole('SD2_D', 0.25, 0.00)

BM1_A = Dipole('BM1_A', 3.50, torch.pi/4.0)
BM2_A = Dipole('BM2_A', 3.50, torch.pi/4.0)
BM1_B = Dipole('BM1_B', 3.50, torch.pi/4.0)
BM2_B = Dipole('BM2_B', 3.50, torch.pi/4.0)
BM1_C = Dipole('BM1_C', 3.50, torch.pi/4.0)
BM2_C = Dipole('BM2_C', 3.50, torch.pi/4.0)
BM1_D = Dipole('BM1_D', 3.50, torch.pi/4.0)
BM2_D = Dipole('BM2_D', 3.50, torch.pi/4.0)

DR1_A = Drift('DR1_A', 0.25)
DR2_A = Drift('DR2_A', 0.25)
DR3_A = Drift('DR3_A', 0.25)
DR4_A = Drift('DR4_A', 0.25)
DR5_A = Drift('DR5_A', 0.25)
DR6_A = Drift('DR6_A', 0.25)
DR7_A = Drift('DR7_A', 0.25)
DR1_B = Drift('DR1_B', 0.25)
DR2_B = Drift('DR2_B', 0.25)
DR3_B = Drift('DR3_B', 0.25)
DR4_B = Drift('DR4_B', 0.25)
DR5_B = Drift('DR5_B', 0.25)
DR6_B = Drift('DR6_B', 0.25)
DR7_B = Drift('DR7_B', 0.25)
DR1_C = Drift('DR1_C', 0.25)
DR2_C = Drift('DR2_C', 0.25)
DR3_C = Drift('DR3_C', 0.25)
DR4_C = Drift('DR4_C', 0.25)
DR5_C = Drift('DR5_C', 0.25)
DR6_C = Drift('DR6_C', 0.25)
DR7_C = Drift('DR7_C', 0.25)
DR1_D = Drift('DR1_D', 0.25)
DR2_D = Drift('DR2_D', 0.25)
DR3_D = Drift('DR3_D', 0.25)
DR4_D = Drift('DR4_D', 0.25)
DR5_D = Drift('DR5_D', 0.25)
DR6_D = Drift('DR6_D', 0.25)
DR7_D = Drift('DR7_D', 0.25)

FODO_A = Line('FODO_A', [QF_A, DR1_A, SF1_A, DR2_A, BM1_A, DR3_A, SD1_A, DR3_A, QD_A, DR4_A, SD2_A, DR5_A, BM2_A, DR6_A, SF2_A, DR7_A], propagate=True)
FODO_B = Line('FODO_B', [QF_B, DR1_B, SF1_B, DR2_B, BM1_B, DR3_B, SD1_B, DR3_B, QD_B, DR4_B, SD2_B, DR5_B, BM2_B, DR6_B, SF2_B, DR7_B], propagate=True)
FODO_C = Line('FODO_C', [QF_C, DR1_C, SF1_C, DR2_C, BM1_C, DR3_C, SD1_C, DR3_C, QD_C, DR4_C, SD2_C, DR5_C, BM2_C, DR6_C, SF2_C, DR7_C], propagate=True)
FODO_D = Line('FODO_D', [QF_D, DR1_D, SF1_D, DR2_D, BM1_D, DR3_D, SD1_D, DR3_D, QD_D, DR4_D, SD2_D, DR5_D, BM2_D, DR6_D, SF2_D, DR7_D], propagate=True)

RING = Line('RING', [FODO_A, FODO_B, FODO_C, FODO_D], propagate=True, dp=0.0, exact=False, output=False, matrix=False)

RING.flatten()
[4]:
# Create parametric transformation from one element to another and its inverse

probe = 'SD2_A'
other = 'SF1_D'

forward, *_ = group(RING, probe, other, ('kn', ['Quadrupole'], None, None), ('ms', ['Sextupole'], None, None), ('dp', None, None, None), root=True, alignment=False)
inverse, *_ = group(RING, other, probe, ('kn', ['Quadrupole'], None, None), ('ms', ['Sextupole'], None, None), ('dp', None, None, None), root=True, alignment=False)
[5]:
# Test propagation and inverse transformation

state = torch.tensor([0.001, 0.005, -0.005, 0.001], dtype=torch.float64)

kn = 1.0E-3*torch.randn( 8, dtype=torch.float64)
ms = 1.0E-3*torch.randn(16, dtype=torch.float64)
dp = torch.tensor([0.001], dtype=torch.float64)

print(local := state.clone())
print(local := forward(local, kn, ms, dp))
print(local := inverse(local, kn, ms, dp))
print(state - local)

tensor([ 0.0010,  0.0050, -0.0050,  0.0010], dtype=torch.float64)
tensor([-0.0041,  0.0022,  0.0047, -0.0001], dtype=torch.float64)
tensor([ 0.0010,  0.0050, -0.0050,  0.0010], dtype=torch.float64)
tensor([ 4.3368e-19, -8.6736e-19, -4.3368e-18,  2.1684e-19],
       dtype=torch.float64)
[6]:
# Test derivatives

state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)

kn = torch.zeros( 8, dtype=torch.float64)
ms = torch.zeros(16, dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)

# Transport matrix

print(torch.func.jacrev(forward)(state, kn, ms, dp).inverse())
print(torch.func.jacrev(inverse)(state, kn, ms, dp))
print()

# Derivatives of transport matrix trace with respect to quadrupole deviations

def matrix(kn, ms, dp):
    return torch.func.jacrev(forward)(state, kn, ms, dp).trace()

print(torch.func.jacrev(matrix)(kn, ms, dp))
print()
tensor([[ 0.3917,  0.5359,  0.0000,  0.0000],
        [ 0.4481,  3.1656,  0.0000,  0.0000],
        [-0.0000, -0.0000, -1.1626, -3.2206],
        [ 0.0000,  0.0000,  0.2044, -0.2939]], dtype=torch.float64)
tensor([[ 0.3917,  0.5359,  0.0000,  0.0000],
        [ 0.4481,  3.1656,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -1.1626, -3.2206],
        [ 0.0000,  0.0000,  0.2044, -0.2939]], dtype=torch.float64)

tensor([ 0.0000,  0.0000,  2.6498, 39.4992, 30.9318, 39.5209,  2.0561,  0.0000],
       dtype=torch.float64)