From d4725ff50a902c6af0ffe094704f147c9aa77f88 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 10 Aug 2020 12:16:46 -0700 Subject: [PATCH] Use PTDS for synchronization in UCX communication This changes the stream we synchronize on to support the per-thread default stream. --- distributed/comm/ucx.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index c016cc5a9d5..7fa69a45e35 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -39,12 +39,10 @@ ucx_create_listener = None -def synchronize_stream(stream=0): +def synchronize_ptds(): import numba.cuda - ctx = numba.cuda.current_context() - cu_stream = numba.cuda.driver.drvapi.cu_stream(stream) - stream = numba.cuda.driver.Stream(ctx, cu_stream, None) + stream = numba.cuda.per_thread_default_stream() stream.synchronize() @@ -221,13 +219,13 @@ async def write( # Send frames - # It is necessary to first synchronize the default stream before start sending - # We synchronize the default stream because UCX is not stream-ordered and - # syncing the default stream will wait for other non-blocking CUDA streams. + # It is necessary to first synchronize the per-thread default stream before start sending + # We synchronize the per-thread default stream because UCX is not stream-ordered and + # syncing the per-thread default stream will wait for other non-blocking CUDA streams. # Note this is only sufficient if the memory being sent is not currently in use on # non-blocking CUDA streams. if any(cuda_send_frames): - synchronize_stream(0) + synchronize_ptds() for each_frame in send_frames: await self.ep.send(each_frame) @@ -280,9 +278,9 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): ) # It is necessary to first populate `frames` with CUDA arrays and synchronize - # the default stream before starting receiving to ensure buffers have been allocated + # the per-thread default stream before starting receiving to ensure buffers have been allocated if any(cuda_recv_frames): - synchronize_stream(0) + synchronize_ptds() for each_frame in recv_frames: await self.ep.recv(each_frame)