Example-58: Frequency (parametric derivatives for linear system)
[1]:
# In this example frequencies for linear model are computed from trajectory data
# Frequency values are compared with ones obtained from one-turn matrix
# Derivatives with respect to parameters are also computed and compared
[2]:
# Import
import torch
from torch import Tensor
from pathlib import Path
from model.library.line import Line
from model.command.external import load_sdds
from model.command.external import load_lattice
from model.command.build import build
from model.command.tune import tune
from model.command.tune import chromaticity
from model.command.trajectory import trajectory
from model.command.frequency import filter
from model.command.frequency import frequency_factory
[3]:
# Load ELEGANT twiss
path = Path('ic.twiss')
parameters, columns = load_sdds(path)
NUX:Tensor = torch.tensor(parameters['nux'] % 1, dtype=torch.float64)
NUY:Tensor = torch.tensor(parameters['nuy'] % 1, dtype=torch.float64)
[4]:
# Build and setup lattice
# Load ELEGANT table
path = Path('ic.lte')
data = load_lattice(path)
# Build ELEGANT table
ring:Line = build('RING', 'ELEGANT', data)
ring.flatten()
# Merge drifts
ring.merge()
# Turn off sextupoles and set linear dipoles
for element in ring:
if element.__class__.__name__ == 'Sextupole':
element.ms = 0.0
if element.__class__.__name__ == 'Dipole':
element.linear = True
# Split BPMs
ring.split((None, ['BPM'], None, None))
# Roll lattice start
ring.roll(1)
# Split lattice into lines by BPMs
ring.splice()
# Set number of elements of different kinds
nb = ring.describe['BPM']
nq = ring.describe['Quadrupole']
ns = ring.describe['Sextupole']
[5]:
# Compute tunes (one-turn matrix)
nux, nuy = tune(ring, [], alignment=False, matched=True, limit=8, epsilon=1.0E-12)
# Compare with elegant
print((NUX - nux).abs())
print((NUY - nuy).abs())
tensor(3.1086e-15, dtype=torch.float64)
tensor(5.5511e-16, dtype=torch.float64)
[6]:
# Compute tunes (trajectory)
# Set trajectory generator
generator = trajectory(ring, [0], matched=True)
# Set initial condition
state = torch.tensor([+1.0E-9, 0.0, -1.0E-9, 0.0], dtype=torch.float64)
# Set window data
window = filter(2**10, 1.0, dtype=ring.dtype, device=ring.device)
# Set frequency generator
frequency = frequency_factory(generator)
# Compute frequencies
nux, nuy = frequency(window, state)
# Compare with elegant
print((NUX - nux).abs())
print((NUY - nuy).abs())
tensor(3.3061e-12, dtype=torch.float64)
tensor(3.1350e-10, dtype=torch.float64)
[7]:
# Derivative with respect to momentum deviation (one-turn matrix)
dp = torch.tensor([0.0], dtype=torch.float64)
print(torch.func.jacrev(lambda dp: tune(ring, [dp], ('dp', None, None, None), matched=True, limit=1, epsilon=None))(dp))
tensor([[-7.8819],
[-3.9483]], dtype=torch.float64)
[8]:
# Derivative with respect to momentum deviation (trajectory)
# Set parametric trajectory generator
generator = trajectory(ring, [0], ('dp', None, None, None), matched=True)
# Set initial state and momentum deviation
# Note, state should not be equal to zero, since zero is a fixed point,
state = torch.tensor([+1.0E-9, 0.0, -1.0E-9, 0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)
# Set window data
window = filter(2**10, 1.0, dtype=ring.dtype, device=ring.device)
# Set frequency generator
frequency = frequency_factory(generator)
# Compute derivative
print(torch.func.jacrev(lambda dp: frequency(window, state, dp), chunk_size=256)(dp))
tensor([[-7.8819],
[-3.9486]], dtype=torch.float64)