From 732aaf8b0fb8724e5e5bf836a37e542feb6a1658 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Tue, 4 Feb 2025 22:18:29 -0500 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20add=20VectorizedDenseInterp?= =?UTF-8?q?olation=20constructor=20method?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Nathaniel Starkman --- src/diffraxtra/interp.py | 128 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 120 insertions(+), 8 deletions(-) diff --git a/src/diffraxtra/interp.py b/src/diffraxtra/interp.py index 0409da6..77b51ad 100644 --- a/src/diffraxtra/interp.py +++ b/src/diffraxtra/interp.py @@ -6,7 +6,7 @@ __all__ = ["AbstractVectorizedDenseInterpolation", "VectorizedDenseInterpolation"] -from collections.abc import Callable +from collections.abc import Callable, Mapping from functools import partial from typing import Any, TypeAlias, cast, final from typing_extensions import override @@ -16,6 +16,7 @@ import jax import jax.numpy as jnp from jaxtyping import Array, ArrayLike, Float, Int, PyTree, Real, Shaped +from plum import dispatch BatchedRealTimes: TypeAlias = Real[Array, "{self.batch_shape} times"] BatchedRealScalar: TypeAlias = Real[Array, "{self.batch_shape}"] @@ -313,25 +314,101 @@ class VectorizedDenseInterpolation(AbstractVectorizedDenseInterpolation): #: The batch shape of the interpolation without vectorization over the #: solver that produced this interpolation. E.g. - batch_shape: Shape + batch_shape: Shape = eqx.field(converter=tuple) #: The shape of the solution. y0_shape: PyTree[Shape, "Y"] - def __init__(self, interp: dfx.DenseInterpolation, /) -> None: - # # Store the batch shape - self.batch_shape = jnp.shape(interp.t0_if_trivial) - self.y0_shape = jax.tree.map( - lambda x: x.shape[self.batch_ndim :], interp.y0_if_trivial + def __init__( + self, + scalar_interpolation: dfx.DenseInterpolation, + batch_shape: Shape | None = None, + y0_shape: PyTree[Shape, "Y"] | None = None, # type: ignore[name-defined] + ) -> None: + # Store the batch shape + self.batch_shape = ( + jnp.shape(scalar_interpolation.t0_if_trivial) + if batch_shape is None + else batch_shape + ) + + # Store the shape of the solution + self.y0_shape = ( + jax.tree.map( + lambda x: x.shape[self.batch_ndim :], scalar_interpolation.y0_if_trivial + ) + if y0_shape is None + else y0_shape ) # Flatten the batch shape of the interpolation self.scalar_interpolation = jax.tree.map( lambda x: x.reshape(-1, *x.shape[self.batch_ndim :]), - interp, + scalar_interpolation, is_leaf=eqx.is_array, ) + @classmethod + @dispatch.abstract # type: ignore[misc] + def from_( + cls: "type[VectorizedDenseInterpolation]", *args: Any, **kw: Any + ) -> "VectorizedDenseInterpolation": + """Construct a `VectorizedDenseInterpolation` from arguments. + + Examples + -------- + >>> import jax + >>> import jax.numpy as jnp + >>> import diffrax as dfx + + >>> vector_field = lambda t, y, args: -y + >>> term = dfx.ODETerm(vector_field) + >>> solver = dfx.Dopri5() + >>> ts = jnp.array([0.0, 1, 2, 3]) + >>> saveat = dfx.SaveAt(ts=ts, dense=True) + >>> stepsize_controller = dfx.PIDController(rtol=1e-5, atol=1e-5) + + >>> soln = dfx.diffeqsolve( + ... term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat, + ... stepsize_controller=stepsize_controller) + + This can be constructed from a `diffrax.DenseInterpolation`: + + >>> interp = VectorizedDenseInterpolation.from_(soln.interpolation) + >>> interp + VectorizedDenseInterpolation( + scalar_interpolation=DenseInterpolation( ... ), + batch_shape=(), y0_shape=() + ) + + Or from a `VectorizedDenseInterpolation`, returning the same object: + + >>> VectorizedDenseInterpolation.from_(interp) is interp + True + + The `batch_shape` and `y0_shape` can be specified manually: + + >>> interp = VectorizedDenseInterpolation.from_( + ... soln.interpolation, (), ()) + >>> interp + VectorizedDenseInterpolation( + scalar_interpolation=DenseInterpolation( ... ), + batch_shape=(), y0_shape=() + ) + + Everything can be packaged in a `Mapping`: + + >>> interp = VectorizedDenseInterpolation.from_( + ... {"scalar_interpolation": soln.interpolation}) + >>> interp + VectorizedDenseInterpolation( + scalar_interpolation=DenseInterpolation( ... ), + batch_shape=(), y0_shape=() + ) + + """ + raise NotImplementedError # pragma: no cover + # ======================= # Convenience methods @@ -350,3 +427,38 @@ def apply_to_solution(cls, soln: dfx.Solution, /) -> dfx.Solution: lambda tree: tree.interpolation, soln, cls(soln.interpolation) ) return vec_soln + + +# =================================================================== + + +@VectorizedDenseInterpolation.from_.dispatch +def from_( + _: type[VectorizedDenseInterpolation], obj: VectorizedDenseInterpolation, / +) -> VectorizedDenseInterpolation: + """Construct from a `VectorizedDenseInterpolation`. + + This is a no-op. + + """ + return obj + + +@VectorizedDenseInterpolation.from_.dispatch # type: ignore[no-redef] +def from_( + cls: type[VectorizedDenseInterpolation], obj: Mapping[str, Any], / +) -> VectorizedDenseInterpolation: + """Construct from a `Mapping`.""" + return cls(**obj) + + +@VectorizedDenseInterpolation.from_.dispatch # type: ignore[no-redef] +def from_( + cls: type[VectorizedDenseInterpolation], + obj: dfx.DenseInterpolation, + batch_shape: Any | None = None, + y0_shape: Any | None = None, + /, +) -> VectorizedDenseInterpolation: + """Construct from a `diffrax.DenseInterpolation`.""" + return cls(obj, batch_shape, y0_shape)