Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def __init__(
connection_args=None,
timeout=None,
io_loop=None,
**kwargs,
):
self.handlers = {
"identity": self.identity,
Expand Down Expand Up @@ -238,8 +237,6 @@ def set_thread_ident():

self.__stopped = False

super().__init__(**kwargs)

@property
def status(self):
return self._status
Expand Down
18 changes: 18 additions & 0 deletions distributed/deploy/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,3 +1063,21 @@ async def test_cluster_names():

async with LocalCluster(processes=False, asynchronous=True) as unnamed_cluster2:
assert unnamed_cluster2 != unnamed_cluster


@pytest.mark.asyncio
@pytest.mark.parametrize("nanny", [True, False])
async def test_local_cluster_redundant_kwarg(nanny):
with pytest.raises(TypeError, match="unexpected keyword argument"):
# Extra arguments are forwarded to the worker class. Depending on
# whether we use the nanny or not, the error treatment is quite
# different and we should assert that an exception is raised
async with await LocalCluster(
typo_kwarg="foo", processes=nanny, n_workers=1
) as cluster:

# This will never work but is a reliable way to block without hard
# coding any sleep values
async with Client(cluster) as c:
f = c.submit(sleep, 0)
await f
209 changes: 121 additions & 88 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import weakref
from contextlib import suppress
from multiprocessing.queues import Empty
from time import sleep as sync_sleep

import psutil
from tornado import gen
Expand Down Expand Up @@ -362,7 +363,6 @@ async def instantiate(self, comm=None) -> Status:
config=self.config,
)

self.auto_restart = True
if self.death_timeout:
try:
result = await asyncio.wait_for(
Expand All @@ -378,7 +378,11 @@ async def instantiate(self, comm=None) -> Status:
raise

else:
result = await self.process.start()
try:
result = await self.process.start()
except Exception:
await self.close()
raise
return result

async def restart(self, comm=None, timeout=2, executor_wait=True):
Expand Down Expand Up @@ -414,9 +418,10 @@ def memory_monitor(self):
""" Track worker's memory. Restart if it goes above terminate fraction """
if self.status != Status.running:
return
if self.process is None or self.process.process is None:
return None
process = self.process.process
if process is None:
return

try:
proc = self._psutil_process
memory = proc.memory_info().rss
Expand Down Expand Up @@ -519,6 +524,9 @@ async def close(self, comm=None, timeout=5, report=None):


class WorkerProcess:
# The interval how often to check the msg queue for init
_init_msg_interval = 0.05

def __init__(
self,
worker_kwargs,
Expand Down Expand Up @@ -584,9 +592,14 @@ async def start(self) -> Status:
except OSError:
logger.exception("Nanny failed to start process", exc_info=True)
self.process.terminate()
return

msg = await self._wait_until_connected(uid)
self.status = Status.failed
return self.status
try:
msg = await self._wait_until_connected(uid)
except Exception:
self.status = Status.failed
self.process.terminate()
raise
if not msg:
return self.status
self.worker_address = msg["address"]
Expand Down Expand Up @@ -683,14 +696,15 @@ async def kill(self, timeout=2, executor_wait=True):
logger.error("Failed to kill worker process: %s", e)

async def _wait_until_connected(self, uid):
delay = 0.05
while True:
if self.status != Status.starting:
return
# This is a multiprocessing queue and we'd block the event loop if
# we simply called get
try:
msg = self.init_result_q.get_nowait()
except Empty:
await asyncio.sleep(delay)
await asyncio.sleep(self._init_msg_interval)
continue

if msg["uid"] != uid: # ensure that we didn't cross queues
Expand All @@ -700,7 +714,6 @@ async def _wait_until_connected(self, uid):
logger.error(
"Failed while trying to start worker process: %s", msg["exception"]
)
await self.process.join()
raise msg["exception"]
else:
return msg
Expand All @@ -718,88 +731,108 @@ def _run(
config,
Worker,
): # pragma: no cover
os.environ.update(env)
dask.config.set(config)
try:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This big diff is mostly wrapping the function in a try/except

from dask.multiprocessing import initialize_worker_process
except ImportError: # old Dask version
pass
else:
initialize_worker_process()
os.environ.update(env)
dask.config.set(config)
try:
from dask.multiprocessing import initialize_worker_process
except ImportError: # old Dask version
pass
else:
initialize_worker_process()

if silence_logs:
logger.setLevel(silence_logs)
if silence_logs:
logger.setLevel(silence_logs)

IOLoop.clear_instance()
loop = IOLoop()
loop.make_current()
worker = Worker(**worker_kwargs)
IOLoop.clear_instance()
loop = IOLoop()
loop.make_current()
worker = Worker(**worker_kwargs)

async def do_stop(timeout=5, executor_wait=True):
try:
await worker.close(
report=True,
nanny=False,
safe=True, # TODO: Graceful or not?
executor_wait=executor_wait,
timeout=timeout,
)
finally:
loop.stop()

def watch_stop_q():
"""
Wait for an incoming stop message and then stop the
worker cleanly.
"""
while True:
try:
msg = child_stop_q.get(timeout=1000)
except Empty:
pass
else:
child_stop_q.close()
assert msg.pop("op") == "stop"
loop.add_callback(do_stop, **msg)
break

t = threading.Thread(target=watch_stop_q, name="Nanny stop queue watch")
t.daemon = True
t.start()

async def run():
"""
Try to start worker and inform parent of outcome.
"""
try:
await worker
except Exception as e:
logger.exception("Failed to start worker")
init_result_q.put({"uid": uid, "exception": e})
init_result_q.close()
else:
async def do_stop(timeout=5, executor_wait=True):
try:
assert worker.address
except ValueError:
pass
else:
init_result_q.put(
{
"address": worker.address,
"dir": worker.local_directory,
"uid": uid,
}
await worker.close(
report=True,
nanny=False,
safe=True, # TODO: Graceful or not?
executor_wait=executor_wait,
timeout=timeout,
Comment on lines +754 to +759
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is new

)
finally:
loop.stop()

def watch_stop_q():
"""
Wait for an incoming stop message and then stop the
worker cleanly.
"""
while True:
try:
msg = child_stop_q.get(timeout=1000)
except Empty:
pass
else:
child_stop_q.close()
assert msg.pop("op") == "stop"
loop.add_callback(do_stop, **msg)
break

t = threading.Thread(target=watch_stop_q, name="Nanny stop queue watch")
t.daemon = True
t.start()

async def run():
"""
Try to start worker and inform parent of outcome.
"""
try:
await worker
except Exception as e:
logger.exception("Failed to start worker")
init_result_q.put({"uid": uid, "exception": e})
init_result_q.close()
await worker.finished()
logger.info("Worker closed")

try:
loop.run_sync(run)
except (TimeoutError, gen.TimeoutError):
# Loop was stopped before wait_until_closed() returned, ignore
pass
except KeyboardInterrupt:
# At this point the loop is not running thus we have to run
# do_stop() explicitly.
loop.run_sync(do_stop)
# If we hit an exception here we need to wait for a least
# one interval for the outside to pick up this message.
# Otherwise we arrive in a race condition where the process
# cleanup wipes the queue before the exception can be
# properly handled. See also
# WorkerProcess._wait_until_connected (the 2 is for good
# measure)
sync_sleep(cls._init_msg_interval * 2)
Comment on lines +794 to +801
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This resolves the race condition in test_worker_start_exception

else:
try:
assert worker.address
except ValueError:
pass
else:
init_result_q.put(
{
"address": worker.address,
"dir": worker.local_directory,
"uid": uid,
}
)
init_result_q.close()
await worker.finished()
logger.info("Worker closed")

except Exception as e:
logger.exception("Failed to initialize Worker")
init_result_q.put({"uid": uid, "exception": e})
init_result_q.close()
# If we hit an exception here we need to wait for a least one
# interval for the outside to pick up this message. Otherwise we
# arrive in a race condition where the process cleanup wipes the
# queue before the exception can be properly handled. See also
# WorkerProcess._wait_until_connected (the 2 is for good measure)
sync_sleep(cls._init_msg_interval * 2)
Comment on lines +822 to +828
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This resolves the race condition in test_failure_during_worker_initialization

else:
try:
loop.run_sync(run)
except (TimeoutError, gen.TimeoutError):
# Loop was stopped before wait_until_closed() returned, ignore
pass
except KeyboardInterrupt:
# At this point the loop is not running thus we have to run
# do_stop() explicitly.
loop.run_sync(do_stop)
6 changes: 6 additions & 0 deletions distributed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,3 +880,9 @@ async def sleep(comm=None):

# weakref set/dict should be cleaned up
assert not len(server._ongoing_coroutines)


@pytest.mark.asyncio
async def test_server_redundant_kwarg():
with pytest.raises(TypeError, match="unexpected keyword argument"):
await Server({}, typo_kwarg="foo")
12 changes: 10 additions & 2 deletions distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import dask

from distributed import Client, Nanny, Scheduler, Worker, rpc, wait, worker
from distributed.compatibility import MACOS
from distributed.core import CommClosedError, Status
from distributed.diagnostics import SchedulerPlugin
from distributed.metrics import time
Expand Down Expand Up @@ -565,10 +564,19 @@ async def start(self):
raise StartException("broken")


@pytest.mark.flaky(reruns=10, reruns_delay=5, condition=MACOS)
@pytest.mark.asyncio
async def test_worker_start_exception(cleanup):
# make sure this raises the right Exception:
with pytest.raises(StartException):
async with Nanny("tcp://localhost:1", worker_class=BrokenWorker) as n:
await n.start()


@pytest.mark.asyncio
async def test_failure_during_worker_initialization(cleanup):
with captured_logger(logger="distributed.nanny", level=logging.WARNING) as logs:
async with Scheduler() as s:
with pytest.raises(Exception):
async with Nanny(s.address, foo="bar") as n:
await n
assert "Restarting worker" not in logs.getvalue()