Example-25: Group
[1]:
# In this example another wrapper construction procedure is illustraded
[2]:
# Import
import torch
torch.set_printoptions(linewidth=128)
from twiss import twiss
from twiss import propagate
from twiss import wolski_to_cs
from model.library.drift import Drift
from model.library.quadrupole import Quadrupole
from model.library.dipole import Dipole
from model.library.line import Line
from model.command.wrapper import group
[3]:
# Define simple FODO based lattice using nested lines
DR = Drift('DR', 0.75)
BM = Dipole('BM', 3.50, torch.pi/4.0)
QF_A = Quadrupole('QF_A', 0.5, +0.20)
QD_A = Quadrupole('QD_A', 0.5, -0.19)
QF_B = Quadrupole('QF_B', 0.5, +0.20)
QD_B = Quadrupole('QD_B', 0.5, -0.19)
QF_C = Quadrupole('QF_C', 0.5, +0.20)
QD_C = Quadrupole('QD_C', 0.5, -0.19)
QF_D = Quadrupole('QF_D', 0.5, +0.20)
QD_D = Quadrupole('QD_D', 0.5, -0.19)
FODO_A = Line('FODO_A', [QF_A, DR, BM, DR, QD_A, QD_A, DR, BM, DR, QF_A], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
FODO_B = Line('FODO_B', [QF_B, DR, BM, DR, QD_B, QD_B, DR, BM, DR, QF_B], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
FODO_C = Line('FODO_C', [QF_C, DR, BM, DR, QD_C, QD_C, DR, BM, DR, QF_C], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
FODO_D = Line('FODO_D', [QF_D, DR, BM, DR, QD_D, QD_D, DR, BM, DR, QF_D], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
RING = Line('RING', [FODO_A, FODO_B, FODO_C, FODO_D], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
[4]:
# Full ring or subline can be wrapped by element kind
fn, table, line = group(RING, # -- source line
'FODO_A', # -- start (name or position in source line sequence)
'FODO_B', # -- end (name or position in source line sequence)
('kn', ['Quadrupole'], None, None)) # -- groups (key:str, kinds:list[str]|None, names:list[str]|None, clean:list[str]|None
# Information about deviation variables is returbed in wrapper format
print(table)
print()
# Wrapped function fn can be called with deviation variables
(_, names, _), *_ = table
knobs = torch.tensor(len(names)*[0.0], dtype=torch.float64)
state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
print(fn(state, knobs))
print()
# Constructed line also returned
print(line)
print()
[(None, ['QF_A', 'QD_A', 'QF_B', 'QD_B'], 'kn')]
tensor([0., 0., 0., 0.], dtype=torch.float64)
Quadrupole(name="QF_A", length=0.5, kn=0.200000000000001, ks=0.0, dp=0.0, exact=False, ns=1, order=0)
Drift(name="DR", length=0.75, dp=0.0, exact=False, ns=1, order=0)
Dipole(name="BM", length=3.5, angle=0.7853981633974493, e1=0.0, e2=0.0, kn=1e-15, ks=0.0, ms=0.0, mo=0.0, dp=0.0, exact=False, ns=1, order=0)
Drift(name="DR", length=0.75, dp=0.0, exact=False, ns=1, order=0)
Quadrupole(name="QD_A", length=0.5, kn=-0.189999999999999, ks=0.0, dp=0.0, exact=False, ns=1, order=0)
Quadrupole(name="QD_A", length=0.5, kn=-0.189999999999999, ks=0.0, dp=0.0, exact=False, ns=1, order=0)
Drift(name="DR", length=0.75, dp=0.0, exact=False, ns=1, order=0)
Dipole(name="BM", length=3.5, angle=0.7853981633974493, e1=0.0, e2=0.0, kn=1e-15, ks=0.0, ms=0.0, mo=0.0, dp=0.0, exact=False, ns=1, order=0)
Drift(name="DR", length=0.75, dp=0.0, exact=False, ns=1, order=0)
Quadrupole(name="QF_A", length=0.5, kn=0.200000000000001, ks=0.0, dp=0.0, exact=False, ns=1, order=0)
Quadrupole(name="QF_B", length=0.5, kn=0.200000000000001, ks=0.0, dp=0.0, exact=False, ns=1, order=0)
Drift(name="DR", length=0.75, dp=0.0, exact=False, ns=1, order=0)
Dipole(name="BM", length=3.5, angle=0.7853981633974493, e1=0.0, e2=0.0, kn=1e-15, ks=0.0, ms=0.0, mo=0.0, dp=0.0, exact=False, ns=1, order=0)
Drift(name="DR", length=0.75, dp=0.0, exact=False, ns=1, order=0)
Quadrupole(name="QD_B", length=0.5, kn=-0.189999999999999, ks=0.0, dp=0.0, exact=False, ns=1, order=0)
Quadrupole(name="QD_B", length=0.5, kn=-0.189999999999999, ks=0.0, dp=0.0, exact=False, ns=1, order=0)
Drift(name="DR", length=0.75, dp=0.0, exact=False, ns=1, order=0)
Dipole(name="BM", length=3.5, angle=0.7853981633974493, e1=0.0, e2=0.0, kn=1e-15, ks=0.0, ms=0.0, mo=0.0, dp=0.0, exact=False, ns=1, order=0)
Drift(name="DR", length=0.75, dp=0.0, exact=False, ns=1, order=0)
Quadrupole(name="QF_B", length=0.5, kn=0.200000000000001, ks=0.0, dp=0.0, exact=False, ns=1, order=0)
[5]:
# By default names are excracted from created subline
_, table, _ = group(RING, 'FODO_A', 'FODO_B', ('kn', ['Quadrupole'], None, None))
print(table)
print()
# Use root flag to extract name from the root line
_, table, _ = group(RING, 'FODO_A', 'FODO_B', ('kn', ['Quadrupole'], None, None), root=True)
print(table)
print()
[(None, ['QF_A', 'QD_A', 'QF_B', 'QD_B'], 'kn')]
[(None, ['QF_A', 'QD_A', 'QF_B', 'QD_B', 'QF_C', 'QD_C', 'QF_D', 'QD_D'], 'kn')]
[6]:
# Set transport between observation points
# 0--A--1--B--2--C--3--D--4
line01, *_ = group(RING, 'FODO_A', 'FODO_A', ('kn', ['Quadrupole'], None, None), root=True)
line12, *_ = group(RING, 'FODO_B', 'FODO_B', ('kn', ['Quadrupole'], None, None), root=True)
line23, *_ = group(RING, 'FODO_C', 'FODO_C', ('kn', ['Quadrupole'], None, None), root=True)
line34, *_ = group(RING, 'FODO_D', 'FODO_D', ('kn', ['Quadrupole'], None, None), root=True)
lines = [
line01,
line12,
line23,
line34
]
def ring(state, knobs):
for line in lines:
state = line(state, knobs)
return state
state = torch.tensor(4*[0.0], dtype=torch.float64)
knobs = torch.tensor(8*[0.0], dtype=torch.float64)
print(ring(state, knobs))
tensor([0., 0., 0., 0.], dtype=torch.float64)
[7]:
# Compute tunes and corresponding derivatives with respect to deviation parameters
def fn(knobs):
m = torch.func.jacfwd(ring)(state, knobs)
t, *_ = twiss(m)
return t
print(fn(knobs))
print(torch.func.jacrev(fn)(knobs))
tensor([0.6951, 0.7019], dtype=torch.float64)
tensor([[ 1.4567, 0.1055, 1.4567, 0.1055, 1.4567, 0.1055, 1.4567, 0.1055],
[-0.5132, -1.6271, -0.5132, -1.6271, -0.5132, -1.6271, -0.5132, -1.6271]], dtype=torch.float64)
[8]:
# Compute beta functions at observation points and corresponding derivatives with respect to deviation parameters
def fn(knobs):
bxs = []
bys = []
m = torch.func.jacfwd(ring)(state, knobs)
*_, w = twiss(m)
_, bx, _, by = wolski_to_cs(w)
for line in lines:
w = propagate(w, torch.func.jacrev(line)(state, knobs))
_, bx, _, by = wolski_to_cs(w)
bxs.append(bx)
bys.append(by)
bxs = torch.stack(bxs)
bys = torch.stack(bys)
return bxs, bys
bx, by = fn(knobs)
dbxdk, dbydk = torch.func.jacrev(fn)(knobs)
print(bx)
print(dbxdk)
print()
print(by)
print(dbydk)
print()
tensor([18.6083, 18.6083, 18.6083, 18.6083], dtype=torch.float64)
tensor([[ 21.1373, -1.5714, 21.1373, -1.5714, 140.4889, -10.4446, 140.4889, -10.4446],
[140.4889, -10.4446, 21.1373, -1.5714, 21.1373, -1.5714, 140.4889, -10.4446],
[140.4889, -10.4446, 140.4889, -10.4446, 21.1373, -1.5714, 21.1373, -1.5714],
[ 21.1373, -1.5714, 140.4889, -10.4446, 140.4889, -10.4446, 21.1373, -1.5714]], dtype=torch.float64)
tensor([6.3291, 6.3291, 6.3291, 6.3291], dtype=torch.float64)
tensor([[ 10.9592, 66.8187, 10.9592, 66.8187, -5.0152, -30.5777, -5.0152, -30.5777],
[ -5.0152, -30.5777, 10.9592, 66.8187, 10.9592, 66.8187, -5.0152, -30.5777],
[ -5.0152, -30.5777, -5.0152, -30.5777, 10.9592, 66.8187, 10.9592, 66.8187],
[ 10.9592, 66.8187, -5.0152, -30.5777, -5.0152, -30.5777, 10.9592, 66.8187]], dtype=torch.float64)