Example-06: Quadrupole (element)

[1]:
# Comparison of quadrupole element with MADX-PTC and other features
[2]:
from pathlib import Path
from os import system

import torch
from model.library.drift import Drift
from model.library.quadrupole import Quadrupole
[3]:
# Tracking (paraxial)

ptc = Path('ptc')
obs = Path('track.obs0001.p0001')

exact = False
align = False

kn = - 2.0
ks = + 1.5
dp = 0.005
length = 1.0
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
qx, px, qy, py = state.tolist()

dx = align*torch.tensor(0.05, dtype=torch.float64)
dy = align*torch.tensor(-0.02, dtype=torch.float64)
dz = align*torch.tensor(0.05, dtype=torch.float64)

wx = align*torch.tensor(0.005, dtype=torch.float64)
wy = align*torch.tensor(-0.005, dtype=torch.float64)
wz = align*torch.tensor(0.1, dtype=torch.float64)

error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}

code = f"""
mag:quadrupole,l={length},k1={kn},k1s={ks} ;
map:line=(mag) ;
beam,energy=1.0E+6,particle=electron ;
set,format="20.20f","-20s" ;
use,period=map ;
select,flag=error,pattern="mag" ;
ealign,dx={dx.item()},dy={dy.item()},ds={dz.item()},dphi={wx.item()},dtheta={wy.item()},dpsi={wz.item()} ;
ptc_create_universe,sector_nmul_max=10,sector_nmul=10 ;
ptc_create_layout,model=1,method=6,nst=1000,exact={str(exact).lower()} ;
ptc_setswitch,fringe=false,time=true,totalpath=true,exact_mis=true ;
ptc_align ;
ptc_start,x={qx},px={px},y={qy},py={py},pt={dp},t=0.0 ;
ptc_track,icase=5,deltap=0.,turns=1,file=track,maxaper={{1.,1.,1.,1.,1.,1.}} ;
ptc_track_end ;
ptc_end ;
"""

with ptc.open('w') as stream:
    stream.write(code)

system(f'madx < {str(ptc)} > /dev/null')

with obs.open('r') as stream:
    for line in stream:
        continue
    _, _, qx, px, qy, py, *_ = line.split()

ref = torch.tensor([float(x) for x in (qx, px, qy, py)], dtype=torch.float64)

Q = Quadrupole('Q', length=length, kn=kn, ks=ks, dp=dp, exact=exact)
res = Q(state, alignment=align, data={**Q.data(), **error})

print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())

ptc.unlink()
obs.unlink()
[0.012268608165994052, 0.012991610983278109, 0.005825218798687177, 0.01752224400608683]
[0.012268608165994056, 0.012991610983278081, 0.005825218798687173, 0.017522244006086804]
[-3.469446951953614e-18, 2.7755575615628914e-17, 4.336808689942018e-18, 2.42861286636753e-17]
[4]:
# Tracking (exact)

ptc = Path('ptc')
obs = Path('track.obs0001.p0001')

exact = True
align = False

kn = - 2.0
ks = + 1.5
dp = 0.005
length = 1.0
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
qx, px, qy, py = state.tolist()

dx = align*torch.tensor(0.05, dtype=torch.float64)
dy = align*torch.tensor(-0.02, dtype=torch.float64)
dz = align*torch.tensor(0.05, dtype=torch.float64)

wx = align*torch.tensor(0.005, dtype=torch.float64)
wy = align*torch.tensor(-0.005, dtype=torch.float64)
wz = align*torch.tensor(0.1, dtype=torch.float64)

error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}

code = f"""
mag:quadrupole,l={length},k1={kn},k1s={ks} ;
map:line=(mag) ;
beam,energy=1.0E+6,particle=electron ;
set,format="20.20f","-20s" ;
use,period=map ;
select,flag=error,pattern="mag" ;
ealign,dx={dx.item()},dy={dy.item()},ds={dz.item()},dphi={wx.item()},dtheta={wy.item()},dpsi={wz.item()} ;
ptc_create_universe,sector_nmul_max=10,sector_nmul=10 ;
ptc_create_layout,model=1,method=6,nst=1000,exact={str(exact).lower()} ;
ptc_setswitch,fringe=false,time=true,totalpath=true,exact_mis=true ;
ptc_align ;
ptc_start,x={qx},px={px},y={qy},py={py},pt={dp},t=0.0 ;
ptc_track,icase=5,deltap=0.,turns=1,file=track,maxaper={{1.,1.,1.,1.,1.,1.}} ;
ptc_track_end ;
ptc_end ;
"""

with ptc.open('w') as stream:
    stream.write(code)

