Example-25: Group

[1]:
# In this example another wrapper construction procedure is illustraded
[2]:
# Import

import torch
torch.set_printoptions(linewidth=128)

from twiss import twiss
from twiss import propagate
from twiss import wolski_to_cs

from model.library.drift import Drift
from model.library.quadrupole import Quadrupole
from model.library.dipole import Dipole
from model.library.line import Line

from model.command.wrapper import group
[3]:
# Define simple FODO based lattice using nested lines

DR = Drift('DR', 0.75)
BM = Dipole('BM', 3.50, torch.pi/4.0)

QF_A = Quadrupole('QF_A', 0.5, +0.20)
QD_A = Quadrupole('QD_A', 0.5, -0.19)
QF_B = Quadrupole('QF_B', 0.5, +0.20)
QD_B = Quadrupole('QD_B', 0.5, -0.19)
QF_C = Quadrupole('QF_C', 0.5, +0.20)
QD_C = Quadrupole('QD_C', 0.5, -0.19)
QF_D = Quadrupole('QF_D', 0.5, +0.20)
QD_D = Quadrupole('QD_D', 0.5, -0.19)

FODO_A = Line('FODO_A', [QF_A, DR, BM, DR, QD_A, QD_A, DR, BM, DR, QF_A], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
FODO_B = Line('FODO_B', [QF_B, DR, BM, DR, QD_B, QD_B, DR, BM, DR, QF_B], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
FODO_C = Line('FODO_C', [QF_C, DR, BM, DR, QD_C, QD_C, DR, BM, DR, QF_C], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
FODO_D = Line('FODO_D', [QF_D, DR, BM, DR, QD_D, QD_D, DR, BM, DR, QF_D], propagate=True, dp=0.0, exact=False, output=False, matrix=False)

RING = Line('RING', [FODO_A, FODO_B, FODO_C, FODO_D], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
[4]:
# Full ring or subline can be wrapped by element kind

fn, table, line = group(RING,                               # -- source line
                        'FODO_A',                           # -- start (name or position in source line sequence)
                        'FODO_B',                           # -- end (name or position in source line sequence)
                        ('kn', ['Quadrupole'], None, None)) # -- groups (key:str, kinds:list[str]|None, names:list[str]|None, clean:list[str]|None

# Information about deviation variables is returbed in wrapper format

print(table)
print()

# Wrapped function fn can be called with deviation variables

(_, names, _), *_ = table
knobs = torch.tensor(len(names)*[0.0], dtype=torch.float64)
state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
print(fn(state, knobs))
print()

# Constructed line also returned

print(line)
print()
[(None, ['QF_A', 'QD_A', 'QF_B', 'QD_B'], 'kn')]

tensor([0., 0., 0., 0.], dtype=torch.float64)

Quadrupole(name="QF_A", length=0.5, kn=0.200000000000001, ks=0.0, dp=0.0, exact=False, ns=1, order=0)
Drift(name="DR", length=0.75, dp=0.0, exact=False, ns=1, order=0)
Dipole(name="BM", length=3.5, angle=0.7853981633974493, e1=0.0, e2=0.0, kn=1e-15, ks=0.0, ms=0.0, mo=0.0, dp=0.0, exact=False, ns=1, order=0)
Drift(name="DR", length=0.75, dp=0.0, exact=False, ns=1, order=0)
Quadrupole(name="QD_A", length=0.5, kn=-0.189999999999999, ks=0.0, dp=0.0, exact=False, ns=1, order=0)
Quadrupole(name="QD_A", length=0.5, kn=-0.189999999999999, ks=0.0, dp=0.0, exact=False, ns=1, order=0)
Drift(name="DR", length=0.75, dp=0.0, exact=False, ns=1, order=0)
Dipole(name="BM", length=3.5, angle=0.7853981633974493, e1=0.0, e2=0.0, kn=1e-15, ks=0.0, ms=0.0, mo=0.0, dp=0.0, exact=False, ns=1, order=0)
Drift(name="DR", length=0.75, dp=0.0, exact=False, ns=1, order=0)
Quadrupole(name="QF_A", length=0.5, kn=0.200000000000001, ks=0.0, dp=0.0, exact=False, ns=1, order=0)
Quadrupole(name="QF_B", length=0.5, kn=0.200000000000001, ks=0.0, dp=0.0, exact=False, ns=1, order=0)
Drift(name="DR", length=0.75, dp=0.0, exact=False, ns=1, order=0)
Dipole(name="BM", length=3.5, angle=0.7853981633974493, e1=0.0, e2=0.0, kn=1e-15, ks=0.0, ms=0.0, mo=0.0, dp=0.0, exact=False, ns=1, order=0)
Drift(name="DR", length=0.75, dp=0.0, exact=False, ns=1, order=0)
Quadrupole(name="QD_B", length=0.5, kn=-0.189999999999999, ks=0.0, dp=0.0, exact=False, ns=1, order=0)
Quadrupole(name="QD_B", length=0.5, kn=-0.189999999999999, ks=0.0, dp=0.0, exact=False, ns=1, order=0)
Drift(name="DR", length=0.75, dp=0.0, exact=False, ns=1, order=0)
Dipole(name="BM", length=3.5, angle=0.7853981633974493, e1=0.0, e2=0.0, kn=1e-15, ks=0.0, ms=0.0, mo=0.0, dp=0.0, exact=False, ns=1, order=0)
Drift(name="DR", length=0.75, dp=0.0, exact=False, ns=1, order=0)
Quadrupole(name="QF_B", length=0.5, kn=0.200000000000001, ks=0.0, dp=0.0, exact=False, ns=1, order=0)

[5]:
# By default names are excracted from created subline

_, table, _ = group(RING, 'FODO_A', 'FODO_B', ('kn', ['Quadrupole'], None, None))

print(table)
print()

# Use root flag to extract name from the root line


_, table, _ = group(RING, 'FODO_A', 'FODO_B', ('kn', ['Quadrupole'], None, None), root=True)

print(table)
print()
[(None, ['QF_A', 'QD_A', 'QF_B', 'QD_B'], 'kn')]

[(None, ['QF_A', 'QD_A', 'QF_B', 'QD_B', 'QF_C', 'QD_C', 'QF_D', 'QD_D'], 'kn')]

[6]:
# Set transport between observation points

# 0--A--1--B--2--C--3--D--4

line01, *_ =  group(RING, 'FODO_A', 'FODO_A', ('kn', ['Quadrupole'], None, None), root=True)
line12, *_ =  group(RING, 'FODO_B', 'FODO_B', ('kn', ['Quadrupole'], None, None), root=True)
line23, *_ =  group(RING, 'FODO_C', 'FODO_C', ('kn', ['Quadrupole'], None, None), root=True)
line34, *_ =  group(RING, 'FODO_D', 'FODO_D', ('kn', ['Quadrupole'], None, None), root=True)

lines = [
    line01,
    line12,
    line23,
    line34
]

def ring(state, knobs):
    for line in lines:
        state = line(state, knobs)
    return state


state = torch.tensor(4*[0.0], dtype=torch.float64)
knobs = torch.tensor(8*[0.0], dtype=torch.float64)

print(ring(state, knobs))
tensor([0., 0., 0., 0.], dtype=torch.float64)
[7]:
# Compute tunes and corresponding derivatives with respect to deviation parameters

def fn(knobs):
    m = torch.func.jacfwd(ring)(state, knobs)
    t, *_ = twiss(m)
    return t

print(fn(knobs))
print(torch.func.jacrev(fn)(knobs))
tensor([0.6951, 0.7019], dtype=torch.float64)
tensor([[ 1.4567,  0.1055,  1.4567,  0.1055,  1.4567,  0.1055,  1.4567,  0.1055],
        [-0.5132, -1.6271, -0.5132, -1.6271, -0.5132, -1.6271, -0.5132, -1.6271]], dtype=torch.float64)
[8]:
# Compute beta functions at observation points and corresponding derivatives with respect to deviation parameters

def fn(knobs):

    bxs = []
    bys = []

    m = torch.func.jacfwd(ring)(state, knobs)

    *_, w = twiss(m)
    _, bx, _, by = wolski_to_cs(w)

    for line in lines:
        w = propagate(w, torch.func.jacrev(line)(state, knobs))
        _, bx, _, by = wolski_to_cs(w)
        bxs.append(bx)
        bys.append(by)

    bxs = torch.stack(bxs)
    bys = torch.stack(bys)

    return bxs, bys

bx, by = fn(knobs)
dbxdk, dbydk = torch.func.jacrev(fn)(knobs)

print(bx)
print(dbxdk)
print()

print(by)
print(dbydk)
print()
tensor([18.6083, 18.6083, 18.6083, 18.6083], dtype=torch.float64)
tensor([[ 21.1373,  -1.5714,  21.1373,  -1.5714, 140.4889, -10.4446, 140.4889, -10.4446],
        [140.4889, -10.4446,  21.1373,  -1.5714,  21.1373,  -1.5714, 140.4889, -10.4446],
        [140.4889, -10.4446, 140.4889, -10.4446,  21.1373,  -1.5714,  21.1373,  -1.5714],
        [ 21.1373,  -1.5714, 140.4889, -10.4446, 140.4889, -10.4446,  21.1373,  -1.5714]], dtype=torch.float64)

tensor([6.3291, 6.3291, 6.3291, 6.3291], dtype=torch.float64)
tensor([[ 10.9592,  66.8187,  10.9592,  66.8187,  -5.0152, -30.5777,  -5.0152, -30.5777],
        [ -5.0152, -30.5777,  10.9592,  66.8187,  10.9592,  66.8187,  -5.0152, -30.5777],
        [ -5.0152, -30.5777,  -5.0152, -30.5777,  10.9592,  66.8187,  10.9592,  66.8187],
        [ 10.9592,  66.8187,  -5.0152, -30.5777,  -5.0152, -30.5777,  10.9592,  66.8187]], dtype=torch.float64)