Example-08: Multipole element factory
In this example multipole factory is illustrated.
The multipole hamiltonian is:
$ \begin{align} & H(q_x, q_y, q_s, p_x, p_y, p_s; s) = \frac{p_s}{\beta} - t(s)(q_x p_y - q_y p_x) - (1 + h(s) q_x) \left(\sqrt{P_s^2 - P_x^2 - P_y^2 - \frac{1}{\beta^2 \gamma^2}} + a_s(q_x, q_y, q_s; s)\right) \\ & \\ & P_s = p_s + 1/\beta - \varphi(q_x, q_y, q_s; s) \\ & P_x = p_x - a_x(q_x, q_y, q_s; s) \\ & P_y = p_y - a_y(q_x, q_y, q_s; s) \\ \\ & (a_x, a_y, a_s) = (0, 0, a_{s, q} + a_{s, s} + a_{s, o})\\ & a_{s, q} = -\frac{1}{2} k_{n, q} \left(q_x^2 + q_y^2 \right) + k_{s, q} q_x q_y \\ & a_{s, s} = -\frac{1}{2!} k_{n, s} \left(\frac{q_x^3}{3} - q_x q_y^2 \right) - \frac{1}{2!} k_{s, s} \left(\frac{q_y^3}{3} - q_x^2 q_y\right) \\ & a_{s, o} = -\frac{1}{3!} k_{n, o} \left(\frac{q_x^4}{4} - 3 q_x^2 q_y^2 + \frac{q_y^4}{4}\right) - \frac{1}{3!} k_{s, o} \left(q_x q_y^3 - q_x^3 q_y \right) \\ & \\ & \varphi = 0 \\ & t = h = 0 \\ \end{align} $
The constructed element signature is:
def multipole(qsps:Array, length:Array, kq_n:Array, kq_s:Array, ks_n:Array, ks_s:Array, ko_n:Array, ko_s:Array) -> Array:
...
[1]:
import jax
from jax import jit
from jax import jacrev
from elementary.util import ptc
from elementary.util import beta
from elementary.multipole import multipole_factory
jax.numpy.set_printoptions(linewidth=256, precision=12)
[2]:
# Set data type
jax.config.update("jax_enable_x64", True)
[3]:
# Set device
device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)
[4]:
# Set initial condition
(q_x, q_y, q_s) = qs = jax.numpy.array([-0.01, 0.005, 0.001])
(p_x, p_y, p_s) = ps = jax.numpy.array([0.001, 0.001, -0.0001])
qsps = jax.numpy.hstack([qs, ps])
[5]:
# Define generic multipole element
gamma = 10**3
element = jit(multipole_factory(beta=beta(gamma), gamma=gamma, order=2**1, iterations=100))
[6]:
# Compare with PTC
length = jax.numpy.float64(0.25)
kq_n = jax.numpy.float64(-2.0)
kq_s = jax.numpy.float64(+1.5)
ks_n = jax.numpy.float64(-50.0)
ks_s = jax.numpy.float64(+75.0)
ko_n = jax.numpy.float64(-100.0)
ko_s = jax.numpy.float64(+500.0)
print(res := element(qsps, length, kq_n, kq_s, ks_n, ks_s, ko_n, ko_s))
print(ref := ptc(qsps, 'quadrupole', {'l': float(length), 'knl': f'{{0.0,{float(kq_n*length)}, {float(ks_n*length)}, {float(ko_n*length)}}}', 'ksl': f'{{0.0,{float(kq_s*length)}, {float(ks_s*length)}, {float(ko_s*length)}}}'}, gamma=gamma))
print(jax.numpy.allclose(res, ref))
[-0.010195569406 0.004634631697 0.00099927232 -0.002587384238 -0.003898816771 -0.0001 ]
[-0.010195569406 0.004634631697 0.00099927232 -0.002587384238 -0.003898816771 -0.0001 ]
True
[7]:
# Differentiability
print(jacrev(element)(qsps, length, kq_n, kq_s, ks_n, ks_s, ko_n, ko_s))
print()
print(jacrev(element, 1)(qsps, length, kq_n, kq_s, ks_n, ks_s, ko_n, ko_s))
print()
[[ 1.058513409511e+00 1.641931362563e-02 0.000000000000e+00 2.548728245266e-01 1.371326922699e-03 1.953577918270e-04]
[ 1.642278685252e-02 9.426914513883e-01 0.000000000000e+00 1.371531111411e-03 2.452402248722e-01 3.641389680595e-04]
[ 1.183449134442e-04 -1.074855507550e-04 1.000000000000e+00 2.074102528591e-04 3.548724731577e-04 1.701670012363e-06]
[ 4.713919361348e-01 1.315473343451e-01 0.000000000000e+00 1.058139548857e+00 1.647371074410e-02 2.099930109674e-05]
[ 1.316213051530e-01 -4.522013363610e-01 0.000000000000e+00 1.647803990033e-02 9.430571887037e-01 -3.375926930070e-05]
[ 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00 1.000000000000e+00]]
[-2.587671336735e-03 -3.899249388134e-03 -1.095013990866e-05 -1.481596650450e-02 -1.916336194105e-02 0.000000000000e+00]