Example-34: Orbit (ORM)

[1]:
# In this example orbit response matrix is computed for ideal lattice
# Next, orbit distortion is introduced (quadrupole and sextupole transverse shifts)
# Orbit is corrected using ideal ORM (experiment to design correction)
# Next, additional gradient and skew error are introduced in quadrupoles
# With this, real ORM matrix will be off, but design ORM still can be used to correct orbit
[2]:
# Import

from pprint import pprint

import torch

from pathlib import Path
from matplotlib import pyplot as plt

from twiss import twiss

from model.library.corrector import Corrector
from model.library.line import Line

from model.command.util import chop
from model.command.util import evaluate
from model.command.util import series

from model.command.external import load_sdds
from model.command.external import load_lattice

from model.command.build import build

from model.command.wrapper import group

from model.command.orbit import orbit
from model.command.orbit import parametric_orbit
[3]:
# Load ELEGANT twiss

path = Path('ic.twiss')
parameters, columns = load_sdds(path)

nu_qx:float = parameters['nux'] % 1
nu_qy:float = parameters['nuy'] % 1
[4]:
# Build and setup lattice

# Quadrupoles are splitted into 2**2 parts, Dipoles -- 2**4 part
# Correctors are inserted between parts

path = Path('ic.lte')
data = load_lattice(path)

ring:Line = build('RING', 'ELEGANT', data)
ring.propagate = True
ring.flatten()
ring.merge()

ring.split((None, ['BPM'], None, None))
ring.roll(1)

n_q = 2**2
n_d = 2**4

for name in [name for name, kind, *_ in ring.layout() if kind == 'Quadrupole']:
    corrector = Corrector(f'{name}_CXY', factor=1/(n_q - 1))
    ring.split((n_q, None, [name], None), paste=[corrector])

for name in [name for name, kind, *_ in ring.layout() if kind == 'Dipole']:
    corrector = Corrector(f'{name}_CXY', factor=1/(n_d - 1))
    ring.split((n_d, None, [name], None), paste=[corrector])

for element in ring:
    if element.__class__.__name__ == 'Dipole':
        element.linear = True

ring.splice()
[5]:
# Compare linear tunes

state = torch.tensor(4*[0.0], dtype=torch.float64)
matrix = torch.func.jacrev(ring)(state)
(nuqx, nuqy), *_ = twiss(matrix)

print(nu_qx - nuqx)
print(nu_qy - nuqy)
tensor(1.4433e-15, dtype=torch.float64)
tensor(-9.9920e-16, dtype=torch.float64)
[6]:
# Compute closed orbit

fp = 1.0E-3*torch.randn(4, dtype=torch.float64)
fp, *_ = orbit(ring, fp, [], alignment=False, limit=8, epsilon=1.0E-12)

# Chop small values

fp = [fp]
chop(fp)
fp, *_ = fp

print(fp)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[7]:
# Compute parametric closed orbit (1st order derivatives wrt cx and cy at each monitor location)

n_kick = ring.describe['Corrector']

cx = torch.tensor(n_kick*[0.0], dtype=torch.float64)
cy = torch.tensor(n_kick*[0.0], dtype=torch.float64)

pox, *_ = parametric_orbit(ring,
                           fp,
                           [cx],
                           (1, 'cx', ['Corrector'], None, None),
                           alignment=False,
                           advance=True,
                           full=False)


poy, *_ = parametric_orbit(ring,
                           fp,
                           [cy],
                           (1, 'cy', ['Corrector'], None, None),
                           alignment=False,
                           advance=True,
                           full=False)

chop(pox)
chop(poy)
[8]:
# Compute orbit response matrix

# [qx_1, qx_2, ..., qx_n, qy_1, qy_2, ..., qy_n] = M @ [cx_1, cx_2, ..., cx_k, cy_1, cy_2, ..., cy_k]

# qx_i/qy_i -- qx/qy orbit at BPM i
# cx_i/cy_i -- cx/cy angle at corrector i

def qxqy(cxy, pox, poy):
    cx, cy = cxy.reshape(1 + 1, -1)
    qx, _, qy, _ = torch.stack([evaluate(tx, [fp, cx]) + evaluate(ty, [fp, cy]) for tx, ty in zip(pox, poy)]).T
    return torch.cat([qx, qy])

cx = torch.tensor(n_kick*[0.0], dtype=torch.float64)
cy = torch.tensor(n_kick*[0.0], dtype=torch.float64)

cxy = torch.cat([cx, cy])
orm = torch.func.jacrev(qxqy)(cxy, pox, poy)

print(cxy.shape)
print(qxqy(cxy, pox, poy).shape)
print(orm.shape)

data = orm.clone()
data[data==0.0] = torch.nan

plt.figure(figsize=(34/4, 72/4))
img = plt.imshow(data.cpu().numpy(), cmap='magma', interpolation='nearest')
cax = plt.gcf().add_axes([plt.gca().get_position().x1 + 0.01, plt.gca().get_position().y0, 0.02, plt.gca().get_position().height])
plt.colorbar(img, cax=cax)
plt.show()
torch.Size([72])
torch.Size([32])
torch.Size([32, 72])
../_images/examples_model-33_8_1.png
[9]:
# Set corrector errors

cx = 50.0E-6*torch.randn_like(cx)
cy = 50.0E-6*torch.randn_like(cy)

# Find closed orbit with errors

points, *_ = orbit(ring, fp, [cx, cy], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), alignment=False, advance=True, full=False, limit=16, epsilon=1.0E-12)

# Set wrapper

start, *_, end = ring.names
mapping, *_ = group(ring, start, end, ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), alignment=False)

# Propagate estimated closed orbit

point, *_ = points
print(point)
print(mapping(point, cx, cy))
print(torch.allclose(point, mapping(point, cx, cy), rtol=1.0E-12, atol=1.0E-12))
print()

# Set orbit

qx, _, qy, _ = points.T

# Compute orbit from known errors

Qx, Qy = (orm @ torch.cat([cx, cy])).reshape(1 + 1, -1)

# qx vs Qx

plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qx.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), Qx.cpu().numpy(), fmt=' ', color='black', marker='x', ms=8, alpha=1)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()

# qy vs Qy

plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qy.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), Qy.cpu().numpy(), fmt=' ', color='black', marker='x', ms=8, alpha=1)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
tensor([-0.0003,  0.0004,  0.0007,  0.0013], dtype=torch.float64)
tensor([-0.0003,  0.0004,  0.0007,  0.0013], dtype=torch.float64)
True

../_images/examples_model-33_9_1.png
../_images/examples_model-33_9_2.png
[10]:
# Perform one correction step (use measured orbit and model matrix)

dcx, dcy = - (torch.linalg.pinv(orm) @ torch.cat([qx, qy])).reshape(1 + 1, -1)

# Find closed orbit with errors and add corrections

points, *_ = orbit(ring, fp, [cx + dcx, cy + dcy], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), alignment=False, advance=True, full=False, limit=16, epsilon=1.0E-12)

# Set orbit

dqx, _, dqy, _ = points.T

# qx vs dqx

plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qx.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), dqx.cpu().numpy(), fmt='-', color='red', marker='o', ms=8, alpha=1)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()

# qy vs dqy

plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qy.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), dqy.cpu().numpy(), fmt='-', color='red', marker='o', ms=8, alpha=1)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
../_images/examples_model-33_10_0.png
../_images/examples_model-33_10_1.png
[11]:
# Correction lool (use measured orbit and model matrix)

# Given measured orbit values qx and qy (differences with reference orbit)
# New corrector settings are computed and applied
# Orbit is remeasured and procedure is repeated

cx = 50.0E-6*torch.randn_like(cx)
cy = 50.0E-6*torch.randn_like(cy)

for _ in range(16):
    points, *_ = orbit(ring, fp, [cx, cy], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), alignment=False, advance=True, full=False, limit=16, epsilon=1.0E-12)
    qx, _, qy, _ = points.T
    dcx, dcy = - 0.5*(torch.linalg.pinv(orm) @ torch.cat([qx, qy])).reshape(1 + 1, -1)
    cx += dcx
    cy += dcy
    print(torch.cat([qx, qy]).norm())
tensor(0.0019, dtype=torch.float64)
tensor(0.0009, dtype=torch.float64)
tensor(0.0005, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(0.0001, dtype=torch.float64)
tensor(5.6358e-05, dtype=torch.float64)
tensor(2.8163e-05, dtype=torch.float64)
tensor(1.4078e-05, dtype=torch.float64)
tensor(7.0379e-06, dtype=torch.float64)
tensor(3.5187e-06, dtype=torch.float64)
tensor(1.7593e-06, dtype=torch.float64)
tensor(8.7964e-07, dtype=torch.float64)
tensor(4.3982e-07, dtype=torch.float64)
tensor(2.1991e-07, dtype=torch.float64)
tensor(1.0996e-07, dtype=torch.float64)
tensor(5.4978e-08, dtype=torch.float64)
[12]:
# In the above, errors were passed as deviaton variables
# Another option is to add errors to the main attributes

# Generate lattice with errors (errors are added to the main attributes)

error:Line = ring.clone()

cx = 50.0E-6*torch.randn_like(cx)
cy = 50.0E-6*torch.randn_like(cx)

index = 0
label = ''

for line in error.sequence:
    for element in line:
        if element.__class__.__name__ == 'Corrector':
            if label != element.name:
                index +=1
            label = element.name
            element.cx = cx[index - 1].item()
            element.cy = cy[index - 1].item()

# Perform correction

cx = torch.zeros_like(cx)
cy = torch.zeros_like(cx)

for _ in range(16):
    points, *_ = orbit(error, fp, [cx, cy], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), alignment=False, advance=True, full=False, limit=16, epsilon=1.0E-12)
    qx, _, qy, _ = points.T
    dcx, dcy = - 0.5*(torch.linalg.pinv(orm) @ torch.cat([qx, qy])).reshape(1 + 1, -1)
    cx += dcx
    cy += dcy
    print(torch.cat([qx, qy]).norm())
tensor(0.0013, dtype=torch.float64)
tensor(0.0006, dtype=torch.float64)
tensor(0.0003, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
tensor(7.8633e-05, dtype=torch.float64)
tensor(3.9300e-05, dtype=torch.float64)
tensor(1.9644e-05, dtype=torch.float64)
tensor(9.8195e-06, dtype=torch.float64)
tensor(4.9086e-06, dtype=torch.float64)
tensor(2.4538e-06, dtype=torch.float64)
tensor(1.2266e-06, dtype=torch.float64)
tensor(6.1319e-07, dtype=torch.float64)
tensor(3.0653e-07, dtype=torch.float64)
tensor(1.5324e-07, dtype=torch.float64)
tensor(7.6602e-08, dtype=torch.float64)
tensor(3.8294e-08, dtype=torch.float64)
[13]:
# Add alignment and focusing errors to quadrupoles

# Note, adding the same alignmet elemets to parts is not valid for all types of alignmet errors

error:Line = ring.clone()

n_quad = error.describe['Quadrupole']

dx = 100.0E-6*torch.randn(n_quad, dtype=torch.float64)
dy = 100.0E-6*torch.randn(n_quad, dtype=torch.float64)

kn = 0.1*torch.randn(n_quad, dtype=torch.float64)
ks = 0.1*torch.randn(n_quad, dtype=torch.float64)

index = 0
label = ''

for line in error.sequence:
    for element in line:
        if element.__class__.__name__ == 'Quadrupole':
            if label != element.name:
                index +=1
            label = element.name
            element.dx = dx[index - 1].item()
            element.dy = dy[index - 1].item()
            element.kn = (element.kn + kn[index - 1]).item()
            element.ks = (element.ks + ks[index - 1]).item()

# Compute closed orbit with zero corrector
# Note, alignment is on

cx = torch.zeros_like(cx)
cy = torch.zeros_like(cx)

points, *_ = orbit(error, fp, [cx, cy], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), alignment=True, advance=True, full=False, limit=16, epsilon=1.0E-12)

# Test closed orbit

point, *_ = points

print(point)
print(error(point, alignment=True))
print(torch.allclose(point, error(point, alignment=True), rtol=1.0E-12, atol=1.0E-12))
print()

# Plot orbit

qx_initial, _, qy_initial, _ = points.T

plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qx_initial.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()

plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qy_initial.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
tensor([ 0.0007, -0.0005, -0.0005, -0.0012], dtype=torch.float64)
tensor([ 0.0007, -0.0005, -0.0005, -0.0012], dtype=torch.float64)
True

../_images/examples_model-33_13_1.png
../_images/examples_model-33_13_2.png
[14]:
# Compure ORM for model with errors (measured ORM)

n_kick = ring.describe['Corrector']

cx = torch.tensor(n_kick*[0.0], dtype=torch.float64)
cy = torch.tensor(n_kick*[0.0], dtype=torch.float64)

pox, *_ = parametric_orbit(error,
                           fp,
                           [cx],
                           (1, 'cx', ['Corrector'], None, None),
                           alignment=False,
                           advance=True,
                           full=False)


poy, *_ = parametric_orbit(error,
                           fp,
                           [cy],
                           (1, 'cy', ['Corrector'], None, None),
                           alignment=False,
                           advance=True,
                           full=False)

chop(pox)
chop(poy)

cx = torch.tensor(n_kick*[0.0], dtype=torch.float64)
cy = torch.tensor(n_kick*[0.0], dtype=torch.float64)

cxy = torch.cat([cx, cy])
orm_error = torch.func.jacrev(qxqy)(cxy, pox, poy)

data = orm.clone()
data[data==0.0] = torch.nan
plt.figure(figsize=(34/4, 72/4))
img = plt.imshow(data.cpu().numpy(), cmap='magma', interpolation='nearest')
cax = plt.gcf().add_axes([plt.gca().get_position().x1 + 0.01, plt.gca().get_position().y0, 0.02, plt.gca().get_position().height])
plt.colorbar(img, cax=cax)
plt.show()

data = orm_error.clone()
data[data==0.0] = torch.nan
plt.figure(figsize=(34/4, 72/4))
img = plt.imshow(data.cpu().numpy(), cmap='magma', interpolation='nearest')
cax = plt.gcf().add_axes([plt.gca().get_position().x1 + 0.01, plt.gca().get_position().y0, 0.02, plt.gca().get_position().height])
plt.colorbar(img, cax=cax)
plt.show()
../_images/examples_model-33_14_0.png
../_images/examples_model-33_14_1.png
[15]:
# Find corrector settings to minimize orbit distortion (model ORM)

cx = torch.zeros_like(cx)
cy = torch.zeros_like(cx)

for _ in range(4):
    points, *_ = orbit(error, fp, [cx, cy], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), alignment=True, advance=True, full=False, limit=16, epsilon=1.0E-12)
    qx, _, qy, _ = points.T
    dcx, dcy = - 0.75*(torch.linalg.pinv(orm) @ torch.cat([qx, qy])).reshape(1 + 1, -1)
    cx += dcx
    cy += dcy
    print(torch.cat([qx, qy]).norm())

qx_model = qx.clone()
qy_model = qy.clone()
tensor(0.0064, dtype=torch.float64)
tensor(0.0020, dtype=torch.float64)
tensor(0.0006, dtype=torch.float64)
tensor(0.0002, dtype=torch.float64)
[16]:
# Find corrector settings to minimize orbit distortion (measured ORM)

cx = torch.zeros_like(cx)
cy = torch.zeros_like(cx)

for _ in range(4):
    points, *_ = orbit(error, fp, [cx, cy], ('cx', ['Corrector'], None, None), ('cy', ['Corrector'], None, None), alignment=True, advance=True, full=False, limit=16, epsilon=1.0E-12)
    qx, _, qy, _ = points.T
    dcx, dcy = - 0.75*(torch.linalg.pinv(orm_error) @ torch.cat([qx, qy])).reshape(1 + 1, -1)
    cx += dcx
    cy += dcy
    print(torch.cat([qx, qy]).norm())

qx_error = qx.clone()
qy_error = qy.clone()
tensor(0.0064, dtype=torch.float64)
tensor(0.0019, dtype=torch.float64)
tensor(0.0005, dtype=torch.float64)
tensor(0.0001, dtype=torch.float64)
[17]:
# Compare orbits after correction

# qx

plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qx_initial.cpu().numpy(), fmt='-', color='black', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qx_model.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qx_error.cpu().numpy(), fmt='-', color='red', marker='o', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()

# qy

plt.figure(figsize=(16, 2))
plt.errorbar(ring.locations().cpu().numpy(), qy_initial.cpu().numpy(), fmt='-', color='black', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qy_model.cpu().numpy(), fmt='-', color='blue', marker='o', ms=8, alpha=0.75)
plt.errorbar(ring.locations().cpu().numpy(), qy_error.cpu().numpy(), fmt='-', color='red', marker='o', ms=8, alpha=0.75)
plt.xticks(ticks=ring.locations(), labels=dict.fromkeys([name for name, kind, *_ in ring.layout() if kind == 'BPM']))
plt.tight_layout()
plt.show()
../_images/examples_model-33_17_0.png
../_images/examples_model-33_17_1.png