Example-12: TM010 cavity factory

In this example cavity factory is illustrated.

The TM010 cavity hamiltonian is:

$ \begin{align} & H(q_x, q_y, q_s, p_x, p_y, p_s; s) = \frac{p_s}{\beta} - t(s)(q_x p_y - q_y p_x) - (1 + h(s) q_x) \left(\sqrt{P_s^2 - P_x^2 - P_y^2 - \frac{1}{\beta^2 \gamma^2}} + a_s(q_x, q_y, q_s; s)\right) \\ & \\ & P_s = p_s + 1/\beta - \varphi(q_x, q_y, q_s; s) \\ & P_x = p_x - a_x(q_x, q_y, q_s; s) \\ & P_y = p_y - a_y(q_x, q_y, q_s; s) \\ \\ & (a_x, a_y, a_s) = (0, 0, \frac{q E}{P \omega} J(k r) \cos(k\left(\frac{s}{\beta} - q_s \right) + \phi))\\ & r^2 = q_x^2 + q_y^2 & \\ & \varphi = 0 \\ & t = h = 0 \\ \end{align} $

The constructed element signature is:

# kick
def cavity(qsps:Array, length:Array, kn:Array, ks:Array) -> Array:
    ...

# main
def cavity(qsps:Array, length:Array, kn:Array, ks:Array) -> Array:
    ...

Note, by default only energy kick is performed and cavity has zero length.

[1]:
import jax
from jax import jit
from jax import jacrev

from elementary.util import ptc
from elementary.util import CL
from elementary.util import ME
from elementary.util import beta
from elementary.util import gamma
from elementary.util import rigidity

from elementary.cavity import cavity_factory

from matplotlib import pyplot as plt

jax.numpy.set_printoptions(linewidth=256, precision=12)
[2]:
# Set data type

jax.config.update("jax_enable_x64", True)
[3]:
# Set device

device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)
[4]:
# Define cavity elements

gamma = 1000.0

element_kick = jit(cavity_factory(rigidity(beta(gamma), gamma, ME, 1), kind='kick', beta=beta(gamma), gamma=gamma, order=2**1, iterations=100))
element_main = jit(cavity_factory(rigidity(beta(gamma), gamma, ME, 1), kind='main', beta=beta(gamma), gamma=gamma, order=2**1, iterations=100))
[5]:
# Set zero initial condition

(q_x, q_y, q_s) = qs = jax.numpy.array([0., 0., 0.])
(p_x, p_y, p_s) = ps = jax.numpy.array([0., 0., 0.])
qsps = jax.numpy.hstack([qs, ps])
[6]:
# Set parameters

length = jax.numpy.float64(1.0)
voltage = jax.numpy.float64(5.0E+5)*1E-6
frequency = jax.numpy.float64(2.0E+08)*1E-6
lag = jax.numpy.float64(90*jax.numpy.pi/180)
[7]:
# Compute energy gain

print((1E+6*voltage)/(rigidity(beta(gamma), gamma, ME, 1)*CL)*jax.numpy.sin(lag))
0.0009784760796965218
[8]:
# Compare with PTC

print(element_kick(qsps, voltage, lag))
print(element_main(qsps, length, voltage, frequency, lag))

print(ref := ptc(qsps, 'rfcavity', {'l': float(length), 'no_cavity_totalpath': 'true', 'n_bessel': 10, 'freq': float(frequency), 'volt': float(voltage), 'lag': float(lag/(2*jax.numpy.pi))}, gamma=gamma))
print()
[0.            0.            0.            0.            0.            0.00097847608]
[-1.082847985127e-18 -1.082847985127e-18  4.885907984818e-10 -1.490078219212e-21 -1.490078219212e-21  9.784760807539e-04]
[0.000000000000e+00 0.000000000000e+00 4.887602574200e-10 0.000000000000e+00 0.000000000000e+00 9.784760810177e-04]

[9]:
# Differentiability

matrix = jax.jacrev(element_main)(qsps, length, voltage, frequency, lag)

