{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "262a5ec8-2553-4237-ab62-319b6ca22089", "metadata": {}, "source": [ "# Example-39: Mapping (Transformations around closed orbit)" ] }, { "cell_type": "code", "execution_count": 1, "id": "e1c8190a-45c4-4aff-8b3c-b73b703167e2", "metadata": {}, "outputs": [], "source": [ "# In this example the steps to define mappings between elements are illustrated\n", "# These mappings are differentiable with respect to state and different deviation groups\n", "# Additionaly, mappings can be defined around closed orbit" ] }, { "cell_type": "code", "execution_count": 2, "id": "d8c227fe-3008-4bd0-9404-976103178dd9", "metadata": {}, "outputs": [], "source": [ "# Import\n", "\n", "from random import random\n", "from pprint import pprint\n", "\n", "import torch\n", "\n", "from pathlib import Path\n", "\n", "from model.library.line import Line\n", "from model.library.corrector import Corrector\n", "\n", "from model.command.external import load_lattice\n", "from model.command.build import build\n", "from model.command.wrapper import group\n", "from model.command.orbit import orbit\n", "from model.command.mapping import mapping\n", "from model.command.mapping import matrix" ] }, { "cell_type": "code", "execution_count": 3, "id": "2b669eb7-ccdf-4222-a95d-709da30c0094", "metadata": {}, "outputs": [], "source": [ "# Build and setup lattice\n", "\n", "# Quadrupoles are splitted into 2**2 parts, Dipoles -- 2**4 part\n", "# Correctors are inserted between parts\n", "# Random errors are assigned to correctors, so that the origin is not preserved\n", "\n", "# Load ELEGANT table\n", "\n", "path = Path('ic.lte')\n", "data = load_lattice(path)\n", "\n", "# Build ELEGANT table\n", "\n", "ring:Line = build('RING', 'ELEGANT', data)\n", "ring.flatten()\n", "\n", "# Merge drifts\n", "\n", "ring.merge()\n", "\n", "# Split quadrupoles and insert correctors\n", "\n", "nq = 2**2\n", "\n", "for name in [name for name, kind, *_ in ring.layout() if kind == 'Quadrupole']:\n", " corrector = Corrector(f'{name}_CXY', factor=1/(nq - 1))\n", " corrector.cx = 1.0E-3*(random() - 0.5)\n", " corrector.cy = 1.0E-3*(random() - 0.5)\n", " ring.split((nq, None, [name], None), paste=[corrector]) \n", "\n", "# Split dipoles and insert correctors\n", "\n", "nd = 2**4\n", " \n", "for name in [name for name, kind, *_ in ring.layout() if kind == 'Dipole']:\n", " corrector = Corrector(f'{name}_CXY', factor=1/(nd - 1))\n", " corrector.cx = 1.0E-3*(random() - 0.5)\n", " corrector.cy = 1.0E-3*(random() - 0.5) \n", " ring.split((nd, None, [name], None), paste=[corrector])\n", "\n", "# Set linear flag in dipoles\n", "\n", "for element in ring: \n", " if element.__class__.__name__ == 'Dipole':\n", " element.linear = True \n", "\n", "# Set number of elements of different kinds\n", "\n", "nb = ring.describe['BPM']\n", "nc = ring.describe['Corrector']\n", "nq = ring.describe['Quadrupole']\n", "ns = ring.describe['Sextupole']" ] }, { "cell_type": "code", "execution_count": 4, "id": "0ecf68e0-2fa0-49d9-a538-dbab96221e4d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([-0.0013, 0.0007, 0.0002, -0.0010], dtype=torch.float64)\n", "tensor([-0.0013, 0.0007, 0.0002, -0.0010], dtype=torch.float64)\n", "tensor([-0.0013, 0.0007, 0.0002, -0.0010], dtype=torch.float64)\n" ] } ], "source": [ "# model.command.wrapper.group can be used to define parametric and differentiable transformations\n", "# This transformations propagate initial state from given (probe) element start to given (other) element end\n", "# Start and end elements can be specified by names (match the first occurance in line sequence)\n", "# Or they can be specified by integers (can be negative, mod number of elements in sequence is used to define specified transformation)\n", "\n", "# Since correctors have non-zero angles, zero is not mapped to zero\n", "\n", "state = torch.tensor(4*[0.0], dtype=torch.float64)\n", "print(ring(state))\n", "\n", "# Define transformation using names (assumed to be different)\n", "\n", "probe, *_, other = ring.names\n", "transformation, *_ = group(ring, probe, other)\n", "print(transformation(state))\n", "\n", "# Define transformation using elements positions is lattice sequence\n", "\n", "probe = 0\n", "other = len(ring) - 1\n", "transformation, *_ = group(ring, probe, other)\n", "print(transformation(state))" ] }, { "cell_type": "code", "execution_count": 5, "id": "36d84c70-9036-439f-80f7-77eaaa4feda1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([ 0.0024, -0.0054, 0.0036, 0.0051], dtype=torch.float64)\n", "tensor([ 0.0024, -0.0054, 0.0036, 0.0051], dtype=torch.float64)\n", "\n", "tensor([-1.7347e-18, -2.6021e-18, 6.9389e-18, 6.9389e-18],\n", " dtype=torch.float64)\n" ] } ], "source": [ "# Compute closed orbit and test transformation around it\n", "\n", "fp = torch.tensor(4*[0.0], dtype=torch.float64)\n", "fps, _ = orbit(ring, fp, [], advance=True)\n", "fp, *_ = fps\n", "\n", "print(fp)\n", "print(ring(fp))\n", "print()\n", "\n", "print(transformation(state + fp) - fp)" ] }, { "cell_type": "code", "execution_count": 6, "id": "dd403691-565b-459f-83e6-d6829b5f7b2c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([-1.7347e-18, -2.6021e-18, 6.9389e-18, 6.9389e-18],\n", " dtype=torch.float64)\n", "tensor([0., 0., 0., 0.], dtype=torch.float64)\n", "tensor([0., 0., 0., 0.], dtype=torch.float64)\n" ] } ], "source": [ "# model.command.mapping.mapping can be used as an alias to model.command.wrapper.group\n", "# Additionaly, tt can be used to construct parametric and differentiable transformations from one element to the other that is build around closed orbit\n", "\n", "transformation, _ = mapping(ring, probe, other, matched=False)\n", "print(transformation(state + fp) - fp)\n", "\n", "# Transformation around closed orbit\n", "\n", "transformation, _ = mapping(ring, probe, other, matched=True, limit=8, epsilon=1.0E-9)\n", "print(transformation(state))\n", "\n", "# With matched flag, closed orbit will be computed on each invocation\n", "# To speed up computations, known fixed point can be passed and number of iterations set to zero\n", "# In this case probe is assumed to be the lattice start\n", "\n", "transformation, _ = mapping(ring, probe, other, matched=True, guess=fp, limit=0, epsilon=None)\n", "print(transformation(state))\n", "\n", "# Also, to compute derivatives, limit can be set to one\n", "# Set epsilon to None for vmap computations" ] }, { "cell_type": "code", "execution_count": 7, "id": "583690a9-6e4e-493c-a5cf-0a4d68c1e3ba", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([-0.0001, 0.0009, 0.0005, 0.0004], dtype=torch.float64)\n", "tensor([-0.0001, 0.0009, 0.0005, 0.0004], dtype=torch.float64)\n", "tensor([-0.0001, 0.0009, 0.0005, 0.0004], dtype=torch.float64)\n" ] } ], "source": [ "# Transformation between given elements\n", "\n", "probe = 'BPM07'\n", "other = 'BPM10'\n", "\n", "# Propagate element by element\n", "\n", "local = state.clone()\n", "for element in ring[ring.position(probe):ring.position(other)]:\n", " local = element(local)\n", "print(local)\n", "\n", "# Propagate using group\n", "\n", "transformation, *_ = group(ring, probe, other)\n", "print(transformation(state))\n", "\n", "# Propagate using mapping\n", "\n", "transformation, *_ = mapping(ring, probe, other)\n", "print(transformation(state))" ] }, { "cell_type": "code", "execution_count": 8, "id": "b52beec7-c45b-45e8-b6f8-f9b15008e1aa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([0., 0., 0., 0.], dtype=torch.float64)\n", "tensor([0., 0., 0., 0.], dtype=torch.float64)\n", "tensor([0., 0., 0., 0.], dtype=torch.float64)\n", "tensor([0., 0., 0., 0.], dtype=torch.float64)\n" ] } ], "source": [ "# Transformation between elements around closed orbit\n", "\n", "# Set closed orbit values at probe and other\n", "\n", "fp_probe = fps[ring.position(probe)]\n", "fp_other = fps[ring.position(other)]\n", "\n", "# Propagate element by element\n", "\n", "local = fp_probe + state.clone()\n", "for element in ring[ring.position(probe):ring.position(other)]:\n", " local = element(local)\n", "print(local - fp_other)\n", "\n", "# Propagate using group\n", "\n", "transformation, *_ = group(ring, probe, other)\n", "print(transformation(state + fp_probe) - fp_other)\n", "\n", "# Propagate using mapping\n", "\n", "transformation, *_ = mapping(ring, probe, other, matched=False)\n", "print(transformation(state + fp_probe) - fp_other)\n", "\n", "# Propagate using mapping (matched)\n", "\n", "transformation, *_ = mapping(ring, probe, other, matched=True)\n", "print(transformation(state))" ] }, { "cell_type": "code", "execution_count": 9, "id": "f5b35c3a-53ac-4bf6-808e-738ad131a3dd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([0., 0., 0., 0.], dtype=torch.float64)\n", "tensor([ 1.3384e-05, -1.2968e-05, 1.0788e-05, 1.8095e-05],\n", " dtype=torch.float64)\n", "\n", "tensor([0., 0., 0., 0.], dtype=torch.float64)\n", "tensor([ 1.3384e-05, -1.2968e-05, 1.0788e-05, 1.8095e-05],\n", " dtype=torch.float64)\n", "\n" ] } ], "source": [ "# Parametric mapping\n", "\n", "# Set initial value (relative to closed orbit)\n", "\n", "state = torch.tensor(4*[0.0], dtype=torch.float64)\n", "\n", "# Without root and matched flag, tensor elements are binded only to matched elements (unique names) within selected range of elements\n", "\n", "transformation, ((_, names_line, _), *_) = mapping(ring, probe, other, ('kn', ['Quadrupole'], None, None), root=False, matched=False)\n", "\n", "# Set random quadupole errors withing the line\n", "\n", "kn_line = 0.01*torch.randn(len(names_line), dtype=torch.float64)\n", "\n", "\n", "# Since non-zero deviations are passed, closed orbit has been changed\n", "\n", "print(transformation(state + fp_probe, 0*kn_line) - fp_other)\n", "print(transformation(state + fp_probe, 1*kn_line) - fp_other)\n", "print()\n", "\n", "# With root flag, tensor elements are binded to all matched elements (unique names)\n", "\n", "transformation, ((_, names_ring, _), *_) = mapping(ring, probe, other, ('kn', ['Quadrupole'], None, None), root=True, matched=False)\n", "\n", "# To match the previous result, correct slice should be set\n", "\n", "start = 0\n", "count = 0\n", "for i, name in enumerate(names_ring):\n", " if name in names_line:\n", " if not start:\n", " start = i\n", " count = i\n", " count += 1\n", "\n", "kn_ring = torch.zeros(len(names_ring), dtype=torch.float64)\n", "kn_ring[start:count] = kn_line\n", "\n", "# Propagate state\n", "\n", "print(transformation(state + fp_probe, 0*kn_ring) - fp_other)\n", "print(transformation(state + fp_probe, 1*kn_ring) - fp_other)\n", "print()" ] }, { "cell_type": "code", "execution_count": 10, "id": "e22dd92e-a1ca-4fc1-91bf-9a7f2bc32c81", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([-3.0358e-18, 1.3878e-17, 3.4694e-18, 7.8063e-18],\n", " dtype=torch.float64)\n", "tensor([ 1.8648e-17, -3.2960e-17, 5.5565e-18, 8.2399e-18],\n", " dtype=torch.float64)\n", "tensor([ 1.9516e-17, -2.6888e-17, -3.9031e-18, -4.3368e-18],\n", " dtype=torch.float64)\n", "tensor([0., 0., 0., 0.], dtype=torch.float64)\n", "tensor([0., 0., 0., 0.], dtype=torch.float64)\n", "tensor([0., 0., 0., 0.], dtype=torch.float64)\n", "tensor([0., 0., 0., 0.], dtype=torch.float64)\n" ] } ], "source": [ "# Parametric mappings around closed orbit\n", "# In this case root parameter is ignored\n", "# Elements are ordered according to their appearance in the input lattice (similar to changing start in orbit function with respect flag)\n", "\n", "# Compute closed orbit with quadrupole errors\n", "\n", "fp = torch.tensor(4*[0.0], dtype=torch.float64)\n", "fps, _ = orbit(ring, fp, [kn_ring], ('kn', ['Quadrupole'], None, None), advance=True, limit=8, epsilon=1.0E-9)\n", "fp, *_ = fps\n", "\n", "# Set closed orbit at probe and other\n", "\n", "fp_probe = fps[ring.position(probe)]\n", "fp_other = fps[ring.position(other)]\n", "\n", "\n", "# Test closed orbit at lattice start\n", "\n", "line = ring.clone()\n", "transformation, ((_, names, _), *_) = mapping(line, 0, len(line) - 1, ('kn', ['Quadrupole'], None, None), matched=False)\n", "print(fp - transformation(fp, kn_ring))\n", "\n", "# Test closed orbit at probe and other\n", "# Note, groups are setup using returned matched names\n", "\n", "line = ring.clone()\n", "line.start = probe\n", "transformation, _ = mapping(line, 0, len(line) - 1, ('kn', None, names, None), matched=False)\n", "print(fp_probe - transformation(fp_probe, kn_ring))\n", "\n", "line = ring.clone()\n", "line.start = other\n", "transformation, _ = mapping(line, 0, len(line) - 1, ('kn', None, names, None), matched=False)\n", "print(fp_other - transformation(fp_other, kn_ring))\n", "\n", "# Test mapping\n", "\n", "transformation, _ = mapping(ring, probe, other, ('kn', ['Quadrupole'], None, None), matched=False, limit=8, epsilon=1.0E-9)\n", "print(transformation(0*state + fp_probe, kn_ring) - fp_other)\n", "print(transformation(1*state + fp_probe, kn_ring) - fp_other)\n", "\n", "# Test mapping around closed orbit\n", "\n", "transformation, _ = mapping(ring, probe, other, ('kn', ['Quadrupole'], None, None), matched=True, limit=8, epsilon=1.0E-9)\n", "print(transformation(0*state, kn_ring))\n", "print(transformation(1*state, kn_ring))" ] }, { "cell_type": "code", "execution_count": 11, "id": "996c39d5-9388-424a-b92f-9bb641ee45b7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([128, 4])\n", "torch.Size([128, 4])\n" ] } ], "source": [ "# Vecrorized mapping over states / knobs\n", "# In this case, epsilon should be set to None (relevant only for the case around closed orbit)\n", "\n", "transformation, _ = mapping(ring, probe, other, ('kn', ['Quadrupole'], None, None), matched=True, limit=8, epsilon=None)\n", "\n", "states = 1.0E-3*torch.randn((128, *state.shape), dtype=torch.float64)\n", "knobs = 1.0E-3*torch.randn((128, *kn_ring.shape), dtype=torch.float64)\n", "\n", "print(torch.vmap(lambda state: transformation(state, kn_ring))(states).shape)\n", "print(torch.vmap(lambda knob: transformation(state, knob))(knobs).shape)" ] }, { "cell_type": "code", "execution_count": 12, "id": "342756f3-cde3-4c4e-ac7b-8ebb792295d6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-1.7095e+00, -1.5994e+00, 4.3872e-03, -2.3450e-03],\n", " [ 4.0853e+00, 3.2372e+00, -1.5448e-02, 1.0439e-02],\n", " [-6.1373e-03, -7.3648e-03, -2.0667e-01, -5.2196e-01],\n", " [-3.3938e-03, -5.6320e-03, 1.2684e+00, -1.6353e+00]],\n", " dtype=torch.float64)\n", "tensor([[ 0.0000e+00, 0.0000e+00, 0.0000e+00, -1.4504e-05, 2.6462e-05,\n", " 9.3523e-06, 1.6029e-05, 7.2119e-06, 0.0000e+00, 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, -7.6792e-05, -3.0233e-05,\n", " -9.6717e-06, 1.4552e-05, 4.2945e-05, 0.0000e+00, 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9984e-05, 1.2823e-04,\n", " 1.7506e-04, 6.9073e-05, 1.3380e-05, 0.0000e+00, 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 8.3031e-06, 2.0438e-04,\n", " 3.2004e-04, 2.1256e-04, 8.4305e-05, 0.0000e+00, 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00, 0.0000e+00]], dtype=torch.float64)\n", "\n", "tensor([[-1.7055, -1.5948, -0.0292, 0.0098],\n", " [ 4.0919, 3.2404, 0.0995, -0.0680],\n", " [ 0.0399, 0.0477, -0.1600, -0.5528],\n", " [-0.0165, 0.0056, 1.3239, -1.6700]], dtype=torch.float64)\n", "tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0.]], dtype=torch.float64)\n", "\n" ] } ], "source": [ "# Differentiability with respect to state and knobs\n", "\n", "state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)\n", "\n", "# Mapping\n", "\n", "transformation, _ = mapping(ring, probe, other, ('kn', ['Quadrupole'], None, None), matched=False, limit=8, epsilon=1.0E-9)\n", "\n", "pprint(torch.func.jacrev(transformation, 0)(state, kn_ring))\n", "pprint(torch.func.jacrev(transformation, 1)(state, kn_ring))\n", "print()\n", "\n", "# Mapping around closed orbit\n", "\n", "transformation, _ = mapping(ring, probe, other, ('kn', ['Quadrupole'], None, None), matched=True, limit=8, epsilon=1.0E-9)\n", "\n", "pprint(torch.func.jacrev(transformation, 0)(state, kn_ring))\n", "pprint(torch.func.jacrev(transformation, 1)(state, kn_ring))\n", "print()" ] } ], "metadata": { "colab": { "collapsed_sections": [ "myt0_gMIOq7b", "5d97819c" ], "name": "03_frequency.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.1" }, "latex_envs": { "LaTeX_envs_menu_present": true, "autoclose": false, "autocomplete": true, "bibliofile": "biblio.bib", "cite_by": "apalike", "current_citInitial": 1, "eqLabelWithNumbers": true, "eqNumInitial": 1, "hotkeys": { "equation": "Ctrl-E", "itemize": "Ctrl-I" }, "labels_anchors": false, "latex_user_defs": false, "report_style_numbering": false, "user_envs_cfg": false } }, "nbformat": 4, "nbformat_minor": 5 }