{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "556562f3-8ece-4517-8c93-ee5e2fc29131", "metadata": {}, "source": [ "# Example-02: Yosida composition" ] }, { "cell_type": "code", "execution_count": 1, "id": "9fce7378-f37a-4133-8105-5d952ccacd61", "metadata": {}, "outputs": [], "source": [ "# Given a time-reversible integration step of difference order 2n\n", "# Yoshida composition procedure can be used to construct integration step of difference order 2(n+1)\n", "# Using Yoshida coefficients, new intergration step is w(2(n+1))(dt) = w(2n)(x1 dt) o w(2n)(x2 dt) o w(2n)(x1 dt)\n", "\n", "# If a hamiltonian vector field can be splitted into several sovable parts\n", "# Second order time-reversible symmetric integrator can be easily constructed as follows\n", "# w1(dt/2) o w2(dt/2) o ... o wn(dt/2) o wn(dt/2) o ... o w2(dt/2) o w1(dt/2)\n", "# where each wi is a mapping for corresponding hamiltonian\n", "# Yoshida composition procedure can be then applied repeatedly to obtain higher order integration steps" ] }, { "cell_type": "code", "execution_count": 2, "id": "cade4bf1-60ab-483a-83e3-b58a92a4007a", "metadata": {}, "outputs": [], "source": [ "# Import\n", "\n", "import jax\n", "\n", "# Function iterations\n", "\n", "from sympint import fold\n", "from sympint import nest\n", "from sympint import nest_list\n", "\n", "# Yoshida composition\n", "\n", "from sympint import weights\n", "from sympint import coefficients\n", "from sympint import table\n", "from sympint import sequence" ] }, { "cell_type": "code", "execution_count": 3, "id": "619bd94e-d295-4170-b629-715230982ce8", "metadata": {}, "outputs": [], "source": [ "# Set data type\n", "\n", "jax.config.update(\"jax_enable_x64\", True)" ] }, { "cell_type": "code", "execution_count": 4, "id": "9bd1cb3f-aa48-4890-a885-e8213d1a9bad", "metadata": {}, "outputs": [], "source": [ "# Set device\n", "\n", "device, *_ = jax.devices('cpu')\n", "jax.config.update('jax_default_device', device)" ] }, { "cell_type": "code", "execution_count": 5, "id": "cf615475-3668-44ea-8eb8-4c674cfdaab5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['1.351', '-1.702', '1.351']\n", "['1.175', '-1.349', '1.175']\n", "['1.116', '-1.232', '1.116']\n", "['1.087', '-1.174', '1.087']\n" ] } ], "source": [ "# Given integration step of difference order 2n\n", "# Yoshida weights for 2(n+1) order can be computed using weights function\n", "# Note, sum of weights is equal to one\n", "\n", "print([f'{weight:.3f}' for weight in weights(1)]) # 2 -> 4\n", "print([f'{weight:.3f}' for weight in weights(2)]) # 4 -> 6\n", "print([f'{weight:.3f}' for weight in weights(3)]) # 6 -> 8\n", "print([f'{weight:.3f}' for weight in weights(4)]) # 8 -> 10" ] }, { "cell_type": "code", "execution_count": 6, "id": "686ca893-f356-4f78-822b-7d7f418ad13a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['1.351', '-1.702', '1.351']\n", "['1.587', '-2.000', '1.587', '-1.823', '2.297', '-1.823', '1.587', '-2.000', '1.587']\n", "['1.175', '-1.349', '1.175']\n" ] } ], "source": [ "# Given integration step of difference order 2n\n", "# Yoshida coefficents for 2m difference order step can be computed using coefficients function\n", "# Note, sum of coefficients is equal to one\n", "\n", "print([f'{coefficient:.3f}' for coefficient in coefficients(1, 1)]) # 2 -> 4\n", "print([f'{coefficient:.3f}' for coefficient in coefficients(1, 2)]) # 2 -> 6\n", "print([f'{coefficient:.3f}' for coefficient in coefficients(2, 2)]) # 4 -> 6" ] }, { "cell_type": "code", "execution_count": 7, "id": "8e127617-1167-46ab-a0bd-307a144dc962", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0, 0, 0], ['1.351', '-1.702', '1.351']]\n", "[[0, 0, 0, 0, 0, 0, 0, 0, 0], ['1.587', '-2.000', '1.587', '-1.823', '2.297', '-1.823', '1.587', '-2.000', '1.587']]\n", "[[0, 0, 0], ['1.175', '-1.349', '1.175']]\n", "\n", "[[0, 1, 0], ['0.500', '1.000', '0.500']]\n", "[[0, 1, 0, 0, 1, 0, 0, 1, 0], ['0.676', '1.351', '0.676', '-0.851', '-1.702', '-0.851', '0.676', '1.351', '0.676']]\n", "\n", "[[0, 1, 0], ['0.500', '1.000', '0.500']]\n", "[[0, 1, 0, 1, 0, 1, 0], ['0.676', '1.351', '-0.176', '-1.702', '-0.176', '1.351', '0.676']]\n", "\n" ] } ], "source": [ "# Given a collection of mappings along with initial and final Yoshida orders (half the corresponding difference orders)\n", "# Corresponding Yoshida table can be computed using table function\n", "# Note, mapping can be an integation step\n", "\n", "# If mapping is an integration step, the last argument should be set to False\n", "\n", "ns, cs = table(1, 1, 1, False) ; print([ns, [f'{c:.3f}' for c in cs]]) # 2 -> 4\n", "ns, cs = table(1, 1, 2, False) ; print([ns, [f'{c:.3f}' for c in cs]]) # 2 -> 6\n", "ns, cs = table(1, 2, 2, False) ; print([ns, [f'{c:.3f}' for c in cs]]) # 4 -> 6\n", "print()\n", "\n", "# Constuct table from two mappings without merging\n", "# Note, number of mappings can be arbitrary\n", "\n", "ns, cs = table(2, 0, 0, False) ; print([ns, [f'{c:.3f}' for c in cs]]) # 2 -> 2\n", "ns, cs = table(2, 0, 1, False) ; print([ns, [f'{c:.3f}' for c in cs]]) # 2 -> 4\n", "print()\n", "\n", "# Constuct table from two mappings with merging\n", "\n", "ns, cs = table(2, 0, 0, True) ; print([ns, [f'{c:.3f}' for c in cs]]) # 2 -> 2\n", "ns, cs = table(2, 0, 1, True) ; print([ns, [f'{c:.3f}' for c in cs]]) # 2 -> 4\n", "print()" ] }, { "cell_type": "code", "execution_count": 8, "id": "83a3b9cc-f0e9-416e-a49e-c578840faf95", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 3 [0.121875 0.07226562]\n", " 9 [0.12162308 0.07215494]\n", " 27 [0.12163202 0.07215096]\n", " 81 [0.12163163 0.07215085]\n", "243 [0.12163164 0.07215085]\n", "\n", " 3 [0.121875 0.07226562]\n", " 7 [0.12162308 0.07215494]\n", " 19 [0.12163202 0.07215096]\n", " 55 [0.12163163 0.07215085]\n", "163 [0.12163164 0.07215085]\n", "\n" ] } ], "source": [ "# Construct integratinn step for a simple rotation Hamiltonian\n", "\n", "# H = H1 + H2\n", "# H1 = 1/2 q**2 -> [q, p] -> [q, p - t*q]\n", "# H2 = 1/2 p**2 -> [q, p] -> [q + t*q, p]\n", "\n", "# Set mappings\n", "\n", "def fn(x, dt):\n", " q, p = x\n", " return jax.numpy.stack([q, p - dt*q])\n", "\n", "def gn(x, dt):\n", " q, p = x\n", " return jax.numpy.stack([q + dt*p, p])\n", "\n", "# Set time step\n", "\n", "dt = jax.numpy.array(0.25)\n", "\n", "# Set initial condition\n", "\n", "x = jax.numpy.array([0.1, 0.1])\n", "\n", "# Generate and fold transformations (without mergin) for different final Yoshida orders\n", "\n", "for i in range(5):\n", " fns = sequence(0, i, [fn, gn], merge=False)\n", " print(f'{len(fns):>3} {fold(fns)(x, dt)}')\n", "print()\n", "\n", "# Generate and fold transformations (with mergin) for different final Yoshida orders\n", "\n", "for i in range(5):\n", " fns = sequence(0, i, [fn, gn], merge=True)\n", " print(f'{len(fns):>3} {fold(fns)(x, dt)}')\n", "print()" ] }, { "cell_type": "code", "execution_count": 9, "id": "7998eb68-64fb-4c9d-8a78-6b4fa8a6c778", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.25 s ± 69.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "%%timeit\n", "\n", "# Even with merging scanning throught a large number of mappings is slow\n", "# Note, fold is a wrapper around jax.lax.scan\n", " \n", "fold(fns)(x, dt)" ] }, { "cell_type": "code", "execution_count": 10, "id": "fb11e6fa-eec8-44d8-a5e6-cb206368ca1f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Array([0.12163164, 0.07215085], dtype=float64)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# This can be remedied by JAX jit compilation\n", "# The first execution will be still slow (compilation step)\n", "\n", "fj = jax.jit(fold(fns))\n", "fj(x, dt)" ] }, { "cell_type": "code", "execution_count": 11, "id": "b42e65c3-74fe-48e4-a285-d235366d0fa2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "28.1 µs ± 256 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "%%timeit\n", "\n", "# Compiled transformation is expected to be much faster\n", "# If the intend is to use it repeatedly with different initials, jit compilation is a way to go\n", "\n", "fj(x, dt)" ] }, { "cell_type": "code", "execution_count": 12, "id": "05056a2d-9fde-40ef-8390-4ab4b0d017e7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.86 ms ± 532 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "%%timeit\n", "\n", "# Compiled step is compatible with JAX functions\n", "# For example, it is possible to compute jacobian with respect to initial condition\n", "# Note, this might trigger a recompile\n", "\n", "jax.jacrev(fj)(x, dt)" ] }, { "cell_type": "code", "execution_count": 13, "id": "7e794b6d-62ce-47fc-a446-e6529dd8be4c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Array([[ 0.96891242, 0.24740395],\n", " [-0.24740396, 0.96891242]], dtype=float64)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Add one more layer of JIT and compile\n", "\n", "jacobian = jax.jit(jax.jacrev(fj))\n", "jacobian(x, dt)" ] }, { "cell_type": "code", "execution_count": 14, "id": "d0806855-653d-47a7-adb6-3420206fb317", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "81.5 µs ± 2.27 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "%%timeit\n", "\n", "# Time the resulting jacobian\n", "\n", "jacobian(x, dt)" ] }, { "cell_type": "code", "execution_count": 15, "id": "97947680-a9a6-4ee1-8182-c428ff1b943e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4.68 s ± 109 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "%%timeit\n", "\n", "# Perform several integrations steps (native Python loop over steps and step parts)\n", "\n", "fs = sequence(0, 5, [fn, gn], merge=True)\n", "\n", "dt = jax.numpy.array(0.25)\n", "x = jax.numpy.array([0.1, 0.1])\n", "\n", "for _ in range(64):\n", " for f in fs:\n", " x = f(x, dt)" ] }, { "cell_type": "code", "execution_count": 16, "id": "c7a4ef32-2743-450e-9dc2-38568b8c37d8", "metadata": {}, "outputs": [], "source": [ "# Compile\n", "\n", "dt = jax.numpy.array(0.25)\n", "x = jax.numpy.array([0.1, 0.1])\n", "\n", "fs = sequence(0, 5, [fn, gn], merge=True)\n", "fj = jax.jit(fold(fs))\n", "fj(x, dt) ;\n", "\n", "fj = nest(64, fj)\n", "fj = jax.jit(fj)\n", "fj(x, dt) ;" ] }, { "cell_type": "code", "execution_count": 17, "id": "594c7b27-2db3-4c2f-ae2d-d6ede07ffb40", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3.97 ms ± 51.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "%%timeit\n", "\n", "# Test (compilation time is excluded)\n", "\n", "fj(x, dt)" ] }, { "cell_type": "code", "execution_count": 18, "id": "5bd71185-d6c5-461f-9f43-5d5cef164f37", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.10943036 0.08841961]\n", "[0.10943036 0.08841961]\n" ] } ], "source": [ "# Setup a multistep integrator\n", "\n", "# H = H1 + H2\n", "# H1 = 1/2 q**2 + 1/3 q**3 -> [q, p] -> [q, p - t*q - t*q**2]\n", "# H2 = 1/2 p**2 -> [q, p] -> [q + t*q, p]\n", "\n", "dt = jax.numpy.array(0.1)\n", "x = jax.numpy.array([0.1, 0.1])\n", "\n", "def fn(x, t):\n", " q, p = x\n", " return jax.numpy.stack([q, p - t*q - t*q**2])\n", "\n", "def gn(x, t):\n", " q, p = x\n", " return jax.numpy.stack([q + t*p, p])\n", "\n", "fs = sequence(0, 1, [fn, gn], merge=True)\n", "print(fold(fs)(x, dt))\n", "\n", "# H = H1 + H2 + H3\n", "# H1 = 1/2 q**2 -> [q, p] -> [q, p - t*q]\n", "# H2 = 1/3 q**3 -> [q, p] -> [q, p - t*q**2]\n", "# H3 = 1/2 p**2 -> [q, p] -> [q + t*q, p]\n", "\n", "def fn(x, t):\n", " q, p = x\n", " return jax.numpy.stack([q, p - t*q])\n", "\n", "def gn(x, t):\n", " q, p = x\n", " return jax.numpy.stack([q, p - t*q**2])\n", "\n", "def hn(x, t):\n", " q, p = x\n", " return jax.numpy.stack([q + t*p, p])\n", "\n", "# Note, the last mapping in the list has the smallest number of evaluations\n", "\n", "fs = sequence(0, 1, [fn, gn, hn], merge=True)\n", "print(fold(fs)(x, dt))\n", "\n", "# Note, the result is identical since two parts commute" ] }, { "cell_type": "code", "execution_count": 19, "id": "9006c247-448e-40ba-9bf2-e39f257c952b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n", "True\n", "True\n" ] } ], "source": [ "# Increase order of an existing intergration step\n", "\n", "# Set time step and initial condition\n", "\n", "dt = jax.numpy.array(0.1)\n", "x = jax.numpy.array([0.1, 0.1])\n", "\n", "# Set transformations for sovable parts\n", "\n", "def fn(x, t):\n", " q, p = x\n", " return jax.numpy.stack([q, p - t*q - t*q**2])\n", "\n", "def gn(x, t):\n", " q, p = x\n", " return jax.numpy.stack([q + t*p, p])\n", "\n", "# Define 2nd, 4th and 6th order integration step from parts\n", "\n", "s2 = fold(sequence(0, 0, [fn, gn], merge=True))\n", "s4 = fold(sequence(0, 1, [fn, gn], merge=True))\n", "s6 = fold(sequence(0, 2, [fn, gn], merge=True))\n", "\n", "# Constuct 4th order integration step from a 2nd order one\n", "# And compare with 4th order step constructed from parts\n", "\n", "w4 = fold(sequence(1, 1, [s2], merge=False))\n", "print(jax.numpy.allclose(s4(x, dt), w4(x, dt)))\n", "\n", "# Construct 6th order from 4th order and compare\n", "\n", "w6 = fold(sequence(2, 2, [s4], merge=False))\n", "print(jax.numpy.allclose(s6(x, dt), w6(x, dt)))\n", "\n", "\n", "# Construct 6th order from 2nd order and compare\n", "\n", "w6 = fold(sequence(1, 2, [s2], merge=False))\n", "print(jax.numpy.allclose(s6(x, dt), w6(x, dt)))" ] }, { "cell_type": "code", "execution_count": 20, "id": "a42f2620-ca50-45a6-9ab4-2044cadc38d6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.10943036 0.08841961]\n", "\n" ] } ], "source": [ "# Pass fixed parameters\n", "\n", "def fn(x, t, a, b):\n", " q, p = x\n", " return jax.numpy.stack([q, p - a*t*q - b*t*q**2])\n", "\n", "def gn(x, t):\n", " q, p = x\n", " return jax.numpy.stack([q + t*p, p])\n", "\n", "t = jax.numpy.array(0.1)\n", "x = jax.numpy.array([0.1, 0.1])\n", "a = jax.numpy.array(1.0)\n", "b = jax.numpy.array(1.0)\n", "\n", "fj = jax.jit(fold(sequence(0, 1, [fn, gn], merge=True, parameters=[[a, b], []])))\n", "\n", "print(fj(x, t))\n", "print()" ] }, { "cell_type": "code", "execution_count": 21, "id": "13607bc3-9609-4d60-a274-7661064754b6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.10943036 0.08841961]\n", "\n" ] } ], "source": [ "# Integration step with parameters (matching signatures)\n", "\n", "def fn(x, t, a, b):\n", " q, p = x\n", " return jax.numpy.stack([q, p - a*t*q - b*t*q**2])\n", "\n", "def gn(x, t, a, b):\n", " q, p = x\n", " return jax.numpy.stack([q + t*p, p])\n", "\n", "t = jax.numpy.array(0.1)\n", "x = jax.numpy.array([0.1, 0.1])\n", "a = jax.numpy.array(1.0)\n", "b = jax.numpy.array(1.0)\n", "\n", "fj = jax.jit(fold(sequence(0, 1, [fn, gn], merge=True)))\n", "\n", "print(fj(x, t, a, b))\n", "print()" ] }, { "cell_type": "code", "execution_count": 22, "id": "8a76ce07-e600-4055-82c2-30fefb7ffd4d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ 0.99397345 0.09979712]\n", " [-0.12071788 0.99394275]]\n", "\n", "[ 0.08841321 -0.12140165]\n", "\n", "[-0.0005159 -0.0104603]\n", "\n", "[-5.32985126e-05 -1.09712073e-03]\n", "\n" ] } ], "source": [ "# Matched signatures allow to compute derivatives with respect to matched parameters\n", "\n", "for i in range(4): \n", " print(jax.jacrev(fj, i)(x, t, a, b))\n", " print()" ] } ], "metadata": { "colab": { "collapsed_sections": [ "myt0_gMIOq7b", "5d97819c" ], "name": "03_frequency.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.1" }, "latex_envs": { "LaTeX_envs_menu_present": true, "autoclose": false, "autocomplete": true, "bibliofile": "biblio.bib", "cite_by": "apalike", "current_citInitial": 1, "eqLabelWithNumbers": true, "eqNumInitial": 1, "hotkeys": { "equation": "Ctrl-E", "itemize": "Ctrl-I" }, "labels_anchors": false, "latex_user_defs": false, "report_style_numbering": false, "user_envs_cfg": false } }, "nbformat": 4, "nbformat_minor": 5 }