Example-11: Corrector (element)

[1]:
# Comparison of corrector element with MADX-PTC and other features
[2]:
from pathlib import Path
from os import system

import torch
from model.library.corrector import Corrector
[3]:
# Tracking

ptc = Path('ptc')
obs = Path('track.obs0001.p0001')

exact = True
align = False

cx = +0.001
cy = -0.005

dp = 0.005

state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
qx, px, qy, py = state.tolist()

dx = align*torch.tensor(0.05, dtype=torch.float64)
dy = align*torch.tensor(-0.02, dtype=torch.float64)
dz = align*torch.tensor(0.05, dtype=torch.float64)

wx = align*torch.tensor(0.005, dtype=torch.float64)
wy = align*torch.tensor(-0.005, dtype=torch.float64)
wz = align*torch.tensor(0.1, dtype=torch.float64)

error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}

code = f"""
cx:hkicker,l=0.0,kick={cx};
cy:vkicker,l=0.0,kick={cy};
map:line=(cx, cy) ;
beam,energy=1.0E+6,particle=electron ;
set,format="20.20f","-20s" ;
use,period=map ;
select,flag=error,pattern="mag" ;
ealign,dx={dx.item()},dy={dy.item()},ds={dz.item()},dphi={wx.item()},dtheta={wy.item()},dpsi={wz.item()} ;
ptc_create_universe,sector_nmul_max=10,sector_nmul=10 ;
ptc_create_layout,model=1,method=6,nst=1000,exact={str(exact).lower()} ;
ptc_setswitch,fringe=false,time=true,totalpath=true,exact_mis=true ;
ptc_align ;
ptc_start,x={qx},px={px},y={qy},py={py},pt={dp},t=0.0 ;
ptc_track,icase=5,deltap=0.,turns=1,file=track,maxaper={{1.,1.,1.,1.,1.,1.}} ;
ptc_track_end ;
ptc_end ;
"""

with ptc.open('w') as stream:
    stream.write(code)

system(f'madx < {str(ptc)} > /dev/null')

with obs.open('r') as stream:
    for line in stream:
        continue
    _, _, qx, px, qy, py, *_ = line.split()

ref = torch.tensor([float(x) for x in (qx, px, qy, py)], dtype=torch.float64)

C = Corrector('C', cx=cx, cy=cy, dp=dp)
res = C(state, alignment=align, data={**C.data(), **error})

print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())

ptc.unlink()
obs.unlink()
[0.0, 0.001, 0.0, -0.005]
[0.0, 0.001, 0.0, -0.005]
[0.0, 0.0, 0.0, 0.0]
[4]:
# Tracking (alignment)

# Only dx and dy alignment errors seems to work as expected in MADX
# Also, wz rotation matches tilt

ptc = Path('ptc')
obs = Path('track.obs0001.p0001')

exact = True
align = True

cx = +0.001
cy = -0.005
dp = 0.005
state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
qx, px, qy, py = state.tolist()

dx = align*torch.tensor(0.05, dtype=torch.float64)
dy = align*torch.tensor(-0.02, dtype=torch.float64)
dz = align*torch.tensor(0.0, dtype=torch.float64)

wx = align*torch.tensor(0.0, dtype=torch.float64)
wy = align*torch.tensor(0.0, dtype=torch.float64)
wz = align*torch.tensor(0.1, dtype=torch.float64)

error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}

code = f"""
cx:hkicker,l=0.0,kick={cx},tilt={wz.item()};
cy:vkicker,l=0.0,kick={cy},tilt={wz.item()};
map:line=(cx, cy) ;
beam,energy=1.0E+6,particle=electron ;
set,format="20.20f","-20s" ;
use,period=map ;
select,flag=error,pattern="mag" ;
ealign,dx={dx.item()},dy={dy.item()} ;
ptc_create_universe,sector_nmul_max=10,sector_nmul=10 ;
ptc_create_layout,model=1,method=6,nst=1000,exact={str(exact).lower()} ;
ptc_setswitch,fringe=false,time=true,totalpath=true,exact_mis=true ;
ptc_align ;
ptc_start,x={qx},px={px},y={qy},py={py},pt={dp},t=0.0 ;
ptc_track,icase=5,deltap=0.,turns=1,file=track,maxaper={{1.,1.,1.,1.,1.,1.}} ;
ptc_track_end ;
ptc_end ;
"""

with ptc.open('w') as stream:
    stream.write(code)

system(f'madx < {str(ptc)} > /dev/null')

with obs.open('r') as stream:
    for line in stream:
        continue
    _, _, qx, px, qy, py, *_ = line.split()

ref = torch.tensor([float(x) for x in (qx, px, qy, py)], dtype=torch.float64)

C = Corrector('C', cx=cx, cy=cy, dp=dp)
res = C(state, alignment=align, data={**C.data(), **error})

print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())

ptc.unlink()
obs.unlink()
[0.0, 0.0014941712485121669, 0.0, -0.004875187409743301]
[-6.938893903907228e-18, 0.0014941712485121667, 0.0, -0.004875187409743301]
[6.938893903907228e-18, 2.168404344971009e-19, 0.0, 0.0]
[5]:
# Deviation/error variables

cx = +0.001
cy = -0.005
dp = 0.005
state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)

dx = torch.tensor(+0.01, dtype=torch.float64)
dy = torch.tensor(-0.01, dtype=torch.float64)
dz = torch.tensor(0.0, dtype=torch.float64)

wx = torch.tensor(0.0, dtype=torch.float64)
wy = torch.tensor(0.0, dtype=torch.float64)
wz = torch.tensor(torch.pi, dtype=torch.float64)

error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}

C = Corrector('C', cx, cy, dp)

# Each element has two variant of a call method
# In the first case only state is passed, it is transformed using parameters specified on initializaton

print(C(state))
print()

# Deviation errors can be also passed to call method
# These variables are added to corresponding parameters specified on initializaton
# For example, element lenght can changed

print(C(state, data={**C.data(), **{'cx': -cx, 'cy': -cy}}))
print()

# In the above C.data() creates default deviation dictionary (with zero values for each deviaton)
# {**C.data(), **{'cx': -cx, 'cy': -cy}} replaces the 'cx' and 'cy' key values

# Additionaly, alignment errors are passed as deivation variables
# They are used if alignment flag is raised

print(C(state, data={**C.data(), **error}, alignment=True))
print()


# The following elements can be made equivalent using deviation variables

CA = Corrector('CA', cx, cy, dp)
CB = Corrector('CB', cx - 0.001, cy, dp)

print(CA(state) - CB(state, data={**CB.data(), **{'cx': + 0.001}}))
tensor([ 0.0000,  0.0010,  0.0000, -0.0050], dtype=torch.float64)

tensor([0., 0., 0., 0.], dtype=torch.float64)

tensor([ 0.0000, -0.0010,  0.0000,  0.0050], dtype=torch.float64)

tensor([0., 0., 0., 0.], dtype=torch.float64)
[6]:
# Mapping over a set of initial conditions

# Call method can be used to map over a set of initial conditions
# Note, device can be set to cpu or gpu via base element classvariables

cx = +0.001
cy = -0.005
dp = 0.005

dx = torch.tensor(+0.01, dtype=torch.float64)
dy = torch.tensor(-0.01, dtype=torch.float64)
dz = torch.tensor(0.0, dtype=torch.float64)

wx = torch.tensor(0.0, dtype=torch.float64)
wy = torch.tensor(0.0, dtype=torch.float64)
wz = torch.tensor(torch.pi, dtype=torch.float64)

error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}

C = Corrector('C', cx, cy, dp)

state = 1.0E-3*torch.randn((512, 4), dtype=C.dtype, device=C.device)

print(torch.vmap(C)(state).shape)

# To map over deviations parameters a wrapper function (or a lambda expression) can be used

def wrapper(state, cx, cy):
    return C(state, data={**C.data(), **{'cx': cx, 'cy': cy}})

cx = 1.0E-3*torch.randn(512, dtype=C.dtype, device=C.device)
cy = 1.0E-3*torch.randn(512, dtype=C.dtype, device=C.device)

print(torch.vmap(wrapper)(state, cx, cy).shape)
torch.Size([512, 4])
torch.Size([512, 4])
[7]:
# Differentiability

# Both call methods are differentiable
# Derivative with respect to state can be computed directly
# For deviation variables, wrapping is required

cx = +0.001
cy = -0.005
dp = 0.005
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)

dx = torch.tensor(+0.01, dtype=torch.float64)
dy = torch.tensor(-0.01, dtype=torch.float64)
dz = torch.tensor(0.0, dtype=torch.float64)

wx = torch.tensor(0.0, dtype=torch.float64)
wy = torch.tensor(0.0, dtype=torch.float64)
wz = torch.tensor(torch.pi, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}

C = Corrector('C', cx, cy, dp)

# Compute derivative with respect to state

print(torch.func.jacrev(C)(state))
print()

# Compute derivative with respect to a deviation variable

dcx = torch.tensor(0.0, dtype=torch.float64)
dcy = torch.tensor(0.0, dtype=torch.float64)
dc = torch.stack([dcx, dcy])

def wrapper(state, dc):
    dcx, dcy = dc
    return C(state, data={**C.data(), **{'cx': dcx, 'cy': dcy}})

print(torch.func.jacrev(wrapper, 1)(state, dc))
print()
tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]], dtype=torch.float64)

tensor([[0., 0.],
        [1., 0.],
        [0., 0.],
        [0., 1.]], dtype=torch.float64)