Example-06: Non-autonomous hamiltonian integration

[1]:
# In this example integration of non-autonomous hamiltonian is illustrated
# Such integration has only limmited support, since function iteration tools do not carry time
# Thus, only one second order integration step can be performed and time should be adjusted manually after each step, i.e. using normal python loop or custom scan body

# Support for more general case would require to modife function iterations, for example, instead of the following loop:
# for _ in range(n): x = f(x, *args)
# nesting should correspond to:
# for _ in range(n): x = f(x, dt, t, *args) ; t = t + dt
# Similary, fold (and other functions)should be modified to carry time

# Instead, it is possible to use extended phase space with midpoint or tao integrators
[2]:
# Import

import jax
from jax import Array
from jax import jit
from jax import vmap

from sympint import fold
from sympint import nest
from sympint import midpoint
from sympint import sequence

jax.numpy.set_printoptions(linewidth=256, precision=12)
[3]:
# Set data type

jax.config.update("jax_enable_x64", True)
[4]:
# Set device

device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)
[5]:
# Set parameters

si = jax.numpy.array(0.0)
ds = jax.numpy.array(0.01)
kn = jax.numpy.array(1.0)
[6]:
# Set initial condition

qs = jax.numpy.array([0.1, 0.1])
ps = jax.numpy.array([0.0, 0.0])
x = jax.numpy.hstack([qs, ps])
[7]:
# Define hamiltonian

def hamiltonian(qs, ps, s, kn, *args):
    q_x, q_y = qs
    p_x, p_y = ps
    return 1/2*(p_x**2 + p_y**2) + 1/2*kn*(1 + jax.numpy.cos(s))*(q_x**2 + q_y**2)
[8]:
# Set implicit midpoint integration step

integrator = jit(fold(sequence(0, 0, [midpoint(hamiltonian, ns=2**4)], merge=False)))
[9]:
# Perform integration with explicit time update

time = si
data = x
for _ in range(10**2):
    data = integrator(data, ds, time, kn)
    time = time + ds
print(data)
[ 0.017983795895  0.017983795895 -0.133154567382 -0.133154567382]
[10]:
# Define hamiltonian (extended)

def extended(qs, ps, s, kn, *args):
    q_x, q_y, q_t = qs
    p_x, p_y, p_t = ps
    return p_t + 1/2*(p_x**2 + p_y**2) + 1/2*kn*(1 + jax.numpy.cos(q_t))*(q_x**2 + q_y**2)
[11]:
# Set extended initial condition

Qs = jax.numpy.concat([qs, si.reshape(-1)])
Ps = jax.numpy.concat([ps, -hamiltonian(qs, ps, si, kn).reshape(-1)])
X = jax.numpy.hstack([Qs, Ps])
[12]:
# Set implicit midpoint integration step using extended hamiltonian

integrator = jit(fold(sequence(0, 0, [midpoint(extended, ns=2**4)], merge=False)))
[13]:
# Set and compile element

element = jit(nest(10**2, integrator))
out = element(X, ds, si, kn)
print(out)
[ 0.017983795895  0.017983795895  1.             -0.133154567382 -0.133154567382 -0.018228323463]
[ ]: