Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/fast_array_utils/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@
from __future__ import annotations

import numpy as np
from numpy.exceptions import AxisError


def validate_axis(axis: int | None) -> None:
def validate_axis(ndim: int, axis: int | None) -> None:
if axis is None:
return
if not isinstance(axis, int | np.integer): # pragma: no cover
msg = "axis must be integer or None."
raise TypeError(msg)
if axis == 0 and ndim == 1:
raise AxisError(axis, ndim, "use axis=None for 1D arrays")
if axis not in range(ndim):
raise AxisError(axis, ndim)
if axis not in (0, 1): # pragma: no cover
msg = "We only support axis 0 and 1 at the moment"
raise NotImplementedError(msg)
7 changes: 4 additions & 3 deletions src/fast_array_utils/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def is_constant(
"""
from ._is_constant import is_constant_

validate_axis(axis)
validate_axis(x.ndim, axis)
return is_constant_(x, axis=axis)


Expand Down Expand Up @@ -144,7 +144,7 @@ def mean(
"""
from ._mean import mean_

validate_axis(axis)
validate_axis(x.ndim, axis)
return mean_(x, axis=axis, dtype=dtype) # type: ignore[no-any-return] # literally the same type, wtf mypy


Expand Down Expand Up @@ -219,6 +219,7 @@ def mean_var(
"""
from ._mean_var import mean_var_

validate_axis(x.ndim, axis)
return mean_var_(x, axis=axis, correction=correction) # type: ignore[no-any-return]


Expand Down Expand Up @@ -284,5 +285,5 @@ def sum(
"""
from ._sum import sum_

validate_axis(axis)
validate_axis(x.ndim, axis)
return sum_(x, axis=axis, dtype=dtype)
27 changes: 16 additions & 11 deletions src/fast_array_utils/stats/_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@
if isinstance(x, types.CSMatrix):
x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x)

return cast("NDArray[Any] | np.number[Any]", np.sum(x, axis=axis, dtype=dtype)) # type: ignore[call-overload]
if axis is None:
return cast("np.number[Any]", x.data.sum(dtype=dtype))
Comment on lines +55 to +56
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without this, 1D csr_arrays sum up to np.float64 no matter if bool, int, or float32.

that’s not the case for 2D ones, so I should probably report this as an issue to scipy.

return cast("NDArray[Any] | np.number[Any]", x.sum(axis=axis, dtype=dtype))


@sum_.register(types.DaskArray)
Expand All @@ -76,17 +78,20 @@
keepdims: bool = False,
) -> NDArray[Any] | types.CupyArray:
del keepdims
match axis:
case (0 | 1 as n,):
axis = n
case (0, 1) | (1, 0):
axis = None
case tuple(): # pragma: no cover
msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got {axis} instead"
raise ValueError(msg)
if a.ndim == 1:
axis = None

Check warning on line 82 in src/fast_array_utils/stats/_sum.py

View check run for this annotation

Codecov / codecov/patch

src/fast_array_utils/stats/_sum.py#L81-L82

Added lines #L81 - L82 were not covered by tests
else:
match axis:
case (0, 1) | (1, 0):
axis = None
case (0 | 1 as n,):
axis = n

Check warning on line 88 in src/fast_array_utils/stats/_sum.py

View check run for this annotation

Codecov / codecov/patch

src/fast_array_utils/stats/_sum.py#L84-L88

Added lines #L84 - L88 were not covered by tests
case tuple(): # pragma: no cover
msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got {axis} instead"
raise ValueError(msg)
rv = sum(a, axis=axis, dtype=dtype)
# make sure rv is 2D
return np.reshape(rv, (1, 1 if rv.shape == () else len(rv))) # type: ignore[arg-type]
shape = (1,) if a.ndim == 1 else (1, 1 if rv.shape == () else len(rv)) # type: ignore[arg-type]
return np.reshape(rv, shape)

Check warning on line 94 in src/fast_array_utils/stats/_sum.py

View check run for this annotation

Codecov / codecov/patch

src/fast_array_utils/stats/_sum.py#L93-L94

Added lines #L93 - L94 were not covered by tests

if dtype is None:
# Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`)
Expand Down
79 changes: 54 additions & 25 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import pytest
from numpy.exceptions import AxisError

from fast_array_utils import stats, types
from testing.fast_array_utils import SUPPORTED_TYPES, Flags
Expand All @@ -26,12 +27,12 @@
DTypeIn = np.float32 | np.float64 | np.int32 | np.bool
DTypeOut = np.float32 | np.float64 | np.int64

