Example-39: Mapping (Transformations around closed orbit)
[1]:
# In this example the steps to define mappings between elements are illustrated
# These mappings are differentiable with respect to state and different deviation groups
# Additionaly, mappings can be defined around closed orbit
[2]:
# Import
from random import random
from pprint import pprint
import torch
from pathlib import Path
from model.library.line import Line
from model.library.corrector import Corrector
from model.command.external import load_lattice
from model.command.build import build
from model.command.wrapper import group
from model.command.orbit import orbit
from model.command.mapping import mapping
from model.command.mapping import matrix
[3]:
# Build and setup lattice
# Quadrupoles are splitted into 2**2 parts, Dipoles -- 2**4 part
# Correctors are inserted between parts
# Random errors are assigned to correctors, so that the origin is not preserved
# Load ELEGANT table
path = Path('ic.lte')
data = load_lattice(path)
# Build ELEGANT table
ring:Line = build('RING', 'ELEGANT', data)
ring.flatten()
# Merge drifts
ring.merge()
# Split quadrupoles and insert correctors
nq = 2**2
for name in [name for name, kind, *_ in ring.layout() if kind == 'Quadrupole']:
corrector = Corrector(f'{name}_CXY', factor=1/(nq - 1))
corrector.cx = 1.0E-3*(random() - 0.5)
corrector.cy = 1.0E-3*(random() - 0.5)
ring.split((nq, None, [name], None), paste=[corrector])
# Split dipoles and insert correctors
nd = 2**4
for name in [name for name, kind, *_ in ring.layout() if kind == 'Dipole']:
corrector = Corrector(f'{name}_CXY', factor=1/(nd - 1))
corrector.cx = 1.0E-3*(random() - 0.5)
corrector.cy = 1.0E-3*(random() - 0.5)
ring.split((nd, None, [name], None), paste=[corrector])
# Set linear flag in dipoles
for element in ring:
if element.__class__.__name__ == 'Dipole':
element.linear = True
# Set number of elements of different kinds
nb = ring.describe['BPM']
nc = ring.describe['Corrector']
nq = ring.describe['Quadrupole']
ns = ring.describe['Sextupole']
[4]:
# model.command.wrapper.group can be used to define parametric and differentiable transformations
# This transformations propagate initial state from given (probe) element start to given (other) element end
# Start and end elements can be specified by names (match the first occurance in line sequence)
# Or they can be specified by integers (can be negative, mod number of elements in sequence is used to define specified transformation)
# Since correctors have non-zero angles, zero is not mapped to zero
state = torch.tensor(4*[0.0], dtype=torch.float64)
print(ring(state))
# Define transformation using names (assumed to be different)
probe, *_, other = ring.names
transformation, *_ = group(ring, probe, other)
print(transformation(state))
# Define transformation using elements positions is lattice sequence
probe = 0
other = len(ring) - 1
transformation, *_ = group(ring, probe, other)
print(transformation(state))
tensor([-0.0013, 0.0007, 0.0002, -0.0010], dtype=torch.float64)
tensor([-0.0013, 0.0007, 0.0002, -0.0010], dtype=torch.float64)
tensor([-0.0013, 0.0007, 0.0002, -0.0010], dtype=torch.float64)
[5]:
# Compute closed orbit and test transformation around it
fp = torch.tensor(4*[0.0], dtype=torch.float64)
fps, _ = orbit(ring, fp, [], advance=True)
fp, *_ = fps
print(fp)
print(ring(fp))
print()
print(transformation(state + fp) - fp)
tensor([ 0.0024, -0.0054, 0.0036, 0.0051], dtype=torch.float64)
tensor([ 0.0024, -0.0054, 0.0036, 0.0051], dtype=torch.float64)
tensor([-1.7347e-18, -2.6021e-18, 6.9389e-18, 6.9389e-18],
dtype=torch.float64)
[6]:
# model.command.mapping.mapping can be used as an alias to model.command.wrapper.group
# Additionaly, tt can be used to construct parametric and differentiable transformations from one element to the other that is build around closed orbit
transformation, _ = mapping(ring, probe, other, matched=False)
print(transformation(state + fp) - fp)
# Transformation around closed orbit
transformation, _ = mapping(ring, probe, other, matched=True, limit=8, epsilon=1.0E-9)
print(transformation(state))
# With matched flag, closed orbit will be computed on each invocation
# To speed up computations, known fixed point can be passed and number of iterations set to zero
# In this case probe is assumed to be the lattice start
transformation, _ = mapping(ring, probe, other, matched=True, guess=fp, limit=0, epsilon=None)
print(transformation(state))
# Also, to compute derivatives, limit can be set to one
# Set epsilon to None for vmap computations
tensor([-1.7347e-18, -2.6021e-18, 6.9389e-18, 6.9389e-18],
dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[7]:
# Transformation between given elements
probe = 'BPM07'
other = 'BPM10'
# Propagate element by element
local = state.clone()
for element in ring[ring.position(probe):ring.position(other)]:
local = element(local)
print(local)
# Propagate using group
transformation, *_ = group(ring, probe, other)
print(transformation(state))
# Propagate using mapping
transformation, *_ = mapping(ring, probe, other)
print(transformation(state))
tensor([-0.0001, 0.0009, 0.0005, 0.0004], dtype=torch.float64)
tensor([-0.0001, 0.0009, 0.0005, 0.0004], dtype=torch.float64)
tensor([-0.0001, 0.0009, 0.0005, 0.0004], dtype=torch.float64)
[8]:
# Transformation between elements around closed orbit
# Set closed orbit values at probe and other
fp_probe = fps[ring.position(probe)]
fp_other = fps[ring.position(other)]
# Propagate element by element
local = fp_probe + state.clone()
for element in ring[ring.position(probe):ring.position(other)]:
local = element(local)
print(local - fp_other)
# Propagate using group
transformation, *_ = group(ring, probe, other)
print(transformation(state + fp_probe) - fp_other)
# Propagate using mapping
transformation, *_ = mapping(ring, probe, other, matched=False)
print(transformation(state + fp_probe) - fp_other)
# Propagate using mapping (matched)
transformation, *_ = mapping(ring, probe, other, matched=True)
print(transformation(state))
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[9]:
# Parametric mapping
# Set initial value (relative to closed orbit)
state = torch.tensor(4*[0.0], dtype=torch.float64)
# Without root and matched flag, tensor elements are binded only to matched elements (unique names) within selected range of elements
transformation, ((_, names_line, _), *_) = mapping(ring, probe, other, ('kn', ['Quadrupole'], None, None), root=False, matched=False)
# Set random quadupole errors withing the line
kn_line = 0.01*torch.randn(len(names_line), dtype=torch.float64)
# Since non-zero deviations are passed, closed orbit has been changed
print(transformation(state + fp_probe, 0*kn_line) - fp_other)
print(transformation(state + fp_probe, 1*kn_line) - fp_other)
print()
# With root flag, tensor elements are binded to all matched elements (unique names)
transformation, ((_, names_ring, _), *_) = mapping(ring, probe, other, ('kn', ['Quadrupole'], None, None), root=True, matched=False)
# To match the previous result, correct slice should be set
start = 0
count = 0
for i, name in enumerate(names_ring):
if name in names_line:
if not start:
start = i
count = i
count += 1
kn_ring = torch.zeros(len(names_ring), dtype=torch.float64)
kn_ring[start:count] = kn_line
# Propagate state
print(transformation(state + fp_probe, 0*kn_ring) - fp_other)
print(transformation(state + fp_probe, 1*kn_ring) - fp_other)
print()
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([ 1.3384e-05, -1.2968e-05, 1.0788e-05, 1.8095e-05],
dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([ 1.3384e-05, -1.2968e-05, 1.0788e-05, 1.8095e-05],
dtype=torch.float64)
[10]:
# Parametric mappings around closed orbit
# In this case root parameter is ignored
# Elements are ordered according to their appearance in the input lattice (similar to changing start in orbit function with respect flag)
# Compute closed orbit with quadrupole errors
fp = torch.tensor(4*[0.0], dtype=torch.float64)
fps, _ = orbit(ring, fp, [kn_ring], ('kn', ['Quadrupole'], None, None), advance=True, limit=8, epsilon=1.0E-9)
fp, *_ = fps
# Set closed orbit at probe and other
fp_probe = fps[ring.position(probe)]
fp_other = fps[ring.position(other)]
# Test closed orbit at lattice start
line = ring.clone()
transformation, ((_, names, _), *_) = mapping(line, 0, len(line) - 1, ('kn', ['Quadrupole'], None, None), matched=False)
print(fp - transformation(fp, kn_ring))
# Test closed orbit at probe and other
# Note, groups are setup using returned matched names
line = ring.clone()
line.start = probe
transformation, _ = mapping(line, 0, len(line) - 1, ('kn', None, names, None), matched=False)
print(fp_probe - transformation(fp_probe, kn_ring))
line = ring.clone()
line.start = other
transformation, _ = mapping(line, 0, len(line) - 1, ('kn', None, names, None), matched=False)
print(fp_other - transformation(fp_other, kn_ring))
# Test mapping
transformation, _ = mapping(ring, probe, other, ('kn', ['Quadrupole'], None, None), matched=False, limit=8, epsilon=1.0E-9)
print(transformation(0*state + fp_probe, kn_ring) - fp_other)
print(transformation(1*state + fp_probe, kn_ring) - fp_other)
# Test mapping around closed orbit
transformation, _ = mapping(ring, probe, other, ('kn', ['Quadrupole'], None, None), matched=True, limit=8, epsilon=1.0E-9)
print(transformation(0*state, kn_ring))
print(transformation(1*state, kn_ring))
tensor([-3.0358e-18, 1.3878e-17, 3.4694e-18, 7.8063e-18],
dtype=torch.float64)
tensor([ 1.8648e-17, -3.2960e-17, 5.5565e-18, 8.2399e-18],
dtype=torch.float64)
tensor([ 1.9516e-17, -2.6888e-17, -3.9031e-18, -4.3368e-18],
dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[11]:
# Vecrorized mapping over states / knobs
# In this case, epsilon should be set to None (relevant only for the case around closed orbit)
transformation, _ = mapping(ring, probe, other, ('kn', ['Quadrupole'], None, None), matched=True, limit=8, epsilon=None)
states = 1.0E-3*torch.randn((128, *state.shape), dtype=torch.float64)
knobs = 1.0E-3*torch.randn((128, *kn_ring.shape), dtype=torch.float64)
print(torch.vmap(lambda state: transformation(state, kn_ring))(states).shape)
print(torch.vmap(lambda knob: transformation(state, knob))(knobs).shape)
torch.Size([128, 4])
torch.Size([128, 4])
[12]:
# Differentiability with respect to state and knobs
state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
# Mapping
transformation, _ = mapping(ring, probe, other, ('kn', ['Quadrupole'], None, None), matched=False, limit=8, epsilon=1.0E-9)
pprint(torch.func.jacrev(transformation, 0)(state, kn_ring))
pprint(torch.func.jacrev(transformation, 1)(state, kn_ring))
print()
# Mapping around closed orbit
transformation, _ = mapping(ring, probe, other, ('kn', ['Quadrupole'], None, None), matched=True, limit=8, epsilon=1.0E-9)
pprint(torch.func.jacrev(transformation, 0)(state, kn_ring))
pprint(torch.func.jacrev(transformation, 1)(state, kn_ring))
print()
tensor([[-1.7095e+00, -1.5994e+00, 4.3872e-03, -2.3450e-03],
[ 4.0853e+00, 3.2372e+00, -1.5448e-02, 1.0439e-02],
[-6.1373e-03, -7.3648e-03, -2.0667e-01, -5.2196e-01],
[-3.3938e-03, -5.6320e-03, 1.2684e+00, -1.6353e+00]],
dtype=torch.float64)
tensor([[ 0.0000e+00, 0.0000e+00, 0.0000e+00, -1.4504e-05, 2.6462e-05,
9.3523e-06, 1.6029e-05, 7.2119e-06, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, -7.6792e-05, -3.0233e-05,
-9.6717e-06, 1.4552e-05, 4.2945e-05, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9984e-05, 1.2823e-04,
1.7506e-04, 6.9073e-05, 1.3380e-05, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 8.3031e-06, 2.0438e-04,
3.2004e-04, 2.1256e-04, 8.4305e-05, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00]], dtype=torch.float64)
tensor([[-1.7055, -1.5948, -0.0292, 0.0098],
[ 4.0919, 3.2404, 0.0995, -0.0680],
[ 0.0399, 0.0477, -0.1600, -0.5528],
[-0.0165, 0.0056, 1.3239, -1.6700]], dtype=torch.float64)
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.]], dtype=torch.float64)