Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
bd61588
Ensure client.restart waits for workers to leave
fjetter Jun 27, 2022
4f4c5dc
Expect all workers to come back; timeout otherwise
gjoseph92 Jul 11, 2022
010c911
Call restart as RPC from client
gjoseph92 Jul 11, 2022
f0a4938
Client restart cleanup even if restart fails
gjoseph92 Jul 11, 2022
e18fd1e
Merge branch 'main' into restart-wait-for-workers
gjoseph92 Jul 14, 2022
5a388d1
logging to track down why restart is hanging
gjoseph92 Jul 14, 2022
85afcbc
Add back broadcasting restart message to clients
gjoseph92 Jul 14, 2022
f7da0d2
Make restart timeout 2x longer by default
gjoseph92 Jul 14, 2022
fdf8358
Revert "logging to track down why restart is hanging"
gjoseph92 Jul 15, 2022
659e2d1
Fix `test_AllProgress`: reorder plugin.restart()
gjoseph92 Jul 15, 2022
b5c6ff0
Inner function for single `wait_for` timeout
gjoseph92 Jul 15, 2022
e18ea37
docstring on client as well
gjoseph92 Jul 15, 2022
b82e9ff
Clarify restart docstring
gjoseph92 Jul 15, 2022
c22f99c
Move other clearing ops first
gjoseph92 Jul 18, 2022
025d02e
Don't apply timeout to `remove_worker`
gjoseph92 Jul 18, 2022
648f1e9
Connect to nannies in parallel
gjoseph92 Jul 18, 2022
868c0c2
Explain `TimeoutError` contract
gjoseph92 Jul 18, 2022
2864e23
Revert "Explain `TimeoutError` contract"
gjoseph92 Jul 18, 2022
08014e9
Only apply timeout to worker-waiting
gjoseph92 Jul 18, 2022
55f16cd
move proc restart testing to more appropraite test
gjoseph92 Jul 18, 2022
0e3d96d
remove `_worker_coroutines`
gjoseph92 Jul 18, 2022
bcc3ad6
Merge remote-tracking branch 'upstream/main' into restart-wait-for-wo…
gjoseph92 Jul 18, 2022
3d6a938
`kill` nannies instead of `restart`
gjoseph92 Jul 18, 2022
b7c5e40
Add `wait_for_workers` option
gjoseph92 Jul 18, 2022
1dd9620
Fix docstrings & error messages
gjoseph92 Jul 19, 2022
675c425
Note new workers may take place of old ones
gjoseph92 Jul 19, 2022
8b4fa1a
decrease wait_for_workers poll interval
gjoseph92 Jul 19, 2022
7eb97ba
Missed one typo
gjoseph92 Jul 19, 2022
b4b9605
Drop redundant geninc (#6740)
hendrikmakait Jul 18, 2022
9ead9c4
fix test_restart_nanny_timeout_exceeded
gjoseph92 Jul 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 34 additions & 19 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1469,8 +1469,9 @@ def _handle_restart(self):
for state in self.futures.values():
state.cancel()
self.futures.clear()
with suppress(AttributeError):
self._restart_event.set()
self.generation += 1
with self._refcount_lock:
self.refcount.clear()

def _handle_error(self, exception=None):
logger.warning("Scheduler exception:")
Expand Down Expand Up @@ -3319,32 +3320,46 @@ def persist(
else:
return result

async def _restart(self, timeout=no_default):
async def _restart(self, timeout=no_default, wait_for_workers=True):
if timeout == no_default:
timeout = self._timeout * 2
timeout = self._timeout * 4
if timeout is not None:
timeout = parse_timedelta(timeout, "s")

self._send_to_scheduler({"op": "restart", "timeout": timeout})
self._restart_event = asyncio.Event()
try:
await asyncio.wait_for(self._restart_event.wait(), timeout)
except TimeoutError:
logger.error("Restart timed out after %.2f seconds", timeout)
await self.scheduler.restart(timeout=timeout, wait_for_workers=wait_for_workers)
return self

self.generation += 1
with self._refcount_lock:
self.refcount.clear()
def restart(self, timeout=no_default, wait_for_workers=True):
"""
Restart all workers. Reset local state. Optionally wait for workers to return.

return self
Workers without nannies are shut down, hoping an external deployment system
will restart them. Therefore, if not using nannies and your deployment system
does not automatically restart workers, ``restart`` will just shut down all
workers, then time out!

def restart(self, **kwargs):
"""Restart the distributed network
After `restart`, all connected workers are new, regardless of whether `TimeoutError`
was raised. Any workers that failed to shut down in time are removed, and
may or may not shut down on their own in the future.

This kills all active work, deletes all data on the network, and
restarts the worker processes.
Parameters
----------
timeout:
How long to wait for workers to shut down and come back, if `wait_for_workers`
is True, otherwise just how long to wait for workers to shut down.
Raises `asyncio.TimeoutError` if this is exceeded.
wait_for_workers:
Whether to wait for all workers to reconnect, or just for them to shut down
(default True). Use ``restart(wait_for_workers=False)`` combined with
`Client.wait_for_workers` for granular control over how many workers to
wait for.
See also
----------
Scheduler.restart
"""
return self.sync(self._restart, **kwargs)
return self.sync(
self._restart, timeout=timeout, wait_for_workers=wait_for_workers
)

async def _upload_large_file(self, local_filename, remote_filename=None):
if remote_filename is None:
Expand Down
2 changes: 1 addition & 1 deletion distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ async def kill(self, timeout=2):
informed
"""
if self.process is None:
return "OK"
return

deadline = time() + timeout
await self.process.kill(timeout=0.8 * (deadline - time()))
Expand Down
153 changes: 103 additions & 50 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
from distributed.event import EventExtension
from distributed.http import get_handlers
from distributed.lock import LockExtension
from distributed.metrics import time
from distributed.metrics import monotonic, time
from distributed.multi_lock import MultiLockExtension
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
Expand Down Expand Up @@ -3077,7 +3077,6 @@ def __init__(
"client-releases-keys": self.client_releases_keys,
"heartbeat-client": self.client_heartbeat,
"close-client": self.remove_client,
"restart": self.restart,
"subscribe-topic": self.subscribe_topic,
"unsubscribe-topic": self.unsubscribe_topic,
}
Expand Down Expand Up @@ -3114,6 +3113,7 @@ def __init__(
"rebalance": self.rebalance,
"replicate": self.replicate,
"run_function": self.run_function,
"restart": self.restart,
"update_data": self.update_data,
"set_resources": self.add_resources,
"retire_workers": self.retire_workers,
Expand Down Expand Up @@ -5112,23 +5112,59 @@ def clear_task_state(self):
collection.clear()

@log_errors
async def restart(self, client=None, timeout=30):
"""Restart all workers. Reset local state."""
async def restart(self, client=None, timeout=30, wait_for_workers=True):
"""
Restart all workers. Reset local state. Optionally wait for workers to return.

Workers without nannies are shut down, hoping an external deployment system
will restart them. Therefore, if not using nannies and your deployment system
does not automatically restart workers, ``restart`` will just shut down all
workers, then time out!

After `restart`, all connected workers are new, regardless of whether `TimeoutError`
was raised. Any workers that failed to shut down in time are removed, and
may or may not shut down on their own in the future.

Parameters
----------
timeout:
How long to wait for workers to shut down and come back, if `wait_for_workers`
is True, otherwise just how long to wait for workers to shut down.
Raises `asyncio.TimeoutError` if this is exceeded.
wait_for_workers:
Whether to wait for all workers to reconnect, or just for them to shut down
(default True). Use ``restart(wait_for_workers=False)`` combined with
`Client.wait_for_workers` for granular control over how many workers to
wait for.
See also
----------
Client.restart
"""
stimulus_id = f"restart-{time()}"
n_workers = len(self.workers)

logger.info("Send lost future signal to clients")
logger.info("Releasing all requested keys")
for cs in self.clients.values():
self.client_releases_keys(
keys=[ts.key for ts in cs.wants_what],
client=cs.client_key,
stimulus_id=stimulus_id,
)

self.clear_task_state()
self.erred_tasks.clear()
self.computations.clear()
self.report({"op": "restart"})

for plugin in list(self.plugins.values()):
try:
plugin.restart(self)
except Exception as e:
logger.exception(e)

n_workers = len(self.workers)
nanny_workers = {
addr: ws.nanny for addr, ws in self.workers.items() if ws.nanny
}

# Close non-Nanny workers. We have no way to restart them, so we just let them go,
# and assume a deployment system is going to restart them for us.
await asyncio.gather(
Expand All @@ -5139,59 +5175,76 @@ async def restart(self, client=None, timeout=30):
)
)

self.clear_task_state()

for plugin in list(self.plugins.values()):
try:
plugin.restart(self)
except Exception as e:
logger.exception(e)

logger.debug("Send kill signal to nannies: %s", nanny_workers)
async with contextlib.AsyncExitStack() as stack:
nannies = [
await stack.enter_async_context(
rpc(nanny_address, connection_args=self.connection_args)
nannies = await asyncio.gather(
*(
stack.enter_async_context(
rpc(nanny_address, connection_args=self.connection_args)
)
for nanny_address in nanny_workers.values()
)
for nanny_address in nanny_workers.values()
]
)

try:
resps = await asyncio.wait_for(
asyncio.gather(
*(
nanny.restart(close=True, timeout=timeout * 0.8)
for nanny in nannies
)
),
timeout,
)
# NOTE: the `WorkerState` entries for these workers will be removed
# naturally when they disconnect from the scheduler.
except TimeoutError:
logger.error(
"Nannies didn't report back restarted within "
"timeout. Continuing with restart process"
)
else:
if not all(resp == "OK" for resp in resps):
logger.error(
"Not all workers responded positively: %s",
resps,
exc_info=True,
start = monotonic()
resps = await asyncio.gather(
*(
asyncio.wait_for(
# FIXME does not raise if the process fails to shut down,
# see https://github.com/dask/distributed/pull/6427/files#r894917424
# NOTE: Nanny will automatically restart worker process when it's killed
nanny.kill(timeout=timeout),
timeout,
)
for nanny in nannies
),
return_exceptions=True,
)
# NOTE: the `WorkerState` entries for these workers will be removed
# naturally when they disconnect from the scheduler.

self.clear_task_state()
# Remove any workers that failed to shut down, so we can guarantee
# that after `restart`, there are no old workers around.
bad_nannies = [
addr for addr, resp in zip(nanny_workers, resps) if resp is not None
]
if bad_nannies:
await asyncio.gather(
*(
self.remove_worker(addr, stimulus_id=stimulus_id)
for addr in bad_nannies
)
)

self.erred_tasks.clear()
self.computations.clear()
raise TimeoutError(
f"{len(bad_nannies)}/{len(nannies)} nanny worker(s) did not shut down within {timeout}s"
)

self.log_event([client, "all"], {"action": "restart", "client": client})
start = time()
while time() < start + 10 and len(self.workers) < n_workers:
await asyncio.sleep(0.01)

self.report({"op": "restart"})
if wait_for_workers:
while monotonic() < start + timeout:
# NOTE: if new (unrelated) workers join while we're waiting, we may return before
# our shut-down workers have come back up. That's fine; workers are interchangeable.
if len(self.workers) >= n_workers:
return
await asyncio.sleep(0.2)
else:
msg = (
f"Waited for {n_workers} worker(s) to reconnect after restarting, "
f"but after {timeout}s, only {len(self.workers)} have returned. "
"Consider a longer timeout, or `wait_for_workers=False`."
)

if (n_nanny := len(nanny_workers)) < n_workers:
msg += (
f" The {n_workers - n_nanny} worker(s) not using Nannies were just shut "
"down instead of restarted (restart is only possible with Nannies). If "
"your deployment system does not automatically re-launch terminated "
"processes, then those workers will never come back, and `Client.restart` "
"will always time out. Do not use `Client.restart` in that case."
)
raise TimeoutError(msg) from None

async def broadcast(
self,
Expand Down
29 changes: 12 additions & 17 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@
double,
gen_cluster,
gen_test,
geninc,
get_cert,
inc,
map_varying,
Expand Down Expand Up @@ -2662,13 +2661,13 @@ def func(x, y=10):

@gen_cluster(client=True)
async def test_run_coroutine(c, s, a, b):
results = await c.run(geninc, 1, delay=0.05)
results = await c.run(asyncinc, 1, delay=0.05)
assert results == {a.address: 2, b.address: 2}

results = await c.run(geninc, 1, delay=0.05, workers=[a.address])
results = await c.run(asyncinc, 1, delay=0.05, workers=[a.address])
assert results == {a.address: 2}

results = await c.run(geninc, 1, workers=[])
results = await c.run(asyncinc, 1, workers=[])
assert results == {}

with pytest.raises(RuntimeError, match="hello"):
Expand All @@ -2679,14 +2678,14 @@ async def test_run_coroutine(c, s, a, b):


def test_run_coroutine_sync(c, s, a, b):
result = c.run(geninc, 2, delay=0.01)
result = c.run(asyncinc, 2, delay=0.01)
assert result == {a["address"]: 3, b["address"]: 3}

result = c.run(geninc, 2, workers=[a["address"]])
result = c.run(asyncinc, 2, workers=[a["address"]])
assert result == {a["address"]: 3}

t1 = time()
result = c.run(geninc, 2, delay=10, wait=False)
result = c.run(asyncinc, 2, delay=10, wait=False)
t2 = time()
assert result is None
assert t2 - t1 <= 1.0
Expand Down Expand Up @@ -3498,13 +3497,17 @@ def block(ev):


@pytest.mark.slow
@gen_cluster(Worker=Nanny, client=True, timeout=60)
@gen_cluster(client=True)
async def test_Client_clears_references_after_restart(c, s, a, b):
x = c.submit(inc, 1)
assert x.key in c.refcount
assert x.key in c.futures

with pytest.raises(TimeoutError):
await c.restart(timeout=5)

await c.restart()
assert x.key not in c.refcount
assert not c.futures

key = x.key
del x
Expand All @@ -3513,14 +3516,6 @@ async def test_Client_clears_references_after_restart(c, s, a, b):
assert key not in c.refcount


@gen_cluster(Worker=Nanny, client=True)
async def test_restart_timeout_is_logged(c, s, a, b):
with captured_logger(logging.getLogger("distributed.client")) as logger:
await c.restart(timeout="0.5s")
text = logger.getvalue()
assert "Restart timed out after 0.50 seconds" in text


def test_get_stops_work_after_error(c):
with pytest.raises(RuntimeError):
c.get({"x": (throws, 1), "y": (sleep, 1.5)}, ["x", "y"])
Expand Down
Loading