Example-24: Wrapper
[1]:
# In this example construction of parametric call wrappers is illustrated
# For elements, all deviation parameters are passed as dictionary
# Wrapped elements are invoked using positional agruments
[2]:
# Import
from pprint import pprint
import torch
from twiss import twiss
from ndmap.pfp import parametric_fixed_point
from ndmap.evaluate import evaluate
from ndmap.signature import chop
from model.library.drift import Drift
from model.library.multipole import Multipole
from model.library.dipole import Dipole
from model.library.line import Line
from model.command.wrapper import wrapper
[3]:
# Define simple FODO based lattice using nested lines
QF = Multipole('QF', 0.5, +0.20)
QD = Multipole('QD', 0.5, -0.19)
DR = Drift('DR', 0.75)
BM = Dipole('BM', 3.50, torch.pi/4.0)
FODO = [QF, DR, BM, DR, QD, QD, DR, BM, DR, QF]
FODO_A = Line('FODO_A', FODO, propagate=True, dp=0.0, exact=False, output=False, matrix=False)
FODO_B = Line('FODO_B', FODO, propagate=True, dp=0.0, exact=False, output=False, matrix=False)
FODO_C = Line('FODO_C', FODO, propagate=True, dp=0.0, exact=False, output=False, matrix=False)
FODO_D = Line('FODO_D', FODO, propagate=True, dp=0.0, exact=False, output=False, matrix=False)
LINE_AB = Line('LINE_AB', [FODO_A, FODO_B], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
LINE_CD = Line('LINE_CD', [FODO_C, FODO_D], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
RING = Line('RING', [LINE_AB, LINE_CD], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
[4]:
# Deviation variables are passed to elements/lines as dictionaries
# In order to compute derivatives with respect to a deviation variable
# A tensor should be binded to a corresponding leaf deviation dictionary value
pprint(RING.data(alignment=False), sort_dicts=False)
{'LINE_AB': {'FODO_A': {'QF': {'kn': tensor(0., dtype=torch.float64),
'ks': tensor(0., dtype=torch.float64),
'ms': tensor(0., dtype=torch.float64),
'mo': tensor(0., dtype=torch.float64),
'dp': tensor(0., dtype=torch.float64),
'dl': tensor(0., dtype=torch.float64)},
'DR': {'dp': tensor(0., dtype=torch.float64),
'dl': tensor(0., dtype=torch.float64)},
'BM': {'dw': tensor(0., dtype=torch.float64),
'e1': tensor(0., dtype=torch.float64),
'e2': tensor(0., dtype=torch.float64),
'kn': tensor(0., dtype=torch.float64),
'ks': tensor(0., dtype=torch.float64),
'ms': tensor(0., dtype=torch.float64),
'mo': tensor(0., dtype=torch.float64),
'dp': tensor(0., dtype=torch.float64),
'dl': tensor(0., dtype=torch.float64)},
'QD': {'kn': tensor(0., dtype=torch.float64),
'ks': tensor(0., dtype=torch.float64),
'ms': tensor(0., dtype=torch.float64),
'mo': tensor(0., dtype=torch.float64),
'dp': tensor(0., dtype=torch.float64),
'dl': tensor(0., dtype=torch.float64)}},
'FODO_B': {'QF': {'kn': tensor(0., dtype=torch.float64),
'ks': tensor(0., dtype=torch.float64),
'ms': tensor(0., dtype=torch.float64),
'mo': tensor(0., dtype=torch.float64),
'dp': tensor(0., dtype=torch.float64),
'dl': tensor(0., dtype=torch.float64)},
'DR': {'dp': tensor(0., dtype=torch.float64),
'dl': tensor(0., dtype=torch.float64)},
'BM': {'dw': tensor(0., dtype=torch.float64),
'e1': tensor(0., dtype=torch.float64),
'e2': tensor(0., dtype=torch.float64),
'kn': tensor(0., dtype=torch.float64),
'ks': tensor(0., dtype=torch.float64),
'ms': tensor(0., dtype=torch.float64),
'mo': tensor(0., dtype=torch.float64),
'dp': tensor(0., dtype=torch.float64),
'dl': tensor(0., dtype=torch.float64)},
'QD': {'kn': tensor(0., dtype=torch.float64),
'ks': tensor(0., dtype=torch.float64),
'ms': tensor(0., dtype=torch.float64),
'mo': tensor(0., dtype=torch.float64),
'dp': tensor(0., dtype=torch.float64),
'dl': tensor(0., dtype=torch.float64)}}},
'LINE_CD': {'FODO_C': {'QF': {'kn': tensor(0., dtype=torch.float64),
'ks': tensor(0., dtype=torch.float64),
'ms': tensor(0., dtype=torch.float64),
'mo': tensor(0., dtype=torch.float64),
'dp': tensor(0., dtype=torch.float64),
'dl': tensor(0., dtype=torch.float64)},
'DR': {'dp': tensor(0., dtype=torch.float64),
'dl': tensor(0., dtype=torch.float64)},
'BM': {'dw': tensor(0., dtype=torch.float64),
'e1': tensor(0., dtype=torch.float64),
'e2': tensor(0., dtype=torch.float64),
'kn': tensor(0., dtype=torch.float64),
'ks': tensor(0., dtype=torch.float64),
'ms': tensor(0., dtype=torch.float64),
'mo': tensor(0., dtype=torch.float64),
'dp': tensor(0., dtype=torch.float64),
'dl': tensor(0., dtype=torch.float64)},
'QD': {'kn': tensor(0., dtype=torch.float64),
'ks': tensor(0., dtype=torch.float64),
'ms': tensor(0., dtype=torch.float64),
'mo': tensor(0., dtype=torch.float64),
'dp': tensor(0., dtype=torch.float64),
'dl': tensor(0., dtype=torch.float64)}},
'FODO_D': {'QF': {'kn': tensor(0., dtype=torch.float64),
'ks': tensor(0., dtype=torch.float64),
'ms': tensor(0., dtype=torch.float64),
'mo': tensor(0., dtype=torch.float64),
'dp': tensor(0., dtype=torch.float64),
'dl': tensor(0., dtype=torch.float64)},
'DR': {'dp': tensor(0., dtype=torch.float64),
'dl': tensor(0., dtype=torch.float64)},
'BM': {'dw': tensor(0., dtype=torch.float64),
'e1': tensor(0., dtype=torch.float64),
'e2': tensor(0., dtype=torch.float64),
'kn': tensor(0., dtype=torch.float64),
'ks': tensor(0., dtype=torch.float64),
'ms': tensor(0., dtype=torch.float64),
'mo': tensor(0., dtype=torch.float64),
'dp': tensor(0., dtype=torch.float64),
'dl': tensor(0., dtype=torch.float64)},
'QD': {'kn': tensor(0., dtype=torch.float64),
'ks': tensor(0., dtype=torch.float64),
'ms': tensor(0., dtype=torch.float64),
'mo': tensor(0., dtype=torch.float64),
'dp': tensor(0., dtype=torch.float64),
'dl': tensor(0., dtype=torch.float64)}}}}
[5]:
# Compute parametric closed orbit (first order with respect to momentum deviation)
# Without wrapping, all momenta deviation occurances should be binded to a singel tensor
# Hence, deviation table should be traversed recursively down to all leafs
def scan(data, name, target):
for key, value in data.items():
if isinstance(value, dict):
scan(value, name, target)
elif key == name:
data[key] = target
# Set ring function
def ring(state, dp):
dp, *_ = dp
data = RING.data()
scan(data, 'dp', dp)
return RING(state, data=data)
# Set deviations
fp = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)
# Compute pfp
pfp, *_ = parametric_fixed_point((1, ), fp, [dp], ring)
chop(pfp)
pfp
[5]:
[tensor([0., 0., 0., 0.], dtype=torch.float64),
tensor([[4.4462],
[0.0000],
[0.0000],
[0.0000]], dtype=torch.float64)]
[6]:
# Using wrapper we can define the about ring function as follows
fn = wrapper(RING, (None, None, 'dp'))
# Set deviations
fp = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)
# Compute pfp
pfp, *_ = parametric_fixed_point((1, ), fp, [dp], fn)
chop(pfp)
pfp
[6]:
[tensor([0., 0., 0., 0.], dtype=torch.float64),
tensor([[4.4462],
[0.0000],
[0.0000],
[0.0000]], dtype=torch.float64)]
[7]:
# Compute chromaticity (without wrapping)
def scan(data, name, target):
for key, value in data.items():
if isinstance(value, dict):
scan(value, name, target)
elif key == name:
data[key] = target
# Set ring function
def ring(state, dp):
dp, *_ = dp
data = RING.data()
scan(data, 'dp', dp)
return RING(state , data=data)
# Set ring function around pfp
def pfp_ring(state, dp):
return ring(state + evaluate(pfp, [dp]), dp) - evaluate(pfp, [dp])
# Set tune function
def tune(dp):
matrix = torch.func.jacrev(pfp_ring)(state, dp)
tunes, *_ = twiss(matrix)
return tunes
state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)
print(tune(dp))
print(torch.func.jacrev(tune)(dp).squeeze())
tensor([0.6951, 0.7019], dtype=torch.float64)
tensor([-2.0649, -0.8260], dtype=torch.float64)
[8]:
# Compute chromaticity (with wrapping)
# Set ring function
fn = wrapper(RING, (None, None, 'dp'))
# Set ring function around pfp
def pfp_ring(state, dp):
return fn(state + evaluate(pfp, [dp]), dp) - evaluate(pfp, [dp])
# Set tune function
def tune(dp):
matrix = torch.func.jacrev(pfp_ring)(state, dp)
tunes, *_ = twiss(matrix)
return tunes
state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)
print(tune(dp))
print(torch.func.jacrev(tune)(dp).squeeze())
tensor([0.6951, 0.7019], dtype=torch.float64)
tensor([-2.0649, -0.8260], dtype=torch.float64)
[9]:
# Compute chromaticity derivative with respect to sextupole ampitudes (without wrapping)
def scan(data, name, target):
for key, value in data.items():
if isinstance(value, dict):
scan(value, name, target)
elif key == name:
data[key] = target
def ring(state, dp, dms):
dp, *_ = dp
dmsf, dmsd, *_ = dms
data = RING.data()
scan(data, 'dp', dp)
data['LINE_AB']['FODO_A']['QF']['ms'] = dmsf
data['LINE_AB']['FODO_B']['QF']['ms'] = dmsf
data['LINE_CD']['FODO_C']['QF']['ms'] = dmsf
data['LINE_CD']['FODO_D']['QF']['ms'] = dmsf
data['LINE_AB']['FODO_A']['QD']['ms'] = dmsd
data['LINE_AB']['FODO_B']['QD']['ms'] = dmsd
data['LINE_CD']['FODO_C']['QD']['ms'] = dmsd
data['LINE_CD']['FODO_D']['QD']['ms'] = dmsd
return RING(state, data=data)
def pfp_ring(state, dp, dms):
return ring(state + evaluate(pfp, [dp]), dp, dms) - evaluate(pfp, [dp])
def tune(dp, dms):
matrix = torch.func.jacrev(pfp_ring)(state, dp, dms)
tunes, *_ = twiss(matrix)
return tunes
def chromaticity(dms):
return torch.func.jacrev(tune)(dp, dms)
state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
dms = torch.tensor([0.0, 0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)
print(tune(dp, dms))
print(chromaticity(dms).squeeze())
print(torch.func.jacrev(chromaticity)(dms).squeeze())
tensor([0.6951, 0.7019], dtype=torch.float64)
tensor([-2.0649, -0.8260], dtype=torch.float64)
tensor([[ 25.8500, 1.0470],
[ -9.0271, -16.4821]], dtype=torch.float64)
[10]:
# Compute chromaticity derivative with respect to sextupole ampitudes (with wrapping)
def scan(data, name, target):
for key, value in data.items():
if isinstance(value, dict):
scan(value, name, target)
elif key == name:
data[key] = target
ring = wrapper(RING, (None, None, 'dp'), (None, ['QF', 'QD'], 'ms'))
def pfp_ring(state, dp, dms):
return ring(state + evaluate(pfp, [dp]), dp, dms) - evaluate(pfp, [dp])
def tune(dp, dms):
matrix = torch.func.jacrev(pfp_ring)(state, dp, dms)
tunes, *_ = twiss(matrix)
return tunes
def chromaticity(dms):
return torch.func.jacrev(tune)(dp, dms)
state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
dms = torch.tensor([0.0, 0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)
print(tune(dp, dms))
print(chromaticity(dms).squeeze())
print(torch.func.jacrev(chromaticity)(dms).squeeze())
tensor([0.6951, 0.7019], dtype=torch.float64)
tensor([-2.0649, -0.8260], dtype=torch.float64)
tensor([[ 25.8500, 1.0470],
[ -9.0271, -16.4821]], dtype=torch.float64)
[11]:
# The above examples demonstrate how to bind tensors to all leafs or to given elements (in all lines)
# (None, None, parameter:str) -- bind tensor to all leaf parameters
# (None, names:list[str], parameter:str) -- bind tensor to all leaf parameters in specified elements
# (path:list[str], names:list[str], parameter:str) -- bind tensor to all leaf parameters in specified elements in given path (path to specific line)
[12]:
# Bind QF and QD in all sublines of a given line (1/2 of sextupoles)
ring = wrapper(RING, (None, None, 'dp'), (['LINE_AB'], ['QF', 'QD'], 'ms'))
def pfp_ring(state, dp, dms):
return ring(state + evaluate(pfp, [dp]), dp, dms) - evaluate(pfp, [dp])
def tune(dp, dms):
matrix = torch.func.jacrev(pfp_ring)(state, dp, dms)
tunes, *_ = twiss(matrix)
return tunes
def chromaticity(dms):
return torch.func.jacrev(tune)(dp, dms)
state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
dms = torch.tensor([0.0, 0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)
print(tune(dp, dms))
print(chromaticity(dms).squeeze())
print(2*torch.func.jacrev(chromaticity)(dms).squeeze())
tensor([0.6951, 0.7019], dtype=torch.float64)
tensor([-2.0649, -0.8260], dtype=torch.float64)
tensor([[ 25.8500, 1.0470],
[ -9.0271, -16.4821]], dtype=torch.float64)
[13]:
# Bind QF and QD in a given leaf line (1/4 of sextupoles)
ring = wrapper(RING, (None, None, 'dp'), (['LINE_AB', 'FODO_A'], ['QF', 'QD'], 'ms'))
def pfp_ring(state, dp, dms):
return ring(state + evaluate(pfp, [dp]), dp, dms) - evaluate(pfp, [dp])
def tune(dp, dms):
matrix = torch.func.jacrev(pfp_ring)(state, dp, dms)
tunes, *_ = twiss(matrix)
return tunes
def chromaticity(dms):
return torch.func.jacrev(tune)(dp, dms)
state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
dms = torch.tensor([0.0, 0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)
print(tune(dp, dms))
print(chromaticity(dms).squeeze())
print(4*torch.func.jacrev(chromaticity)(dms).squeeze())
tensor([0.6951, 0.7019], dtype=torch.float64)
tensor([-2.0649, -0.8260], dtype=torch.float64)
tensor([[ 25.8500, 1.0470],
[ -9.0271, -16.4821]], dtype=torch.float64)
[14]:
# Several sextupole groups
ring = wrapper(RING, (None, None, 'dp'), (['LINE_AB'], ['QF', 'QD'], 'ms'), (['LINE_CD'], ['QF', 'QD'], 'ms'))
def pfp_ring(state, dp, dms_ab, dms_cd):
return ring(state + evaluate(pfp, [dp]), dp, dms_ab, dms_cd) - evaluate(pfp, [dp])
def tune(dp, dms_ab, dms_cd):
matrix = torch.func.jacrev(pfp_ring)(state, dp, dms_ab, dms_cd)
tunes, *_ = twiss(matrix)
return tunes
def chromaticity(dms_ab, dms_cd):
return torch.func.jacrev(tune)(dp, dms_ab, dms_cd)
state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
dms_ab = torch.tensor([0.0, 0.0], dtype=torch.float64)
dms_cd = torch.tensor([0.0, 0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)
print(tune(dp, dms_ab, dms_cd))
print(chromaticity(dms_ab, dms_cd).squeeze())
print(torch.func.jacrev(chromaticity, 0)(dms_ab, dms_cd).squeeze())
print(torch.func.jacrev(chromaticity, 1)(dms_ab, dms_cd).squeeze())
def fn(dms):
dms_ab, dms_cd = dms
return chromaticity(dms_ab, dms_cd)
print(torch.func.jacrev(fn)(torch.stack([dms_ab, dms_cd])).squeeze())
tensor([0.6951, 0.7019], dtype=torch.float64)
tensor([-2.0649, -0.8260], dtype=torch.float64)
tensor([[12.9250, 0.5235],
[-4.5135, -8.2411]], dtype=torch.float64)
tensor([[12.9250, 0.5235],
[-4.5135, -8.2411]], dtype=torch.float64)
tensor([[[12.9250, 0.5235],
[12.9250, 0.5235]],
[[-4.5135, -8.2411],
[-4.5135, -8.2411]]], dtype=torch.float64)