diff --git a/src/fast_array_utils/_validation.py b/src/fast_array_utils/_validation.py index b40f82b..51df2b1 100644 --- a/src/fast_array_utils/_validation.py +++ b/src/fast_array_utils/_validation.py @@ -9,7 +9,7 @@ def validate_axis(ndim: int, axis: int | None) -> None: if axis is None: return if not isinstance(axis, int | np.integer): # pragma: no cover - msg = "axis must be integer or None." + msg = f"axis must be integer or None, not {axis=!r}." raise TypeError(msg) if axis == 0 and ndim == 1: raise AxisError(axis, ndim, "use axis=None for 1D arrays") diff --git a/src/fast_array_utils/stats/__init__.py b/src/fast_array_utils/stats/__init__.py index 30d132a..393f22c 100644 --- a/src/fast_array_utils/stats/__init__.py +++ b/src/fast_array_utils/stats/__init__.py @@ -227,19 +227,56 @@ def mean_var( # https://github.com/scverse/fast-array-utils/issues/52 @overload def sum( - x: CpuArray | GpuArray | DiskArray, /, *, axis: None = None, dtype: DTypeLike | None = None + x: CpuArray | DiskArray, + /, + *, + axis: None = None, + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, ) -> np.number[Any]: ... @overload def sum( - x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None + x: CpuArray | DiskArray, + /, + *, + axis: Literal[0, 1], + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, ) -> NDArray[Any]: ... + + @overload def sum( - x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None + x: GpuArray, + /, + *, + axis: None = None, + dtype: DTypeLike | None = None, + keep_cupy_as_array: Literal[False] = False, +) -> np.number[Any]: ... +@overload +def sum( + x: GpuArray, /, *, axis: None, dtype: DTypeLike | None = None, keep_cupy_as_array: Literal[True] +) -> types.CupyArray: ... +@overload +def sum( + x: GpuArray, + /, + *, + axis: Literal[0, 1], + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, ) -> types.CupyArray: ... + + @overload def sum( - x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None + x: types.DaskArray, + /, + *, + axis: Literal[0, 1, None] = None, + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, ) -> types.DaskArray: ... @@ -249,6 +286,7 @@ def sum( *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, ) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray: """Sum over both or one axis. @@ -286,4 +324,4 @@ def sum( from ._sum import sum_ validate_axis(x.ndim, axis) - return sum_(x, axis=axis, dtype=dtype) + return sum_(x, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index 12f3b2b..d3652e4 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -2,20 +2,25 @@ from __future__ import annotations from functools import partial, singledispatch -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Literal, cast import numpy as np +from numpy.exceptions import AxisError from .. import types if TYPE_CHECKING: - from typing import Any, Literal + from typing import Any, Literal, TypeAlias from numpy.typing import DTypeLike, NDArray from ..typing import CpuArray, DiskArray, GpuArray + ComplexAxis: TypeAlias = ( + tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1, None] + ) + @singledispatch def sum_( @@ -24,7 +29,9 @@ def sum_( *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, ) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray: + del keep_cupy_as_array if TYPE_CHECKING: # these are never passed to this fallback function, but `singledispatch` wants them assert not isinstance( @@ -37,16 +44,31 @@ def sum_( @sum_.register(types.CupyArray | types.CupyCSMatrix) # type: ignore[call-overload,misc] def _sum_cupy( - x: GpuArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None + x: GpuArray, + /, + *, + axis: Literal[0, 1, None] = None, + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, ) -> types.CupyArray | np.number[Any]: arr = cast("types.CupyArray", np.sum(x, axis=axis, dtype=dtype)) - return cast("np.number[Any]", arr.get()[()]) if axis is None else arr.squeeze() + return ( + cast("np.number[Any]", arr.get()[()]) + if not keep_cupy_as_array and axis is None + else arr.squeeze() + ) @sum_.register(types.CSBase) # type: ignore[call-overload,misc] def _sum_cs( - x: types.CSBase, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None + x: types.CSBase, + /, + *, + axis: Literal[0, 1, None] = None, + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, ) -> NDArray[Any] | np.number[Any]: + del keep_cupy_as_array import scipy.sparse as sp if isinstance(x, types.CSMatrix): @@ -59,49 +81,92 @@ def _sum_cs( @sum_.register(types.DaskArray) def _sum_dask( - x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None + x: types.DaskArray, + /, + *, + axis: Literal[0, 1, None] = None, + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, ) -> types.DaskArray: import dask.array as da - from . import sum - if isinstance(x._meta, np.matrix): # pragma: no cover # noqa: SLF001 msg = "sum does not support numpy matrices" raise TypeError(msg) - def sum_drop_keepdims( - a: CpuArray, - /, - *, - axis: tuple[Literal[0], Literal[1]] | Literal[0, 1, None] = None, - dtype: DTypeLike | None = None, - keepdims: bool = False, - ) -> NDArray[Any] | types.CupyArray: - del keepdims - if a.ndim == 1: - axis = None - else: - match axis: - case (0, 1) | (1, 0): - axis = None - case (0 | 1 as n,): - axis = n - case tuple(): # pragma: no cover - msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got {axis} instead" - raise ValueError(msg) - rv = sum(a, axis=axis, dtype=dtype) - shape = (1,) if a.ndim == 1 else (1, 1 if rv.shape == () else len(rv)) # type: ignore[arg-type] - return np.reshape(rv, shape) - if dtype is None: # Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`) dtype = np.zeros(1, dtype=x.dtype).sum().dtype - return da.reduction( + rv = da.reduction( x, - sum_drop_keepdims, # type: ignore[arg-type] - partial(np.sum, dtype=dtype), # pyright: ignore[reportArgumentType] + sum_dask_inner, # type: ignore[arg-type] + partial(sum_dask_inner, dtype=dtype), # pyright: ignore[reportArgumentType] axis=axis, dtype=dtype, meta=np.array([], dtype=dtype), ) + + if axis is not None or ( + isinstance(rv._meta, types.CupyArray) # noqa: SLF001 + and keep_cupy_as_array + ): + return rv + + def to_scalar(a: types.CupyArray | NDArray[Any]) -> np.number[Any]: + if isinstance(a, types.CupyArray): + a = a.get() + return a.reshape(())[()] # type: ignore[return-value] + + return rv.map_blocks(to_scalar, meta=x.dtype.type(0)) # type: ignore[arg-type] + + +def sum_dask_inner( + a: CpuArray | GpuArray, + /, + *, + axis: ComplexAxis = None, + dtype: DTypeLike | None = None, + keepdims: bool = False, +) -> NDArray[Any] | types.CupyArray: + from . import sum + + axis = normalize_axis(axis, a.ndim) + rv = sum(a, axis=axis, dtype=dtype, keep_cupy_as_array=True) # type: ignore[misc,arg-type] + shape = get_shape(rv, axis=axis, keepdims=keepdims) + return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape)) + + +def normalize_axis(axis: ComplexAxis, ndim: int) -> Literal[0, 1, None]: + """Adapt `axis` parameter passed by Dask to what we support.""" + match axis: + case int() | None: + pass + case (0 | 1,): + axis = axis[0] + case (0, 1) | (1, 0): + axis = None + case _: # pragma: no cover + raise AxisError(axis, ndim) # type: ignore[call-overload] + if axis == 0 and ndim == 1: + return None # dask’s aggregate doesn’t know we don’t accept `axis=0` for 1D arrays + return axis + + +def get_shape( + a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Literal[0, 1, None], keepdims: bool +) -> tuple[int] | tuple[int, int]: + """Get the output shape of an axis-flattening operation.""" + match keepdims, a.ndim: + case False, 0: + return (1,) + case True, 0: + return (1, 1) + case False, 1: + return (a.size,) + case True, 1: + assert axis is not None + return (1, a.size) if axis == 0 else (a.size, 1) + # pragma: no cover + msg = f"{keepdims=}, {type(a)}" + raise AssertionError(msg) diff --git a/tests/test_stats.py b/tests/test_stats.py index dfe8fad..8f76dc2 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -97,7 +97,7 @@ def dtype_arg(request: pytest.FixtureRequest) -> type[DTypeOut] | None: @pytest.fixture def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]: - np_arr = cast("NDArray[DTypeIn]", np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype_in)) + np_arr = cast("NDArray[DTypeIn]", np.array([[1, 0], [3, 0], [5, 6]], dtype=dtype_in)) np_arr.flags.writeable = False if ndim == 1: np_arr = np_arr.flatten() @@ -165,6 +165,28 @@ def test_sum( np.testing.assert_array_equal(sum_, expected) +@pytest.mark.parametrize( + "data", + [ + pytest.param([[1, 0], [3, 0], [5, 6]], id="3x2"), + pytest.param([[1, 2, 3], [4, 5, 6]], id="2x3"), + pytest.param([[1, 0], [0, 2]], id="2x2"), + ], +) +@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.array_type(Flags.Dask) +def test_sum_dask_shapes( + array_type: ArrayType[types.DaskArray], axis: Literal[0, 1], data: list[list[int]] +) -> None: + np_arr = np.array(data, dtype=np.float32) + arr = array_type(np_arr) + assert 1 in arr.chunksize, "This test is supposed to test 1×n and n×1 chunk sizes" + sum_ = cast("NDArray[Any] | types.CupyArray", stats.sum(arr, axis=axis).compute()) + if isinstance(sum_, types.CupyArray): + sum_ = sum_.get() + np.testing.assert_almost_equal(np_arr.sum(axis=axis), sum_) + + @pytest.mark.array_type(skip=ATS_SPARSE_DS) def test_mean( array_type: ArrayType[Array], axis: Literal[0, 1, None], np_arr: NDArray[DTypeIn] diff --git a/typings/cupy/_core/core.pyi b/typings/cupy/_core/core.pyi index 34c9135..ccd3874 100644 --- a/typings/cupy/_core/core.pyi +++ b/typings/cupy/_core/core.pyi @@ -8,6 +8,7 @@ from numpy.typing import NDArray class ndarray: dtype: np.dtype[Any] shape: tuple[int, ...] + size: int ndim: int # cupy-specific @@ -15,6 +16,7 @@ class ndarray: # operators def __array__(self) -> NDArray[Any]: ... + def __len__(self) -> int: ... def __getitem__( # never returns scalars self, index: int | slice | EllipsisType | tuple[int | slice | EllipsisType | None, ...] ) -> Self: ... @@ -28,6 +30,7 @@ class ndarray: def all(self, axis: None = None) -> np.bool: ... @overload def all(self, axis: int) -> ndarray: ... + def reshape(self, shape: tuple[int, ...] | int) -> ndarray: ... def squeeze(self, axis: int | None = None) -> Self: ... def ravel(self, order: Literal["C", "F", "A", "K"] = "C") -> Self: ... def flatten(self, order: Literal["C", "F", "A", "K"] = "C") -> Self: ... diff --git a/typings/dask/array/core.pyi b/typings/dask/array/core.pyi index 48f044c..741c6b7 100644 --- a/typings/dask/array/core.pyi +++ b/typings/dask/array/core.pyi @@ -44,6 +44,8 @@ class Array: # dask methods and attrs _meta: _Array blocks: BlockView + chunks: tuple[tuple[int, ...], ...] + chunksize: tuple[int, ...] def compute(self) -> _Array: ... def visualize(