From 2afa8afe4cca7098b409c96b8a6faeaf52d4a838 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Mon, 3 Mar 2025 13:10:51 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=A9=B9=20fix-simple:=20remove=20call=20di?= =?UTF-8?q?spatch=20precedence?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Nathaniel Starkman --- src/diffraxtra/_src/diffeq_abc.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/diffraxtra/_src/diffeq_abc.py b/src/diffraxtra/_src/diffeq_abc.py index 7141f39..55a9be2 100644 --- a/src/diffraxtra/_src/diffeq_abc.py +++ b/src/diffraxtra/_src/diffeq_abc.py @@ -160,19 +160,6 @@ def __call__( return soln - @dispatch(precedence=-1) # type: ignore[no-redef] - @partial(eqx.filter_jit) - def __call__( - self: "AbstractDiffEqSolver", terms: Any, /, **kwargs: Any - ) -> dfx.Solution: - """Solve a differential equation, with keyword arguments.""" - t0 = kwargs.pop("t0") - t1 = kwargs.pop("t1") - dt0 = kwargs.pop("dt0") - y0 = kwargs.pop("y0") - args = kwargs.pop("args", None) - return self(terms, t0, t1, dt0, y0, args, **kwargs) - # ------------------------------------------- # TODO: a contextmanager for producing a temporary DiffEqSolver with @@ -190,6 +177,22 @@ def from_( # ========================================================== +@AbstractDiffEqSolver.__call__.dispatch # type: ignore[attr-defined,misc] +@partial(eqx.filter_jit) +def call(self: "AbstractDiffEqSolver", terms: Any, /, **kwargs: Any) -> dfx.Solution: + """Solve a differential equation, with keyword arguments.""" + t0 = kwargs.pop("t0") + t1 = kwargs.pop("t1") + dt0 = kwargs.pop("dt0") + y0 = kwargs.pop("y0") + args = kwargs.pop("args", None) + out: dfx.Solution = self(terms, t0, t1, dt0, y0, args, **kwargs) # type: ignore[assignment, call-arg] + return out + + +# ========================================================== + + @AbstractDiffEqSolver.from_.dispatch def from_( cls: type[AbstractDiffEqSolver], obj: AbstractDiffEqSolver, /