Example-24: Wrapper

# In this example construction of parametric call wrappers is illustrated
# For elements, all deviation parameters are passed as dictionary
# Wrapped elements are invoked using positional agruments
# Import

from pprint import pprint

import torch

from twiss import twiss

from ndmap.pfp import parametric_fixed_point
from ndmap.evaluate import evaluate
from ndmap.signature import chop

from model.library.drift import Drift
from model.library.multipole import Multipole
from model.library.dipole import Dipole
from model.library.line import Line

from model.command.wrapper import wrapper
# Define simple FODO based lattice using nested lines

QF = Multipole('QF', 0.5, +0.20)
QD = Multipole('QD', 0.5, -0.19)
DR = Drift('DR', 0.75)
BM = Dipole('BM', 3.50, torch.pi/4.0)


FODO_A = Line('FODO_A', FODO, propagate=True, dp=0.0, exact=False, output=False, matrix=False)
FODO_B = Line('FODO_B', FODO, propagate=True, dp=0.0, exact=False, output=False, matrix=False)
FODO_C = Line('FODO_C', FODO, propagate=True, dp=0.0, exact=False, output=False, matrix=False)
FODO_D = Line('FODO_D', FODO, propagate=True, dp=0.0, exact=False, output=False, matrix=False)

LINE_AB = Line('LINE_AB', [FODO_A, FODO_B], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
LINE_CD = Line('LINE_CD', [FODO_C, FODO_D], propagate=True, dp=0.0, exact=False, output=False, matrix=False)

RING = Line('RING', [LINE_AB, LINE_CD], propagate=True, dp=0.0, exact=False, output=False, matrix=False)
# Deviation variables are passed to elements/lines as dictionaries
# In order to compute derivatives with respect to a deviation variable
# A tensor should be binded to a corresponding leaf deviation dictionary value

pprint(RING.data(alignment=False), sort_dicts=False)
{'LINE_AB': {'FODO_A': {'QF': {'kn': tensor(0., dtype=torch.float64),
                               'ks': tensor(0., dtype=torch.float64),
                               'ms': tensor(0., dtype=torch.float64),
                               'mo': tensor(0., dtype=torch.float64),
                               'dp': tensor(0., dtype=torch.float64),
                               'dl': tensor(0., dtype=torch.float64)},
                        'DR': {'dp': tensor(0., dtype=torch.float64),
                               'dl': tensor(0., dtype=torch.float64)},
                        'BM': {'dw': tensor(0., dtype=torch.float64),
                               'e1': tensor(0., dtype=torch.float64),
                               'e2': tensor(0., dtype=torch.float64),
                               'kn': tensor(0., dtype=torch.float64),
                               'ks': tensor(0., dtype=torch.float64),
                               'ms': tensor(0., dtype=torch.float64),
                               'mo': tensor(0., dtype=torch.float64),
                               'dp': tensor(0., dtype=torch.float64),
                               'dl': tensor(0., dtype=torch.float64)},
                        'QD': {'kn': tensor(0., dtype=torch.float64),
                               'ks': tensor(0., dtype=torch.float64),
                               'ms': tensor(0., dtype=torch.float64),
                               'mo': tensor(0., dtype=torch.float64),
                               'dp': tensor(0., dtype=torch.float64),
                               'dl': tensor(0., dtype=torch.float64)}},
             'FODO_B': {'QF': {'kn': tensor(0., dtype=torch.float64),
                               'ks': tensor(0., dtype=torch.float64),
                               'ms': tensor(0., dtype=torch.float64),
                               'mo': tensor(0., dtype=torch.float64),
                               'dp': tensor(0., dtype=torch.float64),
                               'dl': tensor(0., dtype=torch.float64)},
                        'DR': {'dp': tensor(0., dtype=torch.float64),
                               'dl': tensor(0., dtype=torch.float64)},
                        'BM': {'dw': tensor(0., dtype=torch.float64),
                               'e1': tensor(0., dtype=torch.float64),
                               'e2': tensor(0., dtype=torch.float64),
                               'kn': tensor(0., dtype=torch.float64),
                               'ks': tensor(0., dtype=torch.float64),
                               'ms': tensor(0., dtype=torch.float64),
                               'mo': tensor(0., dtype=torch.float64),
                               'dp': tensor(0., dtype=torch.float64),
                               'dl': tensor(0., dtype=torch.float64)},
                        'QD': {'kn': tensor(0., dtype=torch.float64),
                               'ks': tensor(0., dtype=torch.float64),
                               'ms': tensor(0., dtype=torch.float64),
                               'mo': tensor(0., dtype=torch.float64),
                               'dp': tensor(0., dtype=torch.float64),
                               'dl': tensor(0., dtype=torch.float64)}}},
 'LINE_CD': {'FODO_C': {'QF': {'kn': tensor(0., dtype=torch.float64),
                               'ks': tensor(0., dtype=torch.float64),
                               'ms': tensor(0., dtype=torch.float64),
                               'mo': tensor(0., dtype=torch.float64),
                               'dp': tensor(0., dtype=torch.float64),
                               'dl': tensor(0., dtype=torch.float64)},
                        'DR': {'dp': tensor(0., dtype=torch.float64),
                               'dl': tensor(0., dtype=torch.float64)},
                        'BM': {'dw': tensor(0., dtype=torch.float64),
                               'e1': tensor(0., dtype=torch.float64),
                               'e2': tensor(0., dtype=torch.float64),
                               'kn': tensor(0., dtype=torch.float64),
                               'ks': tensor(0., dtype=torch.float64),
                               'ms': tensor(0., dtype=torch.float64),
                               'mo': tensor(0., dtype=torch.float64),
                               'dp': tensor(0., dtype=torch.float64),
                               'dl': tensor(0., dtype=torch.float64)},
                        'QD': {'kn': tensor(0., dtype=torch.float64),
                               'ks': tensor(0., dtype=torch.float64),
                               'ms': tensor(0., dtype=torch.float64),
                               'mo': tensor(0., dtype=torch.float64),
                               'dp': tensor(0., dtype=torch.float64),
                               'dl': tensor(0., dtype=torch.float64)}},
             'FODO_D': {'QF': {'kn': tensor(0., dtype=torch.float64),
                               'ks': tensor(0., dtype=torch.float64),
                               'ms': tensor(0., dtype=torch.float64),
                               'mo': tensor(0., dtype=torch.float64),
                               'dp': tensor(0., dtype=torch.float64),
                               'dl': tensor(0., dtype=torch.float64)},
                        'DR': {'dp': tensor(0., dtype=torch.float64),
                               'dl': tensor(0., dtype=torch.float64)},
                        'BM': {'dw': tensor(0., dtype=torch.float64),
                               'e1': tensor(0., dtype=torch.float64),
                               'e2': tensor(0., dtype=torch.float64),
                               'kn': tensor(0., dtype=torch.float64),
                               'ks': tensor(0., dtype=torch.float64),
                               'ms': tensor(0., dtype=torch.float64),
                               'mo': tensor(0., dtype=torch.float64),
                               'dp': tensor(0., dtype=torch.float64),
                               'dl': tensor(0., dtype=torch.float64)},
                        'QD': {'kn': tensor(0., dtype=torch.float64),
                               'ks': tensor(0., dtype=torch.float64),
                               'ms': tensor(0., dtype=torch.float64),
                               'mo': tensor(0., dtype=torch.float64),
                               'dp': tensor(0., dtype=torch.float64),
                               'dl': tensor(0., dtype=torch.float64)}}}}
# Compute parametric closed orbit (first order with respect to momentum deviation)

# Without wrapping, all momenta deviation occurances should be binded to a singel tensor
# Hence, deviation table should be traversed recursively down to all leafs

def scan(data, name, target):
    for key, value in data.items():
        if isinstance(value, dict):
            scan(value, name, target)
        elif key == name:
            data[key] = target

# Set ring function

def ring(state, dp):
    dp, *_ = dp
    data = RING.data()
    scan(data, 'dp', dp)
    return  RING(state, data=data)

# Set deviations

fp = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)

