From c263986685d7e0808d898af267c1ee443f2c8836 Mon Sep 17 00:00:00 2001 From: Quentin Blampey Date: Wed, 29 Oct 2025 12:10:56 +0100 Subject: [PATCH 1/7] pass dtype only when needed in dask --- src/fast_array_utils/stats/_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/fast_array_utils/stats/_utils.py b/src/fast_array_utils/stats/_utils.py index 7d1fcbb..69765e7 100644 --- a/src/fast_array_utils/stats/_utils.py +++ b/src/fast_array_utils/stats/_utils.py @@ -68,10 +68,12 @@ def _dask_block( ) -> NDArray[Any] | types.CupyArray: from . import max, min, sum + kwargs = {"dtype": dtype} if op in get_args(DtypeOps) else {} + fns = {fn.__name__: fn for fn in (min, max, sum)} axis = _normalize_axis(axis, a.ndim) - rv = fns[op](a, axis=axis, dtype=dtype, keep_cupy_as_array=True) # type: ignore[misc,call-overload] + rv = fns[op](a, axis=axis, keep_cupy_as_array=True, **kwargs) # type: ignore[misc,call-overload] shape = _get_shape(rv, axis=axis, keepdims=keepdims) return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape)) From ec7f2792613df001ba22142304710afc8de626a1 Mon Sep 17 00:00:00 2001 From: Quentin Blampey Date: Wed, 29 Oct 2025 12:11:18 +0100 Subject: [PATCH 2/7] support DiskArray in power and mean_var --- src/fast_array_utils/stats/_mean_var.py | 4 ++-- src/fast_array_utils/stats/_power.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/fast_array_utils/stats/_mean_var.py b/src/fast_array_utils/stats/_mean_var.py index 9037567..5e6cd0d 100644 --- a/src/fast_array_utils/stats/_mean_var.py +++ b/src/fast_array_utils/stats/_mean_var.py @@ -15,12 +15,12 @@ from numpy.typing import NDArray - from ..typing import CpuArray, GpuArray + from ..typing import CpuArray, DiskArray, GpuArray @no_type_check # mypy is extremely confused def mean_var_( - x: CpuArray | GpuArray | types.DaskArray, + x: CpuArray | GpuArray | DiskArray | types.DaskArray, /, *, axis: Literal[0, 1] | None = None, diff --git a/src/fast_array_utils/stats/_power.py b/src/fast_array_utils/stats/_power.py index 43dada2..56b6bb2 100644 --- a/src/fast_array_utils/stats/_power.py +++ b/src/fast_array_utils/stats/_power.py @@ -14,10 +14,10 @@ from numpy.typing import DTypeLike - from fast_array_utils.typing import CpuArray, GpuArray + from fast_array_utils.typing import CpuArray, DiskArray, GpuArray # All supported array types except for disk ones and CSDataset - Array: TypeAlias = CpuArray | GpuArray | types.DaskArray + Array: TypeAlias = CpuArray | GpuArray | DiskArray | types.DaskArray _Arr = TypeVar("_Arr", bound=Array) _Mat = TypeVar("_Mat", bound=types.CSBase | types.CupyCSMatrix) @@ -33,7 +33,7 @@ def power(x: _Arr, n: int, /, dtype: DTypeLike | None = None) -> _Arr: def _power(x: Array, n: int, /, dtype: DTypeLike | None = None) -> Array: if TYPE_CHECKING: assert not isinstance(x, types.DaskArray | types.CSBase | types.CupyCSMatrix) - return x**n if dtype is None else np.power(x, n, dtype=dtype) # type: ignore[operator] + return np.power(x, n, dtype=dtype) # type: ignore[operator] @_power.register(types.CSBase | types.CupyCSMatrix) From 6a82c72664b1e0b053c35501230b082757724efb Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Thu, 30 Oct 2025 13:28:32 +0100 Subject: [PATCH 3/7] dtype kw helper --- src/fast_array_utils/stats/_generic_ops.py | 10 ++++------ src/fast_array_utils/stats/_typing.py | 6 +++++- src/fast_array_utils/stats/_utils.py | 10 ++++++---- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/fast_array_utils/stats/_generic_ops.py b/src/fast_array_utils/stats/_generic_ops.py index d342517..110d069 100644 --- a/src/fast_array_utils/stats/_generic_ops.py +++ b/src/fast_array_utils/stats/_generic_ops.py @@ -8,7 +8,7 @@ from .. import types from ._typing import DtypeOps -from ._utils import _dask_inner +from ._utils import _dask_inner, _dtype_kw if TYPE_CHECKING: @@ -29,8 +29,7 @@ def _run_numpy_op( axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, ) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray: - kwargs = {"dtype": dtype} if op in get_args(DtypeOps) else {} - return getattr(np, op)(x, axis=axis, **kwargs) # type: ignore[no-any-return] + return getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op)) # type: ignore[no-any-return] @singledispatch @@ -83,13 +82,12 @@ def _generic_op_cs( # just convert to sparse array, then `return x.{op}(dtype=dtype)` # https://github.com/scipy/scipy/issues/23768 - kwargs = {"dtype": dtype} if op in get_args(DtypeOps) else {} if axis is None: - return cast("np.number[Any]", getattr(x.data, op)(**kwargs)) + return cast("np.number[Any]", getattr(x.data, op)(**_dtype_kw(dtype, op))) if TYPE_CHECKING: # scipy-stubs thinks e.g. "int64" is invalid, which isn’t true assert isinstance(dtype, np.dtype | type | None) # convert to array so dimensions collapse as expected - x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **kwargs) # type: ignore[call-overload] + x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **_dtype_kw(dtype, op)) # type: ignore[call-overload] return cast("NDArray[Any] | np.number[Any]", getattr(x, op)(axis=axis)) diff --git a/src/fast_array_utils/stats/_typing.py b/src/fast_array_utils/stats/_typing.py index e8b0b65..5411372 100644 --- a/src/fast_array_utils/stats/_typing.py +++ b/src/fast_array_utils/stats/_typing.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: MPL-2.0 from __future__ import annotations -from typing import TYPE_CHECKING, Literal, Protocol +from typing import TYPE_CHECKING, Literal, Protocol, TypedDict import numpy as np @@ -49,3 +49,7 @@ def __call__( NoDtypeOps = Literal["max", "min"] DtypeOps = Literal["sum"] Ops: TypeAlias = NoDtypeOps | DtypeOps + + +class DTypeKw(TypedDict, total=False): + dtype: DTypeLike diff --git a/src/fast_array_utils/stats/_utils.py b/src/fast_array_utils/stats/_utils.py index 69765e7..94b9ba6 100644 --- a/src/fast_array_utils/stats/_utils.py +++ b/src/fast_array_utils/stats/_utils.py @@ -17,7 +17,7 @@ from numpy.typing import DTypeLike, NDArray from ..typing import CpuArray, GpuArray - from ._typing import Ops + from ._typing import DTypeKw, Ops ComplexAxis: TypeAlias = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1] | None @@ -68,12 +68,10 @@ def _dask_block( ) -> NDArray[Any] | types.CupyArray: from . import max, min, sum - kwargs = {"dtype": dtype} if op in get_args(DtypeOps) else {} - fns = {fn.__name__: fn for fn in (min, max, sum)} axis = _normalize_axis(axis, a.ndim) - rv = fns[op](a, axis=axis, keep_cupy_as_array=True, **kwargs) # type: ignore[misc,call-overload] + rv = fns[op](a, axis=axis, keep_cupy_as_array=True, **_dtype_kw(dtype, op)) # type: ignore[misc,call-overload] shape = _get_shape(rv, axis=axis, keepdims=keepdims) return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape)) @@ -109,3 +107,7 @@ def _get_shape(a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Lite case _: # pragma: no cover msg = f"{keepdims=}, {type(a)}" raise AssertionError(msg) + + +def _dtype_kw(dtype: DTypeLike | None, op: Ops) -> DTypeKw: + return {"dtype": dtype} if dtype is not None and op in get_args(DtypeOps) else {} From 54116a4032f2506bf07a9f05903f2f672a7a13a7 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Thu, 30 Oct 2025 13:34:53 +0100 Subject: [PATCH 4/7] undo disk array --- src/fast_array_utils/stats/_generic_ops.py | 2 +- src/fast_array_utils/stats/_mean_var.py | 4 ++-- src/fast_array_utils/stats/_power.py | 6 +++--- src/fast_array_utils/stats/_typing.py | 9 ++++++--- src/fast_array_utils/stats/_utils.py | 9 ++++++--- 5 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/fast_array_utils/stats/_generic_ops.py b/src/fast_array_utils/stats/_generic_ops.py index 110d069..c8563ce 100644 --- a/src/fast_array_utils/stats/_generic_ops.py +++ b/src/fast_array_utils/stats/_generic_ops.py @@ -87,7 +87,7 @@ def _generic_op_cs( if TYPE_CHECKING: # scipy-stubs thinks e.g. "int64" is invalid, which isn’t true assert isinstance(dtype, np.dtype | type | None) # convert to array so dimensions collapse as expected - x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **_dtype_kw(dtype, op)) # type: ignore[call-overload] + x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **_dtype_kw(dtype, op)) # type: ignore[arg-type] return cast("NDArray[Any] | np.number[Any]", getattr(x, op)(axis=axis)) diff --git a/src/fast_array_utils/stats/_mean_var.py b/src/fast_array_utils/stats/_mean_var.py index 5e6cd0d..9037567 100644 --- a/src/fast_array_utils/stats/_mean_var.py +++ b/src/fast_array_utils/stats/_mean_var.py @@ -15,12 +15,12 @@ from numpy.typing import NDArray - from ..typing import CpuArray, DiskArray, GpuArray + from ..typing import CpuArray, GpuArray @no_type_check # mypy is extremely confused def mean_var_( - x: CpuArray | GpuArray | DiskArray | types.DaskArray, + x: CpuArray | GpuArray | types.DaskArray, /, *, axis: Literal[0, 1] | None = None, diff --git a/src/fast_array_utils/stats/_power.py b/src/fast_array_utils/stats/_power.py index 56b6bb2..43dada2 100644 --- a/src/fast_array_utils/stats/_power.py +++ b/src/fast_array_utils/stats/_power.py @@ -14,10 +14,10 @@ from numpy.typing import DTypeLike - from fast_array_utils.typing import CpuArray, DiskArray, GpuArray + from fast_array_utils.typing import CpuArray, GpuArray # All supported array types except for disk ones and CSDataset - Array: TypeAlias = CpuArray | GpuArray | DiskArray | types.DaskArray + Array: TypeAlias = CpuArray | GpuArray | types.DaskArray _Arr = TypeVar("_Arr", bound=Array) _Mat = TypeVar("_Mat", bound=types.CSBase | types.CupyCSMatrix) @@ -33,7 +33,7 @@ def power(x: _Arr, n: int, /, dtype: DTypeLike | None = None) -> _Arr: def _power(x: Array, n: int, /, dtype: DTypeLike | None = None) -> Array: if TYPE_CHECKING: assert not isinstance(x, types.DaskArray | types.CSBase | types.CupyCSMatrix) - return np.power(x, n, dtype=dtype) # type: ignore[operator] + return x**n if dtype is None else np.power(x, n, dtype=dtype) # type: ignore[operator] @_power.register(types.CSBase | types.CupyCSMatrix) diff --git a/src/fast_array_utils/stats/_typing.py b/src/fast_array_utils/stats/_typing.py index 5411372..2ddc4c0 100644 --- a/src/fast_array_utils/stats/_typing.py +++ b/src/fast_array_utils/stats/_typing.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: MPL-2.0 from __future__ import annotations -from typing import TYPE_CHECKING, Literal, Protocol, TypedDict +from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypedDict, TypeVar import numpy as np @@ -51,5 +51,8 @@ def __call__( Ops: TypeAlias = NoDtypeOps | DtypeOps -class DTypeKw(TypedDict, total=False): - dtype: DTypeLike +_DT = TypeVar("_DT", bound="DTypeLike") + + +class DTypeKw(TypedDict, Generic[_DT], total=False): + dtype: _DT diff --git a/src/fast_array_utils/stats/_utils.py b/src/fast_array_utils/stats/_utils.py index 94b9ba6..a567484 100644 --- a/src/fast_array_utils/stats/_utils.py +++ b/src/fast_array_utils/stats/_utils.py @@ -2,7 +2,7 @@ from __future__ import annotations from functools import partial -from typing import TYPE_CHECKING, Literal, cast, get_args +from typing import TYPE_CHECKING, Literal, TypeVar, cast, get_args import numpy as np from numpy.exceptions import AxisError @@ -71,7 +71,7 @@ def _dask_block( fns = {fn.__name__: fn for fn in (min, max, sum)} axis = _normalize_axis(axis, a.ndim) - rv = fns[op](a, axis=axis, keep_cupy_as_array=True, **_dtype_kw(dtype, op)) # type: ignore[misc,call-overload] + rv = fns[op](a, axis=axis, keep_cupy_as_array=True, **_dtype_kw(dtype, op)) # type: ignore[call-overload] shape = _get_shape(rv, axis=axis, keepdims=keepdims) return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape)) @@ -109,5 +109,8 @@ def _get_shape(a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Lite raise AssertionError(msg) -def _dtype_kw(dtype: DTypeLike | None, op: Ops) -> DTypeKw: +DT = TypeVar("DT", bound="DTypeLike") + + +def _dtype_kw(dtype: DT | None, op: Ops) -> DTypeKw[DT]: return {"dtype": dtype} if dtype is not None and op in get_args(DtypeOps) else {} From a963725b21c84fd5e1b2ccbbf181dc982480f135 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Thu, 30 Oct 2025 13:56:51 +0100 Subject: [PATCH 5/7] add tests --- tests/test_stats.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/tests/test_stats.py b/tests/test_stats.py index 984b460..e5b1dc7 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -201,6 +201,20 @@ def test_sum_to_int(array_type: ArrayType[CpuArray | DiskArray | types.DaskArray np.testing.assert_array_equal(sum_, expected) +@pytest.mark.array_type(skip=ATS_SPARSE_DS) +@pytest.mark.parametrize("func", [stats.min, stats.max]) +def test_min_max(array_type: ArrayType[CpuArray | GpuArray | DiskArray | types.DaskArray], axis: Literal[0, 1] | None, func: StatFunNoDtype) -> None: + rng = np.random.default_rng(0) + np_arr = rng.random((100, 100)) + arr = array_type(np_arr) + + result = to_np_dense_checked(func(arr, axis=axis), axis, arr) + + expected = (np.min if func is stats.min else np.max)(np_arr, axis=axis) + np.testing.assert_array_equal(result, expected) + + +@pytest.mark.parametrize("func", [stats.sum, stats.min, stats.max]) @pytest.mark.parametrize( "data", [ @@ -211,14 +225,15 @@ def test_sum_to_int(array_type: ArrayType[CpuArray | DiskArray | types.DaskArray ) @pytest.mark.parametrize("axis", [0, 1]) @pytest.mark.array_type(Flags.Dask) -def test_sum_dask_shapes(array_type: ArrayType[types.DaskArray], axis: Literal[0, 1], data: list[list[int]]) -> None: +def test_dask_shapes(array_type: ArrayType[types.DaskArray], axis: Literal[0, 1], data: list[list[int]], func: StatFunNoDtype) -> None: np_arr = np.array(data, dtype=np.float32) arr = array_type(np_arr) assert 1 in arr.chunksize, "This test is supposed to test 1×n and n×1 chunk sizes" - sum_ = cast("NDArray[Any] | types.CupyArray", stats.sum(arr, axis=axis).compute()) - if isinstance(sum_, types.CupyArray): - sum_ = sum_.get() - np.testing.assert_almost_equal(np_arr.sum(axis=axis), sum_) + stat = cast("NDArray[Any] | types.CupyArray", func(arr, axis=axis).compute()) + if isinstance(stat, types.CupyArray): + stat = stat.get() + np_func = getattr(np, func.__name__) + np.testing.assert_almost_equal(stat, np_func(np_arr, axis=axis)) @pytest.mark.array_type(skip=ATS_SPARSE_DS) From ceb2e132281512709f9c8488099e6ba2a734e18d Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 18 Nov 2025 10:30:37 +0100 Subject: [PATCH 6/7] Fix tests --- src/fast_array_utils/stats/__init__.py | 2 +- src/fast_array_utils/stats/_generic_ops.py | 6 ++++-- src/fast_array_utils/stats/_utils.py | 7 ++++++- typings/cupy/_core/core.pyi | 11 ++++++++++- typings/cupy/cuda.pyi | 1 + 5 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/fast_array_utils/stats/__init__.py b/src/fast_array_utils/stats/__init__.py index b49f6ec..8010444 100644 --- a/src/fast_array_utils/stats/__init__.py +++ b/src/fast_array_utils/stats/__init__.py @@ -223,7 +223,7 @@ def _generic_op( ) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray: from ._generic_ops import generic_op - assert dtype is None or op in get_args(DtypeOps), f"`dtype` is not supported for operation '{op}'" + assert dtype is None or op in get_args(DtypeOps), f"`dtype` is not supported for operation {op!r}" validate_axis(x.ndim, axis) return generic_op(x, op, axis=axis, keep_cupy_as_array=keep_cupy_as_array, dtype=dtype) diff --git a/src/fast_array_utils/stats/_generic_ops.py b/src/fast_array_utils/stats/_generic_ops.py index c8563ce..5db675b 100644 --- a/src/fast_array_utils/stats/_generic_ops.py +++ b/src/fast_array_utils/stats/_generic_ops.py @@ -29,7 +29,8 @@ def _run_numpy_op( axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, ) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray: - return getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op)) # type: ignore[no-any-return] + arr = cast("NDArray[Any] | np.number[Any] | types.CupyArray | types.CupyCOOMatrix | types.DaskArray", getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op))) + return arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr @singledispatch @@ -88,7 +89,8 @@ def _generic_op_cs( assert isinstance(dtype, np.dtype | type | None) # convert to array so dimensions collapse as expected x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **_dtype_kw(dtype, op)) # type: ignore[arg-type] - return cast("NDArray[Any] | np.number[Any]", getattr(x, op)(axis=axis)) + rv = cast("NDArray[Any] | types.coo_array | np.number[Any]", getattr(x, op)(axis=axis)) + return rv.toarray() if isinstance(rv, types.coo_array) else rv @generic_op.register(types.DaskArray) diff --git a/src/fast_array_utils/stats/_utils.py b/src/fast_array_utils/stats/_utils.py index a567484..7ba8952 100644 --- a/src/fast_array_utils/stats/_utils.py +++ b/src/fast_array_utils/stats/_utils.py @@ -8,6 +8,7 @@ from numpy.exceptions import AxisError from .. import types +from ..typing import GpuArray from ._typing import DtypeOps @@ -16,7 +17,7 @@ from numpy.typing import DTypeLike, NDArray - from ..typing import CpuArray, GpuArray + from ..typing import CpuArray from ._typing import DTypeKw, Ops ComplexAxis: TypeAlias = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1] | None @@ -65,9 +66,13 @@ def _dask_block( axis: ComplexAxis = None, dtype: DTypeLike | None = None, keepdims: bool = False, + computing_meta: bool = False, ) -> NDArray[Any] | types.CupyArray: from . import max, min, sum + if computing_meta: # dask.blockwise doesn’t allow to pass `meta` in, and reductions below don’t handle a 0d matrix + return (types.CupyArray if isinstance(a, GpuArray) else np.ndarray)((), dtype or a.dtype) + fns = {fn.__name__: fn for fn in (min, max, sum)} axis = _normalize_axis(axis, a.ndim) diff --git a/typings/cupy/_core/core.pyi b/typings/cupy/_core/core.pyi index 7d1bf96..2995181 100644 --- a/typings/cupy/_core/core.pyi +++ b/typings/cupy/_core/core.pyi @@ -3,7 +3,7 @@ from types import EllipsisType from typing import Any, Literal, Self, overload import numpy as np -from cupy.cuda import Stream +from cupy.cuda import MemoryPointer, Stream from numpy._core.multiarray import flagsobj from numpy.typing import DTypeLike, NDArray @@ -14,6 +14,15 @@ class ndarray: ndim: int flags: flagsobj + def __init__( + self, + shape: tuple[int, ...], + dtype: DTypeLike | None = ..., + memptr: MemoryPointer | None = None, + strides: tuple[int, ...] | None = None, + order: Literal["C", "F"] = "C", + ) -> None: ... + # cupy-specific def get( self, stream: Stream | None = None, order: Literal["C", "F", "A"] = "C", out: NDArray[Any] | None = None, blocking: bool = True diff --git a/typings/cupy/cuda.pyi b/typings/cupy/cuda.pyi index 659d7ee..595d499 100644 --- a/typings/cupy/cuda.pyi +++ b/typings/cupy/cuda.pyi @@ -1,2 +1,3 @@ # SPDX-License-Identifier: MPL-2.0 class Stream: ... +class MemoryPointer: ... From c7325e33d7e3c6555863eb0de64f917f7a6ca510 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 18 Nov 2025 11:02:16 +0100 Subject: [PATCH 7/7] Fix for old scipy --- src/fast_array_utils/stats/_generic_ops.py | 3 ++- src/fast_array_utils/stats/_utils.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/fast_array_utils/stats/_generic_ops.py b/src/fast_array_utils/stats/_generic_ops.py index 5db675b..ef833ed 100644 --- a/src/fast_array_utils/stats/_generic_ops.py +++ b/src/fast_array_utils/stats/_generic_ops.py @@ -90,7 +90,8 @@ def _generic_op_cs( # convert to array so dimensions collapse as expected x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **_dtype_kw(dtype, op)) # type: ignore[arg-type] rv = cast("NDArray[Any] | types.coo_array | np.number[Any]", getattr(x, op)(axis=axis)) - return rv.toarray() if isinstance(rv, types.coo_array) else rv + # old scipy versions’ sparray.{max,min}() return a 1×n/n×1 sparray here, so we squeeze + return rv.toarray().squeeze() if isinstance(rv, types.coo_array) else rv @generic_op.register(types.DaskArray) diff --git a/src/fast_array_utils/stats/_utils.py b/src/fast_array_utils/stats/_utils.py index 7ba8952..1ca4065 100644 --- a/src/fast_array_utils/stats/_utils.py +++ b/src/fast_array_utils/stats/_utils.py @@ -110,7 +110,7 @@ def _get_shape(a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Lite assert axis is not None return (1, a.size) if axis == 0 else (a.size, 1) case _: # pragma: no cover - msg = f"{keepdims=}, {type(a)}" + msg = f"{keepdims=}, {a.ndim=}, {type(a)=}" raise AssertionError(msg)