Example-02: Hamiltonian factory

In this example non-autonomous generic hamiltonial symplectic integration is illustrated. `sympint <https://github.com/i-a-morozov/sympint>`__ JAX based library is used to perfrom integration. elementary.hamiltonian provides hamiltonian_factory that can be used to construct a generic accelerator element hamiltonian with the following signature:

def hamiltonian(qs: Array, ps: Array, s: Array, *args: Array) -> Array:
    q_x, q_y, q_s = qs
    p_x, p_y, p_s = ps
    ...

In the most general case one needs to pass vector (required), scalar, torsion and curvature functions with signatures:

def vector(qs:Array, s:Array, *args:Array) -> tuple[Array, Array, Array]:
  q_x, q_y, q_s = qs
  ...

def scalar(qs:Array, s:Array, *args:Array) -> Array:
  q_x, q_y, q_s = qs
  ...

def curvature(s:Array, *args:Array) -> Array:
  ...

def torsion(s:Array, *args:Array) -> Array:
  ...

Note, *args are expexted to match between all the above functions.

Explicitly, the accelerator hamiltonian is:

$ \begin{align} & H(q_x, q_y, q_s, p_x, p_y, p_s; s) = \frac{p_s}{\beta} - t(s)(q_x p_y - q_y p_x) - (1 + h(s) q_x) \left(\sqrt{P_s^2 - P_x^2 - P_y^2 - \frac{1}{\beta^2 \gamma^2}} + a_s(q_x, q_y, q_s; s)\right) \\ & \\ & P_s = p_s + 1/\beta - \varphi(q_x, q_y, q_s; s) \\ & P_x = p_x - a_x(q_x, q_y, q_s; s) \\ & P_y = p_y - a_y(q_x, q_y, q_s; s) \end{align} $

where \(\beta\) and \(\gamma\) are the relativistic factors, \(h(s)\) is the reference trajectory curvature and \(t(s)\) is the reference trajectory torsion, \(a_x(q_x, q_y, q_s; s)\), \(a_y(q_x, q_y, q_s; s)\) and \(a_s(q_x, q_y, q_s; s)\) are the scaled vector potential components, and \(\varphi(q_x, q_y, q_s; s)\) is the scaled scalar potential. Additionaly, longitudinal coordinate and momentum are given by:

$ \begin{align} & q_s = \frac{s}{\beta} - c t \\ & p_s = \frac{E}{c P} - \frac{1}{\beta} \end{align} $

The expression for \(q_s\) should be used to replace explicit time dependence.

As an example, the following non-autonomous hamiltonian is used:

$ \begin{align} & H = p_s - \left(\sqrt{(1 + p_s)^2 - p_x^2 - p_y^2} + \frac{1}{2} k_n \left(1 + \sin\left(2 \pi \frac{s}{l}\right)\right) \left(q_x^2 - q_y^2\right)\right) \end{align} $

[1]:
# Import

import jax
from jax import Array
from jax import jit
from jax import jacrev

from elementary import fold
from elementary import nest
from elementary import tao
from elementary import midpoint
from elementary import sequence

from elementary.hamiltonian import hamiltonian_factory
from elementary.hamiltonian import autonomize

jax.numpy.set_printoptions(linewidth=256, precision=12)
[2]:
# Set data type

jax.config.update("jax_enable_x64", True)
[3]:
# Set device

device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)
[4]:
# Set parameters (initial independent parameter value, integration step length and quadrupole amplitude)

si = jax.numpy.float64(0.5)
ds = jax.numpy.float64(0.01)

l = jax.numpy.float64(1.0)
kn = jax.numpy.float64(1.0)
[5]:
# Set initial condition

qs = jax.numpy.array([0.001, -0.005, 0.0])
ps = jax.numpy.array([0.005, -0.001, 0.0001])

qsps = jax.numpy.hstack([qs, ps])
[6]:
# Define non-autonomous and extended hamiltonian (explicit)

def hamiltonian(qs, ps, s, l, kn, *args):
    q_x, q_y, q_s = qs
    p_x, p_y, p_s = ps
    return p_s - (jax.numpy.sqrt((1 + p_s)**2 - p_x**2 - p_y**2) + 1/2*kn*(1 + jax.numpy.sin(2*jax.numpy.pi*s/l))*(q_x**2 - q_y**2))

def extended(qs, ps, s, l, kn, *args):
    q_x, q_y, q_s, q_t = qs
    p_x, p_y, p_s, p_t = ps
    return p_t + (p_s - jax.numpy.sqrt((1 + p_s)**2 - p_x**2 - p_y**2) - 1/2*kn*(1 + jax.numpy.sin(2*jax.numpy.pi*q_t/l))*(q_x**2 - q_y**2))
[7]:
# Set extended initial condition

Qs = jax.numpy.concat([qs, si.reshape(-1)])
Ps = jax.numpy.concat([ps, -hamiltonian(qs, ps, si, l, kn).reshape(-1)])
QsPs = jax.numpy.hstack([Qs, Ps])
[8]:
# Set implicit midpoint integration step

integrator = jit(fold(sequence(0, 2**1, [midpoint(extended, ns=2**1)], merge=False)))
[9]:
# Set and compile element

