diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ed3f7150e9e..c66c7183d53 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -274,6 +274,12 @@ class WorkerState: This attribute is kept in sync with :attr:`TaskState.processing_on`. + .. attribute:: executing: {TaskState: duration} + + A dictionary of tasks that are currently being run on this worker. + Each task state is asssociated with the duration in seconds which + the task has been running. + .. attribute:: has_what: {TaskState} The set of tasks which currently reside on this worker. @@ -334,6 +340,7 @@ class WorkerState: _actors: set _address: str _bandwidth: double + _executing: dict _extra: dict _has_what: set _hash: Py_hash_t @@ -360,6 +367,7 @@ class WorkerState: "_address", "_bandwidth", "_extra", + "_executing", "_has_what", "_hash", "_last_seen", @@ -418,6 +426,7 @@ def __init__( self._actors = set() self._has_what = set() self._processing = {} + self._executing = {} self._resources = {} self._used_resources = {} @@ -447,6 +456,10 @@ def address(self): def bandwidth(self): return self._bandwidth + @property + def executing(self): + return self._executing + @property def extra(self): return self._extra @@ -562,6 +575,7 @@ def clean(self): ) ts: TaskState ws._processing = {ts._key: cost for ts, cost in self._processing.items()} + ws._executing = {ts._key: duration for ts, duration in self._executing.items()} return ws def __repr__(self): @@ -2127,6 +2141,7 @@ def heartbeat_worker( resources=None, host_info=None, metrics=None, + executing=None, ): address = self.coerce_address(address, resolve_address) address = normalize_address(address) @@ -2165,6 +2180,11 @@ def heartbeat_worker( ws._last_seen = time() + if executing is not None: + ws._executing = { + self.tasks[key]: duration for key, duration in executing.items() + } + if metrics: ws._metrics = metrics @@ -4789,6 +4809,23 @@ def decide_worker(self, ts: TaskState) -> WorkerState: return ws + def set_duration_estimate(self, ts: TaskState, ws: WorkerState): + """Estimate task duration using worker state and task state. + + If a task takes longer than twice the current average duration we + estimate the task duration to be 2x current-runtime, otherwise we set it + to be the average duration. + """ + duration: double = self.get_task_duration(ts) + comm: double = self.get_comm_cost(ts, ws) + total_duration: double = duration + comm + if ts in ws._executing: + exec_time: double = ws._executing[ts] + if exec_time > 2 * duration: + total_duration = 2 * exec_time + ws._processing[ts] = total_duration + return ws._processing[ts] + def transition_waiting_processing(self, key): try: tasks: dict = self.tasks @@ -4809,14 +4846,10 @@ def transition_waiting_processing(self, key): return {} worker = ws._address - duration = self.get_task_duration(ts) - comm = self.get_comm_cost(ts, ws) - occupancy = duration + comm - - ws._processing[ts] = occupancy + duration_estimate = self.set_duration_estimate(ts, ws) ts._processing_on = ws - ws._occupancy += occupancy - self.total_occupancy += occupancy + ws._occupancy += duration_estimate + self.total_occupancy += duration_estimate ts.state = "processing" self.consume_resources(ts, ws) self.check_idle_saturated(ws) @@ -4827,7 +4860,7 @@ def transition_waiting_processing(self, key): # logger.debug("Send job to worker: %s, %s", worker, key) - self.send_task_to_worker(worker, ts, duration) + self.send_task_to_worker(worker, ts) return {} except Exception as e: @@ -6101,11 +6134,7 @@ def _reevaluate_occupancy_worker(self, ws: WorkerState): new = 0 nbytes = 0 for ts in ws._processing: - duration = self.get_task_duration(ts) - comm = self.get_comm_cost(ts, ws) - occupancy = duration + comm - ws._processing[ts] = occupancy - new += occupancy + new += self.set_duration_estimate(ts, ws) ws._occupancy = new self.total_occupancy += new - old diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 8277ede1833..bee0f06544e 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -809,3 +809,19 @@ async def test_worker_stealing_interval(c, s, a, b): with dask.config.set({"distributed.scheduler.work-stealing-interval": 2}): ws = WorkStealing(s) assert ws._pc.callback_time == 2 + + +@gen_cluster(client=True) +async def test_balance_with_longer_task(c, s, a, b): + np = pytest.importorskip("numpy") + + await c.submit(slowinc, 0, delay=0) # scheduler learns that slowinc is very fast + x = await c.scatter(np.arange(10000), workers=[a.address]) + y = c.submit( + slowinc, 1, delay=5, workers=[a.address], priority=1 + ) # a surprisingly long task + z = c.submit( + inc, x, workers=[a.address], allow_other_workers=True, priority=0 + ) # a task after y, suggesting a, but open to b + await z + assert z.key in b.data diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index fc758da2338..112263ecd7d 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1682,19 +1682,19 @@ async def test_update_latency(cleanup): @pytest.mark.asyncio -async def test_heartbeat_executing(cleanup): +async def test_workerstate_executing(cleanup): async with await Scheduler() as s: async with await Worker(s.address) as w: async with Client(s.address, asynchronous=True) as c: ws = s.workers[w.address] # Initially there are no active tasks - assert not ws.metrics["executing"] - # Submit a task and ensure the worker's heartbeat includes the task - # in it's executing + assert not ws.executing + # Submit a task and ensure the WorkerState is updated with the task + # it's executing f = c.submit(slowinc, 1, delay=1) - while not ws.metrics["executing"]: - await w.heartbeat() - assert f.key in ws.metrics["executing"] + while not ws.executing: + await asyncio.sleep(0.01) + assert s.tasks[f.key] in ws.executing await f diff --git a/distributed/worker.py b/distributed/worker.py index 205369623d1..87ee7fe368d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -797,6 +797,7 @@ def local_dir(self): async def get_metrics(self): now = time() core = dict( + executing=self.executing_count, in_memory=len(self.data), ready=len(self.ready), in_flight=self.in_flight_tasks, @@ -805,11 +806,6 @@ async def get_metrics(self): "workers": dict(self.bandwidth_workers), "types": keymap(typename, self.bandwidth_types), }, - executing={ - key: now - self.tasks[key].start_time - for key in self.active_threads.values() - if key in self.tasks - }, ) custom = {} for k, metric in self.metrics.items(): @@ -937,6 +933,11 @@ async def heartbeat(self): address=self.contact_address, now=time(), metrics=await self.get_metrics(), + executing={ + key: start - self.tasks[key].start_time + for key in self.active_threads.values() + if key in self.tasks + }, ) end = time() middle = (start + end) / 2