Example-13: Kick (element)
[1]:
# Comparison of kick element with MADX-PTC and other features
[2]:
from pathlib import Path
from os import system
import torch
from model.library.kick import Kick
[3]:
# Tracking
ptc = Path('ptc')
obs = Path('track.obs0001.p0001')
exact = True
align = False
ms = +10.0
mo = -50.0
dp = 0.005
state = torch.tensor([0.001, 0.0, -0.005, 0.0], dtype=torch.float64)
qx, px, qy, py = state.tolist()
dx = align*torch.tensor(0.05, dtype=torch.float64)
dy = align*torch.tensor(-0.02, dtype=torch.float64)
dz = align*torch.tensor(0.05, dtype=torch.float64)
wx = align*torch.tensor(0.005, dtype=torch.float64)
wy = align*torch.tensor(-0.005, dtype=torch.float64)
wz = align*torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
code = f"""
mag: multipole,knl={{0.0,0.0,{ms},{mo}}};
map:line=(mag) ;
beam,energy=1.0E+6,particle=electron ;
set,format="20.20f","-20s" ;
use,period=map ;
select,flag=error,pattern="mag" ;
ealign,dx={dx.item()},dy={dy.item()},ds={dz.item()},dphi={wx.item()},dtheta={wy.item()},dpsi={wz.item()} ;
ptc_create_universe,sector_nmul_max=10,sector_nmul=10 ;
ptc_create_layout,model=1,method=6,nst=1000,exact={str(exact).lower()} ;
ptc_setswitch,fringe=false,time=true,totalpath=true,exact_mis=true ;
ptc_align ;
ptc_start,x={qx},px={px},y={qy},py={py},pt={dp},t=0.0 ;
ptc_track,icase=5,deltap=0.,turns=1,file=track,maxaper={{1.,1.,1.,1.,1.,1.}} ;
ptc_track_end ;
ptc_end ;
"""
with ptc.open('w') as stream:
stream.write(code)
system(f'madx < {str(ptc)} > /dev/null')
with obs.open('r') as stream:
for line in stream:
continue
_, _, qx, px, qy, py, *_ = line.split()
ref = torch.tensor([float(x) for x in (qx, px, qy, py)], dtype=torch.float64)
K = Kick('K', ms=ms, mo=mo, dp=dp)
res = K(state, alignment=align, data={**K.data(), **error})
print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())
ptc.unlink()
obs.unlink()
[0.001, 0.00011938333333333334, -0.005, -5.091666666666667e-05]
[0.001, 0.00011938333333333334, -0.005, -5.0916666666666666e-05]
[0.0, 0.0, 0.0, -6.776263578034403e-21]
[4]:
# Tracking (alignment)
ptc = Path('ptc')
obs = Path('track.obs0001.p0001')
exact = True
align = True
ms = +10.0
mo = -50.0
dp = 0.005
state = torch.tensor([0.001, 0.0, -0.005, 0.0], dtype=torch.float64)
qx, px, qy, py = state.tolist()
dx = align*torch.tensor(0.05, dtype=torch.float64)
dy = align*torch.tensor(-0.02, dtype=torch.float64)
dz = align*torch.tensor(0.05, dtype=torch.float64)
wx = align*torch.tensor(0.005, dtype=torch.float64)
wy = align*torch.tensor(-0.005, dtype=torch.float64)
wz = align*torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
code = f"""
mag: multipole,knl={{0.0,0.0,{ms},{mo}}};
map:line=(mag) ;
beam,energy=1.0E+6,particle=electron ;
set,format="20.20f","-20s" ;
use,period=map ;
select,flag=error,pattern="mag" ;
ealign,dx={dx.item()},dy={dy.item()},ds={dz.item()},dphi={wx.item()},dtheta={wy.item()},dpsi={wz.item()} ;
ptc_create_universe,sector_nmul_max=10,sector_nmul=10 ;
ptc_create_layout,model=1,method=6,nst=1000,exact={str(exact).lower()} ;
ptc_setswitch,fringe=false,time=true,totalpath=true,exact_mis=true ;
ptc_align ;
ptc_start,x={qx},px={px},y={qy},py={py},pt={dp},t=0.0 ;
ptc_track,icase=5,deltap=0.,turns=1,file=track,maxaper={{1.,1.,1.,1.,1.,1.}} ;
ptc_track_end ;
ptc_end ;
"""
with ptc.open('w') as stream:
stream.write(code)
system(f'madx < {str(ptc)} > /dev/null')
with obs.open('r') as stream:
for line in stream:
continue
_, _, qx, px, qy, py, *_ = line.split()
ref = torch.tensor([float(x) for x in (qx, px, qy, py)], dtype=torch.float64)
K = Kick('K', ms=ms, mo=mo, dp=dp)
res = K(state, alignment=align, data={**K.data(), **error})
print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())
ptc.unlink()
obs.unlink()
[0.0014217329730612119, -0.00853058669498686, -0.004440545938333818, -0.011316334457474285]
[0.0014217329730612169, -0.008530586694986856, -0.004440545938333821, -0.011316334457474281]
[-4.9873299934333204e-18, -3.469446951953614e-18, 2.6020852139652106e-18, -3.469446951953614e-18]
[5]:
# Deviation/error variables
ms = +10.0
mo = -50.0
dp = 0.005
state = torch.tensor([0.001, 0.0, -0.005, 0.0], dtype=torch.float64)
dx = torch.tensor(+0.01, dtype=torch.float64)
dy = torch.tensor(-0.01, dtype=torch.float64)
dz = torch.tensor(0.0, dtype=torch.float64)
wx = torch.tensor(0.0, dtype=torch.float64)
wy = torch.tensor(0.0, dtype=torch.float64)
wz = torch.tensor(torch.pi, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
K = Kick('K', ms, mo, dp)
# Each element has two variant of a call method
# In the first case only state is passed, it is transformed using parameters specified on initializaton
print(K(state))
print()
# Deviation errors can be also passed to call method
# These variables are added to corresponding parameters specified on initializaton
# For example, element lenght can changed
print(K(state, data={**K.data(), **{'ms': -ms, 'mo': -mo}}))
print()
# In the above K.data() creates default deviation dictionary (with zero values for each deviaton)
# {**K.data(), **{'ms': -ms, 'mo': -mo}} replaces the 'ms' and 'mo' key values
# Additionaly, alignment errors are passed as deivation variables
# They are used if alignment flag is raised
print(K(state, data={**K.data(), **error}, alignment=True))
print()
# The following elements can be made equivalent using deviation variables
KA = Kick('KA', ms, mo, dp)
KB = Kick('KB', ms - 0.1, mo, dp)
print(KA(state) - KB(state, data={**KB.data(), **{'ms': + 0.1}}))
tensor([ 1.0000e-03, 1.1938e-04, -5.0000e-03, -5.0917e-05],
dtype=torch.float64)
tensor([ 0.0010, 0.0000, -0.0050, 0.0000], dtype=torch.float64)
tensor([ 0.0010, 0.0003, -0.0050, 0.0004], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[6]:
# Mapping over a set of initial conditions
# Call method can be used to map over a set of initial conditions
# Note, device can be set to cpu or gpu via base element classvariables
ms = +10.0
mo = -50.0
dp = 0.005
dx = torch.tensor(+0.01, dtype=torch.float64)
dy = torch.tensor(-0.01, dtype=torch.float64)
dz = torch.tensor(0.0, dtype=torch.float64)
wx = torch.tensor(0.0, dtype=torch.float64)
wy = torch.tensor(0.0, dtype=torch.float64)
wz = torch.tensor(torch.pi, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
K = Kick('K', ms, mo, dp)
state = 1.0E-3*torch.randn((512, 4), dtype=K.dtype, device=K.device)
print(torch.vmap(K)(state).shape)
# To map over deviations parameters a wrapper function (or a lambda expression) can be used
def wrapper(state, ms, mo):
return K(state, data={**K.data(), **{'ms': ms, 'mo': mo}})
kn = 1.0E-3*torch.randn(512, dtype=K.dtype, device=K.device)
ks = 1.0E-3*torch.randn(512, dtype=K.dtype, device=K.device)
print(torch.vmap(wrapper)(state, kn, ks).shape)
torch.Size([512, 4])
torch.Size([512, 4])
[7]:
# Differentiability
# Both call methods are differentiable
# Derivative with respect to state can be computed directly
# For deviation variables, wrapping is required
ms = +10.0
mo = -50.0
dp = 0.005
state = torch.tensor([0.001, 0.0, -0.005, 0.0], dtype=torch.float64)
dx = torch.tensor(+0.01, dtype=torch.float64)
dy = torch.tensor(-0.01, dtype=torch.float64)
dz = torch.tensor(0.0, dtype=torch.float64)
wx = torch.tensor(0.0, dtype=torch.float64)
wy = torch.tensor(0.0, dtype=torch.float64)
wz = torch.tensor(torch.pi, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
K = Kick('K', ms, mo, dp)
# Compute derivative with respect to state
print(torch.func.jacrev(K)(state))
print()
# Compute derivative with respect to a deviation variable
dms = torch.tensor(0.0, dtype=torch.float64)
dmo = torch.tensor(0.0, dtype=torch.float64)
dm = torch.stack([dms, dmo])
def wrapper(state, dm):
dms, dmo = dm
return K(state, data={**K.data(), **{'ms': dms, 'mo': dmo}})
print(torch.func.jacrev(wrapper, 1)(state, dm))
print()
tensor([[ 1.0000, 0.0000, 0.0000, 0.0000],
[-0.0106, 1.0000, -0.0498, 0.0000],
[ 0.0000, 0.0000, 1.0000, 0.0000],
[-0.0498, 0.0000, 0.0106, 1.0000]], dtype=torch.float64)
tensor([[ 0.0000e+00, 0.0000e+00],
[ 1.2000e-05, 1.2333e-08],
[ 0.0000e+00, 0.0000e+00],
[-5.0000e-06, 1.8333e-08]], dtype=torch.float64)