{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "18a00f4b-a5c8-47d3-a8d3-5983bfa1021a", "metadata": {}, "source": [ "# Example-24: Wrapper" ] }, { "cell_type": "code", "execution_count": 1, "id": "528a7ace-202c-4f58-93a3-d46f74c2f021", "metadata": {}, "outputs": [], "source": [ "# In this example construction of parametric call wrappers is illustrated\n", "# For elements, all deviation parameters are passed as dictionary\n", "# Wrapped elements are invoked using positional agruments" ] }, { "cell_type": "code", "execution_count": 2, "id": "4021a7e1-6901-4ca4-8ce9-87975e3f465c", "metadata": {}, "outputs": [], "source": [ "# Import\n", "\n", "from pprint import pprint\n", "\n", "import torch\n", "\n", "from twiss import twiss\n", "\n", "from ndmap.pfp import parametric_fixed_point\n", "from ndmap.evaluate import evaluate\n", "from ndmap.signature import chop\n", "\n", "from model.library.drift import Drift\n", "from model.library.multipole import Multipole\n", "from model.library.dipole import Dipole\n", "from model.library.line import Line\n", "\n", "from model.command.wrapper import wrapper" ] }, { "cell_type": "code", "execution_count": 3, "id": "9917e2ae-8f66-448c-a28c-33ceafe5ddc2", "metadata": {}, "outputs": [], "source": [ "# Define simple FODO based lattice using nested lines\n", "\n", "QF = Multipole('QF', 0.5, +0.20)\n", "QD = Multipole('QD', 0.5, -0.19)\n", "DR = Drift('DR', 0.75)\n", "BM = Dipole('BM', 3.50, torch.pi/4.0)\n", "\n", "FODO = [QF, DR, BM, DR, QD, QD, DR, BM, DR, QF]\n", "\n", "FODO_A = Line('FODO_A', FODO, propagate=True, dp=0.0, exact=False, output=False, matrix=False)\n", "FODO_B = Line('FODO_B', FODO, propagate=True, dp=0.0, exact=False, output=False, matrix=False)\n", "FODO_C = Line('FODO_C', FODO, propagate=True, dp=0.0, exact=False, output=False, matrix=False)\n", "FODO_D = Line('FODO_D', FODO, propagate=True, dp=0.0, exact=False, output=False, matrix=False)\n", "\n", "LINE_AB = Line('LINE_AB', [FODO_A, FODO_B], propagate=True, dp=0.0, exact=False, output=False, matrix=False)\n", "LINE_CD = Line('LINE_CD', [FODO_C, FODO_D], propagate=True, dp=0.0, exact=False, output=False, matrix=False)\n", "\n", "RING = Line('RING', [LINE_AB, LINE_CD], propagate=True, dp=0.0, exact=False, output=False, matrix=False)" ] }, { "cell_type": "code", "execution_count": 4, "id": "ce6538bb-75bb-41ee-a450-d8bb5d5383bd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'LINE_AB': {'FODO_A': {'QF': {'kn': tensor(0., dtype=torch.float64),\n", " 'ks': tensor(0., dtype=torch.float64),\n", " 'ms': tensor(0., dtype=torch.float64),\n", " 'mo': tensor(0., dtype=torch.float64),\n", " 'dp': tensor(0., dtype=torch.float64),\n", " 'dl': tensor(0., dtype=torch.float64)},\n", " 'DR': {'dp': tensor(0., dtype=torch.float64),\n", " 'dl': tensor(0., dtype=torch.float64)},\n", " 'BM': {'dw': tensor(0., dtype=torch.float64),\n", " 'e1': tensor(0., dtype=torch.float64),\n", " 'e2': tensor(0., dtype=torch.float64),\n", " 'kn': tensor(0., dtype=torch.float64),\n", " 'ks': tensor(0., dtype=torch.float64),\n", " 'ms': tensor(0., dtype=torch.float64),\n", " 'mo': tensor(0., dtype=torch.float64),\n", " 'dp': tensor(0., dtype=torch.float64),\n", " 'dl': tensor(0., dtype=torch.float64)},\n", " 'QD': {'kn': tensor(0., dtype=torch.float64),\n", " 'ks': tensor(0., dtype=torch.float64),\n", " 'ms': tensor(0., dtype=torch.float64),\n", " 'mo': tensor(0., dtype=torch.float64),\n", " 'dp': tensor(0., dtype=torch.float64),\n", " 'dl': tensor(0., dtype=torch.float64)}},\n", " 'FODO_B': {'QF': {'kn': tensor(0., dtype=torch.float64),\n", " 'ks': tensor(0., dtype=torch.float64),\n", " 'ms': tensor(0., dtype=torch.float64),\n", " 'mo': tensor(0., dtype=torch.float64),\n", " 'dp': tensor(0., dtype=torch.float64),\n", " 'dl': tensor(0., dtype=torch.float64)},\n", " 'DR': {'dp': tensor(0., dtype=torch.float64),\n", " 'dl': tensor(0., dtype=torch.float64)},\n", " 'BM': {'dw': tensor(0., dtype=torch.float64),\n", " 'e1': tensor(0., dtype=torch.float64),\n", " 'e2': tensor(0., dtype=torch.float64),\n", " 'kn': tensor(0., dtype=torch.float64),\n", " 'ks': tensor(0., dtype=torch.float64),\n", " 'ms': tensor(0., dtype=torch.float64),\n", " 'mo': tensor(0., dtype=torch.float64),\n", " 'dp': tensor(0., dtype=torch.float64),\n", " 'dl': tensor(0., dtype=torch.float64)},\n", " 'QD': {'kn': tensor(0., dtype=torch.float64),\n", " 'ks': tensor(0., dtype=torch.float64),\n", " 'ms': tensor(0., dtype=torch.float64),\n", " 'mo': tensor(0., dtype=torch.float64),\n", " 'dp': tensor(0., dtype=torch.float64),\n", " 'dl': tensor(0., dtype=torch.float64)}}},\n", " 'LINE_CD': {'FODO_C': {'QF': {'kn': tensor(0., dtype=torch.float64),\n", " 'ks': tensor(0., dtype=torch.float64),\n", " 'ms': tensor(0., dtype=torch.float64),\n", " 'mo': tensor(0., dtype=torch.float64),\n", " 'dp': tensor(0., dtype=torch.float64),\n", " 'dl': tensor(0., dtype=torch.float64)},\n", " 'DR': {'dp': tensor(0., dtype=torch.float64),\n", " 'dl': tensor(0., dtype=torch.float64)},\n", " 'BM': {'dw': tensor(0., dtype=torch.float64),\n", " 'e1': tensor(0., dtype=torch.float64),\n", " 'e2': tensor(0., dtype=torch.float64),\n", " 'kn': tensor(0., dtype=torch.float64),\n", " 'ks': tensor(0., dtype=torch.float64),\n", " 'ms': tensor(0., dtype=torch.float64),\n", " 'mo': tensor(0., dtype=torch.float64),\n", " 'dp': tensor(0., dtype=torch.float64),\n", " 'dl': tensor(0., dtype=torch.float64)},\n", " 'QD': {'kn': tensor(0., dtype=torch.float64),\n", " 'ks': tensor(0., dtype=torch.float64),\n", " 'ms': tensor(0., dtype=torch.float64),\n", " 'mo': tensor(0., dtype=torch.float64),\n", " 'dp': tensor(0., dtype=torch.float64),\n", " 'dl': tensor(0., dtype=torch.float64)}},\n", " 'FODO_D': {'QF': {'kn': tensor(0., dtype=torch.float64),\n", " 'ks': tensor(0., dtype=torch.float64),\n", " 'ms': tensor(0., dtype=torch.float64),\n", " 'mo': tensor(0., dtype=torch.float64),\n", " 'dp': tensor(0., dtype=torch.float64),\n", " 'dl': tensor(0., dtype=torch.float64)},\n", " 'DR': {'dp': tensor(0., dtype=torch.float64),\n", " 'dl': tensor(0., dtype=torch.float64)},\n", " 'BM': {'dw': tensor(0., dtype=torch.float64),\n", " 'e1': tensor(0., dtype=torch.float64),\n", " 'e2': tensor(0., dtype=torch.float64),\n", " 'kn': tensor(0., dtype=torch.float64),\n", " 'ks': tensor(0., dtype=torch.float64),\n", " 'ms': tensor(0., dtype=torch.float64),\n", " 'mo': tensor(0., dtype=torch.float64),\n", " 'dp': tensor(0., dtype=torch.float64),\n", " 'dl': tensor(0., dtype=torch.float64)},\n", " 'QD': {'kn': tensor(0., dtype=torch.float64),\n", " 'ks': tensor(0., dtype=torch.float64),\n", " 'ms': tensor(0., dtype=torch.float64),\n", " 'mo': tensor(0., dtype=torch.float64),\n", " 'dp': tensor(0., dtype=torch.float64),\n", " 'dl': tensor(0., dtype=torch.float64)}}}}\n" ] } ], "source": [ "# Deviation variables are passed to elements/lines as dictionaries\n", "# In order to compute derivatives with respect to a deviation variable\n", "# A tensor should be binded to a corresponding leaf deviation dictionary value\n", "\n", "pprint(RING.data(alignment=False), sort_dicts=False)" ] }, { "cell_type": "code", "execution_count": 5, "id": "5f7c80ac-43ca-48b8-b987-5fc2849989d9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[tensor([0., 0., 0., 0.], dtype=torch.float64),\n", " tensor([[4.4462],\n", " [0.0000],\n", " [0.0000],\n", " [0.0000]], dtype=torch.float64)]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Compute parametric closed orbit (first order with respect to momentum deviation)\n", "\n", "# Without wrapping, all momenta deviation occurances should be binded to a singel tensor\n", "# Hence, deviation table should be traversed recursively down to all leafs\n", "\n", "def scan(data, name, target):\n", " for key, value in data.items():\n", " if isinstance(value, dict):\n", " scan(value, name, target)\n", " elif key == name:\n", " data[key] = target\n", "\n", "# Set ring function\n", "\n", "def ring(state, dp):\n", " dp, *_ = dp\n", " data = RING.data()\n", " scan(data, 'dp', dp)\n", " return RING(state, data=data)\n", "\n", "# Set deviations\n", "\n", "fp = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)\n", "dp = torch.tensor([0.0], dtype=torch.float64)\n", "\n", "# Compute pfp\n", "\n", "pfp, *_ = parametric_fixed_point((1, ), fp, [dp], ring)\n", "chop(pfp)\n", "pfp" ] }, { "cell_type": "code", "execution_count": 6, "id": "cbdb067a-c35d-41c2-b026-da59ba80cde3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[tensor([0., 0., 0., 0.], dtype=torch.float64),\n", " tensor([[4.4462],\n", " [0.0000],\n", " [0.0000],\n", " [0.0000]], dtype=torch.float64)]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Using wrapper we can define the about ring function as follows\n", "\n", "fn = wrapper(RING, (None, None, 'dp'))\n", "\n", "# Set deviations\n", "\n", "fp = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)\n", "dp = torch.tensor([0.0], dtype=torch.float64)\n", "\n", "# Compute pfp\n", "\n", "pfp, *_ = parametric_fixed_point((1, ), fp, [dp], fn)\n", "chop(pfp)\n", "pfp" ] }, { "cell_type": "code", "execution_count": 7, "id": "4798c66e-42d8-4b68-a40b-b7b8ca587aaa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([0.6951, 0.7019], dtype=torch.float64)\n", "tensor([-2.0649, -0.8260], dtype=torch.float64)\n" ] } ], "source": [ "# Compute chromaticity (without wrapping)\n", "\n", "def scan(data, name, target):\n", " for key, value in data.items():\n", " if isinstance(value, dict):\n", " scan(value, name, target)\n", " elif key == name:\n", " data[key] = target\n", "\n", "# Set ring function\n", "\n", "def ring(state, dp):\n", " dp, *_ = dp\n", " data = RING.data()\n", " scan(data, 'dp', dp)\n", " return RING(state , data=data)\n", "\n", "# Set ring function around pfp\n", "\n", "def pfp_ring(state, dp):\n", " return ring(state + evaluate(pfp, [dp]), dp) - evaluate(pfp, [dp])\n", "\n", "# Set tune function\n", "\n", "def tune(dp):\n", " matrix = torch.func.jacrev(pfp_ring)(state, dp)\n", " tunes, *_ = twiss(matrix)\n", " return tunes\n", "\n", "state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)\n", "dp = torch.tensor([0.0], dtype=torch.float64)\n", "\n", "print(tune(dp))\n", "print(torch.func.jacrev(tune)(dp).squeeze())" ] }, { "cell_type": "code", "execution_count": 8, "id": "8bfde912-0d8e-421a-86cb-ffae76d99611", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([0.6951, 0.7019], dtype=torch.float64)\n", "tensor([-2.0649, -0.8260], dtype=torch.float64)\n" ] } ], "source": [ "# Compute chromaticity (with wrapping)\n", "\n", "# Set ring function\n", "\n", "fn = wrapper(RING, (None, None, 'dp'))\n", "\n", "# Set ring function around pfp\n", "\n", "def pfp_ring(state, dp):\n", " return fn(state + evaluate(pfp, [dp]), dp) - evaluate(pfp, [dp])\n", "\n", "# Set tune function\n", "\n", "def tune(dp):\n", " matrix = torch.func.jacrev(pfp_ring)(state, dp)\n", " tunes, *_ = twiss(matrix)\n", " return tunes\n", "\n", "state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)\n", "dp = torch.tensor([0.0], dtype=torch.float64)\n", "\n", "print(tune(dp))\n", "print(torch.func.jacrev(tune)(dp).squeeze())" ] }, { "cell_type": "code", "execution_count": 9, "id": "54572eed-7903-4c13-9fe4-c0866268fe32", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([0.6951, 0.7019], dtype=torch.float64)\n", "tensor([-2.0649, -0.8260], dtype=torch.float64)\n", "tensor([[ 25.8500, 1.0470],\n", " [ -9.0271, -16.4821]], dtype=torch.float64)\n" ] } ], "source": [ "# Compute chromaticity derivative with respect to sextupole ampitudes (without wrapping)\n", "\n", "def scan(data, name, target):\n", " for key, value in data.items():\n", " if isinstance(value, dict):\n", " scan(value, name, target)\n", " elif key == name:\n", " data[key] = target\n", "\n", "def ring(state, dp, dms):\n", " dp, *_ = dp\n", " dmsf, dmsd, *_ = dms\n", " data = RING.data()\n", " scan(data, 'dp', dp)\n", " data['LINE_AB']['FODO_A']['QF']['ms'] = dmsf\n", " data['LINE_AB']['FODO_B']['QF']['ms'] = dmsf\n", " data['LINE_CD']['FODO_C']['QF']['ms'] = dmsf\n", " data['LINE_CD']['FODO_D']['QF']['ms'] = dmsf\n", " data['LINE_AB']['FODO_A']['QD']['ms'] = dmsd\n", " data['LINE_AB']['FODO_B']['QD']['ms'] = dmsd\n", " data['LINE_CD']['FODO_C']['QD']['ms'] = dmsd\n", " data['LINE_CD']['FODO_D']['QD']['ms'] = dmsd \n", " return RING(state, data=data)\n", "\n", "def pfp_ring(state, dp, dms):\n", " return ring(state + evaluate(pfp, [dp]), dp, dms) - evaluate(pfp, [dp])\n", "\n", "\n", "def tune(dp, dms): \n", " matrix = torch.func.jacrev(pfp_ring)(state, dp, dms)\n", " tunes, *_ = twiss(matrix)\n", " return tunes\n", "\n", "def chromaticity(dms):\n", " return torch.func.jacrev(tune)(dp, dms)\n", "\n", "\n", "state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)\n", "dms = torch.tensor([0.0, 0.0], dtype=torch.float64)\n", "dp = torch.tensor([0.0], dtype=torch.float64)\n", "\n", "print(tune(dp, dms))\n", "print(chromaticity(dms).squeeze())\n", "print(torch.func.jacrev(chromaticity)(dms).squeeze())" ] }, { "cell_type": "code", "execution_count": 10, "id": "6bc3d32d-fdb2-4899-b6ff-3ca46a598e75", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([0.6951, 0.7019], dtype=torch.float64)\n", "tensor([-2.0649, -0.8260], dtype=torch.float64)\n", "tensor([[ 25.8500, 1.0470],\n", " [ -9.0271, -16.4821]], dtype=torch.float64)\n" ] } ], "source": [ "# Compute chromaticity derivative with respect to sextupole ampitudes (with wrapping)\n", "\n", "def scan(data, name, target):\n", " for key, value in data.items():\n", " if isinstance(value, dict):\n", " scan(value, name, target)\n", " elif key == name:\n", " data[key] = target\n", "\n", "ring = wrapper(RING, (None, None, 'dp'), (None, ['QF', 'QD'], 'ms'))\n", "\n", "def pfp_ring(state, dp, dms):\n", " return ring(state + evaluate(pfp, [dp]), dp, dms) - evaluate(pfp, [dp])\n", "\n", "\n", "def tune(dp, dms): \n", " matrix = torch.func.jacrev(pfp_ring)(state, dp, dms)\n", " tunes, *_ = twiss(matrix)\n", " return tunes\n", "\n", "def chromaticity(dms):\n", " return torch.func.jacrev(tune)(dp, dms)\n", "\n", "\n", "state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)\n", "dms = torch.tensor([0.0, 0.0], dtype=torch.float64)\n", "dp = torch.tensor([0.0], dtype=torch.float64)\n", "\n", "print(tune(dp, dms))\n", "print(chromaticity(dms).squeeze())\n", "print(torch.func.jacrev(chromaticity)(dms).squeeze())" ] }, { "cell_type": "code", "execution_count": 11, "id": "d08a3b1f-efa1-4a4c-b033-04b661766f7b", "metadata": {}, "outputs": [], "source": [ "# The above examples demonstrate how to bind tensors to all leafs or to given elements (in all lines)\n", "\n", "# (None, None, parameter:str) -- bind tensor to all leaf parameters\n", "# (None, names:list[str], parameter:str) -- bind tensor to all leaf parameters in specified elements\n", "# (path:list[str], names:list[str], parameter:str) -- bind tensor to all leaf parameters in specified elements in given path (path to specific line)" ] }, { "cell_type": "code", "execution_count": 12, "id": "3ba0d7fe-1108-4064-b555-92884db99b4a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([0.6951, 0.7019], dtype=torch.float64)\n", "tensor([-2.0649, -0.8260], dtype=torch.float64)\n", "tensor([[ 25.8500, 1.0470],\n", " [ -9.0271, -16.4821]], dtype=torch.float64)\n" ] } ], "source": [ "# Bind QF and QD in all sublines of a given line (1/2 of sextupoles)\n", "\n", "ring = wrapper(RING, (None, None, 'dp'), (['LINE_AB'], ['QF', 'QD'], 'ms'))\n", "\n", "def pfp_ring(state, dp, dms):\n", " return ring(state + evaluate(pfp, [dp]), dp, dms) - evaluate(pfp, [dp])\n", "\n", "def tune(dp, dms): \n", " matrix = torch.func.jacrev(pfp_ring)(state, dp, dms)\n", " tunes, *_ = twiss(matrix)\n", " return tunes\n", "\n", "def chromaticity(dms):\n", " return torch.func.jacrev(tune)(dp, dms)\n", "\n", "state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)\n", "dms = torch.tensor([0.0, 0.0], dtype=torch.float64)\n", "dp = torch.tensor([0.0], dtype=torch.float64)\n", "\n", "print(tune(dp, dms))\n", "print(chromaticity(dms).squeeze())\n", "print(2*torch.func.jacrev(chromaticity)(dms).squeeze())" ] }, { "cell_type": "code", "execution_count": 13, "id": "a8574ed7-bcc2-4914-aa6b-b1d635a94cfb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([0.6951, 0.7019], dtype=torch.float64)\n", "tensor([-2.0649, -0.8260], dtype=torch.float64)\n", "tensor([[ 25.8500, 1.0470],\n", " [ -9.0271, -16.4821]], dtype=torch.float64)\n" ] } ], "source": [ "# Bind QF and QD in a given leaf line (1/4 of sextupoles)\n", "\n", "ring = wrapper(RING, (None, None, 'dp'), (['LINE_AB', 'FODO_A'], ['QF', 'QD'], 'ms'))\n", "\n", "def pfp_ring(state, dp, dms):\n", " return ring(state + evaluate(pfp, [dp]), dp, dms) - evaluate(pfp, [dp])\n", "\n", "def tune(dp, dms): \n", " matrix = torch.func.jacrev(pfp_ring)(state, dp, dms)\n", " tunes, *_ = twiss(matrix)\n", " return tunes\n", "\n", "def chromaticity(dms):\n", " return torch.func.jacrev(tune)(dp, dms)\n", "\n", "state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)\n", "dms = torch.tensor([0.0, 0.0], dtype=torch.float64)\n", "dp = torch.tensor([0.0], dtype=torch.float64)\n", "\n", "print(tune(dp, dms))\n", "print(chromaticity(dms).squeeze())\n", "print(4*torch.func.jacrev(chromaticity)(dms).squeeze())" ] }, { "cell_type": "code", "execution_count": 14, "id": "9a162ff6-0c09-4151-936a-d12d86be1683", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([0.6951, 0.7019], dtype=torch.float64)\n", "tensor([-2.0649, -0.8260], dtype=torch.float64)\n", "tensor([[12.9250, 0.5235],\n", " [-4.5135, -8.2411]], dtype=torch.float64)\n", "tensor([[12.9250, 0.5235],\n", " [-4.5135, -8.2411]], dtype=torch.float64)\n", "tensor([[[12.9250, 0.5235],\n", " [12.9250, 0.5235]],\n", "\n", " [[-4.5135, -8.2411],\n", " [-4.5135, -8.2411]]], dtype=torch.float64)\n" ] } ], "source": [ "# Several sextupole groups\n", "\n", "ring = wrapper(RING, (None, None, 'dp'), (['LINE_AB'], ['QF', 'QD'], 'ms'), (['LINE_CD'], ['QF', 'QD'], 'ms'))\n", "\n", "def pfp_ring(state, dp, dms_ab, dms_cd):\n", " return ring(state + evaluate(pfp, [dp]), dp, dms_ab, dms_cd) - evaluate(pfp, [dp])\n", "\n", "def tune(dp, dms_ab, dms_cd): \n", " matrix = torch.func.jacrev(pfp_ring)(state, dp, dms_ab, dms_cd)\n", " tunes, *_ = twiss(matrix)\n", " return tunes\n", "\n", "def chromaticity(dms_ab, dms_cd):\n", " return torch.func.jacrev(tune)(dp, dms_ab, dms_cd)\n", "\n", "state = torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float64)\n", "dms_ab = torch.tensor([0.0, 0.0], dtype=torch.float64)\n", "dms_cd = torch.tensor([0.0, 0.0], dtype=torch.float64)\n", "dp = torch.tensor([0.0], dtype=torch.float64)\n", "\n", "print(tune(dp, dms_ab, dms_cd))\n", "print(chromaticity(dms_ab, dms_cd).squeeze())\n", "print(torch.func.jacrev(chromaticity, 0)(dms_ab, dms_cd).squeeze())\n", "print(torch.func.jacrev(chromaticity, 1)(dms_ab, dms_cd).squeeze())\n", "\n", "def fn(dms):\n", " dms_ab, dms_cd = dms\n", " return chromaticity(dms_ab, dms_cd)\n", "\n", "print(torch.func.jacrev(fn)(torch.stack([dms_ab, dms_cd])).squeeze())" ] } ], "metadata": { "colab": { "collapsed_sections": [ "myt0_gMIOq7b", "5d97819c" ], "name": "03_frequency.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.1" }, "latex_envs": { "LaTeX_envs_menu_present": true, "autoclose": false, "autocomplete": true, "bibliofile": "biblio.bib", "cite_by": "apalike", "current_citInitial": 1, "eqLabelWithNumbers": true, "eqNumInitial": 1, "hotkeys": { "equation": "Ctrl-E", "itemize": "Ctrl-I" }, "labels_anchors": false, "latex_user_defs": false, "report_style_numbering": false, "user_envs_cfg": false } }, "nbformat": 4, "nbformat_minor": 5 }