Example-27: Normalize

[1]:
# In this example normalized objective construction is illustrated
[2]:
# Import

import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
torch.set_printoptions(linewidth=128)

import matplotlib
from matplotlib import pyplot as plt
matplotlib.rcParams['text.usetex'] = True

from twiss import twiss

from ndmap.signature import chop
from ndmap.evaluate import evaluate
from ndmap.pfp import parametric_fixed_point

from model.library.drift import Drift
from model.library.quadrupole import Quadrupole
from model.library.sextupole import Sextupole
from model.library.dipole import Dipole
from model.library.line import Line

from model.command.wrapper import group
from model.command.wrapper import forward
from model.command.wrapper import inverse
from model.command.wrapper import normalize
from model.command.wrapper import Wrapper
[3]:
# Define simple FODO based lattice using nested lines

DR = Drift('DR', 0.25)
BM = Dipole('BM', 3.50, torch.pi/4.0)

QF_A = Quadrupole('QF_A', 0.5, +0.20)
QD_A = Quadrupole('QD_A', 0.5, -0.19)
QF_B = Quadrupole('QF_B', 0.5, +0.20)
QD_B = Quadrupole('QD_B', 0.5, -0.19)
QF_C = Quadrupole('QF_C', 0.5, +0.20)
QD_C = Quadrupole('QD_C', 0.5, -0.19)
QF_D = Quadrupole('QF_D', 0.5, +0.20)
QD_D = Quadrupole('QD_D', 0.5, -0.19)

SF_A = Sextupole('SF_A', 0.25, 0.00)
SD_A = Sextupole('SD_A', 0.25, 0.00)
SF_B = Sextupole('SF_B', 0.25, 0.00)
SD_B = Sextupole('SD_B', 0.25, 0.00)
SF_C = Sextupole('SF_C', 0.25, 0.00)
SD_C = Sextupole('SD_C', 0.25, 0.00)
SF_D = Sextupole('SF_D', 0.25, 0.00)
SD_D = Sextupole('SD_D', 0.25, 0.00)

