ELETTRA-39: ID linear optics distortion (analytical & multiple perturbatons)

[1]:
# In this example one-turn matrix factorization is illustrated
# The one-turn matrix can be expressed as M exp(S A) where Aij = Aji is (a block matrix if there is no coupling that) describes cumulative effect of perturbation
# Here the matrix A is constructed perturbatively using Magnus expansion
[2]:
# Import

import torch
from torch import Tensor

from pathlib import Path

import matplotlib
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
matplotlib.rcParams['text.usetex'] = True

from model.library.element import Element
from model.library.line import Line
from model.library.quadrupole import Quadrupole
from model.library.matrix import Matrix

from model.command.external import load_lattice
from model.command.build import build
from model.command.tune import tune
from model.command.orbit import dispersion
from model.command.twiss import twiss
from model.command.advance import advance
from model.command.coupling import coupling
[3]:
# Set data type and device

Element.dtype = dtype = torch.float64
Element.device = device = torch.device('cpu')
[4]:
# Load lattice (ELEGANT table)
# Note, lattice is allowed to have repeated elements

path = Path('elettra.lte')
data = load_lattice(path)
[5]:
# Build and setup lattice

ring:Line = build('RING', 'ELEGANT', data)

# Flatten sublines

ring.flatten()

# Remove all marker elements but the ones starting with MLL (long straight section centers)

ring.remove_group(pattern=r'^(?!MLL_).*', kinds=['Marker'])

# Replace all sextupoles with quadrupoles

def factory(element:Element) -> None:
    table = element.serialize
    table.pop('ms', None)
    return Quadrupole(**table)

ring.replace_group(pattern=r'', factory=factory, kinds=['Sextupole'])

# Set linear dipoles

def apply(element:Element) -> None:
    element.linear = True

ring.apply(apply, kinds=['Dipole'])

# Merge drifts

ring.merge()

# Change lattice start

ring.start = "MLL_S01"

# Describe

ring.describe
[5]:
{'Marker': 12, 'Drift': 708, 'BPM': 168, 'Quadrupole': 360, 'Dipole': 156}
[6]:
# Unperturbed one-turn matrix

state = torch.tensor(4*[0.0], dtype=dtype)
M0 = torch.func.jacrev(ring)(state)
print(M0)
tensor([[-0.3055,  8.9649,  0.0000,  0.0000],
        [-0.1011, -0.3055,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.5315,  1.3891],
        [ 0.0000,  0.0000, -0.5166,  0.5315]], dtype=torch.float64)
[7]:
# Define IDs

# The first ID is 'fake' and acts as an identity
# It is also located at the lattice start

ca, cb, cc, cd = -0.034441907232402175, -0.04458009513208418, 0.056279356423643276, 0.08037110220505986
A = torch.tensor([[ca, 0.0, 0.0, 0.0], [0.0, cb, 0.0, 0.0], [0.0, 0.0, cc, 0.0], [0.0, 0.0, 0.0, cd]], dtype=dtype)
mask = torch.triu(torch.ones_like(A, dtype=torch.bool))

ID1 = Matrix('ID1', length=0.0, A=(0.25*A[mask]).tolist())
ID2 = Matrix('ID2', length=0.0, A=(0.25*A[mask]).tolist())
ID3 = Matrix('ID3', length=0.0, A=(0.25*A[mask]).tolist())
[8]:
# Insert IDs

# Each ID (or other perturbation) is a thin insertion matrix
# The matrices are inserted after given markers

elements = [ID1, ID2, ID3]
markers = ['MLL_S01', 'MLL_S02', 'MLL_S03']

error = ring.clone()
for element, marker in zip(elements, markers):
    error.insert(element, marker)

# Describe

error.describe
[8]:
{'Marker': 12,
 'Matrix': 3,
 'Drift': 708,
 'BPM': 168,
 'Quadrupole': 360,
 'Dipole': 156}
[9]:
# Perturbed one-turn matrix

state = torch.tensor(4*[0.0], dtype=dtype)
M1 = torch.func.jacrev(error)(state)

print(M0)
print(M1)
print()
tensor([[-0.3055,  8.9649,  0.0000,  0.0000],
        [-0.1011, -0.3055,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.5315,  1.3891],
        [ 0.0000,  0.0000, -0.5166,  0.5315]], dtype=torch.float64)
tensor([[-0.1780,  9.1328,  0.0000,  0.0000],
        [-0.1057, -0.1928,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.4820,  1.4276],
        [ 0.0000,  0.0000, -0.5351,  0.4897]], dtype=torch.float64)

[10]:
# Skew identity matrix

S = torch.tensor([[0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1], [0, 0, -1, 0]], dtype=dtype)

# Compute perturbations

Ais = []
Kis = []
for element in elements:
    Ai = torch.zeros((4, 4), dtype=dtype)
    Ai[mask] = element.A
    Ki = torch.linalg.matrix_exp(S @ Ai)
    Ais.append(Ai)
    Kis.append(Ki)

Ais = torch.stack(Ais)
Kis = torch.stack(Kis)

print([torch.allclose(Ki, torch.func.jacrev(element)(state)) for Ki, element in zip(Kis, elements)])
print()

# Compute matrices between IDs (using markers in the unperturbed lattice)

Tis = []
for i in range(len(markers)):
    _, *sequence, _ = ring[markers[i] : markers[(i + 1) % len(markers)]]
    line = Line('', sequence=sequence)
    Tis.append(torch.func.jacrev(line)(state))
Tis = torch.stack(Tis)

# Construct one-turn matrix

hM = torch.eye(4, dtype=dtype)
for Ti in Tis:
    hM = Ti @ hM

print(M0)
print(hM)
print(torch.allclose(M0, hM))
print()

# Construct one-turn matrix

hM = torch.eye(4, dtype=dtype)
for Ti, Ki in zip(Tis, Kis):
    hM = Ti @ Ki @ hM

print(M1)
print(hM)
print(torch.allclose(M1, hM))
print()
[True, True, True]

tensor([[-0.3055,  8.9649,  0.0000,  0.0000],
        [-0.1011, -0.3055,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.5315,  1.3891],
        [ 0.0000,  0.0000, -0.5166,  0.5315]], dtype=torch.float64)
tensor([[-0.3055,  8.9649,  0.0000,  0.0000],
        [-0.1011, -0.3055,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.5315,  1.3891],
        [ 0.0000,  0.0000, -0.5166,  0.5315]], dtype=torch.float64)
True

tensor([[-0.1780,  9.1328,  0.0000,  0.0000],
        [-0.1057, -0.1928,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.4820,  1.4276],
        [ 0.0000,  0.0000, -0.5351,  0.4897]], dtype=torch.float64)
tensor([[-0.1780,  9.1328,  0.0000,  0.0000],
        [-0.1057, -0.1928,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.4820,  1.4276],
        [ 0.0000,  0.0000, -0.5351,  0.4897]], dtype=torch.float64)
True

[11]:
# Compute unperturbed transport matrices from lattice start to each ID

Tis = []
for marker in markers:
    *sequence, _ = ring[ring.start : marker]
    line = Line('', sequence=sequence)
    Tis.append(torch.func.jacrev(line)(state))
Tis = torch.stack(Tis)

# Construct one-turn matrix

hM = torch.eye(4, dtype=dtype)
for Ti, Ki in zip(Tis, Kis):
    hKi = Ti.inverse() @ Ki @ Ti
    hM =  hKi @ hM
hM = M0 @ hM

print(M1)
print(hM)
print(torch.allclose(M1, hM))
print()
tensor([[-0.1780,  9.1328,  0.0000,  0.0000],
        [-0.1057, -0.1928,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.4820,  1.4276],
        [ 0.0000,  0.0000, -0.5351,  0.4897]], dtype=torch.float64)
tensor([[-0.1780,  9.1328,  0.0000,  0.0000],
        [-0.1057, -0.1928,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.4820,  1.4276],
        [ 0.0000,  0.0000, -0.5351,  0.4897]], dtype=torch.float64)
True

[12]:
# Construct shifted exponents

hAis = []
for Ti, Ai in zip(Tis, Ais):
    hAi = Ti.T @ Ai @ Ti
    hAis.append(hAi)
hAis = torch.stack(hAis)

# Construct one-turn matrix

hM = torch.eye(4, dtype=dtype)
for hAi in hAis:
    hM =  torch.linalg.matrix_exp(S @ hAi) @ hM
hM = M0 @ hM

print(M1)
print(hM)
print(torch.allclose(M1, hM))
print()
tensor([[-0.1780,  9.1328,  0.0000,  0.0000],
        [-0.1057, -0.1928,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.4820,  1.4276],
        [ 0.0000,  0.0000, -0.5351,  0.4897]], dtype=torch.float64)
tensor([[-0.1780,  9.1328,  0.0000,  0.0000],
        [-0.1057, -0.1928,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.4820,  1.4276],
        [ 0.0000,  0.0000, -0.5351,  0.4897]], dtype=torch.float64)
True

[13]:
# Define bracket (commutator)

def bracket(X, Y):
    return X @ Y - Y @ X
[14]:
# Construct product approximation

hA1 = torch.zeros((4, 4), dtype=dtype)
for hAi in hAis:
    hA1 += hAi

hA2 = torch.zeros((4, 4), dtype=dtype)
for i in range(len(elements)):
    for j in range(i + 1):
        dij = (i == j)
        factor = 1/(1 + dij)
        Xi = S @ hAis[i]
        Xj = S @ hAis[j]
        hA2 += factor*1/2*bracket(Xi, Xj)
hA2 = - S @ hA2

hA3 = torch.zeros((4, 4), dtype=dtype)
for i in range(len(elements)):
    for j in range(i + 1):
        for k in range(j + 1):
            dij = (i == j)
            dik = (i == k)
            djk = (j == k)
            factor = (1.0/(1.0 + dij))*(1.0/(1.0 + dik + djk))
            Xi = S @ hAis[i]
            Xj = S @ hAis[j]
            Xk = S @ hAis[k]
            hA3 += factor*(
                + 1/4*bracket(bracket(Xi, Xj), Xk)
                - (1/12)*bracket(bracket(Xi, Xk), Xj)
                - (1/12)*bracket(bracket(Xj, Xk), Xi)
            )
hA3 = - S @ hA3

hA4 = torch.zeros((4, 4), dtype=dtype)
for i in range(len(elements)):
    for j in range(i + 1):
        for k in range(j + 1):
            for l in range(k + 1):
                dij = (i == j)
                dik = (i == k)
                djk = (j == k)
                dil = (i == l)
                djl = (j == l)
                dkl = (k == l)
                factor = (1.0 / (1.0 + dij)) * (1.0 / (1.0 + dik + djk)) * (1.0 / (1.0 + dil + djl + dkl))
                Xi = S @ hAis[i]
                Xj = S @ hAis[j]
                Xk = S @ hAis[k]
                Xl = S @ hAis[l]
                hA4 += factor*(
                    1/12*bracket(bracket(bracket(Xi, Xj), Xk), Xl) +
                    1/12*bracket(Xi, bracket(bracket(Xj, Xk), Xl)) +
                    1/12*bracket(Xi, bracket(Xj, bracket(Xk, Xl))) +
                    1/12*bracket(Xj, bracket(Xk, bracket(Xl, Xi)))
                )
hA4 = - S @ hA4
[15]:
# Set elements

Ax, Bx, _, _, Cx, _, _, Ay, By, Cy = (hA1 + hA2  + hA2  + hA4)[mask].to(torch.complex128)
[16]:
# Compute tunes (fractional part)

nux, nuy = tune(ring, [], matched=True, limit=1)
[17]:
# Compute twiss parameters

ax, bx, ay, by = twiss(ring, [], matched=True, advance=True, full=False).T

axi, *_ = ax
bxi, *_ = bx
ayi, *_ = ay
byi, *_ = by
[18]:
# Compute tunes (fractional part)

nux_id, nuy_id = tune(error, [], matched=True, limit=1)
[19]:
# Compute twiss parameters

ax_id, bx_id, ay_id, by_id = twiss(error, [], matched=True, advance=True, full=False).T

axf, *_ = ax_id
bxf, *_ = bx_id
ayf, *_ = ay_id
byf, *_ = by_id
[20]:
# Tune shifts

print(nux - nux_id)
print(nuy - nuy_id)
tensor(0.0197, dtype=torch.float64)
tensor(-0.0084, dtype=torch.float64)
[21]:
# Tune shifts (exact)

def dnux(Ax, Bx, Cx, ax, bx, nux):
    return nux - torch.arccos(torch.cos(2*nux*torch.pi)*torch.cosh(torch.sqrt(Bx**2 - Ax*Cx)) - ((Ax*bx**2 - 2*ax*bx*Bx + Cx + ax**2*Cx)*torch.sin(2*nux*torch.pi)*torch.sinh(torch.sqrt(Bx**2 - Ax*Cx)))/(2*bx*torch.sqrt(Bx**2 - Ax*Cx)))/(2*torch.pi)

def dnuy(Ay, By, Cy, ay, by, nuy):
    return nuy - torch.arccos(torch.cos(2*nuy*torch.pi)*torch.cosh(torch.sqrt(By**2 - Ay*Cy)) - ((Ay*by**2 - 2*ay*by*By + Cy + ay**2*Cy)*torch.sin(2*nuy*torch.pi)*torch.sinh(torch.sqrt(By**2 - Ay*Cy)))/(2*by*torch.sqrt(By**2 - Ay*Cy)))/(2*torch.pi)

print((nux - nux_id))
print(dnux(Ax, Bx, Cx, axi, bxi, nux).real)
print()

print((nuy - nuy_id))
print(dnux(Ay, By, Cy, ayi, byi, nuy).real)
print()
tensor(0.0197, dtype=torch.float64)
tensor(0.0198, dtype=torch.float64)

tensor(-0.0084, dtype=torch.float64)
tensor(-0.0084, dtype=torch.float64)

[22]:
# Tune shifts (approximate)

def dnux(Ax, Bx, Cx, ax, bx, nux):
    cx = (1 + ax**2)/bx
    return - ((Ax*bx - 2*ax*Bx + cx*Cx)/(4*torch.pi) - ((Ax**2*bx**2 - 4*ax*Ax*bx*Bx + 4*bx*Bx**2*cx + 2*Ax*(-2 + bx*cx)*Cx + cx*Cx*(-4*ax*Bx + cx*Cx))*torch.cos(2*nux*torch.pi)/torch.sin(2*nux*torch.pi))/(16*torch.pi))

def dnuy(Ay, By, Cy, ay, by, nuy):
    cy = (1 + ay**2)/by
    return - ((Ay*by - 2*ay*By + cy*Cy)/(4*torch.pi) - ((Ay**2*by**2 - 4*ay*Ay*by*By + 4*by*By**2*cy + 2*Ay*(-2 + by*cy)*Cy + cy*Cy*(-4*ay*By + cy*Cy))*torch.cos(2*nuy*torch.pi)/torch.sin(2*nuy*torch.pi))/(16*torch.pi))

print((nux - nux_id))
print(dnux(Ax, Bx, Cx, axi, bxi, nux).real)
print()

print((nuy - nuy_id))
print(dnux(Ay, By, Cy, ayi, byi, nuy).real)
print()
tensor(0.0197, dtype=torch.float64)
tensor(0.0198, dtype=torch.float64)

tensor(-0.0084, dtype=torch.float64)
tensor(-0.0084, dtype=torch.float64)

[23]:
# Twiss at the observation point (exact)

def csx(Ax, Bx, Cx, ax, bx, nux):
    cx = (1 + ax**2)/bx
    hax = (2*ax*torch.cosh(torch.sqrt(Bx**2 - Ax*Cx))*torch.sin(2*nux*torch.pi) + ((2*bx*Bx*torch.cos(2*nux*torch.pi) + (-(Ax*bx**2) + Cx + ax**2*Cx)*torch.sin(2*nux*torch.pi))*torch.sinh(torch.sqrt(Bx**2 - Ax*Cx)))/(bx*torch.sqrt(Bx**2 - Ax*Cx)))/(2*torch.sqrt(1 - (torch.cos(2*nux*torch.pi)*torch.cosh(torch.sqrt(Bx**2 - Ax*Cx)) - ((Ax*bx**2 - 2*ax*bx*Bx + Cx + ax**2*Cx)*torch.sin(2*nux*torch.pi)*torch.sinh(torch.sqrt(Bx**2 - Ax*Cx)))/(2*bx*torch.sqrt(Bx**2 - Ax*Cx)))**2))
    hbx = (bx*torch.cosh(torch.sqrt(Bx**2 - Ax*Cx))*torch.sin(2*nux*torch.pi) + ((Cx*torch.cos(2*nux*torch.pi) + (-(bx*Bx) + ax*Cx)*torch.sin(2*nux*torch.pi))*torch.sinh(torch.sqrt(Bx**2 - Ax*Cx)))/torch.sqrt(Bx**2 - Ax*Cx))/torch.sqrt(1 - (torch.cos(2*nux*torch.pi)*torch.cosh(torch.sqrt(Bx**2 - Ax*Cx)) - ((Ax*bx**2 - 2*ax*bx*Bx + Cx + ax**2*Cx)*torch.sin(2*nux*torch.pi)*torch.sinh(torch.sqrt(Bx**2 - Ax*Cx)))/(2*bx*torch.sqrt(Bx**2 - Ax*Cx)))**2)
    return torch.stack([hax, hbx])

def csy(Ay, By, Cy, ay, by, nuy):
    cy = (1 + ay**2)/by
    hay = (2*ay*torch.cosh(torch.sqrt(By**2 - Ay*Cy))*torch.sin(2*nuy*torch.pi) + ((2*by*By*torch.cos(2*nuy*torch.pi) + (-(Ay*by**2) + Cy + ay**2*Cy)*torch.sin(2*nuy*torch.pi))*torch.sinh(torch.sqrt(By**2 - Ay*Cy)))/(by*torch.sqrt(By**2 - Ay*Cy)))/(2*torch.sqrt(1 - (torch.cos(2*nuy*torch.pi)*torch.cosh(torch.sqrt(By**2 - Ay*Cy)) - ((Ay*by**2 - 2*ay*by*By + Cy + ay**2*Cy)*torch.sin(2*nuy*torch.pi)*torch.sinh(torch.sqrt(By**2 - Ay*Cy)))/(2*by*torch.sqrt(By**2 - Ay*Cy)))**2))
    hby = (by*torch.cosh(torch.sqrt(By**2 - Ay*Cy))*torch.sin(2*nuy*torch.pi) + ((Cy*torch.cos(2*nuy*torch.pi) + (-(by*By) + ay*Cy)*torch.sin(2*nuy*torch.pi))*torch.sinh(torch.sqrt(By**2 - Ay*Cy)))/torch.sqrt(By**2 - Ay*Cy))/torch.sqrt(1 - (torch.cos(2*nuy*torch.pi)*torch.cosh(torch.sqrt(By**2 - Ay*Cy)) - ((Ay*by**2 - 2*ay*by*By + Cy + ay**2*Cy)*torch.sin(2*nuy*torch.pi)*torch.sinh(torch.sqrt(By**2 - Ay*Cy)))/(2*by*torch.sqrt(By**2 - Ay*Cy)))**2)
    return torch.stack([hay, hby])

print(torch.stack([axf, bxf]))
print(csx(Ax, Bx, Cx, axi, bxi, nux).real)
print()

print(torch.stack([ayf, byf]))
print(csy(Ay, By, Cy, ayi, byi, nuy).real)
print()
tensor([7.5487e-03, 9.2940e+00], dtype=torch.float64)
tensor([5.3985e-03, 9.3283e+00], dtype=torch.float64)

tensor([-0.0044,  1.6334], dtype=torch.float64)
tensor([-0.0044,  1.6335], dtype=torch.float64)

[24]:
# Twiss at the observation point (approximate)

def csx(Ax, Bx, Cx, ax, bx, nux):
    cx = (1 + ax**2)/bx
    hax = ax + (-(Ax*bx) + cx*Cx - (ax*Ax*bx - 2*bx*Bx*cx + ax*cx*Cx)*(torch.cos(2*nux*torch.pi)/torch.sin(2*nux*torch.pi)))/2.0 + ((1.0/torch.sin(2*nux*torch.pi)**3)*(16*ax*(Bx**2 - Ax*Cx)*torch.sin(2*nux*torch.pi)**3 + 4*ax*torch.sin(2*nux*torch.pi)*(4*(Bx**2 - Ax*Cx)*torch.cos(2*nux*torch.pi)**2 + (Ax*bx - 2*ax*Bx + cx*Cx)**2 * torch.sin(2*nux*torch.pi)**2) + 4*(Ax*bx - 2*ax*Bx + cx*Cx)*(-2*Bx*torch.cos(2*nux*torch.pi) + (Ax*bx - cx*Cx)*torch.sin(2*nux*torch.pi))*torch.sin(4*nux*torch.pi) + 3*ax*(Ax*bx - 2*ax*Bx + cx*Cx)**2*(1.0/torch.sin(2*nux*torch.pi)) * torch.sin(4*nux*torch.pi)**2))/32.0
    hbx = bx - bx*Bx + ax*Cx + (-0.5*(Ax*bx**2) + ax*bx*Bx + Cx - 0.5*(bx*cx*Cx)) * (torch.cos(2*nux*torch.pi)/torch.sin(2*nux*torch.pi)) + (4*bx*(Bx**2 - Ax*Cx) + bx*(Ax*bx - 2*ax*Bx + cx*Cx)**2 + 4*bx*(Bx**2 - Ax*Cx)*(torch.cos(2*nux*torch.pi)/torch.sin(2*nux*torch.pi))**2 + 3*bx*(Ax*bx - 2*ax*Bx + cx*Cx)**2*(torch.cos(2*nux*torch.pi)/torch.sin(2*nux*torch.pi))**2 + 2*(Ax*bx - 2*ax*Bx + cx*Cx)*(bx*Bx - ax*Cx - Cx*(torch.cos(2*nux*torch.pi)/torch.sin(2*nux*torch.pi)))*(1.0/torch.sin(2*nux*torch.pi)**2) * torch.sin(4*nux*torch.pi))/8.0
    return torch.stack([hax, hbx])

def csy(Ay, By, Cy, ay, by, nuy):
    cy = (1 + ay**2)/by
    hay = ay + (-(Ay*by) + cy*Cy - (ay*Ay*by - 2*by*By*cy + ay*cy*Cy)*(torch.cos(2*nuy*torch.pi)/torch.sin(2*nuy*torch.pi)))/2.0 + ((1.0/torch.sin(2*nuy*torch.pi)**3)*(16*ay*(By**2 - Ay*Cy)*torch.sin(2*nuy*torch.pi)**3 + 4*ay*torch.sin(2*nuy*torch.pi)*(4*(By**2 - Ay*Cy)*torch.cos(2*nuy*torch.pi)**2 + (Ay*by - 2*ay*By + cy*Cy)**2 * torch.sin(2*nuy*torch.pi)**2) + 4*(Ay*by - 2*ay*By + cy*Cy)*(-2*By*torch.cos(2*nuy*torch.pi) + (Ay*by - cy*Cy)*torch.sin(2*nuy*torch.pi))*torch.sin(4*nuy*torch.pi) + 3*ay*(Ay*by - 2*ay*By + cy*Cy)**2*(1.0/torch.sin(2*nuy*torch.pi)) * torch.sin(4*nuy*torch.pi)**2))/32.0
    hby = by - by*By + ay*Cy + (-0.5*(Ay*by**2) + ay*by*By + Cy - 0.5*(by*cy*Cy)) * (torch.cos(2*nuy*torch.pi)/torch.sin(2*nuy*torch.pi)) + (4*by*(By**2 - Ay*Cy) + by*(Ay*by - 2*ay*By + cy*Cy)**2 + 4*by*(By**2 - Ay*Cy)*(torch.cos(2*nuy*torch.pi)/torch.sin(2*nuy*torch.pi))**2 + 3*by*(Ay*by - 2*ay*By + cy*Cy)**2*(torch.cos(2*nuy*torch.pi)/torch.sin(2*nuy*torch.pi))**2 + 2*(Ay*by - 2*ay*By + cy*Cy)*(by*By - ay*Cy - Cy*(torch.cos(2*nuy*torch.pi)/torch.sin(2*nuy*torch.pi)))*(1.0/torch.sin(2*nuy*torch.pi)**2) * torch.sin(4*nuy*torch.pi))/8.0
    return torch.stack([hay, hby])

print(torch.stack([axf, bxf]))
print(csx(Ax, Bx, Cx, axi, bxi, nux).real)
print()

print(torch.stack([ayf, byf]))
print(csy(Ay, By, Cy, ayi, byi, nuy).real)
print()
tensor([7.5487e-03, 9.2940e+00], dtype=torch.float64)
tensor([5.3630e-03, 9.3289e+00], dtype=torch.float64)

tensor([-0.0044,  1.6334], dtype=torch.float64)
tensor([-0.0044,  1.6335], dtype=torch.float64)