Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ jobs:
run: uv python install ${{ matrix.python-version }}

- name: Install the project
run: uv sync --group test --resolution lowest-direct
run: uv sync --group test --resolution lowest

- name: Test package
run: >-
Expand Down
11 changes: 6 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ ci:

repos:
- repo: https://github.com/commitizen-tools/commitizen
rev: v4.1.0
rev: v4.4.1
hooks:
- id: commitizen
additional_dependencies: [cz-conventional-gitmoji]
Expand Down Expand Up @@ -44,14 +44,14 @@ repos:
args: [--prose-wrap=always]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.8.6"
rev: "v0.9.9"
hooks:
- id: ruff
args: ["--fix", "--show-fixes"]
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.14.1"
rev: "v1.15.0"
hooks:
- id: mypy
files: src
Expand All @@ -61,9 +61,10 @@ repos:
- equinox
- jaxtyping
- numpy
- plum-dispatch

- repo: https://github.com/codespell-project/codespell
rev: "v2.3.0"
rev: "v2.4.1"
hooks:
- id: codespell

Expand All @@ -87,7 +88,7 @@ repos:
# additional_dependencies: ["validate-pyproject-schema-store[all]"]

- repo: https://github.com/python-jsonschema/check-jsonschema
rev: "0.30.0"
rev: "0.31.2"
hooks:
- id: check-dependabot
- id: check-github-workflows
Expand Down
17 changes: 5 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,28 +173,21 @@ From a `diffrax.AbstractSolver` object.
```pycon
>>> solver = DiffEqSolver.from_(dfx.Dopri5())
>>> solver
DiffEqSolver(
solver=Dopri5(scan_kind=None),
stepsize_controller=ConstantStepSize(),
adjoint=RecursiveCheckpointAdjoint(checkpoints=None),
event=None,
max_steps=4096
)
DiffEqSolver(solver=Dopri5())

```

(Where all other arguments are their default values and printed only if
changed.)

From a `collections.abc.Mapping`

```pycon
>>> solver = DiffEqSolver.from_({"solver": dfx.Dopri5(),
... "stepsize_controller": dfx.PIDController(rtol=1e-5, atol=1e-5)})
>>> solver
DiffEqSolver(
solver=Dopri5(scan_kind=None),
stepsize_controller=PIDController( ... ),
adjoint=RecursiveCheckpointAdjoint(checkpoints=None),
event=None,
max_steps=4096
solver=Dopri5(), stepsize_controller=PIDController(rtol=1e-05, atol=1e-05)
)

```
Expand Down
32 changes: 27 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
dynamic = ["version"]
dependencies = [
"diffrax>=0.6",
"plum-dispatch>=2.5.1",
"equinox>=0.11.5",
"jaxtyping>=0.2.35",
"plum-dispatch>=2.5.7",
"typing_extensions>=4.12.2",
]

Expand All @@ -50,11 +52,11 @@
]
nox = ["nox>=2024.10.9"]
test = [
"attrs",
"pytest >=6",
"attrs >=25.1",
"pytest >=8,<8.1",
"pytest-cov >=3",
"sybil",
]
"sybil >=9",
]

