Example-01: Derivative

[1]:
# Given an input function, its higher order (partial) derivatives with respect to one or sevaral tensor arguments can be computed using forward or reverse mode automatic differentiation
# Derivative orders can be different for each tensor argument
# Input function is expected to return a tensor or a (nested) list of tensors

# Derivatives are computed by nesting torch jacobian functions
# For higher order derivatives, nesting results in exponentially growing redundant computations
# Note, forward mode is more memory efficient in this case

# If the input function returns a tensor, the output is referred as derivative table representation
# This representation can be evaluated near given evaluation point (at a given deviation) if the input function returns a scalar or a vector
# Table representation is a (nested) list of tensors, it can be used as a redundant function representation near given evaluation point (taylor series)
# Table structure for f(x), f(x, y) and f(x, y, z) is shown bellow (similar structure holds for a function with more aruments)

# f(x)
# t(f, x)
# [f, Dx f, Dxx f, ...]

# f(x, y)
# t(f, x, y)
# [
#     [    f,     Dy f,     Dyy f, ...],
#     [ Dx f,  Dx Dy f,  Dx Dyy f, ...],
#     [Dxx f, Dxx Dy f, Dxx Dyy f, ...],
#     ...
# ]

# f(x, y, z)
# t(f, x, y, z)
# [
#     [
#         [         f,          Dz f,          Dzz f, ...],
#         [      Dy f,       Dy Dz f,       Dy Dzz f, ...],
#         [     Dyy f,      Dyy Dz f,      Dyy Dzz f, ...],
#         ...
#     ],
#     [
#         [      Dx f,       Dx Dz f,       Dx Dzz f, ...],
#         [   Dx Dy f,    Dx Dy Dz f,    Dx Dy Dzz f, ...],
#         [  Dx Dyy f,   Dx Dyy Dz f,   Dx Dyy Dzz f, ...],
#         ...
#     ],
#     [
#         [    Dxx f,     Dxx Dz f,     Dxx Dzz f, ...],
#         [ Dxx Dy f,  Dxx Dy Dz f,  Dxx Dy Dzz f, ...],
#         [Dxx Dyy f, Dxx Dyy Dz f, Dxx Dyy Dzz f, ...],
#         ...
#     ],
#     ...
# ]
[2]:
# Import

import torch

from ndmap.derivative import derivative
from ndmap.evaluate import evaluate
from ndmap.series import series

torch.set_printoptions(precision=12, sci_mode=True)
print(torch.cuda.is_available())

import warnings
warnings.filterwarnings("ignore")
True
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Basic derivative interface

# derivative(
#     order:int,                             # derivative order
#     function:Callable,                     # input function
#     *args,                                 # function(*args) = function(x:Tensor, ...)
#     intermediate:bool = True,              # flag to return all intermediate derivatives
#     jacobian:Callable = torch.func.jacfwd  # torch.func.jacfwd or torch.func.jacfrev
# )

# derivative(
#     order:tuple[int, ...],                 # derivative orders
#     function:Callable,                     # input function
#     *args,                                 # function(*args) = function(x:Tensor, y:Tensor, z:Tensor, ...)
#     intermediate:bool = True,              # flag to return all intermediate derivatives
#     jacobian:Callable = torch.func.jacfwd  # torch.func.jacfwd or torch.func.jacfrev
# )
[5]:
# Derivative

# Input:  scalar
# Output: scalar

# Set test function

# Note, the first function argument is a scalar tensor
# Input function can have other additional arguments
# Other arguments are not used in computation of derivatives

def fn(x, a, b, c, d, e, f):
    return a + b*x + c*x**2 + d*x**3 + e*x**4 + f*x**5

# Set derivative order

n = 5

# Set evaluation point

x = torch.tensor(0.0, dtype=dtype, device=device)

# Set fixed parameters

a, b, c, d, e, f = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], dtype=dtype, device=device)

# Compute n'th derivative

value = derivative(n, fn, x, a, b, c, d, e, f, intermediate=False, jacobian=torch.func.jacfwd)
print(value.cpu().numpy().tolist())

# Compute all derivatives upto given order

# Note, function value itself is referred as zero order derivative
# Since function returns a tensor, output is a list of tensors

values = derivative(n, fn, x, a, b, c, d, e, f, intermediate=True, jacobian=torch.func.jacfwd)
print(*[value.cpu().numpy().tolist() for value in values], sep=', ')

# Note, intermediate flag (default=True) can be used to return all derivatives
# For jacobian parameter, torch.func.jacfwd or torch.func.jacrev functions can be passed

# Evaluate derivative table representation for a given deviation from the evaluation point

dx = torch.tensor(1.0, dtype=dtype, device=device)
print(evaluate(derivative(n, fn, x, a, b, c, d, e, f) , [dx]).cpu().numpy().tolist())
print(fn(x + dx, a, b, c, d, e, f).cpu().numpy().tolist())
120.0
1.0, 1.0, 2.0, 6.0, 24.0, 120.0
6.0
6.0
[6]:
# Derivative

# Input:  vector
# Output: scalar

# Set test function

# Note, the first function argument is a vector tensor
# Input function can have other additional arguments
# Other arguments are not used in computation of derivatives

def fn(x, a, b, c):
    x1, x2 = x
    return a + b*(x1 - 1)**2 + c*(x2 + 1)**2

# Set derivative order

n = 2

# Set evaluation point

x = torch.tensor([0.0, 0.0], dtype=dtype, device=device)

# Set fixed parameters

a, b, c = torch.tensor([1.0, 1.0, 1.0], dtype=dtype, device=device)

# Compute only n'th derivative

# Note, for given input & output the result is a hessian

value = derivative(n, fn, x, a, b, c, intermediate=False, jacobian=torch.func.jacfwd)
print(value.cpu().numpy().tolist())

# Compute all derivatives upto given order

# Note, fuction value itself is referred as zero order derivative
# Output is a list of tensors (value, jacobian, hessian, ...)

values = derivative(n, fn, x, a, b, c, intermediate=True, jacobian=torch.func.jacfwd)
print(*[value.cpu().numpy().tolist() for value in values], sep=', ')

# Compute jacobian and hessian with torch

print(fn(x, a, b, c).cpu().numpy().tolist(),
      torch.func.jacfwd(lambda x: fn(x, a, b, c))(x).cpu().numpy().tolist(),
      torch.func.hessian(lambda x: fn(x, a, b, c))(x).cpu().numpy().tolist(),
      sep=', ')

# Evaluate derivative table representation for a given deviation from the evaluation point

dx = torch.tensor([+1.0, -1.0], dtype=dtype, device=device)
print(evaluate(values, [dx]).cpu().numpy())
print(fn(x + dx, a, b, c).cpu().numpy())

# Evaluate can be mapped over a set of deviation values

print(torch.func.vmap(lambda x: evaluate(values, [x]))(torch.stack(5*[dx])).cpu().numpy().tolist())

# Derivative can be mapped over a set of evaluation points

# Note, the inputt function is expeted to return a tensor

print(torch.func.vmap(lambda x: derivative(1, fn, x, a, b, c, intermediate=False))(torch.stack(5*[x])).cpu().numpy().tolist())
[[2.0, 0.0], [0.0, 2.0]]
3.0, [-2.0, 2.0], [[2.0, 0.0], [0.0, 2.0]]
3.0, [-2.0, 2.0], [[2.0, 0.0], [0.0, 2.0]]
1.0
1.0
[1.0, 1.0, 1.0, 1.0, 1.0]
[[-2.0, 2.0], [-2.0, 2.0], [-2.0, 2.0], [-2.0, 2.0], [-2.0, 2.0]]
[7]:
# Derivative

# Input:  vector
# Output: vector

# Set test function

# Note, the first function argument is a vector tensor
# Input function can have other additional arguments
# Other arguments (if any) are not used in computation of derivatives

def fn(x):
    x1, x2 = x
    X1 = 1.0*x1 + 2.0*x2
    X2 = 3.0*x1 + 4.0*x2
    X3 = 5.0*x1 + 6.0*x2
    return torch.stack([X1, X2, X3])

# Set derivative order

n = 1

# Set evaluation point

x = torch.tensor([0.0, 0.0], dtype=dtype, device=device)

# Compute derivatives

values = derivative(n, fn, x)
print(*[value.cpu().numpy().tolist() for value in values], sep=', ')
print()

# Evaluate derivative table representation for a given deviation from the evaluation point

dx = torch.tensor([+1, -1], dtype=dtype, device=device)
print(evaluate(values, [dx]).cpu().numpy().tolist())
print(fn(x + dx).cpu().numpy().tolist())
[0.0, 0.0, 0.0], [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]

[-1.0, -1.0, -1.0]
[-1.0, -1.0, -1.0]
[8]:
# Derivative

# Input:  tensor
# Output: tensor

# Set test function

def fn(x):
    return 1 + x + x**2 + x**3

# Set derivative order

n = 3

# Set evaluation point

x = torch.zeros((1, 2, 3), dtype=dtype, device=device)

# Compute derivatives

# Note, output is a list of tensors

values = derivative(n, fn, x)
print(*[list(value.shape) for value in values], sep='\n')

# Evaluate derivative table representation for a given deviation from the evaluation point

# Note, evaluate function works with scalar or vector tensor input
# One should compute derivatives of a wrapped function and reshape the result of evaluate

# Set wrapped function

def gn(x, shape):
    return fn(x.reshape(shape)).flatten()

print(fn(x).cpu().numpy().tolist())
print(gn(x.flatten(), x.shape).reshape(x.shape).cpu().numpy().tolist())

# Compute derivatives

values = derivative(n, gn, x.flatten(), x.shape)

# Set deviation value

dx = torch.ones_like(x)

# Evaluate

print(evaluate(values, [dx.flatten()]).reshape(x.shape).cpu().numpy().tolist())
print(gn((x + dx).flatten(), x.shape).reshape(x.shape).cpu().numpy().tolist())
print(fn(x + dx).cpu().numpy().tolist())
[1, 2, 3]
[1, 2, 3, 1, 2, 3]
[1, 2, 3, 1, 2, 3, 1, 2, 3]
[1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3]
[[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]]
[[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]]
[[[4.0, 4.0, 4.0], [4.0, 4.0, 4.0]]]
[[[4.0, 4.0, 4.0], [4.0, 4.0, 4.0]]]
[[[4.0, 4.0, 4.0], [4.0, 4.0, 4.0]]]
[9]:
# Derivative

# Input:  vector
# Output: nested list of tensors

# Set test function

def fn(x):
    x1, x2, x3, x4, x5, x6 = x
    X1 = 1.0*x1 + 2.0*x2 + 3.0*x3
    X2 = 4.0*x4 + 5.0*x5 + 6.0*x6
    return [torch.stack([X1]), [torch.stack([X2])]]

# Set derivative order

n = 1

# Set evaluation point

x = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)

# Compute derivatives

values = derivative(n, fn, x, intermediate=False)
[10]:
# Derivative

# Input:  vector, vector, vector
# Output: vector

# Set test function

def fn(x, y, z):
    x1, x2 = x
    y1, y2 = y
    z1, z2 = z
    return torch.stack([(x1 + x2)*(y1 + y2)*(z1 + z2)])

# Set derivative orders for x, y and z

nx, ny, nz = 1, 1, 1

# Set evaluation point
# Note, evaluation point is a list of tensors

x = torch.tensor([0.0, 0.0], dtype=dtype, device=device)
y = torch.tensor([0.0, 0.0], dtype=dtype, device=device)
z = torch.tensor([0.0, 0.0], dtype=dtype, device=device)

# Compute n'th derivativ

value = derivative((nx, ny, nz), fn, x, y, z, intermediate=False)
print(value.cpu().numpy().tolist())

# Compute all derivatives upto given order

values = derivative((nx, ny, nz), fn, x, y, z, intermediate=True)

# Evaluate derivative table representation for a given deviation from the evaluation point

dx = torch.tensor([1.0, 1.0], dtype=dtype, device=device)
dy = torch.tensor([1.0, 1.0], dtype=dtype, device=device)
dz = torch.tensor([1.0, 1.0], dtype=dtype, device=device)
print(evaluate(values, [dx, dy, dz]).cpu().numpy().tolist())
print(fn(x + dx, y + dy, z + dz).cpu().numpy().tolist())

# Note, if the input function has vector arguments and returns a tensor, it can be repsented with series

for key, value in series(tuple(map(len, (x, y, z))), (nx, ny, nz), values).items():
    print(f'{key}: {value.cpu().numpy().tolist()}')
[[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]
[8.0]
[8.0]
(0, 0, 0, 0, 0, 0): [0.0]
(0, 0, 0, 0, 1, 0): [0.0]
(0, 0, 0, 0, 0, 1): [0.0]
(0, 0, 1, 0, 0, 0): [0.0]
(0, 0, 0, 1, 0, 0): [0.0]
(0, 0, 1, 0, 1, 0): [0.0]
(0, 0, 1, 0, 0, 1): [0.0]
(0, 0, 0, 1, 1, 0): [0.0]
(0, 0, 0, 1, 0, 1): [0.0]
(1, 0, 0, 0, 0, 0): [0.0]
(0, 1, 0, 0, 0, 0): [0.0]
(1, 0, 0, 0, 1, 0): [0.0]
(1, 0, 0, 0, 0, 1): [0.0]
(0, 1, 0, 0, 1, 0): [0.0]
(0, 1, 0, 0, 0, 1): [0.0]
(1, 0, 1, 0, 0, 0): [0.0]
(1, 0, 0, 1, 0, 0): [0.0]
(0, 1, 1, 0, 0, 0): [0.0]
(0, 1, 0, 1, 0, 0): [0.0]
(1, 0, 1, 0, 1, 0): [1.0]
(1, 0, 1, 0, 0, 1): [1.0]
(1, 0, 0, 1, 1, 0): [1.0]
(1, 0, 0, 1, 0, 1): [1.0]
(0, 1, 1, 0, 1, 0): [1.0]
(0, 1, 1, 0, 0, 1): [1.0]
(0, 1, 0, 1, 1, 0): [1.0]
(0, 1, 0, 1, 0, 1): [1.0]
[11]:
# Redundancy free computation

# Set test function

def fn(x):
    x1, x2 = x
    return torch.stack([1.0*x1 + 2.0*x2 + 3.0*x1**2 + 4.0*x1*x2 + 5.0*x2**2])

# Set derivative order

n = 2

# Set evaluation point

x = torch.tensor([0.0, 0.0], dtype=dtype, device=device)

# Compute n'th derivative

value = derivative(n, fn, x, intermediate=False)
print(value.cpu().numpy().tolist())

# Since derivatives are computed by nesting of jacobian function, redundant computations appear starting from the second order
# Redundant computations can be avoided if all input arguments are scalar tensors

def gn(x1, x2):
    return fn(torch.stack([x1, x2]))

print(derivative((2, 0), gn, *x, intermediate=False).cpu().numpy().tolist())
print(derivative((1, 1), gn, *x, intermediate=False).cpu().numpy().tolist())
print(derivative((0, 2), gn, *x, intermediate=False).cpu().numpy().tolist())
[[[6.0, 4.0], [4.0, 10.0]]]
[6.0]
[4.0]
[10.0]

Example-02: Derivative table representation

[1]:
# Input function f: R^n x R^m x ... -> R^n is referred as a mapping
# The first function argument is state, other arguments (used in computation of derivatives) and knobs
# State and all knobs are vector-like tensors
# Note, functions of this form can be used to model tranformations throught accelerator magnets

# In this case, derivatives can be used to generate a (parametric) model of the input function
# Function model can be represented as a derivative table or coefficients of monomials (series representation)

# In this example, table representation is used to model transformation throught a sextupole accelerator magnet
# Table is computed with respect to state variables (phase space variables) and knobs (magnet strength and length)
[2]:
# Import

import numpy
import torch

from ndmap.derivative import derivative
from ndmap.signature import signature
from ndmap.signature import get
from ndmap.index import index
from ndmap.index import reduce
from ndmap.index import build
from ndmap.series import series
from ndmap.evaluate import evaluate
from ndmap.evaluate import compare

torch.set_printoptions(precision=12, sci_mode=True)
print(torch.cuda.is_available())

import warnings
warnings.filterwarnings("ignore")
True
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Mapping (sextupole accelerator magnet transformatijet)
# Given initial state, magnet strength and length, state is propagated using explicit symplectic integration
# Number of integration steps is set by count parameter, integration step length is length/count

def mapping(x, k, l, count=10):
    (qx, px, qy, py), (k, ), (l, ) = x, k, l/(2.0*count)
    for _ in range(count):
        qx, qy = qx + l*px, qy + l*py
        px, py = px - 2.0*l*k*(qx**2 - qy**2), py + 2.0*l*k*qx*qy
        qx, qy = qx + l*px, qy + l*py
    return torch.stack([qx, px, qy, py])
[5]:
# Table representation (state)

# Set evaluation point & parameters

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)
k = torch.tensor([10.0], dtype=dtype, device=device)
l = torch.tensor([0.1], dtype=dtype, device=device)

# Compute derivatives (table representation)
# Since derivatives are computed only with respect to the state, output table is a list of tensors

t = derivative(6, mapping, x, k, l)

print(*[element.shape for element in t], sep='\n')
torch.Size([4])
torch.Size([4, 4])
torch.Size([4, 4, 4])
torch.Size([4, 4, 4, 4])
torch.Size([4, 4, 4, 4, 4])
torch.Size([4, 4, 4, 4, 4, 4])
torch.Size([4, 4, 4, 4, 4, 4, 4])
[6]:
# Compare table and exact mapping near the evaluation point (change order to observe convergence)
# Note, table transformation is not symplectic

dx = torch.tensor([0.0, 0.001, 0.0001, 0.0], dtype=dtype, device=device)

print(evaluate(t, [dx]).cpu().tolist())
print(mapping(x + dx, k, l).cpu().tolist())
[0.00010000041624970626, 0.0010000066749862018, 0.00010000016750044096, 5.000018166514047e-09]
[0.0001000004162497062, 0.0010000066749862018, 0.00010000016750044096, 5.000018166514046e-09]
[7]:
# Each bottom element (tensor) in the (flattend) derivative table is assosiated with a signature
# Signature is a tuple of derivative orders

print(signature(t))
[(0,), (1,), (2,), (3,), (4,), (5,), (6,)]
[8]:
# For a given signature, corresponding element can be extracted or changed with get/set functions

print(get(t, (1, )).cpu().numpy())
[[1.  0.1 0.  0. ]
 [0.  1.  0.  0. ]
 [0.  0.  1.  0.1]
 [0.  0.  0.  1. ]]
[9]:
# Each bottom element is related to monomials
# For given order, monomial indices with repetitions can be computed
# These repetitions account for evaluation of the same partial derivatives with diffenent orders, e.g. df/dxdy vs df/dydx

print(index(4, 2))
[[2 0 0 0]
 [1 1 0 0]
 [1 0 1 0]
 [1 0 0 1]
 [1 1 0 0]
 [0 2 0 0]
 [0 1 1 0]
 [0 1 0 1]
 [1 0 1 0]
 [0 1 1 0]
 [0 0 2 0]
 [0 0 1 1]
 [1 0 0 1]
 [0 1 0 1]
 [0 0 1 1]
 [0 0 0 2]]
[10]:
# Explicit evaluation

print(evaluate(t, [dx]).cpu().numpy())
print((t[0] + t[1] @ dx + 1/2 * t[2] @ dx @ dx + 1/2 * 1/3 * t[3] @ dx @ dx @ dx + 1/2 * 1/3 * 1/4 * t[4] @ dx @ dx @ dx @ dx + 1/2 * 1/3 * 1/4 * 1/5 * t[5] @ dx @ dx @ dx @ dx @ dx + 1/2 * 1/3 * 1/4 * 1/5 * 1/6 * t[6] @ dx @ dx @ dx @ dx @ dx @ dx).cpu().numpy())
print((t[0] + (t[1] + 1/2 * (t[2] + 1/3 * (t[3] + 1/4 * (t[4] + 1/5 * (t[5] + 1/6 * t[6] @ dx) @ dx) @ dx) @ dx) @ dx) @ dx).cpu().numpy())
[1.00000416e-04 1.00000667e-03 1.00000168e-04 5.00001817e-09]
[1.00000416e-04 1.00000667e-03 1.00000168e-04 5.00001817e-09]
[1.00000416e-04 1.00000667e-03 1.00000168e-04 5.00001817e-09]
[11]:
# Series representation can be generated from a given table
# This representation stores monomial powers and corresponding coefficients

s = series((4, ), (6, ), t)
print(torch.stack([s[(1, 0, 0, 0)], s[(0, 1, 0, 0)], s[(0, 0, 1, 0)], s[(0, 0, 0, 1)]]).cpu().numpy())
[[1.  0.  0.  0. ]
 [0.1 1.  0.  0. ]
 [0.  0.  1.  0. ]
 [0.  0.  0.1 1. ]]
[12]:
# Evaluate series

print(evaluate(t, [dx]).cpu().numpy())
print(evaluate(s, [dx]).cpu().numpy())
[1.00000416e-04 1.00000667e-03 1.00000168e-04 5.00001817e-09]
[1.00000416e-04 1.00000667e-03 1.00000168e-04 5.00001817e-09]
[13]:
# Table representation (state & knobs)

# Set evaluation point

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)
k = torch.tensor([10.0], dtype=dtype, device=device)
l = torch.tensor([0.1], dtype=dtype, device=device)

# Compute derivatives (table representation)
# Since derivatives are computed with respect to state and knobs, output table is a nested list of tensors

t = derivative((6, 1, 1), mapping, x, k, l)
[14]:
# In this case, bottom table element signature is a tuple with several integers

print(get(t, (1, 0, 0)).cpu().numpy())
[[1.  0.1 0.  0. ]
 [0.  1.  0.  0. ]
 [0.  0.  1.  0.1]
 [0.  0.  0.  1. ]]
[15]:
# Compare table and exact mapping near evaluation point (change order to observe convergence)
# Note, table transofrmation is not symplectic

dx = torch.tensor([0.0, 0.001, 0.0001, 0.0], dtype=dtype, device=device)
dk = torch.tensor([0.1], dtype=dtype, device=device)
dl = torch.tensor([0.001], dtype=dtype, device=device)

print(evaluate(t, [dx, 0.0*dk, 0.0*dl]).cpu().tolist())
print(evaluate(t, [dx, 1.0*dk, 1.0*dl]).cpu().tolist())
print(mapping(x + dx, k + dk, l + dl).cpu().tolist())
[0.00010000041624970626, 0.0010000066749862018, 0.00010000016750044096, 5.000018166514047e-09]
[0.0001010004271286862, 0.001000006741987918, 0.00010000017425071835, 5.1510191197368185e-09]
[0.00010100042712809422, 0.0010000067409770773, 0.00010000017430164039, 5.151524128394736e-09]
[16]:
# Each bottom element (tensor) in the (flattend) derivative table is assosiated with a signature
# Signature is a tuple of derivative orders

print(*[index for index in signature(t)], sep='\n')
(0, 0, 0)
(0, 0, 1)
(0, 1, 0)
(0, 1, 1)
(1, 0, 0)
(1, 0, 1)
(1, 1, 0)
(1, 1, 1)
(2, 0, 0)
(2, 0, 1)
(2, 1, 0)
(2, 1, 1)
(3, 0, 0)
(3, 0, 1)
(3, 1, 0)
(3, 1, 1)
(4, 0, 0)
(4, 0, 1)
(4, 1, 0)
(4, 1, 1)
(5, 0, 0)
(5, 0, 1)
(5, 1, 0)
(5, 1, 1)
(6, 0, 0)
(6, 0, 1)
(6, 1, 0)
(6, 1, 1)
[17]:
# Compute series

s = series((4, 1, 1), (6, 1, 1), t)

# Keys are generalized monomials

print(s[(1, 1, 1, 1, 1, 1)].cpu().numpy())
print()

# Evaluate series

print(evaluate(t, [dx, dk, dl]).cpu().numpy())
print(evaluate(s, [dx, dk, dl]).cpu().numpy())
print()
[1.447578e-06 9.756747e-05 0.000000e+00 0.000000e+00]

[1.01000427e-04 1.00000674e-03 1.00000174e-04 5.15101912e-09]
[1.01000427e-04 1.00000674e-03 1.00000174e-04 5.15101912e-09]

[18]:
# Reduced table representation

sequence, shape, unique = reduce((4, 1, 1), t)
out = derivative((6, 1, 1), lambda x, k, l: x, x, k, l)
build(out, sequence, shape, unique)
compare(t, out)
[18]:
True

Example-03: Derivative table propagation

[1]:
# Given a mapping f(state, *knobs, ...) and a derivative table t, derivatives of f(t, *knobs, ...) are computed
# This can be used to propagate derivative table throught a given mapping (computation of parametric fixed points and other applications)
[2]:
# Import

import numpy
import torch

from ndmap.derivative import derivative
from ndmap.series import series
from ndmap.series import clean
from ndmap.evaluate import evaluate
from ndmap.propagate import identity
from ndmap.propagate import propagate

torch.set_printoptions(precision=12, sci_mode=True)
print(torch.cuda.is_available())

import warnings
warnings.filterwarnings("ignore")
True
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Define mappings

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=100):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def ring(x, w):
    x = quad(x, w, +0.25, 0.5)
    x = drif(x, w, 5.0)
    x = quad(x, w, -0.20, 0.5)
    x = quad(x, w, -0.20, 0.5)
    x = drif(x, w, 5.0)
    x = quad(x, w, +0.25, 0.5)
    return x
[5]:
# Direct

# Set evaluation point

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)
w = torch.tensor([0.0], dtype=dtype, device=device)

# Compute derivatives

t = derivative((1, 4), ring, x, w)

# Evaluate for a given deviation

dx = torch.tensor([0.001, 0.0, 0.001, 0.0], dtype=dtype, device=device)
dw = torch.tensor([0.001], dtype=dtype, device=device)

print(ring(x + dx, w + dw).cpu().numpy())
print(evaluate(t, [dx, dw]).cpu().numpy())
[-8.88568650e-05 -5.43957672e-05  4.97569694e-04 -1.40349102e-04]
[-8.88568650e-05 -5.43957672e-05  4.97569694e-04 -1.40349102e-04]
[6]:
# Propagation

# Set evaluation point

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)
w = torch.tensor([0.0], dtype=dtype, device=device)

# Set identity table

t = identity((1, 4), [x, w])

# Propagate table

t = propagate((4, 1), (1, 4), t, [w], quad, +0.25, 0.5)
t = propagate((4, 1), (1, 4), t, [w], drif, 5.0)
t = propagate((4, 1), (1, 4), t, [w], quad, -0.20, 0.5)
t = propagate((4, 1), (1, 4), t, [w], quad, -0.20, 0.5)
t = propagate((4, 1), (1, 4), t, [w], drif, 5.0)
t = propagate((4, 1), (1, 4), t, [w], quad, +0.25, 0.5)

# Evaluate for a given deviation

dx = torch.tensor([0.001, 0.0, 0.001, 0.0], dtype=dtype, device=device)
dw = torch.tensor([0.001], dtype=dtype, device=device)

print(ring(x + dx, w + dw).cpu().numpy())
print(evaluate(t, [dx, dw]).cpu().numpy())
[-8.88568650e-05 -5.43957672e-05  4.97569694e-04 -1.40349102e-04]
[-8.88568650e-05 -5.43957672e-05  4.97569694e-04 -1.40349102e-04]
[7]:
# Series representation

s = clean(series((4, 1), (1, 4), t))
for key, value in s.items():
    print(f'{key}: {value.cpu().numpy()}')
(1, 0, 0, 0, 0): [-0.09072843 -0.05430498  0.          0.        ]
(0, 1, 0, 0, 0): [18.26293553 -0.09072843  0.          0.        ]
(0, 0, 1, 0, 0): [ 0.          0.          0.49625858 -0.14063034]
(0, 0, 0, 1, 0): [0.         0.         5.35963583 0.49625858]
(1, 0, 0, 0, 1): [ 1.87420951 -0.09097396  0.          0.        ]
(0, 1, 0, 0, 1): [-24.33227046   1.87420951   0.           0.        ]
(0, 0, 1, 0, 1): [0.         0.         1.31324305 0.28161167]
(0, 0, 0, 1, 1): [0.         0.         1.46426265 1.31324305]
(1, 0, 0, 0, 2): [-2.65007558  0.18727838  0.          0.        ]
(0, 1, 0, 0, 2): [30.20570906 -2.65007558  0.          0.        ]
(0, 0, 1, 0, 2): [ 0.          0.         -2.12812904 -0.37146989]
(0, 0, 0, 1, 2): [ 0.          0.         -8.46895909 -2.12812904]
(1, 0, 0, 0, 3): [ 3.41796043 -0.28459189  0.          0.        ]
(0, 1, 0, 0, 3): [-35.88102956   3.41796043   0.           0.        ]
(0, 0, 1, 0, 3): [0.         0.         2.94802229 0.46018886]
(0, 0, 0, 1, 3): [ 0.          0.         15.6516443   2.94802229]
(1, 0, 0, 0, 4): [-4.17749922  0.38289808  0.          0.        ]
(0, 1, 0, 0, 4): [41.3560843  -4.17749922  0.          0.        ]
(0, 0, 1, 0, 4): [ 0.          0.         -3.77254434 -0.54775243]
(0, 0, 0, 1, 4): [  0.           0.         -23.00943634  -3.77254434]
[8]:
# Check invariant
# Note, ring has two quadratic invariants (actions), zeros are padded to match state length

# Define invariant

matrix = torch.tensor([[4.282355639365032, 0.0, 0.0, 0.0], [0.0, 0.23351633638449415, 0.0, 0.0], [0.0, 0.0, 2.484643367729646, 0.0], [0.0, 0.0, 0.0, 0.40247224732044934]], dtype=dtype, device=device)
def invariant(x):
    qx, px, qy, py = matrix.inverse() @ x
    return torch.stack([0.5*(qx**2 + px**2), 0.5*(qy**2 + py**2), *torch.tensor(2*[0.0], dtype=dtype, device=device)])

# Set evaluation point

x = torch.tensor([0.001, 0.0, 0.001, 0.0], dtype=dtype, device=device)
w = torch.tensor([0.0], dtype=dtype, device=device)

# Evaluate invarint for a given state and transformed state

print(invariant(x).cpu().numpy())
print(invariant(ring(x, w)).cpu().numpy())
[2.72649397e-08 8.09919549e-08 0.00000000e+00 0.00000000e+00]
[2.72649397e-08 8.09919549e-08 0.00000000e+00 0.00000000e+00]
[9]:
# Invariant propagation

# Set evaluation point

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)
w = torch.tensor([0.0], dtype=dtype, device=device)

# Compute table and series representations of invariant

t = derivative((2, ), invariant, x)
s = series((4, ), (2, ), t)

print(*[f'{key}: {value.cpu().numpy()}' for key, value in clean(s, epsilon=1.0E-14).items()], sep='\n')
print()

# Compute table and series representations of transformed invariant

t = derivative((2, ), lambda x: invariant(ring(x, w)), x)
s = series((4, ), (2, ), t)

print(*[f'{key}: {value.cpu().numpy()}' for key, value in clean(s, epsilon=1.0E-14).items()], sep='\n')
print()

# Propagate invariant

t = derivative((2, ), ring, x, w)
t = propagate((4, ), (2, ), t, [], invariant)
s = series((4, ), (2, ), t)

print(*[f'{key}: {value.cpu().numpy()}' for key, value in clean(s, epsilon=1.0E-14).items()], sep='\n')
print()
(2, 0, 0, 0): [0.02726494 0.         0.         0.        ]
(0, 2, 0, 0): [9.16928491 0.         0.         0.        ]
(0, 0, 2, 0): [0.         0.08099195 0.         0.        ]
(0, 0, 0, 2): [0.         3.08672633 0.         0.        ]

(2, 0, 0, 0): [0.02726494 0.         0.         0.        ]
(0, 2, 0, 0): [9.16928491 0.         0.         0.        ]
(0, 0, 2, 0): [0.         0.08099195 0.         0.        ]
(0, 0, 0, 2): [0.         3.08672633 0.         0.        ]

(2, 0, 0, 0): [0.02726494 0.         0.         0.        ]
(0, 2, 0, 0): [9.16928491 0.         0.         0.        ]
(0, 0, 2, 0): [0.         0.08099195 0.         0.        ]
(0, 0, 0, 2): [0.         3.08672633 0.         0.        ]

Example-04: Jet class

[1]:
# Jet is a convenience class to work with jets (evaluation point & derivative table)
[2]:
# Import

import numpy
import torch

from ndmap.derivative import derivative
from ndmap.evaluate import evaluate
from ndmap.jet import Jet

torch.set_printoptions(precision=12, sci_mode=True)
print(torch.cuda.is_available())

import warnings
warnings.filterwarnings("ignore")
True
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Define mappings

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=100):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

# Set evaluation point

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)
w = torch.tensor([0.0], dtype=dtype, device=device)

# Compute table representation

t = derivative((1, 4), lambda x, w: quad(drif(x, w, 1.0), w, 1.0, 1.0, 1), x, w)
[5]:
# Set jet

j = Jet((4, 1), (1, 4), point=[x, w], dtype=dtype, device=device)
j = j.propagate(drif, 1.0)
j = j.propagate(quad, 1.0, 1.0, 1)
[6]:
# Evaluate at given deviation

dx = torch.tensor([0.001, 0.0, 0.001, 0.0], dtype=dtype, device=device)
dw = torch.tensor([0.001], dtype=dtype, device=device)

print(evaluate(t, [dx, dw]).cpu().numpy())
print(j([dx, dw]).cpu().numpy())
[ 0.0005005 -0.001      0.0014995  0.001    ]
[ 0.0005005 -0.001      0.0014995  0.001    ]
[7]:
# Composition

j1 = Jet.from_mapping((4, 1), (1, 4), [x, w], drif, 1.0, dtype=dtype, device=device)
j2 = Jet.from_mapping((4, 1), (1, 4), [x, w], quad, 1.0, 1.0, 1, dtype=dtype, device=device)

print((j1 @ j2)([dx, dw]).cpu().numpy())
[ 0.0005005 -0.001      0.0014995  0.001    ]

Example-05: Nonlinear mapping approximation

[1]:
# Composition of several nonlinear mappings can be approximated by its table representation
[2]:
# Import

import numpy
import torch

from ndmap.derivative import derivative
from ndmap.series import series
from ndmap.series import clean
from ndmap.evaluate import evaluate

torch.set_printoptions(precision=12, sci_mode=True)
print(torch.cuda.is_available())

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
True
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Set test mapping
# Rotation with two sextupoles separated by negative identity linear transformation
# Note, result is expected to have zero degree two coefficients due to negative identity linear transformation between sextupoles

def spin(x, mux, muy):
    (qx, px, qy, py), mux, muy = x, mux, muy
    return torch.stack([qx*mux.cos() + px*mux.sin(), px*mux.cos() - qx*mux.sin(), qy*muy.cos() + py*muy.sin(), py*muy.cos() - qy*muy.sin()])

def drif(x, l):
    (qx, px, qy, py), l = x, l
    return torch.stack([qx + l*px, px, qy + l*py, py])

def sext(x, ks, l, n=1):
    (qx, px, qy, py), ks, l = x, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px, qy + l*py
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px, qy + l*py
    return torch.stack([qx, px, qy, py])

def ring(x):
    mux, muy = 2.0*numpy.pi*torch.tensor([1/3 + 0.01, 1/4 + 0.01], dtype=dtype, device=device)
    x = spin(x, mux, muy)
    x = drif(x, -0.05)
    x = sext(x, 10.0, 0.1, 100)
    x = drif(x, -0.05)
    mux, muy = 2.0*numpy.pi*torch.tensor([0.50, 0.50], dtype=dtype, device=device)
    x = spin(x, mux, muy)
    x = drif(x, -0.05)
    x = sext(x, 10.0, 0.1, 100)
    x = drif(x, -0.05)
    return x
[5]:
# Set evaluation point

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)

# Compute derivative table

n = 4
t = derivative(n, ring, x)

# Compute and print series

s = clean(series((4, ), (n, ), t), epsilon=1.0E-12)
print(*[f'{key}: {value.cpu().numpy()}' for key, value in clean(s, epsilon=1.0E-14).items()], sep='\n')
(1, 0, 0, 0): [0.55339155 0.83292124 0.         0.        ]
(0, 1, 0, 0): [-0.83292124  0.55339155  0.          0.        ]
(0, 0, 1, 0): [0.         0.         0.06279052 0.99802673]
(0, 0, 0, 1): [ 0.          0.         -0.99802673  0.06279052]
(3, 0, 0, 0): [-7.53257307e-09  2.82424677e-03  0.00000000e+00  0.00000000e+00]
(2, 1, 0, 0): [-1.96250238e-08 -1.27525063e-02  0.00000000e+00  0.00000000e+00]
(2, 0, 1, 0): [ 0.00000000e+00 -0.00000000e+00  9.21186111e-06  3.34331441e-04]
(2, 0, 0, 1): [ 0.00000000e+00 -0.00000000e+00  1.59704766e-05 -5.06941449e-03]
(1, 2, 0, 0): [-1.11004920e-08  1.91940679e-02  0.00000000e+00  0.00000000e+00]
(1, 1, 1, 0): [ 0.00000000e+00 -0.00000000e+00 -2.98671134e-05 -9.79459005e-04]
(1, 1, 0, 1): [ 0.00000000e+00 -0.00000000e+00 -1.48185697e-05  1.53623878e-02]
(1, 0, 2, 0): [-1.05857397e-06  1.97282603e-05  0.00000000e+00  0.00000000e+00]
(1, 0, 1, 1): [ 1.48154798e-05 -1.18570682e-03  0.00000000e+00  0.00000000e+00]
(1, 0, 0, 2): [2.88067554e-05 9.18409783e-03 0.00000000e+00 0.00000000e+00]
(0, 3, 0, 0): [-9.21338044e-10 -9.62979589e-03  0.00000000e+00  0.00000000e+00]
(0, 2, 1, 0): [ 0.00000000e+00 -0.00000000e+00  2.40366928e-05  7.09979811e-04]
(0, 2, 0, 1): [ 0.00000000e+00 -0.00000000e+00 -1.38786549e-05 -1.15294375e-02]
(0, 1, 2, 0): [ 1.80396305e-06 -2.59196622e-05  0.00000000e+00  0.00000000e+00]
(0, 1, 1, 1): [-2.98509155e-05  1.72488752e-03  0.00000000e+00  0.00000000e+00]
(0, 1, 0, 2): [ 1.66318846e-05 -1.38269522e-02  0.00000000e+00  0.00000000e+00]
(0, 0, 3, 0): [ 0.00000000e+00 -0.00000000e+00 -1.47704279e-08  4.12534350e-06]
(0, 0, 2, 1): [ 0.00000000e+00 -0.00000000e+00 -3.31103168e-09 -1.96719688e-04]
(0, 0, 1, 2): [ 0.00000000e+00 -0.00000000e+00  3.93325786e-09  3.12683573e-03]
(0, 0, 0, 3): [ 0.00000000e+00 -0.00000000e+00  2.56887204e-10 -1.65665408e-02]
(4, 0, 0, 0): [ 6.48869023e-07 -3.91844462e-06  0.00000000e+00  0.00000000e+00]
(3, 1, 0, 0): [-3.91685700e-06  1.51001133e-05  0.00000000e+00  0.00000000e+00]
(3, 0, 1, 0): [ 0.00000000e+00 -0.00000000e+00 -4.36165465e-07  5.76316391e-06]
(3, 0, 0, 1): [ 0.00000000e+00 -0.00000000e+00  6.56410859e-06  1.31668166e-05]
(2, 2, 0, 0): [ 8.85501662e-06 -1.48880894e-05  0.00000000e+00  0.00000000e+00]
(2, 1, 1, 0): [ 0.00000000e+00 -0.00000000e+00  1.92232731e-06 -2.78183189e-05]
(2, 1, 0, 1): [ 0.00000000e+00 -0.00000000e+00 -2.96894787e-05 -3.16635607e-05]
(2, 0, 2, 0): [1.81838643e-08 2.18816315e-06 0.00000000e+00 0.00000000e+00]
(2, 0, 1, 1): [-9.93314202e-07 -3.25277215e-05  0.00000000e+00  0.00000000e+00]
(2, 0, 0, 2): [ 7.67316422e-06 -2.51871510e-05  0.00000000e+00  0.00000000e+00]
(1, 3, 0, 0): [-8.88618620e-06 -4.33921249e-06  0.00000000e+00  0.00000000e+00]
(1, 2, 1, 0): [ 0.00000000e+00 -0.00000000e+00 -2.81910856e-06  4.44963121e-05]
(1, 2, 0, 1): [ 0.00000000e+00 -0.00000000e+00  4.47107608e-05  6.22767407e-06]
(1, 1, 2, 0): [-5.02653030e-08 -6.69892371e-06  0.00000000e+00  0.00000000e+00]
(1, 1, 1, 1): [2.90625131e-06 1.04277202e-04 0.00000000e+00 0.00000000e+00]
(1, 1, 0, 2): [-2.28808176e-05  2.60346923e-05  0.00000000e+00  0.00000000e+00]
(1, 0, 3, 0): [ 0.00000000e+00 -0.00000000e+00  2.09236592e-09  3.51323891e-08]
(1, 0, 2, 1): [ 0.00000000e+00 -0.00000000e+00 -6.41657941e-09 -1.78350571e-06]
(1, 0, 1, 2): [ 0.00000000e+00 -0.00000000e+00 -6.10954936e-07  1.86435774e-05]
(1, 0, 0, 3): [ 0.00000000e+00 -0.00000000e+00  3.05503037e-06 -5.00150901e-05]
(0, 4, 0, 0): [3.33982092e-06 8.88114110e-06 0.00000000e+00 0.00000000e+00]
(0, 3, 1, 0): [ 0.00000000e+00 -0.00000000e+00  1.37553970e-06 -2.36217768e-05]
(0, 3, 0, 1): [ 0.00000000e+00 -0.00000000e+00 -2.24184174e-05  1.77513110e-05]
(0, 2, 2, 0): [3.54703897e-08 5.11986525e-06 0.00000000e+00 0.00000000e+00]
(0, 2, 1, 1): [-2.15414781e-06 -8.31702815e-05  0.00000000e+00  0.00000000e+00]
(0, 2, 0, 2): [1.72952174e-05 1.78791226e-05 0.00000000e+00 0.00000000e+00]
(0, 1, 3, 0): [ 0.00000000e+00 -0.00000000e+00 -3.32576455e-09 -2.16775859e-08]
(0, 1, 2, 1): [ 0.00000000e+00 -0.00000000e+00  1.73891518e-08  1.32003603e-06]
(0, 1, 1, 2): [ 0.00000000e+00 -0.00000000e+00  8.58583303e-07 -7.35197022e-06]
(0, 1, 0, 3): [ 0.00000000e+00 -0.00000000e+00 -4.60205741e-06 -3.44817289e-05]
(0, 0, 4, 0): [-2.95410616e-10 -3.91436795e-10  0.00000000e+00  0.00000000e+00]
(0, 0, 3, 1): [ 1.72261161e-08 -3.76703368e-08  0.00000000e+00  0.00000000e+00]
(0, 0, 2, 2): [-3.76632244e-07  1.04173914e-06  0.00000000e+00  0.00000000e+00]
(0, 0, 1, 3): [ 3.81179356e-06 -5.44741177e-06  0.00000000e+00  0.00000000e+00]
(0, 0, 0, 4): [-1.51552013e-05 -3.46854944e-07  0.00000000e+00  0.00000000e+00]
[6]:
# Compare phase space trajectories
# Note, change order to observe convergence

plt.figure(figsize=(10, 10))

# Direct tracking

x = torch.linspace(0.0, 5.0, 10, dtype=dtype, device=device)
x = torch.stack([x, *3*[torch.zeros_like(x)]]).T

count = 512
table = []

for _ in range(count):
    table.append(x)
    x = torch.func.vmap(lambda x: ring(x))(x)

table = torch.stack(table).swapaxes(0, -1)
qx, px, *_ = table

for q, p in zip(qx.cpu().numpy(), px.cpu().numpy()):
    plt.scatter(q, p, color='black', marker='o', s=1)

# Table tracking
# Note, table representation is not symplectic

x = torch.linspace(0.0, 5.0, 10, dtype=dtype, device=device)
x = torch.stack([x, *3*[torch.zeros_like(x)]]).T

count = 512
table = []

for _ in range(count):
    table.append(x)
    x = torch.func.vmap(lambda x: evaluate(t, [x]))(x)

table = torch.stack(table).swapaxes(0, -1)
qx, px, *_ = table

for q, p in zip(qx.cpu().numpy(), px.cpu().numpy()):
    plt.scatter(q, p, color='red', marker='x', s=1)

plt.show()
../_images/examples_ndmap_55_0.png

Example-06: Fixed point

[1]:
# In this example fixed points are computed for a simple symplectic nonlinear transformation
# Fixed point are computed with Newton root search
[2]:
# Import

import numpy
import torch

from ndmap.pfp import fixed_point
from ndmap.pfp import clean_point
from ndmap.pfp import chain_point
from ndmap.pfp import matrix

torch.set_printoptions(precision=12, sci_mode=True)
print(torch.cuda.is_available())

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
True
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Set forward & inverse mappings

mu = 2.0*numpy.pi*torch.tensor(1/3 - 0.01, dtype=dtype)
kq, ks, ko = torch.tensor([0.0, 0.25, -0.25], dtype=dtype)

def forward(x):
    q, p = x
    q, p = q*mu.cos() + p*mu.sin(), p*mu.cos() - q*mu.sin()
    q, p = q, p + (kq*q + ks*q**2 + ko*q**3)
    return torch.stack([q, p])

def inverse(x):
    q, p = x
    q, p = q, p - (kq*q + ks*q**2 + ko*q**3)
    q, p = q*mu.cos() - p*mu.sin(), p*mu.cos() + q*mu.sin()
    return torch.stack([q, p])
[5]:
# Compute period three fixed points

# Set fixed point period

period = 3

# Set tolerance epsilon

epsilon = 1.0E-12

# Set random initial points

points = 4.0*torch.rand((128, 2), dtype=dtype, device=device) - 2.0

# Perform 512 root search iterations for each initial point

points = torch.func.vmap(lambda point: fixed_point(512, forward, point, power=period))(points)

# Clean points (remove nans, duplicates, points from the same chain)

points = clean_point(period, forward, points, epsilon=epsilon)

# Generate fixed point chains

chains = torch.func.vmap(lambda point: chain_point(period, forward, point))(points)

# Classify fixed point chains (elliptic vs hyperbolic)
# Generate initials for hyperbolic fixed points using corresponding eigenvectors

kinds = []
for chain in chains:
    point, *_ = chain
    values, vectors = torch.linalg.eig(matrix(period, forward, point))
    kind = all(values.log().real < epsilon)
    kinds.append(kind)
    if not kind:
        lines = [point + vector*torch.linspace(-epsilon, +epsilon, 1024, dtype=dtype).reshape(-1, 1) for vector in vectors.real.T]
        lines = torch.stack(lines)

# Plot phase space

x = torch.linspace(0.0, 1.5, 21, dtype=dtype, device=device)
x = torch.stack([x, torch.zeros_like(x)]).T

count = 1024
table = []

for _ in range(count):
    table.append(x)
    x = torch.func.vmap(lambda x: forward(x))(x)

table = torch.stack(table).swapaxes(0, -1)
qs, ps = table

plt.figure(figsize=(10, 10))
plt.xlim(-2.0, 2.0)
plt.ylim(-2.0, 2.0)

for q, p in zip(qs.cpu().numpy(), ps.cpu().numpy()):
    plt.scatter(q, p, color='black', marker='o', s=1)

# Plot (approximated) stable and unstable  manifolds of hyperbolic fixed points

count = 310

for line in lines:

    x = torch.clone(line)
    table = []
    for _ in range(count):
        table.append(x)
        x = torch.func.vmap(lambda x: forward(x))(x)
    table = torch.stack(table).swapaxes(0, -1)
    qs, ps = table
    for q, p in zip(qs.cpu().numpy(), ps.cpu().numpy()):
        plt.scatter(q, p, color='gray', marker='o', s=1)

    x = torch.clone(line)
    table = []
    for _ in range(count):
        table.append(x)
        x = torch.func.vmap(lambda x: inverse(x))(x)
    table = torch.stack(table).swapaxes(0, -1)
    qs, ps = table
    for q, p in zip(qs.cpu().numpy(), ps.cpu().numpy()):
        plt.scatter(q, p, color='gray', marker='o', s=1)

# Plot chains

for chain, kind in zip(chains, kinds):
    plt.scatter(*chain.T, color = {True:'blue', False:'red'}[kind], marker='o')
../_images/examples_ndmap_61_0.png
[6]:
# Set mapping around elliptic fixed point

point, *_ = chains[kinds].squeeze()

def mapping(x):
    x = x + point
    for _ in range(period):
        x = forward(x)
    x = x - point
    return x

# Test mapping

x = torch.zeros_like(point)
print(x)
print(mapping(x))

# Plot phase space

x = torch.linspace(0.0, 1.5, 21, dtype=dtype, device=device)
x = torch.stack([x, torch.zeros_like(x)]).T

count = 1024
table = []

for _ in range(count):
    table.append(x)
    x = torch.func.vmap(lambda x: forward(x))(x)

table = torch.stack(table).swapaxes(0, -1)
qs, ps = table

plt.figure(figsize=(10, 10))
plt.xlim(-2.0, 2.0)
plt.ylim(-2.0, 2.0)

for q, p in zip(qs.cpu().numpy(), ps.cpu().numpy()):
    plt.scatter(q, p, color='black', marker='o', s=1)

x = torch.linspace(0.0, 0.5, 11, dtype=dtype, device=device)
x = torch.stack([x, torch.zeros_like(x)]).T

count = 1024
table = []

for _ in range(count):
    table.append(x)
    x = torch.func.vmap(lambda x: mapping(x))(x)

table = torch.stack(table).swapaxes(0, -1)
qs, ps = table + point.reshape(2, 1, 1)

for q, p in zip(qs.cpu().numpy(), ps.cpu().numpy()):
    plt.scatter(q, p, color='red', marker='o', s=1)
tensor([0., 0.], dtype=torch.float64)
tensor([-1.110223024625e-16,  0.000000000000e+00], dtype=torch.float64)
../_images/examples_ndmap_62_1.png

Example-07: Parametric fixed point

[1]:
# Given a mapping depending on a set of knobs (parameters), parametric fixed points can be computed (position of a fixed point as function of parameters)
# Parametric fixed points can be used to construct responce matrices, e.g. closed orbit responce
# In this case only first order derivatives of the fixed point(s) with respect to parameters are computed
# Or higher order expansions can be computed
# In this example parametric fixed points of a symplectic mapping are computed
[2]:
# Import

import numpy
import torch

from ndmap.util import flatten
from ndmap.evaluate import evaluate
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import clean_point
from ndmap.pfp import chain_point
from ndmap.pfp import matrix
from ndmap.pfp import parametric_fixed_point

torch.set_printoptions(precision=12, sci_mode=True)
print(torch.cuda.is_available())

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
True
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Set mapping

def mapping(x, k):
    q, p = x
    a, b = k
    q, p = q*mu.cos() + p*mu.sin(), p*mu.cos() - q*mu.sin()
    return torch.stack([q, p + a*q**2 + b*q**3])
[5]:
# Compute dynamical fixed points
# Note, fixed point might fail due to escape to large values

# Set parameters

mu = 2.0*numpy.pi*torch.tensor(1/5 - 0.01, dtype=dtype, device=device)
k = torch.tensor([0.25, -0.25], dtype=dtype, device=device)

# Compute and plot phase space trajectories

x = torch.linspace(0.0, 1.5, 21, dtype=dtype)
x = torch.stack([x, torch.zeros_like(x)]).T

count = 1024
table = []
for _ in range(count):
    table.append(x)
    x = torch.func.vmap(lambda x: mapping(x, k))(x)

table = torch.stack(table).swapaxes(0, -1)
qs, ps = table

plt.figure(figsize=(10, 10))
plt.xlim(-2.0, 2.0)
plt.ylim(-2.0, 2.0)
for q, p in zip(qs.cpu().numpy(), ps.cpu().numpy()):
    plt.scatter(q, p, color='black', marker='o', s=1)

# Set tolerance epsilon

epsilon = 1.0E-12

# Compute chains

period = 5
points = torch.rand((32, 2), dtype=dtype, device=device)
points = torch.func.vmap(lambda point: fixed_point(16, mapping, point, k, power=period))(points)
points = clean_point(period, mapping, points, k, epsilon=epsilon)
chains = torch.func.vmap(lambda point: chain_point(period, mapping, point, k))(points)

# Plot chains

for chain in chains:
    point, *_ = chain
    value, vector = torch.linalg.eig(matrix(period, mapping, point, k))
    color = 'blue' if all(value.log().real < epsilon) else 'red'
    plt.scatter(*chain.T, color=color, marker='o')
    if color == 'blue':
        ep, *_ = chain
    else:
        hp, *_ = chain

plt.show()
../_images/examples_ndmap_68_0.png
[6]:
# Compute hyperbolic fixed point for a set of knobs

dks = torch.stack(2*[torch.linspace(0.0, 0.01, 101, dtype=dtype, device=device)]).T

fps = [hp]
for dk in dks:
    *_, initial = fps
    fps.append(fixed_point(16, mapping, initial, k + dk, power=period))

fps = torch.stack(fps)
[7]:
# Compute parametric fixed point

# Set computation order
# Note, change order to observe convergence

order = 4
pfp = parametric_fixed_point((order, ), hp, [k], mapping, power=period)

# Set period mapping and check fixed point propagation

def function(x, k):
    for _ in range(period):
        x = mapping(x, k)
    return x

out = propagate((2, 2), (0, order), pfp, [k], function)
for x, y in zip(flatten(pfp, target=list), flatten(out, target=list)):
    print(torch.allclose(x, y))
True
True
True
True
True
[8]:
# Plot parametric fixed point position for a given set of knobs

out = torch.func.vmap(lambda dk: evaluate(pfp, [hp, dk]))(dks)

plt.figure(figsize=(20, 5))
plt.scatter(*fps.T.cpu().numpy(), color='blue', marker='o')
plt.scatter(*out.T.cpu().numpy(), color='red', marker='x')
plt.show()
../_images/examples_ndmap_71_0.png

Example-08: Fixed point manipulation (collision)

[1]:
# In this example the distance between a pair of hyperbolic and elliptic fixed points is minimized
# First, using a set of initial guesses within a region, a pair is obtained
# For a given pair, first order parametric dependence of fixed point positions is computed
# Gradient of the distance function between the points is computed (GD minimization)
[2]:
# Import

import numpy
import torch

from ndmap.derivative import derivative
from ndmap.evaluate import evaluate
from ndmap.pfp import fixed_point
from ndmap.pfp import clean_point
from ndmap.pfp import chain_point
from ndmap.pfp import matrix
from ndmap.pfp import parametric_fixed_point

torch.set_printoptions(precision=12, sci_mode=True)
print(torch.cuda.is_available())

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
True
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Set mapping

limit = 8
phase = 2.0*numpy.pi*(1/4 + 0.005)
phase = torch.tensor(phase/(limit + 1), dtype=dtype, device=device)

def mapping(state, knobs):
    q, p = state
    for index in range(limit):
        q, p = q*phase.cos() + p*phase.sin(), p*phase.cos() - q*phase.sin()
        q, p = q, p + knobs[index]*q**2
    q, p = q*phase.cos() + p*phase.sin(), p*phase.cos() - q*phase.sin()
    q, p = q, p + q**2
    return torch.stack([q, p])
[5]:
# Locate fixed points and select a pair

# Set initial knobs

knobs = torch.tensor(limit*[0.0], dtype=dtype, device=device)

# Compute and plot phase space trajectories

state = torch.linspace(0.0, 1.5, 21, dtype=dtype)
state = torch.stack([state, torch.zeros_like(state)]).T

count = 1024
table = []
for _ in range(count):
    table.append(state)
    state = torch.func.vmap(lambda state: mapping(state, knobs))(state)

table = torch.stack(table).swapaxes(0, -1)
qs, ps = table

plt.figure(figsize=(8, 8))
plt.xlim(-1., 1.)
plt.ylim(-1., 1.)
for q, p in zip(qs.cpu().numpy(), ps.cpu().numpy()):
    plt.scatter(q, p, color='black', marker='o', s=1)

# Set tolerance epsilon

epsilon = 1.0E-12

# Compute chains

period = 4
points = 4.0*torch.rand((512, 2), dtype=dtype, device=device) - 2.0
points = torch.func.vmap(lambda point: fixed_point(64, mapping, point, knobs, power=period))(points)
points = clean_point(period, mapping, points, knobs, epsilon=epsilon)
chains = torch.func.vmap(lambda point: chain_point(period, mapping, point, knobs))(points)

# Plot chains

for chain in chains:
    point, *_ = chain
    value, vector = torch.linalg.eig(matrix(period, mapping, point, knobs))
    color = 'blue' if all(value.log().real < epsilon) else 'red'
    plt.scatter(*chain.T, color=color, marker='o')
    if color == 'blue':
        ep, *_ = chain
    else:
        hp, *_ = chain

ep_chain, *_ = [chain for chain in chains if ep in chain]
hp_chain, *_ = [chain for chain in chains if hp in chain]

ep, *_ = ep_chain
hp, *_ = hp_chain[(ep - hp_chain).norm(dim=-1) == (ep - hp_chain).norm(dim=-1).min()]

plt.scatter(*ep.cpu().numpy(), color='black', marker='x')
plt.scatter(*hp.cpu().numpy(), color='black', marker='x')
plt.plot(*torch.stack([ep, hp]).T.cpu().numpy(), color='gray')

plt.show()
../_images/examples_ndmap_77_0.png
[6]:
# Compute first order parametric fixed points

order = 1

php = parametric_fixed_point((order, ), hp, [knobs], mapping, power=period)
pep = parametric_fixed_point((order, ), ep, [knobs], mapping, power=period)
[7]:
# Set objective function

def objective(knobs, php, pep):
    dhp = evaluate(php, [torch.zeros_like(knobs), knobs])
    dep = evaluate(pep, [torch.zeros_like(knobs), knobs])
    return (dep - dhp).norm()
[8]:
# Set learning rate and update knobs

lr = 0.0025
gradient = derivative(1, objective, knobs, php, pep, intermediate=False)
knobs -= lr*gradient
[9]:
# Iterate

# Set number of iterations

nitr = 5

# Loop

for intr in range(nitr):

    # Compute and plot phase space trajectories

    state = torch.linspace(0.0, 1.5, 21, dtype=dtype)
    state = torch.stack([state, torch.zeros_like(state)]).T

    table = []
    for _ in range(count):
        table.append(state)
        state = torch.func.vmap(lambda state: mapping(state, knobs))(state)

    table = torch.stack(table).swapaxes(0, -1)
    qs, ps = table

    plt.figure(figsize=(8, 8))
    plt.xlim(-1., 1.)
    plt.ylim(-1., 1.)
    for q, p in zip(qs.cpu().numpy(), ps.cpu().numpy()):
        plt.scatter(q, p, color='black', marker='o', s=1)

    # Find fixed points near previous values

    points = torch.stack([hp, ep])
    points = torch.func.vmap(lambda point: fixed_point(64, mapping, point, knobs, power=period))(points)
    points = clean_point(period, mapping, points, knobs, epsilon=epsilon)
    chains = torch.func.vmap(lambda point: chain_point(period, mapping, point, knobs))(points)

    # Plot chains and selected pair

    for chain in chains:
        point, *_ = chain
        value, vector = torch.linalg.eig(matrix(period, mapping, point, knobs))
        color = 'blue' if all(value.log().real < epsilon) else 'red'
        plt.scatter(*chain.T, color=color, marker='o')
        if color == 'blue':
            ep, *_ = chain
        else:
            hp, *_ = chain

    ep_chain, *_ = [chain for chain in chains if ep in chain]
    hp_chain, *_ = [chain for chain in chains if hp in chain]

    ep, *_ = ep_chain
    hp, *_ = hp_chain[(ep - hp_chain).norm(dim=-1) == (ep - hp_chain).norm(dim=-1).min()]

    plt.scatter(*ep.cpu().numpy(), color='black', marker='x')
    plt.scatter(*hp.cpu().numpy(), color='black', marker='x')
    plt.plot(*torch.stack([ep, hp]).T.cpu().numpy(), color='gray')

    plt.show()

    # Recompute parametric fixed points
    # Note,  not strictly necessary to do at each iteration

    php = parametric_fixed_point((order, ), hp, [knobs], mapping, power=period)
    pep = parametric_fixed_point((order, ), ep, [knobs], mapping, power=period)

    # Update

    lr *= 2.0
    gradient = derivative(1, objective, knobs, php, pep, intermediate=False)
    knobs -= lr*gradient
../_images/examples_ndmap_81_0.png
../_images/examples_ndmap_81_1.png
../_images/examples_ndmap_81_2.png
../_images/examples_ndmap_81_3.png
../_images/examples_ndmap_81_4.png

Example-09: Fixed point manipulation (change point type)

[1]:
# In this example real parts of the eigenvalues of a hyperbolic fixed point are minimized
# First, using a set of initial guesses within a region, a hyperbolic point is located
# Parametric fixed point is computed and propagated
# Propagated table is used as a surrogate model to generate differentible objective
[2]:
# Import

import numpy
import torch

from ndmap.util import nest
from ndmap.derivative import derivative
from ndmap.evaluate import evaluate
from ndmap.evaluate import compare
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import clean_point
from ndmap.pfp import chain_point
from ndmap.pfp import matrix
from ndmap.pfp import parametric_fixed_point

torch.set_printoptions(precision=12, sci_mode=True)
print(torch.cuda.is_available())

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
True
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Set mapping

limit = 8
phase = 2.0*numpy.pi*(1/4 + 0.005)
phase = torch.tensor(phase/(limit + 1), dtype=dtype, device=device)

def mapping(state, knobs):
    q, p = state
    for index in range(limit):
        q, p = q*phase.cos() + p*phase.sin(), p*phase.cos() - q*phase.sin()
        q, p = q, p + knobs[index]*q**2
    q, p = q*phase.cos() + p*phase.sin(), p*phase.cos() - q*phase.sin()
    q, p = q, p + q**2
    return torch.stack([q, p])
[5]:
# Locate fixed points and select a pair

# Set initial knobs

knobs = torch.tensor(limit*[0.0], dtype=dtype, device=device)

# Compute and plot phase space trajectories

state = torch.linspace(0.0, 1.5, 21, dtype=dtype)
state = torch.stack([state, torch.zeros_like(state)]).T

count = 1024
table = []
for _ in range(count):
    table.append(state)
    state = torch.func.vmap(lambda state: mapping(state, knobs))(state)

table = torch.stack(table).swapaxes(0, -1)
qs, ps = table

plt.figure(figsize=(8, 8))
plt.xlim(-1., 1.)
plt.ylim(-1., 1.)
for q, p in zip(qs.cpu().numpy(), ps.cpu().numpy()):
    plt.scatter(q, p, color='black', marker='o', s=1)

# Set tolerance epsilon

epsilon = 1.0E-12

# Compute chains

period = 4
points = 4.0*torch.rand((512, 2), dtype=dtype, device=device) - 2.0
points = torch.func.vmap(lambda point: fixed_point(64, mapping, point, knobs, power=period))(points)
points = clean_point(period, mapping, points, knobs, epsilon=epsilon)
chains = torch.func.vmap(lambda point: chain_point(period, mapping, point, knobs))(points)

# Plot chains

for chain in chains:
    point, *_ = chain
    value, vector = torch.linalg.eig(matrix(period, mapping, point, knobs))
    color = 'blue' if all(value.log().real < epsilon) else 'red'
    plt.scatter(*chain.T, color=color, marker='o')
    if color == 'blue':
        ep, *_ = chain
    else:
        hp, *_ = chain

ep_chain, *_ = [chain for chain in chains if ep in chain]
hp_chain, *_ = [chain for chain in chains if hp in chain]

ep, *_ = ep_chain
hp, *_ = hp_chain[(ep - hp_chain).norm(dim=-1) == (ep - hp_chain).norm(dim=-1).min()]

plt.scatter(*ep.cpu().numpy(), color='black', marker='x')
plt.scatter(*hp.cpu().numpy(), color='black', marker='x')

plt.show()
../_images/examples_ndmap_87_0.png
[6]:
# Matrix around (dynamical) fixed points
# Note, eigenvalues of a hyperbolic fixed point are not on the unit circle

em = matrix(period, mapping, ep, knobs)
print(em)
print(torch.linalg.eigvals(em).log().real)
print()

hm = matrix(period, mapping, hp, knobs)
print(hm)
print(torch.linalg.eigvals(hm).log().real)
print()
tensor([[2.739351528140e-01, -1.149307994931e+00],
        [5.971467208895e-01, 1.145141455240e+00]], dtype=torch.float64)
tensor([-1.784726318725e-15, -1.784726318725e-15], dtype=torch.float64)

tensor([[1.251182357765e+00, 1.288812994113e-01],
        [1.428760927086e-01, 8.139613303868e-01]], dtype=torch.float64)
tensor([2.545448611841e-01, -2.545448611841e-01], dtype=torch.float64)

[7]:
# Compute first order parametric fixed points

order = 1

php = parametric_fixed_point((order, ), hp, [knobs], mapping, power=period)
pep = parametric_fixed_point((order, ), ep, [knobs], mapping, power=period)
[8]:
# Propagate parametric identity table
# Note, propagated table can be used as a surrogate model around (parametric) fixed point
# Here it is used to compute parametric matrix around fixed point and its egenvalues

t = identity((1, 1), [hp, knobs], parametric=php)
t = propagate((2, limit), (1, 1), t, [knobs], nest(period, mapping, knobs))
[9]:
# Set objective function

def objective(knobs):
    hm = derivative(1, lambda x, k: evaluate(t, [x, k]), hp, knobs, intermediate=False)
    return torch.linalg.eigvals(hm).log().real.abs().sum()
[10]:
# Initial objective value

print(objective(knobs))
tensor(5.090897223682e-01, dtype=torch.float64)
[11]:
# Objective gradient

print(derivative(1, objective, knobs, intermediate=False))
tensor([-1.974190541982e-01, -4.137846718749e-01, -5.970747283789e-01, -7.018818743275e-01,
        -7.018818743275e-01, -5.970747283789e-01, -4.137846718749e-01, -1.974190541982e-01],
       dtype=torch.float64)
[12]:
# Set learning rate and update knobs

lr = 0.01
gradient = derivative(1, objective, knobs, intermediate=False)
knobs -= lr*gradient
[13]:
# Iterate

# Set number of iterations

nitr = 5

# Loop

for intr in range(nitr):

    state = torch.linspace(0.0, 1.5, 21, dtype=dtype)
    state = torch.stack([state, torch.zeros_like(state)]).T

    count = 1024
    table = []
    for _ in range(count):
        table.append(state)
        state = torch.func.vmap(lambda state: mapping(state, knobs))(state)

    table = torch.stack(table).swapaxes(0, -1)
    qs, ps = table

    plt.figure(figsize=(8, 8))
    plt.xlim(-1., 1.)
    plt.ylim(-1., 1.)
    for q, p in zip(qs.cpu().numpy(), ps.cpu().numpy()):
        plt.scatter(q, p, color='black', marker='o', s=1)

    # Set tolerance epsilon

    epsilon = 1.0E-12

    # Compute chains

    period = 4
    points = torch.stack([hp, ep])
    points = torch.func.vmap(lambda point: fixed_point(64, mapping, point, knobs, power=period))(points)
    points = clean_point(period, mapping, points, knobs, epsilon=epsilon)
    chains = torch.func.vmap(lambda point: chain_point(period, mapping, point, knobs))(points)

    # Plot chains

    for chain in chains:
        point, *_ = chain
        value, vector = torch.linalg.eig(matrix(period, mapping, point, knobs))
        color = 'blue' if all(value.log().real < epsilon) else 'red'
        plt.scatter(*chain.T, color=color, marker='o')
        if color == 'blue':
            ep, *_ = chain
        else:
            hp, *_ = chain

    ep_chain, *_ = [chain for chain in chains if ep in chain]
    hp_chain, *_ = [chain for chain in chains if hp in chain]

    ep, *_ = ep_chain
    hp, *_ = hp_chain[(ep - hp_chain).norm(dim=-1) == (ep - hp_chain).norm(dim=-1).min()]

    plt.scatter(*ep.cpu().numpy(), color='black', marker='x')
    plt.scatter(*hp.cpu().numpy(), color='black', marker='x')

    plt.show()
    print(objective(knobs).item())

    # Compute parametric fixed points

    php = parametric_fixed_point((order, ), hp, [knobs], mapping, power=period)
    pep = parametric_fixed_point((order, ), ep, [knobs], mapping, power=period)

    # Propagate parametric fixed points

    t = identity((1, 1), [hp, knobs], parametric=php)
    t = propagate((2, limit), (1, 1), t, [knobs], nest(period, mapping, knobs))

    # Update

    lr += 0.005
    gradient = derivative(1, objective, knobs, intermediate=False)
    knobs = knobs - lr*gradient
../_images/examples_ndmap_95_0.png
0.4875424149522949
../_images/examples_ndmap_95_2.png
0.44253915640171704
../_images/examples_ndmap_95_4.png
0.3947019262885917
../_images/examples_ndmap_95_6.png
0.34869831355705994
../_images/examples_ndmap_95_8.png
0.3054166976781556

Example-10: Alignment indices chaos indicators

[1]:
# Import

import numpy
import torch

from ndmap.derivative import derivative

torch.set_printoptions(precision=8, sci_mode=True, linewidth=128)
print(torch.cuda.is_available())

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
True
[2]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[3]:
# Set fixed parameters

a1, b1 = 0, 1
a2, b2 = 0, 1

f1 = torch.tensor(2.0*numpy.pi*0.38, dtype=dtype, device=device)
f2 = torch.tensor(2.0*numpy.pi*0.41, dtype=dtype, device=device)

cf1, sf1 = f1.cos(), f1.sin()
cf2, sf2 = f2.cos(), f2.sin()
[4]:
# Set 4D symplectic mapping

def mapping(x):
    q1, p1, q2, p2 = x
    return torch.stack([
        b1*(p1 + (q1**2 - q2**2))*sf1 + q1*(cf1 + a1*sf1),
        -((q1*(1 + a1**2)*sf1)/b1) + (p1 + (q1**2 - q2**2))*(cf1 - a1*sf1),
        q2*cf2 + (p2*b2 + q2*(a2 - 2*q1*b2))*sf2,
        -((q2*(1 + a2**2)*sf2)/b2) + (p2 - 2*q1*q2)*(cf2 - a2*sf2)
    ])
[5]:
# Set 4D symplectic mapping with tangent dynamics

def tangent(x, vs):
    x, m = derivative(1, mapping, x)
    vs = torch.func.vmap(lambda v: m @ v)(vs)
    return x, vs/vs.norm(dim=-1, keepdim=True)
[6]:
# Set generalized alignment indices computation

# Note, if number if vectors is equal to two, the index tends towards zero for regular orbits
# And tends towards a constant value for chaotic motion
# If the number of vectors is greater than two, index tends towards zero for all cases
# But for chaotic orbits, zero is reached (exponentialy) faster

def gali(vs, threshold=1.0E-12):
    return (threshold + torch.linalg.svdvals(vs).prod()).log10()
[7]:
# First, consider two initial conditions (regular and chaotic)

count = 1024

plt.figure(figsize=(8, 8))

x = torch.tensor([0.50000, 0.0, 0.05, 0.0], dtype=dtype)
orbit = []
for _ in range(count):
    x = mapping(x)
    orbit.append(x)
q, p, *_ = torch.stack(orbit).T
plt.scatter(q, p, s =1, color='blue')

x = torch.tensor([0.68925, 0.0, 0.10, 0.0], dtype=dtype)
orbit = []
for _ in range(count):
    x = mapping(x)
    orbit.append(x)
q, p, *_ = torch.stack(orbit).T
plt.scatter(q, p, s =1, color='red')

plt.show()
../_images/examples_ndmap_103_0.png
[8]:
# Compute and plot the last gali index at each iteration
# Note, running minimum is appended at each iteration

plt.figure(figsize=(20, 5))

x = torch.tensor([0.50000, 0.0, 0.05, 0.0], dtype=dtype, device=device)
vs = torch.eye(4, dtype=dtype, device=device)
out = []
for _ in range(count):
    x, vs = tangent(x, vs)
    res = gali(vs)
    out.append(res if not out else min(res, out[-1]))
out = torch.stack(out)
plt.scatter(range(count), out, color='blue', marker='o')

x = torch.tensor([0.68925, 0.0, 0.10, 0.0], dtype=dtype, device=device)
vs = torch.eye(4, dtype=dtype, device=device)
out = []
for _ in range(count):
    x, vs = tangent(x, vs)
    res = gali(vs)
    out.append(res if not out else min(res, out[-1]))
out = torch.stack(out)
plt.scatter(range(count), out, color='red', marker='o')

plt.show()
../_images/examples_ndmap_104_0.png
[9]:
# Compute indicator using all avaliable vectors for a grid of initial conditions

def gali(vs):
    return torch.linalg.svdvals(vs.nan_to_num()).prod()

# Set grid

n1 = 501
n2 = 501

q1 = torch.linspace(-1.0, +1.0, n1, dtype=dtype, device=device)
q2 = torch.linspace(+0.0, +1.0, n2, dtype=dtype, device=device)

qs = torch.stack(torch.meshgrid(q1, q2, indexing='ij')).swapaxes(-1, 0).reshape(n1*n2, -1)
ps = torch.full_like(qs, 1.0E-10)

q1, q2, p1, p2 = torch.hstack([qs, ps]).T

vs = torch.tensor(n1*n2*[torch.eye(4).tolist()], dtype=dtype, device=device)
qs = torch.stack([q1, p1, q2, p2]).T

# Set tast
# Perform 512 iterations, compute min of indicator value over the next 64 iterations

def task(qs, vs, count=512, total=64, level=1.0E-10):
    for _ in range(count):
        qs, vs = tangent(qs, vs)
    out = []
    for _ in range(total):
        qs, vs = tangent(qs, vs)
        out.append(gali(vs))
    return (torch.stack(out).min() + level*torch.sign(qs.norm())).log10()

# Compute and clean data

out = torch.vmap(task)(qs, vs)
out = out.nan_to_num(neginf=0.0)
out[(out >= -2.0)*(out != 0.0)] = -2.0
out[out == 0.0] = torch.nan
out = out.reshape(n1, n2)

# Plot

plt.figure(figsize=(8, 8))
plt.imshow(
    out.cpu().numpy(),
    vmin=-10.0,
    vmax=-2.0,
    aspect='equal',
    origin='lower',
    cmap='hot',
    interpolation='nearest')
plt.colorbar()
plt.axis('off')
plt.show()
../_images/examples_ndmap_105_0.png

Example-11: Closed orbit (dispersion)

[1]:
# In this example derivatives of closed orbit with respect to momentum deviation are computed
[2]:
# Import

import numpy
import torch

from ndmap.signature import chop
from ndmap.evaluate import evaluate
from ndmap.evaluate import compare
from ndmap.series import series
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import parametric_fixed_point

torch.set_printoptions(precision=8, sci_mode=True, linewidth=128)
print(torch.cuda.is_available())

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
True
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=50):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=10):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=50):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def kick(x, cx, cy):
    (qx, px, qy, py) = x
    return torch.stack([qx, px + cx, qy, py + cy])

def slip(x, dx, dy):
    (qx, px, qy, py) = x
    return torch.stack([qx + dx, px, qy + dy, py])
[5]:
# Set transport maps between observation points
# Note, here observation poins are locations between elements, lattice start and lattice end
# An observable (closed orbit) is computed at observation points

# All maps are expected to have identical signature of differentible parameters
# State and momentum deviation in this example
# But each map can have any number of additional args and kwargs after required parameters

def map_01_02(x, w): return quad(x, w, 0.19, 0.50)
def map_02_03(x, w): return drif(x, w, 0.45)
def map_03_04(x, w): return sext(x, w, 0.00, 0.10)
def map_04_05(x, w): return drif(x, w, 0.45)
def map_05_06(x, w): return bend(x, w, 22.92, 0.015, 0.00, 3.0)
def map_06_07(x, w): return drif(x, w, 0.45)
def map_07_08(x, w): return sext(x, w, 0.00, 0.10)
def map_08_09(x, w): return drif(x, w, 0.45)
def map_09_10(x, w): return quad(x, w, -0.21, 0.50)
def map_10_11(x, w): return quad(x, w, -0.21, 0.50)
def map_11_12(x, w): return drif(x, w, 0.45)
def map_12_13(x, w): return sext(x, w, 0.00, 0.10)
def map_13_14(x, w): return drif(x, w, 0.45)
def map_14_15(x, w): return bend(x, w, 22.92, 0.015, 0.00, 3.0)
def map_15_16(x, w): return drif(x, w, 0.45)
def map_16_17(x, w): return sext(x, w, 0.00, 0.10)
def map_17_18(x, w): return drif(x, w, 0.45)
def map_18_19(x, w): return quad(x, w, 0.19, 0.50)

transport = [
    map_01_02,
    map_02_03,
    map_03_04,
    map_04_05,
    map_05_06,
    map_06_07,
    map_07_08,
    map_08_09,
    map_09_10,
    map_10_11,
    map_11_12,
    map_12_13,
    map_13_14,
    map_14_15,
    map_15_16,
    map_16_17,
    map_17_18,
    map_18_19
]

# Define one-turn transport

def fodo(x, w):
    for mapping in transport:
        x = mapping(x, w)
    return x
[6]:
# The first step is to compute dynamical fixed point

# Set initial guess
# Note, in this example zero is a fixed point

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)

# Set knobs

w = torch.tensor([0.0], dtype=dtype, device=device)

# Find fixed point

fp = fixed_point(16, fodo, x, w, power=1)

print(fp)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[7]:
# Compute parametric closed orbit

pfp = parametric_fixed_point((2, ), fp, [w], fodo)
chop(pfp)

# Print series representation

for key, value in series((4, 1), (0, 2), pfp).items():
    print(f'{key}: {value.cpu().numpy()}')
(0, 0, 0, 0, 0): [0. 0. 0. 0.]
(0, 0, 0, 0, 1): [1.81613351 0.         0.         0.        ]
(0, 0, 0, 0, 2): [0.56855511 0.         0.         0.        ]
[8]:
# Check convergence

print(evaluate(series((4, 1), (0, 0), pfp), [x, w + 1.0E-3]))
print(evaluate(series((4, 1), (0, 1), pfp), [x, w + 1.0E-3]))
print(evaluate(series((4, 1), (0, 2), pfp), [x, w + 1.0E-3]))
print()

out = fixed_point(16, fodo, x, w + 1.0E-3, power=1)
chop([out])

print(out)
print()
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([1.81613351e-03, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], dtype=torch.float64)
tensor([1.81670206e-03, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], dtype=torch.float64)

tensor([  1.81670185e-03,   1.13882757e-19, -5.50391130e-235,   0.00000000e+00], dtype=torch.float64)

[9]:
# Propagate closed orbit

out = []

jet = identity((0, 1), fp, parametric=pfp)
out.append(jet)

for mapping in transport:
    jet = propagate((4, 1), (0, 2), jet, [w], mapping)
    out.append(jet)
[10]:
# Check periodicity

print(compare(pfp, jet))
True
[11]:
# Plot 1st order dispersion

pos = [0.00, 0.50, 0.95, 1.05, 1.50, 4.50, 4.95, 5.05, 5.50, 6.00, 6.50, 6.95, 7.05, 7.50, 10.50, 10.95, 11.05, 11.50, 12.00]
res = torch.stack([series((4, 1), (0, 2), jet)[(0, 0, 0, 0, 1)][0] for jet in out])

plt.figure(figsize=(20, 5))
plt.plot(pos, res.cpu().numpy(), marker='x', color='blue')
plt.show()
../_images/examples_ndmap_117_0.png
[12]:
# Plot 2nd order dispersion

pos = [0.00, 0.50, 0.95, 1.05, 1.50, 4.50, 4.95, 5.05, 5.50, 6.00, 6.50, 6.95, 7.05, 7.50, 10.50, 10.95, 11.05, 11.50, 12.00]
res = torch.stack([series((4, 1), (0, 2), jet)[(0, 0, 0, 0, 2)][0] for jet in out])

plt.figure(figsize=(20, 5))
plt.plot(pos, res.cpu().numpy(), marker='x', color='blue')
plt.show()
../_images/examples_ndmap_118_0.png

Example-12: Closed orbit (quadrupole shift)

[1]:
# Import

import numpy
import torch

from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.evaluate import evaluate
from ndmap.evaluate import compare
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import parametric_fixed_point

torch.set_printoptions(precision=6, sci_mode=True, linewidth=128)
print(torch.cuda.is_available())

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
True
[2]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[3]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=50):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=10):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=50):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def kick(x, cx, cy):
    (qx, px, qy, py) = x
    return torch.stack([qx, px + cx, qy, py + cy])

def slip(x, dx, dy):
    (qx, px, qy, py) = x
    return torch.stack([qx + dx, px, qy + dy, py])
[4]:
# Set transport maps between observation points

def map_01_02(x, d):
    dxf, dyf, dxd, dyd = d
    x = quad(x, [0.0], 0.19, 0.50)
    x = slip(x, -dxf, -dyf)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.10)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015, 0.00, 3.0)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.10)
    x = drif(x, [0.0], 0.45)
    x = slip(x, +dxd, +dyd)
    x = quad(x, [0.0], -0.21, 0.50)
    x = quad(x, [0.0], -0.21, 0.50)
    x = slip(x, -dxd, -dyd)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.10)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015, 0.00, 3.0)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.10)
    x = drif(x, [0.0], 0.45)
    x = slip(x, +dxf, +dyf)
    x = quad(x, [0.0], 0.19, 0.50)
    return x

transport = [
    map_01_02
]

# Define one-turn transport

def fodo(x, d):
    for mapping in transport:
        x = mapping(x, d)
    return x
[5]:
# Compute fixed point

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)
d = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)
fp = fixed_point(16, fodo, x, d, power=1)
print(fp)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[6]:
# Compute parametric fixed point

pfp = parametric_fixed_point((1, ), fp, [d], fodo)
chop(pfp)
pfp
[6]:
[[tensor([0., 0., 0., 0.], dtype=torch.float64),
  tensor([[-4.621551e-01, 0.000000e+00, 1.165780e+00, 0.000000e+00],
          [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00],
          [0.000000e+00, 3.344042e+00, 0.000000e+00, -4.891066e+00],
          [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00]], dtype=torch.float64)]]
[7]:
# Propagate parametric fixed point

out = propagate((4, 4), (0, 1), pfp, [d], fodo)
chop(out)
out
[7]:
[[tensor([0., 0., 0., 0.], dtype=torch.float64),
  tensor([[-4.621551e-01, 0.000000e+00, 1.165780e+00, 0.000000e+00],
          [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00],
          [0.000000e+00, 3.344042e+00, 0.000000e+00, -4.891066e+00],
          [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00]], dtype=torch.float64)]]
[8]:
# Test single random shift

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)
d = 1.0E-3*torch.randn_like(x)

fp = fixed_point(64, fodo, x, d, power=1, epsilon=1.0E-9)
chop([fp])

print(fp)
print(evaluate(pfp, [x, d]))
tensor([8.482363e-04, 8.125977e-21, 3.353913e-03, 6.171311e-20], dtype=torch.float64)
tensor([8.482363e-04, 0.000000e+00, 3.353913e-03, 0.000000e+00], dtype=torch.float64)
[9]:
# Estimate center & spread (tracking)

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)
d = 1.0E-3*torch.randn(8192, 4, dtype=dtype, device=device)

fp = torch.func.vmap(lambda d: fixed_point(64, fodo, x, d, power=1))(d)
chop([fp])

print(fp.T.mean(-1))
print(fp.T.std(-1))
tensor([-9.932218e-06,  4.372133e-22, -1.483549e-06,  1.693071e-21], dtype=torch.float64)
tensor([1.253079e-03, 9.554150e-20, 5.881071e-03, 3.289974e-19], dtype=torch.float64)
[10]:
# Estimate center & spread (surrogate)

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)
d = 1.0E-3*torch.randn(8192, 4, dtype=dtype, device=device)

fp = torch.func.vmap(lambda d: evaluate(pfp, [x, d]))(d)
chop([fp])

print(fp.T.mean(-1))
print(fp.T.std(-1))
tensor([ 3.969686e-06,  0.000000e+00, -2.718688e-05,  0.000000e+00], dtype=torch.float64)
tensor([1.240640e-03, 0.000000e+00, 5.905480e-03, 0.000000e+00], dtype=torch.float64)
[11]:
# Estimate spread (error propagation)

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)
d = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)
s = derivative(1, lambda d: evaluate(pfp, [x, d]), d, intermediate=False)

print((s @ (1.0E-3*torch.eye(4,  dtype=dtype, device=device))**2 @ s.T).diag().sqrt())
tensor([1.254045e-03, 0.000000e+00, 5.924960e-03, 0.000000e+00], dtype=torch.float64)

Example-13: Closed orbit (sextupole shift)

[1]:
# Import

import numpy
import torch

from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.evaluate import evaluate
from ndmap.evaluate import compare
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import parametric_fixed_point

torch.set_printoptions(precision=6, sci_mode=True, linewidth=128)
print(torch.cuda.is_available())

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
True
[2]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[3]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=50):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=10):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=50):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def kick(x, cx, cy):
    (qx, px, qy, py) = x
    return torch.stack([qx, px + cx, qy, py + cy])

def slip(x, dx, dy):
    (qx, px, qy, py) = x
    return torch.stack([qx + dx, px, qy + dy, py])
[4]:
# Set transport maps between observation points

def map_01_02(x, d):
    dxsf1, dysf1, dxsd1, dysd1, dxsf2, dysf2, dxsd2, dysd2 = d
    x = quad(x, [0.0], 0.19, 0.50)
    x = drif(x, [0.0], 0.45)
    x = slip(x, +dxsf1, +dysf1)
    x = sext(x, [0.0], 0.50, 0.10)
    x = slip(x, -dxsf1, -dysf1)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015, 0.00, 3.0)
    x = drif(x, [0.0], 0.45)
    x = slip(x, +dxsd1, +dysd1)
    x = sext(x, [0.0], -0.50, 0.10)
    x = slip(x, -dxsd1, -dysd1)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], -0.21, 0.50)
    x = quad(x, [0.0], -0.21, 0.50)
    x = drif(x, [0.0], 0.45)
    x = slip(x, +dxsd2, +dysd2)
    x = sext(x, [0.0], -0.50, 0.10)
    x = slip(x, -dxsd2, -dysd2)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015, 0.00, 3.0)
    x = drif(x, [0.0], 0.45)
    x = slip(x, +dxsf2, +dysf2)
    x = sext(x, [0.0], 0.50, 0.10)
    x = slip(x, -dxsf2, -dysf2)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], 0.19, 0.50)
    return x

transport = [
    map_01_02
]

# Define one-turn transport

def fodo(x, d):
    for mapping in transport:
        x = mapping(x, d)
    return x
[5]:
# Compute fixed point

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)
d = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)
fp = fixed_point(16, fodo, x, d, power=1)
print(fp)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[6]:
# Compute & check parametric fixed point
# Note, there is no first order contribution from sextupole shifts

pfp = parametric_fixed_point((2, ), fp, [d], fodo)
out = propagate((4, 8), (0, 2), pfp, [d], fodo)
print(compare(pfp, out))
True
[7]:
# Test for a random shift

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)
d = 1.0E-3*torch.randn(8, dtype=dtype, device=device)

fp = fixed_point(64, fodo, x, d, power=1, epsilon=1.0E-9)

print(fp)
print(evaluate(pfp, [x, d]))
tensor([ 2.446138e-07, -1.144473e-08, -8.221335e-07,  6.156318e-09], dtype=torch.float64)
tensor([ 2.442916e-07, -1.142839e-08, -8.240374e-07,  6.170032e-09], dtype=torch.float64)
[8]:
# Estimate center & spread (tracking)

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)
d = 1.0E-3*torch.randn(8192, 8, dtype=dtype, device=device)

fp = torch.func.vmap(lambda d: fixed_point(64, fodo, x, d, power=1))(d)
chop([fp])

print(fp.T.mean(-1))
print(fp.T.std(-1))
tensor([-1.018030e-09, -2.975896e-10, -1.314171e-08, -2.126598e-11], dtype=torch.float64)
tensor([6.895581e-07, 3.187716e-08, 1.813973e-06, 2.939394e-08], dtype=torch.float64)
[9]:
# Estimate center & spread (surrogate)

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)
d = 1.0E-3*torch.randn(8192, 8, dtype=dtype, device=device)

fp = torch.func.vmap(lambda d: evaluate(pfp, [x, d]))(d)
chop([fp])

print(fp.T.mean(-1))
print(fp.T.std(-1))
tensor([-4.887936e-09, -1.410492e-10, -3.069499e-09, -4.740614e-10], dtype=torch.float64)
tensor([6.769223e-07, 3.104300e-08, 1.819394e-06, 2.833964e-08], dtype=torch.float64)

Example-14: Closed orbit (responce matrix & correction)

[1]:
# In this example orbit responce matrix is computed
# Quadrupole shifts are introduced and responce matrix is used to correct the orbit at observation locations

# Correctors are at sextupole centers
# Observation points are at bend centers
[2]:
# Import

import numpy
import torch

from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.evaluate import evaluate
from ndmap.evaluate import compare
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import parametric_fixed_point

torch.set_printoptions(precision=6, sci_mode=True, linewidth=128)
print(torch.cuda.is_available())

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
True
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=50):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=10):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=25):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def kick(x, cx, cy):
    (qx, px, qy, py) = x
    return torch.stack([qx, px + cx, qy, py + cy])

def slip(x, dx, dy):
    (qx, px, qy, py) = x
    return torch.stack([qx + dx, px, qy + dy, py])
[5]:
# Set transport maps between observation points

def map_01_02(x, c, d):
    cxsf1, cxsd1, cxsf2, cxsd2, cysf1, cysd1, cysf2, cysd2 = c
    dxf, dyf, dxd, dyd = d
    x = quad(x, [0.0], 0.19, 0.50)
    x = slip(x, -dxf, -dyf)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.05)
    x = kick(x, cxsf1, cysf1)
    x = sext(x, [0.0], 0.00, 0.05)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015, 0.00, 1.5)
    return x

def map_02_03(x, c, d):
    cxsf1, cxsd1, cxsf2, cxsd2, cysf1, cysd1, cysf2, cysd2 = c
    dxf, dyf, dxd, dyd = d
    x = bend(x, [0.0], 22.92, 0.015, 0.00, 1.5)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.05)
    x = kick(x, cxsd1, cysd1)
    x = sext(x, [0.0], 0.00, 0.05)
    x = drif(x, [0.0], 0.45)
    x = slip(x, +dxd, +dyd)
    x = quad(x, [0.0], -0.21, 0.50)
    x = quad(x, [0.0], -0.21, 0.50)
    x = slip(x, -dxd, -dyd)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.05)
    x = kick(x, cxsd2, cysd2)
    x = sext(x, [0.0], 0.00, 0.05)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015, 0.00, 1.5)
    return x

def map_03_04(x, c, d):
    cxsf1, cxsd1, cxsf2, cxsd2, cysf1, cysd1, cysf2, cysd2 = c
    dxf, dyf, dxd, dyd = d
    x = bend(x, [0.0], 22.92, 0.015, 0.00, 1.5)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.05)
    x = kick(x, cxsf2, cysf2)
    x = sext(x, [0.0], 0.00, 0.05)
    x = drif(x, [0.0], 0.45)
    x = slip(x, +dxf, +dyf)
    x = quad(x, [0.0], 0.19, 0.50)
    return x

transport = [
    map_01_02,
    map_02_03,
    map_03_04
]

# Define one-turn transport

def fodo(x, c, d):
    for mapping in transport:
        x = mapping(x, c, d)
    return x

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
c = torch.tensor(8*[0.0], dtype=dtype, device=device)
d = torch.tensor(4*[0.0], dtype=dtype, device=device)
print(fodo(x, c, d))
tensor([0., 0., 0., 0.], dtype=torch.float64)
[6]:
# Compute fixed point

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
c = torch.tensor(8*[0.0], dtype=dtype, device=device)
d = torch.tensor(4*[0.0], dtype=dtype, device=device)

fp = fixed_point(16, fodo, x, c, d, power=1)
print(fp)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[7]:
# Compute parametric fixed point

pfp = parametric_fixed_point((1, ), fp, [c], fodo, d)
chop(pfp)
pfp
[7]:
[[tensor([0., 0., 0., 0.], dtype=torch.float64),
  tensor([[7.583253e+00, 5.939607e+00, 7.583253e+00, 5.939607e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00],
          [-4.334732e-01, -9.086786e-02, 4.334732e-01, 9.086786e-02, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00],
          [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.349234e+01, 2.165892e+01, 1.349234e+01, 2.165892e+01],
          [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, -4.019148e-01, -7.725758e-02, 4.019148e-01, 7.725758e-02]],
         dtype=torch.float64)]]
[8]:
# Propagate parametric fixed point

out = propagate((4, 8), (0, 1), pfp, [c], fodo, d)
chop(out)
out
[8]:
[[tensor([0., 0., 0., 0.], dtype=torch.float64),
  tensor([[7.583253e+00, 5.939607e+00, 7.583253e+00, 5.939607e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00],
          [-4.334732e-01, -9.086786e-02, 4.334732e-01, 9.086786e-02, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00],
          [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.349234e+01, 2.165892e+01, 1.349234e+01, 2.165892e+01],
          [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, -4.019148e-01, -7.725758e-02, 4.019148e-01, 7.725758e-02]],
         dtype=torch.float64)]]
[9]:
# Orbit derivatives at observation locations

bag = []

pfp = propagate((4, 8), (0, 1), pfp, [c], map_01_02, d)
chop(pfp)
bag.append(pfp)

pfp = propagate((4, 8), (0, 1), pfp, [c], map_02_03, d)
chop(pfp)
bag.append(pfp)
[10]:
# Compute responce matrix

def orbit(c, pfp):
    qx, _, qy, _ = evaluate(pfp, [x, c])
    return torch.stack([qx, qy])

pfp1, pfp2 = bag

rx1, ry1 = derivative(1, orbit, c, pfp1, intermediate=False)
rx2, ry2 = derivative(1, orbit, c, pfp2, intermediate=False)

rm = torch.stack([rx1, rx2, ry1, ry2])
print(rm)
tensor([[6.221115e+00, 4.042085e+00, 6.753998e+00, 4.569069e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00],
        [6.753998e+00, 4.569069e+00, 6.221115e+00, 4.042085e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00],
        [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.808222e+01, 2.754871e+01, 1.855563e+01, 2.802741e+01],
        [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.855563e+01, 2.802741e+01, 1.808222e+01, 2.754871e+01]],
       dtype=torch.float64)
[11]:
# Generate perturbed orbit

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
c = torch.tensor(8*[0.0], dtype=dtype, device=device)
d = torch.tensor([0.001, 0.001, -0.001, 0.001], dtype=dtype, device=device)

fp = fixed_point(16, fodo, x, c, d, power=1)

qx1, _, qy1, _ = map_01_02(fp, c, d)
qx2, _, qy2, _ = map_02_03(map_01_02(fp, c, d), c, d)

o = torch.stack([qx1, qx2, qy1, qy2])
chop([o])
print(o)
tensor([-2.161122e-03, -2.161122e-03, -3.001730e-03, -3.001730e-03], dtype=torch.float64)
[12]:
# Correct orbit

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
c = - (torch.linalg.pinv(rm) @ o)
d = torch.tensor([0.001, 0.001, -0.001, 0.001], dtype=dtype, device=device)

fp = fixed_point(16, fodo, x, c, d, power=1)

qx1, _, qy1, _ = map_01_02(fp, c, d)
qx2, _, qy2, _ = map_02_03(map_01_02(fp, c, d), c, d)

o = torch.stack([qx1, qx2, qy1, qy2])
chop([o])
print(o)
tensor([-1.070014e-18,  3.763685e-18, -1.733072e-17, -1.661201e-17], dtype=torch.float64)

Example-15: Tune (chromaticity)

[1]:
# Import

import numpy
import torch

from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.series import series
from ndmap.series import clean
from ndmap.evaluate import evaluate
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import parametric_fixed_point

from twiss.wolski import twiss

torch.set_printoptions(precision=6, sci_mode=True, linewidth=128)
print(torch.cuda.is_available())

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
True
[2]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[3]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=50):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=10):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=50):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def kick(x, cx, cy):
    (qx, px, qy, py) = x
    return torch.stack([qx, px + cx, qy, py + cy])

def slip(x, dx, dy):
    (qx, px, qy, py) = x
    return torch.stack([qx + dx, px, qy + dy, py])
[4]:
# Set transport maps between observation points

def map_01_02(x, w, k):
    ksf, ksd, ksb = k
    x = quad(x, w, 0.19, 0.50)
    x = drif(x, w, 0.45)
    x = sext(x, w, ksf, 0.10)
    x = drif(x, w, 0.45)
    x = bend(x, w, 22.92, 0.015, ksb, 3.0)
    x = drif(x, w, 0.45)
    x = sext(x, w, ksd, 0.10)
    x = drif(x, w, 0.45)
    x = quad(x, w, -0.21, 0.50)
    x = quad(x, w, -0.21, 0.50)
    x = drif(x, w, 0.45)
    x = sext(x, w, ksd, 0.10)
    x = drif(x, w, 0.45)
    x = bend(x, w, 22.92, 0.015, ksb, 3.0)
    x = drif(x, w, 0.45)
    x = sext(x, w, ksf, 0.10)
    x = drif(x, w, 0.45)
    x = quad(x, w, 0.19, 0.50)
    return x

transport = [
    map_01_02
]

# Define one-turn transport

def fodo(x, d, k):
    for mapping in transport:
        x = mapping(x, d, k)
    return x

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
w = torch.tensor(1*[0.0], dtype=dtype, device=device)
k = torch.tensor(3*[0.0], dtype=dtype, device=device)

print(fodo(x, w, k))
tensor([0., 0., 0., 0.], dtype=torch.float64)
[5]:
# Compute fixed point

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
w = torch.tensor(1*[0.0], dtype=dtype, device=device)
k = torch.tensor(3*[0.0], dtype=dtype, device=device)

fp = fixed_point(16, fodo, x, w, k, power=1)
print(fp)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[6]:
# Compute parametric fixed point

pfp = parametric_fixed_point((1, 1), fp, [w, k], fodo)
chop(pfp)
pfp
[6]:
[[[tensor([0., 0., 0., 0.], dtype=torch.float64),
   tensor([[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]], dtype=torch.float64)],
  [tensor([[1.816134e+00],
           [0.000000e+00],
           [0.000000e+00],
           [0.000000e+00]], dtype=torch.float64),
   tensor([[[0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.]]], dtype=torch.float64)]]]
[7]:
# Propagate parametric fixed point

out = propagate((4, 1, 3), (0, 1, 1), pfp, [w, k], fodo)
chop(out)
out
[7]:
[[[tensor([0., 0., 0., 0.], dtype=torch.float64),
   tensor([[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]], dtype=torch.float64)],
  [tensor([[1.816134e+00],
           [0.000000e+00],
           [0.000000e+00],
           [0.000000e+00]], dtype=torch.float64),
   tensor([[[0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.]]], dtype=torch.float64)]]]
[8]:
# Propagate parametric identity (surrogate model for linear dynamics)

jet = identity((1, 1, 1), fp, parametric=pfp)
jet = propagate((4, 1, 3), (1, 1, 1), jet, [w, k], fodo)
[9]:
# Compute tune

def tune(w, k):
    m = derivative(1, lambda x: evaluate(jet, [x, w, k]), fp, intermediate=False)
    t, *_ = twiss(m)
    return t

print(tune(w, k))
tensor([2.527795e-01, 1.199104e-01], dtype=torch.float64)
[10]:
# Compute parametric tune

t = derivative((1, 1), tune, w, k)

print(t)
[[tensor([2.527795e-01, 1.199104e-01], dtype=torch.float64), tensor([[0., 0., 0.],
        [0., 0., 0.]], dtype=torch.float64)], [tensor([[-2.343079e-01],
        [-2.416176e-01]], dtype=torch.float64), tensor([[[3.618782e-01],
         [8.705079e-02],
         [5.873074e+00]],

        [[-2.986097e-01],
         [-4.727344e-01],
         [-1.106560e+01]]], dtype=torch.float64)]]
[11]:
# Check convergence

print(evaluate(t, [w, k]))
print(evaluate(t, [w + 1.0E-3, k]))
print()

print(tune(w + 1.0E-3, k))
tensor([2.527795e-01, 1.199104e-01], dtype=torch.float64)
tensor([2.525452e-01, 1.196688e-01], dtype=torch.float64)

tensor([2.525452e-01, 1.196687e-01], dtype=torch.float64)
[12]:
# Check convergence

print(evaluate(t, [w, k]))
print(evaluate(t, [w + 1.0E-3, k]))
print(evaluate(t, [w + 1.0E-3, k + 1.0E-2]))
print()

print(tune(w + 1.0E-3, k + 1.0E-2))
tensor([2.527795e-01, 1.199104e-01], dtype=torch.float64)
tensor([2.525452e-01, 1.196688e-01], dtype=torch.float64)
tensor([2.526084e-01, 1.195504e-01], dtype=torch.float64)

tensor([2.526084e-01, 1.195502e-01], dtype=torch.float64)
[13]:
# Series representation

for key, value in clean(series((1, 3), (1, 1), t)).items():
    print(f'{key}: {value}')
(0, 0, 0, 0): tensor([2.527795e-01, 1.199104e-01], dtype=torch.float64)
(1, 0, 0, 0): tensor([-2.343079e-01, -2.416176e-01], dtype=torch.float64)
(1, 1, 0, 0): tensor([3.618782e-01, -2.986097e-01], dtype=torch.float64)
(1, 0, 1, 0): tensor([8.705079e-02, -4.727344e-01], dtype=torch.float64)
(1, 0, 0, 1): tensor([5.873074e+00, -1.106560e+01], dtype=torch.float64)
[14]:
# Set chromaticity to zero

A = derivative((1, 1), lambda w, k: evaluate(t, [w, k]), w, k, intermediate=False)
b = derivative(1, lambda w, k: evaluate(t, [w, k]), w, k, intermediate=False).flatten()

print(derivative(1, lambda w: evaluate(t, [w, k]), w, intermediate=False).flatten())
print(derivative(1, lambda w: evaluate(t, [w, - (torch.linalg.pinv(A.squeeze()) @ b)]), w, intermediate=False).flatten())
tensor([-2.343079e-01, -2.416176e-01], dtype=torch.float64)
tensor([-2.914335e-15,  1.165734e-15], dtype=torch.float64)
[15]:
# Set chromaticity to one

A = derivative((1, 1), lambda w, k: evaluate(t, [w, k]), w, k, intermediate=False)
b = -1.0 + derivative(1, lambda w, k: evaluate(t, [w, k]), w, k, intermediate=False).flatten()

print(derivative(1, lambda w: evaluate(t, [w, k]), w, intermediate=False).flatten())
print(derivative(1, lambda w: evaluate(t, [w, - (torch.linalg.pinv(A.squeeze()) @ b)]), w, intermediate=False).flatten())
tensor([-2.343079e-01, -2.416176e-01], dtype=torch.float64)
tensor([1.000000e+00, 1.000000e+00], dtype=torch.float64)

Example-16: Tune (responce matrix & correction)

[1]:
# Import

import numpy
import torch

from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.series import series
from ndmap.series import clean
from ndmap.evaluate import evaluate
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import parametric_fixed_point

from twiss.wolski import twiss

torch.set_printoptions(precision=6, sci_mode=True, linewidth=128)
print(torch.cuda.is_available())

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
True
[2]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[3]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=50):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=10):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=50):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def kick(x, cx, cy):
    (qx, px, qy, py) = x
    return torch.stack([qx, px + cx, qy, py + cy])

def slip(x, dx, dy):
    (qx, px, qy, py) = x
    return torch.stack([qx + dx, px, qy + dy, py])
[4]:
# Set transport maps between observation points

def map_01_02(x, k):
    kqf, kqd, kqb = k
    x = quad(x, [0.0], 0.19 + kqf, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.0, 0.10)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015 + kqb, 0.0, 3.0)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.0, 0.10)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], -0.21 + kqd, 0.50)
    x = quad(x, [0.0], -0.21 + kqd, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.0, 0.10)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015 + kqb, 0.0, 3.0)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.0, 0.10)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], 0.19 + kqf, 0.50)
    return x

transport = [
    map_01_02
]

# Define one-turn transport

def fodo(x, k):
    for mapping in transport:
        x = mapping(x, k)
    return x

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
k = torch.tensor(3*[0.0], dtype=dtype, device=device)

print(fodo(x, k))
tensor([0., 0., 0., 0.], dtype=torch.float64)
[5]:
# Compute fixed point

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
k = torch.tensor(3*[0.0], dtype=dtype, device=device)

fp = fixed_point(16, fodo, x, k, power=1)
print(fp)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[6]:
# Compute parametric fixed point

pfp = parametric_fixed_point((1, ), fp, [k], fodo)
chop(pfp)
pfp
[6]:
[[tensor([0., 0., 0., 0.], dtype=torch.float64),
  tensor([[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]], dtype=torch.float64)]]
[7]:
# Propagate parametric fixed point

out = propagate((4, 3), (0, 1), pfp, [k], fodo)
chop(out)
out
[7]:
[[tensor([0., 0., 0., 0.], dtype=torch.float64),
  tensor([[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]], dtype=torch.float64)]]
[8]:
# Propagate parametric identity (surrogate model for linear dynamics)

jet = identity((1, 1), fp, parametric=pfp)
jet = propagate((4, 3), (1, 1), jet, [k], fodo)
[9]:
# Compute tune

def tune(k):
    m = derivative(1, lambda x: evaluate(jet, [x, k]), fp, intermediate=False)
    t, *_ = twiss(m)
    return t

print(tune(k))
tensor([2.527795e-01, 1.199104e-01], dtype=torch.float64)
[10]:
# Compute parametric tune

t = derivative((1, ), tune, k)
[11]:
# Check convergence

print(evaluate(t, [k]))
print(evaluate(t, [k + torch.tensor([5.0E-3, 5.0E-3, 1.0E-3], dtype=dtype, device=device)]))
print()

print(tune(k + torch.tensor([5.0E-3, 5.0E-3, 1.0E-3], dtype=dtype, device=device)))
tensor([2.527795e-01, 1.199104e-01], dtype=torch.float64)
tensor([2.646746e-01, 9.564416e-02], dtype=torch.float64)

tensor([2.646564e-01, 9.339156e-02], dtype=torch.float64)
[12]:
# Responce matrix

print(derivative(1, lambda k: evaluate(t, [k]), k, intermediate=False))
tensor([[1.217003e+00, 3.230152e-01, 4.195000e+00],
        [-7.757018e-01, -2.437956e+00, -8.197983e+00]], dtype=torch.float64)
[13]:
# Correct tunes (increase horizontal by 0.01)

A = derivative(1, lambda k: evaluate(t, [k]), k, intermediate=False)
b = -torch.tensor([0.01, 0.0], dtype=dtype, device=device)

print(tune(k))
print(tune(-(torch.linalg.pinv(A) @ b)))
print()
tensor([2.527795e-01, 1.199104e-01], dtype=torch.float64)
tensor([2.627666e-01, 1.199040e-01], dtype=torch.float64)

Example-17: Tune (spread)

[1]:
# Import

import numpy
import torch

from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.series import series
from ndmap.series import clean
from ndmap.evaluate import evaluate
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import parametric_fixed_point

from twiss.wolski import twiss

torch.set_printoptions(precision=6, sci_mode=True, linewidth=128)
print(torch.cuda.is_available())

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
True
[2]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[3]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=50):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=10):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=50):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def kick(x, cx, cy):
    (qx, px, qy, py) = x
    return torch.stack([qx, px + cx, qy, py + cy])

def slip(x, dx, dy):
    (qx, px, qy, py) = x
    return torch.stack([qx + dx, px, qy + dy, py])
[4]:
# Set transport maps between observation points

def map_01_02(x, k):
    kqf, kqd, kqb = k
    x = quad(x, [0.0], 0.19 + kqf, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.0, 0.10)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015 + kqb, 0.0, 3.0)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.0, 0.10)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], -0.21 + kqd, 0.50)
    x = quad(x, [0.0], -0.21 + kqd, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.0, 0.10)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015 + kqb, 0.0, 3.0)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.0, 0.10)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], 0.19 + kqf, 0.50)
    return x

transport = [
    map_01_02
]

# Define one-turn transport

def fodo(x, k):
    for mapping in transport:
        x = mapping(x, k)
    return x

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
k = torch.tensor(3*[0.0], dtype=dtype, device=device)

print(fodo(x, k))
tensor([0., 0., 0., 0.], dtype=torch.float64)
[5]:
# Compute fixed point

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
k = torch.tensor(3*[0.0], dtype=dtype, device=device)

fp = fixed_point(16, fodo, x, k, power=1)
print(fp)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[6]:
# Compute parametric fixed point

pfp = parametric_fixed_point((2, ), fp, [k], fodo)
chop(pfp)
pfp
[6]:
[[tensor([0., 0., 0., 0.], dtype=torch.float64),
  tensor([[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]], dtype=torch.float64),
  tensor([[[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],

          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],

          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],

          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]]], dtype=torch.float64)]]
[7]:
# Propagate parametric fixed point

out = propagate((4, 3), (0, 2), pfp, [k], fodo)
chop(out)
out
[7]:
[[tensor([0., 0., 0., 0.], dtype=torch.float64),
  tensor([[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]], dtype=torch.float64),
  tensor([[[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],

          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],

          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]],

          [[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]]], dtype=torch.float64)]]
[8]:
# Propagate parametric identity (surrogate model for linear dynamics)

jet = identity((1, 2), fp, parametric=pfp)
jet = propagate((4, 3), (1, 2), jet, [k], fodo)
[9]:
# Compute tune

def tune(k):
    m = derivative(1, lambda x: evaluate(jet, [x, k]), fp, intermediate=False)
    t, *_ = twiss(m)
    return t

print(tune(k))
tensor([2.527795e-01, 1.199104e-01], dtype=torch.float64)
[10]:
# Compute tune spread (direct)

sf = 1.0E-3
sd = 1.0E-3
sb = 1.0E-4

def wrapper(k):
    t, *_ = twiss(derivative(1, fodo, x, k, intermediate=False))
    return t

err = torch.tensor([sf, sd, sb])*torch.randn(8192).unsqueeze(1).to(dtype).to(device)
out = torch.func.vmap(wrapper)(err).T

print(out.mean(-1))
print(out.std(-1))
print()
tensor([2.527650e-01, 1.198656e-01], dtype=torch.float64)
tensor([1.958515e-03, 4.040516e-03], dtype=torch.float64)

[11]:
# Compute parametric tune

t = derivative((1, ), tune, k)
[12]:
# Compute tune spread (surrogate)

sf = 1.0E-3
sd = 1.0E-3
sb = 1.0E-4

err = torch.tensor([sf, sd, sb])*torch.randn(8192).unsqueeze(1).to(dtype).to(device)
out = torch.func.vmap(lambda k: evaluate(t, [k]))(err).T

print(out.mean(-1))
print(out.std(-1))
print()
tensor([2.528075e-01, 1.198527e-01], dtype=torch.float64)
tensor([1.973939e-03, 4.063140e-03], dtype=torch.float64)

Example-18: Parametric twiss

[1]:
# Import

import numpy
import torch

from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.series import series
from ndmap.series import clean
from ndmap.evaluate import evaluate
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import parametric_fixed_point

from twiss.wolski import twiss
from twiss.wolski import propagate as propagate_twiss
from twiss.convert import wolski_to_cs

torch.set_printoptions(precision=3, sci_mode=True, linewidth=128)
print(torch.cuda.is_available())

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
True
[2]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[3]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=50):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=10):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=50):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def kick(x, cx, cy):
    (qx, px, qy, py) = x
    return torch.stack([qx, px + cx, qy, py + cy])

def slip(x, dx, dy):
    (qx, px, qy, py) = x
    return torch.stack([qx + dx, px, qy + dy, py])
[4]:
# Set transport maps between observation points

def map_01_02(x, k):
    kqf, kqd, kqb = k
    x = quad(x, [0.0], 0.19 + kqf, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.10)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015 + kqb, 0.00, 1.5)
    return x

def map_02_03(x, k):
    kqf, kqd, kqb = k
    x = bend(x, [0.0], 22.92, 0.015 + kqb, 0.00, 1.5)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.10)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], -0.21 + kqd, 0.50)
    x = quad(x, [0.0], -0.21 + kqd, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.1)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015 + kqb, 0.00, 1.5)
    return x

def map_03_04(x, k):
    kqf, kqd, kqb = k
    x = bend(x, [0.0], 22.92, 0.015 + kqb, 0.00, 1.5)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.1)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], 0.19 + kqf, 0.50)
    return x

transport = [
    map_01_02,
    map_02_03,
    map_03_04
]

# Define one-turn transport

def fodo(x, k):
    for mapping in transport:
        x = mapping(x, k)
    return x

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
k = torch.tensor(3*[0.0], dtype=dtype, device=device)

print(fodo(x, k))
tensor([0., 0., 0., 0.], dtype=torch.float64)
[5]:
# Compute fixed point

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
k = torch.tensor(3*[0.0], dtype=dtype, device=device)

fp = fixed_point(16, fodo, x, k, power=1)
print(fp)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[6]:
# Compute parametric fixed point

pfp = parametric_fixed_point((1, ), fp, [k], fodo)
chop(pfp)
pfp
[6]:
[[tensor([0., 0., 0., 0.], dtype=torch.float64),
  tensor([[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]], dtype=torch.float64)]]
[7]:
# Propagate parametric fixed point

out = propagate((4, 3), (0, 1), pfp, [k], fodo)
chop(out)
out
[7]:
[[tensor([0., 0., 0., 0.], dtype=torch.float64),
  tensor([[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]], dtype=torch.float64)]]
[8]:
# Propagate parametric identity (surrogate model for linear dynamics)

jet = identity((1, 1), fp, parametric=pfp)
jet = propagate((4, 3), (1, 1), jet, [k], fodo)
[9]:
# Compute twiss

# Note, exact or jet one-turn transport can be used
# Other maps can be replaced with jets too

def fn(k, fodo):

    bxs = []
    bys = []

    m = derivative(1, fodo, fp, intermediate=False)

    t, _, w = twiss(m)
    _, bx, _, by = wolski_to_cs(w)

    for mapping in transport:
        w = propagate_twiss(w, derivative(1, mapping, x, k, intermediate=False))
        _, bx, _, by = wolski_to_cs(w)
        bxs.append(bx)
        bys.append(by)

    return torch.stack([*t, *bxs, *bys])

print(fn(k, fodo=lambda x: evaluate(jet, [x, k])))
tensor([2.528e-01, 1.199e-01, 8.703e+00, 8.703e+00, 1.553e+01, 1.678e+01, 1.678e+01, 9.586e+00], dtype=torch.float64)
[10]:
# Compute parametric derivative using exact map (tune & beta)

d = derivative((1, ), lambda k: fn(k, fodo=lambda x: fodo(x, k)), k)

print(d)
[tensor([2.528e-01, 1.199e-01, 8.703e+00, 8.703e+00, 1.553e+01, 1.678e+01, 1.678e+01, 9.586e+00], dtype=torch.float64), tensor([[ 1.217e+00,  3.230e-01,  4.195e+00],
        [-7.757e-01, -2.438e+00, -8.198e+00],
        [-3.111e+01, -1.531e+01, -1.101e+02],
        [-3.111e+01, -1.531e+01, -1.101e+02],
        [-1.749e+00, -3.120e+01, -1.799e+02],
        [ 1.153e+02,  3.308e+02,  1.106e+03],
        [ 1.153e+02,  3.308e+02,  1.106e+03],
        [ 5.215e+01,  2.146e+02,  6.957e+02]], dtype=torch.float64)]
[11]:
# Compute parametric derivative using jet map (tune & beta)

d = derivative((1, ), lambda k: fn(k, fodo=lambda x: evaluate(jet, [x, k])), k)

print(d)
[tensor([2.528e-01, 1.199e-01, 8.703e+00, 8.703e+00, 1.553e+01, 1.678e+01, 1.678e+01, 9.586e+00], dtype=torch.float64), tensor([[ 1.217e+00,  3.230e-01,  4.195e+00],
        [-7.757e-01, -2.438e+00, -8.198e+00],
        [-3.111e+01, -1.531e+01, -1.101e+02],
        [-3.111e+01, -1.531e+01, -1.101e+02],
        [-1.749e+00, -3.120e+01, -1.799e+02],
        [ 1.153e+02,  3.308e+02,  1.106e+03],
        [ 1.153e+02,  3.308e+02,  1.106e+03],
        [ 5.215e+01,  2.146e+02,  6.957e+02]], dtype=torch.float64)]
[12]:
# Check covergence

dk = torch.tensor([1.0E-3, 1.0E-3, 1.0E-4], dtype=dtype, device=device)

values = fn(k, fodo=lambda x: fodo(x, k))
print(' '.join([f'{value:.3f}' for value in values.cpu().tolist()]))

values = evaluate(d, [dk])
print(' '.join([f'{value:.3f}' for value in values.cpu().tolist()]))

values = fn(k + dk, fodo=lambda x: fodo(x, k + dk))
print(' '.join([f'{value:.3f}' for value in values.cpu().tolist()]))
0.253 0.120 8.703 8.703 15.532 16.780 16.780 9.586
0.255 0.116 8.645 8.645 15.481 17.337 17.337 9.922
0.255 0.116 8.646 8.646 15.482 17.366 17.366 9.940
[13]:
# Responce matrix

m = derivative((1, ), lambda k: fn(k, fodo=lambda x: evaluate(jet, [x, k])), k, intermediate=False)

print(m)
tensor([[ 1.217e+00,  3.230e-01,  4.195e+00],
        [-7.757e-01, -2.438e+00, -8.198e+00],
        [-3.111e+01, -1.531e+01, -1.101e+02],
        [-3.111e+01, -1.531e+01, -1.101e+02],
        [-1.749e+00, -3.120e+01, -1.799e+02],
        [ 1.153e+02,  3.308e+02,  1.106e+03],
        [ 1.153e+02,  3.308e+02,  1.106e+03],
        [ 5.215e+01,  2.146e+02,  6.957e+02]], dtype=torch.float64)
[14]:
# Correction

# The target values (tunes and beta functions) are associated with model response matrix
# Given measured values, the goal is to alter knobs to get target values

# Set target values

vf = fn(k, fodo=lambda x: fodo(x, k))

# Set initial solution

sol = torch.zeros_like(dk)

# Iterate

for _ in range(4):

    # Compute current values and set difference

    vi = fn(k + dk + sol, fodo=lambda x: evaluate(jet, [x, k + dk + sol]))

    # Set difference

    dv = vf - vi

    # Update solution

    sol += torch.linalg.pinv(m) @ dv

    # Verbose

    print(-dk)
    print(sol)
    print(dv.norm())
    print()

    # Continue
tensor([-1.000e-03, -1.000e-03, -1.000e-04], dtype=torch.float64)
tensor([-1.064e-03, -1.258e-03, -4.094e-05], dtype=torch.float64)
tensor(9.036e-01, dtype=torch.float64)

tensor([-1.000e-03, -1.000e-03, -1.000e-04], dtype=torch.float64)
tensor([-1.000e-03, -1.000e-03, -9.990e-05], dtype=torch.float64)
tensor(4.251e-02, dtype=torch.float64)

tensor([-1.000e-03, -1.000e-03, -1.000e-04], dtype=torch.float64)
tensor([-1.000e-03, -1.000e-03, -1.000e-04], dtype=torch.float64)
tensor(8.589e-05, dtype=torch.float64)

tensor([-1.000e-03, -1.000e-03, -1.000e-04], dtype=torch.float64)
tensor([-1.000e-03, -1.000e-03, -1.000e-04], dtype=torch.float64)
tensor(3.559e-10, dtype=torch.float64)

[15]:
# Note, similar to tune spread example, it is possible to compute twiss spread
# First order computation can be performed using error propagation
# Or higher order jets can be sampled

Example-19: Parametric phase advance

[1]:
# Import

import numpy
import torch

from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.series import series
from ndmap.series import clean
from ndmap.evaluate import evaluate
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import parametric_fixed_point

from twiss.wolski import twiss
from twiss.wolski import propagate as propagate_twiss
from twiss.wolski import advance

torch.set_printoptions(precision=3, sci_mode=True, linewidth=128)
print(torch.cuda.is_available())

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
True
[2]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[3]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=50):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=10):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=50):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def kick(x, cx, cy):
    (qx, px, qy, py) = x
    return torch.stack([qx, px + cx, qy, py + cy])

def slip(x, dx, dy):
    (qx, px, qy, py) = x
    return torch.stack([qx + dx, px, qy + dy, py])
[4]:
# Set transport maps between observation points

def map_01_02(x, k):
    kqf, kqd, kqb = k
    x = quad(x, [0.0], 0.19 + kqf, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.10)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015 + kqb, 0.00, 1.5)
    return x

def map_02_03(x, k):
    kqf, kqd, kqb = k
    x = bend(x, [0.0], 22.92, 0.015 + kqb, 0.00, 1.5)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.10)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], -0.21 + kqd, 0.50)
    x = quad(x, [0.0], -0.21 + kqd, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.1)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015 + kqb, 0.00, 1.5)
    return x

def map_03_04(x, k):
    kqf, kqd, kqb = k
    x = bend(x, [0.0], 22.92, 0.015 + kqb, 0.00, 1.5)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.1)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], 0.19 + kqf, 0.50)
    return x

transport = [
    map_01_02,
    map_02_03,
    map_03_04
]

# Define one-turn transport

def fodo(x, k):
    for mapping in transport:
        x = mapping(x, k)
    return x

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
k = torch.tensor(3*[0.0], dtype=dtype, device=device)

print(fodo(x, k))
tensor([0., 0., 0., 0.], dtype=torch.float64)
[5]:
# Compute fixed point

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
k = torch.tensor(3*[0.0], dtype=dtype, device=device)

fp = fixed_point(16, fodo, x, k, power=1)
print(fp)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[6]:
# Compute parametric fixed point

pfp = parametric_fixed_point((1, ), fp, [k], fodo)
chop(pfp)
pfp
[6]:
[[tensor([0., 0., 0., 0.], dtype=torch.float64),
  tensor([[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]], dtype=torch.float64)]]
[7]:
# Propagate parametric fixed point

out = propagate((4, 3), (0, 1), pfp, [k], fodo)
chop(out)
out
[7]:
[[tensor([0., 0., 0., 0.], dtype=torch.float64),
  tensor([[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]], dtype=torch.float64)]]
[8]:
# Propagate parametric identity (surrogate model for linear dynamics)

jet = identity((1, 1), fp, parametric=pfp)
jet = propagate((4, 3), (1, 1), jet, [k], fodo)
[9]:
# Compute phase advance

# Note, exact or jet one-turn transport can be used
# Other maps can be replaced with jets too

def fn(k, fodo):

    mus = []

    m = derivative(1, fodo, fp, intermediate=False)

    t, n, _ = twiss(m)

    for mapping in transport:
        mu, n = advance(n, derivative(1, mapping, x, k, intermediate=False))
        mus.append(mu)

    mus = torch.stack(mus).T

    return torch.stack([*t, *mus.flatten()])

print(fn(k, fodo=lambda x: evaluate(jet, [x, k])))
tensor([2.528e-01, 1.199e-01, 2.521e-01, 1.084e+00, 2.521e-01, 2.468e-01, 2.599e-01, 2.468e-01], dtype=torch.float64)
[10]:
# Compute parametric derivative using exact map (tune & advance)

d = derivative((1, ), lambda k: fn(k, fodo=lambda x: fodo(x, k)), k)

print(d)
[tensor([2.528e-01, 1.199e-01, 2.521e-01, 1.084e+00, 2.521e-01, 2.468e-01, 2.599e-01, 2.468e-01], dtype=torch.float64), tensor([[1.217e+00, 3.230e-01, 4.195e+00],
        [-7.757e-01, -2.438e+00, -8.198e+00],
        [4.458e-01, 4.852e-01, 2.926e+00],
        [6.755e+00, 1.059e+00, 2.051e+01],
        [4.458e-01, 4.852e-01, 2.926e+00],
        [-1.523e+00, -5.303e+00, -1.726e+01],
        [-1.827e+00, -4.712e+00, -1.699e+01],
        [-1.523e+00, -5.303e+00, -1.726e+01]], dtype=torch.float64)]
[11]:
# Compute parametric derivative using jet map (tune & advance)

d = derivative((1, ), lambda k: fn(k, fodo=lambda x: evaluate(jet, [x, k])), k)

print(d)
[tensor([2.528e-01, 1.199e-01, 2.521e-01, 1.084e+00, 2.521e-01, 2.468e-01, 2.599e-01, 2.468e-01], dtype=torch.float64), tensor([[1.217e+00, 3.230e-01, 4.195e+00],
        [-7.757e-01, -2.438e+00, -8.198e+00],
        [4.458e-01, 4.852e-01, 2.926e+00],
        [6.755e+00, 1.059e+00, 2.051e+01],
        [4.458e-01, 4.852e-01, 2.926e+00],
        [-1.523e+00, -5.303e+00, -1.726e+01],
        [-1.827e+00, -4.712e+00, -1.699e+01],
        [-1.523e+00, -5.303e+00, -1.726e+01]], dtype=torch.float64)]
[12]:
# Check covergence

dk = torch.tensor([1.0E-3, 1.0E-3, 1.0E-4], dtype=dtype, device=device)

values = fn(k, fodo=lambda x: fodo(x, k))
print(' '.join([f'{value:.3f}' for value in values.cpu().tolist()]))

values = evaluate(d, [dk])
print(' '.join([f'{value:.3f}' for value in values.cpu().tolist()]))

values = fn(k + dk, fodo=lambda x: fodo(x, k + dk))
print(' '.join([f'{value:.3f}' for value in values.cpu().tolist()]))
0.253 0.120 0.252 1.084 0.252 0.247 0.260 0.247
0.255 0.116 0.253 1.094 0.253 0.238 0.252 0.238
0.255 0.116 0.253 1.094 0.253 0.238 0.251 0.238
[13]:
# Responce matrix

m = derivative((1, ), lambda k: fn(k, fodo=lambda x: evaluate(jet, [x, k])), k, intermediate=False)

print(m)
tensor([[1.217e+00, 3.230e-01, 4.195e+00],
        [-7.757e-01, -2.438e+00, -8.198e+00],
        [4.458e-01, 4.852e-01, 2.926e+00],
        [6.755e+00, 1.059e+00, 2.051e+01],
        [4.458e-01, 4.852e-01, 2.926e+00],
        [-1.523e+00, -5.303e+00, -1.726e+01],
        [-1.827e+00, -4.712e+00, -1.699e+01],
        [-1.523e+00, -5.303e+00, -1.726e+01]], dtype=torch.float64)
[14]:
# Correction

# The target values (tunes and advance functions) are associated with model response matrix
# Given measured values, the goal is to alter knobs to get target values

# Set target values

vf = fn(k, fodo=lambda x: fodo(x, k))

# Set initial solution

sol = torch.zeros_like(dk)

# Iterate

for _ in range(4):

    # Compute current values and set difference

    vi = fn(k + dk + sol, fodo=lambda x: evaluate(jet, [x, k + dk + sol]))

    # Set difference

    dv = vf - vi

    # Update solution

    sol += torch.linalg.pinv(m) @ dv

    # Verbose

    print(-dk)
    print(sol)
    print(dv.norm())
    print()

    # Continue
tensor([-1.000e-03, -1.000e-03, -1.000e-04], dtype=torch.float64)
tensor([-1.043e-03, -1.068e-03, -8.224e-05], dtype=torch.float64)
tensor(1.846e-02, dtype=torch.float64)

tensor([-1.000e-03, -1.000e-03, -1.000e-04], dtype=torch.float64)
tensor([-1.000e-03, -1.000e-03, -1.000e-04], dtype=torch.float64)
tensor(2.001e-04, dtype=torch.float64)

tensor([-1.000e-03, -1.000e-03, -1.000e-04], dtype=torch.float64)
tensor([-1.000e-03, -1.000e-03, -1.000e-04], dtype=torch.float64)
tensor(2.786e-08, dtype=torch.float64)

tensor([-1.000e-03, -1.000e-03, -1.000e-04], dtype=torch.float64)
tensor([-1.000e-03, -1.000e-03, -1.000e-04], dtype=torch.float64)
tensor(2.344e-15, dtype=torch.float64)

[15]:
# Note, similar to tune spread example, it is possible to compute advance spread
# First order computation can be performed using error propagation
# Or higher order jets can be sampled

Example-20: Poisson bracket

[1]:
# Import

import numpy
import torch

from ndmap.derivative import derivative
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.bracket import bracket

torch.set_printoptions(precision=3, sci_mode=True, linewidth=128)
print(torch.cuda.is_available())

import warnings
warnings.filterwarnings("ignore")
True
[2]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[3]:
# Poisson bracket between observable/mapping or mapping/observable or mapping/mapping

# [f, g]                   -> [f, g]
# [[f1, f2], g]            -> [[f1, g], [f2, g]]
# [f, [g1, g2]]            -> [[f, g1], [f, g2]]
# [[f1, f2], [g1, g2]]     -> [[f1, g1], [f2, g2]]


def f(x): q, p = x ; return q
def g(x): q, p = x ; return p
x = torch.tensor([0.0, 0.0], dtype=dtype, device=device)
print(bracket(f, g)(x))

def f(x): q, p = x ; return torch.stack([q, p])
def g(x): q, p = x ; return p
x = torch.tensor([0.0, 0.0], dtype=dtype, device=device)
print(bracket(f, g)(x))

def f(x): q, p = x ; return q
def g(x): q, p = x ; return torch.stack([q, p])
x = torch.tensor([0.0, 0.0], dtype=dtype, device=device)
print(bracket(f, g)(x))

def f(x): q, p = x ; return torch.stack([q, p])
def g(x): q, p = x ; return torch.stack([q, p])
x = torch.tensor([0.0, 0.0], dtype=dtype, device=device)
print(bracket(f, g)(x))
tensor(1.000e+00, dtype=torch.float64)
tensor([1.000e+00, 0.000e+00], dtype=torch.float64)
tensor([0.000e+00, 1.000e+00], dtype=torch.float64)
tensor([0., 0.], dtype=torch.float64)
[4]:
# Returns a function that can be differentiated

def f(x): q, p = x ; return q**2
def g(x): q, p = x ; return p
x = torch.tensor([0.0, 0.0], dtype=dtype, device=device)
print(derivative(1, bracket(f, g), x, intermediate=False))
tensor([2.000e+00, 0.000e+00], dtype=torch.float64)
[5]:
# Accepts Table input
# Note, evaluation is at deviation point

tf = propagate((2, ), (2, ), identity((2, ), [x]), [], f)
tg = propagate((2, ), (2, ), identity((2, ), [x]), [], g)
print(derivative(1, bracket(tf, tg), x, intermediate=False))
tensor([2.000e+00, 0.000e+00], dtype=torch.float64)
[6]:
# Accepts Series input
# Note, evaluation is at deviation point

sf = propagate((2, ), (2, ), identity((2, ), [x], flag=True), [], f)
sg = propagate((2, ), (2, ), identity((2, ), [x], flag=True), [], g)
print(derivative(1, bracket(tf, tg), x, intermediate=False))
tensor([2.000e+00, 0.000e+00], dtype=torch.float64)
[7]:
# Propagate identity

t = propagate((2, ), (1, ), identity((1, ), [x]), [], bracket(f, g))
print(t)
[tensor(0., dtype=torch.float64), tensor([2.000e+00, 0.000e+00], dtype=torch.float64)]

Example-21: Taylor integrator

[1]:
# Given an autonomous hamiltonian function H
# The exact solution is x(t) = exp([-t H]) x with Poisson bracket operator [f] g := [f, g]
# Truncated solution is x(t) = exp([-t H]) x = x + t [-H] x + 1/2! t**2 [-H]**2 x + ...
# Such series solution doesn't preserve symplectic structure in general
# Taylor integration is differentiable with respect to time step, initial value and parameters

# Note, generation or derivatives can be extremely slow
[2]:
# Import

import numpy
import torch

from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.series import series
from ndmap.series import clean
from ndmap.evaluate import evaluate
from ndmap.taylor import taylor

torch.set_printoptions(precision=6, sci_mode=True, linewidth=128)
print(torch.cuda.is_available())

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
True
[3]:
# Set data type and device

dtype = torch.float32
device = torch.device('cpu')
[4]:
# Integrate for a given time step and ititial condition
# Note, hamiltonian is not preserved for all cases, but the hamiltonian drift time scale is different

# Set nonlinear oscillator hamiltonian function

def h(x):
    q, p = x
    return p**2/2 + q**2/2 + q**3/3

# Set time step

dt = torch.tensor(0.15, dtype=dtype, device=device)

# Set initial condition

xi = torch.tensor([0.4, 0.0], dtype=dtype, device=device)

# Integrate and plot orbits for several values of truncation order

count = 1024

plt.figure(figsize=(8, 8))
plt.xlim(-1.0, 1.0)
plt.ylim(-1.0, 1.0)

for order, color in zip([1, 2, 4], ['black', 'red', 'blue']):
    orbit = []
    x = torch.clone(xi)
    for _ in range(count):
        x = taylor(order, dt, h, x)
        orbit.append(x)
    orbit = torch.stack(orbit)
    plt.scatter(*orbit.T.cpu().numpy(), color=color, s=1)

plt.show()
../_images/examples_ndmap_241_0.png
[5]:
# Generate derivative table representation at zero
# Note, here a smaller time step is used

x = torch.tensor([0.0, 0.0], dtype=dtype, device=device)
t = derivative(4, lambda x: taylor(6, 0.05, h, x), x)

# Compute series representation

s = series((2, ), (4, ), t)

# Print series

for key, value in s.items():
    print(f'{key}: {value.cpu()}')
(0, 0): tensor([0., 0.])
(1, 0): tensor([9.987502e-01, -4.997917e-02])
(0, 1): tensor([4.997917e-02, 9.987502e-01])
(2, 0): tensor([-1.249219e-03, -4.993753e-02])
(1, 1): tensor([-4.164063e-05, -2.497397e-03])
(0, 2): tensor([-5.206163e-07, -4.164063e-05])
(3, 0): tensor([5.203993e-07, 4.161458e-05])
(2, 1): tensor([2.604167e-08, 2.601563e-06])
(1, 2): tensor([4.340278e-10, 5.208334e-08])
(0, 3): tensor([0.000000e+00, 4.340278e-10])
(4, 0): tensor([-2.170139e-10, -2.604167e-08])
(3, 1): tensor([ 0.000000e+00, -1.736111e-09])
(2, 2): tensor([0., 0.])
(1, 3): tensor([0., 0.])
(0, 4): tensor([0., 0.])
[6]:
# Integrate using derivative table and a batch of initials

q = torch.linspace(0.1, 0.6, 21)
p = torch.zeros_like(q)
x = torch.stack([q, p]).T

orbits = []

for _ in range(count):
    x = torch.func.vmap(lambda x: evaluate(t, [x]))(x)
    orbits.append(x)

orbits = torch.stack(orbits).swapaxes(0, 1)

plt.figure(figsize=(8, 8))
plt.xlim(-1.5, 1)
plt.ylim(-1, 1)

for orbit in orbits:
    plt.scatter(*orbit.T.cpu().numpy(), color='black', s=1)

plt.show()
../_images/examples_ndmap_243_0.png
[7]:
# Derivatives with respect to time step and other parameters

def h(x, kq, ks):
    q, p = x
    return p**2/2 + kq**2/2 + ks*q**3/3

dt = torch.tensor(0.01, dtype=dtype, device=device)
xi = torch.tensor([0.4, 0.0], dtype=dtype, device=device)
kq = torch.tensor(1.0, dtype=dtype, device=device)
ks = torch.tensor(1.0, dtype=dtype, device=device)

t = derivative((2, 1, 1, 1), lambda xi, dt, kq, ks: taylor(4, dt, h, xi, kq, ks), xi, dt, kq, ks)
[8]:
# Check table

dxi = torch.tensor([0.1, 0.0], dtype=dtype, device=device)
ddt = torch.tensor(0.005, dtype=dtype, device=device)
dkq = torch.tensor(0.01, dtype=dtype, device=device)
dks = torch.tensor(0.01, dtype=dtype, device=device)

print(taylor(4, dt + ddt, h, xi + dxi, kq + dkq, ks + dks))
print(evaluate(t, [dxi, ddt, dkq, dks]))
tensor([4.999716e-01, -3.787356e-03])
tensor([4.999748e-01, -3.787395e-03])

Example-22: Yoshida integrator

[1]:
# Given a time-reversible integration step of difference order 2n
# Yoshida coefficients can be used to construct integration step of difference order 2(n+1)
# s(2(n+1))(dt) = s(2n)(x1 dt) o s(2n)(x2 dt) o s(2n)(x1 dt)

# If a hamiltonian vector field can be splitted into several sovable parts
# Second order time-reversible symmetric integrator can be easily constructed
# s1(dt/2) o s2(dt/2) o ... o sn(dt/2) o sn(dt/2) o ... o s2(dt/2) o s1(dt/2)
# Yoshida procedure can be then applied repeatedly to obtain higher order integration steps
[2]:
# Import

import numpy
import torch

from ndmap.util import first
from ndmap.util import nest
from ndmap.derivative import derivative
from ndmap.evaluate import evaluate
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.yoshida import coefficients
from ndmap.yoshida import yoshida

torch.set_printoptions(precision=6, sci_mode=True, linewidth=128)
print(torch.cuda.is_available())

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
True
[3]:
# Given integration step on Yoshida order n
# Yoshida coefficent for the next order can be computed

# Note, sum of coefficients is equal to one

print(coefficients(1)) # 2 -> 4
print(coefficients(2)) # 4 -> 6
print(coefficients(3)) # 6 -> 8
print(coefficients(4)) # 8 -> 10
[1.3512071919596578, -1.7024143839193153, 1.3512071919596578]
[1.1746717580893635, -1.349343516178727, 1.1746717580893635]
[1.1161829393253857, -1.2323658786507714, 1.1161829393253857]
[1.0870271062991708, -1.1740542125983413, 1.0870271062991708]
[4]:
# Given integration step on Yoshida order n
# Yoshida coefficent for m order can be computed

# Note, sum of coefficients is equal to one

print([f'{coefficient:.3f}' for coefficient in coefficients(1, 1)]) # 2 -> 4
print([f'{coefficient:.3f}' for coefficient in coefficients(1, 2)]) # 2 -> 6
print([f'{coefficient:.3f}' for coefficient in coefficients(2, 2)]) # 4 -> 6
['1.351', '-1.702', '1.351']
['1.587', '-2.000', '1.587', '-1.823', '2.297', '-1.823', '1.587', '-2.000', '1.587']
['1.175', '-1.349', '1.175']
[5]:
# Given a set of mappings and (start, final) Yoshida orders
# Corresponding Yoshida coefficients can be computed
# Note, mapping can be an integation step

# Mapping is a step (last argument should be False)

ns, cs = coefficients(1, 1, 1, False) ; print([ns, [f'{c:.3f}' for c in cs]]) # 2 -> 4
ns, cs = coefficients(1, 1, 2, False) ; print([ns, [f'{c:.3f}' for c in cs]]) # 2 -> 6
ns, cs = coefficients(1, 2, 2, False) ; print([ns, [f'{c:.3f}' for c in cs]]) # 4 -> 6
print()

# Two mappings (merge edge mappings)
# Note, number of mappings can be arbitrary

ns, cs = coefficients(2, 0, 0, True)  ; print([ns, [f'{c:.3f}' for c in cs]]) # 2 -> 2
ns, cs = coefficients(2, 0, 1, True)  ; print([ns, [f'{c:.3f}' for c in cs]]) # 2 -> 4
print()
[[0, 0, 0], ['1.351', '-1.702', '1.351']]
[[0, 0, 0, 0, 0, 0, 0, 0, 0], ['1.587', '-2.000', '1.587', '-1.823', '2.297', '-1.823', '1.587', '-2.000', '1.587']]
[[0, 0, 0], ['1.175', '-1.349', '1.175']]

[[0, 1, 0], ['0.500', '1.000', '0.500']]
[[0, 1, 0, 1, 0, 1, 0], ['0.676', '1.351', '-0.176', '-1.702', '-0.176', '1.351', '0.676']]

[6]:
# Integrate rotation

# h = h1 + h2
# h1 = 1/2 q**2 -> [q, p] -> [q, p - t*q]
# h2 = 1/2 p**2 -> [q, p] -> [q + t*q, p]

# Set mappings

def fn(x, t):
    q, p = x
    return torch.stack([q, p - t*q])

def gn(x, t):
    q, p = x
    return torch.stack([q + t*p, p])

# Set time step

t = torch.tensor(0.5, dtype=torch.float64)

# Set initial condition

x = torch.tensor([0.1, 0.1], dtype=torch.float64)

# Without merging

for i in range(10):
    print(yoshida(0, i, False, [fn, gn])(x, t), len(first(yoshida(0, i, False, [fn, gn]).table)))
print()

# With merging

for i in range(10):
    print(yoshida(0, i, True, [fn, gn])(x, t), len(first(yoshida(0, i, True, [fn, gn]).table)))
print()
tensor([1.375000e-01, 4.062500e-02], dtype=torch.float64) 3
tensor([1.354787e-01, 3.997254e-02], dtype=torch.float64) 9
tensor([1.357446e-01, 3.982467e-02], dtype=torch.float64) 27
tensor([1.357004e-01, 3.981894e-02], dtype=torch.float64) 81
tensor([1.356984e-01, 3.981700e-02], dtype=torch.float64) 243
tensor([1.357016e-01, 3.981564e-02], dtype=torch.float64) 729
tensor([1.357009e-01, 3.981567e-02], dtype=torch.float64) 2187
tensor([1.357007e-01, 3.981572e-02], dtype=torch.float64) 6561
tensor([1.357009e-01, 3.981567e-02], dtype=torch.float64) 19683
tensor([1.357007e-01, 3.981573e-02], dtype=torch.float64) 59049

tensor([1.375000e-01, 4.062500e-02], dtype=torch.float64) 3
tensor([1.354787e-01, 3.997254e-02], dtype=torch.float64) 7
tensor([1.357446e-01, 3.982467e-02], dtype=torch.float64) 19
tensor([1.357004e-01, 3.981894e-02], dtype=torch.float64) 55
tensor([1.356984e-01, 3.981700e-02], dtype=torch.float64) 163
tensor([1.357016e-01, 3.981564e-02], dtype=torch.float64) 487
tensor([1.357009e-01, 3.981567e-02], dtype=torch.float64) 1459
tensor([1.357007e-01, 3.981572e-02], dtype=torch.float64) 4375
tensor([1.357009e-01, 3.981567e-02], dtype=torch.float64) 13123
tensor([1.357007e-01, 3.981573e-02], dtype=torch.float64) 39367

[7]:
# Several steps

count = 100
t = torch.tensor(0.5, dtype=torch.float64)
x = torch.tensor([0.1, 0.1], dtype=torch.float64)
for _ in range(count):
    x = yoshida(0, 1, True, [fn, gn])(x, t/count)
print(x)

count = 100
t = torch.tensor(0.5, dtype=torch.float64)
x = torch.tensor([0.1, 0.1], dtype=torch.float64)
x = nest(count, yoshida(0, 1, True, [fn, gn]))(x,  t/count)
print(x)
tensor([1.357008e-01, 3.981570e-02], dtype=torch.float64)
tensor([1.357008e-01, 3.981570e-02], dtype=torch.float64)
[8]:
# Multistep

# h = h1 + h2
# h1 = 1/2 q**2 + 1/3 q**3 -> [q, p] -> [q, p - t*q - t*q**2]
# h2 = 1/2 p**2            -> [q, p] -> [q + t*q, p]

def fn(x, t):
    q, p = x
    return torch.stack([q, p - t*q - t*q**2])

def gn(x, t):
    q, p = x
    return torch.stack([q + t*p, p])

t = torch.tensor(0.1, dtype=torch.float64)
x = torch.tensor([0.1, 0.1], dtype=torch.float64)

print(yoshida(0, 1, True, [fn, gn])(x, t))
print()

# h = h1 + h2 + h3
# h1 = 1/2 q**2 -> [q, p] -> [q, p - t*q]
# h2 = 1/3 q**3 -> [q, p] -> [q, p - t*q**2]
# h3 = 1/2 p**2 -> [q, p] -> [q + t*q, p]

def fn(x, t):
    q, p = x
    return torch.stack([q, p - t*q])

def gn(x, t):
    q, p = x
    return torch.stack([q, p - t*q**2])

def hn(x, t):
    q, p = x
    return torch.stack([q + t*p, p])

t = torch.tensor(0.1, dtype=torch.float64)
x = torch.tensor([0.1, 0.1], dtype=torch.float64)

print(yoshida(0, 1, True, [fn, gn, hn])(x, t))
print()

# Note, the last mapping has the smallest number of evaluations

sequence, _ = yoshida(0, 1, True, [fn, gn, hn]).table

print([*map(sequence.count, sorted(set(sequence)))])
print()
tensor([1.094304e-01, 8.841961e-02], dtype=torch.float64)

tensor([1.094304e-01, 8.841961e-02], dtype=torch.float64)

[4, 6, 3]

[9]:
# Increase order

def fn(x, t):
    q, p = x
    return torch.stack([q, p - t*q - t*q**2])

def gn(x, t):
    q, p = x
    return torch.stack([q + t*p, p])

t = torch.tensor(0.1, dtype=torch.float64)
x = torch.tensor([0.1, 0.1], dtype=torch.float64)

s2 = yoshida(0, 0, True, [fn, gn])
print(torch.allclose(s2(x, t), yoshida(0, 0, True, [fn, gn])(x, t)))

s4 = yoshida(1, 1, False, [s2])
print(torch.allclose(s4(x, t), yoshida(0, 1, True, [fn, gn])(x, t)))

s6 = yoshida(1, 2, False, [s2])
print(torch.allclose(s6(x, t), yoshida(0, 2, True, [fn, gn])(x, t)))

s6 = yoshida(2, 2, False, [s4])
print(torch.allclose(s6(x, t), yoshida(0, 2, True, [fn, gn])(x, t)))
True
True
True
True
[10]:
# Step with parameters (matched signatures)

def fn(x, t, a, b):
    q, p = x
    return torch.stack([q, p - a*t*q - b*t*q**2])

def gn(x, t, a, b):
    q, p = x
    return torch.stack([q + t*p, p])

t = torch.tensor(0.1, dtype=torch.float64)
x = torch.tensor([0.1, 0.1], dtype=torch.float64)
a = torch.tensor(1.0, dtype=torch.float64)
b = torch.tensor(1.0, dtype=torch.float64)

print(yoshida(0, 1, True, [fn, gn])(x, t, a, b))
tensor([1.094304e-01, 8.841961e-02], dtype=torch.float64)
[11]:
# Step with parameters (pass fixed parameters)

def fn(x, t, a, b):
    q, p = x
    return torch.stack([q, p - a*t*q - b*t*q**2])

def gn(x, t):
    q, p = x
    return torch.stack([q + t*p, p])

t = torch.tensor(0.1, dtype=torch.float64)
x = torch.tensor([0.1, 0.1], dtype=torch.float64)
a = torch.tensor(1.0, dtype=torch.float64)
b = torch.tensor(1.0, dtype=torch.float64)

print(yoshida(0, 1, True, [fn, gn], parameters=[[a, b], []])(x, t))
tensor([1.094304e-01, 8.841961e-02], dtype=torch.float64)
[12]:
# Step can be differentiated with respect to initials, time step and/or parametes

def fn(x, t, a, b):
    q, p = x
    return torch.stack([q, p - a*t*q - b*t*q**2])

def gn(x, t, a, b):
    q, p = x
    return torch.stack([q + t*p, p])

t = torch.tensor(0.1, dtype=torch.float64)
x = torch.tensor([0.1, 0.1], dtype=torch.float64)
a = torch.tensor(1.0, dtype=torch.float64)
b = torch.tensor(1.0, dtype=torch.float64)

# Derivative with respect to initial

print(derivative(1, yoshida(0, 1, True, [fn, gn]), x, t, a, b))
print()

# Derivative with respect to time step

print(derivative(1, lambda t, x, a, b: yoshida(0, 1, True, [fn, gn])(x, t, a, b), t, x, a, b))
print(derivative((0, 1), yoshida(0, 1, True, [fn, gn]), x, t, a, b))
print()
[tensor([1.094304e-01, 8.841961e-02], dtype=torch.float64), tensor([[9.939735e-01, 9.979712e-02],
        [-1.207179e-01, 9.939427e-01]], dtype=torch.float64)]

[tensor([1.094304e-01, 8.841961e-02], dtype=torch.float64), tensor([8.841321e-02, -1.214017e-01], dtype=torch.float64)]
[[tensor([1.094304e-01, 8.841961e-02], dtype=torch.float64), tensor([8.841321e-02, -1.214017e-01], dtype=torch.float64)]]

[13]:
# For derivative table propagation all knobs should be vectors

def fn(x, t, a, b):
    q, p = x
    return torch.stack([q, p - a*t*q - b*t*q**2])

def gn(x, t, a, b):
    q, p = x
    return torch.stack([q + t*p, p])

t = torch.tensor([0.1], dtype=torch.float64)
x = torch.tensor([0.1, 0.1], dtype=torch.float64)
a = torch.tensor([1.0], dtype=torch.float64)
b = torch.tensor([1.0], dtype=torch.float64)

step = yoshida(0, 1, True, [fn, gn])

def wrapper(x, t, a, b):
    (t, ), (a, ), (b, ) = t, a, b
    return step(x, t, a, b)

print(step(x, *t, *a, *b))
print(wrapper(x, t, a, b))
print()

out = propagate((2, 1, 1, 1),
                4*(1, ),
                identity(4*(1, ), [x, t, a, b]),
                [t, a, b],
                wrapper)

dt = torch.tensor([+0.001], dtype=torch.float64)
dx = torch.tensor([+0.005, -0.005], dtype=torch.float64)
da = torch.tensor([-0.001], dtype=torch.float64)
db = torch.tensor([+0.001], dtype=torch.float64)

print(wrapper(x + dx, t + dt, a + da, b + db))
print(evaluate(out, [dx, dt, da, db]))
tensor([1.094304e-01, 8.841961e-02], dtype=torch.float64)
tensor([1.094304e-01, 8.841961e-02], dtype=torch.float64)

tensor([1.139844e-01, 8.272697e-02], dtype=torch.float64)
tensor([1.139846e-01, 8.272929e-02], dtype=torch.float64)
[14]:
# Given a step, its derivatives can be used as a taylor model
# Note, taylor model is not symplectic in general

torch.set_printoptions(precision=12, sci_mode=True, linewidth=128)

def fn(x, t, a, b):
    q, p = x
    return torch.stack([q, p - a*t*q - b*t*q**2])

def gn(x, t, a, b):
    q, p = x
    return torch.stack([q + t*p, p])

t = torch.tensor(0.1, dtype=torch.float64)
x = torch.tensor([0.1, 0.1], dtype=torch.float64)
a = torch.tensor(1.0, dtype=torch.float64)
b = torch.tensor(1.0, dtype=torch.float64)

dx = torch.tensor([+0.1, -0.1], dtype=torch.float64)

print(yoshida(0, 1, True, [fn, gn])(x, t, a, b))
print(yoshida(0, 1, True, [fn, gn])(x + dx, t, a, b))
print()

print(evaluate(derivative(1, yoshida(0, 1, True, [fn, gn]), x, t, a, b), [dx]))
print(evaluate(derivative(2, yoshida(0, 1, True, [fn, gn]), x, t, a, b), [dx]))
print(evaluate(derivative(3, yoshida(0, 1, True, [fn, gn]), x, t, a, b), [dx]))
print(evaluate(derivative(4, yoshida(0, 1, True, [fn, gn]), x, t, a, b), [dx]))
print(evaluate(derivative(5, yoshida(0, 1, True, [fn, gn]), x, t, a, b), [dx]))
print()
tensor([1.094303556041e-01, 8.841960703680e-02], dtype=torch.float64)
tensor([1.988014274981e-01, -2.394392236666e-02], dtype=torch.float64)

tensor([1.988479888736e-01, -2.304645602610e-02], dtype=torch.float64)
tensor([1.988014166220e-01, -2.394421962623e-02], dtype=torch.float64)
tensor([1.988014274804e-01, -2.394392240255e-02], dtype=torch.float64)
tensor([1.988014274981e-01, -2.394392236643e-02], dtype=torch.float64)
tensor([1.988014274981e-01, -2.394392236666e-02], dtype=torch.float64)

Example-23: Direct invariant

[1]:
# In this example Taylor invariants are constructed by solving I(f(x)) = I(x) order-by-order
# Mapping f(x) can be replaced with its derivative table representation of a given order n
# Or mapping can be used directly
# Invariant of order n+1 can be computed
# Note, to avoid trivial solution, initial invariant guess should be provided, e.g. linear part invariant
[1]:
# Import

