diff --git a/cuda_core/cuda/core/experimental/_event.pyx b/cuda_core/cuda/core/experimental/_event.pyx index 944969852e..85d606e364 100644 --- a/cuda_core/cuda/core/experimental/_event.pyx +++ b/cuda_core/cuda/core/experimental/_event.pyx @@ -89,7 +89,7 @@ cdef class Event: @classmethod def _init(cls, device_id: int, ctx_handle: Context, options=None): - cdef Event self = Event.__new__(Event) + cdef Event self = Event.__new__(cls) cdef EventOptions opts = check_or_create_options(EventOptions, options, "Event options") flags = 0x0 self._timing_disabled = False diff --git a/cuda_core/cuda/core/experimental/_stream.pyx b/cuda_core/cuda/core/experimental/_stream.pyx index dc8d8e942c..a40b6f90db 100644 --- a/cuda_core/cuda/core/experimental/_stream.pyx +++ b/cuda_core/cuda/core/experimental/_stream.pyx @@ -122,7 +122,7 @@ cdef class Stream: @classmethod def _legacy_default(cls): - cdef Stream self = Stream.__new__(Stream) + cdef Stream self = Stream.__new__(cls) self._handle = driver.CUstream(driver.CU_STREAM_LEGACY) self._owner = None self._builtin = True @@ -134,7 +134,7 @@ cdef class Stream: @classmethod def _per_thread_default(cls): - cdef Stream self = Stream.__new__(Stream) + cdef Stream self = Stream.__new__(cls) self._handle = driver.CUstream(driver.CU_STREAM_PER_THREAD) self._owner = None self._builtin = True @@ -146,7 +146,7 @@ cdef class Stream: @classmethod def _init(cls, obj: Optional[IsStreamT] = None, options=None, device_id: int = None): - cdef Stream self = Stream.__new__(Stream) + cdef Stream self = Stream.__new__(cls) self._handle = None self._owner = None self._builtin = False diff --git a/cuda_core/tests/test_event.py b/cuda_core/tests/test_event.py index 827fa22b4c..d20914cbc0 100644 --- a/cuda_core/tests/test_event.py +++ b/cuda_core/tests/test_event.py @@ -14,6 +14,7 @@ import cuda.core.experimental from cuda.core.experimental import ( Device, + Event, EventOptions, LaunchConfig, LegacyPinnedMemoryResource, @@ -191,3 +192,13 @@ def test_event_context(init_cuda): event = Device().create_event(options=EventOptions()) context = event.context assert context is not None + + +def test_event_subclassing(): + class MyEvent(Event): + pass + + dev = Device() + dev.set_current() + event = MyEvent._init(dev.device_id, dev.context) + assert isinstance(event, MyEvent) diff --git a/cuda_core/tests/test_stream.py b/cuda_core/tests/test_stream.py index 7a3ff8b2ca..1c782c7194 100644 --- a/cuda_core/tests/test_stream.py +++ b/cuda_core/tests/test_stream.py @@ -111,3 +111,29 @@ def test_per_thread_default_stream(): def test_default_stream(): stream = default_stream() assert isinstance(stream, Stream) + + +def test_stream_subclassing(init_cuda): + class MyStream(Stream): + pass + + dev = Device() + dev.set_current() + stream = MyStream._init(options=StreamOptions(), device_id=dev.device_id) + assert isinstance(stream, MyStream) + + +def test_stream_legacy_default_subclassing(): + class MyStream(Stream): + pass + + stream = MyStream._legacy_default() + assert isinstance(stream, MyStream) + + +def test_stream_per_thread_default_subclassing(): + class MyStream(Stream): + pass + + stream = MyStream._per_thread_default() + assert isinstance(stream, MyStream)