ELETTRA-38: ID linear optics distortion (analytical & single uncoupled perturbation)

[1]:
# In this example expressions for various linear optics observables for a ring with one uncoupled perturbations are given
# The ring starts with perturbation followed by the rest of the lattice
[2]:
# Import

import torch
from torch import Tensor

from pathlib import Path

import matplotlib
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
matplotlib.rcParams['text.usetex'] = True

from model.library.element import Element
from model.library.line import Line
from model.library.quadrupole import Quadrupole
from model.library.matrix import Matrix
from model.library.drift import Drift

from model.command.external import load_lattice
from model.command.build import build
from model.command.tune import tune
from model.command.orbit import dispersion
from model.command.twiss import twiss
from model.command.advance import advance
from model.command.coupling import coupling
[3]:
# Set data type and device

Element.dtype = dtype = torch.float64
Element.device = device = torch.device('cpu')
[4]:
# Load lattice (ELEGANT table)
# Note, lattice is allowed to have repeated elements

path = Path('elettra.lte')
data = load_lattice(path)
[5]:
# Build and setup lattice

ring:Line = build('RING', 'ELEGANT', data)

# Flatten sublines

ring.flatten()

# Remove all marker elements but the ones starting with MLL (long straight section centers)

ring.remove_group(pattern=r'^(?!MLL_).*', kinds=['Marker'])

# Replace all sextupoles with quadrupoles

def factory(element:Element) -> None:
    table = element.serialize
    table.pop('ms', None)
    return Quadrupole(**table)

ring.replace_group(pattern=r'', factory=factory, kinds=['Sextupole'])

# Set linear dipoles

def apply(element:Element) -> None:
    element.linear = True

ring.apply(apply, kinds=['Dipole'])

# Merge drifts

ring.merge()

# Change lattice start

ring.start = "MLL_S01"

# Describe

ring.describe
[5]:
{'Marker': 12, 'Drift': 708, 'BPM': 168, 'Quadrupole': 360, 'Dipole': 156}
[6]:
# Define ID

# The ID transport matrix has the form exp(S A) with Aij = Aji with A having only the main diagonal elements
# In general, for uncoupled case, the A matrix in block diagonal
# To model the general case, ID is shifted in the longitudinal plane

ca, cb, cc, cd = -0.034441907232402175, -0.04458009513208418, 0.056279356423643276, 0.08037110220505986
A = torch.tensor([[ca, 0.0, 0.0, 0.0], [0.0, cb, 0.0, 0.0], [0.0, 0.0, cc, 0.0], [0.0, 0.0, 0.0, cd]], dtype=dtype)
print(A)

mask = torch.triu(torch.ones_like(A, dtype=torch.bool))

T = torch.tensor([[1.0000, 0.2000, 0.0000, 0.0000],
                  [0.0000, 1.0000, 0.0000, 0.0000],
                  [0.0000, 0.0000, 1.0000, 0.2000],
                  [0.0000, 0.0000, 0.0000, 1.0000]], dtype=dtype)

A = T.T @ A @ T
print(A)

Ax, Bx, _, _, Cx, _, _, Ay, By, Cy = A[mask].to(torch.complex128)

ID = Matrix('ID', length=0.0, A=(A[mask]).tolist())
tensor([[-0.0344,  0.0000,  0.0000,  0.0000],
        [ 0.0000, -0.0446,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0563,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0804]], dtype=torch.float64)
tensor([[-0.0344, -0.0069,  0.0000,  0.0000],
        [-0.0069, -0.0460,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0563,  0.0113],
        [ 0.0000,  0.0000,  0.0113,  0.0826]], dtype=torch.float64)
[7]:
# Compute tunes (fractional part)

nux, nuy = tune(ring, [], matched=True, limit=1)
[8]:
# Compute twiss parameters

ax, bx, ay, by = twiss(ring, [], matched=True, advance=True, full=False).T

axi, *_ = ax
bxi, *_ = bx
ayi, *_ = ay
byi, *_ = by
[9]:
# Compute phase advances

mux, muy = advance(ring, [], alignment=False, matched=True).T

mux = mux.cumsum(-1)
muy = muy.cumsum(-1)

mux = (mux % mux.max()).roll(1)
muy = (muy % muy.max()).roll(1)
[10]:
# Replace marker by ID

ring.replace('MLL_S01', ID)

# Describe

ring.describe
[10]:
{'Matrix': 1,
 'Drift': 708,
 'BPM': 168,
 'Quadrupole': 360,
 'Dipole': 156,
 'Marker': 11}
[11]:
# Compute tunes (fractional part)

nux_id, nuy_id = tune(ring, [], matched=True, limit=1)
[12]:
# Compute twiss parameters

ax_id, bx_id, ay_id, by_id = twiss(ring, [], matched=True, advance=True, full=False).T

axf, *_ = ax_id
bxf, *_ = bx_id
ayf, *_ = ay_id
byf, *_ = by_id
[13]:
# Compute phase advances

mux_id, muy_id = advance(ring, [], alignment=False, matched=True).T

mux_id = mux_id.cumsum(-1)
muy_id = muy_id.cumsum(-1)

mux_id = (mux_id % mux_id.max()).roll(1)
muy_id = (muy_id % muy_id.max()).roll(1)
[14]:
# Tune shifts

print(nux - nux_id)
print(nuy - nuy_id)
tensor(0.0257, dtype=torch.float64)
tensor(-0.0113, dtype=torch.float64)
[15]:
# Tune shifts (exact)

def dnux(Ax, Bx, Cx, ax, bx, nux):
    return nux - torch.arccos(torch.cos(2*nux*torch.pi)*torch.cosh(torch.sqrt(Bx**2 - Ax*Cx)) - ((Ax*bx**2 - 2*ax*bx*Bx + Cx + ax**2*Cx)*torch.sin(2*nux*torch.pi)*torch.sinh(torch.sqrt(Bx**2 - Ax*Cx)))/(2*bx*torch.sqrt(Bx**2 - Ax*Cx)))/(2*torch.pi)

def dnuy(Ay, By, Cy, ay, by, nuy):
    return nuy - torch.arccos(torch.cos(2*nuy*torch.pi)*torch.cosh(torch.sqrt(By**2 - Ay*Cy)) - ((Ay*by**2 - 2*ay*by*By + Cy + ay**2*Cy)*torch.sin(2*nuy*torch.pi)*torch.sinh(torch.sqrt(By**2 - Ay*Cy)))/(2*by*torch.sqrt(By**2 - Ay*Cy)))/(2*torch.pi)

print((nux - nux_id))
print(dnux(Ax, Bx, Cx, axi, bxi, nux).real)
print()

print((nuy - nuy_id))
print(dnux(Ay, By, Cy, ayi, byi, nuy).real)
print()
tensor(0.0257, dtype=torch.float64)
tensor(0.0257, dtype=torch.float64)

tensor(-0.0113, dtype=torch.float64)
tensor(-0.0113, dtype=torch.float64)

[16]:
# Tune shifts (approximate)

def dnux(Ax, Bx, Cx, ax, bx, nux):
    cx = (1 + ax**2)/bx
    return - ((Ax*bx - 2*ax*Bx + cx*Cx)/(4*torch.pi) - ((Ax**2*bx**2 - 4*ax*Ax*bx*Bx + 4*bx*Bx**2*cx + 2*Ax*(-2 + bx*cx)*Cx + cx*Cx*(-4*ax*Bx + cx*Cx))*torch.cos(2*nux*torch.pi)/torch.sin(2*nux*torch.pi))/(16*torch.pi))

def dnuy(Ay, By, Cy, ay, by, nuy):
    cy = (1 + ay**2)/by
    return - ((Ay*by - 2*ay*By + cy*Cy)/(4*torch.pi) - ((Ay**2*by**2 - 4*ay*Ay*by*By + 4*by*By**2*cy + 2*Ay*(-2 + by*cy)*Cy + cy*Cy*(-4*ay*By + cy*Cy))*torch.cos(2*nuy*torch.pi)/torch.sin(2*nuy*torch.pi))/(16*torch.pi))

print((nux - nux_id))
print(dnux(Ax, Bx, Cx, axi, bxi, nux).real)
print()

print((nuy - nuy_id))
print(dnux(Ay, By, Cy, ayi, byi, nuy).real)
print()
tensor(0.0257, dtype=torch.float64)
tensor(0.0255, dtype=torch.float64)

tensor(-0.0113, dtype=torch.float64)
tensor(-0.0113, dtype=torch.float64)

[17]:
# Twiss at the observation point (exact)

def csx(Ax, Bx, Cx, ax, bx, nux):
    cx = (1 + ax**2)/bx
    hax = (2*ax*torch.cosh(torch.sqrt(Bx**2 - Ax*Cx))*torch.sin(2*nux*torch.pi) + ((2*bx*Bx*torch.cos(2*nux*torch.pi) + (-(Ax*bx**2) + Cx + ax**2*Cx)*torch.sin(2*nux*torch.pi))*torch.sinh(torch.sqrt(Bx**2 - Ax*Cx)))/(bx*torch.sqrt(Bx**2 - Ax*Cx)))/(2*torch.sqrt(1 - (torch.cos(2*nux*torch.pi)*torch.cosh(torch.sqrt(Bx**2 - Ax*Cx)) - ((Ax*bx**2 - 2*ax*bx*Bx + Cx + ax**2*Cx)*torch.sin(2*nux*torch.pi)*torch.sinh(torch.sqrt(Bx**2 - Ax*Cx)))/(2*bx*torch.sqrt(Bx**2 - Ax*Cx)))**2))
    hbx = (bx*torch.cosh(torch.sqrt(Bx**2 - Ax*Cx))*torch.sin(2*nux*torch.pi) + ((Cx*torch.cos(2*nux*torch.pi) + (-(bx*Bx) + ax*Cx)*torch.sin(2*nux*torch.pi))*torch.sinh(torch.sqrt(Bx**2 - Ax*Cx)))/torch.sqrt(Bx**2 - Ax*Cx))/torch.sqrt(1 - (torch.cos(2*nux*torch.pi)*torch.cosh(torch.sqrt(Bx**2 - Ax*Cx)) - ((Ax*bx**2 - 2*ax*bx*Bx + Cx + ax**2*Cx)*torch.sin(2*nux*torch.pi)*torch.sinh(torch.sqrt(Bx**2 - Ax*Cx)))/(2*bx*torch.sqrt(Bx**2 - Ax*Cx)))**2)
    return torch.stack([hax, hbx])

def csy(Ay, By, Cy, ay, by, nuy):
    cy = (1 + ay**2)/by
    hay = (2*ay*torch.cosh(torch.sqrt(By**2 - Ay*Cy))*torch.sin(2*nuy*torch.pi) + ((2*by*By*torch.cos(2*nuy*torch.pi) + (-(Ay*by**2) + Cy + ay**2*Cy)*torch.sin(2*nuy*torch.pi))*torch.sinh(torch.sqrt(By**2 - Ay*Cy)))/(by*torch.sqrt(By**2 - Ay*Cy)))/(2*torch.sqrt(1 - (torch.cos(2*nuy*torch.pi)*torch.cosh(torch.sqrt(By**2 - Ay*Cy)) - ((Ay*by**2 - 2*ay*by*By + Cy + ay**2*Cy)*torch.sin(2*nuy*torch.pi)*torch.sinh(torch.sqrt(By**2 - Ay*Cy)))/(2*by*torch.sqrt(By**2 - Ay*Cy)))**2))
    hby = (by*torch.cosh(torch.sqrt(By**2 - Ay*Cy))*torch.sin(2*nuy*torch.pi) + ((Cy*torch.cos(2*nuy*torch.pi) + (-(by*By) + ay*Cy)*torch.sin(2*nuy*torch.pi))*torch.sinh(torch.sqrt(By**2 - Ay*Cy)))/torch.sqrt(By**2 - Ay*Cy))/torch.sqrt(1 - (torch.cos(2*nuy*torch.pi)*torch.cosh(torch.sqrt(By**2 - Ay*Cy)) - ((Ay*by**2 - 2*ay*by*By + Cy + ay**2*Cy)*torch.sin(2*nuy*torch.pi)*torch.sinh(torch.sqrt(By**2 - Ay*Cy)))/(2*by*torch.sqrt(By**2 - Ay*Cy)))**2)
    return torch.stack([hay, hby])

