{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "556562f3-8ece-4517-8c93-ee5e2fc29131", "metadata": {}, "source": [ "# Example-05: Demo" ] }, { "cell_type": "code", "execution_count": 1, "id": "e06a70db-cb85-4323-ab8a-3ad7a2e3cdae", "metadata": {}, "outputs": [], "source": [ "# In this demo construction of symplectic integrators is illustated for basic accelerator elements" ] }, { "cell_type": "code", "execution_count": 2, "id": "952c915e-fd30-4aaf-9138-3b5cead35b95", "metadata": {}, "outputs": [], "source": [ "# Import\n", "\n", "import torch\n", "import jax\n", "\n", "# Exact solutions\n", "\n", "from model.library.transformations import drift\n", "from model.library.transformations import quadrupole\n", "from model.library.transformations import bend\n", "\n", "# Function iterations\n", "\n", "from sympint import nest\n", "from sympint import fold\n", "\n", "# Integrators and composer\n", "\n", "from sympint import sequence\n", "from sympint import midpoint\n", "from sympint import tao" ] }, { "cell_type": "code", "execution_count": 3, "id": "ee4e1220-79e1-4fdd-96f9-bd1130a9fdf1", "metadata": {}, "outputs": [], "source": [ "# Set data type\n", "\n", "jax.config.update(\"jax_enable_x64\", True)" ] }, { "cell_type": "code", "execution_count": 4, "id": "ef64756e-46fd-418d-80c9-0ffa62ede28b", "metadata": {}, "outputs": [], "source": [ "# Set device\n", "\n", "device, *_ = jax.devices('cpu')\n", "jax.config.update('jax_default_device', device)" ] }, { "cell_type": "code", "execution_count": 5, "id": "cc438bb8-7605-448a-a0ba-e26271991432", "metadata": {}, "outputs": [], "source": [ "# Define Hamiltonial functions for accelerator elements\n", "\n", "def h_drif(qs, ps, t, dp, *args):\n", " qx, qy = qs\n", " px, py = ps\n", " return 1/2*(px**2 + py**2)/(1 + dp)\n", "\n", "def h_quad(qs, ps, t, kn, ks, dp, *args):\n", " qx, qy = qs\n", " px, py = ps\n", " return 1/2*(px**2 + py**2)/(1 + dp) + 1/2*kn*(qx**2 - qy**2) - ks*qx*qy\n", "\n", "def h_bend(qs, ps, t, rd, kn, ks, dp, *args):\n", " qx, qy = qs\n", " px, py = ps\n", " return 1/2*(px**2 + py**2)/(1 + dp) - qx*dp/rd + qx**2/(2*rd**2) + 1/2*kn*(qx**2 - qy**2) - ks*qx*qy" ] }, { "cell_type": "code", "execution_count": 6, "id": "ac1e4ab4-f9ed-4a62-a48e-5da82fca0b73", "metadata": {}, "outputs": [], "source": [ "# Set parameters\n", "\n", "ti = torch.tensor(0.0, dtype=torch.float64)\n", "dt = torch.tensor(0.1, dtype=torch.float64)\n", "rd = torch.tensor(25.0, dtype=torch.float64)\n", "kn = torch.tensor(2.0, dtype=torch.float64)\n", "ks = torch.tensor(0.1, dtype=torch.float64)\n", "dp = torch.tensor(0.001, dtype=torch.float64)" ] }, { "cell_type": "code", "execution_count": 7, "id": "28d483c8-b97f-458b-aff7-7e01500cf600", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n" ] } ], "source": [ "# Hamiltonian conservation (drif)\n", "\n", "(qx, px, qy, py) = x = torch.tensor([0.01, 0.001, -0.005, 0.0005], dtype=torch.float64)\n", "qs = torch.stack([qx, qy])\n", "ps = torch.stack([px, py])\n", "hi = h_drif(qs, ps, ti, dp)\n", "\n", "(qx, px, qy, py) = drift(x, dp, dt)\n", "qs = torch.stack([qx, qy])\n", "ps = torch.stack([px, py])\n", "hf = h_drif(qs, ps, ti, dp)\n", "\n", "print(torch.allclose(hi, hf, rtol=1.0E-16, atol=1.0E-16))" ] }, { "cell_type": "code", "execution_count": 8, "id": "a9d9f83a-815e-4962-86d7-260edf692851", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n" ] } ], "source": [ "# Hamiltonian conservation (quad)\n", "\n", "(qx, px, qy, py) = x = torch.tensor([0.01, 0.001, -0.005, 0.0005], dtype=torch.float64)\n", "qs = torch.stack([qx, qy])\n", "ps = torch.stack([px, py])\n", "hi = h_quad(qs, ps, ti, kn, ks, dp)\n", "\n", "(qx, px, qy, py) = quadrupole(x, kn, ks, dp, dt)\n", "qs = torch.stack([qx, qy])\n", "ps = torch.stack([px, py])\n", "hf = h_quad(qs, ps, ti, kn, ks, dp)\n", "\n", "print(torch.allclose(hi, hf, rtol=1.0E-16, atol=1.0E-16))" ] }, { "cell_type": "code", "execution_count": 9, "id": "8584fc98-84ea-45cb-857c-e5bf2787f29e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n" ] } ], "source": [ "# Hamiltonian conservation (bend)\n", "\n", "(qx, px, qy, py) = x = torch.tensor([0.01, 0.001, -0.005, 0.0005], dtype=torch.float64)\n", "qs = torch.stack([qx, qy])\n", "ps = torch.stack([px, py])\n", "hi = h_bend(qs, ps, ti, rd, kn, ks, dp)\n", "\n", "(qx, px, qy, py) = bend(x, rd, kn, ks, dp, dt)\n", "qs = torch.stack([qx, qy])\n", "ps = torch.stack([px, py])\n", "hf = h_bend(qs, ps, ti, rd, kn, ks, dp)\n", "\n", "print(torch.allclose(hi, hf, rtol=1.0E-16, atol=1.0E-16))" ] }, { "cell_type": "code", "execution_count": 10, "id": "5fe2e1d0-8a07-47f7-9f66-5dd482945646", "metadata": {}, "outputs": [], "source": [ "# To illustrate (multi-map) split and (Yoshida) composition explicit symplectic integrator consider the following split\n", "# h = h1 + h2 = 1/2*(px**2 + py**2)/(1 + dp) - qx*dp/rd + qx**2/(2*rd**2) + 1/2*kn*(qx**2 - qy**2) - ks*qx*qy\n", "# h1 = 1/2*(px**2 + py**2)/(1 + dp)\n", "# qx = qx + dt*px/(1 + dp)\n", "# px = px\n", "# qy = qy + dt*py/(1 + dp)\n", "# py = py\n", "# h2 = - qx*dp/rd + qx**2/(2*rd**2) + 1/2*kn*(qx**2 - qy**2) - ks*qx*qy\n", "# qx = qx\n", "# px = px + dt*(dp/rd - qx/rd**2 - kn*qx + ks*qy)\n", "# qy = qy\n", "# py = py + dt*(kn*qy + ks*qx)\n", "\n", "def fa(x, dt, rd, kn, ks, dp):\n", " qx, qy, px, py = x\n", " return jax.numpy.stack([qx + dt*px/(1 + dp), qy + dt*py/(1 + dp), px, py])\n", "\n", "def fb(x, dt, rd, kn, ks, dp):\n", " qx, qy, px, py = x\n", " return jax.numpy.stack([qx, qy, px + dt*(dp/rd - qx/rd**2 - kn*qx + ks*qy), py + dt*(kn*qy + ks*qx)])" ] }, { "cell_type": "code", "execution_count": 11, "id": "8d1f1ca5-de8f-4822-a7ae-b028e92d2d5b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "8.030437562437563e-05\n", "8.030437539389926e-05\n", "\n", "8.030437562437563e-05\n", "8.03043756243756e-05\n", "\n", "[ 0.00999747 -0.0049949 -0.00105058 -0.0003978 ]\n", "[ 0.00999746 -0.00499491 -0.00105066 -0.00039784]\n", "9.384719084898185e-08\n" ] } ], "source": [ "# Yoshida (bend)\n", "\n", "# Generate integration step\n", "\n", "step = fold(sequence(0, 1, [fa, fb], merge=True))\n", "\n", "# Evaluate integration step\n", "\n", "(qx, px, qy, py) = x = torch.tensor([0.01, 0.001, -0.005, 0.0005], dtype=torch.float64)\n", "qs = torch.stack([qx, qy])\n", "ps = torch.stack([px, py])\n", "qsps = jax.numpy.array(torch.hstack([qs, ps]).tolist())\n", "qs, ps = qsps.reshape(2, -1)\n", "print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))\n", "\n", "qsps = step(qsps, dt.item(), rd.item(), kn.item(), ks.item(), dp.item())\n", "qs, ps = qsps.reshape(2, -1)\n", "print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))\n", "print()\n", "\n", "# Evaluate exact solution\n", "\n", "(qx, px, qy, py) = x = torch.tensor([0.01, 0.001, -0.005, 0.0005], dtype=torch.float64)\n", "qs = torch.stack([qx, qy])\n", "ps = torch.stack([px, py])\n", "QsPs = jax.numpy.array(torch.hstack([qs, ps]).tolist())\n", "qs, ps = QsPs.reshape(2, -1)\n", "print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))\n", "\n", "(qx, px, qy, py) = bend(x, rd, kn, ks, dp, dt)\n", "qs = torch.stack([qx, qy])\n", "ps = torch.stack([px, py])\n", "QsPs = jax.numpy.array(torch.hstack([qs, ps]).tolist())\n", "qs, ps = QsPs.reshape(2, -1)\n", "print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))\n", "print()\n", "\n", "# Compare\n", "\n", "print(qsps)\n", "print(QsPs)\n", "print(jax.numpy.linalg.norm(qsps - QsPs))" ] }, { "cell_type": "code", "execution_count": 12, "id": "ae81ccde-f39a-4db5-950c-5d820aa2bd41", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "8.030437562437563e-05\n", "8.030437562437561e-05\n", "\n", "8.030437562437563e-05\n", "8.03043756243756e-05\n", "\n", "[ 0.00999747 -0.0049949 -0.00105061 -0.00039781]\n", "[ 0.00999746 -0.00499491 -0.00105066 -0.00039784]\n", "5.8710763609389174e-08\n" ] } ], "source": [ "# Midpoint (bend)\n", "\n", "# Generate integration step\n", "\n", "step = fold(sequence(0, 1, [midpoint(h_bend, ns=1)], merge=False))\n", "\n", "# Evaluate integration step\n", "\n", "(qx, px, qy, py) = x = torch.tensor([0.01, 0.001, -0.005, 0.0005], dtype=torch.float64)\n", "qs = torch.stack([qx, qy])\n", "ps = torch.stack([px, py])\n", "qsps = jax.numpy.array(torch.hstack([qs, ps]).tolist())\n", "qs, ps = qsps.reshape(2, -1)\n", "print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))\n", "\n", "qsps = step(qsps, dt.item(), ti.item(), rd.item(), kn.item(), ks.item(), dp.item())\n", "qs, ps = qsps.reshape(2, -1)\n", "print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))\n", "print()\n", "\n", "# Evaluate exact solution\n", "\n", "(qx, px, qy, py) = x = torch.tensor([0.01, 0.001, -0.005, 0.0005], dtype=torch.float64)\n", "qs = torch.stack([qx, qy])\n", "ps = torch.stack([px, py])\n", "QsPs = jax.numpy.array(torch.hstack([qs, ps]).tolist())\n", "qs, ps = QsPs.reshape(2, -1)\n", "print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))\n", "\n", "(qx, px, qy, py) = bend(x, rd, kn, ks, dp, dt)\n", "qs = torch.stack([qx, qy])\n", "ps = torch.stack([px, py])\n", "QsPs = jax.numpy.array(torch.hstack([qs, ps]).tolist())\n", "qs, ps = QsPs.reshape(2, -1)\n", "print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))\n", "print()\n", "\n", "# Compare\n", "\n", "print(qsps)\n", "print(QsPs)\n", "print(jax.numpy.linalg.norm(qsps - QsPs))" ] }, { "cell_type": "code", "execution_count": 13, "id": "8cb1fc97-88ad-4911-9485-d98939316fe6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "8.030437562437563e-05\n", "8.030437585721544e-05\n", "\n", "8.030437562437563e-05\n", "8.03043756243756e-05\n", "\n", "[ 0.00999747 -0.0049949 -0.00105064 -0.00039783]\n", "[ 0.00999746 -0.00499491 -0.00105066 -0.00039784]\n", "2.50766882131807e-08\n" ] } ], "source": [ "# Tao (bend)\n", "\n", "# Generate integration step\n", "\n", "step = fold(sequence(0, 1, [tao(h_bend, binding=0.0)], merge=False))\n", "\n", "# Evaluate integration step\n", "\n", "(qx, px, qy, py) = x = torch.tensor([0.01, 0.001, -0.005, 0.0005], dtype=torch.float64)\n", "qs = torch.stack([qx, qy])\n", "ps = torch.stack([px, py])\n", "qsps = jax.numpy.array(torch.hstack([qs, ps]).tolist())\n", "qs, ps = qsps.reshape(2, -1)\n", "print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))\n", "\n", "qsps = step(qsps, dt.item(), ti.item(), rd.item(), kn.item(), ks.item(), dp.item())\n", "qs, ps = qsps.reshape(2, -1)\n", "print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))\n", "print()\n", "\n", "# Evaluate exact solution\n", "\n", "(qx, px, qy, py) = x = torch.tensor([0.01, 0.001, -0.005, 0.0005], dtype=torch.float64)\n", "qs = torch.stack([qx, qy])\n", "ps = torch.stack([px, py])\n", "QsPs = jax.numpy.array(torch.hstack([qs, ps]).tolist())\n", "qs, ps = QsPs.reshape(2, -1)\n", "print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))\n", "\n", "(qx, px, qy, py) = bend(x, rd, kn, ks, dp, dt)\n", "qs = torch.stack([qx, qy])\n", "ps = torch.stack([px, py])\n", "QsPs = jax.numpy.array(torch.hstack([qs, ps]).tolist())\n", "qs, ps = QsPs.reshape(2, -1)\n", "print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))\n", "print()\n", "\n", "# Compare\n", "\n", "print(qsps)\n", "print(QsPs)\n", "print(jax.numpy.linalg.norm(qsps - QsPs))" ] }, { "cell_type": "code", "execution_count": null, "id": "666415bc-bb5c-4016-9e6d-1ee05bfeeed2", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "colab": { "collapsed_sections": [ "myt0_gMIOq7b", "5d97819c" ], "name": "03_frequency.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.1" }, "latex_envs": { "LaTeX_envs_menu_present": true, "autoclose": false, "autocomplete": true, "bibliofile": "biblio.bib", "cite_by": "apalike", "current_citInitial": 1, "eqLabelWithNumbers": true, "eqNumInitial": 1, "hotkeys": { "equation": "Ctrl-E", "itemize": "Ctrl-I" }, "labels_anchors": false, "latex_user_defs": false, "report_style_numbering": false, "user_envs_cfg": false } }, "nbformat": 4, "nbformat_minor": 5 }