Example-06: Non-autonomous hamiltonian integration
[1]:
# In this example integration of non-autonomous hamiltonian is illustrated
# Such integration has only limmited support, since function iteration tools do not carry time
# Thus, only one second order integration step can be performed and time should be adjusted manually after each step, i.e. using normal python loop or custom scan body
# Support for more general case would require to modife function iterations, for example, instead of the following loop:
# for _ in range(n): x = f(x, *args)
# nesting should correspond to:
# for _ in range(n): x = f(x, dt, t, *args) ; t = t + dt
# Similary, fold (and other functions)should be modified to carry time
# Instead, it is possible to use extended phase space with midpoint or tao integrators
[2]:
# Import
import jax
from jax import Array
from jax import jit
from jax import vmap
from sympint import fold
from sympint import nest
from sympint import midpoint
from sympint import sequence
jax.numpy.set_printoptions(linewidth=256, precision=12)
[3]:
# Set data type
jax.config.update("jax_enable_x64", True)
[4]:
# Set device
device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)
[5]:
# Set parameters
si = jax.numpy.array(0.0)
ds = jax.numpy.array(0.01)
kn = jax.numpy.array(1.0)
[6]:
# Set initial condition
qs = jax.numpy.array([0.1, 0.1])
ps = jax.numpy.array([0.0, 0.0])
x = jax.numpy.hstack([qs, ps])
[7]:
# Define hamiltonian
def hamiltonian(qs, ps, s, kn, *args):
q_x, q_y = qs
p_x, p_y = ps
return 1/2*(p_x**2 + p_y**2) + 1/2*kn*(1 + jax.numpy.cos(s))*(q_x**2 + q_y**2)
[8]:
# Set implicit midpoint integration step
integrator = jit(fold(sequence(0, 0, [midpoint(hamiltonian, ns=2**4)], merge=False)))
[9]:
# Perform integration with explicit time update
time = si
data = x
for _ in range(10**2):
data = integrator(data, ds, time, kn)
time = time + ds
print(data)
[ 0.017983795895 0.017983795895 -0.133154567382 -0.133154567382]
[10]:
# Define hamiltonian (extended)
def extended(qs, ps, s, kn, *args):
q_x, q_y, q_t = qs
p_x, p_y, p_t = ps
return p_t + 1/2*(p_x**2 + p_y**2) + 1/2*kn*(1 + jax.numpy.cos(q_t))*(q_x**2 + q_y**2)
[11]:
# Set extended initial condition
Qs = jax.numpy.concat([qs, si.reshape(-1)])
Ps = jax.numpy.concat([ps, -hamiltonian(qs, ps, si, kn).reshape(-1)])
X = jax.numpy.hstack([Qs, Ps])
[12]:
# Set implicit midpoint integration step using extended hamiltonian
integrator = jit(fold(sequence(0, 0, [midpoint(extended, ns=2**4)], merge=False)))
[13]:
# Set and compile element
element = jit(nest(10**2, integrator))
out = element(X, ds, si, kn)
print(out)
[ 0.017983795895 0.017983795895 1. -0.133154567382 -0.133154567382 -0.018228323463]
[ ]: