Example-01: Coupled twiss parameters

[1]:
import torch
torch.set_printoptions(precision=3, sci_mode=True)
[2]:
# Set elements

def drif(x, l):
    (qx, px, qy, py), l = x, l
    return torch.stack([qx + l*px, px, qy + l*py, py])

def quad(x, kn, ks, l, n=25):
    (qx, px, qy, py), kn, ks, l = x, kn, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px, qy + l*py
        px, py = px - 2.0*l*(kn*qx - ks*qy), py + 2.0*l*(ks*qx + kn*qy)
        qx, qy = qx + l*px, qy + l*py
    return torch.stack([qx, px, qy, py])
[3]:
# Set transport maps

def m11(x, k): kn1, ks1, kn2, ks2 = k; return x
def m12(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn1, ks1, 0.5); return x
def m23(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 2.5); return x
def m34(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 2.5); return x
def m45(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn2, ks2, 0.5); return x
def m56(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn2, ks2, 0.5); return x
def m67(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 2.5); return x
def m78(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 2.5); return x
def m89(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn1, ks1, 0.5); return x
[4]:
# Set fodo cell

def fodo(x, k):
    x = m11(x, k)
    x = m12(x, k)
    x = m23(x, k)
    x = m34(x, k)
    x = m45(x, k)
    x = m56(x, k)
    x = m67(x, k)
    x = m78(x, k)
    x = m89(x, k)
    return x
[5]:
# Compute one-turn transport matrix

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
k = torch.tensor([0.20, -0.01, -0.25, 0.01], dtype=torch.float64)

m = torch.func.jacrev(fodo)(x, k)
print(m)

from twiss import is_stable
print(is_stable(m))
tensor([[ 4.946e-01,  2.032e+01, -8.454e-03,  3.297e-01],
        [-3.725e-02,  4.946e-01,  4.386e-03,  8.052e-03],
        [ 8.052e-03,  3.297e-01, -9.236e-02,  3.539e+00],
        [ 4.386e-03, -8.454e-03, -2.806e-01, -9.236e-02]], dtype=torch.float64)
True
[6]:
# Compute coupled twiss parameters

from twiss import twiss
t, n, w = twiss(m)

print(t)
print(n)
print(w)
tensor([1.677e-01, 2.647e-01], dtype=torch.float64)
tensor([[ 4.835e+00,  0.000e+00,  2.714e-02, -2.788e-16],
        [-3.407e-17,  2.069e-01,  6.135e-18, -7.283e-03],
        [ 6.634e-02,  3.851e-17,  1.884e+00,  0.000e+00],
        [-1.459e-17, -2.980e-03,  1.841e-16,  5.308e-01]], dtype=torch.float64)
tensor([[[ 2.338e+01, -1.647e-16,  3.208e-01, -7.055e-17],
         [-1.647e-16,  4.279e-02,  5.706e-18, -6.165e-04],
         [ 3.208e-01,  5.706e-18,  4.401e-03, -1.083e-18],
         [-7.055e-17, -6.165e-04, -1.083e-18,  8.880e-06]],

        [[ 7.367e-04,  2.197e-18,  5.114e-02, -1.430e-16],
         [ 2.197e-18,  5.305e-05,  1.156e-17, -3.866e-03],
         [ 5.114e-02,  1.156e-17,  3.550e+00,  3.469e-16],
         [-1.430e-16, -3.866e-03,  3.469e-16,  2.818e-01]]],
       dtype=torch.float64)
[7]:
# Check normalization matrix

from math import pi
from twiss import rotation

print(rotation(*2*pi*t))
print(n.inverse() @ m @ n)
tensor([[4.945e-01, 8.692e-01, 0.000e+00, 0.000e+00],
        [-8.692e-01, 4.945e-01, 0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00, -9.224e-02, 9.957e-01],
        [0.000e+00, 0.000e+00, -9.957e-01, -9.224e-02]], dtype=torch.float64)
tensor([[ 4.945e-01,  8.692e-01, -4.585e-18,  5.551e-17],
        [-8.692e-01,  4.945e-01,  8.179e-17, -3.221e-17],
        [ 9.636e-18,  9.741e-18, -9.224e-02,  9.957e-01],
        [ 1.133e-17, -2.888e-17, -9.957e-01, -9.224e-02]], dtype=torch.float64)
[8]:
# Check twiss matrices

from twiss.matrix import symplectic_identity

s = symplectic_identity(len(w), dtype=torch.float64)
u = torch.zeros_like(m)
for ti, wi in zip(t, w):
    u += (wi @ s) * (2*pi*ti).sin() - (wi @ s) @ (wi @ s) * (2*pi*ti).cos()

print(m)
print(u)
tensor([[ 4.946e-01,  2.032e+01, -8.454e-03,  3.297e-01],
        [-3.725e-02,  4.946e-01,  4.386e-03,  8.052e-03],
        [ 8.052e-03,  3.297e-01, -9.236e-02,  3.539e+00],
        [ 4.386e-03, -8.454e-03, -2.806e-01, -9.236e-02]], dtype=torch.float64)
tensor([[ 4.946e-01,  2.032e+01, -8.454e-03,  3.297e-01],
        [-3.725e-02,  4.946e-01,  4.386e-03,  8.052e-03],
        [ 8.052e-03,  3.297e-01, -9.236e-02,  3.539e+00],
        [ 4.386e-03, -8.454e-03, -2.806e-01, -9.236e-02]], dtype=torch.float64)
[9]:
# Compute twiss at each location

from twiss import propagate

out = []
for mapping in (m11, m12, m23, m34, m45, m56, m67, m78, m89):
    w = propagate(w, torch.func.jacrev(mapping)(x, k))
    out.append(w)
out = torch.stack(out)

print(out.shape)
torch.Size([9, 2, 4, 4])
[10]:
# Convert to CS and LB represencation

from twiss import  wolski_to_cs
from twiss import  wolski_to_lb

ax, bx, ay, by = torch.vmap(wolski_to_cs)(out).T
a1x, b1x, a2x, b2x, a1y, b1y, a2y, b2y, *_ = torch.vmap(wolski_to_lb)(out).T
[11]:
# Plot in-plane twiss parameters

from matplotlib import pyplot as plt

plt.figure(figsize=(20, 5))
plt.plot(range(9), bx, marker='o', color='blue')
plt.plot(range(9), b1x, marker='x', color='red')
plt.plot(range(9), by, marker='o', color='blue')
plt.plot(range(9), b2y, marker='x', color='red')
plt.show()
../_images/examples_twiss_11_0.png
[12]:
# Plot in-plane coupled twiss parameters

from matplotlib import pyplot as plt

plt.figure(figsize=(20, 5))
plt.plot(range(9), b2x, marker='x', color='red')
plt.plot(range(9), b1y, marker='x', color='blue')
plt.show()
../_images/examples_twiss_12_0.png
[13]:
# Compute phase advances

from twiss import advance

out = []
for mapping in (m11, m12, m23, m34, m45, m56, m67, m78, m89):
    mu, n = advance(n, torch.func.jacrev(mapping)(x, k))
    out.append(mu)
out = torch.stack(out)

print(t)
print(1/(2*pi)*out.T.sum(-1))
tensor([1.677e-01, 2.647e-01], dtype=torch.float64)
tensor([1.677e-01, 2.647e-01], dtype=torch.float64)
[14]:
# Plot accumulated phase advance

mux, muy = out.T.cumsum(-1)

plt.figure(figsize=(20, 5))
plt.plot(range(9), mux, marker='x', color='red')
plt.plot(range(9), muy, marker='x', color='blue')
plt.show()
../_images/examples_twiss_14_0.png

Example-02: Tune derivatives

[1]:
import torch
torch.set_printoptions(precision=3, sci_mode=True)
[2]:
# Set elements

def drif(x, l):
    (qx, px, qy, py), l = x, l
    return torch.stack([qx + l*px, px, qy + l*py, py])

def quad(x, kn, ks, l, n=25):
    (qx, px, qy, py), kn, ks, l = x, kn, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px, qy + l*py
        px, py = px - 2.0*l*(kn*qx - ks*qy), py + 2.0*l*(ks*qx + kn*qy)
        qx, qy = qx + l*px, qy + l*py
    return torch.stack([qx, px, qy, py])
[3]:
# Set transport maps

def m11(x, k): kn1, ks1, kn2, ks2 = k; return x
def m12(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn1, ks1, 0.5); return x
def m23(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 2.5); return x
def m34(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 2.5); return x
def m45(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn2, ks2, 0.5); return x
def m56(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn2, ks2, 0.5); return x
def m67(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 2.5); return x
def m78(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 2.5); return x
def m89(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn1, ks1, 0.5); return x
[4]:
# Set fodo cell

def fodo(x, k):
    x = m11(x, k)
    x = m12(x, k)
    x = m23(x, k)
    x = m34(x, k)
    x = m45(x, k)
    x = m56(x, k)
    x = m67(x, k)
    x = m78(x, k)
    x = m89(x, k)
    return x
[5]:
# Compute tune derivatives

from twiss import twiss

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
k = torch.tensor([0.20, 0.0, -0.25, 0.0], dtype=torch.float64)

def fn(k):
    m = torch.func.jacrev(fodo)(x, k)
    t, *_ = twiss(m)
    return t

d1 = torch.func.jacrev(fn)(k)
d2 = torch.func.jacrev(torch.func.jacrev(fn))(k)

print(d1)
print(d2)
tensor([[1.832e+00, 0.000e+00, 5.027e-01, 0.000e+00],
        [-2.894e-01, 0.000e+00, -1.430e+00, 0.000e+00]], dtype=torch.float64)
