Example-39: Mapping (Transformations around closed orbit)

[1]:
# In this example the steps to define mappings between elements are illustrated
# These mappings are differentiable with respect to state and different deviation groups
# Additionaly, mappings can be defined around closed orbit
[2]:
# Import

from random import random
from pprint import pprint

import torch

from pathlib import Path

from model.library.line import Line
from model.library.corrector import Corrector

from model.command.external import load_lattice
from model.command.build import build
from model.command.wrapper import group
from model.command.orbit import orbit
from model.command.mapping import mapping
from model.command.mapping import matrix
[3]:
# Build and setup lattice

# Quadrupoles are splitted into 2**2 parts, Dipoles -- 2**4 part
# Correctors are inserted between parts
# Random errors are assigned to correctors, so that the origin is not preserved

# Load ELEGANT table

path = Path('ic.lte')
data = load_lattice(path)

# Build ELEGANT table

ring:Line = build('RING', 'ELEGANT', data)
ring.flatten()

# Merge drifts

ring.merge()

# Split quadrupoles and insert correctors

nq = 2**2

for name in [name for name, kind, *_ in ring.layout() if kind == 'Quadrupole']:
    corrector = Corrector(f'{name}_CXY', factor=1/(nq - 1))
    corrector.cx = 1.0E-3*(random() - 0.5)
    corrector.cy = 1.0E-3*(random() - 0.5)
    ring.split((nq, None, [name], None), paste=[corrector])

# Split dipoles and insert correctors

nd = 2**4

for name in [name for name, kind, *_ in ring.layout() if kind == 'Dipole']:
    corrector = Corrector(f'{name}_CXY', factor=1/(nd - 1))
    corrector.cx = 1.0E-3*(random() - 0.5)
    corrector.cy = 1.0E-3*(random() - 0.5)
    ring.split((nd, None, [name], None), paste=[corrector])

# Set linear flag in dipoles

for element in ring:
    if element.__class__.__name__ == 'Dipole':
        element.linear = True

# Set number of elements of different kinds

nb = ring.describe['BPM']
nc = ring.describe['Corrector']
nq = ring.describe['Quadrupole']
ns = ring.describe['Sextupole']
[4]:
# model.command.wrapper.group can be used to define parametric and differentiable transformations
# This transformations propagate initial state from given (probe) element start to given (other) element end
# Start and end elements can be specified by names (match the first occurance in line sequence)
# Or they can be specified by integers (can be negative, mod number of elements in sequence is used to define specified transformation)

# Since  correctors have non-zero angles, zero is not mapped to zero

state = torch.tensor(4*[0.0], dtype=torch.float64)
print(ring(state))

# Define transformation using names (assumed to be different)

probe, *_, other = ring.names
transformation, *_ = group(ring, probe, other)
print(transformation(state))

# Define transformation using elements positions is lattice sequence

probe = 0
other = len(ring) - 1
transformation, *_ = group(ring, probe, other)
print(transformation(state))
tensor([-0.0013,  0.0007,  0.0002, -0.0010], dtype=torch.float64)
tensor([-0.0013,  0.0007,  0.0002, -0.0010], dtype=torch.float64)
tensor([-0.0013,  0.0007,  0.0002, -0.0010], dtype=torch.float64)
[5]:
# Compute closed orbit and test transformation around it

fp = torch.tensor(4*[0.0], dtype=torch.float64)
fps, _ = orbit(ring, fp, [], advance=True)
fp, *_ = fps

print(fp)
print(ring(fp))
print()

print(transformation(state + fp) - fp)
tensor([ 0.0024, -0.0054,  0.0036,  0.0051], dtype=torch.float64)
tensor([ 0.0024, -0.0054,  0.0036,  0.0051], dtype=torch.float64)

tensor([-1.7347e-18, -2.6021e-18,  6.9389e-18,  6.9389e-18],
       dtype=torch.float64)
[6]:
# model.command.mapping.mapping can be used as an alias to model.command.wrapper.group
# Additionaly, tt can be used to construct parametric and differentiable transformations from one element to the other that is build around closed orbit

transformation, _ = mapping(ring, probe, other, matched=False)
print(transformation(state + fp) - fp)

# Transformation around closed orbit

transformation, _ = mapping(ring, probe, other, matched=True, limit=8, epsilon=1.0E-9)
print(transformation(state))

# With matched flag, closed orbit will be computed on each invocation
# To speed up computations, known fixed point can be passed and number of iterations set to zero
# In this case probe is assumed to be the lattice start

transformation, _ = mapping(ring, probe, other, matched=True, guess=fp, limit=0, epsilon=None)
print(transformation(state))

# Also, to compute derivatives, limit can be set to one
# Set  epsilon to None for vmap computations
tensor([-1.7347e-18, -2.6021e-18,  6.9389e-18,  6.9389e-18],
       dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[7]:
# Transformation between given elements

probe = 'BPM07'
other = 'BPM10'

# Propagate element by element

local = state.clone()
for element in ring[ring.position(probe):ring.position(other)]:
    local = element(local)
print(local)

# Propagate using group

transformation, *_ = group(ring, probe, other)
print(transformation(state))

# Propagate using mapping

transformation, *_ = mapping(ring, probe, other)
print(transformation(state))
tensor([-0.0001,  0.0009,  0.0005,  0.0004], dtype=torch.float64)
tensor([-0.0001,  0.0009,  0.0005,  0.0004], dtype=torch.float64)
tensor([-0.0001,  0.0009,  0.0005,  0.0004], dtype=torch.float64)
[8]:
# Transformation between elements around closed orbit

# Set closed orbit values at probe and other

fp_probe = fps[ring.position(probe)]
fp_other = fps[ring.position(other)]

# Propagate element by element

local = fp_probe + state.clone()
for element in ring[ring.position(probe):ring.position(other)]:
    local = element(local)
print(local - fp_other)

# Propagate using group

transformation, *_ = group(ring, probe, other)
print(transformation(state + fp_probe) - fp_other)

# Propagate using mapping

transformation, *_ = mapping(ring, probe, other, matched=False)
print(transformation(state + fp_probe) - fp_other)

# Propagate using mapping (matched)

transformation, *_ = mapping(ring, probe, other, matched=True)
print(transformation(state))
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[9]:
# Parametric mapping

# Set initial value (relative to closed orbit)

state = torch.tensor(4*[0.0], dtype=torch.float64)

# Without root and matched flag, tensor elements are binded only to matched elements (unique names) within selected range of elements

transformation, ((_, names_line, _), *_) = mapping(ring, probe, other, ('kn', ['Quadrupole'], None, None), root=False, matched=False)

# Set random quadupole errors withing the line

kn_line = 0.01*torch.randn(len(names_line), dtype=torch.float64)


# Since non-zero deviations are passed, closed orbit has been changed

print(transformation(state + fp_probe, 0*kn_line) - fp_other)
print(transformation(state + fp_probe, 1*kn_line) - fp_other)
print()

# With root flag, tensor elements are binded to all matched elements (unique names)

transformation, ((_, names_ring, _), *_) = mapping(ring, probe, other, ('kn', ['Quadrupole'], None, None), root=True, matched=False)

# To match the previous result, correct slice should be set

start = 0
count = 0
for i, name in enumerate(names_ring):
    if name in names_line:
        if not start:
            start = i
            count = i
        count += 1

kn_ring = torch.zeros(len(names_ring), dtype=torch.float64)
kn_ring[start:count] = kn_line

# Propagate state

print(transformation(state + fp_probe, 0*kn_ring) - fp_other)
print(transformation(state + fp_probe, 1*kn_ring) - fp_other)
print()
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([ 1.3384e-05, -1.2968e-05,  1.0788e-05,  1.8095e-05],
       dtype=torch.float64)

tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([ 1.3384e-05, -1.2968e-05,  1.0788e-05,  1.8095e-05],
       dtype=torch.float64)

[10]:
# Parametric mappings around closed orbit
# In this case root parameter is ignored
# Elements are ordered according to their appearance in the input lattice (similar to changing start in orbit function with respect flag)

# Compute closed orbit with quadrupole errors

fp = torch.tensor(4*[0.0], dtype=torch.float64)
fps, _ = orbit(ring, fp, [kn_ring], ('kn', ['Quadrupole'], None, None), advance=True, limit=8, epsilon=1.0E-9)
fp, *_ = fps

# Set closed orbit at probe and other

fp_probe = fps[ring.position(probe)]
fp_other = fps[ring.position(other)]


# Test closed orbit at lattice start

line = ring.clone()
transformation, ((_, names, _), *_) = mapping(line, 0, len(line) - 1, ('kn', ['Quadrupole'], None, None), matched=False)
print(fp - transformation(fp, kn_ring))

# Test closed orbit at probe and other
# Note, groups are setup using returned matched names

line = ring.clone()
line.start = probe
transformation, _ = mapping(line, 0, len(line) - 1, ('kn', None, names, None), matched=False)
print(fp_probe - transformation(fp_probe, kn_ring))

line = ring.clone()
line.start = other
transformation, _ = mapping(line, 0, len(line) - 1, ('kn', None, names, None), matched=False)
print(fp_other - transformation(fp_other, kn_ring))

# Test mapping

transformation, _ = mapping(ring, probe, other, ('kn', ['Quadrupole'], None, None), matched=False, limit=8, epsilon=1.0E-9)
print(transformation(0*state + fp_probe, kn_ring) - fp_other)
print(transformation(1*state + fp_probe, kn_ring) - fp_other)

# Test mapping around closed orbit

transformation, _ = mapping(ring, probe, other, ('kn', ['Quadrupole'], None, None), matched=True, limit=8, epsilon=1.0E-9)
print(transformation(0*state, kn_ring))
print(transformation(1*state, kn_ring))
tensor([-3.0358e-18,  1.3878e-17,  3.4694e-18,  7.8063e-18],
       dtype=torch.float64)
tensor([ 1.8648e-17, -3.2960e-17,  5.5565e-18,  8.2399e-18],
       dtype=torch.float64)
tensor([ 1.9516e-17, -2.6888e-17, -3.9031e-18, -4.3368e-18],
       dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[11]:
# Vecrorized mapping over states / knobs
# In this case, epsilon should be set to None (relevant only for the case around closed orbit)

transformation, _ = mapping(ring, probe, other, ('kn', ['Quadrupole'], None, None), matched=True, limit=8, epsilon=None)

states = 1.0E-3*torch.randn((128, *state.shape), dtype=torch.float64)
knobs = 1.0E-3*torch.randn((128, *kn_ring.shape), dtype=torch.float64)

print(torch.vmap(lambda state: transformation(state, kn_ring))(states).shape)
print(torch.vmap(lambda knob: transformation(state, knob))(knobs).shape)
torch.Size([128, 4])
torch.Size([128, 4])
[12]:
# Differentiability with respect to state and knobs

state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)

# Mapping

transformation, _ = mapping(ring, probe, other, ('kn', ['Quadrupole'], None, None), matched=False, limit=8, epsilon=1.0E-9)

pprint(torch.func.jacrev(transformation, 0)(state, kn_ring))
pprint(torch.func.jacrev(transformation, 1)(state, kn_ring))
print()

# Mapping around closed orbit

transformation, _ = mapping(ring, probe, other, ('kn', ['Quadrupole'], None, None), matched=True, limit=8, epsilon=1.0E-9)

pprint(torch.func.jacrev(transformation, 0)(state, kn_ring))
pprint(torch.func.jacrev(transformation, 1)(state, kn_ring))
print()
tensor([[-1.7095e+00, -1.5994e+00,  4.3872e-03, -2.3450e-03],
        [ 4.0853e+00,  3.2372e+00, -1.5448e-02,  1.0439e-02],
        [-6.1373e-03, -7.3648e-03, -2.0667e-01, -5.2196e-01],
        [-3.3938e-03, -5.6320e-03,  1.2684e+00, -1.6353e+00]],
       dtype=torch.float64)
tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00, -1.4504e-05,  2.6462e-05,
          9.3523e-06,  1.6029e-05,  7.2119e-06,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -7.6792e-05, -3.0233e-05,
         -9.6717e-06,  1.4552e-05,  4.2945e-05,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.9984e-05,  1.2823e-04,
          1.7506e-04,  6.9073e-05,  1.3380e-05,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  8.3031e-06,  2.0438e-04,
          3.2004e-04,  2.1256e-04,  8.4305e-05,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00]], dtype=torch.float64)

tensor([[-1.7055, -1.5948, -0.0292,  0.0098],
        [ 4.0919,  3.2404,  0.0995, -0.0680],
        [ 0.0399,  0.0477, -0.1600, -0.5528],
        [-0.0165,  0.0056,  1.3239, -1.6700]], dtype=torch.float64)
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.]], dtype=torch.float64)