print(matrix)
print(jax.numpy.linalg.det(matrix))
[[ 9.989171520149e-01 -5.721517742631e-20 -4.295716685495e-18  9.995109978134e-01 -1.059978537520e-20  1.082475290913e-18]
 [-5.721517742631e-20  9.989171520149e-01 -4.295716685495e-18 -1.059978537520e-20  9.995109978134e-01  1.082475290913e-18]
 [-1.493632396068e-21 -1.493632396068e-21  9.999999978367e-01  1.082158489132e-18  1.082158489132e-18  9.985358733929e-07]
 [-1.490078219213e-06  6.105199573872e-20 -8.596042621954e-18  1.001082530852e+00  7.757678047506e-21  1.485053344709e-21]
 [ 6.105199573872e-20 -1.490078219213e-06 -8.596042621954e-18  7.757678047506e-21  1.001082530852e+00  1.485053344710e-21]
 [-8.586740834050e-18 -8.586740834050e-18 -6.769838121622e-12 -4.291472216269e-18 -4.291472216269e-18  1.000000002163e+00]]
1.0
[10]:
# Scan initial lag

lags = 2*jax.numpy.pi*jax.numpy.linspace(0.0, 1.0, 101) ;

out_kick = [element_kick(qsps, voltage, lag) for lag in lags]
*_, out_kick = jax.numpy.stack(out_kick).T

out_main = [element_main(qsps, length, voltage, frequency, lag) for lag in lags]
*_, out_main = jax.numpy.stack(out_main).T

out_ptc = [ptc(qsps, 'rfcavity', {'no_cavity_totalpath': 'true', 'l': float(length), 'n_bessel': 10, 'freq': float(frequency), 'volt': float(voltage), 'lag': float(lag/(2*jax.numpy.pi))}, gamma=gamma) for lag in lags]
*_, out_ptc = jax.numpy.stack(out_ptc).T
[11]:
# Plot energy change vs lag

plt.figure(figsize=(16, 4))
plt.plot(lags, out_ptc, color='black', label='ptc', lw=5)
plt.plot(lags, out_kick, color='red', label='kick')
plt.plot(lags, out_main, color='blue', label='main')
plt.gca().axhline(0, color='black', alpha=0.5)
plt.gca().axvline(0, color='black', alpha=0.5)
plt.gca().axvline(jax.numpy.pi/2, color='black', alpha=0.5)
plt.gca().axvline(jax.numpy.pi, color='black', alpha=0.5)
plt.gca().axvline(3*jax.numpy.pi/2, color='black', alpha=0.5)
plt.gca().axvline(2*jax.numpy.pi, color='black', alpha=0.5)
plt.legend()
plt.show()
../_images/examples_example-12_12_0.png
[12]:
# Set non-zero initial condition

(q_x, q_y, q_s) = qs = jax.numpy.array([0., 0., -0.001])
(p_x, p_y, p_s) = ps = jax.numpy.array([0.001, 0.001, 0.0001])
qsps = jax.numpy.hstack([qs, ps])
[13]:
# Set parameters

length = jax.numpy.float64(1.0)
voltage = jax.numpy.float64(5.0E+5)*1E-6
frequency = jax.numpy.float64(2.0E+08)*1E-6
lag = jax.numpy.float64(90*jax.numpy.pi/180)
[14]:
# Compute energy gain

print((1E+6*voltage)/(rigidity(beta(gamma), gamma, ME, 1)*CL)*jax.numpy.sin(lag))
0.0009784760796965218
[15]:
# Compare with PTC

print(element_kick(qsps, voltage, lag))
print(element_main(qsps, length, voltage, frequency, lag))

print(ref := ptc(qsps, 'rfcavity', {'l': float(length), 'no_cavity_totalpath': 'true', 'n_bessel': 10, 'freq': float(frequency), 'volt': float(voltage), 'lag': float(lag/(2*jax.numpy.pi))}, gamma=gamma))
print()
[ 0.             0.            -0.001          0.001          0.001          0.00107847608]
[ 0.00099941211   0.00099941211  -0.001000998234  0.001001086708  0.001001086708  0.001078463202]
[ 0.000999412549  0.000999412549 -0.001000998237  0.001000001075  0.001000001075  0.001078466761]