FODO_A = Line('FODO_A', [QF_A, DR, SF_A, DR, BM, DR, SD_A, DR, QD_A, QD_A, DR, SD_A, DR, BM, DR, SF_A, DR, QF_A], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
FODO_B = Line('FODO_B', [QF_B, DR, SF_B, DR, BM, DR, SD_B, DR, QD_B, QD_B, DR, SD_B, DR, BM, DR, SF_B, DR, QF_B], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
FODO_C = Line('FODO_C', [QF_C, DR, SF_C, DR, BM, DR, SD_C, DR, QD_C, QD_C, DR, SD_C, DR, BM, DR, SF_C, DR, QF_C], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
FODO_D = Line('FODO_D', [QF_D, DR, SF_D, DR, BM, DR, SD_D, DR, QD_D, QD_D, DR, SD_D, DR, BM, DR, SF_D, DR, QF_D], propagate=True, dp=0.0, exact=False, output=False, matrix=False)

RING = Line('RING', [FODO_A, FODO_B, FODO_C, FODO_D], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
[4]:
# Set parametric mapping

ring, *_ = group(RING, 'FODO_A', 'FODO_D', ('ms', ['Sextupole'], None, None), ('dp', None, None, None), root=True)
[5]:
# Construct normalized function

fn = normalize(ring, [(None, None), (-10.0, 10.0), (-0.01, 0.01)])

# Compare with original

fp = torch.tensor([0.001, 0.0005, -0.010, 0.0025], dtype=torch.float64)
ms = torch.tensor([1.0, -1.0, 0.5, 2.0, 4.0, -5.0, -1.0, 3.0], dtype=torch.float64)
dp = torch.tensor([0.005], dtype=torch.float64)

print(ring(fp, ms, dp))
print(fn(*forward([fp, ms, dp],  [(None, None), (-10.0, 10.0), (-0.01, 0.01)])))
tensor([ 0.0157, -0.0006, -0.0189, -0.0032], dtype=torch.float64)
tensor([ 0.0157, -0.0006, -0.0189, -0.0032], dtype=torch.float64)
[6]:
# Set deviation parameters

fp = torch.tensor(4*[0.0], dtype=torch.float64)
ms = torch.tensor(8*[0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)
[7]:
# Define parametric chomaticity function

# Compute parametric fixed point (first order dispersion)

pfp, *_ = parametric_fixed_point((0, 1), fp, [ms, dp], ring)
chop(pfp)

# Define ring around parametric fixed point

def mapping(state, ms, dp):
    return ring(state + evaluate(pfp, [ms, dp]), ms, dp) - evaluate(pfp, [ms, dp])

# Define tunes

def tune(ms, dp):
    matrix = torch.func.jacrev(mapping)(fp, ms, dp)
    tunes, *_ = twiss(matrix)
    return tunes

# Define chromaticity

def chromaticity(ms):
    return torch.func.jacrev(tune, 1)(ms, dp).squeeze()

# Compute natural chromaticity

print(chromaticity(ms))
tensor([-2.0649, -0.8260], dtype=torch.float64)
[8]:
# Chromaticity can be corrected in a single step

# Compute starting values

psix, psiy = chromaticity(ms)

# Set target values

psix_target = torch.tensor(5.0, dtype=torch.float64)
psiy_target = torch.tensor(5.0, dtype=torch.float64)

# Perform correction

dpsix = psix - psix_target
dpsiy = psiy - psiy_target

solution = - torch.linalg.pinv((torch.func.jacrev(chromaticity)(ms)).squeeze()) @ torch.stack([dpsix, dpsiy])
print(solution)

# Test solution

print(chromaticity(solution))
tensor([ 0.7439, -1.2084,  0.7439, -1.2084,  0.7439, -1.2084,  0.7439, -1.2084], dtype=torch.float64)
tensor([5.0000, 5.0000], dtype=torch.float64)
[9]:
# Optimization (wrapping objective funtion and normalization)

# Set model parameters
# Parameters are not cloned inside the module on initialization, values will change during optimization!

ms = torch.tensor(8*[0.0], dtype=torch.float64)
ms, *_ = forward([ms], [(-10, 10)])

# Define scalar objective function

def objective(ms):
    psix, psiy = chromaticity(ms)
    return ((psix - psix_target)**2 + (psiy - psiy_target)**2).sqrt()

print(objective(solution))

# Define normalized objective

objective = normalize(objective, [(-10.0, 10.0)])

print(objective(*forward([solution], [(-10, 10)])))


# Set model (forward returns evaluated objective)

model = Wrapper(objective, ms)

# Set optimizer

optimizer = torch.optim.Adam(model.parameters(), lr=1.0E-3)

# Perfom optimization

epochs = 128
for epoch in range(epochs):

    # Evaluate model
    error = model()

    # Compute derivatives
    error.backward()

    # Perform optimization step
    optimizer.step()

    # Set gradient to zero
    optimizer.zero_grad()

    # Verbose
    knobs, *_ = [*model.parameters()]
    knobs, *_ = inverse([knobs], [(-10, 10)])
    print(error.detach(), (knobs.detach() - solution).norm())
tensor(5.6871e-15, dtype=torch.float64)
tensor(1.1580e-14, dtype=torch.float64)
tensor(9.1573, dtype=torch.float64) tensor(2.7830, dtype=torch.float64)
tensor(8.9651, dtype=torch.float64) tensor(2.7280, dtype=torch.float64)
tensor(8.7737, dtype=torch.float64) tensor(2.6732, dtype=torch.float64)
tensor(8.5832, dtype=torch.float64) tensor(2.6184, dtype=torch.float64)
tensor(8.3937, dtype=torch.float64) tensor(2.5637, dtype=torch.float64)
tensor(8.2052, dtype=torch.float64) tensor(2.5090, dtype=torch.float64)
tensor(8.0177, dtype=torch.float64) tensor(2.4545, dtype=torch.float64)
tensor(7.8314, dtype=torch.float64) tensor(2.4000, dtype=torch.float64)
tensor(7.6464, dtype=torch.float64) tensor(2.3456, dtype=torch.float64)
tensor(7.4627, dtype=torch.float64) tensor(2.2914, dtype=torch.float64)
tensor(7.2804, dtype=torch.float64) tensor(2.2373, dtype=torch.float64)
tensor(7.0995, dtype=torch.float64) tensor(2.1833, dtype=torch.float64)
tensor(6.9202, dtype=torch.float64) tensor(2.1294, dtype=torch.float64)
tensor(6.7426, dtype=torch.float64) tensor(2.0758, dtype=torch.float64)
tensor(6.5666, dtype=torch.float64) tensor(2.0222, dtype=torch.float64)
tensor(6.3924, dtype=torch.float64) tensor(1.9689, dtype=torch.float64)
tensor(6.2200, dtype=torch.float64) tensor(1.9158, dtype=torch.float64)
tensor(6.0495, dtype=torch.float64) tensor(1.8628, dtype=torch.float64)
tensor(5.8808, dtype=torch.float64) tensor(1.8101, dtype=torch.float64)
tensor(5.7141, dtype=torch.float64) tensor(1.7577, dtype=torch.float64)
tensor(5.5492, dtype=torch.float64) tensor(1.7055, dtype=torch.float64)
tensor(5.3862, dtype=torch.float64) tensor(1.6536, dtype=torch.float64)
tensor(5.2250, dtype=torch.float64) tensor(1.6019, dtype=torch.float64)
tensor(5.0655, dtype=torch.float64) tensor(1.5506, dtype=torch.float64)
tensor(4.9077, dtype=torch.float64) tensor(1.4996, dtype=torch.float64)
tensor(4.7514, dtype=torch.float64) tensor(1.4489, dtype=torch.float64)
tensor(4.5965, dtype=torch.float64) tensor(1.3986, dtype=torch.float64)
tensor(4.4428, dtype=torch.float64) tensor(1.3486, dtype=torch.float64)
tensor(4.2902, dtype=torch.float64) tensor(1.2990, dtype=torch.float64)
tensor(4.1384, dtype=torch.float64) tensor(1.2498, dtype=torch.float64)
tensor(3.9873, dtype=torch.float64) tensor(1.2009, dtype=torch.float64)
tensor(3.8366, dtype=torch.float64) tensor(1.1523, dtype=torch.float64)
tensor(3.6862, dtype=torch.float64) tensor(1.1041, dtype=torch.float64)
tensor(3.5359, dtype=torch.float64) tensor(1.0563, dtype=torch.float64)
tensor(3.3857, dtype=torch.float64) tensor(1.0088, dtype=torch.float64)
tensor(3.2353, dtype=torch.float64) tensor(0.9616, dtype=torch.float64)
tensor(3.0848, dtype=torch.float64) tensor(0.9148, dtype=torch.float64)
tensor(2.9343, dtype=torch.float64) tensor(0.8682, dtype=torch.float64)
tensor(2.7836, dtype=torch.float64) tensor(0.8220, dtype=torch.float64)
tensor(2.6331, dtype=torch.float64) tensor(0.7759, dtype=torch.float64)
tensor(2.4828, dtype=torch.float64) tensor(0.7301, dtype=torch.float64)
tensor(2.3329, dtype=torch.float64) tensor(0.6845, dtype=torch.float64)
tensor(2.1835, dtype=torch.float64) tensor(0.6390, dtype=torch.float64)
tensor(2.0349, dtype=torch.float64) tensor(0.5936, dtype=torch.float64)
tensor(1.8870, dtype=torch.float64) tensor(0.5482, dtype=torch.float64)
tensor(1.7399, dtype=torch.float64) tensor(0.5028, dtype=torch.float64)
tensor(1.5934, dtype=torch.float64) tensor(0.4573, dtype=torch.float64)
tensor(1.4474, dtype=torch.float64) tensor(0.4116, dtype=torch.float64)
tensor(1.3016, dtype=torch.float64) tensor(0.3657, dtype=torch.float64)
tensor(1.1557, dtype=torch.float64) tensor(0.3195, dtype=torch.float64)
tensor(1.0093, dtype=torch.float64) tensor(0.2729, dtype=torch.float64)
tensor(0.8622, dtype=torch.float64) tensor(0.2261, dtype=torch.float64)
tensor(0.7146, dtype=torch.float64) tensor(0.1791, dtype=torch.float64)
tensor(0.5667, dtype=torch.float64) tensor(0.1320, dtype=torch.float64)
tensor(0.4196, dtype=torch.float64) tensor(0.0852, dtype=torch.float64)
tensor(0.2748, dtype=torch.float64) tensor(0.0393, dtype=torch.float64)
tensor(0.1340, dtype=torch.float64) tensor(0.0068, dtype=torch.float64)
tensor(0.0262, dtype=torch.float64) tensor(0.0432, dtype=torch.float64)
tensor(0.1620, dtype=torch.float64) tensor(0.0727, dtype=torch.float64)
tensor(0.2647, dtype=torch.float64) tensor(0.0948, dtype=torch.float64)
tensor(0.3263, dtype=torch.float64) tensor(0.1111, dtype=torch.float64)
tensor(0.3638, dtype=torch.float64) tensor(0.1224, dtype=torch.float64)
tensor(0.3894, dtype=torch.float64) tensor(0.1289, dtype=torch.float64)
tensor(0.4067, dtype=torch.float64) tensor(0.1301, dtype=torch.float64)
tensor(0.4121, dtype=torch.float64) tensor(0.1261, dtype=torch.float64)
tensor(0.4008, dtype=torch.float64) tensor(0.1169, dtype=torch.float64)
tensor(0.3711, dtype=torch.float64) tensor(0.1032, dtype=torch.float64)
tensor(0.3262, dtype=torch.float64) tensor(0.0861, dtype=torch.float64)
tensor(0.2728, dtype=torch.float64) tensor(0.0667, dtype=torch.float64)
tensor(0.2177, dtype=torch.float64) tensor(0.0458, dtype=torch.float64)
tensor(0.1600, dtype=torch.float64) tensor(0.0233, dtype=torch.float64)
tensor(0.0867, dtype=torch.float64) tensor(0.0035, dtype=torch.float64)
tensor(0.0227, dtype=torch.float64) tensor(0.0224, dtype=torch.float64)
tensor(0.0708, dtype=torch.float64) tensor(0.0378, dtype=torch.float64)
tensor(0.1198, dtype=torch.float64) tensor(0.0465, dtype=torch.float64)
tensor(0.1470, dtype=torch.float64) tensor(0.0496, dtype=torch.float64)
tensor(0.1569, dtype=torch.float64) tensor(0.0479, dtype=torch.float64)
tensor(0.1535, dtype=torch.float64) tensor(0.0422, dtype=torch.float64)
tensor(0.1370, dtype=torch.float64) tensor(0.0329, dtype=torch.float64)
tensor(0.1064, dtype=torch.float64) tensor(0.0205, dtype=torch.float64)
tensor(0.0647, dtype=torch.float64) tensor(0.0052, dtype=torch.float64)
tensor(0.0201, dtype=torch.float64) tensor(0.0155, dtype=torch.float64)
tensor(0.0578, dtype=torch.float64) tensor(0.0268, dtype=torch.float64)
tensor(0.0859, dtype=torch.float64) tensor(0.0327, dtype=torch.float64)
tensor(0.1052, dtype=torch.float64) tensor(0.0344, dtype=torch.float64)
tensor(0.1163, dtype=torch.float64) tensor(0.0319, dtype=torch.float64)
tensor(0.1068, dtype=torch.float64) tensor(0.0257, dtype=torch.float64)
tensor(0.0814, dtype=torch.float64) tensor(0.0165, dtype=torch.float64)
tensor(0.0557, dtype=torch.float64) tensor(0.0028, dtype=torch.float64)
tensor(0.0120, dtype=torch.float64) tensor(0.0180, dtype=torch.float64)
tensor(0.0825, dtype=torch.float64) tensor(0.0273, dtype=torch.float64)
tensor(0.1034, dtype=torch.float64) tensor(0.0305, dtype=torch.float64)
tensor(0.0967, dtype=torch.float64) tensor(0.0315, dtype=torch.float64)
tensor(0.1101, dtype=torch.float64) tensor(0.0296, dtype=torch.float64)
tensor(0.1152, dtype=torch.float64) tensor(0.0230, dtype=torch.float64)
tensor(0.0848, dtype=torch.float64) tensor(0.0143, dtype=torch.float64)
tensor(0.0472, dtype=torch.float64) tensor(0.0074, dtype=torch.float64)
tensor(0.0454, dtype=torch.float64) tensor(0.0103, dtype=torch.float64)
tensor(0.0325, dtype=torch.float64) tensor(0.0179, dtype=torch.float64)
tensor(0.0641, dtype=torch.float64) tensor(0.0188, dtype=torch.float64)
tensor(0.0619, dtype=torch.float64) tensor(0.0155, dtype=torch.float64)
tensor(0.0520, dtype=torch.float64) tensor(0.0098, dtype=torch.float64)
tensor(0.0419, dtype=torch.float64) tensor(0.0017, dtype=torch.float64)
tensor(0.0103, dtype=torch.float64) tensor(0.0091, dtype=torch.float64)
tensor(0.0345, dtype=torch.float64) tensor(0.0123, dtype=torch.float64)
tensor(0.0390, dtype=torch.float64) tensor(0.0117, dtype=torch.float64)
tensor(0.0417, dtype=torch.float64) tensor(0.0071, dtype=torch.float64)
tensor(0.0244, dtype=torch.float64) tensor(0.0044, dtype=torch.float64)
tensor(0.0279, dtype=torch.float64) tensor(0.0079, dtype=torch.float64)
tensor(0.0253, dtype=torch.float64) tensor(0.0103, dtype=torch.float64)
tensor(0.0361, dtype=torch.float64) tensor(0.0074, dtype=torch.float64)
tensor(0.0241, dtype=torch.float64) tensor(0.0023, dtype=torch.float64)
tensor(0.0138, dtype=torch.float64) tensor(0.0086, dtype=torch.float64)
tensor(0.0475, dtype=torch.float64) tensor(0.0110, dtype=torch.float64)
tensor(0.0466, dtype=torch.float64) tensor(0.0118, dtype=torch.float64)
tensor(0.0417, dtype=torch.float64) tensor(0.0105, dtype=torch.float64)
tensor(0.0462, dtype=torch.float64) tensor(0.0033, dtype=torch.float64)
tensor(0.0105, dtype=torch.float64) tensor(0.0076, dtype=torch.float64)
tensor(0.0357, dtype=torch.float64) tensor(0.0115, dtype=torch.float64)
tensor(0.0363, dtype=torch.float64) tensor(0.0128, dtype=torch.float64)
tensor(0.0477, dtype=torch.float64) tensor(0.0097, dtype=torch.float64)
tensor(0.0349, dtype=torch.float64) tensor(0.0050, dtype=torch.float64)
tensor(0.0254, dtype=torch.float64) tensor(0.0042, dtype=torch.float64)
tensor(0.0137, dtype=torch.float64) tensor(0.0082, dtype=torch.float64)
tensor(0.0353, dtype=torch.float64) tensor(0.0063, dtype=torch.float64)
tensor(0.0205, dtype=torch.float64) tensor(0.0052, dtype=torch.float64)
tensor(0.0327, dtype=torch.float64) tensor(0.0035, dtype=torch.float64)
tensor(0.0119, dtype=torch.float64) tensor(0.0082, dtype=torch.float64)
[10]:
# Compare

print(solution)
print(*inverse([ms], [(-10, 10)]))
tensor([ 0.7439, -1.2084,  0.7439, -1.2084,  0.7439, -1.2084,  0.7439, -1.2084], dtype=torch.float64)
tensor([ 0.7412, -1.2115,  0.7412, -1.2115,  0.7412, -1.2115,  0.7412, -1.2115], dtype=torch.float64)