Example-46: Twiss (Optics correction)
[1]:
# In this example model response matrices of normal and chromatic Twiss parameters are used for correction
# ML style optimization is also performed for optics correction
[2]:
# Import
from pprint import pprint
import torch
from torch import Tensor
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from pathlib import Path
import matplotlib
from matplotlib import pyplot as plt
matplotlib.rcParams['text.usetex'] = True
from model.library.line import Line
from model.command.util import select
from model.command.external import load_sdds
from model.command.external import load_lattice
from model.command.build import build
from model.command.wrapper import group
from model.command.wrapper import forward
from model.command.wrapper import inverse
from model.command.wrapper import normalize
from model.command.wrapper import Wrapper
from model.command.tune import tune
from model.command.twiss import twiss
from model.command.twiss import chromatic_twiss
[3]:
# Load ELEGANT twiss
path = Path('ic.twiss')
parameters, columns = load_sdds(path)
nu_qx:Tensor = torch.tensor(parameters['nux'] % 1, dtype=torch.float64)
nu_qy:Tensor = torch.tensor(parameters['nuy'] % 1, dtype=torch.float64)
# Set twiss parameters at BPMs
kinds = select(columns, 'ElementType', keep=False)
a_qx = select(columns, 'alphax', keep=False)
b_qx = select(columns, 'betax' , keep=False)
a_qy = select(columns, 'alphay', keep=False)
b_qy = select(columns, 'betay' , keep=False)
a_qx:Tensor = torch.tensor([value for (key, value), kind in zip(a_qx.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)
b_qx:Tensor = torch.tensor([value for (key, value), kind in zip(b_qx.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)
a_qy:Tensor = torch.tensor([value for (key, value), kind in zip(a_qy.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)
b_qy:Tensor = torch.tensor([value for (key, value), kind in zip(b_qy.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)
positions = select(columns, 's', keep=False).items()
positions = [value for (key, value), kind in zip(positions, kinds.values()) if kind == 'MONI']
[4]:
# Build and setup lattice
# Load ELEGANT table
path = Path('ic.lte')
data = load_lattice(path)
# Build ELEGANT table
ring:Line = build('RING', 'ELEGANT', data)
ring.flatten()
# Merge drifts
ring.merge()
# Split BPMs
ring.split((None, ['BPM'], None, None))
# Roll lattice start
ring.roll(1)
# Set linear dipoles
for element in ring:
if element.__class__.__name__ == 'Dipole':
element.linear = True
# Split lattice into lines by BPMs
ring.splice()
# Set number of elements of different kinds
nb = ring.describe['BPM']
nq = ring.describe['Quadrupole']
ns = ring.describe['Sextupole']
[5]:
# Compare tunes
nuqx, nuqy = tune(ring, [], alignment=False, matched=True)
print(torch.allclose(nu_qx, nuqx))
print(torch.allclose(nu_qy, nuqy))
True
True
[6]:
# Compare twiss
aqx, bqx, aqy, bqy = twiss(ring, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
print(torch.allclose(a_qx, aqx))
print(torch.allclose(b_qx, bqx))
print(torch.allclose(a_qy, aqy))
print(torch.allclose(b_qy, bqy))
True
True
True
True
[7]:
# Test derivatives with respect kn and ks at the lattice start
kn = torch.zeros(nq, dtype=torch.float64)
ks = torch.zeros(nq, dtype=torch.float64)
pprint(torch.func.jacrev(lambda kn: twiss(ring, [kn], ('kn', ['Quadrupole'], None, None), matched=True))(kn))
print()
pprint(torch.func.jacrev(lambda ks: twiss(ring, [ks], ('ks', ['Quadrupole'], None, None), matched=True))(ks))
print()
# Note, first order derivatives with respect to ks are identicaly equal to zero as expected
# Second order derivative is not identicaly equal to zero in general
# In the following, only first order derivatives are used for optics correctios (lattice without coupling)
tensor([[-0.0893, 0.4014, 1.2554, -0.9068, -1.5491, -0.9866, -0.6147, -0.7071,
-1.9186, 0.2045, -0.1659, 0.4221, 1.9239, 2.2147, 0.5854, -0.4487,
-0.4684, -1.9595, -0.0112, -0.2204, -0.9209, -1.5967, -0.0541, 1.5081,
0.5988, -0.3222, -0.4638, 0.8415],
[ 0.0182, 0.1496, 0.5733, -0.6730, -0.7756, -0.5178, -0.4104, -0.5325,
-0.9651, 0.2699, -0.0145, 0.1734, 0.9035, 1.1954, 0.3374, -0.2984,
-0.4055, -1.0062, 0.1660, -0.0147, -0.4518, -0.8235, 0.1829, 0.8281,
0.3562, -0.2319, -0.4848, 0.2358],
[ 1.6183, -0.0219, -0.2993, -0.0049, 0.3526, 0.2005, -0.5750, -0.4693,
0.1832, 0.0065, -0.5948, -2.2667, -0.8472, 0.8119, 2.2819, 0.7206,
0.0759, -0.2127, 0.4684, 0.6358, -0.0792, -0.2893, -0.0394, 0.3584,
0.2001, -1.2685, -0.6598, 0.2486],
[-0.7202, 0.1496, 0.2258, -0.0342, -0.1608, -0.0228, 0.3921, 0.2789,
-0.1348, 0.0622, 0.4565, 1.3632, 0.4758, -0.5180, -1.3618, -0.3315,
0.0206, 0.1024, -0.2828, -0.3332, 0.1447, 0.2237, -0.0135, -0.1681,
0.0184, 1.0088, 0.4647, -0.2653]], dtype=torch.float64)
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.]], dtype=torch.float64)
[8]:
# Compute twiss derivatives with respect to quadrupole settings (normal and chromatic)
def fn_dtwiss_dkn(kn):
return twiss(ring, [kn], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)
def fn_dtwiss_dp_dkn(kn):
return chromatic_twiss(ring, [kn], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)
kn = torch.zeros(nq, dtype=torch.float64)
dtwiss_dkn = torch.func.jacrev(fn_dtwiss_dkn)(kn)
dtwiss_dp_dkn = torch.func.jacrev(fn_dtwiss_dp_dkn)(kn)
print(dtwiss_dkn.shape)
print(dtwiss_dp_dkn.shape)
torch.Size([16, 4, 28])
torch.Size([16, 4, 28])
[9]:
# Set lattice with focusing errors (no coupling)
error:Line = ring.clone()
nq = error.describe['Quadrupole']
error_kn = 0.1*torch.randn(nq, dtype=torch.float64)
index = 0
label = ''
for line in error.sequence:
for element in line:
if element.__class__.__name__ == 'Quadrupole':
if label != element.name:
index +=1
label = element.name
element.kn = (element.kn + error_kn[index - 1]).item()
[10]:
# Compute twiss and plot beta beating
ax_model, bx_model, ay_model, by_model = twiss(ring, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
ax_error, bx_error, ay_error, by_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
# Compare twiss
print((ax_model - ax_error).norm())
print((bx_model - bx_error).norm())
print((ay_model - ay_error).norm())
print((by_model - by_error).norm())
print()
# Plot beta beating
plt.figure(figsize=(16, 2))
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')
plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()
tensor(1.3708, dtype=torch.float64)
tensor(0.8085, dtype=torch.float64)
tensor(0.5866, dtype=torch.float64)
tensor(0.3774, dtype=torch.float64)
![../_images/examples_model-45_10_1.png](../_images/examples_model-45_10_1.png)
[11]:
# Test Twiss response
twiss_error = torch.stack([ax_error, bx_error, ay_error, by_error])
twiss_model = torch.stack([ax_model, bx_model, ay_model, by_model])
print((twiss_error - (twiss_model + 0.0*(dtwiss_dkn @ error_kn).T)).norm())
print((twiss_error - (twiss_model + 1.0*(dtwiss_dkn @ error_kn).T)).norm())
tensor(1.7376, dtype=torch.float64)
tensor(0.2810, dtype=torch.float64)
[12]:
# Perform correction (model to experiment)
# Set response matrix
matrix = dtwiss_dkn.reshape(-1, nq)
# Set target twiss parameters
twiss_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True)
# Set learning rate
lr = 0.1
# Set initial values
kn = torch.zeros_like(error_kn)
# Fit
for _ in range(32):
twiss_model = twiss(ring, [kn], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)
dkn = - lr*torch.linalg.lstsq(matrix, (twiss_model - twiss_error).flatten(), driver='gelsd').solution
kn += dkn
print((twiss_model - twiss_error).norm())
# Plot final quadrupole settings
plt.figure(figsize=(16, 2))
plt.bar(range(len(error_kn)), error_kn.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(kn)), +kn.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
tensor(1.7376, dtype=torch.float64)
tensor(1.5638, dtype=torch.float64)
tensor(1.4067, dtype=torch.float64)
tensor(1.2649, dtype=torch.float64)
tensor(1.1371, dtype=torch.float64)
tensor(1.0220, dtype=torch.float64)
tensor(0.9184, dtype=torch.float64)
tensor(0.8252, dtype=torch.float64)
tensor(0.7415, dtype=torch.float64)
tensor(0.6663, dtype=torch.float64)
tensor(0.5989, dtype=torch.float64)
tensor(0.5384, dtype=torch.float64)
tensor(0.4842, dtype=torch.float64)
tensor(0.4357, dtype=torch.float64)
tensor(0.3922, dtype=torch.float64)
tensor(0.3533, dtype=torch.float64)
tensor(0.3185, dtype=torch.float64)
tensor(0.2873, dtype=torch.float64)
tensor(0.2595, dtype=torch.float64)
tensor(0.2345, dtype=torch.float64)
tensor(0.2122, dtype=torch.float64)
tensor(0.1922, dtype=torch.float64)
tensor(0.1743, dtype=torch.float64)
tensor(0.1583, dtype=torch.float64)
tensor(0.1440, dtype=torch.float64)
tensor(0.1311, dtype=torch.float64)
tensor(0.1195, dtype=torch.float64)
tensor(0.1091, dtype=torch.float64)
tensor(0.0998, dtype=torch.float64)
tensor(0.0914, dtype=torch.float64)
tensor(0.0838, dtype=torch.float64)
tensor(0.0770, dtype=torch.float64)
![../_images/examples_model-45_12_1.png](../_images/examples_model-45_12_1.png)
[13]:
# Apply corrections
lattice:Line = error.clone()
index = 0
label = ''
for line in lattice.sequence:
for element in line:
if element.__class__.__name__ == 'Quadrupole':
if label != element.name:
index +=1
label = element.name
element.kn = (element.kn - kn[index - 1]).item()
[14]:
# Compute twiss and plot beta beating
ax_model, bx_model, ay_model, by_model = twiss(ring, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
ax_error, bx_error, ay_error, by_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
ax_final, bx_final, ay_final, by_final = twiss(lattice, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
# Plot beta beating
plt.figure(figsize=(16, 2))
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_final)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='x')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_final)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='x')
plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()
![../_images/examples_model-45_14_0.png](../_images/examples_model-45_14_0.png)
[15]:
# Test Twiss response (chromatic)
twiss_error = chromatic_twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True)
twiss_model = chromatic_twiss(ring, [], alignment=False, matched=True, advance=True, full=False, convert=True)
print((twiss_error - (twiss_model + 0.0*(dtwiss_dp_dkn @ error_kn))).norm())
print((twiss_error - (twiss_model + 1.0*(dtwiss_dp_dkn @ error_kn))).norm())
tensor(69.5123, dtype=torch.float64)
tensor(5.5803, dtype=torch.float64)
[16]:
# Perform correction (model to experiment) including chromatic twiss
# Set response matrix
matrix = torch.vstack([dtwiss_dkn.reshape(-1, nq), dtwiss_dp_dkn.reshape(-1, nq)])
# Set target twiss parameters
twiss_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True)
chromatic_twiss_error = chromatic_twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True)
# Set learning rate
lr = 0.1
# Set initial values
kn = torch.zeros_like(error_kn)
# Fit
for _ in range(64):
twiss_model = twiss(ring, [kn], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)
chromatic_twiss_model = chromatic_twiss(ring, [kn], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)
dkn = - lr*torch.linalg.lstsq(matrix, torch.stack([twiss_model - twiss_error, chromatic_twiss_model - chromatic_twiss_error]).flatten(), driver='gelsd').solution
kn += dkn
print(torch.stack([twiss_model - twiss_error, chromatic_twiss_model - chromatic_twiss_error]).norm())
# Plot final quadrupole settings
plt.figure(figsize=(16, 2))
plt.bar(range(len(error_kn)), error_kn.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(kn)), +kn.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
tensor(69.5340, dtype=torch.float64)
tensor(62.5709, dtype=torch.float64)
tensor(56.2831, dtype=torch.float64)
tensor(50.6092, dtype=torch.float64)
tensor(45.4934, dtype=torch.float64)
tensor(40.8845, dtype=torch.float64)
tensor(36.7358, dtype=torch.float64)
tensor(33.0042, dtype=torch.float64)
tensor(29.6499, dtype=torch.float64)
tensor(26.6367, dtype=torch.float64)
tensor(23.9311, dtype=torch.float64)
tensor(21.5026, dtype=torch.float64)
tensor(19.3235, dtype=torch.float64)
tensor(17.3686, dtype=torch.float64)
tensor(15.6148, dtype=torch.float64)
tensor(14.0417, dtype=torch.float64)
tensor(12.6304, dtype=torch.float64)
tensor(11.3641, dtype=torch.float64)
tensor(10.2279, dtype=torch.float64)
tensor(9.2081, dtype=torch.float64)
tensor(8.2926, dtype=torch.float64)
tensor(7.4706, dtype=torch.float64)
tensor(6.7322, dtype=torch.float64)
tensor(6.0687, dtype=torch.float64)
tensor(5.4725, dtype=torch.float64)
tensor(4.9364, dtype=torch.float64)
tensor(4.4543, dtype=torch.float64)
tensor(4.0206, dtype=torch.float64)
tensor(3.6303, dtype=torch.float64)
tensor(3.2790, dtype=torch.float64)
tensor(2.9626, dtype=torch.float64)
tensor(2.6776, dtype=torch.float64)
tensor(2.4208, dtype=torch.float64)
tensor(2.1893, dtype=torch.float64)
tensor(1.9806, dtype=torch.float64)
tensor(1.7924, dtype=torch.float64)
tensor(1.6225, dtype=torch.float64)
tensor(1.4692, dtype=torch.float64)
tensor(1.3308, dtype=torch.float64)
tensor(1.2059, dtype=torch.float64)
tensor(1.0930, dtype=torch.float64)
tensor(0.9909, dtype=torch.float64)
tensor(0.8987, dtype=torch.float64)
tensor(0.8153, dtype=torch.float64)
tensor(0.7399, dtype=torch.float64)
tensor(0.6717, dtype=torch.float64)
tensor(0.6099, dtype=torch.float64)
tensor(0.5540, dtype=torch.float64)
tensor(0.5034, dtype=torch.float64)
tensor(0.4575, dtype=torch.float64)
tensor(0.4159, dtype=torch.float64)
tensor(0.3783, dtype=torch.float64)
tensor(0.3441, dtype=torch.float64)
tensor(0.3131, dtype=torch.float64)
tensor(0.2850, dtype=torch.float64)
tensor(0.2595, dtype=torch.float64)
tensor(0.2364, dtype=torch.float64)
tensor(0.2154, dtype=torch.float64)
tensor(0.1963, dtype=torch.float64)
tensor(0.1789, dtype=torch.float64)
tensor(0.1632, dtype=torch.float64)
tensor(0.1489, dtype=torch.float64)
tensor(0.1358, dtype=torch.float64)
tensor(0.1240, dtype=torch.float64)
![../_images/examples_model-45_16_1.png](../_images/examples_model-45_16_1.png)
[17]:
# Apply corrections
lattice:Line = error.clone()
index = 0
label = ''
for line in lattice.sequence:
for element in line:
if element.__class__.__name__ == 'Quadrupole':
if label != element.name:
index +=1
label = element.name
element.kn = (element.kn - kn[index - 1]).item()
[18]:
# Compute twiss and plot beta beating
ax_model, bx_model, ay_model, by_model = twiss(ring, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
ax_error, bx_error, ay_error, by_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
ax_final, bx_final, ay_final, by_final = twiss(lattice, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
# Plot beta beating
plt.figure(figsize=(16, 2))
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_final)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='x')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_final)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='x')
plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()
![../_images/examples_model-45_18_0.png](../_images/examples_model-45_18_0.png)
[19]:
# ML style correction (model to experiment)
# Set target twiss parameters
twiss_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True)
# Set learning rate
lr = 0.005
# Set parametric twiss
def twiss_model(kn):
return twiss(ring, [kn], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)
# Set objective function
def objective(kn):
return (twiss_error - twiss_model(kn)).norm()
# Set initial values
kn = torch.zeros_like(error_kn)
# Test objective function
print(objective(0.0*error_kn))
print(objective(1.0*error_kn))
print()
# Set normalized objective
objective = normalize(objective, [(-0.5, 0.5)])
# Test normalized objective
print(objective(*forward([0.0*error_kn], [(-0.5, 0.5)])))
print(objective(*forward([1.0*error_kn], [(-0.5, 0.5)])))
print()
# Normalize initial settings
kn, *_ = forward([kn], [(-0.5, 0.5)])
# Set model (forward returns evaluated objective)
model = Wrapper(objective, kn)
# Set optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# Perform optimization
for epoch in range(64):
value = model()
value.backward()
optimizer.step()
optimizer.zero_grad()
print(value.detach())
tensor(1.7376, dtype=torch.float64)
tensor(0., dtype=torch.float64)
tensor(1.7376, dtype=torch.float64)
tensor(4.8796e-13, dtype=torch.float64)
tensor(1.7376, dtype=torch.float64)
tensor(1.5623, dtype=torch.float64)
tensor(1.4006, dtype=torch.float64)
tensor(1.2536, dtype=torch.float64)
tensor(1.1222, dtype=torch.float64)
tensor(1.0066, dtype=torch.float64)
tensor(0.9054, dtype=torch.float64)
tensor(0.8187, dtype=torch.float64)
tensor(0.7487, dtype=torch.float64)
tensor(0.6968, dtype=torch.float64)
tensor(0.6613, dtype=torch.float64)
tensor(0.6364, dtype=torch.float64)
tensor(0.6152, dtype=torch.float64)
tensor(0.5918, dtype=torch.float64)
tensor(0.5634, dtype=torch.float64)
tensor(0.5298, dtype=torch.float64)
tensor(0.4917, dtype=torch.float64)
tensor(0.4498, dtype=torch.float64)
tensor(0.4053, dtype=torch.float64)
tensor(0.3615, dtype=torch.float64)
tensor(0.3244, dtype=torch.float64)
tensor(0.2996, dtype=torch.float64)
tensor(0.2876, dtype=torch.float64)
tensor(0.2831, dtype=torch.float64)
tensor(0.2812, dtype=torch.float64)
tensor(0.2804, dtype=torch.float64)
tensor(0.2789, dtype=torch.float64)
tensor(0.2738, dtype=torch.float64)
tensor(0.2647, dtype=torch.float64)
tensor(0.2542, dtype=torch.float64)
tensor(0.2450, dtype=torch.float64)
tensor(0.2369, dtype=torch.float64)
tensor(0.2295, dtype=torch.float64)
tensor(0.2242, dtype=torch.float64)
tensor(0.2220, dtype=torch.float64)
tensor(0.2203, dtype=torch.float64)
tensor(0.2160, dtype=torch.float64)
tensor(0.2092, dtype=torch.float64)
tensor(0.2011, dtype=torch.float64)
tensor(0.1925, dtype=torch.float64)
tensor(0.1854, dtype=torch.float64)
tensor(0.1817, dtype=torch.float64)
tensor(0.1799, dtype=torch.float64)
tensor(0.1769, dtype=torch.float64)
tensor(0.1718, dtype=torch.float64)
tensor(0.1649, dtype=torch.float64)
tensor(0.1571, dtype=torch.float64)
tensor(0.1504, dtype=torch.float64)
tensor(0.1459, dtype=torch.float64)
tensor(0.1423, dtype=torch.float64)
tensor(0.1390, dtype=torch.float64)
tensor(0.1354, dtype=torch.float64)
tensor(0.1310, dtype=torch.float64)
tensor(0.1267, dtype=torch.float64)
tensor(0.1231, dtype=torch.float64)
tensor(0.1199, dtype=torch.float64)
tensor(0.1169, dtype=torch.float64)
tensor(0.1127, dtype=torch.float64)
tensor(0.1073, dtype=torch.float64)
tensor(0.1017, dtype=torch.float64)
tensor(0.0971, dtype=torch.float64)
tensor(0.0936, dtype=torch.float64)
tensor(0.0902, dtype=torch.float64)
tensor(0.0865, dtype=torch.float64)
[20]:
# Apply corrections
kn, *_ = inverse([kn], [(-0.5, 0.5)])
lattice:Line = error.clone()
index = 0
label = ''
for line in lattice.sequence:
for element in line:
if element.__class__.__name__ == 'Quadrupole':
if label != element.name:
index +=1
label = element.name
element.kn = (element.kn - kn[index - 1]).item()
[21]:
# Compute twiss and plot beta beating
ax_model, bx_model, ay_model, by_model = twiss(ring, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
ax_error, bx_error, ay_error, by_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
ax_final, bx_final, ay_final, by_final = twiss(lattice, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
# Plot beta beating
plt.figure(figsize=(16, 2))
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_final)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='x')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_final)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='x')
plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()
![../_images/examples_model-45_21_0.png](../_images/examples_model-45_21_0.png)
[22]:
# AdEMAMix optimizer
# https://arxiv.org/abs/2409.03137
# https://github.com/apple/ml-ademamix
import math
import torch
from torch.optim import Optimizer
def linear_warmup_scheduler(step, alpha_end, alpha_start=0, warmup=1):
if step < warmup:
a = step / float(warmup)
return (1.0-a) * alpha_start + a * alpha_end
return alpha_end
def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1):
def f(beta, eps=1e-8):
return math.log(0.5)/math.log(beta+eps)-1
def f_inv(t):
return math.pow(0.5, 1/(t+1))
if step < warmup:
a = step / float(warmup)
return f_inv((1.0-a) * f(beta_start) + a * f(beta_end))
return beta_end
class AdEMAMix(Optimizer):
"""Implements the AdEMAMix algorithm.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999, 0.9999))
corresponding to beta_1, beta_2, beta_3 in AdEMAMix
alpha (float): AdEMAMix alpha coeficient mixing the slow and fast EMAs (default: 2)
beta3_warmup (int, optional): number of warmup steps used to increase beta3 (default: None)
alpha_warmup: (int, optional): number of warmup steps used to increase alpha (default: None)
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay as in AdamW (default: 0)
"""
def __init__(self,
params,
lr=1e-3,
betas=(0.9, 0.999, 0.9999),
alpha=2.0,
beta3_warmup=None,
alpha_warmup=None,
eps=1e-8,
weight_decay=0):
defaults = dict(lr=lr,
betas=betas,
eps=eps,
alpha=alpha,
beta3_warmup=beta3_warmup,
alpha_warmup=alpha_warmup,
weight_decay=weight_decay)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group["lr"]
lmbda = group["weight_decay"]
eps = group["eps"]
beta1, beta2, beta3_final = group["betas"]
beta3_warmup = group["beta3_warmup"]
alpha_final = group["alpha"]
alpha_warmup = group["alpha_warmup"]
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError('AdEMAMix does not support sparse gradients.')
state = self.state[p]
if len(state) == 0:
state['step'] = 0
if beta1 != 0.0:
state['exp_avg_fast'] = torch.zeros_like(p, memory_format=torch.preserve_format)
else:
state['exp_avg_fast'] = None
state['exp_avg_slow'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avg_fast, exp_avg_slow, exp_avg_sq = state['exp_avg_fast'], state['exp_avg_slow'], state['exp_avg_sq']
state['step'] += 1
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
if alpha_warmup is not None:
alpha = linear_warmup_scheduler(state["step"], alpha_end=alpha_final, alpha_start=0, warmup=alpha_warmup)
else:
alpha = alpha_final
if beta3_warmup is not None:
beta3 = linear_hl_warmup_scheduler(state["step"], beta_end=beta3_final, beta_start=beta1, warmup=beta3_warmup)
else:
beta3 = beta3_final
if beta1 != 0.0:
exp_avg_fast.mul_(beta1).add_(grad, alpha=1 - beta1)
else:
exp_avg_fast = grad
exp_avg_slow.mul_(beta3).add_(grad, alpha=1 - beta3)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
update = (exp_avg_fast.div(bias_correction1) + alpha * exp_avg_slow) / denom
update.add_(p, alpha=lmbda)
p.add_(-lr * update)
return loss
[23]:
# ML style correction (batched)
# Set target twiss parameters
twiss_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True)
# Set learning rate
lr = 0.005
# Define rings (Twiss parameters will be computed at each ring start)
rings:list[Line] = []
for i, _ in enumerate(ring):
line = ring.clone()
line.roll(i)
rings.append(line)
# Set batched function
_, ((_, names, _), *_), _ = group(ring, 0, len(ring) - 1, ('kn', ['Quadrupole'], None, None))
def task(Is, kn):
result = []
for I in Is:
result.append(twiss(rings[I], [kn], ('kn', None, names, None), alignment=False, matched=True, convert=True))
return torch.stack(result)
# Set initial values
kn = torch.zeros_like(error_kn)
# Normalize objective
task = normalize(task, [(None, None), (-0.5, 0.5)])
# Normalize initial settings
kn, *_ = forward([kn], [(-0.5, 0.5)])
# Set model
model = Wrapper(task, kn)
# Set optimizer
optimizer = AdEMAMix(model.parameters(), lr=lr)
# Set features and labels
X = torch.arange(len(ring))
y = twiss_error.clone()
# Set dataset
# Note, full set is used here, batch size is too small otherwise
batch_size = 16
dataset = TensorDataset(X.clone(), y.clone())
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Set loss funtion
lf = torch.nn.MSELoss()
# Perfom optimization
for epoch in range(64):
for batch, (X, y) in enumerate(dataloader):
y_hat = model(X)
value = lf(y_hat, y)
value.backward()
optimizer.step()
optimizer.zero_grad()
with torch.no_grad():
print(value.detach())
tensor(0.0472, dtype=torch.float64)
tensor(0.0381, dtype=torch.float64)
tensor(0.0307, dtype=torch.float64)
tensor(0.0247, dtype=torch.float64)
tensor(0.0199, dtype=torch.float64)
tensor(0.0161, dtype=torch.float64)
tensor(0.0133, dtype=torch.float64)
tensor(0.0111, dtype=torch.float64)
tensor(0.0094, dtype=torch.float64)
tensor(0.0082, dtype=torch.float64)
tensor(0.0075, dtype=torch.float64)
tensor(0.0069, dtype=torch.float64)
tensor(0.0066, dtype=torch.float64)
tensor(0.0063, dtype=torch.float64)
tensor(0.0060, dtype=torch.float64)
tensor(0.0057, dtype=torch.float64)
tensor(0.0053, dtype=torch.float64)
tensor(0.0049, dtype=torch.float64)
tensor(0.0045, dtype=torch.float64)
tensor(0.0040, dtype=torch.float64)
tensor(0.0035, dtype=torch.float64)
tensor(0.0031, dtype=torch.float64)
tensor(0.0027, dtype=torch.float64)
tensor(0.0023, dtype=torch.float64)
tensor(0.0020, dtype=torch.float64)
tensor(0.0018, dtype=torch.float64)
tensor(0.0016, dtype=torch.float64)
tensor(0.0015, dtype=torch.float64)
tensor(0.0014, dtype=torch.float64)
tensor(0.0014, dtype=torch.float64)
tensor(0.0014, dtype=torch.float64)
tensor(0.0014, dtype=torch.float64)
tensor(0.0014, dtype=torch.float64)
tensor(0.0014, dtype=torch.float64)
tensor(0.0014, dtype=torch.float64)
tensor(0.0013, dtype=torch.float64)
tensor(0.0013, dtype=torch.float64)
tensor(0.0012, dtype=torch.float64)
tensor(0.0012, dtype=torch.float64)
tensor(0.0011, dtype=torch.float64)
tensor(0.0010, dtype=torch.float64)
tensor(0.0009, dtype=torch.float64)
tensor(0.0009, dtype=torch.float64)
tensor(0.0008, dtype=torch.float64)
tensor(0.0008, dtype=torch.float64)
tensor(0.0008, dtype=torch.float64)
tensor(0.0008, dtype=torch.float64)
tensor(0.0008, dtype=torch.float64)
tensor(0.0008, dtype=torch.float64)
tensor(0.0008, dtype=torch.float64)
tensor(0.0008, dtype=torch.float64)
tensor(0.0008, dtype=torch.float64)
tensor(0.0007, dtype=torch.float64)
tensor(0.0007, dtype=torch.float64)
tensor(0.0007, dtype=torch.float64)
tensor(0.0007, dtype=torch.float64)
tensor(0.0007, dtype=torch.float64)
tensor(0.0006, dtype=torch.float64)
tensor(0.0006, dtype=torch.float64)
tensor(0.0006, dtype=torch.float64)
tensor(0.0006, dtype=torch.float64)
tensor(0.0006, dtype=torch.float64)
tensor(0.0006, dtype=torch.float64)
tensor(0.0006, dtype=torch.float64)
[24]:
# Apply corrections
kn, *_ = inverse([kn], [(-0.5, 0.5)])
lattice:Line = error.clone()
index = 0
label = ''
for line in lattice.sequence:
for element in line:
if element.__class__.__name__ == 'Quadrupole':
if label != element.name:
index +=1
label = element.name
element.kn = (element.kn - kn[index - 1]).item()
[25]:
# Compute twiss and plot beta beating
ax_model, bx_model, ay_model, by_model = twiss(ring, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
ax_error, bx_error, ay_error, by_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
ax_final, bx_final, ay_final, by_final = twiss(lattice, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
# Plot beta beating
plt.figure(figsize=(16, 2))
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_final)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='x')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_final)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='x')
plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()
![../_images/examples_model-45_25_0.png](../_images/examples_model-45_25_0.png)