Example-02: Hamiltonian factory
In this example non-autonomous generic hamiltonial symplectic integration is illustrated. `sympint
<https://github.com/i-a-morozov/sympint>`__ JAX based library is used to perfrom integration. elementary.hamiltonian
provides hamiltonian_factory
that can be used to construct a generic accelerator element hamiltonian with the following signature:
def hamiltonian(qs: Array, ps: Array, s: Array, *args: Array) -> Array:
q_x, q_y, q_s = qs
p_x, p_y, p_s = ps
...
In the most general case one needs to pass vector
(required), scalar
, torsion
and curvature
functions with signatures:
def vector(qs:Array, s:Array, *args:Array) -> tuple[Array, Array, Array]:
q_x, q_y, q_s = qs
...
def scalar(qs:Array, s:Array, *args:Array) -> Array:
q_x, q_y, q_s = qs
...
def curvature(s:Array, *args:Array) -> Array:
...
def torsion(s:Array, *args:Array) -> Array:
...
Note, *args
are expexted to match between all the above functions.
Explicitly, the accelerator hamiltonian is:
$ \begin{align} & H(q_x, q_y, q_s, p_x, p_y, p_s; s) = \frac{p_s}{\beta} - t(s)(q_x p_y - q_y p_x) - (1 + h(s) q_x) \left(\sqrt{P_s^2 - P_x^2 - P_y^2 - \frac{1}{\beta^2 \gamma^2}} + a_s(q_x, q_y, q_s; s)\right) \\ & \\ & P_s = p_s + 1/\beta - \varphi(q_x, q_y, q_s; s) \\ & P_x = p_x - a_x(q_x, q_y, q_s; s) \\ & P_y = p_y - a_y(q_x, q_y, q_s; s) \end{align} $
where \(\beta\) and \(\gamma\) are the relativistic factors, \(h(s)\) is the reference trajectory curvature and \(t(s)\) is the reference trajectory torsion, \(a_x(q_x, q_y, q_s; s)\), \(a_y(q_x, q_y, q_s; s)\) and \(a_s(q_x, q_y, q_s; s)\) are the scaled vector potential components, and \(\varphi(q_x, q_y, q_s; s)\) is the scaled scalar potential. Additionaly, longitudinal coordinate and momentum are given by:
$ \begin{align} & q_s = \frac{s}{\beta} - c t \\ & p_s = \frac{E}{c P} - \frac{1}{\beta} \end{align} $
The expression for \(q_s\) should be used to replace explicit time dependence.
As an example, the following non-autonomous hamiltonian is used:
$ \begin{align} & H = p_s - \left(\sqrt{(1 + p_s)^2 - p_x^2 - p_y^2} + \frac{1}{2} k_n \left(1 + \sin\left(2 \pi \frac{s}{l}\right)\right) \left(q_x^2 - q_y^2\right)\right) \end{align} $
[1]:
# Import
import jax
from jax import Array
from jax import jit
from jax import jacrev
from elementary import fold
from elementary import nest
from elementary import tao
from elementary import midpoint
from elementary import sequence
from elementary.hamiltonian import hamiltonian_factory
from elementary.hamiltonian import autonomize
jax.numpy.set_printoptions(linewidth=256, precision=12)
[2]:
# Set data type
jax.config.update("jax_enable_x64", True)
[3]:
# Set device
device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)
[4]:
# Set parameters (initial independent parameter value, integration step length and quadrupole amplitude)
si = jax.numpy.float64(0.5)
ds = jax.numpy.float64(0.01)
l = jax.numpy.float64(1.0)
kn = jax.numpy.float64(1.0)
[5]:
# Set initial condition
qs = jax.numpy.array([0.001, -0.005, 0.0])
ps = jax.numpy.array([0.005, -0.001, 0.0001])
qsps = jax.numpy.hstack([qs, ps])
[6]:
# Define non-autonomous and extended hamiltonian (explicit)
def hamiltonian(qs, ps, s, l, kn, *args):
q_x, q_y, q_s = qs
p_x, p_y, p_s = ps
return p_s - (jax.numpy.sqrt((1 + p_s)**2 - p_x**2 - p_y**2) + 1/2*kn*(1 + jax.numpy.sin(2*jax.numpy.pi*s/l))*(q_x**2 - q_y**2))
def extended(qs, ps, s, l, kn, *args):
q_x, q_y, q_s, q_t = qs
p_x, p_y, p_s, p_t = ps
return p_t + (p_s - jax.numpy.sqrt((1 + p_s)**2 - p_x**2 - p_y**2) - 1/2*kn*(1 + jax.numpy.sin(2*jax.numpy.pi*q_t/l))*(q_x**2 - q_y**2))
[7]:
# Set extended initial condition
Qs = jax.numpy.concat([qs, si.reshape(-1)])
Ps = jax.numpy.concat([ps, -hamiltonian(qs, ps, si, l, kn).reshape(-1)])
QsPs = jax.numpy.hstack([Qs, Ps])
[8]:
# Set implicit midpoint integration step
integrator = jit(fold(sequence(0, 2**1, [midpoint(extended, ns=2**1)], merge=False)))
[9]:
# Set and compile element
element = jit(nest(int(l/ds), integrator))
out = element(QsPs, ds, si, l, kn)
print(out)
[ 7.235126271922e-03 -4.257064841347e-03 -2.199863474715e-05 1.500000000000e+00 9.664543767467e-03 3.998219958383e-03 1.000000000000e-04 9.999624216145e-01]
[10]:
# Set tao integration step
integrator = jit(fold(sequence(0, 2**1, [tao(extended)], merge=False)))
[11]:
# Set and compile element
element = jit(nest(int(l/ds), integrator))
out = element(QsPs, ds, si, l, kn)
print(out)
[ 7.235126271969e-03 -4.257064841463e-03 -2.199863474965e-05 1.500000000000e+00 9.664543767678e-03 3.998219958148e-03 1.000000000000e-04 9.999624216145e-01]
[12]:
# Define non-autonomous and extended hamiltonian (factory)
def vector(qs:Array, s:Array, l:Array, kn:Array, *args:Array) -> tuple[Array, Array, Array]:
q_x, q_y, q_s = qs
a_x, a_y, a_s = jax.numpy.zeros_like(qs)
a_s = 1/2*kn*(1 + jax.numpy.sin(2*jax.numpy.pi*s/l))*(q_x**2 - q_y**2)
return a_x, a_y, a_s
def scalar(qs:Array, s:Array, l:Array, kn:Array, *args:Array) -> Array:
q_x, q_y, q_s = qs
return jax.numpy.zeros_like(s)
hamiltonian = hamiltonian_factory(vector, scalar)
extended = autonomize(hamiltonian)
[13]:
# Set extended initial condition
Qs = jax.numpy.concat([qs, si.reshape(-1)])
Ps = jax.numpy.concat([ps, -hamiltonian(qs, ps, si, l, kn).reshape(-1)])
QsPs = jax.numpy.hstack([Qs, Ps])
[14]:
# Set implicit midpoint integration step
integrator = jit(fold(sequence(0, 2**1, [midpoint(extended, ns=2**1)], merge=False)))
[15]:
# Set and compile element
element = jit(nest(int(l/ds), integrator))
out = element(QsPs, ds, si, l, kn)
print(out)
[ 7.235126271922e-03 -4.257064841347e-03 -2.199863474715e-05 1.500000000000e+00 9.664543767467e-03 3.998219958383e-03 1.000000000000e-04 9.999624216145e-01]
[16]:
# Set tao integration step
integrator = jit(fold(sequence(0, 2**1, [tao(extended)], merge=False)))
[17]:
# Set and compile element
element = jit(nest(int(l/ds), integrator))
out = element(QsPs, ds, si, l, kn)
print(out)
[ 7.235126271969e-03 -4.257064841463e-03 -2.199863474965e-05 1.500000000000e+00 9.664543767678e-03 3.998219958148e-03 1.000000000000e-04 9.999624216145e-01]
[18]:
# Diffirentiability (initial condition)
matrix = jacrev(element)(QsPs, ds, si, l, kn)
print(matrix)
print(jax.numpy.linalg.det(matrix))
[[ 1.366527779203e+00 -5.802223397334e-06 0.000000000000e+00 1.658258296648e-03 1.173782172073e+00 3.024552512394e-06 -7.166765586369e-03 0.000000000000e+00]
[ 6.891863981574e-06 6.833771361996e-01 0.000000000000e+00 4.804596565370e-04 1.099963935945e-05 8.401702767985e-01 -7.785286136285e-04 0.000000000000e+00]
[-2.784634943647e-03 6.697396020357e-04 1.000000000000e+00 -1.788345235604e-05 -7.636376226265e-03 -3.158011828099e-04 5.159686903021e-05 0.000000000000e+00]
[ 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00 1.000000000000e+00 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00]
[ 1.134284155157e+00 -9.555360071284e-07 0.000000000000e+00 -1.418791282023e-04 1.706077662573e+00 -1.231406521087e-07 -3.911016696259e-03 0.000000000000e+00]
[-1.121934780525e-06 -8.767306145299e-01 0.000000000000e+00 1.453245809425e-03 -1.267627972649e-06 3.854343128509e-01 1.873697940944e-05 0.000000000000e+00]
[ 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00 0.000000000000e+00 1.000000000000e+00 0.000000000000e+00]
[-2.074807326405e-03 1.414351057111e-03 0.000000000000e+00 -1.733134500028e-04 -2.995636035736e-03 1.035788071469e-03 2.424534688640e-05 1.000000000000e+00]]
1.0
[19]:
# Diffirentiability (parameter)
print(jacrev(element, argnums=-1)(QsPs, ds, si, l, kn))
[ 1.298574427115e-03 1.618721745092e-03 -1.373594876871e-05 0.000000000000e+00 5.045563129602e-03 4.364921756032e-03 0.000000000000e+00 -2.081340401732e-05]