Example-05: Quadrupole element factory

In this example quadrupole factory is illustrated.

The quadrupole hamiltonian is:

$ \begin{align} & H(q_x, q_y, q_s, p_x, p_y, p_s; s) = \frac{p_s}{\beta} - t(s)(q_x p_y - q_y p_x) - (1 + h(s) q_x) \left(\sqrt{P_s^2 - P_x^2 - P_y^2 - \frac{1}{\beta^2 \gamma^2}} + a_s(q_x, q_y, q_s; s)\right) \\ & \\ & P_s = p_s + 1/\beta - \varphi(q_x, q_y, q_s; s) \\ & P_x = p_x - a_x(q_x, q_y, q_s; s) \\ & P_y = p_y - a_y(q_x, q_y, q_s; s) \\ \\ & (a_x, a_y, a_s) = (0, 0, -\frac{1}{2} k_n \left(q_x^2 + q_y^2 \right) + k_s q_x q_y)\\ & \varphi = 0 \\ & t = h = 0 \\ \end{align} $

The constructed element signature is:

def quadrupole(qsps:Array, length:Array, kn:Array, ks:Array) -> Array:
    ...

Note, both kn and ks should be passed on invocation.

[1]:
import jax
from jax import jit
from jax import jacrev

from elementary.util import ptc
from elementary.util import beta
from elementary.quadrupole import quadrupole_factory

jax.numpy.set_printoptions(linewidth=256, precision=12)
[2]:
# Set data type

jax.config.update("jax_enable_x64", True)
[3]:
# Set device

device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)
[4]:
# Set initial condition

(q_x, q_y, q_s) = qs = jax.numpy.array([-0.01, 0.005, 0.001])
(p_x, p_y, p_s) = ps = jax.numpy.array([0.001, 0.001, -0.0001])
qsps = jax.numpy.hstack([qs, ps])
[5]:
# Define generic quadrupole element

gamma = 10**3
element = jit(quadrupole_factory(beta=beta(gamma), gamma=gamma, order=2**1, iterations=100))
[6]:
# Compare with PTC

length = jax.numpy.float64(1.0)
kn = jax.numpy.float64(-2.0)
ks = jax.numpy.float64(+1.5)

print(res := element(qsps, length, kn, ks))
print(ref := ptc(qsps, 'quadrupole', {'l': float(length), 'k1': float(kn), 'k1s': float(ks)}, gamma=gamma))
print(jax.numpy.allclose(res, ref))
[-0.017335308274 -0.005444917174  0.000881617913 -0.020606852812 -0.020050030434 -0.0001        ]
[-0.017335308274 -0.005444917174  0.000881617913 -0.020606852812 -0.020050030434 -0.0001        ]
True
[7]:
# Differentiability

print(jacrev(element)(qsps, length, kn, ks))
print()

print(jacrev(element, 1)(qsps, length, kn, ks))
print()
[[ 2.279575304234e+00  7.632479527605e-01  0.000000000000e+00  1.388655261371e+00  2.520951472455e-01  9.491025360004e-03]
 [ 7.636609190032e-01  2.438324623808e-01  0.000000000000e+00  2.522344406377e-01  7.165998020053e-01  9.548555445061e-03]
 [ 2.540182886809e-02 -1.691652080606e-03  1.000000000000e+00  1.634207518714e-02  9.283955945891e-03  2.542783581185e-04]
 [ 3.154287926464e+00  1.578288131763e+00  0.000000000000e+00  2.279171804490e+00  7.632559539870e-01  9.596399325345e-03]
 [ 1.578258960753e+00 -1.054777136190e+00  0.000000000000e+00  7.631214548048e-01  2.439057831974e-01 -2.972554738619e-03]
 [ 0.000000000000e+00  0.000000000000e+00  0.000000000000e+00  0.000000000000e+00  0.000000000000e+00  1.000000000000e+00]]

[-0.020617438835 -0.020060330409 -0.000413662569 -0.042837992308 -0.015113128063  0.            ]