From b3ca1a4f25d2a0276d3e8d85f37201e2a5d36c2f Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 12 May 2022 22:51:42 +0100 Subject: [PATCH 1/2] Transition table as a ClassVar --- distributed/scheduler.py | 118 ++++++++++++++++++---------------- distributed/worker.py | 135 ++++++++++++++++++++++----------------- 2 files changed, 140 insertions(+), 113 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index f236e7e0d1c..2ae2f802a4f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -24,6 +24,7 @@ Hashable, Iterable, Iterator, + Mapping, Set, ) from contextlib import suppress @@ -1256,7 +1257,6 @@ class SchedulerState: "replicated_tasks", "total_nthreads", "total_occupancy", - "transitions_table", "unknown_durations", "unrunnable", "validate", @@ -1308,23 +1308,6 @@ def __init__( self.task_metadata = {} # type: ignore self.total_nthreads = 0 self.total_occupancy = 0.0 - self.transitions_table = { - ("released", "waiting"): self.transition_released_waiting, - ("waiting", "released"): self.transition_waiting_released, - ("waiting", "processing"): self.transition_waiting_processing, - ("waiting", "memory"): self.transition_waiting_memory, - ("processing", "released"): self.transition_processing_released, - ("processing", "memory"): self.transition_processing_memory, - ("processing", "erred"): self.transition_processing_erred, - ("no-worker", "released"): self.transition_no_worker_released, - ("no-worker", "waiting"): self.transition_no_worker_waiting, - ("no-worker", "memory"): self.transition_no_worker_memory, - ("released", "forgotten"): self.transition_released_forgotten, - ("memory", "forgotten"): self.transition_memory_forgotten, - ("erred", "released"): self.transition_erred_released, - ("memory", "released"): self.transition_memory_released, - ("released", "erred"): self.transition_released_erred, - } self.unknown_durations: dict[str, set[TaskState]] = {} self.unrunnable = unrunnable self.validate = validate @@ -1457,15 +1440,14 @@ def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs): dependents = set(ts.dependents) dependencies = set(ts.dependencies) - start_finish = (start, finish) - func = self.transitions_table.get(start_finish) + func = self.TRANSITIONS_TABLE.get((start, finish)) if func is not None: recommendations, client_msgs, worker_msgs = func( - key, stimulus_id, *args, **kwargs + self, key, stimulus_id, *args, **kwargs ) # type: ignore - elif "released" not in start_finish: - assert not args and not kwargs, (args, kwargs, start_finish) + elif "released" not in (start, finish): + assert not args and not kwargs, (args, kwargs, start, finish) a_recs: dict a_cmsgs: dict a_wmsgs: dict @@ -1473,11 +1455,11 @@ def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs): a_recs, a_cmsgs, a_wmsgs = a v = a_recs.get(key, finish) - func = self.transitions_table["released", v] + func = self.TRANSITIONS_TABLE["released", v] b_recs: dict b_cmsgs: dict b_wmsgs: dict - b: tuple = func(key, stimulus_id) # type: ignore + b: tuple = func(self, key, stimulus_id) # type: ignore b_recs, b_cmsgs, b_wmsgs = b recommendations.update(a_recs) @@ -1510,7 +1492,7 @@ def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs): start = "released" else: - raise RuntimeError("Impossible transition from %r to %r" % start_finish) + raise RuntimeError(f"Impossible transition from {start} to {finish}") if not stimulus_id: stimulus_id = STIMULUS_ID_UNSET @@ -1849,32 +1831,6 @@ def decide_worker(self, ts: TaskState) -> WorkerState | None: return ws - def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> float: - """Estimate task duration using worker state and task state. - - If a task takes longer than twice the current average duration we - estimate the task duration to be 2x current-runtime, otherwise we set it - to be the average duration. - - See also ``_remove_from_processing`` - """ - exec_time: float = ws.executing.get(ts, 0) - duration: float = self.get_task_duration(ts) - total_duration: float - if exec_time > 2 * duration: - total_duration = 2 * exec_time - else: - comm: float = self.get_comm_cost(ts, ws) - total_duration = duration + comm - old = ws.processing.get(ts, 0) - ws.processing[ts] = total_duration - - if ts not in ws.long_running: - self.total_occupancy += total_duration - old - ws.occupancy += total_duration - old - - return total_duration - def transition_waiting_processing(self, key, stimulus_id): try: ts: TaskState = self.tasks[key] @@ -2436,7 +2392,7 @@ def transition_no_worker_released(self, key, stimulus_id): pdb.set_trace() raise - def remove_key(self, key): + def _remove_key(self, key): ts: TaskState = self.tasks.pop(key) assert ts.state == "forgotten" self.unrunnable.discard(ts) @@ -2479,7 +2435,7 @@ def transition_memory_forgotten(self, key, stimulus_id): _propagate_forgotten(self, ts, recommendations, worker_msgs, stimulus_id) client_msgs = _task_to_client_msgs(ts) - self.remove_key(key) + self._remove_key(key) return recommendations, client_msgs, worker_msgs except Exception as e: @@ -2517,7 +2473,7 @@ def transition_released_forgotten(self, key, stimulus_id): _propagate_forgotten(self, ts, recommendations, worker_msgs, stimulus_id) client_msgs = _task_to_client_msgs(ts) - self.remove_key(key) + self._remove_key(key) return recommendations, client_msgs, worker_msgs except Exception as e: @@ -2528,10 +2484,62 @@ def transition_released_forgotten(self, key, stimulus_id): pdb.set_trace() raise + # { + # (start, finish): + # transition__( + # self, key: str, stimulus_id: str, *args, **kwargs + # ) -> (recommendations, client_msgs, worker_msgs) + # } + TRANSITIONS_TABLE: ClassVar[ + Mapping[tuple[str, str], Callable[..., tuple[dict, dict, dict]]] + ] = { + ("released", "waiting"): transition_released_waiting, + ("waiting", "released"): transition_waiting_released, + ("waiting", "processing"): transition_waiting_processing, + ("waiting", "memory"): transition_waiting_memory, + ("processing", "released"): transition_processing_released, + ("processing", "memory"): transition_processing_memory, + ("processing", "erred"): transition_processing_erred, + ("no-worker", "released"): transition_no_worker_released, + ("no-worker", "waiting"): transition_no_worker_waiting, + ("no-worker", "memory"): transition_no_worker_memory, + ("released", "forgotten"): transition_released_forgotten, + ("memory", "forgotten"): transition_memory_forgotten, + ("erred", "released"): transition_erred_released, + ("memory", "released"): transition_memory_released, + ("released", "erred"): transition_released_erred, + } + ############################## # Assigning Tasks to Workers # ############################## + def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> float: + """Estimate task duration using worker state and task state. + + If a task takes longer than twice the current average duration we + estimate the task duration to be 2x current-runtime, otherwise we set it + to be the average duration. + + See also ``_remove_from_processing`` + """ + exec_time: float = ws.executing.get(ts, 0) + duration: float = self.get_task_duration(ts) + total_duration: float + if exec_time > 2 * duration: + total_duration = 2 * exec_time + else: + comm: float = self.get_comm_cost(ts, ws) + total_duration = duration + comm + old = ws.processing.get(ts, 0) + ws.processing[ts] = total_duration + + if ts not in ws.long_running: + self.total_occupancy += total_duration - old + ws.occupancy += total_duration - old + + return total_duration + def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0): """Update the status of the idle and saturated state diff --git a/distributed/worker.py b/distributed/worker.py index 653b11a2e9c..83297b114a4 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -452,7 +452,6 @@ class Worker(ServerNode): outgoing_transfer_log: deque[dict[str, Any]] target_message_size: int validate: bool - _transitions_table: dict[tuple[str, str], Callable] transition_counter: int transition_counter_max: int | Literal[False] incoming_count: int @@ -609,57 +608,6 @@ def __init__( if validate is None: validate = dask.config.get("distributed.scheduler.validate") self.validate = validate - self._transitions_table = { - ("cancelled", "fetch"): self.transition_cancelled_fetch, - ("cancelled", "released"): self.transition_cancelled_released, - ("cancelled", "missing"): self.transition_cancelled_released, - ("cancelled", "waiting"): self.transition_cancelled_waiting, - ("cancelled", "forgotten"): self.transition_cancelled_forgotten, - ("cancelled", "memory"): self.transition_cancelled_memory, - ("cancelled", "error"): self.transition_cancelled_error, - ("resumed", "memory"): self.transition_generic_memory, - ("resumed", "error"): self.transition_generic_error, - ("resumed", "released"): self.transition_resumed_released, - ("resumed", "waiting"): self.transition_resumed_waiting, - ("resumed", "fetch"): self.transition_resumed_fetch, - ("resumed", "missing"): self.transition_resumed_missing, - ("constrained", "executing"): self.transition_constrained_executing, - ("constrained", "released"): self.transition_generic_released, - ("error", "released"): self.transition_generic_released, - ("executing", "error"): self.transition_executing_error, - ("executing", "long-running"): self.transition_executing_long_running, - ("executing", "memory"): self.transition_executing_memory, - ("executing", "released"): self.transition_executing_released, - ("executing", "rescheduled"): self.transition_executing_rescheduled, - ("fetch", "flight"): self.transition_fetch_flight, - ("fetch", "missing"): self.transition_generic_missing, - ("fetch", "released"): self.transition_generic_released, - ("flight", "error"): self.transition_flight_error, - ("flight", "fetch"): self.transition_flight_fetch, - ("flight", "memory"): self.transition_flight_memory, - ("flight", "missing"): self.transition_flight_missing, - ("flight", "released"): self.transition_flight_released, - ("long-running", "error"): self.transition_generic_error, - ("long-running", "memory"): self.transition_long_running_memory, - ("long-running", "rescheduled"): self.transition_executing_rescheduled, - ("long-running", "released"): self.transition_executing_released, - ("memory", "released"): self.transition_memory_released, - ("missing", "fetch"): self.transition_missing_fetch, - ("missing", "released"): self.transition_missing_released, - ("missing", "error"): self.transition_generic_error, - ("ready", "error"): self.transition_generic_error, - ("ready", "executing"): self.transition_ready_executing, - ("ready", "released"): self.transition_generic_released, - ("released", "error"): self.transition_generic_error, - ("released", "fetch"): self.transition_released_fetch, - ("released", "missing"): self.transition_released_fetch, - ("released", "forgotten"): self.transition_released_forgotten, - ("released", "memory"): self.transition_released_memory, - ("released", "waiting"): self.transition_released_waiting, - ("waiting", "constrained"): self.transition_waiting_constrained, - ("waiting", "ready"): self.transition_waiting_ready, - ("waiting", "released"): self.transition_generic_released, - } self.transition_counter = 0 self.transition_counter_max = dask.config.get( @@ -2025,6 +1973,10 @@ def handle_compute_task( self.update_who_has(who_has) self.transitions(recommendations, stimulus_id=stimulus_id) + ######################## + # Worker State Machine # + ######################## + def _add_to_data_needed(self, ts: TaskState, stimulus_id: str) -> RecsInstrs: self.data_needed.push(ts) for w in ts.who_has: @@ -2646,8 +2598,73 @@ def transition_released_forgotten( self.tasks.pop(ts.key, None) return recommendations, [] + # { + # (start, finish): + # transition__( + # self, ts: TaskState, *args, stimulus_id: str + # ) -> (recommendations, instructions) + # } + TRANSITIONS_TABLE: ClassVar[ + Mapping[tuple[TaskStateState, TaskStateState], Callable[..., RecsInstrs]] + ] = { + ("cancelled", "fetch"): transition_cancelled_fetch, + ("cancelled", "released"): transition_cancelled_released, + ("cancelled", "missing"): transition_cancelled_released, + ("cancelled", "waiting"): transition_cancelled_waiting, + ("cancelled", "forgotten"): transition_cancelled_forgotten, + ("cancelled", "memory"): transition_cancelled_memory, + ("cancelled", "error"): transition_cancelled_error, + ("resumed", "memory"): transition_generic_memory, + ("resumed", "error"): transition_generic_error, + ("resumed", "released"): transition_resumed_released, + ("resumed", "waiting"): transition_resumed_waiting, + ("resumed", "fetch"): transition_resumed_fetch, + ("resumed", "missing"): transition_resumed_missing, + ("constrained", "executing"): transition_constrained_executing, + ("constrained", "released"): transition_generic_released, + ("error", "released"): transition_generic_released, + ("executing", "error"): transition_executing_error, + ("executing", "long-running"): transition_executing_long_running, + ("executing", "memory"): transition_executing_memory, + ("executing", "released"): transition_executing_released, + ("executing", "rescheduled"): transition_executing_rescheduled, + ("fetch", "flight"): transition_fetch_flight, + ("fetch", "missing"): transition_generic_missing, + ("fetch", "released"): transition_generic_released, + ("flight", "error"): transition_flight_error, + ("flight", "fetch"): transition_flight_fetch, + ("flight", "memory"): transition_flight_memory, + ("flight", "missing"): transition_flight_missing, + ("flight", "released"): transition_flight_released, + ("long-running", "error"): transition_generic_error, + ("long-running", "memory"): transition_long_running_memory, + ("long-running", "rescheduled"): transition_executing_rescheduled, + ("long-running", "released"): transition_executing_released, + ("memory", "released"): transition_memory_released, + ("missing", "fetch"): transition_missing_fetch, + ("missing", "released"): transition_missing_released, + ("missing", "error"): transition_generic_error, + ("ready", "error"): transition_generic_error, + ("ready", "executing"): transition_ready_executing, + ("ready", "released"): transition_generic_released, + ("released", "error"): transition_generic_error, + ("released", "fetch"): transition_released_fetch, + ("released", "missing"): transition_generic_missing, + ("released", "forgotten"): transition_released_forgotten, + ("released", "memory"): transition_released_memory, + ("released", "waiting"): transition_released_waiting, + ("waiting", "constrained"): transition_waiting_constrained, + ("waiting", "ready"): transition_waiting_ready, + ("waiting", "released"): transition_generic_released, + } + def _transition( - self, ts: TaskState, finish: str | tuple, *args, stimulus_id: str, **kwargs + self, + ts: TaskState, + finish: TaskStateState | tuple, + *args, + stimulus_id: str, + **kwargs, ) -> RecsInstrs: if isinstance(finish, tuple): # the concatenated transition path might need to access the tuple @@ -2658,7 +2675,7 @@ def _transition( return {}, [] start = ts.state - func = self._transitions_table.get((start, cast(str, finish))) + func = self.TRANSITIONS_TABLE.get((start, cast(TaskStateState, finish))) # Notes: # - in case of transition through released, this counter is incremented by 2 @@ -2682,7 +2699,9 @@ def _transition( raise TransitionCounterMaxExceeded(ts.key, start, finish, self.story(ts)) if func is not None: - recs, instructions = func(ts, *args, stimulus_id=stimulus_id, **kwargs) + recs, instructions = func( + self, ts, *args, stimulus_id=stimulus_id, **kwargs + ) self._notify_plugins("transition", ts.key, start, finish, **kwargs) elif "released" not in (start, finish): @@ -2691,7 +2710,7 @@ def _transition( recs, instructions = self._transition( ts, "released", stimulus_id=stimulus_id ) - v_state: str + v_state: TaskStateState v_args: list | tuple while v := recs.pop(ts, None): if isinstance(v, tuple): @@ -2755,13 +2774,13 @@ def _transition( return recs, instructions def transition( - self, ts: TaskState, finish: str, *, stimulus_id: str, **kwargs + self, ts: TaskState, finish: TaskStateState, *, stimulus_id: str, **kwargs ) -> None: """Transition a key from its current state to the finish state Examples -------- - >>> self.transition('x', 'waiting') + >>> self.transition('x', 'waiting', stimulus_id=f"test-{(time()}") {'x': 'processing'} Returns From d31acdaf1c3abe2ad40366359ef5cc6991f03392 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 13 May 2022 11:16:01 +0100 Subject: [PATCH 2/2] code review --- distributed/scheduler.py | 6 +++--- distributed/worker.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 2ae2f802a4f..8cf312feb2e 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1440,7 +1440,7 @@ def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs): dependents = set(ts.dependents) dependencies = set(ts.dependencies) - func = self.TRANSITIONS_TABLE.get((start, finish)) + func = self._TRANSITIONS_TABLE.get((start, finish)) if func is not None: recommendations, client_msgs, worker_msgs = func( self, key, stimulus_id, *args, **kwargs @@ -1455,7 +1455,7 @@ def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs): a_recs, a_cmsgs, a_wmsgs = a v = a_recs.get(key, finish) - func = self.TRANSITIONS_TABLE["released", v] + func = self._TRANSITIONS_TABLE["released", v] b_recs: dict b_cmsgs: dict b_wmsgs: dict @@ -2490,7 +2490,7 @@ def transition_released_forgotten(self, key, stimulus_id): # self, key: str, stimulus_id: str, *args, **kwargs # ) -> (recommendations, client_msgs, worker_msgs) # } - TRANSITIONS_TABLE: ClassVar[ + _TRANSITIONS_TABLE: ClassVar[ Mapping[tuple[str, str], Callable[..., tuple[dict, dict, dict]]] ] = { ("released", "waiting"): transition_released_waiting, diff --git a/distributed/worker.py b/distributed/worker.py index 83297b114a4..3c3ec2ae747 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2604,7 +2604,7 @@ def transition_released_forgotten( # self, ts: TaskState, *args, stimulus_id: str # ) -> (recommendations, instructions) # } - TRANSITIONS_TABLE: ClassVar[ + _TRANSITIONS_TABLE: ClassVar[ Mapping[tuple[TaskStateState, TaskStateState], Callable[..., RecsInstrs]] ] = { ("cancelled", "fetch"): transition_cancelled_fetch, @@ -2675,7 +2675,7 @@ def _transition( return {}, [] start = ts.state - func = self.TRANSITIONS_TABLE.get((start, cast(TaskStateState, finish))) + func = self._TRANSITIONS_TABLE.get((start, cast(TaskStateState, finish))) # Notes: # - in case of transition through released, this counter is incremented by 2