Example-47: Optimize (Multistart)

[1]:
# In this example optics correction is performed using different initial conditions for quadrupole settings
[2]:
# Import

import torch
from torch import Tensor
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

from pathlib import Path

import matplotlib
from matplotlib import pyplot as plt
matplotlib.rcParams['text.usetex'] = True

from model.library.line import Line

from model.command.util import select

from model.command.external import load_sdds
from model.command.external import load_lattice

from model.command.build import build

from model.command.wrapper import group
from model.command.wrapper import forward
from model.command.wrapper import inverse
from model.command.wrapper import normalize
from model.command.wrapper import Wrapper

from model.command.tune import tune
from model.command.twiss import twiss

from model.command.optimize import adam
from model.command.optimize import newton
[3]:
# Load ELEGANT twiss

path = Path('ic.twiss')
parameters, columns = load_sdds(path)

nu_qx:Tensor = torch.tensor(parameters['nux'] % 1, dtype=torch.float64)
nu_qy:Tensor = torch.tensor(parameters['nuy'] % 1, dtype=torch.float64)

# Set twiss parameters at BPMs

kinds = select(columns, 'ElementType', keep=False)

a_qx = select(columns, 'alphax', keep=False)
b_qx = select(columns, 'betax' , keep=False)
a_qy = select(columns, 'alphay', keep=False)
b_qy = select(columns, 'betay' , keep=False)

a_qx:Tensor = torch.tensor([value for (key, value), kind in zip(a_qx.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)
b_qx:Tensor = torch.tensor([value for (key, value), kind in zip(b_qx.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)
a_qy:Tensor = torch.tensor([value for (key, value), kind in zip(a_qy.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)
b_qy:Tensor = torch.tensor([value for (key, value), kind in zip(b_qy.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)

positions = select(columns, 's', keep=False).items()
positions = [value for (key, value), kind in zip(positions, kinds.values()) if kind == 'MONI']
[4]:
# Build and setup lattice

# Load ELEGANT table

path = Path('ic.lte')
data = load_lattice(path)

# Build ELEGANT table

ring:Line = build('RING', 'ELEGANT', data)
ring.flatten()

# Merge drifts

ring.merge()

# Split BPMs

ring.split((None, ['BPM'], None, None))

# Roll lattice start

ring.roll(1)

# Set linear dipoles

for element in ring:
    if element.__class__.__name__ == 'Dipole':
        element.linear = True

# Split lattice into lines by BPMs

ring.splice()

# Set number of elements of different kinds

nb = ring.describe['BPM']
nq = ring.describe['Quadrupole']
ns = ring.describe['Sextupole']
[5]:
# Compare tunes

nuqx, nuqy = tune(ring, [], alignment=False, matched=True)

print(torch.allclose(nu_qx, nuqx))
print(torch.allclose(nu_qy, nuqy))
True
True
[6]:
# Compare twiss

aqx, bqx, aqy, bqy = twiss(ring, [], alignment=False, matched=True, advance=True, full=False, convert=True).T

print(torch.allclose(a_qx, aqx))
print(torch.allclose(b_qx, bqx))
print(torch.allclose(a_qy, aqy))
print(torch.allclose(b_qy, bqy))
True
True
True
True
[7]:
# Set lattice with focusing errors (no coupling)

error:Line = ring.clone()

nq = error.describe['Quadrupole']

error_kn = 0.1*torch.randn(nq, dtype=torch.float64)

index = 0
label = ''

for line in error.sequence:
    for element in line:
        if element.__class__.__name__ == 'Quadrupole':
            if label != element.name:
                index +=1
            label = element.name
            element.kn = (element.kn + error_kn[index - 1]).item()


fig, ax = plt.subplots(1, 1, figsize=(16, 4))
ax.hist(error_kn.cpu().numpy(), bins=8,  range=(-0.5, 0.5), color='blue', alpha=0.7)
plt.tight_layout()
plt.show()
../_images/examples_model-46_7_0.png
[8]:
# Compute twiss and plot beta beating

ax_model, bx_model, ay_model, by_model = twiss(ring, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
ax_error, bx_error, ay_error, by_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True).T

# Compare twiss

print((ax_model - ax_error).norm())
print((bx_model - bx_error).norm())
print((ay_model - ay_error).norm())
print((by_model - by_error).norm())
print()

# Plot beta beating

plt.figure(figsize=(16, 2))
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')
plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()
tensor(1.5845, dtype=torch.float64)
tensor(0.9961, dtype=torch.float64)
tensor(0.6124, dtype=torch.float64)
tensor(0.3541, dtype=torch.float64)

../_images/examples_model-46_8_1.png
[9]:
# Batch objective

# Set target twiss parameters

twiss_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True)

# Define rings (Twiss parameters will be computed at each ring start)

rings:list[Line] = []

for i, _ in enumerate(ring):
    line = ring.clone()
    line.roll(i)
    rings.append(line)

# Set batched function

_, ((_, names, _), *_), _ = group(ring, 0, len(ring) - 1, ('kn', ['Quadrupole'], None, None))

def evaluate(X, knobs):
    kn = knobs
    result = []
    for x in X:
        result.append(twiss(rings[x], [kn], ('kn', None, names, None), alignment=False, matched=True, convert=True))
    return torch.stack(result)

# Normalize objective

evaluate = normalize(evaluate, [(None, None), (-0.25, 0.25)])

# Set loss funtion

lf = torch.nn.MSELoss()

# Set features and labels

X = torch.arange(len(ring))
y = twiss_error.clone()

# Set dataset

batch_size = 8
dataset = TensorDataset(X.clone(), y.clone())
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

# Set batch objective

def objective(knobs, X, y):
    return lf(evaluate(X, knobs), y)
[10]:
# Optimization with adam optimizer (single initial)

# Set initial knob valurs

kn = torch.zeros_like(error_kn)

# Normalize

kn, *_ = forward([kn], [(-0.25, 0.25)])

# Perform optimization

kn = adam(objective, kn, dataloader, count=64, lr=0.005)

# Transform normalized result

kn_out, *_ = inverse([kn], [(-0.25, 0.25)])

# Plot quadrupole settings

plt.figure(figsize=(16, 2))
plt.bar(range(len(error_kn)), error_kn.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(kn_out)), +kn_out.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()

# Compare twiss

print((twiss_error - twiss(ring, [0.0*kn_out], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)).norm())
print((twiss_error - twiss(ring, [1.0*kn_out], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)).norm())

# Plot beta beating

plt.figure(figsize=(16, 2))

_, bx_model, _, by_model = twiss(ring, [0.0*kn_out], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True).T
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')

_, bx_model, _, by_model = twiss(ring, [1.0*kn_out], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True).T
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='x')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='x')

plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()
../_images/examples_model-46_10_0.png
tensor(2.0008, dtype=torch.float64)
tensor(0.0355, dtype=torch.float64)
../_images/examples_model-46_10_2.png
[11]:
# To improve accuracy of a selected soolution, several Newton optimization steps can be performed

kn = newton(objective, kn, dataloader, count=4, lr=1.0)
kn, *_ = inverse([kn], [(-0.25, 0.25)])

# Plot quadrupole settings

plt.figure(figsize=(16, 2))
plt.bar(range(len(error_kn)), error_kn.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(kn)), +kn.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()

# Compare twiss

print((twiss_error - twiss(ring, [kn_out], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)).norm())
print((twiss_error - twiss(ring, [kn], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)).norm())
../_images/examples_model-46_11_0.png
tensor(0.0355, dtype=torch.float64)
tensor(6.5866e-09, dtype=torch.float64)
[12]:
# Optimization (multiple initials)
# Note, newton optimizer is also vmappable over initials

kns = torch.rand((1024, nq), dtype=torch.float64)
kns = torch.vmap(lambda kn: adam(objective, kn, dataloader, count=64, lr=0.005), randomness='same', chunk_size=1024)(kns)
print(kns.shape)
torch.Size([1024, 28])
[13]:
# Plot final errors within selected range

def error(kn):
    return (twiss_error - twiss(ring, [*inverse([kn], [(-0.25, 0.25)])], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)).norm()

values = torch.vmap(error)(kns)

fig, ax = plt.subplots(1, 1, figsize=(16, 4))
ax.hist(values.cpu().numpy(), bins=50, color='blue', alpha=0.7)
plt.tight_layout()
plt.show()
../_images/examples_model-46_13_0.png