Example-36: Orbit (ORM optics correction)

[1]:
# In this example ORM base optics correction is explored
# Note, effects due to dispersion are not included

# Focusing error are added to quadrupoles and correcponding ORM is computed (observed ORM)
# The goal is to adjust quadrupole settings in the model to match observer ORM
# Once, parameters are fitted, they can be applied (possible with some weight) and ORM can be measured again

# In this example the derivative of ORM is computed with respect to quadrupole strength deviations
# Given a model response, procedure similar to the closed orbit correction can be performed
# Specific ORM blocks can be targeted in correction, e.g. coupling correction
# ML correction loop are also demonstraited including mini-batch version
# Response matrix can be also used with sparse solvers
[2]:
# Import

from pprint import pprint

import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

from sklearn.linear_model import Lasso
from sklearn.linear_model import OrthogonalMatchingPursuit
from sklearn.linear_model import Lars

from pathlib import Path

import matplotlib
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
matplotlib.rcParams['text.usetex'] = True

from twiss import twiss
from twiss import propagate
from twiss import wolski_to_cs

from model.library.corrector import Corrector
from model.library.line import Line

from model.command.util import chop
from model.command.util import select

from model.command.external import load_sdds
from model.command.external import load_lattice

from model.command.build import build

from model.command.wrapper import group
from model.command.wrapper import forward
from model.command.wrapper import inverse
from model.command.wrapper import normalize
from model.command.wrapper import Wrapper

from model.command.orbit import orbit
from model.command.orbit import ORM
from model.command.orbit import ORM_IJ

from model.command.layout import Layout
[3]:
# Load ELEGANT twiss

path = Path('ic.twiss')
parameters, columns = load_sdds(path)

nu_qx:float = parameters['nux'] % 1
nu_qy:float = parameters['nuy'] % 1

# Set twiss parameters at BPMs

kinds = select(columns, 'ElementType', keep=False)

a_qx = select(columns, 'alphax', keep=False)
b_qx = select(columns, 'betax' , keep=False)
a_qy = select(columns, 'alphay', keep=False)
b_qy = select(columns, 'betay' , keep=False)

a_qx = {key: value for (key, value), kind in zip(a_qx.items(), kinds.values()) if kind == 'MONI'}
b_qx = {key: value for (key, value), kind in zip(b_qx.items(), kinds.values()) if kind == 'MONI'}
a_qy = {key: value for (key, value), kind in zip(a_qy.items(), kinds.values()) if kind == 'MONI'}
b_qy = {key: value for (key, value), kind in zip(b_qy.items(), kinds.values()) if kind == 'MONI'}

positions = select(columns, 's', keep=False).items()
positions = [value for (key, value), kind in zip(positions, kinds.values()) if kind == 'MONI']
[4]:
# Build and setup lattice

# Quadrupoles are splitted into 2**2 parts, Dipoles -- 2**4 part
# Correctors are inserted between parts

path = Path('ic.lte')
data = load_lattice(path)

ring:Line = build('RING', 'ELEGANT', data)
ring.flatten()
ring.merge()
ring.split((None, ['BPM'], None, None))
ring.roll(1)

n_q = 2**2
n_d = 2**4

for name in [name for name, kind, *_ in ring.layout() if kind == 'Quadrupole']:
    corrector = Corrector(f'{name}_CXY', factor=1/(n_q - 1))
    ring.split((n_q, None, [name], None), paste=[corrector])

for name in [name for name, kind, *_ in ring.layout() if kind == 'Dipole']:
    corrector = Corrector(f'{name}_CXY', factor=1/(n_d - 1))
    ring.split((n_d, None, [name], None), paste=[corrector])

for element in ring:
    if element.__class__.__name__ == 'Dipole':
        element.linear = True

ring.splice()
[5]:
# Compare linear tunes

state = torch.tensor(4*[0.0], dtype=torch.float64)
matrix = torch.func.jacrev(ring)(state)
(nuqx, nuqy), _, w = twiss(matrix)

print(nu_qx - nuqx)
print(nu_qy - nuqy)

# Compute twiss parameters at bpms

ring.propagate = False
ring.matrix = True

state = torch.tensor(4*[0.0], dtype=torch.float64)
ring(state)

# Propagate twiss parametes

ws = [w]
*ms, _ = ring.container_matrix
for m in ms:
    w = propagate(w, m)
    ws.append(w)
ws = torch.stack(ws)

# Remove matrix flag

ring.matrix = False

# Convert to CS and plot beta functions

ax, bx, ay, by = torch.vmap(wolski_to_cs)(ws).T

layout = Layout(ring)
_, _, lengths, *_ = layout.slicing_table()
rectangles = layout.profile_1d(scale=0.75, shift=0.0, text=False,  exclude=['BPM', 'Corrector'])

