diff --git a/distributed/tests/test_worker_client.py b/distributed/tests/test_worker_client.py index c4e3c775a3a..2b96ae59c4f 100644 --- a/distributed/tests/test_worker_client.py +++ b/distributed/tests/test_worker_client.py @@ -297,3 +297,19 @@ def f(): result = yield c.submit(f) assert result + + +@gen_cluster() +def test_submit_different_names(s, a, b): + # https://github.com/dask/distributed/issues/2058 + da = pytest.importorskip('dask.array') + c = yield Client('localhost:' + s.address.split(":")[-1], loop=s.loop, + asynchronous=True) + try: + X = c.persist(da.random.uniform(size=(100, 10), chunks=50)) + yield wait(X) + + fut = yield c.submit(lambda x: x.sum().compute(), X) + assert fut > 0 + finally: + yield c.close() diff --git a/distributed/worker.py b/distributed/worker.py index 74dbc0949db..948d27db476 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -25,7 +25,7 @@ from tornado.ioloop import IOLoop from tornado.locks import Event -from . import profile +from . import profile, comm from .batched import BatchedSend from .comm import get_address_host, get_local_address_for, connect from .comm.utils import offload @@ -2596,11 +2596,25 @@ def get_worker(): raise ValueError("No workers found") -def get_client(address=None, timeout=3): - """ Get a client while within a task +def get_client(address=None, timeout=3, resolve_address=True): + """Get a client while within a task. This client connects to the same scheduler to which the worker is connected + Parameters + ---------- + address : str, optional + The address of the scheduler to connect to. Defaults to the scheduler + the worker is connected to. + timeout : int, default 3 + Timeout (in seconds) for getting the Client + resolve_address : bool, default True + Whether to resolve `address` to its canonical form. + + Returns + ------- + Client + Examples -------- >>> def f(): @@ -2619,6 +2633,8 @@ def get_client(address=None, timeout=3): worker_client secede """ + if address and resolve_address: + address = comm.resolve_address(address) try: worker = get_worker() except ValueError: # could not find worker