Example-11: Corrector (element)
[1]:
# Comparison of corrector element with MADX-PTC and other features
[2]:
from pathlib import Path
from os import system
import torch
from model.library.corrector import Corrector
[3]:
# Tracking
ptc = Path('ptc')
obs = Path('track.obs0001.p0001')
exact = True
align = False
cx = +0.001
cy = -0.005
dp = 0.005
state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
qx, px, qy, py = state.tolist()
dx = align*torch.tensor(0.05, dtype=torch.float64)
dy = align*torch.tensor(-0.02, dtype=torch.float64)
dz = align*torch.tensor(0.05, dtype=torch.float64)
wx = align*torch.tensor(0.005, dtype=torch.float64)
wy = align*torch.tensor(-0.005, dtype=torch.float64)
wz = align*torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
code = f"""
cx:hkicker,l=0.0,kick={cx};
cy:vkicker,l=0.0,kick={cy};
map:line=(cx, cy) ;
beam,energy=1.0E+6,particle=electron ;
set,format="20.20f","-20s" ;
use,period=map ;
select,flag=error,pattern="mag" ;
ealign,dx={dx.item()},dy={dy.item()},ds={dz.item()},dphi={wx.item()},dtheta={wy.item()},dpsi={wz.item()} ;
ptc_create_universe,sector_nmul_max=10,sector_nmul=10 ;
ptc_create_layout,model=1,method=6,nst=1000,exact={str(exact).lower()} ;
ptc_setswitch,fringe=false,time=true,totalpath=true,exact_mis=true ;
ptc_align ;
ptc_start,x={qx},px={px},y={qy},py={py},pt={dp},t=0.0 ;
ptc_track,icase=5,deltap=0.,turns=1,file=track,maxaper={{1.,1.,1.,1.,1.,1.}} ;
ptc_track_end ;
ptc_end ;
"""
with ptc.open('w') as stream:
stream.write(code)
system(f'madx < {str(ptc)} > /dev/null')
with obs.open('r') as stream:
for line in stream:
continue
_, _, qx, px, qy, py, *_ = line.split()
ref = torch.tensor([float(x) for x in (qx, px, qy, py)], dtype=torch.float64)
C = Corrector('C', cx=cx, cy=cy, dp=dp)
res = C(state, alignment=align, data={**C.data(), **error})
print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())
ptc.unlink()
obs.unlink()
[0.0, 0.001, 0.0, -0.005]
[0.0, 0.001, 0.0, -0.005]
[0.0, 0.0, 0.0, 0.0]
[4]:
# Tracking (alignment)
# Only dx and dy alignment errors seems to work as expected in MADX
# Also, wz rotation matches tilt
ptc = Path('ptc')
obs = Path('track.obs0001.p0001')
exact = True
align = True
cx = +0.001
cy = -0.005
dp = 0.005
state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
qx, px, qy, py = state.tolist()
dx = align*torch.tensor(0.05, dtype=torch.float64)
dy = align*torch.tensor(-0.02, dtype=torch.float64)
dz = align*torch.tensor(0.0, dtype=torch.float64)
wx = align*torch.tensor(0.0, dtype=torch.float64)
wy = align*torch.tensor(0.0, dtype=torch.float64)
wz = align*torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
code = f"""
cx:hkicker,l=0.0,kick={cx},tilt={wz.item()};
cy:vkicker,l=0.0,kick={cy},tilt={wz.item()};
map:line=(cx, cy) ;
beam,energy=1.0E+6,particle=electron ;
set,format="20.20f","-20s" ;
use,period=map ;
select,flag=error,pattern="mag" ;
ealign,dx={dx.item()},dy={dy.item()} ;
ptc_create_universe,sector_nmul_max=10,sector_nmul=10 ;
ptc_create_layout,model=1,method=6,nst=1000,exact={str(exact).lower()} ;
ptc_setswitch,fringe=false,time=true,totalpath=true,exact_mis=true ;
ptc_align ;
ptc_start,x={qx},px={px},y={qy},py={py},pt={dp},t=0.0 ;
ptc_track,icase=5,deltap=0.,turns=1,file=track,maxaper={{1.,1.,1.,1.,1.,1.}} ;
ptc_track_end ;
ptc_end ;
"""
with ptc.open('w') as stream:
stream.write(code)
system(f'madx < {str(ptc)} > /dev/null')
with obs.open('r') as stream:
for line in stream:
continue
_, _, qx, px, qy, py, *_ = line.split()
ref = torch.tensor([float(x) for x in (qx, px, qy, py)], dtype=torch.float64)
C = Corrector('C', cx=cx, cy=cy, dp=dp)
res = C(state, alignment=align, data={**C.data(), **error})
print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())
ptc.unlink()
obs.unlink()
[0.0, 0.0014941712485121669, 0.0, -0.004875187409743301]
[-6.938893903907228e-18, 0.0014941712485121667, 0.0, -0.004875187409743301]
[6.938893903907228e-18, 2.168404344971009e-19, 0.0, 0.0]
[5]:
# Deviation/error variables
cx = +0.001
cy = -0.005
dp = 0.005
state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
dx = torch.tensor(+0.01, dtype=torch.float64)
dy = torch.tensor(-0.01, dtype=torch.float64)
dz = torch.tensor(0.0, dtype=torch.float64)
wx = torch.tensor(0.0, dtype=torch.float64)
wy = torch.tensor(0.0, dtype=torch.float64)
wz = torch.tensor(torch.pi, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
C = Corrector('C', cx, cy, dp)
# Each element has two variant of a call method
# In the first case only state is passed, it is transformed using parameters specified on initializaton
print(C(state))
print()
# Deviation errors can be also passed to call method
# These variables are added to corresponding parameters specified on initializaton
# For example, element lenght can changed
print(C(state, data={**C.data(), **{'cx': -cx, 'cy': -cy}}))
print()
# In the above C.data() creates default deviation dictionary (with zero values for each deviaton)
# {**C.data(), **{'cx': -cx, 'cy': -cy}} replaces the 'cx' and 'cy' key values
# Additionaly, alignment errors are passed as deivation variables
# They are used if alignment flag is raised
print(C(state, data={**C.data(), **error}, alignment=True))
print()
# The following elements can be made equivalent using deviation variables
CA = Corrector('CA', cx, cy, dp)
CB = Corrector('CB', cx - 0.001, cy, dp)
print(CA(state) - CB(state, data={**CB.data(), **{'cx': + 0.001}}))
tensor([ 0.0000, 0.0010, 0.0000, -0.0050], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([ 0.0000, -0.0010, 0.0000, 0.0050], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[6]:
# Mapping over a set of initial conditions
# Call method can be used to map over a set of initial conditions
# Note, device can be set to cpu or gpu via base element classvariables
cx = +0.001
cy = -0.005
dp = 0.005
dx = torch.tensor(+0.01, dtype=torch.float64)
dy = torch.tensor(-0.01, dtype=torch.float64)
dz = torch.tensor(0.0, dtype=torch.float64)
wx = torch.tensor(0.0, dtype=torch.float64)
wy = torch.tensor(0.0, dtype=torch.float64)
wz = torch.tensor(torch.pi, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
C = Corrector('C', cx, cy, dp)
state = 1.0E-3*torch.randn((512, 4), dtype=C.dtype, device=C.device)
print(torch.vmap(C)(state).shape)
# To map over deviations parameters a wrapper function (or a lambda expression) can be used
def wrapper(state, cx, cy):
return C(state, data={**C.data(), **{'cx': cx, 'cy': cy}})
cx = 1.0E-3*torch.randn(512, dtype=C.dtype, device=C.device)
cy = 1.0E-3*torch.randn(512, dtype=C.dtype, device=C.device)
print(torch.vmap(wrapper)(state, cx, cy).shape)
torch.Size([512, 4])
torch.Size([512, 4])
[7]:
# Differentiability
# Both call methods are differentiable
# Derivative with respect to state can be computed directly
# For deviation variables, wrapping is required
cx = +0.001
cy = -0.005
dp = 0.005
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
dx = torch.tensor(+0.01, dtype=torch.float64)
dy = torch.tensor(-0.01, dtype=torch.float64)
dz = torch.tensor(0.0, dtype=torch.float64)
wx = torch.tensor(0.0, dtype=torch.float64)
wy = torch.tensor(0.0, dtype=torch.float64)
wz = torch.tensor(torch.pi, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
C = Corrector('C', cx, cy, dp)
# Compute derivative with respect to state
print(torch.func.jacrev(C)(state))
print()
# Compute derivative with respect to a deviation variable
dcx = torch.tensor(0.0, dtype=torch.float64)
dcy = torch.tensor(0.0, dtype=torch.float64)
dc = torch.stack([dcx, dcy])
def wrapper(state, dc):
dcx, dcy = dc
return C(state, data={**C.data(), **{'cx': dcx, 'cy': dcy}})
print(torch.func.jacrev(wrapper, 1)(state, dc))
print()
tensor([[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]], dtype=torch.float64)
tensor([[0., 0.],
[1., 0.],
[0., 0.],
[0., 1.]], dtype=torch.float64)