diff --git a/distributed/core.py b/distributed/core.py index 4efa17680f1..15205f4f72c 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -135,6 +135,7 @@ def __init__( connection_args=None, timeout=None, io_loop=None, + **kwargs, ): self.handlers = { "identity": self.identity, @@ -236,6 +237,8 @@ def set_thread_ident(): self.__stopped = False + super().__init__(**kwargs) + @property def status(self): return self._status diff --git a/distributed/http/scheduler/info.py b/distributed/http/scheduler/info.py index 6e5a222dd23..96199faba38 100644 --- a/distributed/http/scheduler/info.py +++ b/distributed/http/scheduler/info.py @@ -33,7 +33,13 @@ def get(self): "workers.html", title="Workers", scheduler=self.server, - **merge(self.server.__dict__, ns, self.extra, rel_path_statics), + **merge( + self.server.__dict__, + self.server.__pdict__, + ns, + self.extra, + rel_path_statics, + ), ) @@ -49,7 +55,13 @@ def get(self, worker): title="Worker: " + worker, scheduler=self.server, Worker=worker, - **merge(self.server.__dict__, ns, self.extra, rel_path_statics), + **merge( + self.server.__dict__, + self.server.__pdict__, + ns, + self.extra, + rel_path_statics, + ), ) @@ -65,7 +77,13 @@ def get(self, task): title="Task: " + task, Task=task, scheduler=self.server, - **merge(self.server.__dict__, ns, self.extra, rel_path_statics), + **merge( + self.server.__dict__, + self.server.__pdict__, + ns, + self.extra, + rel_path_statics, + ), ) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 9b4a5b85a3b..896b10e1380 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1527,43 +1527,33 @@ def _task_key_or_none(task): return task.key if task is not None else None -class Scheduler(ServerNode): - """Dynamic distributed task scheduler - - The scheduler tracks the current state of workers, data, and computations. - The scheduler listens for events and responds by controlling workers - appropriately. It continuously tries to use the workers to execute an ever - growing dask graph. - - All events are handled quickly, in linear time with respect to their input - (which is often of constant size) and generally within a millisecond. To - accomplish this the scheduler tracks a lot of state. Every operation - maintains the consistency of this state. - - The scheduler communicates with the outside world through Comm objects. - It maintains a consistent and valid view of the world even when listening - to several clients at once. - - A Scheduler is typically started either with the ``dask-scheduler`` - executable:: +@cclass +class SchedulerState: + """Underlying task state of dynamic scheduler - $ dask-scheduler - Scheduler started at 127.0.0.1:8786 + Tracks the current state of workers, data, and computations. - Or within a LocalCluster a Client starts up without connection - information:: + Handles transitions between different task states. Notifies the + Scheduler of changes by messaging passing through Queues, which the + Scheduler listens to responds accordingly. - >>> c = Client() # doctest: +SKIP - >>> c.cluster.scheduler # doctest: +SKIP - Scheduler(...) + All events are handled quickly, in linear time with respect to their + input (which is often of constant size) and generally within a + millisecond. Additionally when Cythonized, this can be faster still. + To accomplish this the scheduler tracks a lot of state. Every + operation maintains the consistency of this state. - Users typically do not interact with the scheduler directly but rather with - the client object ``Client``. + Users typically do not interact with ``Transitions`` directly. Instead + users interact with the ``Client``, which in turn engages the + ``Scheduler`` affecting different transitions here under-the-hood. In + the background ``Worker``s also engage with the ``Scheduler`` + affecting these state transitions as well. **State** - The scheduler contains the following state variables. Each variable is - listed along with what it stores and a brief description. + The ``Transitions`` object contains the following state variables. + Each variable is listed along with what it stores and a brief + description. * **tasks:** ``{task key: TaskState}`` Tasks currently known to the scheduler @@ -1577,3921 +1567,4241 @@ class Scheduler(ServerNode): * **saturated:** ``{WorkerState}``: Set of workers that are not over-utilized - * **host_info:** ``{hostname: dict}``: - Information about each worker host - * **clients:** ``{client key: ClientState}`` Clients currently connected to the scheduler - * **services:** ``{str: port}``: - Other services running on this scheduler, like Bokeh - * **loop:** ``IOLoop``: - The running Tornado IOLoop - * **client_comms:** ``{client key: Comm}`` - For each client, a Comm object used to receive task requests and - report task status updates. - * **stream_comms:** ``{worker key: Comm}`` - For each worker, a Comm object from which we both accept stimuli and - report results * **task_duration:** ``{key-prefix: time}`` Time we expect certain functions to take, e.g. ``{'sum': 0.25}`` """ - default_port = 8786 - _instances = weakref.WeakSet() + _aliases: dict + _bandwidth: double + _clients: dict + _extensions: dict + _host_info: object + _idle: object + _idle_dv: dict + _n_tasks: Py_ssize_t + _resources: object + _saturated: set + _tasks: dict + _task_metadata: dict + _total_nthreads: Py_ssize_t + _total_occupancy: double + _unknown_durations: object + _unrunnable: set + _validate: bint + _workers: object + _workers_dv: dict def __init__( self, - loop=None, - delete_interval="500ms", - synchronize_worker_interval="60s", - services=None, - service_kwargs=None, - allowed_failures=None, - extensions=None, - validate=None, - scheduler_file=None, - security=None, - worker_ttl=None, - idle_timeout=None, - interface=None, - host=None, - port=0, - protocol=None, - dashboard_address=None, - dashboard=None, - http_prefix="/", - preload=None, - preload_argv=(), - plugins=(), + aliases: dict = None, + clients: dict = None, + workers=None, + host_info=None, + resources=None, + tasks: dict = None, + unrunnable: set = None, + validate: bint = False, **kwargs, ): - self._setup_logging(logger) - - # Attributes - if allowed_failures is None: - allowed_failures = dask.config.get("distributed.scheduler.allowed-failures") - self.allowed_failures = allowed_failures - if validate is None: - validate = dask.config.get("distributed.scheduler.validate") - self.validate = validate - self.proc = psutil.Process() - self.delete_interval = parse_timedelta(delete_interval, default="ms") - self.synchronize_worker_interval = parse_timedelta( - synchronize_worker_interval, default="ms" - ) - self.digests = None - self.service_specs = services or {} - self.service_kwargs = service_kwargs or {} - self.services = {} - self.scheduler_file = scheduler_file - worker_ttl = worker_ttl or dask.config.get("distributed.scheduler.worker-ttl") - self.worker_ttl = parse_timedelta(worker_ttl) if worker_ttl else None - idle_timeout = idle_timeout or dask.config.get( - "distributed.scheduler.idle-timeout" + if aliases is not None: + self._aliases = aliases + else: + self._aliases = dict() + self._bandwidth = parse_bytes( + dask.config.get("distributed.scheduler.bandwidth") ) - if idle_timeout: - self.idle_timeout = parse_timedelta(idle_timeout) + if clients is not None: + self._clients = clients else: - self.idle_timeout = None - self.idle_since = time() - self.time_started = self.idle_since # compatibility for dask-gateway - self._lock = asyncio.Lock() - self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth")) - self.bandwidth_workers = defaultdict(float) - self.bandwidth_types = defaultdict(float) + self._clients = dict() + self._clients["fire-and-forget"] = ClientState("fire-and-forget") + self._extensions = dict() + if host_info is not None: + self._host_info = host_info + else: + self._host_info = defaultdict(dict) + self._idle = sortedcontainers.SortedDict() + self._idle_dv: dict = cast(dict, self._idle) + self._n_tasks = 0 + if resources is not None: + self._resources = resources + else: + self._resources = defaultdict(dict) + self._saturated = set() + if tasks is not None: + self._tasks = tasks + else: + self._tasks = dict() + self._task_metadata = dict() + self._total_nthreads = 0 + self._total_occupancy = 0 + self._unknown_durations = defaultdict(set) + if unrunnable is not None: + self._unrunnable = unrunnable + else: + self._unrunnable = set() + self._validate = validate + if workers is not None: + self._workers = workers + else: + self._workers = sortedcontainers.SortedDict() + self._workers_dv: dict = cast(dict, self._workers) + super().__init__(**kwargs) - if not preload: - preload = dask.config.get("distributed.scheduler.preload") - if not preload_argv: - preload_argv = dask.config.get("distributed.scheduler.preload-argv") - self.preloads = preloading.process_preloads(self, preload, preload_argv) + @property + def aliases(self): + return self._aliases - if isinstance(security, dict): - security = Security(**security) - self.security = security or Security() - assert isinstance(self.security, Security) - self.connection_args = self.security.get_connection_args("scheduler") - self.connection_args["handshake_overrides"] = { # common denominator - "pickle-protocol": 4 - } + @property + def bandwidth(self): + return self._bandwidth - self._start_address = addresses_from_user_args( - host=host, - port=port, - interface=interface, - protocol=protocol, - security=security, - default_port=self.default_port, - ) + @property + def clients(self): + return self._clients - http_server_modules = dask.config.get("distributed.scheduler.http.routes") - show_dashboard = dashboard or (dashboard is None and dashboard_address) - missing_bokeh = False - # install vanilla route if show_dashboard but bokeh is not installed - if show_dashboard: - try: - import distributed.dashboard.scheduler - except ImportError: - missing_bokeh = True - http_server_modules.append("distributed.http.scheduler.missing_bokeh") - routes = get_handlers( - server=self, modules=http_server_modules, prefix=http_prefix - ) - self.start_http_server(routes, dashboard_address, default_port=8787) - if show_dashboard and not missing_bokeh: - distributed.dashboard.scheduler.connect( - self.http_application, self.http_server, self, prefix=http_prefix - ) + @property + def extensions(self): + return self._extensions - # Communication state - self.loop = loop or IOLoop.current() - self.client_comms = dict() - self.stream_comms = dict() - self._worker_coroutines = [] - self._ipython_kernel = None + @property + def host_info(self): + return self._host_info - # Task state - self.tasks = dict() - self.task_groups = dict() - self.task_prefixes = dict() - for old_attr, new_attr, wrap in [ - ("priority", "priority", None), - ("dependencies", "dependencies", _legacy_task_key_set), - ("dependents", "dependents", _legacy_task_key_set), - ("retries", "retries", None), - ]: - func = operator.attrgetter(new_attr) - if wrap is not None: - func = compose(wrap, func) - setattr(self, old_attr, _StateLegacyMapping(self.tasks, func)) + @property + def idle(self): + return self._idle - for old_attr, new_attr, wrap in [ - ("nbytes", "nbytes", None), - ("who_wants", "who_wants", _legacy_client_key_set), - ("who_has", "who_has", _legacy_worker_key_set), - ("waiting", "waiting_on", _legacy_task_key_set), - ("waiting_data", "waiters", _legacy_task_key_set), - ("rprocessing", "processing_on", None), - ("host_restrictions", "host_restrictions", None), - ("worker_restrictions", "worker_restrictions", None), - ("resource_restrictions", "resource_restrictions", None), - ("suspicious_tasks", "suspicious", None), - ("exceptions", "exception", None), - ("tracebacks", "traceback", None), - ("exceptions_blame", "exception_blame", _task_key_or_none), - ]: - func = operator.attrgetter(new_attr) - if wrap is not None: - func = compose(wrap, func) - setattr(self, old_attr, _OptionalStateLegacyMapping(self.tasks, func)) + @property + def n_tasks(self): + return self._n_tasks - for old_attr, new_attr, wrap in [ - ("loose_restrictions", "loose_restrictions", None) - ]: - func = operator.attrgetter(new_attr) - if wrap is not None: - func = compose(wrap, func) - setattr(self, old_attr, _StateLegacySet(self.tasks, func)) + @property + def resources(self): + return self._resources - self.generation = 0 - self._last_client = None - self._last_time = 0 - self.unrunnable = set() + @property + def saturated(self): + return self._saturated - self.n_tasks = 0 - self.task_metadata = dict() - self.datasets = dict() + @property + def tasks(self): + return self._tasks - # Prefix-keyed containers - self.unknown_durations = defaultdict(set) + @property + def task_metadata(self): + return self._task_metadata - # Client state - self.clients = dict() - for old_attr, new_attr, wrap in [ - ("wants_what", "wants_what", _legacy_task_key_set) - ]: - func = operator.attrgetter(new_attr) - if wrap is not None: - func = compose(wrap, func) - setattr(self, old_attr, _StateLegacyMapping(self.clients, func)) - self.clients["fire-and-forget"] = ClientState("fire-and-forget") + @property + def total_nthreads(self): + return self._total_nthreads - # Worker state - self.workers = sortedcontainers.SortedDict() - for old_attr, new_attr, wrap in [ - ("nthreads", "nthreads", None), - ("worker_bytes", "nbytes", None), - ("worker_resources", "resources", None), - ("used_resources", "used_resources", None), - ("occupancy", "occupancy", None), - ("worker_info", "metrics", None), - ("processing", "processing", _legacy_task_key_dict), - ("has_what", "has_what", _legacy_task_key_set), - ]: - func = operator.attrgetter(new_attr) - if wrap is not None: - func = compose(wrap, func) - setattr(self, old_attr, _StateLegacyMapping(self.workers, func)) + @property + def total_occupancy(self): + return self._total_occupancy - self.idle = sortedcontainers.SortedDict() - self.saturated = set() + @total_occupancy.setter + def total_occupancy(self, v: double): + self._total_occupancy = v - self.total_nthreads = 0 - self.total_occupancy = 0 - self.host_info = defaultdict(dict) - self.resources = defaultdict(dict) - self.aliases = dict() + @property + def unknown_durations(self): + return self._unknown_durations - self._task_state_collections = [self.unrunnable] + @property + def unrunnable(self): + return self._unrunnable - self._worker_collections = [ - self.workers, - self.host_info, - self.resources, - self.aliases, - ] + @property + def validate(self): + return self._validate - self.extensions = {} - self.plugins = list(plugins) - self.transition_log = deque( - maxlen=dask.config.get("distributed.scheduler.transition-log-length") - ) - self.log = deque( - maxlen=dask.config.get("distributed.scheduler.transition-log-length") - ) - self.events = defaultdict(lambda: deque(maxlen=100000)) - self.event_counts = defaultdict(int) - self.worker_plugins = [] + @validate.setter + def validate(self, v: bint): + self._validate = v - worker_handlers = { - "task-finished": self.handle_task_finished, - "task-erred": self.handle_task_erred, - "release": self.handle_release_data, - "release-worker-data": self.release_worker_data, - "add-keys": self.add_keys, - "missing-data": self.handle_missing_data, - "long-running": self.handle_long_running, - "reschedule": self.reschedule, - "keep-alive": lambda *args, **kwargs: None, - "log-event": self.log_worker_event, - } + @property + def workers(self): + return self._workers - client_handlers = { - "update-graph": self.update_graph, - "update-graph-hlg": self.update_graph_hlg, - "client-desires-keys": self.client_desires_keys, - "update-data": self.update_data, - "report-key": self.report_on_key, - "client-releases-keys": self.client_releases_keys, - "heartbeat-client": self.client_heartbeat, - "close-client": self.remove_client, - "restart": self.restart, + @property + def __pdict__(self): + return { + "bandwidth": self._bandwidth, + "resources": self._resources, + "saturated": self._saturated, + "unrunnable": self._unrunnable, + "n_tasks": self._n_tasks, + "unknown_durations": self._unknown_durations, + "validate": self._validate, + "tasks": self._tasks, + "total_nthreads": self._total_nthreads, + "total_occupancy": self._total_occupancy, + "extensions": self._extensions, + "clients": self._clients, + "workers": self._workers, + "idle": self._idle, + "host_info": self._host_info, } - self.handlers = { - "register-client": self.add_client, - "scatter": self.scatter, - "register-worker": self.add_worker, - "unregister": self.remove_worker, - "gather": self.gather, - "cancel": self.stimulus_cancel, - "retry": self.stimulus_retry, - "feed": self.feed, - "terminate": self.close, - "broadcast": self.broadcast, - "proxy": self.proxy, - "ncores": self.get_ncores, - "has_what": self.get_has_what, - "who_has": self.get_who_has, - "processing": self.get_processing, - "call_stack": self.get_call_stack, - "profile": self.get_profile, - "performance_report": self.performance_report, - "get_logs": self.get_logs, - "logs": self.get_logs, - "worker_logs": self.get_worker_logs, - "log_event": self.log_worker_event, - "events": self.get_events, - "nbytes": self.get_nbytes, - "versions": self.versions, - "add_keys": self.add_keys, - "rebalance": self.rebalance, - "replicate": self.replicate, - "start_ipython": self.start_ipython, - "run_function": self.run_function, - "update_data": self.update_data, - "set_resources": self.add_resources, - "retire_workers": self.retire_workers, - "get_metadata": self.get_metadata, - "set_metadata": self.set_metadata, - "heartbeat_worker": self.heartbeat_worker, - "get_task_status": self.get_task_status, - "get_task_stream": self.get_task_stream, - "register_worker_plugin": self.register_worker_plugin, - "adaptive_target": self.adaptive_target, - "workers_to_close": self.workers_to_close, - "subscribe_worker_status": self.subscribe_worker_status, - "start_task_metadata": self.start_task_metadata, - "stop_task_metadata": self.stop_task_metadata, - } + def transition_released_waiting(self, key): + try: + ts: TaskState = self._tasks[key] + dts: TaskState + worker_msgs: dict = {} + client_msgs: dict = {} - self._transitions = { - ("released", "waiting"): self.transition_released_waiting, - ("waiting", "released"): self.transition_waiting_released, - ("waiting", "processing"): self.transition_waiting_processing, - ("waiting", "memory"): self.transition_waiting_memory, - ("processing", "released"): self.transition_processing_released, - ("processing", "memory"): self.transition_processing_memory, - ("processing", "erred"): self.transition_processing_erred, - ("no-worker", "released"): self.transition_no_worker_released, - ("no-worker", "waiting"): self.transition_no_worker_waiting, - ("released", "forgotten"): self.transition_released_forgotten, - ("memory", "forgotten"): self.transition_memory_forgotten, - ("erred", "forgotten"): self.transition_released_forgotten, - ("erred", "released"): self.transition_erred_released, - ("memory", "released"): self.transition_memory_released, - ("released", "erred"): self.transition_released_erred, - } + if self._validate: + assert ts._run_spec + assert not ts._waiting_on + assert not ts._who_has + assert not ts._processing_on + assert not any([dts._state == "forgotten" for dts in ts._dependencies]) - connection_limit = get_fileno_limit() / 2 + if ts._has_lost_dependencies: + return {key: "forgotten"}, worker_msgs, client_msgs - super().__init__( - handlers=self.handlers, - stream_handlers=merge(worker_handlers, client_handlers), - io_loop=self.loop, - connection_limit=connection_limit, - deserialize=False, - connection_args=self.connection_args, - **kwargs, - ) + ts.state = "waiting" - if self.worker_ttl: - pc = PeriodicCallback(self.check_worker_ttl, self.worker_ttl) - self.periodic_callbacks["worker-ttl"] = pc + recommendations: dict = {} - if self.idle_timeout: - pc = PeriodicCallback(self.check_idle, self.idle_timeout / 4) - self.periodic_callbacks["idle-timeout"] = pc + dts: TaskState + for dts in ts._dependencies: + if dts._exception_blame: + ts._exception_blame = dts._exception_blame + recommendations[key] = "erred" + return recommendations, worker_msgs, client_msgs - if extensions is None: - extensions = list(DEFAULT_EXTENSIONS) - if dask.config.get("distributed.scheduler.work-stealing"): - extensions.append(WorkStealing) - for ext in extensions: - ext(self) + for dts in ts._dependencies: + dep = dts._key + if not dts._who_has: + ts._waiting_on.add(dts) + if dts._state == "released": + recommendations[dep] = "waiting" + else: + dts._waiters.add(ts) - setproctitle("dask-scheduler [not started]") - Scheduler._instances.add(self) - self.rpc.allow_offload = False - self.status = Status.undefined + ts._waiters = {dts for dts in ts._dependents if dts._state == "waiting"} - ################## - # Administration # - ################## + if not ts._waiting_on: + if self._workers_dv: + recommendations[key] = "processing" + else: + self._unrunnable.add(ts) + ts.state = "no-worker" - def __repr__(self): - return '' % ( - self.address, - len(self.workers), - self.total_nthreads, - ) + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - def identity(self, comm=None): - """ Basic information about ourselves and our cluster """ - d = { - "type": type(self).__name__, - "id": str(self.id), - "address": self.address, - "services": {key: v.port for (key, v) in self.services.items()}, - "started": self.time_started, - "workers": { - worker.address: worker.identity() for worker in self.workers.values() - }, - } - return d + pdb.set_trace() + raise - def get_worker_service_addr(self, worker, service_name, protocol=False): - """ - Get the (host, port) address of the named service on the *worker*. - Returns None if the service doesn't exist. + def transition_no_worker_waiting(self, key): + try: + ts: TaskState = self._tasks[key] + dts: TaskState + worker_msgs: dict = {} + client_msgs: dict = {} - Parameters - ---------- - worker : address - service_name : str - Common services include 'bokeh' and 'nanny' - protocol : boolean - Whether or not to include a full address with protocol (True) - or just a (host, port) pair - """ - ws: WorkerState = self.workers[worker] - port = ws._services.get(service_name) - if port is None: - return None - elif protocol: - return "%(protocol)s://%(host)s:%(port)d" % { - "protocol": ws._address.split("://")[0], - "host": ws.host, - "port": port, - } - else: - return ws.host, port + if self._validate: + assert ts in self._unrunnable + assert not ts._waiting_on + assert not ts._who_has + assert not ts._processing_on - async def start(self): - """ Clear out old state and restart all running coroutines """ - await super().start() - assert self.status != Status.running + self._unrunnable.remove(ts) - enable_gc_diagnosis() + if ts._has_lost_dependencies: + return {key: "forgotten"}, worker_msgs, client_msgs - self.clear_task_state() + recommendations: dict = {} - with suppress(AttributeError): - for c in self._worker_coroutines: - c.cancel() + for dts in ts._dependencies: + dep = dts._key + if not dts._who_has: + ts._waiting_on.add(dts) + if dts._state == "released": + recommendations[dep] = "waiting" + else: + dts._waiters.add(ts) - for addr in self._start_address: - await self.listen( - addr, - allow_offload=False, - handshake_overrides={"pickle-protocol": 4, "compression": None}, - **self.security.get_listen_args("scheduler"), - ) - self.ip = get_address_host(self.listen_address) - listen_ip = self.ip + ts.state = "waiting" - if listen_ip == "0.0.0.0": - listen_ip = "" + if not ts._waiting_on: + if self._workers_dv: + recommendations[key] = "processing" + else: + self._unrunnable.add(ts) + ts.state = "no-worker" - if self.address.startswith("inproc://"): - listen_ip = "localhost" + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - # Services listen on all addresses - self.start_services(listen_ip) + pdb.set_trace() + raise - for listener in self.listeners: - logger.info(" Scheduler at: %25s", listener.contact_address) - for k, v in self.services.items(): - logger.info("%11s at: %25s", k, "%s:%d" % (listen_ip, v.port)) + @ccall + @exceptval(check=False) + def decide_worker(self, ts: TaskState) -> WorkerState: + """ + Decide on a worker for task *ts*. Return a WorkerState. + """ + ws: WorkerState = None + valid_workers: set = self.valid_workers(ts) - self.loop.add_callback(self.reevaluate_occupancy) + if ( + valid_workers is not None + and not valid_workers + and not ts._loose_restrictions + and self._workers_dv + ): + self._unrunnable.add(ts) + ts.state = "no-worker" + return ws - if self.scheduler_file: - with open(self.scheduler_file, "w") as f: - json.dump(self.identity(), f, indent=2) - - fn = self.scheduler_file # remove file when we close the process - - def del_scheduler_file(): - if os.path.exists(fn): - os.remove(fn) - - weakref.finalize(self, del_scheduler_file) + if ts._dependencies or valid_workers is not None: + ws = decide_worker( + ts, + self._workers_dv.values(), + valid_workers, + partial(self.worker_objective, ts), + ) + else: + worker_pool = self._idle or self._workers + worker_pool_dv = cast(dict, worker_pool) + n_workers: Py_ssize_t = len(worker_pool_dv) + if n_workers < 20: # smart but linear in small case + ws = min(worker_pool.values(), key=operator.attrgetter("occupancy")) + else: # dumb but fast in large case + ws = worker_pool.values()[self._n_tasks % n_workers] - for preload in self.preloads: - await preload.start() + if self._validate: + assert ws is None or isinstance(ws, WorkerState), ( + type(ws), + ws, + ) + assert ws._address in self._workers_dv - await asyncio.gather(*[plugin.start(self) for plugin in self.plugins]) + return ws - self.start_periodic_callbacks() + @ccall + def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> double: + """Estimate task duration using worker state and task state. - setproctitle("dask-scheduler [%s]" % (self.address,)) - return self + If a task takes longer than twice the current average duration we + estimate the task duration to be 2x current-runtime, otherwise we set it + to be the average duration. + """ + exec_time: double = ws._executing.get(ts, 0) + duration: double = self.get_task_duration(ts) + total_duration: double + if exec_time > 2 * duration: + total_duration = 2 * exec_time + else: + comm: double = self.get_comm_cost(ts, ws) + total_duration = duration + comm + ws._processing[ts] = total_duration + return total_duration - async def close(self, comm=None, fast=False, close_workers=False): - """Send cleanup signal to all coroutines then wait until finished + def transition_waiting_processing(self, key): + try: + ts: TaskState = self._tasks[key] + dts: TaskState + worker_msgs: dict = {} + client_msgs: dict = {} - See Also - -------- - Scheduler.cleanup - """ - if self.status in (Status.closing, Status.closed, Status.closing_gracefully): - await self.finished() - return - self.status = Status.closing + if self._validate: + assert not ts._waiting_on + assert not ts._who_has + assert not ts._exception_blame + assert not ts._processing_on + assert not ts._has_lost_dependencies + assert ts not in self._unrunnable + assert all([dts._who_has for dts in ts._dependencies]) - logger.info("Scheduler closing...") - setproctitle("dask-scheduler [closing]") + ws: WorkerState = self.decide_worker(ts) + if ws is None: + return {}, worker_msgs, client_msgs + worker = ws._address - for preload in self.preloads: - await preload.teardown() + duration_estimate = self.set_duration_estimate(ts, ws) + ts._processing_on = ws + ws._occupancy += duration_estimate + self._total_occupancy += duration_estimate + ts.state = "processing" + self.consume_resources(ts, ws) + self.check_idle_saturated(ws) + self._n_tasks += 1 - if close_workers: - await self.broadcast(msg={"op": "close_gracefully"}, nanny=True) - for worker in self.workers: - self.worker_send(worker, {"op": "close"}) - for i in range(20): # wait a second for send signals to clear - if self.workers: - await asyncio.sleep(0.05) - else: - break + if ts._actor: + ws._actors.add(ts) - await asyncio.gather(*[plugin.close() for plugin in self.plugins]) + # logger.debug("Send job to worker: %s, %s", worker, key) - for pc in self.periodic_callbacks.values(): - pc.stop() - self.periodic_callbacks.clear() + worker_msgs[worker] = _task_to_msg(self, ts) - self.stop_services() + return {}, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - for ext in self.extensions.values(): - with suppress(AttributeError): - ext.teardown() - logger.info("Scheduler closing all comms") + pdb.set_trace() + raise - futures = [] - for w, comm in list(self.stream_comms.items()): - if not comm.closed(): - comm.send({"op": "close", "report": False}) - comm.send({"op": "close-stream"}) - with suppress(AttributeError): - futures.append(comm.close()) + def transition_waiting_memory( + self, key, nbytes=None, type=None, typename: str = None, worker=None, **kwargs + ): + try: + ws: WorkerState = self._workers_dv[worker] + ts: TaskState = self._tasks[key] + worker_msgs: dict = {} + client_msgs: dict = {} - for future in futures: # TODO: do all at once - await future + if self._validate: + assert not ts._processing_on + assert ts._waiting_on + assert ts._state == "waiting" - for comm in self.client_comms.values(): - comm.abort() + ts._waiting_on.clear() - await self.rpc.close() + if nbytes is not None: + ts.set_nbytes(nbytes) - self.status = Status.closed - self.stop() - await super().close() + self.check_idle_saturated(ws) - setproctitle("dask-scheduler [closed]") - disable_gc_diagnosis() + recommendations: dict = {} + client_msgs: dict = {} - async def close_worker(self, comm=None, worker=None, safe=None): - """Remove a worker from the cluster + _add_to_memory( + self, ts, ws, recommendations, client_msgs, type=type, typename=typename + ) - This both removes the worker from our local state and also sends a - signal to the worker to shut down. This works regardless of whether or - not the worker has a nanny process restarting it - """ - logger.info("Closing worker %s", worker) - with log_errors(): - self.log_event(worker, {"action": "close-worker"}) - ws: WorkerState = self.workers[worker] - nanny_addr = ws._nanny - address = nanny_addr or worker + if self._validate: + assert not ts._processing_on + assert not ts._waiting_on + assert ts._who_has - self.worker_send(worker, {"op": "close", "report": False}) - await self.remove_worker(address=worker, safe=safe) + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - ########### - # Stimuli # - ########### + pdb.set_trace() + raise - def heartbeat_worker( + def transition_processing_memory( self, - comm=None, - address=None, - resolve_address=True, - now=None, - resources=None, - host_info=None, - metrics=None, - executing=None, + key, + nbytes=None, + type=None, + typename: str = None, + worker=None, + startstops=None, + **kwargs, ): - address = self.coerce_address(address, resolve_address) - address = normalize_address(address) - if address not in self.workers: - return {"status": "missing"} + ws: WorkerState + wws: WorkerState + worker_msgs: dict = {} + client_msgs: dict = {} + try: + ts: TaskState = self._tasks[key] + assert worker + assert isinstance(worker, str) - host = get_address_host(address) - local_now = time() - now = now or time() - assert metrics - host_info = host_info or {} + if self._validate: + assert ts._processing_on + ws = ts._processing_on + assert ts in ws._processing + assert not ts._waiting_on + assert not ts._who_has, (ts, ts._who_has) + assert not ts._exception_blame + assert ts._state == "processing" - self.host_info[host]["last-seen"] = local_now - frac = 1 / len(self.workers) - self.bandwidth = ( - self.bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac - ) - for other, (bw, count) in metrics["bandwidth"]["workers"].items(): - if (address, other) not in self.bandwidth_workers: - self.bandwidth_workers[address, other] = bw / count - else: - alpha = (1 - frac) ** count - self.bandwidth_workers[address, other] = self.bandwidth_workers[ - address, other - ] * alpha + bw * (1 - alpha) - for typ, (bw, count) in metrics["bandwidth"]["types"].items(): - if typ not in self.bandwidth_types: - self.bandwidth_types[typ] = bw / count - else: - alpha = (1 - frac) ** count - self.bandwidth_types[typ] = self.bandwidth_types[typ] * alpha + bw * ( - 1 - alpha + ws = self._workers_dv.get(worker) + if ws is None: + return {key: "released"}, worker_msgs, client_msgs + + if ws != ts._processing_on: # someone else has this task + logger.info( + "Unexpected worker completed task, likely due to" + " work stealing. Expected: %s, Got: %s, Key: %s", + ts._processing_on, + ws, + key, ) + return {}, worker_msgs, client_msgs - ws: WorkerState = self.workers[address] + if startstops: + L = list() + for startstop in startstops: + stop = startstop["stop"] + start = startstop["start"] + action = startstop["action"] + if action == "compute": + L.append((start, stop)) - ws._last_seen = time() + # record timings of all actions -- a cheaper way of + # getting timing info compared with get_task_stream() + ts._prefix._all_durations[action] += stop - start - if executing is not None: - ws._executing = { - self.tasks[key]: duration for key, duration in executing.items() - } + if len(L) > 0: + compute_start, compute_stop = L[0] + else: # This is very rare + compute_start = compute_stop = None + else: + compute_start = compute_stop = None - if metrics: - ws._metrics = metrics + ############################# + # Update Timing Information # + ############################# + if compute_start and ws._processing.get(ts, True): + # Update average task duration for worker + old_duration = ts._prefix._duration_average + new_duration = compute_stop - compute_start + if old_duration < 0: + avg_duration = new_duration + else: + avg_duration = 0.5 * old_duration + 0.5 * new_duration - if host_info: - self.host_info[host].update(host_info) - - delay = time() - now - ws._time_delay = delay + ts._prefix._duration_average = avg_duration + ts._group._duration += new_duration - if resources: - self.add_resources(worker=address, resources=resources) + tts: TaskState + for tts in self._unknown_durations.pop(ts._prefix._name, ()): + if tts._processing_on: + wws = tts._processing_on + old = wws._processing[tts] + comm = self.get_comm_cost(tts, wws) + wws._processing[tts] = avg_duration + comm + wws._occupancy += avg_duration + comm - old + self._total_occupancy += avg_duration + comm - old - self.log_event(address, merge({"action": "heartbeat"}, metrics)) + ############################ + # Update State Information # + ############################ + if nbytes is not None: + ts.set_nbytes(nbytes) - return { - "status": "OK", - "time": time(), - "heartbeat-interval": heartbeat_interval(len(self.workers)), - } + recommendations: dict = {} + client_msgs: dict = {} - async def add_worker( - self, - comm=None, - address=None, - keys=(), - nthreads=None, - name=None, - resolve_address=True, - nbytes=None, - types=None, - now=None, - resources=None, - host_info=None, - memory_limit=None, - metrics=None, - pid=0, - services=None, - local_directory=None, - versions=None, - nanny=None, - extra=None, - ): - """ Add a new worker to the cluster """ - with log_errors(): - address = self.coerce_address(address, resolve_address) - address = normalize_address(address) - host = get_address_host(address) + _remove_from_processing(self, ts) - ws: WorkerState = self.workers.get(address) - if ws is not None: - raise ValueError("Worker already exists %s" % ws) + _add_to_memory( + self, ts, ws, recommendations, client_msgs, type=type, typename=typename + ) - if name in self.aliases: - logger.warning( - "Worker tried to connect with a duplicate name: %s", name - ) - msg = { - "status": "error", - "message": "name taken, %s" % name, - "time": time(), - } - if comm: - await comm.write(msg) - return + if self._validate: + assert not ts._processing_on + assert not ts._waiting_on - self.workers[address] = ws = WorkerState( - address=address, - pid=pid, - nthreads=nthreads, - memory_limit=memory_limit or 0, - name=name, - local_directory=local_directory, - services=services, - versions=versions, - nanny=nanny, - extra=extra, - ) + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - if "addresses" not in self.host_info[host]: - self.host_info[host].update({"addresses": set(), "nthreads": 0}) + pdb.set_trace() + raise - self.host_info[host]["addresses"].add(address) - self.host_info[host]["nthreads"] += nthreads + def transition_memory_released(self, key, safe: bint = False): + ws: WorkerState + try: + ts: TaskState = self._tasks[key] + dts: TaskState + worker_msgs: dict = {} + client_msgs: dict = {} - self.total_nthreads += nthreads - self.aliases[name] = address + if self._validate: + assert not ts._waiting_on + assert not ts._processing_on + if safe: + assert not ts._waiters - response = self.heartbeat_worker( - address=address, - resolve_address=resolve_address, - now=now, - resources=resources, - host_info=host_info, - metrics=metrics, - ) + if ts._actor: + for ws in ts._who_has: + ws._actors.discard(ts) + if ts._who_wants: + ts._exception_blame = ts + ts._exception = "Worker holding Actor was lost" + return ( + {ts._key: "erred"}, + worker_msgs, + client_msgs, + ) # don't try to recreate - # Do not need to adjust self.total_occupancy as self.occupancy[ws] cannot exist before this. - self.check_idle_saturated(ws) + recommendations: dict = {} - # for key in keys: # TODO - # self.mark_key_in_memory(key, [address]) + for dts in ts._waiters: + if dts._state in ("no-worker", "processing"): + recommendations[dts._key] = "waiting" + elif dts._state == "waiting": + dts._waiting_on.add(ts) - self.stream_comms[address] = BatchedSend(interval="5ms", loop=self.loop) + # XXX factor this out? + for ws in ts._who_has: + ws._has_what.remove(ts) + ws._nbytes -= ts.get_nbytes() + ts._group._nbytes_in_memory -= ts.get_nbytes() + worker_msgs[ws._address] = { + "op": "delete-data", + "keys": [key], + "report": False, + } - if ws._nthreads > len(ws._processing): - self.idle[ws._address] = ws + ts._who_has.clear() - for plugin in self.plugins[:]: - try: - result = plugin.add_worker(scheduler=self, worker=address) - if inspect.isawaitable(result): - await result - except Exception as e: - logger.exception(e) + ts.state = "released" - recommendations: dict - if nbytes: - for key in nbytes: - tasks: dict = self.tasks - ts: TaskState = tasks.get(key) - if ts is not None and ts._state in ("processing", "waiting"): - recommendations = self.transition( - key, - "memory", - worker=address, - nbytes=nbytes[key], - typename=types[key], - ) - self.transitions(recommendations) + report_msg = {"op": "lost-data", "key": key} + cs: ClientState + for cs in ts._who_wants: + client_msgs[cs._client_key] = report_msg - recommendations = {} - for ts in list(self.unrunnable): - valid: set = self.valid_workers(ts) - if valid is None or ws in valid: - recommendations[ts._key] = "waiting" + if not ts._run_spec: # pure data + recommendations[key] = "forgotten" + elif ts._has_lost_dependencies: + recommendations[key] = "forgotten" + elif ts._who_wants or ts._waiters: + recommendations[key] = "waiting" - if recommendations: - self.transitions(recommendations) + if self._validate: + assert not ts._waiting_on - self.log_event(address, {"action": "add-worker"}) - self.log_event("all", {"action": "add-worker", "worker": address}) - logger.info("Register worker %s", ws) + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - msg = { - "status": "OK", - "time": time(), - "heartbeat-interval": heartbeat_interval(len(self.workers)), - "worker-plugins": self.worker_plugins, - } + pdb.set_trace() + raise - cs: ClientState - version_warning = version_module.error_message( - version_module.get_versions(), - merge( - {w: ws._versions for w, ws in self.workers.items()}, - {c: cs._versions for c, cs in self.clients.items() if cs._versions}, - ), - versions, - client_name="This Worker", - ) - msg.update(version_warning) + def transition_released_erred(self, key): + try: + ts: TaskState = self._tasks[key] + dts: TaskState + failing_ts: TaskState + worker_msgs: dict = {} + client_msgs: dict = {} - if comm: - await comm.write(msg) - await self.handle_worker(comm=comm, worker=address) + if self._validate: + with log_errors(pdb=LOG_PDB): + assert ts._exception_blame + assert not ts._who_has + assert not ts._waiting_on + assert not ts._waiters - def update_graph_hlg( - self, - client=None, - hlg=None, - keys=None, - dependencies=None, - restrictions=None, - priority=None, - loose_restrictions=None, - resources=None, - submitting_task=None, - retries=None, - user_priority=0, - actors=None, - fifo_timeout=0, - ): + recommendations: dict = {} - dsk, dependencies, annotations = highlevelgraph_unpack(hlg) + failing_ts = ts._exception_blame - # Remove any self-dependencies (happens on test_publish_bag() and others) - for k, v in dependencies.items(): - deps = set(v) - if k in deps: - deps.remove(k) - dependencies[k] = deps + for dts in ts._dependents: + dts._exception_blame = failing_ts + if not dts._who_has: + recommendations[dts._key] = "erred" - if priority is None: - # Removing all non-local keys before calling order() - dsk_keys = set(dsk) # intersection() of sets is much faster than dict_keys - stripped_deps = { - k: v.intersection(dsk_keys) - for k, v in dependencies.items() - if k in dsk_keys + report_msg = { + "op": "task-erred", + "key": key, + "exception": failing_ts._exception, + "traceback": failing_ts._traceback, } - priority = dask.order.order(dsk, dependencies=stripped_deps) + cs: ClientState + for cs in ts._who_wants: + client_msgs[cs._client_key] = report_msg - return self.update_graph( - client, - dsk, - keys, - dependencies, - restrictions, - priority, - loose_restrictions, - resources, - submitting_task, - retries, - user_priority, - actors, - fifo_timeout, - annotations, - ) + ts.state = "erred" - def update_graph( - self, - client=None, - tasks=None, - keys=None, - dependencies=None, - restrictions=None, - priority=None, - loose_restrictions=None, - resources=None, - submitting_task=None, - retries=None, - user_priority=0, - actors=None, - fifo_timeout=0, - annotations=None, - ): - """ - Add new computations to the internal dask graph + # TODO: waiting data? + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - This happens whenever the Client calls submit, map, get, or compute. - """ - start = time() - fifo_timeout = parse_timedelta(fifo_timeout) - keys = set(keys) - if len(tasks) > 1: - self.log_event( - ["all", client], {"action": "update_graph", "count": len(tasks)} - ) + pdb.set_trace() + raise - # Remove aliases - for k in list(tasks): - if tasks[k] is k: - del tasks[k] + def transition_erred_released(self, key): + try: + ts: TaskState = self._tasks[key] + dts: TaskState + worker_msgs: dict = {} + client_msgs: dict = {} - dependencies = dependencies or {} + if self._validate: + with log_errors(pdb=LOG_PDB): + assert all([dts._state != "erred" for dts in ts._dependencies]) + assert ts._exception_blame + assert not ts._who_has + assert not ts._waiting_on + assert not ts._waiters - n = 0 - while len(tasks) != n: # walk through new tasks, cancel any bad deps - n = len(tasks) - for k, deps in list(dependencies.items()): - if any( - dep not in self.tasks and dep not in tasks for dep in deps - ): # bad key - logger.info("User asked for computation on lost data, %s", k) - del tasks[k] - del dependencies[k] - if k in keys: - keys.remove(k) - self.report({"op": "cancelled-key", "key": k}, client=client) - self.client_releases_keys(keys=[k], client=client) + recommendations: dict = {} - # Avoid computation that is already finished - ts: TaskState - already_in_memory = set() # tasks that are already done - for k, v in dependencies.items(): - if v and k in self.tasks: - ts = self.tasks[k] - if ts._state in ("memory", "erred"): - already_in_memory.add(k) + ts._exception = None + ts._exception_blame = None + ts._traceback = None - dts: TaskState - if already_in_memory: - dependents = dask.core.reverse_dict(dependencies) - stack = list(already_in_memory) - done = set(already_in_memory) - while stack: # remove unnecessary dependencies - key = stack.pop() - ts = self.tasks[key] - try: - deps = dependencies[key] - except KeyError: - deps = self.dependencies[key] - for dep in deps: - if dep in dependents: - child_deps = dependents[dep] - else: - child_deps = self.dependencies[dep] - if all(d in done for d in child_deps): - if dep in self.tasks and dep not in done: - done.add(dep) - stack.append(dep) + for dts in ts._dependents: + if dts._state == "erred": + recommendations[dts._key] = "waiting" - for d in done: - tasks.pop(d, None) - dependencies.pop(d, None) + report_msg = {"op": "task-retried", "key": key} + cs: ClientState + for cs in ts._who_wants: + client_msgs[cs._client_key] = report_msg - # Get or create task states - stack = list(keys) - touched_keys = set() - touched_tasks = [] - while stack: - k = stack.pop() - if k in touched_keys: - continue - # XXX Have a method get_task_state(self, k) ? - ts = self.tasks.get(k) - if ts is None: - ts = self.new_task(k, tasks.get(k), "released") - elif not ts._run_spec: - ts._run_spec = tasks.get(k) + ts.state = "released" - touched_keys.add(k) - touched_tasks.append(ts) - stack.extend(dependencies.get(k, ())) + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - self.client_desires_keys(keys=keys, client=client) + pdb.set_trace() + raise - # Add dependencies - for key, deps in dependencies.items(): - ts = self.tasks.get(key) - if ts is None or ts._dependencies: - continue - for dep in deps: - dts = self.tasks[dep] - ts.add_dependency(dts) + def transition_waiting_released(self, key): + try: + ts: TaskState = self._tasks[key] + worker_msgs: dict = {} + client_msgs: dict = {} - # Compute priorities - if isinstance(user_priority, Number): - user_priority = {k: user_priority for k in tasks} + if self._validate: + assert not ts._who_has + assert not ts._processing_on - annotations = annotations or {} - restrictions = restrictions or {} - loose_restrictions = loose_restrictions or [] - resources = resources or {} - retries = retries or {} + recommendations: dict = {} - # Override existing taxonomy with per task annotations - if annotations: - if "priority" in annotations: - user_priority.update(annotations["priority"]) + dts: TaskState + for dts in ts._dependencies: + s = dts._waiters + if ts in s: + s.discard(ts) + if not s and not dts._who_wants: + recommendations[dts._key] = "released" + ts._waiting_on.clear() - if "workers" in annotations: - restrictions.update(annotations["workers"]) + ts.state = "released" - if "allow_other_workers" in annotations: - loose_restrictions.extend( - k for k, v in annotations["allow_other_workers"].items() if v - ) + if ts._has_lost_dependencies: + recommendations[key] = "forgotten" + elif not ts._exception_blame and (ts._who_wants or ts._waiters): + recommendations[key] = "waiting" + else: + ts._waiters.clear() - if "retries" in annotations: - retries.update(annotations["retries"]) + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - if "resources" in annotations: - resources.update(annotations["resources"]) + pdb.set_trace() + raise - for a, kv in annotations.items(): - for k, v in kv.items(): - ts = self.tasks[k] - ts._annotations[a] = v + def transition_processing_released(self, key): + try: + ts: TaskState = self._tasks[key] + dts: TaskState + worker_msgs: dict = {} + client_msgs: dict = {} - # Add actors - if actors is True: - actors = list(keys) - for actor in actors or []: - ts = self.tasks[actor] - ts._actor = True + if self._validate: + assert ts._processing_on + assert not ts._who_has + assert not ts._waiting_on + assert self._tasks[key].state == "processing" - priority = priority or dask.order.order( - tasks - ) # TODO: define order wrt old graph + w: str = _remove_from_processing(self, ts) + if w: + worker_msgs[w] = {"op": "release-task", "key": key} - if submitting_task: # sub-tasks get better priority than parent tasks - ts = self.tasks.get(submitting_task) - if ts is not None: - generation = ts._priority[0] - 0.01 - else: # super-task already cleaned up - generation = self.generation - elif self._last_time + fifo_timeout < start: - self.generation += 1 # older graph generations take precedence - generation = self.generation - self._last_time = start - else: - generation = self.generation + ts.state = "released" - for key in set(priority) & touched_keys: - ts = self.tasks[key] - if ts._priority is None: - ts._priority = (-(user_priority.get(key, 0)), generation, priority[key]) + recommendations: dict = {} - # Ensure all runnables have a priority - runnables = [ts for ts in touched_tasks if ts._run_spec] - for ts in runnables: - if ts._priority is None and ts._run_spec: - ts._priority = (self.generation, 0) + if ts._has_lost_dependencies: + recommendations[key] = "forgotten" + elif ts._waiters or ts._who_wants: + recommendations[key] = "waiting" - if restrictions: - # *restrictions* is a dict keying task ids to lists of - # restriction specifications (either worker names or addresses) - for k, v in restrictions.items(): - if v is None: - continue - ts = self.tasks.get(k) - if ts is None: - continue - ts._host_restrictions = set() - ts._worker_restrictions = set() - for w in v: - try: - w = self.coerce_address(w) - except ValueError: - # Not a valid address, but perhaps it's a hostname - ts._host_restrictions.add(w) - else: - ts._worker_restrictions.add(w) + if recommendations.get(key) != "waiting": + for dts in ts._dependencies: + if dts._state != "released": + s = dts._waiters + s.discard(ts) + if not s and not dts._who_wants: + recommendations[dts._key] = "released" + ts._waiters.clear() - if loose_restrictions: - for k in loose_restrictions: - ts = self.tasks[k] - ts._loose_restrictions = True + if self._validate: + assert not ts._processing_on - if resources: - for k, v in resources.items(): - if v is None: - continue - assert isinstance(v, dict) - ts = self.tasks.get(k) - if ts is None: - continue - ts._resource_restrictions = v + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - if retries: - for k, v in retries.items(): - assert isinstance(v, int) - ts = self.tasks.get(k) - if ts is None: - continue - ts._retries = v + pdb.set_trace() + raise - # Compute recommendations - recommendations: dict = {} + def transition_processing_erred( + self, key, cause=None, exception=None, traceback=None, **kwargs + ): + ws: WorkerState + try: + ts: TaskState = self._tasks[key] + dts: TaskState + failing_ts: TaskState + worker_msgs: dict = {} + client_msgs: dict = {} - for ts in sorted(runnables, key=operator.attrgetter("priority"), reverse=True): - if ts._state == "released" and ts._run_spec: - recommendations[ts._key] = "waiting" + if self._validate: + assert cause or ts._exception_blame + assert ts._processing_on + assert not ts._who_has + assert not ts._waiting_on - for ts in touched_tasks: - for dts in ts._dependencies: - if dts._exception_blame: - ts._exception_blame = dts._exception_blame - recommendations[ts._key] = "erred" - break + if ts._actor: + ws = ts._processing_on + ws._actors.remove(ts) - for plugin in self.plugins[:]: - try: - plugin.update_graph( - self, - client=client, - tasks=tasks, - keys=keys, - restrictions=restrictions or {}, - dependencies=dependencies, - priority=priority, - loose_restrictions=loose_restrictions, - resources=resources, - annotations=annotations, - ) - except Exception as e: - logger.exception(e) + _remove_from_processing(self, ts) - self.transitions(recommendations) + if exception is not None: + ts._exception = exception + if traceback is not None: + ts._traceback = traceback + if cause is not None: + failing_ts = self._tasks[cause] + ts._exception_blame = failing_ts + else: + failing_ts = ts._exception_blame - for ts in touched_tasks: - if ts._state in ("memory", "erred"): - self.report_on_key(ts=ts, client=client) + recommendations: dict = {} - end = time() - if self.digests is not None: - self.digests["update-graph-duration"].add(end - start) + for dts in ts._dependents: + dts._exception_blame = failing_ts + recommendations[dts._key] = "erred" - # TODO: balance workers + for dts in ts._dependencies: + s = dts._waiters + s.discard(ts) + if not s and not dts._who_wants: + recommendations[dts._key] = "released" - def new_task(self, key, spec, state): - """ Create a new task, and associated states """ - ts: TaskState = TaskState(key, spec) - tp: TaskPrefix - tg: TaskGroup - ts._state = state - prefix_key = key_split(key) - try: - tp = self.task_prefixes[prefix_key] - except KeyError: - self.task_prefixes[prefix_key] = tp = TaskPrefix(prefix_key) - ts._prefix = tp + ts._waiters.clear() # do anything with this? - group_key = ts._group_key - try: - tg = self.task_groups[group_key] - except KeyError: - self.task_groups[group_key] = tg = TaskGroup(group_key) - tg._prefix = tp - tp._groups.append(tg) - tg.add(ts) - self.tasks[key] = ts - return ts + ts.state = "erred" - def stimulus_task_finished(self, key=None, worker=None, **kwargs): - """ Mark that a task has finished execution on a particular worker """ - logger.debug("Stimulus task finished %s, %s", key, worker) + report_msg = { + "op": "task-erred", + "key": key, + "exception": failing_ts._exception, + "traceback": failing_ts._traceback, + } + cs: ClientState + for cs in ts._who_wants: + client_msgs[cs._client_key] = report_msg - tasks: dict = self.tasks - ts: TaskState = tasks.get(key) - if ts is None: - return {} - workers: dict = cast(dict, self.workers) - ws: WorkerState = workers[worker] - ts._metadata.update(kwargs["metadata"]) + cs = self._clients["fire-and-forget"] + if ts in cs._wants_what: + _client_releases_keys( + self, + cs=cs, + keys=[key], + recommendations=recommendations, + ) - recommendations: dict - if ts._state == "processing": - recommendations = self.transition(key, "memory", worker=worker, **kwargs) + if self._validate: + assert not ts._processing_on - if ts._state == "memory": - assert ws in ts._who_has - else: - logger.debug( - "Received already computed task, worker: %s, state: %s" - ", key: %s, who_has: %s", - worker, - ts._state, - key, - ts._who_has, - ) - if ws not in ts._who_has: - self.worker_send(worker, {"op": "release-task", "key": key}) - recommendations = {} + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - return recommendations + pdb.set_trace() + raise - def stimulus_task_erred( - self, key=None, worker=None, exception=None, traceback=None, **kwargs - ): - """ Mark that a task has erred on a particular worker """ - logger.debug("Stimulus task erred %s, %s", key, worker) + def transition_no_worker_released(self, key): + try: + ts: TaskState = self._tasks[key] + dts: TaskState + worker_msgs: dict = {} + client_msgs: dict = {} - ts: TaskState = self.tasks.get(key) - if ts is None: - return {} + if self._validate: + assert self._tasks[key].state == "no-worker" + assert not ts._who_has + assert not ts._waiting_on - recommendations: dict - if ts._state == "processing": - retries = ts._retries - if retries > 0: - ts._retries = retries - 1 - recommendations = self.transition(key, "waiting") - else: - recommendations = self.transition( - key, - "erred", - cause=key, - exception=exception, - traceback=traceback, - worker=worker, - **kwargs, - ) - else: - recommendations = {} + self._unrunnable.remove(ts) + ts.state = "released" - return recommendations + for dts in ts._dependencies: + dts._waiters.discard(ts) - def stimulus_missing_data( - self, cause=None, key=None, worker=None, ensure=True, **kwargs - ): - """ Mark that certain keys have gone missing. Recover. """ - with log_errors(): - logger.debug("Stimulus missing data %s, %s", key, worker) + ts._waiters.clear() - ts: TaskState = self.tasks.get(key) - if ts is None or ts._state == "memory": - return {} - cts: TaskState = self.tasks.get(cause) + return {}, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - recommendations: dict = {} + pdb.set_trace() + raise - if cts is not None and cts._state == "memory": # couldn't find this - ws: WorkerState - for ws in cts._who_has: # TODO: this behavior is extreme - ws._has_what.remove(cts) - ws._nbytes -= cts.get_nbytes() - cts._who_has.clear() - recommendations[cause] = "released" + @ccall + def remove_key(self, key): + ts: TaskState = self._tasks.pop(key) + assert ts._state == "forgotten" + self._unrunnable.discard(ts) + cs: ClientState + for cs in ts._who_wants: + cs._wants_what.remove(ts) + ts._who_wants.clear() + ts._processing_on = None + ts._exception_blame = ts._exception = ts._traceback = None + self._task_metadata.pop(key, None) - if key: - recommendations[key] = "released" + def transition_memory_forgotten(self, key): + ws: WorkerState + try: + ts: TaskState = self._tasks[key] + worker_msgs: dict = {} + client_msgs: dict = {} - self.transitions(recommendations) + if self._validate: + assert ts._state == "memory" + assert not ts._processing_on + assert not ts._waiting_on + if not ts._run_spec: + # It's ok to forget a pure data task + pass + elif ts._has_lost_dependencies: + # It's ok to forget a task with forgotten dependencies + pass + elif not ts._who_wants and not ts._waiters and not ts._dependents: + # It's ok to forget a task that nobody needs + pass + else: + assert 0, (ts,) - if self.validate: - assert cause not in self.who_has + recommendations: dict = {} - return {} + if ts._actor: + for ws in ts._who_has: + ws._actors.discard(ts) - def stimulus_retry(self, comm=None, keys=None, client=None): - logger.info("Client %s requests to retry %d keys", client, len(keys)) - if client: - self.log_event(client, {"action": "retry", "count": len(keys)}) + _propagate_forgotten(self, ts, recommendations, worker_msgs) - stack = list(keys) - seen = set() - roots = [] - ts: TaskState - dts: TaskState - while stack: - key = stack.pop() - seen.add(key) - ts = self.tasks[key] - erred_deps = [dts._key for dts in ts._dependencies if dts._state == "erred"] - if erred_deps: - stack.extend(erred_deps) - else: - roots.append(key) + client_msgs = _task_to_client_msgs(self, ts) + self.remove_key(key) - recommendations: dict = {key: "waiting" for key in roots} - self.transitions(recommendations) + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - if self.validate: - for key in seen: - assert not self.tasks[key].exception_blame + pdb.set_trace() + raise - return tuple(seen) + def transition_released_forgotten(self, key): + try: + ts: TaskState = self._tasks[key] + worker_msgs: dict = {} + client_msgs: dict = {} - async def remove_worker(self, comm=None, address=None, safe=False, close=True): - """ - Remove worker from cluster + if self._validate: + assert ts._state in ("released", "erred") + assert not ts._who_has + assert not ts._processing_on + assert not ts._waiting_on, (ts, ts._waiting_on) + if not ts._run_spec: + # It's ok to forget a pure data task + pass + elif ts._has_lost_dependencies: + # It's ok to forget a task with forgotten dependencies + pass + elif not ts._who_wants and not ts._waiters and not ts._dependents: + # It's ok to forget a task that nobody needs + pass + else: + assert 0, (ts,) - We do this when a worker reports that it plans to leave or when it - appears to be unresponsive. This may send its tasks back to a released - state. - """ - with log_errors(): - if self.status == Status.closed: - return + recommendations: dict = {} + _propagate_forgotten(self, ts, recommendations, worker_msgs) - address = self.coerce_address(address) + client_msgs = _task_to_client_msgs(self, ts) + self.remove_key(key) - if address not in self.workers: - return "already-removed" + return recommendations, worker_msgs, client_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - host = get_address_host(address) + pdb.set_trace() + raise - ws: WorkerState = self.workers[address] + ############################## + # Assigning Tasks to Workers # + ############################## - self.log_event( - ["all", address], - { - "action": "remove-worker", - "worker": address, - "processing-tasks": dict(ws._processing), - }, - ) - logger.info("Remove worker %s", ws) - if close: - with suppress(AttributeError, CommClosedError): - self.stream_comms[address].send({"op": "close", "report": False}) + @ccall + @exceptval(check=False) + def check_idle_saturated(self, ws: WorkerState, occ: double = -1.0): + """Update the status of the idle and saturated state - self.remove_resources(address) + The scheduler keeps track of workers that are .. - self.host_info[host]["nthreads"] -= ws._nthreads - self.host_info[host]["addresses"].remove(address) - self.total_nthreads -= ws._nthreads + - Saturated: have enough work to stay busy + - Idle: do not have enough work to stay busy - if not self.host_info[host]["addresses"]: - del self.host_info[host] + They are considered saturated if they both have enough tasks to occupy + all of their threads, and if the expected runtime of those tasks is + large enough. - self.rpc.remove(address) - del self.stream_comms[address] - del self.aliases[ws._name] - self.idle.pop(ws._address, None) - self.saturated.discard(ws) - del self.workers[address] - ws.status = Status.closed - self.total_occupancy -= ws._occupancy + This is useful for load balancing and adaptivity. + """ + total_nthreads: Py_ssize_t = self._total_nthreads + if total_nthreads == 0 or ws.status == Status.closed: + return + if occ < 0: + occ = ws._occupancy - recommendations: dict = {} + nc: Py_ssize_t = ws._nthreads + p: Py_ssize_t = len(ws._processing) + total_occupancy: double = self._total_occupancy + avg: double = total_occupancy / total_nthreads - ts: TaskState - for ts in list(ws._processing): - k = ts._key - recommendations[k] = "released" - if not safe: - ts._suspicious += 1 - ts._prefix._suspicious += 1 - if ts._suspicious > self.allowed_failures: - del recommendations[k] - e = pickle.dumps( - KilledWorker(task=k, last_worker=ws.clean()), protocol=4 - ) - r = self.transition(k, "erred", exception=e, cause=k) - recommendations.update(r) - logger.info( - "Task %s marked as failed because %d workers died" - " while trying to run it", - ts._key, - self.allowed_failures, - ) + idle = self._idle + saturated: set = self._saturated + if p < nc or occ < nc * avg / 2: + idle[ws._address] = ws + saturated.discard(ws) + else: + idle.pop(ws._address, None) - for ts in ws._has_what: - ts._who_has.remove(ws) - if not ts._who_has: - if ts._run_spec: - recommendations[ts._key] = "released" - else: # pure data - recommendations[ts._key] = "forgotten" - ws._has_what.clear() + if p > nc: + pending: double = occ * (p - nc) / (p * nc) + if 0.4 < pending > 1.9 * avg: + saturated.add(ws) + return - self.transitions(recommendations) + saturated.discard(ws) - for plugin in self.plugins[:]: - try: - result = plugin.remove_worker(scheduler=self, worker=address) - if inspect.isawaitable(result): - await result - except Exception as e: - logger.exception(e) + @ccall + def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> double: + """ + Get the estimated communication cost (in s.) to compute the task + on the given worker. + """ + dts: TaskState + deps: set = ts._dependencies - ws._has_what + nbytes: Py_ssize_t = 0 + bandwidth: double = self._bandwidth + for dts in deps: + nbytes += dts._nbytes + return nbytes / bandwidth - if not self.workers: - logger.info("Lost all workers") + @ccall + def get_task_duration(self, ts: TaskState, default: double = -1) -> double: + """ + Get the estimated computation cost of the given task + (not including any communication cost). + """ + duration: double = ts._prefix._duration_average + if duration < 0: + s: set = self._unknown_durations[ts._prefix._name] + s.add(ts) + if default < 0: + duration = UNKNOWN_TASK_DURATION + else: + duration = default - for w in self.workers: - self.bandwidth_workers.pop((address, w), None) - self.bandwidth_workers.pop((w, address), None) + return duration - def remove_worker_from_events(): - # If the worker isn't registered anymore after the delay, remove from events - if address not in self.workers and address in self.events: - del self.events[address] + @ccall + @exceptval(check=False) + def valid_workers(self, ts: TaskState) -> set: + """Return set of currently valid workers for key - cleanup_delay = parse_timedelta( - dask.config.get("distributed.scheduler.events-cleanup-delay") - ) - self.loop.call_later(cleanup_delay, remove_worker_from_events) - logger.debug("Removed worker %s", ws) + If all workers are valid then this returns ``None``. + This checks tracks the following state: - return "OK" + * worker_restrictions + * host_restrictions + * resource_restrictions + """ + s: set = None - def stimulus_cancel(self, comm, keys=None, client=None, force=False): - """ Stop execution on a list of keys """ - logger.info("Client %s requests to cancel %d keys", client, len(keys)) - if client: - self.log_event( - client, {"action": "cancel", "count": len(keys), "force": force} - ) - for key in keys: - self.cancel_key(key, client, force=force) + if ts._worker_restrictions: + s = {w for w in ts._worker_restrictions if w in self._workers_dv} - def cancel_key(self, key, client, retries=5, force=False): - """ Cancel a particular key and all dependents """ - # TODO: this should be converted to use the transition mechanism - ts: TaskState = self.tasks.get(key) - dts: TaskState - try: - cs: ClientState = self.clients[client] - except KeyError: - return - if ts is None or not ts._who_wants: # no key yet, lets try again in a moment - if retries: - self.loop.call_later( - 0.2, lambda: self.cancel_key(key, client, retries - 1) - ) - return - if force or ts._who_wants == {cs}: # no one else wants this key - for dts in list(ts._dependents): - self.cancel_key(dts._key, client, force=force) - logger.info("Scheduler cancels key %s. Force=%s", key, force) - self.report({"op": "cancelled-key", "key": key}) - clients = list(ts._who_wants) if force else [cs] - for cs in clients: - self.client_releases_keys(keys=[key], client=cs._client_key) + if ts._host_restrictions: + # Resolve the alias here rather than early, for the worker + # may not be connected when host_restrictions is populated + hr: list = [self.coerce_hostname(h) for h in ts._host_restrictions] + # XXX need HostState? + sl: list = [ + self._host_info[h]["addresses"] for h in hr if h in self._host_info + ] + ss: set = set.union(*sl) if sl else set() + if s is None: + s = ss + else: + s |= ss - def client_desires_keys(self, keys=None, client=None): - cs: ClientState = self.clients.get(client) - if cs is None: - # For publish, queues etc. - self.clients[client] = cs = ClientState(client) - ts: TaskState - for k in keys: - ts = self.tasks.get(k) - if ts is None: - # For publish, queues etc. - ts = self.new_task(k, None, "released") - ts._who_wants.add(cs) - cs._wants_what.add(ts) + if ts._resource_restrictions: + dw: dict = { + resource: { + w + for w, supplied in self._resources[resource].items() + if supplied >= required + } + for resource, required in ts._resource_restrictions.items() + } - if ts._state in ("memory", "erred"): - self.report_on_key(ts=ts, client=client) + ww: set = set.intersection(*dw.values()) + if s is None: + s = ww + else: + s &= ww - def client_releases_keys(self, keys=None, client=None): - """ Remove keys from client desired list """ - logger.debug("Client %s releases keys: %s", client, keys) - cs: ClientState = self.clients[client] - ts: TaskState - tasks2 = set() - for key in list(keys): - ts = self.tasks.get(key) - if ts is not None and ts in cs._wants_what: - cs._wants_what.remove(ts) - s = ts._who_wants - s.remove(cs) - if not s: - tasks2.add(ts) + if s is not None: + s = {self._workers_dv[w] for w in s} - recommendations: dict = {} - for ts in tasks2: - if not ts._dependents: - # No live dependents, can forget - recommendations[ts._key] = "forgotten" - elif ts._state != "erred" and not ts._waiters: - recommendations[ts._key] = "released" + return s - self.transitions(recommendations) + @ccall + def consume_resources(self, ts: TaskState, ws: WorkerState): + if ts._resource_restrictions: + for r, required in ts._resource_restrictions.items(): + ws._used_resources[r] += required - def client_heartbeat(self, client=None): - """ Handle heartbeats from Client """ - cs: ClientState = self.clients[client] - cs._last_seen = time() + @ccall + def release_resources(self, ts: TaskState, ws: WorkerState): + if ts._resource_restrictions: + for r, required in ts._resource_restrictions.items(): + ws._used_resources[r] -= required - ################### - # Task Validation # - ################### - - def validate_released(self, key): - ts: TaskState = self.tasks[key] - dts: TaskState - assert ts._state == "released" - assert not ts._waiters - assert not ts._waiting_on - assert not ts._who_has - assert not ts._processing_on - assert not any([ts in dts._waiters for dts in ts._dependencies]) - assert ts not in self.unrunnable + @ccall + def coerce_hostname(self, host): + """ + Coerce the hostname of a worker. + """ + if host in self._aliases: + return self._workers_dv[self._aliases[host]].host + else: + return host - def validate_waiting(self, key): - ts: TaskState = self.tasks[key] - dts: TaskState - assert ts._waiting_on - assert not ts._who_has - assert not ts._processing_on - assert ts not in self.unrunnable - for dts in ts._dependencies: - # We are waiting on a dependency iff it's not stored - assert (not not dts._who_has) != (dts in ts._waiting_on) - assert ts in dts._waiters # XXX even if dts._who_has? + @ccall + @exceptval(check=False) + def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple: + """ + Objective function to determine which worker should get the task - def validate_processing(self, key): - ts: TaskState = self.tasks[key] + Minimize expected start time. If a tie then break with data storage. + """ dts: TaskState - assert not ts._waiting_on - ws: WorkerState = ts._processing_on - assert ws - assert ts in ws._processing - assert not ts._who_has + nbytes: Py_ssize_t + comm_bytes: Py_ssize_t = 0 for dts in ts._dependencies: - assert dts._who_has - assert ts in dts._waiters + if ws not in dts._who_has: + nbytes = dts.get_nbytes() + comm_bytes += nbytes - def validate_memory(self, key): - ts: TaskState = self.tasks[key] - dts: TaskState - assert ts._who_has - assert not ts._processing_on - assert not ts._waiting_on - assert ts not in self.unrunnable - for dts in ts._dependents: - assert (dts in ts._waiters) == (dts._state in ("waiting", "processing")) - assert ts not in dts._waiting_on + bandwidth: double = self._bandwidth + stack_time: double = ws._occupancy / ws._nthreads + start_time: double = stack_time + comm_bytes / bandwidth - def validate_no_worker(self, key): - ts: TaskState = self.tasks[key] - dts: TaskState - assert ts in self.unrunnable - assert not ts._waiting_on - assert ts in self.unrunnable - assert not ts._processing_on - assert not ts._who_has - for dts in ts._dependencies: - assert dts._who_has + if ts._actor: + return (len(ws._actors), start_time, ws._nbytes) + else: + return (start_time, ws._nbytes) - def validate_erred(self, key): - ts: TaskState = self.tasks[key] - assert ts._exception_blame - assert not ts._who_has - def validate_key(self, key, ts: TaskState = None): - try: - if ts is None: - ts = self.tasks.get(key) - if ts is None: - logger.debug("Key lost: %s", key) - else: - ts.validate() - try: - func = getattr(self, "validate_" + ts._state.replace("-", "_")) - except AttributeError: - logger.error( - "self.validate_%s not found", ts._state.replace("-", "_") - ) - else: - func(key) - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb +class Scheduler(SchedulerState, ServerNode): + """Dynamic distributed task scheduler - pdb.set_trace() - raise + The scheduler tracks the current state of workers, data, and computations. + The scheduler listens for events and responds by controlling workers + appropriately. It continuously tries to use the workers to execute an ever + growing dask graph. - def validate_state(self, allow_overlap=False): - validate_state(self.tasks, self.workers, self.clients) + All events are handled quickly, in linear time with respect to their input + (which is often of constant size) and generally within a millisecond. To + accomplish this the scheduler tracks a lot of state. Every operation + maintains the consistency of this state. - if not (set(self.workers) == set(self.stream_comms)): - raise ValueError("Workers not the same in all collections") + The scheduler communicates with the outside world through Comm objects. + It maintains a consistent and valid view of the world even when listening + to several clients at once. - ws: WorkerState - for w, ws in self.workers.items(): - assert isinstance(w, str), (type(w), w) - assert isinstance(ws, WorkerState), (type(ws), ws) - assert ws._address == w - if not ws._processing: - assert not ws._occupancy - assert ws._address in cast(dict, self.idle) + A Scheduler is typically started either with the ``dask-scheduler`` + executable:: - ts: TaskState - for k, ts in self.tasks.items(): - assert isinstance(ts, TaskState), (type(ts), ts) - assert ts._key == k - self.validate_key(k, ts) + $ dask-scheduler + Scheduler started at 127.0.0.1:8786 - c: str - cs: ClientState - for c, cs in self.clients.items(): - # client=None is often used in tests... - assert c is None or type(c) == str, (type(c), c) - assert type(cs) == ClientState, (type(cs), cs) - assert cs._client_key == c + Or within a LocalCluster a Client starts up without connection + information:: - a = {w: ws._nbytes for w, ws in self.workers.items()} - b = { - w: sum(ts.get_nbytes() for ts in ws._has_what) - for w, ws in self.workers.items() - } - assert a == b, (a, b) + >>> c = Client() # doctest: +SKIP + >>> c.cluster.scheduler # doctest: +SKIP + Scheduler(...) - actual_total_occupancy = 0 - for worker, ws in self.workers.items(): - assert abs(sum(ws._processing.values()) - ws._occupancy) < 1e-8 - actual_total_occupancy += ws._occupancy + Users typically do not interact with the scheduler directly but rather with + the client object ``Client``. - assert abs(actual_total_occupancy - self.total_occupancy) < 1e-8, ( - actual_total_occupancy, - self.total_occupancy, - ) + **State** - ################### - # Manage Messages # - ################### + The scheduler contains the following state variables. Each variable is + listed along with what it stores and a brief description. - def report(self, msg: dict, ts: TaskState = None, client: str = None): - """ - Publish updates to all listening Queues and Comms + * **tasks:** ``{task key: TaskState}`` + Tasks currently known to the scheduler + * **unrunnable:** ``{TaskState}`` + Tasks in the "no-worker" state - If the message contains a key then we only send the message to those - comms that care about the key. - """ - if ts is None: - msg_key = msg.get("key") - if msg_key is not None: - tasks: dict = self.tasks - ts = tasks.get(msg_key) + * **workers:** ``{worker key: WorkerState}`` + Workers currently connected to the scheduler + * **idle:** ``{WorkerState}``: + Set of workers that are not fully utilized + * **saturated:** ``{WorkerState}``: + Set of workers that are not over-utilized - cs: ClientState - client_comms: dict = self.client_comms - client_keys: list - if ts is None: - # Notify all clients - client_keys = list(client_comms) - elif client is None: - # Notify clients interested in key - client_keys = [cs._client_key for cs in ts._who_wants] - else: - # Notify clients interested in key (including `client`) - client_keys = [ - cs._client_key for cs in ts._who_wants if cs._client_key != client - ] - client_keys.append(client) + * **host_info:** ``{hostname: dict}``: + Information about each worker host - k: str - for k in client_keys: - c = client_comms.get(k) - if c is None: - continue - try: - c.send(msg) - # logger.debug("Scheduler sends message to client %s", msg) - except CommClosedError: - if self.status == Status.running: - logger.critical("Tried writing to closed comm: %s", msg) + * **clients:** ``{client key: ClientState}`` + Clients currently connected to the scheduler - async def add_client(self, comm, client=None, versions=None): - """Add client to network + * **services:** ``{str: port}``: + Other services running on this scheduler, like Bokeh + * **loop:** ``IOLoop``: + The running Tornado IOLoop + * **client_comms:** ``{client key: Comm}`` + For each client, a Comm object used to receive task requests and + report task status updates. + * **stream_comms:** ``{worker key: Comm}`` + For each worker, a Comm object from which we both accept stimuli and + report results + * **task_duration:** ``{key-prefix: time}`` + Time we expect certain functions to take, e.g. ``{'sum': 0.25}`` + """ - We listen to all future messages from this Comm. - """ - assert client is not None - comm.name = "Scheduler->Client" - logger.info("Receive client connection: %s", client) - self.log_event(["all", client], {"action": "add-client", "client": client}) - self.clients[client] = ClientState(client, versions=versions) + default_port = 8786 + _instances = weakref.WeakSet() - for plugin in self.plugins[:]: - try: - plugin.add_client(scheduler=self, client=client) - except Exception as e: - logger.exception(e) + def __init__( + self, + loop=None, + delete_interval="500ms", + synchronize_worker_interval="60s", + services=None, + service_kwargs=None, + allowed_failures=None, + extensions=None, + validate=None, + scheduler_file=None, + security=None, + worker_ttl=None, + idle_timeout=None, + interface=None, + host=None, + port=0, + protocol=None, + dashboard_address=None, + dashboard=None, + http_prefix="/", + preload=None, + preload_argv=(), + plugins=(), + **kwargs, + ): + self._setup_logging(logger) - try: - bcomm = BatchedSend(interval="2ms", loop=self.loop) - bcomm.start(comm) - self.client_comms[client] = bcomm - msg = {"op": "stream-start"} - ws: WorkerState - version_warning = version_module.error_message( - version_module.get_versions(), - {w: ws._versions for w, ws in self.workers.items()}, - versions, - ) - msg.update(version_warning) - bcomm.send(msg) - - try: - await self.handle_stream(comm=comm, extra={"client": client}) - finally: - self.remove_client(client=client) - logger.debug("Finished handling client %s", client) - finally: - if not comm.closed(): - self.client_comms[client].send({"op": "stream-closed"}) - try: - if not shutting_down(): - await self.client_comms[client].close() - del self.client_comms[client] - if self.status == Status.running: - logger.info("Close client connection: %s", client) - except TypeError: # comm becomes None during GC - pass - - def remove_client(self, client=None): - """ Remove client from network """ - if self.status == Status.running: - logger.info("Remove client %s", client) - self.log_event(["all", client], {"action": "remove-client", "client": client}) - try: - cs: ClientState = self.clients[client] - except KeyError: - # XXX is this a legitimate condition? - pass + # Attributes + if allowed_failures is None: + allowed_failures = dask.config.get("distributed.scheduler.allowed-failures") + self.allowed_failures = allowed_failures + if validate is None: + validate = dask.config.get("distributed.scheduler.validate") + self.proc = psutil.Process() + self.delete_interval = parse_timedelta(delete_interval, default="ms") + self.synchronize_worker_interval = parse_timedelta( + synchronize_worker_interval, default="ms" + ) + self.digests = None + self.service_specs = services or {} + self.service_kwargs = service_kwargs or {} + self.services = {} + self.scheduler_file = scheduler_file + worker_ttl = worker_ttl or dask.config.get("distributed.scheduler.worker-ttl") + self.worker_ttl = parse_timedelta(worker_ttl) if worker_ttl else None + idle_timeout = idle_timeout or dask.config.get( + "distributed.scheduler.idle-timeout" + ) + if idle_timeout: + self.idle_timeout = parse_timedelta(idle_timeout) else: - ts: TaskState - self.client_releases_keys( - keys=[ts._key for ts in cs._wants_what], client=cs._client_key - ) - del self.clients[client] + self.idle_timeout = None + self.idle_since = time() + self.time_started = self.idle_since # compatibility for dask-gateway + self._lock = asyncio.Lock() + self.bandwidth_workers = defaultdict(float) + self.bandwidth_types = defaultdict(float) - for plugin in self.plugins[:]: - try: - plugin.remove_client(scheduler=self, client=client) - except Exception as e: - logger.exception(e) + if not preload: + preload = dask.config.get("distributed.scheduler.preload") + if not preload_argv: + preload_argv = dask.config.get("distributed.scheduler.preload-argv") + self.preloads = preloading.process_preloads(self, preload, preload_argv) - def remove_client_from_events(): - # If the client isn't registered anymore after the delay, remove from events - if client not in self.clients and client in self.events: - del self.events[client] + if isinstance(security, dict): + security = Security(**security) + self.security = security or Security() + assert isinstance(self.security, Security) + self.connection_args = self.security.get_connection_args("scheduler") + self.connection_args["handshake_overrides"] = { # common denominator + "pickle-protocol": 4 + } - cleanup_delay = parse_timedelta( - dask.config.get("distributed.scheduler.events-cleanup-delay") + self._start_address = addresses_from_user_args( + host=host, + port=port, + interface=interface, + protocol=protocol, + security=security, + default_port=self.default_port, ) - self.loop.call_later(cleanup_delay, remove_client_from_events) - - def send_task_to_worker(self, worker, ts: TaskState, duration=None): - """ Send a single computational task to a worker """ - try: - ws: WorkerState - dts: TaskState - if duration is None: - duration = self.get_task_duration(ts) + http_server_modules = dask.config.get("distributed.scheduler.http.routes") + show_dashboard = dashboard or (dashboard is None and dashboard_address) + missing_bokeh = False + # install vanilla route if show_dashboard but bokeh is not installed + if show_dashboard: + try: + import distributed.dashboard.scheduler + except ImportError: + missing_bokeh = True + http_server_modules.append("distributed.http.scheduler.missing_bokeh") + routes = get_handlers( + server=self, modules=http_server_modules, prefix=http_prefix + ) + self.start_http_server(routes, dashboard_address, default_port=8787) + if show_dashboard and not missing_bokeh: + distributed.dashboard.scheduler.connect( + self.http_application, self.http_server, self, prefix=http_prefix + ) - msg: dict = { - "op": "compute-task", - "key": ts._key, - "priority": ts._priority, - "duration": duration, - } - if ts._resource_restrictions: - msg["resource_restrictions"] = ts._resource_restrictions - if ts._actor: - msg["actor"] = True + # Communication state + self.loop = loop or IOLoop.current() + self.client_comms = dict() + self.stream_comms = dict() + self._worker_coroutines = [] + self._ipython_kernel = None - deps: set = ts._dependencies - if deps: - msg["who_has"] = { - dts._key: [ws._address for ws in dts._who_has] for dts in deps - } - msg["nbytes"] = {dts._key: dts._nbytes for dts in deps} + # Task state + tasks = dict() + self.task_groups = dict() + self.task_prefixes = dict() + for old_attr, new_attr, wrap in [ + ("priority", "priority", None), + ("dependencies", "dependencies", _legacy_task_key_set), + ("dependents", "dependents", _legacy_task_key_set), + ("retries", "retries", None), + ]: + func = operator.attrgetter(new_attr) + if wrap is not None: + func = compose(wrap, func) + setattr(self, old_attr, _StateLegacyMapping(tasks, func)) - if self.validate: - assert all(msg["who_has"].values()) + for old_attr, new_attr, wrap in [ + ("nbytes", "nbytes", None), + ("who_wants", "who_wants", _legacy_client_key_set), + ("who_has", "who_has", _legacy_worker_key_set), + ("waiting", "waiting_on", _legacy_task_key_set), + ("waiting_data", "waiters", _legacy_task_key_set), + ("rprocessing", "processing_on", None), + ("host_restrictions", "host_restrictions", None), + ("worker_restrictions", "worker_restrictions", None), + ("resource_restrictions", "resource_restrictions", None), + ("suspicious_tasks", "suspicious", None), + ("exceptions", "exception", None), + ("tracebacks", "traceback", None), + ("exceptions_blame", "exception_blame", _task_key_or_none), + ]: + func = operator.attrgetter(new_attr) + if wrap is not None: + func = compose(wrap, func) + setattr(self, old_attr, _OptionalStateLegacyMapping(tasks, func)) - task = ts._run_spec - if type(task) is dict: - msg.update(task) - else: - msg["task"] = task + for old_attr, new_attr, wrap in [ + ("loose_restrictions", "loose_restrictions", None) + ]: + func = operator.attrgetter(new_attr) + if wrap is not None: + func = compose(wrap, func) + setattr(self, old_attr, _StateLegacySet(tasks, func)) - self.worker_send(worker, msg) - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + self.generation = 0 + self._last_client = None + self._last_time = 0 + unrunnable = set() - pdb.set_trace() - raise + self.datasets = dict() - def handle_uncaught_error(self, **msg): - logger.exception(clean_exception(**msg)[1]) + # Prefix-keyed containers - def handle_task_finished(self, key=None, worker=None, **msg): - if worker not in self.workers: - return - validate_key(key) - r = self.stimulus_task_finished(key=key, worker=worker, **msg) - self.transitions(r) + # Client state + clients = dict() + for old_attr, new_attr, wrap in [ + ("wants_what", "wants_what", _legacy_task_key_set) + ]: + func = operator.attrgetter(new_attr) + if wrap is not None: + func = compose(wrap, func) + setattr(self, old_attr, _StateLegacyMapping(clients, func)) - def handle_task_erred(self, key=None, **msg): - r = self.stimulus_task_erred(key=key, **msg) - self.transitions(r) + # Worker state + workers = sortedcontainers.SortedDict() + for old_attr, new_attr, wrap in [ + ("nthreads", "nthreads", None), + ("worker_bytes", "nbytes", None), + ("worker_resources", "resources", None), + ("used_resources", "used_resources", None), + ("occupancy", "occupancy", None), + ("worker_info", "metrics", None), + ("processing", "processing", _legacy_task_key_dict), + ("has_what", "has_what", _legacy_task_key_set), + ]: + func = operator.attrgetter(new_attr) + if wrap is not None: + func = compose(wrap, func) + setattr(self, old_attr, _StateLegacyMapping(workers, func)) - def handle_release_data(self, key=None, worker=None, client=None, **msg): - ts: TaskState = self.tasks.get(key) - if ts is None: - return - ws: WorkerState = self.workers[worker] - if ts._processing_on != ws: - return - r = self.stimulus_missing_data(key=key, ensure=False, **msg) - self.transitions(r) + host_info = defaultdict(dict) + resources = defaultdict(dict) + aliases = dict() - def handle_missing_data(self, key=None, errant_worker=None, **kwargs): - logger.debug("handle missing data key=%s worker=%s", key, errant_worker) - self.log.append(("missing", key, errant_worker)) + self._task_state_collections = [unrunnable] - ts: TaskState = self.tasks.get(key) - if ts is None or not ts._who_has: - return - if errant_worker in self.workers: - ws: WorkerState = self.workers[errant_worker] - if ws in ts._who_has: - ts._who_has.remove(ws) - ws._has_what.remove(ts) - ws._nbytes -= ts.get_nbytes() - if not ts._who_has: - if ts._run_spec: - self.transitions({key: "released"}) - else: - self.transitions({key: "forgotten"}) + self._worker_collections = [ + workers, + host_info, + resources, + aliases, + ] - def release_worker_data(self, comm=None, keys=None, worker=None): - ws: WorkerState = self.workers[worker] - tasks = {self.tasks[k] for k in keys} - removed_tasks = tasks & ws._has_what - ws._has_what -= removed_tasks + self.plugins = list(plugins) + self.transition_log = deque( + maxlen=dask.config.get("distributed.scheduler.transition-log-length") + ) + self.log = deque( + maxlen=dask.config.get("distributed.scheduler.transition-log-length") + ) + self.events = defaultdict(lambda: deque(maxlen=100000)) + self.event_counts = defaultdict(int) + self.worker_plugins = [] - ts: TaskState - recommendations: dict = {} - for ts in removed_tasks: - ws._nbytes -= ts.get_nbytes() - wh = ts._who_has - wh.remove(ws) - if not wh: - recommendations[ts._key] = "released" - if recommendations: - self.transitions(recommendations) - - def handle_long_running(self, key=None, worker=None, compute_duration=None): - """A task has seceded from the thread pool - - We stop the task from being stolen in the future, and change task - duration accounting as if the task has stopped. - """ - ts: TaskState = self.tasks[key] - if "stealing" in self.extensions: - self.extensions["stealing"].remove_key_from_stealable(ts) - - ws: WorkerState = ts._processing_on - if ws is None: - logger.debug("Received long-running signal from duplicate task. Ignoring.") - return + worker_handlers = { + "task-finished": self.handle_task_finished, + "task-erred": self.handle_task_erred, + "release": self.handle_release_data, + "release-worker-data": self.release_worker_data, + "add-keys": self.add_keys, + "missing-data": self.handle_missing_data, + "long-running": self.handle_long_running, + "reschedule": self.reschedule, + "keep-alive": lambda *args, **kwargs: None, + "log-event": self.log_worker_event, + } - if compute_duration: - old_duration = ts._prefix._duration_average - new_duration = compute_duration - if old_duration < 0: - avg_duration = new_duration - else: - avg_duration = 0.5 * old_duration + 0.5 * new_duration + client_handlers = { + "update-graph": self.update_graph, + "update-graph-hlg": self.update_graph_hlg, + "client-desires-keys": self.client_desires_keys, + "update-data": self.update_data, + "report-key": self.report_on_key, + "client-releases-keys": self.client_releases_keys, + "heartbeat-client": self.client_heartbeat, + "close-client": self.remove_client, + "restart": self.restart, + } - ts._prefix._duration_average = avg_duration + self.handlers = { + "register-client": self.add_client, + "scatter": self.scatter, + "register-worker": self.add_worker, + "unregister": self.remove_worker, + "gather": self.gather, + "cancel": self.stimulus_cancel, + "retry": self.stimulus_retry, + "feed": self.feed, + "terminate": self.close, + "broadcast": self.broadcast, + "proxy": self.proxy, + "ncores": self.get_ncores, + "has_what": self.get_has_what, + "who_has": self.get_who_has, + "processing": self.get_processing, + "call_stack": self.get_call_stack, + "profile": self.get_profile, + "performance_report": self.performance_report, + "get_logs": self.get_logs, + "logs": self.get_logs, + "worker_logs": self.get_worker_logs, + "log_event": self.log_worker_event, + "events": self.get_events, + "nbytes": self.get_nbytes, + "versions": self.versions, + "add_keys": self.add_keys, + "rebalance": self.rebalance, + "replicate": self.replicate, + "start_ipython": self.start_ipython, + "run_function": self.run_function, + "update_data": self.update_data, + "set_resources": self.add_resources, + "retire_workers": self.retire_workers, + "get_metadata": self.get_metadata, + "set_metadata": self.set_metadata, + "heartbeat_worker": self.heartbeat_worker, + "get_task_status": self.get_task_status, + "get_task_stream": self.get_task_stream, + "register_worker_plugin": self.register_worker_plugin, + "adaptive_target": self.adaptive_target, + "workers_to_close": self.workers_to_close, + "subscribe_worker_status": self.subscribe_worker_status, + "start_task_metadata": self.start_task_metadata, + "stop_task_metadata": self.stop_task_metadata, + } - ws._occupancy -= ws._processing[ts] - self.total_occupancy -= ws._processing[ts] - ws._processing[ts] = 0 - self.check_idle_saturated(ws) + self._transitions = { + ("released", "waiting"): self.transition_released_waiting, + ("waiting", "released"): self.transition_waiting_released, + ("waiting", "processing"): self.transition_waiting_processing, + ("waiting", "memory"): self.transition_waiting_memory, + ("processing", "released"): self.transition_processing_released, + ("processing", "memory"): self.transition_processing_memory, + ("processing", "erred"): self.transition_processing_erred, + ("no-worker", "released"): self.transition_no_worker_released, + ("no-worker", "waiting"): self.transition_no_worker_waiting, + ("released", "forgotten"): self.transition_released_forgotten, + ("memory", "forgotten"): self.transition_memory_forgotten, + ("erred", "forgotten"): self.transition_released_forgotten, + ("erred", "released"): self.transition_erred_released, + ("memory", "released"): self.transition_memory_released, + ("released", "erred"): self.transition_released_erred, + } - async def handle_worker(self, comm=None, worker=None): - """ - Listen to responses from a single worker + connection_limit = get_fileno_limit() / 2 - This is the main loop for scheduler-worker interaction + super().__init__( + aliases=aliases, + handlers=self.handlers, + stream_handlers=merge(worker_handlers, client_handlers), + io_loop=self.loop, + connection_limit=connection_limit, + deserialize=False, + connection_args=self.connection_args, + clients=clients, + workers=workers, + host_info=host_info, + resources=resources, + tasks=tasks, + unrunnable=unrunnable, + validate=validate, + **kwargs, + ) - See Also - -------- - Scheduler.handle_client: Equivalent coroutine for clients - """ - comm.name = "Scheduler connection to worker" - worker_comm = self.stream_comms[worker] - worker_comm.start(comm) - logger.info("Starting worker compute stream, %s", worker) - try: - await self.handle_stream(comm=comm, extra={"worker": worker}) - finally: - if worker in self.stream_comms: - worker_comm.abort() - await self.remove_worker(address=worker) + if self.worker_ttl: + pc = PeriodicCallback(self.check_worker_ttl, self.worker_ttl) + self.periodic_callbacks["worker-ttl"] = pc - def add_plugin(self, plugin=None, idempotent=False, **kwargs): - """ - Add external plugin to scheduler + if self.idle_timeout: + pc = PeriodicCallback(self.check_idle, self.idle_timeout / 4) + self.periodic_callbacks["idle-timeout"] = pc - See https://distributed.readthedocs.io/en/latest/plugins.html - """ - if isinstance(plugin, type): - plugin = plugin(self, **kwargs) + if extensions is None: + extensions = list(DEFAULT_EXTENSIONS) + if dask.config.get("distributed.scheduler.work-stealing"): + extensions.append(WorkStealing) + for ext in extensions: + ext(self) - if idempotent and any(isinstance(p, type(plugin)) for p in self.plugins): - return + setproctitle("dask-scheduler [not started]") + Scheduler._instances.add(self) + self.rpc.allow_offload = False + self.status = Status.undefined - self.plugins.append(plugin) + ################## + # Administration # + ################## - def remove_plugin(self, plugin): - """ Remove external plugin from scheduler """ - self.plugins.remove(plugin) + def __repr__(self): + parent: SchedulerState = cast(SchedulerState, self) + return '' % ( + self.address, + len(parent._workers), + parent._total_nthreads, + ) - def worker_send(self, worker, msg): - """Send message to worker + def identity(self, comm=None): + """ Basic information about ourselves and our cluster """ + parent: SchedulerState = cast(SchedulerState, self) + d = { + "type": type(self).__name__, + "id": str(self.id), + "address": self.address, + "services": {key: v.port for (key, v) in self.services.items()}, + "started": self.time_started, + "workers": { + worker.address: worker.identity() for worker in parent._workers.values() + }, + } + return d - This also handles connection failures by adding a callback to remove - the worker on the next cycle. + def get_worker_service_addr(self, worker, service_name, protocol=False): """ - stream_comms: dict = self.stream_comms - try: - stream_comms[worker].send(msg) - except (CommClosedError, AttributeError): - self.loop.add_callback(self.remove_worker, address=worker) - - ############################ - # Less common interactions # - ############################ - - async def scatter( - self, - comm=None, - data=None, - workers=None, - client=None, - broadcast=False, - timeout=2, - ): - """Send data out to workers + Get the (host, port) address of the named service on the *worker*. + Returns None if the service doesn't exist. - See also - -------- - Scheduler.broadcast: + Parameters + ---------- + worker : address + service_name : str + Common services include 'bokeh' and 'nanny' + protocol : boolean + Whether or not to include a full address with protocol (True) + or just a (host, port) pair """ - start = time() - while not self.workers: - await asyncio.sleep(0.2) - if time() > start + timeout: - raise TimeoutError("No workers found") - - if workers is None: - ws: WorkerState - nthreads = {w: ws._nthreads for w, ws in self.workers.items()} + parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState = parent._workers[worker] + port = ws._services.get(service_name) + if port is None: + return None + elif protocol: + return "%(protocol)s://%(host)s:%(port)d" % { + "protocol": ws._address.split("://")[0], + "host": ws.host, + "port": port, + } else: - workers = [self.coerce_address(w) for w in workers] - nthreads = {w: self.workers[w].nthreads for w in workers} + return ws.host, port - assert isinstance(data, dict) + async def start(self): + """ Clear out old state and restart all running coroutines """ + await super().start() + assert self.status != Status.running - keys, who_has, nbytes = await scatter_to_workers( - nthreads, data, rpc=self.rpc, report=False - ) + enable_gc_diagnosis() - self.update_data(who_has=who_has, nbytes=nbytes, client=client) + self.clear_task_state() - if broadcast: - if broadcast == True: # noqa: E712 - n = len(nthreads) - else: - n = broadcast - await self.replicate(keys=keys, workers=workers, n=n) + with suppress(AttributeError): + for c in self._worker_coroutines: + c.cancel() - self.log_event( - [client, "all"], {"action": "scatter", "client": client, "count": len(data)} - ) - return keys + for addr in self._start_address: + await self.listen( + addr, + allow_offload=False, + handshake_overrides={"pickle-protocol": 4, "compression": None}, + **self.security.get_listen_args("scheduler"), + ) + self.ip = get_address_host(self.listen_address) + listen_ip = self.ip - async def gather(self, comm=None, keys=None, serializers=None): - """ Collect data in from workers """ - ws: WorkerState - keys = list(keys) - who_has = {} - for key in keys: - ts: TaskState = self.tasks.get(key) - if ts is not None: - who_has[key] = [ws._address for ws in ts._who_has] - else: - who_has[key] = [] + if listen_ip == "0.0.0.0": + listen_ip = "" - data, missing_keys, missing_workers = await gather_from_workers( - who_has, rpc=self.rpc, close=False, serializers=serializers - ) - if not missing_keys: - result = {"status": "OK", "data": data} - else: - missing_states = [ - (self.tasks[key].state if key in self.tasks else None) - for key in missing_keys - ] - logger.exception( - "Couldn't gather keys %s state: %s workers: %s", - missing_keys, - missing_states, - missing_workers, - ) - result = {"status": "error", "keys": missing_keys} - with log_errors(): - # Remove suspicious workers from the scheduler but allow them to - # reconnect. - await asyncio.gather( - *[ - self.remove_worker(address=worker, close=False) - for worker in missing_workers - ] - ) - for key, workers in missing_keys.items(): - # Task may already be gone if it was held by a - # `missing_worker` - ts: TaskState = self.tasks.get(key) - logger.exception( - "Workers don't have promised key: %s, %s", - str(workers), - str(key), - ) - if not workers or ts is None: - continue - for worker in workers: - ws = self.workers.get(worker) - if ws is not None and ts in ws._has_what: - ws._has_what.remove(ts) - ts._who_has.remove(ws) - ws._nbytes -= ts.get_nbytes() - self.transitions({key: "released"}) + if self.address.startswith("inproc://"): + listen_ip = "localhost" - self.log_event("all", {"action": "gather", "count": len(keys)}) - return result + # Services listen on all addresses + self.start_services(listen_ip) - def clear_task_state(self): - # XXX what about nested state such as ClientState.wants_what - # (see also fire-and-forget...) - logger.info("Clear task state") - for collection in self._task_state_collections: - collection.clear() + for listener in self.listeners: + logger.info(" Scheduler at: %25s", listener.contact_address) + for k, v in self.services.items(): + logger.info("%11s at: %25s", k, "%s:%d" % (listen_ip, v.port)) - async def restart(self, client=None, timeout=3): - """ Restart all workers. Reset local state. """ - with log_errors(): + self.loop.add_callback(self.reevaluate_occupancy) - n_workers = len(self.workers) + if self.scheduler_file: + with open(self.scheduler_file, "w") as f: + json.dump(self.identity(), f, indent=2) - logger.info("Send lost future signal to clients") - cs: ClientState - ts: TaskState - for cs in self.clients.values(): - self.client_releases_keys( - keys=[ts._key for ts in cs._wants_what], client=cs._client_key - ) + fn = self.scheduler_file # remove file when we close the process - ws: WorkerState - nannies = {addr: ws._nanny for addr, ws in self.workers.items()} + def del_scheduler_file(): + if os.path.exists(fn): + os.remove(fn) - for addr in list(self.workers): - try: - # Ask the worker to close if it doesn't have a nanny, - # otherwise the nanny will kill it anyway - await self.remove_worker(address=addr, close=addr not in nannies) - except Exception as e: - logger.info( - "Exception while restarting. This is normal", exc_info=True - ) + weakref.finalize(self, del_scheduler_file) - self.clear_task_state() + for preload in self.preloads: + await preload.start() - for plugin in self.plugins[:]: - try: - plugin.restart(self) - except Exception as e: - logger.exception(e) + await asyncio.gather(*[plugin.start(self) for plugin in self.plugins]) - logger.debug("Send kill signal to nannies: %s", nannies) + self.start_periodic_callbacks() - nannies = [ - rpc(nanny_address, connection_args=self.connection_args) - for nanny_address in nannies.values() - if nanny_address is not None - ] + setproctitle("dask-scheduler [%s]" % (self.address,)) + return self - resps = All( - [ - nanny.restart( - close=True, timeout=timeout * 0.8, executor_wait=False - ) - for nanny in nannies - ] - ) - try: - resps = await asyncio.wait_for(resps, timeout) - except TimeoutError: - logger.error( - "Nannies didn't report back restarted within " - "timeout. Continuuing with restart process" - ) - else: - if not all(resp == "OK" for resp in resps): - logger.error( - "Not all workers responded positively: %s", resps, exc_info=True - ) - finally: - await asyncio.gather(*[nanny.close_rpc() for nanny in nannies]) + async def close(self, comm=None, fast=False, close_workers=False): + """Send cleanup signal to all coroutines then wait until finished - self.clear_task_state() + See Also + -------- + Scheduler.cleanup + """ + parent: SchedulerState = cast(SchedulerState, self) + if self.status in (Status.closing, Status.closed, Status.closing_gracefully): + await self.finished() + return + self.status = Status.closing - with suppress(AttributeError): - for c in self._worker_coroutines: - c.cancel() + logger.info("Scheduler closing...") + setproctitle("dask-scheduler [closing]") - self.log_event([client, "all"], {"action": "restart", "client": client}) - start = time() - while time() < start + 10 and len(self.workers) < n_workers: - await asyncio.sleep(0.01) + for preload in self.preloads: + await preload.teardown() - self.report({"op": "restart"}) + if close_workers: + await self.broadcast(msg={"op": "close_gracefully"}, nanny=True) + for worker in parent._workers: + self.worker_send(worker, {"op": "close"}) + for i in range(20): # wait a second for send signals to clear + if parent._workers: + await asyncio.sleep(0.05) + else: + break - async def broadcast( - self, - comm=None, - msg=None, - workers=None, - hosts=None, - nanny=False, - serializers=None, - ): - """ Broadcast message to workers, return all results """ - if workers is None or workers is True: - if hosts is None: - workers = list(self.workers) - else: - workers = [] - if hosts is not None: - for host in hosts: - if host in self.host_info: - workers.extend(self.host_info[host]["addresses"]) - # TODO replace with worker_list + await asyncio.gather(*[plugin.close() for plugin in self.plugins]) - if nanny: - addresses = [self.workers[w].nanny for w in workers] - else: - addresses = workers + for pc in self.periodic_callbacks.values(): + pc.stop() + self.periodic_callbacks.clear() - async def send_message(addr): - comm = await self.rpc.connect(addr) - comm.name = "Scheduler Broadcast" - try: - resp = await send_recv(comm, close=True, serializers=serializers, **msg) - finally: - self.rpc.reuse(addr, comm) - return resp + self.stop_services() - results = await All( - [send_message(address) for address in addresses if address is not None] - ) + for ext in parent._extensions.values(): + with suppress(AttributeError): + ext.teardown() + logger.info("Scheduler closing all comms") - return dict(zip(workers, results)) + futures = [] + for w, comm in list(self.stream_comms.items()): + if not comm.closed(): + comm.send({"op": "close", "report": False}) + comm.send({"op": "close-stream"}) + with suppress(AttributeError): + futures.append(comm.close()) - async def proxy(self, comm=None, msg=None, worker=None, serializers=None): - """ Proxy a communication through the scheduler to some other worker """ - d = await self.broadcast( - comm=comm, msg=msg, workers=[worker], serializers=serializers - ) - return d[worker] + for future in futures: # TODO: do all at once + await future - async def _delete_worker_data(self, worker_address, keys): - """Delete data from a worker and update the corresponding worker/task states + for comm in self.client_comms.values(): + comm.abort() - Parameters - ---------- - worker_address: str - Worker address to delete keys from - keys: List[str] - List of keys to delete on the specified worker - """ - await retry_operation( - self.rpc(addr=worker_address).delete_data, keys=list(keys), report=False - ) + await self.rpc.close() - ws: WorkerState = self.workers[worker_address] - ts: TaskState - tasks: set = {self.tasks[key] for key in keys} - ws._has_what -= tasks - for ts in tasks: - ts._who_has.remove(ws) - ws._nbytes -= ts.get_nbytes() - self.log_event(ws._address, {"action": "remove-worker-data", "keys": keys}) + self.status = Status.closed + self.stop() + await super().close() - async def rebalance(self, comm=None, keys=None, workers=None): - """Rebalance keys so that each worker stores roughly equal bytes + setproctitle("dask-scheduler [closed]") + disable_gc_diagnosis() - **Policy** + async def close_worker(self, comm=None, worker=None, safe=None): + """Remove a worker from the cluster - This orders the workers by what fraction of bytes of the existing keys - they have. It walks down this list from most-to-least. At each worker - it sends the largest results it can find and sends them to the least - occupied worker until either the sender or the recipient are at the - average expected load. + This both removes the worker from our local state and also sends a + signal to the worker to shut down. This works regardless of whether or + not the worker has a nanny process restarting it """ - ts: TaskState + parent: SchedulerState = cast(SchedulerState, self) + logger.info("Closing worker %s", worker) with log_errors(): - async with self._lock: - if keys: - tasks = {self.tasks[k] for k in keys} - missing_data = [ts._key for ts in tasks if not ts._who_has] - if missing_data: - return {"status": "missing-data", "keys": missing_data} - else: - tasks = set(self.tasks.values()) - - if workers: - workers = {self.workers[w] for w in workers} - workers_by_task = {ts: ts._who_has & workers for ts in tasks} - else: - workers = set(self.workers.values()) - workers_by_task = {ts: ts._who_has for ts in tasks} + self.log_event(worker, {"action": "close-worker"}) + ws: WorkerState = parent._workers[worker] + nanny_addr = ws._nanny + address = nanny_addr or worker - ws: WorkerState - tasks_by_worker = {ws: set() for ws in workers} + self.worker_send(worker, {"op": "close", "report": False}) + await self.remove_worker(address=worker, safe=safe) - for k, v in workers_by_task.items(): - for vv in v: - tasks_by_worker[vv].add(k) + ########### + # Stimuli # + ########### - worker_bytes = { - ws: sum(ts.get_nbytes() for ts in v) - for ws, v in tasks_by_worker.items() - } + def heartbeat_worker( + self, + comm=None, + address=None, + resolve_address=True, + now=None, + resources=None, + host_info=None, + metrics=None, + executing=None, + ): + parent: SchedulerState = cast(SchedulerState, self) + address = self.coerce_address(address, resolve_address) + address = normalize_address(address) + if address not in parent._workers: + return {"status": "missing"} - avg = sum(worker_bytes.values()) / len(worker_bytes) + host = get_address_host(address) + local_now = time() + now = now or time() + assert metrics + host_info = host_info or {} - sorted_workers = list( - map(first, sorted(worker_bytes.items(), key=second, reverse=True)) + parent._host_info[host]["last-seen"] = local_now + frac = 1 / len(parent._workers) + parent._bandwidth = ( + parent._bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac + ) + for other, (bw, count) in metrics["bandwidth"]["workers"].items(): + if (address, other) not in self.bandwidth_workers: + self.bandwidth_workers[address, other] = bw / count + else: + alpha = (1 - frac) ** count + self.bandwidth_workers[address, other] = self.bandwidth_workers[ + address, other + ] * alpha + bw * (1 - alpha) + for typ, (bw, count) in metrics["bandwidth"]["types"].items(): + if typ not in self.bandwidth_types: + self.bandwidth_types[typ] = bw / count + else: + alpha = (1 - frac) ** count + self.bandwidth_types[typ] = self.bandwidth_types[typ] * alpha + bw * ( + 1 - alpha ) - recipients = iter(reversed(sorted_workers)) - recipient = next(recipients) - msgs = [] # (sender, recipient, key) - for sender in sorted_workers[: len(workers) // 2]: - sender_keys = { - ts: ts.get_nbytes() for ts in tasks_by_worker[sender] - } - sender_keys = iter( - sorted(sender_keys.items(), key=second, reverse=True) - ) + ws: WorkerState = parent._workers[address] - try: - while worker_bytes[sender] > avg: - while ( - worker_bytes[recipient] < avg - and worker_bytes[sender] > avg - ): - ts, nb = next(sender_keys) - if ts not in tasks_by_worker[recipient]: - tasks_by_worker[recipient].add(ts) - # tasks_by_worker[sender].remove(ts) - msgs.append((sender, recipient, ts)) - worker_bytes[sender] -= nb - worker_bytes[recipient] += nb - if worker_bytes[sender] > avg: - recipient = next(recipients) - except StopIteration: - break + ws._last_seen = time() - to_recipients = defaultdict(lambda: defaultdict(list)) - to_senders = defaultdict(list) - for sender, recipient, ts in msgs: - to_recipients[recipient.address][ts._key].append(sender.address) - to_senders[sender.address].append(ts._key) + if executing is not None: + ws._executing = { + self.tasks[key]: duration for key, duration in executing.items() + } - result = await asyncio.gather( - *( - retry_operation(self.rpc(addr=r).gather, who_has=v) - for r, v in to_recipients.items() - ) - ) - for r, v in to_recipients.items(): - self.log_event(r, {"action": "rebalance", "who_has": v}) + if metrics: + ws._metrics = metrics - self.log_event( - "all", - { - "action": "rebalance", - "total-keys": len(tasks), - "senders": valmap(len, to_senders), - "recipients": valmap(len, to_recipients), - "moved_keys": len(msgs), - }, - ) + if host_info: + parent._host_info[host].update(host_info) - if not all(r["status"] == "OK" for r in result): - return { - "status": "missing-data", - "keys": tuple( - concat( - r["keys"].keys() - for r in result - if r["status"] == "missing-data" - ) - ), - } + delay = time() - now + ws._time_delay = delay - for sender, recipient, ts in msgs: - assert ts._state == "memory" - ts._who_has.add(recipient) - recipient.has_what.add(ts) - recipient.nbytes += ts.get_nbytes() - self.log.append( - ( - "rebalance", - ts._key, - time(), - sender.address, - recipient.address, - ) - ) + if resources: + self.add_resources(worker=address, resources=resources) - await asyncio.gather( - *(self._delete_worker_data(r, v) for r, v in to_senders.items()) - ) + self.log_event(address, merge({"action": "heartbeat"}, metrics)) - return {"status": "OK"} + return { + "status": "OK", + "time": time(), + "heartbeat-interval": heartbeat_interval(len(parent._workers)), + } - async def replicate( + async def add_worker( self, comm=None, - keys=None, - n=None, - workers=None, - branching_factor=2, - delete=True, - lock=True, + address=None, + keys=(), + nthreads=None, + name=None, + resolve_address=True, + nbytes=None, + types=None, + now=None, + resources=None, + host_info=None, + memory_limit=None, + metrics=None, + pid=0, + services=None, + local_directory=None, + versions=None, + nanny=None, + extra=None, ): - """Replicate data throughout cluster - - This performs a tree copy of the data throughout the network - individually on each piece of data. + """ Add a new worker to the cluster """ + parent: SchedulerState = cast(SchedulerState, self) + with log_errors(): + address = self.coerce_address(address, resolve_address) + address = normalize_address(address) + host = get_address_host(address) - Parameters - ---------- - keys: Iterable - list of keys to replicate - n: int - Number of replications we expect to see within the cluster - branching_factor: int, optional - The number of workers that can copy data in each generation. - The larger the branching factor, the more data we copy in - a single step, but the more a given worker risks being - swamped by data requests. + ws: WorkerState = parent._workers.get(address) + if ws is not None: + raise ValueError("Worker already exists %s" % ws) - See also - -------- - Scheduler.rebalance - """ - ws: WorkerState - wws: WorkerState - ts: TaskState + if name in parent._aliases: + logger.warning( + "Worker tried to connect with a duplicate name: %s", name + ) + msg = { + "status": "error", + "message": "name taken, %s" % name, + "time": time(), + } + if comm: + await comm.write(msg) + return - assert branching_factor > 0 - async with self._lock if lock else empty_context: - workers = {self.workers[w] for w in self.workers_list(workers)} - if n is None: - n = len(workers) - else: - n = min(n, len(workers)) - if n == 0: - raise ValueError("Can not use replicate to delete data") - - tasks = {self.tasks[k] for k in keys} - missing_data = [ts._key for ts in tasks if not ts._who_has] - if missing_data: - return {"status": "missing-data", "keys": missing_data} - - # Delete extraneous data - if delete: - del_worker_tasks = defaultdict(set) - for ts in tasks: - del_candidates = ts._who_has & workers - if len(del_candidates) > n: - for ws in random.sample( - del_candidates, len(del_candidates) - n - ): - del_worker_tasks[ws].add(ts) - - await asyncio.gather( - *[ - self._delete_worker_data(ws._address, [t.key for t in tasks]) - for ws, tasks in del_worker_tasks.items() - ] - ) - - # Copy not-yet-filled data - while tasks: - gathers = defaultdict(dict) - for ts in list(tasks): - if ts._state == "forgotten": - # task is no longer needed by any client or dependant task - tasks.remove(ts) - continue - n_missing = n - len(ts._who_has & workers) - if n_missing <= 0: - # Already replicated enough - tasks.remove(ts) - continue - - count = min(n_missing, branching_factor * len(ts._who_has)) - assert count > 0 + parent._workers[address] = ws = WorkerState( + address=address, + pid=pid, + nthreads=nthreads, + memory_limit=memory_limit or 0, + name=name, + local_directory=local_directory, + services=services, + versions=versions, + nanny=nanny, + extra=extra, + ) - for ws in random.sample(workers - ts._who_has, count): - gathers[ws._address][ts._key] = [ - wws._address for wws in ts._who_has - ] + if "addresses" not in parent._host_info[host]: + parent._host_info[host].update({"addresses": set(), "nthreads": 0}) - results = await asyncio.gather( - *( - retry_operation(self.rpc(addr=w).gather, who_has=who_has) - for w, who_has in gathers.items() - ) - ) - for w, v in zip(gathers, results): - if v["status"] == "OK": - self.add_keys(worker=w, keys=list(gathers[w])) - else: - logger.warning("Communication failed during replication: %s", v) + parent._host_info[host]["addresses"].add(address) + parent._host_info[host]["nthreads"] += nthreads - self.log_event(w, {"action": "replicate-add", "keys": gathers[w]}) + parent._total_nthreads += nthreads + parent._aliases[name] = address - self.log_event( - "all", - { - "action": "replicate", - "workers": list(workers), - "key-count": len(keys), - "branching-factor": branching_factor, - }, + response = self.heartbeat_worker( + address=address, + resolve_address=resolve_address, + now=now, + resources=resources, + host_info=host_info, + metrics=metrics, ) - def workers_to_close( - self, - comm=None, - memory_ratio=None, - n=None, - key=None, - minimum=None, - target=None, - attribute="address", - ): - """ - Find workers that we can close with low cost - - This returns a list of workers that are good candidates to retire. - These workers are not running anything and are storing - relatively little data relative to their peers. If all workers are - idle then we still maintain enough workers to have enough RAM to store - our data, with a comfortable buffer. - - This is for use with systems like ``distributed.deploy.adaptive``. + # Do not need to adjust parent._total_occupancy as self.occupancy[ws] cannot exist before this. + self.check_idle_saturated(ws) - Parameters - ---------- - memory_factor: Number - Amount of extra space we want to have for our stored data. - Defaults two 2, or that we want to have twice as much memory as we - currently have data. - n: int - Number of workers to close - minimum: int - Minimum number of workers to keep around - key: Callable(WorkerState) - An optional callable mapping a WorkerState object to a group - affiliation. Groups will be closed together. This is useful when - closing workers must be done collectively, such as by hostname. - target: int - Target number of workers to have after we close - attribute : str - The attribute of the WorkerState object to return, like "address" - or "name". Defaults to "address". + # for key in keys: # TODO + # self.mark_key_in_memory(key, [address]) - Examples - -------- - >>> scheduler.workers_to_close() - ['tcp://192.168.0.1:1234', 'tcp://192.168.0.2:1234'] + self.stream_comms[address] = BatchedSend(interval="5ms", loop=self.loop) - Group workers by hostname prior to closing + if ws._nthreads > len(ws._processing): + parent._idle[ws._address] = ws - >>> scheduler.workers_to_close(key=lambda ws: ws.host) - ['tcp://192.168.0.1:1234', 'tcp://192.168.0.1:4567'] + for plugin in self.plugins[:]: + try: + result = plugin.add_worker(scheduler=self, worker=address) + if inspect.isawaitable(result): + await result + except Exception as e: + logger.exception(e) - Remove two workers + recommendations: dict + if nbytes: + for key in nbytes: + tasks: dict = parent._tasks + ts: TaskState = tasks.get(key) + if ts is not None and ts._state in ("processing", "waiting"): + recommendations = self.transition( + key, + "memory", + worker=address, + nbytes=nbytes[key], + typename=types[key], + ) + self.transitions(recommendations) - >>> scheduler.workers_to_close(n=2) + recommendations = {} + for ts in list(parent._unrunnable): + valid: set = self.valid_workers(ts) + if valid is None or ws in valid: + recommendations[ts._key] = "waiting" - Keep enough workers to have twice as much memory as we we need. + if recommendations: + self.transitions(recommendations) - >>> scheduler.workers_to_close(memory_ratio=2) + self.log_event(address, {"action": "add-worker"}) + self.log_event("all", {"action": "add-worker", "worker": address}) + logger.info("Register worker %s", ws) - Returns - ------- - to_close: list of worker addresses that are OK to close + msg = { + "status": "OK", + "time": time(), + "heartbeat-interval": heartbeat_interval(len(parent._workers)), + "worker-plugins": self.worker_plugins, + } - See Also - -------- - Scheduler.retire_workers - """ - if target is not None and n is None: - n = len(self.workers) - target - if n is not None: - if n < 0: - n = 0 - target = len(self.workers) - n + cs: ClientState + version_warning = version_module.error_message( + version_module.get_versions(), + merge( + {w: ws._versions for w, ws in parent._workers.items()}, + { + c: cs._versions + for c, cs in parent._clients.items() + if cs._versions + }, + ), + versions, + client_name="This Worker", + ) + msg.update(version_warning) - if n is None and memory_ratio is None: - memory_ratio = 2 + if comm: + await comm.write(msg) + await self.handle_worker(comm=comm, worker=address) - ws: WorkerState - with log_errors(): - if not n and all([ws._processing for ws in self.workers.values()]): - return [] + def update_graph_hlg( + self, + client=None, + hlg=None, + keys=None, + dependencies=None, + restrictions=None, + priority=None, + loose_restrictions=None, + resources=None, + submitting_task=None, + retries=None, + user_priority=0, + actors=None, + fifo_timeout=0, + ): - if key is None: - key = operator.attrgetter("address") - if isinstance(key, bytes) and dask.config.get( - "distributed.scheduler.pickle" - ): - key = pickle.loads(key) + dsk, dependencies, annotations = highlevelgraph_unpack(hlg) - groups = groupby(key, self.workers.values()) + # Remove any self-dependencies (happens on test_publish_bag() and others) + for k, v in dependencies.items(): + deps = set(v) + if k in deps: + deps.remove(k) + dependencies[k] = deps - limit_bytes = { - k: sum([ws._memory_limit for ws in v]) for k, v in groups.items() + if priority is None: + # Removing all non-local keys before calling order() + dsk_keys = set(dsk) # intersection() of sets is much faster than dict_keys + stripped_deps = { + k: v.intersection(dsk_keys) + for k, v in dependencies.items() + if k in dsk_keys } - group_bytes = {k: sum([ws._nbytes for ws in v]) for k, v in groups.items()} - - limit = sum(limit_bytes.values()) - total = sum(group_bytes.values()) - - def _key(group): - wws: WorkerState - is_idle = not any([wws._processing for wws in groups[group]]) - bytes = -group_bytes[group] - return (is_idle, bytes) + priority = dask.order.order(dsk, dependencies=stripped_deps) - idle = sorted(groups, key=_key) + return self.update_graph( + client, + dsk, + keys, + dependencies, + restrictions, + priority, + loose_restrictions, + resources, + submitting_task, + retries, + user_priority, + actors, + fifo_timeout, + annotations, + ) + + def update_graph( + self, + client=None, + tasks=None, + keys=None, + dependencies=None, + restrictions=None, + priority=None, + loose_restrictions=None, + resources=None, + submitting_task=None, + retries=None, + user_priority=0, + actors=None, + fifo_timeout=0, + annotations=None, + ): + """ + Add new computations to the internal dask graph + + This happens whenever the Client calls submit, map, get, or compute. + """ + parent: SchedulerState = cast(SchedulerState, self) + start = time() + fifo_timeout = parse_timedelta(fifo_timeout) + keys = set(keys) + if len(tasks) > 1: + self.log_event( + ["all", client], {"action": "update_graph", "count": len(tasks)} + ) + + # Remove aliases + for k in list(tasks): + if tasks[k] is k: + del tasks[k] + + dependencies = dependencies or {} + + n = 0 + while len(tasks) != n: # walk through new tasks, cancel any bad deps + n = len(tasks) + for k, deps in list(dependencies.items()): + if any( + dep not in parent._tasks and dep not in tasks for dep in deps + ): # bad key + logger.info("User asked for computation on lost data, %s", k) + del tasks[k] + del dependencies[k] + if k in keys: + keys.remove(k) + self.report({"op": "cancelled-key", "key": k}, client=client) + self.client_releases_keys(keys=[k], client=client) + + # Avoid computation that is already finished + ts: TaskState + already_in_memory = set() # tasks that are already done + for k, v in dependencies.items(): + if v and k in parent._tasks: + ts = parent._tasks[k] + if ts._state in ("memory", "erred"): + already_in_memory.add(k) + + dts: TaskState + if already_in_memory: + dependents = dask.core.reverse_dict(dependencies) + stack = list(already_in_memory) + done = set(already_in_memory) + while stack: # remove unnecessary dependencies + key = stack.pop() + ts = parent._tasks[key] + try: + deps = dependencies[key] + except KeyError: + deps = self.dependencies[key] + for dep in deps: + if dep in dependents: + child_deps = dependents[dep] + else: + child_deps = self.dependencies[dep] + if all(d in done for d in child_deps): + if dep in parent._tasks and dep not in done: + done.add(dep) + stack.append(dep) + + for d in done: + tasks.pop(d, None) + dependencies.pop(d, None) + + # Get or create task states + stack = list(keys) + touched_keys = set() + touched_tasks = [] + while stack: + k = stack.pop() + if k in touched_keys: + continue + # XXX Have a method get_task_state(self, k) ? + ts = parent._tasks.get(k) + if ts is None: + ts = self.new_task(k, tasks.get(k), "released") + elif not ts._run_spec: + ts._run_spec = tasks.get(k) + + touched_keys.add(k) + touched_tasks.append(ts) + stack.extend(dependencies.get(k, ())) + + self.client_desires_keys(keys=keys, client=client) + + # Add dependencies + for key, deps in dependencies.items(): + ts = parent._tasks.get(key) + if ts is None or ts._dependencies: + continue + for dep in deps: + dts = parent._tasks[dep] + ts.add_dependency(dts) + + # Compute priorities + if isinstance(user_priority, Number): + user_priority = {k: user_priority for k in tasks} + + annotations = annotations or {} + restrictions = restrictions or {} + loose_restrictions = loose_restrictions or [] + resources = resources or {} + retries = retries or {} + + # Override existing taxonomy with per task annotations + if annotations: + if "priority" in annotations: + user_priority.update(annotations["priority"]) + + if "workers" in annotations: + restrictions.update(annotations["workers"]) + + if "allow_other_workers" in annotations: + loose_restrictions.extend( + k for k, v in annotations["allow_other_workers"].items() if v + ) + + if "retries" in annotations: + retries.update(annotations["retries"]) + + if "resources" in annotations: + resources.update(annotations["resources"]) + + for a, kv in annotations.items(): + for k, v in kv.items(): + ts = parent._tasks[k] + ts._annotations[a] = v + + # Add actors + if actors is True: + actors = list(keys) + for actor in actors or []: + ts = parent._tasks[actor] + ts._actor = True + + priority = priority or dask.order.order( + tasks + ) # TODO: define order wrt old graph + + if submitting_task: # sub-tasks get better priority than parent tasks + ts = parent._tasks.get(submitting_task) + if ts is not None: + generation = ts._priority[0] - 0.01 + else: # super-task already cleaned up + generation = self.generation + elif self._last_time + fifo_timeout < start: + self.generation += 1 # older graph generations take precedence + generation = self.generation + self._last_time = start + else: + generation = self.generation + + for key in set(priority) & touched_keys: + ts = parent._tasks[key] + if ts._priority is None: + ts._priority = (-(user_priority.get(key, 0)), generation, priority[key]) + + # Ensure all runnables have a priority + runnables = [ts for ts in touched_tasks if ts._run_spec] + for ts in runnables: + if ts._priority is None and ts._run_spec: + ts._priority = (self.generation, 0) + + if restrictions: + # *restrictions* is a dict keying task ids to lists of + # restriction specifications (either worker names or addresses) + for k, v in restrictions.items(): + if v is None: + continue + ts = parent._tasks.get(k) + if ts is None: + continue + ts._host_restrictions = set() + ts._worker_restrictions = set() + for w in v: + try: + w = self.coerce_address(w) + except ValueError: + # Not a valid address, but perhaps it's a hostname + ts._host_restrictions.add(w) + else: + ts._worker_restrictions.add(w) + + if loose_restrictions: + for k in loose_restrictions: + ts = parent._tasks[k] + ts._loose_restrictions = True + + if resources: + for k, v in resources.items(): + if v is None: + continue + assert isinstance(v, dict) + ts = parent._tasks.get(k) + if ts is None: + continue + ts._resource_restrictions = v + + if retries: + for k, v in retries.items(): + assert isinstance(v, int) + ts = parent._tasks.get(k) + if ts is None: + continue + ts._retries = v + + # Compute recommendations + recommendations: dict = {} + + for ts in sorted(runnables, key=operator.attrgetter("priority"), reverse=True): + if ts._state == "released" and ts._run_spec: + recommendations[ts._key] = "waiting" + + for ts in touched_tasks: + for dts in ts._dependencies: + if dts._exception_blame: + ts._exception_blame = dts._exception_blame + recommendations[ts._key] = "erred" + break + + for plugin in self.plugins[:]: + try: + plugin.update_graph( + self, + client=client, + tasks=tasks, + keys=keys, + restrictions=restrictions or {}, + dependencies=dependencies, + priority=priority, + loose_restrictions=loose_restrictions, + resources=resources, + annotations=annotations, + ) + except Exception as e: + logger.exception(e) + + self.transitions(recommendations) + + for ts in touched_tasks: + if ts._state in ("memory", "erred"): + self.report_on_key(ts=ts, client=client) + + end = time() + if self.digests is not None: + self.digests["update-graph-duration"].add(end - start) + + # TODO: balance workers + + def new_task(self, key, spec, state): + """ Create a new task, and associated states """ + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState = TaskState(key, spec) + tp: TaskPrefix + tg: TaskGroup + ts._state = state + prefix_key = key_split(key) + try: + tp = self.task_prefixes[prefix_key] + except KeyError: + self.task_prefixes[prefix_key] = tp = TaskPrefix(prefix_key) + ts._prefix = tp + + group_key = ts._group_key + try: + tg = self.task_groups[group_key] + except KeyError: + self.task_groups[group_key] = tg = TaskGroup(group_key) + tg._prefix = tp + tp._groups.append(tg) + tg.add(ts) + parent._tasks[key] = ts + return ts + + def stimulus_task_finished(self, key=None, worker=None, **kwargs): + """ Mark that a task has finished execution on a particular worker """ + parent: SchedulerState = cast(SchedulerState, self) + logger.debug("Stimulus task finished %s, %s", key, worker) + + tasks: dict = parent._tasks + ts: TaskState = tasks.get(key) + if ts is None: + return {} + ws: WorkerState = parent._workers_dv[worker] + ts._metadata.update(kwargs["metadata"]) + + recommendations: dict + if ts._state == "processing": + recommendations = self.transition(key, "memory", worker=worker, **kwargs) + + if ts._state == "memory": + assert ws in ts._who_has + else: + logger.debug( + "Received already computed task, worker: %s, state: %s" + ", key: %s, who_has: %s", + worker, + ts._state, + key, + ts._who_has, + ) + if ws not in ts._who_has: + self.worker_send(worker, {"op": "release-task", "key": key}) + recommendations = {} - to_close = [] - n_remain = len(self.workers) + return recommendations - while idle: - group = idle.pop() - if n is None and any([ws._processing for ws in groups[group]]): - break + def stimulus_task_erred( + self, key=None, worker=None, exception=None, traceback=None, **kwargs + ): + """ Mark that a task has erred on a particular worker """ + parent: SchedulerState = cast(SchedulerState, self) + logger.debug("Stimulus task erred %s, %s", key, worker) - if minimum and n_remain - len(groups[group]) < minimum: - break + ts: TaskState = parent._tasks.get(key) + if ts is None: + return {} - limit -= limit_bytes[group] + recommendations: dict + if ts._state == "processing": + retries = ts._retries + if retries > 0: + ts._retries = retries - 1 + recommendations = self.transition(key, "waiting") + else: + recommendations = self.transition( + key, + "erred", + cause=key, + exception=exception, + traceback=traceback, + worker=worker, + **kwargs, + ) + else: + recommendations = {} - if (n is not None and n_remain - len(groups[group]) >= target) or ( - memory_ratio is not None and limit >= memory_ratio * total - ): - to_close.append(group) - n_remain -= len(groups[group]) + return recommendations - else: - break + def stimulus_missing_data( + self, cause=None, key=None, worker=None, ensure=True, **kwargs + ): + """ Mark that certain keys have gone missing. Recover. """ + parent: SchedulerState = cast(SchedulerState, self) + with log_errors(): + logger.debug("Stimulus missing data %s, %s", key, worker) - result = [getattr(ws, attribute) for g in to_close for ws in groups[g]] - if result: - logger.debug("Suggest closing workers: %s", result) + ts: TaskState = parent._tasks.get(key) + if ts is None or ts._state == "memory": + return {} + cts: TaskState = parent._tasks.get(cause) - return result + recommendations: dict = {} - async def retire_workers( - self, - comm=None, - workers=None, - remove=True, - close_workers=False, - names=None, - lock=True, - **kwargs, - ) -> dict: - """Gracefully retire workers from cluster + if cts is not None and cts._state == "memory": # couldn't find this + ws: WorkerState + for ws in cts._who_has: # TODO: this behavior is extreme + ws._has_what.remove(cts) + ws._nbytes -= cts.get_nbytes() + cts._who_has.clear() + recommendations[cause] = "released" - Parameters - ---------- - workers: list (optional) - List of worker addresses to retire. - If not provided we call ``workers_to_close`` which finds a good set - workers_names: list (optional) - List of worker names to retire. - remove: bool (defaults to True) - Whether or not to remove the worker metadata immediately or else - wait for the worker to contact us - close_workers: bool (defaults to False) - Whether or not to actually close the worker explicitly from here. - Otherwise we expect some external job scheduler to finish off the - worker. - **kwargs: dict - Extra options to pass to workers_to_close to determine which - workers we should drop + if key: + recommendations[key] = "released" - Returns - ------- - Dictionary mapping worker ID/address to dictionary of information about - that worker for each retired worker. + self.transitions(recommendations) - See Also - -------- - Scheduler.workers_to_close - """ - ws: WorkerState + if parent._validate: + assert cause not in self.who_has + + return {} + + def stimulus_retry(self, comm=None, keys=None, client=None): + parent: SchedulerState = cast(SchedulerState, self) + logger.info("Client %s requests to retry %d keys", client, len(keys)) + if client: + self.log_event(client, {"action": "retry", "count": len(keys)}) + + stack = list(keys) + seen = set() + roots = [] ts: TaskState + dts: TaskState + while stack: + key = stack.pop() + seen.add(key) + ts = parent._tasks[key] + erred_deps = [dts._key for dts in ts._dependencies if dts._state == "erred"] + if erred_deps: + stack.extend(erred_deps) + else: + roots.append(key) + + recommendations: dict = {key: "waiting" for key in roots} + self.transitions(recommendations) + + if parent._validate: + for key in seen: + assert not parent._tasks[key].exception_blame + + return tuple(seen) + + async def remove_worker(self, comm=None, address=None, safe=False, close=True): + """ + Remove worker from cluster + + We do this when a worker reports that it plans to leave or when it + appears to be unresponsive. This may send its tasks back to a released + state. + """ + parent: SchedulerState = cast(SchedulerState, self) with log_errors(): - async with self._lock if lock else empty_context: - if names is not None: - if names: - logger.info("Retire worker names %s", names) - names = set(map(str, names)) - workers = [ - ws._address - for ws in self.workers.values() - if str(ws._name) in names - ] - if workers is None: - while True: - try: - workers = self.workers_to_close(**kwargs) - if workers: - workers = await self.retire_workers( - workers=workers, - remove=remove, - close_workers=close_workers, - lock=False, - ) - return workers - else: - return {} - except KeyError: # keys left during replicate - pass - workers = {self.workers[w] for w in workers if w in self.workers} - if not workers: - return {} - logger.info("Retire workers %s", workers) + if self.status == Status.closed: + return - # Keys orphaned by retiring those workers - keys = set.union(*[w.has_what for w in workers]) - keys = {ts._key for ts in keys if ts._who_has.issubset(workers)} + address = self.coerce_address(address) - other_workers = set(self.workers.values()) - workers - if keys: - if other_workers: - logger.info("Moving %d keys to other workers", len(keys)) - await self.replicate( - keys=keys, - workers=[ws._address for ws in other_workers], - n=1, - delete=False, - lock=False, + if address not in parent._workers_dv: + return "already-removed" + + host = get_address_host(address) + + ws: WorkerState = parent._workers_dv[address] + + self.log_event( + ["all", address], + { + "action": "remove-worker", + "worker": address, + "processing-tasks": dict(ws._processing), + }, + ) + logger.info("Remove worker %s", ws) + if close: + with suppress(AttributeError, CommClosedError): + self.stream_comms[address].send({"op": "close", "report": False}) + + self.remove_resources(address) + + parent._host_info[host]["nthreads"] -= ws._nthreads + parent._host_info[host]["addresses"].remove(address) + parent._total_nthreads -= ws._nthreads + + if not parent._host_info[host]["addresses"]: + del parent._host_info[host] + + self.rpc.remove(address) + del self.stream_comms[address] + del parent._aliases[ws._name] + parent._idle.pop(ws._address, None) + parent._saturated.discard(ws) + del parent._workers[address] + ws.status = Status.closed + parent._total_occupancy -= ws._occupancy + + recommendations: dict = {} + + ts: TaskState + for ts in list(ws._processing): + k = ts._key + recommendations[k] = "released" + if not safe: + ts._suspicious += 1 + ts._prefix._suspicious += 1 + if ts._suspicious > self.allowed_failures: + del recommendations[k] + e = pickle.dumps( + KilledWorker(task=k, last_worker=ws.clean()), protocol=4 ) - else: - return {} + r = self.transition(k, "erred", exception=e, cause=k) + recommendations.update(r) + logger.info( + "Task %s marked as failed because %d workers died" + " while trying to run it", + ts._key, + self.allowed_failures, + ) + + for ts in ws._has_what: + ts._who_has.remove(ws) + if not ts._who_has: + if ts._run_spec: + recommendations[ts._key] = "released" + else: # pure data + recommendations[ts._key] = "forgotten" + ws._has_what.clear() + + self.transitions(recommendations) + + for plugin in self.plugins[:]: + try: + result = plugin.remove_worker(scheduler=self, worker=address) + if inspect.isawaitable(result): + await result + except Exception as e: + logger.exception(e) - worker_keys = {ws._address: ws.identity() for ws in workers} - if close_workers and worker_keys: - await asyncio.gather( - *[self.close_worker(worker=w, safe=True) for w in worker_keys] - ) - if remove: - await asyncio.gather( - *[self.remove_worker(address=w, safe=True) for w in worker_keys] - ) + if not parent._workers_dv: + logger.info("Lost all workers") - self.log_event( - "all", - { - "action": "retire-workers", - "workers": worker_keys, - "moved-keys": len(keys), - }, - ) - self.log_event(list(worker_keys), {"action": "retired"}) + for w in parent._workers_dv: + self.bandwidth_workers.pop((address, w), None) + self.bandwidth_workers.pop((w, address), None) - return worker_keys + def remove_worker_from_events(): + # If the worker isn't registered anymore after the delay, remove from events + if address not in parent._workers_dv and address in self.events: + del self.events[address] - def add_keys(self, comm=None, worker=None, keys=()): - """ - Learn that a worker has certain keys + cleanup_delay = parse_timedelta( + dask.config.get("distributed.scheduler.events-cleanup-delay") + ) + self.loop.call_later(cleanup_delay, remove_worker_from_events) + logger.debug("Removed worker %s", ws) - This should not be used in practice and is mostly here for legacy - reasons. However, it is sent by workers from time to time. - """ - if worker not in self.workers: - return "not found" - ws: WorkerState = self.workers[worker] + return "OK" + + def stimulus_cancel(self, comm, keys=None, client=None, force=False): + """ Stop execution on a list of keys """ + logger.info("Client %s requests to cancel %d keys", client, len(keys)) + if client: + self.log_event( + client, {"action": "cancel", "count": len(keys), "force": force} + ) for key in keys: - ts: TaskState = self.tasks.get(key) - if ts is not None and ts._state == "memory": - if ts not in ws._has_what: - ws._nbytes += ts.get_nbytes() - ws._has_what.add(ts) - ts._who_has.add(ws) - else: - self.worker_send( - worker, {"op": "delete-data", "keys": [key], "report": False} + self.cancel_key(key, client, force=force) + + def cancel_key(self, key, client, retries=5, force=False): + """ Cancel a particular key and all dependents """ + # TODO: this should be converted to use the transition mechanism + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState = parent._tasks.get(key) + dts: TaskState + try: + cs: ClientState = parent._clients[client] + except KeyError: + return + if ts is None or not ts._who_wants: # no key yet, lets try again in a moment + if retries: + self.loop.call_later( + 0.2, lambda: self.cancel_key(key, client, retries - 1) ) + return + if force or ts._who_wants == {cs}: # no one else wants this key + for dts in list(ts._dependents): + self.cancel_key(dts._key, client, force=force) + logger.info("Scheduler cancels key %s. Force=%s", key, force) + self.report({"op": "cancelled-key", "key": key}) + clients = list(ts._who_wants) if force else [cs] + for cs in clients: + self.client_releases_keys(keys=[key], client=cs._client_key) - return "OK" + def client_desires_keys(self, keys=None, client=None): + parent: SchedulerState = cast(SchedulerState, self) + cs: ClientState = parent._clients.get(client) + if cs is None: + # For publish, queues etc. + parent._clients[client] = cs = ClientState(client) + ts: TaskState + for k in keys: + ts = parent._tasks.get(k) + if ts is None: + # For publish, queues etc. + ts = self.new_task(k, None, "released") + ts._who_wants.add(cs) + cs._wants_what.add(ts) - def update_data( - self, comm=None, who_has=None, nbytes=None, client=None, serializers=None - ): - """ - Learn that new data has entered the network from an external source + if ts._state in ("memory", "erred"): + self.report_on_key(ts=ts, client=client) - See Also - -------- - Scheduler.mark_key_in_memory - """ - with log_errors(): - who_has = { - k: [self.coerce_address(vv) for vv in v] for k, v in who_has.items() - } - logger.debug("Update data %s", who_has) + def client_releases_keys(self, keys=None, client=None): + """ Remove keys from client desired list """ - for key, workers in who_has.items(): - ts: TaskState = self.tasks.get(key) - if ts is None: - ts: TaskState = self.new_task(key, None, "memory") - ts.state = "memory" - if key in nbytes: - ts.set_nbytes(nbytes[key]) - for w in workers: - ws: WorkerState = self.workers[w] - if ts not in ws._has_what: - ws._nbytes += ts.get_nbytes() - ws._has_what.add(ts) - ts._who_has.add(ws) - self.report( - {"op": "key-in-memory", "key": key, "workers": list(workers)} - ) + parent: SchedulerState = cast(SchedulerState, self) + if not isinstance(keys, list): + keys = list(keys) + cs: ClientState = parent._clients[client] + recommendations: dict = {} - if client: - self.client_desires_keys(keys=list(who_has), client=client) + _client_releases_keys(parent, keys=keys, cs=cs, recommendations=recommendations) + self.transitions(recommendations) - def report_on_key(self, key: str = None, ts: TaskState = None, client: str = None): - if ts is None: - tasks: dict = self.tasks - ts = tasks.get(key) - elif key is None: - key = ts._key - else: - assert False, (key, ts) - return + def client_heartbeat(self, client=None): + """ Handle heartbeats from Client """ + parent: SchedulerState = cast(SchedulerState, self) + cs: ClientState = parent._clients[client] + cs._last_seen = time() - if ts is None: - self.report({"op": "cancelled-key", "key": key}, client=client) - elif ts._state == "forgotten": - self.report({"op": "cancelled-key", "key": key}, ts=ts, client=client) - elif ts._state == "memory": - self.report({"op": "key-in-memory", "key": key}, ts=ts, client=client) - elif ts._state == "erred": - failing_ts: TaskState = ts._exception_blame - self.report( - { - "op": "task-erred", - "key": key, - "exception": failing_ts._exception, - "traceback": failing_ts._traceback, - }, - ts=ts, - client=client, - ) + ################### + # Task Validation # + ################### - async def feed( - self, comm, function=None, setup=None, teardown=None, interval="1s", **kwargs - ): - """ - Provides a data Comm to external requester + def validate_released(self, key): + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState = parent._tasks[key] + dts: TaskState + assert ts._state == "released" + assert not ts._waiters + assert not ts._waiting_on + assert not ts._who_has + assert not ts._processing_on + assert not any([ts in dts._waiters for dts in ts._dependencies]) + assert ts not in parent._unrunnable - Caution: this runs arbitrary Python code on the scheduler. This should - eventually be phased out. It is mostly used by diagnostics. - """ - if not dask.config.get("distributed.scheduler.pickle"): - logger.warn( - "Tried to call 'feed' route with custom functions, but " - "pickle is disallowed. Set the 'distributed.scheduler.pickle'" - "config value to True to use the 'feed' route (this is mostly " - "commonly used with progress bars)" - ) - return + def validate_waiting(self, key): + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState = parent._tasks[key] + dts: TaskState + assert ts._waiting_on + assert not ts._who_has + assert not ts._processing_on + assert ts not in parent._unrunnable + for dts in ts._dependencies: + # We are waiting on a dependency iff it's not stored + assert (not not dts._who_has) != (dts in ts._waiting_on) + assert ts in dts._waiters # XXX even if dts._who_has? - interval = parse_timedelta(interval) - with log_errors(): - if function: - function = pickle.loads(function) - if setup: - setup = pickle.loads(setup) - if teardown: - teardown = pickle.loads(teardown) - state = setup(self) if setup else None - if inspect.isawaitable(state): - state = await state - try: - while self.status == Status.running: - if state is None: - response = function(self) - else: - response = function(self, state) - await comm.write(response) - await asyncio.sleep(interval) - except (EnvironmentError, CommClosedError): - pass - finally: - if teardown: - teardown(self, state) + def validate_processing(self, key): + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState = parent._tasks[key] + dts: TaskState + assert not ts._waiting_on + ws: WorkerState = ts._processing_on + assert ws + assert ts in ws._processing + assert not ts._who_has + for dts in ts._dependencies: + assert dts._who_has + assert ts in dts._waiters - def log_worker_event(self, worker=None, topic=None, msg=None): - self.log_event(topic, msg) + def validate_memory(self, key): + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState = parent._tasks[key] + dts: TaskState + assert ts._who_has + assert not ts._processing_on + assert not ts._waiting_on + assert ts not in parent._unrunnable + for dts in ts._dependents: + assert (dts in ts._waiters) == (dts._state in ("waiting", "processing")) + assert ts not in dts._waiting_on - def subscribe_worker_status(self, comm=None): - WorkerStatusPlugin(self, comm) - ident = self.identity() - for v in ident["workers"].values(): - del v["metrics"] - del v["last_seen"] - return ident + def validate_no_worker(self, key): + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState = parent._tasks[key] + dts: TaskState + assert ts in parent._unrunnable + assert not ts._waiting_on + assert ts in parent._unrunnable + assert not ts._processing_on + assert not ts._who_has + for dts in ts._dependencies: + assert dts._who_has + + def validate_erred(self, key): + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState = parent._tasks[key] + assert ts._exception_blame + assert not ts._who_has + + def validate_key(self, key, ts: TaskState = None): + parent: SchedulerState = cast(SchedulerState, self) + try: + if ts is None: + ts = parent._tasks.get(key) + if ts is None: + logger.debug("Key lost: %s", key) + else: + ts.validate() + try: + func = getattr(self, "validate_" + ts._state.replace("-", "_")) + except AttributeError: + logger.error( + "self.validate_%s not found", ts._state.replace("-", "_") + ) + else: + func(key) + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - def get_processing(self, comm=None, workers=None): - ws: WorkerState - ts: TaskState - if workers is not None: - workers = set(map(self.coerce_address, workers)) - return {w: [ts._key for ts in self.workers[w].processing] for w in workers} - else: - return { - w: [ts._key for ts in ws._processing] for w, ws in self.workers.items() - } + pdb.set_trace() + raise - def get_who_has(self, comm=None, keys=None): - ws: WorkerState - ts: TaskState - if keys is not None: - return { - k: [ws._address for ws in self.tasks[k].who_has] - if k in self.tasks - else [] - for k in keys - } - else: - return { - key: [ws._address for ws in ts._who_has] - for key, ts in self.tasks.items() - } + def validate_state(self, allow_overlap=False): + parent: SchedulerState = cast(SchedulerState, self) + validate_state(parent._tasks, parent._workers, parent._clients) - def get_has_what(self, comm=None, workers=None): - ws: WorkerState - ts: TaskState - if workers is not None: - workers = map(self.coerce_address, workers) - return { - w: [ts._key for ts in self.workers[w].has_what] - if w in self.workers - else [] - for w in workers - } - else: - return { - w: [ts._key for ts in ws._has_what] for w, ws in self.workers.items() - } + if not (set(parent._workers_dv) == set(self.stream_comms)): + raise ValueError("Workers not the same in all collections") - def get_ncores(self, comm=None, workers=None): ws: WorkerState - if workers is not None: - workers = map(self.coerce_address, workers) - return {w: self.workers[w].nthreads for w in workers if w in self.workers} - else: - return {w: ws._nthreads for w, ws in self.workers.items()} + for w, ws in parent._workers_dv.items(): + assert isinstance(w, str), (type(w), w) + assert isinstance(ws, WorkerState), (type(ws), ws) + assert ws._address == w + if not ws._processing: + assert not ws._occupancy + assert ws._address in parent._idle_dv - async def get_call_stack(self, comm=None, keys=None): ts: TaskState - dts: TaskState - if keys is not None: - stack = list(keys) - processing = set() - while stack: - key = stack.pop() - ts = self.tasks[key] - if ts._state == "waiting": - stack.extend([dts._key for dts in ts._dependencies]) - elif ts._state == "processing": - processing.add(ts) - - workers = defaultdict(list) - for ts in processing: - if ts._processing_on: - workers[ts._processing_on.address].append(ts._key) - else: - workers = {w: None for w in self.workers} + for k, ts in parent._tasks.items(): + assert isinstance(ts, TaskState), (type(ts), ts) + assert ts._key == k + self.validate_key(k, ts) - if not workers: - return {} + c: str + cs: ClientState + for c, cs in parent._clients.items(): + # client=None is often used in tests... + assert c is None or type(c) == str, (type(c), c) + assert type(cs) == ClientState, (type(cs), cs) + assert cs._client_key == c - results = await asyncio.gather( - *(self.rpc(w).call_stack(keys=v) for w, v in workers.items()) - ) - response = {w: r for w, r in zip(workers, results) if r} - return response + a = {w: ws._nbytes for w, ws in parent._workers_dv.items()} + b = { + w: sum(ts.get_nbytes() for ts in ws._has_what) + for w, ws in parent._workers_dv.items() + } + assert a == b, (a, b) - def get_nbytes(self, comm=None, keys=None, summary=True): - ts: TaskState - with log_errors(): - if keys is not None: - result = {k: self.tasks[k].nbytes for k in keys} - else: - result = { - k: ts._nbytes for k, ts in self.tasks.items() if ts._nbytes >= 0 - } + actual_total_occupancy = 0 + for worker, ws in parent._workers_dv.items(): + assert abs(sum(ws._processing.values()) - ws._occupancy) < 1e-8 + actual_total_occupancy += ws._occupancy - if summary: - out = defaultdict(lambda: 0) - for k, v in result.items(): - out[key_split(k)] += v - result = dict(out) + assert abs(actual_total_occupancy - parent._total_occupancy) < 1e-8, ( + actual_total_occupancy, + parent._total_occupancy, + ) - return result + ################### + # Manage Messages # + ################### - def get_comm_cost(self, ts: TaskState, ws: WorkerState): - """ - Get the estimated communication cost (in s.) to compute the task - on the given worker. + def report(self, msg: dict, ts: TaskState = None, client: str = None): """ - dts: TaskState - deps: set = ts._dependencies - ws._has_what - nbytes: Py_ssize_t = 0 - bandwidth: double = self.bandwidth - for dts in deps: - nbytes += dts._nbytes - return nbytes / bandwidth + Publish updates to all listening Queues and Comms - def get_task_duration(self, ts: TaskState, default: double = -1): - """ - Get the estimated computation cost of the given task - (not including any communication cost). + If the message contains a key then we only send the message to those + comms that care about the key. """ - duration: double = ts._prefix._duration_average - if duration < 0: - s: set = self.unknown_durations[ts._prefix._name] - s.add(ts) - if default < 0: - duration = UNKNOWN_TASK_DURATION - else: - duration = default + parent: SchedulerState = cast(SchedulerState, self) + if ts is None: + msg_key = msg.get("key") + if msg_key is not None: + tasks: dict = parent._tasks + ts = tasks.get(msg_key) - return duration + cs: ClientState + client_comms: dict = self.client_comms + client_keys: list + if ts is None: + # Notify all clients + client_keys = list(client_comms) + elif client is None: + # Notify clients interested in key + client_keys = [cs._client_key for cs in ts._who_wants] + else: + # Notify clients interested in key (including `client`) + client_keys = [ + cs._client_key for cs in ts._who_wants if cs._client_key != client + ] + client_keys.append(client) - def run_function(self, stream, function, args=(), kwargs={}, wait=True): - """Run a function within this process + k: str + for k in client_keys: + c = client_comms.get(k) + if c is None: + continue + try: + c.send(msg) + # logger.debug("Scheduler sends message to client %s", msg) + except CommClosedError: + if self.status == Status.running: + logger.critical("Tried writing to closed comm: %s", msg) - See Also - -------- - Client.run_on_scheduler: + async def add_client(self, comm, client=None, versions=None): + """Add client to network + + We listen to all future messages from this Comm. """ - from .worker import run + parent: SchedulerState = cast(SchedulerState, self) + assert client is not None + comm.name = "Scheduler->Client" + logger.info("Receive client connection: %s", client) + self.log_event(["all", client], {"action": "add-client", "client": client}) + parent._clients[client] = ClientState(client, versions=versions) - self.log_event("all", {"action": "run-function", "function": function}) - return run(self, stream, function=function, args=args, kwargs=kwargs, wait=wait) + for plugin in self.plugins[:]: + try: + plugin.add_client(scheduler=self, client=client) + except Exception as e: + logger.exception(e) - def set_metadata(self, comm=None, keys=None, value=None): try: - metadata = self.task_metadata - for key in keys[:-1]: - if key not in metadata or not isinstance(metadata[key], (dict, list)): - metadata[key] = dict() - metadata = metadata[key] - metadata[keys[-1]] = value - except Exception as e: - import pdb + bcomm = BatchedSend(interval="2ms", loop=self.loop) + bcomm.start(comm) + self.client_comms[client] = bcomm + msg = {"op": "stream-start"} + ws: WorkerState + version_warning = version_module.error_message( + version_module.get_versions(), + {w: ws._versions for w, ws in parent._workers_dv.items()}, + versions, + ) + msg.update(version_warning) + bcomm.send(msg) - pdb.set_trace() + try: + await self.handle_stream(comm=comm, extra={"client": client}) + finally: + self.remove_client(client=client) + logger.debug("Finished handling client %s", client) + finally: + if not comm.closed(): + self.client_comms[client].send({"op": "stream-closed"}) + try: + if not shutting_down(): + await self.client_comms[client].close() + del self.client_comms[client] + if self.status == Status.running: + logger.info("Close client connection: %s", client) + except TypeError: # comm becomes None during GC + pass - def get_metadata(self, comm=None, keys=None, default=no_default): - metadata = self.task_metadata - for key in keys[:-1]: - metadata = metadata[key] + def remove_client(self, client=None): + """ Remove client from network """ + parent: SchedulerState = cast(SchedulerState, self) + if self.status == Status.running: + logger.info("Remove client %s", client) + self.log_event(["all", client], {"action": "remove-client", "client": client}) try: - return metadata[keys[-1]] + cs: ClientState = parent._clients[client] except KeyError: - if default != no_default: - return default - else: - raise + # XXX is this a legitimate condition? + pass + else: + ts: TaskState + self.client_releases_keys( + keys=[ts._key for ts in cs._wants_what], client=cs._client_key + ) + del parent._clients[client] + + for plugin in self.plugins[:]: + try: + plugin.remove_client(scheduler=self, client=client) + except Exception as e: + logger.exception(e) + + def remove_client_from_events(): + # If the client isn't registered anymore after the delay, remove from events + if client not in parent._clients and client in self.events: + del self.events[client] + + cleanup_delay = parse_timedelta( + dask.config.get("distributed.scheduler.events-cleanup-delay") + ) + self.loop.call_later(cleanup_delay, remove_client_from_events) + + def send_task_to_worker(self, worker, ts: TaskState, duration=None): + """ Send a single computational task to a worker """ + parent: SchedulerState = cast(SchedulerState, self) + try: + msg: dict = _task_to_msg(parent, ts, duration) + self.worker_send(worker, msg) + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb - def get_task_status(self, comm=None, keys=None): - return { - key: (self.tasks[key].state if key in self.tasks else None) for key in keys - } + pdb.set_trace() + raise - def get_task_stream(self, comm=None, start=None, stop=None, count=None): - from distributed.diagnostics.task_stream import TaskStreamPlugin + def handle_uncaught_error(self, **msg): + logger.exception(clean_exception(**msg)[1]) - self.add_plugin(TaskStreamPlugin, idempotent=True) - tsp = [p for p in self.plugins if isinstance(p, TaskStreamPlugin)][0] - return tsp.collect(start=start, stop=stop, count=count) + def handle_task_finished(self, key=None, worker=None, **msg): + parent: SchedulerState = cast(SchedulerState, self) + if worker not in parent._workers_dv: + return + validate_key(key) + r = self.stimulus_task_finished(key=key, worker=worker, **msg) + self.transitions(r) - def start_task_metadata(self, comm=None, name=None): - plugin = CollectTaskMetaDataPlugin(scheduler=self, name=name) + def handle_task_erred(self, key=None, **msg): + r = self.stimulus_task_erred(key=key, **msg) + self.transitions(r) - self.add_plugin(plugin) + def handle_release_data(self, key=None, worker=None, client=None, **msg): + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState = parent._tasks.get(key) + if ts is None: + return + ws: WorkerState = parent._workers_dv[worker] + if ts._processing_on != ws: + return + r = self.stimulus_missing_data(key=key, ensure=False, **msg) + self.transitions(r) - def stop_task_metadata(self, comm=None, name=None): - plugins = [ - p - for p in self.plugins - if isinstance(p, CollectTaskMetaDataPlugin) and p.name == name - ] - if len(plugins) != 1: - raise ValueError( - "Expected to find exactly one CollectTaskMetaDataPlugin " - f"with name {name} but found {len(plugins)}." - ) + def handle_missing_data(self, key=None, errant_worker=None, **kwargs): + parent: SchedulerState = cast(SchedulerState, self) + logger.debug("handle missing data key=%s worker=%s", key, errant_worker) + self.log.append(("missing", key, errant_worker)) - plugin = plugins[0] - self.remove_plugin(plugin) - return {"metadata": plugin.metadata, "state": plugin.state} + ts: TaskState = parent._tasks.get(key) + if ts is None or not ts._who_has: + return + if errant_worker in parent._workers_dv: + ws: WorkerState = parent._workers_dv[errant_worker] + if ws in ts._who_has: + ts._who_has.remove(ws) + ws._has_what.remove(ts) + ws._nbytes -= ts.get_nbytes() + if not ts._who_has: + if ts._run_spec: + self.transitions({key: "released"}) + else: + self.transitions({key: "forgotten"}) - async def register_worker_plugin(self, comm, plugin, name=None): - """ Registers a setup function, and call it on every worker """ - self.worker_plugins.append({"plugin": plugin, "name": name}) + def release_worker_data(self, comm=None, keys=None, worker=None): + parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState = parent._workers_dv[worker] + tasks = {parent._tasks[k] for k in keys} + removed_tasks = tasks & ws._has_what + ws._has_what -= removed_tasks - responses = await self.broadcast( - msg=dict(op="plugin-add", plugin=plugin, name=name) - ) - return responses + ts: TaskState + recommendations: dict = {} + for ts in removed_tasks: + ws._nbytes -= ts.get_nbytes() + wh = ts._who_has + wh.remove(ws) + if not wh: + recommendations[ts._key] = "released" + if recommendations: + self.transitions(recommendations) - ##################### - # State Transitions # - ##################### + def handle_long_running(self, key=None, worker=None, compute_duration=None): + """A task has seceded from the thread pool - def _remove_from_processing(self, ts: TaskState, send_worker_msg=None): - """ - Remove *ts* from the set of processing tasks. + We stop the task from being stolen in the future, and change task + duration accounting as if the task has stopped. """ - workers: dict = cast(dict, self.workers) + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState = parent._tasks[key] + if "stealing" in self._extensions: + self._extensions["stealing"].remove_key_from_stealable(ts) + ws: WorkerState = ts._processing_on - ts._processing_on = None - w: str = ws._address - if w in workers: # may have been removed - duration = ws._processing.pop(ts) - if not ws._processing: - self.total_occupancy -= ws._occupancy - ws._occupancy = 0 - else: - self.total_occupancy -= duration - ws._occupancy -= duration - self.check_idle_saturated(ws) - self.release_resources(ts, ws) - if send_worker_msg: - self.worker_send(w, send_worker_msg) + if ws is None: + logger.debug("Received long-running signal from duplicate task. Ignoring.") + return - def _add_to_memory( - self, - ts: TaskState, - ws: WorkerState, - recommendations: dict, - type=None, - typename=None, - **kwargs, - ): - """ - Add *ts* to the set of in-memory tasks. - """ - if self.validate: - assert ts not in ws._has_what + if compute_duration: + old_duration = ts._prefix._duration_average + new_duration = compute_duration + if old_duration < 0: + avg_duration = new_duration + else: + avg_duration = 0.5 * old_duration + 0.5 * new_duration - ts._who_has.add(ws) - ws._has_what.add(ts) - ws._nbytes += ts.get_nbytes() + ts._prefix._duration_average = avg_duration - deps: list = list(ts._dependents) - if len(deps) > 1: - deps.sort(key=operator.attrgetter("priority"), reverse=True) + ws._occupancy -= ws._processing[ts] + parent._total_occupancy -= ws._processing[ts] + ws._processing[ts] = 0 + self.check_idle_saturated(ws) - dts: TaskState - s: set - for dts in deps: - s = dts._waiting_on - if ts in s: - s.discard(ts) - if not s: # new task ready to run - recommendations[dts._key] = "processing" + async def handle_worker(self, comm=None, worker=None): + """ + Listen to responses from a single worker - for dts in ts._dependencies: - s = dts._waiters - s.discard(ts) - if not s and not dts._who_wants: - recommendations[dts._key] = "released" + This is the main loop for scheduler-worker interaction - if not ts._waiters and not ts._who_wants: - recommendations[ts._key] = "released" - else: - msg: dict = {"op": "key-in-memory", "key": ts._key} - if type is not None: - msg["type"] = type - self.report(msg) + See Also + -------- + Scheduler.handle_client: Equivalent coroutine for clients + """ + comm.name = "Scheduler connection to worker" + worker_comm = self.stream_comms[worker] + worker_comm.start(comm) + logger.info("Starting worker compute stream, %s", worker) + try: + await self.handle_stream(comm=comm, extra={"worker": worker}) + finally: + if worker in self.stream_comms: + worker_comm.abort() + await self.remove_worker(address=worker) - ts.state = "memory" - ts._type = typename - ts._group._types.add(typename) + def add_plugin(self, plugin=None, idempotent=False, **kwargs): + """ + Add external plugin to scheduler - cs: ClientState = self.clients["fire-and-forget"] - if ts in cs._wants_what: - self.client_releases_keys(client="fire-and-forget", keys=[ts._key]) + See https://distributed.readthedocs.io/en/latest/plugins.html + """ + if isinstance(plugin, type): + plugin = plugin(self, **kwargs) - def transition_released_waiting(self, key): - try: - tasks: dict = self.tasks - workers: dict = cast(dict, self.workers) - ts: TaskState = tasks[key] - dts: TaskState + if idempotent and any(isinstance(p, type(plugin)) for p in self.plugins): + return - if self.validate: - assert ts._run_spec - assert not ts._waiting_on - assert not ts._who_has - assert not ts._processing_on - assert not any([dts._state == "forgotten" for dts in ts._dependencies]) + self.plugins.append(plugin) - if ts._has_lost_dependencies: - return {key: "forgotten"} + def remove_plugin(self, plugin): + """ Remove external plugin from scheduler """ + self.plugins.remove(plugin) - ts.state = "waiting" + def worker_send(self, worker, msg): + """Send message to worker - recommendations: dict = {} + This also handles connection failures by adding a callback to remove + the worker on the next cycle. + """ + stream_comms: dict = self.stream_comms + try: + stream_comms[worker].send(msg) + except (CommClosedError, AttributeError): + self.loop.add_callback(self.remove_worker, address=worker) - dts: TaskState - for dts in ts._dependencies: - if dts._exception_blame: - ts._exception_blame = dts._exception_blame - recommendations[key] = "erred" - return recommendations + def client_send(self, client, msg): + """Send message to client""" + client_comms: dict = self.client_comms + c = client_comms.get(client) + if c is None: + return + try: + c.send(msg) + except CommClosedError: + if self.status == Status.running: + logger.critical("Tried writing to closed comm: %s", msg) - for dts in ts._dependencies: - dep = dts._key - if not dts._who_has: - ts._waiting_on.add(dts) - if dts._state == "released": - recommendations[dep] = "waiting" - else: - dts._waiters.add(ts) + ############################ + # Less common interactions # + ############################ - ts._waiters = {dts for dts in ts._dependents if dts._state == "waiting"} + async def scatter( + self, + comm=None, + data=None, + workers=None, + client=None, + broadcast=False, + timeout=2, + ): + """Send data out to workers - if not ts._waiting_on: - if workers: - recommendations[key] = "processing" - else: - self.unrunnable.add(ts) - ts.state = "no-worker" + See also + -------- + Scheduler.broadcast: + """ + parent: SchedulerState = cast(SchedulerState, self) + start = time() + while not parent._workers_dv: + await asyncio.sleep(0.2) + if time() > start + timeout: + raise TimeoutError("No workers found") - return recommendations - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + if workers is None: + ws: WorkerState + nthreads = {w: ws._nthreads for w, ws in parent._workers_dv.items()} + else: + workers = [self.coerce_address(w) for w in workers] + nthreads = {w: parent._workers_dv[w].nthreads for w in workers} - pdb.set_trace() - raise + assert isinstance(data, dict) - def transition_no_worker_waiting(self, key): - try: - tasks: dict = self.tasks - workers: dict = cast(dict, self.workers) - ts: TaskState = tasks[key] - dts: TaskState + keys, who_has, nbytes = await scatter_to_workers( + nthreads, data, rpc=self.rpc, report=False + ) - if self.validate: - assert ts in self.unrunnable - assert not ts._waiting_on - assert not ts._who_has - assert not ts._processing_on + self.update_data(who_has=who_has, nbytes=nbytes, client=client) - self.unrunnable.remove(ts) + if broadcast: + if broadcast == True: # noqa: E712 + n = len(nthreads) + else: + n = broadcast + await self.replicate(keys=keys, workers=workers, n=n) - if ts._has_lost_dependencies: - return {key: "forgotten"} + self.log_event( + [client, "all"], {"action": "scatter", "client": client, "count": len(data)} + ) + return keys - recommendations: dict = {} + async def gather(self, comm=None, keys=None, serializers=None): + """ Collect data in from workers """ + parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState + keys = list(keys) + who_has = {} + for key in keys: + ts: TaskState = parent._tasks.get(key) + if ts is not None: + who_has[key] = [ws._address for ws in ts._who_has] + else: + who_has[key] = [] - for dts in ts._dependencies: - dep = dts._key - if not dts._who_has: - ts._waiting_on.add(dts) - if dts._state == "released": - recommendations[dep] = "waiting" - else: - dts._waiters.add(ts) + data, missing_keys, missing_workers = await gather_from_workers( + who_has, rpc=self.rpc, close=False, serializers=serializers + ) + if not missing_keys: + result = {"status": "OK", "data": data} + else: + missing_states = [ + (parent._tasks[key].state if key in parent._tasks else None) + for key in missing_keys + ] + logger.exception( + "Couldn't gather keys %s state: %s workers: %s", + missing_keys, + missing_states, + missing_workers, + ) + result = {"status": "error", "keys": missing_keys} + with log_errors(): + # Remove suspicious workers from the scheduler but allow them to + # reconnect. + await asyncio.gather( + *[ + self.remove_worker(address=worker, close=False) + for worker in missing_workers + ] + ) + for key, workers in missing_keys.items(): + # Task may already be gone if it was held by a + # `missing_worker` + ts: TaskState = parent._tasks.get(key) + logger.exception( + "Workers don't have promised key: %s, %s", + str(workers), + str(key), + ) + if not workers or ts is None: + continue + for worker in workers: + ws = parent._workers_dv.get(worker) + if ws is not None and ts in ws._has_what: + ws._has_what.remove(ts) + ts._who_has.remove(ws) + ws._nbytes -= ts.get_nbytes() + self.transitions({key: "released"}) - ts.state = "waiting" + self.log_event("all", {"action": "gather", "count": len(keys)}) + return result - if not ts._waiting_on: - if workers: - recommendations[key] = "processing" - else: - self.unrunnable.add(ts) - ts.state = "no-worker" + def clear_task_state(self): + # XXX what about nested state such as ClientState.wants_what + # (see also fire-and-forget...) + logger.info("Clear task state") + for collection in self._task_state_collections: + collection.clear() - return recommendations - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + async def restart(self, client=None, timeout=3): + """ Restart all workers. Reset local state. """ + parent: SchedulerState = cast(SchedulerState, self) + with log_errors(): - pdb.set_trace() - raise + n_workers = len(parent._workers_dv) - def decide_worker(self, ts: TaskState) -> WorkerState: - """ - Decide on a worker for task *ts*. Return a WorkerState. - """ - workers: dict = cast(dict, self.workers) - ws: WorkerState = None - valid_workers: set = self.valid_workers(ts) + logger.info("Send lost future signal to clients") + cs: ClientState + ts: TaskState + for cs in parent._clients.values(): + self.client_releases_keys( + keys=[ts._key for ts in cs._wants_what], client=cs._client_key + ) - if ( - valid_workers is not None - and not valid_workers - and not ts._loose_restrictions - and workers - ): - self.unrunnable.add(ts) - ts.state = "no-worker" - return ws + ws: WorkerState + nannies = {addr: ws._nanny for addr, ws in parent._workers_dv.items()} - if ts._dependencies or valid_workers is not None: - ws = decide_worker( - ts, - workers.values(), - valid_workers, - partial(self.worker_objective, ts), - ) - else: - worker_pool = self.idle or self.workers - worker_pool_dv = cast(dict, worker_pool) - n_workers: Py_ssize_t = len(worker_pool_dv) - if n_workers < 20: # smart but linear in small case - ws = min(worker_pool.values(), key=operator.attrgetter("occupancy")) - else: # dumb but fast in large case - n_tasks: Py_ssize_t = self.n_tasks - ws = worker_pool.values()[n_tasks % n_workers] + for addr in list(parent._workers_dv): + try: + # Ask the worker to close if it doesn't have a nanny, + # otherwise the nanny will kill it anyway + await self.remove_worker(address=addr, close=addr not in nannies) + except Exception as e: + logger.info( + "Exception while restarting. This is normal", exc_info=True + ) - if self.validate: - assert ws is None or isinstance(ws, WorkerState), ( - type(ws), - ws, - ) - assert ws._address in workers + self.clear_task_state() - return ws + for plugin in self.plugins[:]: + try: + plugin.restart(self) + except Exception as e: + logger.exception(e) - def set_duration_estimate(self, ts: TaskState, ws: WorkerState): - """Estimate task duration using worker state and task state. + logger.debug("Send kill signal to nannies: %s", nannies) - If a task takes longer than twice the current average duration we - estimate the task duration to be 2x current-runtime, otherwise we set it - to be the average duration. - """ - exec_time: double = ws._executing.get(ts, 0) - duration: double = self.get_task_duration(ts) - total_duration: double - if exec_time > 2 * duration: - total_duration = 2 * exec_time - else: - comm: double = self.get_comm_cost(ts, ws) - total_duration = duration + comm - ws._processing[ts] = total_duration - return total_duration + nannies = [ + rpc(nanny_address, connection_args=self.connection_args) + for nanny_address in nannies.values() + if nanny_address is not None + ] - def transition_waiting_processing(self, key): - try: - tasks: dict = self.tasks - ts: TaskState = tasks[key] - dts: TaskState + resps = All( + [ + nanny.restart( + close=True, timeout=timeout * 0.8, executor_wait=False + ) + for nanny in nannies + ] + ) + try: + resps = await asyncio.wait_for(resps, timeout) + except TimeoutError: + logger.error( + "Nannies didn't report back restarted within " + "timeout. Continuuing with restart process" + ) + else: + if not all(resp == "OK" for resp in resps): + logger.error( + "Not all workers responded positively: %s", resps, exc_info=True + ) + finally: + await asyncio.gather(*[nanny.close_rpc() for nanny in nannies]) - if self.validate: - assert not ts._waiting_on - assert not ts._who_has - assert not ts._exception_blame - assert not ts._processing_on - assert not ts._has_lost_dependencies - assert ts not in self.unrunnable - assert all([dts._who_has for dts in ts._dependencies]) + self.clear_task_state() - ws: WorkerState = self.decide_worker(ts) - if ws is None: - return {} - worker = ws._address + with suppress(AttributeError): + for c in self._worker_coroutines: + c.cancel() - duration_estimate = self.set_duration_estimate(ts, ws) - ts._processing_on = ws - ws._occupancy += duration_estimate - self.total_occupancy += duration_estimate - ts.state = "processing" - self.consume_resources(ts, ws) - self.check_idle_saturated(ws) - self.n_tasks += 1 + self.log_event([client, "all"], {"action": "restart", "client": client}) + start = time() + while time() < start + 10 and len(parent._workers_dv) < n_workers: + await asyncio.sleep(0.01) - if ts._actor: - ws._actors.add(ts) + self.report({"op": "restart"}) - # logger.debug("Send job to worker: %s, %s", worker, key) + async def broadcast( + self, + comm=None, + msg=None, + workers=None, + hosts=None, + nanny=False, + serializers=None, + ): + """ Broadcast message to workers, return all results """ + parent: SchedulerState = cast(SchedulerState, self) + if workers is None or workers is True: + if hosts is None: + workers = list(parent._workers_dv) + else: + workers = [] + if hosts is not None: + for host in hosts: + if host in parent._host_info: + workers.extend(parent._host_info[host]["addresses"]) + # TODO replace with worker_list - self.send_task_to_worker(worker, ts) + if nanny: + addresses = [parent._workers_dv[w].nanny for w in workers] + else: + addresses = workers - return {} - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + async def send_message(addr): + comm = await self.rpc.connect(addr) + comm.name = "Scheduler Broadcast" + try: + resp = await send_recv(comm, close=True, serializers=serializers, **msg) + finally: + self.rpc.reuse(addr, comm) + return resp - pdb.set_trace() - raise + results = await All( + [send_message(address) for address in addresses if address is not None] + ) - def transition_waiting_memory(self, key, nbytes=None, worker=None, **kwargs): - try: - workers: dict = cast(dict, self.workers) - ws: WorkerState = workers[worker] - tasks: dict = self.tasks - ts: TaskState = tasks[key] + return dict(zip(workers, results)) - if self.validate: - assert not ts._processing_on - assert ts._waiting_on - assert ts._state == "waiting" + async def proxy(self, comm=None, msg=None, worker=None, serializers=None): + """ Proxy a communication through the scheduler to some other worker """ + d = await self.broadcast( + comm=comm, msg=msg, workers=[worker], serializers=serializers + ) + return d[worker] - ts._waiting_on.clear() + async def _delete_worker_data(self, worker_address, keys): + """Delete data from a worker and update the corresponding worker/task states - if nbytes is not None: - ts.set_nbytes(nbytes) + Parameters + ---------- + worker_address: str + Worker address to delete keys from + keys: List[str] + List of keys to delete on the specified worker + """ + parent: SchedulerState = cast(SchedulerState, self) + await retry_operation( + self.rpc(addr=worker_address).delete_data, keys=list(keys), report=False + ) - self.check_idle_saturated(ws) + ws: WorkerState = parent._workers_dv[worker_address] + ts: TaskState + tasks: set = {parent._tasks[key] for key in keys} + ws._has_what -= tasks + for ts in tasks: + ts._who_has.remove(ws) + ws._nbytes -= ts.get_nbytes() + self.log_event(ws._address, {"action": "remove-worker-data", "keys": keys}) - recommendations: dict = {} + async def rebalance(self, comm=None, keys=None, workers=None): + """Rebalance keys so that each worker stores roughly equal bytes - self._add_to_memory(ts, ws, recommendations, **kwargs) + **Policy** - if self.validate: - assert not ts._processing_on - assert not ts._waiting_on - assert ts._who_has + This orders the workers by what fraction of bytes of the existing keys + they have. It walks down this list from most-to-least. At each worker + it sends the largest results it can find and sends them to the least + occupied worker until either the sender or the recipient are at the + average expected load. + """ + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState + with log_errors(): + async with self._lock: + if keys: + tasks = {parent._tasks[k] for k in keys} + missing_data = [ts._key for ts in tasks if not ts._who_has] + if missing_data: + return {"status": "missing-data", "keys": missing_data} + else: + tasks = set(parent._tasks.values()) - return recommendations - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + if workers: + workers = {parent._workers_dv[w] for w in workers} + workers_by_task = {ts: ts._who_has & workers for ts in tasks} + else: + workers = set(parent._workers_dv.values()) + workers_by_task = {ts: ts._who_has for ts in tasks} - pdb.set_trace() - raise + ws: WorkerState + tasks_by_worker = {ws: set() for ws in workers} - def transition_processing_memory( - self, - key, - nbytes=None, - type=None, - typename=None, - worker=None, - startstops=None, - **kwargs, - ): - ws: WorkerState - wws: WorkerState - try: - tasks: dict = self.tasks - ts: TaskState = tasks[key] - assert worker - assert isinstance(worker, str) + for k, v in workers_by_task.items(): + for vv in v: + tasks_by_worker[vv].add(k) - if self.validate: - assert ts._processing_on - ws = ts._processing_on - assert ts in ws._processing - assert not ts._waiting_on - assert not ts._who_has, (ts, ts._who_has) - assert not ts._exception_blame - assert ts._state == "processing" + worker_bytes = { + ws: sum(ts.get_nbytes() for ts in v) + for ws, v in tasks_by_worker.items() + } - workers: dict = cast(dict, self.workers) - ws = workers.get(worker) - if ws is None: - return {key: "released"} + avg = sum(worker_bytes.values()) / len(worker_bytes) - if ws != ts._processing_on: # someone else has this task - logger.info( - "Unexpected worker completed task, likely due to" - " work stealing. Expected: %s, Got: %s, Key: %s", - ts._processing_on, - ws, - key, + sorted_workers = list( + map(first, sorted(worker_bytes.items(), key=second, reverse=True)) ) - return {} - - if startstops: - L = list() - for startstop in startstops: - stop = startstop["stop"] - start = startstop["start"] - action = startstop["action"] - if action == "compute": - L.append((start, stop)) - - # record timings of all actions -- a cheaper way of - # getting timing info compared with get_task_stream() - ts._prefix._all_durations[action] += stop - start - if len(L) > 0: - compute_start, compute_stop = L[0] - else: # This is very rare - compute_start = compute_stop = None - else: - compute_start = compute_stop = None + recipients = iter(reversed(sorted_workers)) + recipient = next(recipients) + msgs = [] # (sender, recipient, key) + for sender in sorted_workers[: len(workers) // 2]: + sender_keys = { + ts: ts.get_nbytes() for ts in tasks_by_worker[sender] + } + sender_keys = iter( + sorted(sender_keys.items(), key=second, reverse=True) + ) - ############################# - # Update Timing Information # - ############################# - if compute_start and ws._processing.get(ts, True): - # Update average task duration for worker - old_duration = ts._prefix._duration_average - new_duration = compute_stop - compute_start - if old_duration < 0: - avg_duration = new_duration - else: - avg_duration = 0.5 * old_duration + 0.5 * new_duration + try: + while worker_bytes[sender] > avg: + while ( + worker_bytes[recipient] < avg + and worker_bytes[sender] > avg + ): + ts, nb = next(sender_keys) + if ts not in tasks_by_worker[recipient]: + tasks_by_worker[recipient].add(ts) + # tasks_by_worker[sender].remove(ts) + msgs.append((sender, recipient, ts)) + worker_bytes[sender] -= nb + worker_bytes[recipient] += nb + if worker_bytes[sender] > avg: + recipient = next(recipients) + except StopIteration: + break - ts._prefix._duration_average = avg_duration - ts._group._duration += new_duration + to_recipients = defaultdict(lambda: defaultdict(list)) + to_senders = defaultdict(list) + for sender, recipient, ts in msgs: + to_recipients[recipient.address][ts._key].append(sender.address) + to_senders[sender.address].append(ts._key) - tts: TaskState - for tts in self.unknown_durations.pop(ts._prefix._name, ()): - if tts._processing_on: - wws = tts._processing_on - old = wws._processing[tts] - comm = self.get_comm_cost(tts, wws) - wws._processing[tts] = avg_duration + comm - wws._occupancy += avg_duration + comm - old - self.total_occupancy += avg_duration + comm - old + result = await asyncio.gather( + *( + retry_operation(self.rpc(addr=r).gather, who_has=v) + for r, v in to_recipients.items() + ) + ) + for r, v in to_recipients.items(): + self.log_event(r, {"action": "rebalance", "who_has": v}) - ############################ - # Update State Information # - ############################ - if nbytes is not None: - ts.set_nbytes(nbytes) + self.log_event( + "all", + { + "action": "rebalance", + "total-keys": len(tasks), + "senders": valmap(len, to_senders), + "recipients": valmap(len, to_recipients), + "moved_keys": len(msgs), + }, + ) - recommendations: dict = {} + if not all(r["status"] == "OK" for r in result): + return { + "status": "missing-data", + "keys": tuple( + concat( + r["keys"].keys() + for r in result + if r["status"] == "missing-data" + ) + ), + } - self._remove_from_processing(ts) + for sender, recipient, ts in msgs: + assert ts._state == "memory" + ts._who_has.add(recipient) + recipient.has_what.add(ts) + recipient.nbytes += ts.get_nbytes() + self.log.append( + ( + "rebalance", + ts._key, + time(), + sender.address, + recipient.address, + ) + ) - self._add_to_memory(ts, ws, recommendations, type=type, typename=typename) + await asyncio.gather( + *(self._delete_worker_data(r, v) for r, v in to_senders.items()) + ) - if self.validate: - assert not ts._processing_on - assert not ts._waiting_on + return {"status": "OK"} - return recommendations - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + async def replicate( + self, + comm=None, + keys=None, + n=None, + workers=None, + branching_factor=2, + delete=True, + lock=True, + ): + """Replicate data throughout cluster - pdb.set_trace() - raise + This performs a tree copy of the data throughout the network + individually on each piece of data. - def transition_memory_released(self, key, safe=False): - ws: WorkerState - try: - tasks: dict = self.tasks - ts: TaskState = tasks[key] - dts: TaskState + Parameters + ---------- + keys: Iterable + list of keys to replicate + n: int + Number of replications we expect to see within the cluster + branching_factor: int, optional + The number of workers that can copy data in each generation. + The larger the branching factor, the more data we copy in + a single step, but the more a given worker risks being + swamped by data requests. - if self.validate: - assert not ts._waiting_on - assert not ts._processing_on - if safe: - assert not ts._waiters + See also + -------- + Scheduler.rebalance + """ + parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState + wws: WorkerState + ts: TaskState - if ts._actor: - for ws in ts._who_has: - ws._actors.discard(ts) - if ts._who_wants: - ts._exception_blame = ts - ts._exception = "Worker holding Actor was lost" - return {ts._key: "erred"} # don't try to recreate + assert branching_factor > 0 + async with self._lock if lock else empty_context: + workers = {parent._workers_dv[w] for w in self.workers_list(workers)} + if n is None: + n = len(workers) + else: + n = min(n, len(workers)) + if n == 0: + raise ValueError("Can not use replicate to delete data") - recommendations: dict = {} + tasks = {parent._tasks[k] for k in keys} + missing_data = [ts._key for ts in tasks if not ts._who_has] + if missing_data: + return {"status": "missing-data", "keys": missing_data} - for dts in ts._waiters: - if dts._state in ("no-worker", "processing"): - recommendations[dts._key] = "waiting" - elif dts._state == "waiting": - dts._waiting_on.add(ts) + # Delete extraneous data + if delete: + del_worker_tasks = defaultdict(set) + for ts in tasks: + del_candidates = ts._who_has & workers + if len(del_candidates) > n: + for ws in random.sample( + del_candidates, len(del_candidates) - n + ): + del_worker_tasks[ws].add(ts) - # XXX factor this out? - for ws in ts._who_has: - ws._has_what.remove(ts) - ws._nbytes -= ts.get_nbytes() - ts._group._nbytes_in_memory -= ts.get_nbytes() - self.worker_send( - ws._address, {"op": "delete-data", "keys": [key], "report": False} + await asyncio.gather( + *[ + self._delete_worker_data(ws._address, [t.key for t in tasks]) + for ws, tasks in del_worker_tasks.items() + ] ) - ts._who_has.clear() - - ts.state = "released" - self.report({"op": "lost-data", "key": key}) + # Copy not-yet-filled data + while tasks: + gathers = defaultdict(dict) + for ts in list(tasks): + if ts._state == "forgotten": + # task is no longer needed by any client or dependant task + tasks.remove(ts) + continue + n_missing = n - len(ts._who_has & workers) + if n_missing <= 0: + # Already replicated enough + tasks.remove(ts) + continue - if not ts._run_spec: # pure data - recommendations[key] = "forgotten" - elif ts._has_lost_dependencies: - recommendations[key] = "forgotten" - elif ts._who_wants or ts._waiters: - recommendations[key] = "waiting" + count = min(n_missing, branching_factor * len(ts._who_has)) + assert count > 0 - if self.validate: - assert not ts._waiting_on + for ws in random.sample(workers - ts._who_has, count): + gathers[ws._address][ts._key] = [ + wws._address for wws in ts._who_has + ] - return recommendations - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + results = await asyncio.gather( + *( + retry_operation(self.rpc(addr=w).gather, who_has=who_has) + for w, who_has in gathers.items() + ) + ) + for w, v in zip(gathers, results): + if v["status"] == "OK": + self.add_keys(worker=w, keys=list(gathers[w])) + else: + logger.warning("Communication failed during replication: %s", v) - pdb.set_trace() - raise + self.log_event(w, {"action": "replicate-add", "keys": gathers[w]}) - def transition_released_erred(self, key): - try: - tasks: dict = self.tasks - ts: TaskState = tasks[key] - dts: TaskState - failing_ts: TaskState + self.log_event( + "all", + { + "action": "replicate", + "workers": list(workers), + "key-count": len(keys), + "branching-factor": branching_factor, + }, + ) - if self.validate: - with log_errors(pdb=LOG_PDB): - assert ts._exception_blame - assert not ts._who_has - assert not ts._waiting_on - assert not ts._waiters + def workers_to_close( + self, + comm=None, + memory_ratio=None, + n=None, + key=None, + minimum=None, + target=None, + attribute="address", + ): + """ + Find workers that we can close with low cost - recommendations: dict = {} + This returns a list of workers that are good candidates to retire. + These workers are not running anything and are storing + relatively little data relative to their peers. If all workers are + idle then we still maintain enough workers to have enough RAM to store + our data, with a comfortable buffer. - failing_ts = ts._exception_blame + This is for use with systems like ``distributed.deploy.adaptive``. - for dts in ts._dependents: - dts._exception_blame = failing_ts - if not dts._who_has: - recommendations[dts._key] = "erred" + Parameters + ---------- + memory_factor: Number + Amount of extra space we want to have for our stored data. + Defaults two 2, or that we want to have twice as much memory as we + currently have data. + n: int + Number of workers to close + minimum: int + Minimum number of workers to keep around + key: Callable(WorkerState) + An optional callable mapping a WorkerState object to a group + affiliation. Groups will be closed together. This is useful when + closing workers must be done collectively, such as by hostname. + target: int + Target number of workers to have after we close + attribute : str + The attribute of the WorkerState object to return, like "address" + or "name". Defaults to "address". - self.report( - { - "op": "task-erred", - "key": key, - "exception": failing_ts._exception, - "traceback": failing_ts._traceback, - } - ) + Examples + -------- + >>> scheduler.workers_to_close() + ['tcp://192.168.0.1:1234', 'tcp://192.168.0.2:1234'] - ts.state = "erred" + Group workers by hostname prior to closing - # TODO: waiting data? - return recommendations - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + >>> scheduler.workers_to_close(key=lambda ws: ws.host) + ['tcp://192.168.0.1:1234', 'tcp://192.168.0.1:4567'] - pdb.set_trace() - raise + Remove two workers - def transition_erred_released(self, key): - try: - tasks: dict = self.tasks - ts: TaskState = tasks[key] - dts: TaskState + >>> scheduler.workers_to_close(n=2) - if self.validate: - with log_errors(pdb=LOG_PDB): - assert all([dts._state != "erred" for dts in ts._dependencies]) - assert ts._exception_blame - assert not ts._who_has - assert not ts._waiting_on - assert not ts._waiters + Keep enough workers to have twice as much memory as we we need. - recommendations: dict = {} + >>> scheduler.workers_to_close(memory_ratio=2) - ts._exception = None - ts._exception_blame = None - ts._traceback = None + Returns + ------- + to_close: list of worker addresses that are OK to close - for dts in ts._dependents: - if dts._state == "erred": - recommendations[dts._key] = "waiting" + See Also + -------- + Scheduler.retire_workers + """ + parent: SchedulerState = cast(SchedulerState, self) + if target is not None and n is None: + n = len(parent._workers_dv) - target + if n is not None: + if n < 0: + n = 0 + target = len(parent._workers_dv) - n - self.report({"op": "task-retried", "key": key}) - ts.state = "released" + if n is None and memory_ratio is None: + memory_ratio = 2 - return recommendations - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + ws: WorkerState + with log_errors(): + if not n and all([ws._processing for ws in parent._workers_dv.values()]): + return [] - pdb.set_trace() - raise + if key is None: + key = operator.attrgetter("address") + if isinstance(key, bytes) and dask.config.get( + "distributed.scheduler.pickle" + ): + key = pickle.loads(key) - def transition_waiting_released(self, key): - try: - tasks: dict = self.tasks - ts: TaskState = tasks[key] + groups = groupby(key, parent._workers.values()) - if self.validate: - assert not ts._who_has - assert not ts._processing_on + limit_bytes = { + k: sum([ws._memory_limit for ws in v]) for k, v in groups.items() + } + group_bytes = {k: sum([ws._nbytes for ws in v]) for k, v in groups.items()} - recommendations: dict = {} + limit = sum(limit_bytes.values()) + total = sum(group_bytes.values()) - dts: TaskState - for dts in ts._dependencies: - s = dts._waiters - if ts in s: - s.discard(ts) - if not s and not dts._who_wants: - recommendations[dts._key] = "released" - ts._waiting_on.clear() + def _key(group): + wws: WorkerState + is_idle = not any([wws._processing for wws in groups[group]]) + bytes = -group_bytes[group] + return (is_idle, bytes) - ts.state = "released" + idle = sorted(groups, key=_key) - if ts._has_lost_dependencies: - recommendations[key] = "forgotten" - elif not ts._exception_blame and (ts._who_wants or ts._waiters): - recommendations[key] = "waiting" - else: - ts._waiters.clear() + to_close = [] + n_remain = len(parent._workers_dv) - return recommendations - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + while idle: + group = idle.pop() + if n is None and any([ws._processing for ws in groups[group]]): + break - pdb.set_trace() - raise + if minimum and n_remain - len(groups[group]) < minimum: + break - def transition_processing_released(self, key): - try: - tasks: dict = self.tasks - ts: TaskState = tasks[key] - dts: TaskState + limit -= limit_bytes[group] - if self.validate: - assert ts._processing_on - assert not ts._who_has - assert not ts._waiting_on - assert self.tasks[key].state == "processing" + if (n is not None and n_remain - len(groups[group]) >= target) or ( + memory_ratio is not None and limit >= memory_ratio * total + ): + to_close.append(group) + n_remain -= len(groups[group]) - self._remove_from_processing( - ts, send_worker_msg={"op": "release-task", "key": key} - ) + else: + break - ts.state = "released" + result = [getattr(ws, attribute) for g in to_close for ws in groups[g]] + if result: + logger.debug("Suggest closing workers: %s", result) - recommendations: dict = {} + return result - if ts._has_lost_dependencies: - recommendations[key] = "forgotten" - elif ts._waiters or ts._who_wants: - recommendations[key] = "waiting" + async def retire_workers( + self, + comm=None, + workers=None, + remove=True, + close_workers=False, + names=None, + lock=True, + **kwargs, + ) -> dict: + """Gracefully retire workers from cluster - if recommendations.get(key) != "waiting": - for dts in ts._dependencies: - if dts._state != "released": - s = dts._waiters - s.discard(ts) - if not s and not dts._who_wants: - recommendations[dts._key] = "released" - ts._waiters.clear() + Parameters + ---------- + workers: list (optional) + List of worker addresses to retire. + If not provided we call ``workers_to_close`` which finds a good set + workers_names: list (optional) + List of worker names to retire. + remove: bool (defaults to True) + Whether or not to remove the worker metadata immediately or else + wait for the worker to contact us + close_workers: bool (defaults to False) + Whether or not to actually close the worker explicitly from here. + Otherwise we expect some external job scheduler to finish off the + worker. + **kwargs: dict + Extra options to pass to workers_to_close to determine which + workers we should drop - if self.validate: - assert not ts._processing_on + Returns + ------- + Dictionary mapping worker ID/address to dictionary of information about + that worker for each retired worker. - return recommendations - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + See Also + -------- + Scheduler.workers_to_close + """ + parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState + ts: TaskState + with log_errors(): + async with self._lock if lock else empty_context: + if names is not None: + if names: + logger.info("Retire worker names %s", names) + names = set(map(str, names)) + workers = [ + ws._address + for ws in parent._workers_dv.values() + if str(ws._name) in names + ] + if workers is None: + while True: + try: + workers = self.workers_to_close(**kwargs) + if workers: + workers = await self.retire_workers( + workers=workers, + remove=remove, + close_workers=close_workers, + lock=False, + ) + return workers + else: + return {} + except KeyError: # keys left during replicate + pass + workers = { + parent._workers_dv[w] for w in workers if w in parent._workers_dv + } + if not workers: + return {} + logger.info("Retire workers %s", workers) - pdb.set_trace() - raise + # Keys orphaned by retiring those workers + keys = set.union(*[w.has_what for w in workers]) + keys = {ts._key for ts in keys if ts._who_has.issubset(workers)} - def transition_processing_erred( - self, key, cause=None, exception=None, traceback=None, **kwargs - ): - ws: WorkerState - try: - tasks: dict = self.tasks - ts: TaskState = tasks[key] - dts: TaskState - failing_ts: TaskState + other_workers = set(parent._workers_dv.values()) - workers + if keys: + if other_workers: + logger.info("Moving %d keys to other workers", len(keys)) + await self.replicate( + keys=keys, + workers=[ws._address for ws in other_workers], + n=1, + delete=False, + lock=False, + ) + else: + return {} - if self.validate: - assert cause or ts._exception_blame - assert ts._processing_on - assert not ts._who_has - assert not ts._waiting_on + worker_keys = {ws._address: ws.identity() for ws in workers} + if close_workers and worker_keys: + await asyncio.gather( + *[self.close_worker(worker=w, safe=True) for w in worker_keys] + ) + if remove: + await asyncio.gather( + *[self.remove_worker(address=w, safe=True) for w in worker_keys] + ) - if ts._actor: - ws = ts._processing_on - ws._actors.remove(ts) + self.log_event( + "all", + { + "action": "retire-workers", + "workers": worker_keys, + "moved-keys": len(keys), + }, + ) + self.log_event(list(worker_keys), {"action": "retired"}) - self._remove_from_processing(ts) + return worker_keys - if exception is not None: - ts._exception = exception - if traceback is not None: - ts._traceback = traceback - if cause is not None: - failing_ts = self.tasks[cause] - ts._exception_blame = failing_ts + def add_keys(self, comm=None, worker=None, keys=()): + """ + Learn that a worker has certain keys + + This should not be used in practice and is mostly here for legacy + reasons. However, it is sent by workers from time to time. + """ + parent: SchedulerState = cast(SchedulerState, self) + if worker not in parent._workers_dv: + return "not found" + ws: WorkerState = parent._workers_dv[worker] + for key in keys: + ts: TaskState = parent._tasks.get(key) + if ts is not None and ts._state == "memory": + if ts not in ws._has_what: + ws._nbytes += ts.get_nbytes() + ws._has_what.add(ts) + ts._who_has.add(ws) else: - failing_ts = ts._exception_blame + self.worker_send( + worker, {"op": "delete-data", "keys": [key], "report": False} + ) - recommendations: dict = {} + return "OK" - for dts in ts._dependents: - dts._exception_blame = failing_ts - recommendations[dts._key] = "erred" + def update_data( + self, comm=None, who_has=None, nbytes=None, client=None, serializers=None + ): + """ + Learn that new data has entered the network from an external source - for dts in ts._dependencies: - s = dts._waiters - s.discard(ts) - if not s and not dts._who_wants: - recommendations[dts._key] = "released" + See Also + -------- + Scheduler.mark_key_in_memory + """ + parent: SchedulerState = cast(SchedulerState, self) + with log_errors(): + who_has = { + k: [self.coerce_address(vv) for vv in v] for k, v in who_has.items() + } + logger.debug("Update data %s", who_has) + + for key, workers in who_has.items(): + ts: TaskState = parent._tasks.get(key) + if ts is None: + ts: TaskState = self.new_task(key, None, "memory") + ts.state = "memory" + if key in nbytes: + ts.set_nbytes(nbytes[key]) + for w in workers: + ws: WorkerState = parent._workers_dv[w] + if ts not in ws._has_what: + ws._nbytes += ts.get_nbytes() + ws._has_what.add(ts) + ts._who_has.add(ws) + self.report( + {"op": "key-in-memory", "key": key, "workers": list(workers)} + ) + + if client: + self.client_desires_keys(keys=list(who_has), client=client) + + def report_on_key(self, key: str = None, ts: TaskState = None, client: str = None): + parent: SchedulerState = cast(SchedulerState, self) + if ts is None: + tasks: dict = parent._tasks + ts = tasks.get(key) + elif key is None: + key = ts._key + else: + assert False, (key, ts) + return - ts._waiters.clear() # do anything with this? + report_msg: dict = _task_to_report_msg(parent, ts) + if report_msg is not None: + self.report(report_msg, ts=ts, client=client) - ts.state = "erred" + async def feed( + self, comm, function=None, setup=None, teardown=None, interval="1s", **kwargs + ): + """ + Provides a data Comm to external requester - self.report( - { - "op": "task-erred", - "key": key, - "exception": failing_ts._exception, - "traceback": failing_ts._traceback, - } + Caution: this runs arbitrary Python code on the scheduler. This should + eventually be phased out. It is mostly used by diagnostics. + """ + if not dask.config.get("distributed.scheduler.pickle"): + logger.warn( + "Tried to call 'feed' route with custom functions, but " + "pickle is disallowed. Set the 'distributed.scheduler.pickle'" + "config value to True to use the 'feed' route (this is mostly " + "commonly used with progress bars)" ) + return - cs: ClientState = self.clients["fire-and-forget"] - if ts in cs._wants_what: - self.client_releases_keys(client="fire-and-forget", keys=[key]) + interval = parse_timedelta(interval) + with log_errors(): + if function: + function = pickle.loads(function) + if setup: + setup = pickle.loads(setup) + if teardown: + teardown = pickle.loads(teardown) + state = setup(self) if setup else None + if inspect.isawaitable(state): + state = await state + try: + while self.status == Status.running: + if state is None: + response = function(self) + else: + response = function(self, state) + await comm.write(response) + await asyncio.sleep(interval) + except (EnvironmentError, CommClosedError): + pass + finally: + if teardown: + teardown(self, state) - if self.validate: - assert not ts._processing_on + def log_worker_event(self, worker=None, topic=None, msg=None): + self.log_event(topic, msg) - return recommendations - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + def subscribe_worker_status(self, comm=None): + WorkerStatusPlugin(self, comm) + ident = self.identity() + for v in ident["workers"].values(): + del v["metrics"] + del v["last_seen"] + return ident - pdb.set_trace() - raise + def get_processing(self, comm=None, workers=None): + parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState + ts: TaskState + if workers is not None: + workers = set(map(self.coerce_address, workers)) + return { + w: [ts._key for ts in parent._workers_dv[w].processing] for w in workers + } + else: + return { + w: [ts._key for ts in ws._processing] + for w, ws in parent._workers_dv.items() + } - def transition_no_worker_released(self, key): - try: - tasks: dict = self.tasks - ts: TaskState = tasks[key] - dts: TaskState + def get_who_has(self, comm=None, keys=None): + parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState + ts: TaskState + if keys is not None: + return { + k: [ws._address for ws in parent._tasks[k].who_has] + if k in parent._tasks + else [] + for k in keys + } + else: + return { + key: [ws._address for ws in ts._who_has] + for key, ts in parent._tasks.items() + } - if self.validate: - assert self.tasks[key].state == "no-worker" - assert not ts._who_has - assert not ts._waiting_on + def get_has_what(self, comm=None, workers=None): + parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState + ts: TaskState + if workers is not None: + workers = map(self.coerce_address, workers) + return { + w: [ts._key for ts in parent._workers_dv[w].has_what] + if w in parent._workers_dv + else [] + for w in workers + } + else: + return { + w: [ts._key for ts in ws._has_what] + for w, ws in parent._workers_dv.items() + } - self.unrunnable.remove(ts) - ts.state = "released" + def get_ncores(self, comm=None, workers=None): + parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState + if workers is not None: + workers = map(self.coerce_address, workers) + return { + w: parent._workers_dv[w].nthreads + for w in workers + if w in parent._workers_dv + } + else: + return {w: ws._nthreads for w, ws in parent._workers_dv.items()} - for dts in ts._dependencies: - dts._waiters.discard(ts) + async def get_call_stack(self, comm=None, keys=None): + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState + dts: TaskState + if keys is not None: + stack = list(keys) + processing = set() + while stack: + key = stack.pop() + ts = parent._tasks[key] + if ts._state == "waiting": + stack.extend([dts._key for dts in ts._dependencies]) + elif ts._state == "processing": + processing.add(ts) - ts._waiters.clear() + workers = defaultdict(list) + for ts in processing: + if ts._processing_on: + workers[ts._processing_on.address].append(ts._key) + else: + workers = {w: None for w in parent._workers_dv} + if not workers: return {} - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb - - pdb.set_trace() - raise - - def remove_key(self, key): - tasks: dict = self.tasks - ts: TaskState = tasks.pop(key) - assert ts._state == "forgotten" - self.unrunnable.discard(ts) - cs: ClientState - for cs in ts._who_wants: - cs._wants_what.remove(ts) - ts._who_wants.clear() - ts._processing_on = None - ts._exception_blame = ts._exception = ts._traceback = None - self.task_metadata.pop(key, None) - def _propagate_forgotten(self, ts: TaskState, recommendations: dict): - workers: dict = cast(dict, self.workers) - ts.state = "forgotten" - key: str = ts._key - dts: TaskState - for dts in ts._dependents: - dts._has_lost_dependencies = True - dts._dependencies.remove(ts) - dts._waiting_on.discard(ts) - if dts._state not in ("memory", "erred"): - # Cannot compute task anymore - recommendations[dts._key] = "forgotten" - ts._dependents.clear() - ts._waiters.clear() + results = await asyncio.gather( + *(self.rpc(w).call_stack(keys=v) for w, v in workers.items()) + ) + response = {w: r for w, r in zip(workers, results) if r} + return response - for dts in ts._dependencies: - dts._dependents.remove(ts) - s: set = dts._waiters - s.discard(ts) - if not dts._dependents and not dts._who_wants: - # Task not needed anymore - assert dts is not ts - recommendations[dts._key] = "forgotten" - ts._dependencies.clear() - ts._waiting_on.clear() + def get_nbytes(self, comm=None, keys=None, summary=True): + parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState + with log_errors(): + if keys is not None: + result = {k: parent._tasks[k].nbytes for k in keys} + else: + result = { + k: ts._nbytes for k, ts in parent._tasks.items() if ts._nbytes >= 0 + } - if ts._who_has: - ts._group._nbytes_in_memory -= ts.get_nbytes() + if summary: + out = defaultdict(lambda: 0) + for k, v in result.items(): + out[key_split(k)] += v + result = dict(out) - ws: WorkerState - for ws in ts._who_has: - ws._has_what.remove(ts) - ws._nbytes -= ts.get_nbytes() - w: str = ws._address - if w in workers: # in case worker has died - self.worker_send( - w, {"op": "delete-data", "keys": [key], "report": False} - ) - ts._who_has.clear() + return result - def transition_memory_forgotten(self, key): - tasks: dict - ws: WorkerState - try: - tasks = self.tasks - ts: TaskState = tasks[key] + def run_function(self, stream, function, args=(), kwargs={}, wait=True): + """Run a function within this process - if self.validate: - assert ts._state == "memory" - assert not ts._processing_on - assert not ts._waiting_on - if not ts._run_spec: - # It's ok to forget a pure data task - pass - elif ts._has_lost_dependencies: - # It's ok to forget a task with forgotten dependencies - pass - elif not ts._who_wants and not ts._waiters and not ts._dependents: - # It's ok to forget a task that nobody needs - pass - else: - assert 0, (ts,) + See Also + -------- + Client.run_on_scheduler: + """ + from .worker import run - recommendations: dict = {} + self.log_event("all", {"action": "run-function", "function": function}) + return run(self, stream, function=function, args=args, kwargs=kwargs, wait=wait) - if ts._actor: - for ws in ts._who_has: - ws._actors.discard(ts) + def set_metadata(self, comm=None, keys=None, value=None): + parent: SchedulerState = cast(SchedulerState, self) + try: + metadata = parent._task_metadata + for key in keys[:-1]: + if key not in metadata or not isinstance(metadata[key], (dict, list)): + metadata[key] = dict() + metadata = metadata[key] + metadata[keys[-1]] = value + except Exception as e: + import pdb - self._propagate_forgotten(ts, recommendations) + pdb.set_trace() - self.report_on_key(ts=ts) - self.remove_key(key) + def get_metadata(self, comm=None, keys=None, default=no_default): + parent: SchedulerState = cast(SchedulerState, self) + metadata = parent._task_metadata + for key in keys[:-1]: + metadata = metadata[key] + try: + return metadata[keys[-1]] + except KeyError: + if default != no_default: + return default + else: + raise - return recommendations - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + def get_task_status(self, comm=None, keys=None): + parent: SchedulerState = cast(SchedulerState, self) + return { + key: (parent._tasks[key].state if key in parent._tasks else None) + for key in keys + } - pdb.set_trace() - raise + def get_task_stream(self, comm=None, start=None, stop=None, count=None): + from distributed.diagnostics.task_stream import TaskStreamPlugin - def transition_released_forgotten(self, key): - try: - tasks: dict = self.tasks - ts: TaskState = tasks[key] + self.add_plugin(TaskStreamPlugin, idempotent=True) + tsp = [p for p in self.plugins if isinstance(p, TaskStreamPlugin)][0] + return tsp.collect(start=start, stop=stop, count=count) - if self.validate: - assert ts._state in ("released", "erred") - assert not ts._who_has - assert not ts._processing_on - assert not ts._waiting_on, (ts, ts._waiting_on) - if not ts._run_spec: - # It's ok to forget a pure data task - pass - elif ts._has_lost_dependencies: - # It's ok to forget a task with forgotten dependencies - pass - elif not ts._who_wants and not ts._waiters and not ts._dependents: - # It's ok to forget a task that nobody needs - pass - else: - assert 0, (ts,) + def start_task_metadata(self, comm=None, name=None): + plugin = CollectTaskMetaDataPlugin(scheduler=self, name=name) - recommendations: dict = {} - self._propagate_forgotten(ts, recommendations) + self.add_plugin(plugin) - self.report_on_key(ts=ts) - self.remove_key(key) + def stop_task_metadata(self, comm=None, name=None): + plugins = [ + p + for p in self.plugins + if isinstance(p, CollectTaskMetaDataPlugin) and p.name == name + ] + if len(plugins) != 1: + raise ValueError( + "Expected to find exactly one CollectTaskMetaDataPlugin " + f"with name {name} but found {len(plugins)}." + ) - return recommendations - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + plugin = plugins[0] + self.remove_plugin(plugin) + return {"metadata": plugin.metadata, "state": plugin.state} - pdb.set_trace() - raise + async def register_worker_plugin(self, comm, plugin, name=None): + """ Registers a setup function, and call it on every worker """ + self.worker_plugins.append({"plugin": plugin, "name": name}) + + responses = await self.broadcast( + msg=dict(op="plugin-add", plugin=plugin, name=name) + ) + return responses + + ##################### + # State Transitions # + ##################### def transition(self, key, finish, *args, **kwargs): """Transition a key from its current state to the finish state @@ -5509,10 +5819,13 @@ def transition(self, key, finish, *args, **kwargs): -------- Scheduler.transitions: transitive version of this function """ + parent: SchedulerState = cast(SchedulerState, self) ts: TaskState + worker_msgs: dict + client_msgs: dict try: try: - ts = self.tasks[key] + ts = parent._tasks[key] except KeyError: return {} start = ts._state @@ -5524,16 +5837,18 @@ def transition(self, key, finish, *args, **kwargs): dependencies = set(ts._dependencies) recommendations: dict = {} + worker_msgs = {} + client_msgs = {} if (start, finish) in self._transitions: func = self._transitions[start, finish] - recommendations = func(key, *args, **kwargs) + recommendations, worker_msgs, client_msgs = func(key, *args, **kwargs) elif "released" not in (start, finish): func = self._transitions["released", finish] assert not args and not kwargs a = self.transition(key, "released") if key in a: func = self._transitions["released", a[key]] - b = func(key) + b, worker_msgs, client_msgs = func(key) a = a.copy() a.update(b) recommendations = a @@ -5543,9 +5858,14 @@ def transition(self, key, finish, *args, **kwargs): "Impossible transition from %r to %r" % (start, finish) ) + for worker, msg in worker_msgs.items(): + self.worker_send(worker, msg) + for client, msg in client_msgs.items(): + self.client_send(client, msg) + finish2 = ts._state self.transition_log.append((key, start, finish2, recommendations, time())) - if self.validate: + if parent._validate: logger.debug( "Transitioned %r %s->%s (actual: %s). Consequence: %s", key, @@ -5562,14 +5882,14 @@ def transition(self, key, finish, *args, **kwargs): ts._dependencies = dependencies except KeyError: pass - self.tasks[ts._key] = ts + parent._tasks[ts._key] = ts for plugin in list(self.plugins): try: plugin.transition(key, start, finish2, *args, **kwargs) except Exception: logger.info("Plugin failed with exception", exc_info=True) if ts._state == "forgotten": - del self.tasks[ts._key] + del parent._tasks[ts._key] if ts._state == "forgotten" and ts._group._name in self.task_groups: # Remove TaskGroup if all tasks are in the forgotten state @@ -5593,6 +5913,7 @@ def transitions(self, recommendations: dict): This includes feedback from previous transitions and continues until we reach a steady state """ + parent: SchedulerState = cast(SchedulerState, self) keys = set() recommendations = recommendations.copy() while recommendations: @@ -5601,7 +5922,7 @@ def transitions(self, recommendations: dict): new = self.transition(key, finish) recommendations.update(new) - if self.validate: + if parent._validate: for key in keys: self.validate_key(key) @@ -5620,9 +5941,10 @@ def reschedule(self, key=None, worker=None): Things may have shifted and this task may now be better suited to run elsewhere """ + parent: SchedulerState = cast(SchedulerState, self) ts: TaskState try: - ts = self.tasks[key] + ts = parent._tasks[key] except KeyError: logger.warning( "Attempting to reschedule task {}, which was not " @@ -5635,130 +5957,26 @@ def reschedule(self, key=None, worker=None): return self.transitions({key: "released"}) - ############################## - # Assigning Tasks to Workers # - ############################## - - def check_idle_saturated(self, ws: WorkerState, occ: double = -1.0): - """Update the status of the idle and saturated state - - The scheduler keeps track of workers that are .. - - - Saturated: have enough work to stay busy - - Idle: do not have enough work to stay busy - - They are considered saturated if they both have enough tasks to occupy - all of their threads, and if the expected runtime of those tasks is - large enough. - - This is useful for load balancing and adaptivity. - """ - total_nthreads: Py_ssize_t = self.total_nthreads - if total_nthreads == 0 or ws.status == Status.closed: - return - if occ < 0: - occ = ws._occupancy - - nc: Py_ssize_t = ws._nthreads - p: Py_ssize_t = len(ws._processing) - total_occupancy: double = self.total_occupancy - avg: double = total_occupancy / total_nthreads - - idle = self.idle - saturated: set = self.saturated - if p < nc or occ < nc * avg / 2: - idle[ws._address] = ws - saturated.discard(ws) - else: - idle.pop(ws._address, None) - - if p > nc: - pending: double = occ * (p - nc) / (p * nc) - if 0.4 < pending > 1.9 * avg: - saturated.add(ws) - return - - saturated.discard(ws) - - def valid_workers(self, ts: TaskState) -> set: - """Return set of currently valid workers for key - - If all workers are valid then this returns ``None``. - This checks tracks the following state: - - * worker_restrictions - * host_restrictions - * resource_restrictions - """ - workers: dict = cast(dict, self.workers) - s: set = None - - if ts._worker_restrictions: - s = {w for w in ts._worker_restrictions if w in workers} - - if ts._host_restrictions: - # Resolve the alias here rather than early, for the worker - # may not be connected when host_restrictions is populated - hr: list = [self.coerce_hostname(h) for h in ts._host_restrictions] - # XXX need HostState? - sl: list = [ - self.host_info[h]["addresses"] for h in hr if h in self.host_info - ] - ss: set = set.union(*sl) if sl else set() - if s is None: - s = ss - else: - s |= ss - - if ts._resource_restrictions: - dw: dict = { - resource: { - w - for w, supplied in self.resources[resource].items() - if supplied >= required - } - for resource, required in ts._resource_restrictions.items() - } - - ww: set = set.intersection(*dw.values()) - if s is None: - s = ww - else: - s &= ww - - if s is not None: - s = {workers[w] for w in s} - - return s - - def consume_resources(self, ts: TaskState, ws: WorkerState): - if ts._resource_restrictions: - for r, required in ts._resource_restrictions.items(): - ws._used_resources[r] += required - - def release_resources(self, ts: TaskState, ws: WorkerState): - if ts._resource_restrictions: - for r, required in ts._resource_restrictions.items(): - ws._used_resources[r] -= required - ##################### # Utility functions # ##################### def add_resources(self, comm=None, worker=None, resources=None): - ws: WorkerState = self.workers[worker] + parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState = parent._workers_dv[worker] if resources: ws._resources.update(resources) ws._used_resources = {} for resource, quantity in ws._resources.items(): ws._used_resources[resource] = 0 - self.resources[resource][worker] = quantity + parent._resources[resource][worker] = quantity return "OK" def remove_resources(self, worker): - ws: WorkerState = self.workers[worker] + parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState = parent._workers_dv[worker] for resource, quantity in ws._resources.items(): - del self.resources[resource][worker] + del parent._resources[resource][worker] def coerce_address(self, addr, resolve=True): """ @@ -5768,8 +5986,9 @@ def coerce_address(self, addr, resolve=True): Handles strings, tuples, or aliases. """ # XXX how many address-parsing routines do we have? - if addr in self.aliases: - addr = self.aliases[addr] + parent: SchedulerState = cast(SchedulerState, self) + if addr in parent._aliases: + addr = parent._aliases[addr] if isinstance(addr, tuple): addr = unparse_host_port(*addr) if not isinstance(addr, str): @@ -5782,15 +6001,6 @@ def coerce_address(self, addr, resolve=True): return addr - def coerce_hostname(self, host): - """ - Coerce the hostname of a worker. - """ - if host in self.aliases: - return self.workers[self.aliases[host]].host - else: - return host - def workers_list(self, workers): """ List of qualifying workers @@ -5798,15 +6008,16 @@ def workers_list(self, workers): Takes a list of worker addresses or hostnames. Returns a list of all worker addresses that match """ + parent: SchedulerState = cast(SchedulerState, self) if workers is None: - return list(self.workers) + return list(parent._workers) out = set() for w in workers: if ":" in w: out.add(w) else: - out.update({ww for ww in self.workers if w in ww}) # TODO: quadratic + out.update({ww for ww in parent._workers if w in ww}) # TODO: quadratic return list(out) def start_ipython(self, comm=None): @@ -5822,29 +6033,6 @@ def start_ipython(self, comm=None): ) return self._ipython_kernel.get_connection_info() - def worker_objective(self, ts: TaskState, ws: WorkerState): - """ - Objective function to determine which worker should get the task - - Minimize expected start time. If a tie then break with data storage. - """ - dts: TaskState - nbytes: Py_ssize_t - comm_bytes: Py_ssize_t = 0 - for dts in ts._dependencies: - if ws not in dts._who_has: - nbytes = dts.get_nbytes() - comm_bytes += nbytes - - bandwidth: double = self.bandwidth - stack_time: double = ws._occupancy / ws._nthreads - start_time: double = stack_time + comm_bytes / bandwidth - - if ts._actor: - return (len(ws._actors), start_time, ws._nbytes) - else: - return (start_time, ws._nbytes) - async def get_profile( self, comm=None, @@ -5856,10 +6044,11 @@ async def get_profile( stop=None, key=None, ): + parent: SchedulerState = cast(SchedulerState, self) if workers is None: - workers = self.workers + workers = parent._workers_dv else: - workers = set(self.workers) & set(workers) + workers = set(parent._workers_dv) & set(workers) if scheduler: return profile.get_profile(self.io_loop.profile, start=start, stop=stop) @@ -5889,15 +6078,16 @@ async def get_profile_metadata( stop=None, profile_cycle_interval=None, ): + parent: SchedulerState = cast(SchedulerState, self) dt = profile_cycle_interval or dask.config.get( "distributed.worker.profile.cycle" ) dt = parse_timedelta(dt, default="ms") if workers is None: - workers = self.workers + workers = parent._workers_dv else: - workers = set(self.workers) & set(workers) + workers = set(parent._workers_dv) & set(workers) results = await asyncio.gather( *(self.rpc(w).profile_metadata(start=start, stop=stop) for w in workers), return_exceptions=True, @@ -5931,6 +6121,7 @@ async def get_profile_metadata( return {"counts": counts, "keys": keys} async def performance_report(self, comm=None, start=None, code=""): + parent: SchedulerState = cast(SchedulerState, self) stop = time() # Profiles compute, scheduler, workers = await asyncio.gather( @@ -6015,10 +6206,10 @@ def profile_to_figure(state): ntasks=total_tasks, tasks_timings=tasks_timings, address=self.address, - nworkers=len(self.workers), - threads=sum([ws._nthreads for ws in self.workers.values()]), + nworkers=len(parent._workers_dv), + threads=sum([ws._nthreads for ws in parent._workers_dv.values()]), memory=format_bytes( - sum([ws._memory_limit for ws in self.workers.values()]) + sum([ws._memory_limit for ws in parent._workers_dv.values()]) ), code=code, dask_version=dask.__version__, @@ -6107,6 +6298,7 @@ def reevaluate_occupancy(self, worker_index: Py_ssize_t = 0): lets us avoid this fringe optimization when we have better things to think about. """ + parent: SchedulerState = cast(SchedulerState, self) try: if self.status == Status.closed: return @@ -6115,7 +6307,7 @@ def reevaluate_occupancy(self, worker_index: Py_ssize_t = 0): next_time = timedelta(seconds=0.1) if self.proc.cpu_percent() < 50: - workers: list = list(self.workers.values()) + workers: list = list(parent._workers.values()) nworkers: Py_ssize_t = len(workers) i: Py_ssize_t for i in range(nworkers): @@ -6124,7 +6316,7 @@ def reevaluate_occupancy(self, worker_index: Py_ssize_t = 0): try: if ws is None or not ws._processing: continue - self._reevaluate_occupancy_worker(ws) + _reevaluate_occupancy_worker(parent, ws) finally: del ws # lose ref @@ -6141,36 +6333,13 @@ def reevaluate_occupancy(self, worker_index: Py_ssize_t = 0): logger.error("Error in reevaluate occupancy", exc_info=True) raise - def _reevaluate_occupancy_worker(self, ws: WorkerState): - """ See reevaluate_occupancy """ - old: double = ws._occupancy - new: double = 0 - diff: double - ts: TaskState - est: double - for ts in ws._processing: - est = self.set_duration_estimate(ts, ws) - new += est - - ws._occupancy = new - diff = new - old - self.total_occupancy += diff - self.check_idle_saturated(ws) - - # significant increase in duration - if new > old * 1.3: - steal = self.extensions.get("stealing") - if steal is not None: - for ts in ws._processing: - steal.remove_key_from_stealable(ts) - steal.put_key_in_stealable(ts) - async def check_worker_ttl(self): + parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState now = time() - for ws in self.workers.values(): + for ws in parent._workers_dv.values(): if (ws._last_seen < now - self.worker_ttl) and ( - ws._last_seen < now - 10 * heartbeat_interval(len(self.workers)) + ws._last_seen < now - 10 * heartbeat_interval(len(parent._workers_dv)) ): logger.warning( "Worker failed to heartbeat within %s seconds. Closing: %s", @@ -6180,8 +6349,12 @@ async def check_worker_ttl(self): await self.remove_worker(address=ws._address) def check_idle(self): + parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState - if any([ws._processing for ws in self.workers.values()]) or self.unrunnable: + if ( + any([ws._processing for ws in parent._workers_dv.values()]) + or parent._unrunnable + ): self.idle_since = None return elif not self.idle_since: @@ -6210,19 +6383,20 @@ def adaptive_target(self, comm=None, target_duration=None): -------- distributed.deploy.Adaptive """ + parent: SchedulerState = cast(SchedulerState, self) if target_duration is None: target_duration = dask.config.get("distributed.adaptive.target-duration") target_duration = parse_timedelta(target_duration) # CPU cpu = math.ceil( - self.total_occupancy / target_duration + parent._total_occupancy / target_duration ) # TODO: threads per worker # Avoid a few long tasks from asking for many cores ws: WorkerState tasks_processing = 0 - for ws in self.workers.values(): + for ws in parent._workers_dv.values(): tasks_processing += len(ws._processing) if tasks_processing > cpu: @@ -6230,25 +6404,295 @@ def adaptive_target(self, comm=None, target_duration=None): else: cpu = min(tasks_processing, cpu) - if self.unrunnable and not self.workers: + if parent._unrunnable and not parent._workers_dv: cpu = max(1, cpu) # Memory - limit_bytes = {addr: ws._memory_limit for addr, ws in self.workers.items()} - worker_bytes = [ws._nbytes for ws in self.workers.values()] + limit_bytes = { + addr: ws._memory_limit for addr, ws in parent._workers_dv.items() + } + worker_bytes = [ws._nbytes for ws in parent._workers_dv.values()] limit = sum(limit_bytes.values()) total = sum(worker_bytes) if total > 0.6 * limit: - memory = 2 * len(self.workers) + memory = 2 * len(parent._workers_dv) else: memory = 0 target = max(memory, cpu) - if target >= len(self.workers): + if target >= len(parent._workers_dv): return target else: # Scale down? to_close = self.workers_to_close() - return len(self.workers) - len(to_close) + return len(parent._workers_dv) - len(to_close) + + +@cfunc +@exceptval(check=False) +def _remove_from_processing(state: SchedulerState, ts: TaskState) -> str: + """ + Remove *ts* from the set of processing tasks. + """ + ws: WorkerState = ts._processing_on + ts._processing_on = None + w: str = ws._address + if w in state._workers_dv: # may have been removed + duration = ws._processing.pop(ts) + if not ws._processing: + state._total_occupancy -= ws._occupancy + ws._occupancy = 0 + else: + state._total_occupancy -= duration + ws._occupancy -= duration + state.check_idle_saturated(ws) + state.release_resources(ts, ws) + return w + else: + return None + + +@cfunc +@exceptval(check=False) +def _add_to_memory( + state: SchedulerState, + ts: TaskState, + ws: WorkerState, + recommendations: dict, + client_msgs: dict, + type=None, + typename: str = None, +): + """ + Add *ts* to the set of in-memory tasks. + """ + if state._validate: + assert ts not in ws._has_what + + ts._who_has.add(ws) + ws._has_what.add(ts) + ws._nbytes += ts.get_nbytes() + + deps: list = list(ts._dependents) + if len(deps) > 1: + deps.sort(key=operator.attrgetter("priority"), reverse=True) + + dts: TaskState + s: set + for dts in deps: + s = dts._waiting_on + if ts in s: + s.discard(ts) + if not s: # new task ready to run + recommendations[dts._key] = "processing" + + for dts in ts._dependencies: + s = dts._waiters + s.discard(ts) + if not s and not dts._who_wants: + recommendations[dts._key] = "released" + + report_msg: dict = {} + cs: ClientState + if not ts._waiters and not ts._who_wants: + recommendations[ts._key] = "released" + else: + report_msg["op"] = "key-in-memory" + report_msg["key"] = ts._key + if type is not None: + report_msg["type"] = type + + for cs in ts._who_wants: + client_msgs[cs._client_key] = report_msg + + ts.state = "memory" + ts._type = typename + ts._group._types.add(typename) + + cs = state._clients["fire-and-forget"] + if ts in cs._wants_what: + _client_releases_keys( + state, + cs=cs, + keys=[ts._key], + recommendations=recommendations, + ) + + +@cfunc +@exceptval(check=False) +def _propagate_forgotten( + state: SchedulerState, ts: TaskState, recommendations: dict, worker_msgs: dict +): + ts.state = "forgotten" + key: str = ts._key + dts: TaskState + for dts in ts._dependents: + dts._has_lost_dependencies = True + dts._dependencies.remove(ts) + dts._waiting_on.discard(ts) + if dts._state not in ("memory", "erred"): + # Cannot compute task anymore + recommendations[dts._key] = "forgotten" + ts._dependents.clear() + ts._waiters.clear() + + for dts in ts._dependencies: + dts._dependents.remove(ts) + s: set = dts._waiters + s.discard(ts) + if not dts._dependents and not dts._who_wants: + # Task not needed anymore + assert dts is not ts + recommendations[dts._key] = "forgotten" + ts._dependencies.clear() + ts._waiting_on.clear() + + if ts._who_has: + ts._group._nbytes_in_memory -= ts.get_nbytes() + + ws: WorkerState + for ws in ts._who_has: + ws._has_what.remove(ts) + ws._nbytes -= ts.get_nbytes() + w: str = ws._address + if w in state._workers_dv: # in case worker has died + worker_msgs[w] = {"op": "delete-data", "keys": [key], "report": False} + ts._who_has.clear() + + +@cfunc +@exceptval(check=False) +def _client_releases_keys( + state: SchedulerState, keys: list, cs: ClientState, recommendations: dict +): + """ Remove keys from client desired list """ + logger.debug("Client %s releases keys: %s", cs._client_key, keys) + ts: TaskState + tasks2: set = set() + for key in keys: + ts = state._tasks.get(key) + if ts is not None and ts in cs._wants_what: + cs._wants_what.remove(ts) + s: set = ts._who_wants + s.remove(cs) + if not s: + tasks2.add(ts) + + for ts in tasks2: + if not ts._dependents: + # No live dependents, can forget + recommendations[ts._key] = "forgotten" + elif ts._state != "erred" and not ts._waiters: + recommendations[ts._key] = "released" + + +@cfunc +@exceptval(check=False) +def _task_to_msg(state: SchedulerState, ts: TaskState, duration=None) -> dict: + """ Convert a single computational task to a message """ + ws: WorkerState + dts: TaskState + + if duration is None: + duration = state.get_task_duration(ts) + + msg: dict = { + "op": "compute-task", + "key": ts._key, + "priority": ts._priority, + "duration": duration, + } + if ts._resource_restrictions: + msg["resource_restrictions"] = ts._resource_restrictions + if ts._actor: + msg["actor"] = True + + deps: set = ts._dependencies + if deps: + msg["who_has"] = { + dts._key: [ws._address for ws in dts._who_has] for dts in deps + } + msg["nbytes"] = {dts._key: dts._nbytes for dts in deps} + + if state._validate: + assert all(msg["who_has"].values()) + + task = ts._run_spec + if type(task) is dict: + msg.update(task) + else: + msg["task"] = task + + return msg + + +@cfunc +@exceptval(check=False) +def _task_to_report_msg(state: SchedulerState, ts: TaskState) -> dict: + if ts is None: + return {"op": "cancelled-key", "key": ts._key} + elif ts._state == "forgotten": + return {"op": "cancelled-key", "key": ts._key} + elif ts._state == "memory": + return {"op": "key-in-memory", "key": ts._key} + elif ts._state == "erred": + failing_ts: TaskState = ts._exception_blame + return { + "op": "task-erred", + "key": ts._key, + "exception": failing_ts._exception, + "traceback": failing_ts._traceback, + } + else: + return None + + +@cfunc +@exceptval(check=False) +def _task_to_client_msgs(state: SchedulerState, ts: TaskState) -> dict: + cs: ClientState + client_keys: list + if ts is None: + # Notify all clients + client_keys = list(state._clients) + else: + # Notify clients interested in key + client_keys = [cs._client_key for cs in ts._who_wants] + + report_msg: dict = _task_to_report_msg(state, ts) + + client_msgs: dict = {} + for k in client_keys: + client_msgs[k] = report_msg + + return client_msgs + + +@cfunc +@exceptval(check=False) +def _reevaluate_occupancy_worker(state: SchedulerState, ws: WorkerState): + """ See reevaluate_occupancy """ + old: double = ws._occupancy + new: double = 0 + diff: double + ts: TaskState + est: double + for ts in ws._processing: + est = state.set_duration_estimate(ts, ws) + new += est + + ws._occupancy = new + diff = new - old + state._total_occupancy += diff + state.check_idle_saturated(ws) + + # significant increase in duration + if new > old * 1.3: + steal = state._extensions.get("stealing") + if steal is not None: + for ts in ws._processing: + steal.remove_key_from_stealable(ts) + steal.put_key_in_stealable(ts) @cfunc