diff --git a/distributed/client.py b/distributed/client.py
index 86ca5f506e0..081aaf661c4 100644
--- a/distributed/client.py
+++ b/distributed/client.py
@@ -81,6 +81,7 @@
from distributed.sizeof import sizeof
from distributed.threadpoolexecutor import rejoin
from distributed.utils import (
+ STIMULUS_ID,
All,
Any,
CancelledError,
@@ -93,6 +94,7 @@
import_term,
log_errors,
no_default,
+ set_default_stimulus,
sync,
thread_state,
)
@@ -472,6 +474,7 @@ def __setstate__(self, state):
"tasks": {},
"keys": [stringify(self.key)],
"client": c.id,
+ "stimulus_id": f"client-update-graph-{time()}",
}
)
@@ -1367,7 +1370,7 @@ def _inc_ref(self, key):
self.refcount[key] += 1
def _dec_ref(self, key):
- with self._refcount_lock:
+ with self._refcount_lock, set_default_stimulus(f"client-dec-ref-{time()}"):
self.refcount[key] -= 1
if self.refcount[key] == 0:
del self.refcount[key]
@@ -1381,7 +1384,12 @@ def _release_key(self, key):
st.cancel()
if self.status != "closed":
self._send_to_scheduler(
- {"op": "client-releases-keys", "keys": [key], "client": self.id}
+ {
+ "op": "client-releases-keys",
+ "keys": [key],
+ "client": self.id,
+ "stimulus_id": STIMULUS_ID.get(),
+ }
)
async def _handle_report(self):
@@ -1527,31 +1535,33 @@ async def _close(self, fast=False):
):
await self.scheduler_comm.close()
- for key in list(self.futures):
- self._release_key(key=key)
+ with set_default_stimulus(f"client-close-{time()}"):
- if self._start_arg is None:
- with suppress(AttributeError):
- await self.cluster.close()
+ for key in list(self.futures):
+ self._release_key(key=key)
- await self.rpc.close()
+ if self._start_arg is None:
+ with suppress(AttributeError):
+ await self.cluster.close()
- self.status = "closed"
+ await self.rpc.close()
- if _get_global_client() is self:
- _set_global_client(None)
+ self.status = "closed"
- if (
- handle_report_task is not None
- and handle_report_task is not current_task
- ):
- with suppress(TimeoutError, asyncio.CancelledError):
- await asyncio.wait_for(handle_report_task, 0 if fast else 2)
+ if _get_global_client() is self:
+ _set_global_client(None)
- with suppress(AttributeError):
- await self.scheduler.close_rpc()
+ if (
+ handle_report_task is not None
+ and handle_report_task is not current_task
+ ):
+ with suppress(TimeoutError, asyncio.CancelledError):
+ await asyncio.wait_for(handle_report_task, 0 if fast else 2)
- self.scheduler = None
+ with suppress(AttributeError):
+ await self.scheduler.close_rpc()
+
+ self.scheduler = None
self.status = "closed"
@@ -2115,7 +2125,11 @@ async def _gather_remote(self, direct, local_worker):
response["data"].update(data2)
else: # ask scheduler to gather data for us
- response = await retry_operation(self.scheduler.gather, keys=keys)
+ response = await retry_operation(
+ self.scheduler.gather,
+ keys=keys,
+ stimulus_id=f"client-gather-remote-{time()}",
+ )
return response
@@ -2201,6 +2215,8 @@ async def _scatter(
d = await self._scatter(keymap(stringify, data), workers, broadcast)
return {k: d[stringify(k)] for k in data}
+ stimulus_id = f"client-scatter-{time()}"
+
if isinstance(data, type(range(0))):
data = list(data)
input_type = type(data)
@@ -2242,6 +2258,7 @@ async def _scatter(
who_has={key: [local_worker.address] for key in data},
nbytes=valmap(sizeof, data),
client=self.id,
+ stimulus_id=stimulus_id,
)
else:
@@ -2264,7 +2281,10 @@ async def _scatter(
)
await self.scheduler.update_data(
- who_has=who_has, nbytes=nbytes, client=self.id
+ who_has=who_has,
+ nbytes=nbytes,
+ client=self.id,
+ stimulus_id=stimulus_id,
)
else:
await self.scheduler.scatter(
@@ -2273,6 +2293,7 @@ async def _scatter(
client=self.id,
broadcast=broadcast,
timeout=timeout,
+ stimulus_id=stimulus_id,
)
out = {k: Future(k, self, inform=False) for k in data}
@@ -2396,7 +2417,12 @@ def scatter(
async def _cancel(self, futures, force=False):
keys = list({stringify(f.key) for f in futures_of(futures)})
- await self.scheduler.cancel(keys=keys, client=self.id, force=force)
+ await self.scheduler.cancel(
+ keys=keys,
+ client=self.id,
+ force=force,
+ stimulus_id=f"client-cancel-{time()}",
+ )
for k in keys:
st = self.futures.pop(k, None)
if st is not None:
@@ -2423,7 +2449,9 @@ def cancel(self, futures, asynchronous=None, force=False):
async def _retry(self, futures):
keys = list({stringify(f.key) for f in futures_of(futures)})
- response = await self.scheduler.retry(keys=keys, client=self.id)
+ response = await self.scheduler.retry(
+ keys=keys, client=self.id, stimulus_id=f"client-retry-{time()}"
+ )
for key in response:
st = self.futures[key]
st.retry()
@@ -2922,6 +2950,7 @@ def _graph_to_futures(
"fifo_timeout": fifo_timeout,
"actors": actors,
"code": self._get_computation_code(),
+ "stimulus_id": f"client-update-graph-hlg-{time()}",
}
)
return futures
@@ -3424,7 +3453,9 @@ async def _rebalance(self, futures=None, workers=None):
keys = list({stringify(f.key) for f in self.futures_of(futures)})
else:
keys = None
- result = await self.scheduler.rebalance(keys=keys, workers=workers)
+ result = await self.scheduler.rebalance(
+ keys=keys, workers=workers, stimulus_id=f"client-rebalance-{time()}"
+ )
if result["status"] == "partial-fail":
raise KeyError(f"Could not rebalance keys: {result['keys']}")
assert result["status"] == "OK", result
@@ -3459,7 +3490,11 @@ async def _replicate(self, futures, n=None, workers=None, branching_factor=2):
await _wait(futures)
keys = {stringify(f.key) for f in futures}
await self.scheduler.replicate(
- keys=list(keys), n=n, workers=workers, branching_factor=branching_factor
+ keys=list(keys),
+ n=n,
+ workers=workers,
+ branching_factor=branching_factor,
+ stimulus_id=f"client-replicate-{time()}",
)
def replicate(self, futures, n=None, workers=None, branching_factor=2, **kwargs):
@@ -4177,6 +4212,7 @@ def retire_workers(
self.scheduler.retire_workers,
workers=workers,
close_workers=close_workers,
+ stimulus_id=f"client-retire-workers-{time()}",
**kwargs,
)
diff --git a/distributed/http/templates/task.html b/distributed/http/templates/task.html
index 0b5c10695e0..f10aaad5602 100644
--- a/distributed/http/templates/task.html
+++ b/distributed/http/templates/task.html
@@ -118,16 +118,18 @@
Transition Log
Key |
Start |
Finish |
+ Stimulus ID |
Recommended Key |
Recommended Action |
- {% for key, start, finish, recommendations, transition_time in scheduler.story(Task) %}
+ {% for key, start, finish, recommendations, stimulus_id, transition_time in scheduler.story(Task) %}
| {{ fromtimestamp(transition_time) }} |
{{key}} |
{{ start }} |
{{ finish }} |
+ {{ stimulus_id }} |
|
|
@@ -137,6 +139,7 @@ Transition Log
|
|
|
+ |
{{key2}} |
{{ rec }} |
diff --git a/distributed/scheduler.py b/distributed/scheduler.py
index fc96d5d02ec..7cd12a0b9d4 100644
--- a/distributed/scheduler.py
+++ b/distributed/scheduler.py
@@ -85,6 +85,7 @@
from distributed.stealing import WorkStealing
from distributed.stories import scheduler_story
from distributed.utils import (
+ STIMULUS_ID,
All,
TimeoutError,
empty_context,
@@ -94,6 +95,7 @@
log_errors,
no_default,
recursive_to_dict,
+ set_default_stimulus,
validate_key,
)
from distributed.utils_comm import (
@@ -2370,19 +2372,26 @@ def _transition(self, key, finish: str, *args, **kwargs):
else:
raise RuntimeError("Impossible transition from %r to %r" % start_finish)
- finish2 = ts._state
# FIXME downcast antipattern
scheduler = pep484_cast(Scheduler, self)
+
+ stimulus_id = STIMULUS_ID.get(Scheduler.STIMULUS_ID_NOT_SET)
+
+ finish2 = ts._state
scheduler.transition_log.append(
- (key, start, finish2, recommendations, time())
+ (key, start, finish2, recommendations, stimulus_id, time())
)
if parent._validate:
+ if stimulus_id == Scheduler.STIMULUS_ID_NOT_SET:
+ raise LookupError(STIMULUS_ID.name)
+
logger.debug(
- "Transitioned %r %s->%s (actual: %s). Consequence: %s",
+ "Transitioned %r %s->%s (actual: %s) from %s. Consequence: %s",
key,
start,
finish2,
ts._state,
+ stimulus_id,
dict(recommendations),
)
if self.plugins:
@@ -2853,11 +2862,12 @@ def transition_processing_memory(
ws,
key,
)
+
worker_msgs[ts._processing_on.address] = [
{
"op": "cancel-compute",
"key": key,
- "stimulus_id": f"processing-memory-{time()}",
+ "stimulus_id": STIMULUS_ID.get(),
}
]
@@ -2942,10 +2952,11 @@ def transition_memory_released(self, key, safe: bint = False):
dts._waiting_on.add(ts)
# XXX factor this out?
+
worker_msg = {
"op": "free-keys",
"keys": [key],
- "stimulus_id": f"memory-released-{time()}",
+ "stimulus_id": STIMULUS_ID.get(),
}
for ws in ts._who_has:
worker_msgs[ws._address] = [worker_msg]
@@ -3048,7 +3059,7 @@ def transition_erred_released(self, key):
w_msg = {
"op": "free-keys",
"keys": [key],
- "stimulus_id": f"erred-released-{time()}",
+ "stimulus_id": STIMULUS_ID.get(),
}
for ws_addr in ts._erred_on:
worker_msgs[ws_addr] = [w_msg]
@@ -3127,7 +3138,7 @@ def transition_processing_released(self, key):
{
"op": "free-keys",
"keys": [key],
- "stimulus_id": f"processing-released-{time()}",
+ "stimulus_id": STIMULUS_ID.get(),
}
]
@@ -3318,7 +3329,13 @@ def transition_memory_forgotten(self, key):
for ws in ts._who_has:
ws._actors.discard(ts)
- _propagate_forgotten(self, ts, recommendations, worker_msgs)
+ _propagate_forgotten(
+ self,
+ ts,
+ recommendations,
+ worker_msgs,
+ STIMULUS_ID.get(),
+ )
client_msgs = _task_to_client_msgs(self, ts)
self.remove_key(key)
@@ -3356,7 +3373,13 @@ def transition_released_forgotten(self, key):
else:
assert 0, (ts,)
- _propagate_forgotten(self, ts, recommendations, worker_msgs)
+ _propagate_forgotten(
+ self,
+ ts,
+ recommendations,
+ worker_msgs,
+ STIMULUS_ID.get(),
+ )
client_msgs = _task_to_client_msgs(self, ts)
self.remove_key(key)
@@ -3682,6 +3705,7 @@ class Scheduler(SchedulerState, ServerNode):
Time we expect certain functions to take, e.g. ``{'sum': 0.25}``
"""
+ STIMULUS_ID_NOT_SET = ""
default_port = 8786
_instances: "ClassVar[weakref.WeakSet[Scheduler]]" = weakref.WeakSet()
@@ -4548,48 +4572,49 @@ async def add_worker(
recommendations: dict = {}
client_msgs: dict = {}
worker_msgs: dict = {}
- if nbytes:
- assert isinstance(nbytes, dict)
- already_released_keys = []
- for key in nbytes:
- ts: TaskState = parent._tasks.get(key) # type: ignore
- if ts is not None and ts.state != "released":
- if ts.state == "memory":
- self.add_keys(worker=address, keys=[key])
+ with set_default_stimulus(f"add-worker-{time()}"):
+ if nbytes:
+ assert isinstance(nbytes, dict)
+ already_released_keys = []
+ for key in nbytes:
+ ts: TaskState = parent._tasks.get(key) # type: ignore
+ if ts is not None and ts.state != "released":
+ if ts.state == "memory":
+ self.add_keys(worker=address, keys=[key])
+ else:
+ t: tuple = parent._transition(
+ key,
+ "memory",
+ worker=address,
+ nbytes=nbytes[key],
+ typename=types[key],
+ )
+ recommendations, client_msgs, worker_msgs = t
+ parent._transitions(
+ recommendations, client_msgs, worker_msgs
+ )
+ recommendations = {}
else:
- t: tuple = parent._transition(
- key,
- "memory",
- worker=address,
- nbytes=nbytes[key],
- typename=types[key],
+ already_released_keys.append(key)
+ if already_released_keys:
+ if address not in worker_msgs:
+ worker_msgs[address] = []
+ worker_msgs[address].append(
+ {
+ "op": "remove-replicas",
+ "keys": already_released_keys,
+ "stimulus_id": f"reconnect-already-released-{time()}",
+ }
)
- recommendations, client_msgs, worker_msgs = t
- parent._transitions(
- recommendations, client_msgs, worker_msgs
- )
- recommendations = {}
- else:
- already_released_keys.append(key)
- if already_released_keys:
- if address not in worker_msgs:
- worker_msgs[address] = []
- worker_msgs[address].append(
- {
- "op": "remove-replicas",
- "keys": already_released_keys,
- "stimulus_id": f"reconnect-already-released-{time()}",
- }
- )
- if ws._status == Status.running:
- for ts in parent._unrunnable:
- valid: set = self.valid_workers(ts)
- if valid is None or ws in valid:
- recommendations[ts._key] = "waiting"
+ if ws._status == Status.running:
+ for ts in parent._unrunnable:
+ valid: set = self.valid_workers(ts)
+ if valid is None or ws in valid:
+ recommendations[ts._key] = "waiting"
- if recommendations:
- parent._transitions(recommendations, client_msgs, worker_msgs)
+ if recommendations:
+ parent._transitions(recommendations, client_msgs, worker_msgs)
self.send_all(client_msgs, worker_msgs)
@@ -4646,46 +4671,51 @@ def update_graph_hlg(
actors=None,
fifo_timeout=0,
code=None,
+ stimulus_id=None,
):
- unpacked_graph = HighLevelGraph.__dask_distributed_unpack__(hlg)
- dsk = unpacked_graph["dsk"]
- dependencies = unpacked_graph["deps"]
- annotations = unpacked_graph["annotations"]
-
- # Remove any self-dependencies (happens on test_publish_bag() and others)
- for k, v in dependencies.items():
- deps = set(v)
- if k in deps:
- deps.remove(k)
- dependencies[k] = deps
-
- if priority is None:
- # Removing all non-local keys before calling order()
- dsk_keys = set(dsk) # intersection() of sets is much faster than dict_keys
- stripped_deps = {
- k: v.intersection(dsk_keys)
- for k, v in dependencies.items()
- if k in dsk_keys
- }
- priority = dask.order.order(dsk, dependencies=stripped_deps)
-
- return self.update_graph(
- client,
- dsk,
- keys,
- dependencies,
- restrictions,
- priority,
- loose_restrictions,
- resources,
- submitting_task,
- retries,
- user_priority,
- actors,
- fifo_timeout,
- annotations,
- code=code,
- )
+ with set_default_stimulus(stimulus_id):
+ unpacked_graph = HighLevelGraph.__dask_distributed_unpack__(hlg)
+ dsk = unpacked_graph["dsk"]
+ dependencies = unpacked_graph["deps"]
+ annotations = unpacked_graph["annotations"]
+
+ # Remove any self-dependencies (happens on test_publish_bag() and others)
+ for k, v in dependencies.items():
+ deps = set(v)
+ if k in deps:
+ deps.remove(k)
+ dependencies[k] = deps
+
+ if priority is None:
+ # Removing all non-local keys before calling order()
+ dsk_keys = set(
+ dsk
+ ) # intersection() of sets is much faster than dict_keys
+ stripped_deps = {
+ k: v.intersection(dsk_keys)
+ for k, v in dependencies.items()
+ if k in dsk_keys
+ }
+ priority = dask.order.order(dsk, dependencies=stripped_deps)
+
+ return self.update_graph(
+ client,
+ dsk,
+ keys,
+ dependencies,
+ restrictions,
+ priority,
+ loose_restrictions,
+ resources,
+ submitting_task,
+ retries,
+ user_priority,
+ actors,
+ fifo_timeout,
+ annotations,
+ code=code,
+ stimulus_id=STIMULUS_ID.get(),
+ )
def update_graph(
self,
@@ -4704,284 +4734,293 @@ def update_graph(
fifo_timeout=0,
annotations=None,
code=None,
+ stimulus_id=None,
):
"""
Add new computations to the internal dask graph
This happens whenever the Client calls submit, map, get, or compute.
"""
- parent: SchedulerState = cast(SchedulerState, self)
- start = time()
- fifo_timeout = parse_timedelta(fifo_timeout)
- keys = set(keys)
- if len(tasks) > 1:
- self.log_event(
- ["all", client], {"action": "update_graph", "count": len(tasks)}
- )
-
- # Remove aliases
- for k in list(tasks):
- if tasks[k] is k:
- del tasks[k]
+ with set_default_stimulus(stimulus_id):
+ parent: SchedulerState = cast(SchedulerState, self)
+ start = time()
+ fifo_timeout = parse_timedelta(fifo_timeout)
+ keys = set(keys)
+ if len(tasks) > 1:
+ self.log_event(
+ ["all", client], {"action": "update_graph", "count": len(tasks)}
+ )
- dependencies = dependencies or {}
+ # Remove aliases
+ for k in list(tasks):
+ if tasks[k] is k:
+ del tasks[k]
- if parent._total_occupancy > 1e-9 and parent._computations:
- # Still working on something. Assign new tasks to same computation
- computation = cast(Computation, parent._computations[-1])
- else:
- computation = Computation()
- parent._computations.append(computation)
+ dependencies = dependencies or {}
- if code and code not in computation._code: # add new code blocks
- computation._code.add(code)
+ if parent._total_occupancy > 1e-9 and parent._computations:
+ # Still working on something. Assign new tasks to same computation
+ computation = cast(Computation, parent._computations[-1])
+ else:
+ computation = Computation()
+ parent._computations.append(computation)
+
+ if code and code not in computation._code: # add new code blocks
+ computation._code.add(code)
+
+ n = 0
+ while len(tasks) != n: # walk through new tasks, cancel any bad deps
+ n = len(tasks)
+ for k, deps in list(dependencies.items()):
+ if any(
+ dep not in parent._tasks and dep not in tasks for dep in deps
+ ): # bad key
+ logger.info("User asked for computation on lost data, %s", k)
+ del tasks[k]
+ del dependencies[k]
+ if k in keys:
+ keys.remove(k)
+ self.report({"op": "cancelled-key", "key": k}, client=client)
+ self.client_releases_keys(keys=[k], client=client)
+
+ # Avoid computation that is already finished
+ ts: TaskState
+ already_in_memory = set() # tasks that are already done
+ for k, v in dependencies.items():
+ if v and k in parent._tasks:
+ ts = parent._tasks[k]
+ if ts._state in ("memory", "erred"):
+ already_in_memory.add(k)
- n = 0
- while len(tasks) != n: # walk through new tasks, cancel any bad deps
- n = len(tasks)
- for k, deps in list(dependencies.items()):
- if any(
- dep not in parent._tasks and dep not in tasks for dep in deps
- ): # bad key
- logger.info("User asked for computation on lost data, %s", k)
- del tasks[k]
- del dependencies[k]
- if k in keys:
- keys.remove(k)
- self.report({"op": "cancelled-key", "key": k}, client=client)
- self.client_releases_keys(keys=[k], client=client)
+ dts: TaskState
+ if already_in_memory:
+ dependents = dask.core.reverse_dict(dependencies)
+ stack = list(already_in_memory)
+ done = set(already_in_memory)
+ while stack: # remove unnecessary dependencies
+ key = stack.pop()
+ ts = parent._tasks[key]
+ try:
+ deps = dependencies[key]
+ except KeyError:
+ deps = self.dependencies[key]
+ for dep in deps:
+ if dep in dependents:
+ child_deps = dependents[dep]
+ else:
+ child_deps = self.dependencies[dep]
+ if all(d in done for d in child_deps):
+ if dep in parent._tasks and dep not in done:
+ done.add(dep)
+ stack.append(dep)
- # Avoid computation that is already finished
- ts: TaskState
- already_in_memory = set() # tasks that are already done
- for k, v in dependencies.items():
- if v and k in parent._tasks:
- ts = parent._tasks[k]
- if ts._state in ("memory", "erred"):
- already_in_memory.add(k)
+ for d in done:
+ tasks.pop(d, None)
+ dependencies.pop(d, None)
- dts: TaskState
- if already_in_memory:
- dependents = dask.core.reverse_dict(dependencies)
- stack = list(already_in_memory)
- done = set(already_in_memory)
- while stack: # remove unnecessary dependencies
- key = stack.pop()
- ts = parent._tasks[key]
- try:
- deps = dependencies[key]
- except KeyError:
- deps = self.dependencies[key]
- for dep in deps:
- if dep in dependents:
- child_deps = dependents[dep]
- else:
- child_deps = self.dependencies[dep]
- if all(d in done for d in child_deps):
- if dep in parent._tasks and dep not in done:
- done.add(dep)
- stack.append(dep)
-
- for d in done:
- tasks.pop(d, None)
- dependencies.pop(d, None)
-
- # Get or create task states
- stack = list(keys)
- touched_keys = set()
- touched_tasks = []
- while stack:
- k = stack.pop()
- if k in touched_keys:
- continue
- # XXX Have a method get_task_state(self, k) ?
- ts = parent._tasks.get(k)
- if ts is None:
- ts = parent.new_task(
- k, tasks.get(k), "released", computation=computation
- )
- elif not ts._run_spec:
- ts._run_spec = tasks.get(k)
+ # Get or create task states
+ stack = list(keys)
+ touched_keys = set()
+ touched_tasks = []
+ while stack:
+ k = stack.pop()
+ if k in touched_keys:
+ continue
+ # XXX Have a method get_task_state(self, k) ?
+ ts = parent._tasks.get(k)
+ if ts is None:
+ ts = parent.new_task(
+ k, tasks.get(k), "released", computation=computation
+ )
+ elif not ts._run_spec:
+ ts._run_spec = tasks.get(k)
- touched_keys.add(k)
- touched_tasks.append(ts)
- stack.extend(dependencies.get(k, ()))
+ touched_keys.add(k)
+ touched_tasks.append(ts)
+ stack.extend(dependencies.get(k, ()))
- self.client_desires_keys(keys=keys, client=client)
+ self.client_desires_keys(keys=keys, client=client)
- # Add dependencies
- for key, deps in dependencies.items():
- ts = parent._tasks.get(key)
- if ts is None or ts._dependencies:
- continue
- for dep in deps:
- dts = parent._tasks[dep]
- ts.add_dependency(dts)
-
- # Compute priorities
- if isinstance(user_priority, Number):
- user_priority = {k: user_priority for k in tasks}
-
- annotations = annotations or {}
- restrictions = restrictions or {}
- loose_restrictions = loose_restrictions or []
- resources = resources or {}
- retries = retries or {}
-
- # Override existing taxonomy with per task annotations
- if annotations:
- if "priority" in annotations:
- user_priority.update(annotations["priority"])
-
- if "workers" in annotations:
- restrictions.update(annotations["workers"])
-
- if "allow_other_workers" in annotations:
- loose_restrictions.extend(
- k for k, v in annotations["allow_other_workers"].items() if v
- )
+ # Add dependencies
+ for key, deps in dependencies.items():
+ ts = parent._tasks.get(key)
+ if ts is None or ts._dependencies:
+ continue
+ for dep in deps:
+ dts = parent._tasks[dep]
+ ts.add_dependency(dts)
+
+ # Compute priorities
+ if isinstance(user_priority, Number):
+ user_priority = {k: user_priority for k in tasks}
+
+ annotations = annotations or {}
+ restrictions = restrictions or {}
+ loose_restrictions = loose_restrictions or []
+ resources = resources or {}
+ retries = retries or {}
+
+ # Override existing taxonomy with per task annotations
+ if annotations:
+ if "priority" in annotations:
+ user_priority.update(annotations["priority"])
+
+ if "workers" in annotations:
+ restrictions.update(annotations["workers"])
+
+ if "allow_other_workers" in annotations:
+ loose_restrictions.extend(
+ k for k, v in annotations["allow_other_workers"].items() if v
+ )
- if "retries" in annotations:
- retries.update(annotations["retries"])
+ if "retries" in annotations:
+ retries.update(annotations["retries"])
+
+ if "resources" in annotations:
+ resources.update(annotations["resources"])
+
+ for a, kv in annotations.items():
+ for k, v in kv.items():
+ # Tasks might have been culled, in which case
+ # we have nothing to annotate.
+ ts = parent._tasks.get(k)
+ if ts is not None:
+ ts._annotations[a] = v
+
+ # Add actors
+ if actors is True:
+ actors = list(keys)
+ for actor in actors or []:
+ ts = parent._tasks[actor]
+ ts._actor = True
+
+ priority = priority or dask.order.order(
+ tasks
+ ) # TODO: define order wrt old graph
+
+ if submitting_task: # sub-tasks get better priority than parent tasks
+ ts = parent._tasks.get(submitting_task)
+ if ts is not None:
+ generation = ts._priority[0] - 0.01
+ else: # super-task already cleaned up
+ generation = self.generation
+ elif self._last_time + fifo_timeout < start:
+ self.generation += 1 # older graph generations take precedence
+ generation = self.generation
+ self._last_time = start
+ else:
+ generation = self.generation
- if "resources" in annotations:
- resources.update(annotations["resources"])
+ for key in set(priority) & touched_keys:
+ ts = parent._tasks[key]
+ if ts._priority is None:
+ ts._priority = (
+ -(user_priority.get(key, 0)),
+ generation,
+ priority[key],
+ )
- for a, kv in annotations.items():
- for k, v in kv.items():
- # Tasks might have been culled, in which case
- # we have nothing to annotate.
+ # Ensure all runnables have a priority
+ runnables = [ts for ts in touched_tasks if ts._run_spec]
+ for ts in runnables:
+ if ts._priority is None and ts._run_spec:
+ ts._priority = (self.generation, 0)
+
+ if restrictions:
+ # *restrictions* is a dict keying task ids to lists of
+ # restriction specifications (either worker names or addresses)
+ for k, v in restrictions.items():
+ if v is None:
+ continue
ts = parent._tasks.get(k)
- if ts is not None:
- ts._annotations[a] = v
-
- # Add actors
- if actors is True:
- actors = list(keys)
- for actor in actors or []:
- ts = parent._tasks[actor]
- ts._actor = True
-
- priority = priority or dask.order.order(
- tasks
- ) # TODO: define order wrt old graph
-
- if submitting_task: # sub-tasks get better priority than parent tasks
- ts = parent._tasks.get(submitting_task)
- if ts is not None:
- generation = ts._priority[0] - 0.01
- else: # super-task already cleaned up
- generation = self.generation
- elif self._last_time + fifo_timeout < start:
- self.generation += 1 # older graph generations take precedence
- generation = self.generation
- self._last_time = start
- else:
- generation = self.generation
-
- for key in set(priority) & touched_keys:
- ts = parent._tasks[key]
- if ts._priority is None:
- ts._priority = (-(user_priority.get(key, 0)), generation, priority[key])
-
- # Ensure all runnables have a priority
- runnables = [ts for ts in touched_tasks if ts._run_spec]
- for ts in runnables:
- if ts._priority is None and ts._run_spec:
- ts._priority = (self.generation, 0)
-
- if restrictions:
- # *restrictions* is a dict keying task ids to lists of
- # restriction specifications (either worker names or addresses)
- for k, v in restrictions.items():
- if v is None:
- continue
- ts = parent._tasks.get(k)
- if ts is None:
- continue
- ts._host_restrictions = set()
- ts._worker_restrictions = set()
- # Make sure `v` is a collection and not a single worker name / address
- if not isinstance(v, (list, tuple, set)):
- v = [v]
- for w in v:
- try:
- w = self.coerce_address(w)
- except ValueError:
- # Not a valid address, but perhaps it's a hostname
- ts._host_restrictions.add(w)
- else:
- ts._worker_restrictions.add(w)
+ if ts is None:
+ continue
+ ts._host_restrictions = set()
+ ts._worker_restrictions = set()
+ # Make sure `v` is a collection and not a single worker name / address
+ if not isinstance(v, (list, tuple, set)):
+ v = [v]
+ for w in v:
+ try:
+ w = self.coerce_address(w)
+ except ValueError:
+ # Not a valid address, but perhaps it's a hostname
+ ts._host_restrictions.add(w)
+ else:
+ ts._worker_restrictions.add(w)
- if loose_restrictions:
- for k in loose_restrictions:
- ts = parent._tasks[k]
- ts._loose_restrictions = True
+ if loose_restrictions:
+ for k in loose_restrictions:
+ ts = parent._tasks[k]
+ ts._loose_restrictions = True
- if resources:
- for k, v in resources.items():
- if v is None:
- continue
- assert isinstance(v, dict)
- ts = parent._tasks.get(k)
- if ts is None:
- continue
- ts._resource_restrictions = v
+ if resources:
+ for k, v in resources.items():
+ if v is None:
+ continue
+ assert isinstance(v, dict)
+ ts = parent._tasks.get(k)
+ if ts is None:
+ continue
+ ts._resource_restrictions = v
- if retries:
- for k, v in retries.items():
- assert isinstance(v, int)
- ts = parent._tasks.get(k)
- if ts is None:
- continue
- ts._retries = v
+ if retries:
+ for k, v in retries.items():
+ assert isinstance(v, int)
+ ts = parent._tasks.get(k)
+ if ts is None:
+ continue
+ ts._retries = v
- # Compute recommendations
- recommendations: dict = {}
+ # Compute recommendations
+ recommendations: dict = {}
- for ts in sorted(runnables, key=operator.attrgetter("priority"), reverse=True):
- if ts._state == "released" and ts._run_spec:
- recommendations[ts._key] = "waiting"
+ for ts in sorted(
+ runnables, key=operator.attrgetter("priority"), reverse=True
+ ):
+ if ts._state == "released" and ts._run_spec:
+ recommendations[ts._key] = "waiting"
- for ts in touched_tasks:
- for dts in ts._dependencies:
- if dts._exception_blame:
- ts._exception_blame = dts._exception_blame
- recommendations[ts._key] = "erred"
- break
+ for ts in touched_tasks:
+ for dts in ts._dependencies:
+ if dts._exception_blame:
+ ts._exception_blame = dts._exception_blame
+ recommendations[ts._key] = "erred"
+ break
- for plugin in list(self.plugins.values()):
- try:
- plugin.update_graph(
- self,
- client=client,
- tasks=tasks,
- keys=keys,
- restrictions=restrictions or {},
- dependencies=dependencies,
- priority=priority,
- loose_restrictions=loose_restrictions,
- resources=resources,
- annotations=annotations,
- )
- except Exception as e:
- logger.exception(e)
+ for plugin in list(self.plugins.values()):
+ try:
+ plugin.update_graph(
+ self,
+ client=client,
+ tasks=tasks,
+ keys=keys,
+ restrictions=restrictions or {},
+ dependencies=dependencies,
+ priority=priority,
+ loose_restrictions=loose_restrictions,
+ resources=resources,
+ annotations=annotations,
+ )
+ except Exception as e:
+ logger.exception(e)
- self.transitions(recommendations)
+ self.transitions(recommendations)
- for ts in touched_tasks:
- if ts._state in ("memory", "erred"):
- self.report_on_key(ts=ts, client=client)
+ for ts in touched_tasks:
+ if ts._state in ("memory", "erred"):
+ self.report_on_key(ts=ts, client=client)
- end = time()
- if self.digests is not None:
- self.digests["update-graph-duration"].add(end - start)
+ end = time()
+ if self.digests is not None:
+ self.digests["update-graph-duration"].add(end - start)
- # TODO: balance workers
+ # TODO: balance workers
def stimulus_task_finished(self, key=None, worker=None, **kwargs):
"""Mark that a task has finished execution on a particular worker"""
parent: SchedulerState = cast(SchedulerState, self)
+
logger.debug("Stimulus task finished %s, %s", key, worker)
recommendations: dict = {}
@@ -5003,7 +5042,7 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs):
{
"op": "free-keys",
"keys": [key],
- "stimulus_id": f"already-released-or-forgotten-{time()}",
+ "stimulus_id": STIMULUS_ID.get(),
}
]
elif ts._state == "memory":
@@ -5018,10 +5057,16 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs):
return recommendations, client_msgs, worker_msgs
def stimulus_task_erred(
- self, key=None, worker=None, exception=None, traceback=None, **kwargs
+ self,
+ key=None,
+ worker=None,
+ exception=None,
+ traceback=None,
+ **kwargs,
):
"""Mark that a task has erred on a particular worker"""
parent: SchedulerState = cast(SchedulerState, self)
+
logger.debug("Stimulus task erred %s, %s", key, worker)
ts: TaskState = parent._tasks.get(key)
@@ -5042,37 +5087,41 @@ def stimulus_task_erred(
**kwargs,
)
- def stimulus_retry(self, keys, client=None):
+ def stimulus_retry(self, keys, client=None, stimulus_id=None):
parent: SchedulerState = cast(SchedulerState, self)
- logger.info("Client %s requests to retry %d keys", client, len(keys))
- if client:
- self.log_event(client, {"action": "retry", "count": len(keys)})
- stack = list(keys)
- seen = set()
- roots = []
- ts: TaskState
- dts: TaskState
- while stack:
- key = stack.pop()
- seen.add(key)
- ts = parent._tasks[key]
- erred_deps = [dts._key for dts in ts._dependencies if dts._state == "erred"]
- if erred_deps:
- stack.extend(erred_deps)
- else:
- roots.append(key)
+ with set_default_stimulus(stimulus_id):
+ logger.info("Client %s requests to retry %d keys", client, len(keys))
+ if client:
+ self.log_event(client, {"action": "retry", "count": len(keys)})
- recommendations: dict = {key: "waiting" for key in roots}
- self.transitions(recommendations)
+ stack = list(keys)
+ seen = set()
+ roots = []
+ ts: TaskState
+ dts: TaskState
+ while stack:
+ key = stack.pop()
+ seen.add(key)
+ ts = parent._tasks[key]
+ erred_deps = [
+ dts._key for dts in ts._dependencies if dts._state == "erred"
+ ]
+ if erred_deps:
+ stack.extend(erred_deps)
+ else:
+ roots.append(key)
+
+ recommendations: dict = {key: "waiting" for key in roots}
+ self.transitions(recommendations)
- if parent._validate:
- for key in seen:
- assert not parent._tasks[key].exception_blame
+ if parent._validate:
+ for key in seen:
+ assert not parent._tasks[key].exception_blame
- return tuple(seen)
+ return tuple(seen)
- async def remove_worker(self, address, safe=False, close=True):
+ async def remove_worker(self, address, safe=False, close=True, stimulus_id=None):
"""
Remove worker from cluster
@@ -5081,7 +5130,9 @@ async def remove_worker(self, address, safe=False, close=True):
state.
"""
parent: SchedulerState = cast(SchedulerState, self)
- with log_errors():
+ with log_errors(), set_default_stimulus(
+ stimulus_id or f"remove-worker-{time()}"
+ ):
if self.status == Status.closed:
return
@@ -5188,70 +5239,80 @@ def remove_worker_from_events():
return "OK"
- def stimulus_cancel(self, comm, keys=None, client=None, force=False):
+ def stimulus_cancel(
+ self, comm, keys=None, client=None, force=False, stimulus_id=None
+ ):
"""Stop execution on a list of keys"""
- logger.info("Client %s requests to cancel %d keys", client, len(keys))
- if client:
- self.log_event(
- client, {"action": "cancel", "count": len(keys), "force": force}
- )
- for key in keys:
- self.cancel_key(key, client, force=force)
+ with set_default_stimulus(stimulus_id):
+ logger.info("Client %s requests to cancel %d keys", client, len(keys))
+ if client:
+ self.log_event(
+ client, {"action": "cancel", "count": len(keys), "force": force}
+ )
+ for key in keys:
+ self.cancel_key(key, client, force=force)
- def cancel_key(self, key, client, retries=5, force=False):
+ def cancel_key(self, key, client, retries=5, force=False, stimulus_id=None):
"""Cancel a particular key and all dependents"""
# TODO: this should be converted to use the transition mechanism
parent: SchedulerState = cast(SchedulerState, self)
- ts: TaskState = parent._tasks.get(key)
- dts: TaskState
- try:
- cs: ClientState = parent._clients[client]
- except KeyError:
- return
- if ts is None or not ts._who_wants: # no key yet, lets try again in a moment
- if retries:
- self.loop.call_later(
- 0.2, lambda: self.cancel_key(key, client, retries - 1)
- )
- return
- if force or ts._who_wants == {cs}: # no one else wants this key
- for dts in list(ts._dependents):
- self.cancel_key(dts._key, client, force=force)
- logger.info("Scheduler cancels key %s. Force=%s", key, force)
- self.report({"op": "cancelled-key", "key": key})
- clients = list(ts._who_wants) if force else [cs]
- for cs in clients:
- self.client_releases_keys(keys=[key], client=cs._client_key)
-
- def client_desires_keys(self, keys=None, client=None):
+ with set_default_stimulus(stimulus_id or f"cancel-key-{time()}"):
+ ts: TaskState = parent._tasks.get(key)
+ dts: TaskState
+ try:
+ cs: ClientState = parent._clients[client]
+ except KeyError:
+ return
+ if (
+ ts is None or not ts._who_wants
+ ): # no key yet, lets try again in a moment
+ if retries:
+ self.loop.call_later(
+ 0.2, lambda: self.cancel_key(key, client, retries - 1)
+ )
+ return
+ if force or ts._who_wants == {cs}: # no one else wants this key
+ for dts in list(ts._dependents):
+ self.cancel_key(dts._key, client, force=force)
+ logger.info("Scheduler cancels key %s. Force=%s", key, force)
+ self.report({"op": "cancelled-key", "key": key})
+ clients = list(ts._who_wants) if force else [cs]
+ for cs in clients:
+ self.client_releases_keys(keys=[key], client=cs._client_key)
+
+ def client_desires_keys(self, keys=None, client=None, stimulus_id=None):
parent: SchedulerState = cast(SchedulerState, self)
- cs: ClientState = parent._clients.get(client)
- if cs is None:
- # For publish, queues etc.
- parent._clients[client] = cs = ClientState(client)
- ts: TaskState
- for k in keys:
- ts = parent._tasks.get(k)
- if ts is None:
+ with set_default_stimulus(stimulus_id or f"client-desires-keys-{time()}"):
+ cs: ClientState = parent._clients.get(client)
+ if cs is None:
# For publish, queues etc.
- ts = parent.new_task(k, None, "released")
- ts._who_wants.add(cs)
- cs._wants_what.add(ts)
+ parent._clients[client] = cs = ClientState(client)
+ ts: TaskState
+ for k in keys:
+ ts = parent._tasks.get(k)
+ if ts is None:
+ # For publish, queues etc.
+ ts = parent.new_task(k, None, "released")
+ ts._who_wants.add(cs)
+ cs._wants_what.add(ts)
- if ts._state in ("memory", "erred"):
- self.report_on_key(ts=ts, client=client)
+ if ts._state in ("memory", "erred"):
+ self.report_on_key(ts=ts, client=client)
- def client_releases_keys(self, keys=None, client=None):
+ def client_releases_keys(self, keys=None, client=None, stimulus_id=None):
"""Remove keys from client desired list"""
parent: SchedulerState = cast(SchedulerState, self)
- if not isinstance(keys, list):
- keys = list(keys)
- cs: ClientState = parent._clients[client]
- recommendations: dict = {}
+ with set_default_stimulus(stimulus_id or f"client-releases-keys-{time()}"):
+ if not isinstance(keys, list):
+ keys = list(keys)
+ cs: ClientState = parent._clients[client]
+ recommendations: dict = {}
- _client_releases_keys(parent, keys=keys, cs=cs, recommendations=recommendations)
- self.transitions(recommendations)
+ _client_releases_keys(
+ parent, keys=keys, cs=cs, recommendations=recommendations
+ )
+ self.transitions(recommendations)
def client_heartbeat(self, client=None):
"""Handle heartbeats from Client"""
@@ -5468,6 +5529,17 @@ async def add_client(self, comm: Comm, client: str, versions: dict) -> None:
We listen to all future messages from this Comm.
"""
parent: SchedulerState = cast(SchedulerState, self)
+
+ if self._validate:
+ try:
+ stimulus_id = STIMULUS_ID.get()
+ raise AssertionError(
+ f"STIMULUS_ID {stimulus_id} set in Scheduler.add_client"
+ )
+
+ except LookupError:
+ pass
+
assert client is not None
comm.name = "Scheduler->Client"
logger.info("Receive client connection: %s", client)
@@ -5562,34 +5634,40 @@ def send_task_to_worker(self, worker, ts: TaskState, duration: double = -1):
def handle_uncaught_error(self, **msg):
logger.exception(clean_exception(**msg)[1])
- def handle_task_finished(self, key=None, worker=None, **msg):
+ def handle_task_finished(self, key=None, worker=None, stimulus_id=None, **msg):
parent: SchedulerState = cast(SchedulerState, self)
- if worker not in parent._workers_dv:
- return
- validate_key(key)
- recommendations: dict
- client_msgs: dict
- worker_msgs: dict
+ with set_default_stimulus(stimulus_id):
+ if worker not in parent._workers_dv:
+ return
+ validate_key(key)
- r: tuple = self.stimulus_task_finished(key=key, worker=worker, **msg)
- recommendations, client_msgs, worker_msgs = r
- parent._transitions(recommendations, client_msgs, worker_msgs)
+ recommendations: dict
+ client_msgs: dict
+ worker_msgs: dict
- self.send_all(client_msgs, worker_msgs)
+ r: tuple = self.stimulus_task_finished(key=key, worker=worker, **msg)
+ recommendations, client_msgs, worker_msgs = r
+ parent._transitions(recommendations, client_msgs, worker_msgs)
- def handle_task_erred(self, key=None, **msg):
+ self.send_all(client_msgs, worker_msgs)
+
+ def handle_task_erred(self, key=None, stimulus_id=None, **msg):
parent: SchedulerState = cast(SchedulerState, self)
recommendations: dict
client_msgs: dict
worker_msgs: dict
- r: tuple = self.stimulus_task_erred(key=key, **msg)
- recommendations, client_msgs, worker_msgs = r
- parent._transitions(recommendations, client_msgs, worker_msgs)
- self.send_all(client_msgs, worker_msgs)
+ with set_default_stimulus(stimulus_id):
+ r: tuple = self.stimulus_task_erred(key=key, **msg)
+ recommendations, client_msgs, worker_msgs = r
+ parent._transitions(recommendations, client_msgs, worker_msgs)
+
+ self.send_all(client_msgs, worker_msgs)
- def handle_missing_data(self, key=None, errant_worker=None, **kwargs):
+ def handle_missing_data(
+ self, key=None, errant_worker=None, stimulus_id=None, **kwargs
+ ):
"""Signal that `errant_worker` does not hold `key`
This may either indicate that `errant_worker` is dead or that we may be
@@ -5605,115 +5683,132 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs):
Task key that could not be found, by default None
errant_worker : str, optional
Address of the worker supposed to hold a replica, by default None
+ stimulus_id : str, optional
+ Stimulus ID that generated this function call.
"""
parent: SchedulerState = cast(SchedulerState, self)
logger.debug("handle missing data key=%s worker=%s", key, errant_worker)
- self.log_event(errant_worker, {"action": "missing-data", "key": key})
- ts: TaskState = parent._tasks.get(key)
- if ts is None:
- return
- ws: WorkerState = parent._workers_dv.get(errant_worker)
- if ws is not None and ws in ts._who_has:
- parent.remove_replica(ts, ws)
- if ts.state == "memory" and not ts._who_has:
- if ts._run_spec:
- self.transitions({key: "released"})
- else:
- self.transitions({key: "forgotten"})
+ with set_default_stimulus(stimulus_id):
+ self.log_event(errant_worker, {"action": "missing-data", "key": key})
+ ts: TaskState = parent._tasks.get(key)
+ if ts is None:
+ return
+ ws: WorkerState = parent._workers_dv.get(errant_worker)
+
+ if ws is not None and ws in ts._who_has:
+ parent.remove_replica(ts, ws)
+ if ts.state == "memory" and not ts._who_has:
+ if ts._run_spec:
+ self.transitions({key: "released"})
+ else:
+ self.transitions({key: "forgotten"})
- def release_worker_data(self, key, worker):
+ def release_worker_data(self, key, worker, stimulus_id=None):
parent: SchedulerState = cast(SchedulerState, self)
- ws: WorkerState = parent._workers_dv.get(worker)
- ts: TaskState = parent._tasks.get(key)
- if not ws or not ts:
- return
- recommendations: dict = {}
- if ws in ts._who_has:
- parent.remove_replica(ts, ws)
- if not ts._who_has:
- recommendations[ts._key] = "released"
- if recommendations:
- self.transitions(recommendations)
+ with set_default_stimulus(stimulus_id or f"release-worker-data-{time()}"):
+ ws: WorkerState = parent._workers_dv.get(worker)
+ ts: TaskState = parent._tasks.get(key)
+ if not ws or not ts:
+ return
+ recommendations: dict = {}
+ if ws in ts._who_has:
+ parent.remove_replica(ts, ws)
+ if not ts._who_has:
+ recommendations[ts._key] = "released"
+ if recommendations:
+ self.transitions(recommendations)
- def handle_long_running(self, key=None, worker=None, compute_duration=None):
+ def handle_long_running(
+ self, key=None, worker=None, compute_duration=None, stimulus_id=None
+ ):
"""A task has seceded from the thread pool
We stop the task from being stolen in the future, and change task
duration accounting as if the task has stopped.
"""
parent: SchedulerState = cast(SchedulerState, self)
- if key not in parent._tasks:
- logger.debug("Skipping long_running since key %s was already released", key)
- return
- ts: TaskState = parent._tasks[key]
- steal = parent._extensions.get("stealing")
- if steal is not None:
- steal.remove_key_from_stealable(ts)
- ws: WorkerState = ts._processing_on
- if ws is None:
- logger.debug("Received long-running signal from duplicate task. Ignoring.")
- return
+ with set_default_stimulus(stimulus_id):
+ if key not in parent._tasks:
+ logger.debug(
+ "Skipping long_running since key %s was already released", key
+ )
+ return
+ ts: TaskState = parent._tasks[key]
+ steal = parent._extensions.get("stealing")
+ if steal is not None:
+ steal.remove_key_from_stealable(ts)
- if compute_duration:
- old_duration: double = ts._prefix._duration_average
- new_duration: double = compute_duration
- avg_duration: double
- if old_duration < 0:
- avg_duration = new_duration
- else:
- avg_duration = 0.5 * old_duration + 0.5 * new_duration
-
- ts._prefix._duration_average = avg_duration
-
- occ: double = ws._processing[ts]
- ws._occupancy -= occ
- parent._total_occupancy -= occ
- # Cannot remove from processing since we're using this for things like
- # idleness detection. Idle workers are typically targeted for
- # downscaling but we should not downscale workers with long running
- # tasks
- ws._processing[ts] = 0
- ws._long_running.add(ts)
- self.check_idle_saturated(ws)
+ ws: WorkerState = ts._processing_on
+ if ws is None:
+ logger.debug(
+ "Received long-running signal from duplicate task. Ignoring."
+ )
+ return
+
+ if compute_duration:
+ old_duration: double = ts._prefix._duration_average
+ new_duration: double = compute_duration
+ avg_duration: double
+ if old_duration < 0:
+ avg_duration = new_duration
+ else:
+ avg_duration = 0.5 * old_duration + 0.5 * new_duration
+
+ ts._prefix._duration_average = avg_duration
+
+ occ: double = ws._processing[ts]
+ ws._occupancy -= occ
+ parent._total_occupancy -= occ
+ # Cannot remove from processing since we're using this for things like
+ # idleness detection. Idle workers are typically targeted for
+ # downscaling but we should not downscale workers with long running
+ # tasks
+ ws._processing[ts] = 0
+ ws._long_running.add(ts)
+ self.check_idle_saturated(ws)
- def handle_worker_status_change(self, status: str, worker: str) -> None:
+ def handle_worker_status_change(
+ self, status: str, worker: str, stimulus_id: str
+ ) -> None:
parent: SchedulerState = cast(SchedulerState, self)
- ws: WorkerState = parent._workers_dv.get(worker) # type: ignore
- if not ws:
- return
- prev_status = ws._status
- ws._status = Status.lookup[status] # type: ignore
- if ws._status == prev_status:
- return
- self.log_event(
- ws._address,
- {
- "action": "worker-status-change",
- "prev-status": prev_status.name,
- "status": status,
- },
- )
+ with set_default_stimulus(stimulus_id):
+ ws: WorkerState = parent._workers_dv.get(worker) # type: ignore
+ if not ws:
+ return
+ prev_status = ws._status
+ ws._status = Status.lookup[status] # type: ignore
+ if ws._status == prev_status:
+ return
- if ws._status == Status.running:
- parent._running.add(ws)
+ self.log_event(
+ ws._address,
+ {
+ "action": "worker-status-change",
+ "prev-status": prev_status.name,
+ "status": status,
+ },
+ )
- recs = {}
- ts: TaskState
- for ts in parent._unrunnable:
- valid: set = self.valid_workers(ts)
- if valid is None or ws in valid:
- recs[ts._key] = "waiting"
- if recs:
- client_msgs: dict = {}
- worker_msgs: dict = {}
- parent._transitions(recs, client_msgs, worker_msgs)
- self.send_all(client_msgs, worker_msgs)
+ if ws._status == Status.running:
+ parent._running.add(ws)
- else:
- parent._running.discard(ws)
+ recs = {}
+ ts: TaskState
+ for ts in parent._unrunnable:
+ valid: set = self.valid_workers(ts)
+ if valid is None or ws in valid:
+ recs[ts._key] = "waiting"
+ if recs:
+ client_msgs: dict = {}
+ worker_msgs: dict = {}
+ parent._transitions(recs, client_msgs, worker_msgs)
+ self.send_all(client_msgs, worker_msgs)
+
+ else:
+ parent._running.discard(ws)
async def handle_worker(self, comm=None, worker=None):
"""
@@ -5725,6 +5820,7 @@ async def handle_worker(self, comm=None, worker=None):
--------
Scheduler.handle_client: Equivalent coroutine for clients
"""
+
comm.name = "Scheduler connection to worker"
worker_comm = self.stream_comms[worker]
worker_comm.start(comm)
@@ -5931,6 +6027,7 @@ async def scatter(
client=None,
broadcast=False,
timeout=2,
+ stimulus_id=None,
):
"""Send data out to workers
@@ -5941,105 +6038,109 @@ async def scatter(
parent: SchedulerState = cast(SchedulerState, self)
ws: WorkerState
- start = time()
- while True:
- if workers is None:
- wss = parent._running
- else:
- workers = [self.coerce_address(w) for w in workers]
- wss = {parent._workers_dv[w] for w in workers}
- wss = {ws for ws in wss if ws._status == Status.running}
+ with set_default_stimulus(stimulus_id or f"scatter-{time()}"):
+ start = time()
+ while True:
+ if workers is None:
+ wss = parent._running
+ else:
+ workers = [self.coerce_address(w) for w in workers]
+ wss = {parent._workers_dv[w] for w in workers}
+ wss = {ws for ws in wss if ws._status == Status.running}
- if wss:
- break
- if time() > start + timeout:
- raise TimeoutError("No valid workers found")
- await asyncio.sleep(0.1)
+ if wss:
+ break
+ if time() > start + timeout:
+ raise TimeoutError("No valid workers found")
+ await asyncio.sleep(0.1)
- nthreads = {ws._address: ws.nthreads for ws in wss}
+ nthreads = {ws._address: ws.nthreads for ws in wss}
- assert isinstance(data, dict)
+ assert isinstance(data, dict)
- keys, who_has, nbytes = await scatter_to_workers(
- nthreads, data, rpc=self.rpc, report=False
- )
+ keys, who_has, nbytes = await scatter_to_workers(
+ nthreads, data, rpc=self.rpc, report=False
+ )
- self.update_data(who_has=who_has, nbytes=nbytes, client=client)
+ self.update_data(who_has=who_has, nbytes=nbytes, client=client)
- if broadcast:
- n = len(nthreads) if broadcast is True else broadcast
- await self.replicate(keys=keys, workers=workers, n=n)
+ if broadcast:
+ n = len(nthreads) if broadcast is True else broadcast
+ await self.replicate(keys=keys, workers=workers, n=n)
- self.log_event(
- [client, "all"], {"action": "scatter", "client": client, "count": len(data)}
- )
- return keys
+ self.log_event(
+ [client, "all"],
+ {"action": "scatter", "client": client, "count": len(data)},
+ )
+ return keys
- async def gather(self, keys, serializers=None):
+ async def gather(self, keys, serializers=None, stimulus_id=None):
"""Collect data from workers to the scheduler"""
parent: SchedulerState = cast(SchedulerState, self)
ws: WorkerState
- keys = list(keys)
- who_has = {}
- for key in keys:
- ts: TaskState = parent._tasks.get(key)
- if ts is not None:
- who_has[key] = [ws._address for ws in ts._who_has]
- else:
- who_has[key] = []
- data, missing_keys, missing_workers = await gather_from_workers(
- who_has, rpc=self.rpc, close=False, serializers=serializers
- )
- if not missing_keys:
- result = {"status": "OK", "data": data}
- else:
- missing_states = [
- (parent._tasks[key].state if key in parent._tasks else None)
- for key in missing_keys
- ]
- logger.exception(
- "Couldn't gather keys %s state: %s workers: %s",
- missing_keys,
- missing_states,
- missing_workers,
+ with set_default_stimulus(stimulus_id or f"gather-{time()}"):
+ keys = list(keys)
+ who_has = {}
+ for key in keys:
+ ts: TaskState = parent._tasks.get(key)
+ if ts is not None:
+ who_has[key] = [ws._address for ws in ts._who_has]
+ else:
+ who_has[key] = []
+
+ data, missing_keys, missing_workers = await gather_from_workers(
+ who_has, rpc=self.rpc, close=False, serializers=serializers
)
- result = {"status": "error", "keys": missing_keys}
- with log_errors():
- # Remove suspicious workers from the scheduler but allow them to
- # reconnect.
- await asyncio.gather(
- *(
- self.remove_worker(address=worker, close=False)
- for worker in missing_workers
- )
+ if not missing_keys:
+ result = {"status": "OK", "data": data}
+ else:
+ missing_states = [
+ (parent._tasks[key].state if key in parent._tasks else None)
+ for key in missing_keys
+ ]
+ logger.exception(
+ "Couldn't gather keys %s state: %s workers: %s",
+ missing_keys,
+ missing_states,
+ missing_workers,
)
- recommendations: dict
- client_msgs: dict = {}
- worker_msgs: dict = {}
- for key, workers in missing_keys.items():
- # Task may already be gone if it was held by a
- # `missing_worker`
- ts: TaskState = parent._tasks.get(key)
- logger.exception(
- "Workers don't have promised key: %s, %s",
- str(workers),
- str(key),
+ result = {"status": "error", "keys": missing_keys}
+ with log_errors():
+ # Remove suspicious workers from the scheduler but allow them to
+ # reconnect.
+ await asyncio.gather(
+ *(
+ self.remove_worker(address=worker, close=False)
+ for worker in missing_workers
+ )
)
- if not workers or ts is None:
- continue
- recommendations: dict = {key: "released"}
- for worker in workers:
- ws = parent._workers_dv.get(worker)
- if ws is not None and ws in ts._who_has:
- parent.remove_replica(ts, ws)
- parent._transitions(
- recommendations, client_msgs, worker_msgs
- )
- self.send_all(client_msgs, worker_msgs)
-
- self.log_event("all", {"action": "gather", "count": len(keys)})
- return result
+ recommendations: dict
+ client_msgs: dict = {}
+ worker_msgs: dict = {}
+ for key, workers in missing_keys.items():
+ # Task may already be gone if it was held by a
+ # `missing_worker`
+ ts: TaskState = parent._tasks.get(key)
+ logger.exception(
+ "Workers don't have promised key: %s, %s",
+ str(workers),
+ str(key),
+ )
+ if not workers or ts is None:
+ continue
+ recommendations: dict = {key: "released"}
+ for worker in workers:
+ ws = parent._workers_dv.get(worker)
+ if ws is not None and ws in ts._who_has:
+ parent.remove_replica(ts, ws)
+ parent._transitions(
+ recommendations, client_msgs, worker_msgs
+ )
+ self.send_all(client_msgs, worker_msgs)
+
+ self.log_event("all", {"action": "gather", "count": len(keys)})
+ return result
def clear_task_state(self):
# XXX what about nested state such as ClientState.wants_what
@@ -6048,10 +6149,10 @@ def clear_task_state(self):
for collection in self._task_state_collections:
collection.clear()
- async def restart(self, client=None, timeout=30):
+ async def restart(self, client=None, timeout=30, stimulus_id=None):
"""Restart all workers. Reset local state."""
parent: SchedulerState = cast(SchedulerState, self)
- with log_errors():
+ with log_errors(), set_default_stimulus(stimulus_id or f"restart-{time()}"):
n_workers = len(parent._workers_dv)
@@ -6265,7 +6366,9 @@ async def gather_on_worker(
return keys_failed
async def delete_worker_data(
- self, worker_address: str, keys: "Collection[str]"
+ self,
+ worker_address: str,
+ keys: "Collection[str]",
) -> None:
"""Delete data from a worker and update the corresponding worker/task states
@@ -6278,41 +6381,44 @@ async def delete_worker_data(
"""
parent: SchedulerState = cast(SchedulerState, self)
- try:
- await retry_operation(
- self.rpc(addr=worker_address).free_keys,
- keys=list(keys),
- stimulus_id=f"delete-data-{time()}",
- )
- except OSError as e:
- # This can happen e.g. if the worker is going through controlled shutdown;
- # it doesn't necessarily mean that it went unexpectedly missing
- logger.warning(
- f"Communication with worker {worker_address} failed during "
- f"replication: {e.__class__.__name__}: {e}"
- )
- return
+ with set_default_stimulus(f"delete-worker-data-{time()}"):
- ws: WorkerState = parent._workers_dv.get(worker_address) # type: ignore
- if ws is None:
- return
+ try:
+ await retry_operation(
+ self.rpc(addr=worker_address).free_keys,
+ keys=list(keys),
+ stimulus_id=f"delete-data-{time()}",
+ )
+ except OSError as e:
+ # This can happen e.g. if the worker is going through controlled shutdown;
+ # it doesn't necessarily mean that it went unexpectedly missing
+ logger.warning(
+ f"Communication with worker {worker_address} failed during "
+ f"replication: {e.__class__.__name__}: {e}"
+ )
+ return
- for key in keys:
- ts: TaskState = parent._tasks.get(key) # type: ignore
- if ts is not None and ws in ts._who_has:
- assert ts._state == "memory"
- parent.remove_replica(ts, ws)
- if not ts._who_has:
- # Last copy deleted
- self.transitions({key: "released"})
+ ws: WorkerState = parent._workers_dv.get(worker_address) # type: ignore
+ if ws is None:
+ return
+
+ for key in keys:
+ ts: TaskState = parent._tasks.get(key) # type: ignore
+ if ts is not None and ws in ts._who_has:
+ assert ts._state == "memory"
+ parent.remove_replica(ts, ws)
+ if not ts._who_has:
+ # Last copy deleted
+ self.transitions({key: "released"})
- self.log_event(ws._address, {"action": "remove-worker-data", "keys": keys})
+ self.log_event(ws._address, {"action": "remove-worker-data", "keys": keys})
async def rebalance(
self,
comm=None,
keys: "Iterable[Hashable]" = None,
workers: "Iterable[str]" = None,
+ stimulus_id=None,
) -> dict:
"""Rebalance keys so that each worker ends up with roughly the same process
memory (managed+unmanaged).
@@ -6378,10 +6484,11 @@ async def rebalance(
allowlist of workers addresses to be considered as senders or recipients.
All other workers will be ignored. The mean cluster occupancy will be
calculated only using the allowed workers.
+ stimulus_id: str, optional
+ Stimulus ID that caused this function call
"""
parent: SchedulerState = cast(SchedulerState, self)
-
- with log_errors():
+ with log_errors(), set_default_stimulus(stimulus_id or f"rebalance-{time()}"):
wss: "Collection[WorkerState]"
if workers is not None:
wss = [parent._workers_dv[w] for w in workers]
@@ -6682,6 +6789,7 @@ async def replicate(
branching_factor=2,
delete=True,
lock=True,
+ stimulus_id=None,
):
"""Replicate data throughout cluster
@@ -6699,6 +6807,8 @@ async def replicate(
The larger the branching factor, the more data we copy in
a single step, but the more a given worker risks being
swamped by data requests.
+ stimulus_id: str, optional
+ Stimulus ID that caused this function call.
See also
--------
@@ -6709,86 +6819,89 @@ async def replicate(
wws: WorkerState
ts: TaskState
- assert branching_factor > 0
- async with self._lock if lock else empty_context:
- if workers is not None:
- workers = {parent._workers_dv[w] for w in self.workers_list(workers)}
- workers = {ws for ws in workers if ws._status == Status.running}
- else:
- workers = parent._running
-
- if n is None:
- n = len(workers)
- else:
- n = min(n, len(workers))
- if n == 0:
- raise ValueError("Can not use replicate to delete data")
-
- tasks = {parent._tasks[k] for k in keys}
- missing_data = [ts._key for ts in tasks if not ts._who_has]
- if missing_data:
- return {"status": "partial-fail", "keys": missing_data}
-
- # Delete extraneous data
- if delete:
- del_worker_tasks = defaultdict(set)
- for ts in tasks:
- del_candidates = tuple(ts._who_has & workers)
- if len(del_candidates) > n:
- for ws in random.sample(
- del_candidates, len(del_candidates) - n
- ):
- del_worker_tasks[ws].add(ts)
-
- # Note: this never raises exceptions
- await asyncio.gather(
- *[
- self.delete_worker_data(ws._address, [t.key for t in tasks])
- for ws, tasks in del_worker_tasks.items()
- ]
- )
+ with set_default_stimulus(stimulus_id or f"replicate-{time()}"):
+ assert branching_factor > 0
+ async with self._lock if lock else empty_context:
+ if workers is not None:
+ workers = {
+ parent._workers_dv[w] for w in self.workers_list(workers)
+ }
+ workers = {ws for ws in workers if ws._status == Status.running}
+ else:
+ workers = parent._running
- # Copy not-yet-filled data
- while tasks:
- gathers = defaultdict(dict)
- for ts in list(tasks):
- if ts._state == "forgotten":
- # task is no longer needed by any client or dependant task
- tasks.remove(ts)
- continue
- n_missing = n - len(ts._who_has & workers)
- if n_missing <= 0:
- # Already replicated enough
- tasks.remove(ts)
- continue
+ if n is None:
+ n = len(workers)
+ else:
+ n = min(n, len(workers))
+ if n == 0:
+ raise ValueError("Can not use replicate to delete data")
- count = min(n_missing, branching_factor * len(ts._who_has))
- assert count > 0
+ tasks = {parent._tasks[k] for k in keys}
+ missing_data = [ts._key for ts in tasks if not ts._who_has]
+ if missing_data:
+ return {"status": "partial-fail", "keys": missing_data}
- for ws in random.sample(tuple(workers - ts._who_has), count):
- gathers[ws._address][ts._key] = [
- wws._address for wws in ts._who_has
+ # Delete extraneous data
+ if delete:
+ del_worker_tasks = defaultdict(set)
+ for ts in tasks:
+ del_candidates = tuple(ts._who_has & workers)
+ if len(del_candidates) > n:
+ for ws in random.sample(
+ del_candidates, len(del_candidates) - n
+ ):
+ del_worker_tasks[ws].add(ts)
+
+ # Note: this never raises exceptions
+ await asyncio.gather(
+ *[
+ self.delete_worker_data(ws._address, [t.key for t in tasks])
+ for ws, tasks in del_worker_tasks.items()
]
+ )
- await asyncio.gather(
- *(
- # Note: this never raises exceptions
- self.gather_on_worker(w, who_has)
- for w, who_has in gathers.items()
+ # Copy not-yet-filled data
+ while tasks:
+ gathers = defaultdict(dict)
+ for ts in list(tasks):
+ if ts._state == "forgotten":
+ # task is no longer needed by any client or dependant task
+ tasks.remove(ts)
+ continue
+ n_missing = n - len(ts._who_has & workers)
+ if n_missing <= 0:
+ # Already replicated enough
+ tasks.remove(ts)
+ continue
+
+ count = min(n_missing, branching_factor * len(ts._who_has))
+ assert count > 0
+
+ for ws in random.sample(tuple(workers - ts._who_has), count):
+ gathers[ws._address][ts._key] = [
+ wws._address for wws in ts._who_has
+ ]
+
+ await asyncio.gather(
+ *(
+ # Note: this never raises exceptions
+ self.gather_on_worker(w, who_has)
+ for w, who_has in gathers.items()
+ )
)
- )
- for r, v in gathers.items():
- self.log_event(r, {"action": "replicate-add", "who_has": v})
+ for r, v in gathers.items():
+ self.log_event(r, {"action": "replicate-add", "who_has": v})
- self.log_event(
- "all",
- {
- "action": "replicate",
- "workers": list(workers),
- "key-count": len(keys),
- "branching-factor": branching_factor,
- },
- )
+ self.log_event(
+ "all",
+ {
+ "action": "replicate",
+ "workers": list(workers),
+ "key-count": len(keys),
+ "branching-factor": branching_factor,
+ },
+ )
def workers_to_close(
self,
@@ -6934,6 +7047,7 @@ async def retire_workers(
names: "list | None" = None,
close_workers: bool = False,
remove: bool = True,
+ stimulus_id=None,
**kwargs,
) -> dict:
"""Gracefully retire workers from cluster
@@ -6954,6 +7068,8 @@ async def retire_workers(
remove: bool (defaults to True)
Whether or not to remove the worker metadata immediately or else
wait for the worker to contact us
+ stimulus_id: str, optional
+ Stimulus ID that caused this function call.
**kwargs: dict
Extra options to pass to workers_to_close to determine which
workers we should drop
@@ -6970,7 +7086,9 @@ async def retire_workers(
parent: SchedulerState = cast(SchedulerState, self)
ws: WorkerState
ts: TaskState
- with log_errors():
+ with log_errors(), set_default_stimulus(
+ stimulus_id or f"retire-workers-{time()}"
+ ):
# This lock makes retire_workers, rebalance, and replicate mutually
# exclusive and will no longer be necessary once rebalance and replicate are
# migrated to the Active Memory Manager.
@@ -7026,7 +7144,11 @@ async def retire_workers(
ws.status = Status.closing_gracefully
self.running.discard(ws)
self.stream_comms[ws.address].send(
- {"op": "worker-status-change", "status": ws.status.name}
+ {
+ "op": "worker-status-change",
+ "status": ws.status.name,
+ "stimulus_id": STIMULUS_ID.get(),
+ }
)
coros.append(
@@ -7071,7 +7193,11 @@ async def _track_retire_worker(
# conditions and we can wait for a scheduler->worker->scheduler
# round-trip.
self.stream_comms[ws.address].send(
- {"op": "worker-status-change", "status": prev_status.name}
+ {
+ "op": "worker-status-change",
+ "status": prev_status.name,
+ "stimulus_id": STIMULUS_ID.get(),
+ }
)
return None, {}
@@ -7100,38 +7226,35 @@ def add_keys(self, worker=None, keys=(), stimulus_id=None):
reasons. However, it is sent by workers from time to time.
"""
parent: SchedulerState = cast(SchedulerState, self)
- if worker not in parent._workers_dv:
- return "not found"
- ws: WorkerState = parent._workers_dv[worker]
- redundant_replicas = []
- for key in keys:
- ts: TaskState = parent._tasks.get(key)
- if ts is not None and ts._state == "memory":
- if ws not in ts._who_has:
- parent.add_replica(ts, ws)
- else:
- redundant_replicas.append(key)
+ with set_default_stimulus(stimulus_id or f"add-keys-{time()}"):
+ if worker not in parent._workers_dv:
+ return "not found"
+ ws: WorkerState = parent._workers_dv[worker]
+ redundant_replicas = []
+ for key in keys:
+ ts: TaskState = parent._tasks.get(key)
+ if ts is not None and ts._state == "memory":
+ if ws not in ts._who_has:
+ parent.add_replica(ts, ws)
+ else:
+ redundant_replicas.append(key)
- if redundant_replicas:
- if not stimulus_id:
- stimulus_id = f"redundant-replicas-{time()}"
- self.worker_send(
- worker,
- {
- "op": "remove-replicas",
- "keys": redundant_replicas,
- "stimulus_id": stimulus_id,
- },
- )
+ if redundant_replicas:
+ self.worker_send(
+ worker,
+ {
+ "op": "remove-replicas",
+ "keys": redundant_replicas,
+ "stimulus_id": STIMULUS_ID.get(
+ stimulus_id or f"redundant-replicas-{time()}"
+ ),
+ },
+ )
- return "OK"
+ return "OK"
def update_data(
- self,
- *,
- who_has: dict,
- nbytes: dict,
- client=None,
+ self, *, who_has: dict, nbytes: dict, client=None, stimulus_id=None
):
"""
Learn that new data has entered the network from an external source
@@ -7141,7 +7264,7 @@ def update_data(
Scheduler.mark_key_in_memory
"""
parent: SchedulerState = cast(SchedulerState, self)
- with log_errors():
+ with log_errors(), set_default_stimulus(stimulus_id or f"update-data-{time()}"):
who_has = {
k: [self.coerce_address(vv) for vv in v] for k, v in who_has.items()
}
@@ -7607,7 +7730,7 @@ def story(self, *keys):
transition_story = story
- def reschedule(self, key=None, worker=None):
+ def reschedule(self, key=None, worker=None, stimulus_id=None):
"""Reschedule a task
Things may have shifted and this task may now be better suited to run
@@ -7615,19 +7738,20 @@ def reschedule(self, key=None, worker=None):
"""
parent: SchedulerState = cast(SchedulerState, self)
ts: TaskState
- try:
- ts = parent._tasks[key]
- except KeyError:
- logger.warning(
- "Attempting to reschedule task {}, which was not "
- "found on the scheduler. Aborting reschedule.".format(key)
- )
- return
- if ts._state != "processing":
- return
- if worker and ts._processing_on.address != worker:
- return
- self.transitions({key: "released"})
+ with set_default_stimulus(stimulus_id or f"reschedule-{key}"):
+ try:
+ ts = parent._tasks[key]
+ except KeyError:
+ logger.warning(
+ "Attempting to reschedule task {}, which was not "
+ "found on the scheduler. Aborting reschedule.".format(key)
+ )
+ return
+ if ts._state != "processing":
+ return
+ if worker and ts._processing_on.address != worker:
+ return
+ self.transitions({key: "released"})
#####################
# Utility functions #
@@ -8304,7 +8428,11 @@ def _add_to_memory(
@cfunc
@exceptval(check=False)
def _propagate_forgotten(
- state: SchedulerState, ts: TaskState, recommendations: dict, worker_msgs: dict
+ state: SchedulerState,
+ ts: TaskState,
+ recommendations: dict,
+ worker_msgs: dict,
+ stimulus_id: str,
):
ts.state = "forgotten"
key: str = ts._key
@@ -8337,7 +8465,7 @@ def _propagate_forgotten(
{
"op": "free-keys",
"keys": [key],
- "stimulus_id": f"propagate-forgotten-{time()}",
+ "stimulus_id": stimulus_id,
}
]
state.remove_all_replicas(ts)
@@ -8380,7 +8508,7 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) ->
"key": ts._key,
"priority": ts._priority,
"duration": duration,
- "stimulus_id": f"compute-task-{time()}",
+ "stimulus_id": STIMULUS_ID.get(),
"who_has": {},
}
if ts._resource_restrictions:
diff --git a/distributed/stealing.py b/distributed/stealing.py
index 54ef0098c63..27a2e850ab7 100644
--- a/distributed/stealing.py
+++ b/distributed/stealing.py
@@ -17,7 +17,7 @@
from distributed.comm.addressing import get_address_host
from distributed.core import CommClosedError, Status
from distributed.diagnostics.plugin import SchedulerPlugin
-from distributed.utils import log_errors, recursive_to_dict
+from distributed.utils import log_errors, recursive_to_dict, set_default_stimulus
# Stealing requires multiple network bounces and if successful also task
# submission which may include code serialization. Therefore, be very
@@ -333,7 +333,8 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None):
self.scheduler.total_occupancy += d["thief_duration"]
self.put_key_in_stealable(ts)
- self.scheduler.send_task_to_worker(thief.address, ts)
+ with set_default_stimulus(stimulus_id):
+ self.scheduler.send_task_to_worker(thief.address, ts)
self.log(("confirm", *_log_msg))
else:
raise ValueError(f"Unexpected task state: {state}")
diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py
index 1a0df078376..af6239f24a8 100644
--- a/distributed/tests/test_cancelled_state.py
+++ b/distributed/tests/test_cancelled_state.py
@@ -6,7 +6,7 @@
from distributed.core import CommClosedError
from distributed.utils_test import (
_LockedCommPool,
- assert_worker_story,
+ assert_story,
gen_cluster,
inc,
slowinc,
@@ -82,7 +82,7 @@ def f(ev):
while "f1" in a.tasks:
await asyncio.sleep(0.01)
- assert_worker_story(
+ assert_story(
a.story("f1"),
[
("f1", "compute-task"),
diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py
index c8b594c8467..ed4c334027a 100644
--- a/distributed/tests/test_client.py
+++ b/distributed/tests/test_client.py
@@ -4576,7 +4576,7 @@ def test_auto_normalize_collection_sync(c):
def assert_no_data_loss(scheduler):
- for key, start, finish, recommendations, _ in scheduler.transition_log:
+ for key, start, finish, recommendations, _, _ in scheduler.transition_log:
if start == "memory" and finish == "released":
for k, v in recommendations.items():
assert not (k == key and v == "waiting")
diff --git a/distributed/tests/test_cluster_dump.py b/distributed/tests/test_cluster_dump.py
index b01cf2611ca..1762929d378 100644
--- a/distributed/tests/test_cluster_dump.py
+++ b/distributed/tests/test_cluster_dump.py
@@ -8,7 +8,7 @@
import distributed
from distributed.cluster_dump import DumpArtefact, _tuple_to_list, write_state
-from distributed.utils_test import assert_worker_story, gen_cluster, gen_test, inc
+from distributed.utils_test import assert_story, gen_cluster, gen_test, inc
@pytest.mark.parametrize(
@@ -140,7 +140,7 @@ async def test_cluster_dump_story(c, s, a, b, tmp_path):
assert story.keys() == {"f1", "f2"}
for k, task_story in story.items():
- assert_worker_story(
+ assert_story(
task_story,
[
(k, "compute-task"),
diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py
index 0555ab20232..dfbe6a1a0cf 100644
--- a/distributed/tests/test_failed_workers.py
+++ b/distributed/tests/test_failed_workers.py
@@ -15,7 +15,7 @@
from distributed.compatibility import MACOS
from distributed.metrics import time
from distributed.scheduler import COMPILED
-from distributed.utils import CancelledError, sync
+from distributed.utils import CancelledError, set_default_stimulus, sync
from distributed.utils_test import (
captured_logger,
cluster,
@@ -492,7 +492,8 @@ async def test_forget_data_not_supposed_to_have(s, a, b):
ts = TaskState("key", state="flight")
a.tasks["key"] = ts
recommendations = {ts: ("memory", 123)}
- a.transitions(recommendations, stimulus_id="test")
+ with set_default_stimulus("test"):
+ a.transitions(recommendations)
assert a.data
while a.data:
diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py
index 2750868dc01..be6e1d11dca 100644
--- a/distributed/tests/test_scheduler.py
+++ b/distributed/tests/test_scheduler.py
@@ -37,6 +37,7 @@
from distributed.utils import TimeoutError
from distributed.utils_test import (
BrokenComm,
+ assert_story,
captured_logger,
cluster,
dec,
@@ -810,7 +811,7 @@ async def test_story(c, s, a, b):
story = s.story(x.key)
assert all(line in s.transition_log for line in story)
assert len(story) < len(s.transition_log)
- assert all(x.key == line[0] or x.key in line[-2] for line in story)
+ assert all(x.key == line[0] or x.key in line[3] for line in story)
assert len(s.story(x.key, y.key)) > len(story)
@@ -3253,7 +3254,7 @@ async def test_worker_reconnect_task_memory(c, s, a):
await res
assert ("no-worker", "memory") in {
- (start, finish) for (_, start, finish, _, _) in s.transition_log
+ (start, finish) for (_, start, finish, _, _, _) in s.transition_log
}
@@ -3277,7 +3278,7 @@ async def test_worker_reconnect_task_memory_with_resources(c, s, a):
await res
assert ("no-worker", "memory") in {
- (start, finish) for (_, start, finish, _, _) in s.transition_log
+ (start, finish) for (_, start, finish, _, _, _) in s.transition_log
}
@@ -3538,3 +3539,40 @@ async def test_repr(s, a):
repr(ws_b)
== f""
)
+
+
+@gen_cluster(client=True)
+async def test_stimuli(c, s, a, b):
+ f = c.submit(inc, 1)
+ key = f.key
+
+ await f
+ await c.close()
+
+ assert_story(
+ s.story(key),
+ [
+ (key, "released", "waiting", {key: "processing"}),
+ (key, "waiting", "processing", {}),
+ (key, "processing", "memory", {}),
+ (
+ key,
+ "memory",
+ "forgotten",
+ {},
+ ),
+ ],
+ )
+
+ stimuli = [
+ "client-update-graph-hlg",
+ "client-update-graph-hlg",
+ "task-finished",
+ "client-releases-keys",
+ ]
+
+ stories = s.story(key)
+ assert len(stories) == len(stimuli)
+
+ for stimulus_id, story in zip(stimuli, stories):
+ assert story[-2].startswith(stimulus_id), (story[-2], stimulus_id)
diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py
index dc481aae17f..c0d377acc01 100755
--- a/distributed/tests/test_utils_test.py
+++ b/distributed/tests/test_utils_test.py
@@ -21,7 +21,7 @@
from distributed.utils_test import (
_LockedCommPool,
_UnhashableCallable,
- assert_worker_story,
+ assert_story,
check_process_leak,
cluster,
dump_cluster_state,
@@ -406,7 +406,7 @@ async def inner_test(c, s, a, b):
assert "workers" in state
-def test_assert_worker_story():
+def test_assert_story():
now = time()
story = [
("foo", "id1", now - 600),
@@ -414,38 +414,38 @@ def test_assert_worker_story():
("baz", {1: 2}, "id2", now),
]
# strict=False
- assert_worker_story(story, [("foo",), ("bar",), ("baz", {1: 2})])
- assert_worker_story(story, [])
- assert_worker_story(story, [("foo",)])
- assert_worker_story(story, [("foo",), ("bar",)])
- assert_worker_story(story, [("baz", lambda d: d[1] == 2)])
+ assert_story(story, [("foo",), ("bar",), ("baz", {1: 2})])
+ assert_story(story, [])
+ assert_story(story, [("foo",)])
+ assert_story(story, [("foo",), ("bar",)])
+ assert_story(story, [("baz", lambda d: d[1] == 2)])
with pytest.raises(AssertionError):
- assert_worker_story(story, [("foo", "nomatch")])
+ assert_story(story, [("foo", "nomatch")])
with pytest.raises(AssertionError):
- assert_worker_story(story, [("baz",)])
+ assert_story(story, [("baz",)])
with pytest.raises(AssertionError):
- assert_worker_story(story, [("baz", {1: 3})])
+ assert_story(story, [("baz", {1: 3})])
with pytest.raises(AssertionError):
- assert_worker_story(story, [("foo",), ("bar",), ("baz", "extra"), ("+1",)])
+ assert_story(story, [("foo",), ("bar",), ("baz", "extra"), ("+1",)])
with pytest.raises(AssertionError):
- assert_worker_story(story, [("baz", lambda d: d[1] == 3)])
+ assert_story(story, [("baz", lambda d: d[1] == 3)])
with pytest.raises(KeyError): # Faulty lambda
- assert_worker_story(story, [("baz", lambda d: d[2] == 1)])
- assert_worker_story([], [])
- assert_worker_story([("foo", "id1", now)], [("foo",)])
+ assert_story(story, [("baz", lambda d: d[2] == 1)])
+ assert_story([], [])
+ assert_story([("foo", "id1", now)], [("foo",)])
with pytest.raises(AssertionError):
- assert_worker_story([], [("foo",)])
+ assert_story([], [("foo",)])
# strict=True
- assert_worker_story([], [], strict=True)
- assert_worker_story([("foo", "id1", now)], [("foo",)])
- assert_worker_story(story, [("foo",), ("bar",), ("baz", {1: 2})], strict=True)
+ assert_story([], [], strict=True)
+ assert_story([("foo", "id1", now)], [("foo",)])
+ assert_story(story, [("foo",), ("bar",), ("baz", {1: 2})], strict=True)
with pytest.raises(AssertionError):
- assert_worker_story(story, [("foo",), ("bar",)], strict=True)
+ assert_story(story, [("foo",), ("bar",)], strict=True)
with pytest.raises(AssertionError):
- assert_worker_story(story, [("foo",), ("baz", {1: 2})], strict=True)
+ assert_story(story, [("foo",), ("baz", {1: 2})], strict=True)
with pytest.raises(AssertionError):
- assert_worker_story(story, [], strict=True)
+ assert_story(story, [], strict=True)
@pytest.mark.parametrize(
@@ -466,11 +466,11 @@ def test_assert_worker_story():
),
],
)
-def test_assert_worker_story_malformed_story(story_factory):
+def test_assert_story_malformed_story(story_factory):
# defer the calls to time() to when the test runs rather than collection
story = story_factory()
with pytest.raises(AssertionError, match="Malformed story event"):
- assert_worker_story(story, [])
+ assert_story(story, [])
@gen_cluster()
diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py
index aa8549b07ba..eae0e9831ff 100644
--- a/distributed/tests/test_worker.py
+++ b/distributed/tests/test_worker.py
@@ -4,6 +4,7 @@
import importlib
import logging
import os
+import re
import sys
import threading
import traceback
@@ -41,11 +42,11 @@
from distributed.metrics import time
from distributed.protocol import pickle
from distributed.scheduler import Scheduler
-from distributed.utils import TimeoutError
+from distributed.utils import TimeoutError, set_default_stimulus
from distributed.utils_test import (
TaskStateMetadataPlugin,
_LockedCommPool,
- assert_worker_story,
+ assert_story,
captured_logger,
dec,
div,
@@ -1403,7 +1404,7 @@ def assert_amm_transfer_story(key: str, w_from: Worker, w_to: Worker) -> None:
"""Test that an in-memory key was transferred from worker w_from to worker w_to by
the Active Memory Manager and it was not recalculated on w_to
"""
- assert_worker_story(
+ assert_story(
w_to.story(key),
[
(key, "ensure-task-exists", "released"),
@@ -1687,7 +1688,12 @@ async def test_story_with_deps(c, s, a, b):
# Story now includes randomized stimulus_ids and timestamps.
stimulus_ids = {ev[-2] for ev in story}
- assert len(stimulus_ids) == 3, stimulus_ids
+ assert {sid[: re.search(r"\d", sid).start()] for sid in stimulus_ids} == {
+ "ensure-computing-",
+ "ensure-communicating-",
+ "task-finished-",
+ }
+
# This is a simple transition log
expected = [
("res", "compute-task"),
@@ -1697,7 +1703,7 @@ async def test_story_with_deps(c, s, a, b):
("res", "put-in-memory"),
("res", "executing", "memory", "memory", {}),
]
- assert_worker_story(story, expected, strict=True)
+ assert_story(story, expected, strict=True)
story = b.story("dep")
stimulus_ids = {ev[-2] for ev in story}
@@ -1712,7 +1718,7 @@ async def test_story_with_deps(c, s, a, b):
("dep", "put-in-memory"),
("dep", "flight", "memory", "memory", {"res": "ready"}),
]
- assert_worker_story(story, expected, strict=True)
+ assert_story(story, expected, strict=True)
@gen_cluster(client=True)
@@ -2541,7 +2547,9 @@ def __call__(self, *args, **kwargs):
ts = s.tasks[fut.key]
a.handle_steal_request(fut.key, stimulus_id="test")
- stealing_ext.scheduler.send_task_to_worker(b.address, ts)
+
+ with set_default_stimulus("test"):
+ stealing_ext.scheduler.send_task_to_worker(b.address, ts)
fut2 = c.submit(inc, fut, workers=[a.address])
fut3 = c.submit(inc, fut2, workers=[a.address])
@@ -2666,7 +2674,7 @@ async def test_acquire_replicas_same_channel(c, s, a, b):
# same communication channel
for fut in (futA, futB):
- assert_worker_story(
+ assert_story(
b.story(fut.key),
[
("gather-dependencies", a.address, {fut.key}),
@@ -2725,7 +2733,7 @@ def __getstate__(self):
assert await y == 123
story = await c.run(lambda dask_worker: dask_worker.story("x"))
- assert_worker_story(
+ assert_story(
story[b],
[
("x", "ensure-task-exists", "released"),
@@ -2902,7 +2910,7 @@ async def test_who_has_consistent_remove_replicas(c, s, *workers):
await f2
- assert_worker_story(a.story(f1.key), [(f1.key, "missing-dep")])
+ assert_story(a.story(f1.key), [(f1.key, "missing-dep")])
assert a.tasks[f1.key].suspicious_count == 0
assert s.tasks[f1.key].suspicious == 0
@@ -2982,7 +2990,7 @@ async def test_missing_released_zombie_tasks_2(c, s, a, b):
while b.tasks:
await asyncio.sleep(0.01)
- assert_worker_story(
+ assert_story(
b.story(ts),
[("f1", "missing", "released", "released", {"f1": "forgotten"})],
)
@@ -3094,7 +3102,7 @@ async def test_task_flight_compute_oserror(c, s, a, b):
("f1", "put-in-memory"),
("f1", "executing", "memory", "memory", {}),
]
- assert_worker_story(sum_story, expected_sum_story, strict=True)
+ assert_story(sum_story, expected_sum_story, strict=True)
@gen_cluster(client=True)
diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py
index 78597a37e67..214f6dba251 100644
--- a/distributed/tests/test_worker_state_machine.py
+++ b/distributed/tests/test_worker_state_machine.py
@@ -86,6 +86,7 @@ def test_unique_task_heap():
assert repr(heap) == ""
+@pytest.mark.xfail(reason="slots not compatible with defaultfactory")
@pytest.mark.parametrize(
"cls",
chain(
@@ -107,5 +108,9 @@ def test_slots(cls):
def test_sendmsg_to_dict():
# Arbitrary sample class
- smsg = ReleaseWorkerDataMsg(key="x")
- assert smsg.to_dict() == {"op": "release-worker-data", "key": "x"}
+ smsg = ReleaseWorkerDataMsg(key="x", stimulus_id="test")
+ assert smsg.to_dict() == {
+ "op": "release-worker-data",
+ "key": "x",
+ "stimulus_id": "test",
+ }
diff --git a/distributed/utils.py b/distributed/utils.py
index cd0d93d08dd..4fe311e5bd2 100644
--- a/distributed/utils.py
+++ b/distributed/utils.py
@@ -23,13 +23,13 @@
from collections.abc import Collection, Container, KeysView, ValuesView
from concurrent.futures import CancelledError, ThreadPoolExecutor # noqa: F401
from contextlib import contextmanager, suppress
-from contextvars import ContextVar
+from contextvars import ContextVar, Token
from hashlib import md5
from importlib.util import cache_from_source
from time import sleep
from types import ModuleType
from typing import Any as AnyType
-from typing import ClassVar
+from typing import ClassVar, Iterator
import click
import tblib.pickling_support
@@ -67,6 +67,36 @@
no_default = "__no_default__"
+STIMULUS_ID: ContextVar[str] = ContextVar("STIMULUS_ID")
+
+
+@contextmanager
+def set_default_stimulus(name: str) -> Iterator[str]:
+ """Context manager for setting the Scheduler stimulus_id
+
+ If the stimulus_id has already been set further up the call stack,
+ this has no effect.
+
+ Parameters
+ ----------
+ name : str
+ The name of the stimulus.
+ """
+ token: Token[str] | None
+ try:
+ stimulus_id = STIMULUS_ID.get()
+ except LookupError:
+ token = STIMULUS_ID.set(name)
+ stimulus_id = name
+ else:
+ token = None
+
+ try:
+ yield stimulus_id
+ finally:
+ if token:
+ STIMULUS_ID.reset(token)
+
def _initialize_mp_context():
method = dask.config.get("distributed.worker.multiprocessing-method")
diff --git a/distributed/utils_test.py b/distributed/utils_test.py
index 42b58b30588..0ea9c9cb62b 100644
--- a/distributed/utils_test.py
+++ b/distributed/utils_test.py
@@ -1878,7 +1878,7 @@ def xfail_ssl_issue5601():
raise
-def assert_worker_story(
+def assert_story(
story: list[tuple], expect: list[tuple], *, strict: bool = False
) -> None:
"""Test the output of ``Worker.story``
@@ -1944,7 +1944,7 @@ def assert_worker_story(
break
except StopIteration:
raise AssertionError(
- f"assert_worker_story({strict=}) failed\n"
+ f"assert_story({strict=}) failed\n"
f"story:\n{_format_story(story)}\n"
f"expect:\n{_format_story(expect)}"
) from None
diff --git a/distributed/worker.py b/distributed/worker.py
index 7a0628763dd..ceeb70f7477 100644
--- a/distributed/worker.py
+++ b/distributed/worker.py
@@ -25,6 +25,7 @@
)
from concurrent.futures import Executor
from contextlib import suppress
+from contextvars import copy_context
from datetime import timedelta
from inspect import isawaitable
from pickle import PicklingError
@@ -79,6 +80,7 @@
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
LRU,
+ STIMULUS_ID,
TimeoutError,
_maybe_complex,
get_ip,
@@ -92,6 +94,7 @@
offload,
parse_ports,
recursive_to_dict,
+ set_default_stimulus,
silence_logging,
thread_state,
warn_on_duration,
@@ -919,10 +922,11 @@ def status(self, value):
"""
prev_status = self.status
ServerNode.status.__set__(self, value)
- self._send_worker_status_change()
- if prev_status == Status.paused and value == Status.running:
- self.ensure_computing()
- self.ensure_communicating()
+ with set_default_stimulus(f"init-{time()}"):
+ self._send_worker_status_change()
+ if prev_status == Status.paused and value == Status.running:
+ self.ensure_computing()
+ self.ensure_communicating()
def _send_worker_status_change(self) -> None:
if (
@@ -931,7 +935,11 @@ def _send_worker_status_change(self) -> None:
and not self.batched_stream.comm.closed()
):
self.batched_stream.send(
- {"op": "worker-status-change", "status": self._status.name}
+ {
+ "op": "worker-status-change",
+ "status": self._status.name,
+ "stimulus_id": STIMULUS_ID.get(),
+ }
)
elif self._status != Status.closed:
self.loop.call_later(0.05, self._send_worker_status_change)
@@ -1657,8 +1665,7 @@ def update_data(
report: bool = True,
stimulus_id: str = None,
) -> dict[str, Any]:
- if stimulus_id is None:
- stimulus_id = f"update-data-{time()}"
+ STIMULUS_ID.set(stimulus_id or f"update-data-{time()}")
recommendations: Recs = {}
instructions: Instructions = []
for key, value in data.items():
@@ -1669,19 +1676,19 @@ def update_data(
self.tasks[key] = ts = TaskState(key)
try:
- recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id)
+ recs = self._put_key_in_memory(ts, value)
except Exception as e:
msg = error_message(e)
recommendations = {ts: tuple(msg.values())}
else:
recommendations.update(recs)
- self.log.append((key, "receive-from-scatter", stimulus_id, time()))
+ self.log.append((key, "receive-from-scatter", STIMULUS_ID.get(), time()))
if report:
- instructions.append(AddKeysMsg(keys=list(data), stimulus_id=stimulus_id))
+ instructions.append(AddKeysMsg(keys=list(data)))
- self.transitions(recommendations, stimulus_id=stimulus_id)
+ self.transitions(recommendations)
self._handle_instructions(instructions)
return {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"}
@@ -1696,14 +1703,15 @@ def handle_free_keys(self, keys: list[str], stimulus_id: str) -> None:
still decide to hold on to the data and task since it is required by an
upstream dependency.
"""
- self.log.append(("free-keys", keys, stimulus_id, time()))
- recommendations: Recs = {}
- for key in keys:
- ts = self.tasks.get(key)
- if ts:
- recommendations[ts] = "released"
+ with set_default_stimulus(stimulus_id):
+ self.log.append(("free-keys", keys, STIMULUS_ID.get(), time()))
+ recommendations: Recs = {}
+ for key in keys:
+ ts = self.tasks.get(key)
+ if ts:
+ recommendations[ts] = "released"
- self.transitions(recommendations, stimulus_id=stimulus_id)
+ self.transitions(recommendations)
def handle_remove_replicas(self, keys: list[str], stimulus_id: str) -> str:
"""Stream handler notifying the worker that it might be holding unreferenced,
@@ -1724,28 +1732,32 @@ def handle_remove_replicas(self, keys: list[str], stimulus_id: str) -> str:
For stronger guarantees, see handler free_keys
"""
- self.log.append(("remove-replicas", keys, stimulus_id, time()))
- recommendations: Recs = {}
- rejected = []
- for key in keys:
- ts = self.tasks.get(key)
- if ts is None or ts.state != "memory":
- continue
- if not ts.is_protected():
+ with set_default_stimulus(stimulus_id):
+ self.log.append(("remove-replicas", keys, STIMULUS_ID.get(), time()))
+ recommendations: Recs = {}
+
+ rejected = []
+ for key in keys:
+ ts = self.tasks.get(key)
+ if ts is None or ts.state != "memory":
+ continue
+ if not ts.is_protected():
+ self.log.append(
+ (ts.key, "remove-replica-confirmed", STIMULUS_ID.get(), time())
+ )
+ recommendations[ts] = "released"
+ else:
+ rejected.append(key)
+
+ if rejected:
self.log.append(
- (ts.key, "remove-replica-confirmed", stimulus_id, time())
+ ("remove-replica-rejected", rejected, STIMULUS_ID.get(), time())
)
- recommendations[ts] = "released"
- else:
- rejected.append(key)
-
- if rejected:
- self.log.append(("remove-replica-rejected", rejected, stimulus_id, time()))
- smsg = AddKeysMsg(keys=rejected, stimulus_id=stimulus_id)
- self._handle_instructions([smsg])
+ smsg = AddKeysMsg(keys=rejected)
+ self._handle_instructions([smsg])
- self.transitions(recommendations, stimulus_id=stimulus_id)
+ self.transitions(recommendations)
return "OK"
@@ -1773,7 +1785,7 @@ def handle_cancel_compute(self, key: str, stimulus_id: str) -> None:
is in state `waiting` or `ready`.
Nothing will happen otherwise.
"""
- self.handle_stimulus(CancelComputeEvent(key=key, stimulus_id=stimulus_id))
+ self.handle_stimulus(CancelComputeEvent(key=key))
def handle_acquire_replicas(
self,
@@ -1782,36 +1794,38 @@ def handle_acquire_replicas(
who_has: dict[str, Collection[str]],
stimulus_id: str,
) -> None:
- recommendations: Recs = {}
- for key in keys:
- ts = self.ensure_task_exists(
- key=key,
- # Transfer this data after all dependency tasks of computations with
- # default or explicitly high (>0) user priority and before all
- # computations with low priority (<0). Note that the priority= parameter
- # of compute() is multiplied by -1 before it reaches TaskState.priority.
- priority=(1,),
- stimulus_id=stimulus_id,
- )
- if ts.state != "memory":
- recommendations[ts] = "fetch"
+ with set_default_stimulus(stimulus_id):
+ recommendations: Recs = {}
+ for key in keys:
+ ts = self.ensure_task_exists(
+ key=key,
+ # Transfer this data after all dependency tasks of computations with
+ # default or explicitly high (>0) user priority and before all
+ # computations with low priority (<0). Note that the priority= parameter
+ # of compute() is multiplied by -1 before it reaches TaskState.priority.
+ priority=(1,),
+ )
+ if ts.state != "memory":
+ recommendations[ts] = "fetch"
- self.update_who_has(who_has)
- self.transitions(recommendations, stimulus_id=stimulus_id)
+ self.update_who_has(who_has)
+ self.transitions(recommendations)
- def ensure_task_exists(
- self, key: str, *, priority: tuple[int, ...], stimulus_id: str
- ) -> TaskState:
+ def ensure_task_exists(self, key: str, *, priority: tuple[int, ...]) -> TaskState:
try:
ts = self.tasks[key]
- logger.debug("Data task %s already known (stimulus_id=%s)", ts, stimulus_id)
+ logger.debug(
+ "Data task %s already known (stimulus_id=%s)", ts, STIMULUS_ID.get()
+ )
except KeyError:
self.tasks[key] = ts = TaskState(key)
if not ts.priority:
assert priority
ts.priority = priority
- self.log.append((key, "ensure-task-exists", ts.state, stimulus_id, time()))
+ self.log.append(
+ (key, "ensure-task-exists", ts.state, STIMULUS_ID.get(), time())
+ )
return ts
def handle_compute_task(
@@ -1831,75 +1845,78 @@ def handle_compute_task(
annotations: dict | None = None,
stimulus_id: str,
) -> None:
- self.log.append((key, "compute-task", stimulus_id, time()))
- try:
- ts = self.tasks[key]
- logger.debug(
- "Asked to compute an already known task %s",
- {"task": ts, "stimulus_id": stimulus_id},
- )
- except KeyError:
- self.tasks[key] = ts = TaskState(key)
-
- ts.run_spec = SerializedTask(function, args, kwargs, task)
+ with set_default_stimulus(stimulus_id):
+ self.log.append((key, "compute-task", STIMULUS_ID.get(), time()))
+ try:
+ ts = self.tasks[key]
+ logger.debug(
+ "Asked to compute an already known task %s",
+ {"task": ts, "stimulus_id": STIMULUS_ID.get()},
+ )
+ except KeyError:
+ self.tasks[key] = ts = TaskState(key)
- assert isinstance(priority, tuple)
- priority = priority + (self.generation,)
- self.generation -= 1
+ ts.run_spec = SerializedTask(function, args, kwargs, task)
- if actor:
- self.actors[ts.key] = None
+ assert isinstance(priority, tuple)
+ priority = priority + (self.generation,)
+ self.generation -= 1
- ts.exception = None
- ts.traceback = None
- ts.exception_text = ""
- ts.traceback_text = ""
- ts.priority = priority
- ts.duration = duration
- if resource_restrictions:
- ts.resource_restrictions = resource_restrictions
- ts.annotations = annotations
+ if actor:
+ self.actors[ts.key] = None
- recommendations: Recs = {}
- instructions: Instructions = []
- for dependency in who_has:
- dep_ts = self.ensure_task_exists(
- key=dependency,
- priority=priority,
- stimulus_id=stimulus_id,
- )
+ ts.exception = None
+ ts.traceback = None
+ ts.exception_text = ""
+ ts.traceback_text = ""
+ ts.priority = priority
+ ts.duration = duration
+ if resource_restrictions:
+ ts.resource_restrictions = resource_restrictions
+ ts.annotations = annotations
+
+ recommendations: Recs = {}
+ instructions: Instructions = []
+ for dependency in who_has:
+ dep_ts = self.ensure_task_exists(
+ key=dependency,
+ priority=priority,
+ )
- # link up to child / parents
- ts.dependencies.add(dep_ts)
- dep_ts.dependents.add(ts)
+ # link up to child / parents
+ ts.dependencies.add(dep_ts)
+ dep_ts.dependents.add(ts)
- if ts.state in READY | {"executing", "waiting", "resumed"}:
- pass
- elif ts.state == "memory":
- recommendations[ts] = "memory"
- instructions.append(self._get_task_finished_msg(ts))
- elif ts.state in {
- "released",
- "fetch",
- "flight",
- "missing",
- "cancelled",
- "error",
- }:
- recommendations[ts] = "waiting"
- else: # pragma: no cover
- raise RuntimeError(f"Unexpected task state encountered {ts} {stimulus_id}")
+ if ts.state in READY | {"executing", "waiting", "resumed"}:
+ pass
+ elif ts.state == "memory":
+ recommendations[ts] = "memory"
+ instructions.append(self._get_task_finished_msg(ts))
+ elif ts.state in {
+ "released",
+ "fetch",
+ "flight",
+ "missing",
+ "cancelled",
+ "error",
+ }:
+ recommendations[ts] = "waiting"
+ else: # pragma: no cover
+ raise RuntimeError(
+ f"Unexpected task state encountered {ts} {STIMULUS_ID.get()}"
+ )
- self._handle_instructions(instructions)
- self.update_who_has(who_has)
- self.transitions(recommendations, stimulus_id=stimulus_id)
+ self._handle_instructions(instructions)
+ self.update_who_has(who_has)
+ self.transitions(recommendations)
- if nbytes is not None:
- for key, value in nbytes.items():
- self.tasks[key].nbytes = value
+ if nbytes is not None:
+ for key, value in nbytes.items():
+ self.tasks[key].nbytes = value
def transition_missing_fetch(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
if self.validate:
assert ts.state == "missing"
@@ -1912,17 +1929,17 @@ def transition_missing_fetch(
return {}, []
def transition_missing_released(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
self._missing_dep_flight.discard(ts)
- recs, instructions = self.transition_generic_released(
- ts, stimulus_id=stimulus_id
- )
+ recs, instructions = self.transition_generic_released(ts)
assert ts.key in self.tasks
return recs, instructions
def transition_flight_missing(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
assert ts.done
ts.state = "missing"
@@ -1931,7 +1948,8 @@ def transition_flight_missing(
return {}, []
def transition_released_fetch(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
if self.validate:
assert ts.state == "released"
@@ -1944,9 +1962,10 @@ def transition_released_fetch(
return {}, []
def transition_generic_released(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
- self.release_key(ts.key, stimulus_id=stimulus_id)
+ self.release_key(ts.key)
recs: Recs = {}
for dependency in ts.dependencies:
if (
@@ -1961,7 +1980,8 @@ def transition_generic_released(
return recs, []
def transition_released_waiting(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
if self.validate:
assert ts.state == "released"
@@ -1987,7 +2007,9 @@ def transition_released_waiting(
return recommendations, []
def transition_fetch_flight(
- self, ts: TaskState, worker, *, stimulus_id: str
+ self,
+ ts: TaskState,
+ worker,
) -> RecsInstrs:
if self.validate:
assert ts.state == "fetch"
@@ -2000,16 +2022,16 @@ def transition_fetch_flight(
return {}, []
def transition_memory_released(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
- recs, instructions = self.transition_generic_released(
- ts, stimulus_id=stimulus_id
- )
- instructions.append(ReleaseWorkerDataMsg(ts.key))
+ recs, instructions = self.transition_generic_released(ts)
+ instructions.append(ReleaseWorkerDataMsg(key=ts.key))
return recs, instructions
def transition_waiting_constrained(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
if self.validate:
assert ts.state == "waiting"
@@ -2025,14 +2047,16 @@ def transition_waiting_constrained(
return {}, []
def transition_long_running_rescheduled(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
recs: Recs = {ts: "released"}
smsg = RescheduleMsg(key=ts.key, worker=self.address)
return recs, [smsg]
def transition_executing_rescheduled(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
for resource, quantity in ts.resource_restrictions.items():
self.available_resources[resource] += quantity
@@ -2043,7 +2067,8 @@ def transition_executing_rescheduled(
return recs, [smsg]
def transition_waiting_ready(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
if self.validate:
assert ts.state == "waiting"
@@ -2066,8 +2091,6 @@ def transition_cancelled_error(
traceback: Serialize | None,
exception_text: str,
traceback_text: str,
- *,
- stimulus_id: str,
) -> RecsInstrs:
recs: Recs = {}
instructions: Instructions = []
@@ -2078,7 +2101,6 @@ def transition_cancelled_error(
traceback,
exception_text,
traceback_text,
- stimulus_id=stimulus_id,
)
elif ts._previous == "flight":
recs, instructions = self.transition_flight_error(
@@ -2087,7 +2109,6 @@ def transition_cancelled_error(
traceback,
exception_text,
traceback_text,
- stimulus_id=stimulus_id,
)
if ts._next:
recs[ts] = ts._next
@@ -2100,8 +2121,6 @@ def transition_generic_error(
traceback: Serialize | None,
exception_text: str,
traceback_text: str,
- *,
- stimulus_id: str,
) -> RecsInstrs:
ts.exception = exception
ts.traceback = traceback
@@ -2127,8 +2146,6 @@ def transition_executing_error(
traceback: Serialize | None,
exception_text: str,
traceback_text: str,
- *,
- stimulus_id: str,
) -> RecsInstrs:
for resource, quantity in ts.resource_restrictions.items():
self.available_resources[resource] += quantity
@@ -2139,11 +2156,12 @@ def transition_executing_error(
traceback,
exception_text,
traceback_text,
- stimulus_id=stimulus_id,
)
def _transition_from_resumed(
- self, ts: TaskState, finish: TaskStateState, *, stimulus_id: str
+ self,
+ ts: TaskState,
+ finish: TaskStateState,
) -> RecsInstrs:
"""`resumed` is an intermediate degenerate state which splits further up
into two states depending on what the last signal / next state is
@@ -2171,9 +2189,7 @@ def _transition_from_resumed(
# if the next state is already intended to be waiting or if the
# coro/thread is still running (ts.done==False), this is a noop
if ts._next != finish:
- recs, instructions = self.transition_generic_released(
- ts, stimulus_id=stimulus_id
- )
+ recs, instructions = self.transition_generic_released(ts)
assert next_state
recs[ts] = next_state
else:
@@ -2181,29 +2197,35 @@ def _transition_from_resumed(
return recs, instructions
def transition_resumed_fetch(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
"""
See Worker._transition_from_resumed
"""
- return self._transition_from_resumed(ts, "fetch", stimulus_id=stimulus_id)
+ return self._transition_from_resumed(ts, "fetch")
def transition_resumed_missing(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
"""
See Worker._transition_from_resumed
"""
- return self._transition_from_resumed(ts, "missing", stimulus_id=stimulus_id)
+ return self._transition_from_resumed(ts, "missing")
- def transition_resumed_waiting(self, ts: TaskState, *, stimulus_id: str):
+ def transition_resumed_waiting(
+ self,
+ ts: TaskState,
+ ):
"""
See Worker._transition_from_resumed
"""
- return self._transition_from_resumed(ts, "waiting", stimulus_id=stimulus_id)
+ return self._transition_from_resumed(ts, "waiting")
def transition_cancelled_fetch(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
if ts.done:
return {ts: "released"}, []
@@ -2215,14 +2237,17 @@ def transition_cancelled_fetch(
return {ts: ("resumed", "fetch")}, []
def transition_cancelled_resumed(
- self, ts: TaskState, next: TaskStateState, *, stimulus_id: str
+ self,
+ ts: TaskState,
+ next: TaskStateState,
) -> RecsInstrs:
ts._next = next
ts.state = "resumed"
return {}, []
def transition_cancelled_waiting(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
if ts.done:
return {ts: "released"}, []
@@ -2234,7 +2259,8 @@ def transition_cancelled_waiting(
return {ts: ("resumed", "waiting")}, []
def transition_cancelled_forgotten(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
ts._next = "forgotten"
if not ts.done:
@@ -2242,7 +2268,8 @@ def transition_cancelled_forgotten(
return {ts: "released"}, []
def transition_cancelled_released(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
if not ts.done:
ts._next = "released"
@@ -2254,15 +2281,14 @@ def transition_cancelled_released(
for resource, quantity in ts.resource_restrictions.items():
self.available_resources[resource] += quantity
- recs, instructions = self.transition_generic_released(
- ts, stimulus_id=stimulus_id
- )
+ recs, instructions = self.transition_generic_released(ts)
if next_state != "released":
recs[ts] = next_state
return recs, instructions
def transition_executing_released(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
ts._previous = ts.state
ts._next = "released"
@@ -2272,13 +2298,17 @@ def transition_executing_released(
return {}, []
def transition_long_running_memory(
- self, ts: TaskState, value=no_value, *, stimulus_id: str
+ self,
+ ts: TaskState,
+ value=no_value,
) -> RecsInstrs:
self.executed_count += 1
- return self.transition_generic_memory(ts, value=value, stimulus_id=stimulus_id)
+ return self.transition_generic_memory(ts, value=value)
def transition_generic_memory(
- self, ts: TaskState, value=no_value, *, stimulus_id: str
+ self,
+ ts: TaskState,
+ value=no_value,
) -> RecsInstrs:
if value is no_value and ts.key not in self.data:
raise RuntimeError(
@@ -2293,7 +2323,7 @@ def transition_generic_memory(
self._in_flight_tasks.discard(ts)
ts.coming_from = None
try:
- recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id)
+ recs = self._put_key_in_memory(ts, value)
except Exception as e:
msg = error_message(e)
recs = {ts: tuple(msg.values())}
@@ -2304,7 +2334,9 @@ def transition_generic_memory(
return recs, [smsg]
def transition_executing_memory(
- self, ts: TaskState, value=no_value, *, stimulus_id: str
+ self,
+ ts: TaskState,
+ value=no_value,
) -> RecsInstrs:
if self.validate:
assert ts.state == "executing" or ts.key in self.long_running
@@ -2313,10 +2345,11 @@ def transition_executing_memory(
self._executing.discard(ts)
self.executed_count += 1
- return self.transition_generic_memory(ts, value=value, stimulus_id=stimulus_id)
+ return self.transition_generic_memory(ts, value=value)
def transition_constrained_executing(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
if self.validate:
assert not ts.waiting_for_data
@@ -2330,11 +2363,12 @@ def transition_constrained_executing(
self.available_resources[resource] -= quantity
ts.state = "executing"
self._executing.add(ts)
- instr = Execute(key=ts.key, stimulus_id=stimulus_id)
+ instr = Execute(key=ts.key)
return {}, [instr]
def transition_ready_executing(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
if self.validate:
assert not ts.waiting_for_data
@@ -2348,10 +2382,13 @@ def transition_ready_executing(
ts.state = "executing"
self._executing.add(ts)
- instr = Execute(key=ts.key, stimulus_id=stimulus_id)
+ instr = Execute(key=ts.key)
return {}, [instr]
- def transition_flight_fetch(self, ts: TaskState, *, stimulus_id: str) -> RecsInstrs:
+ def transition_flight_fetch(
+ self,
+ ts: TaskState,
+ ) -> RecsInstrs:
# If this transition is called after the flight coroutine has finished,
# we can reset the task and transition to fetch again. If it is not yet
# finished, this should be a no-op
@@ -2377,8 +2414,6 @@ def transition_flight_error(
traceback: Serialize | None,
exception_text: str,
traceback_text: str,
- *,
- stimulus_id: str,
) -> RecsInstrs:
self._in_flight_tasks.discard(ts)
ts.coming_from = None
@@ -2388,16 +2423,16 @@ def transition_flight_error(
traceback,
exception_text,
traceback_text,
- stimulus_id=stimulus_id,
)
def transition_flight_released(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
if ts.done:
# FIXME: Is this even possible? Would an assert instead be more
# sensible?
- return self.transition_generic_released(ts, stimulus_id=stimulus_id)
+ return self.transition_generic_released(ts)
else:
ts._previous = "flight"
ts._next = "released"
@@ -2406,13 +2441,17 @@ def transition_flight_released(
return {}, []
def transition_cancelled_memory(
- self, ts: TaskState, value, *, stimulus_id: str
+ self,
+ ts: TaskState,
+ value,
) -> RecsInstrs:
assert ts._next
return {ts: ts._next}, []
def transition_executing_long_running(
- self, ts: TaskState, compute_duration: float, *, stimulus_id: str
+ self,
+ ts: TaskState,
+ compute_duration: float,
) -> RecsInstrs:
ts.state = "long-running"
self._executing.discard(ts)
@@ -2422,33 +2461,38 @@ def transition_executing_long_running(
return {}, [smsg]
def transition_released_memory(
- self, ts: TaskState, value, *, stimulus_id: str
+ self,
+ ts: TaskState,
+ value,
) -> RecsInstrs:
try:
- recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id)
+ recs = self._put_key_in_memory(ts, value)
except Exception as e:
msg = error_message(e)
recs = {ts: tuple(msg.values())}
return recs, []
- smsg = AddKeysMsg(keys=[ts.key], stimulus_id=stimulus_id)
+ smsg = AddKeysMsg(keys=[ts.key])
return recs, [smsg]
def transition_flight_memory(
- self, ts: TaskState, value, *, stimulus_id: str
+ self,
+ ts: TaskState,
+ value,
) -> RecsInstrs:
self._in_flight_tasks.discard(ts)
ts.coming_from = None
try:
- recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id)
+ recs = self._put_key_in_memory(ts, value)
except Exception as e:
msg = error_message(e)
recs = {ts: tuple(msg.values())}
return recs, []
- smsg = AddKeysMsg(keys=[ts.key], stimulus_id=stimulus_id)
+ smsg = AddKeysMsg(keys=[ts.key])
return recs, [smsg]
def transition_released_forgotten(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> RecsInstrs:
recommendations: Recs = {}
# Dependents _should_ be released by the scheduler before this
@@ -2465,7 +2509,7 @@ def transition_released_forgotten(
return recommendations, []
def _transition(
- self, ts: TaskState, finish: str | tuple, *args, stimulus_id: str, **kwargs
+ self, ts: TaskState, finish: str | tuple, *args, **kwargs
) -> RecsInstrs:
if isinstance(finish, tuple):
# the concatenated transition path might need to access the tuple
@@ -2480,15 +2524,13 @@ def _transition(
if func is not None:
self._transition_counter += 1
- recs, instructions = func(ts, *args, stimulus_id=stimulus_id, **kwargs)
+ recs, instructions = func(ts, *args, **kwargs)
self._notify_plugins("transition", ts.key, start, finish, **kwargs)
elif "released" not in (start, finish):
# start -> "released" -> finish
try:
- recs, instructions = self._transition(
- ts, "released", stimulus_id=stimulus_id
- )
+ recs, instructions = self._transition(ts, "released")
v = recs.get(ts, (finish, *args))
v_state: str
v_args: list | tuple
@@ -2496,9 +2538,7 @@ def _transition(
v_state, *v_args = v
else:
v_state, v_args = v, ()
- b_recs, b_instructions = self._transition(
- ts, v_state, *v_args, stimulus_id=stimulus_id
- )
+ b_recs, b_instructions = self._transition(ts, v_state, *v_args)
recs.update(b_recs)
instructions += b_instructions
except InvalidTransition:
@@ -2523,15 +2563,13 @@ def _transition(
ts.state,
# new recommendations
{ts.key: new for ts, new in recs.items()},
- stimulus_id,
+ STIMULUS_ID.get(),
time(),
)
)
return recs, instructions
- def transition(
- self, ts: TaskState, finish: str, *, stimulus_id: str, **kwargs
- ) -> None:
+ def transition(self, ts: TaskState, finish: str, **kwargs) -> None:
"""Transition a key from its current state to the finish state
Examples
@@ -2547,13 +2585,14 @@ def transition(
--------
Scheduler.transitions: transitive version of this function
"""
- recs, instructions = self._transition(
- ts, finish, stimulus_id=stimulus_id, **kwargs
- )
+ recs, instructions = self._transition(ts, finish, **kwargs)
self._handle_instructions(instructions)
- self.transitions(recs, stimulus_id=stimulus_id)
+ self.transitions(recs)
- def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None:
+ def transitions(
+ self,
+ recommendations: Recs,
+ ) -> None:
"""Process transitions until none are left
This includes feedback from previous transitions and continues until we
@@ -2566,9 +2605,7 @@ def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None:
while remaining_recs:
ts, finish = remaining_recs.popitem()
tasks.add(ts)
- a_recs, a_instructions = self._transition(
- ts, finish, stimulus_id=stimulus_id
- )
+ a_recs, a_instructions = self._transition(ts, finish)
remaining_recs.update(a_recs)
instructions += a_instructions
@@ -2588,12 +2625,16 @@ def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None:
def handle_stimulus(self, stim: StateMachineEvent) -> None:
with log_errors():
- # self.stimulus_history.append(stim) # TODO
- recs, instructions = self.handle_event(stim)
- self.transitions(recs, stimulus_id=stim.stimulus_id)
- self._handle_instructions(instructions)
- self.ensure_computing()
- self.ensure_communicating()
+ token = STIMULUS_ID.set(stim.stimulus_id) # type: ignore
+ try:
+ # self.stimulus_history.append(stim) # TODO
+ recs, instructions = self.handle_event(stim)
+ self.transitions(recs)
+ self._handle_instructions(instructions)
+ self.ensure_computing()
+ self.ensure_communicating()
+ finally:
+ STIMULUS_ID.reset(token)
def _handle_stimulus_from_future(
self, future: asyncio.Future[StateMachineEvent | None]
@@ -2608,25 +2649,27 @@ def _handle_instructions(self, instructions: Instructions) -> None:
# TODO this method is temporary.
# See final design: https://github.com/dask/distributed/issues/5894
for inst in instructions:
- if isinstance(inst, SendMessageToScheduler):
- self.batched_stream.send(inst.to_dict())
- elif isinstance(inst, Execute):
- coro = self.execute(inst.key, stimulus_id=inst.stimulus_id)
- task = asyncio.create_task(coro)
- # TODO track task (at the moment it's fire-and-forget)
- task.add_done_callback(self._handle_stimulus_from_future)
- else:
- raise TypeError(inst) # pragma: nocover
- def maybe_transition_long_running(
- self, ts: TaskState, *, compute_duration: float, stimulus_id: str
- ):
+ token = STIMULUS_ID.set(inst.stimulus_id) # type: ignore
+ try:
+ if isinstance(inst, SendMessageToScheduler):
+ self.batched_stream.send(inst.to_dict())
+ elif isinstance(inst, Execute):
+ coro = self.execute(inst.key)
+ task = asyncio.create_task(coro)
+ # TODO track task (at the moment it's fire-and-forget)
+ task.add_done_callback(self._handle_stimulus_from_future)
+ else:
+ raise TypeError(inst) # pragma: nocover
+ finally:
+ STIMULUS_ID.reset(token)
+
+ def maybe_transition_long_running(self, ts: TaskState, *, compute_duration: float):
if ts.state == "executing":
self.transition(
ts,
"long-running",
compute_duration=compute_duration,
- stimulus_id=stimulus_id,
)
assert ts.state == "long-running"
@@ -2644,62 +2687,70 @@ def story(self, *keys_or_tasks: str | TaskState) -> list[tuple]:
return worker_story(keys, self.log)
def ensure_communicating(self) -> None:
- stimulus_id = f"ensure-communicating-{time()}"
- skipped_worker_in_flight = []
+ with set_default_stimulus(f"ensure-communicating-{time()}"):
+ skipped_worker_in_flight = []
- while self.data_needed and (
- len(self.in_flight_workers) < self.total_out_connections
- or self.comm_nbytes < self.comm_threshold_bytes
- ):
- logger.debug(
- "Ensure communicating. Pending: %d. Connections: %d/%d",
- len(self.data_needed),
- len(self.in_flight_workers),
- self.total_out_connections,
- )
+ while self.data_needed and (
+ len(self.in_flight_workers) < self.total_out_connections
+ or self.comm_nbytes < self.comm_threshold_bytes
+ ):
+ logger.debug(
+ "Ensure communicating. Pending: %d. Connections: %d/%d",
+ len(self.data_needed),
+ len(self.in_flight_workers),
+ self.total_out_connections,
+ )
- ts = self.data_needed.pop()
+ ts = self.data_needed.pop()
- if ts.state != "fetch":
- continue
+ if ts.state != "fetch":
+ continue
- workers = [w for w in ts.who_has if w not in self.in_flight_workers]
- if not workers:
- assert ts.priority is not None
- skipped_worker_in_flight.append(ts)
- continue
+ workers = [w for w in ts.who_has if w not in self.in_flight_workers]
+ if not workers:
+ assert ts.priority is not None
+ skipped_worker_in_flight.append(ts)
+ continue
- host = get_address_host(self.address)
- local = [w for w in workers if get_address_host(w) == host]
- if local:
- worker = random.choice(local)
- else:
- worker = random.choice(list(workers))
- assert worker != self.address
+ host = get_address_host(self.address)
+ local = [w for w in workers if get_address_host(w) == host]
+ if local:
+ worker = random.choice(local)
+ else:
+ worker = random.choice(list(workers))
+ assert worker != self.address
- to_gather, total_nbytes = self.select_keys_for_gather(worker, ts.key)
+ to_gather, total_nbytes = self.select_keys_for_gather(worker, ts.key)
- self.log.append(
- ("gather-dependencies", worker, to_gather, stimulus_id, time())
- )
+ self.log.append(
+ (
+ "gather-dependencies",
+ worker,
+ to_gather,
+ STIMULUS_ID.get(),
+ time(),
+ )
+ )
- self.comm_nbytes += total_nbytes
- self.in_flight_workers[worker] = to_gather
- recommendations: Recs = {
- self.tasks[d]: ("flight", worker) for d in to_gather
- }
- self.transitions(recommendations, stimulus_id=stimulus_id)
-
- self.loop.add_callback(
- self.gather_dep,
- worker=worker,
- to_gather=to_gather,
- total_nbytes=total_nbytes,
- stimulus_id=stimulus_id,
- )
+ self.comm_nbytes += total_nbytes
+ self.in_flight_workers[worker] = to_gather
+ recommendations: Recs = {
+ self.tasks[d]: ("flight", worker) for d in to_gather
+ }
+ self.transitions(recommendations)
+ ctx = copy_context()
+ self.loop.add_callback(
+ ctx.run,
+ functools.partial(
+ self.gather_dep,
+ worker=worker,
+ to_gather=to_gather,
+ total_nbytes=total_nbytes,
+ ),
+ )
- for el in skipped_worker_in_flight:
- self.data_needed.push(el)
+ for el in skipped_worker_in_flight:
+ self.data_needed.push(el)
def _get_task_finished_msg(self, ts: TaskState) -> TaskFinishedMsg:
if ts.key not in self.data and ts.key not in self.actors:
@@ -2729,7 +2780,11 @@ def _get_task_finished_msg(self, ts: TaskState) -> TaskFinishedMsg:
startstops=ts.startstops,
)
- def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs:
+ def _put_key_in_memory(
+ self,
+ ts: TaskState,
+ value,
+ ) -> Recs:
"""
Put a key into memory and set data related task state attributes.
On success, generate recommendations for dependents.
@@ -2779,7 +2834,7 @@ def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs:
self.waiting_for_data_count -= 1
recommendations[dep] = "ready"
- self.log.append((ts.key, "put-in-memory", stimulus_id, time()))
+ self.log.append((ts.key, "put-in-memory", STIMULUS_ID.get(), time()))
return recommendations
def select_keys_for_gather(self, worker, dep):
@@ -2912,8 +2967,6 @@ async def gather_dep(
worker: str,
to_gather: Iterable[str],
total_nbytes: int,
- *,
- stimulus_id: str,
) -> None:
"""Gather dependencies for a task from a worker who has them
@@ -2943,7 +2996,13 @@ async def gather_dep(
if not to_gather_keys:
self.log.append(
- ("nothing-to-gather", worker, to_gather, stimulus_id, time())
+ (
+ "nothing-to-gather",
+ worker,
+ to_gather,
+ STIMULUS_ID.get(),
+ time(),
+ )
)
return
@@ -2953,7 +3012,13 @@ async def gather_dep(
del to_gather
self.log.append(
- ("request-dep", worker, to_gather_keys, stimulus_id, time())
+ (
+ "request-dep",
+ worker,
+ to_gather_keys,
+ STIMULUS_ID.get(),
+ time(),
+ )
)
logger.debug(
"Request %d keys for task %s from %s",
@@ -2978,7 +3043,13 @@ async def gather_dep(
worker=worker,
)
self.log.append(
- ("receive-dep", worker, set(response["data"]), stimulus_id, time())
+ (
+ "receive-dep",
+ worker,
+ set(response["data"]),
+ STIMULUS_ID.get(),
+ time(),
+ )
)
except OSError:
@@ -2986,7 +3057,13 @@ async def gather_dep(
has_what = self.has_what.pop(worker)
self.pending_data_per_worker.pop(worker)
self.log.append(
- ("receive-dep-failed", worker, has_what, stimulus_id, time())
+ (
+ "receive-dep-failed",
+ worker,
+ has_what,
+ STIMULUS_ID.get(),
+ time(),
+ )
)
for d in has_what:
ts = self.tasks[d]
@@ -3010,7 +3087,13 @@ async def gather_dep(
if busy:
self.log.append(
- ("busy-gather", worker, to_gather_keys, stimulus_id, time())
+ (
+ "busy-gather",
+ worker,
+ to_gather_keys,
+ STIMULUS_ID.get(),
+ time(),
+ )
)
for d in self.in_flight_workers.pop(worker):
@@ -3028,13 +3111,18 @@ async def gather_dep(
elif ts not in recommendations:
ts.who_has.discard(worker)
self.has_what[worker].discard(ts.key)
- self.log.append((d, "missing-dep", stimulus_id, time()))
+ self.log.append((d, "missing-dep", STIMULUS_ID.get(), time()))
self.batched_stream.send(
- {"op": "missing-data", "errant_worker": worker, "key": d}
+ {
+ "op": "missing-data",
+ "errant_worker": worker,
+ "key": d,
+ "stimulus_id": STIMULUS_ID.get(),
+ }
)
recommendations[ts] = "fetch" if ts.who_has else "missing"
del data, response
- self.transitions(recommendations, stimulus_id=stimulus_id)
+ self.transitions(recommendations)
self.ensure_computing()
if not busy:
@@ -3049,7 +3137,7 @@ async def gather_dep(
self.ensure_communicating()
async def find_missing(self) -> None:
- with log_errors():
+ with log_errors(), set_default_stimulus(f"find-missing-{time()}"):
if not self._missing_dep_flight:
return
try:
@@ -3057,7 +3145,6 @@ async def find_missing(self) -> None:
for ts in self._missing_dep_flight:
assert not ts.who_has
- stimulus_id = f"find-missing-{time()}"
who_has = await retry_operation(
self.scheduler.who_has,
keys=[ts.key for ts in self._missing_dep_flight],
@@ -3068,7 +3155,7 @@ async def find_missing(self) -> None:
for ts in self._missing_dep_flight:
if ts.who_has:
recommendations[ts] = "fetch"
- self.transitions(recommendations, stimulus_id=stimulus_id)
+ self.transitions(recommendations)
finally:
# This is quite arbitrary but the heartbeat has scaling implemented
@@ -3117,49 +3204,52 @@ def handle_steal_request(self, key: str, stimulus_id: str) -> None:
# There may be a race condition between stealing and releasing a task.
# In this case the self.tasks is already cleared. The `None` will be
# registered as `already-computing` on the other end
- ts = self.tasks.get(key)
- state = ts.state if ts is not None else None
+ with set_default_stimulus(stimulus_id):
+ ts = self.tasks.get(key)
+ state = ts.state if ts is not None else None
- response = {
- "op": "steal-response",
- "key": key,
- "state": state,
- "stimulus_id": stimulus_id,
- }
- self.batched_stream.send(response)
+ response = {
+ "op": "steal-response",
+ "key": key,
+ "state": state,
+ "stimulus_id": stimulus_id,
+ }
+ self.batched_stream.send(response)
- if state in READY | {"waiting"}:
- assert ts
- # If task is marked as "constrained" we haven't yet assigned it an
- # `available_resources` to run on, that happens in
- # `transition_constrained_executing`
- self.transition(ts, "released", stimulus_id=stimulus_id)
+ if state in READY | {"waiting"}:
+ assert ts
+ # If task is marked as "constrained" we haven't yet assigned it an
+ # `available_resources` to run on, that happens in
+ # `transition_constrained_executing`
+ self.transition(ts, "released")
- def handle_worker_status_change(self, status: str) -> None:
- new_status = Status.lookup[status] # type: ignore
+ def handle_worker_status_change(self, status: str, stimulus_id: str) -> None:
+ with set_default_stimulus(stimulus_id):
+ new_status = Status.lookup[status] # type: ignore
- if (
- new_status == Status.closing_gracefully
- and self._status not in Status.ANY_RUNNING # type: ignore
- ):
- logger.error(
- "Invalid Worker.status transition: %s -> %s", self._status, new_status
- )
- # Reiterate the current status to the scheduler to restore sync
- self._send_worker_status_change()
- else:
- # Update status and send confirmation to the Scheduler (see status.setter)
- self.status = new_status
+ if (
+ new_status == Status.closing_gracefully
+ and self._status not in Status.ANY_RUNNING # type: ignore
+ ):
+ logger.error(
+ "Invalid Worker.status transition: %s -> %s",
+ self._status,
+ new_status,
+ )
+ # Reiterate the current status to the scheduler to restore sync
+ self._send_worker_status_change()
+ else:
+ # Update status and send confirmation to the Scheduler (see status.setter)
+ self.status = new_status
def release_key(
self,
key: str,
cause: TaskState | None = None,
report: bool = True,
- *,
- stimulus_id: str,
) -> None:
try:
+ stimulus_id = STIMULUS_ID.get()
if self.validate:
assert not isinstance(key, TaskState)
ts = self.tasks[key]
@@ -3338,7 +3428,8 @@ def meets_resource_constraints(self, key: str) -> bool:
return True
async def _maybe_deserialize_task(
- self, ts: TaskState, *, stimulus_id: str
+ self,
+ ts: TaskState,
) -> tuple[Callable, tuple, dict[str, Any]] | None:
if ts.run_spec is None:
return None
@@ -3358,14 +3449,13 @@ async def _maybe_deserialize_task(
return function, args, kwargs
except Exception as e:
logger.error("Could not deserialize task", exc_info=True)
- self.log.append((ts.key, "deserialize-error", stimulus_id, time()))
+ self.log.append((ts.key, "deserialize-error", STIMULUS_ID.get(), time()))
emsg = error_message(e)
del emsg["status"] # type: ignore
self.transition(
ts,
"error",
**emsg,
- stimulus_id=stimulus_id,
)
raise
@@ -3373,7 +3463,7 @@ def ensure_computing(self) -> None:
if self.status in (Status.paused, Status.closing_gracefully):
return
try:
- stimulus_id = f"ensure-computing-{time()}"
+ STIMULUS_ID.set(f"ensure-computing-{time()}")
while self.constrained and self.executing_count < self.nthreads:
key = self.constrained[0]
ts = self.tasks.get(key, None)
@@ -3382,7 +3472,7 @@ def ensure_computing(self) -> None:
continue
if self.meets_resource_constraints(key):
self.constrained.popleft()
- self.transition(ts, "executing", stimulus_id=stimulus_id)
+ self.transition(ts, "executing")
else:
break
while self.ready and self.executing_count < self.nthreads:
@@ -3395,9 +3485,9 @@ def ensure_computing(self) -> None:
# continue through the heap
continue
elif ts.key in self.data:
- self.transition(ts, "memory", stimulus_id=stimulus_id)
+ self.transition(ts, "memory")
elif ts.state in READY:
- self.transition(ts, "executing", stimulus_id=stimulus_id)
+ self.transition(ts, "executing")
except Exception as e: # pragma: no cover
logger.exception(e)
if LOG_PDB:
@@ -3406,7 +3496,10 @@ def ensure_computing(self) -> None:
pdb.set_trace()
raise
- async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | None:
+ async def execute(
+ self,
+ key: str,
+ ) -> StateMachineEvent | None:
if self.status in {Status.closing, Status.closed, Status.closing_gracefully}:
return None
ts = self.tasks.get(key)
@@ -3417,7 +3510,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
"Trying to execute task %s which is not in executing state anymore",
ts,
)
- return AlreadyCancelledEvent(key=ts.key, stimulus_id=stimulus_id)
+ return AlreadyCancelledEvent(key=ts.key)
try:
if self.validate:
@@ -3426,7 +3519,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
assert ts.run_spec is not None
function, args, kwargs = await self._maybe_deserialize_task( # type: ignore
- ts, stimulus_id=stimulus_id
+ ts
)
args2, kwargs2 = self._prepare_args_for_execution(ts, args, kwargs)
@@ -3480,45 +3573,49 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
self.threads[key] = result["thread"]
- if result["op"] == "task-finished":
- if self.digests is not None:
- self.digests["task-duration"].add(result["stop"] - result["start"])
- return ExecuteSuccessEvent(
+ token = STIMULUS_ID.set(f"{result['op']}-{time()}")
+ try:
+ if result["op"] == "task-finished":
+ if self.digests is not None:
+ self.digests["task-duration"].add(
+ result["stop"] - result["start"]
+ )
+ return ExecuteSuccessEvent(
+ key=key,
+ value=result["result"],
+ start=result["start"],
+ stop=result["stop"],
+ nbytes=result["nbytes"],
+ type=result["type"],
+ )
+ if isinstance(result["actual-exception"], Reschedule):
+ return RescheduleEvent(key=ts.key)
+
+ logger.warning(
+ "Compute Failed\n"
+ "Key: %s\n"
+ "Function: %s\n"
+ "args: %s\n"
+ "kwargs: %s\n"
+ "Exception: %r\n",
+ key,
+ str(funcname(function))[:1000],
+ convert_args_to_str(args2, max_len=1000),
+ convert_kwargs_to_str(kwargs2, max_len=1000),
+ result["exception_text"],
+ )
+ return ExecuteFailureEvent(
key=key,
- value=result["result"],
start=result["start"],
stop=result["stop"],
- nbytes=result["nbytes"],
- type=result["type"],
- stimulus_id=stimulus_id,
+ exception=result["exception"],
+ traceback=result["traceback"],
+ exception_text=result["exception_text"],
+ traceback_text=result["traceback_text"],
)
-
- if isinstance(result["actual-exception"], Reschedule):
- return RescheduleEvent(key=ts.key, stimulus_id=stimulus_id)
-
- logger.warning(
- "Compute Failed\n"
- "Key: %s\n"
- "Function: %s\n"
- "args: %s\n"
- "kwargs: %s\n"
- "Exception: %r\n",
- key,
- str(funcname(function))[:1000],
- convert_args_to_str(args2, max_len=1000),
- convert_kwargs_to_str(kwargs2, max_len=1000),
- result["exception_text"],
- )
- return ExecuteFailureEvent(
- key=key,
- start=result["start"],
- stop=result["stop"],
- exception=result["exception"],
- traceback=result["traceback"],
- exception_text=result["exception_text"],
- traceback_text=result["traceback_text"],
- stimulus_id=stimulus_id,
- )
+ except Exception:
+ STIMULUS_ID.reset(token)
+ raise
except Exception as exc:
logger.error("Exception during execution of task %s.", key, exc_info=True)
@@ -3531,7 +3628,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
traceback=msg["traceback"],
exception_text=msg["exception_text"],
traceback_text=msg["traceback_text"],
- stimulus_id=stimulus_id,
+ stimulus_id=f"task-erred-{time()}",
)
@functools.singledispatchmethod
@@ -3545,7 +3642,7 @@ def _(self, ev: CancelComputeEvent) -> RecsInstrs:
if not ts or ts.state not in READY | {"waiting"}:
return {}, []
- self.log.append((ev.key, "cancel-compute", ev.stimulus_id, time()))
+ self.log.append((ev.key, "cancel-compute", STIMULUS_ID.get(), time()))
# All possible dependents of ts should not be in state Processing on
# scheduler side and therefore should not be assigned to a worker, yet.
assert not ts.dependents
@@ -3596,7 +3693,7 @@ def _(self, ev: ExecuteFailureEvent) -> RecsInstrs:
ev.traceback,
ev.exception_text,
ev.traceback_text,
- )
+ ),
}, []
@handle_event.register
@@ -4172,12 +4269,15 @@ def secede():
worker = get_worker()
tpe_secede() # have this thread secede from the thread pool
duration = time() - thread_state.start_time
- worker.loop.add_callback(
- worker.maybe_transition_long_running,
- worker.tasks[thread_state.key],
- compute_duration=duration,
- stimulus_id=f"secede-{thread_state.key}-{time()}",
- )
+ token = STIMULUS_ID.set(f"secede-{thread_state.key}-{time()}")
+ try:
+ worker.loop.add_callback(
+ worker.maybe_transition_long_running,
+ worker.tasks[thread_state.key],
+ compute_duration=duration,
+ )
+ finally:
+ STIMULUS_ID.reset(token)
class Reschedule(Exception):
diff --git a/distributed/worker_client.py b/distributed/worker_client.py
index 5a775b38191..52a77a58890 100644
--- a/distributed/worker_client.py
+++ b/distributed/worker_client.py
@@ -5,6 +5,7 @@
from distributed.metrics import time
from distributed.threadpoolexecutor import rejoin, secede
+from distributed.utils import set_default_stimulus
from distributed.worker import get_client, get_worker, thread_state
@@ -53,13 +54,13 @@ def worker_client(timeout=None, separate_thread=True):
if separate_thread:
duration = time() - thread_state.start_time
secede() # have this thread secede from the thread pool
- worker.loop.add_callback(
- worker.transition,
- worker.tasks[thread_state.key],
- "long-running",
- stimulus_id=f"worker-client-secede-{time()}",
- compute_duration=duration,
- )
+ with set_default_stimulus(f"worker-client-secede-{time()}"):
+ worker.loop.add_callback(
+ worker.transition,
+ worker.tasks[thread_state.key],
+ "long-running",
+ compute_duration=duration,
+ )
yield client
diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py
index 8ae454417c9..1b468bafe5e 100644
--- a/distributed/worker_state_machine.py
+++ b/distributed/worker_state_machine.py
@@ -12,7 +12,7 @@
from dask.utils import parse_bytes
from distributed.protocol.serialize import Serialize
-from distributed.utils import recursive_to_dict
+from distributed.utils import STIMULUS_ID, recursive_to_dict
if TYPE_CHECKING:
# TODO move to typing and get out of TYPE_CHECKING (requires Python >=3.10)
@@ -69,6 +69,10 @@ class InvalidTransition(Exception):
pass
+def stimulus_id_factory() -> str:
+ return STIMULUS_ID.get()
+
+
@lru_cache
def _default_data_size() -> int:
return parse_bytes(dask.config.get("distributed.scheduler.default-data-size"))
@@ -260,15 +264,14 @@ class Instruction:
@dataclass
class Execute(Instruction):
- __slots__ = ("key", "stimulus_id")
key: str
- stimulus_id: str
+ stimulus_id: str = field(default_factory=stimulus_id_factory)
class SendMessageToScheduler(Instruction):
- __slots__ = ()
#: Matches a key in Scheduler.stream_handlers
op: ClassVar[str]
+ stimulus_id: str = field(default_factory=stimulus_id_factory)
def to_dict(self) -> dict[str, Any]:
"""Convert object to dict so that it can be serialized with msgpack"""
@@ -288,7 +291,8 @@ class TaskFinishedMsg(SendMessageToScheduler):
metadata: dict
thread: int | None
startstops: list[StartStop]
- __slots__ = tuple(__annotations__) # type: ignore
+ stimulus_id: str = field(default_factory=stimulus_id_factory)
+ # __slots__ = tuple(__annotations__) # type: ignore
def to_dict(self) -> dict[str, Any]:
d = super().to_dict()
@@ -307,7 +311,7 @@ class TaskErredMsg(SendMessageToScheduler):
traceback_text: str
thread: int | None
startstops: list[StartStop]
- __slots__ = tuple(__annotations__) # type: ignore
+ stimulus_id: str = field(default_factory=stimulus_id_factory)
def to_dict(self) -> dict[str, Any]:
d = super().to_dict()
@@ -319,8 +323,8 @@ def to_dict(self) -> dict[str, Any]:
class ReleaseWorkerDataMsg(SendMessageToScheduler):
op = "release-worker-data"
- __slots__ = ("key",)
key: str
+ stimulus_id: str = field(default_factory=stimulus_id_factory)
# Not to be confused with RescheduleEvent below or the distributed.Reschedule Exception
@@ -328,9 +332,11 @@ class ReleaseWorkerDataMsg(SendMessageToScheduler):
class RescheduleMsg(SendMessageToScheduler):
op = "reschedule"
+ # Not to be confused with the distributed.Reschedule Exception
__slots__ = ("key", "worker")
key: str
worker: str
+ stimulus_id: str = field(default_factory=stimulus_id_factory)
@dataclass
@@ -340,21 +346,21 @@ class LongRunningMsg(SendMessageToScheduler):
__slots__ = ("key", "compute_duration")
key: str
compute_duration: float
+ stimulus_id: str = field(default_factory=stimulus_id_factory)
@dataclass
class AddKeysMsg(SendMessageToScheduler):
op = "add-keys"
- __slots__ = ("keys", "stimulus_id")
+ __slots__ = "keys"
keys: list[str]
- stimulus_id: str
-@dataclass
+@dataclass()
class StateMachineEvent:
- __slots__ = ("stimulus_id",)
- stimulus_id: str
+ key: str
+ # stimulus_id: str
@dataclass
@@ -365,7 +371,8 @@ class ExecuteSuccessEvent(StateMachineEvent):
stop: float
nbytes: int
type: type | None
- __slots__ = tuple(__annotations__) # type: ignore
+ stimulus_id: str = field(default_factory=stimulus_id_factory)
+ # __slots__ = tuple(__annotations__) # type: ignore
@dataclass
@@ -377,19 +384,22 @@ class ExecuteFailureEvent(StateMachineEvent):
traceback: Serialize | None
exception_text: str
traceback_text: str
- __slots__ = tuple(__annotations__) # type: ignore
+ stimulus_id: str = field(default_factory=stimulus_id_factory)
+ # __slots__ = tuple(__annotations__) # type: ignore
@dataclass
class CancelComputeEvent(StateMachineEvent):
__slots__ = ("key",)
key: str
+ stimulus_id: str = field(default_factory=stimulus_id_factory)
@dataclass
class AlreadyCancelledEvent(StateMachineEvent):
__slots__ = ("key",)
key: str
+ stimulus_id: str = field(default_factory=stimulus_id_factory)
# Not to be confused with RescheduleMsg above or the distributed.Reschedule Exception
@@ -397,6 +407,7 @@ class AlreadyCancelledEvent(StateMachineEvent):
class RescheduleEvent(StateMachineEvent):
__slots__ = ("key",)
key: str
+ stimulus_id: str = field(default_factory=stimulus_id_factory)
if TYPE_CHECKING: