Example-35: Orbit (closed orbit correction)
[1]:
# In this example orbit correction is illustrated
# Normally, global orbit correction is performed using model or measured ORM
# In this case accelerator corrector setting are actively altered to correct closed orbit to a specific target orbit
# Usually, SVD based inversion is used to solve linear system in this case
# Here, SVS and lstsq are used to correct observed model to design model
# Another option would be to fit design model to reproduce observed orbit distortion
# Having obtained corrector setting that reproduce observer behaviour, they can be applied to observed model with flipped signs
# New observation can be then performed followed by next correction iteration if necessary
# Again, global SVD style correction can be performed in this case
# We also perform ML style optimization loop including mini-batch version
[2]:
# Import
from pprint import pprint
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from pathlib import Path
from matplotlib import pyplot as plt
from twiss import twiss
from model.library.corrector import Corrector
from model.library.line import Line
from model.command.util import chop
from model.command.external import load_sdds
from model.command.external import load_lattice
from model.command.build import build
from model.command.wrapper import group
from model.command.wrapper import forward
from model.command.wrapper import inverse
from model.command.wrapper import normalize
from model.command.wrapper import Wrapper
from model.command.orbit import orbit
from model.command.orbit import ORM
[3]:
# Load ELEGANT twiss
path = Path('ic.twiss')
parameters, columns = load_sdds(path)
nu_qx:float = parameters['nux'] % 1
nu_qy:float = parameters['nuy'] % 1
[4]:
# Build and setup lattice
# Quadrupoles are splitted into 2**2 parts, Dipoles -- 2**4 part
# Correctors are inserted between parts
path = Path('ic.lte')
data = load_lattice(path)
ring:Line = build('RING', 'ELEGANT', data)
ring.propagate = True
ring.flatten()
ring.merge()
ring.split((None, ['BPM'], None, None))
ring.roll(1)
n_q = 2**2
n_d = 2**4
for name in [name for name, kind, *_ in ring.layout() if kind == 'Quadrupole']:
corrector = Corrector(f'{name}_CXY', factor=1/(n_q - 1))
ring.split((n_q, None, [name], None), paste=[corrector])
for name in [name for name, kind, *_ in ring.layout() if kind == 'Dipole']:
corrector = Corrector(f'{name}_CXY', factor=1/(n_d - 1))
ring.split((n_d, None, [name], None), paste=[corrector])
for element in ring:
if element.__class__.__name__ == 'Dipole':
element.linear = True
ring.splice()
[5]:
# Compare linear tunes
state = torch.tensor(4*[0.0], dtype=torch.float64)
matrix = torch.func.jacrev(ring)(state)
(nuqx, nuqy), *_ = twiss(matrix)
print(nu_qx - nuqx)
print(nu_qy - nuqy)
tensor(1.4433e-15, dtype=torch.float64)
tensor(-9.9920e-16, dtype=torch.float64)
[6]:
# Compute closed orbit
fp = 1.0E-3*torch.randn(4, dtype=torch.float64)
fp, *_ = orbit(ring, fp, [], alignment=False, limit=8, epsilon=1.0E-12)
# Chop small values
fp = [fp]
chop(fp)
fp, *_ = fp
print(fp)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[7]:
# Compute ORM
orm = ORM(ring, fp, [], limit=1, start=0, epsilon=None)
print(orm.shape)
data = orm.clone()
data[data==0.0] = torch.nan
plt.figure(figsize=(34/4, 72/4))
img = plt.imshow(data.cpu().numpy(), cmap='magma', interpolation='nearest')
cax = plt.gcf().add_axes([plt.gca().get_position().x1 + 0.01, plt.gca().get_position().y0, 0.02, plt.gca().get_position().height])
plt.colorbar(img, cax=cax)
plt.show()
torch.Size([32, 72])
[8]:
# Usually, given the goal and observed orbit, corrector settings can be adjusted to reproduce the goal obit
# Set number of correctors
nc = ring.describe['Corrector']
# Set random errors
error_cx = 100.0E-6*torch.randn(nc, dtype=torch.float64)
error_cy = 100.0E-6*torch.randn(nc, dtype=torch.float64)
# Set first half to zero
error_cx[:nc//2] = 0.0
error_cy[:nc//2] = 0.0
# Set number of BPMs
nb = ring.describe['BPM']
# Set target orbit
qx_target = torch.zeros(nb, dtype=torch.float64)
qy_target = torch.zeros(nb, dtype=torch.float64)
# Correction loop (svd based)
# Find change in corrector settings
lr = 0.75
cx = torch.zeros_like(error_cx)
cy = torch.zeros_like(error_cy)
for _ in range(8):
points, *_ = orbit(ring, fp, [cx + error_cx, cy + error_cx], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), alignment=False, advance=True, full=False, limit=8, epsilon=1.0E-6)
qx, _, qy, _ = points.T
dcx, dcy = - lr*(torch.linalg.pinv(orm) @ torch.cat([qx - qx_target, qy - qy_target])).reshape(1 + 1, -1)
cx += dcx
cy += dcy
print(torch.cat([qx - qx_target, qy - qy_target]).norm())
# Plot final corrector settings
plt.figure(figsize=(16, 2))
plt.bar(range(len(error_cx)), -error_cx.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(cx)), +cx.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
plt.figure(figsize=(16, 2))
plt.bar(range(len(error_cy)), -error_cy.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(cy)), +cy.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
print(cx.norm())
print(cy.norm())
tensor(0.0034, dtype=torch.float64)
tensor(0.0009, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(5.3507e-05, dtype=torch.float64)
tensor(1.3377e-05, dtype=torch.float64)
tensor(3.3442e-06, dtype=torch.float64)
tensor(8.3604e-07, dtype=torch.float64)
tensor(2.0901e-07, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
[9]:
# Same as above but using lstsq insted of explicit pseudo inverse computation
lr = 0.75
cx = torch.zeros_like(error_cx)
cy = torch.zeros_like(error_cy)
for _ in range(8):
points, *_ = orbit(ring, fp, [cx + error_cx, cy + error_cx], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), alignment=False, advance=True, full=False, limit=8, epsilon=1.0E-6)
qx, _, qy, _ = points.T
dcx, dcy = - lr*torch.linalg.lstsq(orm, torch.stack([qx - qx_target, qy - qy_target]).flatten(), driver='gelsd').solution.reshape(1 + 1, -1)
cx += dcx
cy += dcy
print(torch.cat([qx - qx_target, qy - qy_target]).norm())
# Plot final corrector settings
plt.figure(figsize=(16, 2))
plt.bar(range(len(error_cx)), -error_cx.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(cx)), +cx.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
plt.figure(figsize=(16, 2))
plt.bar(range(len(error_cy)), -error_cy.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(cy)), +cy.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
print(cx.norm())
print(cy.norm())
tensor(0.0034, dtype=torch.float64)
tensor(0.0009, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(5.3507e-05, dtype=torch.float64)
tensor(1.3377e-05, dtype=torch.float64)
tensor(3.3442e-06, dtype=torch.float64)
tensor(8.3604e-07, dtype=torch.float64)
tensor(2.0901e-07, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
[10]:
# Assume the lattice has accidental non-zero correction settings, which results in closed orbit distortion
# Given a measured orbit at all BPMs, our goal is to adjust model to match the observed orbit (fit model to experiment)
# Such arrangement of correction settings in the model is not necessarily unique
# Set number of correctors
nc = ring.describe['Corrector']
# Set random errors
error_cx = 100.0E-6*torch.randn(nc, dtype=torch.float64)
error_cy = 100.0E-6*torch.randn(nc, dtype=torch.float64)
# Set first half to zero
error_cx[:nc//2] = 0.0
error_cy[:nc//2] = 0.0
# Compute closed orbit with errors
points, *_ = orbit(ring, fp, [error_cx, error_cy], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), alignment=False, advance=True, full=False, limit=8, epsilon=1.0E-6)
# Set wrapper
start, *_, end = ring.names
mapping, *_ = group(ring, start, end, ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), alignment=False)
# Propagate estimated closed orbit
point, *_ = points
print(point)
print(mapping(point, error_cx, error_cy))
print(torch.allclose(point, mapping(point, error_cx, error_cy), rtol=1.0E-12, atol=1.0E-12))
print()
# Set observed orbit
qx_target, _, qy_target, _ = points.T
# Plot observed orbit
plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qx_target.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qy_target.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
tensor([-4.2773e-04, 9.6366e-04, 4.1688e-06, -7.8074e-05],
dtype=torch.float64)
tensor([-4.2773e-04, 9.6366e-04, 4.1688e-06, -7.8074e-05],
dtype=torch.float64)
True
[11]:
# Correct model to experiment using ORM and lstsq fit
lr = 0.75
cx = torch.zeros_like(error_cx)
cy = torch.zeros_like(error_cy)
for _ in range(8):
points, *_ = orbit(ring, fp, [cx, cy], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), alignment=False, advance=True, full=False, limit=8, epsilon=1.0E-6)
qx, _, qy, _ = points.T
dcx, dcy = - lr*torch.linalg.lstsq(orm, torch.stack([qx - qx_target, qy - qy_target]).flatten(), driver='gelsd').solution.reshape(1 + 1, -1)
cx += dcx
cy += dcy
print(torch.cat([qx - qx_target, qy - qy_target]).norm())
# Plot final corrector settings
plt.figure(figsize=(16, 2))
plt.bar(range(len(error_cx)), error_cx.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(cx)), +cx.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
plt.figure(figsize=(16, 2))
plt.bar(range(len(error_cy)), error_cy.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(cy)), +cy.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
# Plot orbits
plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qx_target.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qx.cpu().numpy(), fmt='-', color='red', marker='x', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qy_target.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qy.cpu().numpy(), fmt='-', color='red', marker='x', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
tensor(0.0021, dtype=torch.float64)
tensor(0.0005, dtype=torch.float64)
tensor(0.0001, dtype=torch.float64)
tensor(3.2555e-05, dtype=torch.float64)
tensor(8.1539e-06, dtype=torch.float64)
tensor(2.0432e-06, dtype=torch.float64)
tensor(5.1223e-07, dtype=torch.float64)
tensor(1.2849e-07, dtype=torch.float64)
[12]:
# Apply negative fitted correstor settings and compute orbit
points, *_ = orbit(ring, fp, [error_cx - cx, error_cy - cy], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), alignment=False, advance=True, full=False, limit=8, epsilon=1.0E-6)
qx, _, qy, _ = points.T
# Plot corrected orbit
plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qx_target.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qx.cpu().numpy(), fmt='-', color='red', marker='x', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qy_target.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qy.cpu().numpy(), fmt='-', color='red', marker='x', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
[13]:
# ML style correction (full orbit)
# Find corrector setting to reproduce observed orbit, apply with negative sign (and some weight) and repeat for new observation
# Setup function to compute orbit at all BPMs for given corrector settings
def qxqy(cx, cy):
points, _ = orbit(ring, fp, [cx, cy], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), advance=True, full=False, alignment=False, limit=8, epsilon=1.0E-6)
qx, _, qy, _ = points.T
return torch.stack([qx, qy])
# Setup objective function
def objective(cx, cy):
qx, qy = qxqy(cx, cy)
return ((qx - qx_target)**2 + (qy - qy_target)**2).sum().sqrt()
# Set initial corrector settings (can also set some random values)
cx = torch.zeros_like(error_cx)
cy = torch.zeros_like(error_cy)
# Test objective function
print(objective(error_cx, error_cy))
# Setup normalized objective
objective = normalize(objective, [(-150*1E-6, 150*1E-6), (-150*1E-6, 150*1E-6)])
# Test normalized objective
print(objective(*forward([error_cx, error_cy],[(-150*1E-6, 150*1E-6), (-150*1E-6, 150*1E-6)])))
# Normalize initial corrector settings
cx, cy, *_ = forward([cx, cy], [(-150*1E-6, 150*1E-6), (-150*1E-6, 150*1E-6)])
# Set model (forward returns evaluated objective)
model = Wrapper(objective, cx, cy)
# Set optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.05)
# Perform optimization
for epoch in range(128):
error = model()
error.backward()
optimizer.step()
optimizer.zero_grad()
print(error.detach())
tensor(1.2776e-17, dtype=torch.float64)
tensor(1.2868e-17, dtype=torch.float64)
tensor(0.0021, dtype=torch.float64)
tensor(0.0022, dtype=torch.float64)
tensor(0.0015, dtype=torch.float64)
tensor(0.0017, dtype=torch.float64)
tensor(0.0017, dtype=torch.float64)
tensor(0.0012, dtype=torch.float64)
tensor(0.0008, dtype=torch.float64)
tensor(0.0012, dtype=torch.float64)
tensor(0.0012, dtype=torch.float64)
tensor(0.0009, dtype=torch.float64)
tensor(0.0006, dtype=torch.float64)
tensor(0.0008, dtype=torch.float64)
tensor(0.0009, dtype=torch.float64)
tensor(0.0007, dtype=torch.float64)
tensor(0.0005, dtype=torch.float64)
tensor(0.0007, dtype=torch.float64)
tensor(0.0008, dtype=torch.float64)
tensor(0.0006, dtype=torch.float64)
tensor(0.0005, dtype=torch.float64)
tensor(0.0006, dtype=torch.float64)
tensor(0.0005, dtype=torch.float64)
tensor(0.0004, dtype=torch.float64)
tensor(0.0004, dtype=torch.float64)
tensor(0.0005, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
tensor(0.0004, dtype=torch.float64)
tensor(0.0004, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
tensor(0.0004, dtype=torch.float64)
tensor(0.0004, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0001, dtype=torch.float64)
tensor(0.0001, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
tensor(0.0001, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0001, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0001, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0001, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
[14]:
# Renormalize output and compute orbit
cx_out, cy_out = inverse([cx, cy],[(-150*1E-6, 150*1E-6), (-150*1E-6, 150*1E-6)])
points, *_ = orbit(ring, fp, [cx_out, cy_out], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), alignment=False, advance=True, full=False, limit=8, epsilon=1.0E-6)
qx, _, qy, _ = points.T
# Plot final corrector settings
plt.figure(figsize=(16, 2))
plt.bar(range(len(error_cx)), error_cx.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(cx)), +cx_out.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
plt.figure(figsize=(16, 2))
plt.bar(range(len(error_cy)), error_cy.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(cy)), +cy_out.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
# Plot orbits
plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qx_target.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qx.cpu().numpy(), fmt='-', color='red', marker='x', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qy_target.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qy.cpu().numpy(), fmt='-', color='red', marker='x', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
[15]:
# Apply negative fitted correstor settings and compute orbit
points, *_ = orbit(ring, fp, [error_cx - cx_out, error_cy - cy_out], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), alignment=False, advance=True, full=False, limit=8, epsilon=1.0E-6)
qx, _, qy, _ = points.T
# Plot corrected orbit
plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qx_target.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qx.cpu().numpy(), fmt='-', color='red', marker='x', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qy_target.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qy.cpu().numpy(), fmt='-', color='red', marker='x', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
[16]:
# ML style correction (batched orbit)
# Find corrector setting to reproduce observed orbit, apply with negative sign (and some weight) and repeat for new observation
# Setup function to compute orbit at selected BPMs for given corrector settings
# This function computes closed orbit at given locations by changing lattice start
def qxqy(starts, cx, cy):
points = []
for start in starts:
point, _ = orbit(ring, fp, [cx, cy], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), start=int(start), advance=False, full=False, alignment=False, limit=8, epsilon=1.0E-6)
points.append(point)
qx, _, qy, _ = torch.stack(points).T
return torch.stack([qx, qy]).T
# In fact it is faster to computer the whole orbit and select a subset
def qxqy(starts, cx, cy):
points, _ = orbit(ring, fp, [cx, cy], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), advance=True, full=False, alignment=False, limit=8, epsilon=1.0E-6)
qx, _, qy, _ = points.T
return torch.stack([qx, qy]).T[starts]
# Set initial corrector settings (can also set some random values)
cx = torch.zeros_like(error_cx)
cy = torch.zeros_like(error_cy)
# Test objective function
print(qxqy([0, 1], error_cx, error_cy))
# Normalize objective
qxqy = normalize(qxqy, [(None, None), (-150*1E-6, 150*1E-6), (-150*1E-6, 150*1E-6)])
# Test normalized objective
print(qxqy(*forward([[0, 1], error_cx, error_cy], [(None, None),(-150*1E-6, 150*1E-6), (-150*1E-6, 150*1E-6)])))
# Normalize initial corrector settings
cx, cy, *_ = forward([cx, cy], [(-150*1E-6, 150*1E-6), (-150*1E-6, 150*1E-6)])
# Set model (forward returns evaluated objective)
model = Wrapper(qxqy, cx, cy)
# Test model
print(model(*forward([[0, 1], error_cx, error_cy], [(None, None), (-150*1E-6, 150*1E-6), (-150*1E-6, 150*1E-6)])))
# Set optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.05)
# Set features and labels
X = torch.arange(nb)
y = torch.stack([qx_target, qy_target]).T
# Set dataset
batch_size = 8
dataset = TensorDataset(X.clone(), y.clone())
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Set loss funtion
lf = torch.nn.MSELoss()
# Perfom optimization
for epoch in range(128):
for batch, (X, y) in enumerate(dataloader):
y_hat = model(X)
error = lf(y_hat, y.squeeze())
error.backward()
optimizer.step()
optimizer.zero_grad()
print(error.detach())
tensor([[-4.2773e-04, 4.1688e-06],
[ 6.2837e-04, -1.2416e-04]], dtype=torch.float64)
tensor([[-4.2773e-04, 4.1688e-06],
[ 6.2837e-04, -1.2416e-04]], dtype=torch.float64)
tensor([[-4.2773e-04, 4.1688e-06],
[ 6.2837e-04, -1.2416e-04]], dtype=torch.float64)
tensor(1.0103e-07, dtype=torch.float64)
tensor(8.5893e-08, dtype=torch.float64)
tensor(6.9922e-08, dtype=torch.float64)
tensor(2.9579e-08, dtype=torch.float64)
tensor(4.0506e-08, dtype=torch.float64)
tensor(2.0698e-08, dtype=torch.float64)
tensor(8.7432e-09, dtype=torch.float64)
tensor(1.6769e-08, dtype=torch.float64)
tensor(1.4832e-08, dtype=torch.float64)
tensor(1.5364e-08, dtype=torch.float64)
tensor(8.4098e-09, dtype=torch.float64)
tensor(1.6792e-08, dtype=torch.float64)
tensor(7.4891e-09, dtype=torch.float64)
tensor(1.0057e-08, dtype=torch.float64)
tensor(1.0049e-08, dtype=torch.float64)
tensor(6.9731e-09, dtype=torch.float64)
tensor(4.2305e-09, dtype=torch.float64)
tensor(3.9217e-09, dtype=torch.float64)
tensor(5.4500e-09, dtype=torch.float64)
tensor(3.3462e-09, dtype=torch.float64)
tensor(3.1018e-09, dtype=torch.float64)
tensor(2.2604e-09, dtype=torch.float64)
tensor(4.6111e-09, dtype=torch.float64)
tensor(2.6192e-09, dtype=torch.float64)
tensor(2.6802e-09, dtype=torch.float64)
tensor(1.7728e-09, dtype=torch.float64)
tensor(1.7442e-09, dtype=torch.float64)
tensor(1.6963e-09, dtype=torch.float64)
tensor(1.8014e-09, dtype=torch.float64)
tensor(1.2343e-09, dtype=torch.float64)
tensor(7.6003e-10, dtype=torch.float64)
tensor(1.2165e-09, dtype=torch.float64)
tensor(1.3753e-09, dtype=torch.float64)
tensor(1.9674e-09, dtype=torch.float64)
tensor(8.7657e-10, dtype=torch.float64)
tensor(1.5341e-09, dtype=torch.float64)
tensor(1.3392e-09, dtype=torch.float64)
tensor(9.5075e-10, dtype=torch.float64)
tensor(5.1002e-10, dtype=torch.float64)
tensor(1.0587e-09, dtype=torch.float64)
tensor(9.2140e-10, dtype=torch.float64)
tensor(7.1149e-10, dtype=torch.float64)
tensor(5.9892e-10, dtype=torch.float64)
tensor(7.4379e-10, dtype=torch.float64)
tensor(9.8351e-10, dtype=torch.float64)
tensor(6.6518e-10, dtype=torch.float64)
tensor(1.0349e-09, dtype=torch.float64)
tensor(6.2190e-10, dtype=torch.float64)
tensor(9.9667e-10, dtype=torch.float64)
tensor(8.0508e-10, dtype=torch.float64)
tensor(8.9738e-10, dtype=torch.float64)
tensor(1.3208e-09, dtype=torch.float64)
tensor(5.8232e-10, dtype=torch.float64)
tensor(1.2484e-09, dtype=torch.float64)
tensor(1.0029e-09, dtype=torch.float64)
tensor(7.3745e-10, dtype=torch.float64)
tensor(6.6399e-10, dtype=torch.float64)
tensor(9.1979e-10, dtype=torch.float64)
tensor(6.0548e-10, dtype=torch.float64)
tensor(5.7826e-10, dtype=torch.float64)
tensor(4.2145e-10, dtype=torch.float64)
tensor(4.2077e-10, dtype=torch.float64)
tensor(1.8459e-10, dtype=torch.float64)
tensor(8.9805e-10, dtype=torch.float64)
tensor(4.4755e-10, dtype=torch.float64)
tensor(2.7856e-10, dtype=torch.float64)
tensor(4.0177e-10, dtype=torch.float64)
tensor(2.9036e-10, dtype=torch.float64)
tensor(4.6847e-10, dtype=torch.float64)
tensor(5.9662e-10, dtype=torch.float64)
tensor(3.4559e-10, dtype=torch.float64)
tensor(4.5530e-10, dtype=torch.float64)
tensor(4.6208e-10, dtype=torch.float64)
tensor(2.5914e-10, dtype=torch.float64)
tensor(3.8662e-10, dtype=torch.float64)
tensor(2.9855e-10, dtype=torch.float64)
tensor(3.1526e-10, dtype=torch.float64)
tensor(3.8357e-10, dtype=torch.float64)
tensor(2.9299e-10, dtype=torch.float64)
tensor(3.2051e-10, dtype=torch.float64)
tensor(2.0474e-10, dtype=torch.float64)
tensor(3.7196e-10, dtype=torch.float64)
tensor(2.8321e-10, dtype=torch.float64)
tensor(2.9790e-10, dtype=torch.float64)
tensor(1.6840e-10, dtype=torch.float64)
tensor(1.4103e-10, dtype=torch.float64)
tensor(2.0365e-10, dtype=torch.float64)
tensor(1.5910e-10, dtype=torch.float64)
tensor(3.1184e-10, dtype=torch.float64)
tensor(1.7404e-10, dtype=torch.float64)
tensor(1.7092e-10, dtype=torch.float64)
tensor(2.6387e-10, dtype=torch.float64)
tensor(1.6192e-10, dtype=torch.float64)
tensor(2.7748e-10, dtype=torch.float64)
tensor(2.0869e-10, dtype=torch.float64)
tensor(2.7848e-10, dtype=torch.float64)
tensor(2.7529e-10, dtype=torch.float64)
tensor(2.9607e-10, dtype=torch.float64)
tensor(2.7802e-10, dtype=torch.float64)
tensor(2.2078e-10, dtype=torch.float64)
tensor(4.1643e-10, dtype=torch.float64)
tensor(1.1816e-10, dtype=torch.float64)
tensor(2.0145e-10, dtype=torch.float64)
tensor(1.7290e-10, dtype=torch.float64)
tensor(2.1903e-10, dtype=torch.float64)
tensor(2.0926e-10, dtype=torch.float64)
tensor(1.6237e-10, dtype=torch.float64)
tensor(2.2521e-10, dtype=torch.float64)
tensor(1.6359e-10, dtype=torch.float64)
tensor(1.4044e-10, dtype=torch.float64)
tensor(2.2042e-10, dtype=torch.float64)
tensor(1.2348e-10, dtype=torch.float64)
tensor(1.5806e-10, dtype=torch.float64)
tensor(1.6219e-10, dtype=torch.float64)
tensor(8.0410e-11, dtype=torch.float64)
tensor(1.3570e-10, dtype=torch.float64)
tensor(1.0485e-10, dtype=torch.float64)
tensor(1.2346e-10, dtype=torch.float64)
tensor(1.3002e-10, dtype=torch.float64)
tensor(8.5194e-11, dtype=torch.float64)
tensor(1.6804e-10, dtype=torch.float64)
tensor(1.2795e-10, dtype=torch.float64)
tensor(1.3855e-10, dtype=torch.float64)
tensor(1.1846e-10, dtype=torch.float64)
tensor(1.4354e-10, dtype=torch.float64)
tensor(1.6040e-10, dtype=torch.float64)
tensor(7.4035e-11, dtype=torch.float64)
tensor(1.0228e-10, dtype=torch.float64)
[17]:
# Renormalize output and compute orbit
cx_out, cy_out = inverse([cx, cy],[(-150*1E-6, 150*1E-6), (-150*1E-6, 150*1E-6)])
points, *_ = orbit(ring, fp, [cx_out, cy_out], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), alignment=False, advance=True, full=False, limit=8, epsilon=1.0E-6)
qx, _, qy, _ = points.T
# Plot final corrector settings
plt.figure(figsize=(16, 2))
plt.bar(range(len(error_cx)), error_cx.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(cx)), +cx_out.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
plt.figure(figsize=(16, 2))
plt.bar(range(len(error_cy)), error_cy.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(cy)), +cy_out.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
# Plot orbits
plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qx_target.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qx.cpu().numpy(), fmt='-', color='red', marker='x', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qy_target.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qy.cpu().numpy(), fmt='-', color='red', marker='x', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
[18]:
# Apply negative fitted correstor settings and compute orbit
points, *_ = orbit(ring, fp, [error_cx - cx_out, error_cy - cy_out], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), alignment=False, advance=True, full=False, limit=16, epsilon=1.0E-6)
qx, _, qy, _ = points.T
# Plot corrected orbit
plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qx_target.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qx.cpu().numpy(), fmt='-', color='red', marker='x', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qy_target.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qy.cpu().numpy(), fmt='-', color='red', marker='x', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
[19]:
# In most cases, global correction will change a large number of correctors
# If the goal is to change just sevaral correctors
# It is possible to solve linear system using Lasso or OrthogonalMatchingPursuit or Lars
# Or other sparse solver
from sklearn.linear_model import Lasso
from sklearn.linear_model import OrthogonalMatchingPursuit
from sklearn.linear_model import Lars
[20]:
# Convert data to numpy
X = orm.cpu().numpy()
y = torch.stack([qx - qx_target, qy - qy_target]).flatten().cpu().numpy()
[21]:
# For Lasso, alpha parameter can be scaned to gen desiered number of non-zero components
lasso = Lasso(alpha=0.0001)
lasso.fit(X, y)
solution = lasso.coef_
cx_out, cy_out = torch.tensor(solution, dtype=torch.float64).reshape(1 + 1, -1)
# Compute orbit
points, *_ = orbit(ring, fp, [error_cx - cx_out, error_cy - cy_out], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), alignment=False, advance=True, full=False, limit=8, epsilon=1.0E-6)
qx_out, _, qy_out, _ = points.T
# Plot final corrector settings
plt.figure(figsize=(16, 2))
plt.bar(range(len(error_cx)), error_cx.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(cx)), +cx_out.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
plt.figure(figsize=(16, 2))
plt.bar(range(len(error_cy)), error_cy.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(cy)), +cy_out.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
# Plot orbits
plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qx_target.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qx_out.cpu().numpy(), fmt='-', color='red', marker='x', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qy_target.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qy_out.cpu().numpy(), fmt='-', color='red', marker='x', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
[22]:
# For OMP, number of nonzero coefficient can be passed directly
omp = OrthogonalMatchingPursuit(n_nonzero_coefs=8)
omp.fit(X, y)
solution = omp.coef_
cx_out, cy_out = torch.tensor(solution, dtype=torch.float64).reshape(1 + 1, -1)
# Compute orbit
points, *_ = orbit(ring, fp, [error_cx - cx_out, error_cy - cy_out], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), alignment=False, advance=True, full=False, limit=8, epsilon=1.0E-6)
qx_out, _, qy_out, _ = points.T
# Plot final corrector settings
plt.figure(figsize=(16, 2))
plt.bar(range(len(error_cx)), error_cx.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(cx)), +cx_out.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
plt.figure(figsize=(16, 2))
plt.bar(range(len(error_cy)), error_cy.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(cy)), +cy_out.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
# Plot orbits
plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qx_target.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qx_out.cpu().numpy(), fmt='-', color='red', marker='x', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qy_target.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qy_out.cpu().numpy(), fmt='-', color='red', marker='x', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
[23]:
# For Lars, number of nonzero coefficient can be passed directly
lars = Lars(n_nonzero_coefs=8)
lars.fit(X, y)
solution = lars.coef_
cx_out, cy_out = torch.tensor(solution, dtype=torch.float64).reshape(1 + 1, -1)
# Compute orbit
points, *_ = orbit(ring, fp, [error_cx - cx_out, error_cy - cy_out], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), alignment=False, advance=True, full=False, limit=8, epsilon=1.0E-6)
qx_out, _, qy_out, _ = points.T
# Plot final corrector settings
plt.figure(figsize=(16, 2))
plt.bar(range(len(error_cx)), error_cx.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(cx)), +cx_out.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
plt.figure(figsize=(16, 2))
plt.bar(range(len(error_cy)), error_cy.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(cy)), +cy_out.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
# Plot orbits
plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qx_target.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qx_out.cpu().numpy(), fmt='-', color='red', marker='x', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qy_target.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qy_out.cpu().numpy(), fmt='-', color='red', marker='x', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
[ ]: