Example-11: Alignment errors (straight layout)
In this example alignment errors for straight layout are illustrated.
Alignment errors (translations and rotations) are defined with respect to the element entrance frame.
The full transformations sequence:
# forward translations and rotations
qsps = tx(qsps, +dx)
qsps = ty(qsps, +dy)
qsps = tz(qsps, +dz, beta, constant)
qsps = rx(qsps, +wx, beta, constant)
qsps = ry(qsps, +wy, beta, constant)
qsps = rz(qsps, +wz)
# element body transformation
qsps = element(qsps, length, ...)
# finite lenght correction
qsps = tz(qsps, -length, beta=beta, constant=constant)
# inverse translation and rotations
qsps = rz(qsps, -wz)
qsps = ry(qsps, -wy, beta, constant)
qsps = rx(qsps, -wx, beta, constant)
qsps = tz(qsps, -dz, beta, constant)
qsps = ty(qsps, -dy)
qsps = tx(qsps, -dx)
# finite lenght correction
qsps = tz(qsps, +length, beta, constant)
[1]:
import jax
from jax import jit
from jax import jacrev
from elementary.util import ptc
from elementary.util import beta
from elementary.quadrupole import quadrupole_factory
from elementary.alignment import alignment_factory
jax.numpy.set_printoptions(linewidth=256, precision=12)
[2]:
# Set data type
jax.config.update("jax_enable_x64", True)
[3]:
# Set device
device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)
[4]:
# Set initial condition
(q_x, q_y, q_s) = qs = jax.numpy.array([-0.01, 0.005, 0.001])
(p_x, p_y, p_s) = ps = jax.numpy.array([0.001, 0.001, -0.0001])
qsps = jax.numpy.hstack([qs, ps])
[5]:
# Define generic quadrupole element with alignment errors
gamma = 10**3
body = quadrupole_factory(beta=beta(gamma), gamma=gamma, order=2**1, iterations=200)
xyz_entrance, xyz_exit = alignment_factory(beta=beta(gamma), gamma=gamma, flag=False)
@jit
def element(x, length, kn, ks, dx, dy, dz, wx, wy, wz):
x = xyz_entrance(x, dx, dy, dz, wx, wy, wz)
x = body(x, length, kn, ks)
x = xyz_exit(x, dx, dy, dz, wx, wy, wz, length)
return x
[6]:
# Set alignment errors
dx, dy, dz = jax.numpy.array([0.05, -0.02, 0.05])
wx, wy, wz = jax.numpy.array([0.005, -0.005, 0.1])
[7]:
# Compare with PTC
length = jax.numpy.float64(1.0)
kn = jax.numpy.float64(-2.0)
ks = jax.numpy.float64(+1.5)
print(res := element(qsps, length, kn, ks, dx, dy, dz, wx, wy, wz))
print(ref := ptc(qsps, 'quadrupole', {'l': float(length), 'k1': float(kn), 'k1s': float(ks)}, gamma=gamma, tx=float(dx), ty=float(dy), tz=float(dz), rx=float(wx), ry=float(wy), rz=float(wz)))
print(jax.numpy.allclose(res, ref))
[-4.532564066213e-02 -5.683873885073e-02 -2.649160041102e-03 -1.148938967200e-01 -1.263031959206e-01 -1.000000000000e-04]
[-4.532564066213e-02 -5.683873885073e-02 -2.649160041114e-03 -1.148938967200e-01 -1.263031959206e-01 -1.000000000000e-04]
True