Example-05: Drift (element)
[1]:
# Comparison of drift element with MADX-PTC and other features
[2]:
from pathlib import Path
from os import system
import torch
from model.library.drift import Drift
[3]:
# Tracking (paraxial)
ptc = Path('ptc')
obs = Path('track.obs0001.p0001')
exact = False
align = False
dp = 0.005
length = 1.5
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
qx, px, qy, py = state.tolist()
dx = align*torch.tensor(0.05, dtype=torch.float64)
dy = align*torch.tensor(-0.02, dtype=torch.float64)
dz = align*torch.tensor(0.05, dtype=torch.float64)
wx = align*torch.tensor(0.005, dtype=torch.float64)
wy = align*torch.tensor(-0.005, dtype=torch.float64)
wz = align*torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
code = f"""
mag:drift,l={length} ;
map:line=(mag) ;
beam,energy=1.0E+6,particle=electron ;
set,format="20.20f","-20s" ;
use,period=map ;
select,flag=error,pattern="mag" ;
ealign,dx={dx.item()},dy={dy.item()},ds={dz.item()},dphi={wx.item()},dtheta={wy.item()},dpsi={wz.item()} ;
ptc_create_universe,sector_nmul_max=10,sector_nmul=10 ;
ptc_create_layout,model=1,method=6,nst=1000,exact={str(exact).lower()} ;
ptc_setswitch,fringe=false,time=true,totalpath=true,exact_mis=true ;
ptc_align ;
ptc_start,x={qx},px={px},y={qy},py={py},pt={dp},t=0.0 ;
ptc_track,icase=5,deltap=0.,turns=1,file=track,maxaper={{1.,1.,1.,1.,1.,1.}} ;
ptc_track_end ;
ptc_end ;
"""
with ptc.open('w') as stream:
stream.write(code)
system(f'madx < {str(ptc)} > /dev/null')
with obs.open('r') as stream:
for line in stream:
continue
_, _, qx, px, qy, py, *_ = line.split()
ref = torch.tensor([float(x) for x in (qx, px, qy, py)], dtype=torch.float64)
D = Drift('D', length=length, dp=dp, exact=exact)
res = D(state, alignment=align, data={**D.data(), **error})
print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())
ptc.unlink()
obs.unlink()
[0.0025373134328358204, -0.005, -0.003507462686567164, 0.001]
[0.0025373134328358204, -0.005, -0.0035074626865671645, 0.001]
[0.0, 0.0, 4.336808689942018e-19, 0.0]
[4]:
# Tracking (exact)
ptc = Path('ptc')
obs = Path('track.obs0001.p0001')
exact = True
align = False
dp = 0.005
length = 1.5
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
qx, px, qy, py = state.tolist()
dx = align*torch.tensor(0.05, dtype=torch.float64)
dy = align*torch.tensor(-0.02, dtype=torch.float64)
dz = align*torch.tensor(0.05, dtype=torch.float64)
wx = align*torch.tensor(0.005, dtype=torch.float64)
wy = align*torch.tensor(-0.005, dtype=torch.float64)
wz = align*torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
code = f"""
mag:drift,l={length} ;
map:line=(mag) ;
beam,energy=1.0E+6,particle=electron ;
set,format="20.20f","-20s" ;
use,period=map ;
select,flag=error,pattern="mag" ;
ealign,dx={dx.item()},dy={dy.item()},ds={dz.item()},dphi={wx.item()},dtheta={wy.item()},dpsi={wz.item()} ;
ptc_create_universe,sector_nmul_max=10,sector_nmul=10 ;
ptc_create_layout,model=1,method=6,nst=1000,exact={str(exact).lower()} ;
ptc_setswitch,fringe=false,time=true,totalpath=true,exact_mis=true ;
ptc_align ;
ptc_start,x={qx},px={px},y={qy},py={py},pt={dp},t=0.0 ;
ptc_track,icase=5,deltap=0.,turns=1,file=track,maxaper={{1.,1.,1.,1.,1.,1.}} ;
ptc_track_end ;
ptc_end ;
"""
with ptc.open('w') as stream:
stream.write(code)
system(f'madx < {str(ptc)} > /dev/null')
with obs.open('r') as stream:
for line in stream:
continue
_, _, qx, px, qy, py, *_ = line.split()
ref = torch.tensor([float(x) for x in (qx, px, qy, py)], dtype=torch.float64)
D = Drift('D', length=length, dp=dp, exact=exact)
res = D(state, alignment=align, data={**D.data(), **error})
print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())
ptc.unlink()
obs.unlink()
[0.002537217378977325, -0.005, -0.003507443475795465, 0.001]
[0.002537217378977325, -0.005, -0.0035074434757954654, 0.001]
[0.0, 0.0, 4.336808689942018e-19, 0.0]
[5]:
# Tracking (exact, alignment)
ptc = Path('ptc')
obs = Path('track.obs0001.p0001')
exact = False
align = True
dp = 0.005
length = 1.5
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
qx, px, qy, py = state.tolist()
dx = align*torch.tensor(0.05, dtype=torch.float64)
dy = align*torch.tensor(-0.02, dtype=torch.float64)
dz = align*torch.tensor(0.05, dtype=torch.float64)
wx = align*torch.tensor(0.005, dtype=torch.float64)
wy = align*torch.tensor(-0.005, dtype=torch.float64)
wz = align*torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
code = f"""
mag:drift,l={length} ;
map:line=(mag) ;
beam,energy=1.0E+6,particle=electron ;
set,format="20.20f","-20s" ;
use,period=map ;
select,flag=error,pattern="mag" ;
ealign,dx={dx.item()},dy={dy.item()},ds={dz.item()},dphi={wx.item()},dtheta={wy.item()},dpsi={wz.item()} ;
ptc_create_universe,sector_nmul_max=10,sector_nmul=10 ;
ptc_create_layout,model=1,method=6,nst=1000,exact={str(exact).lower()} ;
ptc_setswitch,fringe=false,time=true,totalpath=true,exact_mis=true ;
ptc_align ;
ptc_start,x={qx},px={px},y={qy},py={py},pt={dp},t=0.0 ;
ptc_track,icase=5,deltap=0.,turns=1,file=track,maxaper={{1.,1.,1.,1.,1.,1.}} ;
ptc_track_end ;
ptc_end ;
"""
with ptc.open('w') as stream:
stream.write(code)
system(f'madx < {str(ptc)} > /dev/null')
with obs.open('r') as stream:
for line in stream:
continue
_, _, qx, px, qy, py, *_ = line.split()
ref = torch.tensor([float(x) for x in (qx, px, qy, py)], dtype=torch.float64)
D = Drift('D', length=length, dp=dp, exact=exact)
res = D(state, alignment=align, data={**D.data(), **error})
print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())
ptc.unlink()
obs.unlink()
# Note, for some reason drift is not invariant under WX ans WY rotations in MADX
[0.0025372170792497145, -0.0049999999999999975, -0.003507395297032926, 0.0010000000000000009]
[0.0025372170792497166, -0.004999999999999999, -0.0035073952970329247, 0.0010000000000000018]
[-2.168404344971009e-18, 1.734723475976807e-18, -1.3010426069826053e-18, -8.673617379884035e-19]
[6]:
# Deviation/error variables
dp = 0.005
length = 1.5
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)
wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
D = Drift('D', length, dp)
# Each element has two variant of a call method
# In the first case only state is passed, it is transformed using parameters specified on initializaton
print(D(state))
print()
# Deviation errors can be also passed to call method
# These variables are added to corresponding parameters specified on initializaton
# For example, element lenght can changed
print(D(state, data={**D.data(), **{'dl': torch.tensor(-length, dtype=D.dtype)}}))
print()
# In the above D.data() creates default deviation dictionary (with zero values for each deviaton)
# {**D.data(), **{'dl': torch.tensor(-length, dtype=DR.dtype)}} replaces the 'dl' key value
# Additionaly, alignment errors are passed as deivation variables
# They are used if alignment flag is raised
print(D(state, data={**D.data(), **error}, alignment=True))
print()
# The following elements can be made equivalent using deviation variables
DA = Drift('DA', length, dp)
DB = Drift('DB', length - 0.1, dp)
print(DA(state) - DB(state, data={**DB.data(), **{'dl': torch.tensor(+0.1, dtype=DB.dtype)}}))
tensor([ 0.0025, -0.0050, -0.0035, 0.0010], dtype=torch.float64)
tensor([ 0.0100, -0.0050, -0.0050, 0.0010], dtype=torch.float64)
tensor([ 0.0025, -0.0050, -0.0035, 0.0010], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[7]:
# Insertion element
# In this mode elements are treated as thin insertions (at the center)
# Using parameters specified on initialization, transport two matrices are computed
# These matrices are used to insert the element
# Input state is transformed from the element center to its entrance
# Next, transformation from the entrance frame to the exit frame is performed
# This transformation can contain errors
# The final step is to transform state from the exit frame back to the element center
# Without errors, this results in identity transformation for linear elements
dp = 0.0
length = 1.5
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)
wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
D = Drift('D', length, dp, exact=False, insertion=True)
# Identity transformation without errors
print(D(state) - state)
# Represents effect of an error
print(D(state, data={**D.data(), **{'dl': 0.1}}) - state)
# Exact tracking corresponds to inclusion of kinematic term as errors
D = Drift('D', length, dp, exact=True, insertion=True)
print(D(state) - state)
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([-0.0005, 0.0000, 0.0001, 0.0000], dtype=torch.float64)
tensor([-9.7502e-08, 0.0000e+00, 1.9500e-08, 0.0000e+00],
dtype=torch.float64)
[8]:
# Mapping over a set of initial conditions
# Call method can be used to map over a set of initial conditions
# Note, device can be set to cpu or gpu via base element classvariables
dp = 0.0
length = 1.5
dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)
wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
D = Drift('D', length, dp, exact=True)
state = 1.0E-3*torch.randn((512, 4), dtype=D.dtype, device=D.device)
print(torch.vmap(D)(state).shape)
# To map over deviations parameters a wrapper function (or a lambda expression) can be used
def wrapper(state, dp):
return D(state, data={**D.data(), **{'dp': dp}})
dp = 1.0E-3*torch.randn(512, dtype=D.dtype, device=D.device)
print(torch.vmap(wrapper)(state, dp).shape)
torch.Size([512, 4])
torch.Size([512, 4])
[9]:
# Differentiability
# Both call methods are differentiable
# Derivative with respect to state can be computed directly
# For deviation variables, wrapping is required
dp = 0.0
length = 1.5
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)
wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
D = Drift('D', length, dp, exact=False)
# Compute derivative with respect to state
print(torch.func.jacrev(D)(state))
print()
# Compute derivative with respect to a deviation variable
dl = torch.tensor(0.0, dtype=torch.float64)
def wrapper(state, dl):
return D(state, data={**D.data(), **{'dl': dl}})
print(torch.func.jacrev(wrapper, 1)(state, dl))
print()
# Compositional derivative (compute derivative of jacobian trace with respect momentum deviation)
dp = torch.tensor(0.0, dtype=torch.float64)
def trace(state, dp):
return (torch.func.jacrev(lambda state: D(state, data={**D.data(), **{'dp': dp}}))(state)).trace()
torch.func.jacrev(trace, 1)(state, dp)
tensor([[1.0000, 1.5000, 0.0000, 0.0000],
[0.0000, 1.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 1.0000, 1.5000],
[0.0000, 0.0000, 0.0000, 1.0000]], dtype=torch.float64)
tensor([-0.0050, 0.0000, 0.0010, 0.0000], dtype=torch.float64)
[9]:
tensor(-0., dtype=torch.float64)
[10]:
# Output at each step
# It is possible to collect output of state or tangent matrix at each integration step
# Number of integratin steps is controlled by ns parameter on initialization
# Alternatively, desired integration step length can be passed
# Number of integration steps is computed as ceil(length/ds)
dp = 0.0
length = 1.5
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)
wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
D = Drift('D', length, dp, exact=False, ns=10, output=True, matrix=True)
# Final state is still returned
print(D(state))
# Data is added to special attributes (state and tangent matrix)
print(D.container_output.shape)
print(D.container_matrix.shape)
# Number of integration steps can be changed
D.ns = 100
D(state)
print(D.container_output.shape)
print(D.container_matrix.shape)
tensor([ 0.0025, -0.0050, -0.0035, 0.0010], dtype=torch.float64)
torch.Size([10, 4])
torch.Size([10, 4, 4])
torch.Size([100, 4])
torch.Size([100, 4, 4])
[11]:
# Integration order is set on initialization (default value is zero)
# This order is related to difference order as 2n + 2
# Thus, zero corresponds to second order difference method
dp = 0.0
length = 1.5
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)
wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
D = Drift('D', length, dp, order=1, exact=True)
# For drift integration is performed only with exact flag
# In this case, kinematic term error is added
# This term actually commutes with paraxial drift map
# But integration is still performed for consistency with matrix-kick-matrix split
# Only one integration step is required to get exact result
D.ns = 1
ref = D(state)
D.ns = 10
res = D(state)
print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())
print()
# Integrator parameters are stored in data attribute (if integration is actually performed)
maps, weights = D._data
print(maps)
print(weights)
[0.002499902498098708, -0.005, -0.003499980499619743, 0.001]
[0.002499902498098709, -0.005, -0.0034999804996197477, 0.001]
[-8.673617379884035e-19, 0.0, 4.7704895589362195e-18, 0.0]
[0, 1, 0, 1, 0, 1, 0]
[0.6756035959798289, 1.3512071919596578, -0.17560359597982877, -1.7024143839193153, -0.17560359597982877, 1.3512071919596578, 0.6756035959798289]