diff --git a/distributed/client.py b/distributed/client.py index b03f9671289..f4b8989daa1 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -4828,6 +4828,44 @@ def register_scheduler_plugin(self, plugin, name=None, idempotent=False): idempotent=idempotent, ) + async def _unregister_scheduler_plugin(self, name): + return await self.scheduler.unregister_scheduler_plugin(name=name) + + def unregister_scheduler_plugin(self, name): + """Unregisters a scheduler plugin + + See https://distributed.readthedocs.io/en/latest/plugins.html#scheduler-plugins + + Parameters + ---------- + name : str + Name of the plugin to unregister. See the :meth:`Client.register_scheduler_plugin` + docstring for more information. + + Examples + -------- + >>> class MyPlugin(SchedulerPlugin): + ... def __init__(self, *args, **kwargs): + ... pass # the constructor is up to you + ... async def start(self, scheduler: Scheduler) -> None: + ... pass + ... async def before_close(self) -> None: + ... pass + ... async def close(self) -> None: + ... pass + ... def restart(self, scheduler: Scheduler) -> None: + ... pass + + >>> plugin = MyPlugin(1, 2, 3) + >>> client.register_scheduler_plugin(plugin, name='foo') + >>> client.unregister_scheduler_plugin(name='foo') + + See Also + -------- + register_scheduler_plugin + """ + return self.sync(self._unregister_scheduler_plugin, name=name) + def register_worker_callbacks(self, setup=None): """ Registers a setup callback function for all current and future workers. diff --git a/distributed/diagnostics/tests/test_scheduler_plugin.py b/distributed/diagnostics/tests/test_scheduler_plugin.py index 92f72310dac..4846a24ed6a 100644 --- a/distributed/diagnostics/tests/test_scheduler_plugin.py +++ b/distributed/diagnostics/tests/test_scheduler_plugin.py @@ -5,6 +5,7 @@ import pytest from distributed import Scheduler, SchedulerPlugin, Worker, get_worker +from distributed.protocol.pickle import dumps from distributed.utils_test import captured_logger, gen_cluster, gen_test, inc @@ -336,6 +337,39 @@ def start(self, scheduler): assert n_plugins == len(s.plugins) +@gen_cluster(nthreads=[]) +async def test_unregister_scheduler_plugin(s): + class Plugin(SchedulerPlugin): + def __init__(self): + self.name = "plugin" + + plugin = Plugin() + await s.register_scheduler_plugin(plugin=dumps(plugin)) + assert "plugin" in s.plugins + + await s.unregister_scheduler_plugin(name="plugin") + assert "plugin" not in s.plugins + + with pytest.raises(ValueError, match="Could not find plugin"): + await s.unregister_scheduler_plugin(name="plugin") + + +@gen_cluster(client=True) +async def test_unregister_scheduler_plugin_from_client(c, s, a, b): + class Plugin(SchedulerPlugin): + name = "plugin" + + assert "plugin" not in s.plugins + await c.register_scheduler_plugin(Plugin()) + assert "plugin" in s.plugins + + await c.unregister_scheduler_plugin("plugin") + assert "plugin" not in s.plugins + + with pytest.raises(ValueError, match="Could not find plugin"): + await c.unregister_scheduler_plugin(name="plugin") + + @gen_cluster(client=True) async def test_log_event_plugin(c, s, a, b): class EventPlugin(SchedulerPlugin): diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 399c56d86fa..ceff3207ac2 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3700,6 +3700,7 @@ def __init__( "get_task_stream": self.get_task_stream, "get_task_prefix_states": self.get_task_prefix_states, "register_scheduler_plugin": self.register_scheduler_plugin, + "unregister_scheduler_plugin": self.unregister_scheduler_plugin, "register_worker_plugin": self.register_worker_plugin, "unregister_worker_plugin": self.unregister_worker_plugin, "register_nanny_plugin": self.register_nanny_plugin, @@ -5727,11 +5728,7 @@ def add_plugin( self.plugins[name] = plugin - def remove_plugin( - self, - name: str | None = None, - plugin: SchedulerPlugin | None = None, - ) -> None: + def remove_plugin(self, name: str | None = None) -> None: """Remove external plugin from scheduler Parameters @@ -5779,6 +5776,10 @@ async def register_scheduler_plugin( self.add_plugin(plugin, name=name, idempotent=idempotent) + async def unregister_scheduler_plugin(self, name: str) -> None: + """Unregister a plugin on the scheduler.""" + self.remove_plugin(name) + def worker_send(self, worker: str, msg: dict[str, Any]) -> None: """Send message to worker