Example-62: Kick map (element)
[1]:
# In this example AT style kick map element is illustrated
# The kick table is generated with Radia and saved as a mat file
# Effect on linear optics from linear ID model and kick map are compared
# Horizontal phase space trajectories are compared with and without ID
[2]:
# Import
import torch
from torch import Tensor
from pathlib import Path
from tqdm import tqdm
import matplotlib
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
matplotlib.rcParams['text.usetex'] = True
from model.library.element import Element
from model.library.line import Line
from model.library.corrector import Corrector
from model.library.quadrupole import Quadrupole
from model.library.matrix import Matrix
from model.library.kickmap import KM
from model.command.external import load_lattice
from model.command.build import build
from model.command.tune import tune
from model.command.tune import chromaticity
from model.command.orbit import dispersion
from model.command.orbit import ORM
from model.command.twiss import twiss
from model.command.advance import advance
from model.command.coupling import coupling
[3]:
# Set data type and device
Element.dtype = dtype = torch.float64
Element.device = device = torch.device('cpu')
[4]:
# Load lattice (ELEGANT table)
# Note, lattice is allowed to have repeated elements
path = Path('elettra.lte')
data = load_lattice(path)
[5]:
# Build and setup lattice
ring:Line = build('RING', 'ELEGANT', data)
# Flatten sublines
ring.flatten()
# Remove all marker elements but the ones starting with MLL (long straight section centers)
ring.remove_group(pattern=r'^(?!MSS_)(?!MLL_).*', kinds=['Marker'])
# Set sextupole integration order and step size
ring.order = (('Sextupole', 1), )
ring.ns = (('Sextupole', 0.01), )
# Set linear dipoles
def apply(element:Element) -> None:
element.linear = True
ring.apply(apply, kinds=['Dipole'])
# Insert correctors
for name, *_ in ring.layout():
if name.startswith('CH'):
corrector = Corrector(f'{name}_CXY', factor=1)
ring.split((1 + 1, None, [name], None), paste=[corrector])
# Merge drifts
ring.merge()
# Change lattice start start
ring.start = "BPM_S01_01"
# Split BPMs
ring.split((None, ['BPM'], None, None))
# Roll lattice
ring.roll(1)
# Splice
ring.splice()
# Describe
ring.describe
[5]:
{'BPM': 168,
'Drift': 744,
'Dipole': 156,
'Sextupole': 240,
'Quadrupole': 120,
'Corrector': 24,
'Marker': 24}
[6]:
# Compute tunes (fractional part)
nux, nuy = tune(ring, [], matched=True, limit=1)
[7]:
# Compute dispersion
orbit = torch.tensor(4*[0.0], dtype=dtype, device=device)
etaqx, etapx, etaqy, etapy = dispersion(ring, orbit, [], limit=1)
[8]:
# Compute twiss parameters
ax, bx, ay, by = twiss(ring, [], matched=True, advance=True, full=False).T
[9]:
# Compute phase advances
mux, muy = advance(ring, [], alignment=False, matched=True).T
[10]:
# Compute coupling
c = coupling(ring, [])
[11]:
# Compute chromaticity
psi = chromaticity(ring, [], matched=True)
[12]:
# Define ID (linear model)
A = torch.tensor([[-0.0344386, 0., 0., 0.], [0., -0.0445673, 0., 0.], [0., 0., 0.056303, 0.], [0., 0., 0., 0.0804237]], dtype=dtype)
CENTER = 0.0
ID = Matrix('ID', length=0.0, A=A[torch.triu(torch.ones_like(A, dtype=torch.bool))].tolist())
[13]:
# Insert ID (linear model)
error = ring.clone()
error.flatten()
error.insert(ID, error.next('MLL_S01').name, position=CENTER*1.0E-3)
error.splice()
error.describe
[13]:
{'BPM': 168,
'Drift': 745,
'Dipole': 156,
'Sextupole': 240,
'Quadrupole': 120,
'Corrector': 24,
'Marker': 24,
'Matrix': 1}
[14]:
# Compute tunes (fractional part)
nux_id_lm, nuy_id_lm = tune(error, [], matched=True, limit=1)
[15]:
# Compute dispersion
orbit = torch.tensor(4*[0.0], dtype=dtype)
etaqx_id_lm, etapx_id_lm, etaqy_id_lm, etapy_id_lm = dispersion(error, orbit, [], limit=1)
[16]:
# Compute twiss parameters
ax_id_lm, bx_id_lm, ay_id_lm, by_id_lm = twiss(error, [], matched=True, advance=True, full=False).T
[17]:
# Compute phase advances
mux_id_lm, muy_id_lm = advance(error, [], alignment=False, matched=True).T
[18]:
# Compute coupling
c_id_lm = coupling(error, [])
[19]:
# Compute chromaticity
psi_id_lm = chromaticity(error, [])
[20]:
# Define ID (load kick map)
CENTER = 0.0
ID = KM(name='ID', path=Path('id.mat'), energy=2.4, count=40, insertion=True)
[21]:
# Insert ID (kick map)
error = ring.clone()
error.flatten()
error.insert(ID, error.next('MLL_S01').name, position=CENTER*1.0E-3)
error.splice()
error.describe
[21]:
{'BPM': 168,
'Drift': 745,
'Dipole': 156,
'Sextupole': 240,
'Quadrupole': 120,
'Corrector': 24,
'Marker': 24,
'KM': 1}
[22]:
# Compute tunes (fractional part)
nux_id_km, nuy_id_km = tune(error, [], matched=True, limit=1)
[23]:
# Compute dispersion
orbit = torch.tensor(4*[0.0], dtype=dtype)
etaqx_id_km, etapx_id_km, etaqy_id_km, etapy_id_km = dispersion(error, orbit, [], limit=1)
[24]:
# Compute twiss parameters
ax_id_km, bx_id_km, ay_id_km, by_id_km = twiss(error, [], matched=True, advance=True, full=False).T
[25]:
# Compute phase advances
mux_id_km, muy_id_km = advance(error, [], alignment=False, matched=True).T
[26]:
# Compute coupling
c_id_km = coupling(error, [])
[27]:
# Compute chromaticity
psi_id_km = chromaticity(error, [])
[28]:
# Tune shifts
print((nux - nux_id_lm))
print((nuy - nuy_id_lm))
print()
print((nux - nux_id_km))
print((nuy - nuy_id_km))
print()
tensor(0.0257, dtype=torch.float64)
tensor(-0.0112, dtype=torch.float64)
tensor(0.0256, dtype=torch.float64)
tensor(-0.0113, dtype=torch.float64)
[29]:
# Coupling (minimal tune distance)
print(c)
print(c_id_lm)
print(c_id_km)
tensor(0., dtype=torch.float64)
tensor(0., dtype=torch.float64)
tensor(9.8422e-09, dtype=torch.float64)
[30]:
# Chromaticity
print(psi)
print(psi_id_lm)
print(psi_id_km)
tensor([2.0296, 2.0131], dtype=torch.float64)
tensor([3.2067, 1.5121], dtype=torch.float64)
tensor([3.2342, 1.4944], dtype=torch.float64)
[31]:
# Beta-beating and dispersion
def rms(x):
return (x**2).mean().sqrt()
bx_bb_lm = 100.0*(bx - bx_id_lm) / bx
by_bb_lm = 100.0*(by - by_id_lm) / by
bx_bb_km = 100.0*(bx - bx_id_km) / bx
by_bb_km = 100.0*(by - by_id_km) / by
rms_x_lm = rms(bx_bb_lm).item()
ptp_x_lm = (bx_bb_lm.max() - bx_bb_lm.min()).item()
rms_y_lm = rms(by_bb_lm).item()
ptp_y_lm = (by_bb_lm.max() - by_bb_lm.min()).item()
rms_x_km = rms(bx_bb_km).item()
ptp_x_km = (bx_bb_km.max() - bx_bb_km.min()).item()
rms_y_km = rms(by_bb_km).item()
ptp_y_km = (by_bb_km.max() - by_bb_km.min()).item()
s = ring.locations().cpu().numpy()
bx_np_lm = bx_bb_lm.cpu().numpy()
by_np_lm = by_bb_lm.cpu().numpy()
bx_np_km = bx_bb_km.cpu().numpy()
by_np_km = by_bb_km.cpu().numpy()
etax_lm = etaqx - etaqx_id_lm
etay_lm = etaqy - etaqy_id_lm
etax_km = etaqx - etaqx_id_km
etay_km = etaqy - etaqy_id_km
rms_etax_lm = rms(etax_lm).item()
ptp_etax_lm = (etax_lm.max() - etax_lm.min()).item()
rms_etay_lm = rms(etay_lm).item()
ptp_etay_lm = (etay_lm.max() - etay_lm.min()).item()
rms_etax_km = rms(etax_km).item()
ptp_etax_km = (etax_km.max() - etax_km.min()).item()
rms_etay_km = rms(etay_km).item()
ptp_etaf_km = (etay_km.max() - etay_km.min()).item()
etax_np_lm = etax_lm.cpu().numpy()
etay_np_lm = etay_lm.cpu().numpy()
etax_np_km = etax_km.cpu().numpy()
etay_np_km = etay_km.cpu().numpy()
fig, (ax, ay) = plt.subplots(2, 1, figsize=(16, 10), sharex=True, gridspec_kw={'hspace': 0.1})
ax.errorbar(s, bx_np_lm, fmt='-', marker='o', color='blue', alpha=0.75, lw=2.0, label=r'LM', markerfacecolor='none', ms=8)
ax.errorbar(s, by_np_lm, fmt='-', marker='o', color='red', alpha=0.75, lw=2.0, label=r'LM', markerfacecolor='none', ms=8)
ax.errorbar(s, bx_np_km, fmt='-', marker='s', color='blue', alpha=0.75, lw=2.0, label=r'KM', markerfacecolor='none', ms=8)
ax.errorbar(s, by_np_km, fmt='-', marker='s', color='red', alpha=0.75, lw=2.0, label=r'KM', markerfacecolor='none', ms=8)
ax.set_xlabel('s [m]', fontsize=18)
ax.set_ylabel(r'$\Delta \beta / \beta$ [\%]', fontsize=18)
ax.tick_params(width=2, labelsize=16)
ax.tick_params(axis='x', length=8, direction='in')
ax.tick_params(axis='y', length=8, direction='in')
title = (
rf'RMS_LM=({rms_x_lm:6.3f},{rms_y_lm:6.3f})\% \quad '
rf'RMS_KM=({rms_x_km:6.3f},{rms_y_km:6.3f})\% \quad '
)
ax.text(0.0, 1.15, title, transform=ax.transAxes, ha='left', va='bottom', fontsize=16, fontfamily='monospace')
title = (
rf'DNU_LM=({(lambda x: '-' if x < 0 else '~')(nux - nux_id_lm)}{(nux - nux_id_lm).abs().item():05.4f},{(lambda x: '-' if x < 0 else '~')(nuy - nuy_id_lm)}{(nuy - nuy_id_lm).abs().item():05.4f}) \quad '
rf'DNU_KM=({(lambda x: '-' if x < 0 else '~')(nux - nux_id_km)}{(nux - nux_id_km).abs().item():05.4f},{(lambda x: '-' if x < 0 else '~')(nuy - nuy_id_km)}{(nuy - nuy_id_km).abs().item():05.4f}) \quad '
)
ax.text(0.0, 1.075, title, transform=ax.transAxes, ha='left', va='bottom', fontsize=16, fontfamily='monospace')
title = (
rf'\quad C=({c_id_lm.item():.6f}, {c_id_km.item():.6f})'
)
ax.text(0.0, 1.00, title, transform=ax.transAxes, ha='left', va='bottom', fontsize=16, fontfamily='monospace')
ax.legend(loc='upper right', frameon=False, fontsize=14, ncol=6)
ax.set_ylim(-20, 20)
ay.errorbar(s, 10**6*etax_np_lm, fmt='-', marker='o', color='blue', alpha=0.75, lw=2.0, label=r'LM', markerfacecolor='none', ms=8)
ay.errorbar(s, 10**6*etay_np_lm, fmt='-', marker='o', color='red', alpha=0.75, lw=2.0, label=r'LM', markerfacecolor='none', ms=8)
ay.errorbar(s, 10**6*etax_np_km, fmt='-', marker='s', color='blue', alpha=0.75, lw=2.0, label=r'KM', markerfacecolor='none', ms=8)
ay.errorbar(s, 10**6*etay_np_km, fmt='-', marker='s', color='red', alpha=0.75, lw=2.0, label=r'KM', markerfacecolor='none', ms=8)
ay.set_xlabel('s [m]', fontsize=18)
ay.set_ylabel(r'$\Delta \eta$ [$\mu$m]', fontsize=18)
ay.tick_params(width=2, labelsize=16)
ay.tick_params(axis='x', length=8, direction='in')
ay.tick_params(axis='y', length=8, direction='in')
plt.setp(ax.spines.values(), linewidth=2.0)
plt.setp(ay.spines.values(), linewidth=2.0)
plt.show()
[32]:
# Phase space without ID
qx = torch.linspace(0.0, 0.01, 32, dtype=torch.float64)
px = torch.zeros_like(qx)
qy = torch.zeros_like(qx)
py = torch.zeros_like(qx)
state = torch.stack([qx, px, qy, py]).T
orbit = []
for _ in tqdm(range(2**10)):
state = torch.vmap(ring)(state)
orbit.append(state)
qx, px, *_ = torch.stack(orbit).swapaxes(0, -1)
plt.figure(figsize=(6, 6))
plt.scatter(qx.cpu().numpy(), px.cpu().numpy(), s=1, color='black')
plt.xlim(-0.0125, 0.0100)
plt.ylim(-0.005, 0.002)
plt.show()
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1024/1024 [18:19<00:00, 1.07s/it]
[33]:
# Phase space with ID
qx = torch.linspace(0.0, 0.01, 32, dtype=torch.float64)
px = torch.zeros_like(qx)
qy = torch.zeros_like(qx)
py = torch.zeros_like(qx)
state = torch.stack([qx, px, qy, py]).T
orbit = []
for _ in tqdm(range(2**10)):
state = torch.vmap(error)(state)
orbit.append(state)
qx, px, *_ = torch.stack(orbit).swapaxes(0, -1)
plt.figure(figsize=(6, 6))
plt.scatter(qx.cpu().numpy(), px.cpu().numpy(), s=1, color='black')
plt.xlim(-0.0125, 0.0100)
plt.ylim(-0.005, 0.002)
plt.show()
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1024/1024 [19:02<00:00, 1.12s/it]