Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions cuda_core/cuda/core/_memoryview.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ cdef class StridedMemoryView:
def from_dlpack(cls, obj: object, stream_ptr: int | None=None) -> StridedMemoryView:
cdef StridedMemoryView buf
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# ignore the warning triggered by calling the constructor
# inside the library we're allowed to do this
warnings.simplefilter("ignore", DeprecationWarning)
buf = cls()
view_as_dlpack(obj, stream_ptr, buf)
return buf
Expand All @@ -148,11 +150,20 @@ cdef class StridedMemoryView:
def from_cuda_array_interface(cls, obj: object, stream_ptr: int | None=None) -> StridedMemoryView:
cdef StridedMemoryView buf
with warnings.catch_warnings():
warnings.simplefilter("ignore")
warnings.simplefilter("ignore", DeprecationWarning)
buf = cls()
view_as_cai(obj, stream_ptr, buf)
return buf

@classmethod
def from_array_interface(cls, obj: object) -> StridedMemoryView:
cdef StridedMemoryView buf
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
buf = cls()
view_as_array_interface(obj, buf)
return buf

@classmethod
def from_any_interface(cls, obj: object, stream_ptr: int | None = None) -> StridedMemoryView:
if check_has_dlpack(obj):
Expand Down Expand Up @@ -597,6 +608,23 @@ cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None):
return buf


cpdef StridedMemoryView view_as_array_interface(obj, view=None):
cdef dict data = obj.__array_interface__
if data["version"] < 3:
raise BufferError("only NumPy Array Interface v3 or above is supported")
if data.get("mask") is not None:
raise BufferError("mask is not supported")

cdef StridedMemoryView buf = StridedMemoryView() if view is None else view
buf.exporting_obj = obj
buf.metadata = data
buf.dl_tensor = NULL
buf.ptr, buf.readonly = data["data"]
buf.is_device_accessible = False
buf.device_id = handle_return(driver.cuCtxGetDevice())
return buf


def args_viewable_as_strided_memory(tuple arg_indices):
"""
Decorator to create proxy objects to :obj:`StridedMemoryView` for the
Expand Down
2 changes: 1 addition & 1 deletion cuda_core/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def init_cuda():
driver.cuDevicePrimaryCtxSetFlags(device.device_id, driver.CUctx_flags.CU_CTX_SCHED_BLOCKING_SYNC)
)

yield
yield device
_ = _device_unset_current()


Expand Down
64 changes: 64 additions & 0 deletions cuda_core/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from cuda.core import Device
from cuda.core._layout import _StridedLayout
from cuda.core.utils import StridedMemoryView, args_viewable_as_strided_memory
from pytest import param


def test_cast_to_3_tuple_success():
Expand Down Expand Up @@ -460,3 +461,66 @@ def test_struct_array():
# full dtype information doesn't seem to be preserved due to use of type strings,
# which are lossy, e.g., dtype([("a", "int")]).str == "V8"
assert smv.dtype == np.dtype(f"V{x.itemsize}")


@pytest.mark.parametrize(
("x", "expected_dtype"),
[
# 1D arrays with different dtypes
param(np.array([1, 2, 3], dtype=np.int32), "int32", id="1d-int32"),
param(np.array([1.0, 2.0, 3.0], dtype=np.float64), "float64", id="1d-float64"),
param(np.array([1 + 2j, 3 + 4j], dtype=np.complex128), "complex128", id="1d-complex128"),
param(np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=np.complex64), "complex64", id="1d-complex64"),
param(np.array([1, 2, 3, 4, 5], dtype=np.uint8), "uint8", id="1d-uint8"),
param(np.array([1, 2], dtype=np.int64), "int64", id="1d-int64"),
param(np.array([100, 200, 300], dtype=np.int16), "int16", id="1d-int16"),
param(np.array([1000, 2000, 3000], dtype=np.uint16), "uint16", id="1d-uint16"),
param(np.array([10000, 20000, 30000], dtype=np.uint64), "uint64", id="1d-uint64"),
# 2D arrays - C-contiguous
param(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32), "int32", id="2d-c-int32"),
param(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), "float32", id="2d-c-float32"),
# 2D arrays - Fortran-contiguous
param(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32, order="F"), "int32", id="2d-f-int32"),
param(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64, order="F"), "float64", id="2d-f-float64"),
# 3D arrays
param(np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int32), "int32", id="3d-int32"),
param(np.ones((2, 3, 4), dtype=np.float64), "float64", id="3d-float64"),
# Sliced/strided arrays
param(np.array([1, 2, 3, 4, 5, 6], dtype=np.int32)[::2], "int32", id="1d-strided-int32"),
param(np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float64)[:, ::2], "float64", id="2d-strided-float64"),
param(np.arange(20, dtype=np.int32).reshape(4, 5)[::2, ::2], "int32", id="2d-strided-2x2-int32"),
# Scalar (0-D array)
param(np.array(42, dtype=np.int32), "int32", id="scalar-int32"),
param(np.array(3.14, dtype=np.float64), "float64", id="scalar-float64"),
# Empty arrays
param(np.array([], dtype=np.int32), "int32", id="empty-1d-int32"),
param(np.empty((0, 3), dtype=np.float64), "float64", id="empty-2d-float64"),
# Single element
param(np.array([1], dtype=np.int32), "int32", id="single-element"),
# Structured dtype
param(np.array([(1, 2.0), (3, 4.0)], dtype=[("a", "i4"), ("b", "f8")]), "V12", id="structured-dtype"),
],
)
def test_from_array_interface(x, init_cuda, expected_dtype):
smv = StridedMemoryView.from_array_interface(x)
assert smv.size == x.size
assert smv.dtype == np.dtype(expected_dtype)
assert smv.shape == x.shape
assert smv.ptr == x.ctypes.data
assert smv.device_id == init_cuda.device_id
assert smv.is_device_accessible is False
assert smv.exporting_obj is x
assert smv.readonly is not x.flags.writeable
# Check strides
strides_in_counts = convert_strides_to_counts(x.strides, x.dtype.itemsize)
assert (x.flags.c_contiguous and smv.strides is None) or smv.strides == strides_in_counts


def test_from_array_interface_unsupported_strides(init_cuda):
# Create an array with strides that aren't a multiple of itemsize
x = np.array([(1, 2.0), (3, 4.0)], dtype=[("a", "i4"), ("b", "f8")])
b = x["b"]
smv = StridedMemoryView.from_array_interface(b)
with pytest.raises(ValueError, match="strides must be divisible by itemsize"):
# TODO: ideally this would raise on construction
smv.strides # noqa: B018
Loading