Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 41 additions & 27 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def set_thread_ident():
self.thread_id = threading.get_ident()

self.io_loop.add_callback(set_thread_ident)
self._startup_lock = asyncio.Lock()
self._startstop_lock = asyncio.Lock()
self.__startup_exc: Exception | None = None

self.rpc = ConnectionPool(
Expand Down Expand Up @@ -465,7 +465,7 @@ async def start_unsafe(self):

@final
async def start(self):
async with self._startup_lock:
async with self._startstop_lock:
if self.status == Status.failed:
assert self.__startup_exc is not None
raise self.__startup_exc
Expand Down Expand Up @@ -866,33 +866,47 @@ async def handle_stream(self, comm, extra=None):
await comm.close()
assert comm.closed()

async def close(self, timeout=None):
try:
for pc in self.periodic_callbacks.values():
pc.stop()

if not self.__stopped:
self.__stopped = True
_stops = set()
for listener in self.listeners:
future = listener.stop()
if inspect.isawaitable(future):
warnings.warn(
f"{type(listener)} is using an asynchronous `stop` method. "
"Support for asynchronous `Listener.stop` will be removed in a future version",
PendingDeprecationWarning,
)
_stops.add(future)
if _stops:
await asyncio.gather(*_stops)
async def close_unsafe(
self, timeout: float | None = None, reason: str | None = None, **kwargs: Any
) -> None:
"""Attempt to close the server. This is not idempotent and not protected against concurrent closing attempts.

# TODO: Deal with exceptions
await self._ongoing_background_tasks.stop()
This is intended to be overwritten or called by subclasses. For a safe
close, please use ``Server.close`` instead.
"""

await self.rpc.close()
await asyncio.gather(*[comm.close() for comm in list(self._comms)])
finally:
self._event_finished.set()
@final
async def close(
self, timeout: float | None = None, reason: str | None = None, **kwargs: Any
) -> None:
async with self._startstop_lock:
if self.status in (Status.closed, Status.closing, Status.failed):
return None

self.status = Status.closing
logger.info(
"Closing %r at %r. Reason: %s",
type(self).__name__,
self.address_safe,
reason,
)
try:
for pc in self.periodic_callbacks.values():
pc.stop()
self.periodic_callbacks.clear()
self.stop()

await asyncio.wait_for(self.close_unsafe(**kwargs), timeout)

# TODO: Deal with exceptions
await self._ongoing_background_tasks.stop()

await asyncio.gather(*[comm.close() for comm in list(self._comms)])
await self.rpc.close()
finally:
self._event_finished.set()
# TODO: This might break the worker
self.status = Status.closed


def pingpong(comm):
Expand Down
38 changes: 11 additions & 27 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
from collections.abc import Collection
from inspect import isawaitable
from queue import Empty
from time import sleep as sync_sleep
from typing import TYPE_CHECKING, Callable, ClassVar, Literal
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal

from toolz import merge
from tornado.ioloop import IOLoop
Expand Down Expand Up @@ -570,22 +569,15 @@ def close_gracefully(self, reason: str = "nanny-close-gracefully") -> None:
"Closing Nanny gracefully at %r. Reason: %s", self.address_safe, reason
)

async def close(
self, timeout: float = 5, reason: str = "nanny-close"
) -> Literal["OK"]:
async def close_unsafe(
self,
timeout: float | None = 5,
reason: str | None = "nanny-close",
**kwargs: Any,
) -> None:
"""
Close the worker process, stop all comms.
"""
if self.status == Status.closing:
await self.finished()
assert self.status == Status.closed

if self.status == Status.closed:
return "OK"

self.status = Status.closing
logger.info("Closing Nanny at %r. Reason: %s", self.address_safe, reason)

for preload in self.preloads:
await preload.teardown()

Expand All @@ -600,14 +592,13 @@ async def close(
self.stop()
try:
if self.process is not None:
assert timeout is not None
assert reason is not None
await self.kill(timeout=timeout, reason=reason)
except Exception:
logger.exception("Error in Nanny killing Worker subprocess")
self.process = None
await self.rpc.close()
self.status = Status.closed
await super().close()
return "OK"

async def _log_event(self, topic, msg):
await self.scheduler.log_event(
Expand Down Expand Up @@ -817,8 +808,8 @@ async def kill(
"reason": reason,
}
)
await asyncio.sleep(0) # otherwise we get broken pipe errors
queue.close()
queue.join_thread()
del queue

try:
Expand Down Expand Up @@ -943,14 +934,7 @@ async def run() -> None:
logger.exception(f"Failed to {failure_type} worker")
init_result_q.put({"uid": uid, "exception": e})
init_result_q.close()
# If we hit an exception here we need to wait for a least
# one interval for the outside to pick up this message.
# Otherwise we arrive in a race condition where the process
# cleanup wipes the queue before the exception can be
# properly handled. See also
# WorkerProcess._wait_until_connected (the 3 is for good
# measure)
sync_sleep(cls._init_msg_interval * 3)
init_result_q.join_thread()

Comment on lines -947 to 938
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a drive by. I can break this out in anothe rPR if necessary

with contextlib.ExitStack() as stack:

Expand Down
24 changes: 3 additions & 21 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3889,21 +3889,15 @@ def del_scheduler_file():
setproctitle(f"dask scheduler [{self.address}]")
return self

async def close(self, fast=None, close_workers=None):
async def close_unsafe(
self, timeout: float | None = None, reason: str | None = None, **kwargs
):
"""Send cleanup signal to all coroutines then wait until finished

See Also
--------
Scheduler.cleanup
"""
if fast is not None or close_workers is not None:
warnings.warn(
"The 'fast' and 'close_workers' parameters in Scheduler.close have no effect and will be removed in a future version of distributed.",
FutureWarning,
)
if self.status in (Status.closing, Status.closed):
await self.finished()
return

async def log_errors(func):
try:
Expand All @@ -3915,8 +3909,6 @@ async def log_errors(func):
*[log_errors(plugin.before_close) for plugin in list(self.plugins.values())]
)

self.status = Status.closing

logger.info("Scheduler closing...")
setproctitle("dask scheduler [closing]")

Expand All @@ -3930,10 +3922,6 @@ async def log_errors(func):
*[log_errors(plugin.close) for plugin in list(self.plugins.values())]
)

for pc in self.periodic_callbacks.values():
pc.stop()
self.periodic_callbacks.clear()

self.stop_services()

for ext in self.extensions.values():
Expand Down Expand Up @@ -3961,12 +3949,6 @@ async def log_errors(func):
for comm in self.client_comms.values():
comm.abort()

await self.rpc.close()

self.status = Status.closed
self.stop()
await super().close()

setproctitle("dask scheduler [closed]")
disable_gc_diagnosis()

Expand Down
25 changes: 24 additions & 1 deletion distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from distributed import Nanny, Scheduler, Worker, profile, rpc, wait, worker
from distributed.compatibility import LINUX, WINDOWS
from distributed.core import CommClosedError, Status
from distributed.core import CommClosedError, ConnectionPool, Status
from distributed.diagnostics import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
Expand Down Expand Up @@ -760,3 +760,26 @@ async def test_worker_inherits_temp_config(c, s):
async with Nanny(s.address):
out = await c.submit(lambda: dask.config.get("test123"))
assert out == 123


@pytest.mark.parametrize("api", ["restart", "kill"])
@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny)
async def test_restart_stress_nanny_only(c, s, a, api):
async def keep_killing():
pool = await ConnectionPool()
try:
rpc = pool(a.address)
for _ in range(2):
try:
meth = getattr(rpc, api)
await meth(reason="test-trigger")
except OSError:
break

await asyncio.sleep(0.1)
finally:
await pool.close()

kill_tasks = [asyncio.create_task(keep_killing()) for _ in range(2)]
await asyncio.gather(*kill_tasks)
assert a.status == Status.running
19 changes: 3 additions & 16 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1456,12 +1456,13 @@ async def start_unsafe(self):
return self

@log_errors
async def close( # type: ignore
async def close_unsafe( # type: ignore
self,
timeout: float = 30,
executor_wait: bool = True,
nanny: bool = True,
reason: str = "worker-close",
**kwargs,
) -> str | None:
"""Close the worker

Expand All @@ -1488,14 +1489,6 @@ async def close( # type: ignore
# is the other way round. If an external caller wants to close
# nanny+worker, the nanny must be notified first. ==> Remove kwarg
# nanny, see also Scheduler.retire_workers
if self.status in (Status.closed, Status.closing, Status.failed):
logging.debug(
"Attempted to close worker that is already %s. Reason: %s",
self.status,
reason,
)
await self.finished()
return None

if self.status == Status.init:
# If the worker is still in startup/init and is started by a nanny,
Expand Down Expand Up @@ -1526,7 +1519,7 @@ async def close( # type: ignore
pc.stop()

# Cancel async instructions
await BaseWorker.close(self, timeout=timeout)
await BaseWorker.close_unsafe(self, timeout=timeout)

for preload in self.preloads:
try:
Expand Down Expand Up @@ -1626,12 +1619,6 @@ def _close(executor, wait):
executor=executor, wait=executor_wait
) # Just run it directly

self.stop()
await self.rpc.close()

self.status = Status.closed
await ServerNode.close(self)

setproctitle("dask worker [closed]")
return "OK"

Expand Down
4 changes: 3 additions & 1 deletion distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3666,7 +3666,9 @@ def handle_stimulus(self, *stims: StateMachineEvent) -> None:
self._async_instructions.add(task)
task.add_done_callback(self._handle_stimulus_from_task)

async def close(self, timeout: float = 30) -> None:
async def close_unsafe(
self, timeout: float | None = 30, reason: str | None = None
) -> None:
"""Cancel all asynchronous instructions"""
if not self._async_instructions:
return
Expand Down