{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "556562f3-8ece-4517-8c93-ee5e2fc29131", "metadata": {}, "source": [ "# Example-01: Functional iteration" ] }, { "cell_type": "code", "execution_count": 1, "id": "baa39129-4d73-44c2-bf1e-88996ea48c84", "metadata": {}, "outputs": [], "source": [ "# In this exaple usage of nest and fold function factories is illustrated" ] }, { "cell_type": "code", "execution_count": 2, "id": "abeb45e5-3d66-4992-8739-e444b7b4eaaa", "metadata": {}, "outputs": [], "source": [ "# Import\n", "\n", "import jax\n", "\n", "from sympint import nest\n", "from sympint import nest_list\n", "\n", "from sympint import fold\n", "from sympint import fold_list" ] }, { "cell_type": "code", "execution_count": 3, "id": "690143cf-63c2-4353-9792-e22adffa5a7d", "metadata": {}, "outputs": [], "source": [ "# Set data type\n", "\n", "jax.config.update(\"jax_enable_x64\", True)" ] }, { "cell_type": "code", "execution_count": 4, "id": "866b9b35-6e07-4128-960c-eb7bc11e5ec1", "metadata": {}, "outputs": [], "source": [ "# Set device\n", "\n", "device, *_ = jax.devices('cpu')\n", "jax.config.update('jax_default_device', device)" ] }, { "cell_type": "code", "execution_count": 5, "id": "009d09c3-a6f7-4390-a779-8ca77367a6f2", "metadata": {}, "outputs": [], "source": [ "# Define a simple symplectic mapping\n", "\n", "a = jax.numpy.array(0.5)\n", "b = jax.numpy.array(1.0)\n", "x = jax.numpy.array([0.1, 0.0])\n", "\n", "def fn(x, a, b):\n", " q, p, *_ = x\n", " return jax.numpy.stack([p, -q + a*p + b*p**2])" ] }, { "cell_type": "code", "execution_count": 6, "id": "fa39dc40-b681-46c1-856a-6b69437418d3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "193 µs ± 6.88 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" ] } ], "source": [ "%%timeit\n", "fn(x, a, b)" ] }, { "cell_type": "code", "execution_count": 7, "id": "9ad9bcb3-8bf3-426b-b352-940e22efa321", "metadata": {}, "outputs": [], "source": [ "# The above mapping is compatible (composable) with JAX functions\n", "# In particular, jit can be used to speed it up, this is usefull for more complicated mappings" ] }, { "cell_type": "code", "execution_count": 8, "id": "131e3c01-c4a7-40b3-b0f4-8ee48550a934", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Array([ 0. , -0.1], dtype=float64)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Wrap and compile\n", "# Once compiled, the resulting function can be used efficiently with different inputs\n", "\n", "fj = jax.jit(fn)\n", "fj(x, a, b)" ] }, { "cell_type": "code", "execution_count": 9, "id": "d99b4dc5-46af-4f29-90b1-c0c138435230", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "5.16 µs ± 14.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%%timeit\n", "\n", "fj(x, a, b)" ] }, { "cell_type": "code", "execution_count": 10, "id": "b3116091-645b-4175-88e1-a9dd2022ec28", "metadata": {}, "outputs": [], "source": [ "# A common task is to iterate given mapping repeatedly \n", "# Normaly, this can be done with a regular Python loop\n", "# But Python loops are known to be slow and can't be compiled (without unrolling)" ] }, { "cell_type": "code", "execution_count": 11, "id": "749e2755-1839-4d0c-a101-716d6d8fa98e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "12.7 ms ± 108 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "%%timeit\n", "\n", "a = jax.numpy.array(0.5)\n", "b = jax.numpy.array(1.0)\n", "x = jax.numpy.array([0.1, 0.0])\n", "\n", "for _ in range(2**6):\n", " x = fn(x, a, b)" ] }, { "cell_type": "code", "execution_count": 12, "id": "76ccc327-447b-4abb-9a96-260aa9362930", "metadata": {}, "outputs": [], "source": [ "# jax.lax provides several constructs for efficient compilable looping, e.g. jax.lax.scan\n", "# While regular for loop will be unrolled, which will result in non practical compilation time for large number of iterations, scan allows to avoid it\n", "# For repeated mapping application, nest function can be used, which is a wrapper around jax.lax.scan" ] }, { "cell_type": "code", "execution_count": 13, "id": "324e6bca-6448-4238-80e3-be18b349fb7a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Array([-0.09687625, -0.05042709], dtype=float64)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Wrap and compile\n", "# Once compiled, the resulting function can be used efficiently with different inputs\n", "\n", "a = jax.numpy.array(0.5)\n", "b = jax.numpy.array(1.0)\n", "x = jax.numpy.array([0.1, 0.0])\n", "\n", "fj = jax.jit(nest(2**6, fn))\n", "fj(x, a, b)" ] }, { "cell_type": "code", "execution_count": 14, "id": "15600de2-0a45-43f2-bd3f-e92061b3b69b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "9.18 µs ± 87.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%%timeit\n", "fj(x, a, b)" ] }, { "cell_type": "code", "execution_count": 15, "id": "c639c33a-5573-4039-8d77-8b5b360f73fe", "metadata": {}, "outputs": [], "source": [ "# jax.lax.scan also allows to accumulate intermediate results\n", "# For mappings, nest_list function allows to accumulate the output at each iteration (excluding the initial value)" ] }, { "cell_type": "code", "execution_count": 16, "id": "86a9ee2b-1259-49eb-82c2-fb90661a1583", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(64, 2)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fj = jax.jit(nest_list(2**6, fn))\n", "\n", "a = jax.numpy.array(0.5)\n", "b = jax.numpy.array(1.0)\n", "x = jax.numpy.array([0.1, 0.0])\n", "\n", "fj(x, a, b).shape" ] }, { "cell_type": "code", "execution_count": 17, "id": "db8d0b2a-d1f7-457e-ac1a-cc631ebef623", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "11.3 µs ± 72.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%%timeit\n", "fj(x, a, b)" ] }, { "cell_type": "code", "execution_count": 18, "id": "d321b4e2-db44-47c6-ab86-2c9e97255f63", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Array([-0.09687625, -0.05042709], dtype=float64)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# fold function allows to apply a sequence of mappings\n", "# While mappings can be different, identical signature is assumed\n", "\n", "fj = jax.jit(fold(2**6*[fn]))\n", "\n", "a = jax.numpy.array(0.5)\n", "b = jax.numpy.array(1.0)\n", "x = jax.numpy.array([0.1, 0.0])\n", "\n", "fj(x, a, b)" ] }, { "cell_type": "code", "execution_count": 19, "id": "9214b4d9-f7dc-499d-8112-58bcb93d7878", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "13.4 µs ± 46.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%%timeit\n", "fj(x, a, b)" ] }, { "cell_type": "code", "execution_count": 20, "id": "2b20f7d7-914e-4c20-bf57-4d66ac6e07b8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(64, 2)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# fold with accumulation is also avaliable\n", "\n", "fj = jax.jit(fold_list(2**6*[fn]))\n", "\n", "a = jax.numpy.array(0.5)\n", "b = jax.numpy.array(1.0)\n", "x = jax.numpy.array([0.1, 0.0])\n", "\n", "fj(x, a, b).shape" ] }, { "cell_type": "code", "execution_count": 21, "id": "729ce84e-91e2-4e6c-a06e-f9c28f86ed87", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "15.6 µs ± 126 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%%timeit\n", "fj(x, a, b)" ] }, { "cell_type": "code", "execution_count": 22, "id": "d839559d-fa2d-4f03-9f12-30cb60b8a5b8", "metadata": {}, "outputs": [], "source": [ "# Other JAX functions, like vmap and grad can be applied to the results of nest and fold functions" ] }, { "cell_type": "code", "execution_count": 23, "id": "96f6f77d-a4d6-402f-8fb3-70fcb5853949", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(6, 64, 2)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Vectorized map\n", "\n", "xs = jax.numpy.array([[0.0, 0.0], [0.1, 0.0], [0.2, 0.0], [0.3, 0.0], [0.4, 0.0], [0.5, 0.0]])\n", "jax.vmap(fj, (0, None, None))(xs, a, b).shape" ] }, { "cell_type": "code", "execution_count": 24, "id": "24ab4a60-0142-4676-9321-3c8c13b315db", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(64, 2, 2)" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Jacobian\n", "\n", "jax.jacrev(fj)(x, a, b).shape" ] }, { "cell_type": "code", "execution_count": null, "id": "d8f77163-5601-4c55-b70c-d1f21fddf3a1", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "colab": { "collapsed_sections": [ "myt0_gMIOq7b", "5d97819c" ], "name": "03_frequency.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.1" }, "latex_envs": { "LaTeX_envs_menu_present": true, "autoclose": false, "autocomplete": true, "bibliofile": "biblio.bib", "cite_by": "apalike", "current_citInitial": 1, "eqLabelWithNumbers": true, "eqNumInitial": 1, "hotkeys": { "equation": "Ctrl-E", "itemize": "Ctrl-I" }, "labels_anchors": false, "latex_user_defs": false, "report_style_numbering": false, "user_envs_cfg": false } }, "nbformat": 4, "nbformat_minor": 5 }