plt.figure(figsize=(16, 2))
plt.plot(positions, b_qx.values(), color='red', alpha=0.75, marker='o', label=r'$\beta_x$')
plt.errorbar(ring.locations().cpu().numpy(), bx.cpu().numpy(), fmt=' ', ms=8, color='black', alpha=0.75, marker='x')
plt.plot(positions, b_qy.values(), color='blue', alpha=0.75, marker='o', label=r'$\beta_y$')
plt.errorbar(ring.locations().cpu().numpy(), by.cpu().numpy(), fmt=' ', ms=8, color='black', alpha=0.75, marker='x')
plt.xticks(ticks=positions, labels=b_qx.keys())
plt.legend()
for rectangle in rectangles:
    plt.gca().add_patch(Rectangle(**rectangle))
plt.ylim(-1.0, 5.0)
plt.tight_layout()
plt.show()
tensor(1.4433e-15, dtype=torch.float64)
tensor(-9.9920e-16, dtype=torch.float64)
../_images/examples_model-35_5_1.png
[6]:
# Compute closed orbit

fp = 1.0E-3*torch.randn(4, dtype=torch.float64)
fp, *_ = orbit(ring, fp, [], alignment=False, limit=8, epsilon=1.0E-12)

# Chop small values

fp = [fp]
chop(fp)
fp, *_ = fp

print(fp)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[7]:
# Compute ORM (without errors)

orm = ORM(ring, fp, [], limit=1, start=0, epsilon=None)
print(orm.shape)

data = orm.clone()
data[data==0.0] = torch.nan

plt.figure(figsize=(34/4, 72/4))
img = plt.imshow(data.cpu().numpy(), cmap='magma', interpolation='nearest')
cax = plt.gcf().add_axes([plt.gca().get_position().x1 + 0.01, plt.gca().get_position().y0, 0.02, plt.gca().get_position().height])
plt.colorbar(img, cax=cax)
plt.show()
torch.Size([32, 72])
../_images/examples_model-35_7_1.png
[8]:
# Set lattice with focusing errors

error:Line = ring.clone()

nq = error.describe['Quadrupole']

error_kn = 0.05*torch.randn(nq, dtype=torch.float64)
error_ks = 0.05*torch.randn(nq, dtype=torch.float64)

index = 0
label = ''

for line in error.sequence:
    for element in line:
        if element.__class__.__name__ == 'Quadrupole':
            if label != element.name:
                index +=1
            label = element.name
            element.kn = (element.kn + error_kn[index - 1]).item()
            element.ks = (element.ks + error_ks[index - 1]).item()
[9]:
# Setup twiss computation

def task(line):

    line:Line = line.clone()

    state = torch.tensor(4*[0.0], dtype=torch.float64)
    matrix = torch.func.jacrev(line)(state)
    (nux, nuy), _, w = twiss(matrix)

    line.propagate = False
    line.matrix = True

    state = torch.tensor(4*[0.0], dtype=torch.float64)
    line(state)

    ws = [w]
    *ms, _ = line.container_matrix
    for m in ms:
        w = propagate(w, m)
        ws.append(w)
    ws = torch.stack(ws)

    ax, bx, ay, by = torch.vmap(wolski_to_cs)(ws).T

    return (nux, nuy), (ax, bx, ay, by)

# Unperturbed twiss

(nux_model, nuy_model), (ax_model, bx_model, ay_model, by_model) = task(ring)

# Perturbed twiss

(nux_error, nuy_error), (ax_error, bx_error, ay_error, by_error) = task(error)

# Compare twiss

print((nux_model - nux_error).norm())
print((nuy_model - nuy_error).norm())
print()

print((ax_model - ax_error).norm())
print((bx_model - bx_error).norm())
print((ay_model - ay_error).norm())
print((by_model - by_error).norm())
print()

# Plot beta beating

plt.figure(figsize=(16, 2))
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')
plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()
tensor(0.0045, dtype=torch.float64)
tensor(0.0089, dtype=torch.float64)

tensor(0.7893, dtype=torch.float64)
tensor(0.4524, dtype=torch.float64)
tensor(0.2795, dtype=torch.float64)
tensor(0.1634, dtype=torch.float64)

../_images/examples_model-35_9_1.png
[10]:
# Compute (measure) ORM for lattice with errors

orm_error = ORM(error, fp, [], limit=1, start=0, epsilon=None)
print(orm_error.shape)

data = orm_error.clone()
data[data==0.0] = torch.nan

plt.figure(figsize=(34/4, 72/4))
img = plt.imshow(data.cpu().numpy(), cmap='magma', interpolation='nearest')
cax = plt.gcf().add_axes([plt.gca().get_position().x1 + 0.01, plt.gca().get_position().y0, 0.02, plt.gca().get_position().height])
plt.colorbar(img, cax=cax)
plt.show()
torch.Size([32, 72])
../_images/examples_model-35_10_1.png
[11]:
# Plot difference of model and observer ORMs

data = (orm - orm_error).clone()
data[data==0.0] = torch.nan

plt.figure(figsize=(34/4, 72/4))
img = plt.imshow(data.cpu().numpy(), cmap='magma', interpolation='nearest')
cax = plt.gcf().add_axes([plt.gca().get_position().x1 + 0.01, plt.gca().get_position().y0, 0.02, plt.gca().get_position().height])
plt.colorbar(img, cax=cax)
plt.show()
../_images/examples_model-35_11_0.png
[12]:
# Define parametric ORM

def ORM_knks(knks):
    kn, ks = knks.reshape(1 + 1, -1)
    return ORM(ring, fp, [kn, ks], ('kn', ['Quadrupole'], None, None), ('ks', ['Quadrupole'], None, None), limit=1, start=0, epsilon=None)

# Test correction with exact error values

print(torch.allclose(orm_error, ORM_knks(torch.cat([error_kn, error_ks]))))
True
[13]:
# Parametric ORM is differentiable with respect to deviation groups

# Set deviation tensors

kn = torch.zeros_like(error_kn)
ks = torch.zeros_like(error_ks)

knks = torch.cat([kn, ks])

# Compute derivative of model ORM with respect to kn and ks deviations

response = torch.func.jacrev(ORM_knks)(knks)

print(orm.shape)
print(response.shape)
torch.Size([32, 72])
torch.Size([32, 72, 56])
[14]:
# Test ORM response

error_knks = torch.cat([error_kn, error_ks])

matrix = response.reshape(-1, *knks.shape)

print((orm_error - (orm + (matrix @ (0*error_knks)).reshape_as(orm))).norm())
print((orm_error - (orm + (matrix @ (1*error_knks)).reshape_as(orm))).norm())
tensor(3.1476, dtype=torch.float64)
tensor(0.1936, dtype=torch.float64)
[15]:
# Perform correction (model to experiment)

lr = 0.5

kn = torch.zeros_like(error_kn)
ks = torch.zeros_like(error_ks)

for _ in range(16):
    orm_fit = ORM_knks(torch.cat([kn, ks]))
    dkn, dks = - lr*torch.linalg.lstsq(matrix, (orm_fit - orm_error).flatten(), driver='gelsd').solution.reshape(1 + 1, -1)
    kn += dkn
    ks += dks
    print((orm_fit - orm_error).norm())

# Plot final quadrupole settings

plt.figure(figsize=(16, 2))
plt.bar(range(len(error_kn)), error_kn.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(kn)), +kn.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()

plt.figure(figsize=(16, 2))
plt.bar(range(len(error_ks)), error_ks.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(ks)), +ks.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
tensor(3.1476, dtype=torch.float64)
tensor(1.5975, dtype=torch.float64)
tensor(0.8304, dtype=torch.float64)
tensor(0.4394, dtype=torch.float64)
tensor(0.2358, dtype=torch.float64)
tensor(0.1279, dtype=torch.float64)
tensor(0.0700, dtype=torch.float64)
tensor(0.0385, dtype=torch.float64)
tensor(0.0213, dtype=torch.float64)
tensor(0.0118, dtype=torch.float64)
tensor(0.0066, dtype=torch.float64)
tensor(0.0037, dtype=torch.float64)
tensor(0.0020, dtype=torch.float64)
tensor(0.0011, dtype=torch.float64)
tensor(0.0006, dtype=torch.float64)
tensor(0.0004, dtype=torch.float64)
../_images/examples_model-35_15_1.png
../_images/examples_model-35_15_2.png
[16]:
# Apply corrections

lattice:Line = error.clone()

index = 0
label = ''

for line in lattice.sequence:
    for element in line:
        if element.__class__.__name__ == 'Quadrupole':
            if label != element.name:
                index +=1
            label = element.name
            element.kn = (element.kn - kn[index - 1]).item()
            element.ks = (element.ks - ks[index - 1]).item()
[17]:
# Compare ORM with model before and after correction

orm_lattice = ORM(lattice, fp, [], limit=1, start=0, epsilon=None)

data = (orm - orm_error).clone()
data[data==0.0] = torch.nan
plt.figure(figsize=(34/4, 72/4))
img = plt.imshow(data.cpu().numpy(), cmap='magma', interpolation='nearest')
cax = plt.gcf().add_axes([plt.gca().get_position().x1 + 0.01, plt.gca().get_position().y0, 0.02, plt.gca().get_position().height])
plt.colorbar(img, cax=cax)
plt.show()

data = (orm - orm_lattice).clone()
data[data==0.0] = torch.nan
plt.figure(figsize=(34/4, 72/4))
img = plt.imshow(data.cpu().numpy(), cmap='magma', interpolation='nearest')
cax = plt.gcf().add_axes([plt.gca().get_position().x1 + 0.01, plt.gca().get_position().y0, 0.02, plt.gca().get_position().height])
plt.colorbar(img, cax=cax)
plt.show()
../_images/examples_model-35_17_0.png
../_images/examples_model-35_17_1.png
[18]:
# Compare twiss parameters with model before and after correction

