Example-29: Orbit (fixed point computation)
[1]:
# In this example computation of fixed points is illustrated
# Fixed points are computed for given initial guess using Newton root search method
# Closed orbit is computed, which is special case of period one stable (elliptic) fixed point corresponding to center manifold
# Also, period five fixed point is computed (restricted to horizontal plane)
[2]:
# Import
import torch
torch.set_printoptions(linewidth=128)
import matplotlib
from matplotlib import pyplot as plt
matplotlib.rcParams['text.usetex'] = True
from twiss import twiss
from ndmap.signature import chop
from ndmap.evaluate import evaluate
from ndmap.pfp import parametric_fixed_point
from ndmap.pfp import clean_point
from ndmap.pfp import chain_point
from ndmap.pfp import matrix
from model.library.drift import Drift
from model.library.quadrupole import Quadrupole
from model.library.sextupole import Sextupole
from model.library.dipole import Dipole
from model.library.line import Line
from model.command.wrapper import group
from model.command.orbit import orbit
[3]:
# Define simple FODO based lattice using nested lines
DR = Drift('DR', 0.25)
BM = Dipole('BM', 3.50, torch.pi/4.0)
QF_A = Quadrupole('QF_A', 0.5, +0.20)
QD_A = Quadrupole('QD_A', 0.5, -0.19)
QF_B = Quadrupole('QF_B', 0.5, +0.20)
QD_B = Quadrupole('QD_B', 0.5, -0.19)
QF_C = Quadrupole('QF_C', 0.5, +0.20)
QD_C = Quadrupole('QD_C', 0.5, -0.19)
QF_D = Quadrupole('QF_D', 0.5, +0.20)
QD_D = Quadrupole('QD_D', 0.5, -0.19)
SF_A = Sextupole('SF_A', 0.25, 0.00)
SD_A = Sextupole('SD_A', 0.25, 0.00)
SF_B = Sextupole('SF_B', 0.25, 0.00)
SD_B = Sextupole('SD_B', 0.25, 0.00)
SF_C = Sextupole('SF_C', 0.25, 0.00)
SD_C = Sextupole('SD_C', 0.25, 0.00)
SF_D = Sextupole('SF_D', 0.25, 0.00)
SD_D = Sextupole('SD_D', 0.25, 0.00)
FODO_A = Line('FODO_A', [QF_A, DR, SF_A, DR, BM, DR, SD_A, DR, QD_A, QD_A, DR, SD_A, DR, BM, DR, SF_A, DR, QF_A], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
FODO_B = Line('FODO_B', [QF_B, DR, SF_B, DR, BM, DR, SD_B, DR, QD_B, QD_B, DR, SD_B, DR, BM, DR, SF_B, DR, QF_B], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
FODO_C = Line('FODO_C', [QF_C, DR, SF_C, DR, BM, DR, SD_C, DR, QD_C, QD_C, DR, SD_C, DR, BM, DR, SF_C, DR, QF_C], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
FODO_D = Line('FODO_D', [QF_D, DR, SF_D, DR, BM, DR, SD_D, DR, QD_D, QD_D, DR, SD_D, DR, BM, DR, SF_D, DR, QF_D], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
RING = Line('RING', [FODO_A, FODO_B, FODO_C, FODO_D], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
[4]:
# Correct chromaticity
# Set parametric mapping
ring, *_ = group(RING, 'FODO_A', 'FODO_D', ('ms', ['Sextupole'], None, None), ('dp', None, None, None), root=True, alignment=False)
# Set deviation parameters
fp = torch.tensor(4*[0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)
ms = torch.tensor(8*[0.0], dtype=torch.float64)
# Compute first order parametric fixed point with respect to momentum deviation
pfp, *_ = parametric_fixed_point((0, 1), fp, [ms, dp], ring)
chop(pfp)
# Define ring around fixed point
def mapping(state, ms, dp):
return ring(state + evaluate(pfp, [ms, dp]), ms, dp) - evaluate(pfp, [ms, dp])
# Tune
def tune(ms, dp):
matrix = torch.func.jacrev(mapping)(fp, ms, dp)
tunes, *_ = twiss(matrix)
return tunes
# Chromaticity
def chromaticity(ms):
return torch.func.jacrev(tune, 1)(ms, dp)
# Initial chomaticity values
psix, psiy = chromaticity(ms).squeeze()
# Define target chomaticity values
psix_target = torch.tensor(5.0, dtype=torch.float64)
psiy_target = torch.tensor(5.0, dtype=torch.float64)
# Perform correction
dpsix = psix - psix_target
dpsiy = psiy - psiy_target
# Set solution
solution = - torch.linalg.pinv((torch.func.jacrev(chromaticity)(ms)).squeeze()) @ torch.stack([dpsix, dpsiy])
# Set sextupoles
# Note, ring function in not effected
SF_A.ms, SD_A.ms, SF_B.ms, SD_B.ms, SF_C.ms, SD_C.ms, SF_D.ms, SD_D.ms = solution.tolist()
# Check chromaticity
print(chromaticity(solution).squeeze())
# Plot tunes vs momentum deviation
nux, nuy = tune(solution, dp)
dps = torch.linspace(-5.0E-3, 5.0E-3, 16, dtype=torch.float64)
nuxs, nuys = torch.stack([tune(solution, dp) for dp in dps.reshape(-1, 1)]).T
plt.figure(figsize=(16, 4))
plt.plot(dps.cpu().numpy(), (nux + psix_target*dps).cpu().numpy(), color='red', linestyle='dashed')
plt.scatter(dps.cpu().numpy(), nuxs.cpu().numpy(), color='black', marker='x')
plt.plot(dps.cpu().numpy(), (nuy + psiy_target*dps).cpu().numpy(), color='blue', linestyle='dashed')
plt.scatter(dps.cpu().numpy(), nuys.cpu().numpy(), color='black', marker='x')
plt.tight_layout()
plt.show()
tensor([5.0000, 5.0000], dtype=torch.float64)
![../_images/examples_model-28_4_1.png](../_images/examples_model-28_4_1.png)
[5]:
# Generate and plot phase space trajectories
qx = torch.linspace(0.10, 0.4, 16, dtype=torch.float64)
px = torch.zeros_like(qx)
qy = torch.zeros_like(qx)
py = torch.zeros_like(qx)
state = torch.stack([qx, px, qy, py]).T
trjs = []
for _ in range(2**10):
state = torch.vmap(RING)(state)
trjs.append(state)
qx, px, *_ = torch.stack(trjs).swapaxes(0, -1)
plt.figure(figsize=(6, 6))
plt.scatter(qx.cpu().numpy(), px.cpu().numpy(), s=1, color='black')
plt.xlim(-1.0, 0.5)
plt.ylim(-0.075, 0.075)
plt.tight_layout()
plt.show()
![../_images/examples_model-28_5_0.png](../_images/examples_model-28_5_0.png)
[6]:
# Compute closed orbit (period one fixed point)
# Set initial guess
guess = 1.0E-3*torch.tensor([1.0, -1.0, 1.0, -1.0], dtype=torch.float64)
# Compute without deviation parameters and groups
point, table = orbit(RING, guess, [], limit=8, epsilon=1.0E-6)
print(table)
print(point)
print(RING(point))
print()
# Compute matrix around closed orbit
print(torch.func.jacrev(RING)(point))
print()
print(matrix(1, ring, point, ms, dp, jacobian=torch.func.jacrev))
print()
# Classify fixed point
values, _ = torch.linalg.eig(matrix(1, ring, point, ms, dp, jacobian=torch.func.jacrev))
print(values.log().real)
print()
# Deviation parameters are passed after the initial guess, followed by deviation groups
point, table = orbit(RING, guess, [ms, dp], ('ms', ['Sextupole'], None, None), ('dp', None, None, None), limit=8, epsilon=1.0E-6)
print(table)
print(point)
print(RING(point))
print()
# Track closed orbit
# Note, number of points is equal to the number of lines plus one (full=True, default) or number of lines (full=False)
points, table = orbit(RING, guess, [ms, dp], ('ms', ['Sextupole'], None, None), ('dp', None, None, None), advance=True, limit=8, epsilon=1.0E-6)
print(table)
print(points)
print(points.shape)
print(len(RING))
print()
# Closed orbit with non-zero deviation parameters
# Note, alignment flag should be explicitly passed
fp = torch.tensor(4*[0.0], dtype=torch.float64)
dx = torch.tensor([-0.001], dtype=torch.float64)
dy = torch.tensor([+0.001], dtype=torch.float64)
dp = torch.tensor([0.0005], dtype=torch.float64)
ring, *_ = group(RING, 'FODO_A', 'FODO_D', ('dx', None, ['QD_A'], None), ('dy', None, ['QD_A'], None), ('dp', None, None, None), root=True, alignment=True)
print(fp)
print(ring(fp, dx, dy, dp))
print()
point, _ = orbit(RING, guess, [dx, dy, dp], ('dx', None, ['QD_A'], None), ('dy', None, ['QD_A'], None), ('dp', None, None, None), alignment=True, limit=8, epsilon=1.0E-6)
print(point)
print(ring(point, dx, dy, dp))
print()
[]
tensor([ 2.9857e-18, -2.0620e-19, -2.9122e-19, -3.0706e-20], dtype=torch.float64)
tensor([ 2.6010e-18, 2.2073e-19, 2.7221e-19, -3.4789e-20], dtype=torch.float64)
tensor([[-3.3823e-01, -1.7512e+01, 4.0542e-19, 7.6920e-18],
[ 5.0572e-02, -3.3823e-01, -7.0284e-20, 1.3939e-19],
[-1.4957e-19, -8.2174e-18, -2.9764e-01, -6.0422e+00],
[ 6.4815e-20, -4.3138e-19, 1.5084e-01, -2.9764e-01]], dtype=torch.float64)
tensor([[-3.3823e-01, -1.7512e+01, -3.8455e-33, -5.8351e-32],
[ 5.0572e-02, -3.3823e-01, 4.9958e-34, -5.6691e-34],
[ 7.8130e-34, 5.7800e-32, -2.9764e-01, -6.0422e+00],
[-4.8163e-34, 4.4705e-33, 1.5084e-01, -2.9764e-01]], dtype=torch.float64)
tensor([-1.6398e-15, -1.6398e-15, -1.2069e-16, -1.2069e-16], dtype=torch.float64)
[(None, ['SF_A', 'SD_A', 'SF_B', 'SD_B', 'SF_C', 'SD_C', 'SF_D', 'SD_D'], 'ms'), (None, None, 'dp')]
tensor([ 2.9857e-18, -2.0620e-19, -2.9123e-19, -3.0706e-20], dtype=torch.float64)
tensor([ 2.6010e-18, 2.2073e-19, 2.7221e-19, -3.4789e-20], dtype=torch.float64)
[(None, ['SF_A', 'SD_A', 'SF_B', 'SD_B', 'SF_C', 'SD_C', 'SF_D', 'SD_D'], 'ms'), (None, None, 'dp')]
tensor([[ 2.9857e-18, -2.0620e-19, -2.9123e-19, -3.0706e-20],
[-4.4180e-18, 1.0905e-19, -3.0486e-19, 2.7202e-20],
[ 4.8561e-18, 1.2635e-20, 1.6041e-20, 5.5260e-20],
[-4.2013e-18, -1.3148e-19, 3.1934e-19, 2.2679e-20],
[ 2.6010e-18, 2.2073e-19, 2.7221e-19, -3.4789e-20]], dtype=torch.float64)
torch.Size([5, 4])
4
tensor([0., 0., 0., 0.], dtype=torch.float64)
tensor([ 0.0030, -0.0002, -0.0014, -0.0003], dtype=torch.float64)
tensor([ 2.5908e-03, -2.2607e-05, -1.1890e-04, -2.1272e-04], dtype=torch.float64)
tensor([ 2.5908e-03, -2.2607e-05, -1.1890e-04, -2.1272e-04], dtype=torch.float64)
[7]:
# Locate period five fixed points
dp = torch.tensor([0.0], dtype=torch.float64)
ms = torch.tensor(8*[0.0], dtype=torch.float64)
# Set fixed point period
power = 5
# Set tolerance epsilon
epsilon = 1.0E-9
# Set random initial points
qx = 1.0*torch.rand(256, dtype=torch.float64) - 0.50
px = 0.1*torch.rand(256, dtype=torch.float64) - 0.05
qy = torch.zeros_like(qx)
py = torch.zeros_like(px)
points = torch.stack([qx, px, qy, py]).T
def task(guess):
point, _ = orbit(RING,
guess,
[ms, dp],
('ms', ['Sextupole'], None, None),
('dp', None, None, None),
limit=128,
power=power,
epsilon=None)
return point
# Perform root search iterations for each initial point
points = torch.func.vmap(task)(points)
# Set parametric ring
ring, *_ = group(RING, 'FODO_A', 'FODO_D', ('ms', ['Sextupole'], None, None), ('dp', None, None, None), root=True, alignment=True)
# Iterate
for _ in range(128):
locals = torch.vmap(lambda state: ring(state, ms, dp))(points)
# Remove solutions with large norms
points = points[locals.norm(1, dim=-1) < 0.5]
# Remove unconverged solutions
mask = []
for point in points:
local = point.clone()
for _ in range(power):
local = ring(local, ms, dp)
mask.append((local - point).norm() < epsilon)
points = points[mask]
# Clean points (remove nans, duplicates, points from the same chain)
points = clean_point(power, ring, points, ms, dp, epsilon=epsilon)
# Generate fixed point chains
chains = torch.func.vmap(lambda point: chain_point(power, ring, point, ms, dp))(points)
# Classify fixed point chains (elliptic vs hyperbolic)
# Generate initials for hyperbolic fixed points using corresponding eigenvectors
kinds = []
for chain in chains:
point, *_ = chain
values, vectors = torch.linalg.eig(matrix(power, ring, point, ms, dp))
kind = values.log().real.prod() < epsilon
kinds.append(bool(kind))
if not kind:
lines = [point + vector*torch.linspace(-epsilon, +epsilon, 128, dtype=torch.float64).reshape(-1, 1) for vector in vectors.real.T]
lines = torch.stack(lines).reshape(-1, 4)
# Remove vertical plane in chains
qx, px, *_ = chains.swapaxes(0, -1)
chains = torch.stack([qx, px]).swapaxes(0, -1)
# Iterate lines and remove vertical plane
manifold = []
for _ in range(64):
manifold.append(lines)
lines = torch.func.vmap(lambda point: ring(point, ms, dp))(lines)
manifold = torch.stack(manifold)
# Remove vertical plane in lines (including nonlinear leaking)
qx, px, qy, py = manifold.swapaxes(0, -1)
qx = qx[qy.abs() + py.abs() < epsilon]
px = px[qy.abs() + py.abs() < epsilon]
manifold = torch.stack([qx, px])
# Plot
plt.figure(figsize=(6, 6))
qx, px, *_ = torch.stack(trjs).swapaxes(0, -1)
plt.scatter(qx.cpu().numpy(), px.cpu().numpy(), s=1, color='black')
qx, px = manifold
plt.scatter(qx.flatten().cpu().numpy(), px.flatten().cpu().numpy(), s=1, color='grey', alpha=0.5)
for chain, kind in zip(chains, kinds):
plt.scatter(*chain.T, color = {True:'blue', False:'red'}[kind], marker='o')
plt.xlim(-1.0, 0.5)
plt.ylim(-0.075, 0.075)
plt.tight_layout()
plt.show()
![../_images/examples_model-28_7_0.png](../_images/examples_model-28_7_0.png)