Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ From a `diffrax.AbstractSolver` object.
DiffEqSolver(
solver=Dopri5(scan_kind=None),
stepsize_controller=ConstantStepSize(),
adjoint=RecursiveCheckpointAdjoint(checkpoints=None)
adjoint=RecursiveCheckpointAdjoint(checkpoints=None),
max_steps=4096
)

```
Expand All @@ -190,13 +191,14 @@ From a `collections.abc.Mapping`
DiffEqSolver(
solver=Dopri5(scan_kind=None),
stepsize_controller=PIDController( ... ),
adjoint=RecursiveCheckpointAdjoint(checkpoints=None)
adjoint=RecursiveCheckpointAdjoint(checkpoints=None),
max_steps=4096
)

```

For a full enumeration of the ways to construct a `DiffEqSolver` object, see
`galax.dynamics.integrate.DiffEqSolver.from_`.
`diffraxtra.DiffEqSolver.from_`.

### `VectorizedDenseInterpolation`

Expand Down
42 changes: 28 additions & 14 deletions src/diffraxtra/diffeq.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""General wrapper around `diffrax.diffeqsolve`.

This is private API.
This module is private. See `diffraxtra` for the public API.

"""
# ruff:noqa: ERA001

__all__ = [
"DiffEqSolver", # exported to Public API
Expand All @@ -14,7 +13,7 @@

import inspect
from collections.abc import Mapping
from dataclasses import KW_ONLY
from dataclasses import _MISSING_TYPE, KW_ONLY, MISSING
from functools import partial
from typing import Any, TypeAlias, final

Expand Down Expand Up @@ -48,13 +47,17 @@ class DiffEqSolver(eqx.Module, strict=True):

This is a convenience wrapper around `diffrax.diffeqsolve`, allowing for
pre-configuration of a `diffrax.AbstractSolver`,
`diffrax.AbstractStepSizeController`, and `diffrax.AbstractAdjoint`.
Pre-configuring these objects can be useful when you want to:
`diffrax.AbstractStepSizeController`, `diffrax.AbstractAdjoint`, and
``max_steps``. Pre-configuring these objects can be useful when you want to:

- repeatedly solve similar differential equations and can reuse the same
solver, step size controller, and adjoint.
solver and associated settings.
- pass the differential equation solver as an argument to a function.

Note that for some `diffrax.SaveAt` options, `max_steps=None` can be
incompatible. In such cases, you can override the `max_steps` argument when
calling the `DiffEqSolver` object.

Examples
--------
>>> import jax.numpy as jnp
Expand Down Expand Up @@ -104,7 +107,7 @@ class DiffEqSolver(eqx.Module, strict=True):
Array([[0.90483742, 0.81872516],
[0.74080871, 0.67031456]], dtype=float64)

This can be more conveniently done using the `vectorize_interpolation` argument.
This can be more conveniently done using `vectorize_interpolation`.
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
... vectorize_interpolation=True)
>>> soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
Expand All @@ -129,9 +132,14 @@ class DiffEqSolver(eqx.Module, strict=True):
#: See `diffrax` for options.
adjoint: dfx.AbstractAdjoint = eqx.field(default=default_adjoint)

# TODO: should `max_steps` be a field? Given that `max_steps=None` can be
# incompatible with some `SaveAt` options, it would still need to be
# overridable in `__call__`.
#: The maximum number of steps to take before quitting.
#: Some `diffrax.SaveAt` options can be incompatible with `max_steps=None`,
#: so you can override the `max_steps` argument when calling `DiffEqSolver`
max_steps: int | None = eqx.field(default=default_max_steps, static=True)

# TODO: should the event be a field? Again it can be overridden when calling
# `DiffEqSolver`. And should it be static?
# event: dfx.Event | None = eqx.field(default=default_event) # noqa: ERA001

@partial(eqx.filter_jit)
# @partial(quax.quaxify) # TODO: so don't need to strip units
Expand All @@ -148,7 +156,7 @@ def __call__(
# Diffrax options
saveat: dfx.SaveAt = default_saveat,
event: dfx.Event | None = default_event,
max_steps: int | None = default_max_steps,
max_steps: int | None | _MISSING_TYPE = MISSING,
throw: bool = default_throw,
progress_meter: dfx.AbstractProgressMeter[Any] = default_progress_meter,
solver_state: PyTree[ArrayLike] | None = None,
Expand Down Expand Up @@ -182,6 +190,9 @@ def __call__(
using `VectorizedDenseInterpolation`.

"""
# Parse `max_steps`, allowing for it to be overridden.
max_steps = self.max_steps if max_steps is MISSING else max_steps

# Solve with `diffrax.diffeqsolve`, using the `DiffEqSolver`'s `solver`,
# `stepsize_controller` and `adjoint`.
soln: dfx.Solution = dfx.diffeqsolve(
Expand Down Expand Up @@ -258,7 +269,8 @@ def from_(
DiffEqSolver(
solver=Dopri5(scan_kind=None),
stepsize_controller=ConstantStepSize(),
adjoint=RecursiveCheckpointAdjoint(checkpoints=None)
adjoint=RecursiveCheckpointAdjoint(checkpoints=None),
max_steps=4096
)

"""
Expand All @@ -280,7 +292,8 @@ def from_(cls: type[DiffEqSolver], obj: Mapping[str, Any], /) -> DiffEqSolver:
DiffEqSolver(
solver=Dopri5(scan_kind=None),
stepsize_controller=PIDController( ... ),
adjoint=RecursiveCheckpointAdjoint(checkpoints=None)
adjoint=RecursiveCheckpointAdjoint(checkpoints=None),
max_steps=4096
)

"""
Expand All @@ -304,7 +317,8 @@ def from_(cls: type[DiffEqSolver], obj: eqx.Partial, /) -> DiffEqSolver:
DiffEqSolver(
solver=Dopri5(scan_kind=None),
stepsize_controller=ConstantStepSize(),
adjoint=RecursiveCheckpointAdjoint(checkpoints=None)
adjoint=RecursiveCheckpointAdjoint(checkpoints=None),
max_steps=4096
)

"""
Expand Down