Example-12: Alignment errors (curved layout)
In this example alignment errors for curven planar layout are illustrated.
Alignment errors (translations and rotations) are defined with respect to the element entrance frame.
The full transformations sequence:
# forward translations and rotations
qsps = tx(qsps, +dx)
qsps = ty(qsps, +dy)
qsps = tz(qsps, +dz, beta, constant)
qsps = rx(qsps, +wx, beta, constant)
qsps = ry(qsps, +wy, beta, constant)
qsps = rz(qsps, +wz)
# element body transformation
qsps = element(qsps, length, angle, ...)
# finite lenght correction
qsps = ry(qsps, +angle/2, beta, constant)
qsps = tz(qsps, -2.0*length/angle*jax.numpy.sin(angle/2.0), beta, constant)
qsps = ry(qsps, +angle/2, beta, constant)
# inverse translation and rotations
qsps = rz(qsps, -wz)
qsps = ry(qsps, -wy, beta, constant)
qsps = rx(qsps, -wx, beta, constant)
qsps = tz(qsps, -dz, beta, constant)
qsps = ty(qsps, -dy)
qsps = tx(qsps, -dx)
# finite lenght correction
qsps = ry(qsps, -angle/2, beta, constant)
qsps = tz(qsps, +2.0*length/angle*jax.numpy.sin(angle/2.0), beta, constant)
qsps = ry(qsps, -angle/2, beta, constant)
[1]:
import jax
from jax import jit
from jax import jacrev
from elementary.util import ptc
from elementary.util import beta
from elementary.dipole import dipole_factory
from elementary.alignment import alignment_factory
jax.numpy.set_printoptions(linewidth=256, precision=12)
[2]:
# Set data type
jax.config.update("jax_enable_x64", True)
[3]:
# Set device
device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)
[4]:
# Set initial condition
(q_x, q_y, q_s) = qs = jax.numpy.array([-0.01, 0.005, 0.001])
(p_x, p_y, p_s) = ps = jax.numpy.array([0.001, 0.001, -0.0001])
qsps = jax.numpy.hstack([qs, ps])
[5]:
# Define generic dipole element with alignment errors
gamma = 10**3
length = jax.numpy.float64(1.0)
angle = jax.numpy.float64(0.05)
kq_n = jax.numpy.float64(-2.0)
kq_s = jax.numpy.float64(+1.5)
ks_n = jax.numpy.float64(-50.0)
ks_s = jax.numpy.float64(+75.0)
ko_n = jax.numpy.float64(-100.0)
ko_s = jax.numpy.float64(+500.0)
body = dipole_factory(exact=False, multipole=True, beta=beta(gamma), gamma=gamma, order=2**1, iterations=1E3)
xyz_entrance, xyz_exit = alignment_factory(beta=beta(gamma), gamma=gamma, flag=True)
@jit
def element(x, length, angle, kq_n, kq_s, ks_n, ks_s, ko_n, ko_s, dx, dy, dz, wx, wy, wz):
x = xyz_entrance(x, dx, dy, dz, wx, wy, wz)
x = body(x, length, angle, kq_n, kq_s, ks_n, ks_s, ko_n, ko_s)
x = xyz_exit(x, dx, dy, dz, wx, wy, wz, length, angle)
return x
[6]:
# Set alignment errors
dx, dy, dz = jax.numpy.array([0.05, -0.02, 0.05])
wx, wy, wz = jax.numpy.array([0.005, -0.005, 0.1])
[7]:
# Compare with PTC
print(res := element(qsps, length, angle, kq_n, kq_s, ks_n, ks_s, ko_n, ko_s, dx, dy, dz, wx, wy, wz))
print(ref := ptc(qsps, 'sbend', {'l': float(length), 'angle': float(angle), 'knl': f'{{0.0,{float(kq_n*length)}, {float(ks_n*length)}, {float(ko_n*length)}}}', 'ksl': f'{{0.0,{float(kq_s*length)}, {float(ks_s*length)}, {float(ko_s*length)}}}', 'kill_ent_fringe': 'true', 'kill_exi_fringe': 'true'}, gamma=gamma, tx=float(dx), ty=float(dy), tz=float(dz), rx=float(wx), ry=float(wy), rz=float(wz)))
print(jax.numpy.allclose(res, ref))
[-6.489404977601e-02 2.078529965405e-02 -2.211224419908e-04 -1.450307156643e-01 8.291770017251e-02 -1.000000000000e-04]
[-6.489404978720e-02 2.078529965695e-02 -2.211225001033e-04 -1.450307157164e-01 8.291770017969e-02 -1.000000000000e-04]
True