diff --git a/adaptive/runner.py b/adaptive/runner.py index ff3a137c3..46df5908a 100644 --- a/adaptive/runner.py +++ b/adaptive/runner.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Callable, Union import loky +import nest_asyncio from adaptive import ( BalancingLearner, @@ -670,7 +671,15 @@ def __init__( raise_if_retries_exceeded=raise_if_retries_exceeded, allow_running_forever=True, ) - self.ioloop = ioloop or asyncio.get_event_loop() + if ioloop is None: + try: + ioloop = asyncio.get_running_loop() + nest_asyncio.apply(ioloop) + except RuntimeError: + ioloop = asyncio.new_event_loop() + asyncio.set_event_loop(ioloop) + + self.ioloop = ioloop self.task = None # When the learned function is 'async def', we run it @@ -847,6 +856,10 @@ async def _saver(): self.saving_task = self.ioloop.create_task(_saver()) return self.saving_task + def block(self) -> None: + """Block until the runner is finished.""" + self.ioloop.run_until_complete(self.task) + # Default runner Runner = AsyncRunner diff --git a/setup.py b/setup.py index 12d6e7a6b..27fe5279a 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ def get_version_and_cmdclass(package_name): "sortedcontainers >= 2.0", "cloudpickle", "loky >= 2.9", + "nest_asyncio", ] if sys.version_info < (3, 10): install_requires.append("typing_extensions")