Example-54: Coupling (Coupling correction based on minimal tune)
[1]:
# In this example minimal tune distance is used for coupling correction
# Given measured values, fit lattice to reproduce measurements
[2]:
# Import
from pprint import pprint
import torch
from torch import Tensor
from pathlib import Path
import matplotlib
from matplotlib import pyplot as plt
from model.library.line import Line
from model.command.external import load_sdds
from model.command.external import load_lattice
from model.command.wrapper import group
from model.command.wrapper import forward
from model.command.wrapper import inverse
from model.command.wrapper import normalize
from model.command.wrapper import Wrapper
from model.command.build import build
from model.command.coupling import coupling
[3]:
# Build and setup lattice
# Load ELEGANT table
path = Path('ic.lte')
data = load_lattice(path)
# Build ELEGANT table
ring:Line = build('RING', 'ELEGANT', data)
ring.flatten()
# Merge drifts
ring.merge()
# Set linear dipoles
for element in ring:
if element.__class__.__name__ == 'Dipole':
element.linear = True
# Set number of elements of different kinds
nb = ring.describe['BPM']
nq = ring.describe['Quadrupole']
ns = ring.describe['Sextupole']
[4]:
# Set lattice with errors
error:Line = ring.clone()
nq = error.describe['Quadrupole']
error_ks = 0.1*torch.randn(nq, dtype=torch.float64)
index = 0
label = ''
for element in error.sequence:
if element.__class__.__name__ == 'Quadrupole':
if label != element.name:
index +=1
label = element.name
element.ks = (element.ks + error_ks[index - 1]).item()
[5]:
# Compute delta Q min
print(coupling(ring, []))
print(coupling(error, []))
tensor(0., dtype=torch.float64)
tensor(0.0109, dtype=torch.float64)
[6]:
# Correction (model to experiment)
# Set target delta Q min
coupling_error = coupling(error, [])
# Set learning rate
lr = 0.001
# Set parametric coupling (small value is added to avoid nan values)
def coupling_model(ks):
return coupling(ring, [ks + 2.5E-16], ('ks', ['Quadrupole'], None, None))
# Set objective function
def objective(ks):
return (coupling_error - coupling_model(ks)).norm()
# Test objective function
print(objective(0.0*error_ks))
print(objective(1.0*error_ks))
print()
# Set normalized objective
objective = normalize(objective, [(-0.5, 0.5)])
# Test normalized objective
print(objective(*forward([0.0*error_ks], [(-0.5, 0.5)])))
print(objective(*forward([1.0*error_ks], [(-0.5, 0.5)])))
print()
# Initial settings
# Note, it is better to use random initial along with multi-start
ks = torch.rand(nq, dtype=torch.float64)
# Set model (forward returns evaluated objective)
model = Wrapper(objective, ks)
# Set optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
# Perform optimization
for epoch in range(64):
value = model()
value.backward()
optimizer.step()
optimizer.zero_grad()
print(value.detach())
tensor(0.0109, dtype=torch.float64)
tensor(1.8041e-16, dtype=torch.float64)
tensor(0.0109, dtype=torch.float64)
tensor(1.4051e-16, dtype=torch.float64)
tensor(0.0117, dtype=torch.float64)
tensor(0.0104, dtype=torch.float64)
tensor(0.0092, dtype=torch.float64)
tensor(0.0079, dtype=torch.float64)
tensor(0.0067, dtype=torch.float64)
tensor(0.0054, dtype=torch.float64)
tensor(0.0041, dtype=torch.float64)
tensor(0.0027, dtype=torch.float64)
tensor(0.0013, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0012, dtype=torch.float64)
tensor(0.0018, dtype=torch.float64)
tensor(0.0020, dtype=torch.float64)
tensor(0.0019, dtype=torch.float64)
tensor(0.0016, dtype=torch.float64)
tensor(0.0011, dtype=torch.float64)
tensor(0.0005, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0007, dtype=torch.float64)
tensor(0.0009, dtype=torch.float64)
tensor(0.0009, dtype=torch.float64)
tensor(0.0008, dtype=torch.float64)
tensor(0.0006, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
tensor(0.0006, dtype=torch.float64)
tensor(0.0007, dtype=torch.float64)
tensor(0.0006, dtype=torch.float64)
tensor(0.0004, dtype=torch.float64)
tensor(5.3763e-05, dtype=torch.float64)
tensor(0.0004, dtype=torch.float64)
tensor(0.0007, dtype=torch.float64)
tensor(0.0008, dtype=torch.float64)
tensor(0.0007, dtype=torch.float64)
tensor(0.0005, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0004, dtype=torch.float64)
tensor(0.0005, dtype=torch.float64)
tensor(0.0004, dtype=torch.float64)
tensor(0.0001, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0004, dtype=torch.float64)
tensor(0.0004, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
tensor(1.9189e-05, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
tensor(0.0005, dtype=torch.float64)
tensor(0.0006, dtype=torch.float64)
tensor(0.0004, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0004, dtype=torch.float64)
tensor(0.0004, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
tensor(0.0001, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0004, dtype=torch.float64)
tensor(0.0004, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
tensor(9.4410e-06, dtype=torch.float64)
tensor(0.0004, dtype=torch.float64)
tensor(0.0006, dtype=torch.float64)
tensor(0.0006, dtype=torch.float64)
[7]:
# Apply corrections
lattice:Line = error.clone()
index = 0
label = ''
for line in lattice.sequence:
if element.__class__.__name__ == 'Quadrupole':
if label != element.name:
index +=1
label = element.name
element.ks = (element.ks - ks[index - 1]).item()
print(coupling(error, []))
print(coupling(lattice, []))
tensor(0.0109, dtype=torch.float64)
tensor(0.0109, dtype=torch.float64)