Example-11: Alignment errors (straight layout)

In this example alignment errors for straight layout are illustrated.

Alignment errors (translations and rotations) are defined with respect to the element entrance frame.

The full transformations sequence:

# forward translations and rotations
qsps = tx(qsps, +dx)
qsps = ty(qsps, +dy)
qsps = tz(qsps, +dz, beta, constant)
qsps = rx(qsps, +wx, beta, constant)
qsps = ry(qsps, +wy, beta, constant)
qsps = rz(qsps, +wz)

# element body transformation
qsps = element(qsps, length, ...)

# finite lenght correction
qsps = tz(qsps, -length, beta=beta, constant=constant)

# inverse translation and rotations
qsps = rz(qsps, -wz)
qsps = ry(qsps, -wy, beta, constant)
qsps = rx(qsps, -wx, beta, constant)
qsps = tz(qsps, -dz, beta, constant)
qsps = ty(qsps, -dy)
qsps = tx(qsps, -dx)

# finite lenght correction
qsps = tz(qsps, +length, beta, constant)
[1]:
import jax
from jax import jit
from jax import jacrev

from elementary.util import ptc
from elementary.util import beta
from elementary.quadrupole import quadrupole_factory
from elementary.alignment import alignment_factory

jax.numpy.set_printoptions(linewidth=256, precision=12)
[2]:
# Set data type

jax.config.update("jax_enable_x64", True)
[3]:
# Set device

device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)
[4]:
# Set initial condition

(q_x, q_y, q_s) = qs = jax.numpy.array([-0.01, 0.005, 0.001])
(p_x, p_y, p_s) = ps = jax.numpy.array([0.001, 0.001, -0.0001])
qsps = jax.numpy.hstack([qs, ps])
[5]:
# Define generic quadrupole element with alignment errors

gamma = 10**3

body = quadrupole_factory(beta=beta(gamma), gamma=gamma, order=2**1, iterations=200)
xyz_entrance, xyz_exit = alignment_factory(beta=beta(gamma), gamma=gamma, flag=False)

@jit
def element(x, length, kn, ks, dx, dy, dz, wx, wy, wz):
    x = xyz_entrance(x, dx, dy, dz, wx, wy, wz)
    x = body(x, length, kn, ks)
    x = xyz_exit(x, dx, dy, dz, wx, wy, wz, length)
    return x
[6]:
# Set alignment errors

dx, dy, dz = jax.numpy.array([0.05, -0.02, 0.05])
wx, wy, wz = jax.numpy.array([0.005, -0.005, 0.1])
[7]:
# Compare with PTC

length = jax.numpy.float64(1.0)
kn = jax.numpy.float64(-2.0)
ks = jax.numpy.float64(+1.5)

print(res := element(qsps, length, kn, ks, dx, dy, dz, wx, wy, wz))
print(ref := ptc(qsps, 'quadrupole', {'l': float(length), 'k1': float(kn), 'k1s': float(ks)}, gamma=gamma, tx=float(dx), ty=float(dy), tz=float(dz), rx=float(wx), ry=float(wy), rz=float(wz)))
print(jax.numpy.allclose(res, ref))
[-4.532564066213e-02 -5.683873885073e-02 -2.649160041102e-03 -1.148938967200e-01 -1.263031959206e-01 -1.000000000000e-04]
[-4.532564066213e-02 -5.683873885073e-02 -2.649160041114e-03 -1.148938967200e-01 -1.263031959206e-01 -1.000000000000e-04]
True