ELETTRA-07: ID tune shift fit & correction (global tune knob)
[1]:
# In this example effects of an ID (APPLE-II device represented by a linear 4x4 symplectic matrix) are presented
# Tune shift introduced by ID is first used to fit the ID model
# Next, global tune correction is performed using the fitted model and the correction result is evaluated
[2]:
# Import
import torch
from torch import Tensor
from pathlib import Path
import matplotlib
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
matplotlib.rcParams['text.usetex'] = True
from model.library.element import Element
from model.library.line import Line
from model.library.quadrupole import Quadrupole
from model.library.matrix import Matrix
from model.command.external import load_lattice
from model.command.build import build
from model.command.tune import tune
from model.command.orbit import dispersion
from model.command.twiss import twiss
from model.command.advance import advance
from model.command.coupling import coupling
from model.command.wrapper import Wrapper
[3]:
# Set data type and device
Element.dtype = dtype = torch.float64
Element.device = device = torch.device('cpu')
[4]:
# Load lattice (ELEGANT table)
# Note, lattice is allowed to have repeated elements
path = Path('elettra.lte')
data = load_lattice(path)
[5]:
# Build and setup lattice
ring:Line = build('RING', 'ELEGANT', data)
# Flatten sublines
ring.flatten()
# Remove all marker elements but the ones starting with MLL (long straight section centers)
ring.remove_group(pattern=r'^(?!MLL_).*', kinds=['Marker'])
# Replace all sextupoles with quadrupoles
def factory(element:Element) -> None:
table = element.serialize
table.pop('ms', None)
return Quadrupole(**table)
ring.replace_group(pattern=r'', factory=factory, kinds=['Sextupole'])
# Set linear dipoles
def apply(element:Element) -> None:
element.linear = True
ring.apply(apply, kinds=['Dipole'])
# Merge drifts
ring.merge()
# Change lattice start
ring.start = "BPM_S01_01"
# Split BPMs
ring.split((None, ['BPM'], None, None))
# Roll lattice
ring.roll(1)
# Splice lattice
ring.splice()
# Describe
ring.describe
[5]:
{'BPM': 168, 'Drift': 708, 'Dipole': 156, 'Quadrupole': 360, 'Marker': 12}
[6]:
# Compute tunes (fractional part)
nux, nuy = tune(ring, [], matched=True, limit=1)
[7]:
# Compute dispersion
orbit = torch.tensor(4*[0.0], dtype=dtype)
etaqx, etapx, etaqy, etapy = dispersion(ring, orbit, [], limit=1)
[8]:
# Compute twiss parameters
ax, bx, ay, by = twiss(ring, [], matched=True, advance=True, full=False).T
[9]:
# Compute phase advances
mux, muy = advance(ring, [], alignment=False, matched=True).T
[10]:
# Compute coupling
c = coupling(ring, [])
[11]:
# Define ID model
# Note, only the flattened triangular part of the A and B matrices is passed
A = torch.tensor([[-0.03484222052711237, 1.0272120741819959E-7, -4.698931299341201E-9, 0.0015923185492594811],
[1.0272120579834892E-7, -0.046082787920135176, 0.0017792061173117564, 3.3551298301095784E-8],
[-4.6989312853101E-9, 0.0017792061173117072, 0.056853750760983084, -1.5929605363332683E-7],
[0.0015923185492594336, 3.3551298348653296E-8, -1.5929605261642905E-7, 0.08311631737263032]], dtype=dtype)
B = torch.tensor([[0.03649353186115209, 0.0015448347221877217, 0.00002719892025520868, -0.0033681183134964482],
[0.0015448347221877217, 0.13683886657005795, -0.0033198692682377406, 0.00006140578258682469],
[0.00002719892025520868, -0.0033198692682377406, -0.05260095308967722, 0.005019907688182885],
[-0.0033681183134964482, 0.00006140578258682469, 0.005019907688182885, -0.2531573249456863]], dtype=dtype)
ID = Matrix('ID',
length=0.0,
A=A[torch.triu(torch.ones_like(A, dtype=torch.bool))].tolist(),
B=B[torch.triu(torch.ones_like(B, dtype=torch.bool))].tolist())
[12]:
# Insert ID into the existing lattice
# This will replace the target marker
error = ring.clone()
error.flatten()
error.insert(ID, 'MLL_S01', position=0.0)
error.splice()
[13]:
# Measure tunes with ID
nux_id, nuy_id = target = tune(error, [], matched=True)
print((nux - nux_id).abs())
print((nuy - nuy_id).abs())
tensor(0.0260, dtype=torch.float64)
tensor(0.0114, dtype=torch.float64)
[14]:
# Compute dispersion
orbit = torch.tensor(4*[0.0], dtype=dtype)
etaqx_id, etapx_id, etaqy_id, etapy_id = dispersion(error, orbit, [], limit=1)
[15]:
# Compute twiss parameters
ax_id, bx_id, ay_id, by_id = twiss(error, [], matched=True, advance=True, full=False).T
[16]:
# Compute phase advances
mux_id, muy_id = advance(error, [], alignment=False, matched=True).T
[17]:
# Compute coupling
c_id = coupling(error, [])
[18]:
# Create a ring with ID to be fitted to measured observables (tunes)
TM = Matrix('TM')
model = ring.clone()
model.flatten()
model.insert(TM, 'MLL_S01', position=0.0)
model.splice()
[19]:
# Define parametric observable
def observable(knobs):
a11, a12, a13, a14, a22, a23, a24, a33, a34, a44 = knobs.reshape(-1, 1)
return tune(model,
[a11, a12, a13, a14, a22, a23, a24, a33, a34, a44],
('a11', None, ['TM'], None),
('a12', None, ['TM'], None),
('a13', None, ['TM'], None),
('a14', None, ['TM'], None),
('a22', None, ['TM'], None),
('a23', None, ['TM'], None),
('a24', None, ['TM'], None),
('a33', None, ['TM'], None),
('a34', None, ['TM'], None),
('a44', None, ['TM'], None),
matched=True)
[20]:
# Test observable with known ID model
print(target)
print(observable(ID.A))
tensor([0.2735, 0.1723], dtype=torch.float64)
tensor([0.2735, 0.1723], dtype=torch.float64)
[21]:
# Define objective function to be fitted
def objective(knobs):
return (target - observable(knobs)).norm()
print(objective(0.0*ID.A))
print(objective(1.0*ID.A))
tensor(0.0284, dtype=torch.float64)
tensor(0., dtype=torch.float64)
[22]:
# Fit (ADAM)
knobs = torch.tensor(10*[0.0], dtype=dtype)
wrapper = Wrapper(objective, knobs)
optimizer = torch.optim.AdamW(wrapper.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=32, gamma=0.95)
for epoch in range(256):
value = wrapper()
value.backward()
optimizer.step()
optimizer.zero_grad()
scheduler.step()
print(value.detach().item())
0.028369700506047313
0.027604979574733145
0.02684207653388388
0.026081063931985574
0.02532202986496677
0.02456503415221196
0.02381020965125639
0.023057643096075053
0.022307458020095296
0.021559802115037493
0.020814824133035887
0.020072691583410796
0.01933360633793344
0.01859779624814611
0.017865514537404256
0.01713705406561508
0.016412757428888126
0.015693019888152106
0.014978298873072853
0.014269131037937175
0.013566148256137492
0.012870095625237974
0.012181858106829188
0.011502492417944167
0.010833263166714534
0.010175691458854214
0.009531616132536076
0.0089032612390469
0.008293314881859004
0.00770502009416522
0.007142257433254706
0.006609596503016408
0.006112278035565941
0.005677920085235823
0.005285758836406501
0.004940049910706599
0.004643573081830862
0.0043965995525112
0.004196015429417217
0.004035000235607864
0.0039035410383780065
0.0037896923390195983
0.0036811553570592894
0.0035666798848923665
0.0034369804317109483
0.003285134365852536
0.003106597814067926
0.0028990209855239667
0.002662026802255918
0.0023970851335187657
0.0021076221174704873
0.0017995806340871346
0.001482886230260837
0.0011748123315649792
0.0009064586311714966
0.0007258771735991629
0.0006602409194314186
0.0006563256746629608
0.0006389218078180901
0.0005756323523207124
0.0004770652334132349
0.00040482048739814416
0.0004488134079377258
0.0005723447614733553
0.0006825347339017069
0.0007375622822843643
0.0007355602918931942
0.0006810064358945258
0.0005860382549705243
0.0004729294078276646
0.00037926505707123115
0.0003391704912980229
0.0003252472180930264
0.0002793347411800816
0.0001741432239347568
7.986430633214279e-05
0.00020694664461870272
0.00028467317875998787
0.00028737486365218256
0.0002349386205344058
0.00017869777807106698
0.00018210407210867278
0.00018394985032149243
0.00012200096293579978
2.504163199412578e-05
9.657132662324322e-05
9.342239305619409e-05
7.163867003840293e-05
8.063948051910279e-05
3.276989876051909e-05
9.287973607744902e-05
0.00011803087678276688
7.866800880937734e-05
9.610084779320682e-05
0.0001054906028501998
3.6598965960343354e-05
0.00010871857518170825
0.00014499645021069337
0.00010249816781572995
7.709395404407496e-05
0.00011473143281842893
8.447119574935473e-05
2.4138519635949214e-05
3.927730797423161e-05
4.5872961895850735e-05
3.615966267351201e-05
4.238990558167267e-05
2.5711734237141967e-05
6.722204484355513e-05
5.6237176581074446e-05
3.845990376887628e-05
3.5386012152311766e-05
5.463536908683197e-05
5.09545300326953e-05
3.7796217039974957e-05
3.0385406281355677e-05
6.080789589346814e-05
5.953057208598726e-05
2.9211774539882354e-05
2.6270378576523216e-05
5.8214918671829745e-05
5.293282262883786e-05
3.863131956949567e-05
3.629274046056883e-05
4.617446746188101e-05
4.03041094160775e-05
4.833493153181574e-05
4.2695343559576496e-05
4.297793319726718e-05
3.973455242313906e-05
4.172546219346975e-05
3.6131181626745114e-05
4.495459159393313e-05
4.171390245170822e-05
3.898845084458991e-05
3.40525616799877e-05
4.589294169984328e-05
4.22290410999437e-05
3.799126157978804e-05
3.346594407677967e-05
4.561314730373438e-05
4.180245361997472e-05
3.7657136599340985e-05
3.313284793812229e-05
4.5528090843893147e-05
4.1932270138886904e-05
3.6543947430367706e-05
3.18205652575772e-05
4.6589644378946656e-05
4.328422866191937e-05
3.426727169561934e-05
2.9413522880561434e-05
4.867666694102279e-05
4.556636180281018e-05
3.1263398376916696e-05
2.6447558661577688e-05
5.113199868108273e-05
4.800207880127862e-05
2.8348048485094676e-05
2.3773482254170528e-05
5.310392144287346e-05
4.989847245283575e-05
2.2505232483953626e-05
1.8922583760740437e-05
5.276224326107467e-05
4.857750571547126e-05
2.399680558197707e-05
2.0716221147737427e-05
5.048274380409993e-05
4.616866284917532e-05
2.5921201684188245e-05
2.249842355029714e-05
4.8574493461097755e-05
4.4469918266027036e-05
2.698449777660231e-05
2.3403719577676886e-05
4.747691695490115e-05
4.353538193467468e-05
2.7430166628320154e-05
2.3805448031834916e-05
4.6767250638575955e-05
4.2916757329182764e-05
2.7673187480954627e-05
2.4060070352808957e-05
4.617237618013766e-05
4.2385219797625957e-05
2.7856487907968014e-05
2.4234027976466777e-05
4.5705985210299664e-05
4.1999072428320274e-05
2.7869341596320038e-05
2.4198601353925668e-05
4.550055295993997e-05
4.205696123370083e-05
2.396418601692664e-05
2.043356025807998e-05
4.555562847965381e-05
4.2161956060550543e-05
2.354372112083146e-05
1.9990299374537987e-05
4.5766146502044445e-05
4.240805859375417e-05
2.3020777359054405e-05
1.9466818965866686e-05
4.605170524952427e-05
4.2723041572449765e-05
2.245147304912021e-05
1.8900730015085198e-05
4.6391080083101903e-05
4.3091386482800466e-05
2.1835399380092975e-05
1.828266818229918e-05
4.679414305050118e-05
4.351573450817071e-05
2.1176753718748183e-05
1.7624353193379887e-05
4.724479620515181e-05
4.398395474735512e-05
2.0491183017546885e-05
1.6944578498890487e-05
4.7724316938788e-05
4.448238093420612e-05
1.9784646815767107e-05
1.624368362199292e-05
4.823203981133701e-05
4.516591499447908e-05
1.5699650075826235e-05
1.2341150668018344e-05
4.873264396162423e-05
4.5675987504275135e-05
1.5014283095429034e-05
1.1665291024011255e-05
4.923892542290422e-05
4.619420779000219e-05
1.432470729659923e-05
1.0980951927595938e-05
4.976107908584259e-05
4.672690054924609e-05
1.3629601562839224e-05
1.0294192172370253e-05
5.028990354832106e-05
4.726384844695133e-05
1.2937594152729637e-05
9.608561027333687e-06
5.082705145289385e-05
4.781189403084756e-05
1.2238287230536143e-05
8.914247968399803e-06
5.1377728361687444e-05
4.8371756358532895e-05
1.1534707036896961e-05
8.218884916578543e-06
5.192768264520092e-05
4.8923645138056975e-05
1.085328843377851e-05
7.568294157060619e-06
[23]:
# Fit (LBFGS)
knobs = torch.tensor(10*[0.0], dtype=dtype, requires_grad=True)
optimizer = torch.optim.LBFGS([knobs], lr=0.1, line_search_fn="strong_wolfe")
def closure():
optimizer.zero_grad()
value = objective(knobs)
value.backward()
return value
for epoch in range(8):
value = optimizer.step(closure)
print(value.item())
0.028369700506047313
4.066909132862517e-07
3.358642506148264e-10
1.9754381178133033e-10
1.9754381178133033e-10
1.9754381178133033e-10
1.9754381178133033e-10
1.9754381178133033e-10
[24]:
# Check fitted tunes
print(target)
print(observable(ID.A))
print(observable(knobs.detach()))
print()
print((target - observable(ID.A)).abs())
print((target - observable(knobs.detach())).abs())
tensor([0.2735, 0.1723], dtype=torch.float64)
tensor([0.2735, 0.1723], dtype=torch.float64)
tensor([0.2735, 0.1723], dtype=torch.float64)
tensor([0., 0.], dtype=torch.float64)
tensor([1.8011e-10, 8.1134e-11], dtype=torch.float64)
[25]:
# Define fitted model
TM.A = knobs.detach()
print(target)
print(tune(model, [], matched=True, limit=1))
tensor([0.2735, 0.1723], dtype=torch.float64)
tensor([0.2735, 0.1723], dtype=torch.float64)
[26]:
# Define target tunes
target = tune(ring, [], matched=True, limit=1)
print(target)
tensor([0.2994, 0.1608], dtype=torch.float64)
[27]:
# Define parametric observable
QF = [f'QF_S{i:02}_{j:02}' for j in [2, 3] for i in range(1, 12 + 1)]
QD = [f'QD_S{i:02}_{j:02}' for j in [2, 3] for i in range(1, 12 + 1)]
def observable(kn):
return tune(model, [kn], ('kn', None, QF + QD, None), matched=True, limit=1)
knobs = torch.zeros(len(QF + QD), dtype=dtype)
print(observable(knobs))
tensor([0.2735, 0.1723], dtype=torch.float64)
[28]:
# Define objective
def objective(knobs):
return (target - observable(knobs)).norm()
print(objective(knobs))
tensor(0.0284, dtype=torch.float64)
[29]:
# Fit
knobs = torch.tensor(len(QF + QD)*[0.0], dtype=dtype, requires_grad=True)
optimizer = torch.optim.LBFGS([knobs], lr=0.1, line_search_fn="strong_wolfe")
def closure():
optimizer.zero_grad()
value = objective(knobs)
value.backward()
return value
for epoch in range(8):
value = optimizer.step(closure)
print(value.item())
0.028369700703583963
1.250945355652117e-07
2.8035556938341184e-09
2.1928367264705342e-09
1.6978970603827018e-09
3.292485872195638e-10
2.0558624258379998e-10
2.0558624258379998e-10
[30]:
# Apply corrections to model with the exact ID
result = error.clone()
result.flatten()
for name, knob in zip(QF + QD, knobs.detach()):
result[name].kn = (result[name].kn + knob).item()
result.splice()
[31]:
# Compute tunes (fractional part)
nux_result, nuy_result = tune(result, [], matched=True, limit=1)
[32]:
# Compute dispersion
orbit = torch.tensor(4*[0.0], dtype=dtype)
etaqx_result, etapx_result, etaqy_result, etapy_result = dispersion(result, orbit, [], limit=1)
[33]:
# Compute twiss parameters
ax_result, bx_result, ay_result, by_result = twiss(result, [], matched=True, advance=True, full=False).T
[34]:
# Compute phase advances
mux_result, muy_result = advance(result, [], alignment=False, matched=True).T
[35]:
# Compute coupling
c_result = coupling(result, [])
[36]:
# Tune shifts
print((nux - nux_id).abs())
print((nuy - nuy_id).abs())
print()
print((nux - nux_result).abs())
print((nuy - nuy_result).abs())
print()
tensor(0.0260, dtype=torch.float64)
tensor(0.0114, dtype=torch.float64)
tensor(0.0001, dtype=torch.float64)
tensor(8.5259e-05, dtype=torch.float64)
[37]:
# Coupling (minimal tune distance)
print(c)
print(c_id)
print(c_result)
tensor(0., dtype=torch.float64)
tensor(0.0004, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
[38]:
# Dispersion
plt.figure(figsize=(12, 4))
plt.errorbar(ring.locations().cpu().numpy(), (etaqx - etaqx_id).cpu().numpy(), fmt='-', marker='x', color='blue', alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), (etaqy - etaqy_id).cpu().numpy(), fmt='-', marker='x', color='red', alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), (etaqx - etaqx_result).cpu().numpy(), fmt='-', marker='o', color='blue', alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), (etaqy - etaqy_result).cpu().numpy(), fmt='-', marker='o', color='red', alpha=0.75)
plt.tight_layout()
plt.show()
[39]:
# Beta-beating
plt.figure(figsize=(12, 4))
plt.errorbar(ring.locations().cpu().numpy(), 100*((bx - bx_id)/bx).cpu().numpy(), fmt='-', marker='x', color='blue', alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), 100*((by - by_id)/by).cpu().numpy(), fmt='-', marker='x', color='red', alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), 100*((bx - bx_result)/bx).cpu().numpy(), fmt='-', marker='o', color='blue', alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), 100*((by - by_result)/by).cpu().numpy(), fmt='-', marker='o', color='red', alpha=0.75)
plt.tight_layout()
plt.show()
print(100*(((bx - bx_id)/bx)**2).mean().sqrt())
print(100*(((by - by_id)/by)**2).mean().sqrt())
print()
print(100*(((bx - bx_result)/bx)**2).mean().sqrt())
print(100*(((by - by_result)/by)**2).mean().sqrt())
print()
tensor(11.5994, dtype=torch.float64)
tensor(1.7916, dtype=torch.float64)
tensor(13.5893, dtype=torch.float64)
tensor(1.0242, dtype=torch.float64)
[40]:
# Phase advance
plt.figure(figsize=(12, 4))
plt.errorbar(ring.locations().cpu().numpy(), 100*((mux - mux_id)/mux).cpu().numpy(), fmt='-', marker='x', color='blue', alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), 100*((muy - muy_id)/muy).cpu().numpy(), fmt='-', marker='x', color='red', alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), 100*((mux - mux_result)/mux).cpu().numpy(), fmt='-', marker='o', color='blue', alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), 100*((muy - muy_result)/muy).cpu().numpy(), fmt='-', marker='o', color='red', alpha=0.75)
plt.tight_layout()
plt.show()
print(100*(((mux - mux_id)/mux)**2).mean().sqrt())
print(100*(((muy - muy_id)/muy)**2).mean().sqrt())
print()
print(100*(((mux - mux_result)/mux)**2).mean().sqrt())
print(100*(((muy - muy_result)/muy)**2).mean().sqrt())
print()
tensor(8.7941, dtype=torch.float64)
tensor(1.7778, dtype=torch.float64)
tensor(10.1208, dtype=torch.float64)
tensor(1.0428, dtype=torch.float64)