diff --git a/distributed/diagnostics/tests/test_worker_plugin.py b/distributed/diagnostics/tests/test_worker_plugin.py index f3aa8cc6daf..befb228147b 100644 --- a/distributed/diagnostics/tests/test_worker_plugin.py +++ b/distributed/diagnostics/tests/test_worker_plugin.py @@ -188,18 +188,23 @@ async def test_dependent_tasks(c, s, w): @gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) async def test_registering_with_name_arg(c, s, w): class FooWorkerPlugin: - def setup(self, worker): - if hasattr(worker, "foo"): - raise RuntimeError(f"Worker {worker.address} already has foo!") + def __init__(self, value): + self.value = value - worker.foo = True + def setup(self, worker): + worker.foo = self.value - responses = await c.register_worker_plugin(FooWorkerPlugin(), name="foo") + responses = await c.register_worker_plugin(FooWorkerPlugin(23), name="foo") assert list(responses.values()) == [{"status": "OK"}] + results = await c.run(lambda dask_worker: dask_worker.foo) + assert results[w.address] == 23 async with Worker(s.address, loop=s.loop): - responses = await c.register_worker_plugin(FooWorkerPlugin(), name="foo") - assert list(responses.values()) == [{"status": "repeat"}] * 2 + responses = await c.register_worker_plugin(FooWorkerPlugin(42), name="foo") + assert list(responses.values()) == [{"status": "OK"}] * 2 + results = await c.run(lambda dask_worker: dask_worker.foo) + assert len(results) == 2 + assert all(v == 42 for v in results.values()) @gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) diff --git a/distributed/worker.py b/distributed/worker.py index c182a84249d..628747dc83d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2721,22 +2721,19 @@ async def plugin_add(self, comm=None, plugin=None, name=None): assert name - if name in self.plugins: - return {"status": "repeat"} - else: - self.plugins[name] = plugin + self.plugins[name] = plugin - logger.info("Starting Worker plugin %s" % name) - if hasattr(plugin, "setup"): - try: - result = plugin.setup(worker=self) - if isawaitable(result): - result = await result - except Exception as e: - msg = error_message(e) - return msg + logger.info("Starting Worker plugin %s" % name) + if hasattr(plugin, "setup"): + try: + result = plugin.setup(worker=self) + if isawaitable(result): + result = await result + except Exception as e: + msg = error_message(e) + return msg - return {"status": "OK"} + return {"status": "OK"} async def plugin_remove(self, comm=None, name=None): with log_errors(pdb=False):