print(torch.stack([axf, bxf]))
print(csx(Ax, Bx, Cx, axi, bxi, nux).real)
print()

print(torch.stack([ayf, byf]))
print(csy(Ay, By, Cy, ayi, byi, nuy).real)
print()
tensor([0.1559, 9.1353], dtype=torch.float64)
tensor([0.1559, 9.1353], dtype=torch.float64)

tensor([-0.0133,  1.6021], dtype=torch.float64)
tensor([-0.0133,  1.6021], dtype=torch.float64)

[18]:
# Twiss at the observation point (approximate)

def csx(Ax, Bx, Cx, ax, bx, nux):
    cx = (1 + ax**2)/bx
    hax = ax + (-(Ax*bx) + cx*Cx - (ax*Ax*bx - 2*bx*Bx*cx + ax*cx*Cx)*(torch.cos(2*nux*torch.pi)/torch.sin(2*nux*torch.pi)))/2.0 + ((1.0/torch.sin(2*nux*torch.pi)**3)*(16*ax*(Bx**2 - Ax*Cx)*torch.sin(2*nux*torch.pi)**3 + 4*ax*torch.sin(2*nux*torch.pi)*(4*(Bx**2 - Ax*Cx)*torch.cos(2*nux*torch.pi)**2 + (Ax*bx - 2*ax*Bx + cx*Cx)**2 * torch.sin(2*nux*torch.pi)**2) + 4*(Ax*bx - 2*ax*Bx + cx*Cx)*(-2*Bx*torch.cos(2*nux*torch.pi) + (Ax*bx - cx*Cx)*torch.sin(2*nux*torch.pi))*torch.sin(4*nux*torch.pi) + 3*ax*(Ax*bx - 2*ax*Bx + cx*Cx)**2*(1.0/torch.sin(2*nux*torch.pi)) * torch.sin(4*nux*torch.pi)**2))/32.0
    hbx = bx - bx*Bx + ax*Cx + (-0.5*(Ax*bx**2) + ax*bx*Bx + Cx - 0.5*(bx*cx*Cx)) * (torch.cos(2*nux*torch.pi)/torch.sin(2*nux*torch.pi)) + (4*bx*(Bx**2 - Ax*Cx) + bx*(Ax*bx - 2*ax*Bx + cx*Cx)**2 + 4*bx*(Bx**2 - Ax*Cx)*(torch.cos(2*nux*torch.pi)/torch.sin(2*nux*torch.pi))**2 + 3*bx*(Ax*bx - 2*ax*Bx + cx*Cx)**2*(torch.cos(2*nux*torch.pi)/torch.sin(2*nux*torch.pi))**2 + 2*(Ax*bx - 2*ax*Bx + cx*Cx)*(bx*Bx - ax*Cx - Cx*(torch.cos(2*nux*torch.pi)/torch.sin(2*nux*torch.pi)))*(1.0/torch.sin(2*nux*torch.pi)**2) * torch.sin(4*nux*torch.pi))/8.0
    return torch.stack([hax, hbx])

def csy(Ay, By, Cy, ay, by, nuy):
    cy = (1 + ay**2)/by
    hay = ay + (-(Ay*by) + cy*Cy - (ay*Ay*by - 2*by*By*cy + ay*cy*Cy)*(torch.cos(2*nuy*torch.pi)/torch.sin(2*nuy*torch.pi)))/2.0 + ((1.0/torch.sin(2*nuy*torch.pi)**3)*(16*ay*(By**2 - Ay*Cy)*torch.sin(2*nuy*torch.pi)**3 + 4*ay*torch.sin(2*nuy*torch.pi)*(4*(By**2 - Ay*Cy)*torch.cos(2*nuy*torch.pi)**2 + (Ay*by - 2*ay*By + cy*Cy)**2 * torch.sin(2*nuy*torch.pi)**2) + 4*(Ay*by - 2*ay*By + cy*Cy)*(-2*By*torch.cos(2*nuy*torch.pi) + (Ay*by - cy*Cy)*torch.sin(2*nuy*torch.pi))*torch.sin(4*nuy*torch.pi) + 3*ay*(Ay*by - 2*ay*By + cy*Cy)**2*(1.0/torch.sin(2*nuy*torch.pi)) * torch.sin(4*nuy*torch.pi)**2))/32.0
    hby = by - by*By + ay*Cy + (-0.5*(Ay*by**2) + ay*by*By + Cy - 0.5*(by*cy*Cy)) * (torch.cos(2*nuy*torch.pi)/torch.sin(2*nuy*torch.pi)) + (4*by*(By**2 - Ay*Cy) + by*(Ay*by - 2*ay*By + cy*Cy)**2 + 4*by*(By**2 - Ay*Cy)*(torch.cos(2*nuy*torch.pi)/torch.sin(2*nuy*torch.pi))**2 + 3*by*(Ay*by - 2*ay*By + cy*Cy)**2*(torch.cos(2*nuy*torch.pi)/torch.sin(2*nuy*torch.pi))**2 + 2*(Ay*by - 2*ay*By + cy*Cy)*(by*By - ay*Cy - Cy*(torch.cos(2*nuy*torch.pi)/torch.sin(2*nuy*torch.pi)))*(1.0/torch.sin(2*nuy*torch.pi)**2) * torch.sin(4*nuy*torch.pi))/8.0
    return torch.stack([hay, hby])

print(torch.stack([axf, bxf]))
print(csx(Ax, Bx, Cx, axi, bxi, nux).real)
print()

print(torch.stack([ayf, byf]))
print(csy(Ay, By, Cy, ayi, byi, nuy).real)
print()
tensor([0.1559, 9.1353], dtype=torch.float64)
tensor([0.1534, 9.1523], dtype=torch.float64)

tensor([-0.0133,  1.6021], dtype=torch.float64)
tensor([-0.0133,  1.6023], dtype=torch.float64)

[19]:
# Beta beating (exact)

