Example-52: Normalized dispersion

[1]:
# In this example normalized dispersion is used for optics correction along with CS twiss parameters
[2]:
# Import

from pprint import pprint

import torch
from torch import Tensor
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

from pathlib import Path

import matplotlib
from matplotlib import pyplot as plt
matplotlib.rcParams['text.usetex'] = True

from model.library.line import Line

from model.command.util import select

from model.command.external import load_sdds
from model.command.external import load_lattice

from model.command.build import build

from model.command.wrapper import group
from model.command.wrapper import forward
from model.command.wrapper import inverse
from model.command.wrapper import normalize
from model.command.wrapper import Wrapper

from model.command.orbit import dispersion
from model.command.tune import tune
from model.command.twiss import twiss
[3]:
# Load ELEGANT twiss

path = Path('ic.twiss')
parameters, columns = load_sdds(path)

nu_qx:Tensor = torch.tensor(parameters['nux'] % 1, dtype=torch.float64)
nu_qy:Tensor = torch.tensor(parameters['nuy'] % 1, dtype=torch.float64)

# Set twiss parameters at BPMs

kinds = select(columns, 'ElementType', keep=False)

a_qx = select(columns, 'alphax', keep=False)
b_qx = select(columns, 'betax' , keep=False)
a_qy = select(columns, 'alphay', keep=False)
b_qy = select(columns, 'betay' , keep=False)

a_qx:Tensor = torch.tensor([value for (key, value), kind in zip(a_qx.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)
b_qx:Tensor = torch.tensor([value for (key, value), kind in zip(b_qx.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)
a_qy:Tensor = torch.tensor([value for (key, value), kind in zip(a_qy.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)
b_qy:Tensor = torch.tensor([value for (key, value), kind in zip(b_qy.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)

eta_qx = select(columns, 'etax' , keep=False)
eta_px = select(columns, 'etaxp', keep=False)
eta_qy = select(columns, 'etay' , keep=False)
eta_py = select(columns, 'etayp', keep=False)

eta_qx:Tensor = torch.tensor([value for (key, value), kind in zip(eta_qx.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)
eta_px:Tensor = torch.tensor([value for (key, value), kind in zip(eta_px.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)
eta_qy:Tensor = torch.tensor([value for (key, value), kind in zip(eta_qy.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)
eta_py:Tensor = torch.tensor([value for (key, value), kind in zip(eta_py.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)

positions = select(columns, 's', keep=False).items()
positions = [value for (key, value), kind in zip(positions, kinds.values()) if kind == 'MONI']
[4]:
# Build and setup lattice

# Load ELEGANT table

path = Path('ic.lte')
data = load_lattice(path)

# Build ELEGANT table

ring:Line = build('RING', 'ELEGANT', data)
ring.flatten()

# Merge drifts

ring.merge()

# Split BPMs

ring.split((None, ['BPM'], None, None))

# Roll lattice start

ring.roll(1)

# Set linear dipoles

for element in ring:
    if element.__class__.__name__ == 'Dipole':
        element.linear = True

# Split lattice into lines by BPMs

ring.splice()

# Set number of elements of different kinds

nb = ring.describe['BPM']
nq = ring.describe['Quadrupole']
ns = ring.describe['Sextupole']
[5]:
# Compare tunes

nuqx, nuqy = tune(ring, [], alignment=False, matched=True)

print(torch.allclose(nu_qx, nuqx))
print(torch.allclose(nu_qy, nuqy))
True
True
[6]:
# Compare twiss

aqx, bqx, aqy, bqy = twiss(ring, [], alignment=False, matched=True, advance=True, full=False, convert=True).T

print(torch.allclose(a_qx, aqx))
print(torch.allclose(b_qx, bqx))
print(torch.allclose(a_qy, aqy))
print(torch.allclose(b_qy, bqy))
True
True
True
True
[7]:
# Compare dispersion

guess = torch.tensor(4*[0.0], dtype=torch.float64)

etaqx, etapx, etaqy, etapy =  dispersion(ring, guess, [], alignment=False)

print(torch.allclose(eta_qx, etaqx))
print(torch.allclose(eta_px, etapx))
print(torch.allclose(eta_qy, etapy))
print(torch.allclose(eta_py, etaqy))
True
True
True
True
[8]:
# Define parametric normalized dispersion

def normalized_dispersion(kn, line=ring):
    guess = torch.tensor(4*[0.0], dtype=torch.float64)
    etaqx, _, etaqy, _ =  dispersion(line, guess, [kn], ('kn', ['Quadrupole'], None, None), alignment=False)
    _, bqx, _, bqy = twiss(line,  [kn], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True).T
    return torch.stack([etaqx/bqx.sqrt(), etaqy/bqy.sqrt()])
[9]:
# Compute twiss and normalized dispersion derivatives

kn = torch.zeros(nq, dtype=torch.float64)

dtwiss_dkn = torch.func.jacrev(lambda kn: twiss(ring, [kn], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True))(kn)
dnormal_dkn = torch.func.jacrev(normalized_dispersion)(kn).swapaxes(0, 1)

print(dtwiss_dkn.shape)
print(dnormal_dkn.shape)
torch.Size([16, 4, 28])
torch.Size([16, 2, 28])
[10]:
# Set lattice with focusing errors (no coupling)

error:Line = ring.clone()

nq = error.describe['Quadrupole']

error_kn = 0.1*torch.randn(nq, dtype=torch.float64)

index = 0
label = ''

for line in error.sequence:
    for element in line:
        if element.__class__.__name__ == 'Quadrupole':
            if label != element.name:
                index +=1
            label = element.name
            element.kn = (element.kn + error_kn[index - 1]).item()
[11]:
# Compute twiss and plot beta beating

ax_model, bx_model, ay_model, by_model = twiss(ring, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
ax_error, bx_error, ay_error, by_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True).T

# Compare twiss

print((ax_model - ax_error).norm())
print((bx_model - bx_error).norm())
print((ay_model - ay_error).norm())
print((by_model - by_error).norm())
print()

# Plot beta beating

plt.figure(figsize=(16, 2))
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')
plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()
tensor(2.4390, dtype=torch.float64)
tensor(1.4076, dtype=torch.float64)
tensor(2.3071, dtype=torch.float64)
tensor(1.4153, dtype=torch.float64)

../_images/examples_model-51_11_1.png
[12]:
# Compute and plot normalized dispersion

netaqx_model, netaqy_model = normalized_dispersion(kn, ring)
netaqx_error, netaqy_error = normalized_dispersion(kn, error)

print((netaqx_model - netaqx_error).norm())
print((netaqy_model - netaqy_error).norm())
print()

plt.figure(figsize=(16, 2))
plt.plot(ring.locations().cpu().numpy(), (netaqx_model - netaqx_error).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), (netaqy_model - netaqy_error).cpu().numpy(), color='blue', alpha=0.75, marker='o')
plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()
tensor(0.0782, dtype=torch.float64)
tensor(0., dtype=torch.float64)

../_images/examples_model-51_12_1.png
[13]:
# Perform correction (model to experiment)

# Set response matrix

matrix = torch.vstack([dtwiss_dkn.reshape(-1, nq), dnormal_dkn.reshape(-1, nq)])

# Set target twiss parameters

twiss_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True)

# Set target normalized dispesion

normal_error = normalized_dispersion(0*kn, error)

# Set learning rate

lr = 0.1

# Set initial values

kn = torch.zeros_like(error_kn)

# Fit

for _ in range(64):
    twiss_model = twiss(ring, [kn], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)
    normal_model = normalized_dispersion(kn, ring)
    dkn = - lr*torch.linalg.lstsq(matrix, torch.cat([(twiss_model - twiss_error).flatten(), (normal_model - normal_error).flatten()]), driver='gelsd').solution
    kn += dkn
    print(torch.stack([(twiss_model - twiss_error).norm(), (normal_model - normal_error).norm()]))

# Plot final quadrupole settings

plt.figure(figsize=(16, 2))
plt.bar(range(len(error_kn)), error_kn.cpu().numpy(), color='red', alpha=0.75, width=1)
plt.bar(range(len(kn)), +kn.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
tensor([3.9058, 0.0782], dtype=torch.float64)
tensor([3.5081, 0.0731], dtype=torch.float64)
tensor([3.1346, 0.0685], dtype=torch.float64)
tensor([2.7876, 0.0644], dtype=torch.float64)
tensor([2.4690, 0.0608], dtype=torch.float64)
tensor([2.1800, 0.0576], dtype=torch.float64)
tensor([1.9211, 0.0548], dtype=torch.float64)
tensor([1.6919, 0.0523], dtype=torch.float64)
tensor([1.4908, 0.0501], dtype=torch.float64)
tensor([1.3159, 0.0481], dtype=torch.float64)
tensor([1.1645, 0.0462], dtype=torch.float64)
tensor([1.0341, 0.0445], dtype=torch.float64)
tensor([0.9219, 0.0429], dtype=torch.float64)
tensor([0.8253, 0.0414], dtype=torch.float64)
tensor([0.7420, 0.0400], dtype=torch.float64)
tensor([0.6700, 0.0386], dtype=torch.float64)
tensor([0.6076, 0.0372], dtype=torch.float64)
tensor([0.5531, 0.0359], dtype=torch.float64)
tensor([0.5054, 0.0346], dtype=torch.float64)
tensor([0.4634, 0.0334], dtype=torch.float64)
tensor([0.4263, 0.0322], dtype=torch.float64)
tensor([0.3932, 0.0310], dtype=torch.float64)
tensor([0.3636, 0.0299], dtype=torch.float64)
tensor([0.3371, 0.0287], dtype=torch.float64)
tensor([0.3131, 0.0277], dtype=torch.float64)
tensor([0.2914, 0.0266], dtype=torch.float64)
tensor([0.2716, 0.0256], dtype=torch.float64)
tensor([0.2535, 0.0246], dtype=torch.float64)
tensor([0.2370, 0.0236], dtype=torch.float64)
tensor([0.2217, 0.0227], dtype=torch.float64)
tensor([0.2077, 0.0218], dtype=torch.float64)
tensor([0.1948, 0.0210], dtype=torch.float64)
tensor([0.1828, 0.0201], dtype=torch.float64)
tensor([0.1716, 0.0193], dtype=torch.float64)
tensor([0.1613, 0.0185], dtype=torch.float64)
tensor([0.1517, 0.0178], dtype=torch.float64)
tensor([0.1427, 0.0171], dtype=torch.float64)
tensor([0.1343, 0.0164], dtype=torch.float64)
tensor([0.1265, 0.0157], dtype=torch.float64)
tensor([0.1192, 0.0151], dtype=torch.float64)
tensor([0.1123, 0.0144], dtype=torch.float64)
tensor([0.1059, 0.0138], dtype=torch.float64)
tensor([0.0999, 0.0133], dtype=torch.float64)
tensor([0.0942, 0.0127], dtype=torch.float64)
tensor([0.0889, 0.0122], dtype=torch.float64)
tensor([0.0839, 0.0117], dtype=torch.float64)
tensor([0.0792, 0.0112], dtype=torch.float64)
tensor([0.0748, 0.0107], dtype=torch.float64)
tensor([0.0706, 0.0102], dtype=torch.float64)
tensor([0.0667, 0.0098], dtype=torch.float64)
tensor([0.0630, 0.0094], dtype=torch.float64)
tensor([0.0595, 0.0090], dtype=torch.float64)
tensor([0.0563, 0.0086], dtype=torch.float64)
tensor([0.0532, 0.0082], dtype=torch.float64)
tensor([0.0503, 0.0079], dtype=torch.float64)
tensor([0.0475, 0.0075], dtype=torch.float64)
tensor([0.0449, 0.0072], dtype=torch.float64)
tensor([0.0425, 0.0069], dtype=torch.float64)
tensor([0.0402, 0.0066], dtype=torch.float64)
tensor([0.0380, 0.0063], dtype=torch.float64)
tensor([0.0359, 0.0060], dtype=torch.float64)
tensor([0.0340, 0.0058], dtype=torch.float64)
tensor([0.0322, 0.0055], dtype=torch.float64)
tensor([0.0304, 0.0053], dtype=torch.float64)
../_images/examples_model-51_13_1.png
[14]:
# Apply corrections

lattice:Line = error.clone()

index = 0
label = ''

for line in lattice.sequence:
    for element in line:
        if element.__class__.__name__ == 'Quadrupole':
            if label != element.name:
                index +=1
            label = element.name
            element.kn = (element.kn - kn[index - 1]).item()
[15]:
# Compute twiss and plot beta beating

ax_model, bx_model, ay_model, by_model = twiss(ring, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
ax_error, bx_error, ay_error, by_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True).T
ax_final, bx_final, ay_final, by_final = twiss(lattice, [], alignment=False, matched=True, advance=True, full=False, convert=True).T

# Plot beta beating

plt.figure(figsize=(16, 2))
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')
plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_final)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='x')
plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_final)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='x')
plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])
plt.tight_layout()
plt.show()
../_images/examples_model-51_15_0.png