Example-07: Sextupole (element)
[1]:
# Comparison of sextupole element with MADX-PTC and other features
[2]:
from pathlib import Path
from os import system
import torch
from model.library.drift import Drift
from model.library.quadrupole import Quadrupole
from model.library.sextupole import Sextupole
[3]:
# Tracking (paraxial)
ptc = Path('ptc')
obs = Path('track.obs0001.p0001')
exact = False
align = False
ms = 10.0
dp = 0.005
length = 0.25
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
qx, px, qy, py = state.tolist()
dx = align*torch.tensor(0.05, dtype=torch.float64)
dy = align*torch.tensor(-0.02, dtype=torch.float64)
dz = align*torch.tensor(0.05, dtype=torch.float64)
wx = align*torch.tensor(0.005, dtype=torch.float64)
wy = align*torch.tensor(-0.005, dtype=torch.float64)
wz = align*torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
code = f"""
mag: sextupole, l={length},k2={ms} ;
map:line=(mag) ;
beam,energy=1.0E+6,particle=electron ;
set,format="20.20f","-20s" ;
use,period=map ;
select,flag=error,pattern="mag" ;
ealign,dx={dx.item()},dy={dy.item()},ds={dz.item()},dphi={wx.item()},dtheta={wy.item()},dpsi={wz.item()} ;
ptc_create_universe,sector_nmul_max=10,sector_nmul=10 ;
ptc_create_layout,model=1,method=6,nst=1000,exact={str(exact).lower()} ;
ptc_setswitch,fringe=false,time=true,totalpath=true,exact_mis=true ;
ptc_align ;
ptc_start,x={qx},px={px},y={qy},py={py},pt={dp},t=0.0 ;
ptc_track,icase=5,deltap=0.,turns=1,file=track,maxaper={{1.,1.,1.,1.,1.,1.}} ;
ptc_track_end ;
ptc_end ;
"""
with ptc.open('w') as stream:
stream.write(code)
system(f'madx < {str(ptc)} > /dev/null')
with obs.open('r') as stream:
for line in stream:
continue
_, _, qx, px, qy, py, *_ = line.split()
ref = torch.tensor([float(x) for x in (qx, px, qy, py)], dtype=torch.float64)
S = Sextupole('S', length=length, ms=ms, dp=dp, exact=exact, order=5, ns=10)
res = S(state, alignment=align, data={**S.data(), **error})
print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())
ptc.unlink()
obs.unlink()
[0.008745689261382875, -0.005080234821325765, -0.00476590910682766, 0.0008855562050471031]
[0.008745689261382871, -0.00508023482132578, -0.00476590910682768, 0.0008855562050471008]
[3.469446951953614e-18, 1.5612511283791264e-17, 1.9949319973733282e-17, 2.2768245622195593e-18]
[4]:
# Tracking (exact)
ptc = Path('ptc')
obs = Path('track.obs0001.p0001')
exact = True
align = False
ms = 10.0
dp = 0.005
length = 0.25
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
qx, px, qy, py = state.tolist()
dx = align*torch.tensor(0.05, dtype=torch.float64)
dy = align*torch.tensor(-0.02, dtype=torch.float64)
dz = align*torch.tensor(0.05, dtype=torch.float64)
wx = align*torch.tensor(0.005, dtype=torch.float64)
wy = align*torch.tensor(-0.005, dtype=torch.float64)
wz = align*torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
code = f"""
mag: sextupole, l={length},k2={ms} ;
map:line=(mag) ;
beam,energy=1.0E+6,particle=electron ;
set,format="20.20f","-20s" ;
use,period=map ;
select,flag=error,pattern="mag" ;
ealign,dx={dx.item()},dy={dy.item()},ds={dz.item()},dphi={wx.item()},dtheta={wy.item()},dpsi={wz.item()} ;
ptc_create_universe,sector_nmul_max=10,sector_nmul=10 ;
ptc_create_layout,model=1,method=6,nst=1000,exact={str(exact).lower()} ;
ptc_setswitch,fringe=false,time=true,totalpath=true,exact_mis=true ;
ptc_align ;
ptc_start,x={qx},px={px},y={qy},py={py},pt={dp},t=0.0 ;
ptc_track,icase=5,deltap=0.,turns=1,file=track,maxaper={{1.,1.,1.,1.,1.,1.}} ;
ptc_track_end ;
ptc_end ;
"""
with ptc.open('w') as stream:
stream.write(code)
system(f'madx < {str(ptc)} > /dev/null')
with obs.open('r') as stream:
for line in stream:
continue
_, _, qx, px, qy, py, *_ = line.split()
ref = torch.tensor([float(x) for x in (qx, px, qy, py)], dtype=torch.float64)
S = Sextupole('S', length=length, ms=ms, dp=dp, exact=exact, order=5, ns=10)
res = S(state, alignment=align, data={**S.data(), **error})
print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())
ptc.unlink()
obs.unlink()
[0.008745672936611987, -0.005080234654232574, -0.004765906047078759, 0.000885556338827788]
[0.00874567293661202, -0.0050802346542325495, -0.0047659060470788064, 0.0008855563388277869]
[-3.2959746043559335e-17, -2.42861286636753e-17, 4.7704895589362195e-17, 1.0842021724855044e-18]
[5]:
# Tracking (exact, alignment)
ptc = Path('ptc')
obs = Path('track.obs0001.p0001')
exact = True
align = True
ms = 10.0
dp = 0.005
length = 0.25
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
qx, px, qy, py = state.tolist()
dx = align*torch.tensor(0.05, dtype=torch.float64)
dy = align*torch.tensor(-0.02, dtype=torch.float64)
dz = align*torch.tensor(0.05, dtype=torch.float64)
wx = align*torch.tensor(0.005, dtype=torch.float64)
wy = align*torch.tensor(-0.005, dtype=torch.float64)
wz = align*torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
code = f"""
mag: sextupole, l={length},k2={ms} ;
map:line=(mag) ;
beam,energy=1.0E+6,particle=electron ;
set,format="20.20f","-20s" ;
use,period=map ;
select,flag=error,pattern="mag" ;
ealign,dx={dx.item()},dy={dy.item()},ds={dz.item()},dphi={wx.item()},dtheta={wy.item()},dpsi={wz.item()} ;
ptc_create_universe,sector_nmul_max=10,sector_nmul=10 ;
ptc_create_layout,model=1,method=6,nst=1000,exact={str(exact).lower()} ;
ptc_setswitch,fringe=false,time=true,totalpath=true,exact_mis=true ;
ptc_align ;
ptc_start,x={qx},px={px},y={qy},py={py},pt={dp},t=0.0 ;
ptc_track,icase=5,deltap=0.,turns=1,file=track,maxaper={{1.,1.,1.,1.,1.,1.}} ;
ptc_track_end ;
ptc_end ;
"""
with ptc.open('w') as stream:
stream.write(code)
system(f'madx < {str(ptc)} > /dev/null')
with obs.open('r') as stream:
for line in stream:
continue
_, _, qx, px, qy, py, *_ = line.split()
ref = torch.tensor([float(x) for x in (qx, px, qy, py)], dtype=torch.float64)
S = Sextupole('S', length=length, ms=ms, dp=dp, exact=exact, order=5, ns=10)
res = S(state, alignment=align, data={**S.data(), **error})
print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())
ptc.unlink()
obs.unlink()
[0.008663885569968804, -0.006258052120536049, -0.004896687680053297, -0.000915022709372755]
[0.008663885569968922, -0.006258052120536033, -0.004896687680053253, -0.000915022709372733]
[-1.1796119636642288e-16, -1.6479873021779667e-17, -4.423544863740858e-17, -2.200930410145574e-17]
[6]:
# Deviation/error variables
ms = 10.0
dp = 0.005
length = 0.25
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)
wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
S = Sextupole('S', length, ms, dp)
# Each element has two variant of a call method
# In the first case only state is passed, it is transformed using parameters specified on initializaton
print(S(state))
print()
# Deviation errors can be also passed to call method
# These variables are added to corresponding parameters specified on initializaton
# For example, element lenght can changed
print(S(state, data={**S.data(), **{'dl': -S.length}}))
print()
# In the above S.data() creates default deviation dictionary (with zero values for each deviaton)
# {**S.data(), **{'dl': -S.length}} replaces the 'dl' key value
# Additionaly, alignment errors are passed as deivation variables
# They are used if alignment flag is raised
print(S(state, data={**S.data(), **error}, alignment=True))
print()
# The following elements can be made equivalent using deviation variables
SA = Sextupole('SA', length, ms, dp)
SB = Sextupole('SB', length - 0.1, ms, dp)
print(SA(state) - SB(state, data={**SB.data(), **{'dl': torch.tensor(+0.1, dtype=SB.dtype)}}))
# Note, while in some cases float values can be passed as values to deviation variables
# The correct behaviour in guaranteed only for tensors
tensor([ 0.0087, -0.0051, -0.0048, 0.0009], dtype=torch.float64)
tensor([ 0.0100, -0.0050, -0.0050, 0.0010], dtype=torch.float64)
tensor([ 0.0087, -0.0062, -0.0049, -0.0009], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[7]:
# Insertion element
# In this mode elements are treated as thin insertions (at the center)
# Using parameters specified on initialization, transport two matrices are computed
# These matrices are used to insert the element
# Input state is transformed from the element center to its entrance
# Next, transformation from the entrance frame to the exit frame is performed
# This transformation can contain errors
# The final step is to transform state from the exit frame back to the element center
# Without errors, this results in identity transformation for linear elements
ms = 10.0
dp = 0.005
length = 1.5
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)
wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
S = Sextupole('S', length, ms, dp, exact=False, insertion=True)
# Since sextupole is a nonlinear element, insertion is an identity transformation only for zero strenght
print(S(state) - state)
print(S(state, data={**S.data(), **{'ms': -ms}}) - state)
# Represents effect of an error (any nonzero value of strengh or a change in other parameter)
print(S(state, data={**S.data(), **{'dl': 0.1}}) - state)
# Exact tracking corresponds to inclusion of kinematic term as errors
S = Sextupole('S', length, ms, dp, exact=True, insertion=True, ns=100, order=1)
print(S(state) - state)
tensor([ 0.0000, -0.0006, 0.0000, -0.0008], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([-5.2560e-04, -5.6465e-04, 6.1078e-05, -7.7234e-04],
dtype=torch.float64)
tensor([-1.3875e-04, -5.5306e-04, -9.3764e-05, -7.7778e-04],
dtype=torch.float64)
[8]:
# Mapping over a set of initial conditions
# Call method can be used to map over a set of initial conditions
# Note, device can be set to cpu or gpu via base element classvariables
ms = 10.0
dp = 0.0
length = 1.5
dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)
wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
S = Sextupole('S', length, ms, dp, exact=True)
state = 1.0E-3*torch.randn((512, 4), dtype=S.dtype, device=S.device)
print(torch.vmap(S)(state).shape)
# To map over deviations parameters a wrapper function (or a lambda expression) can be used
def wrapper(state, dp):
return S(state, data={**S.data(), **{'dp': dp}})
dp = 1.0E-3*torch.randn(512, dtype=S.dtype, device=S.device)
print(torch.vmap(wrapper)(state, dp).shape)
torch.Size([512, 4])
torch.Size([512, 4])
[9]:
# Differentiability
# Both call methods are differentiable
# Derivative with respect to state can be computed directly
# For deviation variables, wrapping is required
ms = 10.0
dp = 0.0
length = 1.5
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)
wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
S = Sextupole('S', length, ms, dp, exact=False)
# Compute derivative with respect to state
print(torch.func.jacrev(S)(state))
print()
# Compute derivative with respect to a deviation variable
ms = torch.tensor(0.0, dtype=torch.float64)
def wrapper(state, ms):
return S(state, data={**S.data(), **{'ms': ms}})
print(torch.func.jacrev(wrapper, 1)(state, ms))
print()
tensor([[ 0.9297, 1.4473, -0.0478, -0.0359],
[-0.0938, 0.9297, -0.0638, -0.0478],
[-0.0478, -0.0359, 1.0703, 1.5527],
[-0.0638, -0.0478, 0.0938, 1.0703]], dtype=torch.float64)
tensor([-1.1813e-05, -1.5750e-05, -2.9883e-05, -3.9844e-05],
dtype=torch.float64)
[10]:
# Output at each step
# It is possible to collect output of state or tangent matrix at each integration step
# Number of integratin steps is controlled by ns parameter on initialization
# Alternatively, desired integration step length can be passed
# Number of integration steps is computed as ceil(length/ds)
ms = 10.0
dp = 0.0
length = 1.5
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)
wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
S = Sextupole('S', length, ms, dp, exact=False, ns=10, output=True, matrix=True)
# Final state is still returned
print(S(state))
# Data is added to special attributes (state and tangent matrix)
print(S.container_output.shape)
print(S.container_matrix.shape)
# Number of integration steps can be changed
S.ns = 100
S(state)
print(S.container_output.shape)
print(S.container_matrix.shape)
tensor([ 0.0023, -0.0052, -0.0039, 0.0006], dtype=torch.float64)
torch.Size([10, 4])
torch.Size([10, 4, 4])
torch.Size([100, 4])
torch.Size([100, 4, 4])
[11]:
# Integration order is set on initialization (default value is zero)
# This order is related to difference order as 2n + 2
# Thus, zero corresponds to second order difference method
ms = 10.0
dp = 0.0
length = 1.5
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)
wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)
error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}
S = Sextupole('S', length, ms, dp, order=0, exact=True)
# For sextupole integration is always performed
# In exact case, kinematic term error is added
S.ns = 10
ref = S(state)
S.ns = 100
res = S(state)
print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())
print()
# Integrator parameters are stored in data attribute (if integration is actually performed)
maps, weights = S._data
print(maps)
print(weights)
[0.0022880403040043407, -0.005176627884407675, -0.0038882248469424585, 0.0005834231060909752]
[0.0022871610802911667, -0.005176794799148338, -0.0038891562083014203, 0.0005832353574511327]
[8.792237131739246e-07, 1.6691474066278522e-07, 9.313613589618727e-07, 1.877486398425181e-07]
[0, 1, 2, 1, 0]
[0.5, 0.5, 1.0, 0.5, 0.5]
[12]:
# Derivatives of twiss parameters (chromaticity)
# pip install git+https://github.com/i-a-morozov/twiss.git@main
# pip install git+https://github.com/i-a-morozov/ndmap.git@main
from twiss import twiss
from ndmap.pfp import parametric_fixed_point
from ndmap.evaluate import evaluate
# Define elements
QF = Quadrupole('QF', 0.5, +0.21)
QD = Quadrupole('QD', 0.5, -0.19)
SF = Sextupole('SF', 0.25)
SD = Sextupole('SD', 0.25)
DA = Drift('DR', 0.25)
DB = Drift('DR', 4.00)
# Define one-turn transformation
def fodo(state, dp, ms):
dp, *_ = dp
msf, msd, *_ = ms
state = QF(state, data={**QF.data(), **{'dp': dp}})
state = DA(state, data={**DA.data(), **{'dp': dp}})
state = SF(state, data={**SF.data(), **{'dp': dp, 'ms': msf}})
state = DB(state, data={**DB.data(), **{'dp': dp}})
state = SD(state, data={**SD.data(), **{'dp': dp, 'ms': msd}})
state = DA(state, data={**DA.data(), **{'dp': dp}})
state = QD(state, data={**QD.data(), **{'dp': dp}})
state = QD(state, data={**QD.data(), **{'dp': dp}})
state = DA(state, data={**DA.data(), **{'dp': dp}})
state = SD(state, data={**SD.data(), **{'dp': dp, 'ms': msd}})
state = DB(state, data={**DB.data(), **{'dp': dp}})
state = SF(state, data={**SF.data(), **{'dp': dp, 'ms': msf}})
state = DA(state, data={**DA.data(), **{'dp': dp}})
state = QF(state, data={**QF.data(), **{'dp': dp}})
return state
# Set deviation parameters
msf = torch.tensor(0.0, dtype=torch.float64)
msd = torch.tensor(0.0, dtype=torch.float64)
ms = torch.stack([msf, msd])
dp = torch.tensor([0.0], dtype=torch.float64)
# Set fixed point
fp = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
# Compute parametrix fixed point (first order in momentum deviation)
# Note, all parameters must be vectors
pfp, *_ = parametric_fixed_point((1, ), fp, [dp], fodo, ms)
# Define transformation around fixed point
def pfp_fodo(state, dp, ms):
return fodo(state + evaluate(pfp, [dp]), dp, ms) - evaluate(pfp, [dp])
# Tune
def tune(dp, ms):
matrix = torch.func.jacrev(pfp_fodo)(fp, dp, ms)
tune, *_ = twiss(matrix)
return tune
# Chromaticity
def chromaticity(ms):
return torch.func.jacrev(tune)(dp, ms)
# Compute tunes
tunes = tune(dp, ms)
print(tunes)
# Compute chromaticity
chromaticities = chromaticity(ms)
print(chromaticities.squeeze())
# Compute derivative of chromaticities
# The result is zero, since there is no dispersion to feed sextupoles down
print(torch.func.jacrev(chromaticity)(ms).squeeze())
tensor([0.2107, 0.1703], dtype=torch.float64)
tensor([-0.2279, -0.2107], dtype=torch.float64)
tensor([[0., 0.],
[0., 0.]], dtype=torch.float64)