diff --git a/src/fast_array_utils/stats/__init__.py b/src/fast_array_utils/stats/__init__.py index a4aaa29..e87bf29 100644 --- a/src/fast_array_utils/stats/__init__.py +++ b/src/fast_array_utils/stats/__init__.py @@ -7,10 +7,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, cast, get_args, overload from .._validation import validate_axis from ..typing import CpuArray, DiskArray, GpuArray # noqa: TC001 +from ._generic_ops import DtypeOps if TYPE_CHECKING: @@ -21,9 +22,11 @@ from optype.numpy import ToDType from .. import types + from ._generic_ops import Ops + from ._typing import NoDtypeOps, StatFunDtype, StatFunNoDtype -__all__ = ["is_constant", "mean", "mean_var", "sum"] +__all__ = ["is_constant", "max", "mean", "mean_var", "min", "sum"] @overload @@ -201,26 +204,161 @@ def mean_var( return mean_var_(x, axis=axis, correction=correction) # type: ignore[no-any-return] +@overload +def _mk_generic_op(op: NoDtypeOps) -> StatFunNoDtype: ... +@overload +def _mk_generic_op(op: DtypeOps) -> StatFunDtype: ... + + # TODO(flying-sheep): support CSDataset (TODO) # https://github.com/scverse/fast-array-utils/issues/52 +def _mk_generic_op(op: Ops) -> StatFunNoDtype | StatFunDtype: + def _generic_op( + x: CpuArray | GpuArray | DiskArray | types.DaskArray, + /, + *, + axis: Literal[0, 1] | None = None, + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, + ) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray: + from ._generic_ops import generic_op + + assert dtype is None or op in get_args(DtypeOps), f"`dtype` is not supported for operation '{op}'" + + validate_axis(x.ndim, axis) + return generic_op(x, op, axis=axis, keep_cupy_as_array=keep_cupy_as_array, dtype=dtype) + + _generic_op.__name__ = op + return cast("StatFunNoDtype | StatFunDtype", _generic_op) + + +_min = _mk_generic_op("min") +_max = _mk_generic_op("max") +_sum = _mk_generic_op("sum") + + @overload -def sum(x: CpuArray | DiskArray, /, *, axis: None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> np.number[Any]: ... +def min(x: CpuArray | DiskArray, /, *, axis: None = None, keep_cupy_as_array: bool = False) -> np.number[Any]: ... @overload -def sum(x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> NDArray[Any]: ... +def min(x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> NDArray[Any]: ... +@overload +def min(x: GpuArray, /, *, axis: None = None, keep_cupy_as_array: Literal[False] = False) -> np.number[Any]: ... +@overload +def min(x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> types.CupyArray: ... +@overload +def min(x: GpuArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> types.CupyArray: ... +@overload +def min(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ... +def min( + x: CpuArray | GpuArray | DiskArray | types.DaskArray, + /, + *, + axis: Literal[0, 1] | None = None, + keep_cupy_as_array: bool = False, +) -> object: + """Find the minimum along both or one axis. + + Parameters + ---------- + x + Array to find the minimum(s) in. + axis + Axis to reduce over. + + Returns + ------- + If ``axis`` is :data:`None`, then the minimum element is returned as a scalar. + Otherwise, the minimum along the given axis is returned as a 1D array. + + Example + ------- + >>> import numpy as np + >>> x = np.array([ + ... [0, 1, 2], + ... [1, 1, 1], + ... ]) + >>> min(x) + 0 + >>> min(x, axis=0) + array([0, 1, 1]) + >>> min(x, axis=1) + array([0, 1]) + + See Also + -------- + :func:`numpy.min` + + """ + return _min(x, axis=axis, keep_cupy_as_array=keep_cupy_as_array) @overload -def sum(x: GpuArray, /, *, axis: None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: Literal[False] = False) -> np.number[Any]: ... +def max(x: CpuArray | DiskArray, /, *, axis: None = None, keep_cupy_as_array: bool = False) -> np.number[Any]: ... @overload -def sum(x: GpuArray, /, *, axis: None, dtype: DTypeLike | None = None, keep_cupy_as_array: Literal[True]) -> types.CupyArray: ... +def max(x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> NDArray[Any]: ... @overload -def sum(x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> types.CupyArray: ... +def max(x: GpuArray, /, *, axis: None = None, keep_cupy_as_array: Literal[False] = False) -> np.number[Any]: ... +@overload +def max(x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> types.CupyArray: ... +@overload +def max(x: GpuArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> types.CupyArray: ... +@overload +def max(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ... +def max( + x: CpuArray | GpuArray | DiskArray | types.DaskArray, + /, + *, + axis: Literal[0, 1] | None = None, + keep_cupy_as_array: bool = False, +) -> object: + """Find the maximum along both or one axis. + + Parameters + ---------- + x + Array to find the maximum(s) in. + axis + Axis to reduce over. + Returns + ------- + If ``axis`` is :data:`None`, then the maximum element is returned as a scalar. + Otherwise, the maximum along the given axis is returned as a 1D array. -@overload -def sum(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ... + Example + ------- + >>> import numpy as np + >>> x = np.array([ + ... [0, 1, 2], + ... [0, 0, 0], + ... ]) + >>> max(x) + 2 + >>> max(x, axis=0) + array([0, 1, 2]) + >>> max(x, axis=1) + array([2, 0]) + See Also + -------- + :func:`numpy.max` + """ + return _max(x, axis=axis, keep_cupy_as_array=keep_cupy_as_array) + + +@overload +def sum(x: CpuArray | DiskArray, /, *, axis: None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> np.number[Any]: ... +@overload +def sum(x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> NDArray[Any]: ... +@overload +def sum(x: GpuArray, /, *, axis: None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: Literal[False] = False) -> np.number[Any]: ... +@overload +def sum(x: GpuArray, /, *, axis: None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: Literal[True]) -> types.CupyArray: ... +@overload +def sum(x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> types.CupyArray: ... +@overload +def sum(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ... def sum( x: CpuArray | GpuArray | DiskArray | types.DaskArray, /, @@ -262,7 +400,4 @@ def sum( :func:`numpy.sum` """ - from ._sum import sum_ - - validate_axis(x.ndim, axis) - return sum_(x, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array) + return _sum(x, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array) diff --git a/src/fast_array_utils/stats/_generic_ops.py b/src/fast_array_utils/stats/_generic_ops.py new file mode 100644 index 0000000..d342517 --- /dev/null +++ b/src/fast_array_utils/stats/_generic_ops.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: MPL-2.0 +from __future__ import annotations + +from functools import singledispatch +from typing import TYPE_CHECKING, cast, get_args + +import numpy as np + +from .. import types +from ._typing import DtypeOps +from ._utils import _dask_inner + + +if TYPE_CHECKING: + from typing import Any, Literal, TypeAlias + + from numpy.typing import DTypeLike, NDArray + + from ..typing import CpuArray, DiskArray, GpuArray + from ._typing import Ops + + ComplexAxis: TypeAlias = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1] | None + + +def _run_numpy_op( + x: CpuArray | GpuArray | DiskArray | types.DaskArray, + op: Ops, + *, + axis: Literal[0, 1] | None = None, + dtype: DTypeLike | None = None, +) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray: + kwargs = {"dtype": dtype} if op in get_args(DtypeOps) else {} + return getattr(np, op)(x, axis=axis, **kwargs) # type: ignore[no-any-return] + + +@singledispatch +def generic_op( + x: CpuArray | GpuArray | DiskArray | types.DaskArray, + /, + op: Ops, + *, + axis: Literal[0, 1] | None = None, + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, +) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray: + del keep_cupy_as_array + if TYPE_CHECKING: + # these are never passed to this fallback function, but `singledispatch` wants them + assert not isinstance(x, types.CSBase | types.DaskArray | types.CupyArray | types.CupyCSMatrix) + # np supports these, but doesn’t know it. (TODO: test cupy) + assert not isinstance(x, types.ZarrArray | types.H5Dataset) + return cast("NDArray[Any] | np.number[Any]", _run_numpy_op(x, op, axis=axis, dtype=dtype)) + + +@generic_op.register(types.CupyArray | types.CupyCSMatrix) +def _generic_op_cupy( + x: GpuArray, + /, + op: Ops, + *, + axis: Literal[0, 1] | None = None, + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, +) -> types.CupyArray | np.number[Any]: + arr = cast("types.CupyArray", _run_numpy_op(x, op, axis=axis, dtype=dtype)) + return cast("np.number[Any]", arr.get()[()]) if not keep_cupy_as_array and axis is None else arr.squeeze() + + +@generic_op.register(types.CSBase) +def _generic_op_cs( + x: types.CSBase, + /, + op: Ops, + *, + axis: Literal[0, 1] | None = None, + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, +) -> NDArray[Any] | np.number[Any]: + del keep_cupy_as_array + import scipy.sparse as sp + + # TODO(flying-sheep): once scipy fixes this issue, instead of all this, + # just convert to sparse array, then `return x.{op}(dtype=dtype)` + # https://github.com/scipy/scipy/issues/23768 + + kwargs = {"dtype": dtype} if op in get_args(DtypeOps) else {} + if axis is None: + return cast("np.number[Any]", getattr(x.data, op)(**kwargs)) + if TYPE_CHECKING: # scipy-stubs thinks e.g. "int64" is invalid, which isn’t true + assert isinstance(dtype, np.dtype | type | None) + # convert to array so dimensions collapse as expected + x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **kwargs) # type: ignore[call-overload] + return cast("NDArray[Any] | np.number[Any]", getattr(x, op)(axis=axis)) + + +@generic_op.register(types.DaskArray) +def _generic_op_dask( + x: types.DaskArray, + /, + op: Ops, + *, + axis: Literal[0, 1] | None = None, + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, +) -> types.DaskArray: + if op in get_args(DtypeOps) and dtype is None: + # Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`) + dtype = getattr(np, op)(np.zeros(1, dtype=x.dtype)).dtype + + return _dask_inner(x, op, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array) diff --git a/src/fast_array_utils/stats/_mean.py b/src/fast_array_utils/stats/_mean.py index 9588a77..ba08164 100644 --- a/src/fast_array_utils/stats/_mean.py +++ b/src/fast_array_utils/stats/_mean.py @@ -5,7 +5,7 @@ import numpy as np -from ._sum import sum_ +from . import sum if TYPE_CHECKING: @@ -24,6 +24,6 @@ def mean_( axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, ) -> NDArray[np.number[Any]] | np.number[Any] | types.DaskArray: - total = sum_(x, axis=axis, dtype=dtype) + total = sum(x, axis=axis, dtype=dtype) # type: ignore[misc,arg-type] n = np.prod(x.shape) if axis is None else x.shape[axis] - return total / n # type: ignore[operator,return-value] + return total / n # type: ignore[no-any-return] diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py deleted file mode 100644 index d3f09c8..0000000 --- a/src/fast_array_utils/stats/_sum.py +++ /dev/null @@ -1,168 +0,0 @@ -# SPDX-License-Identifier: MPL-2.0 -from __future__ import annotations - -from functools import partial, singledispatch -from typing import TYPE_CHECKING, Literal, cast - -import numpy as np -from numpy.exceptions import AxisError - -from .. import types - - -if TYPE_CHECKING: - from typing import Any, Literal, TypeAlias - - from numpy.typing import DTypeLike, NDArray - - from ..typing import CpuArray, DiskArray, GpuArray - - ComplexAxis: TypeAlias = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1] | None - - -@singledispatch -def sum_( - x: CpuArray | GpuArray | DiskArray | types.DaskArray, - /, - *, - axis: Literal[0, 1] | None = None, - dtype: DTypeLike | None = None, - keep_cupy_as_array: bool = False, -) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray: - del keep_cupy_as_array - if TYPE_CHECKING: - # these are never passed to this fallback function, but `singledispatch` wants them - assert not isinstance(x, types.CSBase | types.DaskArray | types.CupyArray | types.CupyCSMatrix) - # np.sum supports these, but doesn’t know it. (TODO: test cupy) - assert not isinstance(x, types.ZarrArray | types.H5Dataset) - return cast("NDArray[Any] | np.number[Any]", np.sum(x, axis=axis, dtype=dtype)) - - -@sum_.register(types.CupyArray | types.CupyCSMatrix) -def _sum_cupy( - x: GpuArray, - /, - *, - axis: Literal[0, 1] | None = None, - dtype: DTypeLike | None = None, - keep_cupy_as_array: bool = False, -) -> types.CupyArray | np.number[Any]: - arr = cast("types.CupyArray", np.sum(x, axis=axis, dtype=dtype)) - return cast("np.number[Any]", arr.get()[()]) if not keep_cupy_as_array and axis is None else arr.squeeze() - - -@sum_.register(types.CSBase) -def _sum_cs( - x: types.CSBase, - /, - *, - axis: Literal[0, 1] | None = None, - dtype: DTypeLike | None = None, - keep_cupy_as_array: bool = False, -) -> NDArray[Any] | np.number[Any]: - del keep_cupy_as_array - import scipy.sparse as sp - - # TODO(flying-sheep): once scipy fixes this issue, instead of all this, - # just convert to sparse array, then `return x.sum(dtype=dtype)` - # https://github.com/scipy/scipy/issues/23768 - - if axis is None: - return cast("NDArray[Any] | np.number[Any]", x.data.sum(dtype=dtype)) - - if TYPE_CHECKING: # scipy-stubs thinks e.g. "int64" is invalid, which isn’t true - assert isinstance(dtype, np.dtype | type | None) - # convert to array so dimensions collapse as expected - x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, dtype=dtype) - return cast("NDArray[Any] | np.number[Any]", x.sum(axis=axis)) - - -@sum_.register(types.DaskArray) -def _sum_dask( - x: types.DaskArray, - /, - *, - axis: Literal[0, 1] | None = None, - dtype: DTypeLike | None = None, - keep_cupy_as_array: bool = False, -) -> types.DaskArray: - import dask.array as da - - if isinstance(x._meta, np.matrix): # pragma: no cover # noqa: SLF001 - msg = "sum does not support numpy matrices" - raise TypeError(msg) - - if dtype is None: - # Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`) - dtype = np.zeros(1, dtype=x.dtype).sum().dtype - - rv = da.reduction( - x, - partial(sum_dask_inner, dtype=dtype), # pyright: ignore[reportArgumentType] - partial(sum_dask_inner, dtype=dtype), # pyright: ignore[reportArgumentType] - axis=axis, - dtype=dtype, - meta=np.array([], dtype=dtype), - ) - - if axis is not None or ( - isinstance(rv._meta, types.CupyArray) # noqa: SLF001 - and keep_cupy_as_array - ): - return rv - - def to_scalar(a: types.CupyArray | NDArray[Any]) -> np.number[Any]: - if isinstance(a, types.CupyArray): - a = a.get() - return a.reshape(())[()] # type: ignore[return-value] - - return rv.map_blocks(to_scalar, meta=x.dtype.type(0)) # type: ignore[arg-type] - - -def sum_dask_inner( - a: CpuArray | GpuArray, - /, - *, - axis: ComplexAxis = None, - dtype: DTypeLike | None = None, - keepdims: bool = False, -) -> NDArray[Any] | types.CupyArray: - from . import sum - - axis = normalize_axis(axis, a.ndim) - rv = sum(a, axis=axis, dtype=dtype, keep_cupy_as_array=True) # type: ignore[misc,arg-type] - shape = get_shape(rv, axis=axis, keepdims=keepdims) - return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape)) - - -def normalize_axis(axis: ComplexAxis, ndim: int) -> Literal[0, 1] | None: - """Adapt `axis` parameter passed by Dask to what we support.""" - match axis: - case int() | None: - pass - case (0 | 1,): - axis = axis[0] - case (0, 1) | (1, 0): - axis = None - case _: # pragma: no cover - raise AxisError(axis, ndim) # type: ignore[call-overload] - if axis == 0 and ndim == 1: - return None # dask’s aggregate doesn’t know we don’t accept `axis=0` for 1D arrays - return axis - - -def get_shape(a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Literal[0, 1] | None, keepdims: bool) -> tuple[int] | tuple[int, int]: - """Get the output shape of an axis-flattening operation.""" - match keepdims, a.ndim: - case False, 0: - return (1,) - case True, 0: - return (1, 1) - case False, 1: - return (a.size,) - case True, 1: - assert axis is not None - return (1, a.size) if axis == 0 else (a.size, 1) - # pragma: no cover - msg = f"{keepdims=}, {type(a)}" - raise AssertionError(msg) diff --git a/src/fast_array_utils/stats/_typing.py b/src/fast_array_utils/stats/_typing.py new file mode 100644 index 0000000..e8b0b65 --- /dev/null +++ b/src/fast_array_utils/stats/_typing.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: MPL-2.0 +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, Protocol + +import numpy as np + +from fast_array_utils import types + +from ..typing import CpuArray, DiskArray, GpuArray + + +if TYPE_CHECKING: + from typing import Any, TypeAlias + + from numpy.typing import DTypeLike, NDArray + + +Array: TypeAlias = CpuArray | GpuArray | DiskArray | types.CSDataset | types.DaskArray + +DTypeIn: TypeAlias = np.float32 | np.float64 | np.int32 | np.bool_ +DTypeOut: TypeAlias = np.float32 | np.float64 | np.int64 + +NdAndAx: TypeAlias = tuple[Literal[1], None] | tuple[Literal[2], Literal[0, 1] | None] + + +class StatFunNoDtype(Protocol): + __name__: str + + def __call__( + self, x: CpuArray | GpuArray | DiskArray | types.DaskArray, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False + ) -> types.DaskArray: ... + + +class StatFunDtype(Protocol): + __name__: str + + def __call__( + self, + x: CpuArray | GpuArray | DiskArray | types.DaskArray, + /, + *, + axis: Literal[0, 1] | None = None, + dtype: DTypeLike | None = None, + keep_cupy_as_array: bool = False, + ) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray: ... + + +NoDtypeOps = Literal["max", "min"] +DtypeOps = Literal["sum"] +Ops: TypeAlias = NoDtypeOps | DtypeOps diff --git a/src/fast_array_utils/stats/_utils.py b/src/fast_array_utils/stats/_utils.py new file mode 100644 index 0000000..7d1fcbb --- /dev/null +++ b/src/fast_array_utils/stats/_utils.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: MPL-2.0 +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING, Literal, cast, get_args + +import numpy as np +from numpy.exceptions import AxisError + +from .. import types +from ._typing import DtypeOps + + +if TYPE_CHECKING: + from typing import Any, Literal, TypeAlias + + from numpy.typing import DTypeLike, NDArray + + from ..typing import CpuArray, GpuArray + from ._typing import Ops + + ComplexAxis: TypeAlias = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1] | None + + +__all__ = ["_dask_inner"] + + +def _dask_inner(x: types.DaskArray, op: Ops, /, *, axis: Literal[0, 1] | None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool) -> types.DaskArray: + import dask.array as da + + if isinstance(x._meta, np.matrix): # pragma: no cover # noqa: SLF001 + msg = "sum/max/min does not support numpy matrices" + raise TypeError(msg) + + res_dtype = dtype if op in get_args(DtypeOps) else x.dtype + + rv = da.reduction( + x, + partial(_dask_block, op, dtype=dtype), + partial(_dask_block, op, dtype=dtype), + axis=axis, + dtype=res_dtype, + meta=np.array([], dtype=res_dtype), + ) + + if axis is not None or ( + isinstance(rv._meta, types.CupyArray) # noqa: SLF001 + and keep_cupy_as_array + ): + return rv + + def to_scalar(a: types.CupyArray | NDArray[Any]) -> np.number[Any]: + if isinstance(a, types.CupyArray): + a = a.get() + return a.reshape(())[()] # type: ignore[return-value] + + return rv.map_blocks(to_scalar, meta=x.dtype.type(0)) # type: ignore[arg-type] + + +def _dask_block( + op: Ops, + a: CpuArray | GpuArray, + /, + *, + axis: ComplexAxis = None, + dtype: DTypeLike | None = None, + keepdims: bool = False, +) -> NDArray[Any] | types.CupyArray: + from . import max, min, sum + + fns = {fn.__name__: fn for fn in (min, max, sum)} + + axis = _normalize_axis(axis, a.ndim) + rv = fns[op](a, axis=axis, dtype=dtype, keep_cupy_as_array=True) # type: ignore[misc,call-overload] + shape = _get_shape(rv, axis=axis, keepdims=keepdims) + return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape)) + + +def _normalize_axis(axis: ComplexAxis, ndim: int) -> Literal[0, 1] | None: + """Adapt `axis` parameter passed by Dask to what we support.""" + match axis: + case int() | None: # pragma: no cover + pass + case (0 | 1,): + axis = axis[0] + case (0, 1) | (1, 0): + axis = None + case _: # pragma: no cover + raise AxisError(axis, ndim) # type: ignore[call-overload] + if axis == 0 and ndim == 1: + return None # dask’s aggregate doesn’t know we don’t accept `axis=0` for 1D arrays + return axis + + +def _get_shape(a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Literal[0, 1] | None, keepdims: bool) -> tuple[int] | tuple[int, int]: + """Get the output shape of an axis-flattening operation.""" + match keepdims, a.ndim: + case False, 0: + return (1,) + case True, 0: + return (1, 1) + case False, 1: + return (a.size,) + case True, 1: + assert axis is not None + return (1, a.size) if axis == 0 else (a.size, 1) + case _: # pragma: no cover + msg = f"{keepdims=}, {type(a)}" + raise AssertionError(msg) diff --git a/tests/test_stats.py b/tests/test_stats.py index ffde919..849df4d 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -19,35 +19,20 @@ if TYPE_CHECKING: from collections.abc import Callable - from typing import Any, Literal, Protocol, TypeAlias + from typing import Any, Literal from numpy.typing import NDArray from pytest_codspeed import BenchmarkFixture + from fast_array_utils.stats._typing import Array, DTypeIn, DTypeOut, NdAndAx, StatFunNoDtype from fast_array_utils.typing import CpuArray, DiskArray, GpuArray from testing.fast_array_utils import ArrayType - Array: TypeAlias = CpuArray | GpuArray | DiskArray | types.CSDataset | types.DaskArray - - DTypeIn = np.float32 | np.float64 | np.int32 | np.bool - DTypeOut = np.float32 | np.float64 | np.int64 - - NdAndAx: TypeAlias = tuple[Literal[1], None] | tuple[Literal[2], Literal[0, 1] | None] - - class StatFun(Protocol): # noqa: D101 - def __call__( # noqa: D102 - self, - arr: Array, - *, - axis: Literal[0, 1] | None = None, - dtype: type[DTypeOut] | None = None, - ) -> NDArray[Any] | np.number[Any] | types.DaskArray: ... - pytestmark = [pytest.mark.skipif(not find_spec("numba"), reason="numba not installed")] -STAT_FUNCS = [stats.sum, stats.mean, stats.mean_var, stats.is_constant] +STAT_FUNCS = [stats.sum, stats.min, stats.max, stats.mean, stats.mean_var, stats.is_constant] # can’t select these using a category filter ATS_SPARSE_DS = {at for at in SUPPORTED_TYPES if at.mod == "anndata.abc"} @@ -149,7 +134,7 @@ def pbmc64k_reduced_raw() -> sps.csr_array[np.float32]: @pytest.mark.array_type(skip={*ATS_SPARSE_DS, Flags.Matrix}) @pytest.mark.parametrize("func", STAT_FUNCS) @pytest.mark.parametrize(("ndim", "axis"), [(1, 0), (2, 3), (2, -1)], ids=["1d-ax0", "2d-ax3", "2d-axneg"]) -def test_ndim_error(array_type: ArrayType[Array], func: StatFun, ndim: Literal[1, 2], axis: Literal[0, 1] | None) -> None: +def test_ndim_error(array_type: ArrayType[Array], func: StatFunNoDtype, ndim: Literal[1, 2], axis: Literal[0, 1] | None) -> None: check_ndim(array_type, ndim) # not using the fixture because we don’t need to test multiple dtypes np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) @@ -364,7 +349,7 @@ def test_dask_constant_blocks(dask_viz: Callable[[object], None], array_type: Ar @pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32]) def test_stats_benchmark( benchmark: BenchmarkFixture, - func: StatFun, + func: StatFunNoDtype, array_type: ArrayType[CpuArray, None], axis: Literal[0, 1] | None, dtype: type[np.float32 | np.float64],