Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cuda_core/cuda/core/experimental/_device.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ from cuda.core.experimental._utils.cuda_utils import (
from cuda.core.experimental._stream cimport default_stream



# TODO: I prefer to type these as "cdef object" and avoid accessing them from within Python,
# but it seems it is very convenient to expose them for testing purposes...
_tls = threading.local()
Expand Down Expand Up @@ -1273,7 +1274,7 @@ class Device:
"""
self._check_context_initialized()
ctx = self._get_current_context()
return Event._init(self._id, ctx, options)
return Event._init(self._id, ctx, options, True)

def allocate(self, size, stream: Optional[Stream] = None) -> Buffer:
"""Allocate device memory from a specified stream.
Expand Down
2 changes: 2 additions & 0 deletions cuda_core/cuda/core/experimental/_event.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ cdef class Event:
cydriver.CUevent _handle
bint _timing_disabled
bint _busy_waited
bint _ipc_enabled
object _ipc_descriptor
int _device_id
object _ctx_handle

Expand Down
104 changes: 89 additions & 15 deletions cuda_core/cuda/core/experimental/_event.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

from __future__ import annotations

cimport cpython
from libc.stdint cimport uintptr_t
from libc.string cimport memcpy

from cuda.bindings cimport cydriver

Expand All @@ -14,6 +16,7 @@ from cuda.core.experimental._utils.cuda_utils cimport (
)

from dataclasses import dataclass
import multiprocessing
from typing import TYPE_CHECKING, Optional

from cuda.core.experimental._context import Context
Expand All @@ -40,15 +43,15 @@ cdef class EventOptions:
has actually been completed.
Otherwise, the CPU thread will busy-wait until the event has
been completed. (Default to False)
support_ipc : bool, optional
ipc_enabled : bool, optional
Event will be suitable for interprocess use.
Note that enable_timing must be False. (Default to False)

"""

enable_timing: Optional[bool] = False
busy_waited_sync: Optional[bool] = False
support_ipc: Optional[bool] = False
ipc_enabled: Optional[bool] = False


cdef class Event:
Expand Down Expand Up @@ -86,24 +89,35 @@ cdef class Event:
raise RuntimeError("Event objects cannot be instantiated directly. Please use Stream APIs (record).")

@classmethod
def _init(cls, device_id: int, ctx_handle: Context, options=None):
def _init(cls, device_id: int, ctx_handle: Context, options=None, is_free=False):
cdef Event self = Event.__new__(cls)
cdef EventOptions opts = check_or_create_options(EventOptions, options, "Event options")
cdef unsigned int flags = 0x0
self._timing_disabled = False
self._busy_waited = False
self._ipc_enabled = False
self._ipc_descriptor = None
if not opts.enable_timing:
flags |= cydriver.CUevent_flags.CU_EVENT_DISABLE_TIMING
self._timing_disabled = True
if opts.busy_waited_sync:
flags |= cydriver.CUevent_flags.CU_EVENT_BLOCKING_SYNC
self._busy_waited = True
if opts.support_ipc:
raise NotImplementedError("WIP: https://github.com/NVIDIA/cuda-python/issues/103")
if opts.ipc_enabled:
if is_free:
raise TypeError(
"IPC-enabled events must be bound; use Stream.record for creation."
)
flags |= cydriver.CUevent_flags.CU_EVENT_INTERPROCESS
self._ipc_enabled = True
if not self._timing_disabled:
raise TypeError("IPC-enabled events cannot use timing.")
with nogil:
HANDLE_RETURN(cydriver.cuEventCreate(&self._handle, flags))
self._device_id = device_id
self._ctx_handle = ctx_handle
if opts.ipc_enabled:
self.get_ipc_descriptor()
return self

cpdef close(self):
Expand Down Expand Up @@ -151,6 +165,40 @@ cdef class Event:
raise CUDAError(err)
raise RuntimeError(explanation)

def get_ipc_descriptor(self) -> IPCEventDescriptor:
"""Export an event allocated for sharing between processes."""
if self._ipc_descriptor is not None:
return self._ipc_descriptor
if not self.is_ipc_enabled:
raise RuntimeError("Event is not IPC-enabled")
cdef cydriver.CUipcEventHandle data
with nogil:
HANDLE_RETURN(cydriver.cuIpcGetEventHandle(&data, <cydriver.CUevent>(self._handle)))
cdef bytes data_b = cpython.PyBytes_FromStringAndSize(<char*>(data.reserved), sizeof(data.reserved))
self._ipc_descriptor = IPCEventDescriptor._init(data_b, self._busy_waited)
return self._ipc_descriptor

@classmethod
def from_ipc_descriptor(cls, ipc_descriptor: IPCEventDescriptor) -> Event:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: it doesn't seem that this public API is tested?

"""Import an event that was exported from another process."""
cdef cydriver.CUipcEventHandle data
memcpy(data.reserved, <const void*><const char*>(ipc_descriptor._reserved), sizeof(data.reserved))
cdef Event self = Event.__new__(cls)
with nogil:
HANDLE_RETURN(cydriver.cuIpcOpenEventHandle(&self._handle, data))
self._timing_disabled = True
self._busy_waited = ipc_descriptor._busy_waited
self._ipc_enabled = True
self._ipc_descriptor = ipc_descriptor
self._device_id = -1 # ??
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think when we send over the handle we should also send the device id

self._ctx_handle = None # ??
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be (lazily? not sure if it's possible) initialized in the child process (to the specified device's current context).

return self

@property
def is_ipc_enabled(self) -> bool:
"""Return True if the event can be shared across process boundaries, otherwise False."""
return self._ipc_enabled

@property
def is_timing_disabled(self) -> bool:
"""Return True if the event does not record timing data, otherwise False."""
Expand All @@ -161,11 +209,6 @@ cdef class Event:
"""Return True if the event synchronization would keep the CPU busy-waiting, otherwise False."""
return self._busy_waited

@property
def is_ipc_supported(self) -> bool:
"""Return True if this event can be used as an interprocess event, otherwise False."""
raise NotImplementedError("WIP: https://github.com/NVIDIA/cuda-python/issues/103")

def sync(self):
"""Synchronize until the event completes.

Expand Down Expand Up @@ -212,12 +255,43 @@ cdef class Event:
context is set current after a event is created.

"""

from cuda.core.experimental._device import Device # avoid circular import

return Device(self._device_id)
if self._device_id >= 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, then this awkward check can be removed.

from ._device import Device # avoid circular import
return Device(self._device_id)

@property
def context(self) -> Context:
"""Return the :obj:`~_context.Context` associated with this event."""
return Context._from_ctx(self._ctx_handle, self._device_id)
if self._ctx_handle is not None and self._device_id >= 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

return Context._from_ctx(self._ctx_handle, self._device_id)


cdef class IPCEventDescriptor:
"""Serializable object describing an event that can be shared between processes."""

cdef:
bytes _reserved
bint _busy_waited

def __init__(self, *arg, **kwargs):
raise RuntimeError("IPCEventDescriptor objects cannot be instantiated directly. Please use Event APIs.")

@classmethod
def _init(cls, reserved: bytes, busy_waited: bint):
cdef IPCEventDescriptor self = IPCEventDescriptor.__new__(cls)
self._reserved = reserved
self._busy_waited = busy_waited
return self

def __eq__(self, IPCEventDescriptor rhs):
# No need to check self._busy_waited.
return self._reserved == rhs._reserved

def __reduce__(self):
return self._init, (self._reserved, self._busy_waited)


def _reduce_event(event):
return event.from_ipc_descriptor, (event.get_ipc_descriptor(),)

multiprocessing.reduction.register(Event, _reduce_event)
8 changes: 4 additions & 4 deletions cuda_core/cuda/core/experimental/_memory.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,11 @@ cdef class Buffer(_cyBuffer, MemoryResourceAttributes):
if stream is None:
# Note: match this behavior to DeviceMemoryResource.allocate()
stream = default_stream()
cdef cydriver.CUmemPoolPtrExportData share_data
memcpy(share_data.reserved, <const void*><const char*>(ipc_buffer._reserved), sizeof(share_data.reserved))
cdef cydriver.CUmemPoolPtrExportData data
memcpy(data.reserved, <const void*><const char*>(ipc_buffer._reserved), sizeof(data.reserved))
cdef cydriver.CUdeviceptr ptr
with nogil:
HANDLE_RETURN(cydriver.cuMemPoolImportPointer(&ptr, mr._mempool_handle, &share_data))
HANDLE_RETURN(cydriver.cuMemPoolImportPointer(&ptr, mr._mempool_handle, &data))
return Buffer._init(<intptr_t>ptr, ipc_buffer.size, mr, stream)

def copy_to(self, dst: Buffer = None, *, stream: Stream) -> Buffer:
Expand Down Expand Up @@ -511,7 +511,7 @@ cdef class DeviceMemoryResourceOptions:
(Default to 0)
"""
ipc_enabled : cython.bint = False
max_size : cython.int = 0
max_size : cython.size_t = 0


# TODO: cythonize this?
Expand Down
8 changes: 7 additions & 1 deletion cuda_core/cuda/core/experimental/_stream.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,13 @@ cdef class Stream:
# and CU_EVENT_RECORD_EXTERNAL, can be set in EventOptions.
if event is None:
self._get_device_and_context()
event = Event._init(<int>(self._device_id), <uintptr_t>(self._ctx_handle), options)
event = Event._init(<int>(self._device_id), <uintptr_t>(self._ctx_handle), options, False)
elif event.is_ipc_enabled:
raise TypeError(
"IPC-enabled events should not be re-recorded, instead create a "
"new event by supplying options."
)
Comment on lines +264 to +268
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imagine we are doing ping pong between two processes. We should be able to reuse the same event in parent and child. Does the driver actually disallow this?


cdef cydriver.CUevent e = (<cyEvent?>(event))._handle
with nogil:
HANDLE_RETURN(cydriver.cuEventRecord(e, self._handle))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import os
import pathlib
import platform
import sys

CUDA_PATH = os.environ.get("CUDA_PATH")
Expand All @@ -22,12 +23,13 @@
import cuda_python_test_helpers
except ImportError:
# Import shared platform helpers for tests across repos
sys.path.insert(0, str(pathlib.Path(__file__).resolve().parents[2] / "cuda_python_test_helpers"))
sys.path.insert(0, str(pathlib.Path(__file__).resolve().parents[3] / "cuda_python_test_helpers"))
import cuda_python_test_helpers


IS_WSL = cuda_python_test_helpers.IS_WSL
supports_ipc_mempool = cuda_python_test_helpers.supports_ipc_mempool
IS_WINDOWS = platform.system() == "Windows"


del cuda_python_test_helpers
122 changes: 122 additions & 0 deletions cuda_core/tests/helpers/buffers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import ctypes
import sys

from cuda.core.experimental import Buffer, MemoryResource
from cuda.core.experimental._utils.cuda_utils import driver, handle_return

if sys.platform.startswith("win"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should migrate this to a common place so that this pattern doesn't get replicated in multiple files. It would be a maintenance headache if we had to update it in multiple places.

libc = ctypes.CDLL("msvcrt.dll")
else:
libc = ctypes.CDLL("libc.so.6")


__all__ = ["DummyUnifiedMemoryResource", "PatternGen", "make_scratch_buffer", "compare_equal_buffers"]


class DummyUnifiedMemoryResource(MemoryResource):
def __init__(self, device):
self.device = device

def allocate(self, size, stream=None) -> Buffer:
ptr = handle_return(driver.cuMemAllocManaged(size, driver.CUmemAttach_flags.CU_MEM_ATTACH_GLOBAL.value))
return Buffer.from_handle(ptr=ptr, size=size, mr=self)

def deallocate(self, ptr, size, stream=None):
handle_return(driver.cuMemFree(ptr))

@property
def is_device_accessible(self) -> bool:
return True

@property
def is_host_accessible(self) -> bool:
return True

@property
def device_id(self) -> int:
return self.device


class PatternGen:
"""
Provides methods to fill a target buffer with known test patterns and
verify the expected values.
If a stream is provided, operations are synchronized with respect to that
stream. Otherwise, they are synchronized over the device.
The test pattern is either a fixed value or a cyclic pattern generated from
an 8-bit seed. Only one of `value` or `seed` should be supplied.
Distinct test patterns are stored in private buffers called pattern
buffers. Calls to `fill_buffer` copy from a pattern buffer to the target
buffer. Calls to `verify_buffer` copy from the target buffer to a scratch
buffer and then perform a comparison.
"""

def __init__(self, device, size, stream=None):
self.device = device
self.size = size
self.stream = stream if stream is not None else device.create_stream()
self.sync_target = stream if stream is not None else device
self.pattern_buffers = {}

def fill_buffer(self, buffer, seed=None, value=None):
"""Fill a device buffer with a sequential test pattern using unified memory."""
assert buffer.size == self.size
pattern_buffer = self._get_pattern_buffer(seed, value)
buffer.copy_from(pattern_buffer, stream=self.stream)

def verify_buffer(self, buffer, seed=None, value=None):
"""Verify the buffer contents against a sequential pattern."""
assert buffer.size == self.size
scratch_buffer = DummyUnifiedMemoryResource(self.device).allocate(self.size)
ptr_test = self._ptr(scratch_buffer)
pattern_buffer = self._get_pattern_buffer(seed, value)
ptr_expected = self._ptr(pattern_buffer)
scratch_buffer.copy_from(buffer, stream=self.stream)
self.sync_target.sync()
assert libc.memcmp(ptr_test, ptr_expected, self.size) == 0

@staticmethod
def _ptr(buffer):
"""Get a pointer to the specified buffer."""
return ctypes.cast(int(buffer.handle), ctypes.POINTER(ctypes.c_ubyte))

def _get_pattern_buffer(self, seed, value):
"""Get a buffer holding the specified test pattern."""
assert seed is None or value is None
if value is None:
seed = (0 if seed is None else seed) & 0xFF
key = seed, value
pattern_buffer = self.pattern_buffers.get(key, None)
if pattern_buffer is None:
if value is not None:
pattern_buffer = make_scratch_buffer(self.device, value, self.size)
else:
pattern_buffer = DummyUnifiedMemoryResource(self.device).allocate(self.size)
ptr = self._ptr(pattern_buffer)
for i in range(self.size):
ptr[i] = (seed + i) & 0xFF
self.pattern_buffers[key] = pattern_buffer
return pattern_buffer


def make_scratch_buffer(device, value, nbytes):
"""Create a unified memory buffer with the specified value."""
buffer = DummyUnifiedMemoryResource(device).allocate(nbytes)
ptr = ctypes.cast(int(buffer.handle), ctypes.POINTER(ctypes.c_byte))
ctypes.memset(ptr, value & 0xFF, nbytes)
return buffer


def compare_equal_buffers(buffer1, buffer2):
"""Compare the contents of two host-accessible buffers for bitwise equality."""
if buffer1.size != buffer2.size:
return False
ptr1 = ctypes.cast(int(buffer1.handle), ctypes.POINTER(ctypes.c_byte))
ptr2 = ctypes.cast(int(buffer2.handle), ctypes.POINTER(ctypes.c_byte))
return libc.memcmp(ptr1, ptr2, buffer1.size) == 0
Loading
Loading