Example-09: Dipole element factory

In this example dipole factory is illustrated.

The dipole hamiltonian is:

$ \begin{align} & H(q_x, q_y, q_s, p_x, p_y, p_s; s) = \frac{p_s}{\beta} - t(s)(q_x p_y - q_y p_x) - (1 + h(s) q_x) \left(\sqrt{P_s^2 - P_x^2 - P_y^2 - \frac{1}{\beta^2 \gamma^2}} + a_s(q_x, q_y, q_s; s)\right) \\ & \\ & P_s = p_s + 1/\beta - \varphi(q_x, q_y, q_s; s) \\ & P_x = p_x - a_x(q_x, q_y, q_s; s) \\ & P_y = p_y - a_y(q_x, q_y, q_s; s) \\ \\ & (a_x, a_y, a_s) = (0, 0, -\frac{1}{1 + q_x/\rho}\left(\frac{q_x}{ \rho} + \frac{q_x^2}{2 \rho^2} \right))\\ & \varphi = 0 \\ & t = 0 \\ & h = \frac{1}{\rho} = \frac{\alpha}{l} \end{align} $

The constructed element signature is:

def dipole(qsps:Array, length:Array, angle:Array) -> Array:
    ...

Note, no fringe effects are icluded.

By default, exact solution is used to transfrom initial conditions.

[1]:
import jax
from jax import jit
from jax import jacrev

from elementary.util import ptc
from elementary.util import beta
from elementary.dipole import dipole_factory

jax.numpy.set_printoptions(linewidth=256, precision=12)
[2]:
# Set data type

jax.config.update("jax_enable_x64", True)
[3]:
# Set device

device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)
[4]:
# Set initial condition

(q_x, q_y, q_s) = qs = jax.numpy.array([-0.01, 0.005, 0.002])
(p_x, p_y, p_s) = ps = jax.numpy.array([0.001, 0.001, -0.0005])
qsps = jax.numpy.hstack([qs, ps])
[5]:
# Define generic dipole element

gamma = 10**3
element = jit(dipole_factory(beta=beta(gamma), gamma=gamma, order=2**1, iterations=100))
[6]:
# Compare with PTC

length = jax.numpy.float64(2.0)
angle = jax.numpy.float64(0.05)

print(res := element(qsps, length, angle))
print(ref := ptc(qsps, 'sbend', {'l': float(length), 'angle': float(angle), 'kill_ent_fringe': 'true', 'kill_exi_fringe': 'true'}, gamma=gamma))
print(jax.numpy.allclose(res, ref))
[-0.008012832402  0.007000552059  0.00244821723   0.000986205451  0.001          -0.0005        ]
[-0.008012832402  0.007000552059  0.00244821723   0.000986205451  0.001          -0.0005        ]
True
[7]:
# Define generic dipole element using hamiltonian

gamma = 10**3
element = jit(dipole_factory(exact=False, beta=beta(gamma), gamma=gamma, order=2**1, iterations=100))
[8]:
# Compare with PTC

length = jax.numpy.float64(2.0)
angle = jax.numpy.float64(0.05)

print(res := element(qsps, length, angle))
print(ref := ptc(qsps, 'sbend', {'l': float(length), 'angle': float(angle), 'kill_ent_fringe': 'true', 'kill_exi_fringe': 'true'}, gamma=gamma))
print(jax.numpy.allclose(res, ref))
[-0.008012832402  0.007000552059  0.002448217234  0.000986205451  0.001          -0.0005        ]
[-0.008012832402  0.007000552059  0.00244821723   0.000986205451  0.001          -0.0005        ]
True
[9]:
# Differentiability

print(jacrev(element)(qsps, length, angle))
print()

print(jacrev(element, 1)(qsps, length, angle))
print()
[[ 9.987995748301e-01  0.000000000000e+00  0.000000000000e+00  1.999720113524e+00 -4.804052661206e-05  4.801653036904e-02]
 [ 5.000422073929e-05  1.000000000000e+00  0.000000000000e+00  5.201635774981e-05  2.000554060243e+00 -2.000724168894e-03]
 [-4.997924363105e-02  0.000000000000e+00  1.000000000000e+00 -5.199037557913e-02 -2.000724168894e-03 -8.272513473157e-04]
 [-1.249479231765e-03  0.000000000000e+00  0.000000000000e+00  9.987002561736e-01 -5.000422142308e-05  4.997924431449e-02]
 [ 0.000000000000e+00  0.000000000000e+00  0.000000000000e+00  0.000000000000e+00  1.000000000000e+00  0.000000000000e+00]
 [ 0.000000000000e+00  0.000000000000e+00  0.000000000000e+00  0.000000000000e+00  0.000000000000e+00  1.000000000000e+00]]

[ 9.875816733888e-04  1.000526050363e-03 -2.578760124653e-05 -6.247396158822e-06  0.000000000000e+00  0.000000000000e+00]