ELETTRA-04: Tune response matrix

[1]:
# In this example tune responce matrix is constructed
[2]:
# Import

import torch
from torch import Tensor

from pathlib import Path

import matplotlib
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
matplotlib.rcParams['text.usetex'] = True

from model.library.element import Element
from model.library.line import Line
from model.library.quadrupole import Quadrupole

from model.command.external import load_lattice
from model.command.build import build
from model.command.tune import tune
[3]:
# Set data type and device

Element.dtype = dtype = torch.float64
Element.device = device = torch.device('cpu')
[4]:
# Load lattice (ELEGANT table)
# Note, lattice is allowed to have repeated elements

path = Path('elettra.lte')
data = load_lattice(path)
[5]:
# Build and setup lattice

ring:Line = build('RING', 'ELEGANT', data)

# Flatten sublines

ring.flatten()

# Remove all marker elements but the ones starting with MLL (long straight section centers)

ring.remove_group(pattern=r'^(?!MLL_).*', kinds=['Marker'])

# Replace all sextupoles with quadrupoles

def factory(element:Element) -> None:
    table = element.serialize
    table.pop('ms', None)
    return Quadrupole(**table)

ring.replace_group(pattern=r'', factory=factory, kinds=['Sextupole'])

# Set linear dipoles

def apply(element:Element) -> None:
    element.linear = True

ring.apply(apply, kinds=['Dipole'])

# Merge drifts

ring.merge()

# Change lattice start

ring.start = "BPM_S01_01"

# Split BPMs

ring.split((None, ['BPM'], None, None))

# Roll lattice

ring.roll(1)

# Splice lattice

ring.splice()

# Describe

ring.describe
[5]:
{'BPM': 168, 'Drift': 708, 'Dipole': 156, 'Quadrupole': 360, 'Marker': 12}
[6]:
# Set quadrupole names for global tune correction

QF = [f'QF_S{i:02}_{j:02}' for j in [2, 3] for i in range(1, 12 + 1)]
QD = [f'QD_S{i:02}_{j:02}' for j in [2, 3] for i in range(1, 12 + 1)]
[7]:
# Compute response matrix (individual quadrupoles)

def observable(knobs):
    kn = knobs
    return tune(ring, [kn], ('kn', None, QF + QD, None), matched=True, limit=1)

knobs = torch.zeros(len(QF + QD), dtype=dtype)
matrix = torch.func.jacrev(observable)(knobs)

print(matrix.shape)
print(matrix)
print(matrix.reshape(2, 2, (len(QF) + len(QD)) // 2).sum(-1))
torch.Size([2, 48])
tensor([[ 0.2439,  0.2439,  0.2439,  0.2439,  0.2439,  0.2439,  0.2439,  0.2439,
          0.2439,  0.2439,  0.2439,  0.2439,  0.2439,  0.2439,  0.2439,  0.2439,
          0.2439,  0.2439,  0.2439,  0.2439,  0.2439,  0.2439,  0.2439,  0.2439,
          0.0873,  0.0873,  0.0873,  0.0873,  0.0873,  0.0873,  0.0873,  0.0873,
          0.0873,  0.0873,  0.0873,  0.0873,  0.0873,  0.0873,  0.0873,  0.0873,
          0.0873,  0.0873,  0.0873,  0.0873,  0.0873,  0.0873,  0.0873,  0.0873],
        [-0.1247, -0.1247, -0.1247, -0.1247, -0.1247, -0.1247, -0.1247, -0.1247,
         -0.1247, -0.1247, -0.1247, -0.1247, -0.1247, -0.1247, -0.1247, -0.1247,
         -0.1247, -0.1247, -0.1247, -0.1247, -0.1247, -0.1247, -0.1247, -0.1247,
         -0.0525, -0.0525, -0.0525, -0.0525, -0.0525, -0.0525, -0.0525, -0.0525,
         -0.0525, -0.0525, -0.0525, -0.0525, -0.0525, -0.0525, -0.0525, -0.0525,
         -0.0525, -0.0525, -0.0525, -0.0525, -0.0525, -0.0525, -0.0525, -0.0525]],
       dtype=torch.float64)
tensor([[ 5.8543,  2.0964],
        [-2.9918, -1.2602]], dtype=torch.float64)
[8]:
# Compute response matrix (two quadrupole families)

def observable(knobs):
    kf, kd = knobs
    kn = torch.stack(len(QF)*[kf] + len(QD)*[kd])
    return tune(ring, [kn], ('kn', None, QF + QD, None), matched=True, limit=1)

knobs = torch.tensor([0.0, 0.0], dtype=dtype)
matrix = torch.func.jacrev(observable)(knobs)

print(matrix.shape)
print(matrix)
torch.Size([2, 2])
tensor([[ 5.8543,  2.0964],
        [-2.9918, -1.2602]], dtype=torch.float64)