tensor([[[-1.265e+01, 0.000e+00, -6.634e+00, 0.000e+00],
         [0.000e+00, -1.186e+01, 0.000e+00, -1.682e+01],
         [-6.634e+00, 0.000e+00, -1.069e+00, 0.000e+00],
         [0.000e+00, -1.682e+01, 0.000e+00, -1.547e+01]],

        [[-4.998e-02, 0.000e+00, -2.576e+00, 0.000e+00],
         [0.000e+00, 9.716e+00, 0.000e+00, 8.931e+00],
         [-2.576e+00, 0.000e+00, 7.011e-01, 0.000e+00],
         [0.000e+00, 8.931e+00, 0.000e+00, 1.286e+01]]], dtype=torch.float64)
[6]:
# Test

dk = torch.tensor([0.005, -0.005, -0.005, 0.005], dtype=torch.float64)

print(fn(k))
print(fn(k) + d1 @ dk)
print(fn(k) + d1 @ dk + 1/2 * d2 @ dk @ dk)
print()

t, *_ = twiss(torch.func.jacrev(fodo)(x, k + dk))
print(t)
print()
tensor([1.674e-01, 2.645e-01], dtype=torch.float64)
tensor([1.740e-01, 2.702e-01], dtype=torch.float64)
tensor([1.741e-01, 2.703e-01], dtype=torch.float64)

tensor([1.741e-01, 2.703e-01], dtype=torch.float64)

Example-03: Twiss derivatives

[1]:
import torch
torch.set_printoptions(precision=3, sci_mode=True)
[2]:
# Set elements

def drif(x, l):
    (qx, px, qy, py), l = x, l
    return torch.stack([qx + l*px, px, qy + l*py, py])

def quad(x, kn, ks, l, n=25):
    (qx, px, qy, py), kn, ks, l = x, kn, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px, qy + l*py
        px, py = px - 2.0*l*(kn*qx - ks*qy), py + 2.0*l*(ks*qx + kn*qy)
        qx, qy = qx + l*px, qy + l*py
    return torch.stack([qx, px, qy, py])
[3]:
# Set transport maps

def m11(x, k): kn1, ks1, kn2, ks2 = k; return x
def m12(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn1, ks1, 0.5); return x
def m23(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 2.5); return x
def m34(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 2.5); return x
def m45(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn2, ks2, 0.5); return x
def m56(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn2, ks2, 0.5); return x
def m67(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 2.5); return x
def m78(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 2.5); return x
def m89(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn1, ks1, 0.5); return x
[4]:
# Set fodo cell

def fodo(x, k):
    x = m11(x, k)
    x = m12(x, k)
    x = m23(x, k)
    x = m34(x, k)
    x = m45(x, k)
    x = m56(x, k)
    x = m67(x, k)
    x = m78(x, k)
    x = m89(x, k)
    return x
[5]:
# Set task function

from twiss import twiss
from twiss import propagate
from twiss import wolski_to_cs

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
k = torch.tensor([0.20, 0.0, -0.25, 0.0], dtype=torch.float64)

def fn(k):
    m = torch.func.jacrev(fodo)(x, k)
    *_, w = twiss(m)
    out = []
    for mapping in (m11, m12, m23, m34, m45, m56, m67, m78, m89):
        w = propagate(w, torch.func.jacrev(mapping)(x, k))
        out.append(w)
    out = torch.stack(out)
    ax, bx, ay, by = torch.vmap(wolski_to_cs)(out).T
    return torch.stack([bx, by]).T
[6]:
# Compute and plot twiss

bx, by = fn(k).T

from matplotlib import pyplot as plt

plt.figure(figsize=(20, 5))
plt.plot(range(9), bx, marker='o', color='red')
plt.plot(range(9), by, marker='o', color='blue')
plt.show()
../_images/examples_twiss_28_0.png
[7]:
# Compute twiss derivatives

d1 = torch.func.jacrev(fn)(k)
print(d1.shape)
torch.Size([9, 2, 4])
[8]:
# Test

dk = torch.tensor([0.005, -0.005, -0.005, 0.005], dtype=torch.float64)

print(fn(k).T)
print((fn(k) + d1 @ dk).T)
print()

print(fn(k + dk).T)
print()
tensor([[2.341e+01, 2.226e+01, 1.274e+01, 6.608e+00, 6.173e+00, 6.608e+00, 1.274e+01, 2.226e+01, 2.341e+01],
        [3.554e+00, 3.806e+00, 8.444e+00, 1.723e+01, 1.834e+01, 1.723e+01, 8.444e+00, 3.806e+00, 3.554e+00]],
       dtype=torch.float64)
tensor([[2.303e+01, 2.188e+01, 1.232e+01, 6.251e+00, 5.827e+00, 6.251e+00, 1.232e+01, 2.188e+01, 2.303e+01],
        [3.395e+00, 3.647e+00, 8.370e+00, 1.741e+01, 1.856e+01, 1.741e+01, 8.370e+00, 3.647e+00, 3.395e+00]],
       dtype=torch.float64)

tensor([[2.305e+01, 2.190e+01, 1.234e+01, 6.262e+00, 5.838e+00, 6.262e+00, 1.234e+01, 2.190e+01, 2.305e+01],
        [3.396e+00, 3.648e+00, 8.373e+00, 1.742e+01, 1.857e+01, 1.742e+01, 8.373e+00, 3.648e+00, 3.396e+00]],
       dtype=torch.float64)

Example-04: Phase advance derivatives

[1]:
import torch
torch.set_printoptions(precision=3, sci_mode=True)
[2]:
# Set elements

def drif(x, l):
    (qx, px, qy, py), l = x, l
    return torch.stack([qx + l*px, px, qy + l*py, py])

def quad(x, kn, ks, l, n=25):
    (qx, px, qy, py), kn, ks, l = x, kn, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px, qy + l*py
        px, py = px - 2.0*l*(kn*qx - ks*qy), py + 2.0*l*(ks*qx + kn*qy)
        qx, qy = qx + l*px, qy + l*py
    return torch.stack([qx, px, qy, py])
[3]:
# Set transport maps

def m11(x, k): kn1, ks1, kn2, ks2 = k; return x
def m12(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn1, ks1, 0.5); return x
def m23(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 2.5); return x
def m34(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 2.5); return x
def m45(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn2, ks2, 0.5); return x
def m56(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn2, ks2, 0.5); return x
def m67(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 2.5); return x
def m78(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 2.5); return x
def m89(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn1, ks1, 0.5); return x
[4]:
# Set fodo cell

def fodo(x, k):
    x = m11(x, k)
    x = m12(x, k)
    x = m23(x, k)
    x = m34(x, k)
    x = m45(x, k)
    x = m56(x, k)
    x = m67(x, k)
    x = m78(x, k)
    x = m89(x, k)
    return x
[5]:
# Set task function

from twiss import twiss
from twiss import advance

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
k = torch.tensor([0.20, 0.0, -0.25, 0.0], dtype=torch.float64)

def fn(k):
    m = torch.func.jacrev(fodo)(x, k)
    _, n, _ = twiss(m)
    out = []
    for mapping in (m12, m23, m34, m45, m56, m67, m78, m89):
        mu, n = advance(n, torch.func.jacrev(mapping)(x, k))
        out.append(mu)
    return torch.stack(out)
[6]:
# Compute and plot phase advance

mux, muy = fn(k).T

from matplotlib import pyplot as plt

plt.figure(figsize=(20, 5))
plt.plot(range(8), mux, marker='o', color='red')
plt.plot(range(8), muy, marker='o', color='blue')
plt.show()
../_images/examples_twiss_37_0.png
[7]:
# Compute phase advance derivatives

d1 = torch.func.jacrev(fn)(k)
print(d1.shape)
torch.Size([8, 2, 4])
[8]:
# Test

dk = torch.tensor([0.005, -0.005, -0.005, 0.005], dtype=torch.float64)

print(fn(k).T)
print((fn(k) + d1 @ dk).T)
print()

print(fn(k + dk).T)
print()
tensor([[2.172e-02, 1.490e-01, 2.759e-01, 7.918e-02, 7.918e-02, 2.759e-01, 1.490e-01, 2.172e-02],
        [1.375e-01, 4.567e-01, 2.088e-01, 2.784e-02, 2.784e-02, 2.088e-01, 4.567e-01, 1.375e-01]],
       dtype=torch.float64)
tensor([[2.208e-02, 1.528e-01, 2.883e-01, 8.357e-02, 8.357e-02, 2.883e-01, 1.528e-01, 2.208e-02],
        [1.435e-01, 4.691e-01, 2.086e-01, 2.752e-02, 2.752e-02, 2.086e-01, 4.691e-01, 1.435e-01]],
       dtype=torch.float64)

tensor([[2.206e-02, 1.527e-01, 2.884e-01, 8.368e-02, 8.368e-02, 2.884e-01, 1.527e-01, 2.206e-02],
        [1.438e-01, 4.694e-01, 2.085e-01, 2.751e-02, 2.751e-02, 2.085e-01, 4.694e-01, 1.438e-01]],
       dtype=torch.float64)

Example-05: Tune uncertainty from systematic errors

[1]:
import torch
torch.set_printoptions(precision=3, sci_mode=True)
[2]:
# Set elements

def drif(x, l):
    (qx, px, qy, py), l = x, l
    return torch.stack([qx + l*px, px, qy + l*py, py])

def quad(x, kn, ks, l, n=25):
    (qx, px, qy, py), kn, ks, l = x, kn, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px, qy + l*py
        px, py = px - 2.0*l*(kn*qx - ks*qy), py + 2.0*l*(ks*qx + kn*qy)
        qx, qy = qx + l*px, qy + l*py
    return torch.stack([qx, px, qy, py])
[3]:
# Set transport maps

def m11(x, k): kn1, ks1, kn2, ks2 = k; return x
def m12(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn1, ks1, 0.5); return x
def m23(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 2.5); return x
def m34(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 2.5); return x
def m45(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn2, ks2, 0.5); return x
def m56(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn2, ks2, 0.5); return x
def m67(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 2.5); return x
def m78(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 2.5); return x
def m89(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn1, ks1, 0.5); return x
[4]:
# Set fodo cell

def fodo(x, k):
    x = m11(x, k)
    x = m12(x, k)
    x = m23(x, k)
    x = m34(x, k)
    x = m45(x, k)
    x = m56(x, k)
    x = m67(x, k)
    x = m78(x, k)
    x = m89(x, k)
    return x
[5]:
# Compute tune derivatives

from twiss import twiss

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
k = torch.tensor([0.20, 0.0, -0.25, 0.0], dtype=torch.float64)

def fn(k):
    m = torch.func.jacrev(fodo)(x, k)
    t, *_ = twiss(m)
    return t

grad = torch.func.jacrev(fn)(k)
[6]:
# Compute tune uncertainty for given known knobs uncertanties

sk = torch.tensor([0.01, 0.005, 0.01, 0.005], dtype=torch.float64)

st = (grad**2 @ sk**2).sqrt()
print(st)

st = torch.func.vmap(fn)(k + sk * torch.randn((1024, 4), dtype=torch.float64)).T.std(-1)
print(st)
tensor([1.900e-02, 1.459e-02], dtype=torch.float64)
tensor([1.940e-02, 1.358e-02], dtype=torch.float64)

Example-06: Matching (composable)

[1]:
import torch
torch.set_printoptions(precision=3, sci_mode=True)
[2]:
# Set elements

def drif(x, l):
    (qx, px, qy, py), l = x, l
    return torch.stack([qx + l*px, px, qy + l*py, py])

def quad(x, kn, ks, l, n=25):
    (qx, px, qy, py), kn, ks, l = x, kn, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px, qy + l*py
        px, py = px - 2.0*l*(kn*qx - ks*qy), py + 2.0*l*(ks*qx + kn*qy)
        qx, qy = qx + l*px, qy + l*py
    return torch.stack([qx, px, qy, py])
[3]:
# Set transport maps

def m11(x, k): kn1, kn2 = k; return x
def m12(x, k): kn1, kn2 = k; x = quad(x, kn1, 0.0, 0.5); return x
def m23(x, k): kn1, kn2 = k; x = drif(x, 2.5); return x
def m34(x, k): kn1, kn2 = k; x = drif(x, 2.5); return x
def m45(x, k): kn1, kn2 = k; x = quad(x, kn2, 0.0, 0.5); return x
def m56(x, k): kn1, kn2 = k; x = quad(x, kn2, 0.0, 0.5); return x
def m67(x, k): kn1, kn2 = k; x = drif(x, 2.5); return x
def m78(x, k): kn1, kn2 = k; x = drif(x, 2.5); return x
def m89(x, k): kn1, kn2 = k; x = quad(x, kn1, 0.0, 0.5); return x
[4]:
# Set fodo cell

def fodo(x, k):
    x = m11(x, k)
    x = m12(x, k)
    x = m23(x, k)
    x = m34(x, k)
    x = m45(x, k)
    x = m56(x, k)
    x = m67(x, k)
    x = m78(x, k)
    x = m89(x, k)
    return x
[5]:
# Set objective function (bx & by at focusing quadrupole center)

from twiss import twiss
from twiss import wolski_to_cs

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
k = torch.tensor([+0.20, -0.20], dtype=torch.float64)

BX, BY = torch.tensor([25.0, 5.0], dtype=torch.float64)

bag = []

def objective(k, flag=False):
    m = torch.func.jacrev(fodo)(x, k)
    *_, w = twiss(m)
    ax, bx, ay, by = wolski_to_cs(w)
    if flag:
        bag.append(torch.stack([bx, by]))
    return (bx - BX)**2 +  (by - BY)**2
[6]:
# Set optimizer

def adam(objective, knobs, count=1, lr=0.005, betas=(0.900, 0.999), epsilon=1.0E-9):
    b1, b2 = betas
    history_knobs = []
    history_value = []
    m1 = torch.zeros_like(knobs)
    m2 = torch.zeros_like(knobs)
    for i in range(count):
        grad = torch.func.jacrev(objective)(knobs)
        m1 = b1 * m1 + (1.0 - b1) * grad
        m2 = b2 * m2 + (1.0 - b2) * grad ** 2
        f1 = 1/(1 - b1 ** (i + 1))
        f2 = 1/(1 - b2 ** (i + 1))
        knobs = knobs -  lr * m1 / f1 / (torch.sqrt(m2 / f2) + epsilon)
        value = objective(knobs, flag=True)
        history_knobs.append(knobs)
        history_value.append(value)
    return tuple(map(torch.stack, [history_knobs, history_value]))
[7]:
# Optimize

knobs, value = adam(objective, k, count=128, lr=0.001)
[8]:
# Plot beta values

from matplotlib import pyplot as plt

bxs, bys = torch.stack(bag).T

plt.figure(figsize=(20, 5))
plt.plot(range(128), (bxs - BX).cpu().numpy(), color='blue', marker='x')
plt.plot(range(128), (bys - BY).cpu().numpy(), color='red', marker='x')
plt.show()
../_images/examples_twiss_55_0.png

Example-07: Matching (autograd)

[1]:
import torch
torch.set_printoptions(precision=3, sci_mode=True)
[2]:
# Set elements

def drif(x, l):
    (qx, px, qy, py), l = x, l
    return torch.stack([qx + l*px, px, qy + l*py, py])

def quad(x, kn, ks, l, n=25):
    (qx, px, qy, py), kn, ks, l = x, kn, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px, qy + l*py
        px, py = px - 2.0*l*(kn*qx - ks*qy), py + 2.0*l*(ks*qx + kn*qy)
        qx, qy = qx + l*px, qy + l*py
    return torch.stack([qx, px, qy, py])
[3]:
# Set transport maps

def m11(x, k): kn1, kn2 = k; return x
def m12(x, k): kn1, kn2 = k; x = quad(x, kn1, 0.0, 0.5); return x
def m23(x, k): kn1, kn2 = k; x = drif(x, 2.5); return x
def m34(x, k): kn1, kn2 = k; x = drif(x, 2.5); return x
def m45(x, k): kn1, kn2 = k; x = quad(x, kn2, 0.0, 0.5); return x
def m56(x, k): kn1, kn2 = k; x = quad(x, kn2, 0.0, 0.5); return x
def m67(x, k): kn1, kn2 = k; x = drif(x, 2.5); return x
def m78(x, k): kn1, kn2 = k; x = drif(x, 2.5); return x
def m89(x, k): kn1, kn2 = k; x = quad(x, kn1, 0.0, 0.5); return x
[4]:
# Set fodo cell

def fodo(x, k):
    x = m11(x, k)
    x = m12(x, k)
    x = m23(x, k)
    x = m34(x, k)
    x = m45(x, k)
    x = m56(x, k)
    x = m67(x, k)
    x = m78(x, k)
    x = m89(x, k)
    return x
[5]:
# Set objective function (bx & by at focusing quadrupole center)

from twiss import twiss
from twiss import wolski_to_cs

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
k = torch.tensor([+0.20, -0.20], dtype=torch.float64)

BX, BY = torch.tensor([25.0, 5.0], dtype=torch.float64)

def objective(k):
    m = torch.func.jacrev(fodo)(x, k)
    *_, w = twiss(m)
    ax, bx, ay, by = wolski_to_cs(w)
    return (bx - BX)**2 +  (by - BY)**2
[6]:
# Define model

class Model(torch.nn.Module):

    def __init__(self, objective, knobs):
        super().__init__()
        self.objective = objective
        self.knobs = torch.nn.Parameter(torch.clone(knobs))

    def forward(self):
        return self.objective(self.knobs)

model = Model(objective, k)
[7]:
# Set optimizer

optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
[8]:
# Train
# Note, here forward method is a loss function

model.train()

for _ in range(256):
    loss = model()
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
[9]:
# Check result

k = torch.stack([*model.parameters()]).detach().squeeze()
m = torch.func.jacrev(fodo)(x, k)
*_, w = twiss(m)
ax, bx, ay, by = wolski_to_cs(w)
print(torch.stack([bx, by]))
tensor([2.500e+01, 5.000e+00], dtype=torch.float64)

Example-08: Matching (autograd & taylor model)

[1]:
import torch
torch.set_printoptions(precision=3, sci_mode=True)
[2]:
# Set elements

def drif(x, l):
    (qx, px, qy, py), l = x, l
    return torch.stack([qx + l*px, px, qy + l*py, py])

def quad(x, kn, ks, l, n=25):
    (qx, px, qy, py), kn, ks, l = x, kn, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px, qy + l*py
        px, py = px - 2.0*l*(kn*qx - ks*qy), py + 2.0*l*(ks*qx + kn*qy)
        qx, qy = qx + l*px, qy + l*py
    return torch.stack([qx, px, qy, py])
[3]:
# Set transport maps

def m11(x, k): kn1, kn2 = k; return x
def m12(x, k): kn1, kn2 = k; x = quad(x, kn1, 0.0, 0.5); return x
def m23(x, k): kn1, kn2 = k; x = drif(x, 2.5); return x
def m34(x, k): kn1, kn2 = k; x = drif(x, 2.5); return x
def m45(x, k): kn1, kn2 = k; x = quad(x, kn2, 0.0, 0.5); return x
def m56(x, k): kn1, kn2 = k; x = quad(x, kn2, 0.0, 0.5); return x
def m67(x, k): kn1, kn2 = k; x = drif(x, 2.5); return x
def m78(x, k): kn1, kn2 = k; x = drif(x, 2.5); return x
def m89(x, k): kn1, kn2 = k; x = quad(x, kn1, 0.0, 0.5); return x
[4]:
# Set fodo cell

def fodo(x, k):
    x = m11(x, k)
    x = m12(x, k)
    x = m23(x, k)
    x = m34(x, k)
    x = m45(x, k)
    x = m56(x, k)
    x = m67(x, k)
    x = m78(x, k)
    x = m89(x, k)
    return x
[5]:
# Compute parametric transport matrix

from ndmap.derivative import derivative
from ndmap.evaluate import evaluate

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
k = torch.tensor([+0.20, -0.20], dtype=torch.float64)
t = derivative(4, lambda k: torch.func.jacrev(fodo)(x, k), k, jacobian=torch.func.jacfwd)
[6]:
# Set objective function (bx & by at focusing quadrupole center)

from twiss import twiss
from twiss import wolski_to_cs

dk = torch.tensor([0.0, 0.0], dtype=torch.float64)

BX, BY = torch.tensor([25.0, 5.0], dtype=torch.float64)

def objective(k):
    m = evaluate(t, [k])
    *_, w = twiss(m)
    ax, bx, ay, by = wolski_to_cs(w)
    return torch.stack([bx, by])
[7]:
# Define model

class Model(torch.nn.Module):

    def __init__(self, objective, knobs):
        super().__init__()
        self.objective = objective
        self.knobs = torch.nn.Parameter(torch.clone(knobs))

    def forward(self):
        return self.objective(self.knobs)

model = Model(objective, dk)
[8]:
# Set optimizer

optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
[9]:
# Set loss function

mse = torch.nn.MSELoss()
[10]:
# Train

model.train()

data = []

for _ in range(256):
    beta = model()
    loss = mse(beta, torch.stack([BX, BY]))
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    data.append(loss.item())
[11]:
# Plot loss

from matplotlib import pyplot as plt

plt.figure(figsize=(20, 5))
plt.plot(range(256), data, color='black', marker='x')
plt.show()
../_images/examples_twiss_77_0.png
[12]:
# Check result

dk = torch.stack([*model.parameters()]).detach().squeeze()
m = torch.func.jacrev(fodo)(x, k + dk)
*_, w = twiss(m)
ax, bx, ay, by = wolski_to_cs(w)
print(torch.stack([bx, by]))
tensor([2.500e+01, 5.000e+00], dtype=torch.float64)

Example-09: Matched distribution

[1]:
import torch
torch.set_printoptions(precision=3, sci_mode=True)
[2]:
# Set elements

def drif(x, l):
    (qx, px, qy, py), l = x, l
    return torch.stack([qx + l*px, px, qy + l*py, py])

def quad(x, kn, ks, l, n=25):
    (qx, px, qy, py), kn, ks, l = x, kn, ks, l/(2.0*n)
    for _ in range(n):
        qx, qy = qx + l*px, qy + l*py
        px, py = px - 2.0*l*(kn*qx - ks*qy), py + 2.0*l*(ks*qx + kn*qy)
        qx, qy = qx + l*px, qy + l*py
    return torch.stack([qx, px, qy, py])
[3]:
# Set transport maps

def m11(x, k): kn1, ks1, kn2, ks2 = k; return x
def m12(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 5.0); return x
def m23(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn2, ks2, 1.0); return x
def m34(x, k): kn1, ks1, kn2, ks2 = k; x = drif(x, 5.0); return x
def m45(x, k): kn1, ks1, kn2, ks2 = k; x = quad(x, kn1, ks1, 1.0); return x
[4]:
# Set fodo cell

def fodo(x, k):
    x = m11(x, k)
    x = m12(x, k)
    x = m23(x, k)
    x = m34(x, k)
    x = m45(x, k)
    return x
[5]:
# Compute one-turn transport matrix

x = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
k = torch.tensor([0.20, -0.01, -0.25, 0.01], dtype=torch.float64)

m = torch.func.jacrev(fodo)(x, k)
print(m)

from twiss.wolski import is_stable
print(is_stable(m))
tensor([[ 2.443e+00,  1.933e+01,  6.129e-02,  3.072e-01],
        [-2.357e-01, -1.454e+00,  2.789e-03, -4.458e-02],
        [ 5.838e-02,  2.911e-01, -6.017e-01,  3.789e+00],
        [-3.216e-03, -7.752e-02, -3.310e-01,  4.167e-01]], dtype=torch.float64)
True
[6]:
# Compute coupled twiss parameters

