{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "262a5ec8-2553-4237-ab62-319b6ca22089", "metadata": {}, "source": [ "# Example-54: Coupling (Coupling correction based on minimal tune)" ] }, { "cell_type": "code", "execution_count": 1, "id": "affe7136-4467-4ce5-9ce3-4d1bc35c4d9a", "metadata": {}, "outputs": [], "source": [ "# In this example minimal tune distance is used for coupling correction\n", "# Given measured values, fit lattice to reproduce measurements" ] }, { "cell_type": "code", "execution_count": 2, "id": "89033aec-f92c-4108-affb-2251570dd6fa", "metadata": {}, "outputs": [], "source": [ "# Import\n", "\n", "from pprint import pprint\n", "\n", "import torch\n", "from torch import Tensor\n", "\n", "from pathlib import Path\n", "\n", "import matplotlib\n", "from matplotlib import pyplot as plt\n", "\n", "from model.library.line import Line\n", "\n", "from model.command.external import load_sdds\n", "from model.command.external import load_lattice\n", "\n", "from model.command.wrapper import group\n", "from model.command.wrapper import forward\n", "from model.command.wrapper import inverse\n", "from model.command.wrapper import normalize\n", "from model.command.wrapper import Wrapper\n", "\n", "from model.command.build import build\n", "from model.command.coupling import coupling" ] }, { "cell_type": "code", "execution_count": 3, "id": "8bb6e890-fa35-46d7-85e5-830d3922b0e0", "metadata": {}, "outputs": [], "source": [ "# Build and setup lattice\n", "\n", "# Load ELEGANT table\n", "\n", "path = Path('ic.lte')\n", "data = load_lattice(path)\n", "\n", "# Build ELEGANT table\n", "\n", "ring:Line = build('RING', 'ELEGANT', data)\n", "ring.flatten()\n", "\n", "# Merge drifts\n", "\n", "ring.merge()\n", "\n", "# Set linear dipoles\n", "\n", "for element in ring: \n", " if element.__class__.__name__ == 'Dipole':\n", " element.linear = True\n", "\n", "# Set number of elements of different kinds\n", "\n", "nb = ring.describe['BPM']\n", "nq = ring.describe['Quadrupole']\n", "ns = ring.describe['Sextupole']" ] }, { "cell_type": "code", "execution_count": 4, "id": "a1c4f458-d5b9-42bd-b7fc-ed4f143fab08", "metadata": {}, "outputs": [], "source": [ "# Set lattice with errors\n", "\n", "error:Line = ring.clone()\n", "\n", "nq = error.describe['Quadrupole']\n", "\n", "error_ks = 0.1*torch.randn(nq, dtype=torch.float64)\n", "\n", "index = 0\n", "label = ''\n", "\n", "for element in error.sequence:\n", " if element.__class__.__name__ == 'Quadrupole':\n", " if label != element.name:\n", " index +=1\n", " label = element.name\n", " element.ks = (element.ks + error_ks[index - 1]).item()" ] }, { "cell_type": "code", "execution_count": 5, "id": "22bdadaf-3b59-4ca6-b538-4dc92630ca08", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(0., dtype=torch.float64)\n", "tensor(0.0109, dtype=torch.float64)\n" ] } ], "source": [ "# Compute delta Q min\n", "\n", "print(coupling(ring, []))\n", "print(coupling(error, []))" ] }, { "cell_type": "code", "execution_count": 6, "id": "13d47ead-dc96-45fc-9e7e-fd3d24ecd9de", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(0.0109, dtype=torch.float64)\n", "tensor(1.8041e-16, dtype=torch.float64)\n", "\n", "tensor(0.0109, dtype=torch.float64)\n", "tensor(1.4051e-16, dtype=torch.float64)\n", "\n", "tensor(0.0117, dtype=torch.float64)\n", "tensor(0.0104, dtype=torch.float64)\n", "tensor(0.0092, dtype=torch.float64)\n", "tensor(0.0079, dtype=torch.float64)\n", "tensor(0.0067, dtype=torch.float64)\n", "tensor(0.0054, dtype=torch.float64)\n", "tensor(0.0041, dtype=torch.float64)\n", "tensor(0.0027, dtype=torch.float64)\n", "tensor(0.0013, dtype=torch.float64)\n", "tensor(0.0002, dtype=torch.float64)\n", "tensor(0.0012, dtype=torch.float64)\n", "tensor(0.0018, dtype=torch.float64)\n", "tensor(0.0020, dtype=torch.float64)\n", "tensor(0.0019, dtype=torch.float64)\n", "tensor(0.0016, dtype=torch.float64)\n", "tensor(0.0011, dtype=torch.float64)\n", "tensor(0.0005, dtype=torch.float64)\n", "tensor(0.0002, dtype=torch.float64)\n", "tensor(0.0007, dtype=torch.float64)\n", "tensor(0.0009, dtype=torch.float64)\n", "tensor(0.0009, dtype=torch.float64)\n", "tensor(0.0008, dtype=torch.float64)\n", "tensor(0.0006, dtype=torch.float64)\n", "tensor(0.0002, dtype=torch.float64)\n", "tensor(0.0003, dtype=torch.float64)\n", "tensor(0.0006, dtype=torch.float64)\n", "tensor(0.0007, dtype=torch.float64)\n", "tensor(0.0006, dtype=torch.float64)\n", "tensor(0.0004, dtype=torch.float64)\n", "tensor(5.3763e-05, dtype=torch.float64)\n", "tensor(0.0004, dtype=torch.float64)\n", "tensor(0.0007, dtype=torch.float64)\n", "tensor(0.0008, dtype=torch.float64)\n", "tensor(0.0007, dtype=torch.float64)\n", "tensor(0.0005, dtype=torch.float64)\n", "tensor(0.0002, dtype=torch.float64)\n", "tensor(0.0002, dtype=torch.float64)\n", "tensor(0.0004, dtype=torch.float64)\n", "tensor(0.0005, dtype=torch.float64)\n", "tensor(0.0004, dtype=torch.float64)\n", "tensor(0.0001, dtype=torch.float64)\n", "tensor(0.0002, dtype=torch.float64)\n", "tensor(0.0004, dtype=torch.float64)\n", "tensor(0.0004, dtype=torch.float64)\n", "tensor(0.0003, dtype=torch.float64)\n", "tensor(1.9189e-05, dtype=torch.float64)\n", "tensor(0.0003, dtype=torch.float64)\n", "tensor(0.0005, dtype=torch.float64)\n", "tensor(0.0006, dtype=torch.float64)\n", "tensor(0.0004, dtype=torch.float64)\n", "tensor(0.0002, dtype=torch.float64)\n", "tensor(0.0002, dtype=torch.float64)\n", "tensor(0.0004, dtype=torch.float64)\n", "tensor(0.0004, dtype=torch.float64)\n", "tensor(0.0003, dtype=torch.float64)\n", "tensor(0.0001, dtype=torch.float64)\n", "tensor(0.0002, dtype=torch.float64)\n", "tensor(0.0004, dtype=torch.float64)\n", "tensor(0.0004, dtype=torch.float64)\n", "tensor(0.0003, dtype=torch.float64)\n", "tensor(9.4410e-06, dtype=torch.float64)\n", "tensor(0.0004, dtype=torch.float64)\n", "tensor(0.0006, dtype=torch.float64)\n", "tensor(0.0006, dtype=torch.float64)\n" ] } ], "source": [ "# Correction (model to experiment)\n", "\n", "# Set target delta Q min\n", "\n", "coupling_error = coupling(error, [])\n", "\n", "# Set learning rate\n", "\n", "lr = 0.001\n", "\n", "# Set parametric coupling (small value is added to avoid nan values)\n", "\n", "def coupling_model(ks):\n", " return coupling(ring, [ks + 2.5E-16], ('ks', ['Quadrupole'], None, None))\n", "\n", "# Set objective function\n", "\n", "def objective(ks):\n", " return (coupling_error - coupling_model(ks)).norm()\n", "\n", "# Test objective function\n", "\n", "print(objective(0.0*error_ks))\n", "print(objective(1.0*error_ks))\n", "print()\n", "\n", "# Set normalized objective\n", "\n", "objective = normalize(objective, [(-0.5, 0.5)])\n", "\n", "# Test normalized objective\n", "\n", "print(objective(*forward([0.0*error_ks], [(-0.5, 0.5)])))\n", "print(objective(*forward([1.0*error_ks], [(-0.5, 0.5)])))\n", "print()\n", "\n", "# Initial settings\n", "# Note, it is better to use random initial along with multi-start\n", "\n", "ks = torch.rand(nq, dtype=torch.float64)\n", "\n", "# Set model (forward returns evaluated objective)\n", "\n", "model = Wrapper(objective, ks)\n", "\n", "# Set optimizer\n", "\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n", "\n", "# Perform optimization\n", "\n", "for epoch in range(64):\n", " value = model()\n", " value.backward()\n", " optimizer.step()\n", " optimizer.zero_grad()\n", " print(value.detach())" ] }, { "cell_type": "code", "execution_count": 7, "id": "98cabce5-49e5-40ec-9ee8-862142b95593", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(0.0109, dtype=torch.float64)\n", "tensor(0.0109, dtype=torch.float64)\n" ] } ], "source": [ "# Apply corrections\n", "\n", "lattice:Line = error.clone()\n", "\n", "index = 0\n", "label = ''\n", "\n", "for line in lattice.sequence:\n", " if element.__class__.__name__ == 'Quadrupole':\n", " if label != element.name:\n", " index +=1\n", " label = element.name\n", " element.ks = (element.ks - ks[index - 1]).item()\n", "\n", "\n", "print(coupling(error, []))\n", "print(coupling(lattice, []))" ] } ], "metadata": { "colab": { "collapsed_sections": [ "myt0_gMIOq7b", "5d97819c" ], "name": "03_frequency.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.1" }, "latex_envs": { "LaTeX_envs_menu_present": true, "autoclose": false, "autocomplete": true, "bibliofile": "biblio.bib", "cite_by": "apalike", "current_citInitial": 1, "eqLabelWithNumbers": true, "eqNumInitial": 1, "hotkeys": { "equation": "Ctrl-E", "itemize": "Ctrl-I" }, "labels_anchors": false, "latex_user_defs": false, "report_style_numbering": false, "user_envs_cfg": false } }, "nbformat": 4, "nbformat_minor": 5 }