diff --git a/distributed/batched.py b/distributed/batched.py index a9b2be3a43b..f2c370d820b 100644 --- a/distributed/batched.py +++ b/distributed/batched.py @@ -50,6 +50,7 @@ def __init__(self, interval, loop=None): self.stream = None self.message_count = 0 self.batch_count = 0 + self.byte_count = 0 self.next_deadline = None def start(self, stream): @@ -79,7 +80,8 @@ def _background_send(self): self.batch_count += 1 self.next_deadline = self.loop.time() + self.interval try: - yield write(self.stream, payload) + nbytes = yield write(self.stream, payload) + self.byte_count += nbytes except StreamClosedError: logger.info("Batched Stream Closed") break diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 42126f6cff1..224211124d8 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -50,3 +50,11 @@ def restart(self, scheduler, **kwargs): def transition(self, key, start, finish, *args, **kwargs): pass + + def add_worker(self, scheduler=None, worker=None, **kwargs): + """ Run when a new worker enters the cluster """ + pass + + def remove_worker(self, scheduler=None, worker=None, **kwargs): + """ Run when a worker leaves the cluster""" + pass diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 29ae8541a63..e86c826c5f4 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -220,7 +220,7 @@ def __init__(self, center=None, loop=None, self.nbytes = dict() self.worker_bytes = dict() self.processing = dict() - self.rprocessing = defaultdict(set) + self.rprocessing = dict() self.task_duration = {prefix: 0.00001 for prefix in fast_tasks} self.unknown_durations = defaultdict(set) self.host_restrictions = dict() @@ -241,6 +241,7 @@ def __init__(self, center=None, loop=None, self.tracebacks = dict() self.exceptions_blame = dict() self.datasets = dict() + self.n_tasks = 0 self.idle = SortedSet() self.saturated = set() @@ -526,7 +527,11 @@ def add_worker(self, stream=None, address=None, keys=(), ncores=None, if recommendations: self.transitions(recommendations) - # self.(address) + for plugin in self.plugins[:]: + try: + plugin.add_worker(scheduler=self, worker=address) + except Exception as e: + logger.exception(e) logger.info("Register %s", str(address)) return 'OK' @@ -770,6 +775,12 @@ def remove_worker(self, stream=None, address=None, safe=False): self.transitions(recommendations) + for plugin in self.plugins[:]: + try: + plugin.remove_worker(scheduler=self, worker=address) + except Exception as e: + logger.exception(e) + if not self.processing: logger.info("Lost all workers") @@ -1845,12 +1856,12 @@ def transition_waiting_processing(self, key): if len(self.idle) < 20: # smart but linear in small case worker = min(self.idle, key=self.occupancy.get) else: # dumb but fast in large case - worker = random.choice(self.idle) + worker = self.idle[self.n_tasks % len(self.idle)] else: if len(self.workers) < 20: # smart but linear in small case worker = min(self.workers, key=self.occupancy.get) else: # dumb but fast in large case - worker = random.choice(self.workers) + worker = self.workers[self.n_tasks % len(self.workers)] assert worker @@ -1862,12 +1873,13 @@ def transition_waiting_processing(self, key): duration = 0.5 self.processing[worker][key] = duration - self.rprocessing[key].add(worker) + self.rprocessing[key] = {worker} self.occupancy[worker] += duration self.total_occupancy += duration self.task_state[key] = 'processing' self.consume_resources(key, worker) self.check_idle_saturated(worker) + self.n_tasks += 1 # logger.debug("Send job to worker: %s, %s", worker, key) @@ -2531,7 +2543,7 @@ def check_idle_saturated(self, worker): self.idle.remove(worker) pending = occ * (p - nc) / nc - if p > nc and pending > 0.2 and pending > 1.9 * avg: + if p > nc and pending > 0.4 and pending > 1.9 * avg: self.saturated.add(worker) def valid_workers(self, key): diff --git a/distributed/stealing.py b/distributed/stealing.py index dd59fccb24a..7e2ad59421a 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -6,6 +6,7 @@ import random from time import time +from toolz import first from tornado.iostream import StreamClosedError from tornado.ioloop import PeriodicCallback @@ -14,17 +15,25 @@ BANDWIDTH = 100e6 LATENCY = 10e-3 +log_2 = log(2) logger = logging.getLogger(__name__) + class WorkStealing(SchedulerPlugin): def __init__(self, scheduler): self.scheduler = scheduler - self.stealable = [set() for i in range(15)] + self.stealable = dict() self.key_stealable = dict() self.stealable_unknown_durations = defaultdict(set) + self.cost_multipliers = [1 + 2 ** (i - 6) for i in range(15)] + self.cost_multipliers[0] = 1 + + for worker in scheduler.workers: + self.add_worker(worker=worker) + self._pc = PeriodicCallback(callback=self.balance, callback_time=100, io_loop=self.scheduler.loop) @@ -33,6 +42,12 @@ def __init__(self, scheduler): self.scheduler.extensions['stealing'] = self self.log = deque(maxlen=100000) + def add_worker(self, scheduler=None, worker=None): + self.stealable[worker] = [set() for i in range(15)] + + def remove_worker(self, scheduler=None, worker=None): + del self.stealable[worker] + def teardown(self): self._pc.stop() @@ -41,25 +56,33 @@ def transition(self, key, start, finish, compute_start=None, if finish == 'processing': self.put_key_in_stealable(key) - if start == 'processing' and finish == 'memory': + if start == 'processing': self.remove_key_from_stealable(key) - ks = key_split(key) - if ks in self.stealable_unknown_durations: - for key in self.stealable_unknown_durations.pop(ks): - self.put_key_in_stealable(key, split=ks) + if finish == 'memory': + ks = key_split(key) + if ks in self.stealable_unknown_durations: + for k in self.stealable_unknown_durations.pop(ks): + if self.scheduler.task_state[k] == 'processing': + self.put_key_in_stealable(k, split=ks) def put_key_in_stealable(self, key, split=None): - ratio, loc = self.steal_time_ratio(key, split=split) - if ratio is not None: - self.stealable[loc].add(key) - self.key_stealable[key] = loc + try: + worker = first(self.scheduler.rprocessing[key]) + except Exception as e: + logger.exception(e) + import pdb; pdb.set_trace() + cost_multiplier, level = self.steal_time_ratio(key, split=split) + if cost_multiplier is not None: + self.stealable[worker][level].add(key) + self.key_stealable[key] = (worker, level) def remove_key_from_stealable(self, key): - loc = self.key_stealable.pop(key, None) - if loc is not None: + result = self.key_stealable.pop(key, None) + if result is not None: + worker, level = result try: - self.stealable[loc].remove(key) + self.stealable[worker][level].remove(key) except: pass @@ -69,8 +92,9 @@ def steal_time_ratio(self, key, split=None): Returns ------- - ratio: The compute/communication time ratio of the task - loc: The self.stealable bin into which this key should go + cost_multiplier: The increased cost from moving this task as a factor. + For example a result of zero implies a task without dependencies. + level: The location within a stealable list to place this value """ if (key not in self.scheduler.loose_restrictions and (key in self.scheduler.host_restrictions or @@ -79,7 +103,7 @@ def steal_time_ratio(self, key, split=None): return None, None # don't steal if not self.scheduler.dependencies[key]: # no dependencies fast path - return 10000, 0 + return 0, 0 nbytes = sum(self.scheduler.nbytes.get(k, 1000) for k in self.scheduler.dependencies[key]) @@ -94,19 +118,13 @@ def steal_time_ratio(self, key, split=None): self.stealable_unknown_durations[split].add(key) return None, None else: - try: - ratio = compute_time / transfer_time - except ZeroDivisionError: - ratio = 10000 - if ratio == 10000: - loc = 0 - elif ratio > 32: - loc = 1 - elif ratio < 2**-8: - loc = -1 - else: - loc = int(-round(log(ratio) / log(2), 0) + 5) + 1 - return ratio, loc + cost_multiplier = transfer_time / compute_time + if cost_multiplier > 100: + return None, None + + level = int(round(log(cost_multiplier) / log_2 + 6, 0)) + level = max(1, level) + return cost_multiplier, level def move_task(self, key, victim, thief): with log_errors(): @@ -114,6 +132,7 @@ def move_task(self, key, victim, thief): if victim not in self.scheduler.rprocessing[key]: import pdb; pdb.set_trace() + # self.remove_key_from_stealable(key) logger.info("Moved %s, %s: %2f -> %s: %2f", key, victim, self.scheduler.occupancy[victim], thief, self.scheduler.occupancy[thief]) @@ -139,53 +158,61 @@ def move_task(self, key, victim, thief): def balance(self): with log_errors(): - if not self.scheduler.idle or not self.scheduler.saturated: + i = 0 + s = self.scheduler + occupancy = s.occupancy + idle = s.idle + saturated = s.saturated + if not idle or not saturated: return - broken = False + start = time() - with log_errors(): - start = time() - for level, stealable in enumerate(self.stealable[:-1]): - if broken or not stealable: - continue + broken = False + seen = False + acted = False - original = stealable + if len(s.saturated) < 20: + saturated = sorted(saturated, key=occupancy.get, reverse=True) - ratio = 2 ** (level - 5 + 1) + if len(idle) < 20: + idle = sorted(idle, key=occupancy.get) - n_stealable = sum(len(s) for s in self.stealable[level:-1]) - duration_if_hold = n_stealable / len(self.scheduler.saturated) - duration_if_steal = ratio - if level > 1 and duration_if_hold < duration_if_steal: - break + for level, cost_multiplier in enumerate(self.cost_multipliers): + if not idle or not saturated: + break + for sat in list(saturated): + stealable = self.stealable[sat][level] + if not stealable or not idle: + continue + else: + seen = True for key in list(stealable): - if self.scheduler.task_state.get(key) != 'processing': - original.remove(key) - continue - victim = max(self.scheduler.rprocessing[key], - key=self.scheduler.occupancy.get) - if victim not in self.scheduler.idle: - thief = random.choice(self.scheduler.idle) - self.move_task(key, victim, thief) - self.log.append((level, victim, thief, key)) - self.scheduler.check_idle_saturated(victim) - self.scheduler.check_idle_saturated(thief) - original.remove(key) - - if not self.scheduler.idle or not self.scheduler.saturated: - broken = True - break - - stop = time() - if self.scheduler.digests: - self.scheduler.digests['steal-duration'].add(stop - start) - - def restart(self): - for stealable in self.stealable: - stealable.clear() + idl = idle[i % len(idle)] + i += 1 + duration = s.task_duration[key_split(key)] + + if (occupancy[idl] + cost_multiplier * duration + <= occupancy[sat] - duration / 2): + self.move_task(key, sat, idl) + self.log.append((level, sat, idl, key)) + self.scheduler.check_idle_saturated(sat) + self.scheduler.check_idle_saturated(idl) + stealable.remove(key) + acted = True + if seen and not acted: + break + + stop = time() + if self.scheduler.digests: + self.scheduler.digests['steal-duration'].add(stop - start) + + def restart(self, scheduler): + for stealable in self.stealable.values(): + for s in stealable: + s.clear() self.key_stealable.clear() self.stealable_unknown_durations.clear() diff --git a/distributed/tests/test_batched.py b/distributed/tests/test_batched.py index e8b6b1ebb15..a9cdca72616 100644 --- a/distributed/tests/test_batched.py +++ b/distributed/tests/test_batched.py @@ -77,6 +77,8 @@ def test_BatchedSend(): result = yield read(stream); assert result == ['hello', 'hello', 'world'] result = yield read(stream); assert result == ['HELLO', 'HELLO'] + assert b.byte_count > 1 + @gen_test() def test_send_before_start(): diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index a7d0f48389b..c3c4d62bf3c 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -151,4 +151,4 @@ def test_close_on_disconnect(s, w): start = time() while w.status != 'closed': yield gen.sleep(0.01) - assert time() < start + 5 + assert time() < start + 9 diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index a8cabbc7099..dc796fc410f 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -828,3 +828,18 @@ def test_learn_occupancy_2(c, s, a, b): yield gen.sleep(0.01) assert 50 < s.total_occupancy < 200 + + +@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 30) +def test_balance_many_workers(c, s, *workers): + futures = c.map(slowinc, range(20), delay=0.2) + yield _wait(futures) + assert set(map(len, s.has_what.values())) == {0, 1} + + +@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 30) +def test_balance_many_workers_2(c, s, *workers): + s.extensions['stealing']._pc.callback_time = 100000000 + futures = c.map(slowinc, range(90), delay=0.2) + yield _wait(futures) + assert set(map(len, s.has_what.values())) == {3} diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 82acc747d85..8f7c1cf82e1 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1,12 +1,14 @@ from __future__ import print_function, division, absolute_import +import itertools +from functools import partial from operator import mul import random import sys from time import sleep import pytest -from toolz import sliding_window +from toolz import sliding_window, concat from tornado import gen import dask @@ -14,6 +16,7 @@ from distributed import Worker from distributed.client import Client, _wait, wait from distributed.metrics import time +from distributed.scheduler import BANDWIDTH, key_split from distributed.utils_test import (cluster, slowinc, slowadd, randominc, loop, inc, dec, div, throws, gen_cluster, gen_test, double, deep, slowidentity) @@ -333,3 +336,105 @@ def slow2(x): # good future moves first assert any(future.key in w.task_state for w in rest) + + +def func(x): + sleep(1) + pass + + +def assert_balanced(inp, out, c, s, *workers): + steal = s.extensions['stealing'] + steal._pc.callback_time = 1000000000 + counter = itertools.count() + B = BANDWIDTH + tasks = list(concat(inp)) + data = yield c._scatter(range(len(tasks))) + + for t, f in zip(tasks, data): + s.nbytes[f.key] = BANDWIDTH * t + s.task_duration[str(int(t))] = 1 + + futures = [] + data_seq = iter(data) + for w, ts in zip(workers, inp): + for t in ts: + dat = next(data_seq) if t else 123 + f = c.submit(func, dat, key='%d-%d' % (int(t), next(counter)), + workers=w.address, allow_other_workers=True) + futures.append(f) + + while not any(s.processing.values()): + yield gen.sleep(0.001) + + s.extensions['stealing'].balance() + + result = [sorted([int(key_split(k)) for k in s.processing[w.address]], + reverse=True) + for w in workers] + + result = set(map(tuple, result)) + out = set(map(tuple, out)) + + if result != out: + import pdb; pdb.set_trace() + + assert result == out + + +@pytest.mark.parametrize('inp,out', [ + ([[1], []], # don't move unnecessarily + [[1], []]), + + ([[0, 0], []], # balance + [[0], [0]]), + + ([[0.1, 0.1], []], # balance even if results in even + [[0], [0]]), + + ([[0, 0, 0], []], # don't over balance + [[0, 0], [0]]), + + ([[0, 0], [0, 0, 0], []], # move from larger + [[0, 0], [0, 0], [0]]), + + ([[0, 0, 0], [0], []], # move to smaller + [[0, 0], [0], [0]]), + + ([[0, 1], []], # choose easier first + [[1], [0]]), + + ([[0, 0, 0, 0], [], []], # spread evenly + [[0, 0], [0], [0]]), + + ([[1, 0, 2, 0], [], []], # move easier + [[2, 1], [0], [0]]), + + ([[1, 1, 1], []], # be willing to move costly items + [[1, 1], [1]]), + + ([[1, 1, 1, 1], []], # but don't move too many + [[1, 1, 1], [1]]), + + ([[4, 2, 2, 2, 2, 1, 1], + [4, 2, 1, 1], + [], + [], + []], + [[4, 2, 2, 2], + [4, 2, 1], + [2, 1], + [1], + [1]]), + + ([[1, 1, 1, 1, 1, 1, 1], + [1, 1], [1, 1], [1, 1], + []], + [[1, 1, 1, 1], + [1, 1], [1, 1], [1, 1], + [1, 1, 1]]) + ]) +def test_balance(inp, out): + test = lambda *args, **kwargs: assert_balanced(inp, out, *args, **kwargs) + test = gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * len(inp))(test) + test() diff --git a/distributed/worker.py b/distributed/worker.py index e33ce5b1101..5ddb597e441 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1276,15 +1276,10 @@ def gather_dep(self, dep, slot, cause=None): @gen.coroutine def query_who_has(self, *deps): - try: + with log_errors(): response = yield self.scheduler.who_has(keys=deps) self.update_who_has(response) raise gen.Return(response) - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb; pdb.set_trace() - raise def update_who_has(self, who_has): try: