From 63cc96603638d227bdf9a781b338b509921c9a29 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 1 Sep 2025 03:38:42 -0700 Subject: [PATCH 1/3] Remove UCX comm and tests --- distributed/comm/__init__.py | 5 - distributed/comm/tests/test_comms.py | 11 - distributed/comm/tests/test_ucx.py | 462 --------------- distributed/comm/tests/test_ucx_config.py | 163 ------ distributed/comm/ucx.py | 665 ---------------------- distributed/tests/test_nanny.py | 8 +- distributed/tests/test_worker.py | 15 - distributed/utils_test.py | 42 -- pyproject.toml | 1 - 9 files changed, 2 insertions(+), 1370 deletions(-) delete mode 100644 distributed/comm/tests/test_ucx.py delete mode 100644 distributed/comm/tests/test_ucx_config.py delete mode 100644 distributed/comm/ucx.py diff --git a/distributed/comm/__init__.py b/distributed/comm/__init__.py index d1733762ccd..e1b597cc70b 100644 --- a/distributed/comm/__init__.py +++ b/distributed/comm/__init__.py @@ -22,10 +22,5 @@ def _register_transports(): backends["tcp"] = tcp.TCPBackend() backends["tls"] = tcp.TLSBackend() - try: - from distributed.comm import ucx - except ImportError: - pass - _register_transports() diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 55acfd52867..f4f93dc9688 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -570,16 +570,6 @@ async def client_communicate(key, delay=0): listener.stop() -@pytest.mark.gpu -@gen_test() -async def test_ucx_client_server(ucx_loop): - pytest.importorskip("distributed.comm.ucx") - ucp = pytest.importorskip("ucp") - - addr = ucp.get_address() - await check_client_server("ucx://" + addr) - - def tcp_eq(expected_host, expected_port=None): def checker(loc): host, port = parse_host_port(loc) @@ -1395,7 +1385,6 @@ async def test_do_not_share_buffers(tcp, list_cls): See Also -------- test_share_buffer_with_header - test_ucx.py::test_do_not_share_buffers """ np = pytest.importorskip("numpy") diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py deleted file mode 100644 index 988dc183fbc..00000000000 --- a/distributed/comm/tests/test_ucx.py +++ /dev/null @@ -1,462 +0,0 @@ -from __future__ import annotations - -import asyncio -import os -from unittest.mock import patch - -import pytest - -import dask - -from distributed.utils import wait_for - -pytestmark = pytest.mark.gpu - -ucp = pytest.importorskip("ucp") - -from distributed import Client, Scheduler, wait -from distributed.comm import connect, listen, parse_address, ucx -from distributed.comm.core import CommClosedError -from distributed.comm.registry import backends, get_backend -from distributed.deploy.local import LocalCluster -from distributed.diagnostics.nvml import ( - device_get_count, - get_device_index_and_uuid, - get_device_mig_mode, - has_cuda_context, -) -from distributed.protocol import to_serialize -from distributed.protocol.utils_test import get_host_array -from distributed.utils_test import gen_test, inc - -try: - HOST = ucp.get_address() -except Exception: - HOST = "127.0.0.1" - - -def test_registered(ucx_loop): - assert "ucx" in backends - backend = get_backend("ucx") - assert isinstance(backend, ucx.UCXBackend) - - -async def get_comm_pair( - listen_addr=f"ucx://{HOST}", listen_args=None, connect_args=None, **kwargs -): - listen_args = listen_args or {} - connect_args = connect_args or {} - q = asyncio.queues.Queue() - - async def handle_comm(comm): - await q.put(comm) - - listener = listen(listen_addr, handle_comm, **listen_args, **kwargs) - async with listener: - comm = await connect(listener.contact_address, **connect_args, **kwargs) - serv_comm = await q.get() - return (comm, serv_comm) - - -@gen_test() -async def test_ping_pong(ucx_loop): - com, serv_com = await get_comm_pair() - msg = {"op": "ping"} - await com.write(msg) - result = await serv_com.read() - assert result == msg - result["op"] = "pong" - - await serv_com.write(result) - - result = await com.read() - assert result == {"op": "pong"} - - await com.close() - await serv_com.close() - - -@gen_test() -async def test_comm_objs(ucx_loop): - comm, serv_comm = await get_comm_pair() - - scheme, loc = parse_address(comm.peer_address) - assert scheme == "ucx" - - scheme, loc = parse_address(serv_comm.peer_address) - assert scheme == "ucx" - - assert comm.peer_address == serv_comm.local_address - - -@gen_test() -async def test_ucx_specific(ucx_loop): - """ - Test concrete UCX API. - """ - # TODO: - # 1. ensure exceptions in handle_comm fail the test - # 2. Use dict in read / write, put seralization there. - # 3. Test peer_address - # 4. Test cleanup - address = f"ucx://{HOST}:{0}" - - async def handle_comm(comm): - msg = await comm.read() - msg["op"] = "pong" - await comm.write(msg) - await comm.read() - await comm.close() - assert comm.closed() is True - - listener = await ucx.UCXListener(address, handle_comm) - host, port = listener.get_host_port() - assert host.count(".") == 3 - assert port > 0 - - l = [] - - async def client_communicate(key, delay=0): - addr = "%s:%d" % (host, port) - comm = await connect(listener.contact_address) - # TODO: peer_address - # assert comm.peer_address == 'ucx://' + addr - assert comm.extra_info == {} - msg = {"op": "ping", "data": key} - await comm.write(msg) - if delay: - await asyncio.sleep(delay) - msg = await comm.read() - assert msg == {"op": "pong", "data": key} - await comm.write({"op": "client closed"}) - l.append(key) - return comm - - comm = await client_communicate(key=1234, delay=0.5) - - # Many clients at once - N = 2 - futures = [client_communicate(key=i, delay=0.05) for i in range(N)] - await asyncio.gather(*futures) - assert set(l) == {1234} | set(range(N)) - - listener.stop() - - -@gen_test() -async def test_ping_pong_data(ucx_loop): - np = pytest.importorskip("numpy") - - data = np.ones((10, 10)) - - com, serv_com = await get_comm_pair() - msg = {"op": "ping", "data": to_serialize(data)} - await com.write(msg) - result = await serv_com.read() - result["op"] = "pong" - data2 = result.pop("data") - np.testing.assert_array_equal(data2, data) - - await serv_com.write(result) - - result = await com.read() - assert result == {"op": "pong"} - - await com.close() - await serv_com.close() - - -@gen_test() -async def test_ucx_deserialize(ucx_loop): - # Note we see this error on some systems with this test: - # `socket.gaierror: [Errno -5] No address associated with hostname` - # This may be due to a system configuration issue. - from distributed.comm.tests.test_comms import check_deserialize - - await check_deserialize("tcp://") - - -@pytest.mark.parametrize( - "g", - [ - lambda cudf: cudf.Series([1, 2, 3]), - lambda cudf: cudf.Series([], dtype=object), - lambda cudf: cudf.DataFrame([], dtype=object), - lambda cudf: cudf.DataFrame([1]).head(0), - lambda cudf: cudf.DataFrame([1.0]).head(0), - lambda cudf: cudf.DataFrame({"a": []}), - lambda cudf: cudf.DataFrame({"a": ["a"]}).head(0), - lambda cudf: cudf.DataFrame({"a": [1.0]}).head(0), - lambda cudf: cudf.DataFrame({"a": [1]}).head(0), - lambda cudf: cudf.DataFrame({"a": [1, 2, None], "b": [1.0, 2.0, None]}), - lambda cudf: cudf.DataFrame({"a": ["Check", "str"], "b": ["Sup", "port"]}), - ], -) -@gen_test() -async def test_ping_pong_cudf(ucx_loop, g): - # if this test appears after cupy an import error arises - # *** ImportError: /usr/lib/x86_64-linux-gnu/libstdc++.so.6: version `CXXABI_1.3.11' - # not found (required by python3.7/site-packages/pyarrow/../../../libarrow.so.12) - cudf = pytest.importorskip("cudf") - from cudf.testing import assert_eq - - cudf_obj = g(cudf) - - com, serv_com = await get_comm_pair() - msg = {"op": "ping", "data": to_serialize(cudf_obj)} - - await com.write(msg) - result = await serv_com.read() - - cudf_obj_2 = result.pop("data") - assert result["op"] == "ping" - assert_eq(cudf_obj, cudf_obj_2) - - await com.close() - await serv_com.close() - - -@pytest.mark.parametrize("shape", [(100,), (10, 10), (4947,)]) -@gen_test() -async def test_ping_pong_cupy(ucx_loop, shape): - cupy = pytest.importorskip("cupy") - com, serv_com = await get_comm_pair() - - arr = cupy.random.random(shape) - msg = {"op": "ping", "data": to_serialize(arr)} - - _, result = await asyncio.gather(com.write(msg), serv_com.read()) - data2 = result.pop("data") - - assert result["op"] == "ping" - cupy.testing.assert_array_equal(arr, data2) - await com.close() - await serv_com.close() - - -@pytest.mark.slow -@pytest.mark.parametrize("n", [int(1e9), int(2.5e9)]) -@gen_test() -async def test_large_cupy(ucx_loop, n, cleanup): - cupy = pytest.importorskip("cupy") - com, serv_com = await get_comm_pair() - - arr = cupy.ones(n, dtype="u1") - msg = {"op": "ping", "data": to_serialize(arr)} - - _, result = await asyncio.gather(com.write(msg), serv_com.read()) - data2 = result.pop("data") - - assert result["op"] == "ping" - assert len(data2) == len(arr) - await com.close() - await serv_com.close() - - -@gen_test() -async def test_ping_pong_numba(ucx_loop): - np = pytest.importorskip("numpy") - numba = pytest.importorskip("numba") - import numba.cuda - - arr = np.arange(10) - arr = numba.cuda.to_device(arr) - - com, serv_com = await get_comm_pair() - msg = {"op": "ping", "data": to_serialize(arr)} - - await com.write(msg) - result = await serv_com.read() - data2 = result.pop("data") - assert result["op"] == "ping" - - -@pytest.mark.parametrize("processes", [True, False]) -@gen_test() -async def test_ucx_localcluster(ucx_loop, processes, cleanup): - async with LocalCluster( - protocol="ucx", - host=HOST, - dashboard_address=":0", - n_workers=2, - threads_per_worker=1, - processes=processes, - asynchronous=True, - ) as cluster: - async with Client(cluster, asynchronous=True) as client: - x = client.submit(inc, 1) - await x - assert x.key in cluster.scheduler.tasks - if not processes: - assert any(w.data == {x.key: 2} for w in cluster.workers.values()) - assert len(cluster.scheduler.workers) == 2 - - -@pytest.mark.slow -@gen_test(timeout=120) -async def test_stress( - ucx_loop, -): - da = pytest.importorskip("dask.array") - - chunksize = "10 MB" - - async with LocalCluster( - protocol="ucx", - dashboard_address=":0", - asynchronous=True, - host=HOST, - ) as cluster: - async with Client(cluster, asynchronous=True) as client: - rs = da.random.RandomState() - x = rs.random((10000, 10000), chunks=(-1, chunksize)) - x = client.persist(x) - await wait(x) - - for _ in range(10): - x = x.rechunk((chunksize, -1)) - x = x.rechunk((-1, chunksize)) - x = client.persist(x) - await wait(x) - - -@gen_test() -async def test_simple( - ucx_loop, -): - async with LocalCluster( - protocol="ucx", n_workers=2, threads_per_worker=2, asynchronous=True - ) as cluster: - async with Client(cluster, asynchronous=True) as client: - assert cluster.scheduler_address.startswith("ucx://") - assert await client.submit(lambda x: x + 1, 10) == 11 - - -@pytest.mark.xfail(reason="If running on Docker, requires --pid=host") -@gen_test() -async def test_cuda_context( - ucx_loop, -): - try: - device_info = get_device_index_and_uuid( - next( - filter( - lambda i: get_device_mig_mode(i)[0] == 0, range(device_get_count()) - ) - ) - ) - except StopIteration: - pytest.skip("No CUDA device in non-MIG mode available") - - with patch.dict( - os.environ, {"CUDA_VISIBLE_DEVICES": device_info.uuid.decode("utf-8")} - ): - with dask.config.set({"distributed.comm.ucx.create-cuda-context": True}): - async with LocalCluster( - protocol="ucx", n_workers=1, asynchronous=True - ) as cluster: - async with Client(cluster, asynchronous=True) as client: - assert cluster.scheduler_address.startswith("ucx://") - ctx = has_cuda_context() - assert ctx.has_context and ctx.device_info == device_info - worker_cuda_context = await client.run(has_cuda_context) - assert len(worker_cuda_context) == 1 - worker_cuda_context = list(worker_cuda_context.values()) - assert ( - worker_cuda_context[0].has_context - and worker_cuda_context[0].device_info == device_info - ) - - -@gen_test() -async def test_transpose( - ucx_loop, -): - da = pytest.importorskip("dask.array") - - async with LocalCluster( - protocol="ucx", n_workers=2, threads_per_worker=2, asynchronous=True - ) as cluster: - async with Client(cluster, asynchronous=True) as client: - assert cluster.scheduler_address.startswith("ucx://") - x = client.persist(da.ones((10000, 10000), chunks=(1000, 1000))) - await x - y = (x + x.T).sum() - await y - - -@pytest.mark.parametrize("port", [0, 1234]) -@gen_test() -async def test_ucx_protocol(ucx_loop, cleanup, port): - async with Scheduler(protocol="ucx", port=port, dashboard_address=":0") as s: - assert s.address.startswith("ucx://") - - -@pytest.mark.skipif( - not hasattr(ucp.exceptions, "UCXUnreachable"), - reason="Requires UCX-Py support for UCXUnreachable exception", -) -@gen_test() -async def test_ucx_unreachable( - ucx_loop, -): - with pytest.raises(OSError, match="Timed out trying to connect to"): - await Client("ucx://255.255.255.255:12345", timeout=1, asynchronous=True) - - -@gen_test() -async def test_comm_closed_on_read_error(): - reader, writer = await get_comm_pair() - - # Depending on the UCP protocol selected, it may raise either - # `asyncio.TimeoutError` or `CommClosedError`, so validate either one. - with pytest.raises((asyncio.TimeoutError, CommClosedError)): - await wait_for(reader.read(), 0.01) - - assert reader.closed() - - -@gen_test() -async def test_embedded_cupy_array( - ucx_loop, -): - cupy = pytest.importorskip("cupy") - da = pytest.importorskip("dask.array") - np = pytest.importorskip("numpy") - - async with LocalCluster( - protocol="ucx", n_workers=1, threads_per_worker=1, asynchronous=True - ) as cluster: - async with Client(cluster, asynchronous=True) as client: - assert cluster.scheduler_address.startswith("ucx://") - a = cupy.arange(10000) - x = da.from_array(a, chunks=(10000,)) - b = await client.compute(x) - cupy.testing.assert_array_equal(a, b) - - -@gen_test() -async def test_do_not_share_buffers(ucx_loop): - """Test that two objects with buffer interface in the same message do not share - their buffer upon deserialization. - - See Also - -------- - test_comms.py::test_do_not_share_buffers - """ - np = pytest.importorskip("numpy") - - com, serv_com = await get_comm_pair() - msg = {"data": to_serialize([np.array([1, 2]), np.array([3, 4])])} - - await com.write(msg) - result = await serv_com.read() - await com.close() - await serv_com.close() - - a, b = result["data"] - ha = get_host_array(a) - hb = get_host_array(b) - assert ha is not hb - assert ha.nbytes == a.nbytes - assert hb.nbytes == a.nbytes diff --git a/distributed/comm/tests/test_ucx_config.py b/distributed/comm/tests/test_ucx_config.py deleted file mode 100644 index d7a91703263..00000000000 --- a/distributed/comm/tests/test_ucx_config.py +++ /dev/null @@ -1,163 +0,0 @@ -from __future__ import annotations - -import os -import sys -from time import sleep - -import pytest - -pytestmark = pytest.mark.gpu - -import dask - -from distributed import Client -from distributed.comm.ucx import _prepare_ucx_config -from distributed.utils import get_ip, open_port -from distributed.utils_test import gen_test, popen - -try: - HOST = get_ip() -except Exception: - HOST = "127.0.0.1" - -ucp = pytest.importorskip("ucp") -rmm = pytest.importorskip("rmm") - - -@gen_test() -async def test_ucx_config(ucx_loop, cleanup): - ucx = { - "nvlink": True, - "infiniband": True, - "rdmacm": False, - "tcp": True, - "cuda-copy": True, - } - - with dask.config.set({"distributed.comm.ucx": ucx}): - ucx_config, ucx_environment = _prepare_ucx_config() - assert ucx_config == { - "TLS": "rc,tcp,cuda_copy,cuda_ipc", - "SOCKADDR_TLS_PRIORITY": "tcp", - } - assert ucx_environment == {} - - ucx = { - "nvlink": False, - "infiniband": True, - "rdmacm": False, - "tcp": True, - "cuda-copy": False, - } - - with dask.config.set({"distributed.comm.ucx": ucx}): - ucx_config, ucx_environment = _prepare_ucx_config() - assert ucx_config == {"TLS": "rc,tcp", "SOCKADDR_TLS_PRIORITY": "tcp"} - assert ucx_environment == {} - - ucx = { - "nvlink": False, - "infiniband": True, - "rdmacm": True, - "tcp": True, - "cuda-copy": True, - } - - with dask.config.set({"distributed.comm.ucx": ucx}): - ucx_config, ucx_environment = _prepare_ucx_config() - assert ucx_config == { - "TLS": "rc,tcp,cuda_copy", - "SOCKADDR_TLS_PRIORITY": "rdmacm", - } - assert ucx_environment == {} - - ucx = { - "nvlink": None, - "infiniband": None, - "rdmacm": None, - "tcp": None, - "cuda-copy": None, - } - - with dask.config.set({"distributed.comm.ucx": ucx}): - ucx_config, ucx_environment = _prepare_ucx_config() - assert ucx_config == {} - assert ucx_environment == {} - - ucx = { - "nvlink": False, - "infiniband": True, - "rdmacm": True, - "tcp": True, - "cuda-copy": True, - } - - with dask.config.set( - { - "distributed.comm.ucx": ucx, - "distributed.comm.ucx.environment": { - "tls": "all", - "memtrack-dest": "stdout", - }, - } - ): - ucx_config, ucx_environment = _prepare_ucx_config() - assert ucx_config == { - "TLS": "rc,tcp,cuda_copy", - "SOCKADDR_TLS_PRIORITY": "rdmacm", - } - assert ucx_environment == {"UCX_MEMTRACK_DEST": "stdout"} - - -def test_ucx_config_w_env_var(ucx_loop, cleanup, loop): - env = os.environ.copy() - env["DASK_DISTRIBUTED__RMM__POOL_SIZE"] = "1000.00 MB" - - port = str(open_port()) - # Using localhost appears to be less flaky than {HOST}. Additionally, this is - # closer to how other dask worker tests are written. - sched_addr = f"ucx://127.0.0.1:{port}" - - with popen( - [ - sys.executable, - "-m", - "dask", - "scheduler", - "--no-dashboard", - "--protocol", - "ucx", - "--port", - port, - ], - env=env, - ): - with popen( - [ - sys.executable, - "-m", - "dask", - "worker", - sched_addr, - "--host", - "127.0.0.1", - "--no-dashboard", - "--protocol", - "ucx", - "--no-nanny", - ], - env=env, - ): - with Client(sched_addr, loop=loop, timeout=30) as c: - while not c.scheduler_info()["workers"]: - sleep(0.1) - - # Check for RMM pool resource type - rmm_resource = c.run_on_scheduler( - rmm.mr.get_current_device_resource_type - ) - assert rmm_resource == rmm.mr.PoolMemoryResource - - rmm_resource_workers = c.run(rmm.mr.get_current_device_resource_type) - for v in rmm_resource_workers.values(): - assert v == rmm.mr.PoolMemoryResource diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py deleted file mode 100644 index 54b14fec443..00000000000 --- a/distributed/comm/ucx.py +++ /dev/null @@ -1,665 +0,0 @@ -""" -:ref:`UCX`_ based communications for distributed. - -See :ref:`communications` for more. - -.. _UCX: https://github.com/openucx/ucx -""" - -from __future__ import annotations - -import functools -import logging -import os -import struct -import weakref -from collections.abc import Awaitable, Callable, Collection -from typing import TYPE_CHECKING, Any -from unittest.mock import patch - -import dask -from dask.utils import parse_bytes - -from distributed.comm.addressing import parse_host_port, unparse_host_port -from distributed.comm.core import BaseListener, Comm, CommClosedError, Connector -from distributed.comm.registry import Backend, backends -from distributed.comm.utils import ensure_concrete_host, from_frames, to_frames -from distributed.diagnostics.nvml import ( - CudaDeviceInfo, - get_device_index_and_uuid, - has_cuda_context, -) -from distributed.protocol.utils import host_array -from distributed.utils import ensure_ip, get_ip, get_ipv6, log_errors, nbytes - -logger = logging.getLogger(__name__) - -# In order to avoid double init when forking/spawning new processes (multiprocess), -# we make sure only to import and initialize UCX once at first use. This is also -# required to ensure Dask configuration gets propagated to UCX, which needs -# variables to be set before being imported. -if TYPE_CHECKING: - try: - import ucp - except ImportError: - pass -else: - ucp = None - -device_array = None -pre_existing_cuda_context = False -cuda_context_created = False - - -_warning_suffix = ( - "This is often the result of a CUDA-enabled library calling a CUDA runtime function before " - "Dask-CUDA can spawn worker processes. Please make sure any such function calls don't happen " - "at import time or in the global scope of a program." -) - - -def _get_device_and_uuid_str(device_info: CudaDeviceInfo) -> str: - return f"{device_info.device_index} ({str(device_info.uuid)})" - - -def _warn_existing_cuda_context(device_info: CudaDeviceInfo, pid: int) -> None: - device_uuid_str = _get_device_and_uuid_str(device_info) - logger.warning( - f"A CUDA context for device {device_uuid_str} already exists " - f"on process ID {pid}. {_warning_suffix}" - ) - - -def _warn_cuda_context_wrong_device( - device_info_expected: CudaDeviceInfo, device_info_actual: CudaDeviceInfo, pid: int -) -> None: - expected_device_uuid_str = _get_device_and_uuid_str(device_info_expected) - actual_device_uuid_str = _get_device_and_uuid_str(device_info_actual) - logger.warning( - f"Worker with process ID {pid} should have a CUDA context assigned to device " - f"{expected_device_uuid_str}, but instead the CUDA context is on device " - f"{actual_device_uuid_str}. {_warning_suffix}" - ) - - -def synchronize_stream(stream=0): - import numba.cuda - - ctx = numba.cuda.current_context() - cu_stream = numba.cuda.driver.drvapi.cu_stream(stream) - stream = numba.cuda.driver.Stream(ctx, cu_stream, None) - stream.synchronize() - - -def init_once(): - global ucp, device_array - global ucx_create_endpoint, ucx_create_listener - global pre_existing_cuda_context, cuda_context_created - - if ucp is not None: - return - - # remove/process dask.ucx flags for valid ucx options - ucx_config, ucx_environment = _prepare_ucx_config() - - # We ensure the CUDA context is created before initializing UCX. This can't - # be safely handled externally because communications in Dask start before - # preload scripts run. - # Precedence: - # 1. external environment - # 2. ucx_config (high level settings passed to ucp.init) - # 3. ucx_environment (low level settings equivalent to environment variables) - ucx_tls = os.environ.get( - "UCX_TLS", - ucx_config.get("TLS", ucx_environment.get("UCX_TLS", "")), - ) - if ( - dask.config.get("distributed.comm.ucx.create-cuda-context") is True - # This is not foolproof, if UCX_TLS=all we might require CUDA - # depending on configuration of UCX, but this is better than - # nothing - or ("cuda" in ucx_tls and "^cuda" not in ucx_tls) - ): - try: - import numba.cuda - except ImportError: - raise ImportError( - "CUDA support with UCX requires Numba for context management" - ) - - cuda_visible_device = get_device_index_and_uuid( - os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0] - ) - pre_existing_cuda_context = has_cuda_context() - if pre_existing_cuda_context.has_context: - _warn_existing_cuda_context( - pre_existing_cuda_context.device_info, os.getpid() - ) - - numba.cuda.current_context() - - cuda_context_created = has_cuda_context() - if ( - cuda_context_created.has_context - and cuda_context_created.device_info.uuid != cuda_visible_device.uuid - ): - _warn_cuda_context_wrong_device( - cuda_visible_device, cuda_context_created.device_info, os.getpid() - ) - - import ucp as _ucp - - ucp = _ucp - - with patch.dict(os.environ, ucx_environment): - # We carefully ensure that ucx_environment only contains things - # that don't override ucx_config or existing slots in the - # environment, so the user's external environment can safely - # override things here. - ucp.init(options=ucx_config, env_takes_precedence=True) - - pool_size_str = dask.config.get("distributed.rmm.pool-size") - - # Find the function, `cuda_array()`, to use when allocating new CUDA arrays - try: - import rmm - - def device_array(n): - return rmm.DeviceBuffer(size=n) - - if pool_size_str is not None: - pool_size = parse_bytes(pool_size_str) - rmm.reinitialize( - pool_allocator=True, managed_memory=False, initial_pool_size=pool_size - ) - except ImportError: - try: - import numba.cuda - - def numba_device_array(n): - a = numba.cuda.device_array((n,), dtype="u1") - weakref.finalize(a, numba.cuda.current_context) - return a - - device_array = numba_device_array - - except ImportError: - - def device_array(n): - raise RuntimeError( - "In order to send/recv CUDA arrays, Numba or RMM is required" - ) - - if pool_size_str is not None: - logger.warning( - "Initial RMM pool size defined, but RMM is not available. " - "Please consider installing RMM or removing the pool size option." - ) - - -def _close_comm(ref): - """Callback to close Dask Comm when UCX Endpoint closes or errors - - Parameters - ---------- - ref: weak reference to a Dask UCX comm - """ - comm = ref() - if comm is not None: - comm._closed = True - - -class UCX(Comm): - """Comm object using UCP. - - Parameters - ---------- - ep : ucp.Endpoint - The UCP endpoint. - address : str - The address, prefixed with `ucx://` to use. - deserialize : bool, default True - Whether to deserialize data in :meth:`distributed.protocol.loads` - - Notes - ----- - The read-write cycle uses the following pattern: - - Each msg is serialized into a number of "data" frames. We prepend these - real frames with two additional frames - - 1. is_gpus: Boolean indicator for whether the frame should be - received into GPU memory. Packed in '?' format. Unpack with - ``?`` format. - 2. frame_size : Unsigned int describing the size of frame (in bytes) - to receive. Packed in 'Q' format, so a length-0 frame is equivalent - to an unsized frame. Unpacked with ``Q``. - - The expected read cycle is - - 1. Read the frame describing if connection is closing and number of frames - 2. Read the frame describing whether each data frame is gpu-bound - 3. Read the frame describing whether each data frame is sized - 4. Read all the data frames. - """ - - def __init__( # type: ignore[no-untyped-def] - self, ep, local_addr: str, peer_addr: str, deserialize: bool = True - ): - super().__init__(deserialize=deserialize) - self._ep = ep - if local_addr: - assert local_addr.startswith("ucx") - assert peer_addr.startswith("ucx") - self._local_addr = local_addr - self._peer_addr = peer_addr - self.comm_flag = None - - # When the UCX endpoint closes or errors the registered callback - # is called. - if hasattr(self._ep, "set_close_callback"): - ref = weakref.ref(self) - self._ep.set_close_callback(functools.partial(_close_comm, ref)) - self._closed = False - self._has_close_callback = True - else: - self._has_close_callback = False - - logger.debug("UCX.__init__ %s", self) - - @property - def local_address(self) -> str: - return self._local_addr - - @property - def peer_address(self) -> str: - return self._peer_addr - - @property - def same_host(self) -> bool: - """Unlike in TCP, local_address can be blank""" - return super().same_host if self._local_addr else False - - @log_errors - async def write( - self, - msg: dict, - serializers: Collection[str] | None = None, - on_error: str = "message", - ) -> int: - if self.closed(): - raise CommClosedError("Endpoint is closed -- unable to send message") - - if serializers is None: - serializers = ("cuda", "dask", "pickle", "error") - # msg can also be a list of dicts when sending batched messages - frames = await to_frames( - msg, - serializers=serializers, - on_error=on_error, - allow_offload=self.allow_offload, - ) - nframes = len(frames) - cuda_frames = tuple(hasattr(f, "__cuda_array_interface__") for f in frames) - sizes = tuple(nbytes(f) for f in frames) - cuda_send_frames, send_frames = zip( - *( - (is_cuda, each_frame) - for is_cuda, each_frame in zip(cuda_frames, frames) - if nbytes(each_frame) > 0 - ) - ) - - try: - # Send meta data - - # Send close flag and number of frames (_Bool, int64) - await self.ep.send(struct.pack("?Q", False, nframes)) - # Send which frames are CUDA (bool) and - # how large each frame is (uint64) - await self.ep.send( - struct.pack(nframes * "?" + nframes * "Q", *cuda_frames, *sizes) - ) - - # Send frames - - # It is necessary to first synchronize the default stream before start - # sending We synchronize the default stream because UCX is not - # stream-ordered and syncing the default stream will wait for other - # non-blocking CUDA streams. Note this is only sufficient if the memory - # being sent is not currently in use on non-blocking CUDA streams. - if any(cuda_send_frames): - synchronize_stream(0) - - for each_frame in send_frames: - await self.ep.send(each_frame) - return sum(sizes) - except ucp.exceptions.UCXBaseException: - self.abort() - raise CommClosedError("While writing, the connection was closed") - - @log_errors - async def read(self, deserializers=("cuda", "dask", "pickle", "error")): - if deserializers is None: - deserializers = ("cuda", "dask", "pickle", "error") - - try: - # Recv meta data - - # Recv close flag and number of frames (_Bool, int64) - msg = host_array(struct.calcsize("?Q")) - await self.ep.recv(msg) - (shutdown, nframes) = struct.unpack("?Q", msg) - - if shutdown: # The writer is closing the connection - raise CommClosedError("Connection closed by writer") - - # Recv which frames are CUDA (bool) and - # how large each frame is (uint64) - header_fmt = nframes * "?" + nframes * "Q" - header = host_array(struct.calcsize(header_fmt)) - await self.ep.recv(header) - header = struct.unpack(header_fmt, header) - cuda_frames, sizes = header[:nframes], header[nframes:] - except BaseException as e: # noqa: B036 - # In addition to UCX exceptions, may be CancelledError or another - # "low-level" exception. The only safe thing to do is to abort. - # (See also https://github.com/dask/distributed/pull/6574). - self.abort() - raise CommClosedError( - f"Connection closed by writer.\nInner exception: {e!r}" - ) - else: - # Recv frames - frames = [ - device_array(each_size) if is_cuda else host_array(each_size) - for is_cuda, each_size in zip(cuda_frames, sizes) - ] - cuda_recv_frames, recv_frames = zip( - *( - (is_cuda, each_frame) - for is_cuda, each_frame in zip(cuda_frames, frames) - if nbytes(each_frame) > 0 - ) - ) - - # It is necessary to first populate `frames` with CUDA arrays and synchronize - # the default stream before starting receiving to ensure buffers have been allocated - if any(cuda_recv_frames): - synchronize_stream(0) - - try: - for each_frame in recv_frames: - await self.ep.recv(each_frame) - except BaseException as e: # noqa: B036 - # In addition to UCX exceptions, may be CancelledError or another - # "low-level" exception. The only safe thing to do is to abort. - # (See also https://github.com/dask/distributed/pull/6574). - self.abort() - raise CommClosedError( - f"Connection closed by writer.\nInner exception: {e!r}" - ) - - try: - msg = await from_frames( - frames, - deserialize=self.deserialize, - deserializers=deserializers, - allow_offload=self.allow_offload, - ) - except EOFError: - # Frames possibly garbled or truncated by communication error - self.abort() - raise CommClosedError("Aborted stream on truncated data") - return msg - - async def close(self): - self._closed = True - if self._ep is not None: - try: - await self.ep.send(struct.pack("?Q", True, 0)) - except ( # noqa: B030 - ucp.exceptions.UCXError, - ucp.exceptions.UCXCloseError, - ucp.exceptions.UCXCanceled, - ) + (getattr(ucp.exceptions, "UCXConnectionReset", ()),): - # If the other end is in the process of closing, - # UCX will sometimes raise a `Input/output` error, - # which we can ignore. - pass - self.abort() - self._ep = None - - def abort(self): - self._closed = True - if self._ep is not None: - self._ep.abort() - self._ep = None - - @property - def ep(self): - if self._ep is not None: - return self._ep - else: - raise CommClosedError("UCX Endpoint is closed") - - def closed(self): - if self._has_close_callback is True: - # The self._closed flag is separate from the endpoint's lifetime, even when - # the endpoint has closed or errored, there may be messages on its buffer - # still to be received, even though sending is not possible anymore. - return self._closed - else: - return self._ep is None - - -class UCXConnector(Connector): - prefix = "ucx://" - comm_class = UCX - encrypted = False - - async def connect( - self, address: str, deserialize: bool = True, **connection_args: Any - ) -> UCX: - logger.debug("UCXConnector.connect: %s", address) - ip, port = parse_host_port(address) - init_once() - try: - ep = await ucp.create_endpoint(ip, port) - except ucp.exceptions.UCXBaseException: - raise CommClosedError("Connection closed before handshake completed") - return self.comm_class( - ep, - local_addr="", - peer_addr=self.prefix + address, - deserialize=deserialize, - ) - - -class UCXListener(BaseListener): - prefix = UCXConnector.prefix - comm_class = UCXConnector.comm_class - encrypted = UCXConnector.encrypted - - def __init__( - self, - address: str, - comm_handler: Callable[[UCX], Awaitable[None]] | None = None, - deserialize: bool = False, - allow_offload: bool = True, - **connection_args: Any, - ): - super().__init__() - if not address.startswith("ucx"): - address = "ucx://" + address - self.ip, self._input_port = parse_host_port(address, default_port=0) - self.comm_handler = comm_handler - self.deserialize = deserialize - self.allow_offload = allow_offload - self._ep = None # type: ucp.Endpoint - self.ucp_server = None - self.connection_args = connection_args - - @property - def port(self): - return self.ucp_server.port - - @property - def address(self): - return "ucx://" + self.ip + ":" + str(self.port) - - async def start(self): - async def serve_forever(client_ep): - ucx = UCX( - client_ep, - local_addr=self.address, - peer_addr=self.address, - deserialize=self.deserialize, - ) - ucx.allow_offload = self.allow_offload - try: - await self.on_connection(ucx) - except CommClosedError: - logger.debug("Connection closed before handshake completed") - return - if self.comm_handler: - await self.comm_handler(ucx) - - init_once() - self.ucp_server = ucp.create_listener(serve_forever, port=self._input_port) - - def stop(self): - self.ucp_server = None - - def get_host_port(self): - # TODO: TCP raises if this hasn't started yet. - return self.ip, self.port - - @property - def listen_address(self): - return self.prefix + unparse_host_port(*self.get_host_port()) - - @property - def contact_address(self): - host, port = self.get_host_port() - host = ensure_concrete_host(host) # TODO: ensure_concrete_host - return self.prefix + unparse_host_port(host, port) - - @property - def bound_address(self): - # TODO: Does this become part of the base API? Kinda hazy, since - # we exclude in for inproc. - return self.get_host_port() - - -class UCXBackend(Backend): - # I / O - - def get_connector(self): - return UCXConnector() - - def get_listener(self, loc, handle_comm, deserialize, **connection_args): - return UCXListener(loc, handle_comm, deserialize, **connection_args) - - # Address handling - # This duplicates BaseTCPBackend - - def get_address_host(self, loc): - return parse_host_port(loc)[0] - - def get_address_host_port(self, loc): - return parse_host_port(loc) - - def resolve_address(self, loc): - host, port = parse_host_port(loc) - return unparse_host_port(ensure_ip(host), port) - - def get_local_address_for(self, loc): - host, port = parse_host_port(loc) - host = ensure_ip(host) - if ":" in host: - local_host = get_ipv6(host) - else: - local_host = get_ip(host) - return unparse_host_port(local_host, None) - - -backends["ucx"] = UCXBackend() - - -def _prepare_ucx_config(): - """Translate dask config options to appropriate UCX config options - - Returns - ------- - tuple - Options suitable for passing to ``ucp.init`` and additional - UCX options that will be inserted directly into the environment - while calling ``ucp.init``. - """ - - # configuration of UCX can happen in two ways: - # 1) high level on/off flags which correspond to UCX configuration - # 2) explicitly defined UCX configuration flags in distributed.comm.ucx.environment - # High-level settings in (1) are preferred to settings in (2) - # Settings in the external environment override both - - high_level_options = {} - - # if any of the high level flags are set, as long as they are not Null/None, - # we assume we should configure basic TLS settings for UCX, otherwise we - # leave UCX to its default configuration - if any( - [ - dask.config.get("distributed.comm.ucx.tcp"), - dask.config.get("distributed.comm.ucx.nvlink"), - dask.config.get("distributed.comm.ucx.infiniband"), - ] - ): - if dask.config.get("distributed.comm.ucx.rdmacm"): - tls = "tcp" - tls_priority = "rdmacm" - else: - tls = "tcp" - tls_priority = "tcp" - - # CUDA COPY can optionally be used with ucx -- we rely on the user - # to define when messages will include CUDA objects. Note: - # defining only the Infiniband flag will not enable cuda_copy - if any( - [ - dask.config.get("distributed.comm.ucx.nvlink"), - dask.config.get("distributed.comm.ucx.cuda-copy"), - ] - ): - tls = tls + ",cuda_copy" - - if dask.config.get("distributed.comm.ucx.infiniband"): - tls = "rc," + tls - if dask.config.get("distributed.comm.ucx.nvlink"): - tls = tls + ",cuda_ipc" - - high_level_options = {"TLS": tls, "SOCKADDR_TLS_PRIORITY": tls_priority} - - # Pick up any other ucx environment settings - environment_options = {} - for k, v in dask.config.get("distributed.comm.ucx.environment", {}).items(): - # {"some-name": value} is translated to {"UCX_SOME_NAME": value} - key = "_".join(map(str.upper, ("UCX", *k.split("-")))) - if (hl_key := key[4:]) in high_level_options: - logger.warning( - f"Ignoring {k}={v} ({key=}) in ucx.environment, " - f"preferring {hl_key}={high_level_options[hl_key]} " - "from high level options" - ) - elif key in os.environ: - # This is only info because setting UCX configuration via - # environment variables is a reasonably common approach - logger.info( - f"Ignoring {k}={v} ({key=}) in ucx.environment, " - f"preferring {key}={os.environ[key]} from external environment" - ) - else: - environment_options[key] = v - - return high_level_options, environment_options diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index e6924cb60a3..df39e5e87ef 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -567,13 +567,9 @@ def raise_err(): self.loop.add_callback(raise_err) -@pytest.mark.parametrize("protocol", ["tcp", "ucx"]) @gen_test() -async def test_nanny_closed_by_keyboard_interrupt(ucx_loop, protocol): - if protocol == "ucx": # Skip if UCX isn't available - pytest.importorskip("ucp") - - async with Scheduler(protocol=protocol, dashboard_address=":0") as s: +async def test_nanny_closed_by_keyboard_interrupt(): + async with Scheduler(dashboard_address=":0") as s: async with Nanny( s.address, nthreads=1, worker_class=KeyboardInterruptWorker ) as n: diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 86856f2deea..6d1946570df 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1450,21 +1450,6 @@ async def test_interface_async(Worker): assert all("127.0.0.1" == d["host"] for d in info["workers"].values()) -@pytest.mark.gpu -@pytest.mark.parametrize("Worker", [Worker, Nanny]) -@gen_test() -async def test_protocol_from_scheduler_address(ucx_loop, Worker): - pytest.importorskip("ucp") - - async with Scheduler(protocol="ucx", dashboard_address=":0") as s: - assert s.address.startswith("ucx://") - async with Worker(s.address) as w: - assert w.address.startswith("ucx://") - async with Client(s.address, asynchronous=True) as c: - info = c.scheduler_info() - assert info["address"].startswith("ucx://") - - @gen_test() async def test_host_uses_scheduler_protocol(monkeypatch): # Ensure worker uses scheduler's protocol to determine host address, not the default scheme diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 1b4639baf7a..8720dcd87cd 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -2089,48 +2089,6 @@ def raises_with_cause( raise exc -def ucx_exception_handler(loop, context): - """UCX exception handler for `ucx_loop` during test. - - Prints the exception and its message. - - Parameters - ---------- - loop: object - Reference to the running event loop - context: dict - Dictionary containing exception details. - """ - msg = context.get("exception", context["message"]) - print(msg) - - -# Let's make sure that UCX gets time to cancel -# progress tasks before closing the event loop. -@pytest.fixture(scope="function") -def ucx_loop(): - """Allows UCX to cancel progress tasks before closing event loop. - - When UCX tasks are not completed in time (e.g., by unexpected Endpoint - closure), clean up tasks before closing the event loop to prevent unwanted - errors from being raised. - """ - ucp = pytest.importorskip("ucp") - - loop = asyncio.new_event_loop() - loop.set_exception_handler(ucx_exception_handler) - ucp.reset() - yield loop - ucp.reset() - loop.close() - - # Reset also Distributed's UCX initialization, i.e., revert the effects of - # `distributed.comm.ucx.init_once()`. - import distributed.comm.ucx - - distributed.comm.ucx.ucp = None - - def wait_for_log_line( match: bytes, stream: IO[bytes] | None, max_lines: int | None = 10 ) -> bytes: diff --git a/pyproject.toml b/pyproject.toml index 8da71aecf16..2ef69c2f856 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -230,7 +230,6 @@ omit = [ "distributed/deploy/ssh.py", "distributed/_version.py", "distributed/pytest_resourceleaks.py", - "distributed/comm/ucx.py", ] [tool.coverage.report] From f5102db3b10d7be21f029dbfdeed9b4b320eda73 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 3 Sep 2025 14:14:57 -0700 Subject: [PATCH 2/3] Raise deprecation exception --- distributed/comm/__init__.py | 10 ++++ distributed/comm/ucx.py | 95 ++++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+) create mode 100644 distributed/comm/ucx.py diff --git a/distributed/comm/__init__.py b/distributed/comm/__init__.py index e1b597cc70b..0b9801c5cd3 100644 --- a/distributed/comm/__init__.py +++ b/distributed/comm/__init__.py @@ -22,5 +22,15 @@ def _register_transports(): backends["tcp"] = tcp.TCPBackend() backends["tls"] = tcp.TLSBackend() + try: + # If `distributed-ucxx` is installed, it takes over the protocol="ucx" support + import distributed_ucxx + except ImportError: + try: + # Else protocol="ucx" will raise a deprecation warning and exception + from distributed.comm import ucx + except ImportError: + pass + _register_transports() diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py new file mode 100644 index 00000000000..600ae30fcdf --- /dev/null +++ b/distributed/comm/ucx.py @@ -0,0 +1,95 @@ +""" +:ref:`UCX`_ based communications for distributed. + +See :ref:`communications` for more. + +.. _UCX: https://github.com/openucx/ucx +""" + +from __future__ import annotations + +import textwrap +import warnings +from collections.abc import Awaitable, Callable +from typing import Any + +from distributed.comm.core import BaseListener, Comm, Connector +from distributed.comm.registry import Backend, backends + + +def _raise_deprecated(): + message = textwrap.dedent( + """\ + The 'ucx' protocol was removed from Distributed because UCX-Py has been + deprecated. To continue using protocol='ucx', please install + 'distributed-ucxx' (conda-forge) or 'distributed-ucxx-cu[12,13]' (PyPI, + selecting 12 for CUDA version 12, and 13 for CUDA version 13. + """ + ) + warnings.warn(message, FutureWarning) + raise FutureWarning(message) + + +class UCX(Comm): + def __init__(self): + _raise_deprecated() + + +class UCXConnector(Connector): + prefix = "ucx://" + comm_class = UCX + encrypted = False + + +class UCXListener(BaseListener): + prefix = UCXConnector.prefix + comm_class = UCXConnector.comm_class + encrypted = UCXConnector.encrypted + + def __init__( + self, + address: str, + comm_handler: Callable[[UCX], Awaitable[None]] | None = None, + deserialize: bool = False, + allow_offload: bool = True, + **connection_args: Any, + ): + _raise_deprecated() + + @property + def port(self): + return self.ucp_server.port + + async def start(self): + _raise_deprecated() + + def stop(self): + _raise_deprecated() + + @property + def listen_address(self): + _raise_deprecated() + + @property + def contact_address(self): + _raise_deprecated() + + +class UCXBackend(Backend): + def get_connector(self): + return UCXConnector() + + def get_listener(self, loc, handle_comm, deserialize, **connection_args): + return UCXListener(loc, handle_comm, deserialize, **connection_args) + + def get_address_host(self, loc): + _raise_deprecated() + + def resolve_address(self, loc): + _raise_deprecated() + + def get_local_address_for(self, loc): + _raise_deprecated() + + +backends["ucx"] = UCXBackend() From 52a17b319920e4cbe9eed5fed782ebf8b17fde90 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 4 Sep 2025 01:39:15 -0700 Subject: [PATCH 3/3] Improve warning/error message --- distributed/comm/ucx.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 600ae30fcdf..4061701702f 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -1,11 +1,3 @@ -""" -:ref:`UCX`_ based communications for distributed. - -See :ref:`communications` for more. - -.. _UCX: https://github.com/openucx/ucx -""" - from __future__ import annotations import textwrap @@ -20,10 +12,10 @@ def _raise_deprecated(): message = textwrap.dedent( """\ - The 'ucx' protocol was removed from Distributed because UCX-Py has been - deprecated. To continue using protocol='ucx', please install - 'distributed-ucxx' (conda-forge) or 'distributed-ucxx-cu[12,13]' (PyPI, - selecting 12 for CUDA version 12, and 13 for CUDA version 13. + The 'ucx' protocol was removed from Distributed because UCX-Py has been deprecated. + To continue using protocol='ucx', please install 'distributed-ucxx' (conda-forge) + or 'distributed-ucxx-cu[12,13]' (PyPI, selecting 12 for CUDA version 12.*, and 13 + for CUDA version 13.*). """ ) warnings.warn(message, FutureWarning)