ELETTRA-39: ID linear optics distortion (analytical & multiple perturbatons)
[1]:
# In this example one-turn matrix factorization is illustrated
# The one-turn matrix can be expressed as M exp(S A) where Aij = Aji is (a block matrix if there is no coupling that) describes cumulative effect of perturbation
# Here the matrix A is constructed perturbatively using Magnus expansion
[2]:
# Import
import torch
from torch import Tensor
from pathlib import Path
import matplotlib
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
matplotlib.rcParams['text.usetex'] = True
from model.library.element import Element
from model.library.line import Line
from model.library.quadrupole import Quadrupole
from model.library.matrix import Matrix
from model.command.external import load_lattice
from model.command.build import build
from model.command.tune import tune
from model.command.orbit import dispersion
from model.command.twiss import twiss
from model.command.advance import advance
from model.command.coupling import coupling
[3]:
# Set data type and device
Element.dtype = dtype = torch.float64
Element.device = device = torch.device('cpu')
[4]:
# Load lattice (ELEGANT table)
# Note, lattice is allowed to have repeated elements
path = Path('elettra.lte')
data = load_lattice(path)
[5]:
# Build and setup lattice
ring:Line = build('RING', 'ELEGANT', data)
# Flatten sublines
ring.flatten()
# Remove all marker elements but the ones starting with MLL (long straight section centers)
ring.remove_group(pattern=r'^(?!MLL_).*', kinds=['Marker'])
# Replace all sextupoles with quadrupoles
def factory(element:Element) -> None:
table = element.serialize
table.pop('ms', None)
return Quadrupole(**table)
ring.replace_group(pattern=r'', factory=factory, kinds=['Sextupole'])
# Set linear dipoles
def apply(element:Element) -> None:
element.linear = True
ring.apply(apply, kinds=['Dipole'])
# Merge drifts
ring.merge()
# Change lattice start
ring.start = "MLL_S01"
# Describe
ring.describe
[5]:
{'Marker': 12, 'Drift': 708, 'BPM': 168, 'Quadrupole': 360, 'Dipole': 156}
[6]:
# Unperturbed one-turn matrix
state = torch.tensor(4*[0.0], dtype=dtype)
M0 = torch.func.jacrev(ring)(state)
print(M0)
tensor([[-0.3055, 8.9649, 0.0000, 0.0000],
[-0.1011, -0.3055, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.5315, 1.3891],
[ 0.0000, 0.0000, -0.5166, 0.5315]], dtype=torch.float64)
[7]:
# Define IDs
# The first ID is 'fake' and acts as an identity
# It is also located at the lattice start
ca, cb, cc, cd = -0.034441907232402175, -0.04458009513208418, 0.056279356423643276, 0.08037110220505986
A = torch.tensor([[ca, 0.0, 0.0, 0.0], [0.0, cb, 0.0, 0.0], [0.0, 0.0, cc, 0.0], [0.0, 0.0, 0.0, cd]], dtype=dtype)
mask = torch.triu(torch.ones_like(A, dtype=torch.bool))
ID1 = Matrix('ID1', length=0.0, A=(0.25*A[mask]).tolist())
ID2 = Matrix('ID2', length=0.0, A=(0.25*A[mask]).tolist())
ID3 = Matrix('ID3', length=0.0, A=(0.25*A[mask]).tolist())
[8]:
# Insert IDs
# Each ID (or other perturbation) is a thin insertion matrix
# The matrices are inserted after given markers
elements = [ID1, ID2, ID3]
markers = ['MLL_S01', 'MLL_S02', 'MLL_S03']
error = ring.clone()
for element, marker in zip(elements, markers):
error.insert(element, marker)
# Describe
error.describe
[8]:
{'Marker': 12,
'Matrix': 3,
'Drift': 708,
'BPM': 168,
'Quadrupole': 360,
'Dipole': 156}
[9]:
# Perturbed one-turn matrix
state = torch.tensor(4*[0.0], dtype=dtype)
M1 = torch.func.jacrev(error)(state)
print(M0)
print(M1)
print()
tensor([[-0.3055, 8.9649, 0.0000, 0.0000],
[-0.1011, -0.3055, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.5315, 1.3891],
[ 0.0000, 0.0000, -0.5166, 0.5315]], dtype=torch.float64)
tensor([[-0.1780, 9.1328, 0.0000, 0.0000],
[-0.1057, -0.1928, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.4820, 1.4276],
[ 0.0000, 0.0000, -0.5351, 0.4897]], dtype=torch.float64)
[10]:
# Skew identity matrix
S = torch.tensor([[0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1], [0, 0, -1, 0]], dtype=dtype)
# Compute perturbations
Ais = []
Kis = []
for element in elements:
Ai = torch.zeros((4, 4), dtype=dtype)
Ai[mask] = element.A
Ki = torch.linalg.matrix_exp(S @ Ai)
Ais.append(Ai)
Kis.append(Ki)
Ais = torch.stack(Ais)
Kis = torch.stack(Kis)
print([torch.allclose(Ki, torch.func.jacrev(element)(state)) for Ki, element in zip(Kis, elements)])
print()
# Compute matrices between IDs (using markers in the unperturbed lattice)
Tis = []
for i in range(len(markers)):
_, *sequence, _ = ring[markers[i] : markers[(i + 1) % len(markers)]]
line = Line('', sequence=sequence)
Tis.append(torch.func.jacrev(line)(state))
Tis = torch.stack(Tis)
# Construct one-turn matrix
hM = torch.eye(4, dtype=dtype)
for Ti in Tis:
hM = Ti @ hM
print(M0)
print(hM)
print(torch.allclose(M0, hM))
print()
# Construct one-turn matrix
hM = torch.eye(4, dtype=dtype)
for Ti, Ki in zip(Tis, Kis):
hM = Ti @ Ki @ hM
print(M1)
print(hM)
print(torch.allclose(M1, hM))
print()
[True, True, True]
tensor([[-0.3055, 8.9649, 0.0000, 0.0000],
[-0.1011, -0.3055, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.5315, 1.3891],
[ 0.0000, 0.0000, -0.5166, 0.5315]], dtype=torch.float64)
tensor([[-0.3055, 8.9649, 0.0000, 0.0000],
[-0.1011, -0.3055, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.5315, 1.3891],
[ 0.0000, 0.0000, -0.5166, 0.5315]], dtype=torch.float64)
True
tensor([[-0.1780, 9.1328, 0.0000, 0.0000],
[-0.1057, -0.1928, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.4820, 1.4276],
[ 0.0000, 0.0000, -0.5351, 0.4897]], dtype=torch.float64)
tensor([[-0.1780, 9.1328, 0.0000, 0.0000],
[-0.1057, -0.1928, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.4820, 1.4276],
[ 0.0000, 0.0000, -0.5351, 0.4897]], dtype=torch.float64)
True
[11]:
# Compute unperturbed transport matrices from lattice start to each ID
Tis = []
for marker in markers:
*sequence, _ = ring[ring.start : marker]
line = Line('', sequence=sequence)
Tis.append(torch.func.jacrev(line)(state))
Tis = torch.stack(Tis)
# Construct one-turn matrix
hM = torch.eye(4, dtype=dtype)
for Ti, Ki in zip(Tis, Kis):
hKi = Ti.inverse() @ Ki @ Ti
hM = hKi @ hM
hM = M0 @ hM
print(M1)
print(hM)
print(torch.allclose(M1, hM))
print()
tensor([[-0.1780, 9.1328, 0.0000, 0.0000],
[-0.1057, -0.1928, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.4820, 1.4276],
[ 0.0000, 0.0000, -0.5351, 0.4897]], dtype=torch.float64)
tensor([[-0.1780, 9.1328, 0.0000, 0.0000],
[-0.1057, -0.1928, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.4820, 1.4276],
[ 0.0000, 0.0000, -0.5351, 0.4897]], dtype=torch.float64)
True
[12]:
# Construct shifted exponents
hAis = []
for Ti, Ai in zip(Tis, Ais):
hAi = Ti.T @ Ai @ Ti
hAis.append(hAi)
hAis = torch.stack(hAis)
# Construct one-turn matrix
hM = torch.eye(4, dtype=dtype)
for hAi in hAis:
hM = torch.linalg.matrix_exp(S @ hAi) @ hM
hM = M0 @ hM
print(M1)
print(hM)
print(torch.allclose(M1, hM))
print()
tensor([[-0.1780, 9.1328, 0.0000, 0.0000],
[-0.1057, -0.1928, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.4820, 1.4276],
[ 0.0000, 0.0000, -0.5351, 0.4897]], dtype=torch.float64)
tensor([[-0.1780, 9.1328, 0.0000, 0.0000],
[-0.1057, -0.1928, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.4820, 1.4276],
[ 0.0000, 0.0000, -0.5351, 0.4897]], dtype=torch.float64)
True
[13]:
# Define bracket (commutator)
def bracket(X, Y):
return X @ Y - Y @ X
[14]:
# Construct product approximation
hA1 = torch.zeros((4, 4), dtype=dtype)
for hAi in hAis:
hA1 += hAi
hA2 = torch.zeros((4, 4), dtype=dtype)
for i in range(len(elements)):
for j in range(i + 1):
dij = (i == j)
factor = 1/(1 + dij)
Xi = S @ hAis[i]
Xj = S @ hAis[j]
hA2 += factor*1/2*bracket(Xi, Xj)
hA2 = - S @ hA2
hA3 = torch.zeros((4, 4), dtype=dtype)
for i in range(len(elements)):
for j in range(i + 1):
for k in range(j + 1):
dij = (i == j)
dik = (i == k)
djk = (j == k)
factor = (1.0/(1.0 + dij))*(1.0/(1.0 + dik + djk))
Xi = S @ hAis[i]
Xj = S @ hAis[j]
Xk = S @ hAis[k]
hA3 += factor*(
+ 1/4*bracket(bracket(Xi, Xj), Xk)
- (1/12)*bracket(bracket(Xi, Xk), Xj)
- (1/12)*bracket(bracket(Xj, Xk), Xi)
)
hA3 = - S @ hA3
hA4 = torch.zeros((4, 4), dtype=dtype)
for i in range(len(elements)):
for j in range(i + 1):
for k in range(j + 1):
for l in range(k + 1):
dij = (i == j)
dik = (i == k)
djk = (j == k)
dil = (i == l)
djl = (j == l)
dkl = (k == l)
factor = (1.0 / (1.0 + dij)) * (1.0 / (1.0 + dik + djk)) * (1.0 / (1.0 + dil + djl + dkl))
Xi = S @ hAis[i]
Xj = S @ hAis[j]
Xk = S @ hAis[k]
Xl = S @ hAis[l]
hA4 += factor*(
1/12*bracket(bracket(bracket(Xi, Xj), Xk), Xl) +
1/12*bracket(Xi, bracket(bracket(Xj, Xk), Xl)) +
1/12*bracket(Xi, bracket(Xj, bracket(Xk, Xl))) +
1/12*bracket(Xj, bracket(Xk, bracket(Xl, Xi)))
)
hA4 = - S @ hA4
[15]:
# Set elements
Ax, Bx, _, _, Cx, _, _, Ay, By, Cy = (hA1 + hA2 + hA2 + hA4)[mask].to(torch.complex128)
[16]:
# Compute tunes (fractional part)
nux, nuy = tune(ring, [], matched=True, limit=1)
[17]:
# Compute twiss parameters
ax, bx, ay, by = twiss(ring, [], matched=True, advance=True, full=False).T
axi, *_ = ax
bxi, *_ = bx
ayi, *_ = ay
byi, *_ = by
[18]:
# Compute tunes (fractional part)
nux_id, nuy_id = tune(error, [], matched=True, limit=1)
[19]:
# Compute twiss parameters
ax_id, bx_id, ay_id, by_id = twiss(error, [], matched=True, advance=True, full=False).T
axf, *_ = ax_id
bxf, *_ = bx_id
ayf, *_ = ay_id
byf, *_ = by_id
[20]:
# Tune shifts
print(nux - nux_id)
print(nuy - nuy_id)
tensor(0.0197, dtype=torch.float64)
tensor(-0.0084, dtype=torch.float64)
[21]:
# Tune shifts (exact)
def dnux(Ax, Bx, Cx, ax, bx, nux):
return nux - torch.arccos(torch.cos(2*nux*torch.pi)*torch.cosh(torch.sqrt(Bx**2 - Ax*Cx)) - ((Ax*bx**2 - 2*ax*bx*Bx + Cx + ax**2*Cx)*torch.sin(2*nux*torch.pi)*torch.sinh(torch.sqrt(Bx**2 - Ax*Cx)))/(2*bx*torch.sqrt(Bx**2 - Ax*Cx)))/(2*torch.pi)
def dnuy(Ay, By, Cy, ay, by, nuy):
return nuy - torch.arccos(torch.cos(2*nuy*torch.pi)*torch.cosh(torch.sqrt(By**2 - Ay*Cy)) - ((Ay*by**2 - 2*ay*by*By + Cy + ay**2*Cy)*torch.sin(2*nuy*torch.pi)*torch.sinh(torch.sqrt(By**2 - Ay*Cy)))/(2*by*torch.sqrt(By**2 - Ay*Cy)))/(2*torch.pi)
print((nux - nux_id))
print(dnux(Ax, Bx, Cx, axi, bxi, nux).real)
print()
print((nuy - nuy_id))
print(dnux(Ay, By, Cy, ayi, byi, nuy).real)
print()
tensor(0.0197, dtype=torch.float64)
tensor(0.0198, dtype=torch.float64)
tensor(-0.0084, dtype=torch.float64)
tensor(-0.0084, dtype=torch.float64)
[22]:
# Tune shifts (approximate)
def dnux(Ax, Bx, Cx, ax, bx, nux):
cx = (1 + ax**2)/bx
return - ((Ax*bx - 2*ax*Bx + cx*Cx)/(4*torch.pi) - ((Ax**2*bx**2 - 4*ax*Ax*bx*Bx + 4*bx*Bx**2*cx + 2*Ax*(-2 + bx*cx)*Cx + cx*Cx*(-4*ax*Bx + cx*Cx))*torch.cos(2*nux*torch.pi)/torch.sin(2*nux*torch.pi))/(16*torch.pi))
def dnuy(Ay, By, Cy, ay, by, nuy):
cy = (1 + ay**2)/by
return - ((Ay*by - 2*ay*By + cy*Cy)/(4*torch.pi) - ((Ay**2*by**2 - 4*ay*Ay*by*By + 4*by*By**2*cy + 2*Ay*(-2 + by*cy)*Cy + cy*Cy*(-4*ay*By + cy*Cy))*torch.cos(2*nuy*torch.pi)/torch.sin(2*nuy*torch.pi))/(16*torch.pi))
print((nux - nux_id))
print(dnux(Ax, Bx, Cx, axi, bxi, nux).real)
print()
print((nuy - nuy_id))
print(dnux(Ay, By, Cy, ayi, byi, nuy).real)
print()
tensor(0.0197, dtype=torch.float64)
tensor(0.0198, dtype=torch.float64)
tensor(-0.0084, dtype=torch.float64)
tensor(-0.0084, dtype=torch.float64)
[23]:
# Twiss at the observation point (exact)
def csx(Ax, Bx, Cx, ax, bx, nux):
cx = (1 + ax**2)/bx
hax = (2*ax*torch.cosh(torch.sqrt(Bx**2 - Ax*Cx))*torch.sin(2*nux*torch.pi) + ((2*bx*Bx*torch.cos(2*nux*torch.pi) + (-(Ax*bx**2) + Cx + ax**2*Cx)*torch.sin(2*nux*torch.pi))*torch.sinh(torch.sqrt(Bx**2 - Ax*Cx)))/(bx*torch.sqrt(Bx**2 - Ax*Cx)))/(2*torch.sqrt(1 - (torch.cos(2*nux*torch.pi)*torch.cosh(torch.sqrt(Bx**2 - Ax*Cx)) - ((Ax*bx**2 - 2*ax*bx*Bx + Cx + ax**2*Cx)*torch.sin(2*nux*torch.pi)*torch.sinh(torch.sqrt(Bx**2 - Ax*Cx)))/(2*bx*torch.sqrt(Bx**2 - Ax*Cx)))**2))
hbx = (bx*torch.cosh(torch.sqrt(Bx**2 - Ax*Cx))*torch.sin(2*nux*torch.pi) + ((Cx*torch.cos(2*nux*torch.pi) + (-(bx*Bx) + ax*Cx)*torch.sin(2*nux*torch.pi))*torch.sinh(torch.sqrt(Bx**2 - Ax*Cx)))/torch.sqrt(Bx**2 - Ax*Cx))/torch.sqrt(1 - (torch.cos(2*nux*torch.pi)*torch.cosh(torch.sqrt(Bx**2 - Ax*Cx)) - ((Ax*bx**2 - 2*ax*bx*Bx + Cx + ax**2*Cx)*torch.sin(2*nux*torch.pi)*torch.sinh(torch.sqrt(Bx**2 - Ax*Cx)))/(2*bx*torch.sqrt(Bx**2 - Ax*Cx)))**2)
return torch.stack([hax, hbx])
def csy(Ay, By, Cy, ay, by, nuy):
cy = (1 + ay**2)/by
hay = (2*ay*torch.cosh(torch.sqrt(By**2 - Ay*Cy))*torch.sin(2*nuy*torch.pi) + ((2*by*By*torch.cos(2*nuy*torch.pi) + (-(Ay*by**2) + Cy + ay**2*Cy)*torch.sin(2*nuy*torch.pi))*torch.sinh(torch.sqrt(By**2 - Ay*Cy)))/(by*torch.sqrt(By**2 - Ay*Cy)))/(2*torch.sqrt(1 - (torch.cos(2*nuy*torch.pi)*torch.cosh(torch.sqrt(By**2 - Ay*Cy)) - ((Ay*by**2 - 2*ay*by*By + Cy + ay**2*Cy)*torch.sin(2*nuy*torch.pi)*torch.sinh(torch.sqrt(By**2 - Ay*Cy)))/(2*by*torch.sqrt(By**2 - Ay*Cy)))**2))
hby = (by*torch.cosh(torch.sqrt(By**2 - Ay*Cy))*torch.sin(2*nuy*torch.pi) + ((Cy*torch.cos(2*nuy*torch.pi) + (-(by*By) + ay*Cy)*torch.sin(2*nuy*torch.pi))*torch.sinh(torch.sqrt(By**2 - Ay*Cy)))/torch.sqrt(By**2 - Ay*Cy))/torch.sqrt(1 - (torch.cos(2*nuy*torch.pi)*torch.cosh(torch.sqrt(By**2 - Ay*Cy)) - ((Ay*by**2 - 2*ay*by*By + Cy + ay**2*Cy)*torch.sin(2*nuy*torch.pi)*torch.sinh(torch.sqrt(By**2 - Ay*Cy)))/(2*by*torch.sqrt(By**2 - Ay*Cy)))**2)
return torch.stack([hay, hby])
print(torch.stack([axf, bxf]))
print(csx(Ax, Bx, Cx, axi, bxi, nux).real)
print()
print(torch.stack([ayf, byf]))
print(csy(Ay, By, Cy, ayi, byi, nuy).real)
print()
tensor([7.5487e-03, 9.2940e+00], dtype=torch.float64)
tensor([5.3985e-03, 9.3283e+00], dtype=torch.float64)
tensor([-0.0044, 1.6334], dtype=torch.float64)
tensor([-0.0044, 1.6335], dtype=torch.float64)
[24]:
# Twiss at the observation point (approximate)
def csx(Ax, Bx, Cx, ax, bx, nux):
cx = (1 + ax**2)/bx
hax = ax + (-(Ax*bx) + cx*Cx - (ax*Ax*bx - 2*bx*Bx*cx + ax*cx*Cx)*(torch.cos(2*nux*torch.pi)/torch.sin(2*nux*torch.pi)))/2.0 + ((1.0/torch.sin(2*nux*torch.pi)**3)*(16*ax*(Bx**2 - Ax*Cx)*torch.sin(2*nux*torch.pi)**3 + 4*ax*torch.sin(2*nux*torch.pi)*(4*(Bx**2 - Ax*Cx)*torch.cos(2*nux*torch.pi)**2 + (Ax*bx - 2*ax*Bx + cx*Cx)**2 * torch.sin(2*nux*torch.pi)**2) + 4*(Ax*bx - 2*ax*Bx + cx*Cx)*(-2*Bx*torch.cos(2*nux*torch.pi) + (Ax*bx - cx*Cx)*torch.sin(2*nux*torch.pi))*torch.sin(4*nux*torch.pi) + 3*ax*(Ax*bx - 2*ax*Bx + cx*Cx)**2*(1.0/torch.sin(2*nux*torch.pi)) * torch.sin(4*nux*torch.pi)**2))/32.0
hbx = bx - bx*Bx + ax*Cx + (-0.5*(Ax*bx**2) + ax*bx*Bx + Cx - 0.5*(bx*cx*Cx)) * (torch.cos(2*nux*torch.pi)/torch.sin(2*nux*torch.pi)) + (4*bx*(Bx**2 - Ax*Cx) + bx*(Ax*bx - 2*ax*Bx + cx*Cx)**2 + 4*bx*(Bx**2 - Ax*Cx)*(torch.cos(2*nux*torch.pi)/torch.sin(2*nux*torch.pi))**2 + 3*bx*(Ax*bx - 2*ax*Bx + cx*Cx)**2*(torch.cos(2*nux*torch.pi)/torch.sin(2*nux*torch.pi))**2 + 2*(Ax*bx - 2*ax*Bx + cx*Cx)*(bx*Bx - ax*Cx - Cx*(torch.cos(2*nux*torch.pi)/torch.sin(2*nux*torch.pi)))*(1.0/torch.sin(2*nux*torch.pi)**2) * torch.sin(4*nux*torch.pi))/8.0
return torch.stack([hax, hbx])
def csy(Ay, By, Cy, ay, by, nuy):
cy = (1 + ay**2)/by
hay = ay + (-(Ay*by) + cy*Cy - (ay*Ay*by - 2*by*By*cy + ay*cy*Cy)*(torch.cos(2*nuy*torch.pi)/torch.sin(2*nuy*torch.pi)))/2.0 + ((1.0/torch.sin(2*nuy*torch.pi)**3)*(16*ay*(By**2 - Ay*Cy)*torch.sin(2*nuy*torch.pi)**3 + 4*ay*torch.sin(2*nuy*torch.pi)*(4*(By**2 - Ay*Cy)*torch.cos(2*nuy*torch.pi)**2 + (Ay*by - 2*ay*By + cy*Cy)**2 * torch.sin(2*nuy*torch.pi)**2) + 4*(Ay*by - 2*ay*By + cy*Cy)*(-2*By*torch.cos(2*nuy*torch.pi) + (Ay*by - cy*Cy)*torch.sin(2*nuy*torch.pi))*torch.sin(4*nuy*torch.pi) + 3*ay*(Ay*by - 2*ay*By + cy*Cy)**2*(1.0/torch.sin(2*nuy*torch.pi)) * torch.sin(4*nuy*torch.pi)**2))/32.0
hby = by - by*By + ay*Cy + (-0.5*(Ay*by**2) + ay*by*By + Cy - 0.5*(by*cy*Cy)) * (torch.cos(2*nuy*torch.pi)/torch.sin(2*nuy*torch.pi)) + (4*by*(By**2 - Ay*Cy) + by*(Ay*by - 2*ay*By + cy*Cy)**2 + 4*by*(By**2 - Ay*Cy)*(torch.cos(2*nuy*torch.pi)/torch.sin(2*nuy*torch.pi))**2 + 3*by*(Ay*by - 2*ay*By + cy*Cy)**2*(torch.cos(2*nuy*torch.pi)/torch.sin(2*nuy*torch.pi))**2 + 2*(Ay*by - 2*ay*By + cy*Cy)*(by*By - ay*Cy - Cy*(torch.cos(2*nuy*torch.pi)/torch.sin(2*nuy*torch.pi)))*(1.0/torch.sin(2*nuy*torch.pi)**2) * torch.sin(4*nuy*torch.pi))/8.0
return torch.stack([hay, hby])
print(torch.stack([axf, bxf]))
print(csx(Ax, Bx, Cx, axi, bxi, nux).real)
print()
print(torch.stack([ayf, byf]))
print(csy(Ay, By, Cy, ayi, byi, nuy).real)
print()
tensor([7.5487e-03, 9.2940e+00], dtype=torch.float64)
tensor([5.3630e-03, 9.3289e+00], dtype=torch.float64)
tensor([-0.0044, 1.6334], dtype=torch.float64)
tensor([-0.0044, 1.6335], dtype=torch.float64)