Example-01: Functional iteration

[1]:
# In this exaple usage of nest and fold function factories is illustrated
[2]:
# Import

import jax

from sympint import nest
from sympint import nest_list

from sympint import fold
from sympint import fold_list
[3]:
# Set data type

jax.config.update("jax_enable_x64", True)
[4]:
# Set device

device, *_ = jax.devices('cpu')
jax.config.update('jax_default_device', device)
[5]:
# Define a simple symplectic mapping

a = jax.numpy.array(0.5)
b = jax.numpy.array(1.0)
x = jax.numpy.array([0.1, 0.0])

def fn(x, a, b):
    q, p, *_ = x
    return jax.numpy.stack([p, -q + a*p + b*p**2])
[6]:
%%timeit
fn(x, a, b)
193 µs ± 6.88 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
[7]:
# The above mapping is compatible (composable) with JAX functions
# In particular, jit can be used to speed it up, this is usefull for more complicated mappings
[8]:
# Wrap and compile
# Once compiled, the resulting function can be used efficiently with different inputs

fj = jax.jit(fn)
fj(x, a, b)
[8]:
Array([ 0. , -0.1], dtype=float64)
[9]:
%%timeit

fj(x, a, b)
5.16 µs ± 14.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
[10]:
# A common task is to iterate given mapping repeatedly
# Normaly, this can be done with a regular Python loop
# But Python loops are known to be slow and can't be compiled (without unrolling)
[11]:
%%timeit

a = jax.numpy.array(0.5)
b = jax.numpy.array(1.0)
x = jax.numpy.array([0.1, 0.0])

for _ in range(2**6):
    x = fn(x, a, b)
12.7 ms ± 108 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
[12]:
# jax.lax provides several constructs for efficient compilable looping, e.g. jax.lax.scan
# While regular for loop will be unrolled, which will result in non practical compilation time for large number of iterations, scan allows to avoid it
# For repeated mapping application, nest function can be used, which is a wrapper around jax.lax.scan
[13]:
# Wrap and compile
# Once compiled, the resulting function can be used efficiently with different inputs

a = jax.numpy.array(0.5)
b = jax.numpy.array(1.0)
x = jax.numpy.array([0.1, 0.0])

fj = jax.jit(nest(2**6, fn))
fj(x, a, b)
[13]:
Array([-0.09687625, -0.05042709], dtype=float64)
[14]:
%%timeit
fj(x, a, b)
9.18 µs ± 87.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
[15]:
# jax.lax.scan also allows to accumulate intermediate results
# For mappings, nest_list function allows to accumulate the output at each iteration (excluding the initial value)
[16]:
fj = jax.jit(nest_list(2**6, fn))

a = jax.numpy.array(0.5)
b = jax.numpy.array(1.0)
x = jax.numpy.array([0.1, 0.0])

fj(x, a, b).shape
[16]:
(64, 2)
[17]:
%%timeit
fj(x, a, b)
11.3 µs ± 72.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
[18]:
# fold function allows to apply a sequence of mappings
# While mappings can be different, identical signature is assumed

fj = jax.jit(fold(2**6*[fn]))

a = jax.numpy.array(0.5)
b = jax.numpy.array(1.0)
x = jax.numpy.array([0.1, 0.0])

fj(x, a, b)
[18]:
Array([-0.09687625, -0.05042709], dtype=float64)
[19]:
%%timeit
fj(x, a, b)
13.4 µs ± 46.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
[20]:
# fold with accumulation is also avaliable

fj = jax.jit(fold_list(2**6*[fn]))

a = jax.numpy.array(0.5)
b = jax.numpy.array(1.0)
x = jax.numpy.array([0.1, 0.0])

fj(x, a, b).shape
[20]:
(64, 2)
[21]:
%%timeit
fj(x, a, b)
15.6 µs ± 126 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
[22]:
# Other JAX functions, like vmap and grad can be applied to the results of nest and fold functions
[23]:
# Vectorized map

xs = jax.numpy.array([[0.0, 0.0], [0.1, 0.0], [0.2, 0.0], [0.3, 0.0], [0.4, 0.0], [0.5, 0.0]])
jax.vmap(fj, (0, None, None))(xs, a, b).shape
[23]:
(6, 64, 2)
[24]:
# Jacobian

jax.jacrev(fj)(x, a, b).shape
[24]:
(64, 2, 2)
[ ]: