Example-38: Orbit (chromatic ORM)

[1]:
# Import

from pprint import pprint

import torch

from pathlib import Path

import matplotlib
from matplotlib import pyplot as plt
matplotlib.rcParams['text.usetex'] = True

from twiss import twiss
from twiss import propagate
from twiss import wolski_to_cs

from model.library.line import Element
from model.library.corrector import Corrector
from model.library.line import Line

from model.command.util import chop

from model.command.external import load_lattice

from model.command.build import build

from model.command.wrapper import forward
from model.command.wrapper import inverse
from model.command.wrapper import normalize
from model.command.wrapper import Wrapper

from model.command.orbit import orbit
from model.command.orbit import ORM
from model.command.orbit import ORM_DP
[2]:
# Set element data type and devive

Element.dtype = torch.float64
Element.device = torch.device('cuda')
[3]:
# Build and setup lattice

# Quadrupoles are splitted into 2**2 parts, Dipoles -- 2**4 part
# Correctors are inserted between parts

# Load ELEGANT table

path = Path('ic.lte')
data = load_lattice(path)

# Build ELEGANT table

ring:Line = build('RING', 'ELEGANT', data)
ring.flatten()

# Merge drifts

ring.merge()

# Split BPMs

ring.split((None, ['BPM'], None, None))

# Roll lattice start

ring.roll(1)

# Split quadrupoles and insert correctors

nq = 2**2

for name in [name for name, kind, *_ in ring.layout() if kind == 'Quadrupole']:
    corrector = Corrector(f'{name}_CXY', factor=1/(nq - 1))
    ring.split((nq, None, [name], None), paste=[corrector])

# Split dipoles and insert correctors

nd = 2**4

for name in [name for name, kind, *_ in ring.layout() if kind == 'Dipole']:
    corrector = Corrector(f'{name}_CXY', factor=1/(nd - 1))
    ring.split((nd, None, [name], None), paste=[corrector])

# Set linear flag in dipoles

for element in ring:
    element.alignment = False
    if element.__class__.__name__ == 'Dipole':
        element.linear = True
    if element.__class__.__name__ == 'Sextupole':
        element.alignment = True
    if element.__class__.__name__ == 'Quadrupole':
        element.alignment = True

# Split lattice into lines by BPMs

ring.splice()

# Set number of elements of different kinds

nb = ring.describe['BPM']
nc = ring.describe['Corrector']
nq = ring.describe['Quadrupole']
ns = ring.describe['Sextupole']
[4]:
# Compute closed orbit

fp = 1.0E-3*torch.randn(4, dtype=Element.dtype, device=Element.device)
fp, *_ = orbit(ring, fp, [], alignment=False, limit=8, epsilon=1.0E-12)

# Chop small values

fp = [fp]
chop(fp)
fp, *_ = fp

print(fp)
tensor([0., 0., 0., 0.], device='cuda:0', dtype=torch.float64)
[5]:
# Compute ORM

orm = ORM(ring, fp, [], limit=1, start=0, epsilon=None)
print(orm.shape)

data = orm.clone()
data[data==0.0] = torch.nan
plt.figure(figsize=(34/4, 72/4))
img = plt.imshow(data.cpu().numpy(), cmap='magma', interpolation='nearest')
cax = plt.gcf().add_axes([plt.gca().get_position().x1 + 0.01, plt.gca().get_position().y0, 0.02, plt.gca().get_position().height])
plt.colorbar(img, cax=cax)
plt.show()
torch.Size([32, 72])
../_images/examples_model-37_5_1.png
[6]:
# Compute ORM derivative with respect to momentum deviation

orm_dp = ORM_DP(ring, fp, [], limit=1, start=0, epsilon=None)
print(orm_dp.shape)

data = orm_dp.clone()
data[data==0.0] = torch.nan
plt.figure(figsize=(34/4, 72/4))
img = plt.imshow(data.cpu().numpy(), cmap='magma', interpolation='nearest')
cax = plt.gcf().add_axes([plt.gca().get_position().x1 + 0.01, plt.gca().get_position().y0, 0.02, plt.gca().get_position().height])
plt.colorbar(img, cax=cax)
plt.show()
torch.Size([32, 72])
../_images/examples_model-37_6_1.png
[7]:
# Set lattice with focusing errors

error:Line = ring.clone()

nq = error.describe['Quadrupole']

error_kn = 0.1*torch.randn(nq, dtype=Element.dtype, device=Element.device)
error_ks = 0.1*torch.randn(nq, dtype=Element.dtype, device=Element.device)

index = 0
label = ''

for line in error.sequence:
    for element in line:
        if element.__class__.__name__ == 'Quadrupole':
            if label != element.name:
                index +=1
            label = element.name
            element.kn = (element.kn + error_kn[index - 1]).item()
            element.ks = (element.ks + error_ks[index - 1]).item()
[8]:
# Setup twiss computation

def task(line):

    line:Line = line.clone()

    state = torch.tensor(4*[0.0], dtype=Element.dtype, device=Element.device)
    matrix = torch.func.jacrev(line)(state)
    (nux, nuy), _, w = twiss(matrix)

    line.propagate = False
    line.matrix = True

    state = torch.tensor(4*[0.0], dtype=Element.dtype, device=Element.device)
    line(state)

    ws = [w]
    *ms, _ = line.container_matrix
    for m in ms:
        w = propagate(w, m)
        ws.append(w)
    ws = torch.stack(ws)

    ax, bx, ay, by = torch.vmap(wolski_to_cs)(ws).T

    return (nux, nuy), (ax, bx, ay, by)

# Unperturbed twiss

(nux_model, nuy_model), (ax_model, bx_model, ay_model, by_model) = task(ring)

# Perturbed twiss

(nux_error, nuy_error), (ax_error, bx_error, ay_error, by_error) = task(error)

# Compare twiss

print((nux_model - nux_error).norm())
print((nuy_model - nuy_error).norm())
print()

print((ax_model - ax_error).norm())
print((bx_model - bx_error).norm())
print((ay_model - ay_error).norm())
print((by_model - by_error).norm())
print()

# Plot beta beating

