Functional
Collection of factories for functional iteration Output of nest and fold function factories can be composed with regular JAX functions
- sympint.functional.fold(mappings: Sequence[Callable[[...], jax.Array]]) Callable[[...], jax.Array] [source]
Create a function that sequentially applies mappings from a given list
- Parameters:
mappings (Sequence[Callable[[Array, *Any], Array]]) – ordered sequence of transformation (identical signature) functions [R^n x … -> R^n]
- Return type:
Callable[[Array, *Any], Array]
- sympint.functional.fold_list(mappings: Sequence[Callable[[...], jax.Array]]) Callable[[...], jax.Array] [source]
Create a function that sequentially applies mappings from a given list And accumulates intermediate results
- Parameters:
mappings (Sequence[Callable[[Array, *Any], Array]]) – ordered sequence of transformation (identical signature) functions [R^n x … -> R^n]
- Return type:
Callable[[Array, *Any], Array]
- sympint.functional.nest(length: int, mapping: Callable[[...], jax.Array]) Callable[[...], jax.Array] [source]
Create a function that iteratively applies a state transformation mapping
- Parameters:
length (int, positive) – number of iterations to perform
mapping (Callable[[Array, *Any], Array]) – state transformation mapping R^n x … -> R^n
- Return type:
Callable[[Array, *Any], Array]
- sympint.functional.nest_list(length: int, mapping: Callable[[...], jax.Array]) Callable[[...], jax.Array] [source]
Create a function that iteratively applies a state transformation mapping And accumulates intermediate results
- Parameters:
length (int, positive) – number of iterations to perform
mapping (Callable[[Array, *Any], Array]) – state transformation mapping R^n x … -> R^n
- Return type:
Callable[[Array, *Any], Array]
Note
Initial value is not included in the output, output length is equal to the number of iterations Accumulate is equivalent to the following Python loop:
xs = [] for _ in range(n):
x = f(x, *args) xs.append(x)