def bbx(Ax, Bx, Cx, nux, ax, bx, axs, bxs, mux):
    bxf = (bxs*(8*bx**2*(Bx**2 - Ax*Cx)*(torch.cos(2*mux) - torch.cos(4*nux*torch.pi))*torch.cosh(torch.sqrt(Bx**2 - Ax*Cx))**4 + 16*bx**2*(Bx**2 - Ax*Cx)*torch.cosh(torch.sqrt(Bx**2 - Ax*Cx))**2*torch.sin(mux)**2 + 8*bx*torch.sqrt(Bx**2 - Ax*Cx)*torch.cosh(torch.sqrt(Bx**2 - Ax*Cx))**3*(-4*Cx*torch.cos(mux)*torch.cos(2*nux*torch.pi) + 2*torch.sin(mux)*(2*(bx*Bx - ax*Cx)*torch.cos(2*nux*torch.pi) + (Ax*bx**2 - 2*ax*bx*Bx + Cx + ax**2*Cx)*torch.sin(2*nux*torch.pi)))*torch.sin(mux - 2*nux*torch.pi)*torch.sinh(torch.sqrt(Bx**2 - Ax*Cx)) + 16*(Cx*torch.cos(mux) + (-(bx*Bx) + ax*Cx)*torch.sin(mux))**2*torch.sinh(torch.sqrt(Bx**2 - Ax*Cx))**2 - 8*bx*torch.sqrt(Bx**2 - Ax*Cx)*torch.cosh(torch.sqrt(Bx**2 - Ax*Cx))*(-4*Cx*torch.cos(mux)*torch.cos(2*nux*torch.pi) + 2*torch.sin(mux)*(2*(bx*Bx - ax*Cx)*torch.cos(2*nux*torch.pi) + (Ax*bx**2 - 2*ax*bx*Bx + Cx + ax**2*Cx)*torch.sin(2*nux*torch.pi)))*torch.sin(mux - 2*nux*torch.pi)*torch.sinh(torch.sqrt(Bx**2 - Ax*Cx))**3 - 16*(-(bx**2*(Bx**2 - Ax*Cx)*torch.cos(nux*torch.pi)**2*torch.sin(mux)) + torch.sin(mux)*torch.sin(nux*torch.pi)*(-2*(bx*Bx - ax*Cx)*(Ax*bx**2 - 2*ax*bx*Bx + Cx + ax**2*Cx)*torch.cos(nux*torch.pi) + bx**2*(Bx**2 - Ax*Cx)*torch.sin(nux*torch.pi)) + (bx**2*Bx**2 - 2*ax*bx*Bx*Cx + (1 + ax**2)*Cx**2)*torch.cos(mux)*torch.sin(2*nux*torch.pi))*torch.sin(mux - 2*nux*torch.pi)*torch.sinh(torch.sqrt(Bx**2 - Ax*Cx))**4 + torch.sinh(2*torch.sqrt(Bx**2 - Ax*Cx))*(-16*bx*torch.sqrt(Bx**2 - Ax*Cx)*torch.sin(mux)*(-(Cx*torch.cos(mux)) + (bx*Bx - ax*Cx)*torch.sin(mux)) - 4*(-((-2*ax*bx*Bx*Cx + (1 + ax**2)*Cx**2 + bx**2*(2*Bx**2 - Ax*Cx))*torch.cos(mux)) + (bx*Bx - ax*Cx)*(Ax*bx**2 - 2*ax*bx*Bx + Cx + ax**2*Cx)*torch.sin(mux))*torch.sin(2*nux*torch.pi)*torch.sin(mux - 2*nux*torch.pi)*torch.sinh(2*torch.sqrt(Bx**2 - Ax*Cx)))))/(8*bx*torch.sqrt(Bx**2 - Ax*Cx)*(bx*torch.sqrt(Bx**2 - Ax*Cx)*torch.cosh(torch.sqrt(Bx**2 - Ax*Cx))*torch.sin(2*nux*torch.pi) + (Cx*torch.cos(2*nux*torch.pi) + (-(bx*Bx) + ax*Cx)*torch.sin(2*nux*torch.pi))*torch.sinh(torch.sqrt(Bx**2 - Ax*Cx)))*torch.sqrt((-4*bx**2*(Bx**2 - Ax*Cx)*torch.cos(2*nux*torch.pi)**2*torch.cosh(torch.sqrt(Bx**2 - Ax*Cx))**2 - (Ax*bx**2 - 2*ax*bx*Bx + Cx + ax**2*Cx)**2*torch.sin(2*nux*torch.pi)**2*torch.sinh(torch.sqrt(Bx**2 - Ax*Cx))**2 + bx*(4*bx*(Bx**2 - Ax*Cx) - (-(Ax*bx**2) + 2*ax*bx*Bx - Cx - ax**2*Cx)*torch.sqrt(Bx**2 - Ax*Cx)*torch.sin(4*nux*torch.pi)*torch.sinh(2*torch.sqrt(Bx**2 - Ax*Cx))))/(bx**2*(Bx**2 - Ax*Cx))))
    return (100*(bxf - bxs)/bxs).real

def bby(Ay, By, Cy, nuy, ay, by, ays, bys, muy):
    byf = (bys*(8*by**2*(By**2 - Ay*Cy)*(torch.cos(2*muy) - torch.cos(4*nuy*torch.pi))*torch.cosh(torch.sqrt(By**2 - Ay*Cy))**4 + 16*by**2*(By**2 - Ay*Cy)*torch.cosh(torch.sqrt(By**2 - Ay*Cy))**2*torch.sin(muy)**2 + 8*by*torch.sqrt(By**2 - Ay*Cy)*torch.cosh(torch.sqrt(By**2 - Ay*Cy))**3*(-4*Cy*torch.cos(muy)*torch.cos(2*nuy*torch.pi) + 2*torch.sin(muy)*(2*(by*By - ay*Cy)*torch.cos(2*nuy*torch.pi) + (Ay*by**2 - 2*ay*by*By + Cy + ay**2*Cy)*torch.sin(2*nuy*torch.pi)))*torch.sin(muy - 2*nuy*torch.pi)*torch.sinh(torch.sqrt(By**2 - Ay*Cy)) + 16*(Cy*torch.cos(muy) + (-(by*By) + ay*Cy)*torch.sin(muy))**2*torch.sinh(torch.sqrt(By**2 - Ay*Cy))**2 - 8*by*torch.sqrt(By**2 - Ay*Cy)*torch.cosh(torch.sqrt(By**2 - Ay*Cy))*(-4*Cy*torch.cos(muy)*torch.cos(2*nuy*torch.pi) + 2*torch.sin(muy)*(2*(by*By - ay*Cy)*torch.cos(2*nuy*torch.pi) + (Ay*by**2 - 2*ay*by*By + Cy + ay**2*Cy)*torch.sin(2*nuy*torch.pi)))*torch.sin(muy - 2*nuy*torch.pi)*torch.sinh(torch.sqrt(By**2 - Ay*Cy))**3 - 16*(-(by**2*(By**2 - Ay*Cy)*torch.cos(nuy*torch.pi)**2*torch.sin(muy)) + torch.sin(muy)*torch.sin(nuy*torch.pi)*(-2*(by*By - ay*Cy)*(Ay*by**2 - 2*ay*by*By + Cy + ay**2*Cy)*torch.cos(nuy*torch.pi) + by**2*(By**2 - Ay*Cy)*torch.sin(nuy*torch.pi)) + (by**2*By**2 - 2*ay*by*By*Cy + (1 + ay**2)*Cy**2)*torch.cos(muy)*torch.sin(2*nuy*torch.pi))*torch.sin(muy - 2*nuy*torch.pi)*torch.sinh(torch.sqrt(By**2 - Ay*Cy))**4 + torch.sinh(2*torch.sqrt(By**2 - Ay*Cy))*(-16*by*torch.sqrt(By**2 - Ay*Cy)*torch.sin(muy)*(-(Cy*torch.cos(muy)) + (by*By - ay*Cy)*torch.sin(muy)) - 4*(-((-2*ay*by*By*Cy + (1 + ay**2)*Cy**2 + by**2*(2*By**2 - Ay*Cy))*torch.cos(muy)) + (by*By - ay*Cy)*(Ay*by**2 - 2*ay*by*By + Cy + ay**2*Cy)*torch.sin(muy))*torch.sin(2*nuy*torch.pi)*torch.sin(muy - 2*nuy*torch.pi)*torch.sinh(2*torch.sqrt(By**2 - Ay*Cy)))))/(8*by*torch.sqrt(By**2 - Ay*Cy)*(by*torch.sqrt(By**2 - Ay*Cy)*torch.cosh(torch.sqrt(By**2 - Ay*Cy))*torch.sin(2*nuy*torch.pi) + (Cy*torch.cos(2*nuy*torch.pi) + (-(by*By) + ay*Cy)*torch.sin(2*nuy*torch.pi))*torch.sinh(torch.sqrt(By**2 - Ay*Cy)))*torch.sqrt((-4*by**2*(By**2 - Ay*Cy)*torch.cos(2*nuy*torch.pi)**2*torch.cosh(torch.sqrt(By**2 - Ay*Cy))**2 - (Ay*by**2 - 2*ay*by*By + Cy + ay**2*Cy)**2*torch.sin(2*nuy*torch.pi)**2*torch.sinh(torch.sqrt(By**2 - Ay*Cy))**2 + by*(4*by*(By**2 - Ay*Cy) - (-(Ay*by**2) + 2*ay*by*By - Cy - ay**2*Cy)*torch.sqrt(By**2 - Ay*Cy)*torch.sin(4*nuy*torch.pi)*torch.sinh(2*torch.sqrt(By**2 - Ay*Cy))))/(by**2*(By**2 - Ay*Cy))))
    return (100*(byf - bys)/bys).real

plt.figure(figsize=(12, 4))
plt.errorbar(ring.locations().cpu().numpy(), 100*((bx - bx_id)/bx).cpu().numpy(), fmt='-', marker='x', ms=0, color='blue', alpha=0.75, label='exact')
plt.errorbar(ring.locations().cpu().numpy(), -bbx(Ax, Bx, Cx, nux, axi, bxi, ax, bx, mux), fmt=' ', marker='x', ms=5, color='red', alpha=0.75, label='approximate')

plt.gca().tick_params(axis='x', length=6, width=1.5, direction='in', labelsize=12, bottom=True, top=False, labelbottom=True, labeltop=False)
plt.gca().tick_params(axis='y', length=0, width=0, labelsize=12)
plt.gca().set_xlabel(r'$s$', fontsize=18)
plt.gca().set_ylabel(r'$\Delta \beta_x / \beta_x$', fontsize=18)
plt.legend(loc='upper left', frameon=False, fontsize=14, ncol=2)
plt.ylim(-25, 25)
plt.tight_layout()
plt.show()

print(100*(((bx - bx_id)/bx)**2).mean().sqrt())
print((bbx(Ax, Bx, Cx, nux, axi, bxi, ax, bx, mux)**2).mean().sqrt())

plt.figure(figsize=(12, 4))
plt.errorbar(ring.locations().cpu().numpy(), 100*((by - by_id)/by).cpu().numpy(), fmt='-', marker='x', ms=0, color='blue', alpha=0.75, label='exact')
plt.errorbar(ring.locations().cpu().numpy(), -bby(Ay, By, Cy, nuy, ayi, byi, ay, by, muy), fmt=' ', marker='x', ms=5, color='red', alpha=0.75, label='approximate')
plt.gca().tick_params(axis='x', length=6, width=1.5, direction='in', labelsize=12, bottom=True, top=False, labelbottom=True, labeltop=False)
plt.gca().tick_params(axis='y', length=0, width=0, labelsize=12)
plt.gca().set_xlabel(r'$s$', fontsize=18)
plt.gca().set_ylabel(r'$\Delta \beta_y / \beta_y$', fontsize=18)
plt.legend(loc='upper left', frameon=False, fontsize=14, ncol=2)
plt.ylim(-5, 5)
plt.tight_layout()
plt.show()

print(100*(((by - by_id)/by)**2).mean().sqrt())
print((bby(Ay, By, Cy, nuy, ayi, byi, ay, by, muy)**2).mean().sqrt())
../_images/examples_elettra-37_19_0.png
tensor(11.5261, dtype=torch.float64)
tensor(11.5264, dtype=torch.float64)
../_images/examples_elettra-37_19_2.png
tensor(1.9667, dtype=torch.float64)
tensor(1.9658, dtype=torch.float64)
[20]:
# Beta beating (approximate)

