From 301242bbc59ec3b4dbc99c676e5ae36b24230040 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 20 Jun 2018 12:44:39 -0500 Subject: [PATCH 1/3] BUG: Normalize address before comparison Closes https://github.com/dask/distributed/issues/2058 --- distributed/tests/test_worker_client.py | 14 ++++++++++++++ distributed/worker.py | 5 ++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_worker_client.py b/distributed/tests/test_worker_client.py index c4e3c775a3a..a6b1ac2d148 100644 --- a/distributed/tests/test_worker_client.py +++ b/distributed/tests/test_worker_client.py @@ -297,3 +297,17 @@ def f(): result = yield c.submit(f) assert result + + +@gen_cluster(client=True) +def test_submit_different_names(c, s, a, b): + # https://github.com/dask/distributed/issues/2058 + da = pytest.importorskip('dask.array') + c = yield Client('localhost:' + s.address.split(":")[-1], loop=s.loop, + asynchronous=True) + with c: + X = c.persist(da.random.uniform(size=(100, 10), chunks=50)) + yield wait(X) + + fut = yield c.submit(lambda x: x.sum().compute(), X) + assert fut > 0 diff --git a/distributed/worker.py b/distributed/worker.py index 74dbc0949db..b6a735a06f2 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -27,7 +27,8 @@ from . import profile from .batched import BatchedSend -from .comm import get_address_host, get_local_address_for, connect +from .comm import (get_address_host, get_local_address_for, connect, + resolve_address) from .comm.utils import offload from .compatibility import unicode, get_thread_identity, finalize from .core import (error_message, CommClosedError, @@ -2619,6 +2620,8 @@ def get_client(address=None, timeout=3): worker_client secede """ + if address: + address = resolve_address(address) try: worker = get_worker() except ValueError: # could not find worker From df96738799e15ddd93bc7ce44507d28dd30df4ac Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 20 Jun 2018 13:15:42 -0500 Subject: [PATCH 2/3] Make a keyword --- distributed/worker.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index b6a735a06f2..948d27db476 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -25,10 +25,9 @@ from tornado.ioloop import IOLoop from tornado.locks import Event -from . import profile +from . import profile, comm from .batched import BatchedSend -from .comm import (get_address_host, get_local_address_for, connect, - resolve_address) +from .comm import get_address_host, get_local_address_for, connect from .comm.utils import offload from .compatibility import unicode, get_thread_identity, finalize from .core import (error_message, CommClosedError, @@ -2597,11 +2596,25 @@ def get_worker(): raise ValueError("No workers found") -def get_client(address=None, timeout=3): - """ Get a client while within a task +def get_client(address=None, timeout=3, resolve_address=True): + """Get a client while within a task. This client connects to the same scheduler to which the worker is connected + Parameters + ---------- + address : str, optional + The address of the scheduler to connect to. Defaults to the scheduler + the worker is connected to. + timeout : int, default 3 + Timeout (in seconds) for getting the Client + resolve_address : bool, default True + Whether to resolve `address` to its canonical form. + + Returns + ------- + Client + Examples -------- >>> def f(): @@ -2620,8 +2633,8 @@ def get_client(address=None, timeout=3): worker_client secede """ - if address: - address = resolve_address(address) + if address and resolve_address: + address = comm.resolve_address(address) try: worker = get_worker() except ValueError: # could not find worker From e435b2c8fe20a6651987a809af515514950a576b Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 20 Jun 2018 16:12:21 -0500 Subject: [PATCH 3/3] Cleanup test --- distributed/tests/test_worker_client.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/distributed/tests/test_worker_client.py b/distributed/tests/test_worker_client.py index a6b1ac2d148..2b96ae59c4f 100644 --- a/distributed/tests/test_worker_client.py +++ b/distributed/tests/test_worker_client.py @@ -299,15 +299,17 @@ def f(): assert result -@gen_cluster(client=True) -def test_submit_different_names(c, s, a, b): +@gen_cluster() +def test_submit_different_names(s, a, b): # https://github.com/dask/distributed/issues/2058 da = pytest.importorskip('dask.array') c = yield Client('localhost:' + s.address.split(":")[-1], loop=s.loop, asynchronous=True) - with c: + try: X = c.persist(da.random.uniform(size=(100, 10), chunks=50)) yield wait(X) fut = yield c.submit(lambda x: x.sum().compute(), X) assert fut > 0 + finally: + yield c.close()