Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion distributed/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, interval, loop=None):
self.stream = None
self.message_count = 0
self.batch_count = 0
self.byte_count = 0
self.next_deadline = None

def start(self, stream):
Expand Down Expand Up @@ -79,7 +80,8 @@ def _background_send(self):
self.batch_count += 1
self.next_deadline = self.loop.time() + self.interval
try:
yield write(self.stream, payload)
nbytes = yield write(self.stream, payload)
self.byte_count += nbytes
except StreamClosedError:
logger.info("Batched Stream Closed")
break
Expand Down
8 changes: 8 additions & 0 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,11 @@ def restart(self, scheduler, **kwargs):

def transition(self, key, start, finish, *args, **kwargs):
pass

def add_worker(self, scheduler=None, worker=None, **kwargs):
""" Run when a new worker enters the cluster """
pass

def remove_worker(self, scheduler=None, worker=None, **kwargs):
""" Run when a worker leaves the cluster"""
pass
24 changes: 18 additions & 6 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def __init__(self, center=None, loop=None,
self.nbytes = dict()
self.worker_bytes = dict()
self.processing = dict()
self.rprocessing = defaultdict(set)
self.rprocessing = dict()
self.task_duration = {prefix: 0.00001 for prefix in fast_tasks}
self.unknown_durations = defaultdict(set)
self.host_restrictions = dict()
Expand All @@ -241,6 +241,7 @@ def __init__(self, center=None, loop=None,
self.tracebacks = dict()
self.exceptions_blame = dict()
self.datasets = dict()
self.n_tasks = 0

self.idle = SortedSet()
self.saturated = set()
Expand Down Expand Up @@ -526,7 +527,11 @@ def add_worker(self, stream=None, address=None, keys=(), ncores=None,
if recommendations:
self.transitions(recommendations)

# self.(address)
for plugin in self.plugins[:]:
try:
plugin.add_worker(scheduler=self, worker=address)
except Exception as e:
logger.exception(e)

logger.info("Register %s", str(address))
return 'OK'
Expand Down Expand Up @@ -770,6 +775,12 @@ def remove_worker(self, stream=None, address=None, safe=False):

self.transitions(recommendations)

for plugin in self.plugins[:]:
try:
plugin.remove_worker(scheduler=self, worker=address)
except Exception as e:
logger.exception(e)

if not self.processing:
logger.info("Lost all workers")

Expand Down Expand Up @@ -1845,12 +1856,12 @@ def transition_waiting_processing(self, key):
if len(self.idle) < 20: # smart but linear in small case
worker = min(self.idle, key=self.occupancy.get)
else: # dumb but fast in large case
worker = random.choice(self.idle)
worker = self.idle[self.n_tasks % len(self.idle)]
else:
if len(self.workers) < 20: # smart but linear in small case
worker = min(self.workers, key=self.occupancy.get)
else: # dumb but fast in large case
worker = random.choice(self.workers)
worker = self.workers[self.n_tasks % len(self.workers)]

assert worker

Expand All @@ -1862,12 +1873,13 @@ def transition_waiting_processing(self, key):
duration = 0.5

self.processing[worker][key] = duration
self.rprocessing[key].add(worker)
self.rprocessing[key] = {worker}
self.occupancy[worker] += duration
self.total_occupancy += duration
self.task_state[key] = 'processing'
self.consume_resources(key, worker)
self.check_idle_saturated(worker)
self.n_tasks += 1

# logger.debug("Send job to worker: %s, %s", worker, key)

Expand Down Expand Up @@ -2531,7 +2543,7 @@ def check_idle_saturated(self, worker):
self.idle.remove(worker)

pending = occ * (p - nc) / nc
if p > nc and pending > 0.2 and pending > 1.9 * avg:
if p > nc and pending > 0.4 and pending > 1.9 * avg:
self.saturated.add(worker)

def valid_workers(self, key):
Expand Down
161 changes: 94 additions & 67 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import random
from time import time

from toolz import first
from tornado.iostream import StreamClosedError
from tornado.ioloop import PeriodicCallback

Expand All @@ -14,17 +15,25 @@

BANDWIDTH = 100e6
LATENCY = 10e-3
log_2 = log(2)

logger = logging.getLogger(__name__)



class WorkStealing(SchedulerPlugin):
def __init__(self, scheduler):
self.scheduler = scheduler
self.stealable = [set() for i in range(15)]
self.stealable = dict()
self.key_stealable = dict()
self.stealable_unknown_durations = defaultdict(set)

self.cost_multipliers = [1 + 2 ** (i - 6) for i in range(15)]
self.cost_multipliers[0] = 1

for worker in scheduler.workers:
self.add_worker(worker=worker)

self._pc = PeriodicCallback(callback=self.balance,
callback_time=100,
io_loop=self.scheduler.loop)
Expand All @@ -33,6 +42,12 @@ def __init__(self, scheduler):
self.scheduler.extensions['stealing'] = self
self.log = deque(maxlen=100000)

def add_worker(self, scheduler=None, worker=None):
self.stealable[worker] = [set() for i in range(15)]

def remove_worker(self, scheduler=None, worker=None):
del self.stealable[worker]

def teardown(self):
self._pc.stop()

Expand All @@ -41,25 +56,33 @@ def transition(self, key, start, finish, compute_start=None,
if finish == 'processing':
self.put_key_in_stealable(key)

if start == 'processing' and finish == 'memory':
if start == 'processing':
self.remove_key_from_stealable(key)
ks = key_split(key)
if ks in self.stealable_unknown_durations:
for key in self.stealable_unknown_durations.pop(ks):
self.put_key_in_stealable(key, split=ks)
if finish == 'memory':
ks = key_split(key)
if ks in self.stealable_unknown_durations:
for k in self.stealable_unknown_durations.pop(ks):
if self.scheduler.task_state[k] == 'processing':
self.put_key_in_stealable(k, split=ks)


def put_key_in_stealable(self, key, split=None):
ratio, loc = self.steal_time_ratio(key, split=split)
if ratio is not None:
self.stealable[loc].add(key)
self.key_stealable[key] = loc
try:
worker = first(self.scheduler.rprocessing[key])
except Exception as e:
logger.exception(e)
import pdb; pdb.set_trace()
cost_multiplier, level = self.steal_time_ratio(key, split=split)
if cost_multiplier is not None:
self.stealable[worker][level].add(key)
self.key_stealable[key] = (worker, level)

def remove_key_from_stealable(self, key):
loc = self.key_stealable.pop(key, None)
if loc is not None:
result = self.key_stealable.pop(key, None)
if result is not None:
worker, level = result
try:
self.stealable[loc].remove(key)
self.stealable[worker][level].remove(key)
except:
pass

Expand All @@ -69,8 +92,9 @@ def steal_time_ratio(self, key, split=None):
Returns
-------

ratio: The compute/communication time ratio of the task
loc: The self.stealable bin into which this key should go
cost_multiplier: The increased cost from moving this task as a factor.
For example a result of zero implies a task without dependencies.
level: The location within a stealable list to place this value
"""
if (key not in self.scheduler.loose_restrictions
and (key in self.scheduler.host_restrictions or
Expand All @@ -79,7 +103,7 @@ def steal_time_ratio(self, key, split=None):
return None, None # don't steal

if not self.scheduler.dependencies[key]: # no dependencies fast path
return 10000, 0
return 0, 0

nbytes = sum(self.scheduler.nbytes.get(k, 1000)
for k in self.scheduler.dependencies[key])
Expand All @@ -94,26 +118,21 @@ def steal_time_ratio(self, key, split=None):
self.stealable_unknown_durations[split].add(key)
return None, None
else:
try:
ratio = compute_time / transfer_time
except ZeroDivisionError:
ratio = 10000
if ratio == 10000:
loc = 0
elif ratio > 32:
loc = 1
elif ratio < 2**-8:
loc = -1
else:
loc = int(-round(log(ratio) / log(2), 0) + 5) + 1
return ratio, loc
cost_multiplier = transfer_time / compute_time
if cost_multiplier > 100:
return None, None

level = int(round(log(cost_multiplier) / log_2 + 6, 0))
level = max(1, level)
return cost_multiplier, level

def move_task(self, key, victim, thief):
with log_errors():
if self.scheduler.validate:
if victim not in self.scheduler.rprocessing[key]:
import pdb; pdb.set_trace()

# self.remove_key_from_stealable(key)
logger.info("Moved %s, %s: %2f -> %s: %2f", key,
victim, self.scheduler.occupancy[victim],
thief, self.scheduler.occupancy[thief])
Expand All @@ -139,53 +158,61 @@ def move_task(self, key, victim, thief):

def balance(self):
with log_errors():
if not self.scheduler.idle or not self.scheduler.saturated:
i = 0
s = self.scheduler
occupancy = s.occupancy
idle = s.idle
saturated = s.saturated
if not idle or not saturated:
return

broken = False
start = time()

with log_errors():
start = time()
for level, stealable in enumerate(self.stealable[:-1]):
if broken or not stealable:
continue
broken = False
seen = False
acted = False

original = stealable
if len(s.saturated) < 20:
saturated = sorted(saturated, key=occupancy.get, reverse=True)

ratio = 2 ** (level - 5 + 1)
if len(idle) < 20:
idle = sorted(idle, key=occupancy.get)

n_stealable = sum(len(s) for s in self.stealable[level:-1])
duration_if_hold = n_stealable / len(self.scheduler.saturated)
duration_if_steal = ratio

if level > 1 and duration_if_hold < duration_if_steal:
break
for level, cost_multiplier in enumerate(self.cost_multipliers):
if not idle or not saturated:
break
for sat in list(saturated):
stealable = self.stealable[sat][level]
if not stealable or not idle:
continue
else:
seen = True

for key in list(stealable):
if self.scheduler.task_state.get(key) != 'processing':
original.remove(key)
continue
victim = max(self.scheduler.rprocessing[key],
key=self.scheduler.occupancy.get)
if victim not in self.scheduler.idle:
thief = random.choice(self.scheduler.idle)
self.move_task(key, victim, thief)
self.log.append((level, victim, thief, key))
self.scheduler.check_idle_saturated(victim)
self.scheduler.check_idle_saturated(thief)
original.remove(key)

if not self.scheduler.idle or not self.scheduler.saturated:
broken = True
break

stop = time()
if self.scheduler.digests:
self.scheduler.digests['steal-duration'].add(stop - start)

def restart(self):
for stealable in self.stealable:
stealable.clear()
idl = idle[i % len(idle)]
i += 1
duration = s.task_duration[key_split(key)]

if (occupancy[idl] + cost_multiplier * duration
<= occupancy[sat] - duration / 2):
self.move_task(key, sat, idl)
self.log.append((level, sat, idl, key))
self.scheduler.check_idle_saturated(sat)
self.scheduler.check_idle_saturated(idl)
stealable.remove(key)
acted = True
if seen and not acted:
break

stop = time()
if self.scheduler.digests:
self.scheduler.digests['steal-duration'].add(stop - start)

def restart(self, scheduler):
for stealable in self.stealable.values():
for s in stealable:
s.clear()

self.key_stealable.clear()
self.stealable_unknown_durations.clear()
Expand Down
2 changes: 2 additions & 0 deletions distributed/tests/test_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def test_BatchedSend():
result = yield read(stream); assert result == ['hello', 'hello', 'world']
result = yield read(stream); assert result == ['HELLO', 'HELLO']

assert b.byte_count > 1


@gen_test()
def test_send_before_start():
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,4 @@ def test_close_on_disconnect(s, w):
start = time()
while w.status != 'closed':
yield gen.sleep(0.01)
assert time() < start + 5
assert time() < start + 9
15 changes: 15 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,3 +828,18 @@ def test_learn_occupancy_2(c, s, a, b):
yield gen.sleep(0.01)

assert 50 < s.total_occupancy < 200


@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 30)
def test_balance_many_workers(c, s, *workers):
futures = c.map(slowinc, range(20), delay=0.2)
yield _wait(futures)
assert set(map(len, s.has_what.values())) == {0, 1}


@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 30)
def test_balance_many_workers_2(c, s, *workers):
s.extensions['stealing']._pc.callback_time = 100000000
futures = c.map(slowinc, range(90), delay=0.2)
yield _wait(futures)
assert set(map(len, s.has_what.values())) == {3}
Loading