{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "556562f3-8ece-4517-8c93-ee5e2fc29131", "metadata": {}, "source": [ "# Example-06: Non-autonomous hamiltonian integration" ] }, { "cell_type": "code", "execution_count": 1, "id": "c9e9a3dc-a1c5-4f5b-b40d-dc97f4759d40", "metadata": {}, "outputs": [], "source": [ "# In this example integration of non-autonomous hamiltonian is illustrated\n", "# Such integration has only limmited support, since function iteration tools do not carry time\n", "# Thus, only one second order integration step can be performed and time should be adjusted manually after each step, i.e. using normal python loop or custom scan body\n", "\n", "# Support for more general case would require to modife function iterations, for example, instead of the following loop:\n", "# for _ in range(n): x = f(x, *args)\n", "# nesting should correspond to:\n", "# for _ in range(n): x = f(x, dt, t, *args) ; t = t + dt\n", "# Similary, fold (and other functions)should be modified to carry time\n", "\n", "# Instead, it is possible to use extended phase space with midpoint or tao integrators" ] }, { "cell_type": "code", "execution_count": 2, "id": "d92f8ee7-ea12-4a11-921d-a0b44783f3ac", "metadata": {}, "outputs": [], "source": [ "# Import \n", "\n", "import jax\n", "from jax import Array\n", "from jax import jit\n", "from jax import vmap\n", "\n", "from sympint import fold\n", "from sympint import nest\n", "from sympint import midpoint\n", "from sympint import sequence\n", "\n", "jax.numpy.set_printoptions(linewidth=256, precision=12)" ] }, { "cell_type": "code", "execution_count": 3, "id": "24d31165-c4b0-4e24-bd10-25695e6e008a", "metadata": {}, "outputs": [], "source": [ "# Set data type\n", "\n", "jax.config.update(\"jax_enable_x64\", True)" ] }, { "cell_type": "code", "execution_count": 4, "id": "0bbcbcc1-5fd7-4dc5-91d2-04157d0774d1", "metadata": {}, "outputs": [], "source": [ "# Set device\n", "\n", "device, *_ = jax.devices('cpu')\n", "jax.config.update('jax_default_device', device)" ] }, { "cell_type": "code", "execution_count": 5, "id": "f1208433-15a8-403d-b1ae-45d8f5b00b6f", "metadata": {}, "outputs": [], "source": [ "# Set parameters\n", "\n", "si = jax.numpy.array(0.0)\n", "ds = jax.numpy.array(0.01)\n", "kn = jax.numpy.array(1.0)" ] }, { "cell_type": "code", "execution_count": 6, "id": "4d85b8da-a6f6-425d-8b13-069af52e8541", "metadata": {}, "outputs": [], "source": [ "# Set initial condition\n", "\n", "qs = jax.numpy.array([0.1, 0.1])\n", "ps = jax.numpy.array([0.0, 0.0])\n", "x = jax.numpy.hstack([qs, ps])" ] }, { "cell_type": "code", "execution_count": 7, "id": "8b4ab5d9-676d-450d-8e9e-4927881bc542", "metadata": {}, "outputs": [], "source": [ "# Define hamiltonian\n", "\n", "def hamiltonian(qs, ps, s, kn, *args):\n", " q_x, q_y = qs\n", " p_x, p_y = ps\n", " return 1/2*(p_x**2 + p_y**2) + 1/2*kn*(1 + jax.numpy.cos(s))*(q_x**2 + q_y**2)" ] }, { "cell_type": "code", "execution_count": 8, "id": "dc84c3f8-14ce-4344-b5d4-faa013098b27", "metadata": {}, "outputs": [], "source": [ "# Set implicit midpoint integration step\n", "\n", "integrator = jit(fold(sequence(0, 0, [midpoint(hamiltonian, ns=2**4)], merge=False)))" ] }, { "cell_type": "code", "execution_count": 9, "id": "8c2ef87f-263b-47a6-a8f1-1f2242fb2773", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 0.017983795895 0.017983795895 -0.133154567382 -0.133154567382]\n" ] } ], "source": [ "# Perform integration with explicit time update\n", "\n", "time = si\n", "data = x\n", "for _ in range(10**2):\n", " data = integrator(data, ds, time, kn)\n", " time = time + ds\n", "print(data)" ] }, { "cell_type": "code", "execution_count": 10, "id": "2c5650e3-bf09-4d95-acbf-110d56001e47", "metadata": {}, "outputs": [], "source": [ "# Define hamiltonian (extended)\n", "\n", "def extended(qs, ps, s, kn, *args):\n", " q_x, q_y, q_t = qs\n", " p_x, p_y, p_t = ps\n", " return p_t + 1/2*(p_x**2 + p_y**2) + 1/2*kn*(1 + jax.numpy.cos(q_t))*(q_x**2 + q_y**2)" ] }, { "cell_type": "code", "execution_count": 11, "id": "184e803c-716e-4410-89fc-705ccc3e50ac", "metadata": {}, "outputs": [], "source": [ "# Set extended initial condition\n", "\n", "Qs = jax.numpy.concat([qs, si.reshape(-1)])\n", "Ps = jax.numpy.concat([ps, -hamiltonian(qs, ps, si, kn).reshape(-1)])\n", "X = jax.numpy.hstack([Qs, Ps])" ] }, { "cell_type": "code", "execution_count": 12, "id": "e00ec64a-601b-4f56-9404-f42d399a1616", "metadata": {}, "outputs": [], "source": [ "# Set implicit midpoint integration step using extended hamiltonian\n", "\n", "integrator = jit(fold(sequence(0, 0, [midpoint(extended, ns=2**4)], merge=False)))" ] }, { "cell_type": "code", "execution_count": 13, "id": "b2137273-e1a7-485d-8746-3601849015c6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 0.017983795895 0.017983795895 1. -0.133154567382 -0.133154567382 -0.018228323463]\n" ] } ], "source": [ "# Set and compile element\n", "\n", "element = jit(nest(10**2, integrator))\n", "out = element(X, ds, si, kn)\n", "print(out)" ] }, { "cell_type": "code", "execution_count": null, "id": "886d34e5-f90c-473e-b6fc-f99d7b534d71", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "colab": { "collapsed_sections": [ "myt0_gMIOq7b", "5d97819c" ], "name": "03_frequency.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.1" }, "latex_envs": { "LaTeX_envs_menu_present": true, "autoclose": false, "autocomplete": true, "bibliofile": "biblio.bib", "cite_by": "apalike", "current_citInitial": 1, "eqLabelWithNumbers": true, "eqNumInitial": 1, "hotkeys": { "equation": "Ctrl-E", "itemize": "Ctrl-I" }, "labels_anchors": false, "latex_user_defs": false, "report_style_numbering": false, "user_envs_cfg": false } }, "nbformat": 4, "nbformat_minor": 5 }