From c4d3b16aa85279012f6be19a2637b2fdd7b69383 Mon Sep 17 00:00:00 2001 From: Brian Phillips Date: Wed, 5 Jul 2023 10:30:21 -0400 Subject: [PATCH 1/4] Add Client.unregister_scheduler_plugin method --- distributed/client.py | 38 +++++++++++++++++++++++++++++ distributed/scheduler.py | 11 +++++---- distributed/tests/test_scheduler.py | 17 +++++++++++++ 3 files changed, 61 insertions(+), 5 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index b03f9671289..f4b8989daa1 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -4828,6 +4828,44 @@ def register_scheduler_plugin(self, plugin, name=None, idempotent=False): idempotent=idempotent, ) + async def _unregister_scheduler_plugin(self, name): + return await self.scheduler.unregister_scheduler_plugin(name=name) + + def unregister_scheduler_plugin(self, name): + """Unregisters a scheduler plugin + + See https://distributed.readthedocs.io/en/latest/plugins.html#scheduler-plugins + + Parameters + ---------- + name : str + Name of the plugin to unregister. See the :meth:`Client.register_scheduler_plugin` + docstring for more information. + + Examples + -------- + >>> class MyPlugin(SchedulerPlugin): + ... def __init__(self, *args, **kwargs): + ... pass # the constructor is up to you + ... async def start(self, scheduler: Scheduler) -> None: + ... pass + ... async def before_close(self) -> None: + ... pass + ... async def close(self) -> None: + ... pass + ... def restart(self, scheduler: Scheduler) -> None: + ... pass + + >>> plugin = MyPlugin(1, 2, 3) + >>> client.register_scheduler_plugin(plugin, name='foo') + >>> client.unregister_scheduler_plugin(name='foo') + + See Also + -------- + register_scheduler_plugin + """ + return self.sync(self._unregister_scheduler_plugin, name=name) + def register_worker_callbacks(self, setup=None): """ Registers a setup callback function for all current and future workers. diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 399c56d86fa..ceff3207ac2 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3700,6 +3700,7 @@ def __init__( "get_task_stream": self.get_task_stream, "get_task_prefix_states": self.get_task_prefix_states, "register_scheduler_plugin": self.register_scheduler_plugin, + "unregister_scheduler_plugin": self.unregister_scheduler_plugin, "register_worker_plugin": self.register_worker_plugin, "unregister_worker_plugin": self.unregister_worker_plugin, "register_nanny_plugin": self.register_nanny_plugin, @@ -5727,11 +5728,7 @@ def add_plugin( self.plugins[name] = plugin - def remove_plugin( - self, - name: str | None = None, - plugin: SchedulerPlugin | None = None, - ) -> None: + def remove_plugin(self, name: str | None = None) -> None: """Remove external plugin from scheduler Parameters @@ -5779,6 +5776,10 @@ async def register_scheduler_plugin( self.add_plugin(plugin, name=name, idempotent=idempotent) + async def unregister_scheduler_plugin(self, name: str) -> None: + """Unregister a plugin on the scheduler.""" + self.remove_plugin(name) + def worker_send(self, worker: str, msg: dict[str, Any]) -> None: """Send message to worker diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 497cc1dd35f..f6fe49c621e 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -4159,6 +4159,23 @@ def __init__(self, instance=None): assert s.plugins["nonidempotentplugin"].instance == "second" +@gen_cluster(nthreads=[]) +async def test_plugin_removal(s): + class Plugin(SchedulerPlugin): + def __init__(self): + self.name = 'plugin' + + plugin = Plugin() + await s.register_scheduler_plugin(plugin=dumps(plugin)) + assert 'plugin' in s.plugins + + await s.unregister_scheduler_plugin(name='plugin') + assert 'plugin' not in s.plugins + + with pytest.raises(ValueError, match='Could not find plugin'): + await s.unregister_scheduler_plugin(name='plugin') + + @gen_cluster(nthreads=[("", 1)]) async def test_repr(s, a): async with Worker(s.address, nthreads=2) as b: # name = address by default From 205ec73dbade62ed40abb746d8370f4bfc58cf66 Mon Sep 17 00:00:00 2001 From: Brian Phillips Date: Wed, 5 Jul 2023 10:35:49 -0400 Subject: [PATCH 2/4] styling --- distributed/tests/test_scheduler.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index f6fe49c621e..863a06d8cb8 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -4163,17 +4163,17 @@ def __init__(self, instance=None): async def test_plugin_removal(s): class Plugin(SchedulerPlugin): def __init__(self): - self.name = 'plugin' + self.name = "plugin" plugin = Plugin() await s.register_scheduler_plugin(plugin=dumps(plugin)) - assert 'plugin' in s.plugins + assert "plugin" in s.plugins - await s.unregister_scheduler_plugin(name='plugin') - assert 'plugin' not in s.plugins + await s.unregister_scheduler_plugin(name="plugin") + assert "plugin" not in s.plugins - with pytest.raises(ValueError, match='Could not find plugin'): - await s.unregister_scheduler_plugin(name='plugin') + with pytest.raises(ValueError, match="Could not find plugin"): + await s.unregister_scheduler_plugin(name="plugin") @gen_cluster(nthreads=[("", 1)]) From c763363c2feb22f1b1c0e8d0c04b2eb55ced41ec Mon Sep 17 00:00:00 2001 From: Brian Phillips Date: Wed, 5 Jul 2023 14:58:59 -0400 Subject: [PATCH 3/4] Improve coverage --- distributed/tests/test_client.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 2f9c252b0c3..2ae01fcd6a8 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -75,7 +75,7 @@ from distributed.comm import CommClosedError from distributed.compatibility import LINUX, WINDOWS from distributed.core import Status, error_message -from distributed.diagnostics.plugin import WorkerPlugin +from distributed.diagnostics.plugin import SchedulerPlugin, WorkerPlugin from distributed.metrics import time from distributed.scheduler import CollectTaskMetaDataPlugin, KilledWorker, Scheduler from distributed.shuffle import check_minimal_arrow_version @@ -6791,6 +6791,18 @@ def setup(self, worker=None): await c.register_worker_plugin(MyPlugin()) +@gen_cluster(client=True) +async def test_scheduler_plugins(c, s, a, b): + class Plugin(SchedulerPlugin): + name = "plugin" + + assert "plugin" not in s.plugins + await c.register_scheduler_plugin(Plugin()) + assert "plugin" in s.plugins + await c.unregister_scheduler_plugin("plugin") + assert "plugin" not in s.plugins + + @gen_cluster(client=True, nthreads=[("", 1)]) async def test_log_event(c, s, a): # Log an event from inside a task From 8ece1b749a0a913d30f50bf34642c8918590c734 Mon Sep 17 00:00:00 2001 From: Brian Phillips Date: Mon, 17 Jul 2023 10:22:04 -0400 Subject: [PATCH 4/4] PR comments --- .../tests/test_scheduler_plugin.py | 34 +++++++++++++++++++ distributed/tests/test_client.py | 14 +------- distributed/tests/test_scheduler.py | 17 ---------- 3 files changed, 35 insertions(+), 30 deletions(-) diff --git a/distributed/diagnostics/tests/test_scheduler_plugin.py b/distributed/diagnostics/tests/test_scheduler_plugin.py index 92f72310dac..4846a24ed6a 100644 --- a/distributed/diagnostics/tests/test_scheduler_plugin.py +++ b/distributed/diagnostics/tests/test_scheduler_plugin.py @@ -5,6 +5,7 @@ import pytest from distributed import Scheduler, SchedulerPlugin, Worker, get_worker +from distributed.protocol.pickle import dumps from distributed.utils_test import captured_logger, gen_cluster, gen_test, inc @@ -336,6 +337,39 @@ def start(self, scheduler): assert n_plugins == len(s.plugins) +@gen_cluster(nthreads=[]) +async def test_unregister_scheduler_plugin(s): + class Plugin(SchedulerPlugin): + def __init__(self): + self.name = "plugin" + + plugin = Plugin() + await s.register_scheduler_plugin(plugin=dumps(plugin)) + assert "plugin" in s.plugins + + await s.unregister_scheduler_plugin(name="plugin") + assert "plugin" not in s.plugins + + with pytest.raises(ValueError, match="Could not find plugin"): + await s.unregister_scheduler_plugin(name="plugin") + + +@gen_cluster(client=True) +async def test_unregister_scheduler_plugin_from_client(c, s, a, b): + class Plugin(SchedulerPlugin): + name = "plugin" + + assert "plugin" not in s.plugins + await c.register_scheduler_plugin(Plugin()) + assert "plugin" in s.plugins + + await c.unregister_scheduler_plugin("plugin") + assert "plugin" not in s.plugins + + with pytest.raises(ValueError, match="Could not find plugin"): + await c.unregister_scheduler_plugin(name="plugin") + + @gen_cluster(client=True) async def test_log_event_plugin(c, s, a, b): class EventPlugin(SchedulerPlugin): diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 2ae01fcd6a8..2f9c252b0c3 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -75,7 +75,7 @@ from distributed.comm import CommClosedError from distributed.compatibility import LINUX, WINDOWS from distributed.core import Status, error_message -from distributed.diagnostics.plugin import SchedulerPlugin, WorkerPlugin +from distributed.diagnostics.plugin import WorkerPlugin from distributed.metrics import time from distributed.scheduler import CollectTaskMetaDataPlugin, KilledWorker, Scheduler from distributed.shuffle import check_minimal_arrow_version @@ -6791,18 +6791,6 @@ def setup(self, worker=None): await c.register_worker_plugin(MyPlugin()) -@gen_cluster(client=True) -async def test_scheduler_plugins(c, s, a, b): - class Plugin(SchedulerPlugin): - name = "plugin" - - assert "plugin" not in s.plugins - await c.register_scheduler_plugin(Plugin()) - assert "plugin" in s.plugins - await c.unregister_scheduler_plugin("plugin") - assert "plugin" not in s.plugins - - @gen_cluster(client=True, nthreads=[("", 1)]) async def test_log_event(c, s, a): # Log an event from inside a task diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 863a06d8cb8..497cc1dd35f 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -4159,23 +4159,6 @@ def __init__(self, instance=None): assert s.plugins["nonidempotentplugin"].instance == "second" -@gen_cluster(nthreads=[]) -async def test_plugin_removal(s): - class Plugin(SchedulerPlugin): - def __init__(self): - self.name = "plugin" - - plugin = Plugin() - await s.register_scheduler_plugin(plugin=dumps(plugin)) - assert "plugin" in s.plugins - - await s.unregister_scheduler_plugin(name="plugin") - assert "plugin" not in s.plugins - - with pytest.raises(ValueError, match="Could not find plugin"): - await s.unregister_scheduler_plugin(name="plugin") - - @gen_cluster(nthreads=[("", 1)]) async def test_repr(s, a): async with Worker(s.address, nthreads=2) as b: # name = address by default