Example-02: Yosida composition

[1]:
# Given a time-reversible integration step of difference order 2n
# Yoshida composition procedure can be used to construct integration step of difference order 2(n+1)
# Using Yoshida coefficients, new intergration step is w(2(n+1))(dt) = w(2n)(x1 dt) o w(2n)(x2 dt) o w(2n)(x1 dt)

# If a hamiltonian vector field can be splitted into several sovable parts
# Second order time-reversible symmetric integrator can be easily constructed as follows
# w1(dt/2) o w2(dt/2) o ... o wn(dt/2) o wn(dt/2) o ... o w2(dt/2) o w1(dt/2)
# where each wi is a mapping for corresponding hamiltonian
# Yoshida composition procedure can be then applied repeatedly to obtain higher order integration steps
[2]:
# Import

import jax

# Function iterations

from sympint import fold
from sympint import nest
from sympint import nest_list

# Yoshida composition

from sympint import weights
from sympint import coefficients
from sympint import table
from sympint import sequence
[3]:
# Set data type

jax.config.update("jax_enable_x64", True)
[4]:
# Set device

device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)
[5]:
# Given integration step of difference order 2n
# Yoshida weights for 2(n+1) order can be computed using weights function
# Note, sum of weights is equal to one

print([f'{weight:.3f}' for weight in weights(1)]) # 2 -> 4
print([f'{weight:.3f}' for weight in weights(2)]) # 4 -> 6
print([f'{weight:.3f}' for weight in weights(3)]) # 6 -> 8
print([f'{weight:.3f}' for weight in weights(4)]) # 8 -> 10
['1.351', '-1.702', '1.351']
['1.175', '-1.349', '1.175']
['1.116', '-1.232', '1.116']
['1.087', '-1.174', '1.087']
[6]:
# Given integration step of difference order 2n
# Yoshida coefficents for 2m difference order step can be computed using coefficients function
# Note, sum of coefficients is equal to one

print([f'{coefficient:.3f}' for coefficient in coefficients(1, 1)]) # 2 -> 4
print([f'{coefficient:.3f}' for coefficient in coefficients(1, 2)]) # 2 -> 6
print([f'{coefficient:.3f}' for coefficient in coefficients(2, 2)]) # 4 -> 6
['1.351', '-1.702', '1.351']
['1.587', '-2.000', '1.587', '-1.823', '2.297', '-1.823', '1.587', '-2.000', '1.587']
['1.175', '-1.349', '1.175']
[7]:
# Given a collection of mappings along with initial and final Yoshida orders (half the corresponding difference orders)
# Corresponding Yoshida table can be computed using table function
# Note, mapping can be an integation step

# If mapping is an integration step, the last argument should be set to False

ns, cs = table(1, 1, 1, False) ; print([ns, [f'{c:.3f}' for c in cs]]) # 2 -> 4
ns, cs = table(1, 1, 2, False) ; print([ns, [f'{c:.3f}' for c in cs]]) # 2 -> 6
ns, cs = table(1, 2, 2, False) ; print([ns, [f'{c:.3f}' for c in cs]]) # 4 -> 6
print()

# Constuct table from two mappings without merging
# Note, number of mappings can be arbitrary

ns, cs = table(2, 0, 0, False)  ; print([ns, [f'{c:.3f}' for c in cs]]) # 2 -> 2
ns, cs = table(2, 0, 1, False)  ; print([ns, [f'{c:.3f}' for c in cs]]) # 2 -> 4
print()

# Constuct table from two mappings with merging

ns, cs = table(2, 0, 0, True)  ; print([ns, [f'{c:.3f}' for c in cs]]) # 2 -> 2
ns, cs = table(2, 0, 1, True)  ; print([ns, [f'{c:.3f}' for c in cs]]) # 2 -> 4
print()
[[0, 0, 0], ['1.351', '-1.702', '1.351']]
[[0, 0, 0, 0, 0, 0, 0, 0, 0], ['1.587', '-2.000', '1.587', '-1.823', '2.297', '-1.823', '1.587', '-2.000', '1.587']]
[[0, 0, 0], ['1.175', '-1.349', '1.175']]

[[0, 1, 0], ['0.500', '1.000', '0.500']]
[[0, 1, 0, 0, 1, 0, 0, 1, 0], ['0.676', '1.351', '0.676', '-0.851', '-1.702', '-0.851', '0.676', '1.351', '0.676']]

[[0, 1, 0], ['0.500', '1.000', '0.500']]
[[0, 1, 0, 1, 0, 1, 0], ['0.676', '1.351', '-0.176', '-1.702', '-0.176', '1.351', '0.676']]

[8]:
# Construct integratinn step for a simple rotation Hamiltonian

# H = H1 + H2
# H1 = 1/2 q**2 -> [q, p] -> [q, p - t*q]
# H2 = 1/2 p**2 -> [q, p] -> [q + t*q, p]

# Set mappings

def fn(x, dt):
    q, p = x
    return jax.numpy.stack([q, p - dt*q])

def gn(x, dt):
    q, p = x
    return jax.numpy.stack([q + dt*p, p])

# Set time step

dt = jax.numpy.array(0.25)

# Set initial condition

x = jax.numpy.array([0.1, 0.1])

# Generate and fold transformations (without mergin) for different final Yoshida orders

for i in range(5):
    fns = sequence(0, i, [fn, gn], merge=False)
    print(f'{len(fns):>3} {fold(fns)(x, dt)}')
print()

# Generate and fold transformations (with mergin) for different final Yoshida orders

for i in range(5):
    fns = sequence(0, i, [fn, gn], merge=True)
    print(f'{len(fns):>3} {fold(fns)(x, dt)}')
print()
  3 [0.121875   0.07226562]
  9 [0.12162308 0.07215494]
 27 [0.12163202 0.07215096]
 81 [0.12163163 0.07215085]
243 [0.12163164 0.07215085]

  3 [0.121875   0.07226562]
  7 [0.12162308 0.07215494]
 19 [0.12163202 0.07215096]
 55 [0.12163163 0.07215085]
163 [0.12163164 0.07215085]

[9]:
%%timeit

# Even with merging scanning throught a large number of mappings is slow
# Note, fold is a wrapper around jax.lax.scan

fold(fns)(x, dt)
1.25 s ± 69.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[10]:
# This can be remedied by JAX jit compilation
# The first execution will be still slow (compilation step)

fj = jax.jit(fold(fns))
fj(x, dt)
[10]:
Array([0.12163164, 0.07215085], dtype=float64)
[11]:
%%timeit

# Compiled transformation is expected to be much faster
# If the intend is to use it repeatedly with different initials, jit compilation is a way to go

fj(x, dt)
28.1 µs ± 256 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
[12]:
%%timeit

# Compiled step is compatible with JAX functions
# For example, it is possible to compute jacobian with respect to initial condition
# Note, this might trigger a recompile

jax.jacrev(fj)(x, dt)
1.86 ms ± 532 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
[13]:
# Add one more layer of JIT and compile

jacobian = jax.jit(jax.jacrev(fj))
jacobian(x, dt)
[13]:
Array([[ 0.96891242,  0.24740395],
       [-0.24740396,  0.96891242]], dtype=float64)
[14]:
%%timeit

# Time the resulting jacobian

jacobian(x, dt)
81.5 µs ± 2.27 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
[15]:
%%timeit

# Perform several integrations steps (native Python loop over steps and step parts)

fs = sequence(0, 5, [fn, gn], merge=True)

dt = jax.numpy.array(0.25)
x = jax.numpy.array([0.1, 0.1])

for _ in range(64):
    for f in fs:
        x = f(x, dt)
4.68 s ± 109 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[16]:
# Compile

dt = jax.numpy.array(0.25)
x = jax.numpy.array([0.1, 0.1])

fs = sequence(0, 5, [fn, gn], merge=True)
fj = jax.jit(fold(fs))
fj(x, dt) ;

fj = nest(64, fj)
fj = jax.jit(fj)
fj(x, dt) ;
[17]:
%%timeit

# Test (compilation time is excluded)

fj(x, dt)
3.97 ms ± 51.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
[18]:
# Setup a multistep integrator

# H = H1 + H2
# H1 = 1/2 q**2 + 1/3 q**3 -> [q, p] -> [q, p - t*q - t*q**2]
# H2 = 1/2 p**2            -> [q, p] -> [q + t*q, p]

