ELETTRA-07: ID tune shift fit & correction (global tune knob)

[1]:
# In this example effects of an ID (APPLE-II device represented by a linear 4x4 symplectic matrix) are presented

# Tune shift introduced by ID is first used to fit the ID model
# Next, global tune correction is performed using the fitted model and the correction result is evaluated
[2]:
# Import

import torch
from torch import Tensor

from pathlib import Path

import matplotlib
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
matplotlib.rcParams['text.usetex'] = True

from model.library.element import Element
from model.library.line import Line
from model.library.quadrupole import Quadrupole
from model.library.matrix import Matrix

from model.command.external import load_lattice
from model.command.build import build
from model.command.tune import tune
from model.command.orbit import dispersion
from model.command.twiss import twiss
from model.command.advance import advance
from model.command.coupling import coupling

from model.command.wrapper import Wrapper
[3]:
# Set data type and device

Element.dtype = dtype = torch.float64
Element.device = device = torch.device('cpu')
[4]:
# Load lattice (ELEGANT table)
# Note, lattice is allowed to have repeated elements

path = Path('elettra.lte')
data = load_lattice(path)
[5]:
# Build and setup lattice

ring:Line = build('RING', 'ELEGANT', data)

# Flatten sublines

ring.flatten()

# Remove all marker elements but the ones starting with MLL (long straight section centers)

ring.remove_group(pattern=r'^(?!MLL_).*', kinds=['Marker'])

# Replace all sextupoles with quadrupoles

def factory(element:Element) -> None:
    table = element.serialize
    table.pop('ms', None)
    return Quadrupole(**table)

ring.replace_group(pattern=r'', factory=factory, kinds=['Sextupole'])

# Set linear dipoles

def apply(element:Element) -> None:
    element.linear = True

ring.apply(apply, kinds=['Dipole'])

# Merge drifts

ring.merge()

# Change lattice start

ring.start = "BPM_S01_01"

# Split BPMs

ring.split((None, ['BPM'], None, None))

# Roll lattice

ring.roll(1)

# Splice lattice

ring.splice()

# Describe

ring.describe
[5]:
{'BPM': 168, 'Drift': 708, 'Dipole': 156, 'Quadrupole': 360, 'Marker': 12}
[6]:
# Compute tunes (fractional part)

nux, nuy = tune(ring, [], matched=True, limit=1)
[7]:
# Compute dispersion

orbit = torch.tensor(4*[0.0], dtype=dtype)
etaqx, etapx, etaqy, etapy = dispersion(ring, orbit, [], limit=1)
[8]:
# Compute twiss parameters

ax, bx, ay, by = twiss(ring, [], matched=True, advance=True, full=False).T
[9]:
# Compute phase advances

mux, muy = advance(ring, [], alignment=False, matched=True).T
[10]:
# Compute coupling

c = coupling(ring, [])
[11]:
# Define ID model
# Note, only the flattened triangular part of the A and B matrices is passed

A = torch.tensor([[-0.03484222052711237, 1.0272120741819959E-7, -4.698931299341201E-9, 0.0015923185492594811],
                  [1.0272120579834892E-7, -0.046082787920135176, 0.0017792061173117564, 3.3551298301095784E-8],
                  [-4.6989312853101E-9, 0.0017792061173117072, 0.056853750760983084, -1.5929605363332683E-7],
                  [0.0015923185492594336, 3.3551298348653296E-8, -1.5929605261642905E-7, 0.08311631737263032]], dtype=dtype)

B = torch.tensor([[0.03649353186115209, 0.0015448347221877217, 0.00002719892025520868, -0.0033681183134964482],
                  [0.0015448347221877217, 0.13683886657005795, -0.0033198692682377406, 0.00006140578258682469],
                  [0.00002719892025520868, -0.0033198692682377406, -0.05260095308967722, 0.005019907688182885],
                  [-0.0033681183134964482, 0.00006140578258682469, 0.005019907688182885, -0.2531573249456863]], dtype=dtype)

ID = Matrix('ID',
            length=0.0,
            A=A[torch.triu(torch.ones_like(A, dtype=torch.bool))].tolist(),
            B=B[torch.triu(torch.ones_like(B, dtype=torch.bool))].tolist())
[12]:
# Insert ID into the existing lattice
# This will replace the target marker

error = ring.clone()
error.flatten()
error.insert(ID, 'MLL_S01', position=0.0)
error.splice()
[13]:
# Measure tunes with ID

nux_id, nuy_id = target = tune(error, [], matched=True)

print((nux - nux_id).abs())
print((nuy - nuy_id).abs())
tensor(0.0260, dtype=torch.float64)
tensor(0.0114, dtype=torch.float64)
[14]:
# Compute dispersion

orbit = torch.tensor(4*[0.0], dtype=dtype)
etaqx_id, etapx_id, etaqy_id, etapy_id = dispersion(error, orbit, [], limit=1)
[15]:
# Compute twiss parameters

ax_id, bx_id, ay_id, by_id = twiss(error, [], matched=True, advance=True, full=False).T
[16]:
# Compute phase advances

mux_id, muy_id = advance(error, [], alignment=False, matched=True).T
[17]:
# Compute coupling

c_id = coupling(error, [])
[18]:
# Create a ring with ID to be fitted to measured observables (tunes)

TM = Matrix('TM')

model = ring.clone()
model.flatten()
model.insert(TM, 'MLL_S01', position=0.0)
model.splice()
[19]:
# Define parametric observable

def observable(knobs):
    a11, a12, a13, a14, a22, a23, a24, a33, a34, a44 = knobs.reshape(-1, 1)
    return tune(model,
                [a11, a12, a13, a14, a22, a23, a24, a33, a34, a44],
                ('a11', None, ['TM'], None),
                ('a12', None, ['TM'], None),
                ('a13', None, ['TM'], None),
                ('a14', None, ['TM'], None),
                ('a22', None, ['TM'], None),
                ('a23', None, ['TM'], None),
                ('a24', None, ['TM'], None),
                ('a33', None, ['TM'], None),
                ('a34', None, ['TM'], None),
                ('a44', None, ['TM'], None),
                matched=True)
[20]:
# Test observable with known ID model

print(target)
print(observable(ID.A))
tensor([0.2735, 0.1723], dtype=torch.float64)
tensor([0.2735, 0.1723], dtype=torch.float64)
[21]:
# Define objective function to be fitted

def objective(knobs):
    return (target - observable(knobs)).norm()

print(objective(0.0*ID.A))
print(objective(1.0*ID.A))
tensor(0.0284, dtype=torch.float64)
tensor(0., dtype=torch.float64)
[22]:
# Fit (ADAM)

knobs = torch.tensor(10*[0.0], dtype=dtype)

wrapper = Wrapper(objective, knobs)
optimizer = torch.optim.AdamW(wrapper.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=32, gamma=0.95)

for epoch in range(256):
    value = wrapper()
    value.backward()
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()
    print(value.detach().item())
0.028369700506047313
0.027604979574733145
0.02684207653388388
0.026081063931985574
0.02532202986496677
0.02456503415221196
0.02381020965125639
0.023057643096075053
0.022307458020095296
0.021559802115037493
0.020814824133035887
0.020072691583410796
0.01933360633793344
0.01859779624814611
0.017865514537404256
0.01713705406561508
0.016412757428888126
0.015693019888152106
0.014978298873072853
0.014269131037937175
0.013566148256137492
0.012870095625237974
0.012181858106829188
0.011502492417944167
0.010833263166714534
0.010175691458854214
0.009531616132536076
0.0089032612390469
0.008293314881859004
0.00770502009416522
0.007142257433254706
0.006609596503016408
0.006112278035565941
0.005677920085235823
0.005285758836406501
0.004940049910706599
0.004643573081830862
0.0043965995525112
0.004196015429417217
0.004035000235607864
0.0039035410383780065
0.0037896923390195983
0.0036811553570592894
0.0035666798848923665
0.0034369804317109483
0.003285134365852536
0.003106597814067926
0.0028990209855239667
0.002662026802255918
0.0023970851335187657
0.0021076221174704873
0.0017995806340871346
0.001482886230260837
0.0011748123315649792
0.0009064586311714966
0.0007258771735991629
0.0006602409194314186
0.0006563256746629608
0.0006389218078180901
0.0005756323523207124
0.0004770652334132349
0.00040482048739814416
0.0004488134079377258
0.0005723447614733553
0.0006825347339017069
0.0007375622822843643
0.0007355602918931942
0.0006810064358945258
0.0005860382549705243
0.0004729294078276646
0.00037926505707123115
0.0003391704912980229
0.0003252472180930264
0.0002793347411800816
0.0001741432239347568
7.986430633214279e-05
0.00020694664461870272
0.00028467317875998787
0.00028737486365218256
0.0002349386205344058
0.00017869777807106698
0.00018210407210867278
0.00018394985032149243
0.00012200096293579978
2.504163199412578e-05
9.657132662324322e-05
9.342239305619409e-05
7.163867003840293e-05
8.063948051910279e-05
3.276989876051909e-05
9.287973607744902e-05
0.00011803087678276688
7.866800880937734e-05
9.610084779320682e-05
0.0001054906028501998
3.6598965960343354e-05
0.00010871857518170825
0.00014499645021069337
0.00010249816781572995
7.709395404407496e-05
0.00011473143281842893
8.447119574935473e-05
2.4138519635949214e-05
3.927730797423161e-05
4.5872961895850735e-05
3.615966267351201e-05
4.238990558167267e-05
2.5711734237141967e-05
6.722204484355513e-05
5.6237176581074446e-05
3.845990376887628e-05
3.5386012152311766e-05
5.463536908683197e-05
5.09545300326953e-05
3.7796217039974957e-05
3.0385406281355677e-05
6.080789589346814e-05
5.953057208598726e-05
2.9211774539882354e-05
2.6270378576523216e-05
5.8214918671829745e-05
5.293282262883786e-05
3.863131956949567e-05
3.629274046056883e-05
4.617446746188101e-05
4.03041094160775e-05
4.833493153181574e-05
4.2695343559576496e-05
4.297793319726718e-05
3.973455242313906e-05
4.172546219346975e-05
3.6131181626745114e-05
4.495459159393313e-05
4.171390245170822e-05
3.898845084458991e-05
3.40525616799877e-05
4.589294169984328e-05
4.22290410999437e-05
3.799126157978804e-05
3.346594407677967e-05
4.561314730373438e-05
4.180245361997472e-05
3.7657136599340985e-05
3.313284793812229e-05
4.5528090843893147e-05
4.1932270138886904e-05
3.6543947430367706e-05
3.18205652575772e-05
4.6589644378946656e-05
4.328422866191937e-05
3.426727169561934e-05
2.9413522880561434e-05
4.867666694102279e-05
4.556636180281018e-05
3.1263398376916696e-05
2.6447558661577688e-05
5.113199868108273e-05
4.800207880127862e-05
2.8348048485094676e-05
2.3773482254170528e-05
5.310392144287346e-05
4.989847245283575e-05
2.2505232483953626e-05
1.8922583760740437e-05
5.276224326107467e-05
4.857750571547126e-05
2.399680558197707e-05
2.0716221147737427e-05
5.048274380409993e-05
4.616866284917532e-05
2.5921201684188245e-05
2.249842355029714e-05
4.8574493461097755e-05
4.4469918266027036e-05
2.698449777660231e-05
2.3403719577676886e-05
4.747691695490115e-05
4.353538193467468e-05
2.7430166628320154e-05
2.3805448031834916e-05
4.6767250638575955e-05
4.2916757329182764e-05
2.7673187480954627e-05
2.4060070352808957e-05
4.617237618013766e-05
4.2385219797625957e-05
2.7856487907968014e-05
2.4234027976466777e-05
4.5705985210299664e-05
4.1999072428320274e-05
2.7869341596320038e-05
2.4198601353925668e-05
4.550055295993997e-05
4.205696123370083e-05
2.396418601692664e-05
2.043356025807998e-05
4.555562847965381e-05
4.2161956060550543e-05
2.354372112083146e-05
1.9990299374537987e-05
4.5766146502044445e-05
4.240805859375417e-05
2.3020777359054405e-05
1.9466818965866686e-05
4.605170524952427e-05
4.2723041572449765e-05
2.245147304912021e-05
1.8900730015085198e-05
4.6391080083101903e-05
4.3091386482800466e-05
2.1835399380092975e-05
1.828266818229918e-05
4.679414305050118e-05
4.351573450817071e-05
2.1176753718748183e-05
1.7624353193379887e-05
4.724479620515181e-05
4.398395474735512e-05
2.0491183017546885e-05
1.6944578498890487e-05
4.7724316938788e-05
4.448238093420612e-05
1.9784646815767107e-05
1.624368362199292e-05
4.823203981133701e-05
4.516591499447908e-05
1.5699650075826235e-05
1.2341150668018344e-05
4.873264396162423e-05
4.5675987504275135e-05
1.5014283095429034e-05
1.1665291024011255e-05
4.923892542290422e-05
4.619420779000219e-05
1.432470729659923e-05
1.0980951927595938e-05
4.976107908584259e-05
4.672690054924609e-05
1.3629601562839224e-05
1.0294192172370253e-05
5.028990354832106e-05
4.726384844695133e-05
1.2937594152729637e-05
9.608561027333687e-06
5.082705145289385e-05
4.781189403084756e-05
1.2238287230536143e-05
8.914247968399803e-06
5.1377728361687444e-05
4.8371756358532895e-05
1.1534707036896961e-05
8.218884916578543e-06
5.192768264520092e-05
4.8923645138056975e-05
1.085328843377851e-05
7.568294157060619e-06
[23]:
# Fit (LBFGS)

knobs = torch.tensor(10*[0.0], dtype=dtype, requires_grad=True)
optimizer = torch.optim.LBFGS([knobs], lr=0.1, line_search_fn="strong_wolfe")

def closure():
    optimizer.zero_grad()
    value = objective(knobs)
    value.backward()
    return value

for epoch in range(8):
    value = optimizer.step(closure)
    print(value.item())
0.028369700506047313
4.066909132862517e-07
3.358642506148264e-10
1.9754381178133033e-10
1.9754381178133033e-10
1.9754381178133033e-10
1.9754381178133033e-10
1.9754381178133033e-10
[24]:
# Check fitted tunes

print(target)
print(observable(ID.A))
print(observable(knobs.detach()))
print()

print((target - observable(ID.A)).abs())
print((target - observable(knobs.detach())).abs())
tensor([0.2735, 0.1723], dtype=torch.float64)
tensor([0.2735, 0.1723], dtype=torch.float64)
tensor([0.2735, 0.1723], dtype=torch.float64)

tensor([0., 0.], dtype=torch.float64)
tensor([1.8011e-10, 8.1134e-11], dtype=torch.float64)
[25]:
# Define fitted model

TM.A = knobs.detach()

print(target)
print(tune(model, [], matched=True, limit=1))
tensor([0.2735, 0.1723], dtype=torch.float64)
tensor([0.2735, 0.1723], dtype=torch.float64)
[26]:
# Define target tunes

target = tune(ring, [], matched=True, limit=1)
print(target)
tensor([0.2994, 0.1608], dtype=torch.float64)
[27]:
# Define parametric observable

QF = [f'QF_S{i:02}_{j:02}' for j in [2, 3] for i in range(1, 12 + 1)]
QD = [f'QD_S{i:02}_{j:02}' for j in [2, 3] for i in range(1, 12 + 1)]

def observable(kn):
    return tune(model, [kn], ('kn', None, QF + QD, None), matched=True, limit=1)

knobs = torch.zeros(len(QF + QD), dtype=dtype)
print(observable(knobs))
tensor([0.2735, 0.1723], dtype=torch.float64)
[28]:
# Define objective

def objective(knobs):
    return (target - observable(knobs)).norm()

print(objective(knobs))
tensor(0.0284, dtype=torch.float64)
[29]:
# Fit

knobs = torch.tensor(len(QF + QD)*[0.0], dtype=dtype, requires_grad=True)
optimizer = torch.optim.LBFGS([knobs], lr=0.1, line_search_fn="strong_wolfe")

def closure():
    optimizer.zero_grad()
    value = objective(knobs)
    value.backward()
    return value

for epoch in range(8):
    value = optimizer.step(closure)
    print(value.item())
0.028369700703583963
1.250945355652117e-07
2.8035556938341184e-09
2.1928367264705342e-09
1.6978970603827018e-09
3.292485872195638e-10
2.0558624258379998e-10
2.0558624258379998e-10
[30]:
# Apply corrections to model with the exact ID

result = error.clone()
result.flatten()
for name, knob in zip(QF + QD, knobs.detach()):
    result[name].kn = (result[name].kn + knob).item()
