Skip to content
35 changes: 35 additions & 0 deletions cuda_bindings/cuda/bindings/driver.pyx.in
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ ctypedef unsigned long long float_ptr
ctypedef unsigned long long double_ptr
ctypedef unsigned long long void_ptr

cdef dict _cu_mem_alloc_managed_concurrent_access_by_device = {}

#: CUDA API version number
CUDA_VERSION = cydriver.CUDA_VERSION

Expand Down Expand Up @@ -31341,6 +31343,39 @@ def cuMemAllocManaged(size_t bytesize, unsigned int flags):
--------
:py:obj:`~.cuArray3DCreate`, :py:obj:`~.cuArray3DGetDescriptor`, :py:obj:`~.cuArrayCreate`, :py:obj:`~.cuArrayDestroy`, :py:obj:`~.cuArrayGetDescriptor`, :py:obj:`~.cuMemAllocHost`, :py:obj:`~.cuMemAllocPitch`, :py:obj:`~.cuMemcpy2D`, :py:obj:`~.cuMemcpy2DAsync`, :py:obj:`~.cuMemcpy2DUnaligned`, :py:obj:`~.cuMemcpy3D`, :py:obj:`~.cuMemcpy3DAsync`, :py:obj:`~.cuMemcpyAtoA`, :py:obj:`~.cuMemcpyAtoD`, :py:obj:`~.cuMemcpyAtoH`, :py:obj:`~.cuMemcpyAtoHAsync`, :py:obj:`~.cuMemcpyDtoA`, :py:obj:`~.cuMemcpyDtoD`, :py:obj:`~.cuMemcpyDtoDAsync`, :py:obj:`~.cuMemcpyDtoH`, :py:obj:`~.cuMemcpyDtoHAsync`, :py:obj:`~.cuMemcpyHtoA`, :py:obj:`~.cuMemcpyHtoAAsync`, :py:obj:`~.cuMemcpyHtoD`, :py:obj:`~.cuMemcpyHtoDAsync`, :py:obj:`~.cuMemFree`, :py:obj:`~.cuMemFreeHost`, :py:obj:`~.cuMemGetAddressRange`, :py:obj:`~.cuMemGetInfo`, :py:obj:`~.cuMemHostAlloc`, :py:obj:`~.cuMemHostGetDevicePointer`, :py:obj:`~.cuMemsetD2D8`, :py:obj:`~.cuMemsetD2D16`, :py:obj:`~.cuMemsetD2D32`, :py:obj:`~.cuMemsetD8`, :py:obj:`~.cuMemsetD16`, :py:obj:`~.cuMemsetD32`, :py:obj:`~.cuDeviceGetAttribute`, :py:obj:`~.cuStreamAttachMemAsync`, :py:obj:`~.cudaMallocManaged`
"""
# WIP-WIP-WIP THIS CODE NEEDS TO BE PORTED TO THE CODE GENERATOR
cdef int concurrent_access = 0
cdef int device_id = 0
cdef cydriver.CUdevice device
err = cydriver.cuCtxGetDevice(&device)
if err != cydriver.CUDA_SUCCESS:
# cuMemAllocManaged would fail with the same error anyway.
return (_CUresult(err), None)
device_id = <int>device
if device_id in _cu_mem_alloc_managed_concurrent_access_by_device:
if _cu_mem_alloc_managed_concurrent_access_by_device[device_id] == 0:
raise RuntimeError(
"cuMemAllocManaged is not supported when "
"CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS == 0"
)
else:
err = cydriver.cuDeviceGetAttribute(
&concurrent_access,
cydriver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS,
device,
)
if err != cydriver.CUDA_SUCCESS:
raise RuntimeError(
"cuDeviceGetAttribute(CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS) failed "
f"while validating cuMemAllocManaged: {_CUresult(err)}"
)
_cu_mem_alloc_managed_concurrent_access_by_device[device_id] = concurrent_access
if concurrent_access == 0:
raise RuntimeError(
"cuMemAllocManaged is not supported when "
"CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS == 0"
)

cdef CUdeviceptr dptr = CUdeviceptr()
with nogil:
err = cydriver.cuMemAllocManaged(<cydriver.CUdeviceptr*>dptr._pvt_ptr, bytesize, flags)
Expand Down
3 changes: 3 additions & 0 deletions cuda_bindings/tests/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import pytest
from cuda_python_test_helpers.managed_memory import skip_if_concurrent_managed_access_disabled

import cuda.bindings.driver as cuda
import cuda.bindings.runtime as cudart
Expand Down Expand Up @@ -325,6 +326,7 @@ def test_cuda_memPool_attr():
driverVersionLessThan(11030) or not supportsManagedMemory(), reason="When new attributes were introduced"
)
def test_cuda_pointer_attr():
skip_if_concurrent_managed_access_disabled()
err, ptr = cuda.cuMemAllocManaged(0x1000, cuda.CUmemAttach_flags.CU_MEM_ATTACH_GLOBAL.value)
assert err == cuda.CUresult.CUDA_SUCCESS

Expand Down Expand Up @@ -390,6 +392,7 @@ def test_pointer_get_attributes_device_ordinal():

@pytest.mark.skipif(not supportsManagedMemory(), reason="When new attributes were introduced")
def test_cuda_mem_range_attr(device):
skip_if_concurrent_managed_access_disabled()
size = 0x1000
location_device = cuda.CUmemLocation()
location_device.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
Expand Down
12 changes: 12 additions & 0 deletions cuda_core/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,14 @@ def skip_if_pinned_memory_unsupported(device):
pytest.skip("PinnedMemoryResource requires CUDA 13.0 or later")


def _skip_if_concurrent_managed_access_disabled(device=None) -> None:
from cuda_python_test_helpers.managed_memory import skip_if_concurrent_managed_access_disabled

skip_if_concurrent_managed_access_disabled(device)


def skip_if_managed_memory_unsupported(device):
_skip_if_concurrent_managed_access_disabled(device)
try:
if not device.properties.memory_pools_supported or not device.properties.concurrent_managed_access:
pytest.skip("Device does not support managed memory pool operations")
Expand All @@ -74,6 +81,11 @@ def create_managed_memory_resource_or_skip(*args, **kwargs):
raise


@pytest.fixture
def requires_concurrent_managed_access():
_skip_if_concurrent_managed_access_disabled()


@pytest.fixture(scope="session", autouse=True)
def session_setup():
# Always init CUDA.
Expand Down
2 changes: 2 additions & 0 deletions cuda_core/tests/graph/test_capture_alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def free(self, buffers):
self.stream.sync()


@pytest.mark.usefixtures("requires_concurrent_managed_access")
@pytest.mark.parametrize("mode", ["no_graph", "global", "thread_local", "relaxed"])
@pytest.mark.parametrize("action", ["incr", "fill"])
def test_graph_alloc(mempool_device, mode, action):
Expand Down Expand Up @@ -146,6 +147,7 @@ def apply_kernels(mr, stream, out):
assert compare_buffer_to_constant(out, 6)


@pytest.mark.usefixtures("requires_concurrent_managed_access")
@pytest.mark.skipif(IS_WINDOWS or IS_WSL, reason="auto_free_on_launch not supported on Windows")
@pytest.mark.parametrize("mode", ["global", "thread_local", "relaxed"])
def test_graph_alloc_with_output(mempool_device, mode):
Expand Down
2 changes: 2 additions & 0 deletions cuda_core/tests/memory_ipc/test_event_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
class TestEventIpc:
"""Check the basic usage of IPC-enabled events with a latch kernel."""

pytestmark = pytest.mark.usefixtures("requires_concurrent_managed_access")

@pytest.mark.flaky(reruns=2)
def test_main(self, ipc_device, ipc_memory_resource):
log = TimestampedLogger(prefix="parent: ", enabled=ENABLE_LOGGING)
Expand Down
2 changes: 2 additions & 0 deletions cuda_core/tests/memory_ipc/test_memory_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
NWORKERS = 2
NTASKS = 2

pytestmark = pytest.mark.usefixtures("requires_concurrent_managed_access")


class TestIpcMempool:
@pytest.mark.flaky(reruns=2)
Expand Down
2 changes: 2 additions & 0 deletions cuda_core/tests/memory_ipc/test_peer_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class TestBufferPeerAccessAfterImport:
setting peer access on the imported memory resource, and that access can be revoked.
"""

pytestmark = pytest.mark.usefixtures("requires_concurrent_managed_access")

@pytest.mark.flaky(reruns=2)
@pytest.mark.parametrize("grant_access_in_parent", [True, False])
def test_main(self, mempool_device_x2, grant_access_in_parent):
Expand Down
2 changes: 2 additions & 0 deletions cuda_core/tests/memory_ipc/test_send_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
NTASKS = 7
POOL_SIZE = 2097152

pytestmark = pytest.mark.usefixtures("requires_concurrent_managed_access")


class TestIpcSendBuffers:
@pytest.mark.flaky(reruns=2)
Expand Down
2 changes: 2 additions & 0 deletions cuda_core/tests/memory_ipc/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
NBYTES = 64
POOL_SIZE = 2097152

pytestmark = pytest.mark.usefixtures("requires_concurrent_managed_access")


class TestObjectSerializationDirect:
"""
Expand Down
2 changes: 2 additions & 0 deletions cuda_core/tests/memory_ipc/test_workerpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
NTASKS = 20
POOL_SIZE = 2097152

pytestmark = pytest.mark.usefixtures("requires_concurrent_managed_access")


class TestIpcWorkerPool:
"""
Expand Down
2 changes: 2 additions & 0 deletions cuda_core/tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
ENABLE_LOGGING = False # Set True for test debugging and development
NBYTES = 64

pytestmark = pytest.mark.usefixtures("requires_concurrent_managed_access")


def test_latchkernel():
"""Test LatchKernel."""
Expand Down
86 changes: 66 additions & 20 deletions cuda_core/tests/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,26 @@ def buffer_initialization(dummy_mr: MemoryResource):
buffer.close()


def test_buffer_initialization():
@pytest.mark.parametrize(
("mr_factory", "needs_device"),
[
(DummyDeviceMemoryResource, True),
(DummyHostMemoryResource, False),
(DummyUnifiedMemoryResource, True),
(DummyPinnedMemoryResource, True),
],
ids=["device", "host", "unified", "pinned"],
)
def test_buffer_initialization(mr_factory, needs_device, request):
device = Device()
device.set_current()
buffer_initialization(DummyDeviceMemoryResource(device))
buffer_initialization(DummyHostMemoryResource())
buffer_initialization(DummyUnifiedMemoryResource(device))
buffer_initialization(DummyPinnedMemoryResource(device))
if mr_factory is DummyUnifiedMemoryResource:
request.getfixturevalue("requires_concurrent_managed_access")
mr = mr_factory(device) if needs_device else mr_factory()
buffer_initialization(mr)


def test_buffer_initialization_invalid_mr():
with pytest.raises(TypeError):
buffer_initialization(MemoryResource())

Expand Down Expand Up @@ -198,12 +211,22 @@ def buffer_copy_to(dummy_mr: MemoryResource, device: Device, check=False):
src_buffer.close()


def test_buffer_copy_to():
@pytest.mark.parametrize(
("mr_factory", "check"),
[
(DummyDeviceMemoryResource, False),
(DummyUnifiedMemoryResource, False),
(DummyPinnedMemoryResource, True),
],
ids=["device", "unified", "pinned"],
)
def test_buffer_copy_to(mr_factory, check, request):
device = Device()
device.set_current()
buffer_copy_to(DummyDeviceMemoryResource(device), device)
buffer_copy_to(DummyUnifiedMemoryResource(device), device)
buffer_copy_to(DummyPinnedMemoryResource(device), device, check=True)
if mr_factory is DummyUnifiedMemoryResource:
request.getfixturevalue("requires_concurrent_managed_access")
mr = mr_factory(device)
buffer_copy_to(mr, device, check=check)


def buffer_copy_from(dummy_mr: MemoryResource, device, check=False):
Expand All @@ -229,12 +252,22 @@ def buffer_copy_from(dummy_mr: MemoryResource, device, check=False):
src_buffer.close()


def test_buffer_copy_from():
@pytest.mark.parametrize(
("mr_factory", "check"),
[
(DummyDeviceMemoryResource, False),
(DummyUnifiedMemoryResource, False),
(DummyPinnedMemoryResource, True),
],
ids=["device", "unified", "pinned"],
)
def test_buffer_copy_from(mr_factory, check, request):
device = Device()
device.set_current()
buffer_copy_from(DummyDeviceMemoryResource(device), device)
buffer_copy_from(DummyUnifiedMemoryResource(device), device)
buffer_copy_from(DummyPinnedMemoryResource(device), device, check=True)
if mr_factory is DummyUnifiedMemoryResource:
request.getfixturevalue("requires_concurrent_managed_access")
mr = mr_factory(device)
buffer_copy_from(mr, device, check=check)


def _bytes_repeat(pattern: bytes, size: int) -> bytes:
Expand All @@ -256,6 +289,7 @@ def fill_env(request):
if request.param == "device":
mr = DummyDeviceMemoryResource(device)
elif request.param == "unified":
request.getfixturevalue("requires_concurrent_managed_access")
mr = DummyUnifiedMemoryResource(device)
else:
mr = DummyPinnedMemoryResource(device)
Expand Down Expand Up @@ -345,13 +379,23 @@ def buffer_close(dummy_mr: MemoryResource):
assert buffer.memory_resource is None


def test_buffer_close():
@pytest.mark.parametrize(
("mr_factory", "needs_device"),
[
(DummyDeviceMemoryResource, True),
(DummyHostMemoryResource, False),
(DummyUnifiedMemoryResource, True),
(DummyPinnedMemoryResource, True),
],
ids=["device", "host", "unified", "pinned"],
)
def test_buffer_close(mr_factory, needs_device, request):
device = Device()
device.set_current()
buffer_close(DummyDeviceMemoryResource(device))
buffer_close(DummyHostMemoryResource())
buffer_close(DummyUnifiedMemoryResource(device))
buffer_close(DummyPinnedMemoryResource(device))
if mr_factory is DummyUnifiedMemoryResource:
request.getfixturevalue("requires_concurrent_managed_access")
mr = mr_factory(device) if needs_device else mr_factory()
buffer_close(mr)


def test_buffer_external_host():
Expand Down Expand Up @@ -447,7 +491,7 @@ def test_buffer_external_pinned_registered(change_device):


@pytest.mark.parametrize("change_device", [True, False])
def test_buffer_external_managed(change_device):
def test_buffer_external_managed(change_device, requires_concurrent_managed_access):
n = ccx_system.get_num_devices()
if n < 1:
pytest.skip("No devices found")
Expand Down Expand Up @@ -560,9 +604,11 @@ def test_buffer_dunder_dlpack():
(DummyPinnedMemoryResource, (DLDeviceType.kDLCUDAHost, 0)),
],
)
def test_buffer_dunder_dlpack_device_success(DummyMR, expected):
def test_buffer_dunder_dlpack_device_success(DummyMR, expected, request):
device = Device()
device.set_current()
if DummyMR is DummyUnifiedMemoryResource:
request.getfixturevalue("requires_concurrent_managed_access")
dummy_mr = DummyMR() if DummyMR is DummyHostMemoryResource else DummyMR(device)
buffer = dummy_mr.allocate(size=1024)
assert buffer.__dlpack_device__() == expected
Expand Down
2 changes: 2 additions & 0 deletions cuda_core/tests/test_memory_peer_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
NBYTES = 1024


@pytest.mark.usefixtures("requires_concurrent_managed_access")
def test_peer_access_basic(mempool_device_x2):
"""Basic tests for dmr.peer_accessible_by."""
dev0, dev1 = mempool_device_x2
Expand Down Expand Up @@ -81,6 +82,7 @@ def check(expected):
dmr.peer_accessible_by = [num_devices] # device ID out of bounds


@pytest.mark.usefixtures("requires_concurrent_managed_access")
def test_peer_access_transitions(mempool_device_x3):
"""Advanced tests for dmr.peer_accessible_by."""

Expand Down
Loading
Loading