diff --git a/src/fast_array_utils/_validation.py b/src/fast_array_utils/_validation.py index 8bee97b..b40f82b 100644 --- a/src/fast_array_utils/_validation.py +++ b/src/fast_array_utils/_validation.py @@ -2,14 +2,19 @@ from __future__ import annotations import numpy as np +from numpy.exceptions import AxisError -def validate_axis(axis: int | None) -> None: +def validate_axis(ndim: int, axis: int | None) -> None: if axis is None: return if not isinstance(axis, int | np.integer): # pragma: no cover msg = "axis must be integer or None." raise TypeError(msg) + if axis == 0 and ndim == 1: + raise AxisError(axis, ndim, "use axis=None for 1D arrays") + if axis not in range(ndim): + raise AxisError(axis, ndim) if axis not in (0, 1): # pragma: no cover msg = "We only support axis 0 and 1 at the moment" raise NotImplementedError(msg) diff --git a/src/fast_array_utils/stats/__init__.py b/src/fast_array_utils/stats/__init__.py index 8dabc66..30d132a 100644 --- a/src/fast_array_utils/stats/__init__.py +++ b/src/fast_array_utils/stats/__init__.py @@ -75,7 +75,7 @@ def is_constant( """ from ._is_constant import is_constant_ - validate_axis(axis) + validate_axis(x.ndim, axis) return is_constant_(x, axis=axis) @@ -144,7 +144,7 @@ def mean( """ from ._mean import mean_ - validate_axis(axis) + validate_axis(x.ndim, axis) return mean_(x, axis=axis, dtype=dtype) # type: ignore[no-any-return] # literally the same type, wtf mypy @@ -219,6 +219,7 @@ def mean_var( """ from ._mean_var import mean_var_ + validate_axis(x.ndim, axis) return mean_var_(x, axis=axis, correction=correction) # type: ignore[no-any-return] @@ -284,5 +285,5 @@ def sum( """ from ._sum import sum_ - validate_axis(axis) + validate_axis(x.ndim, axis) return sum_(x, axis=axis, dtype=dtype) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index feb64bf..12f3b2b 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -52,7 +52,9 @@ def _sum_cs( if isinstance(x, types.CSMatrix): x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x) - return cast("NDArray[Any] | np.number[Any]", np.sum(x, axis=axis, dtype=dtype)) # type: ignore[call-overload] + if axis is None: + return cast("np.number[Any]", x.data.sum(dtype=dtype)) + return cast("NDArray[Any] | np.number[Any]", x.sum(axis=axis, dtype=dtype)) @sum_.register(types.DaskArray) @@ -76,17 +78,20 @@ def sum_drop_keepdims( keepdims: bool = False, ) -> NDArray[Any] | types.CupyArray: del keepdims - match axis: - case (0 | 1 as n,): - axis = n - case (0, 1) | (1, 0): - axis = None - case tuple(): # pragma: no cover - msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got {axis} instead" - raise ValueError(msg) + if a.ndim == 1: + axis = None + else: + match axis: + case (0, 1) | (1, 0): + axis = None + case (0 | 1 as n,): + axis = n + case tuple(): # pragma: no cover + msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got {axis} instead" + raise ValueError(msg) rv = sum(a, axis=axis, dtype=dtype) - # make sure rv is 2D - return np.reshape(rv, (1, 1 if rv.shape == () else len(rv))) # type: ignore[arg-type] + shape = (1,) if a.ndim == 1 else (1, 1 if rv.shape == () else len(rv)) # type: ignore[arg-type] + return np.reshape(rv, shape) if dtype is None: # Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`) diff --git a/tests/test_stats.py b/tests/test_stats.py index 75a9b8c..dfe8fad 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -6,6 +6,7 @@ import numpy as np import pytest +from numpy.exceptions import AxisError from fast_array_utils import stats, types from testing.fast_array_utils import SUPPORTED_TYPES, Flags @@ -26,12 +27,12 @@ DTypeIn = np.float32 | np.float64 | np.int32 | np.bool DTypeOut = np.float32 | np.float64 | np.int64 - NdAndAx: TypeAlias = tuple[Literal[2], Literal[0, 1, None]] + NdAndAx: TypeAlias = tuple[Literal[1], Literal[None]] | tuple[Literal[2], Literal[0, 1, None]] - class BenchFun(Protocol): # noqa: D101 + class StatFun(Protocol): # noqa: D101 def __call__( # noqa: D102 self, - arr: CpuArray, + arr: Array, *, axis: Literal[0, 1, None] = None, dtype: type[DTypeOut] | None = None, @@ -41,6 +42,8 @@ def __call__( # noqa: D102 pytestmark = [pytest.mark.skipif(not find_spec("numba"), reason="numba not installed")] +STAT_FUNCS = [stats.sum, stats.mean, stats.mean_var, stats.is_constant] + # can’t select these using a category filter ATS_SPARSE_DS = {at for at in SUPPORTED_TYPES if at.mod == "anndata.abc"} ATS_CUPY_SPARSE = {at for at in SUPPORTED_TYPES if "cupyx.scipy" in str(at)} @@ -49,6 +52,7 @@ def __call__( # noqa: D102 @pytest.fixture( scope="session", params=[ + pytest.param((1, None), id="1d-all"), pytest.param((2, None), id="2d-all"), pytest.param((2, 0), id="2d-ax0"), pytest.param((2, 1), id="2d-ax1"), @@ -59,8 +63,17 @@ def ndim_and_axis(request: pytest.FixtureRequest) -> NdAndAx: @pytest.fixture -def ndim(ndim_and_axis: NdAndAx) -> Literal[2]: - return ndim_and_axis[0] +def ndim(ndim_and_axis: NdAndAx, array_type: ArrayType) -> Literal[1, 2]: + return check_ndim(array_type, ndim_and_axis[0]) + + +def check_ndim(array_type: ArrayType, ndim: Literal[1, 2]) -> Literal[1, 2]: + inner_cls = array_type.inner.cls if array_type.inner else array_type.cls + if ndim != 2 and issubclass(inner_cls, types.CSMatrix | types.CupyCSMatrix): + pytest.skip("CSMatrix only supports 2D") + if ndim != 2 and inner_cls is types.csc_array: + pytest.skip("csc_array only supports 2D") + return ndim @pytest.fixture(scope="session") @@ -68,9 +81,13 @@ def axis(ndim_and_axis: NdAndAx) -> Literal[0, 1, None]: return ndim_and_axis[1] -@pytest.fixture(scope="session", params=[np.float32, np.float64, np.int32, np.bool]) -def dtype_in(request: pytest.FixtureRequest) -> type[DTypeIn]: - return cast("type[DTypeIn]", request.param) +@pytest.fixture(params=[np.float32, np.float64, np.int32, np.bool]) +def dtype_in(request: pytest.FixtureRequest, array_type: ArrayType) -> type[DTypeIn]: + dtype = cast("type[DTypeIn]", request.param) + inner_cls = array_type.inner.cls if array_type.inner else array_type.cls + if np.dtype(dtype).kind not in "fdFD" and issubclass(inner_cls, types.CupyCSMatrix): + pytest.skip("Cupy sparse matrices don’t support non-floating dtypes") + return dtype @pytest.fixture(scope="session", params=[np.float32, np.float64, None]) @@ -79,12 +96,33 @@ def dtype_arg(request: pytest.FixtureRequest) -> type[DTypeOut] | None: @pytest.fixture -def np_arr(dtype_in: type[DTypeIn]) -> NDArray[DTypeIn]: +def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]: np_arr = cast("NDArray[DTypeIn]", np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype_in)) np_arr.flags.writeable = False + if ndim == 1: + np_arr = np_arr.flatten() return np_arr +@pytest.mark.array_type(skip={*ATS_SPARSE_DS, Flags.Matrix}) +@pytest.mark.parametrize("func", STAT_FUNCS) +@pytest.mark.parametrize( + ("ndim", "axis"), [(1, 0), (2, 3), (2, -1)], ids=["1d-ax0", "2d-ax3", "2d-axneg"] +) +def test_ndim_error( + array_type: ArrayType[Array], func: StatFun, ndim: Literal[1, 2], axis: Literal[0, 1, None] +) -> None: + check_ndim(array_type, ndim) + # not using the fixture because we don’t need to test multiple dtypes + np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) + if ndim == 1: + np_arr = np_arr.flatten() + arr = array_type(np_arr) + + with pytest.raises(AxisError): + func(arr, axis=axis) + + @pytest.mark.array_type(skip=ATS_SPARSE_DS) def test_sum( array_type: ArrayType[Array], @@ -93,8 +131,6 @@ def test_sum( axis: Literal[0, 1, None], np_arr: NDArray[DTypeIn], ) -> None: - if array_type in ATS_CUPY_SPARSE and np_arr.dtype.kind != "f": - pytest.skip("CuPy sparse matrices only support floats") arr = array_type(np_arr.copy()) assert arr.dtype == dtype_in @@ -133,8 +169,6 @@ def test_sum( def test_mean( array_type: ArrayType[Array], axis: Literal[0, 1, None], np_arr: NDArray[DTypeIn] ) -> None: - if array_type in ATS_CUPY_SPARSE and np_arr.dtype.kind != "f": - pytest.skip("CuPy sparse matrices only support floats") arr = array_type(np_arr) result = stats.mean(arr, axis=axis) # type: ignore[arg-type] # https://github.com/python/mypy/issues/16777 @@ -148,26 +182,21 @@ def test_mean( @pytest.mark.array_type(skip=Flags.Disk) -@pytest.mark.parametrize( - ("axis", "mean_expected", "var_expected"), - [(None, 3.5, 3.5), (0, [2.5, 3.5, 4.5], [4.5, 4.5, 4.5]), (1, [2.0, 5.0], [1.0, 1.0])], -) def test_mean_var( array_type: ArrayType[CpuArray | GpuArray | types.DaskArray], axis: Literal[0, 1, None], - mean_expected: float | list[float], - var_expected: float | list[float], + np_arr: NDArray[DTypeIn], ) -> None: - np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64) - np.testing.assert_array_equal(np.mean(np_arr, axis=axis), mean_expected) - np.testing.assert_array_equal(np.var(np_arr, axis=axis, correction=1), var_expected) - arr = array_type(np_arr) + mean, var = stats.mean_var(arr, axis=axis, correction=1) if isinstance(mean, types.DaskArray) and isinstance(var, types.DaskArray): mean, var = mean.compute(), var.compute() # type: ignore[assignment] if isinstance(mean, types.CupyArray) and isinstance(var, types.CupyArray): mean, var = mean.get(), var.get() + + mean_expected = np.mean(np_arr, axis=axis) # type: ignore[arg-type] + var_expected = np.var(np_arr, axis=axis, correction=1) # type: ignore[arg-type] np.testing.assert_array_equal(mean, mean_expected) np.testing.assert_array_almost_equal(var, var_expected) # type: ignore[arg-type] @@ -223,11 +252,11 @@ def test_dask_constant_blocks( @pytest.mark.benchmark @pytest.mark.array_type(skip=Flags.Matrix | Flags.Dask | Flags.Disk | Flags.Gpu) -@pytest.mark.parametrize("func", [stats.sum, stats.mean, stats.mean_var, stats.is_constant]) +@pytest.mark.parametrize("func", STAT_FUNCS) @pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32]) def test_stats_benchmark( benchmark: BenchmarkFixture, - func: BenchFun, + func: StatFun, array_type: ArrayType[CpuArray, None], axis: Literal[0, 1, None], dtype: type[np.float32 | np.float64], diff --git a/typings/cupy/_core/core.pyi b/typings/cupy/_core/core.pyi index e95e036..34c9135 100644 --- a/typings/cupy/_core/core.pyi +++ b/typings/cupy/_core/core.pyi @@ -8,6 +8,7 @@ from numpy.typing import NDArray class ndarray: dtype: np.dtype[Any] shape: tuple[int, ...] + ndim: int # cupy-specific def get(self) -> NDArray[Any]: ... diff --git a/typings/cupyx/scipy/sparse/_base.pyi b/typings/cupyx/scipy/sparse/_base.pyi index 4f24482..3cf37c5 100644 --- a/typings/cupyx/scipy/sparse/_base.pyi +++ b/typings/cupyx/scipy/sparse/_base.pyi @@ -9,6 +9,7 @@ from numpy.typing import NDArray class spmatrix: dtype: np.dtype[Any] shape: tuple[int, int] + ndim: int def toarray(self, order: Literal["C", "F", None] = None, out: None = None) -> cupy.ndarray: ... def __power__(self, other: int) -> Self: ... def __array__(self) -> NDArray[Any]: ... diff --git a/typings/h5py.pyi b/typings/h5py.pyi index 67bfd50..c6a211d 100644 --- a/typings/h5py.pyi +++ b/typings/h5py.pyi @@ -12,6 +12,7 @@ class HLObject: ... class Dataset(HLObject): dtype: np.dtype[Any] shape: tuple[int, ...] + ndim: int class Group(HLObject): ...