Example-07: Octupole element factory
In this example octupole factory is illustrated.
The octupole hamiltonian is:
$ \begin{align} & H(q_x, q_y, q_s, p_x, p_y, p_s; s) = \frac{p_s}{\beta} - t(s)(q_x p_y - q_y p_x) - (1 + h(s) q_x) \left(\sqrt{P_s^2 - P_x^2 - P_y^2 - \frac{1}{\beta^2 \gamma^2}} + a_s(q_x, q_y, q_s; s)\right) \\ & \\ & P_s = p_s + 1/\beta - \varphi(q_x, q_y, q_s; s) \\ & P_x = p_x - a_x(q_x, q_y, q_s; s) \\ & P_y = p_y - a_y(q_x, q_y, q_s; s) \\ \\ & (a_x, a_y, a_s) = (0, 0, -\frac{1}{3!} k_n \left(\frac{q_x^4}{4} - 3 q_x^2 q_y^2 + \frac{q_y^4}{4}\right) - \frac{1}{3!} k_s \left(q_x q_y^3 - q_x^3 q_y \right))\\ & \varphi = 0 \\ & t = h = 0 \\ \end{align} $
The constructed element signature is:
def octupole(qsps:Array, length:Array, kn:Array, ks:Array) -> Array:
...
Note, both kn
and ks
should be passed on invocation.
[1]:
import jax
from jax import jit
from jax import jacrev
from elementary.util import ptc
from elementary.util import beta
from elementary.octupole import octupole_factory
jax.numpy.set_printoptions(linewidth=256, precision=12)
[2]:
# Set data type
jax.config.update("jax_enable_x64", True)
[3]:
# Set device
device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)
[4]:
# Set initial condition
(q_x, q_y, q_s) = qs = jax.numpy.array([-0.01, 0.005, 0.001])
(p_x, p_y, p_s) = ps = jax.numpy.array([0.001, 0.001, -0.0001])
qsps = jax.numpy.hstack([qs, ps])
[5]:
# Define generic octupole element
gamma = 10**3
element = jit(octupole_factory(beta=beta(gamma), gamma=gamma, order=2**1, iterations=100))
[6]:
# Compare with PTC
length = jax.numpy.float64(0.2)
kn = jax.numpy.float64(-100.0)
ks = jax.numpy.float64(+500.0)
print(res := element(qsps, length, kn, ks))
print(ref := ptc(qsps, 'octupole', {'l': float(length), 'k3': float(kn), 'k3s': float(ks)}, gamma=gamma))
print(jax.numpy.allclose(res, ref))
[-0.009797768862 0.005199205032 0.000999798525 0.001022114248 0.000992148146 -0.0001 ]
[-0.009797768862 0.005199205032 0.000999798525 0.001022114248 0.000992148146 -0.0001 ]
True
[7]:
# Differentiability
print(jacrev(element)(qsps, length, kn, ks))
print()
print(jacrev(element, 1)(qsps, length, kn, ks))
print()
[[ 9.995697830497e-01 4.656632792258e-04 0.000000000000e+00 1.999915493370e-01 3.093502378623e-05 -2.022535230983e-04]
[ 4.656649771080e-04 1.000430351571e+00 0.000000000000e+00 3.093510854460e-05 2.000492660372e-01 -1.992851672412e-04]
[-2.659897270538e-08 -9.006312695687e-07 1.000000000000e+00 -2.022530833439e-04 -1.992853937288e-04 6.030721251410e-07]
[-4.326806733311e-03 4.608973281245e-03 0.000000000000e+00 9.995646349961e-01 4.562292899969e-04 -1.639276538502e-08]
[ 4.609037834460e-03 4.329471623185e-03 0.000000000000e+00 4.562327538309e-04 1.000435498762e+00 -8.939569416383e-07]
[ 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00 1.000000000000e+00]]
[ 1.022217507094e-03 9.922483777699e-04 -1.014842745316e-06 1.106308301324e-04 -3.477926503128e-05 0.000000000000e+00]