Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion distributed/comm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def read(self, deserializers=None):
"""

@abstractmethod
def write(self, msg, on_error=None):
def write(self, msg, serializers=None, on_error=None):
"""
Write a message (a Python object).

Expand Down
14 changes: 9 additions & 5 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2545,25 +2545,29 @@ async def gather(self, comm=None, keys=None, serializers=None):
(self.tasks[key].state if key in self.tasks else None)
for key in missing_keys
]
logger.debug(
logger.exception(
"Couldn't gather keys %s state: %s workers: %s",
missing_keys,
missing_states,
missing_workers,
)
result = {"status": "error", "keys": missing_keys}
with log_errors():
# Remove suspicious workers from the scheduler but allow them to
# reconnect.
for worker in missing_workers:
self.remove_worker(address=worker) # this is extreme
self.remove_worker(address=worker, close=False)
for key, workers in missing_keys.items():
if not workers:
continue
ts = self.tasks[key]
# Task may already be gone if it was held by a
# `missing_worker`
ts = self.tasks.get(key)
logger.exception(
"Workers don't have promised key: %s, %s",
str(workers),
str(key),
)
if not workers or ts is None:
continue
for worker in workers:
ws = self.workers.get(worker)
if ws is not None and ts in ws.has_what:
Expand Down
145 changes: 144 additions & 1 deletion distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import pytest

from distributed import Nanny, Worker, Client, wait, fire_and_forget
from distributed.core import connect, rpc
from distributed.comm import Comm
from distributed.core import connect, rpc, ConnectionPool
from distributed.scheduler import Scheduler, TaskState
from distributed.client import wait
from distributed.metrics import time
Expand Down Expand Up @@ -1704,3 +1705,145 @@ async def test_no_danglng_asyncio_tasks(cleanup):

tasks = asyncio.all_tasks()
assert tasks == start


class BrokenComm(Comm):
peer_address = None
local_address = None

def close(self):
pass

def closed(self):
pass

def abort(self):
pass

def read(self, deserializers=None):
raise EnvironmentError

def write(self, msg, serializers=None, on_error=None):
raise EnvironmentError


class FlakyConnectionPool(ConnectionPool):
def __init__(self, *args, failing_connections=0, **kwargs):
self.cnn_count = 0
self.failing_connections = failing_connections
super(FlakyConnectionPool, self).__init__(*args, **kwargs)

async def connect(self, *args, **kwargs):
self.cnn_count += 1
if self.cnn_count > self.failing_connections:
return await super(FlakyConnectionPool, self).connect(*args, **kwargs)
else:
return BrokenComm()


@gen_cluster(client=True)
async def test_gather_failing_cnn_recover(c, s, a, b):
orig_rpc = s.rpc
x = await c.scatter({"x": 1}, workers=a.address)

s.rpc = FlakyConnectionPool(failing_connections=1)
res = await s.gather(keys=["x"])
assert res["status"] == "OK"


@gen_cluster(client=True)
async def test_gather_failing_cnn_error(c, s, a, b):
orig_rpc = s.rpc
x = await c.scatter({"x": 1}, workers=a.address)

s.rpc = FlakyConnectionPool(failing_connections=10)
res = await s.gather(keys=["x"])
assert res["status"] == "error"
assert list(res["keys"]) == ["x"]


@gen_cluster(client=True)
async def test_gather_no_workers(c, s, a, b):
await asyncio.sleep(1)
x = await c.scatter({"x": 1}, workers=a.address)

await a.close()
await b.close()

res = await s.gather(keys=["x"])
assert res["status"] == "error"
assert list(res["keys"]) == ["x"]


@gen_cluster(client=True, client_kwargs={"direct_to_workers": False})
async def test_gather_allow_worker_reconnect(c, s, a, b):
"""
Test that client resubmissions allow failed workers to reconnect and re-use
their results. Failure scenario would be a connection issue during result
gathering.
Upon connection failure, the worker is flagged as suspicious and removed
from the scheduler. If the worker is healthy and reconnencts we want to use
its results instead of recomputing them.
"""
# GH3246
ALREADY_CALCULATED = []

import time

def inc_slow(x):
# Once the graph below is rescheduled this computation runs again. We
# need to sleep for at least 0.5 seconds to give the worker a chance to
# reconnect (Heartbeat timing)
if x in ALREADY_CALCULATED:
time.sleep(0.5)
ALREADY_CALCULATED.append(x)
return x + 1

x = c.submit(inc_slow, 1)
y = c.submit(inc_slow, 2)

def reducer(x, y):
return x + y

z = c.submit(reducer, x, y)

s.rpc = FlakyConnectionPool(failing_connections=4)

with captured_logger(logging.getLogger("distributed.scheduler")) as sched_logger:
with captured_logger(logging.getLogger("distributed.client")) as client_logger:
with captured_logger(
logging.getLogger("distributed.worker")
) as worker_logger:
# Gather using the client (as an ordinary user would)
# Upon a missing key, the client will reschedule the computations
res = await c.gather(z)

assert res == 5

sched_logger = sched_logger.getvalue()
client_logger = client_logger.getvalue()
worker_logger = worker_logger.getvalue()

# Ensure that the communication was done via the scheduler, i.e. we actually hit a bad connection
assert s.rpc.cnn_count > 0

assert "Encountered connection issue during data collection" in worker_logger

# The reducer task was actually not found upon first collection. The client will reschedule the graph
assert "Couldn't gather 1 keys, rescheduling" in client_logger
# There will also be a `Unexpected worker completed task` message but this
# is rather an artifact and not the intention
assert "Workers don't have promised key" in sched_logger

# Once the worker reconnects, it will also submit the keys it holds such
# that the scheduler again knows about the result.
# The final reduce step should then be used from the re-connected worker
# instead of recomputing it.

starts = []
finish_processing_transitions = 0
for transition in s.transition_log:
key, start, finish, recommendations, timestamp = transition
if "reducer" in key and finish == "processing":
finish_processing_transitions += 1
assert finish_processing_transitions == 1
43 changes: 39 additions & 4 deletions distributed/tests/test_utils_comm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest

from distributed.core import rpc
from distributed.core import ConnectionPool
from distributed.comm import Comm
from distributed.utils_test import gen_cluster
from distributed.utils_comm import pack_data, gather_from_workers

Expand All @@ -12,9 +11,9 @@ def test_pack_data():
assert pack_data({"a": ["x"], "b": "y"}, data) == {"a": [1], "b": "y"}


@pytest.mark.xfail(reason="rpc now needs to be a connection pool")
@gen_cluster(client=True)
def test_gather_from_workers_permissive(c, s, a, b):
rpc = ConnectionPool()
x = yield c.scatter({"x": 1}, workers=a.address)

data, missing, bad_workers = yield gather_from_workers(
Expand All @@ -23,3 +22,39 @@ def test_gather_from_workers_permissive(c, s, a, b):

assert data == {"x": 1}
assert list(missing) == ["y"]


class BrokenComm(Comm):
peer_address = None
local_address = None

def close(self):
pass

def closed(self):
pass

def abort(self):
pass

def read(self, deserializers=None):
raise EnvironmentError

def write(self, msg, serializers=None, on_error=None):
raise EnvironmentError


class BrokenConnectionPool(ConnectionPool):
async def connect(self, *args, **kwargs):
return BrokenComm()


@gen_cluster(client=True)
def test_gather_from_workers_permissive_flaky(c, s, a, b):
x = yield c.scatter({"x": 1}, workers=a.address)

rpc = BrokenConnectionPool()
data, missing, bad_workers = yield gather_from_workers({"x": [a.address]}, rpc=rpc)

assert missing == {"x": [a.address]}
assert bad_workers == [a.address]
59 changes: 39 additions & 20 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3116,27 +3116,46 @@ async def get_data_from_worker(
if deserializers is None:
deserializers = rpc.deserializers

comm = await rpc.connect(worker)
comm.name = "Ephemeral Worker->Worker for gather"
try:
response = await send_recv(
comm,
serializers=serializers,
deserializers=deserializers,
op="get_data",
keys=keys,
who=who,
max_connections=max_connections,
)
retry_count = 0
max_retries = 3

while True:
comm = await rpc.connect(worker)
comm.name = "Ephemeral Worker->Worker for gather"
try:
status = response["status"]
except KeyError:
raise ValueError("Unexpected response", response)
else:
if status == "OK":
await comm.write("OK")
finally:
rpc.reuse(worker, comm)
response = await send_recv(
comm,
serializers=serializers,
deserializers=deserializers,
op="get_data",
keys=keys,
who=who,
max_connections=max_connections,
)
try:
status = response["status"]
except KeyError:
raise ValueError("Unexpected response", response)
else:
if status == "OK":
await comm.write("OK")
break
except (EnvironmentError, CommClosedError):
if retry_count < max_retries:
await asyncio.sleep(0.1 * (2 ** retry_count))
retry_count += 1
logger.info(
"Encountered connection issue during data collection of keys %s on worker %s. Retrying (%s / %s)",
keys,
worker,
retry_count,
max_retries,
)
continue
else:
raise
finally:
rpc.reuse(worker, comm)

return response

Expand Down