import numpy
import torch

from ndmap.derivative import derivative
from ndmap.series import series
from ndmap.evaluate import evaluate
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.invariant import invariant

from twiss.wolski import twiss

torch.set_printoptions(precision=6, sci_mode=True, linewidth=128)

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
[3]:
# Data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=10):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=5):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=20):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def kick(x, cx, cy):
    (qx, px, qy, py) = x
    return torch.stack([qx, px + cx, qy, py + cy])

def slip(x, dx, dy):
    (qx, px, qy, py) = x
    return torch.stack([qx + dx, px, qy + dy, py])

[5]:
# Set transport maps between observation points

def map_01_02(x):
    x = quad(x, [0.0], 0.19, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.1, +0.5)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015, 0.25, 3.0)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.1, -0.5)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], -0.21, 0.50)
    x = quad(x, [0.0], -0.21, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.1, -0.5)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015, 0.25, 3.0)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.1, +0.5)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], 0.19, 0.50)
    return x

transport = [
    map_01_02
]
[6]:
# Define one-turn transport

def fodo(x):
    for mapping in transport:
        x = mapping(x)
    return x
[7]:
# Set evaluation point

# Note, zero is a fixed point and derivatives with respect to parameters are not used, i.e. no need to compute paraetric fixed point

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
fodo(x)
[7]:
tensor([0., 0., 0., 0.], dtype=torch.float64)
[8]:
# Generate derivative table representation for given order

t = identity(4, x, jacobian=torch.func.jacfwd)
t = propagate((4, ), (4, ), t, [], fodo, jacobian=torch.func.jacfwd)
[9]:
# Compute linear normalization matrix

_, n, _ = twiss(derivative(1, lambda x: evaluate(t, [x]), x, intermediate=False))
[10]:
# Set initial invariants

def ix(x):
    qx, px, qy, py = torch.linalg.inv(n) @ x
    return 1/2*(qx**2 + px**2)

def iy(x):
    qx, px, qy, py = torch.linalg.inv(n) @ x
    return 1/2*(qy**2 + py**2)
[11]:
# Check conservation of linear invariants

print(derivative(2, ix, x, intermediate=False))
print(propagate((4, ), (2, ), t, [], ix, intermediate=False))
print()

print(derivative(2, iy, x, intermediate=False))
print(propagate((4, ), (2, ), t, [], iy, intermediate=False))
print()
tensor([[ 6.775970e-02, -1.148888e-15,  0.000000e+00,  0.000000e+00],
        [-1.148888e-15,  1.475804e+01,  0.000000e+00,  0.000000e+00],
        [ 0.000000e+00,  0.000000e+00,  0.000000e+00,  0.000000e+00],
        [ 0.000000e+00,  0.000000e+00,  0.000000e+00,  0.000000e+00]], dtype=torch.float64)
tensor([[ 6.775970e-02, -1.151856e-15,  0.000000e+00,  0.000000e+00],
        [-1.151856e-15,  1.475804e+01,  0.000000e+00,  0.000000e+00],
        [ 0.000000e+00,  0.000000e+00,  0.000000e+00,  0.000000e+00],
        [ 0.000000e+00,  0.000000e+00,  0.000000e+00,  0.000000e+00]], dtype=torch.float64)

tensor([[0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00],
        [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00],
        [0.000000e+00, 0.000000e+00, 8.242765e-02, 1.472561e-15],
        [0.000000e+00, 0.000000e+00, 1.472561e-15, 1.213185e+01]], dtype=torch.float64)
tensor([[0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00],
        [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00],
        [0.000000e+00, 0.000000e+00, 8.242765e-02, 1.554312e-15],
        [0.000000e+00, 0.000000e+00, 1.554312e-15, 1.213185e+01]], dtype=torch.float64)

[12]:
# Compute nonlinear invariants
# Note, computation is not optimized and requires a lot of memory

tx, _ = invariant((4, ), x, [], ix, t, jacobian=torch.func.jacfwd, threshold=1.0E-3)
ty, _ = invariant((4, ), x, [], iy, t, jacobian=torch.func.jacfwd, threshold=1.0E-3)
[13]:
# Generate and plot trajectory

x = torch.tensor([0.005, 0.0, 0.005, 0.0], dtype=dtype, device=device)
bag = []
for _ in range(2048):
    x = fodo(x)
    bag.append(x)
bag = torch.stack(bag)
qx, px, qy, py = bag.T

plt.figure(figsize=(2*8, 8))
ax = plt.subplot(121)
ax.scatter(qx.cpu().numpy(), px.cpu().numpy(), marker='x', s=1, color='red')
ax = plt.subplot(122)
ax.scatter(qy.cpu().numpy(), py.cpu().numpy(), marker='x', s=1, color='red')
plt.show()
../_images/examples_ndmap_274_0.png
[14]:
# 1st invariant conservation
# Note, for different initial conditions higher order invariants can give worse results

plt.figure(figsize=(20, 5))

for order, color in zip([2, 3, 4], ['black', 'red', 'blue']):
    sx = series((4, ), (order, ), tx)
    vx = torch.func.vmap(lambda x: evaluate(sx, [x]))(bag)
    print(vx.mean())
    print(vx.std())
    print()
    plt.scatter(range(len(vx)), vx.cpu().numpy(), color=color, marker='x')

plt.show()
tensor(6.939136e-07, dtype=torch.float64)
tensor(8.133688e-08, dtype=torch.float64)

tensor(7.033768e-07, dtype=torch.float64)
tensor(2.919483e-09, dtype=torch.float64)

tensor(7.008737e-07, dtype=torch.float64)
tensor(1.419347e-09, dtype=torch.float64)

../_images/examples_ndmap_275_1.png
[15]:
# 2nd invariant conservation
# Note, for different initial conditions higher order invariants can give worse results

plt.figure(figsize=(20, 5))

for order, color in zip([2, 3, 4], ['black', 'red', 'blue']):
    sy = series((4, ), (order, ), ty)
    vy = torch.func.vmap(lambda x: evaluate(sy, [x]))(bag)
    print(vy.mean())
    print(vy.std())
    print()
    plt.scatter(range(len(vy)), vy.cpu().numpy(), color=color, marker='x')

plt.show()
tensor(1.222613e-06, dtype=torch.float64)
tensor(1.538842e-07, dtype=torch.float64)

tensor(1.200414e-06, dtype=torch.float64)
tensor(7.017794e-09, dtype=torch.float64)

tensor(1.206937e-06, dtype=torch.float64)
tensor(3.553736e-09, dtype=torch.float64)

../_images/examples_ndmap_276_1.png

Example-24: Direct invariant (parametric)

[1]:
# Import

import numpy
import torch

from ndmap.util import first
from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.evaluate import evaluate
from ndmap.evaluate import compare
from ndmap.series import series
from ndmap.series import clean
from ndmap.series import split
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import parametric_fixed_point
from ndmap.invariant import invariant

from twiss.wolski import twiss

torch.set_printoptions(precision=6, sci_mode=True, linewidth=128)

import warnings
warnings.filterwarnings("ignore")
[2]:
# Data type and device

dtype = torch.float64
device = torch.device('cpu')
[3]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=10):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=5):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=20):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def kick(x, cx, cy):
    (qx, px, qy, py) = x
    return torch.stack([qx, px + cx, qy, py + cy])

def slip(x, dx, dy):
    (qx, px, qy, py) = x
    return torch.stack([qx + dx, px, qy + dy, py])

[4]:
# Set transport maps between observation points

def map_01_02(x, w, k):
    ksf, ksd, ksb = k
    x = quad(x, w, 0.19, 0.50)
    x = drif(x, w, 0.45)
    x = sext(x, w, 0.1, +0.5 + ksf)
    x = drif(x, w, 0.45)
    x = bend(x, w, 22.92, 0.015, 0.25 + ksb, 3.0)
    x = drif(x, w, 0.45)
    x = sext(x, w, 0.1, -0.5 + ksd)
    x = drif(x, w, 0.45)
    x = quad(x, w, -0.21, 0.50)
    x = quad(x, w, -0.21, 0.50)
    x = drif(x, w, 0.45)
    x = sext(x, w, 0.1, -0.5 + ksd)
    x = drif(x, w, 0.45)
    x = bend(x, w, 22.92, 0.015, 0.25 + ksb, 3.0)
    x = drif(x, w, 0.45)
    x = sext(x, w, 0.1, +0.5 + ksf)
    x = drif(x, w, 0.45)
    x = quad(x, w, 0.19, 0.50)
    return x

transport = [
    map_01_02
]
[5]:
# Define one-turn transport

def fodo(x, k, w):
    for mapping in transport:
        x = mapping(x, k, w)
    return x
[6]:
# Find fixed point

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
w = torch.tensor(1*[0.0], dtype=dtype, device=device)
k = torch.tensor(3*[0.0], dtype=dtype, device=device)

fp = fixed_point(32, fodo, x, w, k, power=1)

print(fp)
print(fodo(fp, w, k))
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[7]:
# Set computation orders

(nx, nw, nk) = (3, 1, 1)
[8]:
# Find parametric fixed point

pfp = parametric_fixed_point((nw, nk), fp, [w, k], fodo)
[9]:
# Test parametric fixed point
# Note, if fp is not zero, redefine one-turn transport to map zero to zero

print(compare(pfp, propagate((4, 1, 3), (0, nw, nk), pfp, [w, k], fodo)))
True
[10]:
# Define a fodo variant around parametric fixed point

def mapping(x, w, k):
    x = x + evaluate(first(pfp), [w, k])
    x = fodo(x, w, k)
    x = x - evaluate(first(pfp), [w, k])
    return x

print(mapping(x, w, k))
tensor([0., 0., 0., 0.], dtype=torch.float64)
[11]:
# Compute derivative table representation

# Note, no parametric part should be passed, parametric zero is transformed to parametric zero

t = identity((nx, nw, nk), x)
t = propagate((4, 1, 3), (nx - 1, nw, nk), t, [w, k], mapping)
chop(t)
print(first(t))
[[tensor([0., 0., 0., 0.], dtype=torch.float64), tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], dtype=torch.float64)], [tensor([[0.],
        [0.],
        [0.],
        [0.]], dtype=torch.float64), tensor([[[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]]], dtype=torch.float64)]]
[12]:
# Compute parametric normalization matrix

def fn(w, k):
    m = derivative(1, lambda x: evaluate(t, [x, w, k]), fp, intermediate=False)
    _, n, _ = twiss(m)
    return n

tn = derivative((nw, nk), fn, w, k)
[13]:
# Set initial invariants

def ix(x, w, k):
    qx, px, qy, py = torch.inverse(evaluate(tn, [w, k])) @ x
    return 1/2*(qx**2 + px**2)

def iy(x, w, k):
    qx, px, qy, py = torch.inverse(evaluate(tn, [w, k])) @ x
    return 1/2*(qy**2 + py**2)
[14]:
# 1st invariant

tx, _ = invariant((nx, nw, nk), x, [w, k], ix, t, jacobian=torch.func.jacrev, threshold=0.01)
_
[14]:
[]
[15]:
# 2nd invariant

ty, _ = invariant((nx, nw, nk), x, [w, k], iy, t, jacobian=torch.func.jacrev, threshold=0.01)
_
[15]:
[]
[16]:
# Test conservation

# Note, here propagate is used as composition

print(compare(propagate((4, 1, 3), (nx-1, nw, nk), t, [w, k], tx), tx))
print(compare(propagate((4, 1, 3), (nx-1, nw, nk), t, [w, k], ty), ty))
True
True
[17]:
# Series representation
# Note, generalized monomial is qx px qy py w ksf ksd ksb

s, *_ = split(clean(series((4, 1, 3), (nx, nw, nk), tx)))
s
[17]:
{(2, 0, 0, 0, 0, 0, 0, 0): tensor(3.387985e-02, dtype=torch.float64),
 (0, 2, 0, 0, 0, 0, 0, 0): tensor(7.379018e+00, dtype=torch.float64),
 (2, 0, 0, 0, 0, 1, 0, 0): tensor(-6.047879e-03, dtype=torch.float64),
 (2, 0, 0, 0, 0, 0, 1, 0): tensor(-6.517039e-03, dtype=torch.float64),
 (0, 2, 0, 0, 0, 1, 0, 0): tensor(1.317226e+00, dtype=torch.float64),
 (0, 2, 0, 0, 0, 0, 1, 0): tensor(1.419409e+00, dtype=torch.float64),
 (2, 0, 0, 0, 1, 0, 0, 0): tensor(1.994743e-01, dtype=torch.float64),
 (0, 2, 0, 0, 1, 0, 0, 0): tensor(-4.344542e+01, dtype=torch.float64),
 (2, 0, 0, 0, 1, 1, 0, 0): tensor(-6.188913e-02, dtype=torch.float64),
 (2, 0, 0, 0, 1, 0, 1, 0): tensor(-8.813844e-02, dtype=torch.float64),
 (2, 0, 0, 0, 1, 0, 0, 1): tensor(6.534032e-01, dtype=torch.float64),
 (0, 2, 0, 0, 1, 1, 0, 0): tensor(-2.031422e+00, dtype=torch.float64),
 (0, 2, 0, 0, 1, 0, 1, 0): tensor(2.482422e+00, dtype=torch.float64),
 (0, 2, 0, 0, 1, 0, 0, 1): tensor(-1.423110e+02, dtype=torch.float64),
 (3, 0, 0, 0, 0, 0, 0, 0): tensor(5.742237e-02, dtype=torch.float64),
 (1, 2, 0, 0, 0, 0, 0, 0): tensor(7.908874e+00, dtype=torch.float64),
 (1, 0, 2, 0, 0, 0, 0, 0): tensor(-1.249413e+00, dtype=torch.float64),
 (1, 0, 0, 2, 0, 0, 0, 0): tensor(8.887557e+01, dtype=torch.float64),
 (0, 1, 1, 1, 0, 0, 0, 0): tensor(-2.613173e+02, dtype=torch.float64),
 (3, 0, 0, 0, 0, 1, 0, 0): tensor(-2.180292e-02, dtype=torch.float64),
 (3, 0, 0, 0, 0, 0, 1, 0): tensor(-3.543561e-02, dtype=torch.float64),
 (3, 0, 0, 0, 0, 0, 0, 1): tensor(2.292041e-01, dtype=torch.float64),
 (1, 2, 0, 0, 0, 1, 0, 0): tensor(7.152281e+00, dtype=torch.float64),
 (1, 2, 0, 0, 0, 0, 1, 0): tensor(1.584743e+01, dtype=torch.float64),
 (1, 2, 0, 0, 0, 0, 0, 1): tensor(1.861577e+01, dtype=torch.float64),
 (1, 0, 2, 0, 0, 1, 0, 0): tensor(-7.509155e-01, dtype=torch.float64),
 (1, 0, 2, 0, 0, 0, 1, 0): tensor(-1.639441e+00, dtype=torch.float64),
 (1, 0, 2, 0, 0, 0, 0, 1): tensor(-5.148461e+00, dtype=torch.float64),
 (1, 0, 0, 2, 0, 1, 0, 0): tensor(6.392325e+01, dtype=torch.float64),
 (1, 0, 0, 2, 0, 0, 1, 0): tensor(1.341707e+02, dtype=torch.float64),
 (1, 0, 0, 2, 0, 0, 0, 1): tensor(3.638539e+02, dtype=torch.float64),
 (0, 1, 1, 1, 0, 1, 0, 0): tensor(-2.452915e+02, dtype=torch.float64),
 (0, 1, 1, 1, 0, 0, 1, 0): tensor(-4.911196e+02, dtype=torch.float64),
 (0, 1, 1, 1, 0, 0, 0, 1): tensor(-1.064586e+03, dtype=torch.float64),
 (3, 0, 0, 0, 1, 0, 0, 0): tensor(-5.835809e-01, dtype=torch.float64),
 (1, 2, 0, 0, 1, 0, 0, 0): tensor(6.243185e+01, dtype=torch.float64),
 (1, 0, 2, 0, 1, 0, 0, 0): tensor(1.617990e+02, dtype=torch.float64),
 (1, 0, 0, 2, 1, 0, 0, 0): tensor(-1.601207e+04, dtype=torch.float64),
 (0, 1, 1, 1, 1, 0, 0, 0): tensor(4.860999e+04, dtype=torch.float64),
 (3, 0, 0, 0, 1, 1, 0, 0): tensor(-8.770463e-02, dtype=torch.float64),
 (3, 0, 0, 0, 1, 0, 1, 0): tensor(-4.631656e-01, dtype=torch.float64),
 (3, 0, 0, 0, 1, 0, 0, 1): tensor(-4.557887e+00, dtype=torch.float64),
 (1, 2, 0, 0, 1, 1, 0, 0): tensor(9.372351e+01, dtype=torch.float64),
 (1, 2, 0, 0, 1, 0, 1, 0): tensor(1.993612e+02, dtype=torch.float64),
 (1, 2, 0, 0, 1, 0, 0, 1): tensor(5.704250e+02, dtype=torch.float64),
 (1, 0, 2, 0, 1, 1, 0, 0): tensor(2.404182e+02, dtype=torch.float64),
 (1, 0, 2, 0, 1, 0, 1, 0): tensor(5.584620e+02, dtype=torch.float64),
 (1, 0, 2, 0, 1, 0, 0, 1): tensor(1.295549e+03, dtype=torch.float64),
 (1, 0, 0, 2, 1, 1, 0, 0): tensor(-2.229816e+04, dtype=torch.float64),
 (1, 0, 0, 2, 1, 0, 1, 0): tensor(-5.022342e+04, dtype=torch.float64),
 (1, 0, 0, 2, 1, 0, 0, 1): tensor( -1.279140e+05, dtype=torch.float64),
 (0, 1, 1, 1, 1, 1, 0, 0): tensor(7.791695e+04, dtype=torch.float64),
 (0, 1, 1, 1, 1, 0, 1, 0): tensor( 1.680246e+05, dtype=torch.float64),
 (0, 1, 1, 1, 1, 0, 0, 1): tensor( 3.876573e+05, dtype=torch.float64)}

Example-25: Composition

[1]:
# In this example the following homomorphism is illustrated
# t(f o g) = t(f) o t(g)
# Meaning, table of composition is equal to composition of tables
# Mappings f and g that map zero to zero, i.e. are without constant part
[2]:
# Import

import numpy
import torch

from ndmap.derivative import derivative
from ndmap.evaluate import evaluate
from ndmap.evaluate import compare
from ndmap.propagate import identity
from ndmap.propagate import propagate

torch.set_printoptions(precision=12, sci_mode=True)
print(torch.cuda.is_available())

import warnings
warnings.filterwarnings("ignore")
True
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=10):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=5):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=20):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def kick(x, cx, cy):
    (qx, px, qy, py) = x
    return torch.stack([qx, px + cx, qy, py + cy])

def slip(x, dx, dy):
    (qx, px, qy, py) = x
    return torch.stack([qx + dx, px, qy + dy, py])
[5]:
# Set transport maps between observation points

def map_01_02(x):
    x = quad(x, [0.0], 0.19, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.1, +0.5)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015, 0.25, 3.0)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.1, -0.5)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], -0.21, 0.50)
    return x

def map_02_03(x):
    x = quad(x, [0.0], -0.21, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.1, -0.5)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015, 0.25, 3.0)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.1, +0.5)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], 0.19, 0.50)
    return x
[6]:
# Set computation order & evaluation point

n = 4
x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)
[7]:
# Note, both mapping map zero to zero
# Also, parameters that effect closed orbit are not used

print(map_01_02(x))
print(map_02_03(x))
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[8]:
# Direct table generation

T = derivative(n, lambda x: map_02_03(map_01_02(x)), x)
[9]:
# Propagation of identity (equvalent to direct table generation)

t = identity((n, ), [x])
t = propagate((4, ), (n, ), t, [], lambda x: map_02_03(map_01_02(x)))

print(compare(T, t))
True
[10]:
# Propagation (split)

# Table is propagated through mappings
# Note, this evaluation is inherently sequential

t = identity((n, ), [x])
t = propagate((4, ), (n, ), t, [], map_01_02)
t = propagate((4, ), (n, ), t, [], map_02_03)

print(compare(T, t))
True
[11]:
# Composition

# Note, given an evaluation point (e.g. closed orbit at each element)
# Identity can be propagated (or just perform direct computation) for each element
# This can significantly reduce computation cost

# Note, evaluations of t_01_02 and t_02_03 are independent and can be performed in parallel
# Also, propagation of table through table might be computationally less expensive

t_01_02 = identity((n, ), [x])
t_01_02 = propagate((4, ), (n, ), t_01_02, [], map_01_02)

t_02_03 = identity((n, ), [x])
t_02_03 = propagate((4, ), (n, ), t_02_03, [], map_02_03)

t = propagate((4, ), (n, ), t_01_02, [], t_02_03)

print(compare(T, t))
True

Example-26: Composition (closed orbit)

[1]:
# Composition illustration with non-zero closed orbit
[2]:
# Import

import numpy
import torch

from ndmap.util import first
from ndmap.derivative import derivative
from ndmap.evaluate import evaluate
from ndmap.evaluate import compare
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import parametric_fixed_point

from matplotlib import pyplot as plt

torch.set_printoptions(precision=12, sci_mode=True)
print(torch.cuda.is_available())

import warnings
warnings.filterwarnings("ignore")
True
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=10):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=5):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=20):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def kick(x, cx, cy):
    (qx, px, qy, py) = x
    return torch.stack([qx, px + cx, qy, py + cy])

def slip(x, dx, dy):
    (qx, px, qy, py) = x
    return torch.stack([qx + dx, px, qy + dy, py])
[5]:
# Set transport maps between observation points

# Note, kick is used to generate non-zero (dynamical) closed orbit
# Momentum deviation is used as a parameter (coupled to closed orbit via dispersion)

def map_01_02(x, w):
    x = kick(x, +1.0E-4, -1.0E-4)
    x = quad(x, w, 0.19, 0.50)
    x = drif(x, w, 0.45)
    x = sext(x, w, 0.1, +0.5)
    x = drif(x, w, 0.45)
    x = bend(x, w, 22.92, 0.015, 0.25, 3.0)
    x = drif(x, w, 0.45)
    x = sext(x, w, 0.1, -0.5)
    x = drif(x, w, 0.45)
    x = quad(x, w, -0.21, 0.50)
    return x

def map_02_03(x, w):
    x = quad(x, w, -0.21, 0.50)
    x = drif(x, w, 0.45)
    x = sext(x, w, 0.1, -0.5)
    x = drif(x, w, 0.45)
    x = bend(x, w, 22.92, 0.015, 0.25, 3.0)
    x = drif(x, w, 0.45)
    x = sext(x, w, 0.1, +0.5)
    x = drif(x, w, 0.45)
    x = quad(x, w, 0.19, 0.50)
    return x
[6]:
# Set evaluation point

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
w = torch.tensor(1*[0.0], dtype=dtype, device=device)
[7]:
# Find (dynamical) fixed point

fp = fixed_point(32, lambda x, w: map_02_03(map_01_02(x, w), w), x, w, power=1)

# Check fixed point

print(fp)
print(map_02_03(map_01_02(fp, w), w))
tensor([ 8.418072943377e-04, -5.000000000000e-05, -2.043309959087e-03,
         5.000000000000e-05], dtype=torch.float64)
tensor([ 8.418072943377e-04, -5.000000000000e-05, -2.043309959087e-03,
         5.000000000000e-05], dtype=torch.float64)
[8]:
# Set computation orders for state and each knob group

(nx, nw) = (4, 2)
[9]:
# Find parametric fixed point

pfp = parametric_fixed_point((nw, ), fp, [w], lambda x, w: map_02_03(map_01_02(x, w), w))

# Check

print(compare(pfp, propagate((4, 1), (0, nw), pfp, [w], lambda x, w: map_02_03(map_01_02(x, w), w))))
True
[10]:
# Set parametric fixed points at each map entrance

pfp_01 = identity((0, nw), [x, w], parametric=pfp)
pfp_02 = propagate((4, 1), (0, nw), pfp_01, [w], map_01_02)

# Check

print(compare(pfp_01, propagate((4, 1), (0, nw), pfp_01, [w], lambda x, w: map_02_03(map_01_02(x, w), w))))
print(compare(pfp_02, propagate((4, 1), (0, nw), pfp_02, [w], lambda x, w: map_01_02(map_02_03(x, w), w))))
True
True
[11]:
# Define transformations around parametric fixed points

# Note, this transformation map zero (parametric) state to zero (upto given order)
# This is true by construction

def fn_01_02(x, w):
    return map_01_02(x + evaluate(first(pfp_01), [w]), w) - evaluate(first(pfp_02), [w])

def fn_02_03(x, w):
    return map_02_03(x + evaluate(first(pfp_02), [w]), w) - evaluate(first(pfp_01), [w])

print(propagate((4, 1), (0, nw), identity((0, nw), [x, w]), [w], fn_01_02))
[[tensor([0., 0., 0., 0.], dtype=torch.float64), tensor([[0.],
        [0.],
        [0.],
        [0.]], dtype=torch.float64), tensor([[[0.]],

        [[0.]],

        [[0.]],

        [[0.]]], dtype=torch.float64)]]
[12]:
# Propagate identity sequentially

T = identity((nx, nw), [x, w], parametric=pfp)
T = propagate((4, 1), (nx, nw), T, [w], lambda x, w: map_01_02(x, w))
T = propagate((4, 1), (nx, nw), T, [w], lambda x, w: map_02_03(x, w))
[13]:
# Composition

# Note, parametric part is zero, other elements of t should be equal to corresponding elements of T

t_01_02 = identity((nx, nw), [x, w])
t_01_02 = propagate((4, 1), (nx, nw), t_01_02, [w], fn_01_02)

t_02_03 = identity((nx, nw), [x, w])
t_02_03 = propagate((4, 1), (nx, nw), t_02_03, [w], fn_02_03)

t = propagate((4, 1), (nx, nw), t_01_02, [w], t_02_03)
[14]:
# Compare phase space trajectories

plt.figure(figsize=(10, 10))

qx = torch.linspace(0.0, 0.01, 10, dtype=dtype, device=device)
px = torch.zeros_like(qx) + 1.0E-12
qy = torch.zeros_like(qx) + 1.0E-03
py = torch.zeros_like(qx) + 1.0E-12

x = torch.stack([qx, px, qy, py]).T

w = torch.tensor([1.0E-3], dtype=dtype, device=device)

count = 256
table = []
y = torch.clone(x)
for _ in range(count):
    table.append(y)
    y = torch.func.vmap(lambda x: map_02_03(map_01_02(x, w), w))(y)
table = torch.stack(table).swapaxes(0, -1)
qx, px, *_ = table
for q, p in zip(qx.cpu().numpy(), px.cpu().numpy()):
    plt.scatter(q, p, color='gray', marker='o')


count = 256
table = []
y = torch.clone(x)
for _ in range(count):
    table.append(y)
    y = y - evaluate(first(pfp), [w])
    y = torch.func.vmap(lambda x: evaluate(T, [x, w]))(y)
table = torch.stack(table).swapaxes(0, -1)
qx, px, *_ = table
for q, p in zip(qx.cpu().numpy(), px.cpu().numpy()):
    plt.scatter(q, p, color='red', marker='o')


count = 256
table = []
y = torch.clone(x)
for _ in range(count):
    table.append(y)
    y = y - evaluate(first(pfp), [w])
    y = torch.func.vmap(lambda x: evaluate(t, [x, w]))(y)
    y = y + evaluate(first(pfp), [w])
table = torch.stack(table).swapaxes(0, -1)
qx, px, *_ = table
for q, p in zip(qx.cpu().numpy(), px.cpu().numpy()):
    plt.scatter(q, p, color='blue', marker='x')

plt.show()
../_images/examples_ndmap_321_0.png

Example-27: Inverse

[1]:
# Import

import numpy
import torch

from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.evaluate import evaluate
from ndmap.evaluate import compare
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.inverse import inverse

torch.set_printoptions(precision=12, sci_mode=True)
print(torch.cuda.is_available())

import warnings
warnings.filterwarnings("ignore")
True
[2]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[3]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=10):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=5):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=20):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def kick(x, cx, cy):
    (qx, px, qy, py) = x
    return torch.stack([qx, px + cx, qy, py + cy])

def slip(x, dx, dy):
    (qx, px, qy, py) = x
    return torch.stack([qx + dx, px, qy + dy, py])
[4]:
# Set transport maps between observation points

def map_01_02(x):
    x = quad(x, [0.0], 0.19, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.1, +0.5)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015, 0.25, 3.0)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.1, -0.5)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], -0.21, 0.50)
    x = quad(x, [0.0], -0.21, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.1, -0.5)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015, 0.25, 3.0)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.1, +0.5)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], 0.19, 0.50)
    return x
[5]:
# Set computation order & evaluation point

n = 3
x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)

print(map_01_02(x))
tensor([0., 0., 0., 0.], dtype=torch.float64)
[6]:
# Compute derivative table

t = identity((n, ), [x])
t = propagate((4, ), (n, ), t, [], lambda x: map_01_02(x))
[7]:
# Compute inverse

t_inv = inverse((n, ), x, [], t)
[8]:
# Check

out = propagate((4, ), (n, ), t_inv, [], t)
chop(out, replace=True)
out
[8]:
[[],
 tensor([[1.000000000000e+00, 0.000000000000e+00, 0.000000000000e+00, 0.000000000000e+00],
         [0.000000000000e+00, 1.000000000000e+00, 0.000000000000e+00, 0.000000000000e+00],
         [0.000000000000e+00, 0.000000000000e+00, 1.000000000000e+00, 0.000000000000e+00],
         [0.000000000000e+00, 0.000000000000e+00, 0.000000000000e+00, 1.000000000000e+00]],
        dtype=torch.float64),
 [],
 []]

Example-28: Inverse (closed orbit)

[1]:
# Import

import numpy
import torch

from ndmap.util import first
from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.evaluate import evaluate
from ndmap.evaluate import compare
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import parametric_fixed_point
from ndmap.inverse import inverse

torch.set_printoptions(precision=12, sci_mode=True)
print(torch.cuda.is_available())

import warnings
warnings.filterwarnings("ignore")
True
[2]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[3]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=10):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=5):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=20):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def kick(x, cx, cy):
    (qx, px, qy, py) = x
    return torch.stack([qx, px + cx, qy, py + cy])

def slip(x, dx, dy):
    (qx, px, qy, py) = x
    return torch.stack([qx + dx, px, qy + dy, py])
[4]:
# Set transport maps between observation points

def map_01_02(x, w):
    x = kick(x, +1.0E-4, -1.0E-4)
    x = quad(x, w, 0.19, 0.50)
    x = drif(x, w, 0.45)
    x = sext(x, w, 0.1, +0.5)
    x = drif(x, w, 0.45)
    x = bend(x, w, 22.92, 0.015, 0.25, 3.0)
    x = drif(x, w, 0.45)
    x = sext(x, w, 0.1, -0.5)
    x = drif(x, w, 0.45)
    x = quad(x, w, -0.21, 0.50)
    x = quad(x, w, -0.21, 0.50)
    x = drif(x, w, 0.45)
    x = sext(x, w, 0.1, -0.5)
    x = drif(x, w, 0.45)
    x = bend(x, w, 22.92, 0.015, 0.25, 3.0)
    x = drif(x, w, 0.45)
    x = sext(x, w, 0.1, +0.5)
    x = drif(x, w, 0.45)
    x = quad(x, w, 0.19, 0.50)
    return x
[5]:
# Set evaluation point

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
w = torch.tensor(1*[0.0], dtype=dtype, device=device)
[6]:
# Find (dynamical) fixed point

fp = fixed_point(32, lambda x, w: map_01_02(x, w), x, w, power=1)

# Check fixed point

print(fp)
print(map_01_02(fp, w))
tensor([ 8.418072943377e-04, -5.000000000000e-05, -2.043309959087e-03,
         5.000000000000e-05], dtype=torch.float64)
tensor([ 8.418072943377e-04, -5.000000000000e-05, -2.043309959087e-03,
         5.000000000000e-05], dtype=torch.float64)
[7]:
# Set computation orders for state and each knob group

(nx, nw) = (3, 2)
[8]:
# Find parametric fixed point

pfp = parametric_fixed_point((nw, ), fp, [w], lambda x, w: map_01_02(x, w))

# Check

print(compare(pfp, propagate((4, 1), (0, nw), pfp, [w], lambda x, w: map_01_02(x, w))))
True
[9]:
# Define transformations around parametric fixed points

# Note, this transformation map zero (parametric) state to zero (upto given order)
# This is true by construction

def fn_01_02(x, w):
    return map_01_02(x + evaluate(first(pfp), [w]), w) - evaluate(first(pfp), [w])

out = propagate((4, 1), (0, nw), identity((0, nw), [x, w]), [w], fn_01_02)
chop(out, replace=True)
out
[9]:
[[[], [], []]]
[10]:
# Compute derivative table

t = identity((nx, nw), [x, w])
t = propagate((4, 1), (nx, nw), t, [w], lambda x, w: fn_01_02(x, w))
[11]:
# Compute inverse

t_inv = inverse((nx, nw), x, [w], t)
[12]:
# Check

out = propagate((4, 1), (nx, nw), t_inv, [w], t)
chop(out, replace=True)
out
[12]:
[[[], [], []],
 [tensor([[1.000000000000e+00, 0.000000000000e+00, 0.000000000000e+00, 0.000000000000e+00],
          [0.000000000000e+00, 1.000000000000e+00, 0.000000000000e+00, 0.000000000000e+00],
          [0.000000000000e+00, 0.000000000000e+00, 1.000000000000e+00, 0.000000000000e+00],
          [0.000000000000e+00, 0.000000000000e+00, 0.000000000000e+00, 1.000000000000e+00]],
         dtype=torch.float64),
  [],
  []],
 [[], [], []],
 [[], [], []]]

Example-29: Momenta generator

[1]:
# In this example initial and final monemta are computed from given initial and final coordinates
# Given an origing preserving mapping, its table representation can be computed upto some order
# This representation can be considered as exact
# Next, given initial and final coordinates, corresponding momenta can be computed using momenta generator
[2]:
# Import

import numpy
import torch

from ndmap.util import first
from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.evaluate import evaluate
from ndmap.evaluate import compare
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import parametric_fixed_point
from ndmap.momenta import momenta

torch.set_printoptions(precision=12, sci_mode=True)
print(torch.cuda.is_available())

import warnings
warnings.filterwarnings("ignore")
True
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=10):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=5):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=20):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def kick(x, cx, cy):
    (qx, px, qy, py) = x
    return torch.stack([qx, px + cx, qy, py + cy])

def slip(x, dx, dy):
    (qx, px, qy, py) = x
    return torch.stack([qx + dx, px, qy + dy, py])
[5]:
# Set transport maps between observation points

def map_01_02(x, w):
    x = kick(x, +1.0E-4, -1.0E-4)
    x = quad(x, w, 0.19, 0.50)
    x = drif(x, w, 0.45)
    x = sext(x, w, 0.1, +0.5)
    x = drif(x, w, 0.45)
    x = bend(x, w, 22.92, 0.015, 0.25, 3.0)
    x = drif(x, w, 0.45)
    x = sext(x, w, 0.1, -0.5)
    x = drif(x, w, 0.45)
    x = quad(x, w, -0.21, 0.50)
    x = quad(x, w, -0.21, 0.50)
    x = drif(x, w, 0.45)
    x = sext(x, w, 0.1, -0.5)
    x = drif(x, w, 0.45)
    x = bend(x, w, 22.92, 0.015, 0.25, 3.0)
    x = drif(x, w, 0.45)
    x = sext(x, w, 0.1, +0.5)
    x = drif(x, w, 0.45)
    x = quad(x, w, 0.19, 0.50)
    return x
[6]:
# Set evaluation point

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
w = torch.tensor(1*[0.0], dtype=dtype, device=device)
[7]:
# Find (dynamical) fixed point

fp = fixed_point(32, lambda x, w: map_01_02(x, w), x, w, power=1)

# Check fixed point

print(fp)
print(map_01_02(fp, w))
tensor([ 8.418072943377e-04, -5.000000000000e-05, -2.043309959087e-03,
         5.000000000000e-05], dtype=torch.float64)
tensor([ 8.418072943377e-04, -5.000000000000e-05, -2.043309959087e-03,
         5.000000000000e-05], dtype=torch.float64)
[8]:
# Set computation orders for state and each knob group

(nx, nw) = (4, 2)
[9]:
# Find parametric fixed point

pfp = parametric_fixed_point((nw, ), fp, [w], lambda x, w: map_01_02(x, w))

# Check

print(compare(pfp, propagate((4, 1), (0, nw), pfp, [w], lambda x, w: map_01_02(x, w))))
True
[10]:
# Define transformations around parametric fixed points

# Note, this transformation map zero (parametric) state to zero (upto given order)
# This is true by construction

def fn_01_02(x, w):
    return map_01_02(x + evaluate(first(pfp), [w]), w) - evaluate(first(pfp), [w])

out = propagate((4, 1), (0, nw), identity((0, nw), [x, w]), [w], fn_01_02)
chop(out)
out
[10]:
[[tensor([0., 0., 0., 0.], dtype=torch.float64),
  tensor([[0.],
          [0.],
          [0.],
          [0.]], dtype=torch.float64),
  tensor([[[0.]],

          [[0.]],

          [[0.]],

          [[0.]]], dtype=torch.float64)]]
[11]:
# Compute derivative table

t = identity((nx, nw), [x, w])
t = propagate((4, 1), (nx, nw), t, [w], lambda x, w: fn_01_02(x, w))
[12]:
# Compute momenta generator

# Note, computation order can be different from that of the input table
# Accuracy is strongly related to computation order and magnitude of initial coordinates

m = momenta((nx, nw), x, [w], t)
[13]:
# Recover momenta from coordinates

# Set deviations

xi = torch.tensor([0.0005, 0.0001, -0.0005, -0.0001], dtype=dtype, device=device)
dw = torch.tensor(1*[0.0001], dtype=dtype, device=device)

# Evaluate final state using table representation

xf = evaluate(t, [xi, dw])

print(xf)
print()

# Set initial and final coordinates

qi, _ = xi.reshape(-1, 2).T
qf, _ = xf.reshape(-1, 2).T
qs = torch.cat([qf, qi])

print(qs)
print()

# Set initial and final momenta

_, pi, = xi.reshape(-1, 2).T
_, pf, = xf.reshape(-1, 2).T
ps = torch.cat([pf, pi])

print(ps)
print()

# Evaluate generator using coordinates

print(evaluate(m, [qs, 0*dw]))
print(evaluate(m, [qs, 1*dw]))
tensor([ 1.536782226285e-03, -2.360460895977e-05, -1.128298475344e-03,
        -6.800430927953e-05], dtype=torch.float64)

tensor([1.536782226285e-03, -1.128298475344e-03, 5.000000000000e-04, -5.000000000000e-04],
       dtype=torch.float64)

tensor([-2.360460895977e-05, -6.800430927953e-05,  1.000000000000e-04,
        -1.000000000000e-04], dtype=torch.float64)

tensor([-2.353065307569e-05, -6.780339520304e-05,  9.993585406462e-05,
        -1.001716171999e-04], dtype=torch.float64)
tensor([-2.360460892419e-05, -6.800430928312e-05,  9.999999996933e-05,
        -1.000000000050e-04], dtype=torch.float64)

Example-30: Factorization (kick)

[1]:
# Given a derivative table representation (taylor series) of an origin preserving mapping
# It is possible to factorize it into composition of linear and nonlinear table representations
# If original mapping represents a single turn, linear part can be brought to its normal form
# This can be done using linear theory
# The nonlinear part is a near identity transformation, it can be represented with Lie transformation
# t = tl o tn and tn = exp(-[h])
# In this example (generator) h is computed for a single kick
# It can be used to construct nonlinear normal forms or as a redundancy free representation
[2]:
# Import

import torch

from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.evaluate import evaluate
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.series import series
from ndmap.series import clean
from ndmap.taylor import taylor
from ndmap.inverse import inverse
from ndmap.factorization import hamiltonian

torch.set_printoptions(precision=12, sci_mode=True)
print(torch.cuda.is_available())

import warnings
warnings.filterwarnings("ignore")
True
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Set transformation

# Note, this transformation can be exactly represented with a taylor series (i.e. polynomial)

def mapping(x, k, l):
    (qx, px, qy, py), (k, ), l = x, k, l/2
    qx, qy = qx + l*px, qy + l*py
    px, py = px - 1.0*l*k*(qx**2 - qy**2), py + 2.0*l*k*qx*qy
    qx, qy = qx + l*px, qy + l*py
    return torch.stack([qx, px, qy, py])

# Test origin propagation

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
k = torch.tensor(1*[0.0], dtype=dtype, device=device)

print(mapping(x, k, 0.1))
tensor([0., 0., 0., 0.], dtype=torch.float64)
[5]:
# Compute table representation

t = identity((2, 1), [x, k])
t = propagate((4, 1), (2, 1), t, [k], mapping, 0.1)
[6]:
# Compare for some deviation

dx = torch.tensor([0.1, 0.01, 0.05, 0.01], dtype=dtype, device=device)
dk = torch.tensor([1.0],dtype=dtype, device=device)

print(mapping(x + dx, k + dk, 0.1))
print(evaluate(t, [dx, dk]))
tensor([1.009811250000e-01, 9.622500000000e-03, 5.102537625000e-02, 1.050752500000e-02],
       dtype=torch.float64)
tensor([1.009811250000e-01, 9.622500000000e-03, 5.102537625000e-02, 1.050752500000e-02],
       dtype=torch.float64)
[7]:
# Linear part is not near identity

print(derivative(1, lambda x, k: evaluate(t, [x, k]), x, k, intermediate=False))
tensor([[1.000000000000e+00, 1.000000000000e-01, 0.000000000000e+00, 0.000000000000e+00],
        [0.000000000000e+00, 1.000000000000e+00, 0.000000000000e+00, 0.000000000000e+00],
        [0.000000000000e+00, 0.000000000000e+00, 1.000000000000e+00, 1.000000000000e-01],
        [0.000000000000e+00, 0.000000000000e+00, 0.000000000000e+00, 1.000000000000e+00]],
       dtype=torch.float64)
[8]:
# Set linear part and compose with its inverse

l = derivative(1, lambda x, k: evaluate(t, [x, k]), x, k)
t = propagate((4, 1), (2, 1), inverse(1, x, [k], l), [k], t)
chop(t)
[9]:
# Now table represents a near identity transformation

print(derivative(1, lambda x, k: evaluate(t, [x, k]), x, k, intermediate=False))
tensor([[1.000000000000e+00, 0.000000000000e+00, 0.000000000000e+00, 0.000000000000e+00],
        [0.000000000000e+00, 1.000000000000e+00, 0.000000000000e+00, 0.000000000000e+00],
        [0.000000000000e+00, 0.000000000000e+00, 1.000000000000e+00, 0.000000000000e+00],
        [0.000000000000e+00, 0.000000000000e+00, 0.000000000000e+00, 1.000000000000e+00]],
       dtype=torch.float64)
[10]:
# Compute single exponent generator

h = hamiltonian((2, 1), x, [k], t)
chop(h)

# Compute series representation

# Note, each coefficient is proportional to strength

s = clean(series((4, 1), (3, 1), h))
for index, value in s.items():
    print(index, value.squeeze())
(3, 0, 0, 0, 1) tensor(1.666666666667e-02, dtype=torch.float64)
(2, 1, 0, 0, 1) tensor(-2.500000000000e-03, dtype=torch.float64)
(1, 2, 0, 0, 1) tensor(1.250000000000e-04, dtype=torch.float64)
(1, 0, 2, 0, 1) tensor(-5.000000000000e-02, dtype=torch.float64)
(1, 0, 1, 1, 1) tensor(5.000000000000e-03, dtype=torch.float64)
(1, 0, 0, 2, 1) tensor(-1.250000000000e-04, dtype=torch.float64)
(0, 3, 0, 0, 1) tensor(-2.083333333333e-06, dtype=torch.float64)
(0, 1, 2, 0, 1) tensor(2.500000000000e-03, dtype=torch.float64)
(0, 1, 1, 1, 1) tensor(-2.500000000000e-04, dtype=torch.float64)
(0, 1, 0, 2, 1) tensor(6.250000000000e-06, dtype=torch.float64)
[11]:
# In this example, this hamiltonian generates exact solution with taylor intergator

print(mapping(x + dx, k + dk, 0.1))
print(taylor(1, 1.0, lambda x, k: evaluate(h, [x, k]), evaluate(l, [dx, dk]), dk))
tensor([1.009811250000e-01, 9.622500000000e-03, 5.102537625000e-02, 1.050752500000e-02],
       dtype=torch.float64)
tensor([1.009811250000e-01, 9.622500000000e-03, 5.102537625000e-02, 1.050752500000e-02],
       dtype=torch.float64)

Example-31: Factorization (fodo)

[1]:
# Import

import torch

from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.evaluate import evaluate
from ndmap.evaluate import compare
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.series import series
from ndmap.series import clean
from ndmap.taylor import taylor
from ndmap.inverse import inverse
from ndmap.factorization import hamiltonian
from ndmap.factorization import solution

torch.set_printoptions(precision=12, sci_mode=True)
print(torch.cuda.is_available())

import warnings
warnings.filterwarnings("ignore")
True
[2]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[3]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=5):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=1):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=10):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def kick(x, cx, cy):
    (qx, px, qy, py) = x
    return torch.stack([qx, px + cx, qy, py + cy])

def slip(x, dx, dy):
    (qx, px, qy, py) = x
    return torch.stack([qx + dx, px, qy + dy, py])
[4]:
# Set fodo

def fodo(x):
    x = quad(x, [0.0], 0.19, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.1, +0.5)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015, 0.25, 3.0)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.1, -0.5)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], -0.21, 0.50)
    x = quad(x, [0.0], -0.21, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.1, -0.5)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015, 0.25, 3.0)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.1, +0.5)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], 0.19, 0.50)
    return x
[5]:
# Set computation order & evaluation point

n = 4
x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)

# Note, origin is preserved

fodo(x)
[5]:
tensor([0., 0., 0., 0.], dtype=torch.float64)
[6]:
# Compute derivative table

t = identity((n, ), [x])
t = propagate((4, ), (n, ), t, [], lambda x: fodo(x))
[7]:
# Set linear part

l = derivative(1, lambda x: evaluate(t, [x]), x)
l
[7]:
[tensor([0., 0., 0., 0.], dtype=torch.float64),
 tensor([[8.259928915375e-02, 1.470387298165e+01, 0.000000000000e+00, 0.000000000000e+00],
         [-6.754528950779e-02, 8.259928915375e-02, 0.000000000000e+00, 0.000000000000e+00],
         [0.000000000000e+00, 0.000000000000e+00, 8.239443265215e-01, 6.857644998910e+00],
         [0.000000000000e+00, 0.000000000000e+00, -4.682595072274e-02, 8.239443265215e-01]],
        dtype=torch.float64)]
[8]:
# Set nonlinear part

u = propagate((4, ), (n, ), inverse(1, x, [], l), [], t)
chop(u)
derivative(1, lambda x: evaluate(u, [x]), x)
[8]:
[tensor([0., 0., 0., 0.], dtype=torch.float64),
 tensor([[1.000000000000e+00, 0.000000000000e+00, 0.000000000000e+00, 0.000000000000e+00],
         [0.000000000000e+00, 1.000000000000e+00, 0.000000000000e+00, 0.000000000000e+00],
         [0.000000000000e+00, 0.000000000000e+00, 1.000000000000e+00, 0.000000000000e+00],
         [0.000000000000e+00, 0.000000000000e+00, 0.000000000000e+00, 1.000000000000e+00]],
        dtype=torch.float64)]
[9]:
# Compute hamiltonian

h = hamiltonian((n, ), x, [], u)
chop(h)
[10]:
# Compute series representation

s = clean(series((4, ), (n + 1, ), h), epsilon=1.0E-9)
for index, value in s.items():
    print(index, value.squeeze())
(3, 0, 0, 0) tensor(5.024988945080e-02, dtype=torch.float64)
(2, 1, 0, 0) tensor(-8.027849817894e-01, dtype=torch.float64)
(1, 2, 0, 0) tensor(1.242554019904e+01, dtype=torch.float64)
(1, 0, 2, 0) tensor(-5.976214487541e-01, dtype=torch.float64)
(1, 0, 1, 1) tensor(3.719621673904e+00, dtype=torch.float64)
(1, 0, 0, 2) tensor(-6.659517059512e+00, dtype=torch.float64)
(0, 3, 0, 0) tensor(-1.465736828313e+02, dtype=torch.float64)
(0, 1, 2, 0) tensor(7.616595688746e+00, dtype=torch.float64)
(0, 1, 1, 1) tensor(-6.812652464723e+01, dtype=torch.float64)
(0, 1, 0, 2) tensor(1.637189795764e+02, dtype=torch.float64)
(4, 0, 0, 0) tensor(-2.670453079972e-02, dtype=torch.float64)
(3, 1, 0, 0) tensor(1.592046859279e+00, dtype=torch.float64)
(2, 2, 0, 0) tensor(-4.015290100155e+01, dtype=torch.float64)
(2, 0, 2, 0) tensor(1.098921583040e-02, dtype=torch.float64)
(2, 0, 1, 1) tensor(-1.150214563563e+00, dtype=torch.float64)
(2, 0, 0, 2) tensor(5.596016324538e+00, dtype=torch.float64)
(1, 3, 0, 0) tensor(2.720415502394e+02, dtype=torch.float64)
(1, 1, 2, 0) tensor(9.092563694884e+00, dtype=torch.float64)
(1, 1, 1, 1) tensor(-7.311435065924e+01, dtype=torch.float64)
(1, 1, 0, 2) tensor(1.043406910282e+02, dtype=torch.float64)
(0, 4, 0, 0) tensor(-8.872464473277e+02, dtype=torch.float64)
(0, 2, 2, 0) tensor(2.515413272498e+01, dtype=torch.float64)
(0, 2, 1, 1) tensor(6.993401944277e+01, dtype=torch.float64)
(0, 2, 0, 2) tensor(-3.593047920669e+02, dtype=torch.float64)
(0, 0, 4, 0) tensor(-1.258999636146e+00, dtype=torch.float64)
(0, 0, 3, 1) tensor(1.898846058062e+01, dtype=torch.float64)
(0, 0, 2, 2) tensor(-1.062500223737e+02, dtype=torch.float64)
(0, 0, 1, 3) tensor(2.608144282259e+02, dtype=torch.float64)
(0, 0, 0, 4) tensor(-2.419706859670e+02, dtype=torch.float64)
(5, 0, 0, 0) tensor(3.720153703396e-02, dtype=torch.float64)
(4, 1, 0, 0) tensor(-2.470046738149e+00, dtype=torch.float64)
(3, 2, 0, 0) tensor(5.465213020046e+01, dtype=torch.float64)
(3, 0, 2, 0) tensor(6.994399435987e-02, dtype=torch.float64)
(3, 0, 1, 1) tensor(-9.734770779462e-01, dtype=torch.float64)
(3, 0, 0, 2) tensor(5.509140066922e-01, dtype=torch.float64)
(2, 3, 0, 0) tensor(-7.728896235979e+02, dtype=torch.float64)
(2, 1, 2, 0) tensor(1.614402519455e+01, dtype=torch.float64)
(2, 1, 1, 1) tensor(-1.014114400601e+02, dtype=torch.float64)
(2, 1, 0, 2) tensor(2.099498728183e+02, dtype=torch.float64)
(1, 4, 0, 0) tensor(6.599464974629e+03, dtype=torch.float64)
(1, 2, 2, 0) tensor(-1.831472433171e+02, dtype=torch.float64)
(1, 2, 1, 1) tensor(1.658609587412e+03, dtype=torch.float64)
(1, 2, 0, 2) tensor(-4.455648121228e+03, dtype=torch.float64)
(1, 0, 4, 0) tensor(-2.372154036991e+00, dtype=torch.float64)
(1, 0, 3, 1) tensor(3.120364861986e+01, dtype=torch.float64)
(1, 0, 2, 2) tensor(-1.392873544053e+02, dtype=torch.float64)
(1, 0, 1, 3) tensor(2.280297561995e+02, dtype=torch.float64)
(1, 0, 0, 4) tensor(-5.830536198793e+01, dtype=torch.float64)
(0, 5, 0, 0) tensor(-1.712807218852e+04, dtype=torch.float64)
(0, 3, 2, 0) tensor(-5.278629070820e+02, dtype=torch.float64)
(0, 3, 1, 1) tensor(3.474912442117e+03, dtype=torch.float64)
(0, 3, 0, 2) tensor(1.999757814650e+02, dtype=torch.float64)
(0, 1, 4, 0) tensor(2.874466602579e+01, dtype=torch.float64)
(0, 1, 3, 1) tensor(-5.199961828427e+02, dtype=torch.float64)
(0, 1, 2, 2) tensor(3.267795295152e+03, dtype=torch.float64)
(0, 1, 1, 3) tensor(-8.677457594504e+03, dtype=torch.float64)
(0, 1, 0, 4) tensor(8.128231575922e+03, dtype=torch.float64)
[11]:
# Restore mapping table nonlinear part from hamiltonian

print(compare(u, solution((n, ), x, [], h)))
print(compare(t, propagate((4, ), (n, ), l, [], solution((n, ), x, [], h))))
True
True

Example-32: Factorization (factorize)

[1]:
# Import

import torch
from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.evaluate import evaluate
from ndmap.evaluate import compare
from ndmap.series import series
from ndmap.series import clean
from ndmap.series import split
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.bracket import bracket
from ndmap.factorization import hamiltonian
from ndmap.factorization import solution
from ndmap.factorization import factorize

torch.set_printoptions(precision=12, sci_mode=True)
print(torch.cuda.is_available())

import warnings
warnings.filterwarnings("ignore")
True
[2]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[3]:
# Set test hamiltonian function and compute corresponding table representation

# Note, test hamiltonian is near identity

def h(x, u, v):
    q, p = x
    u, = u
    v, = v
    h1 = (1 + u + u**2)*q**3 + q**2*p + q*p**2 + p**3
    h2 = (1 + v)*q**4 + q**3*p + q**2*p**2 + q*p**3 + p**4
    h3 = q**5 + q**4*p + q**3*p**2 + q**2*p**3 + q*p**4 + p**5
    return h1 + h2 + h3

x = torch.tensor([0.0, 0.0], dtype=dtype, device=device)
u = torch.tensor([0.0], dtype=dtype, device=device)
v = torch.tensor([0.0], dtype=dtype, device=device)

h = derivative((5, 2, 1), h, x, u, v)
chop(h, replace=True)

