From 2833ad4153f1d6950e79ad112b2642ab038a6e05 Mon Sep 17 00:00:00 2001 From: Quentin Blampey Date: Wed, 27 Aug 2025 09:21:41 +0200 Subject: [PATCH 01/12] add min max --- src/fast_array_utils/stats/__init__.py | 24 ++++ src/fast_array_utils/stats/_min_max.py | 190 +++++++++++++++++++++++++ 2 files changed, 214 insertions(+) create mode 100644 src/fast_array_utils/stats/_min_max.py diff --git a/src/fast_array_utils/stats/__init__.py b/src/fast_array_utils/stats/__init__.py index 60dcae6..9f1bae4 100644 --- a/src/fast_array_utils/stats/__init__.py +++ b/src/fast_array_utils/stats/__init__.py @@ -266,3 +266,27 @@ def sum( validate_axis(x.ndim, axis) return sum_(x, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array) + + +def min( + x: CpuArray | GpuArray | DiskArray | types.DaskArray, + /, + *, + axis: Literal[0, 1, None] = None, + keep_cupy_as_array: bool = False, +) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray: + from ._min_max import min_max + + return min_max("min", x, axis=axis, keep_cupy_as_array=keep_cupy_as_array) + + +def max( + x: CpuArray | GpuArray | DiskArray | types.DaskArray, + /, + *, + axis: Literal[0, 1, None] = None, + keep_cupy_as_array: bool = False, +) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray: + from ._min_max import min_max + + return min_max("max", x, axis=axis, keep_cupy_as_array=keep_cupy_as_array) diff --git a/src/fast_array_utils/stats/_min_max.py b/src/fast_array_utils/stats/_min_max.py new file mode 100644 index 0000000..7397e02 --- /dev/null +++ b/src/fast_array_utils/stats/_min_max.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: MPL-2.0 +from __future__ import annotations + +from functools import partial, singledispatch +from typing import TYPE_CHECKING, cast, overload + +import numpy as np +from numpy.exceptions import AxisError + +from .. import types +from . import validate_axis + + +if TYPE_CHECKING: + from typing import Any, Literal, TypeAlias + + from numpy.typing import NDArray + + from ..typing import CpuArray, DiskArray, GpuArray + + ComplexAxis: TypeAlias = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1, None] + + MinMaxOps = Literal["max", "min"] + + +@overload +def min_max(ops: MinMaxOps, x: CpuArray | DiskArray, /, *, axis: None = None, keep_cupy_as_array: bool = False) -> np.number[Any]: ... +@overload +def min_max(ops: MinMaxOps, x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> NDArray[Any]: ... + + +@overload +def min_max(ops: MinMaxOps, x: GpuArray, /, *, axis: None = None, keep_cupy_as_array: Literal[False] = False) -> np.number[Any]: ... +@overload +def min_max(ops: MinMaxOps, x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> types.CupyArray: ... +@overload +def min_max(ops: MinMaxOps, x: GpuArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> types.CupyArray: ... + + +@overload +def min_max(ops: MinMaxOps, x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ... + + +def min_max( + ops: MinMaxOps, + x: CpuArray | GpuArray | DiskArray | types.DaskArray, + /, + *, + axis: Literal[0, 1, None] = None, + keep_cupy_as_array: bool = False, +) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray: + from ._min_max import min_max_ + + validate_axis(x.ndim, axis) + return min_max_(ops, x, axis=axis, keep_cupy_as_array=keep_cupy_as_array) + + +@singledispatch +def min_max_( + op: MinMaxOps, + x: CpuArray | GpuArray | DiskArray | types.DaskArray, + /, + *, + axis: Literal[0, 1, None] = None, + keep_cupy_as_array: bool = False, +) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray: + del keep_cupy_as_array + if TYPE_CHECKING: + # these are never passed to this fallback function, but `singledispatch` wants them + assert not isinstance(x, types.CSBase | types.DaskArray | types.CupyArray | types.CupyCSMatrix) + # np supports these, but doesn’t know it. (TODO: test cupy) + assert not isinstance(x, types.ZarrArray | types.H5Dataset) + return cast("NDArray[Any] | np.number[Any]", getattr(np, op)(x, axis=axis)) + + +@min_max_.register(types.CupyArray | types.CupyCSMatrix) +def _min_max_cupy( + op: MinMaxOps, + x: GpuArray, + /, + *, + axis: Literal[0, 1, None] = None, + keep_cupy_as_array: bool = False, +) -> types.CupyArray | np.number[Any]: + arr = cast("types.CupyArray", getattr(np, op)(x, axis=axis)) + return cast("np.number[Any]", arr.get()[()]) if not keep_cupy_as_array and axis is None else arr.squeeze() + + +@min_max_.register(types.CSBase) +def _min_max_cs( + op: MinMaxOps, + x: types.CSBase, + /, + *, + axis: Literal[0, 1, None] = None, + keep_cupy_as_array: bool = False, +) -> NDArray[Any] | np.number[Any]: + del keep_cupy_as_array + import scipy.sparse as sp + + if isinstance(x, types.CSMatrix): + x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x) + + if axis is None: + return cast("np.number[Any]", getattr(sp, op)(x)) + return cast("NDArray[Any] | np.number[Any]", getattr(sp, op)(x, axis=axis)) + + +@min_max_.register(types.DaskArray) +def _min_max_dask( + op: MinMaxOps, + x: types.DaskArray, + /, + *, + axis: Literal[0, 1, None] = None, + keep_cupy_as_array: bool = False, +) -> types.DaskArray: + import dask.array as da + + if isinstance(x._meta, np.matrix): # pragma: no cover # noqa: SLF001 + msg = "sum/max/min does not support numpy matrices" + raise TypeError(msg) + + rv = da.reduction( + x, + partial(min_max_dask_inner, op), # type: ignore[arg-type] + partial(min_max_dask_inner, op), # pyright: ignore[reportArgumentType] + axis=axis, + meta=np.array([], dtype=x.dtype), + ) + + if axis is not None or ( + isinstance(rv._meta, types.CupyArray) # noqa: SLF001 + and keep_cupy_as_array + ): + return rv + + def to_scalar(a: types.CupyArray | NDArray[Any]) -> np.number[Any]: + if isinstance(a, types.CupyArray): + a = a.get() + return a.reshape(())[()] # type: ignore[return-value] + + return rv.map_blocks(to_scalar, meta=x.dtype.type(0)) # type: ignore[arg-type] + + +def min_max_dask_inner( + op: MinMaxOps, + a: CpuArray | GpuArray, + /, + *, + axis: ComplexAxis = None, + keepdims: bool = False, +) -> NDArray[Any] | types.CupyArray: + axis = normalize_axis(axis, a.ndim) + rv = min_max(op, a, axis=axis, keep_cupy_as_array=True) # type: ignore[misc,arg-type] + shape = get_shape(rv, axis=axis, keepdims=keepdims) + return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape)) + + +def normalize_axis(axis: ComplexAxis, ndim: int) -> Literal[0, 1, None]: + """Adapt `axis` parameter passed by Dask to what we support.""" + match axis: + case int() | None: + pass + case (0 | 1,): + axis = axis[0] + case (0, 1) | (1, 0): + axis = None + case _: # pragma: no cover + raise AxisError(axis, ndim) # type: ignore[call-overload] + if axis == 0 and ndim == 1: + return None # dask’s aggregate doesn’t know we don’t accept `axis=0` for 1D arrays + return axis + + +def get_shape(a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Literal[0, 1, None], keepdims: bool) -> tuple[int] | tuple[int, int]: + """Get the output shape of an axis-flattening operation.""" + match keepdims, a.ndim: + case False, 0: + return (1,) + case True, 0: + return (1, 1) + case False, 1: + return (a.size,) + case True, 1: + assert axis is not None + return (1, a.size) if axis == 0 else (a.size, 1) + # pragma: no cover + msg = f"{keepdims=}, {type(a)}" + raise AssertionError(msg) From ddc450bd595796b6bf0b7cc48114e7143601beb8 Mon Sep 17 00:00:00 2001 From: Quentin Blampey Date: Wed, 27 Aug 2025 09:32:49 +0200 Subject: [PATCH 02/12] rename ops -> op --- src/fast_array_utils/stats/_min_max.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/fast_array_utils/stats/_min_max.py b/src/fast_array_utils/stats/_min_max.py index 7397e02..93a62aa 100644 --- a/src/fast_array_utils/stats/_min_max.py +++ b/src/fast_array_utils/stats/_min_max.py @@ -24,25 +24,25 @@ @overload -def min_max(ops: MinMaxOps, x: CpuArray | DiskArray, /, *, axis: None = None, keep_cupy_as_array: bool = False) -> np.number[Any]: ... +def min_max(op: MinMaxOps, x: CpuArray | DiskArray, /, *, axis: None = None, keep_cupy_as_array: bool = False) -> np.number[Any]: ... @overload -def min_max(ops: MinMaxOps, x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> NDArray[Any]: ... +def min_max(op: MinMaxOps, x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> NDArray[Any]: ... @overload -def min_max(ops: MinMaxOps, x: GpuArray, /, *, axis: None = None, keep_cupy_as_array: Literal[False] = False) -> np.number[Any]: ... +def min_max(op: MinMaxOps, x: GpuArray, /, *, axis: None = None, keep_cupy_as_array: Literal[False] = False) -> np.number[Any]: ... @overload -def min_max(ops: MinMaxOps, x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> types.CupyArray: ... +def min_max(op: MinMaxOps, x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> types.CupyArray: ... @overload -def min_max(ops: MinMaxOps, x: GpuArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> types.CupyArray: ... +def min_max(op: MinMaxOps, x: GpuArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> types.CupyArray: ... @overload -def min_max(ops: MinMaxOps, x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ... +def min_max(op: MinMaxOps, x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ... def min_max( - ops: MinMaxOps, + op: MinMaxOps, x: CpuArray | GpuArray | DiskArray | types.DaskArray, /, *, @@ -52,7 +52,7 @@ def min_max( from ._min_max import min_max_ validate_axis(x.ndim, axis) - return min_max_(ops, x, axis=axis, keep_cupy_as_array=keep_cupy_as_array) + return min_max_(op, x, axis=axis, keep_cupy_as_array=keep_cupy_as_array) @singledispatch From 15347f7c1d741de4e867790c430b8c7058f049d9 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 28 Aug 2025 17:22:01 +0200 Subject: [PATCH 03/12] deduplicate and test --- src/fast_array_utils/stats/__init__.py | 39 ++++---- src/fast_array_utils/stats/_mean.py | 5 +- src/fast_array_utils/stats/_min_max.py | 128 +++---------------------- src/fast_array_utils/stats/_sum.py | 81 +--------------- src/fast_array_utils/stats/_typing.py | 43 +++++++++ src/fast_array_utils/stats/_utils.py | 102 ++++++++++++++++++++ tests/test_stats.py | 23 +---- 7 files changed, 187 insertions(+), 234 deletions(-) create mode 100644 src/fast_array_utils/stats/_typing.py create mode 100644 src/fast_array_utils/stats/_utils.py diff --git a/src/fast_array_utils/stats/__init__.py b/src/fast_array_utils/stats/__init__.py index 9f1bae4..3723a16 100644 --- a/src/fast_array_utils/stats/__init__.py +++ b/src/fast_array_utils/stats/__init__.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, cast, overload from .._validation import validate_axis from ..typing import CpuArray, DiskArray, GpuArray # noqa: TC001 @@ -21,6 +21,8 @@ from optype.numpy import ToDType from .. import types + from ._min_max import MinMaxOps + from ._typing import StatFun __all__ = ["is_constant", "mean", "mean_var", "sum"] @@ -131,7 +133,7 @@ def mean( from ._mean import mean_ validate_axis(x.ndim, axis) - return mean_(x, axis=axis, dtype=dtype) # type: ignore[no-any-return] # literally the same type, wtf mypy + return mean_(x, axis=axis, dtype=dtype) @overload @@ -268,25 +270,22 @@ def sum( return sum_(x, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array) -def min( - x: CpuArray | GpuArray | DiskArray | types.DaskArray, - /, - *, - axis: Literal[0, 1, None] = None, - keep_cupy_as_array: bool = False, -) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray: - from ._min_max import min_max +def _mk_min_max(op: MinMaxOps) -> StatFun: + def _min_max( + x: CpuArray | GpuArray | DiskArray | types.DaskArray, + /, + *, + axis: Literal[0, 1, None] = None, + keep_cupy_as_array: bool = False, + ) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray: + from ._min_max import min_max - return min_max("min", x, axis=axis, keep_cupy_as_array=keep_cupy_as_array) + validate_axis(x.ndim, axis) + return min_max(x, op, axis=axis, keep_cupy_as_array=keep_cupy_as_array) + _min_max.__name__ = op + return cast("StatFun", _min_max) -def max( - x: CpuArray | GpuArray | DiskArray | types.DaskArray, - /, - *, - axis: Literal[0, 1, None] = None, - keep_cupy_as_array: bool = False, -) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray: - from ._min_max import min_max - return min_max("max", x, axis=axis, keep_cupy_as_array=keep_cupy_as_array) +min = _mk_min_max("min") +max = _mk_min_max("max") diff --git a/src/fast_array_utils/stats/_mean.py b/src/fast_array_utils/stats/_mean.py index f6dd63f..d697a16 100644 --- a/src/fast_array_utils/stats/_mean.py +++ b/src/fast_array_utils/stats/_mean.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: MPL-2.0 from __future__ import annotations -from typing import TYPE_CHECKING, no_type_check +from typing import TYPE_CHECKING import numpy as np @@ -17,7 +17,6 @@ from ..typing import CpuArray, DiskArray, GpuArray -@no_type_check # mypy is very confused def mean_( x: CpuArray | GpuArray | DiskArray | types.DaskArray, /, @@ -27,4 +26,4 @@ def mean_( ) -> NDArray[np.number[Any]] | np.number[Any] | types.DaskArray: total = sum_(x, axis=axis, dtype=dtype) n = np.prod(x.shape) if axis is None else x.shape[axis] - return total / n + return total / n # type: ignore[call-overload,operator,return-value] diff --git a/src/fast_array_utils/stats/_min_max.py b/src/fast_array_utils/stats/_min_max.py index 93a62aa..240aac5 100644 --- a/src/fast_array_utils/stats/_min_max.py +++ b/src/fast_array_utils/stats/_min_max.py @@ -1,14 +1,13 @@ # SPDX-License-Identifier: MPL-2.0 from __future__ import annotations -from functools import partial, singledispatch -from typing import TYPE_CHECKING, cast, overload +from functools import singledispatch +from typing import TYPE_CHECKING, cast import numpy as np -from numpy.exceptions import AxisError from .. import types -from . import validate_axis +from ._utils import _dask_inner if TYPE_CHECKING: @@ -23,43 +22,11 @@ MinMaxOps = Literal["max", "min"] -@overload -def min_max(op: MinMaxOps, x: CpuArray | DiskArray, /, *, axis: None = None, keep_cupy_as_array: bool = False) -> np.number[Any]: ... -@overload -def min_max(op: MinMaxOps, x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> NDArray[Any]: ... - - -@overload -def min_max(op: MinMaxOps, x: GpuArray, /, *, axis: None = None, keep_cupy_as_array: Literal[False] = False) -> np.number[Any]: ... -@overload -def min_max(op: MinMaxOps, x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> types.CupyArray: ... -@overload -def min_max(op: MinMaxOps, x: GpuArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> types.CupyArray: ... - - -@overload -def min_max(op: MinMaxOps, x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ... - - +@singledispatch def min_max( - op: MinMaxOps, x: CpuArray | GpuArray | DiskArray | types.DaskArray, /, - *, - axis: Literal[0, 1, None] = None, - keep_cupy_as_array: bool = False, -) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray: - from ._min_max import min_max_ - - validate_axis(x.ndim, axis) - return min_max_(op, x, axis=axis, keep_cupy_as_array=keep_cupy_as_array) - - -@singledispatch -def min_max_( op: MinMaxOps, - x: CpuArray | GpuArray | DiskArray | types.DaskArray, - /, *, axis: Literal[0, 1, None] = None, keep_cupy_as_array: bool = False, @@ -73,11 +40,11 @@ def min_max_( return cast("NDArray[Any] | np.number[Any]", getattr(np, op)(x, axis=axis)) -@min_max_.register(types.CupyArray | types.CupyCSMatrix) +@min_max.register(types.CupyArray | types.CupyCSMatrix) def _min_max_cupy( - op: MinMaxOps, x: GpuArray, /, + op: MinMaxOps, *, axis: Literal[0, 1, None] = None, keep_cupy_as_array: bool = False, @@ -86,11 +53,11 @@ def _min_max_cupy( return cast("np.number[Any]", arr.get()[()]) if not keep_cupy_as_array and axis is None else arr.squeeze() -@min_max_.register(types.CSBase) +@min_max.register(types.CSBase) def _min_max_cs( - op: MinMaxOps, x: types.CSBase, /, + op: MinMaxOps, *, axis: Literal[0, 1, None] = None, keep_cupy_as_array: bool = False, @@ -106,85 +73,16 @@ def _min_max_cs( return cast("NDArray[Any] | np.number[Any]", getattr(sp, op)(x, axis=axis)) -@min_max_.register(types.DaskArray) +@min_max.register(types.DaskArray) def _min_max_dask( - op: MinMaxOps, x: types.DaskArray, /, + op: MinMaxOps, *, axis: Literal[0, 1, None] = None, keep_cupy_as_array: bool = False, ) -> types.DaskArray: - import dask.array as da - - if isinstance(x._meta, np.matrix): # pragma: no cover # noqa: SLF001 - msg = "sum/max/min does not support numpy matrices" - raise TypeError(msg) + from . import max, min - rv = da.reduction( - x, - partial(min_max_dask_inner, op), # type: ignore[arg-type] - partial(min_max_dask_inner, op), # pyright: ignore[reportArgumentType] - axis=axis, - meta=np.array([], dtype=x.dtype), - ) - - if axis is not None or ( - isinstance(rv._meta, types.CupyArray) # noqa: SLF001 - and keep_cupy_as_array - ): - return rv - - def to_scalar(a: types.CupyArray | NDArray[Any]) -> np.number[Any]: - if isinstance(a, types.CupyArray): - a = a.get() - return a.reshape(())[()] # type: ignore[return-value] - - return rv.map_blocks(to_scalar, meta=x.dtype.type(0)) # type: ignore[arg-type] - - -def min_max_dask_inner( - op: MinMaxOps, - a: CpuArray | GpuArray, - /, - *, - axis: ComplexAxis = None, - keepdims: bool = False, -) -> NDArray[Any] | types.CupyArray: - axis = normalize_axis(axis, a.ndim) - rv = min_max(op, a, axis=axis, keep_cupy_as_array=True) # type: ignore[misc,arg-type] - shape = get_shape(rv, axis=axis, keepdims=keepdims) - return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape)) - - -def normalize_axis(axis: ComplexAxis, ndim: int) -> Literal[0, 1, None]: - """Adapt `axis` parameter passed by Dask to what we support.""" - match axis: - case int() | None: - pass - case (0 | 1,): - axis = axis[0] - case (0, 1) | (1, 0): - axis = None - case _: # pragma: no cover - raise AxisError(axis, ndim) # type: ignore[call-overload] - if axis == 0 and ndim == 1: - return None # dask’s aggregate doesn’t know we don’t accept `axis=0` for 1D arrays - return axis - - -def get_shape(a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Literal[0, 1, None], keepdims: bool) -> tuple[int] | tuple[int, int]: - """Get the output shape of an axis-flattening operation.""" - match keepdims, a.ndim: - case False, 0: - return (1,) - case True, 0: - return (1, 1) - case False, 1: - return (a.size,) - case True, 1: - assert axis is not None - return (1, a.size) if axis == 0 else (a.size, 1) - # pragma: no cover - msg = f"{keepdims=}, {type(a)}" - raise AssertionError(msg) + fns = {fn.__name__: fn for fn in (min, max)} + return _dask_inner(x, fns[op], axis=axis, keep_cupy_as_array=keep_cupy_as_array) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index e5eb711..9c78bb7 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: MPL-2.0 from __future__ import annotations -from functools import partial, singledispatch +from functools import singledispatch from typing import TYPE_CHECKING, Literal, cast import numpy as np -from numpy.exceptions import AxisError from .. import types +from ._utils import _dask_inner if TYPE_CHECKING: @@ -80,83 +80,10 @@ def _sum_dask( dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False, ) -> types.DaskArray: - import dask.array as da - - if isinstance(x._meta, np.matrix): # pragma: no cover # noqa: SLF001 - msg = "sum does not support numpy matrices" - raise TypeError(msg) + from . import sum if dtype is None: # Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`) dtype = np.zeros(1, dtype=x.dtype).sum().dtype - rv = da.reduction( - x, - sum_dask_inner, # type: ignore[arg-type] - partial(sum_dask_inner, dtype=dtype), # pyright: ignore[reportArgumentType] - axis=axis, - dtype=dtype, - meta=np.array([], dtype=dtype), - ) - - if axis is not None or ( - isinstance(rv._meta, types.CupyArray) # noqa: SLF001 - and keep_cupy_as_array - ): - return rv - - def to_scalar(a: types.CupyArray | NDArray[Any]) -> np.number[Any]: - if isinstance(a, types.CupyArray): - a = a.get() - return a.reshape(())[()] # type: ignore[return-value] - - return rv.map_blocks(to_scalar, meta=x.dtype.type(0)) # type: ignore[arg-type] - - -def sum_dask_inner( - a: CpuArray | GpuArray, - /, - *, - axis: ComplexAxis = None, - dtype: DTypeLike | None = None, - keepdims: bool = False, -) -> NDArray[Any] | types.CupyArray: - from . import sum - - axis = normalize_axis(axis, a.ndim) - rv = sum(a, axis=axis, dtype=dtype, keep_cupy_as_array=True) # type: ignore[misc,arg-type] - shape = get_shape(rv, axis=axis, keepdims=keepdims) - return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape)) - - -def normalize_axis(axis: ComplexAxis, ndim: int) -> Literal[0, 1, None]: - """Adapt `axis` parameter passed by Dask to what we support.""" - match axis: - case int() | None: - pass - case (0 | 1,): - axis = axis[0] - case (0, 1) | (1, 0): - axis = None - case _: # pragma: no cover - raise AxisError(axis, ndim) # type: ignore[call-overload] - if axis == 0 and ndim == 1: - return None # dask’s aggregate doesn’t know we don’t accept `axis=0` for 1D arrays - return axis - - -def get_shape(a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Literal[0, 1, None], keepdims: bool) -> tuple[int] | tuple[int, int]: - """Get the output shape of an axis-flattening operation.""" - match keepdims, a.ndim: - case False, 0: - return (1,) - case True, 0: - return (1, 1) - case False, 1: - return (a.size,) - case True, 1: - assert axis is not None - return (1, a.size) if axis == 0 else (a.size, 1) - # pragma: no cover - msg = f"{keepdims=}, {type(a)}" - raise AssertionError(msg) + return _dask_inner(x, sum, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array) diff --git a/src/fast_array_utils/stats/_typing.py b/src/fast_array_utils/stats/_typing.py new file mode 100644 index 0000000..ce13e33 --- /dev/null +++ b/src/fast_array_utils/stats/_typing.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: MPL-2.0 +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, Protocol, overload + +import numpy as np + +from fast_array_utils import types + +from ..typing import CpuArray, DiskArray, GpuArray + + +if TYPE_CHECKING: + from typing import Any, TypeAlias + + from numpy.typing import NDArray + + +Array: TypeAlias = CpuArray | GpuArray | DiskArray | types.CSDataset | types.DaskArray + +DTypeIn: TypeAlias = np.float32 | np.float64 | np.int32 | np.bool_ +DTypeOut: TypeAlias = np.float32 | np.float64 | np.int64 + +NdAndAx: TypeAlias = tuple[Literal[1], Literal[None]] | tuple[Literal[2], Literal[0, 1, None]] + + +class StatFun(Protocol): + __name__: str + + @overload + def __call__(self, x: CpuArray | DiskArray, /, *, axis: None = None, keep_cupy_as_array: bool = False) -> np.number[Any]: ... + @overload + def __call__(self, x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> NDArray[Any]: ... + + @overload + def __call__(self, x: GpuArray, /, *, axis: None = None, keep_cupy_as_array: Literal[False] = False) -> np.number[Any]: ... + @overload + def __call__(self, x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> types.CupyArray: ... + @overload + def __call__(self, x: GpuArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> types.CupyArray: ... + + @overload + def __call__(self, x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ... diff --git a/src/fast_array_utils/stats/_utils.py b/src/fast_array_utils/stats/_utils.py new file mode 100644 index 0000000..a2eb555 --- /dev/null +++ b/src/fast_array_utils/stats/_utils.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: MPL-2.0 +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING, Literal, cast + +import numpy as np +from numpy.exceptions import AxisError + +from .. import types + + +if TYPE_CHECKING: + from typing import Any, Literal, TypeAlias + + from numpy.typing import DTypeLike, NDArray + + from ..typing import CpuArray, GpuArray + from ._typing import StatFun + + ComplexAxis: TypeAlias = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1, None] + + +__all__ = ["_dask_inner"] + + +def _dask_inner(x: types.DaskArray, op: StatFun, /, *, axis: Literal[0, 1, None], dtype: DTypeLike | None = None, keep_cupy_as_array: bool) -> types.DaskArray: + import dask.array as da + + if isinstance(x._meta, np.matrix): # pragma: no cover # noqa: SLF001 + msg = "sum/max/min does not support numpy matrices" + raise TypeError(msg) + + rv = da.reduction( + x, + partial(_dask_block, op), + partial(_dask_block, op, dtype=dtype), + axis=axis, + dtype=dtype, + meta=np.array([], dtype=dtype), + ) + + if axis is not None or ( + isinstance(rv._meta, types.CupyArray) # noqa: SLF001 + and keep_cupy_as_array + ): + return rv + + def to_scalar(a: types.CupyArray | NDArray[Any]) -> np.number[Any]: + if isinstance(a, types.CupyArray): + a = a.get() + return a.reshape(())[()] # type: ignore[return-value] + + return rv.map_blocks(to_scalar, meta=x.dtype.type(0)) # type: ignore[arg-type] + + +def _dask_block( + op: StatFun, + a: CpuArray | GpuArray, + /, + *, + axis: ComplexAxis = None, + dtype: DTypeLike | None = None, + keepdims: bool = False, +) -> NDArray[Any] | types.CupyArray: + axis = _normalize_axis(axis, a.ndim) + rv = op(a, axis=axis, dtype=dtype, keep_cupy_as_array=True) # type: ignore[misc,call-overload] + shape = _get_shape(rv, axis=axis, keepdims=keepdims) + return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape)) + + +def _normalize_axis(axis: ComplexAxis, ndim: int) -> Literal[0, 1, None]: + """Adapt `axis` parameter passed by Dask to what we support.""" + match axis: + case int() | None: + pass + case (0 | 1,): + axis = axis[0] + case (0, 1) | (1, 0): + axis = None + case _: # pragma: no cover + raise AxisError(axis, ndim) # type: ignore[call-overload] + if axis == 0 and ndim == 1: + return None # dask’s aggregate doesn’t know we don’t accept `axis=0` for 1D arrays + return axis + + +def _get_shape(a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Literal[0, 1, None], keepdims: bool) -> tuple[int] | tuple[int, int]: + """Get the output shape of an axis-flattening operation.""" + match keepdims, a.ndim: + case False, 0: + return (1,) + case True, 0: + return (1, 1) + case False, 1: + return (a.size,) + case True, 1: + assert axis is not None + return (1, a.size) if axis == 0 else (a.size, 1) + # pragma: no cover + msg = f"{keepdims=}, {type(a)}" + raise AssertionError(msg) diff --git a/tests/test_stats.py b/tests/test_stats.py index 8edb826..ac3fdab 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -14,35 +14,20 @@ if TYPE_CHECKING: from collections.abc import Callable - from typing import Any, Literal, Protocol, TypeAlias + from typing import Any, Literal from numpy.typing import NDArray from pytest_codspeed import BenchmarkFixture + from fast_array_utils.stats._typing import Array, DTypeIn, DTypeOut, NdAndAx, StatFun from fast_array_utils.typing import CpuArray, DiskArray, GpuArray from testing.fast_array_utils import ArrayType - Array: TypeAlias = CpuArray | GpuArray | DiskArray | types.CSDataset | types.DaskArray - - DTypeIn = np.float32 | np.float64 | np.int32 | np.bool_ - DTypeOut = np.float32 | np.float64 | np.int64 - - NdAndAx: TypeAlias = tuple[Literal[1], Literal[None]] | tuple[Literal[2], Literal[0, 1, None]] - - class StatFun(Protocol): # noqa: D101 - def __call__( # noqa: D102 - self, - arr: Array, - *, - axis: Literal[0, 1, None] = None, - dtype: type[DTypeOut] | None = None, - ) -> NDArray[Any] | np.number[Any] | types.DaskArray: ... - pytestmark = [pytest.mark.skipif(not find_spec("numba"), reason="numba not installed")] -STAT_FUNCS = [stats.sum, stats.mean, stats.mean_var, stats.is_constant] +STAT_FUNCS = [stats.sum, stats.min, stats.max, stats.mean, stats.mean_var, stats.is_constant] # can’t select these using a category filter ATS_SPARSE_DS = {at for at in SUPPORTED_TYPES if at.mod == "anndata.abc"} @@ -314,4 +299,4 @@ def test_stats_benchmark( arr = array_type.random(shape, dtype=dtype) func(arr, axis=axis) # warmup: numba compile - benchmark(func, arr, axis=axis) + benchmark(func, arr, axis=axis) # type: ignore[arg-type] From fb6e14e356588435646c79a4c9d96e05ab7cef93 Mon Sep 17 00:00:00 2001 From: Quentin Blampey Date: Mon, 6 Oct 2025 11:50:36 +0200 Subject: [PATCH 04/12] deduplication using generic function for sum/max/min --- src/fast_array_utils/stats/__init__.py | 108 ++++-------------- .../stats/{_sum.py => _generic_ops.py} | 63 ++++++---- src/fast_array_utils/stats/_mean.py | 6 +- src/fast_array_utils/stats/_min_max.py | 88 -------------- src/fast_array_utils/stats/_typing.py | 9 +- src/fast_array_utils/stats/_utils.py | 27 +++-- 6 files changed, 91 insertions(+), 210 deletions(-) rename src/fast_array_utils/stats/{_sum.py => _generic_ops.py} (53%) delete mode 100644 src/fast_array_utils/stats/_min_max.py diff --git a/src/fast_array_utils/stats/__init__.py b/src/fast_array_utils/stats/__init__.py index 3723a16..6e57921 100644 --- a/src/fast_array_utils/stats/__init__.py +++ b/src/fast_array_utils/stats/__init__.py @@ -7,10 +7,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING, cast, overload +from typing import TYPE_CHECKING, cast, get_args, overload from .._validation import validate_axis from ..typing import CpuArray, DiskArray, GpuArray # noqa: TC001 +from ._generic_ops import DtypeOps if TYPE_CHECKING: @@ -21,11 +22,11 @@ from optype.numpy import ToDType from .. import types - from ._min_max import MinMaxOps + from ._generic_ops import Ops from ._typing import StatFun -__all__ = ["is_constant", "mean", "mean_var", "sum"] +__all__ = ["is_constant", "max", "mean", "mean_var", "min", "sum"] @overload @@ -35,14 +36,14 @@ def is_constant(x: NDArray[Any] | types.CSBase, /, *, axis: Literal[0, 1]) -> ND @overload def is_constant(x: types.CupyArray, /, *, axis: Literal[0, 1]) -> types.CupyArray: ... @overload -def is_constant(x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None) -> types.DaskArray: ... +def is_constant(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None) -> types.DaskArray: ... def is_constant( x: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray, /, *, - axis: Literal[0, 1, None] = None, + axis: Literal[0, 1] | None = None, ) -> bool | NDArray[np.bool_] | types.CupyArray | types.DaskArray: """Check whether values in array are constant. @@ -82,7 +83,7 @@ def is_constant( # TODO(flying-sheep): support CSDataset (TODO) # https://github.com/scverse/fast-array-utils/issues/52 @overload -def mean(x: CpuArray | GpuArray | DiskArray, /, *, axis: Literal[None] = None, dtype: DTypeLike | None = None) -> np.number[Any]: ... +def mean(x: CpuArray | GpuArray | DiskArray, /, *, axis: None = None, dtype: DTypeLike | None = None) -> np.number[Any]: ... @overload def mean(x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None) -> NDArray[np.number[Any]]: ... @overload @@ -95,7 +96,7 @@ def mean( x: CpuArray | GpuArray | DiskArray | types.DaskArray, /, *, - axis: Literal[0, 1, None] = None, + axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, ) -> NDArray[np.number[Any]] | types.CupyArray | np.number[Any] | types.DaskArray: """Mean over both or one axis. @@ -137,20 +138,20 @@ def mean( @overload -def mean_var(x: CpuArray | GpuArray, /, *, axis: Literal[None] = None, correction: int = 0) -> tuple[np.float64, np.float64]: ... +def mean_var(x: CpuArray | GpuArray, /, *, axis: None = None, correction: int = 0) -> tuple[np.float64, np.float64]: ... @overload def mean_var(x: CpuArray, /, *, axis: Literal[0, 1], correction: int = 0) -> tuple[NDArray[np.float64], NDArray[np.float64]]: ... @overload def mean_var(x: GpuArray, /, *, axis: Literal[0, 1], correction: int = 0) -> tuple[types.CupyArray, types.CupyArray]: ... @overload -def mean_var(x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, correction: int = 0) -> tuple[types.DaskArray, types.DaskArray]: ... +def mean_var(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, correction: int = 0) -> tuple[types.DaskArray, types.DaskArray]: ... def mean_var( x: CpuArray | GpuArray | types.DaskArray, /, *, - axis: Literal[0, 1, None] = None, + axis: Literal[0, 1] | None = None, correction: int = 0, ) -> ( tuple[np.float64, np.float64] @@ -205,87 +206,26 @@ def mean_var( # TODO(flying-sheep): support CSDataset (TODO) # https://github.com/scverse/fast-array-utils/issues/52 -@overload -def sum(x: CpuArray | DiskArray, /, *, axis: None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> np.number[Any]: ... -@overload -def sum(x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> NDArray[Any]: ... - - -@overload -def sum(x: GpuArray, /, *, axis: None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: Literal[False] = False) -> np.number[Any]: ... -@overload -def sum(x: GpuArray, /, *, axis: None, dtype: DTypeLike | None = None, keep_cupy_as_array: Literal[True]) -> types.CupyArray: ... -@overload -def sum(x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> types.CupyArray: ... - - -@overload -def sum(x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ... - - -def sum( - x: CpuArray | GpuArray | DiskArray | types.DaskArray, - /, - *, - axis: Literal[0, 1, None] = None, - dtype: DTypeLike | None = None, - keep_cupy_as_array: bool = False, -) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray: - """Sum over both or one axis. - - Parameters - ---------- - x - Array to sum up. - axis - Axis to reduce over. - - Returns - ------- - If ``axis`` is :data:`None`, then the sum over all elements is returned as a scalar. - Otherwise, the sum over the given axis is returned as a 1D array. - - Example - ------- - >>> import numpy as np - >>> x = np.array([ - ... [0, 1, 2], - ... [0, 0, 0], - ... ]) - >>> sum(x) - 3 - >>> sum(x, axis=0) - array([0, 1, 2]) - >>> sum(x, axis=1) - array([3, 0]) - - See Also - -------- - :func:`numpy.sum` - - """ - from ._sum import sum_ - - validate_axis(x.ndim, axis) - return sum_(x, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array) - - -def _mk_min_max(op: MinMaxOps) -> StatFun: - def _min_max( +def _mk_generic_op(op: Ops) -> StatFun: + def _generic_op( x: CpuArray | GpuArray | DiskArray | types.DaskArray, /, *, - axis: Literal[0, 1, None] = None, + axis: Literal[0, 1] | None = None, + dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False, ) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray: - from ._min_max import min_max + from ._generic_ops import generic_op + + assert dtype is None or op in get_args(DtypeOps), f"`dtype` is not supported for operation '{op}'" validate_axis(x.ndim, axis) - return min_max(x, op, axis=axis, keep_cupy_as_array=keep_cupy_as_array) + return generic_op(x, op, axis=axis, keep_cupy_as_array=keep_cupy_as_array, dtype=dtype) - _min_max.__name__ = op - return cast("StatFun", _min_max) + _generic_op.__name__ = op + return cast("StatFun", _generic_op) -min = _mk_min_max("min") -max = _mk_min_max("max") +min = _mk_generic_op("min") +max = _mk_generic_op("max") +sum = _mk_generic_op("sum") diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_generic_ops.py similarity index 53% rename from src/fast_array_utils/stats/_sum.py rename to src/fast_array_utils/stats/_generic_ops.py index 9c78bb7..cdcd485 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_generic_ops.py @@ -2,11 +2,12 @@ from __future__ import annotations from functools import singledispatch -from typing import TYPE_CHECKING, Literal, cast +from typing import TYPE_CHECKING, cast, get_args import numpy as np from .. import types +from ._typing import DtypeOps from ._utils import _dask_inner @@ -16,16 +17,29 @@ from numpy.typing import DTypeLike, NDArray from ..typing import CpuArray, DiskArray, GpuArray + from ._typing import Ops - ComplexAxis: TypeAlias = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1, None] + ComplexAxis: TypeAlias = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1] | None + + +def _run_numpy_op( + x: CpuArray | GpuArray | DiskArray | types.DaskArray, + op: Ops, + *, + axis: Literal[0, 1] | None = None, + dtype: DTypeLike | None = None, +) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray: + kwargs = {"dtype": dtype} if op in get_args(DtypeOps) else {} + return getattr(np, op)(x, axis=axis, **kwargs) @singledispatch -def sum_( +def generic_op( x: CpuArray | GpuArray | DiskArray | types.DaskArray, /, + op: Ops, *, - axis: Literal[0, 1, None] = None, + axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False, ) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray: @@ -33,30 +47,32 @@ def sum_( if TYPE_CHECKING: # these are never passed to this fallback function, but `singledispatch` wants them assert not isinstance(x, types.CSBase | types.DaskArray | types.CupyArray | types.CupyCSMatrix) - # np.sum supports these, but doesn’t know it. (TODO: test cupy) + # np supports these, but doesn’t know it. (TODO: test cupy) assert not isinstance(x, types.ZarrArray | types.H5Dataset) - return cast("NDArray[Any] | np.number[Any]", np.sum(x, axis=axis, dtype=dtype)) + return cast("NDArray[Any] | np.number[Any]", _run_numpy_op(x, op, axis=axis, dtype=dtype)) -@sum_.register(types.CupyArray | types.CupyCSMatrix) -def _sum_cupy( +@generic_op.register(types.CupyArray | types.CupyCSMatrix) +def _generic_op_cupy( x: GpuArray, /, + op: Ops, *, - axis: Literal[0, 1, None] = None, + axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False, ) -> types.CupyArray | np.number[Any]: - arr = cast("types.CupyArray", np.sum(x, axis=axis, dtype=dtype)) + arr = cast("types.CupyArray", _run_numpy_op(x, op, axis=axis, dtype=dtype)) return cast("np.number[Any]", arr.get()[()]) if not keep_cupy_as_array and axis is None else arr.squeeze() -@sum_.register(types.CSBase) -def _sum_cs( +@generic_op.register(types.CSBase) +def _generic_op_cs( x: types.CSBase, /, + op: Ops, *, - axis: Literal[0, 1, None] = None, + axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False, ) -> NDArray[Any] | np.number[Any]: @@ -66,24 +82,25 @@ def _sum_cs( if isinstance(x, types.CSMatrix): x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x) + kwargs = {"dtype": dtype} if op in get_args(DtypeOps) else {} + if axis is None: - return cast("np.number[Any]", x.data.sum(dtype=dtype)) - return cast("NDArray[Any] | np.number[Any]", x.sum(axis=axis, dtype=dtype)) + return cast("np.number[Any]", getattr(sp, op)(x, **kwargs)) + return cast("NDArray[Any] | np.number[Any]", getattr(sp, op)(x, axis=axis, **kwargs)) -@sum_.register(types.DaskArray) -def _sum_dask( +@generic_op.register(types.DaskArray) +def _generic_op_dask( x: types.DaskArray, /, + op: Ops, *, - axis: Literal[0, 1, None] = None, + axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False, ) -> types.DaskArray: - from . import sum - - if dtype is None: + if op in get_args(DtypeOps) and dtype is None: # Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`) - dtype = np.zeros(1, dtype=x.dtype).sum().dtype + dtype = getattr(np, op)(np.zeros(1, dtype=x.dtype)).dtype - return _dask_inner(x, sum, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array) + return _dask_inner(x, op, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array) diff --git a/src/fast_array_utils/stats/_mean.py b/src/fast_array_utils/stats/_mean.py index d697a16..db4e352 100644 --- a/src/fast_array_utils/stats/_mean.py +++ b/src/fast_array_utils/stats/_mean.py @@ -5,7 +5,7 @@ import numpy as np -from ._sum import sum_ +from . import sum if TYPE_CHECKING: @@ -21,9 +21,9 @@ def mean_( x: CpuArray | GpuArray | DiskArray | types.DaskArray, /, *, - axis: Literal[0, 1, None] = None, + axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, ) -> NDArray[np.number[Any]] | np.number[Any] | types.DaskArray: - total = sum_(x, axis=axis, dtype=dtype) + total = sum(x, axis=axis, dtype=dtype) n = np.prod(x.shape) if axis is None else x.shape[axis] return total / n # type: ignore[call-overload,operator,return-value] diff --git a/src/fast_array_utils/stats/_min_max.py b/src/fast_array_utils/stats/_min_max.py deleted file mode 100644 index 240aac5..0000000 --- a/src/fast_array_utils/stats/_min_max.py +++ /dev/null @@ -1,88 +0,0 @@ -# SPDX-License-Identifier: MPL-2.0 -from __future__ import annotations - -from functools import singledispatch -from typing import TYPE_CHECKING, cast - -import numpy as np - -from .. import types -from ._utils import _dask_inner - - -if TYPE_CHECKING: - from typing import Any, Literal, TypeAlias - - from numpy.typing import NDArray - - from ..typing import CpuArray, DiskArray, GpuArray - - ComplexAxis: TypeAlias = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1, None] - - MinMaxOps = Literal["max", "min"] - - -@singledispatch -def min_max( - x: CpuArray | GpuArray | DiskArray | types.DaskArray, - /, - op: MinMaxOps, - *, - axis: Literal[0, 1, None] = None, - keep_cupy_as_array: bool = False, -) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray: - del keep_cupy_as_array - if TYPE_CHECKING: - # these are never passed to this fallback function, but `singledispatch` wants them - assert not isinstance(x, types.CSBase | types.DaskArray | types.CupyArray | types.CupyCSMatrix) - # np supports these, but doesn’t know it. (TODO: test cupy) - assert not isinstance(x, types.ZarrArray | types.H5Dataset) - return cast("NDArray[Any] | np.number[Any]", getattr(np, op)(x, axis=axis)) - - -@min_max.register(types.CupyArray | types.CupyCSMatrix) -def _min_max_cupy( - x: GpuArray, - /, - op: MinMaxOps, - *, - axis: Literal[0, 1, None] = None, - keep_cupy_as_array: bool = False, -) -> types.CupyArray | np.number[Any]: - arr = cast("types.CupyArray", getattr(np, op)(x, axis=axis)) - return cast("np.number[Any]", arr.get()[()]) if not keep_cupy_as_array and axis is None else arr.squeeze() - - -@min_max.register(types.CSBase) -def _min_max_cs( - x: types.CSBase, - /, - op: MinMaxOps, - *, - axis: Literal[0, 1, None] = None, - keep_cupy_as_array: bool = False, -) -> NDArray[Any] | np.number[Any]: - del keep_cupy_as_array - import scipy.sparse as sp - - if isinstance(x, types.CSMatrix): - x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x) - - if axis is None: - return cast("np.number[Any]", getattr(sp, op)(x)) - return cast("NDArray[Any] | np.number[Any]", getattr(sp, op)(x, axis=axis)) - - -@min_max.register(types.DaskArray) -def _min_max_dask( - x: types.DaskArray, - /, - op: MinMaxOps, - *, - axis: Literal[0, 1, None] = None, - keep_cupy_as_array: bool = False, -) -> types.DaskArray: - from . import max, min - - fns = {fn.__name__: fn for fn in (min, max)} - return _dask_inner(x, fns[op], axis=axis, keep_cupy_as_array=keep_cupy_as_array) diff --git a/src/fast_array_utils/stats/_typing.py b/src/fast_array_utils/stats/_typing.py index ce13e33..fa48d8c 100644 --- a/src/fast_array_utils/stats/_typing.py +++ b/src/fast_array_utils/stats/_typing.py @@ -21,7 +21,7 @@ DTypeIn: TypeAlias = np.float32 | np.float64 | np.int32 | np.bool_ DTypeOut: TypeAlias = np.float32 | np.float64 | np.int64 -NdAndAx: TypeAlias = tuple[Literal[1], Literal[None]] | tuple[Literal[2], Literal[0, 1, None]] +NdAndAx: TypeAlias = tuple[Literal[1], None] | tuple[Literal[2], Literal[0, 1] | None] class StatFun(Protocol): @@ -40,4 +40,9 @@ def __call__(self, x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[Tr def __call__(self, x: GpuArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> types.CupyArray: ... @overload - def __call__(self, x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ... + def __call__(self, x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ... + + +NoDtypeOps = Literal["max", "min"] +DtypeOps = Literal["sum"] +Ops: TypeAlias = NoDtypeOps | DtypeOps diff --git a/src/fast_array_utils/stats/_utils.py b/src/fast_array_utils/stats/_utils.py index a2eb555..5509324 100644 --- a/src/fast_array_utils/stats/_utils.py +++ b/src/fast_array_utils/stats/_utils.py @@ -2,12 +2,13 @@ from __future__ import annotations from functools import partial -from typing import TYPE_CHECKING, Literal, cast +from typing import TYPE_CHECKING, Literal, cast, get_args import numpy as np from numpy.exceptions import AxisError from .. import types +from ._typing import DtypeOps if TYPE_CHECKING: @@ -16,28 +17,30 @@ from numpy.typing import DTypeLike, NDArray from ..typing import CpuArray, GpuArray - from ._typing import StatFun + from ._typing import Ops - ComplexAxis: TypeAlias = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1, None] + ComplexAxis: TypeAlias = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1] | None __all__ = ["_dask_inner"] -def _dask_inner(x: types.DaskArray, op: StatFun, /, *, axis: Literal[0, 1, None], dtype: DTypeLike | None = None, keep_cupy_as_array: bool) -> types.DaskArray: +def _dask_inner(x: types.DaskArray, op: Ops, /, *, axis: Literal[0, 1] | None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool) -> types.DaskArray: import dask.array as da if isinstance(x._meta, np.matrix): # pragma: no cover # noqa: SLF001 msg = "sum/max/min does not support numpy matrices" raise TypeError(msg) + res_dtype = dtype if op in get_args(DtypeOps) else x.dtype + rv = da.reduction( x, partial(_dask_block, op), partial(_dask_block, op, dtype=dtype), axis=axis, - dtype=dtype, - meta=np.array([], dtype=dtype), + dtype=res_dtype, + meta=np.array([], dtype=res_dtype), ) if axis is not None or ( @@ -55,7 +58,7 @@ def to_scalar(a: types.CupyArray | NDArray[Any]) -> np.number[Any]: def _dask_block( - op: StatFun, + op: Ops, a: CpuArray | GpuArray, /, *, @@ -63,13 +66,17 @@ def _dask_block( dtype: DTypeLike | None = None, keepdims: bool = False, ) -> NDArray[Any] | types.CupyArray: + from . import max, min, sum + + fns = {fn.__name__: fn for fn in (min, max, sum)} + axis = _normalize_axis(axis, a.ndim) - rv = op(a, axis=axis, dtype=dtype, keep_cupy_as_array=True) # type: ignore[misc,call-overload] + rv = fns[op](a, axis=axis, dtype=dtype, keep_cupy_as_array=True) # type: ignore[misc,call-overload] shape = _get_shape(rv, axis=axis, keepdims=keepdims) return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape)) -def _normalize_axis(axis: ComplexAxis, ndim: int) -> Literal[0, 1, None]: +def _normalize_axis(axis: ComplexAxis, ndim: int) -> Literal[0, 1] | None: """Adapt `axis` parameter passed by Dask to what we support.""" match axis: case int() | None: @@ -85,7 +92,7 @@ def _normalize_axis(axis: ComplexAxis, ndim: int) -> Literal[0, 1, None]: return axis -def _get_shape(a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Literal[0, 1, None], keepdims: bool) -> tuple[int] | tuple[int, int]: +def _get_shape(a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Literal[0, 1] | None, keepdims: bool) -> tuple[int] | tuple[int, int]: """Get the output shape of an axis-flattening operation.""" match keepdims, a.ndim: case False, 0: From f99bcd3e0aca4f9e6a3749ca1e31f61294014312 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 13 Oct 2025 11:04:57 +0200 Subject: [PATCH 05/12] correct statfun --- src/fast_array_utils/stats/__init__.py | 12 ++++++++--- src/fast_array_utils/stats/_typing.py | 29 ++++++++++++++++++++++++-- tests/test_stats.py | 6 +++--- 3 files changed, 39 insertions(+), 8 deletions(-) diff --git a/src/fast_array_utils/stats/__init__.py b/src/fast_array_utils/stats/__init__.py index 2791976..027bfa5 100644 --- a/src/fast_array_utils/stats/__init__.py +++ b/src/fast_array_utils/stats/__init__.py @@ -23,7 +23,7 @@ from .. import types from ._generic_ops import Ops - from ._typing import StatFun + from ._typing import NoDtypeOps, StatFunDtype, StatFunNoDtype __all__ = ["is_constant", "max", "mean", "mean_var", "min", "sum"] @@ -204,9 +204,15 @@ def mean_var( return mean_var_(x, axis=axis, correction=correction) # type: ignore[no-any-return] +@overload +def _mk_generic_op(op: NoDtypeOps) -> StatFunNoDtype: ... +@overload +def _mk_generic_op(op: DtypeOps) -> StatFunDtype: ... + + # TODO(flying-sheep): support CSDataset (TODO) # https://github.com/scverse/fast-array-utils/issues/52 -def _mk_generic_op(op: Ops) -> StatFun: +def _mk_generic_op(op: Ops) -> StatFunNoDtype | StatFunDtype: def _generic_op( x: CpuArray | GpuArray | DiskArray | types.DaskArray, /, @@ -223,7 +229,7 @@ def _generic_op( return generic_op(x, op, axis=axis, keep_cupy_as_array=keep_cupy_as_array, dtype=dtype) _generic_op.__name__ = op - return cast("StatFun", _generic_op) + return cast("StatFunNoDtype | StatFunDtype", _generic_op) min = _mk_generic_op("min") diff --git a/src/fast_array_utils/stats/_typing.py b/src/fast_array_utils/stats/_typing.py index fa48d8c..318b905 100644 --- a/src/fast_array_utils/stats/_typing.py +++ b/src/fast_array_utils/stats/_typing.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from typing import Any, TypeAlias - from numpy.typing import NDArray + from numpy.typing import DTypeLike, NDArray Array: TypeAlias = CpuArray | GpuArray | DiskArray | types.CSDataset | types.DaskArray @@ -24,7 +24,7 @@ NdAndAx: TypeAlias = tuple[Literal[1], None] | tuple[Literal[2], Literal[0, 1] | None] -class StatFun(Protocol): +class StatFunNoDtype(Protocol): __name__: str @overload @@ -43,6 +43,31 @@ def __call__(self, x: GpuArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: b def __call__(self, x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ... +class StatFunDtype(Protocol): + __name__: str + + @overload + def __call__( + self, x: CpuArray | DiskArray, /, *, axis: None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False + ) -> np.number[Any]: ... + @overload + def __call__( + self, x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False + ) -> NDArray[Any]: ... + + @overload + def __call__(self, x: GpuArray, /, *, axis: None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: Literal[False] = False) -> np.number[Any]: ... + @overload + def __call__(self, x: GpuArray, /, *, axis: None, dtype: DTypeLike | None = None, keep_cupy_as_array: Literal[True]) -> types.CupyArray: ... + @overload + def __call__(self, x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> types.CupyArray: ... + + @overload + def __call__( + self, x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False + ) -> types.DaskArray: ... + + NoDtypeOps = Literal["max", "min"] DtypeOps = Literal["sum"] Ops: TypeAlias = NoDtypeOps | DtypeOps diff --git a/tests/test_stats.py b/tests/test_stats.py index c910490..94c2310 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -19,7 +19,7 @@ from numpy.typing import NDArray from pytest_codspeed import BenchmarkFixture - from fast_array_utils.stats._typing import Array, DTypeIn, DTypeOut, NdAndAx, StatFun + from fast_array_utils.stats._typing import Array, DTypeIn, DTypeOut, NdAndAx, StatFunNoDtype from fast_array_utils.typing import CpuArray, DiskArray, GpuArray from testing.fast_array_utils import ArrayType @@ -92,7 +92,7 @@ def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]: @pytest.mark.array_type(skip={*ATS_SPARSE_DS, Flags.Matrix}) @pytest.mark.parametrize("func", STAT_FUNCS) @pytest.mark.parametrize(("ndim", "axis"), [(1, 0), (2, 3), (2, -1)], ids=["1d-ax0", "2d-ax3", "2d-axneg"]) -def test_ndim_error(array_type: ArrayType[Array], func: StatFun, ndim: Literal[1, 2], axis: Literal[0, 1] | None) -> None: +def test_ndim_error(array_type: ArrayType[Array], func: StatFunNoDtype, ndim: Literal[1, 2], axis: Literal[0, 1] | None) -> None: check_ndim(array_type, ndim) # not using the fixture because we don’t need to test multiple dtypes np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) @@ -290,7 +290,7 @@ def test_dask_constant_blocks(dask_viz: Callable[[object], None], array_type: Ar @pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32]) def test_stats_benchmark( benchmark: BenchmarkFixture, - func: StatFun, + func: StatFunNoDtype, array_type: ArrayType[CpuArray, None], axis: Literal[0, 1] | None, dtype: type[np.float32 | np.float64], From dc43d44567e78751a380fc5760e83f301fee0899 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 13 Oct 2025 11:27:02 +0200 Subject: [PATCH 06/12] fix sparse tests --- src/fast_array_utils/stats/_generic_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/fast_array_utils/stats/_generic_ops.py b/src/fast_array_utils/stats/_generic_ops.py index cdcd485..b94f506 100644 --- a/src/fast_array_utils/stats/_generic_ops.py +++ b/src/fast_array_utils/stats/_generic_ops.py @@ -30,7 +30,7 @@ def _run_numpy_op( dtype: DTypeLike | None = None, ) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray: kwargs = {"dtype": dtype} if op in get_args(DtypeOps) else {} - return getattr(np, op)(x, axis=axis, **kwargs) + return getattr(np, op)(x, axis=axis, **kwargs) # type: ignore[no-any-return] @singledispatch @@ -85,8 +85,8 @@ def _generic_op_cs( kwargs = {"dtype": dtype} if op in get_args(DtypeOps) else {} if axis is None: - return cast("np.number[Any]", getattr(sp, op)(x, **kwargs)) - return cast("NDArray[Any] | np.number[Any]", getattr(sp, op)(x, axis=axis, **kwargs)) + return cast("np.number[Any]", getattr(x, op)(**kwargs)) + return cast("NDArray[Any] | np.number[Any]", getattr(x, op)(axis=axis, **kwargs)) @generic_op.register(types.DaskArray) From 51e8e3c9ab97fceb401dcb26321263640b26ad55 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Wed, 15 Oct 2025 18:28:12 +0200 Subject: [PATCH 07/12] fix performance --- src/fast_array_utils/stats/_generic_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fast_array_utils/stats/_generic_ops.py b/src/fast_array_utils/stats/_generic_ops.py index cd20e8b..62c3876 100644 --- a/src/fast_array_utils/stats/_generic_ops.py +++ b/src/fast_array_utils/stats/_generic_ops.py @@ -85,7 +85,7 @@ def _generic_op_cs( kwargs = {"dtype": dtype} if op in get_args(DtypeOps) else {} if axis is None: - return cast("np.number[Any]", getattr(x, op)(**kwargs)) + return cast("np.number[Any]", getattr(x.data, op)(**kwargs)) if TYPE_CHECKING: # scipy-stubs thinks e.g. "int64" is invalid, which isn’t true assert isinstance(dtype, np.dtype | type | None) # convert to array so dimensions collapse as expected From 05aac3b149450bc015ecc56c11e7fa2bdd52fb6c Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 28 Oct 2025 10:00:55 +0100 Subject: [PATCH 08/12] mypy is useless --- src/fast_array_utils/stats/__init__.py | 172 ++++++++++++++++++++- src/fast_array_utils/stats/_generic_ops.py | 2 +- src/fast_array_utils/stats/_mean.py | 4 +- src/fast_array_utils/stats/_typing.py | 4 +- 4 files changed, 174 insertions(+), 8 deletions(-) diff --git a/src/fast_array_utils/stats/__init__.py b/src/fast_array_utils/stats/__init__.py index 027bfa5..c6c2f78 100644 --- a/src/fast_array_utils/stats/__init__.py +++ b/src/fast_array_utils/stats/__init__.py @@ -232,6 +232,172 @@ def _generic_op( return cast("StatFunNoDtype | StatFunDtype", _generic_op) -min = _mk_generic_op("min") -max = _mk_generic_op("max") -sum = _mk_generic_op("sum") +_min = _mk_generic_op("min") +_max = _mk_generic_op("max") +_sum = _mk_generic_op("sum") + + +@overload +def min(x: CpuArray | DiskArray, /, *, axis: None = None, keep_cupy_as_array: bool = False) -> np.number[Any]: ... +@overload +def min(x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> NDArray[Any]: ... +@overload +def min(x: GpuArray, /, *, axis: None = None, keep_cupy_as_array: Literal[False] = False) -> np.number[Any]: ... +@overload +def min(x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> types.CupyArray: ... +@overload +def min(x: GpuArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> types.CupyArray: ... +@overload +def min(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ... +def min( + x: CpuArray | GpuArray | DiskArray | types.DaskArray, + /, + *, + axis: Literal[0, 1] | None = None, + keep_cupy_as_array: bool = False, +) -> object: + """Find the minimum along both or one axis. + + Parameters + ---------- + x + Array to find the minimum(s) in. + axis + Axis to reduce over. + + Returns + ------- + If ``axis`` is :data:`None`, then the minimum element is returned as a scalar. + Otherwise, the minimum along the given axis is returned as a 1D array. + + Example + ------- + >>> import numpy as np + >>> x = np.array([ + ... [0, 1, 2], + ... [1, 1, 1], + ... ]) + >>> min(x) + 2 + >>> min(x, axis=0) + array([0, 1, 1]) + >>> min(x, axis=1) + array([0, 1]) + + See Also + -------- + :func:`numpy.min` + + """ + return _min(x, axis=axis, keep_cupy_as_array=keep_cupy_as_array) # type: ignore[misc,arg-type] + + +@overload +def max(x: CpuArray | DiskArray, /, *, axis: None = None, keep_cupy_as_array: bool = False) -> np.number[Any]: ... +@overload +def max(x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> NDArray[Any]: ... +@overload +def max(x: GpuArray, /, *, axis: None = None, keep_cupy_as_array: Literal[False] = False) -> np.number[Any]: ... +@overload +def max(x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> types.CupyArray: ... +@overload +def max(x: GpuArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> types.CupyArray: ... +@overload +def max(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ... +def max( + x: CpuArray | GpuArray | DiskArray | types.DaskArray, + /, + *, + axis: Literal[0, 1] | None = None, + keep_cupy_as_array: bool = False, +) -> object: + """Find the maximum along both or one axis. + + Parameters + ---------- + x + Array to find the maximum(s) in. + axis + Axis to reduce over. + + Returns + ------- + If ``axis`` is :data:`None`, then the maximum element is returned as a scalar. + Otherwise, the maximum along the given axis is returned as a 1D array. + + Example + ------- + >>> import numpy as np + >>> x = np.array([ + ... [0, 1, 2], + ... [0, 0, 0], + ... ]) + >>> max(x) + 2 + >>> sum(x, axis=0) + array([0, 1, 2]) + >>> sum(x, axis=1) + array([2, 0]) + + See Also + -------- + :func:`numpy.max` + + """ + return _max(x, axis=axis, keep_cupy_as_array=keep_cupy_as_array) # type: ignore[misc,arg-type] + + +@overload +def sum(x: CpuArray | DiskArray, /, *, axis: None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> np.number[Any]: ... +@overload +def sum(x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> NDArray[Any]: ... +@overload +def sum(x: GpuArray, /, *, axis: None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: Literal[False] = False) -> np.number[Any]: ... +@overload +def sum(x: GpuArray, /, *, axis: None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: Literal[True]) -> types.CupyArray: ... +@overload +def sum(x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> types.CupyArray: ... +@overload +def sum(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ... +def sum( + x: CpuArray | GpuArray | DiskArray | types.DaskArray, + /, + *, + axis: Literal[0, 1] | None = None, + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, +) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray: + """Sum over both or one axis. + + Parameters + ---------- + x + Array to sum up. + axis + Axis to reduce over. + + Returns + ------- + If ``axis`` is :data:`None`, then the sum over all elements is returned as a scalar. + Otherwise, the sum over the given axis is returned as a 1D array. + + Example + ------- + >>> import numpy as np + >>> x = np.array([ + ... [0, 1, 2], + ... [0, 0, 0], + ... ]) + >>> sum(x) + 3 + >>> sum(x, axis=0) + array([0, 1, 2]) + >>> sum(x, axis=1) + array([3, 0]) + + See Also + -------- + :func:`numpy.sum` + + """ + return _sum(x, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array) # type: ignore[misc,arg-type] diff --git a/src/fast_array_utils/stats/_generic_ops.py b/src/fast_array_utils/stats/_generic_ops.py index 62c3876..d342517 100644 --- a/src/fast_array_utils/stats/_generic_ops.py +++ b/src/fast_array_utils/stats/_generic_ops.py @@ -89,7 +89,7 @@ def _generic_op_cs( if TYPE_CHECKING: # scipy-stubs thinks e.g. "int64" is invalid, which isn’t true assert isinstance(dtype, np.dtype | type | None) # convert to array so dimensions collapse as expected - x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **kwargs) + x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **kwargs) # type: ignore[call-overload] return cast("NDArray[Any] | np.number[Any]", getattr(x, op)(axis=axis)) diff --git a/src/fast_array_utils/stats/_mean.py b/src/fast_array_utils/stats/_mean.py index cd754ef..ba08164 100644 --- a/src/fast_array_utils/stats/_mean.py +++ b/src/fast_array_utils/stats/_mean.py @@ -24,6 +24,6 @@ def mean_( axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, ) -> NDArray[np.number[Any]] | np.number[Any] | types.DaskArray: - total = sum(x, axis=axis, dtype=dtype) + total = sum(x, axis=axis, dtype=dtype) # type: ignore[misc,arg-type] n = np.prod(x.shape) if axis is None else x.shape[axis] - return total / n # type: ignore[operator,return-value] + return total / n # type: ignore[no-any-return] diff --git a/src/fast_array_utils/stats/_typing.py b/src/fast_array_utils/stats/_typing.py index 318b905..6187aa6 100644 --- a/src/fast_array_utils/stats/_typing.py +++ b/src/fast_array_utils/stats/_typing.py @@ -35,7 +35,7 @@ def __call__(self, x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], keep_cupy @overload def __call__(self, x: GpuArray, /, *, axis: None = None, keep_cupy_as_array: Literal[False] = False) -> np.number[Any]: ... @overload - def __call__(self, x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> types.CupyArray: ... + def __call__(self, x: GpuArray, /, *, axis: None = None, keep_cupy_as_array: Literal[True]) -> types.CupyArray: ... @overload def __call__(self, x: GpuArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> types.CupyArray: ... @@ -58,7 +58,7 @@ def __call__( @overload def __call__(self, x: GpuArray, /, *, axis: None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: Literal[False] = False) -> np.number[Any]: ... @overload - def __call__(self, x: GpuArray, /, *, axis: None, dtype: DTypeLike | None = None, keep_cupy_as_array: Literal[True]) -> types.CupyArray: ... + def __call__(self, x: GpuArray, /, *, axis: None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: Literal[True]) -> types.CupyArray: ... @overload def __call__(self, x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> types.CupyArray: ... From eb06f9bb0fac84e7514234073d8da40fdc1407aa Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 28 Oct 2025 10:45:34 +0100 Subject: [PATCH 09/12] oops --- src/fast_array_utils/stats/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fast_array_utils/stats/__init__.py b/src/fast_array_utils/stats/__init__.py index c6c2f78..2b69d9a 100644 --- a/src/fast_array_utils/stats/__init__.py +++ b/src/fast_array_utils/stats/__init__.py @@ -334,9 +334,9 @@ def max( ... ]) >>> max(x) 2 - >>> sum(x, axis=0) + >>> max(x, axis=0) array([0, 1, 2]) - >>> sum(x, axis=1) + >>> max(x, axis=1) array([2, 0]) See Also From 17e3b7c99662b2cdfcd93594cda6d99c2cdbbd17 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 28 Oct 2025 10:52:05 +0100 Subject: [PATCH 10/12] simplify --- src/fast_array_utils/stats/__init__.py | 6 ++-- src/fast_array_utils/stats/_typing.py | 46 +++++++------------------- tests/test_stats.py | 2 +- 3 files changed, 16 insertions(+), 38 deletions(-) diff --git a/src/fast_array_utils/stats/__init__.py b/src/fast_array_utils/stats/__init__.py index 2b69d9a..e525220 100644 --- a/src/fast_array_utils/stats/__init__.py +++ b/src/fast_array_utils/stats/__init__.py @@ -289,7 +289,7 @@ def min( :func:`numpy.min` """ - return _min(x, axis=axis, keep_cupy_as_array=keep_cupy_as_array) # type: ignore[misc,arg-type] + return _min(x, axis=axis, keep_cupy_as_array=keep_cupy_as_array) @overload @@ -344,7 +344,7 @@ def max( :func:`numpy.max` """ - return _max(x, axis=axis, keep_cupy_as_array=keep_cupy_as_array) # type: ignore[misc,arg-type] + return _max(x, axis=axis, keep_cupy_as_array=keep_cupy_as_array) @overload @@ -400,4 +400,4 @@ def sum( :func:`numpy.sum` """ - return _sum(x, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array) # type: ignore[misc,arg-type] + return _sum(x, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array) diff --git a/src/fast_array_utils/stats/_typing.py b/src/fast_array_utils/stats/_typing.py index 6187aa6..e8b0b65 100644 --- a/src/fast_array_utils/stats/_typing.py +++ b/src/fast_array_utils/stats/_typing.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: MPL-2.0 from __future__ import annotations -from typing import TYPE_CHECKING, Literal, Protocol, overload +from typing import TYPE_CHECKING, Literal, Protocol import numpy as np @@ -27,45 +27,23 @@ class StatFunNoDtype(Protocol): __name__: str - @overload - def __call__(self, x: CpuArray | DiskArray, /, *, axis: None = None, keep_cupy_as_array: bool = False) -> np.number[Any]: ... - @overload - def __call__(self, x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> NDArray[Any]: ... - - @overload - def __call__(self, x: GpuArray, /, *, axis: None = None, keep_cupy_as_array: Literal[False] = False) -> np.number[Any]: ... - @overload - def __call__(self, x: GpuArray, /, *, axis: None = None, keep_cupy_as_array: Literal[True]) -> types.CupyArray: ... - @overload - def __call__(self, x: GpuArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> types.CupyArray: ... - - @overload - def __call__(self, x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ... + def __call__( + self, x: CpuArray | GpuArray | DiskArray | types.DaskArray, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False + ) -> types.DaskArray: ... class StatFunDtype(Protocol): __name__: str - @overload - def __call__( - self, x: CpuArray | DiskArray, /, *, axis: None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False - ) -> np.number[Any]: ... - @overload def __call__( - self, x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False - ) -> NDArray[Any]: ... - - @overload - def __call__(self, x: GpuArray, /, *, axis: None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: Literal[False] = False) -> np.number[Any]: ... - @overload - def __call__(self, x: GpuArray, /, *, axis: None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: Literal[True]) -> types.CupyArray: ... - @overload - def __call__(self, x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> types.CupyArray: ... - - @overload - def __call__( - self, x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False - ) -> types.DaskArray: ... + self, + x: CpuArray | GpuArray | DiskArray | types.DaskArray, + /, + *, + axis: Literal[0, 1] | None = None, + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, + ) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray: ... NoDtypeOps = Literal["max", "min"] diff --git a/tests/test_stats.py b/tests/test_stats.py index b508714..849df4d 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -358,4 +358,4 @@ def test_stats_benchmark( arr = array_type.random(shape, dtype=dtype) func(arr, axis=axis) # warmup: numba compile - benchmark(func, arr, axis=axis) # type: ignore[arg-type] + benchmark(func, arr, axis=axis) From e35a4367e29db2ea27605689aa6f2dffcdab9373 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 28 Oct 2025 10:54:32 +0100 Subject: [PATCH 11/12] oops2 --- src/fast_array_utils/stats/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fast_array_utils/stats/__init__.py b/src/fast_array_utils/stats/__init__.py index e525220..e87bf29 100644 --- a/src/fast_array_utils/stats/__init__.py +++ b/src/fast_array_utils/stats/__init__.py @@ -278,7 +278,7 @@ def min( ... [1, 1, 1], ... ]) >>> min(x) - 2 + 0 >>> min(x, axis=0) array([0, 1, 1]) >>> min(x, axis=1) From 9a30eb11f420a0466d423267bcf389c5a17d97d4 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 28 Oct 2025 11:02:19 +0100 Subject: [PATCH 12/12] coverage --- src/fast_array_utils/stats/_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/fast_array_utils/stats/_utils.py b/src/fast_array_utils/stats/_utils.py index cd634c7..7d1fcbb 100644 --- a/src/fast_array_utils/stats/_utils.py +++ b/src/fast_array_utils/stats/_utils.py @@ -79,7 +79,7 @@ def _dask_block( def _normalize_axis(axis: ComplexAxis, ndim: int) -> Literal[0, 1] | None: """Adapt `axis` parameter passed by Dask to what we support.""" match axis: - case int() | None: + case int() | None: # pragma: no cover pass case (0 | 1,): axis = axis[0] @@ -104,6 +104,6 @@ def _get_shape(a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Lite case True, 1: assert axis is not None return (1, a.size) if axis == 0 else (a.size, 1) - # pragma: no cover - msg = f"{keepdims=}, {type(a)}" - raise AssertionError(msg) + case _: # pragma: no cover + msg = f"{keepdims=}, {type(a)}" + raise AssertionError(msg)