From be8da8c146483c6988fc37f2ff6fe185cfa99cb4 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Thu, 24 Apr 2025 12:35:45 +0200 Subject: [PATCH 01/10] run tests on 0d arrays --- tests/test_stats.py | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/tests/test_stats.py b/tests/test_stats.py index cb3cd50..cc2a608 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -23,7 +23,7 @@ Array: TypeAlias = CpuArray | GpuArray | DiskArray | types.CSDataset | types.DaskArray - DTypeIn = type[np.float32 | np.float64 | np.int32 | np.bool] + DTypeIn = np.float32 | np.float64 | np.int32 | np.bool DTypeOut = type[np.float32 | np.float64 | np.int64] class BenchFun(Protocol): # noqa: D101 @@ -44,14 +44,19 @@ def __call__( # noqa: D102 ATS_CUPY_SPARSE = {at for at in SUPPORTED_TYPES if "cupyx.scipy" in str(at)} -@pytest.fixture(scope="session", params=[0, 1, None]) +@pytest.fixture(scope="session", params=[0, 1, None], ids=["ax0", "ax1", "all"]) def axis(request: pytest.FixtureRequest) -> Literal[0, 1, None]: return cast("Literal[0, 1, None]", request.param) +@pytest.fixture(scope="session", params=[1, 2], ids=["1d", "2d"]) +def ndim(request: pytest.FixtureRequest) -> Literal[1, 2]: + return cast("Literal[1, 2]", request.param) + + @pytest.fixture(scope="session", params=[np.float32, np.float64, np.int32, np.bool]) -def dtype_in(request: pytest.FixtureRequest) -> DTypeIn: - return cast("DTypeIn", request.param) +def dtype_in(request: pytest.FixtureRequest) -> type[DTypeIn]: + return cast("type[DTypeIn]", request.param) @pytest.fixture(scope="session", params=[np.float32, np.float64, None]) @@ -59,14 +64,22 @@ def dtype_arg(request: pytest.FixtureRequest) -> DTypeOut | None: return cast("DTypeOut | None", request.param) +@pytest.fixture(scope="session") +def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]: + np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype_in) + if ndim == 1: + np_arr = np_arr.flatten() + return np_arr + + @pytest.mark.array_type(skip=ATS_SPARSE_DS) def test_sum( array_type: ArrayType[Array], - dtype_in: DTypeIn, + dtype_in: type[DTypeIn], dtype_arg: DTypeOut | None, axis: Literal[0, 1, None], + np_arr: NDArray[DTypeIn], ) -> None: - np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype_in) if array_type in ATS_CUPY_SPARSE and np_arr.dtype.kind != "f": pytest.skip("CuPy sparse matrices only support floats") arr = array_type(np_arr.copy()) @@ -106,9 +119,11 @@ def test_sum( @pytest.mark.array_type(skip=ATS_SPARSE_DS) @pytest.mark.parametrize(("axis", "expected"), [(None, 3.5), (0, [2.5, 3.5, 4.5]), (1, [2.0, 5.0])]) def test_mean( - array_type: ArrayType[Array], axis: Literal[0, 1, None], expected: float | list[float] + array_type: ArrayType[Array], + axis: Literal[0, 1, None], + expected: float | list[float], + np_arr: NDArray[DTypeIn], ) -> None: - np_arr = np.array([[1, 2, 3], [4, 5, 6]]) if array_type in ATS_CUPY_SPARSE and np_arr.dtype.kind != "f": pytest.skip("CuPy sparse matrices only support floats") np.testing.assert_array_equal(np.mean(np_arr, axis=axis), expected) @@ -132,8 +147,8 @@ def test_mean_var( axis: Literal[0, 1, None], mean_expected: float | list[float], var_expected: float | list[float], + np_arr: NDArray[DTypeIn], ) -> None: - np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64) np.testing.assert_array_equal(np.mean(np_arr, axis=axis), mean_expected) np.testing.assert_array_equal(np.var(np_arr, axis=axis, correction=1), var_expected) From b50b7584dfa714c954ac73b7230b71fe78ea47bc Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 24 Apr 2025 14:32:08 +0200 Subject: [PATCH 02/10] sum mostly works --- src/fast_array_utils/_validation.py | 7 +++- src/fast_array_utils/stats/__init__.py | 6 +-- src/fast_array_utils/stats/_sum.py | 10 +++-- tests/test_stats.py | 57 ++++++++++++++++++-------- 4 files changed, 55 insertions(+), 25 deletions(-) diff --git a/src/fast_array_utils/_validation.py b/src/fast_array_utils/_validation.py index 8bee97b..b40f82b 100644 --- a/src/fast_array_utils/_validation.py +++ b/src/fast_array_utils/_validation.py @@ -2,14 +2,19 @@ from __future__ import annotations import numpy as np +from numpy.exceptions import AxisError -def validate_axis(axis: int | None) -> None: +def validate_axis(ndim: int, axis: int | None) -> None: if axis is None: return if not isinstance(axis, int | np.integer): # pragma: no cover msg = "axis must be integer or None." raise TypeError(msg) + if axis == 0 and ndim == 1: + raise AxisError(axis, ndim, "use axis=None for 1D arrays") + if axis not in range(ndim): + raise AxisError(axis, ndim) if axis not in (0, 1): # pragma: no cover msg = "We only support axis 0 and 1 at the moment" raise NotImplementedError(msg) diff --git a/src/fast_array_utils/stats/__init__.py b/src/fast_array_utils/stats/__init__.py index 8dabc66..0fcb88b 100644 --- a/src/fast_array_utils/stats/__init__.py +++ b/src/fast_array_utils/stats/__init__.py @@ -75,7 +75,7 @@ def is_constant( """ from ._is_constant import is_constant_ - validate_axis(axis) + validate_axis(x.ndim, axis) return is_constant_(x, axis=axis) @@ -144,7 +144,7 @@ def mean( """ from ._mean import mean_ - validate_axis(axis) + validate_axis(x.ndim, axis) return mean_(x, axis=axis, dtype=dtype) # type: ignore[no-any-return] # literally the same type, wtf mypy @@ -284,5 +284,5 @@ def sum( """ from ._sum import sum_ - validate_axis(axis) + validate_axis(x.ndim, axis) return sum_(x, axis=axis, dtype=dtype) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index feb64bf..f215336 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -76,15 +76,17 @@ def sum_drop_keepdims( keepdims: bool = False, ) -> NDArray[Any] | types.CupyArray: del keepdims - match axis: - case (0 | 1 as n,): - axis = n - case (0, 1) | (1, 0): + match (axis, a.ndim): + case (0 | (0,), 1) | ((0, 1) | (1, 0), _): axis = None + case (0 | 1 as n,), _: + axis = n case tuple(): # pragma: no cover msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got {axis} instead" raise ValueError(msg) rv = sum(a, axis=axis, dtype=dtype) + if a.ndim == 1: # make sure rv is 1D + return np.reshape(rv, (1,)) # make sure rv is 2D return np.reshape(rv, (1, 1 if rv.shape == () else len(rv))) # type: ignore[arg-type] diff --git a/tests/test_stats.py b/tests/test_stats.py index cc2a608..7e70cb5 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -24,7 +24,9 @@ Array: TypeAlias = CpuArray | GpuArray | DiskArray | types.CSDataset | types.DaskArray DTypeIn = np.float32 | np.float64 | np.int32 | np.bool - DTypeOut = type[np.float32 | np.float64 | np.int64] + DTypeOut = np.float32 | np.float64 | np.int64 + + NdAndAx = tuple[Literal[1], Literal[None]] | tuple[Literal[2], Literal[0, 1, None]] class BenchFun(Protocol): # noqa: D101 def __call__( # noqa: D102 @@ -32,7 +34,7 @@ def __call__( # noqa: D102 arr: CpuArray, *, axis: Literal[0, 1, None] = None, - dtype: DTypeOut | None = None, + dtype: type[DTypeOut] | None = None, ) -> NDArray[Any] | np.number[Any] | types.DaskArray: ... @@ -44,14 +46,33 @@ def __call__( # noqa: D102 ATS_CUPY_SPARSE = {at for at in SUPPORTED_TYPES if "cupyx.scipy" in str(at)} -@pytest.fixture(scope="session", params=[0, 1, None], ids=["ax0", "ax1", "all"]) -def axis(request: pytest.FixtureRequest) -> Literal[0, 1, None]: - return cast("Literal[0, 1, None]", request.param) +@pytest.fixture( + scope="session", + params=[ + pytest.param((1, None), id="1d-all"), + pytest.param((2, None), id="2d-all"), + pytest.param((2, 0), id="2d-ax0"), + pytest.param((2, 1), id="2d-ax1"), + ], +) +def ndim_and_axis(request: pytest.FixtureRequest) -> NdAndAx: + return cast("NdAndAx", request.param) + + +@pytest.fixture +def ndim(ndim_and_axis: NdAndAx, array_type: ArrayType) -> Literal[1, 2]: + ndim = ndim_and_axis[0] + inner_cls = array_type.inner.cls if array_type.inner else array_type.cls + if ndim != 2 and issubclass(inner_cls, types.CSMatrix): + pytest.skip("CSMatrix only supports 2D") + if ndim != 2 and inner_cls is types.csc_array: + pytest.skip("csc_array only supports 2D") + return ndim -@pytest.fixture(scope="session", params=[1, 2], ids=["1d", "2d"]) -def ndim(request: pytest.FixtureRequest) -> Literal[1, 2]: - return cast("Literal[1, 2]", request.param) +@pytest.fixture(scope="session") +def axis(ndim_and_axis: NdAndAx) -> Literal[0, 1, None]: + return ndim_and_axis[1] @pytest.fixture(scope="session", params=[np.float32, np.float64, np.int32, np.bool]) @@ -60,13 +81,14 @@ def dtype_in(request: pytest.FixtureRequest) -> type[DTypeIn]: @pytest.fixture(scope="session", params=[np.float32, np.float64, None]) -def dtype_arg(request: pytest.FixtureRequest) -> DTypeOut | None: - return cast("DTypeOut | None", request.param) +def dtype_arg(request: pytest.FixtureRequest) -> type[DTypeOut] | None: + return cast("type[DTypeOut] | None", request.param) -@pytest.fixture(scope="session") +@pytest.fixture def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]: np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype_in) + np_arr.flags.writeable = False if ndim == 1: np_arr = np_arr.flatten() return np_arr @@ -76,7 +98,7 @@ def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]: def test_sum( array_type: ArrayType[Array], dtype_in: type[DTypeIn], - dtype_arg: DTypeOut | None, + dtype_arg: type[DTypeOut] | None, axis: Literal[0, 1, None], np_arr: NDArray[DTypeIn], ) -> None: @@ -117,13 +139,14 @@ def test_sum( @pytest.mark.array_type(skip=ATS_SPARSE_DS) -@pytest.mark.parametrize(("axis", "expected"), [(None, 3.5), (0, [2.5, 3.5, 4.5]), (1, [2.0, 5.0])]) def test_mean( - array_type: ArrayType[Array], - axis: Literal[0, 1, None], - expected: float | list[float], - np_arr: NDArray[DTypeIn], + array_type: ArrayType[Array], axis: Literal[0, 1, None], np_arr: NDArray[DTypeIn] ) -> None: + expected = { + None: 3.5, + 0: [2.5, 3.5, 4.5], + 1: [2.0, 5.0], + }[None] if array_type in ATS_CUPY_SPARSE and np_arr.dtype.kind != "f": pytest.skip("CuPy sparse matrices only support floats") np.testing.assert_array_equal(np.mean(np_arr, axis=axis), expected) From b5e4ae7e7c8998931ca8db381b47564999f94caa Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 24 Apr 2025 14:42:25 +0200 Subject: [PATCH 03/10] fix types --- tests/test_stats.py | 8 ++++---- typings/cupy/_core/core.pyi | 1 + typings/cupyx/scipy/sparse/_base.pyi | 1 + typings/h5py.pyi | 1 + 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_stats.py b/tests/test_stats.py index 7e70cb5..0dfee44 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -87,7 +87,7 @@ def dtype_arg(request: pytest.FixtureRequest) -> type[DTypeOut] | None: @pytest.fixture def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]: - np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype_in) + np_arr = cast("NDArray[DTypeIn]", np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype_in)) np_arr.flags.writeable = False if ndim == 1: np_arr = np_arr.flatten() @@ -149,7 +149,7 @@ def test_mean( }[None] if array_type in ATS_CUPY_SPARSE and np_arr.dtype.kind != "f": pytest.skip("CuPy sparse matrices only support floats") - np.testing.assert_array_equal(np.mean(np_arr, axis=axis), expected) + np.testing.assert_array_equal(np.mean(np_arr, axis=axis), expected) # type: ignore[arg-type] arr = array_type(np_arr) result = stats.mean(arr, axis=axis) # type: ignore[arg-type] # https://github.com/python/mypy/issues/16777 @@ -172,8 +172,8 @@ def test_mean_var( var_expected: float | list[float], np_arr: NDArray[DTypeIn], ) -> None: - np.testing.assert_array_equal(np.mean(np_arr, axis=axis), mean_expected) - np.testing.assert_array_equal(np.var(np_arr, axis=axis, correction=1), var_expected) + np.testing.assert_array_equal(np.mean(np_arr, axis=axis), mean_expected) # type: ignore[arg-type] + np.testing.assert_array_equal(np.var(np_arr, axis=axis, correction=1), var_expected) # type: ignore[arg-type] arr = array_type(np_arr) mean, var = stats.mean_var(arr, axis=axis, correction=1) diff --git a/typings/cupy/_core/core.pyi b/typings/cupy/_core/core.pyi index e95e036..34c9135 100644 --- a/typings/cupy/_core/core.pyi +++ b/typings/cupy/_core/core.pyi @@ -8,6 +8,7 @@ from numpy.typing import NDArray class ndarray: dtype: np.dtype[Any] shape: tuple[int, ...] + ndim: int # cupy-specific def get(self) -> NDArray[Any]: ... diff --git a/typings/cupyx/scipy/sparse/_base.pyi b/typings/cupyx/scipy/sparse/_base.pyi index 4f24482..3cf37c5 100644 --- a/typings/cupyx/scipy/sparse/_base.pyi +++ b/typings/cupyx/scipy/sparse/_base.pyi @@ -9,6 +9,7 @@ from numpy.typing import NDArray class spmatrix: dtype: np.dtype[Any] shape: tuple[int, int] + ndim: int def toarray(self, order: Literal["C", "F", None] = None, out: None = None) -> cupy.ndarray: ... def __power__(self, other: int) -> Self: ... def __array__(self) -> NDArray[Any]: ... diff --git a/typings/h5py.pyi b/typings/h5py.pyi index 67bfd50..c6a211d 100644 --- a/typings/h5py.pyi +++ b/typings/h5py.pyi @@ -12,6 +12,7 @@ class HLObject: ... class Dataset(HLObject): dtype: np.dtype[Any] shape: tuple[int, ...] + ndim: int class Group(HLObject): ... From 69ab3711ef83b05a4c099d54eec135d54657d719 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 24 Apr 2025 16:17:04 +0200 Subject: [PATCH 04/10] some fixes --- src/fast_array_utils/stats/_sum.py | 8 +++----- tests/test_stats.py | 7 +------ 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index f215336..8b69b38 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -81,14 +81,12 @@ def sum_drop_keepdims( axis = None case (0 | 1 as n,), _: axis = n - case tuple(): # pragma: no cover + case tuple(), _: # pragma: no cover msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got {axis} instead" raise ValueError(msg) rv = sum(a, axis=axis, dtype=dtype) - if a.ndim == 1: # make sure rv is 1D - return np.reshape(rv, (1,)) - # make sure rv is 2D - return np.reshape(rv, (1, 1 if rv.shape == () else len(rv))) # type: ignore[arg-type] + shape = (1,) if a.ndim == 1 else (1, 1 if rv.shape == () else len(rv)) + return np.reshape(rv, shape) if dtype is None: # Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`) diff --git a/tests/test_stats.py b/tests/test_stats.py index 0dfee44..98cf230 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -142,14 +142,9 @@ def test_sum( def test_mean( array_type: ArrayType[Array], axis: Literal[0, 1, None], np_arr: NDArray[DTypeIn] ) -> None: - expected = { - None: 3.5, - 0: [2.5, 3.5, 4.5], - 1: [2.0, 5.0], - }[None] + expected = np.mean(np_arr, axis=axis) # type: ignore[arg-type] if array_type in ATS_CUPY_SPARSE and np_arr.dtype.kind != "f": pytest.skip("CuPy sparse matrices only support floats") - np.testing.assert_array_equal(np.mean(np_arr, axis=axis), expected) # type: ignore[arg-type] arr = array_type(np_arr) result = stats.mean(arr, axis=axis) # type: ignore[arg-type] # https://github.com/python/mypy/issues/16777 From fc4368340196b96eac9aec01f3256219d4726383 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 24 Apr 2025 16:41:57 +0200 Subject: [PATCH 05/10] fmt --- tests/test_stats.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_stats.py b/tests/test_stats.py index 98cf230..6a67bdd 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -142,16 +142,17 @@ def test_sum( def test_mean( array_type: ArrayType[Array], axis: Literal[0, 1, None], np_arr: NDArray[DTypeIn] ) -> None: - expected = np.mean(np_arr, axis=axis) # type: ignore[arg-type] if array_type in ATS_CUPY_SPARSE and np_arr.dtype.kind != "f": pytest.skip("CuPy sparse matrices only support floats") - arr = array_type(np_arr) + result = stats.mean(arr, axis=axis) # type: ignore[arg-type] # https://github.com/python/mypy/issues/16777 if isinstance(result, types.DaskArray): result = result.compute() if isinstance(result, types.CupyArray | types.CupyCSMatrix): result = result.get() + + expected = np.mean(np_arr, axis=axis) # type: ignore[arg-type] np.testing.assert_array_equal(result, expected) From 308e2a3cf89b7fa0c9044f20d8476309a851ea14 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 24 Apr 2025 17:09:00 +0200 Subject: [PATCH 06/10] fix typing --- src/fast_array_utils/stats/_sum.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index 8b69b38..12f3b2b 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -52,7 +52,9 @@ def _sum_cs( if isinstance(x, types.CSMatrix): x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x) - return cast("NDArray[Any] | np.number[Any]", np.sum(x, axis=axis, dtype=dtype)) # type: ignore[call-overload] + if axis is None: + return cast("np.number[Any]", x.data.sum(dtype=dtype)) + return cast("NDArray[Any] | np.number[Any]", x.sum(axis=axis, dtype=dtype)) @sum_.register(types.DaskArray) @@ -76,16 +78,19 @@ def sum_drop_keepdims( keepdims: bool = False, ) -> NDArray[Any] | types.CupyArray: del keepdims - match (axis, a.ndim): - case (0 | (0,), 1) | ((0, 1) | (1, 0), _): - axis = None - case (0 | 1 as n,), _: - axis = n - case tuple(), _: # pragma: no cover - msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got {axis} instead" - raise ValueError(msg) + if a.ndim == 1: + axis = None + else: + match axis: + case (0, 1) | (1, 0): + axis = None + case (0 | 1 as n,): + axis = n + case tuple(): # pragma: no cover + msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got {axis} instead" + raise ValueError(msg) rv = sum(a, axis=axis, dtype=dtype) - shape = (1,) if a.ndim == 1 else (1, 1 if rv.shape == () else len(rv)) + shape = (1,) if a.ndim == 1 else (1, 1 if rv.shape == () else len(rv)) # type: ignore[arg-type] return np.reshape(rv, shape) if dtype is None: From f1ec75b090088cd2bf60172edf75853684ffb2c1 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 24 Apr 2025 17:43:29 +0200 Subject: [PATCH 07/10] fix mean_var and cupy tests --- tests/test_stats.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/test_stats.py b/tests/test_stats.py index 6a67bdd..d7d540d 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -75,9 +75,13 @@ def axis(ndim_and_axis: NdAndAx) -> Literal[0, 1, None]: return ndim_and_axis[1] -@pytest.fixture(scope="session", params=[np.float32, np.float64, np.int32, np.bool]) -def dtype_in(request: pytest.FixtureRequest) -> type[DTypeIn]: - return cast("type[DTypeIn]", request.param) +@pytest.fixture(params=[np.float32, np.float64, np.int32, np.bool]) +def dtype_in(request: pytest.FixtureRequest, array_type: ArrayType) -> type[DTypeIn]: + dtype = cast("type[DTypeIn]", request.param) + inner_cls = array_type.inner.cls if array_type.inner else array_type.cls + if np.dtype(dtype).kind not in "fdFD" and issubclass(inner_cls, types.CupyCSMatrix): + pytest.skip("Cupy sparse matrices don’t support non-floating dtypes") + return dtype @pytest.fixture(scope="session", params=[np.float32, np.float64, None]) @@ -157,26 +161,21 @@ def test_mean( @pytest.mark.array_type(skip=Flags.Disk) -@pytest.mark.parametrize( - ("axis", "mean_expected", "var_expected"), - [(None, 3.5, 3.5), (0, [2.5, 3.5, 4.5], [4.5, 4.5, 4.5]), (1, [2.0, 5.0], [1.0, 1.0])], -) def test_mean_var( array_type: ArrayType[CpuArray | GpuArray | types.DaskArray], axis: Literal[0, 1, None], - mean_expected: float | list[float], - var_expected: float | list[float], np_arr: NDArray[DTypeIn], ) -> None: - np.testing.assert_array_equal(np.mean(np_arr, axis=axis), mean_expected) # type: ignore[arg-type] - np.testing.assert_array_equal(np.var(np_arr, axis=axis, correction=1), var_expected) # type: ignore[arg-type] - arr = array_type(np_arr) + mean, var = stats.mean_var(arr, axis=axis, correction=1) if isinstance(mean, types.DaskArray) and isinstance(var, types.DaskArray): mean, var = mean.compute(), var.compute() # type: ignore[assignment] if isinstance(mean, types.CupyArray) and isinstance(var, types.CupyArray): mean, var = mean.get(), var.get() + + mean_expected = np.mean(np_arr, axis=axis) # type: ignore[arg-type] + var_expected = np.var(np_arr, axis=axis, correction=1) # type: ignore[arg-type] np.testing.assert_array_equal(mean, mean_expected) np.testing.assert_array_almost_equal(var, var_expected) # type: ignore[arg-type] From 0d3b5985536d53df332ac3dab6148591f494aff3 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 24 Apr 2025 18:25:33 +0200 Subject: [PATCH 08/10] test errors --- src/fast_array_utils/stats/__init__.py | 1 + tests/test_stats.py | 39 ++++++++++++++++++++------ 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/src/fast_array_utils/stats/__init__.py b/src/fast_array_utils/stats/__init__.py index 0fcb88b..30d132a 100644 --- a/src/fast_array_utils/stats/__init__.py +++ b/src/fast_array_utils/stats/__init__.py @@ -219,6 +219,7 @@ def mean_var( """ from ._mean_var import mean_var_ + validate_axis(x.ndim, axis) return mean_var_(x, axis=axis, correction=correction) # type: ignore[no-any-return] diff --git a/tests/test_stats.py b/tests/test_stats.py index d7d540d..466b895 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -6,6 +6,7 @@ import numpy as np import pytest +from numpy.exceptions import AxisError from fast_array_utils import stats, types from testing.fast_array_utils import SUPPORTED_TYPES, Flags @@ -28,7 +29,7 @@ NdAndAx = tuple[Literal[1], Literal[None]] | tuple[Literal[2], Literal[0, 1, None]] - class BenchFun(Protocol): # noqa: D101 + class StatFun(Protocol): # noqa: D101 def __call__( # noqa: D102 self, arr: CpuArray, @@ -41,6 +42,8 @@ def __call__( # noqa: D102 pytestmark = [pytest.mark.skipif(not find_spec("numba"), reason="numba not installed")] +STAT_FUNCS = [stats.sum, stats.mean, stats.mean_var, stats.is_constant] + # can’t select these using a category filter ATS_SPARSE_DS = {at for at in SUPPORTED_TYPES if at.mod == "anndata.abc"} ATS_CUPY_SPARSE = {at for at in SUPPORTED_TYPES if "cupyx.scipy" in str(at)} @@ -61,9 +64,12 @@ def ndim_and_axis(request: pytest.FixtureRequest) -> NdAndAx: @pytest.fixture def ndim(ndim_and_axis: NdAndAx, array_type: ArrayType) -> Literal[1, 2]: - ndim = ndim_and_axis[0] + return check_ndim(array_type, ndim_and_axis[0]) + + +def check_ndim(array_type: ArrayType, ndim: Literal[1, 2]) -> Literal[1, 2]: inner_cls = array_type.inner.cls if array_type.inner else array_type.cls - if ndim != 2 and issubclass(inner_cls, types.CSMatrix): + if ndim != 2 and issubclass(inner_cls, types.CSMatrix | types.CupyCSMatrix): pytest.skip("CSMatrix only supports 2D") if ndim != 2 and inner_cls is types.csc_array: pytest.skip("csc_array only supports 2D") @@ -98,6 +104,25 @@ def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]: return np_arr +@pytest.mark.array_type(skip={*ATS_SPARSE_DS, Flags.Matrix}) +@pytest.mark.parametrize("func", STAT_FUNCS) +@pytest.mark.parametrize( + ("ndim", "axis"), [(1, 0), (2, 3), (2, -1)], ids=["1d-ax0", "2d-ax3", "2d-axneg"] +) +def test_ndim_error( + array_type: ArrayType[Array], func: StatFun, ndim: Literal[1, 2], axis: Literal[0, 1, None] +) -> None: + check_ndim(array_type, ndim) + # not using the fixture because we don’t need to test multiple dtypes + np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) + if ndim == 1: + np_arr = np_arr.flatten() + arr = array_type(np_arr) + + with pytest.raises(AxisError): + func(arr, axis=axis) + + @pytest.mark.array_type(skip=ATS_SPARSE_DS) def test_sum( array_type: ArrayType[Array], @@ -106,8 +131,6 @@ def test_sum( axis: Literal[0, 1, None], np_arr: NDArray[DTypeIn], ) -> None: - if array_type in ATS_CUPY_SPARSE and np_arr.dtype.kind != "f": - pytest.skip("CuPy sparse matrices only support floats") arr = array_type(np_arr.copy()) assert arr.dtype == dtype_in @@ -146,8 +169,6 @@ def test_sum( def test_mean( array_type: ArrayType[Array], axis: Literal[0, 1, None], np_arr: NDArray[DTypeIn] ) -> None: - if array_type in ATS_CUPY_SPARSE and np_arr.dtype.kind != "f": - pytest.skip("CuPy sparse matrices only support floats") arr = array_type(np_arr) result = stats.mean(arr, axis=axis) # type: ignore[arg-type] # https://github.com/python/mypy/issues/16777 @@ -231,11 +252,11 @@ def test_dask_constant_blocks( @pytest.mark.benchmark @pytest.mark.array_type(skip=Flags.Matrix | Flags.Dask | Flags.Disk | Flags.Gpu) -@pytest.mark.parametrize("func", [stats.sum, stats.mean, stats.mean_var, stats.is_constant]) +@pytest.mark.parametrize("func", argvalues=STAT_FUNCS) @pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32]) def test_stats_benchmark( benchmark: BenchmarkFixture, - func: BenchFun, + func: StatFun, array_type: ArrayType[CpuArray, None], axis: Literal[0, 1, None], dtype: type[np.float32 | np.float64], From a4a36585d6cd43c663a7b383be6b48fd3255261c Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 24 Apr 2025 18:37:26 +0200 Subject: [PATCH 09/10] typing --- tests/test_stats.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_stats.py b/tests/test_stats.py index 466b895..d9b94de 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -27,12 +27,12 @@ DTypeIn = np.float32 | np.float64 | np.int32 | np.bool DTypeOut = np.float32 | np.float64 | np.int64 - NdAndAx = tuple[Literal[1], Literal[None]] | tuple[Literal[2], Literal[0, 1, None]] + NdAndAx: TypeAlias = tuple[Literal[1], Literal[None]] | tuple[Literal[2], Literal[0, 1, None]] class StatFun(Protocol): # noqa: D101 def __call__( # noqa: D102 self, - arr: CpuArray, + arr: Array, *, axis: Literal[0, 1, None] = None, dtype: type[DTypeOut] | None = None, From 6bcd3b78fabc9f9899aff69e036532b34e72778e Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 24 Apr 2025 18:38:27 +0200 Subject: [PATCH 10/10] oops --- tests/test_stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_stats.py b/tests/test_stats.py index d9b94de..dfe8fad 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -252,7 +252,7 @@ def test_dask_constant_blocks( @pytest.mark.benchmark @pytest.mark.array_type(skip=Flags.Matrix | Flags.Dask | Flags.Disk | Flags.Gpu) -@pytest.mark.parametrize("func", argvalues=STAT_FUNCS) +@pytest.mark.parametrize("func", STAT_FUNCS) @pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32]) def test_stats_benchmark( benchmark: BenchmarkFixture,