From 94936830b82ba6f249ff27611d66acc6cb43502f Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 22 Aug 2021 08:30:35 -0500 Subject: [PATCH] Overwrite worker plugins This changes behavior for WorkerPlugins to overwrite the previous plugin if there is a colliding name. --- .../diagnostics/tests/test_worker_plugin.py | 62 +++++++++++++------ distributed/worker.py | 35 +++++------ 2 files changed, 58 insertions(+), 39 deletions(-) diff --git a/distributed/diagnostics/tests/test_worker_plugin.py b/distributed/diagnostics/tests/test_worker_plugin.py index 7ae01f09227..6d62f8103ea 100644 --- a/distributed/diagnostics/tests/test_worker_plugin.py +++ b/distributed/diagnostics/tests/test_worker_plugin.py @@ -185,24 +185,6 @@ async def test_dependent_tasks(c, s, w): await async_wait_for(lambda: not w.tasks, timeout=10) -@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) -async def test_registering_with_name_arg(c, s, w): - class FooWorkerPlugin: - def setup(self, worker): - if hasattr(worker, "foo"): - raise RuntimeError(f"Worker {worker.address} already has foo!") - - worker.foo = True - - responses = await c.register_worker_plugin(FooWorkerPlugin(), name="foo") - assert list(responses.values()) == [{"status": "OK"}] - - async with Worker(s.address, loop=s.loop): - with pytest.warns(FutureWarning, match="worker plugin will be overwritten"): - responses = await c.register_worker_plugin(FooWorkerPlugin(), name="foo") - assert list(responses.values()) == [{"status": "repeat"}] * 2 - - @gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) async def test_empty_plugin(c, s, w): class EmptyPlugin: @@ -219,3 +201,47 @@ class MyCustomPlugin(WorkerPlugin): await c.register_worker_plugin(MyCustomPlugin()) assert len(w.plugins) == 1 assert next(iter(w.plugins)).startswith("MyCustomPlugin-") + + +@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) +async def test_WorkerPlugin_overwrite(c, s, w): + class MyCustomPlugin(WorkerPlugin): + name = "custom" + + def setup(self, worker): + self.worker = worker + self.worker.foo = 0 + + def transition(self, *args, **kwargs): + self.worker.foo = 123 + + def teardown(self, worker): + del self.worker.foo + + await c.register_worker_plugin(MyCustomPlugin) + + assert w.foo == 0 + + await c.submit(inc, 0) + assert w.foo == 123 + + class MyCustomPlugin(WorkerPlugin): + name = "custom" + + def setup(self, worker): + self.worker = worker + self.worker.bar = 0 + + def transition(self, *args, **kwargs): + self.worker.bar = 456 + + def teardown(self, worker): + del self.worker.bar + + await c.register_worker_plugin(MyCustomPlugin) + + assert not hasattr(w, "foo") + assert w.bar == 0 + + await c.submit(inc, 0) + assert w.bar == 456 diff --git a/distributed/worker.py b/distributed/worker.py index c1f3a0c47d2..0f0785cc9bb 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2774,28 +2774,21 @@ async def plugin_add(self, comm=None, plugin=None, name=None): assert name if name in self.plugins: - warnings.warn( - "Attempting to add a worker plugin with the same name as an already registered " - f"plugin ({name}). Currently this results in no change and the previously registered " - "plugin is not overwritten. This behavior is deprecated and in a future release " - f"the previously registered {name} worker plugin will be overwritten.", - category=FutureWarning, - ) - return {"status": "repeat"} - else: - self.plugins[name] = plugin + await self.plugin_remove(comm=comm, name=name) - logger.info("Starting Worker plugin %s" % name) - if hasattr(plugin, "setup"): - try: - result = plugin.setup(worker=self) - if isawaitable(result): - result = await result - except Exception as e: - msg = error_message(e) - return msg - - return {"status": "OK"} + self.plugins[name] = plugin + + logger.info("Starting Worker plugin %s" % name) + if hasattr(plugin, "setup"): + try: + result = plugin.setup(worker=self) + if isawaitable(result): + result = await result + except Exception as e: + msg = error_message(e) + return msg + + return {"status": "OK"} async def plugin_remove(self, comm=None, name=None): with log_errors(pdb=False):