# Compute pfp

pfp, *_ = parametric_fixed_point((1, ), fp, [dp], ring)
[tensor([0., 0., 0., 0.], dtype=torch.float64),
         [0.0000]], dtype=torch.float64)]
# Using wrapper we can define the about ring function as follows

fn = wrapper(RING, (None, None, 'dp'))

# Set deviations

fp = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)

# Compute pfp

pfp, *_ = parametric_fixed_point((1, ), fp, [dp], fn)
[tensor([0., 0., 0., 0.], dtype=torch.float64),
         [0.0000]], dtype=torch.float64)]
# Compute chromaticity (without wrapping)

def scan(data, name, target):
    for key, value in data.items():
        if isinstance(value, dict):
            scan(value, name, target)
        elif key == name:
            data[key] = target

# Set ring function

def ring(state, dp):
    dp, *_ = dp
    data = RING.data()
    scan(data, 'dp', dp)
    return RING(state , data=data)

# Set ring function around pfp

def pfp_ring(state, dp):
    return ring(state + evaluate(pfp, [dp]), dp) - evaluate(pfp, [dp])

# Set tune function

def tune(dp):
    matrix = torch.func.jacrev(pfp_ring)(state, dp)
    tunes, *_ = twiss(matrix)
    return tunes

state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)

tensor([0.6951, 0.7019], dtype=torch.float64)
tensor([-2.0649, -0.8260], dtype=torch.float64)
# Compute chromaticity (with wrapping)

# Set ring function

fn = wrapper(RING, (None, None, 'dp'))

# Set ring function around pfp

def pfp_ring(state, dp):
    return fn(state + evaluate(pfp, [dp]), dp) - evaluate(pfp, [dp])

# Set tune function

def tune(dp):
    matrix = torch.func.jacrev(pfp_ring)(state, dp)
    tunes, *_ = twiss(matrix)
    return tunes

state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)

tensor([0.6951, 0.7019], dtype=torch.float64)
tensor([-2.0649, -0.8260], dtype=torch.float64)
# Compute chromaticity derivative with respect to sextupole ampitudes (without wrapping)

def scan(data, name, target):
    for key, value in data.items():
        if isinstance(value, dict):
            scan(value, name, target)
        elif key == name:
            data[key] = target

def ring(state, dp, dms):
    dp, *_ = dp
    dmsf, dmsd, *_ = dms
    data = RING.data()
    scan(data, 'dp', dp)
    data['LINE_AB']['FODO_A']['QF']['ms'] = dmsf
    data['LINE_AB']['FODO_B']['QF']['ms'] = dmsf
    data['LINE_CD']['FODO_C']['QF']['ms'] = dmsf
    data['LINE_CD']['FODO_D']['QF']['ms'] = dmsf
    data['LINE_AB']['FODO_A']['QD']['ms'] = dmsd
    data['LINE_AB']['FODO_B']['QD']['ms'] = dmsd
    data['LINE_CD']['FODO_C']['QD']['ms'] = dmsd
    data['LINE_CD']['FODO_D']['QD']['ms'] = dmsd
    return RING(state, data=data)

def pfp_ring(state, dp, dms):
    return ring(state + evaluate(pfp, [dp]), dp, dms) - evaluate(pfp, [dp])

def tune(dp, dms):
    matrix = torch.func.jacrev(pfp_ring)(state, dp, dms)
    tunes, *_ = twiss(matrix)
    return tunes

def chromaticity(dms):
    return torch.func.jacrev(tune)(dp, dms)

state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
dms = torch.tensor([0.0, 0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)

print(tune(dp, dms))
tensor([0.6951, 0.7019], dtype=torch.float64)
tensor([-2.0649, -0.8260], dtype=torch.float64)
tensor([[ 25.8500,   1.0470],
        [ -9.0271, -16.4821]], dtype=torch.float64)
# Compute chromaticity derivative with respect to sextupole ampitudes (with wrapping)

def scan(data, name, target):
    for key, value in data.items():
        if isinstance(value, dict):
            scan(value, name, target)
        elif key == name:
            data[key] = target

ring = wrapper(RING, (None, None, 'dp'), (None, ['QF', 'QD'], 'ms'))

def pfp_ring(state, dp, dms):
    return ring(state + evaluate(pfp, [dp]), dp, dms) - evaluate(pfp, [dp])

def tune(dp, dms):
    matrix = torch.func.jacrev(pfp_ring)(state, dp, dms)
    tunes, *_ = twiss(matrix)
    return tunes

def chromaticity(dms):
    return torch.func.jacrev(tune)(dp, dms)

state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
dms = torch.tensor([0.0, 0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)

print(tune(dp, dms))
tensor([0.6951, 0.7019], dtype=torch.float64)
tensor([-2.0649, -0.8260], dtype=torch.float64)
tensor([[ 25.8500,   1.0470],
        [ -9.0271, -16.4821]], dtype=torch.float64)
# The above examples demonstrate how to bind tensors to all leafs or to given elements (in all lines)

# (None,           None,            parameter:str) -- bind tensor to all leaf parameters
# (None,           names:list[str], parameter:str) -- bind tensor to all leaf parameters in specified elements
# (path:list[str], names:list[str], parameter:str) -- bind tensor to all leaf parameters in specified elements in given path (path to specific line)
# Bind QF and QD in all sublines of a given line (1/2 of sextupoles)

ring = wrapper(RING, (None, None, 'dp'), (['LINE_AB'], ['QF', 'QD'], 'ms'))

def pfp_ring(state, dp, dms):
    return ring(state + evaluate(pfp, [dp]), dp, dms) - evaluate(pfp, [dp])

def tune(dp, dms):
    matrix = torch.func.jacrev(pfp_ring)(state, dp, dms)
    tunes, *_ = twiss(matrix)
    return tunes

def chromaticity(dms):
    return torch.func.jacrev(tune)(dp, dms)

state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
dms = torch.tensor([0.0, 0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)

print(tune(dp, dms))
tensor([0.6951, 0.7019], dtype=torch.float64)
tensor([-2.0649, -0.8260], dtype=torch.float64)
tensor([[ 25.8500,   1.0470],
        [ -9.0271, -16.4821]], dtype=torch.float64)
# Bind QF and QD in a given leaf line (1/4 of sextupoles)

ring = wrapper(RING, (None, None, 'dp'), (['LINE_AB', 'FODO_A'], ['QF', 'QD'], 'ms'))

def pfp_ring(state, dp, dms):
    return ring(state + evaluate(pfp, [dp]), dp, dms) - evaluate(pfp, [dp])

def tune(dp, dms):
    matrix = torch.func.jacrev(pfp_ring)(state, dp, dms)
    tunes, *_ = twiss(matrix)
    return tunes

def chromaticity(dms):
    return torch.func.jacrev(tune)(dp, dms)

state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
dms = torch.tensor([0.0, 0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)

print(tune(dp, dms))
tensor([0.6951, 0.7019], dtype=torch.float64)
tensor([-2.0649, -0.8260], dtype=torch.float64)
tensor([[ 25.8500,   1.0470],
        [ -9.0271, -16.4821]], dtype=torch.float64)
# Several sextupole groups

ring = wrapper(RING, (None, None, 'dp'), (['LINE_AB'], ['QF', 'QD'], 'ms'), (['LINE_CD'], ['QF', 'QD'], 'ms'))

def pfp_ring(state, dp, dms_ab, dms_cd):
    return ring(state + evaluate(pfp, [dp]), dp, dms_ab, dms_cd) - evaluate(pfp, [dp])

def tune(dp, dms_ab, dms_cd):
    matrix = torch.func.jacrev(pfp_ring)(state, dp, dms_ab, dms_cd)
    tunes, *_ = twiss(matrix)
    return tunes

def chromaticity(dms_ab, dms_cd):
    return torch.func.jacrev(tune)(dp, dms_ab, dms_cd)

state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)
dms_ab = torch.tensor([0.0, 0.0], dtype=torch.float64)
dms_cd = torch.tensor([0.0, 0.0], dtype=torch.float64)
dp = torch.tensor([0.0], dtype=torch.float64)

print(tune(dp, dms_ab, dms_cd))
print(chromaticity(dms_ab, dms_cd).squeeze())
print(torch.func.jacrev(chromaticity, 0)(dms_ab, dms_cd).squeeze())
print(torch.func.jacrev(chromaticity, 1)(dms_ab, dms_cd).squeeze())

def fn(dms):
    dms_ab, dms_cd = dms
    return chromaticity(dms_ab, dms_cd)

print(torch.func.jacrev(fn)(torch.stack([dms_ab, dms_cd])).squeeze())
tensor([0.6951, 0.7019], dtype=torch.float64)
tensor([-2.0649, -0.8260], dtype=torch.float64)
tensor([[12.9250,  0.5235],
        [-4.5135, -8.2411]], dtype=torch.float64)
tensor([[12.9250,  0.5235],
        [-4.5135, -8.2411]], dtype=torch.float64)
tensor([[[12.9250,  0.5235],
         [12.9250,  0.5235]],

        [[-4.5135, -8.2411],
         [-4.5135, -8.2411]]], dtype=torch.float64)