diff --git a/cuda_bindings/cuda/bindings/driver.pyx.in b/cuda_bindings/cuda/bindings/driver.pyx.in index c6d2e3ca2c..adce0af721 100644 --- a/cuda_bindings/cuda/bindings/driver.pyx.in +++ b/cuda_bindings/cuda/bindings/driver.pyx.in @@ -8,7 +8,7 @@ import cython import ctypes from libc.stdlib cimport calloc, malloc, free from libc cimport string -from libc.stdint cimport int32_t, uint32_t, int64_t, uint64_t +from libc.stdint cimport int32_t, uint32_t, int64_t, uint64_t, uintptr_t from libc.stddef cimport wchar_t from libc.limits cimport CHAR_MIN from libcpp.vector cimport vector @@ -6525,6 +6525,12 @@ cdef class CUcontext: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUcontext): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6556,6 +6562,12 @@ cdef class CUmodule: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUmodule): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6587,6 +6599,12 @@ cdef class CUfunction: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUfunction): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6618,6 +6636,12 @@ cdef class CUlibrary: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUlibrary): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6649,6 +6673,12 @@ cdef class CUkernel: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUkernel): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6680,6 +6710,12 @@ cdef class CUarray: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUarray): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6711,6 +6747,12 @@ cdef class CUmipmappedArray: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUmipmappedArray): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6742,6 +6784,12 @@ cdef class CUtexref: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUtexref): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6773,6 +6821,12 @@ cdef class CUsurfref: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUsurfref): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6804,6 +6858,12 @@ cdef class CUevent: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUevent): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6835,6 +6895,12 @@ cdef class CUstream: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUstream): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6866,6 +6932,12 @@ cdef class CUgraphicsResource: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUgraphicsResource): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6897,6 +6969,12 @@ cdef class CUexternalMemory: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUexternalMemory): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6928,6 +7006,12 @@ cdef class CUexternalSemaphore: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUexternalSemaphore): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6959,6 +7043,12 @@ cdef class CUgraph: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUgraph): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6990,6 +7080,12 @@ cdef class CUgraphNode: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUgraphNode): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7021,6 +7117,12 @@ cdef class CUgraphExec: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUgraphExec): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7052,6 +7154,12 @@ cdef class CUmemoryPool: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUmemoryPool): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7083,6 +7191,12 @@ cdef class CUuserObject: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUuserObject): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7114,6 +7228,12 @@ cdef class CUgraphDeviceNode: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUgraphDeviceNode): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7145,6 +7265,12 @@ cdef class CUasyncCallbackHandle: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUasyncCallbackHandle): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7176,6 +7302,12 @@ cdef class CUgreenCtx: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUgreenCtx): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7205,6 +7337,12 @@ cdef class CUlinkState: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUlinkState): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7236,6 +7374,12 @@ cdef class CUdevResourceDesc: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUdevResourceDesc): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7265,6 +7409,12 @@ cdef class CUlogsCallbackHandle: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUlogsCallbackHandle): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7296,6 +7446,12 @@ cdef class CUeglStreamConnection: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUeglStreamConnection): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7325,6 +7481,12 @@ cdef class EGLImageKHR: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, EGLImageKHR): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7354,6 +7516,12 @@ cdef class EGLStreamKHR: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, EGLStreamKHR): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7383,6 +7551,12 @@ cdef class EGLSyncKHR: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, EGLSyncKHR): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): diff --git a/cuda_bindings/cuda/bindings/nvrtc.pyx.in b/cuda_bindings/cuda/bindings/nvrtc.pyx.in index 08abcbcf13..d274acc996 100644 --- a/cuda_bindings/cuda/bindings/nvrtc.pyx.in +++ b/cuda_bindings/cuda/bindings/nvrtc.pyx.in @@ -8,7 +8,7 @@ import cython import ctypes from libc.stdlib cimport calloc, malloc, free from libc cimport string -from libc.stdint cimport int32_t, uint32_t, int64_t, uint64_t +from libc.stdint cimport int32_t, uint32_t, int64_t, uint64_t, uintptr_t from libc.stddef cimport wchar_t from libc.limits cimport CHAR_MIN from libcpp.vector cimport vector @@ -109,6 +109,12 @@ cdef class nvrtcProgram: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, nvrtcProgram): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): diff --git a/cuda_bindings/cuda/bindings/runtime.pyx.in b/cuda_bindings/cuda/bindings/runtime.pyx.in index 426664570b..1f1fc72727 100644 --- a/cuda_bindings/cuda/bindings/runtime.pyx.in +++ b/cuda_bindings/cuda/bindings/runtime.pyx.in @@ -8,7 +8,7 @@ import cython import ctypes from libc.stdlib cimport calloc, malloc, free from libc cimport string -from libc.stdint cimport int32_t, uint32_t, int64_t, uint64_t +from libc.stdint cimport int32_t, uint32_t, int64_t, uint64_t, uintptr_t from libc.stddef cimport wchar_t from libc.limits cimport CHAR_MIN from libcpp.vector cimport vector @@ -5202,6 +5202,12 @@ cdef class cudaArray_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaArray_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5233,6 +5239,12 @@ cdef class cudaArray_const_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaArray_const_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5264,6 +5276,12 @@ cdef class cudaMipmappedArray_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaMipmappedArray_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5295,6 +5313,12 @@ cdef class cudaMipmappedArray_const_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaMipmappedArray_const_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5326,6 +5350,12 @@ cdef class cudaGraphicsResource_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaGraphicsResource_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5357,6 +5387,12 @@ cdef class cudaExternalMemory_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaExternalMemory_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5388,6 +5424,12 @@ cdef class cudaExternalSemaphore_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaExternalSemaphore_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5419,6 +5461,12 @@ cdef class cudaKernel_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaKernel_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5450,6 +5498,12 @@ cdef class cudaLibrary_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaLibrary_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5481,6 +5535,12 @@ cdef class cudaGraphDeviceNode_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaGraphDeviceNode_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5512,6 +5572,12 @@ cdef class cudaAsyncCallbackHandle_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaAsyncCallbackHandle_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5541,6 +5607,12 @@ cdef class EGLImageKHR: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, EGLImageKHR): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5570,6 +5642,12 @@ cdef class EGLStreamKHR: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, EGLStreamKHR): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5599,6 +5677,12 @@ cdef class EGLSyncKHR: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, EGLSyncKHR): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): diff --git a/cuda_bindings/docs/source/release/12.X.Y-notes.rst b/cuda_bindings/docs/source/release/12.X.Y-notes.rst index eb45c39309..80cd405308 100644 --- a/cuda_bindings/docs/source/release/12.X.Y-notes.rst +++ b/cuda_bindings/docs/source/release/12.X.Y-notes.rst @@ -36,6 +36,7 @@ Miscellaneous ------------- * Added PTX utilities including :func:`~utils.get_minimal_required_cuda_ver_from_ptx_ver` and :func:`~utils.get_ptx_ver`. +* Common CUDA objects such as :class:`~runtime.cudaStream_t` now compare equal if the underlying address is the same. Known issues diff --git a/cuda_bindings/tests/test_cuda.py b/cuda_bindings/tests/test_cuda.py index c87c755598..da3c6dec61 100644 --- a/cuda_bindings/tests/test_cuda.py +++ b/cuda_bindings/tests/test_cuda.py @@ -10,6 +10,7 @@ import cuda.cuda as cuda import cuda.cudart as cudart +from cuda.bindings import driver def driverVersionLessThan(target): @@ -1008,3 +1009,29 @@ def test_private_function_pointer_inspector(): from cuda.bindings._bindings.cydriver import _inspect_function_pointer assert _inspect_function_pointer("__cuGetErrorString") != 0 + + +@pytest.mark.parametrize( + "target", + ( + driver.CUcontext, + driver.CUstream, + driver.CUevent, + driver.CUmodule, + driver.CUlibrary, + driver.CUfunction, + driver.CUkernel, + driver.CUgraph, + driver.CUgraphNode, + driver.CUgraphExec, + driver.CUmemoryPool, + ), +) +def test_struct_pointer_comparison(target): + a = target(123) + b = target(123) + assert a == b + assert hash(a) == hash(b) + c = target(456) + assert a != c + assert hash(a) != hash(c) diff --git a/cuda_bindings/tests/test_cudart.py b/cuda_bindings/tests/test_cudart.py index 756d8fac6b..70803c0777 100644 --- a/cuda_bindings/tests/test_cudart.py +++ b/cuda_bindings/tests/test_cudart.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2021-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE import ctypes @@ -9,6 +9,7 @@ import cuda.cuda as cuda import cuda.cudart as cudart +from cuda.bindings import runtime def isSuccess(err): @@ -1366,3 +1367,24 @@ def test_cudart_conditional(): assert len(params.conditional.phGraph_out) == 1 assert int(params.conditional.phGraph_out[0]) != 0 + + +@pytest.mark.parametrize( + "target", + ( + runtime.cudaStream_t, + runtime.cudaEvent_t, + runtime.cudaGraph_t, + runtime.cudaGraphNode_t, + runtime.cudaGraphExec_t, + runtime.cudaMemPool_t, + ), +) +def test_struct_pointer_comparison(target): + a = target(123) + b = target(123) + assert a == b + assert hash(a) == hash(b) + c = target(456) + assert a != c + assert hash(a) != hash(c)