plt.figure(figsize=(16, 2))
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')
plt.xticks(ticks=ring.locations().cpu().numpy(), labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()
tensor(0.0341, device='cuda:0', dtype=torch.float64)
tensor(0.0089, device='cuda:0', dtype=torch.float64)

tensor(2.4572, device='cuda:0', dtype=torch.float64)
tensor(1.5588, device='cuda:0', dtype=torch.float64)
tensor(2.4890, device='cuda:0', dtype=torch.float64)
tensor(1.5438, device='cuda:0', dtype=torch.float64)

../_images/examples_model-37_8_1.png
[9]:
# Compute (measure) ORM for lattice with errors

orm_error = ORM(error, fp, [], limit=1, start=0, epsilon=None)
print(orm_error.shape)

data = (orm - orm_error).clone()
data[data==0.0] = torch.nan
plt.figure(figsize=(34/4, 72/4))
img = plt.imshow(data.cpu().numpy(), cmap='magma', interpolation='nearest')
cax = plt.gcf().add_axes([plt.gca().get_position().x1 + 0.01, plt.gca().get_position().y0, 0.02, plt.gca().get_position().height])
plt.colorbar(img, cax=cax)
plt.show()
torch.Size([32, 72])
../_images/examples_model-37_9_1.png
[10]:
# Compute (measure) ORM derivative for lattice with errors

orm_dp_error = ORM_DP(error, fp, [], limit=1, start=0, epsilon=None)
print(orm_dp.shape)

data = (orm_dp - orm_dp_error).clone()
data[data==0.0] = torch.nan
plt.figure(figsize=(34/4, 72/4))
img = plt.imshow(data.cpu().numpy(), cmap='magma', interpolation='nearest')
cax = plt.gcf().add_axes([plt.gca().get_position().x1 + 0.01, plt.gca().get_position().y0, 0.02, plt.gca().get_position().height])
plt.colorbar(img, cax=cax)
plt.show()
torch.Size([32, 72])
../_images/examples_model-37_10_1.png
[11]:
# Define parametric ORM + ORM_DP

def jacobian(*args, **kwargs):
    return torch.func.jacrev(*args, **kwargs)

def ORM_kn(kn):
    return ORM(ring, fp, [kn], ('kn', ['Quadrupole'], None, None), limit=1, start=0, epsilon=None, jacobian=jacobian)

def ORM_ks(ks):
    return ORM(ring, fp, [ks], ('ks', ['Quadrupole'], None, None), limit=1, start=0, epsilon=None, jacobian=jacobian)

def ORM_DP_kn(kn):
    return ORM_DP(ring, fp, [kn], ('kn', ['Quadrupole'], None, None), limit=1, start=0, epsilon=None, jacobian=jacobian)

def ORM_DP_ks(ks):
    return ORM_DP(ring, fp, [ks], ('ks', ['Quadrupole'], None, None), limit=1, start=0, epsilon=None, jacobian=jacobian)
[12]:
# Parametric ORMs are differentiable with respect to deviation groups

# Set deviation tensors

kn = torch.zeros_like(error_kn)
ks = torch.zeros_like(error_ks)

# Compute derivatives of model ORMs and  with respect to kn and ks deviations

dorm_dkn = torch.func.jacrev(ORM_kn)(kn).reshape(4*nb*nc, nq)
torch.cuda.empty_cache()
print(dorm_dkn.shape)

dorm_dks = torch.func.jacrev(ORM_ks)(ks).reshape(4*nb*nc, nq)
torch.cuda.empty_cache()
print(dorm_dks.shape)

dorm_dp_dkn = torch.func.jacfwd(ORM_DP_kn)(kn).reshape(4*nb*nc, nq)
torch.cuda.empty_cache()
print(dorm_dp_dkn.shape)

dorm_dp_dks = torch.func.jacfwd(ORM_DP_ks)(ks).reshape(4*nb*nc, nq)
torch.cuda.empty_cache()
print(dorm_dp_dks.shape)
torch.Size([2304, 28])
torch.Size([2304, 28])
torch.Size([2304, 28])
torch.Size([2304, 28])
[13]:
# Set response matrix

# [..., orm_ij, ..., dorm_dp_ij, ....] = M [..., kn_i, ..., ks_i, ...]

response = torch.hstack([torch.vstack([dorm_dkn, dorm_dp_dkn]), torch.vstack([dorm_dks, dorm_dp_dks])])
print(response.shape)
torch.Size([4608, 56])
[14]:
# Test response matrix

result_error = torch.stack([orm_error, orm_dp_error]).flatten()

result = torch.stack([orm, orm_dp]).flatten()
vector = torch.cat([error_kn, error_ks])

print((result_error - (result + 0*(response @ vector))).norm())
print((result_error - (result + 1*(response @ vector))).norm())
tensor(483.8618, device='cuda:0', dtype=torch.float64)
tensor(164.2502, device='cuda:0', dtype=torch.float64)
[15]:
# Perform correction (model to experiment)

lr = 0.5

kn = torch.zeros_like(error_kn)
ks = torch.zeros_like(error_ks)

for _ in range(32):
    orm_fit = ORM(ring, fp, [kn, ks], ('kn', ['Quadrupole'], None, None), ('ks', ['Quadrupole'], None, None), limit=1, start=0, epsilon=None, jacobian=jacobian)
    orm_dp_fit = ORM_DP(ring, fp, [kn, ks], ('kn', ['Quadrupole'], None, None), ('ks', ['Quadrupole'], None, None), limit=1, start=0, epsilon=None, jacobian=jacobian)
    result_fit = torch.stack([orm_fit, orm_dp_fit]).flatten()
    dkn, dks = (- lr*torch.linalg.lstsq(response, (result_fit - result_error), driver='gels').solution).split((nq, nq))
    kn += dkn
    ks += dks
    print((result_fit - result_error).norm())

# Plot final quadrupole settings

plt.figure(figsize=(16, 2))
plt.bar(range(len(error_kn)), error_kn.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(kn)), +kn.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()

plt.figure(figsize=(16, 2))
plt.bar(range(len(error_ks)), error_ks.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(ks)), +ks.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
tensor(483.8618, device='cuda:0', dtype=torch.float64)
tensor(165.7608, device='cuda:0', dtype=torch.float64)
tensor(54.7430, device='cuda:0', dtype=torch.float64)
tensor(30.1470, device='cuda:0', dtype=torch.float64)
tensor(17.5798, device='cuda:0', dtype=torch.float64)
tensor(11.3230, device='cuda:0', dtype=torch.float64)
tensor(7.7388, device='cuda:0', dtype=torch.float64)
tensor(5.4908, device='cuda:0', dtype=torch.float64)
tensor(3.9718, device='cuda:0', dtype=torch.float64)
tensor(2.9000, device='cuda:0', dtype=torch.float64)
tensor(2.1269, device='cuda:0', dtype=torch.float64)
tensor(1.5630, device='cuda:0', dtype=torch.float64)
tensor(1.1498, device='cuda:0', dtype=torch.float64)
tensor(0.8462, device='cuda:0', dtype=torch.float64)
tensor(0.6231, device='cuda:0', dtype=torch.float64)
tensor(0.4590, device='cuda:0', dtype=torch.float64)
tensor(0.3383, device='cuda:0', dtype=torch.float64)
tensor(0.2494, device='cuda:0', dtype=torch.float64)
tensor(0.1840, device='cuda:0', dtype=torch.float64)
tensor(0.1358, device='cuda:0', dtype=torch.float64)
tensor(0.1002, device='cuda:0', dtype=torch.float64)
tensor(0.0740, device='cuda:0', dtype=torch.float64)
tensor(0.0547, device='cuda:0', dtype=torch.float64)
tensor(0.0404, device='cuda:0', dtype=torch.float64)
tensor(0.0299, device='cuda:0', dtype=torch.float64)
tensor(0.0221, device='cuda:0', dtype=torch.float64)
tensor(0.0163, device='cuda:0', dtype=torch.float64)
tensor(0.0121, device='cuda:0', dtype=torch.float64)
tensor(0.0089, device='cuda:0', dtype=torch.float64)
tensor(0.0066, device='cuda:0', dtype=torch.float64)
tensor(0.0049, device='cuda:0', dtype=torch.float64)
tensor(0.0036, device='cuda:0', dtype=torch.float64)
../_images/examples_model-37_15_1.png
../_images/examples_model-37_15_2.png
[16]:
# Apply corrections

lattice:Line = error.clone()

index = 0
label = ''

for line in lattice.sequence:
    for element in line:
        if element.__class__.__name__ == 'Quadrupole':
            if label != element.name:
                index +=1
            label = element.name
            element.kn = (element.kn - kn[index - 1]).item()
            element.ks = (element.ks - ks[index - 1]).item()
[17]:
# Compare twiss parameters with model before and after correction

(nux_model, nuy_model), (ax_model, bx_model, ay_model, by_model) = task(ring)
(nux_error, nuy_error), (ax_error, bx_error, ay_error, by_error) = task(error)
(nux_lattice, nuy_lattice), (ax_lattice, bx_lattice, ay_lattice, by_lattice) = task(lattice)

# Before

plt.figure(figsize=(16, 2))
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')
plt.xticks(ticks=ring.locations().cpu().numpy(), labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()

# After

plt.figure(figsize=(16, 2))
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_lattice)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_lattice)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')
plt.xticks(ticks=ring.locations().cpu().numpy(), labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()
../_images/examples_model-37_17_0.png
../_images/examples_model-37_17_1.png