Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/fast_array_utils/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _generic_op(
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
from ._generic_ops import generic_op

assert dtype is None or op in get_args(DtypeOps), f"`dtype` is not supported for operation '{op}'"
assert dtype is None or op in get_args(DtypeOps), f"`dtype` is not supported for operation {op!r}"

validate_axis(x.ndim, axis)
return generic_op(x, op, axis=axis, keep_cupy_as_array=keep_cupy_as_array, dtype=dtype)
Expand Down
15 changes: 8 additions & 7 deletions src/fast_array_utils/stats/_generic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .. import types
from ._typing import DtypeOps
from ._utils import _dask_inner
from ._utils import _dask_inner, _dtype_kw


if TYPE_CHECKING:
Expand All @@ -29,8 +29,8 @@ def _run_numpy_op(
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
kwargs = {"dtype": dtype} if op in get_args(DtypeOps) else {}
return getattr(np, op)(x, axis=axis, **kwargs) # type: ignore[no-any-return]
arr = cast("NDArray[Any] | np.number[Any] | types.CupyArray | types.CupyCOOMatrix | types.DaskArray", getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op)))
return arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr


@singledispatch
Expand Down Expand Up @@ -83,14 +83,15 @@ def _generic_op_cs(
# just convert to sparse array, then `return x.{op}(dtype=dtype)`
# https://github.com/scipy/scipy/issues/23768

kwargs = {"dtype": dtype} if op in get_args(DtypeOps) else {}
if axis is None:
return cast("np.number[Any]", getattr(x.data, op)(**kwargs))
return cast("np.number[Any]", getattr(x.data, op)(**_dtype_kw(dtype, op)))
if TYPE_CHECKING: # scipy-stubs thinks e.g. "int64" is invalid, which isn’t true
assert isinstance(dtype, np.dtype | type | None)
# convert to array so dimensions collapse as expected
x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **kwargs) # type: ignore[call-overload]
return cast("NDArray[Any] | np.number[Any]", getattr(x, op)(axis=axis))
x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **_dtype_kw(dtype, op)) # type: ignore[arg-type]
rv = cast("NDArray[Any] | types.coo_array | np.number[Any]", getattr(x, op)(axis=axis))
# old scipy versions’ sparray.{max,min}() return a 1×n/n×1 sparray here, so we squeeze
return rv.toarray().squeeze() if isinstance(rv, types.coo_array) else rv


@generic_op.register(types.DaskArray)
Expand Down
9 changes: 8 additions & 1 deletion src/fast_array_utils/stats/_typing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations

from typing import TYPE_CHECKING, Literal, Protocol
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypedDict, TypeVar

import numpy as np

Expand Down Expand Up @@ -49,3 +49,10 @@ def __call__(
NoDtypeOps = Literal["max", "min"]
DtypeOps = Literal["sum"]
Ops: TypeAlias = NoDtypeOps | DtypeOps


_DT = TypeVar("_DT", bound="DTypeLike")


class DTypeKw(TypedDict, Generic[_DT], total=False):
dtype: _DT
22 changes: 17 additions & 5 deletions src/fast_array_utils/stats/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, Literal, cast, get_args
from typing import TYPE_CHECKING, Literal, TypeVar, cast, get_args

import numpy as np
from numpy.exceptions import AxisError

from .. import types
from ..typing import GpuArray
from ._typing import DtypeOps


Expand All @@ -16,8 +17,8 @@

from numpy.typing import DTypeLike, NDArray

from ..typing import CpuArray, GpuArray
from ._typing import Ops
from ..typing import CpuArray
from ._typing import DTypeKw, Ops

ComplexAxis: TypeAlias = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1] | None

Expand Down Expand Up @@ -65,13 +66,17 @@ def _dask_block(
axis: ComplexAxis = None,
dtype: DTypeLike | None = None,
keepdims: bool = False,
computing_meta: bool = False,
) -> NDArray[Any] | types.CupyArray:
from . import max, min, sum

if computing_meta: # dask.blockwise doesn’t allow to pass `meta` in, and reductions below don’t handle a 0d matrix
return (types.CupyArray if isinstance(a, GpuArray) else np.ndarray)((), dtype or a.dtype)

fns = {fn.__name__: fn for fn in (min, max, sum)}

axis = _normalize_axis(axis, a.ndim)
rv = fns[op](a, axis=axis, dtype=dtype, keep_cupy_as_array=True) # type: ignore[misc,call-overload]
rv = fns[op](a, axis=axis, keep_cupy_as_array=True, **_dtype_kw(dtype, op)) # type: ignore[call-overload]
shape = _get_shape(rv, axis=axis, keepdims=keepdims)
return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape))

Expand Down Expand Up @@ -105,5 +110,12 @@ def _get_shape(a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Lite
assert axis is not None
return (1, a.size) if axis == 0 else (a.size, 1)
case _: # pragma: no cover
msg = f"{keepdims=}, {type(a)}"
msg = f"{keepdims=}, {a.ndim=}, {type(a)=}"
raise AssertionError(msg)


DT = TypeVar("DT", bound="DTypeLike")


def _dtype_kw(dtype: DT | None, op: Ops) -> DTypeKw[DT]:
return {"dtype": dtype} if dtype is not None and op in get_args(DtypeOps) else {}
25 changes: 20 additions & 5 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,20 @@ def test_sum_to_int(array_type: ArrayType[CpuArray | DiskArray | types.DaskArray
np.testing.assert_array_equal(sum_, expected)


@pytest.mark.array_type(skip=ATS_SPARSE_DS)
@pytest.mark.parametrize("func", [stats.min, stats.max])
def test_min_max(array_type: ArrayType[CpuArray | GpuArray | DiskArray | types.DaskArray], axis: Literal[0, 1] | None, func: StatFunNoDtype) -> None:
rng = np.random.default_rng(0)
np_arr = rng.random((100, 100))
arr = array_type(np_arr)

result = to_np_dense_checked(func(arr, axis=axis), axis, arr)

expected = (np.min if func is stats.min else np.max)(np_arr, axis=axis)
np.testing.assert_array_equal(result, expected)


@pytest.mark.parametrize("func", [stats.sum, stats.min, stats.max])
@pytest.mark.parametrize(
"data",
[
Expand All @@ -211,14 +225,15 @@ def test_sum_to_int(array_type: ArrayType[CpuArray | DiskArray | types.DaskArray
)
@pytest.mark.parametrize("axis", [0, 1])
@pytest.mark.array_type(Flags.Dask)
def test_sum_dask_shapes(array_type: ArrayType[types.DaskArray], axis: Literal[0, 1], data: list[list[int]]) -> None:
def test_dask_shapes(array_type: ArrayType[types.DaskArray], axis: Literal[0, 1], data: list[list[int]], func: StatFunNoDtype) -> None:
np_arr = np.array(data, dtype=np.float32)
arr = array_type(np_arr)
assert 1 in arr.chunksize, "This test is supposed to test 1×n and n×1 chunk sizes"
sum_ = cast("NDArray[Any] | types.CupyArray", stats.sum(arr, axis=axis).compute())
if isinstance(sum_, types.CupyArray):
sum_ = sum_.get()
np.testing.assert_almost_equal(np_arr.sum(axis=axis), sum_)
stat = cast("NDArray[Any] | types.CupyArray", func(arr, axis=axis).compute())
if isinstance(stat, types.CupyArray):
stat = stat.get()
np_func = getattr(np, func.__name__)
np.testing.assert_almost_equal(stat, np_func(np_arr, axis=axis))


@pytest.mark.array_type(skip=ATS_SPARSE_DS)
Expand Down
11 changes: 10 additions & 1 deletion typings/cupy/_core/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ from types import EllipsisType
from typing import Any, Literal, Self, overload

import numpy as np
from cupy.cuda import Stream
from cupy.cuda import MemoryPointer, Stream
from numpy._core.multiarray import flagsobj
from numpy.typing import DTypeLike, NDArray

Expand All @@ -14,6 +14,15 @@ class ndarray:
ndim: int
flags: flagsobj

def __init__(
self,
shape: tuple[int, ...],
dtype: DTypeLike | None = ...,
memptr: MemoryPointer | None = None,
strides: tuple[int, ...] | None = None,
order: Literal["C", "F"] = "C",
) -> None: ...

# cupy-specific
def get(
self, stream: Stream | None = None, order: Literal["C", "F", "A"] = "C", out: NDArray[Any] | None = None, blocking: bool = True
Expand Down
1 change: 1 addition & 0 deletions typings/cupy/cuda.pyi
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# SPDX-License-Identifier: MPL-2.0
class Stream: ...
class MemoryPointer: ...
Loading