system(f'madx < {str(ptc)} > /dev/null')

with obs.open('r') as stream:
    for line in stream:
        continue
    _, _, qx, px, qy, py, *_ = line.split()

ref = torch.tensor([float(x) for x in (qx, px, qy, py)], dtype=torch.float64)

Q = Quadrupole('Q', length=length, kn=kn, ks=ks, dp=dp, exact=exact, order=5, ns=5)
res = Q(state, alignment=align, data={**Q.data(), **error})

print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())

ptc.unlink()
obs.unlink()
[0.012269208914523159, 0.012992157908766264, 0.005826335256074007, 0.017521822791072554]
[0.012269208914522952, 0.012992157908766015, 0.005826335256074073, 0.017521822791072776]
[2.0643209364124004e-16, 2.498001805406602e-16, -6.591949208711867e-17, -2.220446049250313e-16]
[5]:
# Tracking (exact, alignment)

ptc = Path('ptc')
obs = Path('track.obs0001.p0001')

exact = True
align = True

kn = - 2.0
ks = + 1.5
dp = 0.005
length = 1.0
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)
qx, px, qy, py = state.tolist()

dx = align*torch.tensor(0.05, dtype=torch.float64)
dy = align*torch.tensor(-0.02, dtype=torch.float64)
dz = align*torch.tensor(0.05, dtype=torch.float64)

wx = align*torch.tensor(0.005, dtype=torch.float64)
wy = align*torch.tensor(-0.005, dtype=torch.float64)
wz = align*torch.tensor(0.1, dtype=torch.float64)

error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}

code = f"""
mag:quadrupole,l={length},k1={kn},k1s={ks} ;
map:line=(mag) ;
beam,energy=1.0E+6,particle=electron ;
set,format="20.20f","-20s" ;
use,period=map ;
select,flag=error,pattern="mag" ;
ealign,dx={dx.item()},dy={dy.item()},ds={dz.item()},dphi={wx.item()},dtheta={wy.item()},dpsi={wz.item()} ;
ptc_create_universe,sector_nmul_max=10,sector_nmul=10 ;
ptc_create_layout,model=1,method=6,nst=1000,exact={str(exact).lower()} ;
ptc_setswitch,fringe=false,time=true,totalpath=true,exact_mis=true ;
ptc_align ;
ptc_start,x={qx},px={px},y={qy},py={py},pt={dp},t=0.0 ;
ptc_track,icase=5,deltap=0.,turns=1,file=track,maxaper={{1.,1.,1.,1.,1.,1.}} ;
ptc_track_end ;
ptc_end ;
"""

with ptc.open('w') as stream:
    stream.write(code)

system(f'madx < {str(ptc)} > /dev/null')

with obs.open('r') as stream:
    for line in stream:
        continue
    _, _, qx, px, qy, py, *_ = line.split()

ref = torch.tensor([float(x) for x in (qx, px, qy, py)], dtype=torch.float64)

Q = Quadrupole('Q', length=length, kn=kn, ks=ks, dp=dp, exact=exact, order=5, ns=5)
res = Q(state, alignment=align, data={**Q.data(), **error})

print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())

ptc.unlink()
obs.unlink()
[-0.022075488016794924, -0.09165584224601611, -0.04570124622656498, -0.08629975808408008]
[-0.02207548801679271, -0.09165584224601468, -0.04570124622656417, -0.08629975808408101]
[-2.213507155346406e-15, -1.429412144204889e-15, -8.118505867571457e-16, 9.298117831235686e-16]
[6]:
# Deviation/error variables

kn = - 2.0
ks = + 1.5
dp = 0.005
length = 1.5
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)

dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)

wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)

error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}

Q = Quadrupole('Q', length, kn, ks, dp)

# Each element has two variant of a call method
# In the first case only state is passed, it is transformed using parameters specified on initializaton

print(Q(state))
print()


# Deviation errors can be also passed to call method
# These variables are added to corresponding parameters specified on initializaton
# For example, element lenght can changed

print(Q(state, data={**Q.data(), **{'dl': -Q.length}}))
print()

# In the above Q.data() creates default deviation dictionary (with zero values for each deviaton)
# {**Q.data(), **{'dl': -Q.length}} replaces the 'dl' key value

# Additionaly, alignment errors are passed as deivation variables
# They are used if alignment flag is raised

print(Q(state, data={**Q.data(), **error}, alignment=True))
print()

# The following elements can be made equivalent using deviation variables

QA = Quadrupole('QA', length, kn, ks, dp)
QB = Quadrupole('QB', length - 0.1, kn, ks, dp)

