Example-03: Element factory
In this example procedure for constructing generic accelerator element using elementary.element_factory
is illustrated. With generic accelerator hamiltonian: $ \begin{align}
& H(q_x, q_y, q_s, p_x, p_y, p_s; s) = \frac{p_s}{\beta} - t(s)(q_x p_y - q_y p_x) - (1 + h(s) q_x) \left(\sqrt{P_s^2 - P_x^2 - P_y^2 - \frac{1}{\beta^2 \gamma^2}} + a_s(q_x, q_y, q_s; s)\right) \\
& \\
& P_s = p_s + 1/\beta - \varphi(q_x, q_y, q_s; s) \\
& P_x = p_x - a_x(q_x, q_y, q_s; s) \\
& P_y = p_y - a_y(q_x, q_y, q_s; s)
\end{align} $
where \(\beta\) and \(\gamma\) are the relativistic factors, \(h(s)\) is the reference trajectory curvature and \(t(s)\) is the reference trajectory torsion, \(a_x(q_x, q_y, q_s; s)\), \(a_y(q_x, q_y, q_s; s)\) and \(a_s(q_x, q_y, q_s; s)\) are the scaled vector potential components, and \(\varphi(q_x, q_y, q_s; s)\) is the scaled scalar potential. Additionaly, longitudinal coordinate and momentum are given by:
$ \begin{align} & q_s = \frac{s}{\beta} - c t \\ & p_s = \frac{E}{c P} - \frac{1}{\beta} \end{align} $
Corresponding element can be constructed by passing hamiltonian function or other parameters (e. g. vector potential). The returned element has the following signature:
def element(qsps:Array, length:Array, start:Array, *args:Array) -> Array:
qs, ps = jax.numpy.reshape(qsps, (2, -1))
q_x, q_y, q_s = qs
p_x, p_y, p_s = ps
...
The following explicit hamiltonian is used as an example
$ H = p_s - \left`(:nbsphinx-math:sqrt{(1 + p_s)^2 - p_x^2 - p_y^2}` + a_s:nbsphinx-math:right) $ with \((a_x, a_y, a_s) = \left(0, 0, \frac{1}{2} k_n \left(1 + \sin\left(2 \pi \frac{s}{l}\right)\right) \left(q_x^2 - q_y^2\right)\right)\)
[1]:
# Import
import jax
from jax import Array
from jax import jit
from jax import jacrev
from elementary import fold
from elementary import nest
from elementary import tao
from elementary import sequence
from elementary.hamiltonian import hamiltonian_factory
from elementary.hamiltonian import autonomize
from elementary.element import element_factory
jax.numpy.set_printoptions(linewidth=256, precision=12)
[2]:
# Set data type
jax.config.update("jax_enable_x64", True)
[3]:
# Set device
device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)
[4]:
# Set parameters (initial independent parameter value, integration step length and quadrupole amplitude)
si = jax.numpy.float64(0.5)
ds = jax.numpy.float64(0.01)
l = jax.numpy.float64(1.0)
kn = jax.numpy.float64(1.0)
[5]:
# Set initial condition
qs = jax.numpy.array([0.001, -0.005, 0.0])
ps = jax.numpy.array([0.005, -0.001, 0.0001])
qsps = jax.numpy.hstack([qs, ps])
[6]:
# Define non-autonomous and extended hamiltonian (factory)
def vector(qs:Array, s:Array, l:Array, kn:Array, *args:Array) -> tuple[Array, Array, Array]:
q_x, q_y, q_s = qs
a_x, a_y, a_s = jax.numpy.zeros_like(qs)
a_s = 1/2*kn*(1 + jax.numpy.sin(2*jax.numpy.pi*s/l))*(q_x**2 - q_y**2)
return a_x, a_y, a_s
def scalar(qs:Array, s:Array, l:Array, kn:Array, *args:Array) -> Array:
q_x, q_y, q_s = qs
return jax.numpy.zeros_like(s)
hamiltonian = hamiltonian_factory(vector, scalar)
extended = autonomize(hamiltonian)
[7]:
# Set extended initial condition
Qs = jax.numpy.concat([qs, si.reshape(-1)])
Ps = jax.numpy.concat([ps, -hamiltonian(qs, ps, si, l, kn).reshape(-1)])
QsPs = jax.numpy.hstack([Qs, Ps])
[8]:
# Set tao integration step
integrator = jit(fold(sequence(0, 2**1, [tao(extended)], merge=False)))
[9]:
# Set and compile element
element = jit(nest(int(l/ds), integrator))
out = element(QsPs, ds, si, l, kn)
print(out)
[ 7.235126271969e-03 -4.257064841463e-03 -2.199863474965e-05 1.500000000000e+00 9.664543767678e-03 3.998219958148e-03 1.000000000000e-04 9.999624216145e-01]
[10]:
# Generate element from hamiltonian
# Note, phase space extension is handled internaly
element = element_factory(vector=None,
scalar=None,
curvature=None,
torsion=None,
hamiltonian=hamiltonian,
driver=tao,
order=2**1,
iterations=int(l/ds),
autonomous=False)
element = jit(element)
out = element(qsps, l, si, l, kn)
print(out)
[ 7.235126271969e-03 -4.257064841463e-03 -2.199863474965e-05 9.664543767678e-03 3.998219958148e-03 1.000000000000e-04]
[11]:
# Generate element from potential
element = element_factory(vector=vector,
scalar=None,
curvature=None,
torsion=None,
hamiltonian=None,
driver=tao,
order=2**1,
iterations=int(l/ds),
autonomous=False)
element = jit(element)
out = element(qsps, l, si, l, kn)
print(out)
[ 7.235126271969e-03 -4.257064841463e-03 -2.199863474965e-05 9.664543767678e-03 3.998219958148e-03 1.000000000000e-04]
[12]:
%%timeit
# Note, first call also performs compilation
# Subsequent calls use compiled function, but operations like map and jacobian will trigger recompilation
element(qsps, l, si, l, kn).block_until_ready()
1.17 ms ± 15 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)