diff --git a/src/scriptworker/client.py b/src/scriptworker/client.py index 42aa2966..3ba944e3 100644 --- a/src/scriptworker/client.py +++ b/src/scriptworker/client.py @@ -136,7 +136,7 @@ def sync_main( config_path: Optional[str] = None, default_config: Optional[Dict[str, Any]] = None, should_validate_task: bool = True, - loop_function: Callable[[], AbstractEventLoop] = asyncio.get_event_loop, + loop_function: Callable[[], AbstractEventLoop] = asyncio.new_event_loop, ) -> None: """Entry point for scripts using scriptworker. @@ -156,7 +156,7 @@ def sync_main( schema. Defaults to True. loop_function (function, optional): the function to call to get the event loop; here for testing purposes. Defaults to - ``asyncio.get_event_loop``. + ``asyncio.new_event_loop``. """ context = _init_context(config_path, default_config) diff --git a/tests/test_client.py b/tests/test_client.py index 50096303..4cc0c801 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -245,6 +245,22 @@ def loop_function(): assert len(async_main_calls) == 1 # async_main was called once +def test_sync_main(config): + copyfile(BASIC_TASK, os.path.join(config["work_dir"], "task.json")) + async_main_calls = [] + + async def async_main(*args): + async_main_calls.append(args) + + with tempfile.NamedTemporaryFile("w+") as f: + json.dump(config, f) + f.seek(0) + + client.sync_main(async_main, should_validate_task=False, config_path=f.name) + + assert len(async_main_calls) == 1 # async_main was called once + + @pytest.mark.parametrize( "does_use_argv, default_config", (