Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 120 additions & 8 deletions src/diffraxtra/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

__all__ = ["AbstractVectorizedDenseInterpolation", "VectorizedDenseInterpolation"]

from collections.abc import Callable
from collections.abc import Callable, Mapping
from functools import partial
from typing import Any, TypeAlias, cast, final
from typing_extensions import override
Expand All @@ -16,6 +16,7 @@
import jax
import jax.numpy as jnp
from jaxtyping import Array, ArrayLike, Float, Int, PyTree, Real, Shaped
from plum import dispatch

BatchedRealTimes: TypeAlias = Real[Array, "{self.batch_shape} times"]
BatchedRealScalar: TypeAlias = Real[Array, "{self.batch_shape}"]
Expand Down Expand Up @@ -313,25 +314,101 @@ class VectorizedDenseInterpolation(AbstractVectorizedDenseInterpolation):

#: The batch shape of the interpolation without vectorization over the
#: solver that produced this interpolation. E.g.
batch_shape: Shape
batch_shape: Shape = eqx.field(converter=tuple)

#: The shape of the solution.
y0_shape: PyTree[Shape, "Y"]

def __init__(self, interp: dfx.DenseInterpolation, /) -> None:
# # Store the batch shape
self.batch_shape = jnp.shape(interp.t0_if_trivial)
self.y0_shape = jax.tree.map(
lambda x: x.shape[self.batch_ndim :], interp.y0_if_trivial
def __init__(
self,
scalar_interpolation: dfx.DenseInterpolation,
batch_shape: Shape | None = None,
y0_shape: PyTree[Shape, "Y"] | None = None, # type: ignore[name-defined]
) -> None:
# Store the batch shape
self.batch_shape = (
jnp.shape(scalar_interpolation.t0_if_trivial)
if batch_shape is None
else batch_shape
)

# Store the shape of the solution
self.y0_shape = (
jax.tree.map(
lambda x: x.shape[self.batch_ndim :], scalar_interpolation.y0_if_trivial
)
if y0_shape is None
else y0_shape
)

# Flatten the batch shape of the interpolation
self.scalar_interpolation = jax.tree.map(
lambda x: x.reshape(-1, *x.shape[self.batch_ndim :]),
interp,
scalar_interpolation,
is_leaf=eqx.is_array,
)

@classmethod
@dispatch.abstract # type: ignore[misc]
def from_(
cls: "type[VectorizedDenseInterpolation]", *args: Any, **kw: Any
) -> "VectorizedDenseInterpolation":
"""Construct a `VectorizedDenseInterpolation` from arguments.

Examples
--------
>>> import jax
>>> import jax.numpy as jnp
>>> import diffrax as dfx

>>> vector_field = lambda t, y, args: -y
>>> term = dfx.ODETerm(vector_field)
>>> solver = dfx.Dopri5()
>>> ts = jnp.array([0.0, 1, 2, 3])
>>> saveat = dfx.SaveAt(ts=ts, dense=True)
>>> stepsize_controller = dfx.PIDController(rtol=1e-5, atol=1e-5)

>>> soln = dfx.diffeqsolve(
... term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
... stepsize_controller=stepsize_controller)

This can be constructed from a `diffrax.DenseInterpolation`:

>>> interp = VectorizedDenseInterpolation.from_(soln.interpolation)
>>> interp
VectorizedDenseInterpolation(
scalar_interpolation=DenseInterpolation( ... ),
batch_shape=(), y0_shape=()
)

Or from a `VectorizedDenseInterpolation`, returning the same object:

>>> VectorizedDenseInterpolation.from_(interp) is interp
True

The `batch_shape` and `y0_shape` can be specified manually:

>>> interp = VectorizedDenseInterpolation.from_(
... soln.interpolation, (), ())
>>> interp
VectorizedDenseInterpolation(
scalar_interpolation=DenseInterpolation( ... ),
batch_shape=(), y0_shape=()
)

Everything can be packaged in a `Mapping`:

>>> interp = VectorizedDenseInterpolation.from_(
... {"scalar_interpolation": soln.interpolation})
>>> interp
VectorizedDenseInterpolation(
scalar_interpolation=DenseInterpolation( ... ),
batch_shape=(), y0_shape=()
)

"""
raise NotImplementedError # pragma: no cover

# =======================
# Convenience methods

Expand All @@ -350,3 +427,38 @@ def apply_to_solution(cls, soln: dfx.Solution, /) -> dfx.Solution:
lambda tree: tree.interpolation, soln, cls(soln.interpolation)
)
return vec_soln


# ===================================================================


@VectorizedDenseInterpolation.from_.dispatch
def from_(
_: type[VectorizedDenseInterpolation], obj: VectorizedDenseInterpolation, /
) -> VectorizedDenseInterpolation:
"""Construct from a `VectorizedDenseInterpolation`.

This is a no-op.

"""
return obj


@VectorizedDenseInterpolation.from_.dispatch # type: ignore[no-redef]
def from_(
cls: type[VectorizedDenseInterpolation], obj: Mapping[str, Any], /
) -> VectorizedDenseInterpolation:
"""Construct from a `Mapping`."""
return cls(**obj)


@VectorizedDenseInterpolation.from_.dispatch # type: ignore[no-redef]
def from_(
cls: type[VectorizedDenseInterpolation],
obj: dfx.DenseInterpolation,
batch_shape: Any | None = None,
y0_shape: Any | None = None,
/,
) -> VectorizedDenseInterpolation:
"""Construct from a `diffrax.DenseInterpolation`."""
return cls(obj, batch_shape, y0_shape)