Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 44 additions & 18 deletions distributed/diagnostics/tests/test_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,24 +185,6 @@ async def test_dependent_tasks(c, s, w):
await async_wait_for(lambda: not w.tasks, timeout=10)


@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True)
async def test_registering_with_name_arg(c, s, w):
class FooWorkerPlugin:
def setup(self, worker):
if hasattr(worker, "foo"):
raise RuntimeError(f"Worker {worker.address} already has foo!")

worker.foo = True

responses = await c.register_worker_plugin(FooWorkerPlugin(), name="foo")
assert list(responses.values()) == [{"status": "OK"}]

async with Worker(s.address, loop=s.loop):
with pytest.warns(FutureWarning, match="worker plugin will be overwritten"):
responses = await c.register_worker_plugin(FooWorkerPlugin(), name="foo")
assert list(responses.values()) == [{"status": "repeat"}] * 2


@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True)
async def test_empty_plugin(c, s, w):
class EmptyPlugin:
Expand All @@ -219,3 +201,47 @@ class MyCustomPlugin(WorkerPlugin):
await c.register_worker_plugin(MyCustomPlugin())
assert len(w.plugins) == 1
assert next(iter(w.plugins)).startswith("MyCustomPlugin-")


@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True)
async def test_WorkerPlugin_overwrite(c, s, w):
class MyCustomPlugin(WorkerPlugin):
name = "custom"

def setup(self, worker):
self.worker = worker
self.worker.foo = 0

def transition(self, *args, **kwargs):
self.worker.foo = 123

def teardown(self, worker):
del self.worker.foo

await c.register_worker_plugin(MyCustomPlugin)

assert w.foo == 0

await c.submit(inc, 0)
assert w.foo == 123

class MyCustomPlugin(WorkerPlugin):
name = "custom"

def setup(self, worker):
self.worker = worker
self.worker.bar = 0

def transition(self, *args, **kwargs):
self.worker.bar = 456

def teardown(self, worker):
del self.worker.bar

await c.register_worker_plugin(MyCustomPlugin)

assert not hasattr(w, "foo")
assert w.bar == 0

await c.submit(inc, 0)
assert w.bar == 456
35 changes: 14 additions & 21 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2774,28 +2774,21 @@ async def plugin_add(self, comm=None, plugin=None, name=None):
assert name

if name in self.plugins:
warnings.warn(
"Attempting to add a worker plugin with the same name as an already registered "
f"plugin ({name}). Currently this results in no change and the previously registered "
"plugin is not overwritten. This behavior is deprecated and in a future release "
f"the previously registered {name} worker plugin will be overwritten.",
category=FutureWarning,
)
return {"status": "repeat"}
else:
self.plugins[name] = plugin
await self.plugin_remove(comm=comm, name=name)

logger.info("Starting Worker plugin %s" % name)
if hasattr(plugin, "setup"):
try:
result = plugin.setup(worker=self)
if isawaitable(result):
result = await result
except Exception as e:
msg = error_message(e)
return msg

return {"status": "OK"}
self.plugins[name] = plugin

logger.info("Starting Worker plugin %s" % name)
if hasattr(plugin, "setup"):
try:
result = plugin.setup(worker=self)
if isawaitable(result):
result = await result
except Exception as e:
msg = error_message(e)
return msg

return {"status": "OK"}

async def plugin_remove(self, comm=None, name=None):
with log_errors(pdb=False):
Expand Down