Example-47: Optimize (Multistart)
[1]:
# In this example optics correction is performed using different initial conditions for quadrupole settings
[2]:
# Import
import torch
from torch import Tensor
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from pathlib import Path
import matplotlib
from matplotlib import pyplot as plt
matplotlib.rcParams['text.usetex'] = True
from model.library.line import Line
from model.command.util import select
from model.command.external import load_sdds
from model.command.external import load_lattice
from model.command.build import build
from model.command.wrapper import group
from model.command.wrapper import forward
from model.command.wrapper import inverse
from model.command.wrapper import normalize
from model.command.wrapper import Wrapper
from model.command.tune import tune
from model.command.twiss import twiss
from model.command.optimize import adam
from model.command.optimize import newton
[3]:
# Load ELEGANT twiss
path = Path('ic.twiss')
parameters, columns = load_sdds(path)
nu_qx:Tensor = torch.tensor(parameters['nux'] % 1, dtype=torch.float64)
nu_qy:Tensor = torch.tensor(parameters['nuy'] % 1, dtype=torch.float64)
# Set twiss parameters at BPMs
kinds = select(columns, 'ElementType', keep=False)
a_qx = select(columns, 'alphax', keep=False)
b_qx = select(columns, 'betax' , keep=False)
a_qy = select(columns, 'alphay', keep=False)
b_qy = select(columns, 'betay' , keep=False)
a_qx:Tensor = torch.tensor([value for (key, value), kind in zip(a_qx.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)
b_qx:Tensor = torch.tensor([value for (key, value), kind in zip(b_qx.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)
a_qy:Tensor = torch.tensor([value for (key, value), kind in zip(a_qy.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)
b_qy:Tensor = torch.tensor([value for (key, value), kind in zip(b_qy.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)
positions = select(columns, 's', keep=False).items()
positions = [value for (key, value), kind in zip(positions, kinds.values()) if kind == 'MONI']
[4]:
# Build and setup lattice
# Load ELEGANT table
path = Path('ic.lte')
data = load_lattice(path)
# Build ELEGANT table
ring:Line = build('RING', 'ELEGANT', data)
ring.flatten()
# Merge drifts
ring.merge()
# Split BPMs
ring.split((None, ['BPM'], None, None))
# Roll lattice start
ring.roll(1)
# Set linear dipoles
for element in ring:
if element.__class__.__name__ == 'Dipole':
element.linear = True
# Split lattice into lines by BPMs
ring.splice()
# Set number of elements of different kinds
nb = ring.describe['BPM']
nq = ring.describe['Quadrupole']
ns = ring.describe['Sextupole']
[5]:
# Compare tunes
nuqx, nuqy = tune(ring, [], alignment=False, matched=True)
print(torch.allclose(nu_qx, nuqx))
print(torch.allclose(nu_qy, nuqy))
True
True
[6]:
# Compare twiss
aqx, bqx, aqy, bqy = twiss(ring, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
print(torch.allclose(a_qx, aqx))
print(torch.allclose(b_qx, bqx))
print(torch.allclose(a_qy, aqy))
print(torch.allclose(b_qy, bqy))
True
True
True
True
[7]:
# Set lattice with focusing errors (no coupling)
error:Line = ring.clone()
nq = error.describe['Quadrupole']
error_kn = 0.1*torch.randn(nq, dtype=torch.float64)
index = 0
label = ''
for line in error.sequence:
for element in line:
if element.__class__.__name__ == 'Quadrupole':
if label != element.name:
index +=1
label = element.name
element.kn = (element.kn + error_kn[index - 1]).item()
fig, ax = plt.subplots(1, 1, figsize=(16, 4))
ax.hist(error_kn.cpu().numpy(), bins=8, range=(-0.5, 0.5), color='blue', alpha=0.7)
plt.tight_layout()
plt.show()
[8]:
# Compute twiss and plot beta beating
ax_model, bx_model, ay_model, by_model = twiss(ring, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
ax_error, bx_error, ay_error, by_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
# Compare twiss
print((ax_model - ax_error).norm())
print((bx_model - bx_error).norm())
print((ay_model - ay_error).norm())
print((by_model - by_error).norm())
print()
# Plot beta beating
plt.figure(figsize=(16, 2))
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')
plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()
tensor(1.5845, dtype=torch.float64)
tensor(0.9961, dtype=torch.float64)
tensor(0.6124, dtype=torch.float64)
tensor(0.3541, dtype=torch.float64)
[9]:
# Batch objective
# Set target twiss parameters
twiss_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True)
# Define rings (Twiss parameters will be computed at each ring start)
rings:list[Line] = []
for i, _ in enumerate(ring):
line = ring.clone()
line.roll(i)
rings.append(line)
# Set batched function
_, ((_, names, _), *_), _ = group(ring, 0, len(ring) - 1, ('kn', ['Quadrupole'], None, None))
def evaluate(X, knobs):
kn = knobs
result = []
for x in X:
result.append(twiss(rings[x], [kn], ('kn', None, names, None), alignment=False, matched=True, convert=True))
return torch.stack(result)
# Normalize objective
evaluate = normalize(evaluate, [(None, None), (-0.25, 0.25)])
# Set loss funtion
lf = torch.nn.MSELoss()
# Set features and labels
X = torch.arange(len(ring))
y = twiss_error.clone()
# Set dataset
batch_size = 8
dataset = TensorDataset(X.clone(), y.clone())
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
# Set batch objective
def objective(knobs, X, y):
return lf(evaluate(X, knobs), y)
[10]:
# Optimization with adam optimizer (single initial)
# Set initial knob valurs
kn = torch.zeros_like(error_kn)
# Normalize
kn, *_ = forward([kn], [(-0.25, 0.25)])
# Perform optimization
kn = adam(objective, kn, dataloader, count=64, lr=0.005)
# Transform normalized result
kn_out, *_ = inverse([kn], [(-0.25, 0.25)])
# Plot quadrupole settings
plt.figure(figsize=(16, 2))
plt.bar(range(len(error_kn)), error_kn.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(kn_out)), +kn_out.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
# Compare twiss
print((twiss_error - twiss(ring, [0.0*kn_out], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)).norm())
print((twiss_error - twiss(ring, [1.0*kn_out], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)).norm())
# Plot beta beating
plt.figure(figsize=(16, 2))
_, bx_model, _, by_model = twiss(ring, [0.0*kn_out], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True).T
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')
_, bx_model, _, by_model = twiss(ring, [1.0*kn_out], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True).T
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='x')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='x')
plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()
tensor(2.0008, dtype=torch.float64)
tensor(0.0355, dtype=torch.float64)
[11]:
# To improve accuracy of a selected soolution, several Newton optimization steps can be performed
kn = newton(objective, kn, dataloader, count=4, lr=1.0)
kn, *_ = inverse([kn], [(-0.25, 0.25)])
# Plot quadrupole settings
plt.figure(figsize=(16, 2))
plt.bar(range(len(error_kn)), error_kn.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(kn)), +kn.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
# Compare twiss
print((twiss_error - twiss(ring, [kn_out], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)).norm())
print((twiss_error - twiss(ring, [kn], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)).norm())
tensor(0.0355, dtype=torch.float64)
tensor(6.5866e-09, dtype=torch.float64)
[12]:
# Optimization (multiple initials)
# Note, newton optimizer is also vmappable over initials
kns = torch.rand((1024, nq), dtype=torch.float64)
kns = torch.vmap(lambda kn: adam(objective, kn, dataloader, count=64, lr=0.005), randomness='same', chunk_size=1024)(kns)
print(kns.shape)
torch.Size([1024, 28])
[13]:
# Plot final errors within selected range
def error(kn):
return (twiss_error - twiss(ring, [*inverse([kn], [(-0.25, 0.25)])], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)).norm()
values = torch.vmap(error)(kns)
fig, ax = plt.subplots(1, 1, figsize=(16, 4))
ax.hist(values.cpu().numpy(), bins=50, color='blue', alpha=0.7)
plt.tight_layout()
plt.show()