Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v5
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
cache-dependency-glob: "uv.lock"
Expand Down Expand Up @@ -59,7 +59,7 @@ jobs:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v5
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
cache-dependency-glob: "uv.lock"
Expand All @@ -81,7 +81,7 @@ jobs:
--durations=20

- name: Upload coverage report
uses: codecov/codecov-action@v5.4.0
uses: codecov/codecov-action@v5.4.2
with:
token: ${{ secrets.CODECOV_TOKEN }}

Expand All @@ -99,7 +99,7 @@ jobs:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v5
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
cache-dependency-glob: "uv.lock"
Expand All @@ -116,6 +116,6 @@ jobs:
--durations=20

- name: Upload coverage report
uses: codecov/codecov-action@v5.4.0
uses: codecov/codecov-action@v5.4.2
with:
token: ${{ secrets.CODECOV_TOKEN }}
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ We'll start with a non-batched interpolation:
... term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
... stepsize_controller=stepsize_controller)
>>> interp = VectorizedDenseInterpolation(sol.interpolation)
>>> interp
>>> interp # doctest: +SKIP
VectorizedDenseInterpolation(
scalar_interpolation=DenseInterpolation(
ts=f64[1,4097],
Expand Down
8 changes: 4 additions & 4 deletions src/diffraxtra/_src/diffeq_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
"params",
]

import functools as ft
import inspect
from collections.abc import Mapping
from dataclasses import _MISSING_TYPE, KW_ONLY, MISSING
from functools import partial
from typing import Any, TypeAlias

import diffrax as dfx
Expand Down Expand Up @@ -78,9 +78,9 @@ class AbstractDiffEqSolver(eqx.Module, strict=True):

# -------------------------------------------

# @partial(quax.quaxify) # TODO: so don't need to strip units
# @ft.partial(quax.quaxify) # TODO: so don't need to strip units
@dispatch
@partial(eqx.filter_jit)
@ft.partial(eqx.filter_jit)
def __call__(
self: "AbstractDiffEqSolver",
terms: PyTree[dfx.AbstractTerm],
Expand Down Expand Up @@ -178,7 +178,7 @@ def from_(


@AbstractDiffEqSolver.__call__.dispatch # type: ignore[attr-defined,misc]
@partial(eqx.filter_jit)
@ft.partial(eqx.filter_jit)
def call(self: "AbstractDiffEqSolver", terms: Any, /, **kwargs: Any) -> dfx.Solution:
"""Solve a differential equation, with keyword arguments."""
t0 = kwargs.pop("t0")
Expand Down
11 changes: 6 additions & 5 deletions src/diffraxtra/_src/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

__all__ = ["AbstractVectorizedDenseInterpolation", "VectorizedDenseInterpolation"]

import functools as ft
from collections.abc import Callable, Mapping
from functools import partial
from typing import Any, TypeAlias, cast, final
from typing_extensions import override

Expand Down Expand Up @@ -95,7 +95,7 @@ def evaluate(
# Evaluate the scalar interpolation over the batch dimension of the
# interpolator and an array of times.
ys = jax.vmap( # vmap over the batch dimension of the interpolator
lambda interp: jax.vmap(partial(interp.evaluate, left=left))(t0)
lambda interp: jax.vmap(ft.partial(interp.evaluate, left=left))(t0)
)(self.scalar_interpolation)

# Reshape the result to match the input shape in the time axes.
Expand Down Expand Up @@ -185,7 +185,7 @@ class VectorizedDenseInterpolation(AbstractVectorizedDenseInterpolation):
... term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
... stepsize_controller=stepsize_controller)
>>> interp = VectorizedDenseInterpolation(sol.interpolation)
>>> interp
>>> interp # doctest: +SKIP
VectorizedDenseInterpolation(
scalar_interpolation=DenseInterpolation(
ts=f64[1,4097],
Expand Down Expand Up @@ -264,9 +264,10 @@ class VectorizedDenseInterpolation(AbstractVectorizedDenseInterpolation):
>>> ys.shape # (batch, *times)
(3, 2, 2)

Let's inspect the rest of the API.
Let's inspect the rest of the API. First, the flattened) original
interpolation

>>> interp.scalar_interpolation # (flattened) original interpolation
>>> interp.scalar_interpolation # doctest: +SKIP
DenseInterpolation(
ts=f64[3,4097],
ts_size=weak_i64[3],
Expand Down