diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index 581e17b8027..58a374c8bbf 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -357,6 +357,15 @@ async def test_ucx_protocol(ucx_loop, cleanup, port): assert s.address.startswith("ucx://") +@gen_test() +async def test_ucx_listener_ip(ucx_loop, cleanup): + async with Scheduler( + protocol="ucx", interface="localhost", dashboard_address=":0" + ) as s: + assert s.address.startswith("ucx://127.0.0.1") + assert s.listener.ucp.server.ip == "127.0.0.1" + + @pytest.mark.skipif( not hasattr(ucp.exceptions, "UCXUnreachable"), reason="Requires UCX-Py support for UCXUnreachable exception", diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 16ce82d49d2..955751d257c 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -472,7 +472,9 @@ async def serve_forever(client_ep): await self.comm_handler(ucx) init_once() - self.ucp_server = ucp.create_listener(serve_forever, port=self._input_port) + self.ucp_server = ucp.create_listener( + serve_forever, port=self._input_port, ip_address=self.ip + ) def stop(self): self.ucp_server = None