[tool.hatch]
version.source = "vcs"
Expand All @@ -72,6 +74,7 @@
filterwarnings = [
"error",
"ignore:jax.core.Primitive is deprecated:DeprecationWarning",
"ignore:jax.core.pp_eqn_rules is deprecated:DeprecationWarning",
]
log_cli_level = "INFO"
testpaths = [
Expand Down Expand Up @@ -143,3 +146,22 @@ src = ["src"]
"diffrax" = "dfx"
"equinox" = "eqx"
"numpy" = "np"


[tool.uv]
constraint-dependencies = [
# Because IPyKernel doesn't constrain its dependencies
"appnope >=0.1.4",
"cffi >=1.17",
"decorator >=5.1",
"psutil >=6.1.1",
"py >=1.11",
# Matplotlib
"matplotlib >=3.10",
"pillow>=10.3.0",
# Jax
"scipy >=1.14",
"jax>0.4.34",
# Misc
"opt-einsum >=3.3.0",
]
44 changes: 23 additions & 21 deletions src/diffraxtra/_src/diffeq_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,11 @@ class AbstractDiffEqSolver(eqx.Module, strict=True):
#: so you can override the `max_steps` argument when calling `DiffEqSolver`
max_steps: eqx.AbstractVar[int | None]

@partial(eqx.filter_jit)
# -------------------------------------------

# @partial(quax.quaxify) # TODO: so don't need to strip units
@dispatch
@partial(eqx.filter_jit)
def __call__(
self: "AbstractDiffEqSolver",
terms: PyTree[dfx.AbstractTerm],
Expand Down Expand Up @@ -157,11 +160,26 @@ def __call__(

return soln

@dispatch(precedence=-1) # type: ignore[no-redef]
@partial(eqx.filter_jit)
def __call__(
self: "AbstractDiffEqSolver", terms: Any, /, **kwargs: Any
) -> dfx.Solution:
"""Solve a differential equation, with keyword arguments."""
t0 = kwargs.pop("t0")
t1 = kwargs.pop("t1")
dt0 = kwargs.pop("dt0")
y0 = kwargs.pop("y0")
args = kwargs.pop("args", None)
return self(terms, t0, t1, dt0, y0, args, **kwargs)

# -------------------------------------------

# TODO: a contextmanager for producing a temporary DiffEqSolver with
# different field values.

@classmethod
@dispatch.abstract # type: ignore[misc]
@dispatch.abstract
def from_(
cls: "type[AbstractDiffEqSolver]", *args: Any, **kwargs: Any
) -> "AbstractDiffEqSolver":
Expand Down Expand Up @@ -212,13 +230,7 @@ def from_(

>>> solver = DiffEqSolver.from_(dfx.Dopri5())
>>> solver
DiffEqSolver(
solver=Dopri5(scan_kind=None),
stepsize_controller=ConstantStepSize(),
adjoint=RecursiveCheckpointAdjoint(checkpoints=None),
event=None,
max_steps=4096
)
DiffEqSolver(solver=Dopri5())

"""
return cls(scheme, **kwargs)
Expand All @@ -239,11 +251,7 @@ def from_(
... "stepsize_controller": dfx.PIDController(rtol=1e-5, atol=1e-5)})
>>> solver
DiffEqSolver(
solver=Dopri5(scan_kind=None),
stepsize_controller=PIDController( ... ),
adjoint=RecursiveCheckpointAdjoint(checkpoints=None),
event=None,
max_steps=4096
solver=Dopri5(), stepsize_controller=PIDController(rtol=1e-05, atol=1e-05)
)

"""
Expand All @@ -264,13 +272,7 @@ def from_(cls: type[AbstractDiffEqSolver], obj: eqx.Partial, /) -> AbstractDiffE

>>> solver = DiffEqSolver.from_(partial)
>>> solver
DiffEqSolver(
solver=Dopri5(scan_kind=None),
stepsize_controller=ConstantStepSize(),
adjoint=RecursiveCheckpointAdjoint(checkpoints=None),
event=None,
max_steps=4096
)
DiffEqSolver(solver=Dopri5())

"""
obj = eqx.error_if(
Expand Down
10 changes: 5 additions & 5 deletions src/diffraxtra/_src/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,15 @@ def evaluate(
return ys # noqa: RET504

@override
@property # type: ignore[override]
def t0(self) -> BatchedRealScalar:
@property
def t0(self) -> BatchedRealScalar: # type: ignore[override]
"""The start time of the interpolation."""
flatt0 = jax.vmap(lambda x: x.t0)(self.scalar_interpolation)
return cast(BatchedRealScalar, flatt0.reshape(*self.batch_shape))

@override
@property # type: ignore[override]
def t1(self) -> BatchedRealScalar:
@property
def t1(self) -> BatchedRealScalar: # type: ignore[override]
"""The end time of the interpolation."""
flatt1 = jax.vmap(lambda x: x.t1)(self.scalar_interpolation)
return cast(BatchedRealScalar, flatt1.reshape(*self.batch_shape))
Expand Down Expand Up @@ -349,7 +349,7 @@ def __init__(
)

@classmethod
@dispatch.abstract # type: ignore[misc]
@dispatch.abstract
def from_(
cls: "type[VectorizedDenseInterpolation]", *args: Any, **kw: Any
) -> "VectorizedDenseInterpolation":
Expand Down
Loading