diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 39a8b123cd3..0befb36d712 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -59,7 +59,7 @@ def read(self, deserializers=None): """ @abstractmethod - def write(self, msg, on_error=None): + def write(self, msg, serializers=None, on_error=None): """ Write a message (a Python object). diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 6724524344d..0afd16f1867 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2545,7 +2545,7 @@ async def gather(self, comm=None, keys=None, serializers=None): (self.tasks[key].state if key in self.tasks else None) for key in missing_keys ] - logger.debug( + logger.exception( "Couldn't gather keys %s state: %s workers: %s", missing_keys, missing_states, @@ -2553,17 +2553,21 @@ async def gather(self, comm=None, keys=None, serializers=None): ) result = {"status": "error", "keys": missing_keys} with log_errors(): + # Remove suspicious workers from the scheduler but allow them to + # reconnect. for worker in missing_workers: - self.remove_worker(address=worker) # this is extreme + self.remove_worker(address=worker, close=False) for key, workers in missing_keys.items(): - if not workers: - continue - ts = self.tasks[key] + # Task may already be gone if it was held by a + # `missing_worker` + ts = self.tasks.get(key) logger.exception( "Workers don't have promised key: %s, %s", str(workers), str(key), ) + if not workers or ts is None: + continue for worker in workers: ws = self.workers.get(worker) if ws is not None and ts in ws.has_what: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index f57dbfb9e07..4e0e9a8710c 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -17,7 +17,8 @@ import pytest from distributed import Nanny, Worker, Client, wait, fire_and_forget -from distributed.core import connect, rpc +from distributed.comm import Comm +from distributed.core import connect, rpc, ConnectionPool from distributed.scheduler import Scheduler, TaskState from distributed.client import wait from distributed.metrics import time @@ -1704,3 +1705,145 @@ async def test_no_danglng_asyncio_tasks(cleanup): tasks = asyncio.all_tasks() assert tasks == start + + +class BrokenComm(Comm): + peer_address = None + local_address = None + + def close(self): + pass + + def closed(self): + pass + + def abort(self): + pass + + def read(self, deserializers=None): + raise EnvironmentError + + def write(self, msg, serializers=None, on_error=None): + raise EnvironmentError + + +class FlakyConnectionPool(ConnectionPool): + def __init__(self, *args, failing_connections=0, **kwargs): + self.cnn_count = 0 + self.failing_connections = failing_connections + super(FlakyConnectionPool, self).__init__(*args, **kwargs) + + async def connect(self, *args, **kwargs): + self.cnn_count += 1 + if self.cnn_count > self.failing_connections: + return await super(FlakyConnectionPool, self).connect(*args, **kwargs) + else: + return BrokenComm() + + +@gen_cluster(client=True) +async def test_gather_failing_cnn_recover(c, s, a, b): + orig_rpc = s.rpc + x = await c.scatter({"x": 1}, workers=a.address) + + s.rpc = FlakyConnectionPool(failing_connections=1) + res = await s.gather(keys=["x"]) + assert res["status"] == "OK" + + +@gen_cluster(client=True) +async def test_gather_failing_cnn_error(c, s, a, b): + orig_rpc = s.rpc + x = await c.scatter({"x": 1}, workers=a.address) + + s.rpc = FlakyConnectionPool(failing_connections=10) + res = await s.gather(keys=["x"]) + assert res["status"] == "error" + assert list(res["keys"]) == ["x"] + + +@gen_cluster(client=True) +async def test_gather_no_workers(c, s, a, b): + await asyncio.sleep(1) + x = await c.scatter({"x": 1}, workers=a.address) + + await a.close() + await b.close() + + res = await s.gather(keys=["x"]) + assert res["status"] == "error" + assert list(res["keys"]) == ["x"] + + +@gen_cluster(client=True, client_kwargs={"direct_to_workers": False}) +async def test_gather_allow_worker_reconnect(c, s, a, b): + """ + Test that client resubmissions allow failed workers to reconnect and re-use + their results. Failure scenario would be a connection issue during result + gathering. + Upon connection failure, the worker is flagged as suspicious and removed + from the scheduler. If the worker is healthy and reconnencts we want to use + its results instead of recomputing them. + """ + # GH3246 + ALREADY_CALCULATED = [] + + import time + + def inc_slow(x): + # Once the graph below is rescheduled this computation runs again. We + # need to sleep for at least 0.5 seconds to give the worker a chance to + # reconnect (Heartbeat timing) + if x in ALREADY_CALCULATED: + time.sleep(0.5) + ALREADY_CALCULATED.append(x) + return x + 1 + + x = c.submit(inc_slow, 1) + y = c.submit(inc_slow, 2) + + def reducer(x, y): + return x + y + + z = c.submit(reducer, x, y) + + s.rpc = FlakyConnectionPool(failing_connections=4) + + with captured_logger(logging.getLogger("distributed.scheduler")) as sched_logger: + with captured_logger(logging.getLogger("distributed.client")) as client_logger: + with captured_logger( + logging.getLogger("distributed.worker") + ) as worker_logger: + # Gather using the client (as an ordinary user would) + # Upon a missing key, the client will reschedule the computations + res = await c.gather(z) + + assert res == 5 + + sched_logger = sched_logger.getvalue() + client_logger = client_logger.getvalue() + worker_logger = worker_logger.getvalue() + + # Ensure that the communication was done via the scheduler, i.e. we actually hit a bad connection + assert s.rpc.cnn_count > 0 + + assert "Encountered connection issue during data collection" in worker_logger + + # The reducer task was actually not found upon first collection. The client will reschedule the graph + assert "Couldn't gather 1 keys, rescheduling" in client_logger + # There will also be a `Unexpected worker completed task` message but this + # is rather an artifact and not the intention + assert "Workers don't have promised key" in sched_logger + + # Once the worker reconnects, it will also submit the keys it holds such + # that the scheduler again knows about the result. + # The final reduce step should then be used from the re-connected worker + # instead of recomputing it. + + starts = [] + finish_processing_transitions = 0 + for transition in s.transition_log: + key, start, finish, recommendations, timestamp = transition + if "reducer" in key and finish == "processing": + finish_processing_transitions += 1 + assert finish_processing_transitions == 1 diff --git a/distributed/tests/test_utils_comm.py b/distributed/tests/test_utils_comm.py index 224b4b7f181..f66d3ba62d5 100644 --- a/distributed/tests/test_utils_comm.py +++ b/distributed/tests/test_utils_comm.py @@ -1,6 +1,5 @@ -import pytest - -from distributed.core import rpc +from distributed.core import ConnectionPool +from distributed.comm import Comm from distributed.utils_test import gen_cluster from distributed.utils_comm import pack_data, gather_from_workers @@ -12,9 +11,9 @@ def test_pack_data(): assert pack_data({"a": ["x"], "b": "y"}, data) == {"a": [1], "b": "y"} -@pytest.mark.xfail(reason="rpc now needs to be a connection pool") @gen_cluster(client=True) def test_gather_from_workers_permissive(c, s, a, b): + rpc = ConnectionPool() x = yield c.scatter({"x": 1}, workers=a.address) data, missing, bad_workers = yield gather_from_workers( @@ -23,3 +22,39 @@ def test_gather_from_workers_permissive(c, s, a, b): assert data == {"x": 1} assert list(missing) == ["y"] + + +class BrokenComm(Comm): + peer_address = None + local_address = None + + def close(self): + pass + + def closed(self): + pass + + def abort(self): + pass + + def read(self, deserializers=None): + raise EnvironmentError + + def write(self, msg, serializers=None, on_error=None): + raise EnvironmentError + + +class BrokenConnectionPool(ConnectionPool): + async def connect(self, *args, **kwargs): + return BrokenComm() + + +@gen_cluster(client=True) +def test_gather_from_workers_permissive_flaky(c, s, a, b): + x = yield c.scatter({"x": 1}, workers=a.address) + + rpc = BrokenConnectionPool() + data, missing, bad_workers = yield gather_from_workers({"x": [a.address]}, rpc=rpc) + + assert missing == {"x": [a.address]} + assert bad_workers == [a.address] diff --git a/distributed/worker.py b/distributed/worker.py index 2a705320baa..327127cd39d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3116,27 +3116,46 @@ async def get_data_from_worker( if deserializers is None: deserializers = rpc.deserializers - comm = await rpc.connect(worker) - comm.name = "Ephemeral Worker->Worker for gather" - try: - response = await send_recv( - comm, - serializers=serializers, - deserializers=deserializers, - op="get_data", - keys=keys, - who=who, - max_connections=max_connections, - ) + retry_count = 0 + max_retries = 3 + + while True: + comm = await rpc.connect(worker) + comm.name = "Ephemeral Worker->Worker for gather" try: - status = response["status"] - except KeyError: - raise ValueError("Unexpected response", response) - else: - if status == "OK": - await comm.write("OK") - finally: - rpc.reuse(worker, comm) + response = await send_recv( + comm, + serializers=serializers, + deserializers=deserializers, + op="get_data", + keys=keys, + who=who, + max_connections=max_connections, + ) + try: + status = response["status"] + except KeyError: + raise ValueError("Unexpected response", response) + else: + if status == "OK": + await comm.write("OK") + break + except (EnvironmentError, CommClosedError): + if retry_count < max_retries: + await asyncio.sleep(0.1 * (2 ** retry_count)) + retry_count += 1 + logger.info( + "Encountered connection issue during data collection of keys %s on worker %s. Retrying (%s / %s)", + keys, + worker, + retry_count, + max_retries, + ) + continue + else: + raise + finally: + rpc.reuse(worker, comm) return response