ELETTRA-04: Tune response matrix
[1]:
# In this example tune responce matrix is constructed
[2]:
# Import
import torch
from torch import Tensor
from pathlib import Path
import matplotlib
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
matplotlib.rcParams['text.usetex'] = True
from model.library.element import Element
from model.library.line import Line
from model.library.quadrupole import Quadrupole
from model.command.external import load_lattice
from model.command.build import build
from model.command.tune import tune
[3]:
# Set data type and device
Element.dtype = dtype = torch.float64
Element.device = device = torch.device('cpu')
[4]:
# Load lattice (ELEGANT table)
# Note, lattice is allowed to have repeated elements
path = Path('elettra.lte')
data = load_lattice(path)
[5]:
# Build and setup lattice
ring:Line = build('RING', 'ELEGANT', data)
# Flatten sublines
ring.flatten()
# Remove all marker elements but the ones starting with MLL (long straight section centers)
ring.remove_group(pattern=r'^(?!MLL_).*', kinds=['Marker'])
# Replace all sextupoles with quadrupoles
def factory(element:Element) -> None:
table = element.serialize
table.pop('ms', None)
return Quadrupole(**table)
ring.replace_group(pattern=r'', factory=factory, kinds=['Sextupole'])
# Set linear dipoles
def apply(element:Element) -> None:
element.linear = True
ring.apply(apply, kinds=['Dipole'])
# Merge drifts
ring.merge()
# Change lattice start
ring.start = "BPM_S01_01"
# Split BPMs
ring.split((None, ['BPM'], None, None))
# Roll lattice
ring.roll(1)
# Splice lattice
ring.splice()
# Describe
ring.describe
[5]:
{'BPM': 168, 'Drift': 708, 'Dipole': 156, 'Quadrupole': 360, 'Marker': 12}
[6]:
# Set quadrupole names for global tune correction
QF = [f'QF_S{i:02}_{j:02}' for j in [2, 3] for i in range(1, 12 + 1)]
QD = [f'QD_S{i:02}_{j:02}' for j in [2, 3] for i in range(1, 12 + 1)]
[7]:
# Compute response matrix (individual quadrupoles)
def observable(knobs):
kn = knobs
return tune(ring, [kn], ('kn', None, QF + QD, None), matched=True, limit=1)
knobs = torch.zeros(len(QF + QD), dtype=dtype)
matrix = torch.func.jacrev(observable)(knobs)
print(matrix.shape)
print(matrix)
print(matrix.reshape(2, 2, (len(QF) + len(QD)) // 2).sum(-1))
torch.Size([2, 48])
tensor([[ 0.2439, 0.2439, 0.2439, 0.2439, 0.2439, 0.2439, 0.2439, 0.2439,
0.2439, 0.2439, 0.2439, 0.2439, 0.2439, 0.2439, 0.2439, 0.2439,
0.2439, 0.2439, 0.2439, 0.2439, 0.2439, 0.2439, 0.2439, 0.2439,
0.0873, 0.0873, 0.0873, 0.0873, 0.0873, 0.0873, 0.0873, 0.0873,
0.0873, 0.0873, 0.0873, 0.0873, 0.0873, 0.0873, 0.0873, 0.0873,
0.0873, 0.0873, 0.0873, 0.0873, 0.0873, 0.0873, 0.0873, 0.0873],
[-0.1247, -0.1247, -0.1247, -0.1247, -0.1247, -0.1247, -0.1247, -0.1247,
-0.1247, -0.1247, -0.1247, -0.1247, -0.1247, -0.1247, -0.1247, -0.1247,
-0.1247, -0.1247, -0.1247, -0.1247, -0.1247, -0.1247, -0.1247, -0.1247,
-0.0525, -0.0525, -0.0525, -0.0525, -0.0525, -0.0525, -0.0525, -0.0525,
-0.0525, -0.0525, -0.0525, -0.0525, -0.0525, -0.0525, -0.0525, -0.0525,
-0.0525, -0.0525, -0.0525, -0.0525, -0.0525, -0.0525, -0.0525, -0.0525]],
dtype=torch.float64)
tensor([[ 5.8543, 2.0964],
[-2.9918, -1.2602]], dtype=torch.float64)
[8]:
# Compute response matrix (two quadrupole families)
def observable(knobs):
kf, kd = knobs
kn = torch.stack(len(QF)*[kf] + len(QD)*[kd])
return tune(ring, [kn], ('kn', None, QF + QD, None), matched=True, limit=1)
knobs = torch.tensor([0.0, 0.0], dtype=dtype)
matrix = torch.func.jacrev(observable)(knobs)
print(matrix.shape)
print(matrix)
torch.Size([2, 2])
tensor([[ 5.8543, 2.0964],
[-2.9918, -1.2602]], dtype=torch.float64)