diff --git a/distributed/queues.py b/distributed/queues.py index d022b010e1c..e368d329d03 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -7,7 +7,7 @@ from .client import Future, Client from .utils import sync, thread_state -from .worker import get_client +from .worker import get_client, get_worker from .utils import parse_timedelta logger = logging.getLogger(__name__) @@ -150,8 +150,8 @@ class Queue: Name used by other clients and the scheduler to identify the queue. If not given, a random name will be generated. client: Client (optional) - Client used for communication with the scheduler. Defaults to the - value of ``Client.current()``. + Client used for communication with the scheduler. + If not given, the default global client will be used. maxsize: int (optional) Number of items allowed in the queue. If 0 (the default), the queue size is unbounded. @@ -170,7 +170,11 @@ class Queue: """ def __init__(self, name=None, client=None, maxsize=0): - self.client = client or Client.current() + try: + self.client = client or Client.current() + except ValueError: + # Initialise new client + self.client = get_worker().client self.name = name or "queue-" + uuid.uuid4().hex self._event_started = asyncio.Event() if self.client.asynchronous or getattr( diff --git a/distributed/tests/test_queues.py b/distributed/tests/test_queues.py index 94d80c9dbcf..8f400498854 100644 --- a/distributed/tests/test_queues.py +++ b/distributed/tests/test_queues.py @@ -6,7 +6,7 @@ from distributed import Client, Queue, Nanny, worker_client, wait, TimeoutError from distributed.metrics import time -from distributed.utils_test import gen_cluster, inc, div +from distributed.utils_test import gen_cluster, inc, div, popen from distributed.utils_test import client, cluster_fixture, loop # noqa: F401 @@ -276,3 +276,22 @@ def get(): res = c.submit(get) await c.gather([res, fut]) + + +def test_queue_in_task(loop): + # Ensure that we can create a Queue inside a task on a + # worker in a separate Python process than the client + with popen(["dask-scheduler", "--no-dashboard"]): + with popen(["dask-worker", "127.0.0.1:8786"]): + with Client("tcp://127.0.0.1:8786", loop=loop) as c: + c.wait_for_workers(1) + + x = Queue("x") + x.put(123) + + def foo(): + y = Queue("x") + return y.get() + + result = c.submit(foo).result() + assert result == 123 diff --git a/distributed/tests/test_variable.py b/distributed/tests/test_variable.py index 5d9ece6ee54..37b3c756be7 100644 --- a/distributed/tests/test_variable.py +++ b/distributed/tests/test_variable.py @@ -10,9 +10,8 @@ from distributed import Client, Variable, worker_client, Nanny, wait, TimeoutError from distributed.metrics import time from distributed.compatibility import WINDOWS -from distributed.utils_test import gen_cluster, inc, div +from distributed.utils_test import gen_cluster, inc, div, captured_logger, popen from distributed.utils_test import client, cluster_fixture, loop # noqa: F401 -from distributed.utils_test import captured_logger @gen_cluster(client=True) @@ -40,6 +39,25 @@ async def test_variable(c, s, a, b): assert time() < start + 5 +def test_variable_in_task(loop): + # Ensure that we can create a Variable inside a task on a + # worker in a separate Python process than the client + with popen(["dask-scheduler", "--no-dashboard"]): + with popen(["dask-worker", "127.0.0.1:8786"]): + with Client("tcp://127.0.0.1:8786", loop=loop) as c: + c.wait_for_workers(1) + + x = Variable("x") + x.set(123) + + def foo(): + y = Variable("x") + return y.get() + + result = c.submit(foo).result() + assert result == 123 + + @gen_cluster(client=True) async def test_delete_unset_variable(c, s, a, b): x = Variable() diff --git a/distributed/variable.py b/distributed/variable.py index 19fbd2bb031..c3fdc94d0d7 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -9,13 +9,13 @@ from dask.utils import stringify from .client import Future, Client from .utils import log_errors, TimeoutError, parse_timedelta -from .worker import get_client +from .worker import get_client, get_worker logger = logging.getLogger(__name__) class VariableExtension: - """An extension for the scheduler to manage queues + """An extension for the scheduler to manage Variables This adds the following routes to the scheduler @@ -145,8 +145,8 @@ class Variable: Name used by other clients and the scheduler to identify the variable. If not given, a random name will be generated. client: Client (optional) - Client used for communication with the scheduler. Defaults to the - value of ``Client.current()``. + Client used for communication with the scheduler. + If not given, the default global client will be used. Examples -------- @@ -165,7 +165,11 @@ class Variable: """ def __init__(self, name=None, client=None, maxsize=0): - self.client = client or Client.current() + try: + self.client = client or Client.current() + except ValueError: + # Initialise new client + self.client = get_worker().client self.name = name or "variable-" + uuid.uuid4().hex async def _set(self, value):