print(QA(state) - QB(state, data={**QB.data(), **{'dl': torch.tensor(+0.1, dtype=QB.dtype)}}))

# Note, while in some cases float values can be passed as values to deviation variables
# The correct behaviour in guaranteed only for tensors
tensor([0.0242, 0.0380, 0.0152, 0.0200], dtype=torch.float64)

tensor([ 0.0100, -0.0050, -0.0050,  0.0010], dtype=torch.float64)

tensor([-0.0908, -0.2335, -0.0963, -0.1316], dtype=torch.float64)

tensor([0., 0., 0., 0.], dtype=torch.float64)
[7]:
# Insertion element

# In this mode elements are treated as thin insertions (at the center)
# Using parameters specified on initialization, transport two matrices are computed
# These matrices are used to insert the element
# Input state is transformed from the element center to its entrance
# Next, transformation from the entrance frame to the exit frame is performed
# This transformation can contain errors
# The final step is to transform state from the exit frame back to the element center
# Without errors, this results in identity transformation for linear elements

kn = - 2.0
ks = + 1.5
dp = 0.005
length = 1.5
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)

dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)

wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)

error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}

Q = Quadrupole('Q', length, kn, ks, dp, exact=False, insertion=True)

# Identity transformation without errors

print(Q(state) - state)

# Represents effect of an error

print(Q(state, data={**Q.data(), **{'dl': 0.1, 'kn': -0.1}}) - state)

# Exact tracking corresponds to inclusion of kinematic term as errors

Q = Quadrupole('Q', length, kn, ks, dp, exact=True, insertion=True, ns=100, order=1)

print(Q(state) - state)
tensor([-5.2042e-18,  1.0408e-17, -3.4694e-18, -3.4694e-18],
       dtype=torch.float64)
tensor([-0.0002,  0.0037,  0.0003,  0.0031], dtype=torch.float64)
tensor([-2.2924e-06, -3.9787e-06, -9.4215e-07,  2.1943e-07],
       dtype=torch.float64)
[8]:
# Mapping over a set of initial conditions

# Call method can be used to map over a set of initial conditions
# Note, device can be set to cpu or gpu via base element classvariables

kn = - 2.0
ks = + 1.5
dp = 0.0
length = 1.5

dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)

wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)

error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}

Q = Quadrupole('Q', length, kn, ks, dp, exact=True)

state = 1.0E-3*torch.randn((512, 4), dtype=Q.dtype, device=Q.device)

print(torch.vmap(Q)(state).shape)

# To map over deviations parameters a wrapper function (or a lambda expression) can be used

def wrapper(state, dp):
    return Q(state, data={**Q.data(), **{'dp': dp}})

dp = 1.0E-3*torch.randn(512, dtype=Q.dtype, device=Q.device)

print(torch.vmap(wrapper)(state, dp).shape)
torch.Size([512, 4])
torch.Size([512, 4])
[9]:
# Differentiability

# Both call methods are differentiable
# Derivative with respect to state can be computed directly
# For deviation variables, wrapping is required

kn = - 2.0
ks = + 1.5
dp = 0.0
length = 1.5
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)

dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)

wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)

error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}

Q = Quadrupole('Q', length, kn, ks, dp, exact=False)

# Compute derivative with respect to state

print(torch.func.jacrev(Q)(state))
print()

# Compute derivative with respect to a deviation variable

kn = torch.tensor(0.0, dtype=torch.float64)

def wrapper(state, kn):
    return Q(state, data={**Q.data(), **{'kn': kn}})

print(torch.func.jacrev(wrapper, 1)(state, kn))
print()

# Compositional derivative (compute derivative of jacobian trace with respect quadrupole strength)

length = 0.5
knf = +0.2
knd = -0.2

QF = Quadrupole('QF', length, knf)
QD = Quadrupole('QD', length, knd)
DR = Drift('DR', 5.0)

dknf = torch.tensor(0.0, dtype=torch.float64)
dknd = torch.tensor(0.0, dtype=torch.float64)
dkn = torch.stack([dknf, dknd])

def fodo(state, dkn):
    dknf, dknd = dkn
    state = QF(state, data={**QF.data(), **{'kn': dknf}})
    state = DR(state)
    state = QD(state, data={**QD.data(), **{'kn': dknd}})
    state = QD(state, data={**QD.data(), **{'kn': dknd}})
    state = DR(state)
    state = QF(state, data={**QF.data(), **{'kn': dknf}})
    return state

state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)

def trace(dkn):
    return (torch.func.jacrev(fodo)(state, dkn)).trace()