s, *_ = split(clean(series((2, 1, 1), (5, 2, 1), h)))
s
[3]:
{(3, 0, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (2, 1, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (1, 2, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (0, 3, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (3, 0, 1, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (3, 0, 2, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (4, 0, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (3, 1, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (2, 2, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (1, 3, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (0, 4, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (4, 0, 0, 1): tensor(1.000000000000e+00, dtype=torch.float64),
 (5, 0, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (4, 1, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (3, 2, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (2, 3, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (1, 4, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (0, 5, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64)}
[4]:
# Compute solution

t = solution((4, 2, 1), x, [u, v], h)
[5]:
# Compute hamiltonian from solution and compare

compare(h, hamiltonian((4, 2, 1), x, [u, v], t))
[5]:
True
[6]:
# Perform factorization

h1, h2, h3, *_ = factorize((4, 2, 1), x, [u, v], t)
[7]:
# Examine series representation

s, *_ = split(clean(series((2, 1, 1), (5, 1), h1)))
s
[7]:
{(3, 0, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (2, 1, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (1, 2, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (0, 3, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (3, 0, 1, 0): tensor(1.000000000000e+00, dtype=torch.float64)}
[8]:
# Examine series representation

s, *_ = split(clean(series((2, 1, 1), (5, 1), h2)))
s
[8]:
{(4, 0, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (3, 1, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (2, 2, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (1, 3, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (0, 4, 0, 0): tensor(1.000000000000e+00, dtype=torch.float64),
 (4, 0, 0, 1): tensor(1.000000000000e+00, dtype=torch.float64)}
[9]:
# Examine series representation

s, *_ = split(clean(series((2, 1, 1), (5, 1), h3)))
s
[9]:
{(5, 0, 0, 0): tensor(5.000000000000e-01, dtype=torch.float64),
 (4, 1, 0, 0): tensor(-5.000000000000e-01, dtype=torch.float64),
 (3, 2, 0, 0): tensor(-2.000000000000e+00, dtype=torch.float64),
 (2, 3, 0, 0): tensor(4.000000000000e+00, dtype=torch.float64),
 (1, 4, 0, 0): tensor(2.500000000000e+00, dtype=torch.float64),
 (0, 5, 0, 0): tensor(1.500000000000e+00, dtype=torch.float64),
 (5, 0, 0, 1): tensor(-2.000000000000e+00, dtype=torch.float64),
 (4, 1, 0, 1): tensor(-4.000000000000e+00, dtype=torch.float64),
 (3, 2, 0, 1): tensor(-6.000000000000e+00, dtype=torch.float64),
 (5, 0, 1, 0): tensor(1.500000000000e+00, dtype=torch.float64),
 (4, 1, 1, 0): tensor(3.000000000000e+00, dtype=torch.float64),
 (3, 2, 1, 0): tensor(4.500000000000e+00, dtype=torch.float64),
 (2, 3, 1, 0): tensor(6.000000000000e+00, dtype=torch.float64)}
[10]:
# Compute solutions

t1 = solution((4, 2, 1), x, [u, v], h1)
t2 = solution((4, 2, 1), x, [u, v], h2)
t3 = solution((4, 2, 1) ,x, [u, v], h3)
[11]:
# Compose individual solutions and compare with initial solution

T = identity((4, 2, 1), [x, u, v])
T = propagate((2, 1, 1), (4, 2, 1), T, [u, v], t1)
T = propagate((2, 1, 1), (4, 2, 1), T, [u, v], t2)
T = propagate((2, 1, 1), (4, 2, 1), T, [u, v], t3)
compare(t, T)
[11]:
True

Example-33: Nonlinear mapping approximation (gradient)

[1]:
# Import

import numpy
import torch

from ndmap.gradient import series
from ndmap.series import clean

torch.set_printoptions(precision=12, sci_mode=True)
print(torch.cuda.is_available())

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
True
[2]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[3]:
# Set test mapping
# Rotation with two sextupoles separated by negative identity linear transformation
# Note, result is expected to have zero degree two coefficients due to negative identity linear transformation between sextupoles

def spin(x, mux, muy):
    (qx, px, qy, py), mux, muy = x, mux, muy
    return torch.stack([qx*mux.cos() + px*mux.sin(), px*mux.cos() - qx*mux.sin(), qy*muy.cos() + py*muy.sin(), py*muy.cos() - qy*muy.sin()])

def drif(x, l):
    (qx, px, qy, py), l = x, l
    return torch.stack([qx + l*px, px, qy + l*py, py])

def sext(x, ks, l, n=1):
    (qx, px, qy, py), ks, l = x, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px, qy + l*py
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px, qy + l*py
    return torch.stack([qx, px, qy, py])

def ring(x):
    mux, muy = 2.0*numpy.pi*torch.tensor([1/3 + 0.01, 1/4 + 0.01], dtype=dtype, device=device)
    x = spin(x, mux, muy)
    x = drif(x, -0.05)
    x = sext(x, 10.0, 0.1, 100)
    x = drif(x, -0.05)
    mux, muy = 2.0*numpy.pi*torch.tensor([0.50, 0.50], dtype=dtype, device=device)
    x = spin(x, mux, muy)
    x = drif(x, -0.05)
    x = sext(x, 10.0, 0.1, 100)
    x = drif(x, -0.05)
    return x
[4]:
# Set evaluation point

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=dtype, device=device)

# Compute and print series

n = 4
s = clean(series((n, ), ring, x, retain=False), epsilon=1.0E-12)
print(*[f'{key}: {value.cpu().numpy()}' for key, value in clean(s, epsilon=1.0E-14).items()], sep='\n')
(1, 0, 0, 0): [0.55339155 0.83292124 0.         0.        ]
(0, 1, 0, 0): [-0.83292124  0.55339155  0.          0.        ]
(0, 0, 1, 0): [0.         0.         0.06279052 0.99802673]
(0, 0, 0, 1): [ 0.          0.         -0.99802673  0.06279052]
(3, 0, 0, 0): [-7.53257307e-09  2.82424677e-03 -0.00000000e+00 -0.00000000e+00]
(2, 1, 0, 0): [-1.96250238e-08 -1.27525063e-02 -0.00000000e+00 -0.00000000e+00]
(2, 0, 1, 0): [-0.00000000e+00 -0.00000000e+00  9.21186111e-06  3.34331441e-04]
(2, 0, 0, 1): [-0.00000000e+00 -0.00000000e+00  1.59704766e-05 -5.06941449e-03]
(1, 2, 0, 0): [-1.11004920e-08  1.91940679e-02 -0.00000000e+00 -0.00000000e+00]
(1, 1, 1, 0): [-0.00000000e+00 -0.00000000e+00 -2.98671134e-05 -9.79459005e-04]
(1, 1, 0, 1): [-0.00000000e+00 -0.00000000e+00 -1.48185697e-05  1.53623878e-02]
(1, 0, 2, 0): [-1.05857397e-06  1.97282603e-05 -0.00000000e+00 -0.00000000e+00]
(1, 0, 1, 1): [ 1.48154798e-05 -1.18570682e-03 -0.00000000e+00 -0.00000000e+00]
(1, 0, 0, 2): [ 2.88067554e-05  9.18409783e-03 -0.00000000e+00 -0.00000000e+00]
(0, 3, 0, 0): [-9.21338045e-10 -9.62979589e-03  0.00000000e+00  0.00000000e+00]
(0, 2, 1, 0): [0.00000000e+00 0.00000000e+00 2.40366928e-05 7.09979811e-04]
(0, 2, 0, 1): [ 0.00000000e+00  0.00000000e+00 -1.38786549e-05 -1.15294375e-02]
(0, 1, 2, 0): [ 1.80396305e-06 -2.59196622e-05  0.00000000e+00  0.00000000e+00]
(0, 1, 1, 1): [-2.98509155e-05  1.72488752e-03  0.00000000e+00  0.00000000e+00]
(0, 1, 0, 2): [ 1.66318846e-05 -1.38269522e-02  0.00000000e+00  0.00000000e+00]
(0, 0, 3, 0): [-0.00000000e+00 -0.00000000e+00 -1.47704279e-08  4.12534350e-06]
(0, 0, 2, 1): [-0.00000000e+00 -0.00000000e+00 -3.31103168e-09 -1.96719688e-04]
(0, 0, 1, 2): [-0.00000000e+00 -0.00000000e+00  3.93325786e-09  3.12683573e-03]
(0, 0, 0, 3): [ 0.00000000e+00  0.00000000e+00  2.56887204e-10 -1.65665408e-02]
(4, 0, 0, 0): [ 6.48869023e-07 -3.91844462e-06  0.00000000e+00  0.00000000e+00]
(3, 1, 0, 0): [-3.91685700e-06  1.51001133e-05  0.00000000e+00  0.00000000e+00]
(3, 0, 1, 0): [ 0.00000000e+00  0.00000000e+00 -4.36165465e-07  5.76316391e-06]
(3, 0, 0, 1): [0.00000000e+00 0.00000000e+00 6.56410859e-06 1.31668166e-05]
(2, 2, 0, 0): [ 8.85501662e-06 -1.48880894e-05  0.00000000e+00  0.00000000e+00]
(2, 1, 1, 0): [ 0.00000000e+00  0.00000000e+00  1.92232731e-06 -2.78183189e-05]
(2, 1, 0, 1): [ 0.00000000e+00  0.00000000e+00 -2.96894787e-05 -3.16635607e-05]
(2, 0, 2, 0): [1.81838643e-08 2.18816315e-06 0.00000000e+00 0.00000000e+00]
(2, 0, 1, 1): [-9.93314202e-07 -3.25277215e-05  0.00000000e+00  0.00000000e+00]
(2, 0, 0, 2): [ 7.67316422e-06 -2.51871510e-05  0.00000000e+00  0.00000000e+00]
(1, 3, 0, 0): [-8.88618620e-06 -4.33921249e-06  0.00000000e+00  0.00000000e+00]
(1, 2, 1, 0): [ 0.00000000e+00  0.00000000e+00 -2.81910856e-06  4.44963121e-05]
(1, 2, 0, 1): [0.00000000e+00 0.00000000e+00 4.47107608e-05 6.22767407e-06]
(1, 1, 2, 0): [-5.02653030e-08 -6.69892371e-06  0.00000000e+00  0.00000000e+00]
(1, 1, 1, 1): [2.90625131e-06 1.04277202e-04 0.00000000e+00 0.00000000e+00]
(1, 1, 0, 2): [-2.28808176e-05  2.60346923e-05  0.00000000e+00  0.00000000e+00]
(1, 0, 3, 0): [0.00000000e+00 0.00000000e+00 2.09236592e-09 3.51323891e-08]
(1, 0, 2, 1): [ 0.00000000e+00  0.00000000e+00 -6.41657941e-09 -1.78350571e-06]
(1, 0, 1, 2): [ 0.00000000e+00  0.00000000e+00 -6.10954936e-07  1.86435774e-05]
(1, 0, 0, 3): [ 0.00000000e+00  0.00000000e+00  3.05503037e-06 -5.00150901e-05]
(0, 4, 0, 0): [3.33982092e-06 8.88114110e-06 0.00000000e+00 0.00000000e+00]
(0, 3, 1, 0): [ 0.00000000e+00  0.00000000e+00  1.37553970e-06 -2.36217768e-05]
(0, 3, 0, 1): [ 0.00000000e+00  0.00000000e+00 -2.24184174e-05  1.77513110e-05]
(0, 2, 2, 0): [3.54703897e-08 5.11986525e-06 0.00000000e+00 0.00000000e+00]
(0, 2, 1, 1): [-2.15414781e-06 -8.31702815e-05  0.00000000e+00  0.00000000e+00]
(0, 2, 0, 2): [1.72952174e-05 1.78791226e-05 0.00000000e+00 0.00000000e+00]
(0, 1, 3, 0): [ 0.00000000e+00  0.00000000e+00 -3.32576455e-09 -2.16775859e-08]
(0, 1, 2, 1): [0.00000000e+00 0.00000000e+00 1.73891518e-08 1.32003603e-06]
(0, 1, 1, 2): [ 0.00000000e+00  0.00000000e+00  8.58583303e-07 -7.35197022e-06]
(0, 1, 0, 3): [ 0.00000000e+00  0.00000000e+00 -4.60205741e-06 -3.44817289e-05]
(0, 0, 4, 0): [-2.95410616e-10 -3.91436795e-10  0.00000000e+00  0.00000000e+00]
(0, 0, 3, 1): [ 1.72261161e-08 -3.76703368e-08  0.00000000e+00  0.00000000e+00]
(0, 0, 2, 2): [-3.76632244e-07  1.04173914e-06  0.00000000e+00  0.00000000e+00]
(0, 0, 1, 3): [ 3.81179356e-06 -5.44741177e-06  0.00000000e+00  0.00000000e+00]
(0, 0, 0, 4): [-1.51552013e-05 -3.46854944e-07  0.00000000e+00  0.00000000e+00]

Example-34: ORM optics correction

[1]:
# In this example orbit responce matrix (ORM) is used to correct linear optics in a simple FODO cell
# Two gradient errors are introduced into cell quadrupoles

# This example illustrates one optimization step
# Given a measured ORM, the model knobs are fitted to reproduce it
# Next, the corrections should be applied and the matrix should be remeasured
[2]:
# Import

import numpy
import torch

from ndmap.util import first
from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.evaluate import evaluate
from ndmap.evaluate import compare
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import parametric_fixed_point

torch.set_printoptions(precision=3, sci_mode=True, linewidth=128)
print(torch.cuda.is_available())

import matplotlib.pyplot as plt
True
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=5):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=1):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=10):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def kick(x, cx, cy):
    (qx, px, qy, py) = x
    return torch.stack([qx, px + cx, qy, py + cy])

def slip(x, dx, dy):
    (qx, px, qy, py) = x
    return torch.stack([qx + dx, px, qy + dy, py])
[5]:
# Set transport maps between observation points
# Note, transport maps are expected to have identical (differentiable) signature

def t_01_02(x, cs, dk):
    cxsf1, cxsd1, cxsf2, cxsd2, cysf1, cysd1, cysf2, cysd2 = cs
    kf, kd = dk
    x = quad(x, [0.0], 0.19 + kf, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.05)
    x = kick(x, cxsf1, cysf1)
    x = sext(x, [0.0], 0.00, 0.05)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015, 0.00, 1.5)
    return x

def t_02_03(x, cs, dk):
    cxsf1, cxsd1, cxsf2, cxsd2, cysf1, cysd1, cysf2, cysd2 = cs
    kf, kd = dk
    x = bend(x, [0.0], 22.92, 0.015, 0.00, 1.5)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.05)
    x = kick(x, cxsd1, cysd1)
    x = sext(x, [0.0], 0.00, 0.05)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], -0.21 + kd, 0.50)
    return x

def t_03_04(x, cs, dk):
    cxsf1, cxsd1, cxsf2, cxsd2, cysf1, cysd1, cysf2, cysd2 = cs
    kf, kd = dk
    x = quad(x, [0.0], -0.21 + kd, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.05)
    x = kick(x, cxsd2, cysd2)
    x = sext(x, [0.0], 0.00, 0.05)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015, 0.00, 1.5)
    return x

def t_04_05(x, cs, dk):
    cxsf1, cxsd1, cxsf2, cxsd2, cysf1, cysd1, cysf2, cysd2 = cs
    kf, kd = dk
    x = bend(x, [0.0], 22.92, 0.015, 0.00, 1.5)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.05)
    x = kick(x, cxsf2, cysf2)
    x = sext(x, [0.0], 0.00, 0.05)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], 0.19 + kf, 0.50)
    return x

ts = [t_01_02,t_02_03, t_03_04, t_04_05]
[6]:
# Set deviation variables

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
cs = torch.tensor(8*[0.0], dtype=dtype, device=device)
dk = torch.tensor(2*[0.0], dtype=dtype, device=device)
[7]:
# Define one-turn transport at the lattice entrance

def fodo(x, cs, kq):
    for t in ts:
        x = t(x, cs, kq)
    return x
[8]:
# Test one-turn transport

print(fodo(x, cs, dk))
tensor([0., 0., 0., 0.], dtype=torch.float64)
[9]:
# Compute (dynamical) fixed point
# Note, dynamical part is assumed to be fixed during optimization

fp = fixed_point(16, fodo, x, cs, dk, power=1, jacobian=torch.func.jacrev)

print(fp)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[10]:
# Define parametric responce matrix

def rm(dk):

    pfp = parametric_fixed_point((1, ), fp, [cs], lambda x, cs: fodo(x, cs, dk), jacobian=torch.func.jacrev)
    chop(pfp)
    _, (dqx, _, dqy, _) = first(pfp)

    out = [torch.stack([dqx, dqy])]
    for t in ts:
        pfp = propagate((4, 8), (0, 1), pfp, [cs], lambda x, cs: t(x, cs, dk),  jacobian=torch.func.jacrev)
        chop(pfp)
        _, (dqx, _, dqy, _) = first(pfp)
        out.append(torch.stack([dqx, dqy]))

    return torch.stack(out).swapaxes(0, 1).reshape(-1, len(cs))

print(2*(len(ts) + 1), len(cs))
print(rm(dk).shape)
print()

print(rm(dk))
10 8
torch.Size([10, 8])

tensor([[7.577e+00, 5.936e+00, 7.577e+00, 5.936e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [6.216e+00, 4.039e+00, 6.749e+00, 4.566e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [5.110e+00, 2.611e+00, 5.110e+00, 2.611e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [6.749e+00, 4.566e+00, 6.216e+00, 4.039e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [7.577e+00, 5.936e+00, 7.577e+00, 5.936e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.344e+01, 2.158e+01, 1.344e+01, 2.158e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.801e+01, 2.744e+01, 1.849e+01, 2.792e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 2.509e+01, 3.667e+01, 2.509e+01, 3.667e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.849e+01, 2.792e+01, 1.801e+01, 2.744e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.344e+01, 2.158e+01, 1.344e+01, 2.158e+01]], dtype=torch.float64)
[11]:
# Test responce matrix

dc = 1.0E-3*torch.ones_like(cs)

o = fixed_point(16, fodo, x, cs + dc, dk, power=1, jacobian=torch.func.jacrev)

os = []
qx, _, qy, _ = o
os.append(torch.stack([qx, qy]))

for t in ts:
    o = t(o, dc, dk)
    qx, _, qy, _ = o
    os.append(torch.stack([qx, qy]))

print(torch.allclose(torch.stack(os).T.flatten(), rm(dk) @ dc))
True
[12]:
# Set quadrupole gradient errors

ek = torch.tensor([-0.010, 0.005], dtype=dtype, device=device)
[13]:
# Measure ORM

erm = rm(ek)

print(erm)
tensor([[8.038e+00, 6.338e+00, 8.038e+00, 6.338e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [6.658e+00, 4.415e+00, 7.190e+00, 4.942e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [5.488e+00, 2.923e+00, 5.488e+00, 2.923e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [7.190e+00, 4.942e+00, 6.658e+00, 4.415e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [8.038e+00, 6.338e+00, 8.038e+00, 6.338e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.470e+01, 2.316e+01, 1.470e+01, 2.316e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.943e+01, 2.915e+01, 1.990e+01, 2.963e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 2.678e+01, 3.865e+01, 2.678e+01, 3.865e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.990e+01, 2.963e+01, 1.943e+01, 2.915e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.470e+01, 2.316e+01, 1.470e+01, 2.316e+01]], dtype=torch.float64)
[14]:
# Define objective to minimize

def objective(dk):
    return ((erm - rm(dk))**2).sum()

print(objective(dk))
tensor(5.298e+01, dtype=torch.float64)
[15]:
# Set model class

class Model(torch.nn.Module):

    def __init__(self, knobs):
        super().__init__()
        self.knobs = torch.nn.Parameter(torch.clone(knobs))

    def forward(self):
        return objective(self.knobs)
[16]:
# Set model instance
# Note, initial knobs are set to zero

model = Model(torch.zeros_like(dk))

print(model())
tensor(5.298e+01, dtype=torch.float64, grad_fn=<SumBackward0>)
[17]:
# Set optimizer

lr = 2.5E-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
[18]:
# Fit model

epochs = 256

print(ek)
print(model.knobs)
print(model.forward())
print()

knobs, errors = [], []

for epoch in range(epochs):
    error = model.forward()
    with torch.no_grad():
        knobs.append(model.knobs.clone().detach())
        errors.append(error.clone().detach())
    error.backward()
    optimizer.step()
    optimizer.zero_grad()
    if epoch % 10 == 0:
        print(f'epoch: {epoch}, error: {error.item()}')

print(ek)
print(model.knobs)
print(model.forward())
print()
tensor([-1.000e-02, 5.000e-03], dtype=torch.float64)
Parameter containing:
tensor([0., 0.], dtype=torch.float64, requires_grad=True)
tensor(5.298e+01, dtype=torch.float64, grad_fn=<SumBackward0>)

epoch: 0, error: 52.97690929514796
epoch: 10, error: 9.061476544019179
epoch: 20, error: 4.116026193086178
epoch: 30, error: 4.4716138801849885
epoch: 40, error: 2.5934201513661606
epoch: 50, error: 1.320221084167608
epoch: 60, error: 0.7104114404644974
epoch: 70, error: 0.39073287332555584
epoch: 80, error: 0.17765207040232386
epoch: 90, error: 0.07092862791751303
epoch: 100, error: 0.024482773974344098
epoch: 110, error: 0.006731421279765544
epoch: 120, error: 0.001272143234908415
epoch: 130, error: 0.00019181854968944713
epoch: 140, error: 2.3799691903263396e-05
epoch: 150, error: 2.155612473667951e-05
epoch: 160, error: 2.4284316948196817e-05
epoch: 170, error: 1.359976274459628e-05
epoch: 180, error: 4.543935327150297e-06
epoch: 190, error: 1.1354899250982208e-06
epoch: 200, error: 1.718342123892043e-07
epoch: 210, error: 2.551350446132945e-08
epoch: 220, error: 2.8003866265404305e-08
epoch: 230, error: 2.1762818183897804e-08
epoch: 240, error: 9.432013407780834e-09
epoch: 250, error: 2.3809020247549156e-09
tensor([-1.000e-02, 5.000e-03], dtype=torch.float64)
Parameter containing:
tensor([-1.000e-02, 5.000e-03], dtype=torch.float64, requires_grad=True)
tensor(5.351e-10, dtype=torch.float64, grad_fn=<SumBackward0>)

[19]:
# Plot error vs iteration

plt.figure(figsize=(20, 5))
plt.plot(range(len(errors)), torch.stack(errors).cpu().numpy(), color='black', marker='x')
plt.show()
../_images/examples_ndmap_418_0.png
[20]:
# Plot knobs vs iteration

plt.figure(figsize=(20, 5))
plt.hlines(ek.cpu().numpy(), 0, epochs, linestyles='dashed', color='gray', alpha=0.75)
for knob in torch.stack(knobs).T:
    plt.plot(range(len(knob)), knob.cpu().numpy(), color='black', marker='x')
plt.show()
../_images/examples_ndmap_419_0.png
[21]:
# Higher order derivatives can be used to create surrogate Taylor model of the ORM
# Taylor model can be expensive to compute, but computation is performed only one time
# Note, model can be also updated duaring optimization

n = 5
t = derivative(n, rm, dk, jacobian=torch.func.jacfwd)
[22]:
# Redefine objective fuction to use Taylor model

def objective(dk):
    return ((erm - evaluate(t, [dk]))**2).sum()

print(objective(dk))
tensor(5.298e+01, dtype=torch.float64)
[23]:
# Set model instance
# Note, initial knobs are set to zero

model = Model(torch.zeros_like(dk))

print(model())
tensor(5.298e+01, dtype=torch.float64, grad_fn=<SumBackward0>)
[24]:
# Set optimizer

lr = 2.5E-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
[25]:
# Fit model

epochs = 256

print(ek)
print(model.knobs)
print(model.forward())
print()

knobs, errors = [], []

for epoch in range(epochs):
    error = model.forward()
    with torch.no_grad():
        knobs.append(model.knobs.clone().detach())
        errors.append(error.clone().detach())
    error.backward()
    optimizer.step()
    optimizer.zero_grad()
    if epoch % 10 == 0:
        print(f'epoch: {epoch}, error: {error.item()}')

print(ek)
print(model.knobs)
print(model.forward())
print()
tensor([-1.000e-02, 5.000e-03], dtype=torch.float64)
Parameter containing:
tensor([0., 0.], dtype=torch.float64, requires_grad=True)
tensor(5.298e+01, dtype=torch.float64, grad_fn=<SumBackward0>)

epoch: 0, error: 52.976909295147884
epoch: 10, error: 9.059850191301422
epoch: 20, error: 4.116647404613835
epoch: 30, error: 4.472951117480212
epoch: 40, error: 2.5922560001567434
epoch: 50, error: 1.3204603071831524
epoch: 60, error: 0.7100392325724316
epoch: 70, error: 0.39053200207952343
epoch: 80, error: 0.17757259750915444
epoch: 90, error: 0.07088795313036353
epoch: 100, error: 0.024459899278350294
epoch: 110, error: 0.006720239692606007
epoch: 120, error: 0.001269417957035261
epoch: 130, error: 0.00019187232318368966
epoch: 140, error: 2.3597651642216307e-05
epoch: 150, error: 2.1673080207299973e-05
epoch: 160, error: 2.426302907045522e-05
epoch: 170, error: 1.3607180530103824e-05
epoch: 180, error: 4.539802198450803e-06
epoch: 190, error: 1.1314656087725447e-06
epoch: 200, error: 1.710558210516473e-07
epoch: 210, error: 2.5644830807954306e-08
epoch: 220, error: 2.8132779788222537e-08
epoch: 230, error: 2.1802735069022923e-08
epoch: 240, error: 9.435223715979277e-09
epoch: 250, error: 2.377689322793694e-09
tensor([-1.000e-02, 5.000e-03], dtype=torch.float64)
Parameter containing:
tensor([-1.000e-02, 5.000e-03], dtype=torch.float64, requires_grad=True)
tensor(5.331e-10, dtype=torch.float64, grad_fn=<SumBackward0>)

[26]:
# Plot error vs iteration

plt.figure(figsize=(20, 5))
plt.plot(range(len(errors)), torch.stack(errors).cpu().numpy(), color='black', marker='x')
plt.show()
../_images/examples_ndmap_425_0.png
[27]:
# Plot knobs vs iteration

plt.figure(figsize=(20, 5))
plt.hlines(ek.cpu().numpy(), 0, epochs, linestyles='dashed', color='gray', alpha=0.75)
for knob in torch.stack(knobs).T:
    plt.plot(range(len(knob)), knob.cpu().numpy(), color='black', marker='x')
plt.show()
../_images/examples_ndmap_426_0.png

Example-35: ORM optics correction (training loop)

[1]:
# In this example orbit responce matrix (ORM) is used to correct linear optics in a simple FODO cell
# Two gradient errors are introduced into cell quadrupoles

# This example illustrates one optimization step
# Given a measured ORM, the model knobs are fitted to reproduce it
# Next, the corrections should be applied and the matrix should be remeasured

# Fitting step mirrors neural net training loop
# Elements of measured responce matrix are used as targets

# Note, full ORM is computed for each batch
# It is also possible to define elementwise computation (see the next example)
[2]:
# Import

import numpy
import torch

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split

from ndmap.util import first
from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.evaluate import evaluate
from ndmap.evaluate import compare
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import parametric_fixed_point

torch.set_printoptions(precision=3, sci_mode=True, linewidth=128)
print(torch.cuda.is_available())

import matplotlib.pyplot as plt
True
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=5):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=1):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=10):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def kick(x, cx, cy):
    (qx, px, qy, py) = x
    return torch.stack([qx, px + cx, qy, py + cy])

def slip(x, dx, dy):
    (qx, px, qy, py) = x
    return torch.stack([qx + dx, px, qy + dy, py])
[5]:
# Set transport maps between observation points
# Note, transport maps are expected to have identical (differentiable) signature

def t_01_02(x, cs, dk):
    cxsf1, cxsd1, cxsf2, cxsd2, cysf1, cysd1, cysf2, cysd2 = cs
    kf, kd = dk
    x = quad(x, [0.0], 0.19 + kf, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.05)
    x = kick(x, cxsf1, cysf1)
    x = sext(x, [0.0], 0.00, 0.05)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015, 0.00, 1.5)
    return x

def t_02_03(x, cs, dk):
    cxsf1, cxsd1, cxsf2, cxsd2, cysf1, cysd1, cysf2, cysd2 = cs
    kf, kd = dk
    x = bend(x, [0.0], 22.92, 0.015, 0.00, 1.5)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.05)
    x = kick(x, cxsd1, cysd1)
    x = sext(x, [0.0], 0.00, 0.05)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], -0.21 + kd, 0.50)
    return x

def t_03_04(x, cs, dk):
    cxsf1, cxsd1, cxsf2, cxsd2, cysf1, cysd1, cysf2, cysd2 = cs
    kf, kd = dk
    x = quad(x, [0.0], -0.21 + kd, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.05)
    x = kick(x, cxsd2, cysd2)
    x = sext(x, [0.0], 0.00, 0.05)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015, 0.00, 1.5)
    return x

def t_04_05(x, cs, dk):
    cxsf1, cxsd1, cxsf2, cxsd2, cysf1, cysd1, cysf2, cysd2 = cs
    kf, kd = dk
    x = bend(x, [0.0], 22.92, 0.015, 0.00, 1.5)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.05)
    x = kick(x, cxsf2, cysf2)
    x = sext(x, [0.0], 0.00, 0.05)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], 0.19 + kf, 0.50)
    return x

ts = [t_01_02,t_02_03, t_03_04, t_04_05]
[6]:
# Set deviation variables

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
cs = torch.tensor(8*[0.0], dtype=dtype, device=device)
dk = torch.tensor(2*[0.0], dtype=dtype, device=device)
[7]:
# Define one-turn transport at the lattice entrance

def fodo(x, cs, kq):
    for t in ts:
        x = t(x, cs, kq)
    return x
[8]:
# Test one-turn transport

print(fodo(x, cs, dk))
tensor([0., 0., 0., 0.], dtype=torch.float64)
[9]:
# Compute (dynamical) fixed point
# Note, dynamical part is assumed to be fixed during optimization

fp = fixed_point(16, fodo, x, cs, dk, power=1, jacobian=torch.func.jacrev)

print(fp)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[10]:
# Define parametric responce matrix

def rm(dk):

    pfp = parametric_fixed_point((1, ), fp, [cs], lambda x, cs: fodo(x, cs, dk), jacobian=torch.func.jacrev)
    chop(pfp)
    _, (dqx, _, dqy, _) = first(pfp)

    out = [torch.stack([dqx, dqy])]
    for t in ts:
        pfp = propagate((4, 8), (0, 1), pfp, [cs], lambda x, cs: t(x, cs, dk),  jacobian=torch.func.jacrev)
        chop(pfp)
        _, (dqx, _, dqy, _) = first(pfp)
        out.append(torch.stack([dqx, dqy]))

    return torch.stack(out).swapaxes(0, 1).reshape(-1, len(cs))

print(2*(len(ts) + 1), len(cs))
print(rm(dk).shape)
print()

print(rm(dk))
10 8
torch.Size([10, 8])

tensor([[7.577e+00, 5.936e+00, 7.577e+00, 5.936e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [6.216e+00, 4.039e+00, 6.749e+00, 4.566e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [5.110e+00, 2.611e+00, 5.110e+00, 2.611e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [6.749e+00, 4.566e+00, 6.216e+00, 4.039e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [7.577e+00, 5.936e+00, 7.577e+00, 5.936e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.344e+01, 2.158e+01, 1.344e+01, 2.158e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.801e+01, 2.744e+01, 1.849e+01, 2.792e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 2.509e+01, 3.667e+01, 2.509e+01, 3.667e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.849e+01, 2.792e+01, 1.801e+01, 2.744e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.344e+01, 2.158e+01, 1.344e+01, 2.158e+01]], dtype=torch.float64)
[11]:
# Test responce matrix

dc = 1.0E-3*torch.ones_like(cs)

o = fixed_point(16, fodo, x, cs + dc, dk, power=1, jacobian=torch.func.jacrev)

os = []
qx, _, qy, _ = o
os.append(torch.stack([qx, qy]))

for t in ts:
    o = t(o, dc, dk)
    qx, _, qy, _ = o
    os.append(torch.stack([qx, qy]))

print(torch.allclose(torch.stack(os).T.flatten(), rm(dk) @ dc))
True
[12]:
# Set quadrupole gradient errors

ek = torch.tensor([-0.010, 0.005], dtype=dtype, device=device)
[13]:
# Measure ORM

erm = rm(ek)

print(erm)
tensor([[8.038e+00, 6.338e+00, 8.038e+00, 6.338e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [6.658e+00, 4.415e+00, 7.190e+00, 4.942e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [5.488e+00, 2.923e+00, 5.488e+00, 2.923e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [7.190e+00, 4.942e+00, 6.658e+00, 4.415e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [8.038e+00, 6.338e+00, 8.038e+00, 6.338e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.470e+01, 2.316e+01, 1.470e+01, 2.316e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.943e+01, 2.915e+01, 1.990e+01, 2.963e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 2.678e+01, 3.865e+01, 2.678e+01, 3.865e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.990e+01, 2.963e+01, 1.943e+01, 2.915e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.470e+01, 2.316e+01, 1.470e+01, 2.316e+01]], dtype=torch.float64)
[14]:
# Set data

i_max, j_max = erm.shape
i_val, j_val = torch.arange(i_max), torch.arange(j_max)
X = torch.vstack([*torch.stack(torch.meshgrid(i_val, j_val, indexing='xy')).swapaxes(0, -1)])
y = erm.clone().flatten()

batch_size = 16
dataset = TensorDataset(X.clone(), y.clone())
dataset, validation = random_split(dataset, [0.80, 0.20])

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
[15]:
# Set model

class Model(torch.nn.Module):

    def __init__(self, knobs):
        super().__init__()
        self.knobs = torch.nn.Parameter(torch.clone(knobs))

    def forward(self, x):
        i, j = x.unsqueeze(0).swapaxes(0, -1)
        return (rm(self.knobs)[i, j]).squeeze()
[16]:
# Set model instance
# Note, initial knobs are set to zero

model = Model(torch.zeros_like(dk))
[17]:
# Test model

i, j = 0, 0

print(rm(dk)[i, j])
print(model(torch.tensor([[i, j]])))
print()
tensor(7.577e+00, dtype=torch.float64)
tensor(7.577e+00, dtype=torch.float64, grad_fn=<SqueezeBackward0>)

[18]:
# Set optimizer

lr = 1.0E-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
[19]:
# Set loss function

lf = torch.nn.MSELoss()
[20]:
# Fit model
# Note, each epoch loss is computed for full validation set

epochs = 128

print()
print(ek)
print(model.knobs)
print()

knobs, errors = [], []

for epoch in range(epochs):
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        y_hat = model(X)
        error = lf(y_hat, y)
        with torch.no_grad():
            knobs.append(model.knobs.clone().detach())
            errors.append(error.clone().detach())
        error.backward()
        optimizer.step()
        optimizer.zero_grad()
    model.eval()
    X, y = validation.dataset.tensors
    test = lf(model(X[validation.indices]), y[validation.indices])
    if epoch % 10 == 0:
        print(f'epoch: {epoch}, error: {error.item()} / {test.item()}')

print()
print(ek)
print(model.knobs)
print()

tensor([-1.000e-02, 5.000e-03], dtype=torch.float64)
Parameter containing:
tensor([0., 0.], dtype=torch.float64, requires_grad=True)

epoch: 0, error: 0.37529107626509794 / 0.72679523282213
epoch: 10, error: 0.019561891190916357 / 0.08595406190271199
epoch: 20, error: 0.028442131733101884 / 0.059218224485797756
epoch: 30, error: 0.024756037639704742 / 0.03382898705876869
epoch: 40, error: 0.007808066790505313 / 0.014454231605437087
epoch: 50, error: 0.003117435961114121 / 0.006542069654623646
epoch: 60, error: 0.0012404708769704496 / 0.00196001107485135
epoch: 70, error: 0.0002826792552836003 / 0.000523081505929046
epoch: 80, error: 6.0312874463179874e-05 / 0.0001471941670079705
epoch: 90, error: 2.3739166186638856e-05 / 5.9827396259044656e-05
epoch: 100, error: 6.057427185883081e-06 / 7.80155998736023e-06
epoch: 110, error: 1.3473974612042764e-06 / 1.5022630951121078e-06
epoch: 120, error: 7.86204138851082e-08 / 1.297239564506323e-07

tensor([-1.000e-02, 5.000e-03], dtype=torch.float64)
Parameter containing:
tensor([-9.994e-03, 4.998e-03], dtype=torch.float64, requires_grad=True)

[21]:
# Plot error vs iteration

plt.figure(figsize=(20, 5))
plt.plot(range(len(errors)), torch.stack(errors).cpu().numpy(), color='black', marker='x')
plt.show()
../_images/examples_ndmap_448_0.png
[22]:
# Plot knobs vs iteration

plt.figure(figsize=(20, 5))
plt.hlines(ek.cpu().numpy(), 0, len(errors), linestyles='dashed', color='gray', alpha=0.75)
for knob in torch.stack(knobs).T:
    plt.plot(range(len(knob)), knob.cpu().numpy(), color='black', marker='x')
plt.show()
../_images/examples_ndmap_449_0.png

Example-36: ORM optics correction (training loop + elementwise computation)

[1]:
# In this example orbit responce matrix (ORM) is used to correct linear optics in a simple FODO cell
# Two gradient errors are introduced into cell quadrupoles

# This example illustrates one optimization step
# Given a measured ORM, the model knobs are fitted to reproduce it
# Next, the corrections should be applied and the matrix should be remeasured

# Fitting step mirrors neural net training loop
# Elements of measured responce matrix are used as targets

# Note, elements of ORM are computed in forward method, but computation is sequential
[2]:
# Import

from functools import partial

import numpy
import torch

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split

from ndmap.util import flatten
from ndmap.util import first
from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.evaluate import evaluate
from ndmap.evaluate import compare
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import parametric_fixed_point

torch.set_printoptions(precision=3, sci_mode=True, linewidth=128)
print(torch.cuda.is_available())

import matplotlib.pyplot as plt
True
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=5):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=1):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=10):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def kick(x, cx, cy):
    (qx, px, qy, py) = x
    return torch.stack([qx, px + cx, qy, py + cy])

def slip(x, dx, dy):
    (qx, px, qy, py) = x
    return torch.stack([qx + dx, px, qy + dy, py])
[5]:
# Set transport maps between observation points
# Note, transport maps are expected to have identical (differentiable) signature

def t_01_02(x, cs, dk):
    cxsf1, cxsd1, cxsf2, cxsd2, cysf1, cysd1, cysf2, cysd2 = cs
    kf, kd = dk
    x = quad(x, [0.0], 0.19 + kf, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.05)
    x = kick(x, cxsf1, cysf1)
    x = sext(x, [0.0], 0.00, 0.05)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015, 0.00, 1.5)
    return x

def t_02_03(x, cs, dk):
    cxsf1, cxsd1, cxsf2, cxsd2, cysf1, cysd1, cysf2, cysd2 = cs
    kf, kd = dk
    x = bend(x, [0.0], 22.92, 0.015, 0.00, 1.5)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.05)
    x = kick(x, cxsd1, cysd1)
    x = sext(x, [0.0], 0.00, 0.05)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], -0.21 + kd, 0.50)
    return x

def t_03_04(x, cs, dk):
    cxsf1, cxsd1, cxsf2, cxsd2, cysf1, cysd1, cysf2, cysd2 = cs
    kf, kd = dk
    x = quad(x, [0.0], -0.21 + kd, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.05)
    x = kick(x, cxsd2, cysd2)
    x = sext(x, [0.0], 0.00, 0.05)
    x = drif(x, [0.0], 0.45)
    x = bend(x, [0.0], 22.92, 0.015, 0.00, 1.5)
    return x

def t_04_05(x, cs, dk):
    cxsf1, cxsd1, cxsf2, cxsd2, cysf1, cysd1, cysf2, cysd2 = cs
    kf, kd = dk
    x = bend(x, [0.0], 22.92, 0.015, 0.00, 1.5)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.00, 0.05)
    x = kick(x, cxsf2, cysf2)
    x = sext(x, [0.0], 0.00, 0.05)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], 0.19 + kf, 0.50)
    return x
[6]:
# Set deviation variables

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
cs = torch.tensor(8*[0.0], dtype=dtype, device=device)
dk = torch.tensor(2*[0.0], dtype=dtype, device=device)
[7]:
# Define one-turn transport at the lattice entrance

def fodo(x, cs, kq):
    for t in [t_01_02, t_02_03, t_03_04, t_04_05]:
        x = t(x, cs, kq)
    return x
[8]:
# Test one-turn transport

print(fodo(x, cs, dk))
tensor([0., 0., 0., 0.], dtype=torch.float64)
[9]:
# Set rings at observation points

def ring(x, cs, dk, s=0):
    ts = [t_01_02, t_02_03, t_03_04, t_04_05]
    ts = ts[s:] + ts[:s]
    for t in ts:
        x = t(x, cs, dk)
    return x

rs = [partial(ring, s=s) for s in range(5)]
[10]:
# Compute fixed point
# Note, dynamical part is assumed to be fixed during optimization

fp = fixed_point(16, fodo, x, cs, dk, power=1, jacobian=torch.func.jacrev)

print(fp)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[11]:
# Compute fixed points for all rings
# Note, dynamical part is assumed to be fixed during optimization

fps = torch.stack([fixed_point(16, r, x, cs, dk, power=1, jacobian=torch.func.jacrev) for r in rs])

print(fps)
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]], dtype=torch.float64)
[12]:
# Define parametric responce matrix

def rm(dk):

    pfp = parametric_fixed_point((1, ), fp, [cs], lambda x, cs: fodo(x, cs, dk), jacobian=torch.func.jacrev)
    chop(pfp)
    _, (dqx, _, dqy, _) = first(pfp)

    out = [torch.stack([dqx, dqy])]
    for t in [t_01_02, t_02_03, t_03_04, t_04_05]:
        pfp = propagate((4, 8), (0, 1), pfp, [cs], lambda x, cs: t(x, cs, dk),  jacobian=torch.func.jacrev)
        chop(pfp)
        _, (dqx, _, dqy, _) = first(pfp)
        out.append(torch.stack([dqx, dqy]))

    return torch.stack(out).swapaxes(0, 1).reshape(-1, len(cs))

print(2*(len([t_01_02, t_02_03, t_03_04, t_04_05]) + 1), len(cs))
print(rm(dk).shape)
print()

print(rm(dk))
10 8
torch.Size([10, 8])

tensor([[7.577e+00, 5.936e+00, 7.577e+00, 5.936e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [6.216e+00, 4.039e+00, 6.749e+00, 4.566e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [5.110e+00, 2.611e+00, 5.110e+00, 2.611e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [6.749e+00, 4.566e+00, 6.216e+00, 4.039e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [7.577e+00, 5.936e+00, 7.577e+00, 5.936e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.344e+01, 2.158e+01, 1.344e+01, 2.158e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.801e+01, 2.744e+01, 1.849e+01, 2.792e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 2.509e+01, 3.667e+01, 2.509e+01, 3.667e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.849e+01, 2.792e+01, 1.801e+01, 2.744e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.344e+01, 2.158e+01, 1.344e+01, 2.158e+01]], dtype=torch.float64)
[13]:
# Test responce matrix

dc = 1.0E-3*torch.ones_like(cs)

o = fixed_point(16, fodo, x, cs + dc, dk, power=1, jacobian=torch.func.jacrev)

os = []
qx, _, qy, _ = o
os.append(torch.stack([qx, qy]))

for t in [t_01_02, t_02_03, t_03_04, t_04_05]:
    o = t(o, dc, dk)
    qx, _, qy, _ = o
    os.append(torch.stack([qx, qy]))

print(torch.allclose(torch.stack(os).T.flatten(), rm(dk) @ dc))
True
[14]:
# Define parametric responce matrix element

def rm_ijk(dk, ijk):
    i, j, k = ijk
    t = lambda x, *cs: rs[i](x, torch.stack(cs).flatten(), dk)
    v = tuple(torch.eye(len(cs), dtype=torch.int)[j].tolist())
    fp = fps[i]
    pfp = parametric_fixed_point(v, fp, list(cs.reshape(-1, 1)), t, jacobian=torch.func.jacrev)
    _, (dqx, _, dqy, _) = [*flatten(pfp, target=list)]
    return torch.stack([dqx, dqy])[k].squeeze()
[15]:
# Set quadrupole gradient errors

ek = torch.tensor([-0.010, 0.005], dtype=dtype, device=device)
[16]:
# Measure ORM

erm = rm(ek)

print(erm)
tensor([[8.038e+00, 6.338e+00, 8.038e+00, 6.338e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [6.658e+00, 4.415e+00, 7.190e+00, 4.942e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [5.488e+00, 2.923e+00, 5.488e+00, 2.923e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [7.190e+00, 4.942e+00, 6.658e+00, 4.415e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [8.038e+00, 6.338e+00, 8.038e+00, 6.338e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.470e+01, 2.316e+01, 1.470e+01, 2.316e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.943e+01, 2.915e+01, 1.990e+01, 2.963e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 2.678e+01, 3.865e+01, 2.678e+01, 3.865e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.990e+01, 2.963e+01, 1.943e+01, 2.915e+01],
        [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.470e+01, 2.316e+01, 1.470e+01, 2.316e+01]], dtype=torch.float64)
[17]:
# Set data

i_max, j_max, k_max = 5, 8, 2
i_val, j_val, k_val = torch.arange(i_max), torch.arange(j_max), torch.arange(k_max)
X = torch.vstack([*torch.vstack([*torch.stack(torch.meshgrid(i_val, j_val, k_val, indexing='xy')).swapaxes(0, -1)])])
y = torch.stack([rm_ijk(ek, ijk) for ijk in X])

batch_size = 16
dataset = TensorDataset(X.clone(), y.clone())
dataset, validation = random_split(dataset, [0.80, 0.20])

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
[18]:
# Set model

class Model(torch.nn.Module):

    def __init__(self, knobs):
        super().__init__()
        self.knobs = torch.nn.Parameter(torch.clone(knobs))

    def forward(self, x):
        return torch.stack([rm_ijk(self.knobs, ijk) for ijk in x]).squeeze()
[19]:
# Set model instance
# Note, initial knobs are set to zero

model = Model(torch.zeros_like(dk))
[20]:
# Set optimizer

lr = 1.0E-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
[21]:
# Set loss function

lf = torch.nn.MSELoss()
[22]:
# Fit model
# Note, each epoch loss is computed for full validation set

epochs = 128

print()
print(ek)
print(model.knobs)
print()

knobs, errors = [], []

for epoch in range(epochs):
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        y_hat = model(X)
        error = lf(y_hat, y)
        with torch.no_grad():
            knobs.append(model.knobs.clone().detach())
            errors.append(error.clone().detach())
        error.backward()
        optimizer.step()
        optimizer.zero_grad()
    model.eval()
    X, y = validation.dataset.tensors
    test = lf(model(X[validation.indices]), y[validation.indices])
    if epoch % 10 == 0:
        print(f'epoch: {epoch}, error: {error.item()} / {test.item()}')

print()
print(ek)
print(model.knobs)
print()

tensor([-1.000e-02, 5.000e-03], dtype=torch.float64)
Parameter containing:
tensor([0., 0.], dtype=torch.float64, requires_grad=True)

epoch: 0, error: 0.3207854421804211 / 0.4834669895527817
epoch: 10, error: 0.06071210114479893 / 0.04106920029560601
epoch: 20, error: 0.017744182732083627 / 0.01457174468546367
epoch: 30, error: 0.003963890419609968 / 0.0034633084736078864
epoch: 40, error: 0.0009804767999389498 / 0.0005130453666292978
epoch: 50, error: 8.948304821381796e-05 / 0.00014339307948138367
epoch: 60, error: 4.0143085845154395e-06 / 2.272965287854233e-05
epoch: 70, error: 3.4453470618958206e-06 / 3.3214685216885225e-06
epoch: 80, error: 1.387860215946706e-07 / 9.034775495586135e-08
epoch: 90, error: 6.611469669425412e-08 / 4.719136099019071e-09
epoch: 100, error: 3.0221850571449367e-10 / 2.1032181485124297e-09
epoch: 110, error: 1.9253040326146467e-11 / 2.887542634043959e-13
epoch: 120, error: 1.9606914010193334e-13 / 5.526546647035777e-14

tensor([-1.000e-02, 5.000e-03], dtype=torch.float64)
Parameter containing:
tensor([-1.000e-02, 5.000e-03], dtype=torch.float64, requires_grad=True)

[23]:
# Plot error vs iteration

plt.figure(figsize=(20, 5))
plt.plot(range(len(errors)), torch.stack(errors).cpu().numpy(), color='black', marker='x')
plt.show()
../_images/examples_ndmap_473_0.png
[24]:
# Plot knobs vs iteration

plt.figure(figsize=(20, 5))
plt.hlines(ek.cpu().numpy(), 0, len(errors), linestyles='dashed', color='gray', alpha=0.75)
for knob in torch.stack(knobs).T:
    plt.plot(range(len(knob)), knob.cpu().numpy(), color='black', marker='x')
plt.show()
../_images/examples_ndmap_474_0.png

Example-37: Normalized dispersion

[1]:
# Normalized dispersion can be used for calibration independent correction (10.1109/PAC.2007.4440536)
# In this example derivatives of normalized dispersion with respect to quadrupole amplitudes are computed
[2]:
# Import

from functools import partial

import numpy
import torch

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split

from ndmap.util import flatten
from ndmap.util import first
from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.evaluate import evaluate
from ndmap.evaluate import compare
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import parametric_fixed_point

from twiss.wolski import twiss
from twiss.wolski import propagate as propagate_twiss
from twiss.convert import wolski_to_cs

torch.set_printoptions(precision=3, sci_mode=True, linewidth=128)
print(torch.cuda.is_available())

import matplotlib.pyplot as plt
True
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=5):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=1):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=10):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def kick(x, cx, cy):
    (qx, px, qy, py) = x
    return torch.stack([qx, px + cx, qy, py + cy])

def slip(x, dx, dy):
    (qx, px, qy, py) = x
    return torch.stack([qx + dx, px, qy + dy, py])
[5]:
# Set transport maps between observation points
# Note, transport maps are expected to have identical (differentiable) signature

def t_01_02(x, w, dk):
    kf, kd = dk
    x = quad(x, w, 0.19 + kf, 0.50)
    x = drif(x, w, 0.45)
    x = sext(x, w, 0.00, 0.1)
    x = drif(x, w, 0.45)
    x = bend(x, w, 22.92, 0.015, 0.00, 1.5)
    return x

def t_02_03(x, w, dk):
    kf, kd = dk
    x = bend(x, w, 22.92, 0.015, 0.00, 1.5)
    x = drif(x, w, 0.45)
    x = sext(x, w, 0.00, 0.1)
    x = drif(x, w, 0.45)
    x = quad(x, w, -0.21 + kd, 0.50)
    return x

def t_03_04(x, w, dk):
    kf, kd = dk
    x = quad(x, w, -0.21 + kd, 0.50)
    x = drif(x, w, 0.45)
    x = sext(x, w, 0.00, 0.1)
    x = drif(x, w, 0.45)
    x = bend(x, w, 22.92, 0.015, 0.00, 1.5)
    return x

def t_04_05(x, w, dk):
    kf, kd = dk
    x = bend(x, w, 22.92, 0.015, 0.00, 1.5)
    x = drif(x, w, 0.45)
    x = sext(x, w, 0.00, 0.1)
    x = drif(x, w, 0.45)
    x = quad(x, w, 0.19 + kf, 0.50)
    return x

ts = [t_01_02,t_02_03, t_03_04, t_04_05]
[6]:
# Set deviation variables

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
w = torch.tensor(1*[0.0], dtype=dtype, device=device)
dk = torch.tensor(2*[0.0], dtype=dtype, device=device)
[7]:
# Define one-turn transport at the lattice entrance

def fodo(x, w, dk):
    for t in ts:
        x = t(x, w, dk)
    return x
[8]:
# Test one-turn transport

print(fodo(x, w, dk))
tensor([0., 0., 0., 0.], dtype=torch.float64)
[9]:
# Compute (dynamical) fixed point
# Note, dynamical part is assumed to be fixed during optimization

fp = fixed_point(16, fodo, x, w, dk, power=1, jacobian=torch.func.jacrev)

print(fp)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[10]:
# Normalized (horizontal) dispersion

def dispersion(dk):

    # Set container for parametric fixed points

    pfps = []

    # Set container for x and y dispersions

    etas = []

    # Compute parametric fixed point and set dispersion at the lattice entrance

    pfp = parametric_fixed_point((1, ), fp, [w], lambda x, w: fodo(x, w, dk), jacobian=torch.func.jacrev)
    chop(pfp)
    pfps.append(pfp)

    _, (etax, _, etay, _) = first(pfp)
    etas.append(torch.stack([etax, etay]))

    # Propagate fixed point and set dispersion values

    for t in ts:
        pfp = propagate((4, 1), (0, 1), pfp, [w], lambda x, w: t(x, w, dk), jacobian=torch.func.jacrev)
        chop(pfp)
        pfps.append(pfp)
        _, (etax, _, etay, _) = first(pfp)
        etas.append(torch.stack([etax, etay]))

    # Set dispersion at all observation points

    etaxs, etays = torch.hstack(etas)

    # Define wrapper for transport maps

    def wrapper(x, w, dk, transport, pfp_in, pfp_out):
        x = x + evaluate(first(pfp_in), [w])
        x = transport(x, w, dk)
        x = x - evaluate(first(pfp_out), [w])
        return x

    # Set containers for beta functions

    bxs = []
    bys = []

    # Compute beta functions at the lattice entrance

    pfp_in = first(pfps)
    pfp_out = first(pfps)

    matrix = derivative(1, lambda x: wrapper(x, w, dk, fodo, pfp_in, pfp_out), fp, intermediate=False, jacobian=torch.func.jacrev)

    *_, m = twiss(matrix)
    _, bx, _, by = wolski_to_cs(m)
    bxs.append(bx)
    bys.append(by)

    # Propagate twiss

    for i, t in enumerate(ts):
        pfp_in = pfps[i]
        pfp_out = pfps[i + 1]
        m = propagate_twiss(m, derivative(1, lambda x: wrapper(x, w, dk, t, pfp_in, pfp_out), x, intermediate=False, jacobian=torch.func.jacrev))
        _, bx, _, by = wolski_to_cs(m)
        bxs.append(bx)
        bys.append(by)

    bxs = torch.stack(bxs)
    bys = torch.stack(bys)

    # Set normalized dispersions (exclude lattice exit)

    *etaxs, _ = etaxs/bxs
    *etaxy, _ = etays/bys

    return torch.stack([*etaxs, *etays])
[11]:
# Compute normalized dispersion derivatives (responce matrix)

rm = derivative(1, dispersion, dk, intermediate=False, jacobian=torch.func.jacrev)

print(rm)
tensor([[-8.795e-01, -1.390e-01],
        [-8.732e-01, -2.077e-01],
        [-4.797e-01, -4.941e-01],
        [-8.732e-01, -2.077e-01],
        [0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00]], dtype=torch.float64)
[12]:
# Check covergence

ek = torch.tensor([-0.001, 0.001], dtype=dtype, device=device)

print(dispersion(dk))
print(dispersion(ek))
print(dispersion(dk) + rm @ ek)
tensor([1.169e-01, 1.568e-01, 2.615e-01, 1.568e-01, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
       dtype=torch.float64)
tensor([1.176e-01, 1.575e-01, 2.615e-01, 1.575e-01, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
       dtype=torch.float64)
tensor([1.176e-01, 1.575e-01, 2.615e-01, 1.575e-01, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00],
       dtype=torch.float64)
[13]:
# Correction

# The target values (normalized dispersion) are associated with model response matrix
# Given measured values, the goal is to alter knobs to get target values

# Set target values

vf = dispersion(dk)

# Set initial solution

sol = torch.zeros_like(dk)


# Iterate

for _ in range(4):

    # Compute current values and set difference

    vi = dispersion(ek + sol)

    # Set difference

    dv = vf - vi

    # Update solution

    sol += torch.linalg.pinv(rm) @ dv

    # Verbose

    print(-dk)
    print(sol)
    print(dv.norm())
    print()

    # Continue
tensor([-0., -0.], dtype=torch.float64)
tensor([9.987e-04, -9.924e-04], dtype=torch.float64)
tensor(1.198e-03, dtype=torch.float64)

tensor([-0., -0.], dtype=torch.float64)
tensor([1.000e-03, -1.000e-03], dtype=torch.float64)
tensor(3.185e-06, dtype=torch.float64)

tensor([-0., -0.], dtype=torch.float64)
tensor([1.000e-03, -1.000e-03], dtype=torch.float64)
tensor(1.955e-11, dtype=torch.float64)

tensor([-0., -0.], dtype=torch.float64)
tensor([1.000e-03, -1.000e-03], dtype=torch.float64)
tensor(2.021e-16, dtype=torch.float64)

Example-38: Coupling (minimal tune distance)

[1]:
# In this example, minimal tune distance is computed using TEAPOT expression
# The value is compared with some approximate analytical expressions
# Computation dQmin of derivative (gradient) is illustrated
[2]:
# Import

import numpy
import torch

from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.series import series
from ndmap.series import clean
from ndmap.evaluate import evaluate
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import parametric_fixed_point

from twiss.wolski import twiss
from twiss.convert import wolski_to_cs
from twiss.matrix import symplectic_conjugate

torch.set_printoptions(precision=3, sci_mode=True, linewidth=128)
print(torch.cuda.is_available())

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
False
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=50):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=10):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=50):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def roll(x, a):
    (qx, px, qy, py), cn, sn = x, a.cos(), a.sin()
    return torch.stack([qx*cn + qy*sn, px*cn + py*sn, qy*cn - qx*sn, py*cn - py*sn])

def kick(x, kn, ks):
    (qx, px, qy, py), kn, ks = x, kn, ks
    return torch.stack([qx, px - kn*qx + ks*qy, qy, py + ks*qx + kn*qy])
[5]:
# Set transport maps between observation points

def map_01_02(x, k):
    kf, kd = k
    x = kick(x, 0.0, kf/2.0)
    x = quad(x, [0.0], 0.21, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.0, 0.10)
    x = drif(x, [0.0], 0.45)
    x = drif(x, [0.0], 3.0)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.0, 0.10)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], -0.18, 0.50)
    x = kick(x, 0.0, kd/2.0)
    x = kick(x, 0.0, kd/2.0)
    x = quad(x, [0.0], -0.18, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.0, 0.10)
    x = drif(x, [0.0], 0.45)
    x = drif(x, [0.0], 3.0)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.0, 0.10)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], 0.21, 0.50)
    x = kick(x, 0.0, kf/2.0)
    return x

transport = [
    map_01_02
]

# Define one-turn transport

def fodo(x, k):
    for mapping in transport:
        x = mapping(x, k)
    return x

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
k = torch.tensor(2*[0.0], dtype=dtype, device=device)

print(fodo(x, k))
tensor([0., 0., 0., 0.], dtype=torch.float64)
[6]:
# Compute fixed point

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
k = torch.tensor(2*[0.0], dtype=dtype, device=device)

fp = fixed_point(16, fodo, x, k, power=1)
print(fp)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[7]:
# Compute parametric fixed point

pfp = parametric_fixed_point((1, ), fp, [k], fodo)
chop(pfp)
pfp
[7]:
[[tensor([0., 0., 0., 0.], dtype=torch.float64),
  tensor([[0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.]], dtype=torch.float64)]]
[8]:
# Propagate parametric fixed point

out = propagate((4, 2), (0, 1), pfp, [k], fodo)
chop(out)
out
[8]:
[[tensor([0., 0., 0., 0.], dtype=torch.float64),
  tensor([[0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.]], dtype=torch.float64)]]
[9]:
# Propagate parametric identity (surrogate model for linear dynamics)

jet = identity((1, 1), fp, parametric=pfp)
jet = propagate((4, 2), (1, 1), jet, [k], fodo)
[10]:
# Compute uncoupled one-turn matrix (zero skew quadrupole amplitudes)

m = derivative(1, lambda x: evaluate(jet, [x, k]), fp, intermediate=False)
m
[10]:
tensor([[2.192e-01, 1.772e+01, 0.000e+00, 0.000e+00],
        [-5.372e-02, 2.192e-01, 0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00, 5.733e-01, 6.018e+00],
        [0.000e+00, 0.000e+00, -1.116e-01, 5.733e-01]], dtype=torch.float64)
[11]:
# Compute (uncoupled) CS twiss parameters

(nux, nuy), _, w = twiss(m)

mux, muy = 2.0*torch.pi*nux, 2.0*torch.pi*nuy

ax, bx, ay, by = wolski_to_cs(w)

torch.stack([ax, bx, ay, by])
[11]:
tensor([7.539e-16, 1.816e+01, 1.423e-15, 7.345e+00], dtype=torch.float64)
[12]:
# Compute coupled one-turn matrix

dm = derivative(1, kick, x, 0.0, 0.5E-3, intermediate=False)
dm @ m @ dm
[12]:
tensor([[ 2.192e-01,  1.772e+01,  8.859e-03,  0.000e+00],
        [-5.372e-02,  2.192e-01,  3.963e-04,  3.009e-03],
        [ 3.009e-03,  0.000e+00,  5.733e-01,  6.018e+00],
        [ 3.963e-04,  8.859e-03, -1.115e-01,  5.733e-01]], dtype=torch.float64)
[13]:
# Coupled one-turn matrix from jet
# Note, jet is first order in skew quadrupole strenght

dkf = 1.0E-3
dkd = 0.0

dk = torch.tensor([dkf, dkd], dtype=dtype, device=device)

m = derivative(1, lambda x: evaluate(jet, [x, k + dk]), fp, intermediate=False)
m
[13]:
tensor([[ 2.192e-01,  1.772e+01,  8.859e-03,  0.000e+00],
        [-5.372e-02,  2.192e-01,  3.963e-04,  3.009e-03],
        [ 3.009e-03,  0.000e+00,  5.733e-01,  6.018e+00],
        [ 3.963e-04,  8.859e-03, -1.116e-01,  5.733e-01]], dtype=torch.float64)
[14]:
# |dQmin| (Edwards & Shyphers, first order in amplitude and unperturbed tune differenct)

f'{abs(dkf)/(2.0*torch.pi)*(bx*by).sqrt().item():.6e}'
[14]:
'1.838149e-03'
[15]:
# dQmin (first order in amplitude)
# Note, tunes in [0, 1/2] are assumed

f'{abs(dkf)/(torch.pi)*(bx*by).sqrt()*(mux.sin()*muy.sin()).abs().sqrt()/(mux.sin() + muy.sin()).item():.6e}'
[15]:
'1.831163e-03'
[16]:
# dQmin (TEAPOT manual, appendix G, 1996)
# Note,

(NUX, NUY), *_ = twiss(m)

mux, muy = 2.0*torch.pi*NUX, 2.0*torch.pi*NUY

B = m[:2, 2:]
C = m[2:, :2]

f'{torch.linalg.det(C + symplectic_conjugate(B)).abs().sqrt()/(torch.pi*(mux.sin() + muy.sin())).item():.6e}'
[16]:
'1.831193e-03'
[17]:
# Effect of skew quadrupole on tunes

torch.stack([nux - NUX, nuy - NUY]).abs()
[17]:
tensor([1.275e-05, 1.314e-05], dtype=torch.float64)
[18]:
# dQmin derivative (TEAPOT manual, appendix G, 1996)

def dQmin(k):
    m = derivative(1, lambda x: fodo(x, k), fp, intermediate=False)
    (nux, nuy), *_ = twiss(m)
    mux, muy = 2.0*torch.pi*nux, 2.0*torch.pi*nuy
    B = m[:2, 2:]
    C = m[2:, :2]
    return (C + symplectic_conjugate(B)).diag().prod().abs().sqrt()/(mux.sin() + muy.sin())/torch.pi

print(dQmin(k + dk))
print()

# Derivative a zero is not defined

print(derivative(1, dQmin, k, intermediate=False))
print()

# Derivatives at points near zero are valid (note the sign flip)

print(derivative(1, dQmin, k + 1.0E-16, intermediate=False))
print(derivative(1, dQmin, k - 1.0E-16, intermediate=False))
print()

# Derivative at a point

print(derivative(1, dQmin, k + dk, intermediate=False))
print()
tensor(1.831e-03, dtype=torch.float64)

tensor([nan, nan], dtype=torch.float64)

tensor([1.835e+00, 1.715e+00], dtype=torch.float64)
tensor([-1.835e+00, -1.715e+00], dtype=torch.float64)

tensor([1.831e+00, 1.725e+00], dtype=torch.float64)

[19]:
# Plot dQmin and its derivative norm on a grid

dkf = torch.linspace(-5.0E-3, +5.0E-3, 128, dtype=dtype, device=device)
dkd = torch.linspace(-5.0E-3, +5.0E-3, 128, dtype=dtype, device=device)

dk = torch.stack(torch.meshgrid(dkf, dkd, indexing='ij')).swapaxes(-1, 0).reshape(128*128, -1)
dQ = torch.vmap(dQmin)(dk).reshape(128, 128)

plt.figure(figsize=(8, 8))
plt.imshow(dQ.cpu().numpy(), cmap='plasma', interpolation='bilinear', origin='upper', extent=(-5.0E-3, +5.0E-3, -5.0E-3, +5.0E-3))
plt.colorbar(fraction=0.045, pad=0.05)
plt.contour(dQ.cpu().numpy(), origin='lower', extend='both', linewidths=1, extent=(-5.0E-3, +5.0E-3, -5.0E-3, +5.0E-3), colors='black')
plt.tight_layout()
plt.show()
../_images/examples_ndmap_508_0.png

Example-39: Coupling (minimal tune distance correction)

[1]:
# In this example point derivative of dQmin (minimal tune distance) is used for coupling correction illustration
# A pair of skew quadrupole errors are added and correction is performed with GD
[2]:
# Import

import numpy
import torch

from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.series import series
from ndmap.series import clean
from ndmap.evaluate import evaluate
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import parametric_fixed_point

from twiss.wolski import twiss
from twiss.convert import wolski_to_cs
from twiss.matrix import symplectic_conjugate

torch.set_printoptions(precision=3, sci_mode=True, linewidth=128)
print(torch.cuda.is_available())

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
False
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=50):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=10):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=50):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def roll(x, a):
    (qx, px, qy, py), cn, sn = x, a.cos(), a.sin()
    return torch.stack([qx*cn + qy*sn, px*cn + py*sn, qy*cn - qx*sn, py*cn - py*sn])

def kick(x, kn, ks):
    (qx, px, qy, py), kn, ks = x, kn, ks
    return torch.stack([qx, px - kn*qx + ks*qy, qy, py + ks*qx + kn*qy])
[5]:
# Set transport maps between observation points (close tunes, bends are replaced with drifts)

def map_01_02(x, k):
    kf, kd = k
    x = kick(x, 0.0, kf/2.0)
    x = quad(x, [0.0], 0.21, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.0, 0.10)
    x = drif(x, [0.0], 0.45)
    x = drif(x, [0.0], 3.0)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.0, 0.10)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], -0.18, 0.50)
    x = kick(x, 0.0, kd/2.0)
    x = kick(x, 0.0, kd/2.0)
    x = quad(x, [0.0], -0.18, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.0, 0.10)
    x = drif(x, [0.0], 0.45)
    x = drif(x, [0.0], 3.0)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.0, 0.10)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], 0.21, 0.50)
    x = kick(x, 0.0, kf/2.0)
    return x

transport = [
    map_01_02
]

# Define one-turn transport

def fodo(x, k):
    for mapping in transport:
        x = mapping(x, k)
    return x

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
k = torch.tensor(2*[0.0], dtype=dtype, device=device)

print(fodo(x, k))
tensor([0., 0., 0., 0.], dtype=torch.float64)
[6]:
# Compute fixed point

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
k = torch.tensor(2*[0.0], dtype=dtype, device=device)

fp = fixed_point(16, fodo, x, k, power=1)
print(fp)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[7]:
# Compute parametric fixed point

pfp = parametric_fixed_point((1, ), fp, [k], fodo)
chop(pfp)
pfp
[7]:
[[tensor([0., 0., 0., 0.], dtype=torch.float64),
  tensor([[0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.]], dtype=torch.float64)]]
[8]:
# Propagate parametric fixed point

out = propagate((4, 2), (0, 1), pfp, [k], fodo)
chop(out)
out
[8]:
[[tensor([0., 0., 0., 0.], dtype=torch.float64),
  tensor([[0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.]], dtype=torch.float64)]]
[9]:
# Propagate parametric identity (surrogate model for linear dynamics)

jet = identity((1, 1), fp, parametric=pfp)
jet = propagate((4, 2), (1, 1), jet, [k], fodo)
[10]:
# Minimal tuen distance (using 1st order parametric matrix arounf closed orbit)

def dQmin(k):
    m = derivative(1, lambda x: evaluate(jet, [x, k]), fp, intermediate=False)
    (nux, nuy), *_ = twiss(m)
    mux, muy = 2.0*torch.pi*nux, 2.0*torch.pi*nuy
    B = m[:2, 2:]
    C = m[2:, :2]
    (m11, m12), (m21, m22) = C + symplectic_conjugate(B)
    return 1.0/torch.pi * (m11*m22 - m12*m21).abs().sqrt()/(mux.sin() + muy.sin()).abs()
[11]:
# Set skew errors

dkf = 1.0E-2
dkd = 1.0E-2

dk = torch.tensor([dkf, dkd], dtype=dtype, device=device)
[12]:
# Note, the sign flip doesn't change dQmin value

dQmin(+dk) - dQmin(-dk)
[12]:
tensor(0., dtype=torch.float64)
[13]:
# Compute dQmin and gradient

derivative(1, dQmin, k + dk, intermediate=True)
[13]:
[tensor(3.574e-02, dtype=torch.float64),
 tensor([1.871e+00, 1.751e+00], dtype=torch.float64)]
[14]:
# Correction setup (minimize dQmin)

# Set objective
# Note, this objective represents a model with errors, the task is to find knob values that minimize it

error = dk
objective = lambda knobs: dQmin(knobs + error)

# Exact solution

print(objective(-error))

# Set initial guess (zero knobs values in this example)

solution = torch.zeros_like(error)

# Evaluate objective for initial guess

print(objective(solution))
tensor(0., dtype=torch.float64)
tensor(3.574e-02, dtype=torch.float64)
[15]:
# Correction loop (gradient descent)

# Note, small learning rate is required here for convergence

ni = 2048
lr = 2.5E-6
xs = []

for i in range(ni):
    solution -= lr*derivative(1, objective, solution, intermediate=False)
    xs.append(objective(solution))

xs = torch.stack(xs)

# Evaluate objective for final solution

print(dk)
print(-solution)
print(objective(solution))

plt.figure(figsize=(16, 4))
plt.plot(xs.cpu().numpy(), color='blue')
plt.tight_layout()
plt.show()
tensor([1.000e-02, 1.000e-02], dtype=torch.float64)
tensor([9.462e-03, 8.845e-03], dtype=torch.float64)
tensor(2.966e-03, dtype=torch.float64)
../_images/examples_ndmap_524_1.png
[16]:
# Another (prefered) approach would be to fit our model to observation
# In this case, observer value is computed first, the task is to find knob values that fit our model to the observed
# The next step is to interact with experemental model by applying fitted negativd knob values
[17]:
# Set objective

dQmin_observed = dQmin(error)
objective = lambda knobs: (dQmin(knobs) - dQmin_observed)**2
solution = torch.zeros_like(error)

print(dQmin(error))
print(dQmin(error - error))

print(objective(error))
tensor(3.574e-02, dtype=torch.float64)
tensor(0., dtype=torch.float64)
tensor(0., dtype=torch.float64)
[18]:
# The derivative at zero initial guess is not defined

derivative(1, objective, solution, intermediate=False)
[18]:
tensor([nan, nan], dtype=torch.float64)
[19]:
# Set random initial guess

solution = 1.0E-3*torch.randn_like(error)
derivative(1, objective, solution, intermediate=False)
[19]:
tensor([-1.113e-01, -1.049e-01], dtype=torch.float64)
[20]:
# Correction loop (gradient descent)

# Note, in this case learning rate is not as small as in the previouse case
# Thus, the number of iterations is also reduced

ni = 256
lr = 5.0E-3
xs = []

for i in range(ni):
    solution -= lr*derivative(1, objective, solution, intermediate=False)
    xs.append(objective(solution))

k1 = solution.clone()
xs = torch.stack(xs)

# Evaluate objective for final solution

print(dk)
print(-k1)
print(objective(k1))

plt.figure(figsize=(16, 4))
plt.plot(xs.cpu().numpy(), color='blue')
plt.tight_layout()
plt.show()
tensor([1.000e-02, 1.000e-02], dtype=torch.float64)
tensor([-1.170e-02, -8.184e-03], dtype=torch.float64)
tensor(8.324e-19, dtype=torch.float64)
../_images/examples_ndmap_529_1.png
[21]:
# While objective is minimized, the knob values do not correspond to the set errors
# Note, objective values with '+' and '-' are the same here

print(objective(-error))
print(objective(+k1))
print(objective(-k1))
print()
tensor(0., dtype=torch.float64)
tensor(8.324e-19, dtype=torch.float64)
tensor(8.324e-19, dtype=torch.float64)

[22]:
# Now, if we apply the obtained negative solution to the observed system, dQmin value will be reduced, but not as much as if the exact errors were recovered

# Note, a fraction of a fitted solution can be applied
# Also, new correction step can be performed using new observed value

print(dQmin(error - 0.00*k1))
print(dQmin(error - 0.25*k1))
print(dQmin(error - 0.50*k1))
print(dQmin(error - 1.00*k1))
print()
tensor(3.574e-02, dtype=torch.float64)
tensor(2.672e-02, dtype=torch.float64)
tensor(1.777e-02, dtype=torch.float64)
tensor(3.969e-04, dtype=torch.float64)

[23]:
# Perform next correction step

dQmin_observed = dQmin(error - 0.5*k1)
objective = lambda knobs: (dQmin(knobs) - dQmin_observed)**2
[24]:
# Set random initial guess

solution = 1.0E-3*torch.randn_like(error)
derivative(1, objective, solution, intermediate=False)
[24]:
tensor([-5.780e-02, -5.374e-02], dtype=torch.float64)
[25]:
# Correction loop

ni = 256
lr = 5.0E-3
xs = []

for i in range(ni):
    solution -= lr*derivative(1, objective, solution, intermediate=False)
    xs.append(objective(solution))

k2 = solution.clone()
xs = torch.stack(xs)

plt.figure(figsize=(16, 4))
plt.plot(xs.cpu().numpy(), color='blue')
plt.tight_layout()
plt.show()
../_images/examples_ndmap_534_0.png
[26]:
print(dQmin(error))
print(dQmin(error - 0.5*k1))
print(dQmin(error - 0.5*k1 - 0.5*k2))
tensor(3.574e-02, dtype=torch.float64)
tensor(1.777e-02, dtype=torch.float64)
tensor(8.875e-03, dtype=torch.float64)
[27]:
# Note, there seems to be no guarantee for the above procedure to converge

Example-40: Coupling (amplitude ratio correction)

[1]:
# In this example an objective function constructed from ratio of coupled and uncoupled amplitudes is used to minimize coupling
# Amplitudes can be computed from simulated TbT data at one or several locations
[2]:
# Import

import numpy
import torch

from ndmap.derivative import derivative
from ndmap.signature import chop
from ndmap.series import series
from ndmap.series import clean
from ndmap.evaluate import evaluate
from ndmap.propagate import identity
from ndmap.propagate import propagate
from ndmap.pfp import fixed_point
from ndmap.pfp import parametric_fixed_point

from twiss.wolski import twiss
from twiss.convert import wolski_to_cs
from twiss.matrix import symplectic_conjugate

torch.set_printoptions(precision=3, sci_mode=True, linewidth=128)
print(torch.cuda.is_available())

from matplotlib import pyplot as plt

import warnings
warnings.filterwarnings("ignore")
False
[3]:
# Set data type and device

dtype = torch.float64
device = torch.device('cpu')
[4]:
# Set elements

def drif(x, w, l):
    (qx, px, qy, py), (w, ), l = x, w, l
    return torch.stack([qx + l*px/(1 + w), px, qy + l*py/(1 + w), py])

def quad(x, w, kq, l, n=50):
    (qx, px, qy, py), (w, ), kq, l = x, w, kq, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx, py + 2.0*l*kq*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def sext(x, w, ks, l, n=10):
    (qx, px, qy, py), (w, ), ks, l = x, w, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 1.0*l*ks*(qx**2 - qy**2), py + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def bend(x, w, r, kq, ks, l, n=50):
    (qx, px, qy, py), (w, ), r, kq, ks, l = x, w, r, kq, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
        px, py = px - 2.0*l*kq*qx - 1.0*l*ks*(qx**2 - qy**2) + 2.0*l/r**2*(w*r - qx), py + 2.0*l*kq*qy + 2.0*l*ks*qx*qy
        qx, qy = qx + l*px/(1 + w), qy + l*py/(1 + w)
    return torch.stack([qx, px, qy, py])

def roll(x, a):
    (qx, px, qy, py), cn, sn = x, a.cos(), a.sin()
    return torch.stack([qx*cn + qy*sn, px*cn + py*sn, qy*cn - qx*sn, py*cn - py*sn])

def kick(x, kn, ks):
    (qx, px, qy, py), kn, ks = x, kn, ks
    return torch.stack([qx, px - kn*qx + ks*qy, qy, py + ks*qx + kn*qy])
[5]:
# Set transport maps between observation points

def map_01_02(x, k):
    kf, kd = k
    x = kick(x, 0.0, kf/2.0)
    x = quad(x, [0.0], 0.21, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.0, 0.10)
    x = drif(x, [0.0], 0.45)
    x = drif(x, [0.0], 3.0)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.0, 0.10)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], -0.18, 0.50)
    x = kick(x, 0.0, kd/2.0)
    return x

def map_02_03(x, k):
    kf, kd = k
    x = kick(x, 0.0, kd/2.0)
    x = quad(x, [0.0], -0.18, 0.50)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.0, 0.10)
    x = drif(x, [0.0], 0.45)
    x = drif(x, [0.0], 3.0)
    x = drif(x, [0.0], 0.45)
    x = sext(x, [0.0], 0.0, 0.10)
    x = drif(x, [0.0], 0.45)
    x = quad(x, [0.0], 0.21, 0.50)
    x = kick(x, 0.0, kf/2.0)
    return x

transport = [
    map_01_02,
    map_02_03
]

# Define one-turn transport

def fodo(x, k):
    for mapping in transport:
        x = mapping(x, k)
    return x

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
k = torch.tensor(2*[0.0], dtype=dtype, device=device)

print(fodo(x, k))
tensor([0., 0., 0., 0.], dtype=torch.float64)
[6]:
# Compute fixed point

x = torch.tensor(4*[0.0], dtype=dtype, device=device)
k = torch.tensor(2*[0.0], dtype=dtype, device=device)

fp = fixed_point(16, fodo, x, k, power=1)
print(fp)
tensor([0., 0., 0., 0.], dtype=torch.float64)
[7]:
# Compute parametric fixed point

pfp = parametric_fixed_point((1, ), fp, [k], fodo)
chop(pfp)
pfp
[7]:
[[tensor([0., 0., 0., 0.], dtype=torch.float64),
  tensor([[0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.]], dtype=torch.float64)]]
[8]:
# Propagate parametric fixed point

out = propagate((4, 2), (0, 1), pfp, [k], fodo)
chop(out)
out
[8]:
[[tensor([0., 0., 0., 0.], dtype=torch.float64),
  tensor([[0., 0.],
          [0., 0.],
          [0., 0.],
          [0., 0.]], dtype=torch.float64)]]
[9]:
# Propagate parametric identity (surrogate model for linear dynamics)

jet = identity((1, 1), fp, parametric=pfp)
jet = propagate((4, 2), (1, 1), jet, [k], fodo)
[10]:
# dQmin (TEAPOT manual, appendix G, 1996)

def dQmin(k):

    m = derivative(1, lambda x: evaluate(jet, [x, k]), fp, intermediate=False)

    (nux, nuy), *_ = twiss(m)
    mux, muy = 2.0*torch.pi*nux, 2.0*torch.pi*nuy

    B = m[:2, 2:]
    C = m[2:, :2]

    (m11, m12), (m21, m22) = C + symplectic_conjugate(B)

    return 1.0/torch.pi*(m11*m22 - m12*m21).abs().sqrt()/(mux.sin() + muy.sin()).abs()
[11]:
# Set skew errors

dkf = +1.5E-3
dkd = -0.5E-3

dk = torch.tensor([dkf, dkd], dtype=dtype, device=device)

dQmin(dk)
[11]:
tensor(1.881e-03, dtype=torch.float64)
[12]:
# Given qx TbT data, in linear approximation qx(n) = (cxx cos(mux n) + sxx sin(mux n)) +  (cxy cos(muy n) + sxy sin(muy n))
# Amplitudes axx^2 = cxx^2 + sxx^2 and axy^2 = cxy^2 + cyx^2 are used to compute ratio axy/axx
# Similary, ayx/ayy can be computed using qy
# These ration are computed at each observation point
[13]:
# Observation function

# Fit 'experiment' to 'model' ('experiment' is 'model' with errors)
# The goal is to fit knobs so that 'experiment' observation matches 'model' observation (ratios are zero)

# Fit 'model' to 'experiment'
# The goal is to fit knobs so that 'model' observation matches 'experiment' observation (final ratios are non-zero)

# Set initial condition for TbT
# Exact value is not important, since the underlying model is linear
# Also, this initial should be set relative to (parametric) closed orbit

initial = torch.tensor([1.0, 0.0, 1.0, 0.0], dtype=dtype, device=device)

# Normalized window for computation of amplitudes

def window(n, *, s=1.0, dtype=dtype, device=device):
    t = torch.linspace(0.0, (n - 1.0)/n, n, dtype=dtype, device=device)
    f = torch.exp(-1.0/((1.0 - t)**s*t**s))
    return f/torch.sum(f)

def fn(k, n, x, error):

    if error:
        k = k + dk

    w = window(n)
    t = torch.linspace(0.0, n - 1, n, dtype=dtype, device=device)

    matrix = derivative(1, lambda x: evaluate(jet, [x, k]), fp, intermediate=False)
    (nux, nuy), *_ = twiss(matrix)

    mux = 2.0*torch.pi*nux
    muy = 2.0*torch.pi*nuy

    xs = []
    for _ in range(n):
        for mapping in transport:
            x = mapping(x, k)
            xs.append(x)
    xs = torch.stack(xs).reshape(n, len(transport), -1).swapaxes(1, 0).swapaxes(1, -1)

    cxx, cxy, cyx, cyy = [], [], [], []
    sxx, sxy, syx, syy = [], [], [], []

    for x in xs:

        qx, _, qy, _ = x

        cxx.append(w*qx @ (mux*t).cos())
        cxy.append(w*qx @ (muy*t).cos())
        cyx.append(w*qy @ (mux*t).cos())
        cyy.append(w*qy @ (muy*t).cos())

        sxx.append(w*qx @ (mux*t).sin())
        sxy.append(w*qx @ (muy*t).sin())
        syx.append(w*qy @ (mux*t).sin())
        syy.append(w*qy @ (muy*t).sin())

    cxx = torch.stack(cxx)
    cxy = torch.stack(cxy)
    cyx = torch.stack(cyx)
    cyy = torch.stack(cyy)

    sxx = torch.stack(sxx)
    sxy = torch.stack(sxy)
    syx = torch.stack(syx)
    syy = torch.stack(syy)

    axx = (cxx**2 + sxx**2).sqrt()
    axy = (cxy**2 + sxy**2).sqrt()
    ayx = (cyx**2 + syx**2).sqrt()
    ayy = (cyy**2 + syy**2).sqrt()

    return torch.stack([axy/axx, ayx/ayy]).T.flatten()
[14]:
# The above function returns ratios (as a vector, so that the derivative is a matrix)
# rx_1, rx_2, ..., ry_1, ry_2, ...
[15]:
# Without errors, the ratios should be zero
# The accuracy is determined by the TbT data length

print(fn(k, 128, initial, False))
print(fn(k, 256, initial, False))
print(fn(k, 512, initial, False))
print()

print(fn(k, 128, initial, True))
print(fn(k, 256, initial, True))
print(fn(k, 512, initial, True))
print()
tensor([1.041e-04, 1.041e-04, 1.041e-04, 1.041e-04], dtype=torch.float64)
tensor([4.081e-06, 4.081e-06, 4.081e-06, 4.081e-06], dtype=torch.float64)
tensor([7.097e-09, 7.097e-09, 7.097e-09, 7.097e-09], dtype=torch.float64)

tensor([3.210e-02, 6.907e-03, 2.517e-02, 8.844e-03], dtype=torch.float64)
tensor([3.200e-02, 7.011e-03, 2.507e-02, 8.946e-03], dtype=torch.float64)
tensor([3.200e-02, 7.007e-03, 2.508e-02, 8.942e-03], dtype=torch.float64)

[16]:
# Correction ('experiment' to 'model')
# Target system is 'model'

# Learning rate

lr = 0.75

# Target values

vf = fn(k, 256, initial, False)

# Initial solution

solution = torch.zeros_like(dk)

# Correction loop

for _ in range(16):

    vi, jacobian = derivative(1, fn, solution, 256, initial, True, intermediate=True)
    dv = vf - vi
    solution += lr * torch.linalg.pinv(jacobian) @ dv

    print(solution)
    print(vf)
    print(vi)
    print(dv.norm())
    print()
tensor([-1.094e-03, 3.652e-04], dtype=torch.float64)
tensor([4.081e-06, 4.081e-06, 4.081e-06, 4.081e-06], dtype=torch.float64)
tensor([3.200e-02, 7.011e-03, 2.507e-02, 8.946e-03], dtype=torch.float64)
tensor(4.220e-02, dtype=torch.float64)

tensor([-1.396e-03, 4.657e-04], dtype=torch.float64)
tensor([4.081e-06, 4.081e-06, 4.081e-06, 4.081e-06], dtype=torch.float64)
tensor([8.467e-03, 1.954e-03, 6.640e-03, 2.491e-03], dtype=torch.float64)
tensor(1.121e-02, dtype=torch.float64)

tensor([-1.474e-03, 4.915e-04], dtype=torch.float64)
tensor([4.081e-06, 4.081e-06, 4.081e-06, 4.081e-06], dtype=torch.float64)
tensor([2.153e-03, 5.078e-04, 1.689e-03, 6.460e-04], dtype=torch.float64)
tensor(2.850e-03, dtype=torch.float64)

tensor([-1.493e-03, 4.980e-04], dtype=torch.float64)
tensor([4.081e-06, 4.081e-06, 4.081e-06, 4.081e-06], dtype=torch.float64)
tensor([5.430e-04, 1.326e-04, 4.266e-04, 1.674e-04], dtype=torch.float64)
tensor(7.156e-04, dtype=torch.float64)

tensor([-1.498e-03, 4.996e-04], dtype=torch.float64)
tensor([4.081e-06, 4.081e-06, 4.081e-06, 4.081e-06], dtype=torch.float64)
tensor([1.385e-04, 3.790e-05, 1.093e-04, 4.665e-05], dtype=torch.float64)
tensor(1.791e-04, dtype=torch.float64)

tensor([-1.499e-03, 5.000e-04], dtype=torch.float64)
tensor([4.081e-06, 4.081e-06, 4.081e-06, 4.081e-06], dtype=torch.float64)
tensor([3.719e-05, 1.418e-05, 2.991e-05, 1.639e-05], dtype=torch.float64)
tensor(4.491e-05, dtype=torch.float64)

tensor([-1.500e-03, 5.001e-04], dtype=torch.float64)
tensor([4.081e-06, 4.081e-06, 4.081e-06, 4.081e-06], dtype=torch.float64)
tensor([1.188e-05, 8.259e-06, 1.005e-05, 8.816e-06], dtype=torch.float64)
tensor(1.167e-05, dtype=torch.float64)

tensor([-1.500e-03, 5.002e-04], dtype=torch.float64)
tensor([4.081e-06, 4.081e-06, 4.081e-06, 4.081e-06], dtype=torch.float64)
tensor([5.555e-06, 6.783e-06, 5.081e-06, 6.915e-06], dtype=torch.float64)
tensor(4.302e-06, dtype=torch.float64)

tensor([-1.500e-03, 5.002e-04], dtype=torch.float64)
tensor([4.081e-06, 4.081e-06, 4.081e-06, 4.081e-06], dtype=torch.float64)
tensor([3.961e-06, 6.415e-06, 3.839e-06, 6.430e-06], dtype=torch.float64)
tensor(3.323e-06, dtype=torch.float64)

tensor([-1.500e-03, 5.002e-04], dtype=torch.float64)
tensor([4.081e-06, 4.081e-06, 4.081e-06, 4.081e-06], dtype=torch.float64)
tensor([3.551e-06, 6.324e-06, 3.529e-06, 6.305e-06], dtype=torch.float64)
tensor(3.249e-06, dtype=torch.float64)

tensor([-1.500e-03, 5.002e-04], dtype=torch.float64)
tensor([4.081e-06, 4.081e-06, 4.081e-06, 4.081e-06], dtype=torch.float64)
tensor([3.445e-06, 6.301e-06, 3.452e-06, 6.272e-06], dtype=torch.float64)
tensor(3.245e-06, dtype=torch.float64)

tensor([-1.500e-03, 5.002e-04], dtype=torch.float64)
tensor([4.081e-06, 4.081e-06, 4.081e-06, 4.081e-06], dtype=torch.float64)
tensor([3.418e-06, 6.295e-06, 3.432e-06, 6.263e-06], dtype=torch.float64)
tensor(3.244e-06, dtype=torch.float64)

tensor([-1.500e-03, 5.002e-04], dtype=torch.float64)
tensor([4.081e-06, 4.081e-06, 4.081e-06, 4.081e-06], dtype=torch.float64)
tensor([3.411e-06, 6.293e-06, 3.427e-06, 6.261e-06], dtype=torch.float64)
tensor(3.244e-06, dtype=torch.float64)

tensor([-1.500e-03, 5.002e-04], dtype=torch.float64)
tensor([4.081e-06, 4.081e-06, 4.081e-06, 4.081e-06], dtype=torch.float64)
tensor([3.409e-06, 6.293e-06, 3.426e-06, 6.261e-06], dtype=torch.float64)
tensor(3.244e-06, dtype=torch.float64)

tensor([-1.500e-03, 5.002e-04], dtype=torch.float64)
tensor([4.081e-06, 4.081e-06, 4.081e-06, 4.081e-06], dtype=torch.float64)
tensor([3.408e-06, 6.293e-06, 3.426e-06, 6.261e-06], dtype=torch.float64)
tensor(3.244e-06, dtype=torch.float64)

tensor([-1.500e-03, 5.002e-04], dtype=torch.float64)
tensor([4.081e-06, 4.081e-06, 4.081e-06, 4.081e-06], dtype=torch.float64)
tensor([3.408e-06, 6.293e-06, 3.426e-06, 6.260e-06], dtype=torch.float64)
tensor(3.244e-06, dtype=torch.float64)

[17]:
# Test final solution

print((vf - fn(0.0*solution, 256, initial, True)).norm())
print((vf - fn(1.0*solution, 256, initial, True)).norm())
print()

print(dQmin(dk))
print(dQmin(dk + solution))
print()
tensor(4.220e-02, dtype=torch.float64)
tensor(3.244e-06, dtype=torch.float64)

tensor(1.881e-03, dtype=torch.float64)
tensor(5.079e-07, dtype=torch.float64)

[18]:
# Correction ('model' to 'experiment')
# Target system is 'experiment'

# Learning rate

lr = 0.75

# Target values

vf = fn(k, 256, initial, True)


solution = torch.zeros_like(dk)

for _ in range(16):

    vi, jacobian = derivative(1, fn, solution, 256, initial, False, intermediate=True)
    dv = vf - vi
    solution += lr * torch.linalg.pinv(jacobian) @ dv

    print(solution)
    print(vf)
    print(vi)
    print(dv.norm())
    print()
tensor([-1.019e-03, 3.612e-04], dtype=torch.float64)
tensor([3.200e-02, 7.011e-03, 2.507e-02, 8.946e-03], dtype=torch.float64)
tensor([4.081e-06, 4.081e-06, 4.081e-06, 4.081e-06], dtype=torch.float64)
tensor(4.220e-02, dtype=torch.float64)

tensor([-1.443e-03, 4.881e-04], dtype=torch.float64)
tensor([3.200e-02, 7.011e-03, 2.507e-02, 8.946e-03], dtype=torch.float64)
tensor([2.008e-02, 4.856e-03, 1.556e-02, 6.267e-03], dtype=torch.float64)
tensor(1.563e-02, dtype=torch.float64)

tensor([-1.553e-03, 5.221e-04], dtype=torch.float64)
tensor([3.200e-02, 7.011e-03, 2.507e-02, 8.946e-03], dtype=torch.float64)
tensor([2.867e-02, 7.133e-03, 2.241e-02, 9.127e-03], dtype=torch.float64)
tensor(4.267e-03, dtype=torch.float64)

tensor([-1.582e-03, 5.310e-04], dtype=torch.float64)
tensor([3.200e-02, 7.011e-03, 2.507e-02, 8.946e-03], dtype=torch.float64)
tensor([3.087e-02, 7.729e-03, 2.416e-02, 9.878e-03], dtype=torch.float64)
tensor(1.869e-03, dtype=torch.float64)

tensor([-1.589e-03, 5.332e-04], dtype=torch.float64)
tensor([3.200e-02, 7.011e-03, 2.507e-02, 8.946e-03], dtype=torch.float64)
tensor([3.142e-02, 7.878e-03, 2.459e-02, 1.007e-02], dtype=torch.float64)
tensor(1.603e-03, dtype=torch.float64)

tensor([-1.590e-03, 5.338e-04], dtype=torch.float64)
tensor([3.200e-02, 7.011e-03, 2.507e-02, 8.946e-03], dtype=torch.float64)
tensor([3.156e-02, 7.916e-03, 2.470e-02, 1.011e-02], dtype=torch.float64)
tensor(1.585e-03, dtype=torch.float64)

tensor([-1.591e-03, 5.340e-04], dtype=torch.float64)
tensor([3.200e-02, 7.011e-03, 2.507e-02, 8.946e-03], dtype=torch.float64)
tensor([3.159e-02, 7.925e-03, 2.473e-02, 1.013e-02], dtype=torch.float64)
tensor(1.584e-03, dtype=torch.float64)

tensor([-1.591e-03, 5.340e-04], dtype=torch.float64)
tensor([3.200e-02, 7.011e-03, 2.507e-02, 8.946e-03], dtype=torch.float64)
tensor([3.160e-02, 7.927e-03, 2.474e-02, 1.013e-02], dtype=torch.float64)
tensor(1.584e-03, dtype=torch.float64)

tensor([-1.591e-03, 5.340e-04], dtype=torch.float64)
tensor([3.200e-02, 7.011e-03, 2.507e-02, 8.946e-03], dtype=torch.float64)
tensor([3.160e-02, 7.928e-03, 2.474e-02, 1.013e-02], dtype=torch.float64)
tensor(1.584e-03, dtype=torch.float64)

tensor([-1.591e-03, 5.340e-04], dtype=torch.float64)
tensor([3.200e-02, 7.011e-03, 2.507e-02, 8.946e-03], dtype=torch.float64)
tensor([3.160e-02, 7.928e-03, 2.474e-02, 1.013e-02], dtype=torch.float64)
tensor(1.584e-03, dtype=torch.float64)

tensor([-1.591e-03, 5.340e-04], dtype=torch.float64)
tensor([3.200e-02, 7.011e-03, 2.507e-02, 8.946e-03], dtype=torch.float64)
tensor([3.160e-02, 7.928e-03, 2.474e-02, 1.013e-02], dtype=torch.float64)
tensor(1.584e-03, dtype=torch.float64)

tensor([-1.591e-03, 5.340e-04], dtype=torch.float64)
tensor([3.200e-02, 7.011e-03, 2.507e-02, 8.946e-03], dtype=torch.float64)
tensor([3.160e-02, 7.928e-03, 2.474e-02, 1.013e-02], dtype=torch.float64)
tensor(1.584e-03, dtype=torch.float64)

tensor([-1.591e-03, 5.340e-04], dtype=torch.float64)
tensor([3.200e-02, 7.011e-03, 2.507e-02, 8.946e-03], dtype=torch.float64)
tensor([3.160e-02, 7.928e-03, 2.474e-02, 1.013e-02], dtype=torch.float64)
tensor(1.584e-03, dtype=torch.float64)

tensor([-1.591e-03, 5.340e-04], dtype=torch.float64)
tensor([3.200e-02, 7.011e-03, 2.507e-02, 8.946e-03], dtype=torch.float64)
tensor([3.160e-02, 7.928e-03, 2.474e-02, 1.013e-02], dtype=torch.float64)
tensor(1.584e-03, dtype=torch.float64)

tensor([-1.591e-03, 5.340e-04], dtype=torch.float64)
tensor([3.200e-02, 7.011e-03, 2.507e-02, 8.946e-03], dtype=torch.float64)
tensor([3.160e-02, 7.928e-03, 2.474e-02, 1.013e-02], dtype=torch.float64)
tensor(1.584e-03, dtype=torch.float64)

tensor([-1.591e-03, 5.340e-04], dtype=torch.float64)
tensor([3.200e-02, 7.011e-03, 2.507e-02, 8.946e-03], dtype=torch.float64)
tensor([3.160e-02, 7.928e-03, 2.474e-02, 1.013e-02], dtype=torch.float64)
tensor(1.584e-03, dtype=torch.float64)

[19]:
# Test final solution

print((vf - fn(0.0*solution, 256, initial, False)).norm())
print((vf - fn(1.0*solution, 256, initial, False)).norm())
print()

print(dQmin(dk))
print(dQmin(solution))
print(dQmin(dk + solution))
print(dQmin(dk - solution))
print()
tensor(4.220e-02, dtype=torch.float64)
tensor(1.584e-03, dtype=torch.float64)

tensor(1.881e-03, dtype=torch.float64)
tensor(1.989e-03, dtype=torch.float64)
tensor(1.078e-04, dtype=torch.float64)
tensor(3.870e-03, dtype=torch.float64)

[20]:
# Plot

factors = torch.linspace(-1.5, 1.5, 128, dtype=dtype, device=device)
plt.figure(figsize=(16, 4))
plt.scatter(factors, torch.stack([dQmin(dk + factor*solution) for factor in factors]), color='blue')
plt.tight_layout()
plt.show()
../_images/examples_ndmap_557_0.png
[ ]:

[ ]:

[ ]: