Example-58: Frequency (parametric derivatives for linear system)

[1]:
# In this example frequencies for linear model are computed from trajectory data
# Frequency values are compared with ones obtained from one-turn matrix
# Derivatives with respect to parameters are also computed and compared
[2]:
# Import

import torch
from torch import Tensor

from pathlib import Path

from model.library.line import Line

from model.command.external import load_sdds
from model.command.external import load_lattice
from model.command.build import build
from model.command.tune import tune
from model.command.tune import chromaticity
from model.command.trajectory import trajectory
from model.command.frequency import filter
from model.command.frequency import frequency_factory
[3]:
# Load ELEGANT twiss

path = Path('ic.twiss')
parameters, columns = load_sdds(path)

NUX:Tensor = torch.tensor(parameters['nux'] % 1, dtype=torch.float64)
NUY:Tensor = torch.tensor(parameters['nuy'] % 1, dtype=torch.float64)
[4]:
# Build and setup lattice

# Load ELEGANT table

path = Path('ic.lte')
data = load_lattice(path)

# Build ELEGANT table

ring:Line = build('RING', 'ELEGANT', data)
ring.flatten()

# Merge drifts

ring.merge()

# Turn off sextupoles and set linear dipoles

for element in ring:
    if element.__class__.__name__ == 'Sextupole':
        element.ms = 0.0
    if element.__class__.__name__ == 'Dipole':
        element.linear = True

# Split BPMs

ring.split((None, ['BPM'], None, None))

# Roll lattice start

ring.roll(1)

# Split lattice into lines by BPMs

ring.splice()

# Set number of elements of different kinds

nb = ring.describe['BPM']
nq = ring.describe['Quadrupole']
ns = ring.describe['Sextupole']
[5]:
# Compute tunes (one-turn matrix)

nux, nuy = tune(ring, [], alignment=False, matched=True, limit=8, epsilon=1.0E-12)

# Compare with elegant

print((NUX - nux).abs())
print((NUY - nuy).abs())
tensor(3.1086e-15, dtype=torch.float64)
tensor(5.5511e-16, dtype=torch.float64)
[6]:
# Compute tunes (trajectory)

# Set trajectory generator

generator = trajectory(ring, [0], matched=True)

# Set initial condition

state = torch.tensor([+1.0E-9, 0.0, -1.0E-9, 0.0], dtype=torch.float64)

# Set window data

window = filter(2**10, 1.0, dtype=ring.dtype, device=ring.device)

# Set frequency generator

frequency = frequency_factory(generator)

# Compute frequencies

nux, nuy = frequency(window, state)

# Compare with elegant

print((NUX - nux).abs())
print((NUY - nuy).abs())
tensor(3.3061e-12, dtype=torch.float64)
tensor(3.1350e-10, dtype=torch.float64)
[7]:
# Derivative with respect to momentum deviation (one-turn matrix)

dp = torch.tensor([0.0], dtype=torch.float64)

print(torch.func.jacrev(lambda dp: tune(ring, [dp], ('dp', None, None, None), matched=True, limit=1, epsilon=None))(dp))
tensor([[-7.8819],
        [-3.9483]], dtype=torch.float64)
[8]:
# Derivative with respect to momentum deviation (trajectory)

# Set parametric trajectory generator

generator = trajectory(ring, [0], ('dp', None, None, None), matched=True)

# Set initial state and momentum deviation
# Note, state should not be equal to zero, since zero is a fixed point,

state = torch.tensor([+1.0E-9, 0.0, -1.0E-9, 0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)

# Set window data

window = filter(2**10, 1.0, dtype=ring.dtype, device=ring.device)

# Set frequency generator

frequency = frequency_factory(generator)

# Compute derivative

print(torch.func.jacrev(lambda dp: frequency(window, state, dp), chunk_size=256)(dp))
tensor([[-7.8819],
        [-3.9486]], dtype=torch.float64)