diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 16e927294b5..4534f6a4f65 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1561,7 +1561,6 @@ def _transitions( """ keys: set = set() recommendations = recommendations.copy() - msgs: list new_msgs: list new: tuple new_recs: dict @@ -1576,13 +1575,13 @@ def _transitions( recommendations.update(new_recs) for c, new_msgs in new_cmsgs.items(): - msgs = client_msgs.get(c) # type: ignore + msgs = client_msgs.get(c) if msgs is not None: msgs.extend(new_msgs) else: client_msgs[c] = new_msgs for w, new_msgs in new_wmsgs.items(): - msgs = worker_msgs.get(w) # type: ignore + msgs = worker_msgs.get(w) if msgs is not None: msgs.extend(new_msgs) else: @@ -3547,7 +3546,7 @@ def heartbeat_worker( @log_errors async def add_worker( self, - comm=None, + comm, *, address: str, status: str, @@ -3585,8 +3584,7 @@ async def add_worker( "message": "name taken, %s" % name, "time": time(), } - if comm: - await comm.write(msg) + await comm.write(msg) return self.log_event(address, {"action": "add-worker"}) @@ -3652,19 +3650,18 @@ async def add_worker( except Exception as e: logger.exception(e) - recommendations: dict = {} client_msgs: dict = {} worker_msgs: dict = {} if nbytes: assert isinstance(nbytes, dict) already_released_keys = [] for key in nbytes: - ts: TaskState = self.tasks.get(key) # type: ignore + ts: TaskState | None = self.tasks.get(key) if ts is not None and ts.state != "released": if ts.state == "memory": self.add_keys(worker=address, keys=[key]) else: - t: tuple = self._transition( + recommendations, new_cmsgs, new_wmsgs = self._transition( key, "memory", stimulus_id, @@ -3672,11 +3669,21 @@ async def add_worker( nbytes=nbytes[key], typename=types[key], ) - recommendations, client_msgs, worker_msgs = t + for c, new_msgs in new_cmsgs.items(): + msgs = client_msgs.get(c) + if msgs is not None: + msgs.extend(new_msgs) + else: + client_msgs[c] = new_msgs + for w, new_msgs in new_wmsgs.items(): + msgs = worker_msgs.get(w) + if msgs is not None: + msgs.extend(new_msgs) + else: + worker_msgs[w] = new_msgs self._transitions( recommendations, client_msgs, worker_msgs, stimulus_id ) - recommendations = {} else: already_released_keys.append(key) if already_released_keys: @@ -3691,10 +3698,12 @@ async def add_worker( ) if ws.status == Status.running: - recommendations.update(self.bulk_schedule_after_adding_worker(ws)) - - if recommendations: - self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) + self._transitions( + self.bulk_schedule_after_adding_worker(ws), + client_msgs, + worker_msgs, + stimulus_id, + ) self.send_all(client_msgs, worker_msgs) @@ -3719,10 +3728,9 @@ async def add_worker( ) msg.update(version_warning) - if comm: - await comm.write(msg) + await comm.write(msg) - await self.handle_worker(comm=comm, worker=address, stimulus_id=stimulus_id) + await self.handle_worker(comm, address, stimulus_id=stimulus_id) async def add_nanny(self, comm): msg = { @@ -4803,7 +4811,7 @@ def handle_worker_status_change( else: self.running.discard(ws) - async def handle_worker(self, comm=None, worker=None, stimulus_id=None): + async def handle_worker(self, comm, worker: str, stimulus_id=None): """ Listen to responses from a single worker diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index dbbf5c55ded..66e645ba909 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -370,25 +370,38 @@ async def test_clear_events_client_removal(c, s, a, b): assert time() < start + 2 -@gen_cluster() -async def test_add_worker(s, a, b): - w = Worker(s.address, nthreads=3) - w.data["x-5"] = 6 - w.data["y"] = 1 +@gen_cluster(nthreads=[("", 1)], client=True) +async def test_add_worker(c, s, a): + lock = Lock() - dsk = {("x-%d" % i): (inc, i) for i in range(10)} - s.update_graph( - tasks=valmap(dumps_task, dsk), - keys=list(dsk), - client="client", - dependencies={k: set() for k in dsk}, - ) - s.validate_state() - await w - s.validate_state() + async with lock: + anywhere = c.submit(inc, 0, key="l-0") + l1 = c.submit(lock.acquire, key="l-1") + l2 = c.submit(lock.acquire, key="l-2") + + while not (sum(t.state == "processing" for t in s.tasks.values()) == 3): + await asyncio.sleep(0.01) + + # Simulate a worker joining with necessary and unnecessary data. + w = Worker(s.address, nthreads=1) + w.update_data({"l-1": 2, "l-2": 3, "x": -1, "y": -2}) + # `update_data` queues messages to send; we want to purely test `add_worker` logic + w.batched_stream.buffer.clear() + + s.validate_state() + await w + s.validate_state() + + while not len(s.workers) == 2: + await asyncio.sleep(0.01) + + assert w.ip in s.host_info + assert s.host_info[w.ip]["addresses"] == {a.address, w.address} + + assert await c.gather([anywhere, l1, l2]) == [1, 2, 3] + assert "x" not in w.data + assert "y" not in w.data - assert w.ip in s.host_info - assert s.host_info[w.ip]["addresses"] == {a.address, b.address, w.address} await w.close()