dt = jax.numpy.array(0.1)
x = jax.numpy.array([0.1, 0.1])

def fn(x, t):
    q, p = x
    return jax.numpy.stack([q, p - t*q - t*q**2])

def gn(x, t):
    q, p = x
    return jax.numpy.stack([q + t*p, p])

fs = sequence(0, 1, [fn, gn], merge=True)
print(fold(fs)(x, dt))

# H = H1 + H2 + H3
# H1 = 1/2 q**2 -> [q, p] -> [q, p - t*q]
# H2 = 1/3 q**3 -> [q, p] -> [q, p - t*q**2]
# H3 = 1/2 p**2 -> [q, p] -> [q + t*q, p]

def fn(x, t):
    q, p = x
    return jax.numpy.stack([q, p - t*q])

def gn(x, t):
    q, p = x
    return jax.numpy.stack([q, p - t*q**2])

def hn(x, t):
    q, p = x
    return jax.numpy.stack([q + t*p, p])

# Note, the last mapping in the list has the smallest number of evaluations

fs = sequence(0, 1, [fn, gn, hn], merge=True)
print(fold(fs)(x, dt))

# Note, the result is identical since two parts commute
[0.10943036 0.08841961]
[0.10943036 0.08841961]
[19]:
# Increase order of an existing intergration step

# Set time step and initial condition

dt = jax.numpy.array(0.1)
x = jax.numpy.array([0.1, 0.1])

# Set transformations for sovable parts

def fn(x, t):
    q, p = x
    return jax.numpy.stack([q, p - t*q - t*q**2])

def gn(x, t):
    q, p = x
    return jax.numpy.stack([q + t*p, p])

# Define 2nd, 4th and 6th order integration step from parts

s2 = fold(sequence(0, 0, [fn, gn], merge=True))
s4 = fold(sequence(0, 1, [fn, gn], merge=True))
s6 = fold(sequence(0, 2, [fn, gn], merge=True))

# Constuct 4th order integration step from a 2nd order one
# And compare with 4th order step constructed from parts

w4 = fold(sequence(1, 1, [s2], merge=False))
print(jax.numpy.allclose(s4(x, dt), w4(x, dt)))

# Construct 6th order from 4th order and compare

w6 = fold(sequence(2, 2, [s4], merge=False))
print(jax.numpy.allclose(s6(x, dt), w6(x, dt)))


# Construct 6th order from 2nd order and compare

w6 = fold(sequence(1, 2, [s2], merge=False))
print(jax.numpy.allclose(s6(x, dt), w6(x, dt)))
True
True
True
[20]:
# Pass fixed parameters

def fn(x, t, a, b):
    q, p = x
    return jax.numpy.stack([q, p - a*t*q - b*t*q**2])

def gn(x, t):
    q, p = x
    return jax.numpy.stack([q + t*p, p])

t = jax.numpy.array(0.1)
x = jax.numpy.array([0.1, 0.1])
a = jax.numpy.array(1.0)
b = jax.numpy.array(1.0)

fj = jax.jit(fold(sequence(0, 1, [fn, gn], merge=True, parameters=[[a, b], []])))

print(fj(x, t))
print()
[0.10943036 0.08841961]

[21]:
# Integration step with parameters (matching signatures)

def fn(x, t, a, b):
    q, p = x
    return jax.numpy.stack([q, p - a*t*q - b*t*q**2])

def gn(x, t, a, b):
    q, p = x
    return jax.numpy.stack([q + t*p, p])

t = jax.numpy.array(0.1)
x = jax.numpy.array([0.1, 0.1])
a = jax.numpy.array(1.0)
b = jax.numpy.array(1.0)

fj = jax.jit(fold(sequence(0, 1, [fn, gn], merge=True)))

print(fj(x, t, a, b))
print()
[0.10943036 0.08841961]

[22]:
# Matched signatures allow to compute derivatives with respect to matched parameters

for i in range(4):
    print(jax.jacrev(fj, i)(x, t, a, b))
    print()
[[ 0.99397345  0.09979712]
 [-0.12071788  0.99394275]]

[ 0.08841321 -0.12140165]

[-0.0005159 -0.0104603]

[-5.32985126e-05 -1.09712073e-03]