From aca984562f7fd93a73765d7377517e694251f53d Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 25 Apr 2025 18:45:47 +0200 Subject: [PATCH 1/9] Add failing test --- src/fast_array_utils/stats/_sum.py | 5 ++++- tests/test_stats.py | 13 ++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index 12f3b2b..f3fc1f9 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -97,10 +97,13 @@ def sum_drop_keepdims( # Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`) dtype = np.zeros(1, dtype=x.dtype).sum().dtype + def debug_sum(x, *, axis, dtype, keepdims): + return np.sum(x, axis=axis, dtype=dtype, keepdims=keepdims) + return da.reduction( x, sum_drop_keepdims, # type: ignore[arg-type] - partial(np.sum, dtype=dtype), # pyright: ignore[reportArgumentType] + partial(debug_sum, dtype=dtype), # pyright: ignore[reportArgumentType] axis=axis, dtype=dtype, meta=np.array([], dtype=dtype), diff --git a/tests/test_stats.py b/tests/test_stats.py index dfe8fad..233ce4c 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -95,9 +95,16 @@ def dtype_arg(request: pytest.FixtureRequest) -> type[DTypeOut] | None: return cast("type[DTypeOut] | None", request.param) -@pytest.fixture -def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]: - np_arr = cast("NDArray[DTypeIn]", np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype_in)) +@pytest.fixture( + params=[ + pytest.param([[1, 0], [3, 0], [5, 6]], id="3x2"), + pytest.param([[1, 2, 3], [4, 5, 6]], id="2x3"), + ] +) +def np_arr( + request: pytest.FixtureRequest, dtype_in: type[DTypeIn], ndim: Literal[1, 2] +) -> NDArray[DTypeIn]: + np_arr = cast("NDArray[DTypeIn]", np.array(request.param, dtype=dtype_in)) np_arr.flags.writeable = False if ndim == 1: np_arr = np_arr.flatten() From 4aa1210819b18bc43077d24ea2b091d1063023ce Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 25 Apr 2025 18:54:57 +0200 Subject: [PATCH 2/9] Fix it --- src/fast_array_utils/stats/_sum.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index f3fc1f9..4683cfa 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -90,7 +90,7 @@ def sum_drop_keepdims( msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got {axis} instead" raise ValueError(msg) rv = sum(a, axis=axis, dtype=dtype) - shape = (1,) if a.ndim == 1 else (1, 1 if rv.shape == () else len(rv)) # type: ignore[arg-type] + shape = (1,) if a.ndim == 1 else (1 if rv.shape == () else len(rv), 1) # type: ignore[arg-type] return np.reshape(rv, shape) if dtype is None: From b328898e0168d82d0ab6bbd35de1a62453a2c2d2 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 25 Apr 2025 18:55:40 +0200 Subject: [PATCH 3/9] undo sum --- src/fast_array_utils/stats/_sum.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index 4683cfa..163f741 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -97,13 +97,10 @@ def sum_drop_keepdims( # Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`) dtype = np.zeros(1, dtype=x.dtype).sum().dtype - def debug_sum(x, *, axis, dtype, keepdims): - return np.sum(x, axis=axis, dtype=dtype, keepdims=keepdims) - return da.reduction( x, sum_drop_keepdims, # type: ignore[arg-type] - partial(debug_sum, dtype=dtype), # pyright: ignore[reportArgumentType] + partial(np.sum, dtype=dtype), # pyright: ignore[reportArgumentType] axis=axis, dtype=dtype, meta=np.array([], dtype=dtype), From 2518c220165d3fd9c30134efd92eb18edf207db6 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 28 Apr 2025 11:34:24 +0200 Subject: [PATCH 4/9] almost works --- src/fast_array_utils/_validation.py | 2 +- src/fast_array_utils/stats/__init__.py | 48 +++++++++-- src/fast_array_utils/stats/_sum.py | 110 +++++++++++++++++++------ tests/test_stats.py | 33 +++++--- typings/cupy/_core/core.pyi | 3 + 5 files changed, 153 insertions(+), 43 deletions(-) diff --git a/src/fast_array_utils/_validation.py b/src/fast_array_utils/_validation.py index b40f82b..51df2b1 100644 --- a/src/fast_array_utils/_validation.py +++ b/src/fast_array_utils/_validation.py @@ -9,7 +9,7 @@ def validate_axis(ndim: int, axis: int | None) -> None: if axis is None: return if not isinstance(axis, int | np.integer): # pragma: no cover - msg = "axis must be integer or None." + msg = f"axis must be integer or None, not {axis=!r}." raise TypeError(msg) if axis == 0 and ndim == 1: raise AxisError(axis, ndim, "use axis=None for 1D arrays") diff --git a/src/fast_array_utils/stats/__init__.py b/src/fast_array_utils/stats/__init__.py index 30d132a..393f22c 100644 --- a/src/fast_array_utils/stats/__init__.py +++ b/src/fast_array_utils/stats/__init__.py @@ -227,19 +227,56 @@ def mean_var( # https://github.com/scverse/fast-array-utils/issues/52 @overload def sum( - x: CpuArray | GpuArray | DiskArray, /, *, axis: None = None, dtype: DTypeLike | None = None + x: CpuArray | DiskArray, + /, + *, + axis: None = None, + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, ) -> np.number[Any]: ... @overload def sum( - x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None + x: CpuArray | DiskArray, + /, + *, + axis: Literal[0, 1], + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, ) -> NDArray[Any]: ... + + @overload def sum( - x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None + x: GpuArray, + /, + *, + axis: None = None, + dtype: DTypeLike | None = None, + keep_cupy_as_array: Literal[False] = False, +) -> np.number[Any]: ... +@overload +def sum( + x: GpuArray, /, *, axis: None, dtype: DTypeLike | None = None, keep_cupy_as_array: Literal[True] +) -> types.CupyArray: ... +@overload +def sum( + x: GpuArray, + /, + *, + axis: Literal[0, 1], + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, ) -> types.CupyArray: ... + + @overload def sum( - x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None + x: types.DaskArray, + /, + *, + axis: Literal[0, 1, None] = None, + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, ) -> types.DaskArray: ... @@ -249,6 +286,7 @@ def sum( *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, ) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray: """Sum over both or one axis. @@ -286,4 +324,4 @@ def sum( from ._sum import sum_ validate_axis(x.ndim, axis) - return sum_(x, axis=axis, dtype=dtype) + return sum_(x, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index 163f741..3f53afe 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -2,20 +2,25 @@ from __future__ import annotations from functools import partial, singledispatch -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Literal, cast import numpy as np +from numpy.exceptions import AxisError from .. import types if TYPE_CHECKING: - from typing import Any, Literal + from typing import Any, Literal, TypeAlias from numpy.typing import DTypeLike, NDArray from ..typing import CpuArray, DiskArray, GpuArray + ComplexAxis: TypeAlias = ( + tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1, None] + ) + @singledispatch def sum_( @@ -24,7 +29,9 @@ def sum_( *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, ) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray: + del keep_cupy_as_array if TYPE_CHECKING: # these are never passed to this fallback function, but `singledispatch` wants them assert not isinstance( @@ -37,16 +44,31 @@ def sum_( @sum_.register(types.CupyArray | types.CupyCSMatrix) # type: ignore[call-overload,misc] def _sum_cupy( - x: GpuArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None + x: GpuArray, + /, + *, + axis: Literal[0, 1, None] = None, + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, ) -> types.CupyArray | np.number[Any]: arr = cast("types.CupyArray", np.sum(x, axis=axis, dtype=dtype)) - return cast("np.number[Any]", arr.get()[()]) if axis is None else arr.squeeze() + return ( + cast("np.number[Any]", arr.get()[()]) + if not keep_cupy_as_array and axis is None + else arr.squeeze() + ) @sum_.register(types.CSBase) # type: ignore[call-overload,misc] def _sum_cs( - x: types.CSBase, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None + x: types.CSBase, + /, + *, + axis: Literal[0, 1, None] = None, + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, ) -> NDArray[Any] | np.number[Any]: + del keep_cupy_as_array import scipy.sparse as sp if isinstance(x, types.CSMatrix): @@ -59,7 +81,12 @@ def _sum_cs( @sum_.register(types.DaskArray) def _sum_dask( - x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None + x: types.DaskArray, + /, + *, + axis: Literal[0, 1, None] = None, + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, ) -> types.DaskArray: import dask.array as da @@ -69,39 +96,68 @@ def _sum_dask( msg = "sum does not support numpy matrices" raise TypeError(msg) - def sum_drop_keepdims( - a: CpuArray, + def sum_dask_inner( + a: CpuArray | GpuArray, /, *, - axis: tuple[Literal[0], Literal[1]] | Literal[0, 1, None] = None, + axis: ComplexAxis = None, dtype: DTypeLike | None = None, keepdims: bool = False, ) -> NDArray[Any] | types.CupyArray: - del keepdims - if a.ndim == 1: - axis = None - else: - match axis: - case (0, 1) | (1, 0): - axis = None - case (0 | 1 as n,): - axis = n - case tuple(): # pragma: no cover - msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got {axis} instead" - raise ValueError(msg) - rv = sum(a, axis=axis, dtype=dtype) - shape = (1,) if a.ndim == 1 else (1 if rv.shape == () else len(rv), 1) # type: ignore[arg-type] - return np.reshape(rv, shape) + axis = normalize_axis(axis, a.ndim) + rv = sum(a, axis=axis, dtype=dtype, keep_cupy_as_array=True) + shape = get_shape(rv, axis=axis, keepdims=keepdims) + return rv.reshape(shape) if dtype is None: # Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`) dtype = np.zeros(1, dtype=x.dtype).sum().dtype - return da.reduction( + rv = da.reduction( x, - sum_drop_keepdims, # type: ignore[arg-type] - partial(np.sum, dtype=dtype), # pyright: ignore[reportArgumentType] + sum_dask_inner, # type: ignore[arg-type] + partial(sum_dask_inner, dtype=dtype), # pyright: ignore[reportArgumentType] axis=axis, dtype=dtype, meta=np.array([], dtype=dtype), ) + if ( + axis is None + and not keep_cupy_as_array + and isinstance(x._meta, types.CupyArray | types.CupyCSMatrix) # noqa: SLF001 + ): + return rv.map_blocks(lambda a: a.get()[()], meta=x.dtype.type(0)) + return rv + + +def normalize_axis(axis: ComplexAxis, ndim: int) -> Literal[0, 1, None]: + """Adapt `axis` parameter passed by Dask to what we support.""" + match axis: + case int() | None: + pass + case (0 | 1,): + axis = axis[0] + case (0, 1) | (1, 0): + axis = None + case _: + raise AxisError(axis, ndim) + if axis == 0 and ndim == 1: + return None # dask’s aggregate doesn’t know we don’t accept `axis=0` for 1D arrays + return axis + + +def get_shape( + a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Literal[0, 1, None], keepdims: bool +) -> tuple[int] | tuple[int, int]: + """Get the output shape of an axis-flattening operation.""" + match keepdims, a.ndim: + case _, 0: + return (1,) + case False, 1: + return (a.size,) + case True, 1: + assert axis is not None + return (1, a.size) if axis == 0 else (a.size, 1) + case _, _: + msg = f"{keepdims=}, {type(a)}" + raise AssertionError(msg) diff --git a/tests/test_stats.py b/tests/test_stats.py index 233ce4c..a1d6a02 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -95,16 +95,9 @@ def dtype_arg(request: pytest.FixtureRequest) -> type[DTypeOut] | None: return cast("type[DTypeOut] | None", request.param) -@pytest.fixture( - params=[ - pytest.param([[1, 0], [3, 0], [5, 6]], id="3x2"), - pytest.param([[1, 2, 3], [4, 5, 6]], id="2x3"), - ] -) -def np_arr( - request: pytest.FixtureRequest, dtype_in: type[DTypeIn], ndim: Literal[1, 2] -) -> NDArray[DTypeIn]: - np_arr = cast("NDArray[DTypeIn]", np.array(request.param, dtype=dtype_in)) +@pytest.fixture +def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]: + np_arr = cast("NDArray[DTypeIn]", np.array([[1, 0], [3, 0], [5, 6]], dtype=dtype_in)) np_arr.flags.writeable = False if ndim == 1: np_arr = np_arr.flatten() @@ -172,6 +165,26 @@ def test_sum( np.testing.assert_array_equal(sum_, expected) +@pytest.mark.parametrize( + "data", + [ + pytest.param([[1, 0], [3, 0], [5, 6]], id="3x2"), + pytest.param([[1, 2, 3], [4, 5, 6]], id="2x3"), + ], +) +@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.array_type(Flags.Dask) +def test_sum_dask_shapes( + array_type: ArrayType[types.DaskArray], axis: Literal[0, 1], data: list[list[int]] +) -> None: + np_arr = np.array(data, dtype=np.float32) + arr = array_type(np_arr) + sum_ = cast("NDArray[Any] | types.CupyArray", stats.sum(arr, axis=axis).compute()) + if isinstance(sum_, types.CupyArray): + sum_ = sum_.get() + np.testing.assert_almost_equal(np_arr.sum(axis=axis), sum_) + + @pytest.mark.array_type(skip=ATS_SPARSE_DS) def test_mean( array_type: ArrayType[Array], axis: Literal[0, 1, None], np_arr: NDArray[DTypeIn] diff --git a/typings/cupy/_core/core.pyi b/typings/cupy/_core/core.pyi index 34c9135..ccd3874 100644 --- a/typings/cupy/_core/core.pyi +++ b/typings/cupy/_core/core.pyi @@ -8,6 +8,7 @@ from numpy.typing import NDArray class ndarray: dtype: np.dtype[Any] shape: tuple[int, ...] + size: int ndim: int # cupy-specific @@ -15,6 +16,7 @@ class ndarray: # operators def __array__(self) -> NDArray[Any]: ... + def __len__(self) -> int: ... def __getitem__( # never returns scalars self, index: int | slice | EllipsisType | tuple[int | slice | EllipsisType | None, ...] ) -> Self: ... @@ -28,6 +30,7 @@ class ndarray: def all(self, axis: None = None) -> np.bool: ... @overload def all(self, axis: int) -> ndarray: ... + def reshape(self, shape: tuple[int, ...] | int) -> ndarray: ... def squeeze(self, axis: int | None = None) -> Self: ... def ravel(self, order: Literal["C", "F", "A", "K"] = "C") -> Self: ... def flatten(self, order: Literal["C", "F", "A", "K"] = "C") -> Self: ... From 44a280fb7ddd6afc6c3e80bb0c674591eaac15bd Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 28 Apr 2025 12:36:03 +0200 Subject: [PATCH 5/9] It works! --- src/fast_array_utils/stats/_sum.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index 3f53afe..0e4de9b 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -105,9 +105,9 @@ def sum_dask_inner( keepdims: bool = False, ) -> NDArray[Any] | types.CupyArray: axis = normalize_axis(axis, a.ndim) - rv = sum(a, axis=axis, dtype=dtype, keep_cupy_as_array=True) + rv = sum(a, axis=axis, dtype=dtype, keep_cupy_as_array=True) # type: ignore[misc,arg-type] shape = get_shape(rv, axis=axis, keepdims=keepdims) - return rv.reshape(shape) + return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape)) if dtype is None: # Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`) @@ -126,7 +126,11 @@ def sum_dask_inner( and not keep_cupy_as_array and isinstance(x._meta, types.CupyArray | types.CupyCSMatrix) # noqa: SLF001 ): - return rv.map_blocks(lambda a: a.get()[()], meta=x.dtype.type(0)) + + def to_scalar(a: types.CupyArray) -> np.number[Any]: + return cast("np.number[Any]", a.get()[()]) + + return rv.map_blocks(to_scalar, meta=x.dtype.type(0)) # type: ignore[arg-type] return rv @@ -139,8 +143,8 @@ def normalize_axis(axis: ComplexAxis, ndim: int) -> Literal[0, 1, None]: axis = axis[0] case (0, 1) | (1, 0): axis = None - case _: - raise AxisError(axis, ndim) + case _: # pragma: no cover + raise AxisError(axis, ndim) # type: ignore[call-overload] if axis == 0 and ndim == 1: return None # dask’s aggregate doesn’t know we don’t accept `axis=0` for 1D arrays return axis @@ -151,13 +155,14 @@ def get_shape( ) -> tuple[int] | tuple[int, int]: """Get the output shape of an axis-flattening operation.""" match keepdims, a.ndim: - case _, 0: + case False, 0: return (1,) + case True, 0: + return (1, 1) case False, 1: return (a.size,) case True, 1: assert axis is not None return (1, a.size) if axis == 0 else (a.size, 1) - case _, _: - msg = f"{keepdims=}, {type(a)}" - raise AssertionError(msg) + msg = f"{keepdims=}, {type(a)}" + raise AssertionError(msg) From 522189c202ec5c7db235906964c771d6ca8f38b8 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 28 Apr 2025 16:02:11 +0200 Subject: [PATCH 6/9] everything works --- src/fast_array_utils/stats/_sum.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index 0e4de9b..cd86f11 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -121,15 +121,20 @@ def sum_dask_inner( dtype=dtype, meta=np.array([], dtype=dtype), ) - if ( - axis is None - and not keep_cupy_as_array - and isinstance(x._meta, types.CupyArray | types.CupyCSMatrix) # noqa: SLF001 - ): - def to_scalar(a: types.CupyArray) -> np.number[Any]: - return cast("np.number[Any]", a.get()[()]) + if isinstance(x._meta, types.CupyArray | types.CupyCSMatrix): # noqa: SLF001 + if keep_cupy_as_array: + return rv + def to_scalar(a: types.CupyArray | NDArray[Any]) -> np.number[Any]: + return a.get().reshape(())[()] + + else: + + def to_scalar(a: NDArray[Any]) -> np.number[Any]: + return a.reshape(())[()] + + if axis is None: return rv.map_blocks(to_scalar, meta=x.dtype.type(0)) # type: ignore[arg-type] return rv From 16f8814be126cb85a8456c1e3ee40091a193e13d Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 28 Apr 2025 16:11:07 +0200 Subject: [PATCH 7/9] types --- src/fast_array_utils/stats/_sum.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index cd86f11..b54fe76 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -122,21 +122,18 @@ def sum_dask_inner( meta=np.array([], dtype=dtype), ) - if isinstance(x._meta, types.CupyArray | types.CupyCSMatrix): # noqa: SLF001 - if keep_cupy_as_array: - return rv - - def to_scalar(a: types.CupyArray | NDArray[Any]) -> np.number[Any]: - return a.get().reshape(())[()] - - else: - - def to_scalar(a: NDArray[Any]) -> np.number[Any]: - return a.reshape(())[()] - - if axis is None: - return rv.map_blocks(to_scalar, meta=x.dtype.type(0)) # type: ignore[arg-type] - return rv + if axis is not None or ( + isinstance(rv._meta, types.CupyArray) # noqa: SLF001 + and keep_cupy_as_array + ): + return rv + + def to_scalar(a: types.CupyArray | NDArray[Any]) -> np.number[Any]: + if isinstance(a, types.CupyArray): + a = a.get() + return a.reshape(())[()] # type: ignore[return-value] + + return rv.map_blocks(to_scalar, meta=x.dtype.type(0)) # type: ignore[arg-type] def normalize_axis(axis: ComplexAxis, ndim: int) -> Literal[0, 1, None]: From 52fcc1ad9bbcc66e053a547c338b912c90009c1d Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 28 Apr 2025 16:42:26 +0200 Subject: [PATCH 8/9] =?UTF-8?q?add=201=C3=971=20chunks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_stats.py | 2 ++ typings/dask/array/core.pyi | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/test_stats.py b/tests/test_stats.py index a1d6a02..8f76dc2 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -170,6 +170,7 @@ def test_sum( [ pytest.param([[1, 0], [3, 0], [5, 6]], id="3x2"), pytest.param([[1, 2, 3], [4, 5, 6]], id="2x3"), + pytest.param([[1, 0], [0, 2]], id="2x2"), ], ) @pytest.mark.parametrize("axis", [0, 1]) @@ -179,6 +180,7 @@ def test_sum_dask_shapes( ) -> None: np_arr = np.array(data, dtype=np.float32) arr = array_type(np_arr) + assert 1 in arr.chunksize, "This test is supposed to test 1×n and n×1 chunk sizes" sum_ = cast("NDArray[Any] | types.CupyArray", stats.sum(arr, axis=axis).compute()) if isinstance(sum_, types.CupyArray): sum_ = sum_.get() diff --git a/typings/dask/array/core.pyi b/typings/dask/array/core.pyi index 48f044c..741c6b7 100644 --- a/typings/dask/array/core.pyi +++ b/typings/dask/array/core.pyi @@ -44,6 +44,8 @@ class Array: # dask methods and attrs _meta: _Array blocks: BlockView + chunks: tuple[tuple[int, ...], ...] + chunksize: tuple[int, ...] def compute(self) -> _Array: ... def visualize( From 5812b9d24112eb6ef74220a550ddf83dfb29e920 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 28 Apr 2025 17:24:33 +0200 Subject: [PATCH 9/9] coverage --- src/fast_array_utils/stats/_sum.py | 32 ++++++++++++++++-------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index b54fe76..d3652e4 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -90,25 +90,10 @@ def _sum_dask( ) -> types.DaskArray: import dask.array as da - from . import sum - if isinstance(x._meta, np.matrix): # pragma: no cover # noqa: SLF001 msg = "sum does not support numpy matrices" raise TypeError(msg) - def sum_dask_inner( - a: CpuArray | GpuArray, - /, - *, - axis: ComplexAxis = None, - dtype: DTypeLike | None = None, - keepdims: bool = False, - ) -> NDArray[Any] | types.CupyArray: - axis = normalize_axis(axis, a.ndim) - rv = sum(a, axis=axis, dtype=dtype, keep_cupy_as_array=True) # type: ignore[misc,arg-type] - shape = get_shape(rv, axis=axis, keepdims=keepdims) - return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape)) - if dtype is None: # Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`) dtype = np.zeros(1, dtype=x.dtype).sum().dtype @@ -136,6 +121,22 @@ def to_scalar(a: types.CupyArray | NDArray[Any]) -> np.number[Any]: return rv.map_blocks(to_scalar, meta=x.dtype.type(0)) # type: ignore[arg-type] +def sum_dask_inner( + a: CpuArray | GpuArray, + /, + *, + axis: ComplexAxis = None, + dtype: DTypeLike | None = None, + keepdims: bool = False, +) -> NDArray[Any] | types.CupyArray: + from . import sum + + axis = normalize_axis(axis, a.ndim) + rv = sum(a, axis=axis, dtype=dtype, keep_cupy_as_array=True) # type: ignore[misc,arg-type] + shape = get_shape(rv, axis=axis, keepdims=keepdims) + return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape)) + + def normalize_axis(axis: ComplexAxis, ndim: int) -> Literal[0, 1, None]: """Adapt `axis` parameter passed by Dask to what we support.""" match axis: @@ -166,5 +167,6 @@ def get_shape( case True, 1: assert axis is not None return (1, a.size) if axis == 0 else (a.size, 1) + # pragma: no cover msg = f"{keepdims=}, {type(a)}" raise AssertionError(msg)