ELETTRA-02: Linear optics comparison
[1]:
# In this example linear optics is computed for Elettra 2.0 storage ring and compared with ELEGANT
# Note, to get smooth optics output elements are sliced, this results in dispersion being slightly off (no difference without slicing)
[2]:
# Import
import torch
from torch import Tensor
from pathlib import Path
import matplotlib
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
matplotlib.rcParams['text.usetex'] = True
from model.library.element import Element
from model.library.line import Line
from model.library.quadrupole import Quadrupole
from model.command.util import select
from model.command.external import load_sdds
from model.command.external import load_lattice
from model.command.build import build
from model.command.layout import Layout
from model.command.tune import tune
from model.command.orbit import dispersion
from model.command.twiss import twiss
from model.command.advance import advance
[3]:
# Set data type and device
Element.dtype = dtype = torch.float64
Element.device = device = torch.device('cpu')
[4]:
# Load ELEGANT data at BPMs
# Note, the following creates a python dictionary using element names as keys
# Repeated elements are collapsed
path = Path('elettra.twiss')
parameters, columns = load_sdds(path)
# Tune
nu_x:Tensor = torch.tensor(parameters['nux'], dtype=dtype)
nu_y:Tensor = torch.tensor(parameters['nuy'], dtype=dtype)
# Chomaticity
xi_x:Tensor = torch.tensor(parameters['dnux/dp'], dtype=dtype)
xi_y:Tensor = torch.tensor(parameters['dnuy/dp'], dtype=dtype)
# Element kind
kinds = select(columns, 'ElementType', keep=False)
# Dispersion
eta_qx = select(columns, 'etax' , keep=False)
eta_px = select(columns, 'etaxp', keep=False)
eta_qy = select(columns, 'etay' , keep=False)
eta_py = select(columns, 'etayp', keep=False)
eta_qx = torch.tensor([value for (key, value), kind in zip(eta_qx.items(), kinds.values()) if kind == 'MONI'], dtype=dtype)
eta_px = torch.tensor([value for (key, value), kind in zip(eta_px.items(), kinds.values()) if kind == 'MONI'], dtype=dtype)
eta_qy = torch.tensor([value for (key, value), kind in zip(eta_qy.items(), kinds.values()) if kind == 'MONI'], dtype=dtype)
eta_py = torch.tensor([value for (key, value), kind in zip(eta_py.items(), kinds.values()) if kind == 'MONI'], dtype=dtype)
# Twiss
a_x = select(columns, 'alphax', keep=False)
b_x = select(columns, 'betax' , keep=False)
a_y = select(columns, 'alphay', keep=False)
b_y = select(columns, 'betay' , keep=False)
a_x:Tensor = torch.tensor([value for (key, value), kind in zip(a_x.items(), kinds.values()) if kind == 'MONI'], dtype=dtype)
b_x:Tensor = torch.tensor([value for (key, value), kind in zip(b_x.items(), kinds.values()) if kind == 'MONI'], dtype=dtype)
a_y:Tensor = torch.tensor([value for (key, value), kind in zip(a_y.items(), kinds.values()) if kind == 'MONI'], dtype=dtype)
b_y:Tensor = torch.tensor([value for (key, value), kind in zip(b_y.items(), kinds.values()) if kind == 'MONI'], dtype=dtype)
# Advance (accumulated from lattice start to each BPM position)
mu_x = select(columns, 'psix', keep=False)
mu_y = select(columns, 'psiy' , keep=False)
mu_x:Tensor = torch.tensor([value for (key, value), kind in zip(mu_x.items(), kinds.values()) if kind == 'MONI'], dtype=dtype)
mu_y:Tensor = torch.tensor([value for (key, value), kind in zip(mu_y.items(), kinds.values()) if kind == 'MONI'], dtype=dtype)
# BPM positions
positions = select(columns, 's', keep=False).items()
positions = [value for (key, value), kind in zip(positions, kinds.values()) if kind == 'MONI']
[5]:
# Load lattice (ELEGANT table)
# Note, lattice is allowed to have repeated elements
path = Path('elettra.lte')
data = load_lattice(path)
[6]:
# Build and setup lattice
ring:Line = build('RING', 'ELEGANT', data)
# Flatten sublines
ring.flatten()
# Remove all marker elements but the ones starting with MLL (long straight section centers)
ring.remove_group(pattern=r'^(?!MLL_).*', kinds=['Marker'])
# Replace all sextupoles with quadrupoles
def factory(element:Element) -> None:
table = element.serialize
table.pop('ms', None)
return Quadrupole(**table)
ring.replace_group(pattern=r'', factory=factory, kinds=['Sextupole'])
# Set linear dipoles
def apply(element:Element) -> None:
element.linear = True
ring.apply(apply, kinds=['Dipole'])
# Merge drifts
ring.merge()
# Split elements (smooth twiss output)
step = 0.05
for element in ring:
if length := element.length:
ring.split((int(length/step), None, [element.name], None))
# Describe
ring.describe
[6]:
{'Marker': 12, 'Drift': 708, 'BPM': 168, 'Quadrupole': 360, 'Dipole': 156}
[7]:
# Compute tunes (fractional part)
nux, nuy = tune(ring, [], matched=True, limit=1)
# Compare with elegant
print(torch.allclose(nu_x % 1, nux))
print(torch.allclose(nu_y % 1, nuy))
True
True
[8]:
# Compute and compare dispersion
# Note, the values are not matched exactly due to splitting of dipoles
orbit = torch.tensor(4*[0.0], dtype=dtype)
etaqx, etapx, etaqy, etapy = dispersion(ring, orbit, [], limit=1)
mask = [kind == 'BPM' for (_, kind, *_) in ring.layout()]
etaqx_bpm = etaqx[mask]
etapx_bpm = etapx[mask]
etaqy_bpm = etaqy[mask]
etapy_bpm = etapy[mask]
print(torch.allclose(eta_qx, etaqx_bpm))
print(torch.allclose(eta_px, etapx_bpm))
print(torch.allclose(eta_qy, etaqy_bpm))
print(torch.allclose(eta_py, etapy_bpm))
layout = Layout(ring)
_, _, lengths, *_ = layout.slicing_table()
rectangles = layout.profile_1d(scale=0.005, shift=-0.005, text=False, exclude=['Marker', 'BPM'])
plt.figure(figsize=(12, 6))
plt.errorbar(ring.locations().cpu().numpy(), etaqx.cpu().numpy(), fmt='-', color='red', alpha=0.75)
plt.errorbar(positions, eta_qx.cpu().numpy(), fmt=' ', color='black', alpha=1.0, marker='x')
for rectangle in rectangles:
plt.gca().add_patch(Rectangle(**rectangle))
plt.tight_layout()
plt.xlim(0.5*ring.length/12.0, 1.5*ring.length/12.0)
plt.show()
False
False
True
True
[9]:
# Compute and compare twiss parameters
ax, bx, ay, by = twiss(ring, [], matched=True, advance=True, full=False).T
mask = [kind == 'BPM' for (_, kind, *_) in ring.layout()]
ax_bpm = ax[mask]
bx_bpm = bx[mask]
ay_bpm = ay[mask]
by_bpm = by[mask]
print(torch.allclose(a_x, ax_bpm))
print(torch.allclose(b_x, bx_bpm))
print(torch.allclose(a_y, ay_bpm))
print(torch.allclose(b_y, by_bpm))
layout = Layout(ring)
_, _, lengths, *_ = layout.slicing_table()
rectangles = layout.profile_1d(scale=2.0, shift=-1.0, text=False, exclude=['Marker', 'BPM'])
plt.figure(figsize=(12, 6))
plt.errorbar(ring.locations().cpu().numpy(), bx.cpu().numpy(), fmt='-', color='red', alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), by.cpu().numpy(), fmt='-', color='blue', alpha=0.75)
plt.errorbar(positions, b_x.cpu().numpy(), fmt=' ', color='black', alpha=1.0, marker='x')
plt.errorbar(positions, b_y.cpu().numpy(), fmt=' ', color='black', alpha=1.0, marker='x')
for rectangle in rectangles:
plt.gca().add_patch(Rectangle(**rectangle))
plt.xlim(0.5*ring.length/12.0, 1.5*ring.length/12.0)
plt.tight_layout()
plt.show()
True
True
True
True
[10]:
# Compute and compare phase advances
mux, muy = advance(ring, [], alignment=False, matched=True).T
mux = mux.cumsum(-1)
muy = muy.cumsum(-1)
mask = [kind == 'BPM' for (_, kind, *_) in ring.layout()]
mux_bpm = mux[mask]
muy_bpm = muy[mask]
print(torch.allclose(mux_bpm, mu_x))
print(torch.allclose(muy_bpm, mu_y))
layout = Layout(ring)
_, _, lengths, *_ = layout.slicing_table()
rectangles = layout.profile_1d(scale=1.5, shift=-0.5, text=False, exclude=['Marker', 'BPM'])
plt.figure(figsize=(12, 6))
plt.errorbar(ring.locations().cpu().numpy(), mux.cpu().numpy(), fmt='-', color='red', alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), muy.cpu().numpy(), fmt='-', color='blue', alpha=0.75)
plt.errorbar(positions, mu_x.cpu().numpy(), fmt=' ', color='black', alpha=1.0, marker='x')
plt.errorbar(positions, mu_y.cpu().numpy(), fmt=' ', color='black', alpha=1.0, marker='x')
for rectangle in rectangles:
plt.gca().add_patch(Rectangle(**rectangle))
plt.xlim(0.5*ring.length/12.0, 1.5*ring.length/12.0)
plt.ylim(-1, 2*torch.pi*max(nu_x, nu_y)/8)
plt.tight_layout()
plt.show()
True
True
[ ]: