Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
from .core import Comm, Connector, Listener, CommClosedError, FatalCommClosedError
from .utils import to_frames, from_frames, get_tcp_server_address, ensure_concrete_host

import asyncio
from itertools import starmap
from operator import add

from tlz import accumulate, cons, sliding_window

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -190,18 +195,24 @@ async def read(self, deserializers=None):
lengths = await stream.read_bytes(8 * n_frames)
lengths = struct.unpack("Q" * n_frames, lengths)

frames = []
for length in lengths:
if length:
if self._iostream_has_read_into:
frame = bytearray(length)
n = await stream.read_into(frame)
assert n == length, (n, length)
else:
if self._iostream_has_read_into:
frame_arr = bytearray(sum(lengths))
slices = starmap(
slice, sliding_window(2, accumulate(add, cons(0, sizes)))
)
frames = [frames_arr[sl] for sl in slices]
recvd_lengths = await asyncio.gather([
stream.read_into(f) for f in frames if len(f)
])
assert all(recvd_lengths == lengths), (recvd_lengths, lengths)
else:
frames = []
for length in lengths:
if length:
frame = await stream.read_bytes(length)
else:
frame = b""
frames.append(frame)
else:
frame = b""
frames.append(frame)
except StreamClosedError as e:
self.stream = None
if not shutting_down():
Expand Down
52 changes: 36 additions & 16 deletions distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
import dask
import numpy as np

import asyncio
from itertools import starmap
from operator import add

from tlz import accumulate, cons, sliding_window


logger = logging.getLogger(__name__)

Expand All @@ -42,12 +48,22 @@ def init_once():
ucp = _ucp
ucp.init(options=dask.config.get("ucx"), env_takes_precedence=True)

# Find the function, `as_cuda_array()`, to get array-likes from CUDA
try:
import numba.cuda

as_cuda_array = lambda a: numba.cuda.as_cuda_array(a)
except ImportError:

def as_cuda_array(a):
raise RuntimeError("In order to send/recv CUDA arrays, Numba is required")

# Find the function, `cuda_array()`, to use when allocating new CUDA arrays
try:
import rmm

if hasattr(rmm, "DeviceBuffer"):
cuda_array = lambda n: rmm.DeviceBuffer(size=n)
cuda_array = lambda n: as_cuda_array(rmm.DeviceBuffer(size=n))
else: # pre-0.11.0
cuda_array = lambda n: rmm.device_array(n, dtype=np.uint8)
except ImportError:
Expand All @@ -59,7 +75,7 @@ def init_once():

def cuda_array(n):
raise RuntimeError(
"In order to send/recv CUDA arrays, Numba or RMM is required"
"In order to send/recv CUDA arrays, Numba and RMM are required"
)


Expand Down Expand Up @@ -178,20 +194,24 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")):
raise CommClosedError("While reading, the connection was closed")
else:
# Recv frames
frames = []
for is_cuda, size in zip(is_cudas.tolist(), sizes.tolist()):
if size > 0:
if is_cuda:
frame = cuda_array(size)
else:
frame = np.empty(size, dtype=np.uint8)
await self.ep.recv(frame)
frames.append(frame)
else:
if is_cuda:
frames.append(cuda_array(size))
else:
frames.append(b"")
sizes_dev = sizes[is_cudas]
sizes_host = sizes[~is_cudas]
frames_dev_arr = cuda_array(sum(sizes_dev))
frames_host_arr = np.empty(sum(sizes_host), dtype=np.uint8)
slices_dev = starmap(
slice, sliding_window(2, accumulate(add, cons(0, sizes_dev)))
)
slices_host = starmap(
slice, sliding_window(2, accumulate(add, cons(0, sizes_host)))
)
frames_dev = [frames_dev_arr[sl] for sl in slices_dev]
frames_host = [frames_host_arr[sl] for sl in slices_host]
frames = len(sizes) * [None]
for i, f in zip(is_cudas.nonzero()[0], frames_dev):
frames[i] = f
for i, f in zip((~is_cudas).nonzero()[0], frames_host):
frames[i] = f
await asyncio.gather([self.ep.recv(f) for f in frames if len(f)])
msg = await from_frames(
frames, deserialize=self.deserialize, deserializers=deserializers
)
Expand Down