Example-50: Advance (Phase advance sensitivity)

[1]:
# In this example effect of systematic quadruple errors on phase advances is illustrated
[2]:
# Import

from random import random
from pprint import pprint

import torch
from torch import Tensor

from pathlib import Path

import matplotlib
from matplotlib import pyplot as plt
matplotlib.rcParams['text.usetex'] = True

from model.library.line import Line

from model.command.external import load_sdds
from model.command.external import load_lattice
from model.command.build import build
from model.command.tune import tune
from model.command.advance import advance
[3]:
# Load ELEGANT twiss

path = Path('ic.twiss')
parameters, columns = load_sdds(path)

nu_qx:Tensor = torch.tensor(parameters['nux'] % 1, dtype=torch.float64)
nu_qy:Tensor = torch.tensor(parameters['nuy'] % 1, dtype=torch.float64)
[4]:
# Build and setup lattice

# Note, sextupoles are turned off and dipoles are linear

# Load ELEGANT table

path = Path('ic.lte')
data = load_lattice(path)

# Build ELEGANT table

ring:Line = build('RING', 'ELEGANT', data)
ring.flatten()

# Merge drifts

ring.merge()

# Split BPMs

ring.split((None, ['BPM'], None, None))

# Roll lattice start

ring.roll(1)

# Set linear dipoles

for element in ring:
    if element.__class__.__name__ == 'Dipole':
        element.linear = True

# Split lattice into lines by BPMs

ring.splice()

# Set number of elements of different kinds

nb = ring.describe['BPM']
nq = ring.describe['Quadrupole']
ns = ring.describe['Sextupole']
[5]:
# Compute tunes (fractional part)

guess = torch.tensor(4*[0.0], dtype=torch.float64)
nuqx, nuqy = tune(ring, [], alignment=False, matched=True, guess=guess, limit=8, epsilon=1.0E-9)

# Compare with elegant

print(torch.allclose(nu_qx, nuqx))
print(torch.allclose(nu_qy, nuqy))
True
True
[6]:
# Compute nominal phase advances between BPMs

kn = torch.zeros(nq, dtype=torch.float64)

muqx_model, muqy_model = advance(ring, [kn], ('kn', ['Quadrupole'], None, None), matched=True, limit=1, epsilon=None).T

print(muqx_model.shape)
print(muqy_model.shape)
torch.Size([16])
torch.Size([16])
[7]:
# Compute phase advances between BPMs using MC

kns = 0.01*torch.randn((8192, nq), dtype=torch.float64)

muqxs, muqys = torch.vmap(lambda kn: advance(ring, [kn], ('kn', ['Quadrupole'], None, None), matched=True, limit=1, epsilon=None), chunk_size=1024)(kns).swapaxes(0, -1)

dmuqxs = muqxs - muqx_model.unsqueeze(1)
dmuqys = muqys - muqy_model.unsqueeze(1)

print(dmuqxs.shape)
print(dmuqys.shape)
torch.Size([16, 8192])
torch.Size([16, 8192])
[8]:
# Plot histograms

fig, axs = plt.subplots(1, len(ring), figsize=(17*2, 2))
for dmuqx, ax in zip(dmuqxs, axs):
    ax.hist(dmuqx.cpu().numpy(), bins=100, color='blue', alpha=0.7)
plt.tight_layout()
plt.show()

fig, axs = plt.subplots(1, len(ring), figsize=(len(ring)*2.5, 2.5))
for dmuqy, ax in zip(dmuqys, axs):
    ax.hist(dmuqy.cpu().numpy(), bins=100, color='blue', alpha=0.7)
plt.tight_layout()
plt.show()
../_images/examples_model-49_8_0.png
../_images/examples_model-49_8_1.png
[9]:
# Compute and plot spreads

sigma_dmuqxs = dmuqxs.std(1)
sigma_dmuqys = dmuqys.std(1)

plt.figure(figsize=(16, 2))
plt.bar(range(len(sigma_dmuqxs)), sigma_dmuqxs.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()

plt.figure(figsize=(16, 2))
plt.bar(range(len(sigma_dmuqys)), sigma_dmuqys.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
../_images/examples_model-49_9_0.png
../_images/examples_model-49_9_1.png
[10]:
# Compute twiss derivatives and estimate spread from linear surrogate model using MC

# Compute derivatives

kn = torch.zeros(nq, dtype=torch.float64)
dmuqx_dk, dmuqy_dk = torch.func.jacrev(lambda kn: advance(ring, [kn], ('kn', ['Quadrupole'], None, None), matched=True, limit=1, epsilon=None), chunk_size=1024)(kn).swapaxes(0, 1)

print(dmuqx_dk.shape)
print(dmuqy_dk.shape)

# Sample

kns = 0.01*torch.randn((8192, nq), dtype=torch.float64)

dmuqxs = dmuqx_dk @ kns.T
dmuqxy = dmuqy_dk @ kns.T

# Compute and plot spreads

sigma_dmuqxs = dmuqxs.std(1)
sigma_dmuqys = dmuqys.std(1)

plt.figure(figsize=(16, 2))
plt.bar(range(len(sigma_dmuqxs)), sigma_dmuqxs.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()

plt.figure(figsize=(16, 2))
plt.bar(range(len(sigma_dmuqys)), sigma_dmuqys.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
torch.Size([16, 28])
torch.Size([16, 28])
../_images/examples_model-49_10_1.png
../_images/examples_model-49_10_2.png
[11]:
# Compute spread using error propagation

sigma_dmuqxs = (dmuqx_dk @ (0.01*torch.eye(nq,  dtype=torch.float64))**2 @ dmuqx_dk.T).diag().sqrt()
sigma_dmuqys = (dmuqy_dk @ (0.01*torch.eye(nq,  dtype=torch.float64))**2 @ dmuqy_dk.T).diag().sqrt()

plt.figure(figsize=(16, 2))
plt.bar(range(len(sigma_dmuqxs)), sigma_dmuqxs.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()

plt.figure(figsize=(16, 2))
plt.bar(range(len(sigma_dmuqys)), sigma_dmuqys.cpu().numpy(), color='blue', alpha=0.75, width=0.75)
plt.tight_layout()
plt.show()
../_images/examples_model-49_11_0.png
../_images/examples_model-49_11_1.png