Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 42 additions & 13 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,12 @@ class WorkerState:

This attribute is kept in sync with :attr:`TaskState.processing_on`.

.. attribute:: executing: {TaskState: duration}

A dictionary of tasks that are currently being run on this worker.
Each task state is asssociated with the duration in seconds which
the task has been running.

.. attribute:: has_what: {TaskState}

The set of tasks which currently reside on this worker.
Expand Down Expand Up @@ -334,6 +340,7 @@ class WorkerState:
_actors: set
_address: str
_bandwidth: double
_executing: dict
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're adding a new attribute to the WorkerState class which we recently added type annotations to for Cythonization. @jakirkham I think I took care of everything needed for this addition, but is there some kind of check I can run to make sure Cython is happy? For instance, is being able to successfully build the C-extensions a sufficient check?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks James! Looks like everything is good so far.

Had a few comments on the new method. IDK if we plan to keep that though or not (if not they can just serve as an example for new functions).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry missed your second question. We build with Cython on the Python 3.7 GitHub Action job and run the tests there. If they pass, we should be good 🙂

_extra: dict
_has_what: set
_hash: Py_hash_t
Expand All @@ -360,6 +367,7 @@ class WorkerState:
"_address",
"_bandwidth",
"_extra",
"_executing",
"_has_what",
"_hash",
"_last_seen",
Expand Down Expand Up @@ -418,6 +426,7 @@ def __init__(
self._actors = set()
self._has_what = set()
self._processing = {}
self._executing = {}
self._resources = {}
self._used_resources = {}

Expand Down Expand Up @@ -447,6 +456,10 @@ def address(self):
def bandwidth(self):
return self._bandwidth

@property
def executing(self):
return self._executing

@property
def extra(self):
return self._extra
Expand Down Expand Up @@ -562,6 +575,7 @@ def clean(self):
)
ts: TaskState
ws._processing = {ts._key: cost for ts, cost in self._processing.items()}
ws._executing = {ts._key: duration for ts, duration in self._executing.items()}
return ws

def __repr__(self):
Expand Down Expand Up @@ -2127,6 +2141,7 @@ def heartbeat_worker(
resources=None,
host_info=None,
metrics=None,
executing=None,
):
address = self.coerce_address(address, resolve_address)
address = normalize_address(address)
Expand Down Expand Up @@ -2165,6 +2180,11 @@ def heartbeat_worker(

ws._last_seen = time()

if executing is not None:
ws._executing = {
self.tasks[key]: duration for key, duration in executing.items()
}

if metrics:
ws._metrics = metrics

Expand Down Expand Up @@ -4789,6 +4809,23 @@ def decide_worker(self, ts: TaskState) -> WorkerState:

return ws

def set_duration_estimate(self, ts: TaskState, ws: WorkerState):
"""Estimate task duration using worker state and task state.

If a task takes longer than twice the current average duration we
estimate the task duration to be 2x current-runtime, otherwise we set it
to be the average duration.
"""
duration: double = self.get_task_duration(ts)
comm: double = self.get_comm_cost(ts, ws)
total_duration: double = duration + comm
if ts in ws._executing:
exec_time: double = ws._executing[ts]
if exec_time > 2 * duration:
total_duration = 2 * exec_time
ws._processing[ts] = total_duration
return ws._processing[ts]

def transition_waiting_processing(self, key):
try:
tasks: dict = self.tasks
Expand All @@ -4809,14 +4846,10 @@ def transition_waiting_processing(self, key):
return {}
worker = ws._address

duration = self.get_task_duration(ts)
comm = self.get_comm_cost(ts, ws)
occupancy = duration + comm

ws._processing[ts] = occupancy
duration_estimate = self.set_duration_estimate(ts, ws)
ts._processing_on = ws
ws._occupancy += occupancy
self.total_occupancy += occupancy
ws._occupancy += duration_estimate
self.total_occupancy += duration_estimate
ts.state = "processing"
self.consume_resources(ts, ws)
self.check_idle_saturated(ws)
Expand All @@ -4827,7 +4860,7 @@ def transition_waiting_processing(self, key):

# logger.debug("Send job to worker: %s, %s", worker, key)

self.send_task_to_worker(worker, ts, duration)
self.send_task_to_worker(worker, ts)

return {}
except Exception as e:
Expand Down Expand Up @@ -6101,11 +6134,7 @@ def _reevaluate_occupancy_worker(self, ws: WorkerState):
new = 0
nbytes = 0
for ts in ws._processing:
duration = self.get_task_duration(ts)
comm = self.get_comm_cost(ts, ws)
occupancy = duration + comm
ws._processing[ts] = occupancy
new += occupancy
new += self.set_duration_estimate(ts, ws)

ws._occupancy = new
self.total_occupancy += new - old
Expand Down
16 changes: 16 additions & 0 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,3 +809,19 @@ async def test_worker_stealing_interval(c, s, a, b):
with dask.config.set({"distributed.scheduler.work-stealing-interval": 2}):
ws = WorkStealing(s)
assert ws._pc.callback_time == 2


@gen_cluster(client=True)
async def test_balance_with_longer_task(c, s, a, b):
np = pytest.importorskip("numpy")

await c.submit(slowinc, 0, delay=0) # scheduler learns that slowinc is very fast
x = await c.scatter(np.arange(10000), workers=[a.address])
y = c.submit(
slowinc, 1, delay=5, workers=[a.address], priority=1
) # a surprisingly long task
z = c.submit(
inc, x, workers=[a.address], allow_other_workers=True, priority=0
) # a task after y, suggesting a, but open to b
await z
assert z.key in b.data
14 changes: 7 additions & 7 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1682,19 +1682,19 @@ async def test_update_latency(cleanup):


@pytest.mark.asyncio
async def test_heartbeat_executing(cleanup):
async def test_workerstate_executing(cleanup):
async with await Scheduler() as s:
async with await Worker(s.address) as w:
async with Client(s.address, asynchronous=True) as c:
ws = s.workers[w.address]
# Initially there are no active tasks
assert not ws.metrics["executing"]
# Submit a task and ensure the worker's heartbeat includes the task
# in it's executing
assert not ws.executing
# Submit a task and ensure the WorkerState is updated with the task
# it's executing
f = c.submit(slowinc, 1, delay=1)
while not ws.metrics["executing"]:
await w.heartbeat()
assert f.key in ws.metrics["executing"]
while not ws.executing:
await asyncio.sleep(0.01)
assert s.tasks[f.key] in ws.executing
await f


Expand Down
11 changes: 6 additions & 5 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,7 @@ def local_dir(self):
async def get_metrics(self):
now = time()
core = dict(
executing=self.executing_count,
in_memory=len(self.data),
ready=len(self.ready),
in_flight=self.in_flight_tasks,
Expand All @@ -805,11 +806,6 @@ async def get_metrics(self):
"workers": dict(self.bandwidth_workers),
"types": keymap(typename, self.bandwidth_types),
},
executing={
key: now - self.tasks[key].start_time
for key in self.active_threads.values()
if key in self.tasks
},
)
custom = {}
for k, metric in self.metrics.items():
Expand Down Expand Up @@ -937,6 +933,11 @@ async def heartbeat(self):
address=self.contact_address,
now=time(),
metrics=await self.get_metrics(),
executing={
key: start - self.tasks[key].start_time
for key in self.active_threads.values()
if key in self.tasks
},
)
end = time()
middle = (start + end) / 2
Expand Down