diff --git a/README.md b/README.md index cb261da..a910636 100644 --- a/README.md +++ b/README.md @@ -176,7 +176,8 @@ From a `diffrax.AbstractSolver` object. DiffEqSolver( solver=Dopri5(scan_kind=None), stepsize_controller=ConstantStepSize(), - adjoint=RecursiveCheckpointAdjoint(checkpoints=None) + adjoint=RecursiveCheckpointAdjoint(checkpoints=None), + max_steps=4096 ) ``` @@ -190,13 +191,14 @@ From a `collections.abc.Mapping` DiffEqSolver( solver=Dopri5(scan_kind=None), stepsize_controller=PIDController( ... ), - adjoint=RecursiveCheckpointAdjoint(checkpoints=None) + adjoint=RecursiveCheckpointAdjoint(checkpoints=None), + max_steps=4096 ) ``` For a full enumeration of the ways to construct a `DiffEqSolver` object, see -`galax.dynamics.integrate.DiffEqSolver.from_`. +`diffraxtra.DiffEqSolver.from_`. ### `VectorizedDenseInterpolation` diff --git a/src/diffraxtra/diffeq.py b/src/diffraxtra/diffeq.py index aeff504..a67c273 100644 --- a/src/diffraxtra/diffeq.py +++ b/src/diffraxtra/diffeq.py @@ -1,9 +1,8 @@ """General wrapper around `diffrax.diffeqsolve`. -This is private API. +This module is private. See `diffraxtra` for the public API. """ -# ruff:noqa: ERA001 __all__ = [ "DiffEqSolver", # exported to Public API @@ -14,7 +13,7 @@ import inspect from collections.abc import Mapping -from dataclasses import KW_ONLY +from dataclasses import _MISSING_TYPE, KW_ONLY, MISSING from functools import partial from typing import Any, TypeAlias, final @@ -48,13 +47,17 @@ class DiffEqSolver(eqx.Module, strict=True): This is a convenience wrapper around `diffrax.diffeqsolve`, allowing for pre-configuration of a `diffrax.AbstractSolver`, - `diffrax.AbstractStepSizeController`, and `diffrax.AbstractAdjoint`. - Pre-configuring these objects can be useful when you want to: + `diffrax.AbstractStepSizeController`, `diffrax.AbstractAdjoint`, and + ``max_steps``. Pre-configuring these objects can be useful when you want to: - repeatedly solve similar differential equations and can reuse the same - solver, step size controller, and adjoint. + solver and associated settings. - pass the differential equation solver as an argument to a function. + Note that for some `diffrax.SaveAt` options, `max_steps=None` can be + incompatible. In such cases, you can override the `max_steps` argument when + calling the `DiffEqSolver` object. + Examples -------- >>> import jax.numpy as jnp @@ -104,7 +107,7 @@ class DiffEqSolver(eqx.Module, strict=True): Array([[0.90483742, 0.81872516], [0.74080871, 0.67031456]], dtype=float64) - This can be more conveniently done using the `vectorize_interpolation` argument. + This can be more conveniently done using `vectorize_interpolation`. >>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat, ... vectorize_interpolation=True) >>> soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2)) @@ -129,9 +132,14 @@ class DiffEqSolver(eqx.Module, strict=True): #: See `diffrax` for options. adjoint: dfx.AbstractAdjoint = eqx.field(default=default_adjoint) - # TODO: should `max_steps` be a field? Given that `max_steps=None` can be - # incompatible with some `SaveAt` options, it would still need to be - # overridable in `__call__`. + #: The maximum number of steps to take before quitting. + #: Some `diffrax.SaveAt` options can be incompatible with `max_steps=None`, + #: so you can override the `max_steps` argument when calling `DiffEqSolver` + max_steps: int | None = eqx.field(default=default_max_steps, static=True) + + # TODO: should the event be a field? Again it can be overridden when calling + # `DiffEqSolver`. And should it be static? + # event: dfx.Event | None = eqx.field(default=default_event) # noqa: ERA001 @partial(eqx.filter_jit) # @partial(quax.quaxify) # TODO: so don't need to strip units @@ -148,7 +156,7 @@ def __call__( # Diffrax options saveat: dfx.SaveAt = default_saveat, event: dfx.Event | None = default_event, - max_steps: int | None = default_max_steps, + max_steps: int | None | _MISSING_TYPE = MISSING, throw: bool = default_throw, progress_meter: dfx.AbstractProgressMeter[Any] = default_progress_meter, solver_state: PyTree[ArrayLike] | None = None, @@ -182,6 +190,9 @@ def __call__( using `VectorizedDenseInterpolation`. """ + # Parse `max_steps`, allowing for it to be overridden. + max_steps = self.max_steps if max_steps is MISSING else max_steps + # Solve with `diffrax.diffeqsolve`, using the `DiffEqSolver`'s `solver`, # `stepsize_controller` and `adjoint`. soln: dfx.Solution = dfx.diffeqsolve( @@ -258,7 +269,8 @@ def from_( DiffEqSolver( solver=Dopri5(scan_kind=None), stepsize_controller=ConstantStepSize(), - adjoint=RecursiveCheckpointAdjoint(checkpoints=None) + adjoint=RecursiveCheckpointAdjoint(checkpoints=None), + max_steps=4096 ) """ @@ -280,7 +292,8 @@ def from_(cls: type[DiffEqSolver], obj: Mapping[str, Any], /) -> DiffEqSolver: DiffEqSolver( solver=Dopri5(scan_kind=None), stepsize_controller=PIDController( ... ), - adjoint=RecursiveCheckpointAdjoint(checkpoints=None) + adjoint=RecursiveCheckpointAdjoint(checkpoints=None), + max_steps=4096 ) """ @@ -304,7 +317,8 @@ def from_(cls: type[DiffEqSolver], obj: eqx.Partial, /) -> DiffEqSolver: DiffEqSolver( solver=Dopri5(scan_kind=None), stepsize_controller=ConstantStepSize(), - adjoint=RecursiveCheckpointAdjoint(checkpoints=None) + adjoint=RecursiveCheckpointAdjoint(checkpoints=None), + max_steps=4096 ) """