def bbx(Ax, Bx, Cx, nux, ax, bx, axs, bxs, mux):
    t = 1.0
    cx = (1 + ax**2)/bx
    bxf = bxs + (bxs*t**2*(1.0/torch.sin(2*nux*torch.pi)**2)*(4*ax*Bx*Cx*(torch.cos(2*mux) + torch.cos(2*mux - 4*nux*torch.pi)) - 2*cx*Cx**2*(torch.cos(2*mux) + torch.cos(2*mux - 4*nux*torch.pi)) + Ax**2*bx**3*(1 + torch.cos(2*mux) + torch.cos(2*mux - 4*nux*torch.pi)) - 4*ax*Ax*bx**2*Bx*(1 + torch.cos(2*mux) + torch.cos(2*mux - 4*nux*torch.pi)) - 4*ax*bx*Bx*cx*Cx*(1 + torch.cos(2*mux) + torch.cos(2*mux - 4*nux*torch.pi)) + bx*cx**2*Cx**2*(1 + torch.cos(2*mux) + torch.cos(2*mux - 4*nux*torch.pi)) + 2*bx**2*cx*(2*Bx**2 + Ax*Cx)*(1 + torch.cos(2*mux) + torch.cos(2*mux - 4*nux*torch.pi)) - 2*bx*(2*Ax*Cx + (2*Bx**2 + Ax*Cx)*torch.cos(2*mux) + (2*Bx**2 + Ax*Cx)*torch.cos(2*mux - 4*nux*torch.pi)) - 4*Bx*Cx*torch.sin(2*mux) - 4*Bx*Cx*torch.sin(2*mux - 4*nux*torch.pi) + 2*Ax*bx**2*Bx*(torch.sin(2*mux) + torch.sin(2*mux - 4*nux*torch.pi)) + 6*bx*Bx*cx*Cx*(torch.sin(2*mux) + torch.sin(2*mux - 4*nux*torch.pi)) - 2*ax*cx*Cx**2*(torch.sin(2*mux) + torch.sin(2*mux - 4*nux*torch.pi)) - 2*ax*bx*(2*Bx**2 + Ax*Cx)*(torch.sin(2*mux) + torch.sin(2*mux - 4*nux*torch.pi))))/(8*bx) - (bxs*t*(1.0/torch.sin(2*nux*torch.pi))*((Ax*bx**2 - 2*ax*bx*Bx - 2*Cx + bx*cx*Cx)*torch.cos(2*mux - 2*nux*torch.pi) + 2*(bx*Bx - ax*Cx)*torch.sin(2*mux - 2*nux*torch.pi)))/(2*bx)
    return (100*(bxf - bxs)/bxs).real

def bby(Ay, By, Cy, nuy, ay, by, ays, bys, muy):
    t = 1.0
    cy = (1 + ay**2)/by
    byf = bys + (bys*t**2*(1.0/torch.sin(2*nuy*torch.pi)**2)*(4*ay*By*Cy*(torch.cos(2*muy) + torch.cos(2*muy - 4*nuy*torch.pi)) - 2*cy*Cy**2*(torch.cos(2*muy) + torch.cos(2*muy - 4*nuy*torch.pi)) + Ay**2*by**3*(1 + torch.cos(2*muy) + torch.cos(2*muy - 4*nuy*torch.pi)) - 4*ay*Ay*by**2*By*(1 + torch.cos(2*muy) + torch.cos(2*muy - 4*nuy*torch.pi)) - 4*ay*by*By*cy*Cy*(1 + torch.cos(2*muy) + torch.cos(2*muy - 4*nuy*torch.pi)) + by*cy**2*Cy**2*(1 + torch.cos(2*muy) + torch.cos(2*muy - 4*nuy*torch.pi)) + 2*by**2*cy*(2*By**2 + Ay*Cy)*(1 + torch.cos(2*muy) + torch.cos(2*muy - 4*nuy*torch.pi)) - 2*by*(2*Ay*Cy + (2*By**2 + Ay*Cy)*torch.cos(2*muy) + (2*By**2 + Ay*Cy)*torch.cos(2*muy - 4*nuy*torch.pi)) - 4*By*Cy*torch.sin(2*muy) - 4*By*Cy*torch.sin(2*muy - 4*nuy*torch.pi) + 2*Ay*by**2*By*(torch.sin(2*muy) + torch.sin(2*muy - 4*nuy*torch.pi)) + 6*by*By*cy*Cy*(torch.sin(2*muy) + torch.sin(2*muy - 4*nuy*torch.pi)) - 2*ay*cy*Cy**2*(torch.sin(2*muy) + torch.sin(2*muy - 4*nuy*torch.pi)) - 2*ay*by*(2*By**2 + Ay*Cy)*(torch.sin(2*muy) + torch.sin(2*muy - 4*nuy*torch.pi))))/(8*by) - (bys*t*(1.0/torch.sin(2*nuy*torch.pi))*((Ay*by**2 - 2*ay*by*By - 2*Cy + by*cy*Cy)*torch.cos(2*muy - 2*nuy*torch.pi) + 2*(by*By - ay*Cy)*torch.sin(2*muy - 2*nuy*torch.pi)))/(2*by)
    return (100*(byf - bys)/bys).real

plt.figure(figsize=(12, 4))
plt.errorbar(ring.locations().cpu().numpy(), 100*((bx - bx_id)/bx).cpu().numpy(), fmt='-', marker='x', ms=0, color='blue', alpha=0.75, label='exact')
plt.errorbar(ring.locations().cpu().numpy(), -bbx(Ax, Bx, Cx, nux, axi, bxi, ax, bx, mux), fmt=' ', marker='x', ms=5, color='red', alpha=0.75, label='approximate')

plt.gca().tick_params(axis='x', length=6, width=1.5, direction='in', labelsize=12, bottom=True, top=False, labelbottom=True, labeltop=False)
plt.gca().tick_params(axis='y', length=0, width=0, labelsize=12)
plt.gca().set_xlabel(r'$s$', fontsize=18)
plt.gca().set_ylabel(r'$\Delta \beta_x / \beta_x$', fontsize=18)
plt.legend(loc='upper left', frameon=False, fontsize=14, ncol=2)
plt.ylim(-25, 25)
plt.tight_layout()
plt.show()

print(100*(((bx - bx_id)/bx)**2).mean().sqrt())
print((bbx(Ax, Bx, Cx, nux, axi, bxi, ax, bx, mux)**2).mean().sqrt())

plt.figure(figsize=(12, 4))
plt.errorbar(ring.locations().cpu().numpy(), 100*((by - by_id)/by).cpu().numpy(), fmt='-', marker='x', ms=0, color='blue', alpha=0.75, label='exact')
plt.errorbar(ring.locations().cpu().numpy(), -bby(Ay, By, Cy, nuy, ayi, byi, ay, by, muy), fmt=' ', marker='x', ms=5, color='red', alpha=0.75, label='approximate')
plt.gca().tick_params(axis='x', length=6, width=1.5, direction='in', labelsize=12, bottom=True, top=False, labelbottom=True, labeltop=False)
plt.gca().tick_params(axis='y', length=0, width=0, labelsize=12)
plt.gca().set_xlabel(r'$s$', fontsize=18)
plt.gca().set_ylabel(r'$\Delta \beta_y / \beta_y$', fontsize=18)
plt.legend(loc='upper left', frameon=False, fontsize=14, ncol=2)
plt.ylim(-5, 5)
plt.tight_layout()
plt.show()

print(100*(((by - by_id)/by)**2).mean().sqrt())
print((bby(Ay, By, Cy, nuy, ayi, byi, ay, by, muy)**2).mean().sqrt())
../_images/examples_elettra-37_20_0.png
tensor(11.5261, dtype=torch.float64)
tensor(11.3582, dtype=torch.float64)
../_images/examples_elettra-37_20_2.png
tensor(1.9667, dtype=torch.float64)
tensor(1.9583, dtype=torch.float64)
[21]:
# Analytical RMS beta-beting

def rmsx(Ax, Bx, Cx, ax, bx, nux):
    cx = (1 + ax**2)/bx
    return 100*(((1.0/torch.sin(2*nux*torch.pi)**4)*((Ax**4*bx**6 + 16*bx**4*Bx**4*cx**2 - 4*Ax**3*bx**4*(2*ax*bx*Bx + Cx - bx*cx*Cx) - 8*bx**3*Bx**2*cx*(4*Bx**2 + 4*ax*Bx*cx*Cx - cx**2*Cx**2) + 4*(1 + ax**2)*Cx**2*(-4 + 4*Bx**2 + cx**2*Cx**2) - 4*bx*cx*Cx**2*(-4 + 4*(2 + ax**2)*Bx**2 + cx**2*Cx**2) + 2*Ax**2*bx**2*(4*bx**3*Bx**2*cx + 2*(1 + ax**2)*Cx**2 + 2*bx*Cx*(4*ax*Bx - 3*cx*Cx) + bx**2*(-2 + (-2 + 8*ax**2)*Bx**2 - 12*ax*Bx*cx*Cx + 3*cx**2*Cx**2)) + bx**2*(16*(1 + ax**2)*Bx**4 + 16*ax*Bx**3*cx*Cx + cx**2*Cx**2*(-4 + cx**2*Cx**2) - 8*ax*Bx*cx*Cx*(-2 + cx**2*Cx**2) + 4*Bx**2*(-4 + 3*cx**2*Cx**2 + 4*ax**2*(-1 + cx**2*Cx**2))) + 4*Ax*bx*(2*ax**2*Cx*(-2*bx*Bx**2 + 4*bx**2*Bx**2*cx + cx*Cx**2) + 2*ax*bx*Bx*(-4*bx**2*Bx**2*cx + 2*cx*Cx**2 + bx*(2 + 2*Bx**2 - 3*cx**2*Cx**2)) + Cx*(4*bx**3*Bx**2*cx**2 + 2*cx*Cx**2 + bx*(4 - 3*cx**2*Cx**2) + bx**2*cx*(-2 - 2*Bx**2 + cx**2*Cx**2))))*torch.cos(4*nux*torch.pi) + 2*(Ax**4*bx**6 + 8*bx**2*Bx**2 + 8*ax**2*bx**2*Bx**2 + 8*bx**2*Bx**4 + 8*ax**2*bx**2*Bx**4 - 16*bx**3*Bx**4*cx + 16*bx**4*Bx**4*cx**2 - 8*ax*bx**2*Bx*cx*Cx + 8*ax*bx**2*Bx**3*cx*Cx - 32*ax*bx**3*Bx**3*cx**2*Cx + 8*Cx**2 + 8*ax**2*Cx**2 + 8*Bx**2*Cx**2 + 8*ax**2*Bx**2*Cx**2 - 8*bx*cx*Cx**2 - 16*bx*Bx**2*cx*Cx**2 - 8*ax**2*bx*Bx**2*cx*Cx**2 + 2*bx**2*cx**2*Cx**2 + 6*bx**2*Bx**2*cx**2*Cx**2 + 16*ax**2*bx**2*Bx**2*cx**2*Cx**2 + 8*bx**3*Bx**2*cx**3*Cx**2 - 8*ax*bx**2*Bx*cx**3*Cx**3 + 2*cx**2*Cx**4 + 2*ax**2*cx**2*Cx**4 - 2*bx*cx**3*Cx**4 + bx**2*cx**4*Cx**4 - 2*Ax**3*bx**4*(4*ax*bx*Bx + (3 - 2*bx*cx)*Cx) + 2*Ax**2*bx**2*(4*bx**3*Bx**2*cx + (5 + ax**2)*Cx**2 + bx*Cx*(12*ax*Bx - 7*cx*Cx) + bx**2*(1 + (-1 + 8*ax**2)*Bx**2 - 12*ax*Bx*cx*Cx + 3*cx**2*Cx**2)) + 2*Ax*bx*(2*ax**2*Cx*(-2*bx*Bx**2 + 8*bx**2*Bx**2*cx + cx*Cx**2) - 4*ax*bx*Bx*(bx - bx*Bx**2 + 4*bx**2*Bx**2*cx - 3*cx*Cx**2 + 3*bx*cx**2*Cx**2) + Cx*(8*bx**3*Bx**2*cx**2 + 2*cx*Cx**2 - bx*(4 + 5*cx**2*Cx**2) + 2*bx**2*(cx - 5*Bx**2*cx + cx**3*Cx**2))) - 2*(Ax**3*bx**5 + Ax**2*bx**3*(-6*ax*bx*Bx + (-4 + 3*bx*cx)*Cx) + Ax*bx*(4*bx**3*Bx**2*cx + 4*ax*bx*Bx*(2 - 3*bx*cx)*Cx + 4*Cx**2 - 8*bx*cx*Cx**2 + 3*bx**2*cx**2*Cx**2 + 4*ax**2*(2*bx**2*Bx**2 + Cx**2)) + cx*(4*bx**3*Bx**2*cx*Cx + 4*Cx**3 - 4*bx*cx*Cx**3 + bx**2*cx**2*Cx**3 - 2*ax*bx**2*Bx*(4*bx*Bx**2 + 3*cx*Cx**2) + 4*ax**2*(2*bx**2*Bx**2*Cx + Cx**3)))*torch.sin(4*nux*torch.pi))))/(64*bx**2)).sqrt()

def rmsy(Ay, By, Cy, ay, by, nuy):
    cy = (1 + ay**2)/by
    return 100*(((1.0/torch.sin(2*nuy*torch.pi)**4)*((Ay**4*by**6 + 16*by**4*By**4*cy**2 - 4*Ay**3*by**4*(2*ay*by*By + Cy - by*cy*Cy) - 8*by**3*By**2*cy*(4*By**2 + 4*ay*By*cy*Cy - cy**2*Cy**2) + 4*(1 + ay**2)*Cy**2*(-4 + 4*By**2 + cy**2*Cy**2) - 4*by*cy*Cy**2*(-4 + 4*(2 + ay**2)*By**2 + cy**2*Cy**2) + 2*Ay**2*by**2*(4*by**3*By**2*cy + 2*(1 + ay**2)*Cy**2 + 2*by*Cy*(4*ay*By - 3*cy*Cy) + by**2*(-2 + (-2 + 8*ay**2)*By**2 - 12*ay*By*cy*Cy + 3*cy**2*Cy**2)) + by**2*(16*(1 + ay**2)*By**4 + 16*ay*By**3*cy*Cy + cy**2*Cy**2*(-4 + cy**2*Cy**2) - 8*ay*By*cy*Cy*(-2 + cy**2*Cy**2) + 4*By**2*(-4 + 3*cy**2*Cy**2 + 4*ay**2*(-1 + cy**2*Cy**2))) + 4*Ay*by*(2*ay**2*Cy*(-2*by*By**2 + 4*by**2*By**2*cy + cy*Cy**2) + 2*ay*by*By*(-4*by**2*By**2*cy + 2*cy*Cy**2 + by*(2 + 2*By**2 - 3*cy**2*Cy**2)) + Cy*(4*by**3*By**2*cy**2 + 2*cy*Cy**2 + by*(4 - 3*cy**2*Cy**2) + by**2*cy*(-2 - 2*By**2 + cy**2*Cy**2))))*torch.cos(4*nuy*torch.pi) + 2*(Ay**4*by**6 + 8*by**2*By**2 + 8*ay**2*by**2*By**2 + 8*by**2*By**4 + 8*ay**2*by**2*By**4 - 16*by**3*By**4*cy + 16*by**4*By**4*cy**2 - 8*ay*by**2*By*cy*Cy + 8*ay*by**2*By**3*cy*Cy - 32*ay*by**3*By**3*cy**2*Cy + 8*Cy**2 + 8*ay**2*Cy**2 + 8*By**2*Cy**2 + 8*ay**2*By**2*Cy**2 - 8*by*cy*Cy**2 - 16*by*By**2*cy*Cy**2 - 8*ay**2*by*By**2*cy*Cy**2 + 2*by**2*cy**2*Cy**2 + 6*by**2*By**2*cy**2*Cy**2 + 16*ay**2*by**2*By**2*cy**2*Cy**2 + 8*by**3*By**2*cy**3*Cy**2 - 8*ay*by**2*By*cy**3*Cy**3 + 2*cy**2*Cy**4 + 2*ay**2*cy**2*Cy**4 - 2*by*cy**3*Cy**4 + by**2*cy**4*Cy**4 - 2*Ay**3*by**4*(4*ay*by*By + (3 - 2*by*cy)*Cy) + 2*Ay**2*by**2*(4*by**3*By**2*cy + (5 + ay**2)*Cy**2 + by*Cy*(12*ay*By - 7*cy*Cy) + by**2*(1 + (-1 + 8*ay**2)*By**2 - 12*ay*By*cy*Cy + 3*cy**2*Cy**2)) + 2*Ay*by*(2*ay**2*Cy*(-2*by*By**2 + 8*by**2*By**2*cy + cy*Cy**2) - 4*ay*by*By*(by - by*By**2 + 4*by**2*By**2*cy - 3*cy*Cy**2 + 3*by*cy**2*Cy**2) + Cy*(8*by**3*By**2*cy**2 + 2*cy*Cy**2 - by*(4 + 5*cy**2*Cy**2) + 2*by**2*(cy - 5*By**2*cy + cy**3*Cy**2))) - 2*(Ay**3*by**5 + Ay**2*by**3*(-6*ay*by*By + (-4 + 3*by*cy)*Cy) + Ay*by*(4*by**3*By**2*cy + 4*ay*by*By*(2 - 3*by*cy)*Cy + 4*Cy**2 - 8*by*cy*Cy**2 + 3*by**2*cy**2*Cy**2 + 4*ay**2*(2*by**2*By**2 + Cy**2)) + cy*(4*by**3*By**2*cy*Cy + 4*Cy**3 - 4*by*cy*Cy**3 + by**2*cy**2*Cy**3 - 2*ay*by**2*By*(4*by*By**2 + 3*cy*Cy**2) + 4*ay**2*(2*by**2*By**2*Cy + Cy**3)))*torch.sin(4*nuy*torch.pi))))/(64*by**2)).sqrt()

print(100*(((bx - bx_id)/bx)**2).mean().sqrt())
print(rmsx(Ax, Bx, Cx, axi, bxi, nux).real)
print()

print(100*(((by - by_id)/by)**2).mean().sqrt())
print(rmsy(Ay, By, Cy, ayi, byi, nuy).real)
print()
tensor(11.5261, dtype=torch.float64)
tensor(11.3315, dtype=torch.float64)

tensor(1.9667, dtype=torch.float64)
tensor(1.8970, dtype=torch.float64)