(nux_model, nuy_model), (ax_model, bx_model, ay_model, by_model) = task(ring)
(nux_error, nuy_error), (ax_error, bx_error, ay_error, by_error) = task(error)
(nux_lattice, nuy_lattice), (ax_lattice, bx_lattice, ay_lattice, by_lattice) = task(lattice)

# Before

plt.figure(figsize=(16, 2))
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')
plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()

# After

plt.figure(figsize=(16, 2))
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_lattice)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_lattice)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')
plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()
../_images/examples_model-35_18_0.png
../_images/examples_model-35_18_1.png
[19]:
# ORM function computes all blocks of the orbit response matrix [[Rxx, Rxy], [Ryx, Ryy]]
# It might be usefull to compute block separately
# This can be done directly with orbit function as follows

def Rxx(knks):
    kn, ks = knks.reshape(1 + 1, -1)
    count = ring.describe['Corrector']
    cx = torch.tensor(count*[0.0], dtype=line.dtype, device=line.device)
    def task(cx):
        points, _ = orbit(ring, fp, [cx, kn, ks], ('cx', ['Corrector'], None, None), ('kn', ['Quadrupole'], None, None), ('ks', ['Quadrupole'], None, None), advance=True, full=False, alignment=False, limit=1, epsilon=None)
        qx, _, qy, _ = points.T
        return qx
    return torch.func.jacrev(task)(cx)


def Rxy(knks):
    kn, ks = knks.reshape(1 + 1, -1)
    count = ring.describe['Corrector']
    cy = torch.tensor(count*[0.0], dtype=line.dtype, device=line.device)
    def task(cy):
        points, _ = orbit(ring, fp, [cy, kn, ks], ('cy', ['Corrector'], None, None), ('kn', ['Quadrupole'], None, None), ('ks', ['Quadrupole'], None, None), advance=True, full=False, alignment=False, limit=1, epsilon=None)
        qx, _, qy, _ = points.T
        return qx
    return torch.func.jacrev(task)(cy)


def Ryx(knks):
    kn, ks = knks.reshape(1 + 1, -1)
    count = ring.describe['Corrector']
    cx = torch.tensor(count*[0.0], dtype=line.dtype, device=line.device)
    def task(cy):
        points, _ = orbit(ring, fp, [cx, kn, ks], ('cx', ['Corrector'], None, None), ('kn', ['Quadrupole'], None, None), ('ks', ['Quadrupole'], None, None), advance=True, full=False, alignment=False, limit=1, epsilon=None)
        qx, _, qy, _  = points.T
        return qy
    return torch.func.jacrev(task)(cx)


def Ryy(knks):
    kn, ks = knks.reshape(1 + 1, -1)
    count = ring.describe['Corrector']
    cy = torch.tensor(count*[0.0], dtype=line.dtype, device=line.device)
    def task(cy):
        points, _ = orbit(ring, fp, [cy, kn, ks], ('cy', ['Corrector'], None, None), ('kn', ['Quadrupole'], None, None), ('ks', ['Quadrupole'], None, None), advance=True, full=False, alignment=False, limit=1, epsilon=None)
        qx, _, qy, _  = points.T
        return qy
    return torch.func.jacrev(task)(cy)
[20]:
# Compute blocks

kn = torch.zeros_like(error_kn)
ks = torch.zeros_like(error_ks)
knks = torch.cat([kn, ks])

rxx = Rxx(knks)
rxy = Rxy(knks)
ryx = Ryx(knks)
ryy = Ryy(knks)

torch.allclose(orm, torch.vstack([torch.hstack([rxx, rxy]), torch.hstack([rxy, ryy])]))
[20]:
True
[21]:
# Blocks can be used for more targeted correction
# For example, coupling correction can be performed using only rxy and ryx blocks

def objective(kn, ks):
    knks = torch.cat([kn, ks])
    rxy = Rxy(knks)
    ryx = Ryx(knks)
    return torch.stack([rxy, ryx]).flatten()

kn = torch.zeros_like(error_kn)
ks = torch.zeros_like(error_ks)

matrix = torch.func.jacrev(objective, 1)(kn, ks)
vector_error = objective(error_kn, error_ks)

lr = 0.5

for _ in range(8):
    vector_fit = objective(kn, ks)
    dks = - lr*torch.linalg.lstsq(matrix, (vector_fit - vector_error), driver='gelsd').solution
    ks += dks
    print((vector_fit - vector_error).norm())

# Plot fitted parameters

plt.figure(figsize=(16, 2))
plt.bar(range(len(error_ks)), error_ks.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(ks)), +ks.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
tensor(1.0791, dtype=torch.float64)
tensor(0.5400, dtype=torch.float64)
tensor(0.2708, dtype=torch.float64)
tensor(0.1374, dtype=torch.float64)
tensor(0.0729, dtype=torch.float64)
tensor(0.0439, dtype=torch.float64)
tensor(0.0330, dtype=torch.float64)
tensor(0.0296, dtype=torch.float64)
../_images/examples_model-35_21_1.png
[22]:
# Use rxx and ryy blocks to fit kn values

def objective(kn, ks):
    knks = torch.cat([kn, ks])
    rxx = Rxx(knks)
    ryy = Ryy(knks)
    return torch.stack([rxx, ryy]).flatten()

matrix = torch.func.jacrev(objective, 0)(kn, ks)
vector_error = objective(error_kn, error_ks)

lr = 0.5

for _ in range(8):
    vector_fit = objective(kn, ks)
    dkn = - lr*torch.linalg.lstsq(matrix, (vector_fit - vector_error), driver='gelsd').solution
    kn += dkn
    print((vector_fit - vector_error).norm())

# Plot fitted parameters

plt.figure(figsize=(16, 2))
plt.bar(range(len(error_kn)), error_kn.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(kn)), +kn.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
tensor(2.7973, dtype=torch.float64)
tensor(1.4148, dtype=torch.float64)
tensor(0.7288, dtype=torch.float64)
tensor(0.3806, dtype=torch.float64)
tensor(0.2011, dtype=torch.float64)
tensor(0.1074, dtype=torch.float64)
tensor(0.0578, dtype=torch.float64)
tensor(0.0314, dtype=torch.float64)
../_images/examples_model-35_22_1.png
[23]:
# ML style correction (full ORM)

# Set parametric ORM

def ORM_knks(kn, ks):
    return ORM(ring, fp, [kn, ks], ('kn', ['Quadrupole'], None, None), ('ks', ['Quadrupole'], None, None), limit=1, start=0, epsilon=None)

# Setup objective function

def objective(kn, ks):
    orm = ORM_knks(kn, ks)
    return (orm - orm_error).norm()

# Set initial values

kn = torch.zeros_like(error_kn)
ks = torch.zeros_like(error_ks)

# Test objective function

print(objective(error_kn, error_ks))

# Setup normalized objective

objective = normalize(objective, [(-0.5, 0.5), (-0.5, 0.5)])

# Test normalized objective

print(objective(*forward([error_kn, error_ks], [(-0.5, 0.5), (-0.5, 0.5)])))

# Normalize initial corrector settings

kn, ks, *_ = forward([kn, ks], [(-0.5, 0.5), (-0.5, 0.5)])

# Set model (forward returns evaluated objective)

model = Wrapper(objective, kn, ks)

# Set optimizer

optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)

# Perform optimization

for epoch in range(64):
    value = model()
    value.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(value.detach())
tensor(0., dtype=torch.float64)
tensor(2.6115e-13, dtype=torch.float64)
tensor(3.1476, dtype=torch.float64)
tensor(2.4511, dtype=torch.float64)
tensor(2.1673, dtype=torch.float64)
tensor(1.6036, dtype=torch.float64)
tensor(1.2814, dtype=torch.float64)
tensor(1.3156, dtype=torch.float64)
tensor(1.2044, dtype=torch.float64)
tensor(1.1573, dtype=torch.float64)
tensor(1.2635, dtype=torch.float64)
tensor(1.2003, dtype=torch.float64)
tensor(0.9719, dtype=torch.float64)
tensor(0.8235, dtype=torch.float64)
tensor(0.7653, dtype=torch.float64)
tensor(0.5897, dtype=torch.float64)
tensor(0.5532, dtype=torch.float64)
tensor(0.6488, dtype=torch.float64)
tensor(0.6361, dtype=torch.float64)
tensor(0.6368, dtype=torch.float64)
tensor(0.5782, dtype=torch.float64)
tensor(0.4452, dtype=torch.float64)
tensor(0.4537, dtype=torch.float64)
tensor(0.4271, dtype=torch.float64)
tensor(0.4672, dtype=torch.float64)
tensor(0.4747, dtype=torch.float64)
tensor(0.4334, dtype=torch.float64)
tensor(0.3978, dtype=torch.float64)
tensor(0.2903, dtype=torch.float64)
tensor(0.3264, dtype=torch.float64)
tensor(0.3193, dtype=torch.float64)
tensor(0.3527, dtype=torch.float64)
tensor(0.2962, dtype=torch.float64)
tensor(0.2854, dtype=torch.float64)
tensor(0.2589, dtype=torch.float64)
tensor(0.3143, dtype=torch.float64)
tensor(0.2854, dtype=torch.float64)
tensor(0.2890, dtype=torch.float64)
tensor(0.2751, dtype=torch.float64)
tensor(0.2480, dtype=torch.float64)
tensor(0.2566, dtype=torch.float64)
tensor(0.2515, dtype=torch.float64)
tensor(0.2427, dtype=torch.float64)
tensor(0.2202, dtype=torch.float64)
tensor(0.2097, dtype=torch.float64)
tensor(0.2235, dtype=torch.float64)
tensor(0.2249, dtype=torch.float64)
tensor(0.1987, dtype=torch.float64)
tensor(0.1980, dtype=torch.float64)
tensor(0.2068, dtype=torch.float64)
tensor(0.1977, dtype=torch.float64)
tensor(0.1842, dtype=torch.float64)
tensor(0.1782, dtype=torch.float64)
tensor(0.1792, dtype=torch.float64)
tensor(0.1709, dtype=torch.float64)
tensor(0.1703, dtype=torch.float64)
tensor(0.1781, dtype=torch.float64)
tensor(0.1855, dtype=torch.float64)
tensor(0.1845, dtype=torch.float64)
tensor(0.1834, dtype=torch.float64)
tensor(0.1941, dtype=torch.float64)
tensor(0.1998, dtype=torch.float64)
tensor(0.1804, dtype=torch.float64)
tensor(0.1682, dtype=torch.float64)
tensor(0.1681, dtype=torch.float64)
tensor(0.1829, dtype=torch.float64)
[24]:
# Apply corrections

kn, ks = inverse([kn, ks], [(-0.5, 0.5), (-0.5, 0.5)])

lattice:Line = error.clone()

index = 0
label = ''

for line in lattice.sequence:
    for element in line:
        if element.__class__.__name__ == 'Quadrupole':
            if label != element.name:
                index +=1
            label = element.name
            element.kn = (element.kn - kn[index - 1]).item()
            element.ks = (element.ks - ks[index - 1]).item()
[25]:
# Compare ORM with model before and after correction

orm_lattice = ORM(lattice, fp, [], limit=1, start=0, epsilon=None)

data = (orm - orm_error).clone()
data[data==0.0] = torch.nan
plt.figure(figsize=(34/4, 72/4))
img = plt.imshow(data.cpu().numpy(), cmap='magma', interpolation='nearest')
cax = plt.gcf().add_axes([plt.gca().get_position().x1 + 0.01, plt.gca().get_position().y0, 0.02, plt.gca().get_position().height])
plt.colorbar(img, cax=cax)
plt.show()

data = (orm - orm_lattice).clone()
data[data==0.0] = torch.nan
plt.figure(figsize=(34/4, 72/4))
img = plt.imshow(data.cpu().numpy(), cmap='magma', interpolation='nearest')
cax = plt.gcf().add_axes([plt.gca().get_position().x1 + 0.01, plt.gca().get_position().y0, 0.02, plt.gca().get_position().height])
plt.colorbar(img, cax=cax)
plt.show()
../_images/examples_model-35_25_0.png
../_images/examples_model-35_25_1.png
[26]:
# Compare twiss parameters with model before and after correction

(nux_model, nuy_model), (ax_model, bx_model, ay_model, by_model) = task(ring)
(nux_error, nuy_error), (ax_error, bx_error, ay_error, by_error) = task(error)
(nux_lattice, nuy_lattice), (ax_lattice, bx_lattice, ay_lattice, by_lattice) = task(lattice)

# Before

plt.figure(figsize=(16, 2))
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')
plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()

# After

plt.figure(figsize=(16, 2))
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_lattice)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_lattice)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')
plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()
../_images/examples_model-35_26_0.png
../_images/examples_model-35_26_1.png
[27]:
# ORM_IJ function can be used to compute derivatives of closed orbit at specified BPM (I) from specified corrector (J)
# In this case the lattice shoud be flat
# Is and Js should match BPM and Corrector locations (this if fact a requirement for correctors)

# Set flat line

line = error.clone()
line.flatten()

# Set locations
# Note, index method returns a tensor, but ints are expected by ORM_IJ

_, Is = line.index('BPM')
_, Js = line.index('Corrector')

# Compute ORM_IJ

nb = line.describe['BPM']
nc = line.describe['Corrector']

index_i = 5
index_j = 10

I = Is[index_i]
J = Js[index_j]

print(ORM_IJ(line, fp, I.item(), J.item(), [], limit=1, epsilon=None))
print()

print(orm_error[index_i, index_j])           # d(qx)/d(cx)
print(orm_error[index_i, index_j + nc])      # d(qx)/d(cy)
print(orm_error[index_i + nb, index_j])      # d(qy)/d(cx)
print(orm_error[index_i + nb, index_j + nc]) # d(qy)/d(cy)
tensor([[ 0.4677,  0.0016],
        [-0.0080,  0.3623]], dtype=torch.float64)

tensor(0.4677, dtype=torch.float64)
tensor(0.0016, dtype=torch.float64)
tensor(-0.0080, dtype=torch.float64)
tensor(0.3623, dtype=torch.float64)
[28]:
# ML style correction (batched)

# Set parametric ORM

line = ring.clone()
line.flatten()

nb = line.describe['BPM']
nc = line.describe['Corrector']

_, Is = line.index('BPM')
_, Js = line.index('Corrector')

# Set batched funtion
# This function is to slow to be of practical interest
# As it is the case for orbit correction, evaluation of full ORM and selecting specific elements is now faster
# Here, only several interations are performed for demonstration

def ORM_knks(IJs, kn, ks):
    result = []
    for IJ in IJs:
        result.append(ORM_IJ(line, fp, *IJ.tolist(), [kn, ks], ('kn', ['Quadrupole'], None, None), ('ks', ['Quadrupole'], None, None), limit=1, epsilon=None))
    return torch.stack(result)

kn = torch.zeros_like(error_kn)
ks = torch.zeros_like(error_ks)

# Normalize objective

ORM_knks = normalize(ORM_knks, [(None, None), (-0.5, 0.5), (-0.5, 0.5)])

# Normalize initial corrector settings

kn, ks, *_ = forward([kn, ks], [(-0.5, 0.5), (-0.5, 0.5)])

# Set model

model = Wrapper(ORM_knks, kn, ks)

# Set optimizer

optimizer = torch.optim.AdamW(model.parameters(), lr=0.005)

# Set features and labels

X = []
y = []
for i in range(nb):
    for j in range(nc):
        X.append(torch.stack([Is[i], Js[j]]))
        y.append(torch.stack([orm_error[i, j], orm_error[i, j + nc], orm_error[i + nb, j], orm_error[i + nb, j + nc]]).reshape(2, 2))
X = torch.stack(X)
y = torch.stack(y)

# Set dataset

batch_size = 16
dataset = TensorDataset(X.clone(), y.clone())
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Set loss funtion

lf = torch.nn.MSELoss()

# Perfom optimization

for epoch in range(1):
    for batch, (X, y) in enumerate(dataloader):
        y_hat = model(X)
        value = lf(y_hat, y)
        value.backward()
        optimizer.step()
        optimizer.zero_grad()
        with torch.no_grad():
            print(epoch, batch, len(dataloader) - 1, value.detach())
0 0 35 tensor(0.0038, dtype=torch.float64)
0 1 35 tensor(0.0035, dtype=torch.float64)
0 2 35 tensor(0.0020, dtype=torch.float64)
0 3 35 tensor(0.0017, dtype=torch.float64)
0 4 35 tensor(0.0009, dtype=torch.float64)
0 5 35 tensor(0.0008, dtype=torch.float64)
0 6 35 tensor(0.0008, dtype=torch.float64)
0 7 35 tensor(0.0010, dtype=torch.float64)
0 8 35 tensor(0.0015, dtype=torch.float64)
0 9 35 tensor(0.0006, dtype=torch.float64)
0 10 35 tensor(0.0007, dtype=torch.float64)
0 11 35 tensor(0.0009, dtype=torch.float64)
0 12 35 tensor(0.0009, dtype=torch.float64)
0 13 35 tensor(0.0007, dtype=torch.float64)
0 14 35 tensor(0.0007, dtype=torch.float64)
0 15 35 tensor(0.0007, dtype=torch.float64)
0 16 35 tensor(0.0006, dtype=torch.float64)
0 17 35 tensor(0.0003, dtype=torch.float64)
0 18 35 tensor(0.0006, dtype=torch.float64)
0 19 35 tensor(0.0004, dtype=torch.float64)
0 20 35 tensor(0.0005, dtype=torch.float64)
0 21 35 tensor(0.0004, dtype=torch.float64)
0 22 35 tensor(0.0002, dtype=torch.float64)
0 23 35 tensor(0.0003, dtype=torch.float64)
0 24 35 tensor(0.0003, dtype=torch.float64)
0 25 35 tensor(0.0003, dtype=torch.float64)
0 26 35 tensor(0.0002, dtype=torch.float64)
0 27 35 tensor(0.0001, dtype=torch.float64)
0 28 35 tensor(0.0003, dtype=torch.float64)
0 29 35 tensor(0.0003, dtype=torch.float64)
0 30 35 tensor(0.0001, dtype=torch.float64)
0 31 35 tensor(0.0003, dtype=torch.float64)
0 32 35 tensor(0.0002, dtype=torch.float64)
0 33 35 tensor(0.0001, dtype=torch.float64)
0 34 35 tensor(0.0001, dtype=torch.float64)
0 35 35 tensor(0.0002, dtype=torch.float64)
[29]:
# Apply corrections

kn, ks = inverse([kn, ks], [(-0.5, 0.5), (-0.5, 0.5)])

lattice:Line = error.clone()

index = 0
label = ''

for line in lattice.sequence:
    for element in line:
        if element.__class__.__name__ == 'Quadrupole':
            if label != element.name:
                index +=1
            label = element.name
            element.kn = (element.kn - kn[index - 1]).item()
            element.ks = (element.ks - ks[index - 1]).item()
[30]:
# Compare ORM with model before and after correction

orm_lattice = ORM(lattice, fp, [], limit=1, start=0, epsilon=None)

data = (orm - orm_error).clone()
data[data==0.0] = torch.nan
plt.figure(figsize=(34/4, 72/4))
img = plt.imshow(data.cpu().numpy(), cmap='magma', interpolation='nearest')
cax = plt.gcf().add_axes([plt.gca().get_position().x1 + 0.01, plt.gca().get_position().y0, 0.02, plt.gca().get_position().height])
plt.colorbar(img, cax=cax)
plt.show()

data = (orm - orm_lattice).clone()
data[data==0.0] = torch.nan
plt.figure(figsize=(34/4, 72/4))
img = plt.imshow(data.cpu().numpy(), cmap='magma', interpolation='nearest')
cax = plt.gcf().add_axes([plt.gca().get_position().x1 + 0.01, plt.gca().get_position().y0, 0.02, plt.gca().get_position().height])
plt.colorbar(img, cax=cax)
plt.show()
../_images/examples_model-35_30_0.png
../_images/examples_model-35_30_1.png
[31]:
# Compare twiss parameters with model before and after correction

(nux_model, nuy_model), (ax_model, bx_model, ay_model, by_model) = task(ring)
(nux_error, nuy_error), (ax_error, bx_error, ay_error, by_error) = task(error)
(nux_lattice, nuy_lattice), (ax_lattice, bx_lattice, ay_lattice, by_lattice) = task(lattice)

# Before

plt.figure(figsize=(16, 2))
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')
plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()

# After

plt.figure(figsize=(16, 2))
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_lattice)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_lattice)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')
plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()
../_images/examples_model-35_31_0.png
../_images/examples_model-35_31_1.png
[32]:
# Set parametric ORM

def ORM_knks(kn, ks):
    return ORM(ring, fp, [kn, ks], ('kn', ['Quadrupole'], None, None), ('ks', ['Quadrupole'], None, None), limit=1, start=0, epsilon=None)
[33]:
# Convert data to numpy

X = response.reshape(-1, *knks.shape).cpu().numpy()
y = (orm_error - orm).flatten().cpu().numpy()
[34]:
# Lasso

lasso = Lasso(alpha=0.0005)
lasso.fit(X, y)
solution = lasso.coef_
kn_out, ks_out = torch.tensor(solution, dtype=torch.float64).reshape(1 + 1, -1)

# Check error

print((orm_error - ORM_knks(0.0*kn_out, 0.0*ks_out)).norm())
print((orm_error - ORM_knks(1.0*kn_out, 1.0*ks_out)).norm())

# Plot kn and ks

plt.figure(figsize=(16, 2))
plt.bar(range(len(error_kn)), error_kn.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(kn_out)), +kn_out.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()

plt.figure(figsize=(16, 2))
plt.bar(range(len(error_ks)), error_ks.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(ks_out)), +ks_out.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
tensor(3.1476, dtype=torch.float64)
tensor(0.6863, dtype=torch.float64)
../_images/examples_model-35_34_1.png
../_images/examples_model-35_34_2.png
[35]:
# OMP

omp = OrthogonalMatchingPursuit(n_nonzero_coefs=16)
omp.fit(X, y)
solution = omp.coef_
kn_out, ks_out = torch.tensor(solution, dtype=torch.float64).reshape(1 + 1, -1)

# Check error

print((orm_error - ORM_knks(0.0*kn_out, 0.0*ks_out)).norm())
print((orm_error - ORM_knks(1.0*kn_out, 1.0*ks_out)).norm())

# Plot kn and ks

plt.figure(figsize=(16, 2))
plt.bar(range(len(error_kn)), error_kn.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(kn_out)), +kn_out.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()

plt.figure(figsize=(16, 2))
plt.bar(range(len(error_ks)), error_ks.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(ks_out)), +ks_out.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
tensor(3.1476, dtype=torch.float64)
tensor(0.4987, dtype=torch.float64)
../_images/examples_model-35_35_1.png
../_images/examples_model-35_35_2.png
[36]:
# Lars

lars = Lars(n_nonzero_coefs=16)
lars.fit(X, y)
solution = lars.coef_
kn_out, ks_out = torch.tensor(solution, dtype=torch.float64).reshape(1 + 1, -1)

# Check error

print((orm_error - ORM_knks(0.0*kn_out, 0.0*ks_out)).norm())
print((orm_error - ORM_knks(1.0*kn_out, 1.0*ks_out)).norm())

# Plot kn and ks

plt.figure(figsize=(16, 2))
plt.bar(range(len(error_kn)), error_kn.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(kn_out)), +kn_out.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()

plt.figure(figsize=(16, 2))
plt.bar(range(len(error_ks)), error_ks.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(ks_out)), +ks_out.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
tensor(3.1476, dtype=torch.float64)
tensor(1.8210, dtype=torch.float64)
../_images/examples_model-35_36_1.png
../_images/examples_model-35_36_2.png