torch.func.jacrev(trace)(dkn)
tensor([[ 4.7923,  3.0672,  1.8367,  0.8757],
        [ 7.4479,  4.7923,  2.8495,  1.8367],
        [ 1.8367,  0.8757, -0.1057,  0.7321],
        [ 2.8495,  1.8367, -0.1507, -0.1057]], dtype=torch.float64)

tensor([-0.0175, -0.0354, -0.0029, -0.0029], dtype=torch.float64)

[9]:
tensor([-12.7901,  12.7901], dtype=torch.float64)
[10]:
# Output at each step

# It is possible to collect output of state or tangent matrix at each integration step
# Number of integratin steps is controlled by ns parameter on initialization
# Alternatively, desired integration step length can be passed
# Number of integration steps is computed as ceil(length/ds)

kn = - 2.0
ks = + 1.5
dp = 0.0
length = 1.5
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)

dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)

wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)

error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}

Q = Quadrupole('Q', length, kn, ks, dp, exact=False, ns=10, output=True, matrix=True)

# Final state is still returned

print(Q(state))

# Data is added to special attributes (state and tangent matrix)

print(Q.container_output.shape)
print(Q.container_matrix.shape)

# Number of integration steps can be changed

Q.ns = 100

Q(state)
print(Q.container_output.shape)
print(Q.container_matrix.shape)
tensor([0.0243, 0.0381, 0.0153, 0.0200], dtype=torch.float64)
torch.Size([10, 4])
torch.Size([10, 4, 4])
torch.Size([100, 4])
torch.Size([100, 4, 4])
[11]:
# Integration order is set on initialization (default value is zero)
# This order is related to difference order as 2n + 2
# Thus, zero corresponds to second order difference method

kn = - 2.0
ks = + 1.5
dp = 0.0
length = 1.5
state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)

dx = torch.tensor(0.05, dtype=torch.float64)
dy = torch.tensor(-0.02, dtype=torch.float64)
dz = torch.tensor(0.05, dtype=torch.float64)

wx = torch.tensor(0.005, dtype=torch.float64)
wy = torch.tensor(-0.005, dtype=torch.float64)
wz = torch.tensor(0.1, dtype=torch.float64)

error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}

Q = Quadrupole('Q', length, kn, ks, dp, order=1, exact=True)

# For quadrupole integration is performed only with exact flag
# In this case, kinematic term error is added

Q.ns = 1
ref = Q(state)

Q.ns = 10
res = Q(state)

print(ref.tolist())
print(res.tolist())
print((ref - res).tolist())
print()

# Integrator parameters are stored in data attribute (if integration is actually performed)

maps, weights = Q._data
print(maps)
print(weights)
[0.024284271092022615, 0.03811354319280181, 0.015254854998694, 0.01995832080452666]
[0.024286764143785347, 0.038112599406681526, 0.015255420211007705, 0.01995809327180016]
[-2.4930517627322346e-06, 9.437861202832298e-07, -5.652123137040582e-07, 2.2753272650027911e-07]

[0, 1, 0, 1, 0, 1, 0]
[0.6756035959798289, 1.3512071919596578, -0.17560359597982877, -1.7024143839193153, -0.17560359597982877, 1.3512071919596578, 0.6756035959798289]
[12]:
# Derivatives of twiss parameters

# pip install git+https://github.com/i-a-morozov/twiss.git@main

from twiss import twiss

length = 0.5
knf = +0.21
knd = -0.19

QF = Quadrupole('QF', length, knf)
QD = Quadrupole('QD', length, knd)
DR = Drift('DR', 5.0)

dknf = torch.tensor(0.0, dtype=torch.float64)
dknd = torch.tensor(0.0, dtype=torch.float64)
dkn = torch.stack([dknf, dknd])

def fodo(state, dkn):
    dknf, dknd = dkn
    state = QF(state, data={**QF.data(), **{'kn': dknf}})
    state = DR(state)
    state = QD(state, data={**QD.data(), **{'kn': dknd}})
    state = QD(state, data={**QD.data(), **{'kn': dknd}})
    state = DR(state)
    state = QF(state, data={**QF.data(), **{'kn': dknf}})
    return state

state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)

def tune(dkn):
    matrix = torch.func.jacrev(fodo)(state, dkn)
    tune, *_ = twiss(matrix)
    return tune

# Compute tunes and jacobian

values = tune(dkn)
jacobian = torch.func.jacrev(tune)(dkn)

# Test jacobiant

print(values)
print(tune(dkn + 1.0E-3))
print(values + jacobian @ (dkn + 1.0E-3))
tensor([0.2107, 0.1703], dtype=torch.float64)
tensor([0.2126, 0.1681], dtype=torch.float64)
tensor([0.2126, 0.1681], dtype=torch.float64)