result.splice()
[31]:
# Compute tunes (fractional part)

nux_result, nuy_result = tune(result, [], matched=True, limit=1)
[32]:
# Compute dispersion

orbit = torch.tensor(4*[0.0], dtype=dtype)
etaqx_result, etapx_result, etaqy_result, etapy_result = dispersion(result, orbit, [], limit=1)
[33]:
# Compute twiss parameters

ax_result, bx_result, ay_result, by_result = twiss(result, [], matched=True, advance=True, full=False).T
[34]:
# Compute phase advances

mux_result, muy_result = advance(result, [], alignment=False, matched=True).T
[35]:
# Compute coupling

c_result = coupling(result, [])
[36]:
# Tune shifts

print((nux - nux_id).abs())
print((nuy - nuy_id).abs())
print()

print((nux - nux_result).abs())
print((nuy - nuy_result).abs())
print()
tensor(0.0260, dtype=torch.float64)
tensor(0.0114, dtype=torch.float64)

tensor(0.0001, dtype=torch.float64)
tensor(8.5259e-05, dtype=torch.float64)

[37]:
# Coupling (minimal tune distance)

print(c)
print(c_id)
print(c_result)
tensor(0., dtype=torch.float64)
tensor(0.0004, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
[38]:
# Dispersion

plt.figure(figsize=(12, 4))
plt.errorbar(ring.locations().cpu().numpy(), (etaqx - etaqx_id).cpu().numpy(), fmt='-', marker='x', color='blue', alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), (etaqy - etaqy_id).cpu().numpy(), fmt='-', marker='x', color='red', alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), (etaqx - etaqx_result).cpu().numpy(), fmt='-', marker='o', color='blue', alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), (etaqy - etaqy_result).cpu().numpy(), fmt='-', marker='o', color='red', alpha=0.75)

plt.tight_layout()
plt.show()
../_images/examples_elettra-06_38_0.png
[39]:
# Beta-beating

plt.figure(figsize=(12, 4))
plt.errorbar(ring.locations().cpu().numpy(), 100*((bx - bx_id)/bx).cpu().numpy(), fmt='-', marker='x', color='blue', alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), 100*((by - by_id)/by).cpu().numpy(), fmt='-', marker='x', color='red', alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), 100*((bx - bx_result)/bx).cpu().numpy(), fmt='-', marker='o', color='blue', alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), 100*((by - by_result)/by).cpu().numpy(), fmt='-', marker='o', color='red', alpha=0.75)
plt.tight_layout()
plt.show()

print(100*(((bx - bx_id)/bx)**2).mean().sqrt())
print(100*(((by - by_id)/by)**2).mean().sqrt())
print()

print(100*(((bx - bx_result)/bx)**2).mean().sqrt())
print(100*(((by - by_result)/by)**2).mean().sqrt())
print()
../_images/examples_elettra-06_39_0.png
tensor(11.5994, dtype=torch.float64)
tensor(1.7916, dtype=torch.float64)

tensor(13.5893, dtype=torch.float64)
tensor(1.0242, dtype=torch.float64)

[40]:
# Phase advance

plt.figure(figsize=(12, 4))
plt.errorbar(ring.locations().cpu().numpy(), 100*((mux - mux_id)/mux).cpu().numpy(), fmt='-', marker='x', color='blue', alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), 100*((muy - muy_id)/muy).cpu().numpy(), fmt='-', marker='x', color='red', alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), 100*((mux - mux_result)/mux).cpu().numpy(), fmt='-', marker='o', color='blue', alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), 100*((muy - muy_result)/muy).cpu().numpy(), fmt='-', marker='o', color='red', alpha=0.75)
plt.tight_layout()
plt.show()

print(100*(((mux - mux_id)/mux)**2).mean().sqrt())
print(100*(((muy - muy_id)/muy)**2).mean().sqrt())
print()

print(100*(((mux - mux_result)/mux)**2).mean().sqrt())
print(100*(((muy - muy_result)/muy)**2).mean().sqrt())
print()
../_images/examples_elettra-06_40_0.png
tensor(8.7941, dtype=torch.float64)
tensor(1.7778, dtype=torch.float64)

tensor(10.1208, dtype=torch.float64)
tensor(1.0428, dtype=torch.float64)