From 1a4269e52c97a44a83d3aed9e557ea913a762b6d Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Mon, 21 Jul 2025 15:52:41 +0000 Subject: [PATCH 1/3] allow opaque structs to compare equal if the underlying address is the same --- cuda_bindings/cuda/bindings/driver.pyx.in | 116 +++++++++++++++++++++ cuda_bindings/cuda/bindings/nvrtc.pyx.in | 4 + cuda_bindings/cuda/bindings/runtime.pyx.in | 56 ++++++++++ 3 files changed, 176 insertions(+) diff --git a/cuda_bindings/cuda/bindings/driver.pyx.in b/cuda_bindings/cuda/bindings/driver.pyx.in index c6d2e3ca2c..1efc5e3940 100644 --- a/cuda_bindings/cuda/bindings/driver.pyx.in +++ b/cuda_bindings/cuda/bindings/driver.pyx.in @@ -6525,6 +6525,10 @@ cdef class CUcontext: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUcontext): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6556,6 +6560,10 @@ cdef class CUmodule: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUmodule): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6587,6 +6595,10 @@ cdef class CUfunction: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUfunction): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6618,6 +6630,10 @@ cdef class CUlibrary: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUlibrary): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6649,6 +6665,10 @@ cdef class CUkernel: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUkernel): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6680,6 +6700,10 @@ cdef class CUarray: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUarray): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6711,6 +6735,10 @@ cdef class CUmipmappedArray: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUmipmappedArray): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6742,6 +6770,10 @@ cdef class CUtexref: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUtexref): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6773,6 +6805,10 @@ cdef class CUsurfref: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUsurfref): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6804,6 +6840,10 @@ cdef class CUevent: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUevent): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6835,6 +6875,10 @@ cdef class CUstream: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUstream): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6866,6 +6910,10 @@ cdef class CUgraphicsResource: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUgraphicsResource): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6897,6 +6945,10 @@ cdef class CUexternalMemory: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUexternalMemory): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6928,6 +6980,10 @@ cdef class CUexternalSemaphore: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUexternalSemaphore): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6959,6 +7015,10 @@ cdef class CUgraph: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUgraph): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6990,6 +7050,10 @@ cdef class CUgraphNode: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUgraphNode): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7021,6 +7085,10 @@ cdef class CUgraphExec: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUgraphExec): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7052,6 +7120,10 @@ cdef class CUmemoryPool: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUmemoryPool): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7083,6 +7155,10 @@ cdef class CUuserObject: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUuserObject): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7114,6 +7190,10 @@ cdef class CUgraphDeviceNode: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUgraphDeviceNode): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7145,6 +7225,10 @@ cdef class CUasyncCallbackHandle: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUasyncCallbackHandle): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7176,6 +7260,10 @@ cdef class CUgreenCtx: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUgreenCtx): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7205,6 +7293,10 @@ cdef class CUlinkState: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUlinkState): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7236,6 +7328,10 @@ cdef class CUdevResourceDesc: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUdevResourceDesc): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7265,6 +7361,10 @@ cdef class CUlogsCallbackHandle: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUlogsCallbackHandle): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7296,6 +7396,10 @@ cdef class CUeglStreamConnection: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, CUeglStreamConnection): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7325,6 +7429,10 @@ cdef class EGLImageKHR: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, EGLImageKHR): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7354,6 +7462,10 @@ cdef class EGLStreamKHR: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, EGLStreamKHR): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7383,6 +7495,10 @@ cdef class EGLSyncKHR: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, EGLSyncKHR): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): diff --git a/cuda_bindings/cuda/bindings/nvrtc.pyx.in b/cuda_bindings/cuda/bindings/nvrtc.pyx.in index 08abcbcf13..3e4ef8346f 100644 --- a/cuda_bindings/cuda/bindings/nvrtc.pyx.in +++ b/cuda_bindings/cuda/bindings/nvrtc.pyx.in @@ -109,6 +109,10 @@ cdef class nvrtcProgram: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, nvrtcProgram): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): diff --git a/cuda_bindings/cuda/bindings/runtime.pyx.in b/cuda_bindings/cuda/bindings/runtime.pyx.in index 426664570b..4c12f387ef 100644 --- a/cuda_bindings/cuda/bindings/runtime.pyx.in +++ b/cuda_bindings/cuda/bindings/runtime.pyx.in @@ -5202,6 +5202,10 @@ cdef class cudaArray_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaArray_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5233,6 +5237,10 @@ cdef class cudaArray_const_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaArray_const_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5264,6 +5272,10 @@ cdef class cudaMipmappedArray_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaMipmappedArray_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5295,6 +5307,10 @@ cdef class cudaMipmappedArray_const_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaMipmappedArray_const_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5326,6 +5342,10 @@ cdef class cudaGraphicsResource_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaGraphicsResource_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5357,6 +5377,10 @@ cdef class cudaExternalMemory_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaExternalMemory_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5388,6 +5412,10 @@ cdef class cudaExternalSemaphore_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaExternalSemaphore_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5419,6 +5447,10 @@ cdef class cudaKernel_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaKernel_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5450,6 +5482,10 @@ cdef class cudaLibrary_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaLibrary_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5481,6 +5517,10 @@ cdef class cudaGraphDeviceNode_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaGraphDeviceNode_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5512,6 +5552,10 @@ cdef class cudaAsyncCallbackHandle_t: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, cudaAsyncCallbackHandle_t): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5541,6 +5585,10 @@ cdef class EGLImageKHR: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, EGLImageKHR): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5570,6 +5618,10 @@ cdef class EGLStreamKHR: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, EGLStreamKHR): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5599,6 +5651,10 @@ cdef class EGLSyncKHR: return '' def __index__(self): return self.__int__() + def __eq__(self, other): + if not isinstance(other, EGLSyncKHR): + return False + return self._pvt_ptr[0] == (other)._pvt_ptr[0] def __int__(self): return self._pvt_ptr[0] def getPtr(self): From 68f66e18015e39a8fcce2c1823b9a948dfcb44b0 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Tue, 22 Jul 2025 02:46:34 +0000 Subject: [PATCH 2/3] add tests --- cuda_bindings/tests/test_cuda.py | 25 +++++++++++++++++++++++++ cuda_bindings/tests/test_cudart.py | 22 +++++++++++++++++++++- 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/cuda_bindings/tests/test_cuda.py b/cuda_bindings/tests/test_cuda.py index c87c755598..e3dce787e7 100644 --- a/cuda_bindings/tests/test_cuda.py +++ b/cuda_bindings/tests/test_cuda.py @@ -10,6 +10,7 @@ import cuda.cuda as cuda import cuda.cudart as cudart +from cuda.bindings import driver def driverVersionLessThan(target): @@ -1008,3 +1009,27 @@ def test_private_function_pointer_inspector(): from cuda.bindings._bindings.cydriver import _inspect_function_pointer assert _inspect_function_pointer("__cuGetErrorString") != 0 + + +@pytest.mark.parametrize( + "target", + ( + driver.CUcontext, + driver.CUstream, + driver.CUevent, + driver.CUmodule, + driver.CUlibrary, + driver.CUfunction, + driver.CUkernel, + driver.CUgraph, + driver.CUgraphNode, + driver.CUgraphExec, + driver.CUmemoryPool, + ), +) +def test_struct_pointer_comparison(target): + a = target(123) + b = target(123) + assert a == b + c = target(456) + assert a != c diff --git a/cuda_bindings/tests/test_cudart.py b/cuda_bindings/tests/test_cudart.py index 756d8fac6b..f9c4ee8e57 100644 --- a/cuda_bindings/tests/test_cudart.py +++ b/cuda_bindings/tests/test_cudart.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2021-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE import ctypes @@ -9,6 +9,7 @@ import cuda.cuda as cuda import cuda.cudart as cudart +from cuda.bindings import runtime def isSuccess(err): @@ -1366,3 +1367,22 @@ def test_cudart_conditional(): assert len(params.conditional.phGraph_out) == 1 assert int(params.conditional.phGraph_out[0]) != 0 + + +@pytest.mark.parametrize( + "target", + ( + runtime.cudaStream_t, + runtime.cudaEvent_t, + runtime.cudaGraph_t, + runtime.cudaGraphNode_t, + runtime.cudaGraphExec_t, + runtime.cudaMemPool_t, + ), +) +def test_struct_pointer_comparison(target): + a = target(123) + b = target(123) + assert a == b + c = target(456) + assert a != c From 4db82a9cbf03b8ea3902077b918cbd7e8cf09d48 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Wed, 23 Jul 2025 08:42:19 +0000 Subject: [PATCH 3/3] ensure __eq__ is always accompanied by __hash__ --- cuda_bindings/cuda/bindings/driver.pyx.in | 60 ++++++++++++++++++- cuda_bindings/cuda/bindings/nvrtc.pyx.in | 4 +- cuda_bindings/cuda/bindings/runtime.pyx.in | 30 +++++++++- .../docs/source/release/12.X.Y-notes.rst | 1 + cuda_bindings/tests/test_cuda.py | 2 + cuda_bindings/tests/test_cudart.py | 2 + 6 files changed, 96 insertions(+), 3 deletions(-) diff --git a/cuda_bindings/cuda/bindings/driver.pyx.in b/cuda_bindings/cuda/bindings/driver.pyx.in index 1efc5e3940..adce0af721 100644 --- a/cuda_bindings/cuda/bindings/driver.pyx.in +++ b/cuda_bindings/cuda/bindings/driver.pyx.in @@ -8,7 +8,7 @@ import cython import ctypes from libc.stdlib cimport calloc, malloc, free from libc cimport string -from libc.stdint cimport int32_t, uint32_t, int64_t, uint64_t +from libc.stdint cimport int32_t, uint32_t, int64_t, uint64_t, uintptr_t from libc.stddef cimport wchar_t from libc.limits cimport CHAR_MIN from libcpp.vector cimport vector @@ -6529,6 +6529,8 @@ cdef class CUcontext: if not isinstance(other, CUcontext): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6564,6 +6566,8 @@ cdef class CUmodule: if not isinstance(other, CUmodule): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6599,6 +6603,8 @@ cdef class CUfunction: if not isinstance(other, CUfunction): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6634,6 +6640,8 @@ cdef class CUlibrary: if not isinstance(other, CUlibrary): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6669,6 +6677,8 @@ cdef class CUkernel: if not isinstance(other, CUkernel): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6704,6 +6714,8 @@ cdef class CUarray: if not isinstance(other, CUarray): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6739,6 +6751,8 @@ cdef class CUmipmappedArray: if not isinstance(other, CUmipmappedArray): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6774,6 +6788,8 @@ cdef class CUtexref: if not isinstance(other, CUtexref): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6809,6 +6825,8 @@ cdef class CUsurfref: if not isinstance(other, CUsurfref): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6844,6 +6862,8 @@ cdef class CUevent: if not isinstance(other, CUevent): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6879,6 +6899,8 @@ cdef class CUstream: if not isinstance(other, CUstream): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6914,6 +6936,8 @@ cdef class CUgraphicsResource: if not isinstance(other, CUgraphicsResource): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6949,6 +6973,8 @@ cdef class CUexternalMemory: if not isinstance(other, CUexternalMemory): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -6984,6 +7010,8 @@ cdef class CUexternalSemaphore: if not isinstance(other, CUexternalSemaphore): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7019,6 +7047,8 @@ cdef class CUgraph: if not isinstance(other, CUgraph): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7054,6 +7084,8 @@ cdef class CUgraphNode: if not isinstance(other, CUgraphNode): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7089,6 +7121,8 @@ cdef class CUgraphExec: if not isinstance(other, CUgraphExec): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7124,6 +7158,8 @@ cdef class CUmemoryPool: if not isinstance(other, CUmemoryPool): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7159,6 +7195,8 @@ cdef class CUuserObject: if not isinstance(other, CUuserObject): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7194,6 +7232,8 @@ cdef class CUgraphDeviceNode: if not isinstance(other, CUgraphDeviceNode): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7229,6 +7269,8 @@ cdef class CUasyncCallbackHandle: if not isinstance(other, CUasyncCallbackHandle): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7264,6 +7306,8 @@ cdef class CUgreenCtx: if not isinstance(other, CUgreenCtx): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7297,6 +7341,8 @@ cdef class CUlinkState: if not isinstance(other, CUlinkState): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7332,6 +7378,8 @@ cdef class CUdevResourceDesc: if not isinstance(other, CUdevResourceDesc): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7365,6 +7413,8 @@ cdef class CUlogsCallbackHandle: if not isinstance(other, CUlogsCallbackHandle): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7400,6 +7450,8 @@ cdef class CUeglStreamConnection: if not isinstance(other, CUeglStreamConnection): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7433,6 +7485,8 @@ cdef class EGLImageKHR: if not isinstance(other, EGLImageKHR): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7466,6 +7520,8 @@ cdef class EGLStreamKHR: if not isinstance(other, EGLStreamKHR): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -7499,6 +7555,8 @@ cdef class EGLSyncKHR: if not isinstance(other, EGLSyncKHR): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): diff --git a/cuda_bindings/cuda/bindings/nvrtc.pyx.in b/cuda_bindings/cuda/bindings/nvrtc.pyx.in index 3e4ef8346f..d274acc996 100644 --- a/cuda_bindings/cuda/bindings/nvrtc.pyx.in +++ b/cuda_bindings/cuda/bindings/nvrtc.pyx.in @@ -8,7 +8,7 @@ import cython import ctypes from libc.stdlib cimport calloc, malloc, free from libc cimport string -from libc.stdint cimport int32_t, uint32_t, int64_t, uint64_t +from libc.stdint cimport int32_t, uint32_t, int64_t, uint64_t, uintptr_t from libc.stddef cimport wchar_t from libc.limits cimport CHAR_MIN from libcpp.vector cimport vector @@ -113,6 +113,8 @@ cdef class nvrtcProgram: if not isinstance(other, nvrtcProgram): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): diff --git a/cuda_bindings/cuda/bindings/runtime.pyx.in b/cuda_bindings/cuda/bindings/runtime.pyx.in index 4c12f387ef..1f1fc72727 100644 --- a/cuda_bindings/cuda/bindings/runtime.pyx.in +++ b/cuda_bindings/cuda/bindings/runtime.pyx.in @@ -8,7 +8,7 @@ import cython import ctypes from libc.stdlib cimport calloc, malloc, free from libc cimport string -from libc.stdint cimport int32_t, uint32_t, int64_t, uint64_t +from libc.stdint cimport int32_t, uint32_t, int64_t, uint64_t, uintptr_t from libc.stddef cimport wchar_t from libc.limits cimport CHAR_MIN from libcpp.vector cimport vector @@ -5206,6 +5206,8 @@ cdef class cudaArray_t: if not isinstance(other, cudaArray_t): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5241,6 +5243,8 @@ cdef class cudaArray_const_t: if not isinstance(other, cudaArray_const_t): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5276,6 +5280,8 @@ cdef class cudaMipmappedArray_t: if not isinstance(other, cudaMipmappedArray_t): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5311,6 +5317,8 @@ cdef class cudaMipmappedArray_const_t: if not isinstance(other, cudaMipmappedArray_const_t): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5346,6 +5354,8 @@ cdef class cudaGraphicsResource_t: if not isinstance(other, cudaGraphicsResource_t): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5381,6 +5391,8 @@ cdef class cudaExternalMemory_t: if not isinstance(other, cudaExternalMemory_t): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5416,6 +5428,8 @@ cdef class cudaExternalSemaphore_t: if not isinstance(other, cudaExternalSemaphore_t): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5451,6 +5465,8 @@ cdef class cudaKernel_t: if not isinstance(other, cudaKernel_t): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5486,6 +5502,8 @@ cdef class cudaLibrary_t: if not isinstance(other, cudaLibrary_t): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5521,6 +5539,8 @@ cdef class cudaGraphDeviceNode_t: if not isinstance(other, cudaGraphDeviceNode_t): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5556,6 +5576,8 @@ cdef class cudaAsyncCallbackHandle_t: if not isinstance(other, cudaAsyncCallbackHandle_t): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5589,6 +5611,8 @@ cdef class EGLImageKHR: if not isinstance(other, EGLImageKHR): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5622,6 +5646,8 @@ cdef class EGLStreamKHR: if not isinstance(other, EGLStreamKHR): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): @@ -5655,6 +5681,8 @@ cdef class EGLSyncKHR: if not isinstance(other, EGLSyncKHR): return False return self._pvt_ptr[0] == (other)._pvt_ptr[0] + def __hash__(self): + return hash((self._pvt_ptr[0])) def __int__(self): return self._pvt_ptr[0] def getPtr(self): diff --git a/cuda_bindings/docs/source/release/12.X.Y-notes.rst b/cuda_bindings/docs/source/release/12.X.Y-notes.rst index a6da591858..551cde48e0 100644 --- a/cuda_bindings/docs/source/release/12.X.Y-notes.rst +++ b/cuda_bindings/docs/source/release/12.X.Y-notes.rst @@ -30,6 +30,7 @@ Miscellaneous ------------- * Added PTX utilities including :func:`~utils.get_minimal_required_cuda_ver_from_ptx_ver` and :func:`~utils.get_ptx_ver`. +* Common CUDA objects such as :class:`~runtime.cudaStream_t` now compare equal if the underlying address is the same. Known issues diff --git a/cuda_bindings/tests/test_cuda.py b/cuda_bindings/tests/test_cuda.py index e3dce787e7..da3c6dec61 100644 --- a/cuda_bindings/tests/test_cuda.py +++ b/cuda_bindings/tests/test_cuda.py @@ -1031,5 +1031,7 @@ def test_struct_pointer_comparison(target): a = target(123) b = target(123) assert a == b + assert hash(a) == hash(b) c = target(456) assert a != c + assert hash(a) != hash(c) diff --git a/cuda_bindings/tests/test_cudart.py b/cuda_bindings/tests/test_cudart.py index f9c4ee8e57..70803c0777 100644 --- a/cuda_bindings/tests/test_cudart.py +++ b/cuda_bindings/tests/test_cudart.py @@ -1384,5 +1384,7 @@ def test_struct_pointer_comparison(target): a = target(123) b = target(123) assert a == b + assert hash(a) == hash(b) c = target(456) assert a != c + assert hash(a) != hash(c)