{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "ed1c26f4-6b1f-4333-8a47-e8964fb9b279", "metadata": {}, "source": [ "# Example-07: Sextupole (element)" ] }, { "cell_type": "code", "execution_count": 1, "id": "8a8f1fd1-7695-4af2-a56a-db6a1bcb121d", "metadata": {}, "outputs": [], "source": [ "# Comparison of sextupole element with MADX-PTC and other features" ] }, { "cell_type": "code", "execution_count": 2, "id": "be7f42de-9948-4efc-af35-701b14fda472", "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "from os import system\n", "\n", "import torch\n", "from model.library.drift import Drift\n", "from model.library.quadrupole import Quadrupole\n", "from model.library.sextupole import Sextupole" ] }, { "cell_type": "code", "execution_count": 3, "id": "0f63654e-e265-4e86-86bd-3b342df72459", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.008745689261382875, -0.005080234821325765, -0.00476590910682766, 0.0008855562050471031]\n", "[0.008745689261382871, -0.00508023482132578, -0.00476590910682768, 0.0008855562050471008]\n", "[3.469446951953614e-18, 1.5612511283791264e-17, 1.9949319973733282e-17, 2.2768245622195593e-18]\n" ] } ], "source": [ "# Tracking (paraxial)\n", "\n", "ptc = Path('ptc')\n", "obs = Path('track.obs0001.p0001')\n", "\n", "exact = False\n", "align = False\n", "\n", "ms = 10.0\n", "dp = 0.005\n", "length = 0.25\n", "state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)\n", "qx, px, qy, py = state.tolist()\n", "\n", "dx = align*torch.tensor(0.05, dtype=torch.float64)\n", "dy = align*torch.tensor(-0.02, dtype=torch.float64)\n", "dz = align*torch.tensor(0.05, dtype=torch.float64)\n", "\n", "wx = align*torch.tensor(0.005, dtype=torch.float64)\n", "wy = align*torch.tensor(-0.005, dtype=torch.float64)\n", "wz = align*torch.tensor(0.1, dtype=torch.float64)\n", "\n", "error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}\n", "\n", "code = f\"\"\"\n", "mag: sextupole, l={length},k2={ms} ;\n", "map:line=(mag) ;\n", "beam,energy=1.0E+6,particle=electron ;\n", "set,format=\"20.20f\",\"-20s\" ;\n", "use,period=map ;\n", "select,flag=error,pattern=\"mag\" ;\n", "ealign,dx={dx.item()},dy={dy.item()},ds={dz.item()},dphi={wx.item()},dtheta={wy.item()},dpsi={wz.item()} ;\n", "ptc_create_universe,sector_nmul_max=10,sector_nmul=10 ;\n", "ptc_create_layout,model=1,method=6,nst=1000,exact={str(exact).lower()} ;\n", "ptc_setswitch,fringe=false,time=true,totalpath=true,exact_mis=true ;\n", "ptc_align ;\n", "ptc_start,x={qx},px={px},y={qy},py={py},pt={dp},t=0.0 ;\n", "ptc_track,icase=5,deltap=0.,turns=1,file=track,maxaper={{1.,1.,1.,1.,1.,1.}} ;\n", "ptc_track_end ;\n", "ptc_end ;\n", "\"\"\" \n", "\n", "with ptc.open('w') as stream:\n", " stream.write(code)\n", " \n", "system(f'madx < {str(ptc)} > /dev/null')\n", "\n", "with obs.open('r') as stream:\n", " for line in stream:\n", " continue\n", " _, _, qx, px, qy, py, *_ = line.split()\n", " \n", "ref = torch.tensor([float(x) for x in (qx, px, qy, py)], dtype=torch.float64)\n", "\n", "S = Sextupole('S', length=length, ms=ms, dp=dp, exact=exact, order=5, ns=10)\n", "res = S(state, alignment=align, data={**S.data(), **error})\n", "\n", "print(ref.tolist())\n", "print(res.tolist())\n", "print((ref - res).tolist())\n", "\n", "ptc.unlink()\n", "obs.unlink()" ] }, { "cell_type": "code", "execution_count": 4, "id": "1770d4dc-1d8a-4ab9-ab99-fd95bd2ed371", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.008745672936611987, -0.005080234654232574, -0.004765906047078759, 0.000885556338827788]\n", "[0.00874567293661202, -0.0050802346542325495, -0.0047659060470788064, 0.0008855563388277869]\n", "[-3.2959746043559335e-17, -2.42861286636753e-17, 4.7704895589362195e-17, 1.0842021724855044e-18]\n" ] } ], "source": [ "# Tracking (exact)\n", "\n", "ptc = Path('ptc')\n", "obs = Path('track.obs0001.p0001')\n", "\n", "exact = True\n", "align = False\n", "\n", "ms = 10.0\n", "dp = 0.005\n", "length = 0.25\n", "state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)\n", "qx, px, qy, py = state.tolist()\n", "\n", "dx = align*torch.tensor(0.05, dtype=torch.float64)\n", "dy = align*torch.tensor(-0.02, dtype=torch.float64)\n", "dz = align*torch.tensor(0.05, dtype=torch.float64)\n", "\n", "wx = align*torch.tensor(0.005, dtype=torch.float64)\n", "wy = align*torch.tensor(-0.005, dtype=torch.float64)\n", "wz = align*torch.tensor(0.1, dtype=torch.float64)\n", "\n", "error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}\n", "\n", "code = f\"\"\"\n", "mag: sextupole, l={length},k2={ms} ;\n", "map:line=(mag) ;\n", "beam,energy=1.0E+6,particle=electron ;\n", "set,format=\"20.20f\",\"-20s\" ;\n", "use,period=map ;\n", "select,flag=error,pattern=\"mag\" ;\n", "ealign,dx={dx.item()},dy={dy.item()},ds={dz.item()},dphi={wx.item()},dtheta={wy.item()},dpsi={wz.item()} ;\n", "ptc_create_universe,sector_nmul_max=10,sector_nmul=10 ;\n", "ptc_create_layout,model=1,method=6,nst=1000,exact={str(exact).lower()} ;\n", "ptc_setswitch,fringe=false,time=true,totalpath=true,exact_mis=true ;\n", "ptc_align ;\n", "ptc_start,x={qx},px={px},y={qy},py={py},pt={dp},t=0.0 ;\n", "ptc_track,icase=5,deltap=0.,turns=1,file=track,maxaper={{1.,1.,1.,1.,1.,1.}} ;\n", "ptc_track_end ;\n", "ptc_end ;\n", "\"\"\" \n", "\n", "with ptc.open('w') as stream:\n", " stream.write(code)\n", " \n", "system(f'madx < {str(ptc)} > /dev/null')\n", "\n", "with obs.open('r') as stream:\n", " for line in stream:\n", " continue\n", " _, _, qx, px, qy, py, *_ = line.split()\n", " \n", "ref = torch.tensor([float(x) for x in (qx, px, qy, py)], dtype=torch.float64)\n", "\n", "S = Sextupole('S', length=length, ms=ms, dp=dp, exact=exact, order=5, ns=10)\n", "res = S(state, alignment=align, data={**S.data(), **error})\n", "\n", "print(ref.tolist())\n", "print(res.tolist())\n", "print((ref - res).tolist())\n", "\n", "ptc.unlink()\n", "obs.unlink()" ] }, { "cell_type": "code", "execution_count": 5, "id": "2e0156cd-d879-4760-8e36-f9e7e89151b9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.008663885569968804, -0.006258052120536049, -0.004896687680053297, -0.000915022709372755]\n", "[0.008663885569968922, -0.006258052120536033, -0.004896687680053253, -0.000915022709372733]\n", "[-1.1796119636642288e-16, -1.6479873021779667e-17, -4.423544863740858e-17, -2.200930410145574e-17]\n" ] } ], "source": [ "# Tracking (exact, alignment)\n", "\n", "ptc = Path('ptc')\n", "obs = Path('track.obs0001.p0001')\n", "\n", "exact = True\n", "align = True\n", "\n", "ms = 10.0\n", "dp = 0.005\n", "length = 0.25\n", "state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)\n", "qx, px, qy, py = state.tolist()\n", "\n", "dx = align*torch.tensor(0.05, dtype=torch.float64)\n", "dy = align*torch.tensor(-0.02, dtype=torch.float64)\n", "dz = align*torch.tensor(0.05, dtype=torch.float64)\n", "\n", "wx = align*torch.tensor(0.005, dtype=torch.float64)\n", "wy = align*torch.tensor(-0.005, dtype=torch.float64)\n", "wz = align*torch.tensor(0.1, dtype=torch.float64)\n", "\n", "error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}\n", "\n", "code = f\"\"\"\n", "mag: sextupole, l={length},k2={ms} ;\n", "map:line=(mag) ;\n", "beam,energy=1.0E+6,particle=electron ;\n", "set,format=\"20.20f\",\"-20s\" ;\n", "use,period=map ;\n", "select,flag=error,pattern=\"mag\" ;\n", "ealign,dx={dx.item()},dy={dy.item()},ds={dz.item()},dphi={wx.item()},dtheta={wy.item()},dpsi={wz.item()} ;\n", "ptc_create_universe,sector_nmul_max=10,sector_nmul=10 ;\n", "ptc_create_layout,model=1,method=6,nst=1000,exact={str(exact).lower()} ;\n", "ptc_setswitch,fringe=false,time=true,totalpath=true,exact_mis=true ;\n", "ptc_align ;\n", "ptc_start,x={qx},px={px},y={qy},py={py},pt={dp},t=0.0 ;\n", "ptc_track,icase=5,deltap=0.,turns=1,file=track,maxaper={{1.,1.,1.,1.,1.,1.}} ;\n", "ptc_track_end ;\n", "ptc_end ;\n", "\"\"\" \n", "\n", "with ptc.open('w') as stream:\n", " stream.write(code)\n", " \n", "system(f'madx < {str(ptc)} > /dev/null')\n", "\n", "with obs.open('r') as stream:\n", " for line in stream:\n", " continue\n", " _, _, qx, px, qy, py, *_ = line.split()\n", " \n", "ref = torch.tensor([float(x) for x in (qx, px, qy, py)], dtype=torch.float64)\n", "\n", "S = Sextupole('S', length=length, ms=ms, dp=dp, exact=exact, order=5, ns=10)\n", "res = S(state, alignment=align, data={**S.data(), **error})\n", "\n", "print(ref.tolist())\n", "print(res.tolist())\n", "print((ref - res).tolist())\n", "\n", "ptc.unlink()\n", "obs.unlink()" ] }, { "cell_type": "code", "execution_count": 6, "id": "c797301f-2f9d-4424-a535-3c3ebddfa453", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([ 0.0087, -0.0051, -0.0048, 0.0009], dtype=torch.float64)\n", "\n", "tensor([ 0.0100, -0.0050, -0.0050, 0.0010], dtype=torch.float64)\n", "\n", "tensor([ 0.0087, -0.0062, -0.0049, -0.0009], dtype=torch.float64)\n", "\n", "tensor([0., 0., 0., 0.], dtype=torch.float64)\n" ] } ], "source": [ "# Deviation/error variables\n", "\n", "ms = 10.0\n", "dp = 0.005\n", "length = 0.25\n", "state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)\n", "\n", "dx = torch.tensor(0.05, dtype=torch.float64)\n", "dy = torch.tensor(-0.02, dtype=torch.float64)\n", "dz = torch.tensor(0.05, dtype=torch.float64)\n", "\n", "wx = torch.tensor(0.005, dtype=torch.float64)\n", "wy = torch.tensor(-0.005, dtype=torch.float64)\n", "wz = torch.tensor(0.1, dtype=torch.float64)\n", "\n", "error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}\n", "\n", "S = Sextupole('S', length, ms, dp)\n", "\n", "# Each element has two variant of a call method\n", "# In the first case only state is passed, it is transformed using parameters specified on initializaton\n", "\n", "print(S(state))\n", "print()\n", "\n", "# Deviation errors can be also passed to call method\n", "# These variables are added to corresponding parameters specified on initializaton\n", "# For example, element lenght can changed\n", "\n", "print(S(state, data={**S.data(), **{'dl': -S.length}}))\n", "print()\n", "\n", "# In the above S.data() creates default deviation dictionary (with zero values for each deviaton)\n", "# {**S.data(), **{'dl': -S.length}} replaces the 'dl' key value \n", "\n", "# Additionaly, alignment errors are passed as deivation variables\n", "# They are used if alignment flag is raised\n", "\n", "print(S(state, data={**S.data(), **error}, alignment=True))\n", "print()\n", "\n", "# The following elements can be made equivalent using deviation variables\n", "\n", "SA = Sextupole('SA', length, ms, dp)\n", "SB = Sextupole('SB', length - 0.1, ms, dp)\n", "\n", "print(SA(state) - SB(state, data={**SB.data(), **{'dl': torch.tensor(+0.1, dtype=SB.dtype)}}))\n", "\n", "# Note, while in some cases float values can be passed as values to deviation variables\n", "# The correct behaviour in guaranteed only for tensors" ] }, { "cell_type": "code", "execution_count": 7, "id": "5f4340f0-1082-4f86-a22c-e022c0575934", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([ 0.0000, -0.0006, 0.0000, -0.0008], dtype=torch.float64)\n", "tensor([0., 0., 0., 0.], dtype=torch.float64)\n", "tensor([-5.2560e-04, -5.6465e-04, 6.1078e-05, -7.7234e-04],\n", " dtype=torch.float64)\n", "tensor([-1.3875e-04, -5.5306e-04, -9.3764e-05, -7.7778e-04],\n", " dtype=torch.float64)\n" ] } ], "source": [ "# Insertion element\n", "\n", "# In this mode elements are treated as thin insertions (at the center)\n", "# Using parameters specified on initialization, transport two matrices are computed\n", "# These matrices are used to insert the element\n", "# Input state is transformed from the element center to its entrance\n", "# Next, transformation from the entrance frame to the exit frame is performed\n", "# This transformation can contain errors\n", "# The final step is to transform state from the exit frame back to the element center\n", "# Without errors, this results in identity transformation for linear elements\n", "\n", "ms = 10.0\n", "dp = 0.005\n", "length = 1.5\n", "state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)\n", "\n", "dx = torch.tensor(0.05, dtype=torch.float64)\n", "dy = torch.tensor(-0.02, dtype=torch.float64)\n", "dz = torch.tensor(0.05, dtype=torch.float64)\n", "\n", "wx = torch.tensor(0.005, dtype=torch.float64)\n", "wy = torch.tensor(-0.005, dtype=torch.float64)\n", "wz = torch.tensor(0.1, dtype=torch.float64)\n", "\n", "error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}\n", "\n", "S = Sextupole('S', length, ms, dp, exact=False, insertion=True)\n", "\n", "# Since sextupole is a nonlinear element, insertion is an identity transformation only for zero strenght\n", "\n", "print(S(state) - state)\n", "print(S(state, data={**S.data(), **{'ms': -ms}}) - state)\n", "\n", "# Represents effect of an error (any nonzero value of strengh or a change in other parameter)\n", "\n", "print(S(state, data={**S.data(), **{'dl': 0.1}}) - state)\n", "\n", "# Exact tracking corresponds to inclusion of kinematic term as errors\n", "\n", "S = Sextupole('S', length, ms, dp, exact=True, insertion=True, ns=100, order=1)\n", "\n", "print(S(state) - state)" ] }, { "cell_type": "code", "execution_count": 8, "id": "ab74f465-c887-40c6-bb5c-ddfd553d73d9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([512, 4])\n", "torch.Size([512, 4])\n" ] } ], "source": [ "# Mapping over a set of initial conditions\n", "\n", "# Call method can be used to map over a set of initial conditions\n", "# Note, device can be set to cpu or gpu via base element classvariables\n", "\n", "ms = 10.0\n", "dp = 0.0\n", "length = 1.5\n", "\n", "dx = torch.tensor(0.05, dtype=torch.float64)\n", "dy = torch.tensor(-0.02, dtype=torch.float64)\n", "dz = torch.tensor(0.05, dtype=torch.float64)\n", "\n", "wx = torch.tensor(0.005, dtype=torch.float64)\n", "wy = torch.tensor(-0.005, dtype=torch.float64)\n", "wz = torch.tensor(0.1, dtype=torch.float64)\n", "\n", "error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}\n", "\n", "S = Sextupole('S', length, ms, dp, exact=True)\n", "\n", "state = 1.0E-3*torch.randn((512, 4), dtype=S.dtype, device=S.device)\n", "\n", "print(torch.vmap(S)(state).shape)\n", "\n", "# To map over deviations parameters a wrapper function (or a lambda expression) can be used\n", "\n", "def wrapper(state, dp):\n", " return S(state, data={**S.data(), **{'dp': dp}})\n", "\n", "dp = 1.0E-3*torch.randn(512, dtype=S.dtype, device=S.device)\n", "\n", "print(torch.vmap(wrapper)(state, dp).shape)" ] }, { "cell_type": "code", "execution_count": 9, "id": "997dae49-c873-4e95-9b42-f0160eb9bee5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ 0.9297, 1.4473, -0.0478, -0.0359],\n", " [-0.0938, 0.9297, -0.0638, -0.0478],\n", " [-0.0478, -0.0359, 1.0703, 1.5527],\n", " [-0.0638, -0.0478, 0.0938, 1.0703]], dtype=torch.float64)\n", "\n", "tensor([-1.1813e-05, -1.5750e-05, -2.9883e-05, -3.9844e-05],\n", " dtype=torch.float64)\n", "\n" ] } ], "source": [ "# Differentiability\n", "\n", "# Both call methods are differentiable\n", "# Derivative with respect to state can be computed directly\n", "# For deviation variables, wrapping is required\n", "\n", "ms = 10.0\n", "dp = 0.0\n", "length = 1.5\n", "state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)\n", "\n", "dx = torch.tensor(0.05, dtype=torch.float64)\n", "dy = torch.tensor(-0.02, dtype=torch.float64)\n", "dz = torch.tensor(0.05, dtype=torch.float64)\n", "\n", "wx = torch.tensor(0.005, dtype=torch.float64)\n", "wy = torch.tensor(-0.005, dtype=torch.float64)\n", "wz = torch.tensor(0.1, dtype=torch.float64)\n", "\n", "error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}\n", "\n", "S = Sextupole('S', length, ms, dp, exact=False)\n", "\n", "# Compute derivative with respect to state\n", "\n", "print(torch.func.jacrev(S)(state))\n", "print()\n", "\n", "# Compute derivative with respect to a deviation variable\n", "\n", "ms = torch.tensor(0.0, dtype=torch.float64)\n", "\n", "def wrapper(state, ms):\n", " return S(state, data={**S.data(), **{'ms': ms}})\n", "\n", "print(torch.func.jacrev(wrapper, 1)(state, ms))\n", "print()" ] }, { "cell_type": "code", "execution_count": 10, "id": "6c3a722b-5737-4aa3-8cfc-e1106eeb49e9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([ 0.0023, -0.0052, -0.0039, 0.0006], dtype=torch.float64)\n", "torch.Size([10, 4])\n", "torch.Size([10, 4, 4])\n", "torch.Size([100, 4])\n", "torch.Size([100, 4, 4])\n" ] } ], "source": [ "# Output at each step\n", "\n", "# It is possible to collect output of state or tangent matrix at each integration step\n", "# Number of integratin steps is controlled by ns parameter on initialization\n", "# Alternatively, desired integration step length can be passed\n", "# Number of integration steps is computed as ceil(length/ds)\n", "\n", "ms = 10.0\n", "dp = 0.0\n", "length = 1.5\n", "state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)\n", "\n", "dx = torch.tensor(0.05, dtype=torch.float64)\n", "dy = torch.tensor(-0.02, dtype=torch.float64)\n", "dz = torch.tensor(0.05, dtype=torch.float64)\n", "\n", "wx = torch.tensor(0.005, dtype=torch.float64)\n", "wy = torch.tensor(-0.005, dtype=torch.float64)\n", "wz = torch.tensor(0.1, dtype=torch.float64)\n", "\n", "error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}\n", "\n", "S = Sextupole('S', length, ms, dp, exact=False, ns=10, output=True, matrix=True)\n", "\n", "# Final state is still returned\n", "\n", "print(S(state))\n", "\n", "# Data is added to special attributes (state and tangent matrix)\n", "\n", "print(S.container_output.shape)\n", "print(S.container_matrix.shape)\n", "\n", "# Number of integration steps can be changed\n", "\n", "S.ns = 100\n", "\n", "S(state)\n", "print(S.container_output.shape)\n", "print(S.container_matrix.shape)" ] }, { "cell_type": "code", "execution_count": 11, "id": "58abd0bc-7a66-494a-b8a8-ee1d690a4f29", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.0022880403040043407, -0.005176627884407675, -0.0038882248469424585, 0.0005834231060909752]\n", "[0.0022871610802911667, -0.005176794799148338, -0.0038891562083014203, 0.0005832353574511327]\n", "[8.792237131739246e-07, 1.6691474066278522e-07, 9.313613589618727e-07, 1.877486398425181e-07]\n", "\n", "[0, 1, 2, 1, 0]\n", "[0.5, 0.5, 1.0, 0.5, 0.5]\n" ] } ], "source": [ "# Integration order is set on initialization (default value is zero)\n", "# This order is related to difference order as 2n + 2\n", "# Thus, zero corresponds to second order difference method\n", "\n", "ms = 10.0\n", "dp = 0.0\n", "length = 1.5\n", "state = torch.tensor([0.01, -0.005, -0.005, 0.001], dtype=torch.float64)\n", "\n", "dx = torch.tensor(0.05, dtype=torch.float64)\n", "dy = torch.tensor(-0.02, dtype=torch.float64)\n", "dz = torch.tensor(0.05, dtype=torch.float64)\n", "\n", "wx = torch.tensor(0.005, dtype=torch.float64)\n", "wy = torch.tensor(-0.005, dtype=torch.float64)\n", "wz = torch.tensor(0.1, dtype=torch.float64)\n", "\n", "error = {'dx': dx, 'dy': dy, 'dz': dz, 'wx': wx, 'wy': wy, 'wz': wz}\n", "\n", "S = Sextupole('S', length, ms, dp, order=0, exact=True)\n", "\n", "# For sextupole integration is always performed\n", "# In exact case, kinematic term error is added\n", "\n", "S.ns = 10\n", "ref = S(state)\n", "\n", "S.ns = 100\n", "res = S(state)\n", "\n", "print(ref.tolist())\n", "print(res.tolist())\n", "print((ref - res).tolist())\n", "print()\n", "\n", "# Integrator parameters are stored in data attribute (if integration is actually performed)\n", "\n", "maps, weights = S._data\n", "print(maps)\n", "print(weights)" ] }, { "cell_type": "code", "execution_count": 12, "id": "9c53baf6-e0ad-478d-a252-232a20b11b9a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([0.2107, 0.1703], dtype=torch.float64)\n", "tensor([-0.2279, -0.2107], dtype=torch.float64)\n", "tensor([[0., 0.],\n", " [0., 0.]], dtype=torch.float64)\n" ] } ], "source": [ "# Derivatives of twiss parameters (chromaticity)\n", "\n", "# pip install git+https://github.com/i-a-morozov/twiss.git@main\n", "# pip install git+https://github.com/i-a-morozov/ndmap.git@main\n", "\n", "from twiss import twiss\n", "\n", "from ndmap.pfp import parametric_fixed_point\n", "from ndmap.evaluate import evaluate\n", "\n", "# Define elements\n", "\n", "QF = Quadrupole('QF', 0.5, +0.21)\n", "QD = Quadrupole('QD', 0.5, -0.19)\n", "SF = Sextupole('SF', 0.25)\n", "SD = Sextupole('SD', 0.25)\n", "DA = Drift('DR', 0.25)\n", "DB = Drift('DR', 4.00)\n", "\n", "# Define one-turn transformation\n", "\n", "def fodo(state, dp, ms):\n", " dp, *_ = dp\n", " msf, msd, *_ = ms\n", " state = QF(state, data={**QF.data(), **{'dp': dp}})\n", " state = DA(state, data={**DA.data(), **{'dp': dp}})\n", " state = SF(state, data={**SF.data(), **{'dp': dp, 'ms': msf}})\n", " state = DB(state, data={**DB.data(), **{'dp': dp}})\n", " state = SD(state, data={**SD.data(), **{'dp': dp, 'ms': msd}})\n", " state = DA(state, data={**DA.data(), **{'dp': dp}})\n", " state = QD(state, data={**QD.data(), **{'dp': dp}})\n", " state = QD(state, data={**QD.data(), **{'dp': dp}})\n", " state = DA(state, data={**DA.data(), **{'dp': dp}})\n", " state = SD(state, data={**SD.data(), **{'dp': dp, 'ms': msd}})\n", " state = DB(state, data={**DB.data(), **{'dp': dp}})\n", " state = SF(state, data={**SF.data(), **{'dp': dp, 'ms': msf}})\n", " state = DA(state, data={**DA.data(), **{'dp': dp}})\n", " state = QF(state, data={**QF.data(), **{'dp': dp}})\n", " return state\n", "\n", "# Set deviation parameters\n", "\n", "msf = torch.tensor(0.0, dtype=torch.float64)\n", "msd = torch.tensor(0.0, dtype=torch.float64)\n", "ms = torch.stack([msf, msd])\n", "dp = torch.tensor([0.0], dtype=torch.float64)\n", "\n", "# Set fixed point\n", "\n", "fp = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)\n", "\n", "\n", "# Compute parametrix fixed point (first order in momentum deviation)\n", "# Note, all parameters must be vectors\n", "\n", "pfp, *_ = parametric_fixed_point((1, ), fp, [dp], fodo, ms)\n", "\n", "# Define transformation around fixed point\n", "\n", "def pfp_fodo(state, dp, ms):\n", " return fodo(state + evaluate(pfp, [dp]), dp, ms) - evaluate(pfp, [dp])\n", "\n", "# Tune\n", "\n", "def tune(dp, ms):\n", " matrix = torch.func.jacrev(pfp_fodo)(fp, dp, ms)\n", " tune, *_ = twiss(matrix)\n", " return tune\n", "\n", "# Chromaticity\n", "\n", "def chromaticity(ms):\n", " return torch.func.jacrev(tune)(dp, ms)\n", "\n", "# Compute tunes\n", "\n", "tunes = tune(dp, ms)\n", "print(tunes)\n", "\n", "# Compute chromaticity\n", "\n", "chromaticities = chromaticity(ms)\n", "print(chromaticities.squeeze())\n", "\n", "# Compute derivative of chromaticities \n", "# The result is zero, since there is no dispersion to feed sextupoles down\n", "\n", "print(torch.func.jacrev(chromaticity)(ms).squeeze())" ] } ], "metadata": { "colab": { "collapsed_sections": [ "myt0_gMIOq7b", "5d97819c" ], "name": "03_frequency.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.1" }, "latex_envs": { "LaTeX_envs_menu_present": true, "autoclose": false, "autocomplete": true, "bibliofile": "biblio.bib", "cite_by": "apalike", "current_citInitial": 1, "eqLabelWithNumbers": true, "eqNumInitial": 1, "hotkeys": { "equation": "Ctrl-E", "itemize": "Ctrl-I" }, "labels_anchors": false, "latex_user_defs": false, "report_style_numbering": false, "user_envs_cfg": false } }, "nbformat": 4, "nbformat_minor": 5 }