From 0de248fe44665cef1801a00969305feeed7c09ac Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 1 May 2025 08:02:17 +0000 Subject: [PATCH 1/2] build(deps): bump the actions group with 2 updates Bumps the actions group with 2 updates: [astral-sh/setup-uv](https://github.com/astral-sh/setup-uv) and [codecov/codecov-action](https://github.com/codecov/codecov-action). Updates `astral-sh/setup-uv` from 5 to 6 - [Release notes](https://github.com/astral-sh/setup-uv/releases) - [Commits](https://github.com/astral-sh/setup-uv/compare/v5...v6) Updates `codecov/codecov-action` from 5.4.0 to 5.4.2 - [Release notes](https://github.com/codecov/codecov-action/releases) - [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/codecov/codecov-action/compare/v5.4.0...v5.4.2) --- updated-dependencies: - dependency-name: astral-sh/setup-uv dependency-version: '6' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions - dependency-name: codecov/codecov-action dependency-version: 5.4.2 dependency-type: direct:production update-type: version-update:semver-patch dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/ci.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fa7e3fc..1e7392e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,7 +24,7 @@ jobs: - uses: actions/checkout@v4 - name: Install uv - uses: astral-sh/setup-uv@v5 + uses: astral-sh/setup-uv@v6 with: enable-cache: true cache-dependency-glob: "uv.lock" @@ -59,7 +59,7 @@ jobs: - uses: actions/checkout@v4 - name: Install uv - uses: astral-sh/setup-uv@v5 + uses: astral-sh/setup-uv@v6 with: enable-cache: true cache-dependency-glob: "uv.lock" @@ -81,7 +81,7 @@ jobs: --durations=20 - name: Upload coverage report - uses: codecov/codecov-action@v5.4.0 + uses: codecov/codecov-action@v5.4.2 with: token: ${{ secrets.CODECOV_TOKEN }} @@ -99,7 +99,7 @@ jobs: - uses: actions/checkout@v4 - name: Install uv - uses: astral-sh/setup-uv@v5 + uses: astral-sh/setup-uv@v6 with: enable-cache: true cache-dependency-glob: "uv.lock" @@ -116,6 +116,6 @@ jobs: --durations=20 - name: Upload coverage report - uses: codecov/codecov-action@v5.4.0 + uses: codecov/codecov-action@v5.4.2 with: token: ${{ secrets.CODECOV_TOKEN }} From 905012d9f9f582494d5a494077a1c09d456b1e69 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Fri, 2 May 2025 07:26:12 -0400 Subject: [PATCH 2/2] tests: doctest skip Signed-off-by: Nathaniel Starkman --- README.md | 2 +- src/diffraxtra/_src/diffeq_abc.py | 8 ++++---- src/diffraxtra/_src/interp.py | 11 ++++++----- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index f78742e..1f990eb 100644 --- a/README.md +++ b/README.md @@ -222,7 +222,7 @@ We'll start with a non-batched interpolation: ... term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat, ... stepsize_controller=stepsize_controller) >>> interp = VectorizedDenseInterpolation(sol.interpolation) ->>> interp +>>> interp # doctest: +SKIP VectorizedDenseInterpolation( scalar_interpolation=DenseInterpolation( ts=f64[1,4097], diff --git a/src/diffraxtra/_src/diffeq_abc.py b/src/diffraxtra/_src/diffeq_abc.py index 55a9be2..40a89cc 100644 --- a/src/diffraxtra/_src/diffeq_abc.py +++ b/src/diffraxtra/_src/diffeq_abc.py @@ -9,10 +9,10 @@ "params", ] +import functools as ft import inspect from collections.abc import Mapping from dataclasses import _MISSING_TYPE, KW_ONLY, MISSING -from functools import partial from typing import Any, TypeAlias import diffrax as dfx @@ -78,9 +78,9 @@ class AbstractDiffEqSolver(eqx.Module, strict=True): # ------------------------------------------- - # @partial(quax.quaxify) # TODO: so don't need to strip units + # @ft.partial(quax.quaxify) # TODO: so don't need to strip units @dispatch - @partial(eqx.filter_jit) + @ft.partial(eqx.filter_jit) def __call__( self: "AbstractDiffEqSolver", terms: PyTree[dfx.AbstractTerm], @@ -178,7 +178,7 @@ def from_( @AbstractDiffEqSolver.__call__.dispatch # type: ignore[attr-defined,misc] -@partial(eqx.filter_jit) +@ft.partial(eqx.filter_jit) def call(self: "AbstractDiffEqSolver", terms: Any, /, **kwargs: Any) -> dfx.Solution: """Solve a differential equation, with keyword arguments.""" t0 = kwargs.pop("t0") diff --git a/src/diffraxtra/_src/interp.py b/src/diffraxtra/_src/interp.py index e0ae8cd..ffa1fb5 100644 --- a/src/diffraxtra/_src/interp.py +++ b/src/diffraxtra/_src/interp.py @@ -6,8 +6,8 @@ __all__ = ["AbstractVectorizedDenseInterpolation", "VectorizedDenseInterpolation"] +import functools as ft from collections.abc import Callable, Mapping -from functools import partial from typing import Any, TypeAlias, cast, final from typing_extensions import override @@ -95,7 +95,7 @@ def evaluate( # Evaluate the scalar interpolation over the batch dimension of the # interpolator and an array of times. ys = jax.vmap( # vmap over the batch dimension of the interpolator - lambda interp: jax.vmap(partial(interp.evaluate, left=left))(t0) + lambda interp: jax.vmap(ft.partial(interp.evaluate, left=left))(t0) )(self.scalar_interpolation) # Reshape the result to match the input shape in the time axes. @@ -185,7 +185,7 @@ class VectorizedDenseInterpolation(AbstractVectorizedDenseInterpolation): ... term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat, ... stepsize_controller=stepsize_controller) >>> interp = VectorizedDenseInterpolation(sol.interpolation) - >>> interp + >>> interp # doctest: +SKIP VectorizedDenseInterpolation( scalar_interpolation=DenseInterpolation( ts=f64[1,4097], @@ -264,9 +264,10 @@ class VectorizedDenseInterpolation(AbstractVectorizedDenseInterpolation): >>> ys.shape # (batch, *times) (3, 2, 2) - Let's inspect the rest of the API. + Let's inspect the rest of the API. First, the flattened) original + interpolation - >>> interp.scalar_interpolation # (flattened) original interpolation + >>> interp.scalar_interpolation # doctest: +SKIP DenseInterpolation( ts=f64[3,4097], ts_size=weak_i64[3],