{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "262a5ec8-2553-4237-ab62-319b6ca22089", "metadata": {}, "source": [ "# Example-46: Twiss (Optics correction)" ] }, { "cell_type": "code", "execution_count": 1, "id": "89f1b2d0-d62c-40f0-aa24-98a15e78976c", "metadata": {}, "outputs": [], "source": [ "# In this example model response matrices of normal and chromatic Twiss parameters are used for correction\n", "# ML style optimization is also performed for optics correction" ] }, { "cell_type": "code", "execution_count": 2, "id": "465011b9-bb27-4fd5-919a-3923ae103a71", "metadata": {}, "outputs": [], "source": [ "# Import\n", "\n", "from pprint import pprint\n", "\n", "import torch\n", "from torch import Tensor\n", "from torch.utils.data import TensorDataset\n", "from torch.utils.data import DataLoader\n", "\n", "from pathlib import Path\n", "\n", "import matplotlib\n", "from matplotlib import pyplot as plt\n", "matplotlib.rcParams['text.usetex'] = True\n", "\n", "from model.library.line import Line\n", "\n", "from model.command.util import select\n", "\n", "from model.command.external import load_sdds\n", "from model.command.external import load_lattice\n", "\n", "from model.command.build import build\n", "\n", "from model.command.wrapper import group\n", "from model.command.wrapper import forward\n", "from model.command.wrapper import inverse\n", "from model.command.wrapper import normalize\n", "from model.command.wrapper import Wrapper\n", "\n", "from model.command.tune import tune\n", "from model.command.twiss import twiss\n", "from model.command.twiss import chromatic_twiss" ] }, { "cell_type": "code", "execution_count": 3, "id": "aa613d87-34ee-4349-b29e-9ce7603b2e29", "metadata": {}, "outputs": [], "source": [ "# Load ELEGANT twiss\n", "\n", "path = Path('ic.twiss')\n", "parameters, columns = load_sdds(path)\n", "\n", "nu_qx:Tensor = torch.tensor(parameters['nux'] % 1, dtype=torch.float64)\n", "nu_qy:Tensor = torch.tensor(parameters['nuy'] % 1, dtype=torch.float64)\n", "\n", "# Set twiss parameters at BPMs\n", "\n", "kinds = select(columns, 'ElementType', keep=False)\n", "\n", "a_qx = select(columns, 'alphax', keep=False)\n", "b_qx = select(columns, 'betax' , keep=False)\n", "a_qy = select(columns, 'alphay', keep=False)\n", "b_qy = select(columns, 'betay' , keep=False)\n", "\n", "a_qx:Tensor = torch.tensor([value for (key, value), kind in zip(a_qx.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)\n", "b_qx:Tensor = torch.tensor([value for (key, value), kind in zip(b_qx.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)\n", "a_qy:Tensor = torch.tensor([value for (key, value), kind in zip(a_qy.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)\n", "b_qy:Tensor = torch.tensor([value for (key, value), kind in zip(b_qy.items(), kinds.values()) if kind == 'MONI'], dtype=torch.float64)\n", "\n", "positions = select(columns, 's', keep=False).items()\n", "positions = [value for (key, value), kind in zip(positions, kinds.values()) if kind == 'MONI']" ] }, { "cell_type": "code", "execution_count": 4, "id": "c948e8c4-a0bb-48ac-a3fa-6677b0a5e02d", "metadata": {}, "outputs": [], "source": [ "# Build and setup lattice\n", "\n", "# Load ELEGANT table\n", "\n", "path = Path('ic.lte')\n", "data = load_lattice(path)\n", "\n", "# Build ELEGANT table\n", "\n", "ring:Line = build('RING', 'ELEGANT', data)\n", "ring.flatten()\n", "\n", "# Merge drifts\n", "\n", "ring.merge()\n", "\n", "# Split BPMs\n", "\n", "ring.split((None, ['BPM'], None, None))\n", "\n", "# Roll lattice start\n", "\n", "ring.roll(1)\n", "\n", "# Set linear dipoles\n", "\n", "for element in ring:\n", " if element.__class__.__name__ == 'Dipole':\n", " element.linear = True\n", "\n", "# Split lattice into lines by BPMs\n", "\n", "ring.splice()\n", "\n", "# Set number of elements of different kinds\n", "\n", "nb = ring.describe['BPM']\n", "nq = ring.describe['Quadrupole']\n", "ns = ring.describe['Sextupole']" ] }, { "cell_type": "code", "execution_count": 5, "id": "c01bdb6a-3831-4696-8781-adeacb8c418a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n", "True\n" ] } ], "source": [ "# Compare tunes\n", "\n", "nuqx, nuqy = tune(ring, [], alignment=False, matched=True)\n", "\n", "print(torch.allclose(nu_qx, nuqx))\n", "print(torch.allclose(nu_qy, nuqy))" ] }, { "cell_type": "code", "execution_count": 6, "id": "ace743e7-4133-4b67-a98d-ebab26964da8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n", "True\n", "True\n", "True\n" ] } ], "source": [ "# Compare twiss\n", "\n", "aqx, bqx, aqy, bqy = twiss(ring, [], alignment=False, matched=True, advance=True, full=False, convert=True).T\n", "\n", "print(torch.allclose(a_qx, aqx))\n", "print(torch.allclose(b_qx, bqx))\n", "print(torch.allclose(a_qy, aqy))\n", "print(torch.allclose(b_qy, bqy))" ] }, { "cell_type": "code", "execution_count": 7, "id": "cd4930e5-e0c1-4346-a1c4-c0561aeaf9cf", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-0.0893, 0.4014, 1.2554, -0.9068, -1.5491, -0.9866, -0.6147, -0.7071,\n", " -1.9186, 0.2045, -0.1659, 0.4221, 1.9239, 2.2147, 0.5854, -0.4487,\n", " -0.4684, -1.9595, -0.0112, -0.2204, -0.9209, -1.5967, -0.0541, 1.5081,\n", " 0.5988, -0.3222, -0.4638, 0.8415],\n", " [ 0.0182, 0.1496, 0.5733, -0.6730, -0.7756, -0.5178, -0.4104, -0.5325,\n", " -0.9651, 0.2699, -0.0145, 0.1734, 0.9035, 1.1954, 0.3374, -0.2984,\n", " -0.4055, -1.0062, 0.1660, -0.0147, -0.4518, -0.8235, 0.1829, 0.8281,\n", " 0.3562, -0.2319, -0.4848, 0.2358],\n", " [ 1.6183, -0.0219, -0.2993, -0.0049, 0.3526, 0.2005, -0.5750, -0.4693,\n", " 0.1832, 0.0065, -0.5948, -2.2667, -0.8472, 0.8119, 2.2819, 0.7206,\n", " 0.0759, -0.2127, 0.4684, 0.6358, -0.0792, -0.2893, -0.0394, 0.3584,\n", " 0.2001, -1.2685, -0.6598, 0.2486],\n", " [-0.7202, 0.1496, 0.2258, -0.0342, -0.1608, -0.0228, 0.3921, 0.2789,\n", " -0.1348, 0.0622, 0.4565, 1.3632, 0.4758, -0.5180, -1.3618, -0.3315,\n", " 0.0206, 0.1024, -0.2828, -0.3332, 0.1447, 0.2237, -0.0135, -0.1681,\n", " 0.0184, 1.0088, 0.4647, -0.2653]], dtype=torch.float64)\n", "\n", "tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0.]], dtype=torch.float64)\n", "\n" ] } ], "source": [ "# Test derivatives with respect kn and ks at the lattice start\n", "\n", "kn = torch.zeros(nq, dtype=torch.float64)\n", "ks = torch.zeros(nq, dtype=torch.float64)\n", "\n", "pprint(torch.func.jacrev(lambda kn: twiss(ring, [kn], ('kn', ['Quadrupole'], None, None), matched=True))(kn))\n", "print()\n", "\n", "pprint(torch.func.jacrev(lambda ks: twiss(ring, [ks], ('ks', ['Quadrupole'], None, None), matched=True))(ks))\n", "print()\n", "\n", "# Note, first order derivatives with respect to ks are identicaly equal to zero as expected\n", "# Second order derivative is not identicaly equal to zero in general\n", "# In the following, only first order derivatives are used for optics correctios (lattice without coupling)" ] }, { "cell_type": "code", "execution_count": 8, "id": "7afc08d7-2bc1-4dca-b0fe-08632d4167c3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([16, 4, 28])\n", "torch.Size([16, 4, 28])\n" ] } ], "source": [ "# Compute twiss derivatives with respect to quadrupole settings (normal and chromatic)\n", "\n", "def fn_dtwiss_dkn(kn):\n", " return twiss(ring, [kn], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)\n", "\n", "def fn_dtwiss_dp_dkn(kn):\n", " return chromatic_twiss(ring, [kn], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)\n", "\n", "kn = torch.zeros(nq, dtype=torch.float64)\n", "\n", "dtwiss_dkn = torch.func.jacrev(fn_dtwiss_dkn)(kn)\n", "dtwiss_dp_dkn = torch.func.jacrev(fn_dtwiss_dp_dkn)(kn)\n", "\n", "print(dtwiss_dkn.shape)\n", "print(dtwiss_dp_dkn.shape)" ] }, { "cell_type": "code", "execution_count": 9, "id": "75be4cfc-010d-4770-a6cf-ce8d24324070", "metadata": {}, "outputs": [], "source": [ "# Set lattice with focusing errors (no coupling)\n", "\n", "error:Line = ring.clone()\n", "\n", "nq = error.describe['Quadrupole']\n", "\n", "error_kn = 0.1*torch.randn(nq, dtype=torch.float64)\n", "\n", "index = 0\n", "label = ''\n", "\n", "for line in error.sequence:\n", " for element in line:\n", " if element.__class__.__name__ == 'Quadrupole':\n", " if label != element.name:\n", " index +=1\n", " label = element.name\n", " element.kn = (element.kn + error_kn[index - 1]).item()" ] }, { "cell_type": "code", "execution_count": 10, "id": "cdfc0eaf-01ae-4461-9c32-c8e11babc763", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(1.3708, dtype=torch.float64)\n", "tensor(0.8085, dtype=torch.float64)\n", "tensor(0.5866, dtype=torch.float64)\n", "tensor(0.3774, dtype=torch.float64)\n", "\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Compute twiss and plot beta beating\n", "\n", "ax_model, bx_model, ay_model, by_model = twiss(ring, [], alignment=False, matched=True, advance=True, full=False, convert=True).T\n", "ax_error, bx_error, ay_error, by_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True).T\n", "\n", "# Compare twiss\n", "\n", "print((ax_model - ax_error).norm())\n", "print((bx_model - bx_error).norm())\n", "print((ay_model - ay_error).norm())\n", "print((by_model - by_error).norm())\n", "print()\n", "\n", "# Plot beta beating\n", "\n", "plt.figure(figsize=(16, 2))\n", "plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')\n", "plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')\n", "plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 11, "id": "77137cd1-1a6e-4334-ad4e-c5c8a0f39069", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(1.7376, dtype=torch.float64)\n", "tensor(0.2810, dtype=torch.float64)\n" ] } ], "source": [ "# Test Twiss response\n", "\n", "twiss_error = torch.stack([ax_error, bx_error, ay_error, by_error])\n", "twiss_model = torch.stack([ax_model, bx_model, ay_model, by_model])\n", "\n", "print((twiss_error - (twiss_model + 0.0*(dtwiss_dkn @ error_kn).T)).norm())\n", "print((twiss_error - (twiss_model + 1.0*(dtwiss_dkn @ error_kn).T)).norm())" ] }, { "cell_type": "code", "execution_count": 12, "id": "f56fc0fa-2227-4222-9386-10106f9a367b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(1.7376, dtype=torch.float64)\n", "tensor(1.5638, dtype=torch.float64)\n", "tensor(1.4067, dtype=torch.float64)\n", "tensor(1.2649, dtype=torch.float64)\n", "tensor(1.1371, dtype=torch.float64)\n", "tensor(1.0220, dtype=torch.float64)\n", "tensor(0.9184, dtype=torch.float64)\n", "tensor(0.8252, dtype=torch.float64)\n", "tensor(0.7415, dtype=torch.float64)\n", "tensor(0.6663, dtype=torch.float64)\n", "tensor(0.5989, dtype=torch.float64)\n", "tensor(0.5384, dtype=torch.float64)\n", "tensor(0.4842, dtype=torch.float64)\n", "tensor(0.4357, dtype=torch.float64)\n", "tensor(0.3922, dtype=torch.float64)\n", "tensor(0.3533, dtype=torch.float64)\n", "tensor(0.3185, dtype=torch.float64)\n", "tensor(0.2873, dtype=torch.float64)\n", "tensor(0.2595, dtype=torch.float64)\n", "tensor(0.2345, dtype=torch.float64)\n", "tensor(0.2122, dtype=torch.float64)\n", "tensor(0.1922, dtype=torch.float64)\n", "tensor(0.1743, dtype=torch.float64)\n", "tensor(0.1583, dtype=torch.float64)\n", "tensor(0.1440, dtype=torch.float64)\n", "tensor(0.1311, dtype=torch.float64)\n", "tensor(0.1195, dtype=torch.float64)\n", "tensor(0.1091, dtype=torch.float64)\n", "tensor(0.0998, dtype=torch.float64)\n", "tensor(0.0914, dtype=torch.float64)\n", "tensor(0.0838, dtype=torch.float64)\n", "tensor(0.0770, dtype=torch.float64)\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAABjYAAAC+CAYAAACWEzYrAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAARsElEQVR4nO3dQU4b2RYG4JMoEqOQag9bzaSygYg2C0BxdmB6B/DmDECMWj1CYdBzkhV02ztISV4AotQboHrAU4aOQ4+YNG8Q4RcS7DjYxr7F90lIqfLFOZZyc7F/zr2Prq6urgIAAAAAACABjxddAAAAAAAAwKQEGwAAAAAAQDIEGwAAAAAAQDIEGwAAAAAAQDIEGwAAAAAAQDIEGwAAAAAAQDIEGwAAAAAAQDIEGwAAAAAAQDKeLLqAafz777/x/v37ePr0aTx69GjR5QAAAAAAAHdwdXUV//zzT/z444/x+PH4noykg43379/H2traossAAAAAAABm4Pz8PH766aexY5IONp4+fRoRn17o6urqgqsBAAAAAADu4uLiItbW1oaf+4+TdLBxvf3U6uqqYAMAAAAAABI3ybETDg8HAAAAAACSIdgAAAAAAACSIdgAAAAAAACSIdgAAAAAAACSkfTh4QAAtbO5uegKZqPXW3QFAAAA1JRgAwAgAZt//b7oEm7Ve7G76BIAAAB4YGxFBQAAAAAAJEOwAQAAAAAAJEOwAQAAAAAAJEOwAQAAAAAAJEOwAQAAAAAAJEOwAQAAAAAAJEOwAQAAAAAAJEOwAQAAAAAAJEOwAQAAAAAAJOPJogsAAAAAAICJbW7efvuv3++5kG/rvdgd8UDvfgupGR0bAAAAAABAMgQbAAAAAABAMgQbAAAAAABAMgQbAAAAAABAMuZ2eHhVVdHtdiPP86iqKnZ2diLLspHjy7KM7e3tOD09nVdJD8uIA3SS4xAdAAAAAAA+M7dgY2traxhSVFUV29vb0el0bh17HYCUZTmvcgAAAAAAgBqYS7BRVdWN6zzPoyiKkePb7fY8ygAAAAAAAGpmLmdsFEURjUbjxr1Go6EjAwAAAAAAmMpcOjYGg8Gt9/v9/lTPe3l5GZeXl8Pri4uLqZ4PAAAAAABIy1w6NkYZFXhM6vDwMJ49ezb8Wltbm01hAAAAAABAEuYSbGRZ9lV3Rr/fjyzLpnreg4OD+Pjx4/Dr/Px8qucDAAAAAADSMpdgo9Vq3Xq/2WxO9bwrKyuxurp64wsAAAAAAHg45nLGRp7nN66rqopmszns2CjLMrIs+2pcxKftqqbt7GC0zb9+X3QJX+m92F10CQAAAADMw+bmoiuYjV5v0RUAn5nbGRudTif29/ej2+3G8fFxdDqd4WOHh4fR7XaH10VRxP7+/q2PAQAAAAAAXJtLx0bEp66N169fR0REu92+8djnIUfEp62rWq3WcDwAAAAAUG92FgHuam4dGwAAAAAAALM2t44NAAAAAIBacFYILBUdGwAAAAAAQDIEGwAAAAAAQDIEGwAAAAAAQDIEGwAAAAAAQDIEGwAAAAAAQDKeLLoAuIvNzUVX8LVeb9EVAAAAs+D9BgDAchNssNxGvaP46/f7rWMSm7u33/cOBAAAls+49ML7DQCApSbYAGAxlvFXIe/ChwkAAAAA98oZGwAAAAAAQDJ0bAAAAAAA3MHmMm5fGBG9FyO2MISa0LEBAAAAAAAkQ7ABAAAAAAAkw1ZUAEA91OVAegAAAGAswQZAjSzj57q93qIrgOXc99aetwAAAHA3gg3gYVnGT/6/l6QAAAAAgAfMGRsAAAAAAEAyBBsAAAAAAEAyBBsAAAAAAEAynLEBEA4WBgAAAIBU6NgAAAAAAACSoWMDIDWbm6MfW8LOk9j8vs6TZeyeidBBAwDA3Y37EX5Rer1FVwAAd6djAwAAAAAASIZgAwAAAAAASIatqGDelrHn+C70KQMAAAAAS0CwAQviHAEAAAAAgO9nKyoAAAAAACAZc+vYqKoqut1u5HkeVVXFzs5OZFk29VgAABJgK0YAAADmZG7BxtbWVpyenkbEp+Bie3s7Op3O1GMBAEjXMm7FaBtGAGZiXKi/hOtfbI5Y/4T6ACRgLsFGVVU3rvM8j6Ioph4LAMzJiDfiPoQG4HssY7OWz2gBAOpnLmdsFEURjUbjxr1GoxFlWU41FgAAAAAAeNjm0rExGAxuvd/v96cae3l5GZeXl8Pri4uLO9UHAAAAAACkaW5nbNxmVIgx6djDw8P47bffZldQnY3ot17OLuzvr6oOr2M5X0PEyMoS2qYmYsxWNebGApgby2LsFk7mxgLUfG7MZPR9GVNVXeb4CHV4Hcv4GiLq8Tru8m9qKef4qO2xxuxRtZSvowY/i0TUfG74N7UQtf43NUZSr6Muc2Mmo+9T+nP8Ib+H5dvmshVVlmVfdVz0+/3IsmyqsQcHB/Hx48fh1/n5+SzLBgAAAAAAltxcgo1Wq3Xr/WazOdXYlZWVWF1dvfEFAAAAAAA8HHPZiirP8xvXVVVFs9kcdmGUZRlZlkWe598cCwAAwJIZs60IkLhR83vUtm6LNrLeZS0YgFmY2xkbnU4n9vf3Y2NjI05OTqLT6QwfOzw8jI2Njdjb2/vmWAAAAAAAgGtzCzbyPI/Xr19HRES73b7x2JfBxbixAAAAAAAA1+YWbAAAAABQD3agA2CZzOXwcAAAAAAAgHkQbAAAAAAAAMkQbAAAAAAAAMkQbAAAAAAAAMkQbAAAAAAAAMl4sugCAAAAgAeu17v9/ub9ljGRUbUCAPdGsAFMpfdid9ElAAAAAAAPiGADmIzfSgIAAAAAloAzNgAAAAAAgGQINgAAAAAAgGTYigoAAB4oZ2UBQJqs4cBDJ9gAgDnwRgNYGnU5J2tzc9EVAMD9q8s6DjBjtqICAAAAAACSoWMDAAAAakb3KABQZ4INAAAAYCnZhQcAuI2tqAAAAAAAgGQINgAAAAAAgGTYigoApmF/BAAAAIB7pWMDAAAAAABIho4NAAAASJXuUQDgAdKxAQAAAAAAJEPHBgAAkKzei91FlwAAANwzHRsAAAAAAEAydGwAAADLry7nCGxuLroCAABIno4NAAAAAAAgGTo2AAAAAHgQnM0EUA9z6dioqiqOjo6i2+3G0dFRDAaDsePLsoyff/55HqUAAAAAAAA1MpeOja2trTg9PY2ITyHH9vZ2dDqdW8d2u93I8zzKspxHKQAAAAA8NHU5mwmAW8082Kiq6sZ1nudRFMXI8e12e9YlAAAAAAAANTXzraiKoohGo3HjXqPR0JEBAAAAAABMbeYdG6PO0+j3+1M/9+XlZVxeXg6vLy4upn5OAAAAAAAgHXM5Y+M23zpAfBKHh4fx22+/TV8MAABMofdid9ElAAAAPFgTBxtv3ryJs7OzkY+/evUqWq1WZFn2VXdGv9+PLMvuXOS1g4OD2N39/5vIi4uLWFtbm/p5AQAAAACANEwcbOzs7Ew0rtVqxfHx8Vf3m83m5FWNsLKyEisrK1M/DwAATKTXW3QFAAAAfGHmh4fneX7juqqqaDabw46Nsiyjqqpbv3cW21UBAAAAAAD1NfNgIyKi0+nE/v5+dLvdOD4+jk6nM3zs8PAwut3u8Looitjf37/1MQAAAAAAgM/N5fDwPM/j9evXERHRbrdvPPZ5yBHxaeuqVqs1HA8AAAAAADDKXDo2AAAAAAAA5kGwAQAAAAAAJEOwAQAAAAAAJEOwAQAAAAAAJEOwAQAAAAAAJEOwAQAAAAAAJOPJogsAAAAAAGBxei92F10CfBcdGwAAAAAAQDIEGwAAAAAAQDIEGwAAAAAAQDIEGwAAAAAAQDIEGwAAAAAAQDIEGwAAAAAAQDKeLLoAAAAAAADuQa+36ApgJnRsAAAAAAAAyRBsAAAAAAAAyRBsAAAAAAAAyRBsAAAAAAAAyRBsAAAAAAAAyRBsAAAAAAAAyRBsAAAAAAAAyRBsAAAAAAAAyRBsAAAAAAAAyRBsAAAAAAAAyRBsAAAAAAAAyRBsAAAAAAAAyRBsAAAAAAAAyRBsAAAAAAAAyXgyjyetqiq63W7keR5VVcXOzk5kWXbr2LIsoyiKiIg4OTmJt2/fjhwLAAAAAAA8bHMJNra2tuL09DQiPoUc29vb0el0bh1bFEXs7e1FRMTR0VG8fPly+L0AAAAPQe/F7qJLAACAZDy6urq6muUTVlV1I9iIiPjhhx/iw4cPX40tyzJevnw5fKyqqnj+/HmcnZ1Fnuff/LsuLi7i2bNn8fHjx1hdXZ3diwAAAAAAAO7N93zeP/MzNoqiiEajceNeo9GIsiy/Gru+vh5v374dXg8Gg+F4AAAAAACAL818K6rrcOJL/X7/1vvtdnv45z/++CNardbIMzYuLy/j8vJyeH1xcXHnOgEAAAAAgPTMvGNjlFGBx+ePd7vdkWdxREQcHh7Gs2fPhl9ra2szrhIAAAAAAFhmE3dsvHnzJs7OzkY+/urVq2G3xZfdGf1+f2QXxrX9/f149+7d2HEHBwexu/v/Q/UuLi6EGwAAAAAA8IDc2+Hhf//998jQ4ujoKNrtduR5Puzs+FYQEuHwcAAAAAAAqIPv+bx/5mds5Hl+47qqqmg2m8OgoizLyLJsOK7b7cb6+vow1Pjzzz9jZ2dnor/rOpNx1gYAAAAAAKTr+nP+SXoxZt6xEfEpzDg+Po6NjY04OTmJg4ODYbCxtbUVGxsbsbe3F1VVxfPnz298b5Zl8eHDh4n+nv/+97+2ogIAAAAAgJo4Pz+Pn376aeyYuQQb9+Xff/+N9+/fx9OnT+PRo0eLLudBuT7f5Pz83DZgUDPmN9SbOQ71Zo5DfZnfUG/mOHzq1Pjnn3/ixx9/jMePH48dO/OtqO7T48ePv5ncMF+rq6v+s4WaMr+h3sxxqDdzHOrL/IZ6M8d56J49ezbRuPGxBwAAAAAAwBIRbAAAAAAAAMkQbHAnKysr8euvv8bKysqiSwFmzPyGejPHod7Mcagv8xvqzRyH75P04eEAAAAAAMDDomMDAAAAAABIhmADAAAAAABIxpNFF0BaqqqKbrcbeZ5HVVWxs7MTWZYtuixgRsqyjIiI9fX1qKoqBoNBrK+vL7gq4K7Ksozt7e04PT29cd96DvUwao5bzyF9ZVlGURQREXFychJv374drtXWcUjfuDluHYfJCDb4LltbW8M3TlVVxfb2dnQ6nQVXBczK8fFxvHnzJiIiWq2W+Q0Ju/7A4/qN0ees55C+cXPceg7pK4oi9vb2IiLi6OgoXr58OVy7reOQvnFz3DoOk3F4OBOrqurGD1ARET/88EN8+PBhgVUBs/TmzZv45ZdfIiL81hfUxKNHj+LzH/es51AvX87xCOs5pK4sy3j58uVwba6qKp4/fx5nZ2cREdZxSNy4OZ7nuXUcJuSMDSZWFEU0Go0b9xqNxq2/JQakK8syPzxBjVnP4WGwnkO61tfX4+3bt8PrwWAQEZ/Wa+s4pG/cHL9mHYdvsxUVE7v+j/ZL/X7/fgsB5mYwGES3242IT/t8/uc//4k8zxdcFTBL1nOoP+s5pK/dbg///Mcff0Sr1Yosy6zjUBOj5niEdRwmJdhgaqN+sALS8/nBg3mex6tXr4Yt70C9Wc+hPqznUB/XH3B+vvXUqHFAem6b49ZxmIytqJhYlmVf/RZIv9/XGgc1UlXV8M95nkdVVTfuAemznkP9Wc+hPvb39+Pdu3fDddo6DvXy5RyPsI7DpAQbTKzVat16v9ls3nMlwDxcH2D2pS/38AXSZj2HerOeQ30cHR3F/v5+5Hkeg8EgBoOBdRxq5LY5bh2HyQk2mNiX+/lVVRXNZtNvhkBN5Hker1+/Hl4XRRHtdtschxr4fHsK6znUz5dz3HoO6et2u7G+vj78wPPPP/+MLMus41AT4+a4dRwm8+jq6upq0UWQjqqq4vj4ODY2NuLk5CQODg785wo1UpZlFEURWZbF2dnZjR+ogLQURRHv3r2Lo6Oj2Nvbi42NjeEhhdZzSN+4OW49h7RVVRXPnz+/cS/Lsvjw4cPwces4pOtbc9w6DpMRbAAAAAAAAMmwFRUAAAAAAJAMwQYAAAAAAJAMwQYAAAAAAJAMwQYAAAAAAJAMwQYAAAAAAJAMwQYAAAAAAJAMwQYAAAAAAJAMwQYAAAAAAJAMwQYAAAAAAJAMwQYAAAAAAJAMwQYAAAAAAJAMwQYAAAAAAJCM/wFgtrrE6g7iFgAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Perform correction (model to experiment)\n", "\n", "# Set response matrix\n", "\n", "matrix = dtwiss_dkn.reshape(-1, nq)\n", "\n", "# Set target twiss parameters\n", "\n", "twiss_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True)\n", "\n", "# Set learning rate\n", "\n", "lr = 0.1\n", "\n", "# Set initial values\n", "\n", "kn = torch.zeros_like(error_kn)\n", "\n", "# Fit\n", "\n", "for _ in range(32):\n", " twiss_model = twiss(ring, [kn], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)\n", " dkn = - lr*torch.linalg.lstsq(matrix, (twiss_model - twiss_error).flatten(), driver='gelsd').solution\n", " kn += dkn\n", " print((twiss_model - twiss_error).norm())\n", "\n", "# Plot final quadrupole settings\n", "\n", "plt.figure(figsize=(16, 2))\n", "plt.bar(range(len(error_kn)), error_kn.cpu().numpy(), color='red', alpha=0.75, width=1)\n", "plt.bar(range(len(kn)), +kn.cpu().numpy(), color='blue', alpha=0.75, width=0.75)\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 13, "id": "f0832cee-54f1-4365-a0a5-d97f7a4b7efe", "metadata": {}, "outputs": [], "source": [ "# Apply corrections\n", "\n", "lattice:Line = error.clone()\n", "\n", "index = 0\n", "label = ''\n", "\n", "for line in lattice.sequence:\n", " for element in line:\n", " if element.__class__.__name__ == 'Quadrupole':\n", " if label != element.name:\n", " index +=1\n", " label = element.name\n", " element.kn = (element.kn - kn[index - 1]).item()" ] }, { "cell_type": "code", "execution_count": 14, "id": "ea12e689-2408-40c3-afe6-b59ddda48155", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Compute twiss and plot beta beating\n", "\n", "ax_model, bx_model, ay_model, by_model = twiss(ring, [], alignment=False, matched=True, advance=True, full=False, convert=True).T\n", "ax_error, bx_error, ay_error, by_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True).T\n", "ax_final, bx_final, ay_final, by_final = twiss(lattice, [], alignment=False, matched=True, advance=True, full=False, convert=True).T\n", "\n", "# Plot beta beating\n", "\n", "plt.figure(figsize=(16, 2))\n", "plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')\n", "plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')\n", "plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_final)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='x')\n", "plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_final)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='x')\n", "plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 15, "id": "c877b0c4-fe73-4e40-bff8-1f61f3f0e562", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(69.5123, dtype=torch.float64)\n", "tensor(5.5803, dtype=torch.float64)\n" ] } ], "source": [ "# Test Twiss response (chromatic)\n", "\n", "twiss_error = chromatic_twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True)\n", "twiss_model = chromatic_twiss(ring, [], alignment=False, matched=True, advance=True, full=False, convert=True)\n", "\n", "print((twiss_error - (twiss_model + 0.0*(dtwiss_dp_dkn @ error_kn))).norm())\n", "print((twiss_error - (twiss_model + 1.0*(dtwiss_dp_dkn @ error_kn))).norm())" ] }, { "cell_type": "code", "execution_count": 16, "id": "77135f2d-8032-482d-b562-281d0fcc83c4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(69.5340, dtype=torch.float64)\n", "tensor(62.5709, dtype=torch.float64)\n", "tensor(56.2831, dtype=torch.float64)\n", "tensor(50.6092, dtype=torch.float64)\n", "tensor(45.4934, dtype=torch.float64)\n", "tensor(40.8845, dtype=torch.float64)\n", "tensor(36.7358, dtype=torch.float64)\n", "tensor(33.0042, dtype=torch.float64)\n", "tensor(29.6499, dtype=torch.float64)\n", "tensor(26.6367, dtype=torch.float64)\n", "tensor(23.9311, dtype=torch.float64)\n", "tensor(21.5026, dtype=torch.float64)\n", "tensor(19.3235, dtype=torch.float64)\n", "tensor(17.3686, dtype=torch.float64)\n", "tensor(15.6148, dtype=torch.float64)\n", "tensor(14.0417, dtype=torch.float64)\n", "tensor(12.6304, dtype=torch.float64)\n", "tensor(11.3641, dtype=torch.float64)\n", "tensor(10.2279, dtype=torch.float64)\n", "tensor(9.2081, dtype=torch.float64)\n", "tensor(8.2926, dtype=torch.float64)\n", "tensor(7.4706, dtype=torch.float64)\n", "tensor(6.7322, dtype=torch.float64)\n", "tensor(6.0687, dtype=torch.float64)\n", "tensor(5.4725, dtype=torch.float64)\n", "tensor(4.9364, dtype=torch.float64)\n", "tensor(4.4543, dtype=torch.float64)\n", "tensor(4.0206, dtype=torch.float64)\n", "tensor(3.6303, dtype=torch.float64)\n", "tensor(3.2790, dtype=torch.float64)\n", "tensor(2.9626, dtype=torch.float64)\n", "tensor(2.6776, dtype=torch.float64)\n", "tensor(2.4208, dtype=torch.float64)\n", "tensor(2.1893, dtype=torch.float64)\n", "tensor(1.9806, dtype=torch.float64)\n", "tensor(1.7924, dtype=torch.float64)\n", "tensor(1.6225, dtype=torch.float64)\n", "tensor(1.4692, dtype=torch.float64)\n", "tensor(1.3308, dtype=torch.float64)\n", "tensor(1.2059, dtype=torch.float64)\n", "tensor(1.0930, dtype=torch.float64)\n", "tensor(0.9909, dtype=torch.float64)\n", "tensor(0.8987, dtype=torch.float64)\n", "tensor(0.8153, dtype=torch.float64)\n", "tensor(0.7399, dtype=torch.float64)\n", "tensor(0.6717, dtype=torch.float64)\n", "tensor(0.6099, dtype=torch.float64)\n", "tensor(0.5540, dtype=torch.float64)\n", "tensor(0.5034, dtype=torch.float64)\n", "tensor(0.4575, dtype=torch.float64)\n", "tensor(0.4159, dtype=torch.float64)\n", "tensor(0.3783, dtype=torch.float64)\n", "tensor(0.3441, dtype=torch.float64)\n", "tensor(0.3131, dtype=torch.float64)\n", "tensor(0.2850, dtype=torch.float64)\n", "tensor(0.2595, dtype=torch.float64)\n", "tensor(0.2364, dtype=torch.float64)\n", "tensor(0.2154, dtype=torch.float64)\n", "tensor(0.1963, dtype=torch.float64)\n", "tensor(0.1789, dtype=torch.float64)\n", "tensor(0.1632, dtype=torch.float64)\n", "tensor(0.1489, dtype=torch.float64)\n", "tensor(0.1358, dtype=torch.float64)\n", "tensor(0.1240, dtype=torch.float64)\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAABjYAAAC+CAYAAACWEzYrAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAARn0lEQVR4nO3dQU4b1x8H8F+iSKwSpl5WZTO5QETNAVCcG5jeAP57FiBWVVcoLLonOUFr3yAjcQDEqBfIdEGVpePQFZvyX0RYIbEdJ3iw3/D5SEjMzGPys8Lj2f76vffg6urqKgAAAAAAABLwcNEFAAAAAAAAzEqwAQAAAAAAJEOwAQAAAAAAJEOwAQAAAAAAJEOwAQAAAAAAJEOwAQAAAAAAJEOwAQAAAAAAJEOwAQAAAAAAJOPRogu4jf/++y/evXsXjx8/jgcPHiy6HAAAAAAA4DtcXV3Fv//+Gz/++GM8fDh9TkbSwca7d+9ibW1t0WUAAAAAAABzcH5+Hj/99NPUNkkHG48fP46Ijw/0yZMnC64GAAAAAAD4HhcXF7G2tjZ633+apION6+Wnnjx5ItgAAAAAAIDEzbLthM3DAQAAAACAZAg2AAAAAACAZAg2AAAAAACAZAg2AAAAAACAZCS9eTgAQONsbo4//dfvd1zIbE6e7U64cHK3hQAAAHBvmLEBAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAk49GiCwAAAAAAgJltbo4//dfvd1zI1508251w4eRuC2kYMzYAAAAAAIBkCDYAAAAAAIBkCDYAAAAAAIBkCDYAAAAAAIBk1LZ5eFVV0e/3I8/zqKoqdnZ2Isuyie3Lsozt7e04Ozurq6T7pQkb6ETYRAcAAAAAgBtqCza2trZGIUVVVbG9vR29Xm9s2+sApCzLusoBAAAAAAAaoJZgo6qqG8d5nkdRFBPbd7vdOsoAAAAAAAAappY9NoqiiFardeNcq9UyIwMAAAAAALiVWmZsDIfDsecHg8Gt7nt5eRmXl5ej44uLi1vdDwAAAAAASEstMzYmmRR4zOrw8DBWV1dHX2tra/MpDAAAAAAASEItwUaWZV/MzhgMBpFl2a3ue3BwEB8+fBh9nZ+f3+p+AAAAAABAWmpZiqrT6cTx8fEX59vt9q3uu7KyEisrK7e6BwAAAACweJubi67gSycni64AmEUtwUae5zeOq6qKdrs9mrFRlmVkWfZFu4iPy1XddmYHAAAAALAEpqUXf/1+d3XManN3/HmJByyV2vbY6PV6sb+/H/1+P46Pj6PX642uHR4eRr/fHx0XRRH7+/tjrwEAAAAAAFyrZcZGxMdZGy9fvoyIiG63e+PapyFHxMelqzqdzqg9AAAAAADAOLXN2AAAAAAAAJi32mZsAAAAAAA0woS9QjaXcZ+QiDh5Zq8Qms2MDQAAAAAAIBmCDQAAAAAAIBmCDQAAAAAAIBmCDQAAAAAAIBmCDQAAAAAAIBmCDQAAAAAAIBmPFl0ATLW5uegKbu/kZNEVAAAAn5vyWmPzr9/vsJDZnDzbnXDB6w0A4P4RbJCkpF5oAOM1IbiM8GYCAAAAwB0TbACwVJYxuIwQXgIAAAAsC3tsAAAAAAAAyRBsAAAAAAAAyRBsAAAAAAAAybDHBgDQDFM2pF/GvVvs2wIAAADfx4wNAAAAAAAgGWZsAPfLhE90J/Vp7pOTuy0EAAAAAJaIGRsAAAAAAEAyBBsAAAAAAEAyBBsAAAAAAEAyBBsAAAAAAEAyBBsAAAAAAEAyHi26AAC+0ebmoisAAAAAgIURbAA0yOZfvy+6hC+cPNtddAkAAAAANIilqAAAAAAAgGSYsQF1m7Bs0DJ+sj5iyqfrT07uthAAAAAAgDHM2AAAAAAAAJJR24yNqqqi3+9HnudRVVXs7OxElmW3bgsAAAAAANxftQUbW1tbcXZ2FhEfg4vt7e3o9Xq3bgsAQAISWopx4jKMEZZiBAAAWEK1BBtVVd04zvM8iqK4dVsAAACApTQh1E+OUB+ABNQSbBRFEa1W68a5VqsVZVnG+vr6d7cFAGrSlE/XAwAsIc+pAGC+atk8fDgcjj0/GAxu1RYAAAAAALjfattjY5xJIcasbS8vL+Py8nJ0fHFxMYeqAAAAAACAVNQSbGRZ9sWMi8FgEFmW3art4eFh/Pbbb/MstbkmrIm5nCtlfntVTXgcy/kYIiZWltAyNRFTplXrGwugbyyL79kgeTn/P/SN5dKExzGlqqb08Qma8DiW8TFENONxNP53aso+Akn9rUro71REw/uG36mFaPTv1BRJPY6m9I25tL5L6ffx+/walq+rZSmqTqcz9ny73b5V24ODg/jw4cPo6/z8/HaFAgAAAAAASallxkae5zeOq6qKdrs9moVRlmVkWRZ5nn+17adWVlZiZWWljpIBAAC4JZsRAwBwF2rbY6PX68X+/n5sbGzE6elp9Hq90bXDw8PY2NiIvb29r7YFAABgyUxZVgRIXFLLu0R863I7ADRDbcFGnufx8uXLiIjodrs3rn0eXExrCwAAAAAAcK2WPTYAAAAAAADqINgAAAAAAACSIdgAAAAAAACSIdgAAAAAAACSIdgAAAAAAACS8WjRBQAAAACMc/Jsd9ElAABLyIwNAAAAAAAgGWZsAAAAAIt1crLoCgCAhAg2gNlMeKGxvC8/lrcyAAAAAOD7WYoKAAAAAABIhmADAAAAAABIhmADAAAAAABIhj02AACgyaZtyLt5d2XMbFK9m8tYLAAAsAiCDQAAuKemZR4AwBKYGPjfbRkz8cQCuEOCDQAAAABIiAwBuO/ssQEAAAAAACRDsAEAAAAAACRDsAEAAAAAACTDHhsAUIOTZ7uLLgEAAACgkQQbAHAbdu0DAAAAuFOCDQAAAEjVpA9ZbN5tGTPzoRAAYA4EGwAAANAw8gMAoMlsHg4AAAAAACRDsAEAAAAAACTDUlQAAECyTp7tLroEAADgjgk2AACA5deUDQM2l3VHZwAASIelqAAAAAAAgGQINgAAAAAAgGTUEmxUVRVHR0fR7/fj6OgohsPh1PZlWcbPP/9cRykAAAAAAECD1LLHxtbWVpydnUXEx5Bje3s7er3e2Lb9fj/yPI+yLOsoBQAAAID7ZsLeTMu5Y9NyVgWwzOYebFRVdeM4z/MoimJi+263O+8SAAAAAACAhpr7UlRFUUSr1bpxrtVqmZEBAAAAAADc2txnbEzaT2MwGNz63peXl3F5eTk6vri4uPU9AQAAAACAdNSyefg4X9tAfBaHh4exuro6+lpbW7t9YQAAAAAAQDJmnrHx6tWrePv27cTrL168iE6nE1mWfTE7YzAYRJZl313ktYODg9jd3R0dX1xcCDcAAKiPjUcBAACWzszBxs7OzkztOp1OHB8ff3G+3W7PXtUEKysrsbKycuv7AAAAAAAAaZr7UlR5nt84rqoq2u32aMZGWZZRVdXYn53HclUAAAAAAEBz1bLHRq/Xi/39/ej3+3F8fBy9Xm907fDwMPr9/ui4KIrY398few0AAAAAAOBTMy9F9S3yPI+XL19GRES3271x7dOQI+Lj0lWdTmfUHgAAAAAAYJJaZmwAAAAAAADUQbABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAk49GiCwAAAAAA4A6cnIw/fcdlzGY5q2I5mLEBAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAkQ7ABAAAAAAAk41EdN62qKvr9fuR5HlVVxc7OTmRZNrZtWZZRFEVERJyensbr168ntgUAAAAAAO63WoKNra2tODs7i4iPIcf29nb0er2xbYuiiL29vYiIODo6iufPn49+FgAAAAAA4FNzX4qqqqobx3mej2ZkfK4syzg8PBwdd7vdKMvyi3sAAAAAAABE1DBjoyiKaLVaN861Wq0oyzLW19dvnF9fX4/Xr1+PjofD4ag9AABA45ycjD99x2XMZjmrAgCAuQcb1+HE5waDwdjz3W539P0ff/wRnU5n4h4bl5eXcXl5OTq+uLj47joBAAAAAID0zH0pqkkmBR6fXu/3+xP34oiIODw8jNXV1dHX2tranKsEAAAAAACW2cwzNl69ehVv376deP3Fixej2Rafz84YDAYTZ2Fc29/fjzdv3kxtd3BwELu7u6Pji4sL4QYAAAAAANwjD66urq7mecOqqmJrayvOzs5G53744Yf4+++/J4YWR0dH0e12I8/z0cyOrwUhER+DjdXV1fjw4UM8efJkDtUDAAAAAAB37Vve75/7UlR5nt84rqoq2u32KKgoyzKqqhpd7/f7sb6+Pgo1/vzzz5lCDQAAAAAA4P6Z+4yNiI9hxvHxcWxsbMTp6WkcHByMwoqtra3Y2NiIvb29qKoqnj59euNnsyyL9+/fz/TvfPjwIbIsi/PzczM2AAAAAAAgUddbTwyHw1hdXZ3atpZg4678888/9tgAAAAAAICGOD8/j59++mlqm6SDjf/++y/evXsXjx8/jgcPHiy6nHvlOj0zWwaaR/+GZtPHodn0cWgu/RuaTR+HiKurq/j333/jxx9/jIcPp++i8eiOaqrFw4cPv5rcUK8nT574YwsNpX9Ds+nj0Gz6ODSX/g3Npo9z331tCaprc988HAAAAAAAoC6CDQAAAAAAIBmCDb7LyspK/Prrr7GysrLoUoA507+h2fRxaDZ9HJpL/4Zm08fh2yS9eTgAAAAAAHC/mLEBAAAAAAAkQ7ABAAAAAAAk49GiCyAtVVVFv9+PPM+jqqrY2dmJLMsWXRYwJ2VZRkTE+vp6VFUVw+Ew1tfXF1wV8L3Ksozt7e04Ozu7cd54Ds0wqY8bzyF9ZVlGURQREXF6ehqvX78ejdXGcUjftD5uHIfZCDb4JltbW6MXTlVVxfb2dvR6vQVXBczL8fFxvHr1KiIiOp2O/g0Ju37D4/qF0aeM55C+aX3ceA7pK4oi9vb2IiLi6Ogonj9/Phq7jeOQvml93DgOs7F5ODOrqurGE6iIiB9++CHev3+/wKqAeXr16lX88ssvERE+9QUN8eDBg/j06Z7xHJrl8z4eYTyH1JVlGc+fPx+NzVVVxdOnT+Pt27cREcZxSNy0Pp7nuXEcZmSPDWZWFEW0Wq0b51qt1thPiQHpyrLMkydoMOM53A/Gc0jX+vp6vH79enQ8HA4j4uN4bRyH9E3r49eM4/B1lqJiZtd/aD83GAzuthCgNsPhMPr9fkR8XOfzf//7X+R5vuCqgHkynkPzGc8hfd1ud/T9H3/8EZ1OJ7IsM45DQ0zq4xHGcZiVYINbm/TECkjPpxsP5nkeL168GE15B5rNeA7NYTyH5rh+g/PTpacmtQPSM66PG8dhNpaiYmZZln3xKZDBYGBqHDRIVVWj7/M8j6qqbpwD0mc8h+YznkNz7O/vx5s3b0bjtHEcmuXzPh5hHIdZCTaYWafTGXu+3W7fcSVAHa43MPvc52v4AmkznkOzGc+hOY6OjmJ/fz/yPI/hcBjD4dA4Dg0yro8bx2F2gg1m9vl6flVVRbvd9skQaIg8z+Ply5ej46Iootvt6uPQAJ8uT2E8h+b5vI8bzyF9/X4/1tfXR294/vnnn5FlmXEcGmJaHzeOw2weXF1dXS26CNJRVVUcHx/HxsZGnJ6exsHBgT+u0CBlWUZRFJFlWbx9+/bGEyogLUVRxJs3b+Lo6Cj29vZiY2NjtEmh8RzSN62PG88hbVVVxdOnT2+cy7Is3r9/P7puHId0fa2PG8dhNoINAAAAAAAgGZaiAgAAAAAAkiHYAAAAAAAAkiHYAAAAAAAAkiHYAAAAAAAAkiHYAAAAAAAAkiHYAAAAAAAAkiHYAAAAAAAAkiHYAAAAAAAAkiHYAAAAAAAAkiHYAAAAAAAAkiHYAAAAAAAAkiHYAAAAAAAAkvF/o/iuQfqV11cAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Perform correction (model to experiment) including chromatic twiss\n", "\n", "# Set response matrix\n", "\n", "matrix = torch.vstack([dtwiss_dkn.reshape(-1, nq), dtwiss_dp_dkn.reshape(-1, nq)])\n", "\n", "# Set target twiss parameters\n", "\n", "twiss_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True)\n", "chromatic_twiss_error = chromatic_twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True)\n", "\n", "# Set learning rate\n", "\n", "lr = 0.1\n", "\n", "# Set initial values\n", "\n", "kn = torch.zeros_like(error_kn)\n", "\n", "# Fit\n", "\n", "for _ in range(64):\n", " twiss_model = twiss(ring, [kn], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)\n", " chromatic_twiss_model = chromatic_twiss(ring, [kn], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)\n", " dkn = - lr*torch.linalg.lstsq(matrix, torch.stack([twiss_model - twiss_error, chromatic_twiss_model - chromatic_twiss_error]).flatten(), driver='gelsd').solution\n", " kn += dkn\n", " print(torch.stack([twiss_model - twiss_error, chromatic_twiss_model - chromatic_twiss_error]).norm())\n", "\n", "# Plot final quadrupole settings\n", "\n", "plt.figure(figsize=(16, 2))\n", "plt.bar(range(len(error_kn)), error_kn.cpu().numpy(), color='red', alpha=0.75, width=1)\n", "plt.bar(range(len(kn)), +kn.cpu().numpy(), color='blue', alpha=0.75, width=0.75)\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 17, "id": "82db6fb5-20d7-4f69-a934-6e4d66e55172", "metadata": {}, "outputs": [], "source": [ "# Apply corrections\n", "\n", "lattice:Line = error.clone()\n", "\n", "index = 0\n", "label = ''\n", "\n", "for line in lattice.sequence:\n", " for element in line:\n", " if element.__class__.__name__ == 'Quadrupole':\n", " if label != element.name:\n", " index +=1\n", " label = element.name\n", " element.kn = (element.kn - kn[index - 1]).item()" ] }, { "cell_type": "code", "execution_count": 18, "id": "c498df81-f365-4830-9620-1cdb2bc89b75", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Compute twiss and plot beta beating\n", "\n", "ax_model, bx_model, ay_model, by_model = twiss(ring, [], alignment=False, matched=True, advance=True, full=False, convert=True).T\n", "ax_error, bx_error, ay_error, by_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True).T\n", "ax_final, bx_final, ay_final, by_final = twiss(lattice, [], alignment=False, matched=True, advance=True, full=False, convert=True).T\n", "\n", "# Plot beta beating\n", "\n", "plt.figure(figsize=(16, 2))\n", "plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')\n", "plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')\n", "plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_final)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='x')\n", "plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_final)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='x')\n", "plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 19, "id": "7327d250-aac9-4872-8645-3b094cda26a5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(1.7376, dtype=torch.float64)\n", "tensor(0., dtype=torch.float64)\n", "\n", "tensor(1.7376, dtype=torch.float64)\n", "tensor(4.8796e-13, dtype=torch.float64)\n", "\n", "tensor(1.7376, dtype=torch.float64)\n", "tensor(1.5623, dtype=torch.float64)\n", "tensor(1.4006, dtype=torch.float64)\n", "tensor(1.2536, dtype=torch.float64)\n", "tensor(1.1222, dtype=torch.float64)\n", "tensor(1.0066, dtype=torch.float64)\n", "tensor(0.9054, dtype=torch.float64)\n", "tensor(0.8187, dtype=torch.float64)\n", "tensor(0.7487, dtype=torch.float64)\n", "tensor(0.6968, dtype=torch.float64)\n", "tensor(0.6613, dtype=torch.float64)\n", "tensor(0.6364, dtype=torch.float64)\n", "tensor(0.6152, dtype=torch.float64)\n", "tensor(0.5918, dtype=torch.float64)\n", "tensor(0.5634, dtype=torch.float64)\n", "tensor(0.5298, dtype=torch.float64)\n", "tensor(0.4917, dtype=torch.float64)\n", "tensor(0.4498, dtype=torch.float64)\n", "tensor(0.4053, dtype=torch.float64)\n", "tensor(0.3615, dtype=torch.float64)\n", "tensor(0.3244, dtype=torch.float64)\n", "tensor(0.2996, dtype=torch.float64)\n", "tensor(0.2876, dtype=torch.float64)\n", "tensor(0.2831, dtype=torch.float64)\n", "tensor(0.2812, dtype=torch.float64)\n", "tensor(0.2804, dtype=torch.float64)\n", "tensor(0.2789, dtype=torch.float64)\n", "tensor(0.2738, dtype=torch.float64)\n", "tensor(0.2647, dtype=torch.float64)\n", "tensor(0.2542, dtype=torch.float64)\n", "tensor(0.2450, dtype=torch.float64)\n", "tensor(0.2369, dtype=torch.float64)\n", "tensor(0.2295, dtype=torch.float64)\n", "tensor(0.2242, dtype=torch.float64)\n", "tensor(0.2220, dtype=torch.float64)\n", "tensor(0.2203, dtype=torch.float64)\n", "tensor(0.2160, dtype=torch.float64)\n", "tensor(0.2092, dtype=torch.float64)\n", "tensor(0.2011, dtype=torch.float64)\n", "tensor(0.1925, dtype=torch.float64)\n", "tensor(0.1854, dtype=torch.float64)\n", "tensor(0.1817, dtype=torch.float64)\n", "tensor(0.1799, dtype=torch.float64)\n", "tensor(0.1769, dtype=torch.float64)\n", "tensor(0.1718, dtype=torch.float64)\n", "tensor(0.1649, dtype=torch.float64)\n", "tensor(0.1571, dtype=torch.float64)\n", "tensor(0.1504, dtype=torch.float64)\n", "tensor(0.1459, dtype=torch.float64)\n", "tensor(0.1423, dtype=torch.float64)\n", "tensor(0.1390, dtype=torch.float64)\n", "tensor(0.1354, dtype=torch.float64)\n", "tensor(0.1310, dtype=torch.float64)\n", "tensor(0.1267, dtype=torch.float64)\n", "tensor(0.1231, dtype=torch.float64)\n", "tensor(0.1199, dtype=torch.float64)\n", "tensor(0.1169, dtype=torch.float64)\n", "tensor(0.1127, dtype=torch.float64)\n", "tensor(0.1073, dtype=torch.float64)\n", "tensor(0.1017, dtype=torch.float64)\n", "tensor(0.0971, dtype=torch.float64)\n", "tensor(0.0936, dtype=torch.float64)\n", "tensor(0.0902, dtype=torch.float64)\n", "tensor(0.0865, dtype=torch.float64)\n" ] } ], "source": [ "# ML style correction (model to experiment)\n", "\n", "# Set target twiss parameters\n", "\n", "twiss_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True)\n", "\n", "# Set learning rate\n", "\n", "lr = 0.005\n", "\n", "# Set parametric twiss\n", "\n", "def twiss_model(kn):\n", " return twiss(ring, [kn], ('kn', ['Quadrupole'], None, None), alignment=False, matched=True, advance=True, full=False, convert=True)\n", "\n", "# Set objective function\n", "\n", "def objective(kn):\n", " return (twiss_error - twiss_model(kn)).norm()\n", "\n", "# Set initial values\n", "\n", "kn = torch.zeros_like(error_kn)\n", "\n", "# Test objective function\n", "\n", "print(objective(0.0*error_kn))\n", "print(objective(1.0*error_kn))\n", "print()\n", "\n", "# Set normalized objective\n", "\n", "objective = normalize(objective, [(-0.5, 0.5)])\n", "\n", "# Test normalized objective\n", "\n", "print(objective(*forward([0.0*error_kn], [(-0.5, 0.5)])))\n", "print(objective(*forward([1.0*error_kn], [(-0.5, 0.5)])))\n", "print()\n", "\n", "# Normalize initial settings\n", "\n", "kn, *_ = forward([kn], [(-0.5, 0.5)])\n", "\n", "# Set model (forward returns evaluated objective)\n", "\n", "model = Wrapper(objective, kn)\n", "\n", "# Set optimizer\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", "\n", "# Perform optimization\n", "\n", "for epoch in range(64):\n", " value = model()\n", " value.backward()\n", " optimizer.step()\n", " optimizer.zero_grad()\n", " print(value.detach())" ] }, { "cell_type": "code", "execution_count": 20, "id": "b9c45388-8774-4d59-bb8d-2695dc1d1afb", "metadata": {}, "outputs": [], "source": [ "# Apply corrections\n", "\n", "kn, *_ = inverse([kn], [(-0.5, 0.5)])\n", "\n", "lattice:Line = error.clone()\n", "\n", "index = 0\n", "label = ''\n", "\n", "for line in lattice.sequence:\n", " for element in line:\n", " if element.__class__.__name__ == 'Quadrupole':\n", " if label != element.name:\n", " index +=1\n", " label = element.name\n", " element.kn = (element.kn - kn[index - 1]).item()" ] }, { "cell_type": "code", "execution_count": 21, "id": "ebefada3-eea3-4cdb-8278-94b6a10339d3", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Compute twiss and plot beta beating\n", "\n", "ax_model, bx_model, ay_model, by_model = twiss(ring, [], alignment=False, matched=True, advance=True, full=False, convert=True).T\n", "ax_error, bx_error, ay_error, by_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True).T\n", "ax_final, bx_final, ay_final, by_final = twiss(lattice, [], alignment=False, matched=True, advance=True, full=False, convert=True).T\n", "\n", "# Plot beta beating\n", "\n", "plt.figure(figsize=(16, 2))\n", "plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')\n", "plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')\n", "plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_final)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='x')\n", "plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_final)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='x')\n", "plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 22, "id": "ff578ab4-cbbc-4e66-b624-6f58b906855d", "metadata": {}, "outputs": [], "source": [ "# AdEMAMix optimizer\n", "\n", "# https://arxiv.org/abs/2409.03137\n", "# https://github.com/apple/ml-ademamix\n", "\n", "import math\n", "import torch\n", "from torch.optim import Optimizer\n", "\n", "\n", "def linear_warmup_scheduler(step, alpha_end, alpha_start=0, warmup=1):\n", " if step < warmup:\n", " a = step / float(warmup)\n", " return (1.0-a) * alpha_start + a * alpha_end\n", " return alpha_end\n", "\n", "\n", "def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1):\n", " def f(beta, eps=1e-8):\n", " return math.log(0.5)/math.log(beta+eps)-1\n", " def f_inv(t):\n", " return math.pow(0.5, 1/(t+1))\n", " if step < warmup:\n", " a = step / float(warmup)\n", " return f_inv((1.0-a) * f(beta_start) + a * f(beta_end))\n", " return beta_end\n", "\n", "\n", "class AdEMAMix(Optimizer):\n", " \"\"\"Implements the AdEMAMix algorithm.\n", "\n", " Arguments:\n", " params (iterable): iterable of parameters to optimize or dicts defining\n", " parameter groups\n", " lr (float, optional): learning rate (default: 1e-3)\n", " betas (Tuple[float, float, float], optional): coefficients used for computing\n", " running averages of gradient and its square (default: (0.9, 0.999, 0.9999)) \n", " corresponding to beta_1, beta_2, beta_3 in AdEMAMix\n", " alpha (float): AdEMAMix alpha coeficient mixing the slow and fast EMAs (default: 2)\n", " beta3_warmup (int, optional): number of warmup steps used to increase beta3 (default: None)\n", " alpha_warmup: (int, optional): number of warmup steps used to increase alpha (default: None)\n", " eps (float, optional): term added to the denominator to improve\n", " numerical stability (default: 1e-8)\n", " weight_decay (float, optional): weight decay as in AdamW (default: 0)\n", " \"\"\"\n", " def __init__(self, \n", " params, \n", " lr=1e-3, \n", " betas=(0.9, 0.999, 0.9999), \n", " alpha=2.0, \n", " beta3_warmup=None, \n", " alpha_warmup=None, \n", " eps=1e-8,\n", " weight_decay=0):\n", " \n", " defaults = dict(lr=lr, \n", " betas=betas, \n", " eps=eps, \n", " alpha=alpha, \n", " beta3_warmup=beta3_warmup,\n", " alpha_warmup=alpha_warmup, \n", " weight_decay=weight_decay)\n", " \n", " super().__init__(params, defaults)\n", "\n", " def __setstate__(self, state):\n", " super().__setstate__(state)\n", "\n", " @torch.no_grad()\n", " def step(self, closure=None):\n", " \"\"\"Performs a single optimization step.\n", "\n", " Arguments:\n", " closure (callable, optional): A closure that reevaluates the model\n", " and returns the loss.\n", " \"\"\"\n", " loss = None\n", " if closure is not None:\n", " with torch.enable_grad():\n", " loss = closure()\n", "\n", " for group in self.param_groups:\n", " \n", " lr = group[\"lr\"]\n", " lmbda = group[\"weight_decay\"]\n", " eps = group[\"eps\"]\n", " beta1, beta2, beta3_final = group[\"betas\"]\n", " beta3_warmup = group[\"beta3_warmup\"]\n", " alpha_final = group[\"alpha\"]\n", " alpha_warmup = group[\"alpha_warmup\"]\n", " \n", " for p in group['params']:\n", " if p.grad is None:\n", " continue\n", " grad = p.grad\n", " if grad.is_sparse:\n", " raise RuntimeError('AdEMAMix does not support sparse gradients.')\n", "\n", " state = self.state[p]\n", "\n", "\n", " if len(state) == 0:\n", " state['step'] = 0\n", " if beta1 != 0.0:\n", " state['exp_avg_fast'] = torch.zeros_like(p, memory_format=torch.preserve_format)\n", " else: \n", " state['exp_avg_fast'] = None\n", " state['exp_avg_slow'] = torch.zeros_like(p, memory_format=torch.preserve_format)\n", " state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)\n", "\n", " exp_avg_fast, exp_avg_slow, exp_avg_sq = state['exp_avg_fast'], state['exp_avg_slow'], state['exp_avg_sq']\n", "\n", " state['step'] += 1\n", " bias_correction1 = 1 - beta1 ** state['step']\n", " bias_correction2 = 1 - beta2 ** state['step']\n", "\n", " if alpha_warmup is not None:\n", " alpha = linear_warmup_scheduler(state[\"step\"], alpha_end=alpha_final, alpha_start=0, warmup=alpha_warmup)\n", " else:\n", " alpha = alpha_final\n", " \n", " if beta3_warmup is not None:\n", " beta3 = linear_hl_warmup_scheduler(state[\"step\"], beta_end=beta3_final, beta_start=beta1, warmup=beta3_warmup)\n", " else:\n", " beta3 = beta3_final\n", "\n", " if beta1 != 0.0:\n", " exp_avg_fast.mul_(beta1).add_(grad, alpha=1 - beta1)\n", " else:\n", " exp_avg_fast = grad\n", " exp_avg_slow.mul_(beta3).add_(grad, alpha=1 - beta3)\n", " exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)\n", "\n", " denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)\n", "\n", " update = (exp_avg_fast.div(bias_correction1) + alpha * exp_avg_slow) / denom\n", "\n", " update.add_(p, alpha=lmbda)\n", "\n", " p.add_(-lr * update)\n", "\n", " return loss" ] }, { "cell_type": "code", "execution_count": 23, "id": "e866fea2-f0c2-4068-8527-c4ace2afa8bc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(0.0472, dtype=torch.float64)\n", "tensor(0.0381, dtype=torch.float64)\n", "tensor(0.0307, dtype=torch.float64)\n", "tensor(0.0247, dtype=torch.float64)\n", "tensor(0.0199, dtype=torch.float64)\n", "tensor(0.0161, dtype=torch.float64)\n", "tensor(0.0133, dtype=torch.float64)\n", "tensor(0.0111, dtype=torch.float64)\n", "tensor(0.0094, dtype=torch.float64)\n", "tensor(0.0082, dtype=torch.float64)\n", "tensor(0.0075, dtype=torch.float64)\n", "tensor(0.0069, dtype=torch.float64)\n", "tensor(0.0066, dtype=torch.float64)\n", "tensor(0.0063, dtype=torch.float64)\n", "tensor(0.0060, dtype=torch.float64)\n", "tensor(0.0057, dtype=torch.float64)\n", "tensor(0.0053, dtype=torch.float64)\n", "tensor(0.0049, dtype=torch.float64)\n", "tensor(0.0045, dtype=torch.float64)\n", "tensor(0.0040, dtype=torch.float64)\n", "tensor(0.0035, dtype=torch.float64)\n", "tensor(0.0031, dtype=torch.float64)\n", "tensor(0.0027, dtype=torch.float64)\n", "tensor(0.0023, dtype=torch.float64)\n", "tensor(0.0020, dtype=torch.float64)\n", "tensor(0.0018, dtype=torch.float64)\n", "tensor(0.0016, dtype=torch.float64)\n", "tensor(0.0015, dtype=torch.float64)\n", "tensor(0.0014, dtype=torch.float64)\n", "tensor(0.0014, dtype=torch.float64)\n", "tensor(0.0014, dtype=torch.float64)\n", "tensor(0.0014, dtype=torch.float64)\n", "tensor(0.0014, dtype=torch.float64)\n", "tensor(0.0014, dtype=torch.float64)\n", "tensor(0.0014, dtype=torch.float64)\n", "tensor(0.0013, dtype=torch.float64)\n", "tensor(0.0013, dtype=torch.float64)\n", "tensor(0.0012, dtype=torch.float64)\n", "tensor(0.0012, dtype=torch.float64)\n", "tensor(0.0011, dtype=torch.float64)\n", "tensor(0.0010, dtype=torch.float64)\n", "tensor(0.0009, dtype=torch.float64)\n", "tensor(0.0009, dtype=torch.float64)\n", "tensor(0.0008, dtype=torch.float64)\n", "tensor(0.0008, dtype=torch.float64)\n", "tensor(0.0008, dtype=torch.float64)\n", "tensor(0.0008, dtype=torch.float64)\n", "tensor(0.0008, dtype=torch.float64)\n", "tensor(0.0008, dtype=torch.float64)\n", "tensor(0.0008, dtype=torch.float64)\n", "tensor(0.0008, dtype=torch.float64)\n", "tensor(0.0008, dtype=torch.float64)\n", "tensor(0.0007, dtype=torch.float64)\n", "tensor(0.0007, dtype=torch.float64)\n", "tensor(0.0007, dtype=torch.float64)\n", "tensor(0.0007, dtype=torch.float64)\n", "tensor(0.0007, dtype=torch.float64)\n", "tensor(0.0006, dtype=torch.float64)\n", "tensor(0.0006, dtype=torch.float64)\n", "tensor(0.0006, dtype=torch.float64)\n", "tensor(0.0006, dtype=torch.float64)\n", "tensor(0.0006, dtype=torch.float64)\n", "tensor(0.0006, dtype=torch.float64)\n", "tensor(0.0006, dtype=torch.float64)\n" ] } ], "source": [ "# ML style correction (batched)\n", "\n", "# Set target twiss parameters\n", "\n", "twiss_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True)\n", "\n", "# Set learning rate\n", "\n", "lr = 0.005\n", "\n", "# Define rings (Twiss parameters will be computed at each ring start)\n", "\n", "rings:list[Line] = []\n", "\n", "for i, _ in enumerate(ring):\n", " line = ring.clone()\n", " line.roll(i)\n", " rings.append(line)\n", "\n", "# Set batched function\n", "\n", "_, ((_, names, _), *_), _ = group(ring, 0, len(ring) - 1, ('kn', ['Quadrupole'], None, None))\n", "\n", "def task(Is, kn):\n", " result = []\n", " for I in Is:\n", " result.append(twiss(rings[I], [kn], ('kn', None, names, None), alignment=False, matched=True, convert=True))\n", " return torch.stack(result)\n", "\n", "# Set initial values\n", "\n", "kn = torch.zeros_like(error_kn)\n", "\n", "# Normalize objective\n", "\n", "task = normalize(task, [(None, None), (-0.5, 0.5)])\n", "\n", "# Normalize initial settings\n", "\n", "kn, *_ = forward([kn], [(-0.5, 0.5)])\n", "\n", "# Set model\n", "\n", "model = Wrapper(task, kn)\n", "\n", "# Set optimizer\n", "\n", "optimizer = AdEMAMix(model.parameters(), lr=lr)\n", "\n", "# Set features and labels \n", "\n", "X = torch.arange(len(ring))\n", "y = twiss_error.clone()\n", "\n", "# Set dataset\n", "# Note, full set is used here, batch size is too small otherwise\n", "\n", "batch_size = 16\n", "dataset = TensorDataset(X.clone(), y.clone())\n", "dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n", "\n", "# Set loss funtion\n", "\n", "lf = torch.nn.MSELoss()\n", "\n", "# Perfom optimization\n", "\n", "for epoch in range(64):\n", " for batch, (X, y) in enumerate(dataloader):\n", " y_hat = model(X)\n", " value = lf(y_hat, y)\n", " value.backward()\n", " optimizer.step()\n", " optimizer.zero_grad()\n", " with torch.no_grad():\n", " print(value.detach())" ] }, { "cell_type": "code", "execution_count": 24, "id": "1dcfc6d7-13fa-43fd-b632-ce7f26e4d287", "metadata": {}, "outputs": [], "source": [ "# Apply corrections\n", "\n", "kn, *_ = inverse([kn], [(-0.5, 0.5)])\n", "\n", "lattice:Line = error.clone()\n", "\n", "index = 0\n", "label = ''\n", "\n", "for line in lattice.sequence:\n", " for element in line:\n", " if element.__class__.__name__ == 'Quadrupole':\n", " if label != element.name:\n", " index +=1\n", " label = element.name\n", " element.kn = (element.kn - kn[index - 1]).item()" ] }, { "cell_type": "code", "execution_count": 25, "id": "61c51458-3e83-42d7-95fb-c5b8ae3ad7f3", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Compute twiss and plot beta beating\n", "\n", "ax_model, bx_model, ay_model, by_model = twiss(ring, [], alignment=False, matched=True, advance=True, full=False, convert=True).T\n", "ax_error, bx_error, ay_error, by_error = twiss(error, [], alignment=False, matched=True, advance=True, full=False, convert=True).T\n", "ax_final, bx_final, ay_final, by_final = twiss(lattice, [], alignment=False, matched=True, advance=True, full=False, convert=True).T\n", "\n", "# Plot beta beating\n", "\n", "plt.figure(figsize=(16, 2))\n", "plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_error)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='o')\n", "plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_error)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='o')\n", "plt.plot(ring.locations().cpu().numpy(), 100*((bx_model - bx_final)/bx_model).cpu().numpy(), color='red', alpha=0.75, marker='x')\n", "plt.plot(ring.locations().cpu().numpy(), 100*((by_model - by_final)/by_model).cpu().numpy(), color='blue', alpha=0.75, marker='x')\n", "plt.xticks(ticks=positions, labels=['BPM05', 'BPM07', 'BPM08', 'BPM09', 'BPM10', 'BPM11', 'BPM12', 'BPM13', 'BPM14', 'BPM15', 'BPM16', 'BPM17', 'BPM01', 'BPM02', 'BPM03', 'BPM04'])\n", "plt.tight_layout()\n", "plt.show()" ] } ], "metadata": { "colab": { "collapsed_sections": [ "myt0_gMIOq7b", "5d97819c" ], "name": "03_frequency.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.1" }, "latex_envs": { "LaTeX_envs_menu_present": true, "autoclose": false, "autocomplete": true, "bibliofile": "biblio.bib", "cite_by": "apalike", "current_citInitial": 1, "eqLabelWithNumbers": true, "eqNumInitial": 1, "hotkeys": { "equation": "Ctrl-E", "itemize": "Ctrl-I" }, "labels_anchors": false, "latex_user_defs": false, "report_style_numbering": false, "user_envs_cfg": false } }, "nbformat": 4, "nbformat_minor": 5 }