NdAndAx: TypeAlias = tuple[Literal[2], Literal[0, 1, None]]
NdAndAx: TypeAlias = tuple[Literal[1], Literal[None]] | tuple[Literal[2], Literal[0, 1, None]]

class BenchFun(Protocol): # noqa: D101
class StatFun(Protocol): # noqa: D101
def __call__( # noqa: D102
self,
arr: CpuArray,
arr: Array,
*,
axis: Literal[0, 1, None] = None,
dtype: type[DTypeOut] | None = None,
Expand All @@ -41,6 +42,8 @@ def __call__( # noqa: D102
pytestmark = [pytest.mark.skipif(not find_spec("numba"), reason="numba not installed")]


STAT_FUNCS = [stats.sum, stats.mean, stats.mean_var, stats.is_constant]

# can’t select these using a category filter
ATS_SPARSE_DS = {at for at in SUPPORTED_TYPES if at.mod == "anndata.abc"}
ATS_CUPY_SPARSE = {at for at in SUPPORTED_TYPES if "cupyx.scipy" in str(at)}
Expand All @@ -49,6 +52,7 @@ def __call__( # noqa: D102
@pytest.fixture(
scope="session",
params=[
pytest.param((1, None), id="1d-all"),
pytest.param((2, None), id="2d-all"),
pytest.param((2, 0), id="2d-ax0"),
pytest.param((2, 1), id="2d-ax1"),
Expand All @@ -59,18 +63,31 @@ def ndim_and_axis(request: pytest.FixtureRequest) -> NdAndAx:


@pytest.fixture
def ndim(ndim_and_axis: NdAndAx) -> Literal[2]:
return ndim_and_axis[0]
def ndim(ndim_and_axis: NdAndAx, array_type: ArrayType) -> Literal[1, 2]:
return check_ndim(array_type, ndim_and_axis[0])


def check_ndim(array_type: ArrayType, ndim: Literal[1, 2]) -> Literal[1, 2]:
inner_cls = array_type.inner.cls if array_type.inner else array_type.cls
if ndim != 2 and issubclass(inner_cls, types.CSMatrix | types.CupyCSMatrix):
pytest.skip("CSMatrix only supports 2D")
if ndim != 2 and inner_cls is types.csc_array:
pytest.skip("csc_array only supports 2D")
return ndim


@pytest.fixture(scope="session")
def axis(ndim_and_axis: NdAndAx) -> Literal[0, 1, None]:
return ndim_and_axis[1]


@pytest.fixture(scope="session", params=[np.float32, np.float64, np.int32, np.bool])
def dtype_in(request: pytest.FixtureRequest) -> type[DTypeIn]:
return cast("type[DTypeIn]", request.param)
@pytest.fixture(params=[np.float32, np.float64, np.int32, np.bool])
def dtype_in(request: pytest.FixtureRequest, array_type: ArrayType) -> type[DTypeIn]:
dtype = cast("type[DTypeIn]", request.param)
inner_cls = array_type.inner.cls if array_type.inner else array_type.cls
if np.dtype(dtype).kind not in "fdFD" and issubclass(inner_cls, types.CupyCSMatrix):
pytest.skip("Cupy sparse matrices don’t support non-floating dtypes")
return dtype


@pytest.fixture(scope="session", params=[np.float32, np.float64, None])
Expand All @@ -79,12 +96,33 @@ def dtype_arg(request: pytest.FixtureRequest) -> type[DTypeOut] | None:


@pytest.fixture
def np_arr(dtype_in: type[DTypeIn]) -> NDArray[DTypeIn]:
def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]:
np_arr = cast("NDArray[DTypeIn]", np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype_in))
np_arr.flags.writeable = False
if ndim == 1:
np_arr = np_arr.flatten()
return np_arr


@pytest.mark.array_type(skip={*ATS_SPARSE_DS, Flags.Matrix})
@pytest.mark.parametrize("func", STAT_FUNCS)
@pytest.mark.parametrize(
("ndim", "axis"), [(1, 0), (2, 3), (2, -1)], ids=["1d-ax0", "2d-ax3", "2d-axneg"]
)
def test_ndim_error(
array_type: ArrayType[Array], func: StatFun, ndim: Literal[1, 2], axis: Literal[0, 1, None]
) -> None:
check_ndim(array_type, ndim)
# not using the fixture because we don’t need to test multiple dtypes
np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
if ndim == 1:
np_arr = np_arr.flatten()
arr = array_type(np_arr)

with pytest.raises(AxisError):
func(arr, axis=axis)


@pytest.mark.array_type(skip=ATS_SPARSE_DS)
def test_sum(
array_type: ArrayType[Array],
Expand All @@ -93,8 +131,6 @@ def test_sum(
axis: Literal[0, 1, None],
np_arr: NDArray[DTypeIn],
) -> None:
if array_type in ATS_CUPY_SPARSE and np_arr.dtype.kind != "f":
pytest.skip("CuPy sparse matrices only support floats")
arr = array_type(np_arr.copy())
assert arr.dtype == dtype_in

Expand Down Expand Up @@ -133,8 +169,6 @@ def test_sum(
def test_mean(
array_type: ArrayType[Array], axis: Literal[0, 1, None], np_arr: NDArray[DTypeIn]
) -> None:
if array_type in ATS_CUPY_SPARSE and np_arr.dtype.kind != "f":
pytest.skip("CuPy sparse matrices only support floats")
arr = array_type(np_arr)

result = stats.mean(arr, axis=axis) # type: ignore[arg-type] # https://github.com/python/mypy/issues/16777
Expand All @@ -148,26 +182,21 @@ def test_mean(


@pytest.mark.array_type(skip=Flags.Disk)
@pytest.mark.parametrize(
("axis", "mean_expected", "var_expected"),
[(None, 3.5, 3.5), (0, [2.5, 3.5, 4.5], [4.5, 4.5, 4.5]), (1, [2.0, 5.0], [1.0, 1.0])],
)
def test_mean_var(
array_type: ArrayType[CpuArray | GpuArray | types.DaskArray],
axis: Literal[0, 1, None],
mean_expected: float | list[float],
var_expected: float | list[float],
np_arr: NDArray[DTypeIn],
) -> None:
np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64)
np.testing.assert_array_equal(np.mean(np_arr, axis=axis), mean_expected)
np.testing.assert_array_equal(np.var(np_arr, axis=axis, correction=1), var_expected)

arr = array_type(np_arr)

mean, var = stats.mean_var(arr, axis=axis, correction=1)
if isinstance(mean, types.DaskArray) and isinstance(var, types.DaskArray):
mean, var = mean.compute(), var.compute() # type: ignore[assignment]
if isinstance(mean, types.CupyArray) and isinstance(var, types.CupyArray):
mean, var = mean.get(), var.get()

mean_expected = np.mean(np_arr, axis=axis) # type: ignore[arg-type]
var_expected = np.var(np_arr, axis=axis, correction=1) # type: ignore[arg-type]
np.testing.assert_array_equal(mean, mean_expected)
np.testing.assert_array_almost_equal(var, var_expected) # type: ignore[arg-type]

Expand Down Expand Up @@ -223,11 +252,11 @@ def test_dask_constant_blocks(

@pytest.mark.benchmark
@pytest.mark.array_type(skip=Flags.Matrix | Flags.Dask | Flags.Disk | Flags.Gpu)
@pytest.mark.parametrize("func", [stats.sum, stats.mean, stats.mean_var, stats.is_constant])
@pytest.mark.parametrize("func", STAT_FUNCS)
@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32])
def test_stats_benchmark(
benchmark: BenchmarkFixture,
func: BenchFun,
func: StatFun,
array_type: ArrayType[CpuArray, None],
axis: Literal[0, 1, None],
dtype: type[np.float32 | np.float64],
Expand Down
1 change: 1 addition & 0 deletions typings/cupy/_core/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ from numpy.typing import NDArray
class ndarray:
dtype: np.dtype[Any]
shape: tuple[int, ...]
ndim: int

# cupy-specific
def get(self) -> NDArray[Any]: ...
Expand Down
1 change: 1 addition & 0 deletions typings/cupyx/scipy/sparse/_base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ from numpy.typing import NDArray
class spmatrix:
dtype: np.dtype[Any]
shape: tuple[int, int]
ndim: int
def toarray(self, order: Literal["C", "F", None] = None, out: None = None) -> cupy.ndarray: ...
def __power__(self, other: int) -> Self: ...
def __array__(self) -> NDArray[Any]: ...
Expand Down
1 change: 1 addition & 0 deletions typings/h5py.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class HLObject: ...
class Dataset(HLObject):
dtype: np.dtype[Any]
shape: tuple[int, ...]
ndim: int

class Group(HLObject): ...

Expand Down
Loading