element = jit(nest(int(l/ds), integrator))
out = element(QsPs, ds, si, l, kn)
print(out)
[ 7.235126271922e-03 -4.257064841347e-03 -2.199863474715e-05  1.500000000000e+00  9.664543767467e-03  3.998219958383e-03  1.000000000000e-04  9.999624216145e-01]
[10]:
# Set tao integration step

integrator = jit(fold(sequence(0, 2**1, [tao(extended)], merge=False)))
[11]:
# Set and compile element

element = jit(nest(int(l/ds), integrator))
out = element(QsPs, ds, si, l, kn)
print(out)
[ 7.235126271969e-03 -4.257064841463e-03 -2.199863474965e-05  1.500000000000e+00  9.664543767678e-03  3.998219958148e-03  1.000000000000e-04  9.999624216145e-01]
[12]:
# Define non-autonomous and extended hamiltonian (factory)

def vector(qs:Array, s:Array, l:Array, kn:Array, *args:Array) -> tuple[Array, Array, Array]:
    q_x, q_y, q_s = qs
    a_x, a_y, a_s = jax.numpy.zeros_like(qs)
    a_s = 1/2*kn*(1 + jax.numpy.sin(2*jax.numpy.pi*s/l))*(q_x**2 - q_y**2)
    return a_x, a_y, a_s

def scalar(qs:Array, s:Array, l:Array, kn:Array, *args:Array) -> Array:
    q_x, q_y, q_s = qs
    return jax.numpy.zeros_like(s)

hamiltonian = hamiltonian_factory(vector, scalar)
extended = autonomize(hamiltonian)
[13]:
# Set extended initial condition

Qs = jax.numpy.concat([qs, si.reshape(-1)])
Ps = jax.numpy.concat([ps, -hamiltonian(qs, ps, si, l, kn).reshape(-1)])
QsPs = jax.numpy.hstack([Qs, Ps])
[14]:
# Set implicit midpoint integration step

integrator = jit(fold(sequence(0, 2**1, [midpoint(extended, ns=2**1)], merge=False)))
[15]:
# Set and compile element

element = jit(nest(int(l/ds), integrator))
out = element(QsPs, ds, si, l, kn)
print(out)
[ 7.235126271922e-03 -4.257064841347e-03 -2.199863474715e-05  1.500000000000e+00  9.664543767467e-03  3.998219958383e-03  1.000000000000e-04  9.999624216145e-01]
[16]:
# Set tao integration step

integrator = jit(fold(sequence(0, 2**1, [tao(extended)], merge=False)))
[17]:
# Set and compile element

element = jit(nest(int(l/ds), integrator))
out = element(QsPs, ds, si, l, kn)
print(out)
[ 7.235126271969e-03 -4.257064841463e-03 -2.199863474965e-05  1.500000000000e+00  9.664543767678e-03  3.998219958148e-03  1.000000000000e-04  9.999624216145e-01]
[18]:
# Diffirentiability (initial condition)

matrix = jacrev(element)(QsPs, ds, si, l, kn)

print(matrix)
print(jax.numpy.linalg.det(matrix))
[[ 1.366527779203e+00 -5.802223397334e-06  0.000000000000e+00  1.658258296648e-03  1.173782172073e+00  3.024552512394e-06 -7.166765586369e-03  0.000000000000e+00]
 [ 6.891863981574e-06  6.833771361996e-01  0.000000000000e+00  4.804596565370e-04  1.099963935945e-05  8.401702767985e-01 -7.785286136285e-04  0.000000000000e+00]
 [-2.784634943647e-03  6.697396020357e-04  1.000000000000e+00 -1.788345235604e-05 -7.636376226265e-03 -3.158011828099e-04  5.159686903021e-05  0.000000000000e+00]
 [ 0.000000000000e+00  0.000000000000e+00  0.000000000000e+00  1.000000000000e+00  0.000000000000e+00  0.000000000000e+00  0.000000000000e+00  0.000000000000e+00]
 [ 1.134284155157e+00 -9.555360071284e-07  0.000000000000e+00 -1.418791282023e-04  1.706077662573e+00 -1.231406521087e-07 -3.911016696259e-03  0.000000000000e+00]
 [-1.121934780525e-06 -8.767306145299e-01  0.000000000000e+00  1.453245809425e-03 -1.267627972649e-06  3.854343128509e-01  1.873697940944e-05  0.000000000000e+00]
 [ 0.000000000000e+00  0.000000000000e+00  0.000000000000e+00  0.000000000000e+00  0.000000000000e+00  0.000000000000e+00  1.000000000000e+00  0.000000000000e+00]
 [-2.074807326405e-03  1.414351057111e-03  0.000000000000e+00 -1.733134500028e-04 -2.995636035736e-03  1.035788071469e-03  2.424534688640e-05  1.000000000000e+00]]
1.0
[19]:
# Diffirentiability (parameter)

print(jacrev(element, argnums=-1)(QsPs, ds, si, l, kn))
[ 1.298574427115e-03  1.618721745092e-03 -1.373594876871e-05  0.000000000000e+00  5.045563129602e-03  4.364921756032e-03  0.000000000000e+00 -2.081340401732e-05]