Example-03: Element factory

In this example procedure for constructing generic accelerator element using elementary.element_factory is illustrated. With generic accelerator hamiltonian: $ \begin{align} & H(q_x, q_y, q_s, p_x, p_y, p_s; s) = \frac{p_s}{\beta} - t(s)(q_x p_y - q_y p_x) - (1 + h(s) q_x) \left(\sqrt{P_s^2 - P_x^2 - P_y^2 - \frac{1}{\beta^2 \gamma^2}} + a_s(q_x, q_y, q_s; s)\right) \\ & \\ & P_s = p_s + 1/\beta - \varphi(q_x, q_y, q_s; s) \\ & P_x = p_x - a_x(q_x, q_y, q_s; s) \\ & P_y = p_y - a_y(q_x, q_y, q_s; s) \end{align} $

where \(\beta\) and \(\gamma\) are the relativistic factors, \(h(s)\) is the reference trajectory curvature and \(t(s)\) is the reference trajectory torsion, \(a_x(q_x, q_y, q_s; s)\), \(a_y(q_x, q_y, q_s; s)\) and \(a_s(q_x, q_y, q_s; s)\) are the scaled vector potential components, and \(\varphi(q_x, q_y, q_s; s)\) is the scaled scalar potential. Additionaly, longitudinal coordinate and momentum are given by:

$ \begin{align} & q_s = \frac{s}{\beta} - c t \\ & p_s = \frac{E}{c P} - \frac{1}{\beta} \end{align} $

Corresponding element can be constructed by passing hamiltonian function or other parameters (e. g. vector potential). The returned element has the following signature:

def element(qsps:Array, length:Array, start:Array, *args:Array) -> Array:
  qs, ps = jax.numpy.reshape(qsps, (2, -1))
  q_x, q_y, q_s = qs
  p_x, p_y, p_s = ps
  ...

The following explicit hamiltonian is used as an example

$ H = p_s - \left`(:nbsphinx-math:sqrt{(1 + p_s)^2 - p_x^2 - p_y^2}` + a_s:nbsphinx-math:right) $ with \((a_x, a_y, a_s) = \left(0, 0, \frac{1}{2} k_n \left(1 + \sin\left(2 \pi \frac{s}{l}\right)\right) \left(q_x^2 - q_y^2\right)\right)\)

[1]:
# Import

import jax
from jax import Array
from jax import jit
from jax import jacrev

from elementary import fold
from elementary import nest
from elementary import tao
from elementary import sequence

from elementary.hamiltonian import hamiltonian_factory
from elementary.hamiltonian import autonomize
from elementary.element import element_factory

jax.numpy.set_printoptions(linewidth=256, precision=12)
[2]:
# Set data type

jax.config.update("jax_enable_x64", True)
[3]:
# Set device

device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)
[4]:
# Set parameters (initial independent parameter value, integration step length and quadrupole amplitude)

si = jax.numpy.float64(0.5)
ds = jax.numpy.float64(0.01)

l = jax.numpy.float64(1.0)
kn = jax.numpy.float64(1.0)
[5]:
# Set initial condition

qs = jax.numpy.array([0.001, -0.005, 0.0])
ps = jax.numpy.array([0.005, -0.001, 0.0001])

qsps = jax.numpy.hstack([qs, ps])
[6]:
# Define non-autonomous and extended hamiltonian (factory)

def vector(qs:Array, s:Array, l:Array, kn:Array, *args:Array) -> tuple[Array, Array, Array]:
    q_x, q_y, q_s = qs
    a_x, a_y, a_s = jax.numpy.zeros_like(qs)
    a_s = 1/2*kn*(1 + jax.numpy.sin(2*jax.numpy.pi*s/l))*(q_x**2 - q_y**2)
    return a_x, a_y, a_s

def scalar(qs:Array, s:Array, l:Array, kn:Array, *args:Array) -> Array:
    q_x, q_y, q_s = qs
    return jax.numpy.zeros_like(s)

hamiltonian = hamiltonian_factory(vector, scalar)
extended = autonomize(hamiltonian)
[7]:
# Set extended initial condition

Qs = jax.numpy.concat([qs, si.reshape(-1)])
Ps = jax.numpy.concat([ps, -hamiltonian(qs, ps, si, l, kn).reshape(-1)])
QsPs = jax.numpy.hstack([Qs, Ps])
[8]:
# Set tao integration step

integrator = jit(fold(sequence(0, 2**1, [tao(extended)], merge=False)))
[9]:
# Set and compile element

element = jit(nest(int(l/ds), integrator))
out = element(QsPs, ds, si, l, kn)
print(out)
[ 7.235126271969e-03 -4.257064841463e-03 -2.199863474965e-05  1.500000000000e+00  9.664543767678e-03  3.998219958148e-03  1.000000000000e-04  9.999624216145e-01]
[10]:
# Generate element from hamiltonian
# Note, phase space extension is handled internaly

element = element_factory(vector=None,
                          scalar=None,
                          curvature=None,
                          torsion=None,
                          hamiltonian=hamiltonian,
                          driver=tao,
                          order=2**1,
                          iterations=int(l/ds),
                          autonomous=False)

element = jit(element)

out = element(qsps, l, si, l, kn)
print(out)
[ 7.235126271969e-03 -4.257064841463e-03 -2.199863474965e-05  9.664543767678e-03  3.998219958148e-03  1.000000000000e-04]
[11]:
# Generate element from potential

element = element_factory(vector=vector,
                          scalar=None,
                          curvature=None,
                          torsion=None,
                          hamiltonian=None,
                          driver=tao,
                          order=2**1,
                          iterations=int(l/ds),
                          autonomous=False)

element = jit(element)

out = element(qsps, l, si, l, kn)
print(out)
[ 7.235126271969e-03 -4.257064841463e-03 -2.199863474965e-05  9.664543767678e-03  3.998219958148e-03  1.000000000000e-04]
[12]:
%%timeit

# Note, first call also performs compilation
# Subsequent calls use compiled function, but operations like map and jacobian will trigger recompilation

element(qsps, l, si, l, kn).block_until_ready()
1.17 ms ± 15 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)