Example-11: FLI

[1]:
# In this example, a basic application of the FLI indicator is presented
[2]:
# Import

import numpy

from tqdm import tqdm

import jax
from jax import jit
from jax import vmap

# Test symplectic mapping

from tohubohu.util import forward4D

# FLI factory

from tohubohu import fli

# Plotting

from matplotlib import pyplot as plt
from matplotlib import colormaps

cmap = colormaps.get_cmap('viridis')
cmap.set_bad(color='lightgray')
[3]:
# Set data type

jax.config.update("jax_enable_x64", True)
[4]:
# Set device

device, *_ = jax.devices('cuda')
jax.config.update('jax_default_device', device)
[5]:
# Set mapping parameters

nux, nuy = 0.168, 0.201
mux, muy = 2*jax.numpy.pi*nux, 2*jax.numpy.pi*nuy
cx, sx, cy, sy = jax.numpy.cos(mux), jax.numpy.sin(mux), jax.numpy.cos(muy), jax.numpy.sin(muy)
mu = 0.0

k = jax.numpy.asarray([cx, sx, cy, sy, mu])
[6]:
# Set and compile indicator (without normalization)

@jit
def fn(x):
    return fli(2**12, forward4D, normalize=False)(x, v, k)

x = jax.numpy.array([0.25, 0.25, 0.0, 0.0])
v = jax.numpy.array([1.0, 0.0, 0.0, 0.0])

out = fn(x)
[7]:
# Set initial grid in (qx, qy) plane

n = 1001

qx = jax.numpy.linspace(0.0, 0.6, n)
qy = jax.numpy.linspace(0.0, 0.6, n)
qs = jax.numpy.stack(jax.numpy.meshgrid(qx, qy, indexing='ij')).swapaxes(-1, 0).reshape(n*n, -1)
ps = jax.numpy.full_like(qs, 1.0E-12)
xs = jax.numpy.hstack([qs, ps])
xs = jax.numpy.array_split(xs, n)
[8]:
# Evaluate indicator

xb, *xr = xs
fj = jit(vmap(fn))
out = [fj(xb)]

for xb in tqdm(xr):
    out.append(fj(xb))
out = jax.numpy.concatenate(out)

# Winsorize data

data = numpy.array(out)
data[data < 0.0] = 0.0
data[data > 16.0] = 16.0
data = data.reshape(n, n)

# Plot

plt.figure(figsize=(8, 8))
plt.imshow(data, aspect='equal', vmin=0.0, vmax=16.0, origin='lower', cmap=cmap, interpolation='nearest', extent=(0., 0.6, 0., 0.6))
plt.xlabel('qx')
plt.ylabel('qy')
plt.tight_layout()
plt.show()
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:32<00:00, 10.79it/s]
../_images/examples_example-10_8_1.png
[9]:
# Set and compile indicator (with normalization)

@jit
def fn(x):
    return fli(2**12, forward4D, normalize=True)(x, v, k)

x = jax.numpy.array([0.25, 0.25, 0.0, 0.0])
v = jax.numpy.array([1.0, 0.0, 0.0, 0.0])

out = fn(x)
[10]:
# Set initial grid in (qx, qy) plane

n = 1001

qx = jax.numpy.linspace(0.0, 0.6, n)
qy = jax.numpy.linspace(0.0, 0.6, n)
qs = jax.numpy.stack(jax.numpy.meshgrid(qx, qy, indexing='ij')).swapaxes(-1, 0).reshape(n*n, -1)
ps = jax.numpy.full_like(qs, 1.0E-12)
xs = jax.numpy.hstack([qs, ps])
xs = jax.numpy.array_split(xs, n)
[11]:
# Evaluate indicator

xb, *xr = xs
fj = jit(vmap(fn))
out = [fj(xb)]

for xb in tqdm(xr):
    out.append(fj(xb))
out = jax.numpy.concatenate(out)

# Winsorize data

data = numpy.array(out)
data = data.reshape(n, n)

# Plot

plt.figure(figsize=(8, 8))
plt.imshow(data, aspect='equal', vmin=0.0, vmax=0.005, origin='lower', cmap=cmap, interpolation='nearest', extent=(0., 0.6, 0., 0.6))
plt.xlabel('qx')
plt.ylabel('qy')
plt.tight_layout()
plt.show()
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:44<00:00,  9.57it/s]
../_images/examples_example-10_11_1.png