diff --git a/src/diffraxtra/_src/diffeq_abc.py b/src/diffraxtra/_src/diffeq_abc.py index 7141f39..55a9be2 100644 --- a/src/diffraxtra/_src/diffeq_abc.py +++ b/src/diffraxtra/_src/diffeq_abc.py @@ -160,19 +160,6 @@ def __call__( return soln - @dispatch(precedence=-1) # type: ignore[no-redef] - @partial(eqx.filter_jit) - def __call__( - self: "AbstractDiffEqSolver", terms: Any, /, **kwargs: Any - ) -> dfx.Solution: - """Solve a differential equation, with keyword arguments.""" - t0 = kwargs.pop("t0") - t1 = kwargs.pop("t1") - dt0 = kwargs.pop("dt0") - y0 = kwargs.pop("y0") - args = kwargs.pop("args", None) - return self(terms, t0, t1, dt0, y0, args, **kwargs) - # ------------------------------------------- # TODO: a contextmanager for producing a temporary DiffEqSolver with @@ -190,6 +177,22 @@ def from_( # ========================================================== +@AbstractDiffEqSolver.__call__.dispatch # type: ignore[attr-defined,misc] +@partial(eqx.filter_jit) +def call(self: "AbstractDiffEqSolver", terms: Any, /, **kwargs: Any) -> dfx.Solution: + """Solve a differential equation, with keyword arguments.""" + t0 = kwargs.pop("t0") + t1 = kwargs.pop("t1") + dt0 = kwargs.pop("dt0") + y0 = kwargs.pop("y0") + args = kwargs.pop("args", None) + out: dfx.Solution = self(terms, t0, t1, dt0, y0, args, **kwargs) # type: ignore[assignment, call-arg] + return out + + +# ========================================================== + + @AbstractDiffEqSolver.from_.dispatch def from_( cls: type[AbstractDiffEqSolver], obj: AbstractDiffEqSolver, /