Example-09: Multipole (element)
[1]:
# Comparison of sextupole element with MADX-PTC and other features
[2]:
from pathlib import Path
from os import system
import torch
from model.library.multipole import Multipole
[3]:
# Tracking (paraxial)
ptc = Path('ptc')
obs = Path('track.obs0001.p0001')
exact = False
align = False
kn = - 2.0
ks = + 1.5
ms = 25.0
mo = 110.0
dp = 0.005
length = 0.5
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
qx, px, qy, py = state.tolist()
dx = align*torch.tensor(0.05, dtype=torch.float64)
dy = align*torch.tensor(-0.02, dtype=torch.float64)
dz = align*torch.tensor(0.05, dtype=torch.float64)
wx = align*torch.tensor(0.005, dtype=torch.float64)
wy = align*torch.tensor(-0.005, dtype=torch.float64)
wz = align*torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
code = f"""
mag:quadrupole,l={length},knl={{0.0,{kn*length},{ms*length},{mo*length}}},ksl={{0.0,{ks*length}}};
map:line=(mag) ;
beam,energy=1.0E+6,particle=electron ;
set,format="20.20f","-20s" ;
use,period=map ;
select,flag=error,pattern="mag" ;
ealign,dx={dx.item()},dy={dy.item()},ds={dz.item()},dphi={wx.item()},dtheta={wy.item()},dpsi={wz.item()} ;
ptc_create_universe,sector_nmul_max=10,sector_nmul=10 ;
ptc_create_layout,model=1,method=6,nst=1000,exact={str(exact).lower()} ;
ptc_setswitch,fringe=false,time=true,totalpath=true,exact_mis=true ;
ptc_align ;
ptc_start,x={qx},px={px},y={qy},py={py},pt={dp},t=0.0 ;
ptc_track,icase=5,deltap=0.,turns=1,file=track,maxaper={{1.,1.,1.,1.,1.,1.}} ;
ptc_track_end ;
ptc_end ;
"""
with ptc.open('w') as stream:
stream.write(code)
system(f'madx < {str(ptc)} > /dev/null')
with obs.open('r') as stream:
for line in stream:
continue
_, _, qx, px, qy, py, *_ = line.split()
ref = torch.tensor([float(x) for x in (qx, px, qy, py)], dtype=torch.float64)
M = Multipole('M', length=length, kn=kn, ks=ks, ms=ms, mo=mo, dp=dp, exact=exact, order=5, ns=10)
res = M(state, alignment=align, data={**M.data(), **error})
print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())
ptc.unlink()
obs.unlink()
[0.00892946929535585, 0.0009165531774213284, -0.0017930189940976037, 0.01128986832436645]
[0.008929469295355376, 0.0009165531774212371, -0.001793018994097562, 0.011289868324366148]
[4.735795089416683e-16, 9.128982292327947e-17, -4.163336342344337e-17, 3.0184188481996443e-16]
[4]:
# Tracking (exact)
ptc = Path('ptc')
obs = Path('track.obs0001.p0001')
exact = True
align = False
kn = - 2.0
ks = + 1.5
ms = 25.0
mo = 110.0
dp = 0.005
length = 0.5
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
qx, px, qy, py = state.tolist()
dx = align*torch.tensor(0.05, dtype=torch.float64)
dy = align*torch.tensor(-0.02, dtype=torch.float64)
dz = align*torch.tensor(0.05, dtype=torch.float64)
wx = align*torch.tensor(0.005, dtype=torch.float64)
wy = align*torch.tensor(-0.005, dtype=torch.float64)
wz = align*torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
code = f"""
mag:quadrupole,l={length},knl={{0.0,{kn*length},{ms*length},{mo*length}}},ksl={{0.0,{ks*length}}};
map:line=(mag) ;
beam,energy=1.0E+6,particle=electron ;
set,format="20.20f","-20s" ;
use,period=map ;
select,flag=error,pattern="mag" ;
ealign,dx={dx.item()},dy={dy.item()},ds={dz.item()},dphi={wx.item()},dtheta={wy.item()},dpsi={wz.item()} ;
ptc_create_universe,sector_nmul_max=10,sector_nmul=10 ;
ptc_create_layout,model=1,method=6,nst=1000,exact={str(exact).lower()} ;
ptc_setswitch,fringe=false,time=true,totalpath=true,exact_mis=true ;
ptc_align ;
ptc_start,x={qx},px={px},y={qy},py={py},pt={dp},t=0.0 ;
ptc_track,icase=5,deltap=0.,turns=1,file=track,maxaper={{1.,1.,1.,1.,1.,1.}} ;
ptc_track_end ;
ptc_end ;
"""
with ptc.open('w') as stream:
stream.write(code)
system(f'madx < {str(ptc)} > /dev/null')
with obs.open('r') as stream:
for line in stream:
continue
_, _, qx, px, qy, py, *_ = line.split()
ref = torch.tensor([float(x) for x in (qx, px, qy, py)], dtype=torch.float64)
M = Multipole('M', length=length, kn=kn, ks=ks, ms=ms, mo=mo, dp=dp, exact=exact, order=5, ns=10)
res = M(state, alignment=align, data={**M.data(), **error})
print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())
ptc.unlink()
obs.unlink()
[0.00892945162589142, 0.0009165626425029898, -0.0017929090323601758, 0.011289834389170184]
[0.008929451625891005, 0.0009165626425029993, -0.001792909032360122, 0.01128983438916985]
[4.145989107584569e-16, -9.432558900623889e-18, -5.377642775528102e-17, 3.3480163086352377e-16]
[5]:
# Tracking (exact, alignment)
ptc = Path('ptc')
obs = Path('track.obs0001.p0001')
exact = True
align = True
kn = - 2.0
ks = + 1.5
ms = 25.0
mo = 110.0
dp = 0.005
length = 0.5
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
qx, px, qy, py = state.tolist()
dx = align*torch.tensor(0.05, dtype=torch.float64)
dy = align*torch.tensor(-0.02, dtype=torch.float64)
dz = align*torch.tensor(0.05, dtype=torch.float64)
wx = align*torch.tensor(0.005, dtype=torch.float64)
wy = align*torch.tensor(-0.005, dtype=torch.float64)
wz = align*torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
code = f"""
mag:quadrupole,l={length},knl={{0.0,{kn*length},{ms*length},{mo*length}}},ksl={{0.0,{ks*length}}};
map:line=(mag) ;
beam,energy=1.0E+6,particle=electron ;
set,format="20.20f","-20s" ;
use,period=map ;
select,flag=error,pattern="mag" ;
ealign,dx={dx.item()},dy={dy.item()},ds={dz.item()},dphi={wx.item()},dtheta={wy.item()},dpsi={wz.item()} ;
ptc_create_universe,sector_nmul_max=10,sector_nmul=10 ;
ptc_create_layout,model=1,method=6,nst=1000,exact={str(exact).lower()} ;
ptc_setswitch,fringe=false,time=true,totalpath=true,exact_mis=true ;
ptc_align ;
ptc_start,x={qx},px={px},y={qy},py={py},pt={dp},t=0.0 ;
ptc_track,icase=5,deltap=0.,turns=1,file=track,maxaper={{1.,1.,1.,1.,1.,1.}} ;
ptc_track_end ;
ptc_end ;
"""
with ptc.open('w') as stream:
stream.write(code)
system(f'madx < {str(ptc)} > /dev/null')
with obs.open('r') as stream:
for line in stream:
continue
_, _, qx, px, qy, py, *_ = line.split()
ref = torch.tensor([float(x) for x in (qx, px, qy, py)], dtype=torch.float64)
M = Multipole('M', length=length, kn=kn, ks=ks, ms=ms, mo=mo, dp=dp, exact=exact, order=5, ns=10)
res = M(state, alignment=align, data={**M.data(), **error})
print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())
ptc.unlink()
obs.unlink()
[0.001664975746516068, -0.04016261215593401, -0.015906155451718536, -0.05363979979736713]
[0.0016649757465192622, -0.04016261215593149, -0.015906155451718116, -0.053639799797364655]
[-3.194059600142296e-15, -2.518818487118324e-15, -4.198030811863873e-16, -2.4771851236948805e-15]
[6]:
# Deviation/error variables
kn = - 2.0
ks = + 1.5
ms = 25.0
mo = 110.0
mo = 50.0
dp = 0.005
length = 0.2
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)
wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
M = Multipole('M', length, kn, ks, ms, mo, dp)
# Each element has two variant of a call method
# In the first case only state is passed, it is transformed using parameters specified on initializaton
print(M(state))
print()
# Deviation errors can be also passed to call method
# These variables are added to corresponding parameters specified on initializaton
# For example, element lenght can changed
print(M(state, data={**M.data(), **{'dl': -M.length}}))
print()
# In the above M.data() creates default deviation dictionary (with zero values for each deviaton)
# {**M.data(), **{'dl': -M.length}} replaces the 'dl' key value
# Additionaly, alignment errors are passed as deivation variables
# They are used if alignment flag is raised
print(M(state, data={**M.data(), **error}, alignment=True))
print()
# The following elements can be made equivalent using deviation variables
MA = Multipole('MA', length, kn, ks, ms, mo, dp)
MB = Multipole('MB', length - 0.1, kn, ks, ms, mo, dp)
print(MA(state) - MB(state, data={**MB.data(), **{'dl': torch.tensor(+0.1, dtype=MB.dtype)}}))
# Note, while in some cases float values can be passed as values to deviation variables
# The correct behaviour in guaranteed only for tensors
tensor([ 0.0092, -0.0028, -0.0043, 0.0055], dtype=torch.float64)
tensor([ 0.0100, -0.0050, -0.0050, 0.0010], dtype=torch.float64)
tensor([ 0.0085, -0.0159, -0.0060, -0.0224], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[7]:
# Insertion element
# In this mode elements are treated as thin insertions (at the center)
# Using parameters specified on initialization, transport two matrices are computed
# These matrices are used to insert the element
# Input state is transformed from the element center to its entrance
# Next, transformation from the entrance frame to the exit frame is performed
# This transformation can contain errors
# The final step is to transform state from the exit frame back to the element center
# Without errors, this results in identity transformation for linear elements
kn = - 2.0
ks = + 1.5
ms = 25.0
mo = 110.0
mo = 50.0
dp = 0.005
length = 0.2
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)
wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
M = Multipole('M', length, kn, ks, ms, mo, dp, exact=False, insertion=True)
# Since multipole is a nonlinear element (non-zero sextupole or octupole)
# Insertion is an identity transformation only for zero strenght
print(M(state) - state)
print(M(state, data={**M.data(), **{'ms': -ms, 'mo': -mo}}) - state)
# Represents effect of an error (any nonzero value of strengh or a change in other parameter)
print(M(state, data={**M.data(), **{'dl': 0.1}}) - state)
# Exact tracking corresponds to inclusion of kinematic term as errors
M = Multipole('M', length, kn, ks, ms, mo, dp, exact=True, insertion=True, ns=20, order=1)
print(M(state) - state)
tensor([ 6.9389e-18, -1.8792e-04, 8.6736e-19, -2.5229e-04],
dtype=torch.float64)
tensor([ 6.9389e-18, -4.3368e-18, 8.6736e-19, -8.6736e-19],
dtype=torch.float64)
tensor([-0.0004, 0.0009, 0.0002, 0.0021], dtype=torch.float64)
tensor([-7.9785e-07, -1.9093e-04, -5.6642e-07, -2.5071e-04],
dtype=torch.float64)
[8]:
# Mapping over a set of initial conditions
# Call method can be used to map over a set of initial conditions
# Note, device can be set to cpu or gpu via base element classvariables
kn = - 2.0
ks = + 1.5
ms = 25.0
mo = 110.0
mo = 50.0
dp = 0.005
length = 0.2
dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)
wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
M = Multipole('M', length, kn, ks, ms, mo, dp, exact=True)
state = 1.0E-3*torch.randn((512, 4), dtype=M.dtype, device=M.device)
print(torch.vmap(M)(state).shape)
# To map over deviations parameters a wrapper function (or a lambda expression) can be used
def wrapper(state, dp):
return M(state, data={**M.data(), **{'dp': dp}})
dp = 1.0E-3*torch.randn(512, dtype=M.dtype, device=M.device)
print(torch.vmap(wrapper)(state, dp).shape)
torch.Size([512, 4])
torch.Size([512, 4])
[9]:
# Differentiability
# Both call methods are differentiable
# Derivative with respect to state can be computed directly
# For deviation variables, wrapping is required
kn = - 2.0
ks = + 1.5
ms = 25.0
mo = 110.0
mo = 50.0
dp = 0.005
length = 0.2
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)
wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
M = Multipole('M', length, kn, ks, ms, mo, dp, exact=False)
# Compute derivative with respect to state
print(torch.func.jacrev(M)(state))
print()
# Compute derivative with respect to a deviation variable
kn = torch.tensor(0.0, dtype=torch.float64)
def wrapper(state, kn):
return M(state, data={**M.data(), **{'kn': kn}})
print(torch.func.jacrev(wrapper, 1)(state, kn))
print()
tensor([[ 1.0353, 0.2012, 0.0274, 0.0017],
[ 0.3588, 1.0353, 0.2757, 0.0274],
[ 0.0274, 0.0017, 0.9653, 0.1969],
[ 0.2757, 0.0274, -0.3449, 0.9653]], dtype=torch.float64)
tensor([-1.9468e-04, -1.9487e-03, -9.6920e-05, -9.5528e-04],
dtype=torch.float64)
[10]:
# Output at each step
# It is possible to collect output of state or tangent matrix at each integration step
# Number of integratin steps is controlled by ns parameter on initialization
# Alternatively, desired integration step length can be passed
# Number of integration steps is computed as ceil(length/ds)
kn = - 2.0
ks = + 1.5
ms = 25.0
mo = 110.0
mo = 50.0
dp = 0.005
length = 0.2
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)
wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
M = Multipole('M', length, kn, ks, ms, mo, dp, exact=False, ns=10, output=True, matrix=True)
# Final state is still returned
print(M(state))
# Data is added to special attributes (state and tangent matrix)
print(M.container_output.shape)
print(M.container_matrix.shape)
# Number of integration steps can be changed
M.ns = 100
M(state)
print(M.container_output.shape)
print(M.container_matrix.shape)
tensor([ 0.0092, -0.0028, -0.0043, 0.0055], dtype=torch.float64)
torch.Size([10, 4])
torch.Size([10, 4, 4])
torch.Size([100, 4])
torch.Size([100, 4, 4])
[11]:
# Integration order is set on initialization (default value is zero)
# This order is related to difference order as 2n + 2
# Thus, zero corresponds to second order difference method
kn = - 2.0
ks = + 1.5
ms = 25.0
mo = 110.0
mo = 50.0
dp = 0.005
length = 0.2
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)
wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
M = Multipole('M', length, kn, ks, ms, mo, dp, order=0, exact=True)
# For multipole with non-zero sextupole and/or octupole integration is always performed
# In exact case, kinematic term error is added
M.ns = 10
ref = M(state)
M.ns = 100
res = M(state)
print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())
print()
# Integrator parameters are stored in data attribute (if integration is actually performed)
maps, weights = M._data
print(maps)
print(weights)
[0.009228705737231328, -0.002766282719416843, -0.004341639948599635, 0.005542200849934217]
[0.009228699312521686, -0.0027663108074714475, -0.004341646792491436, 0.00554221421387834]
[6.424709642766091e-09, 2.808805460441377e-08, 6.843891801368296e-09, -1.3363944122331273e-08]
[0, 1, 2, 1, 0]
[0.5, 0.5, 1.0, 0.5, 0.5]