from twiss import twiss
*_, ws = twiss(m)
[7]:
# Generate beam

from torch.distributions.multivariate_normal import MultivariateNormal

ex = 1.0E-6
ey = 1.0E-8

wx, wy = ws

mean = torch.zeros(4, dtype=torch.float64)
covariance_matrix = ex*wx + ey*wy

distribution = MultivariateNormal(mean, covariance_matrix)

beam = distribution.sample((2**14, )).T

print(covariance_matrix)
print()

print(distribution.covariance_matrix)
print()

print(beam.cov())
print()
tensor([[ 2.224e-05, -2.241e-06,  2.924e-07, -8.274e-08],
        [-2.241e-06,  2.709e-07, -3.030e-08,  7.656e-09],
        [ 2.924e-07, -3.030e-08,  4.186e-08,  4.045e-09],
        [-8.274e-08,  7.656e-09,  4.045e-09,  3.639e-09]], dtype=torch.float64)

tensor([[ 2.224e-05, -2.241e-06,  2.924e-07, -8.274e-08],
        [-2.241e-06,  2.709e-07, -3.030e-08,  7.656e-09],
        [ 2.924e-07, -3.030e-08,  4.186e-08,  4.045e-09],
        [-8.274e-08,  7.656e-09,  4.045e-09,  3.639e-09]], dtype=torch.float64)

tensor([[ 2.223e-05, -2.246e-06,  2.874e-07, -8.269e-08],
        [-2.246e-06,  2.726e-07, -3.022e-08,  7.830e-09],
        [ 2.874e-07, -3.022e-08,  4.217e-08,  4.078e-09],
        [-8.269e-08,  7.830e-09,  4.078e-09,  3.673e-09]], dtype=torch.float64)

[8]:
# Same with distribution function

from twiss import normal

bd = normal(mean, torch.tensor([ex, ey], dtype=torch.float64), ws)

beam = bd.sample((2**16, )).T

print(bd.covariance_matrix)
print()

print(beam.cov())
print()
tensor([[ 2.224e-05, -2.241e-06,  2.924e-07, -8.274e-08],
        [-2.241e-06,  2.709e-07, -3.030e-08,  7.656e-09],
        [ 2.924e-07, -3.030e-08,  4.186e-08,  4.045e-09],
        [-8.274e-08,  7.656e-09,  4.045e-09,  3.639e-09]], dtype=torch.float64)

tensor([[ 2.215e-05, -2.234e-06,  2.895e-07, -8.329e-08],
        [-2.234e-06,  2.699e-07, -2.991e-08,  7.692e-09],
        [ 2.895e-07, -2.991e-08,  4.185e-08,  4.045e-09],
        [-8.329e-08,  7.692e-09,  4.045e-09,  3.650e-09]], dtype=torch.float64)

[9]:
# Transformed beam covariance matrix (matrix)

print((m @ beam).cov())
tensor([[ 2.209e-05, -2.227e-06,  2.863e-07, -8.293e-08],
        [-2.227e-06,  2.694e-07, -2.968e-08,  7.647e-09],
        [ 2.863e-07, -2.968e-08,  4.182e-08,  4.077e-09],
        [-8.293e-08,  7.647e-09,  4.077e-09,  3.650e-09]], dtype=torch.float64)
[10]:
# Transformed beam covariance matrix (map)

beam = torch.func.vmap(lambda x: fodo(x, k))(beam.T).T

print(beam.cov())
tensor([[ 2.209e-05, -2.227e-06,  2.863e-07, -8.293e-08],
        [-2.227e-06,  2.694e-07, -2.968e-08,  7.647e-09],
        [ 2.863e-07, -2.968e-08,  4.182e-08,  4.077e-09],
        [-8.293e-08,  7.647e-09,  4.077e-09,  3.650e-09]], dtype=torch.float64)
[11]:
# Prepare data frame for plotting

import pandas

df = pandas.DataFrame(beam.T.cpu().numpy())
df.columns = ['qx', 'px', 'qy', 'py']
df
[11]:
qx px qy py
0 -0.005472 0.000417 0.000104 0.000064
1 -0.002620 0.000006 0.000144 0.000058
2 -0.006187 0.000686 -0.000392 -0.000012
3 0.004921 -0.000598 0.000285 -0.000008
4 0.000305 -0.000059 -0.000060 -0.000067
... ... ... ... ...
65531 0.004505 -0.000694 0.000119 0.000035
65532 -0.006830 0.000694 -0.000085 0.000006
65533 -0.006022 0.000519 -0.000034 -0.000057
65534 -0.004400 0.000133 -0.000101 -0.000076
65535 0.001186 -0.000023 0.000396 0.000022

65536 rows × 4 columns

[12]:
# Plot beam

import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style('darkgrid')

sns.pairplot(df, kind='kde')
plt.show()
../_images/examples_twiss_91_0.png
[ ]: