diff --git a/README.md b/README.md index a910636..6c22775 100644 --- a/README.md +++ b/README.md @@ -177,6 +177,7 @@ DiffEqSolver( solver=Dopri5(scan_kind=None), stepsize_controller=ConstantStepSize(), adjoint=RecursiveCheckpointAdjoint(checkpoints=None), + event=None, max_steps=4096 ) @@ -192,6 +193,7 @@ DiffEqSolver( solver=Dopri5(scan_kind=None), stepsize_controller=PIDController( ... ), adjoint=RecursiveCheckpointAdjoint(checkpoints=None), + event=None, max_steps=4096 ) diff --git a/src/diffraxtra/_src/diffeq.py b/src/diffraxtra/_src/diffeq.py index 6dca39f..a905376 100644 --- a/src/diffraxtra/_src/diffeq.py +++ b/src/diffraxtra/_src/diffeq.py @@ -122,6 +122,9 @@ class DiffEqSolver(AbstractDiffEqSolver, strict=True): #: See `diffrax` for options. adjoint: dfx.AbstractAdjoint = eqx.field(default=default_adjoint) + #: Event. Can override the `event` argument when calling `DiffEqSolver` + event: dfx.Event | None = None + #: The maximum number of steps to take before quitting. #: Some `diffrax.SaveAt` options can be incompatible with `max_steps=None`, #: so you can override the `max_steps` argument when calling `DiffEqSolver` diff --git a/src/diffraxtra/_src/diffeq_abc.py b/src/diffraxtra/_src/diffeq_abc.py index 1f8640f..8747d53 100644 --- a/src/diffraxtra/_src/diffeq_abc.py +++ b/src/diffraxtra/_src/diffeq_abc.py @@ -68,15 +68,14 @@ class AbstractDiffEqSolver(eqx.Module, strict=True): #: See `diffrax` for options. adjoint: eqx.AbstractVar[dfx.AbstractAdjoint] + #: Event. Can override the `event` argument when calling `DiffEqSolver` + event: eqx.AbstractVar[dfx.Event | None] + #: The maximum number of steps to take before quitting. #: Some `diffrax.SaveAt` options can be incompatible with `max_steps=None`, #: so you can override the `max_steps` argument when calling `DiffEqSolver` max_steps: eqx.AbstractVar[int | None] - # TODO: should the event be a field? Again it can be overridden when calling - # `DiffEqSolver`. And should it be static? - # event: dfx.Event | None = eqx.field(default=default_event) # noqa: ERA001 - @partial(eqx.filter_jit) # @partial(quax.quaxify) # TODO: so don't need to strip units def __call__( @@ -91,7 +90,7 @@ def __call__( *, # Diffrax options saveat: dfx.SaveAt = default_saveat, - event: dfx.Event | None = default_event, + event: dfx.Event | None | _MISSING_TYPE = MISSING, max_steps: int | None | _MISSING_TYPE = MISSING, throw: bool = default_throw, progress_meter: dfx.AbstractProgressMeter[Any] = default_progress_meter, @@ -128,6 +127,8 @@ def __call__( """ # Parse `max_steps`, allowing for it to be overridden. max_steps = self.max_steps if max_steps is MISSING else max_steps + # Parse `event`, allowing for it to be overridden. + event = self.event if event is MISSING else event # Solve with `diffrax.diffeqsolve`, using the `DiffEqSolver`'s `solver`, # `stepsize_controller` and `adjoint`. @@ -215,6 +216,7 @@ def from_( solver=Dopri5(scan_kind=None), stepsize_controller=ConstantStepSize(), adjoint=RecursiveCheckpointAdjoint(checkpoints=None), + event=None, max_steps=4096 ) @@ -240,6 +242,7 @@ def from_( solver=Dopri5(scan_kind=None), stepsize_controller=PIDController( ... ), adjoint=RecursiveCheckpointAdjoint(checkpoints=None), + event=None, max_steps=4096 ) @@ -265,6 +268,7 @@ def from_(cls: type[AbstractDiffEqSolver], obj: eqx.Partial, /) -> AbstractDiffE solver=Dopri5(scan_kind=None), stepsize_controller=ConstantStepSize(), adjoint=RecursiveCheckpointAdjoint(checkpoints=None), + event=None, max_steps=4096 )