Example-04: Drift element factory
In this example drift factory is illustrated.
The drift hamiltonian is:
$ \begin{align} & H(q_x, q_y, q_s, p_x, p_y, p_s; s) = \frac{p_s}{\beta} - t(s)(q_x p_y - q_y p_x) - (1 + h(s) q_x) \left(\sqrt{P_s^2 - P_x^2 - P_y^2 - \frac{1}{\beta^2 \gamma^2}} + a_s(q_x, q_y, q_s; s)\right) \\ & \\ & P_s = p_s + 1/\beta - \varphi(q_x, q_y, q_s; s) \\ & P_x = p_x - a_x(q_x, q_y, q_s; s) \\ & P_y = p_y - a_y(q_x, q_y, q_s; s) \\ \\ & (a_x, a_y, a_s) = (0, 0, 0) \\ & \varphi = 0 \\ & t = h = 0 \\ \end{align} $
The constructed element signature is:
def drift(qsps:Array, length:Array) -> Array:
...
Note, by default, exact solution is used instead of hamiltonial based.
[1]:
import jax
from jax import jit
from jax import jacrev
from elementary.util import ptc
from elementary.util import beta
from elementary.drift import drift_factory
jax.numpy.set_printoptions(linewidth=256, precision=12)
[2]:
# Set data type
jax.config.update("jax_enable_x64", True)
[3]:
# Set device
device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)
[4]:
# Set initial condition
(q_x, q_y, q_s) = qs = jax.numpy.array([0.0, 0.0, 0.01])
(p_x, p_y, p_s) = ps = jax.numpy.array([0.001, 0.001, -0.0001])
qsps = jax.numpy.hstack([qs, ps])
[5]:
# Define generic drift element
gamma = 10**3
element = jit(drift_factory(beta=beta(gamma), gamma=gamma))
[6]:
# Compare with PTC
print(res := element(qsps, 1.0))
print(ref := ptc(qsps, 'drift', {'l': 1.0}, gamma=gamma))
print(jax.numpy.allclose(res, ref))
[ 0.00100010101 0.00100010101 0.009998999698 0.001 0.001 -0.0001 ]
[ 0.00100010101 0.00100010101 0.009998999698 0.001 0.001 -0.0001 ]
True
[7]:
# Define generic drift element using hamiltonian *)
gamma = 10**3
element = jit(drift_factory(exact=False, beta=beta(gamma), gamma=gamma))
[8]:
# Compare with PTC
print(res := element(qsps, 1.0))
print(ref := ptc(qsps, 'drift', {'l': 1.0}, gamma=gamma))
print(jax.numpy.allclose(res, ref))
[ 0.00100010101 0.00100010101 0.009998999698 0.001 0.001 -0.0001 ]
[ 0.00100010101 0.00100010101 0.009998999698 0.001 0.001 -0.0001 ]
True
[9]:
# Differentiability
length = jax.numpy.float64(1.0)
print(jacrev(element)(qsps, length))
print()
print(jacrev(element, -1)(qsps, length))
print()
[[ 1.000000000000e+00 0.000000000000e+00 0.000000000000e+00 1.000102010656e+00 1.000303061668e-06 -1.000203531514e-03]
[ 0.000000000000e+00 1.000000000000e+00 0.000000000000e+00 1.000303061668e-06 1.000102010656e+00 -1.000203531514e-03]
[ 0.000000000000e+00 0.000000000000e+00 1.000000000000e+00 -1.000203531514e-03 -1.000203531514e-03 3.000910185458e-06]
[ 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00 1.000000000000e+00 0.000000000000e+00 0.000000000000e+00]
[ 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00 1.000000000000e+00 0.000000000000e+00]
[ 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00 1.000000000000e+00]]
[ 1.000101010353e-03 1.000101010353e-03 -1.000302046226e-06 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00]