From 48400613bcd7c55593fd795af60e4a1c1803bd75 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Wed, 23 Jul 2025 17:53:50 +0000 Subject: [PATCH 1/3] cache primary context --- cuda_core/cuda/core/experimental/_device.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_device.py b/cuda_core/cuda/core/experimental/_device.py index f7c760f5d6..e31b39427e 100644 --- a/cuda_core/cuda/core/experimental/_device.py +++ b/cuda_core/cuda/core/experimental/_device.py @@ -1022,6 +1022,18 @@ def _check_context_initialized(self): f"Device {self._id} is not yet initialized, perhaps you forgot to call .set_current() first?" ) + def _get_primary_context(self) -> driver.CUcontext: + try: + primary_ctxs = _tls.primary_ctxs + except AttributeError: + total = len(_tls.devices) + primary_ctxs = _tls.primary_ctxs = [None] * total + ctx = primary_ctxs[self._id] + if ctx is None: + ctx = handle_return(driver.cuDevicePrimaryCtxRetain(self._id)) + primary_ctxs[self._id] = ctx + return ctx + def _get_current_context(self, check_consistency=False) -> driver.CUcontext: err, ctx = driver.cuCtxGetCurrent() @@ -1189,13 +1201,13 @@ def set_current(self, ctx: Context = None) -> Union[Context, None]: ctx = handle_return(driver.cuCtxGetCurrent()) if int(ctx) == 0: # use primary ctx - ctx = handle_return(driver.cuDevicePrimaryCtxRetain(self._id)) + ctx = self._get_primary_context() handle_return(driver.cuCtxPushCurrent(ctx)) else: ctx_id = handle_return(driver.cuCtxGetDevice()) if ctx_id != self._id: # use primary ctx - ctx = handle_return(driver.cuDevicePrimaryCtxRetain(self._id)) + ctx = self._get_primary_context() handle_return(driver.cuCtxPushCurrent(ctx)) else: # no-op, a valid context already exists and is set current From c4afd331b16af9766b08f2593b613827e10ac288 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Fri, 25 Jul 2025 03:13:24 +0000 Subject: [PATCH 2/3] avoid increasing stack size --- cuda_core/cuda/core/experimental/_device.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_device.py b/cuda_core/cuda/core/experimental/_device.py index e31b39427e..deed7f651b 100644 --- a/cuda_core/cuda/core/experimental/_device.py +++ b/cuda_core/cuda/core/experimental/_device.py @@ -1202,13 +1202,13 @@ def set_current(self, ctx: Context = None) -> Union[Context, None]: if int(ctx) == 0: # use primary ctx ctx = self._get_primary_context() - handle_return(driver.cuCtxPushCurrent(ctx)) + handle_return(driver.cuCtxSetCurrent(ctx)) else: ctx_id = handle_return(driver.cuCtxGetDevice()) if ctx_id != self._id: # use primary ctx ctx = self._get_primary_context() - handle_return(driver.cuCtxPushCurrent(ctx)) + handle_return(driver.cuCtxSetCurrent(ctx)) else: # no-op, a valid context already exists and is set current pass From 562340c7747952bacbb21f699371586decd19742 Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Fri, 25 Jul 2025 03:55:03 +0000 Subject: [PATCH 3/3] unconditionally set primary context to current --- cuda_core/cuda/core/experimental/_device.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_device.py b/cuda_core/cuda/core/experimental/_device.py index deed7f651b..6268fd5389 100644 --- a/cuda_core/cuda/core/experimental/_device.py +++ b/cuda_core/cuda/core/experimental/_device.py @@ -1198,20 +1198,9 @@ def set_current(self, ctx: Context = None) -> Union[Context, None]: if int(prev_ctx) != 0: return Context._from_ctx(prev_ctx, self._id) else: - ctx = handle_return(driver.cuCtxGetCurrent()) - if int(ctx) == 0: - # use primary ctx - ctx = self._get_primary_context() - handle_return(driver.cuCtxSetCurrent(ctx)) - else: - ctx_id = handle_return(driver.cuCtxGetDevice()) - if ctx_id != self._id: - # use primary ctx - ctx = self._get_primary_context() - handle_return(driver.cuCtxSetCurrent(ctx)) - else: - # no-op, a valid context already exists and is set current - pass + # use primary ctx + ctx = self._get_primary_context() + handle_return(driver.cuCtxSetCurrent(ctx)) self._has_inited = True def create_context(self, options: ContextOptions = None) -> Context: