diff --git a/executor/engine/job/process.py b/executor/engine/job/process.py index 31f0963..0ed3e1b 100644 --- a/executor/engine/job/process.py +++ b/executor/engine/job/process.py @@ -1,9 +1,11 @@ import asyncio import functools +import inspect from loky.process_executor import ProcessPoolExecutor from .base import Job +from .utils import _gen_initializer, GeneratorWrapper class ProcessJob(Job): @@ -43,11 +45,16 @@ def release_resource(self) -> bool: async def run(self): """Run job in process pool.""" - self._executor = ProcessPoolExecutor(1) - loop = asyncio.get_running_loop() - func = functools.partial(self.func, **self.kwargs) - fut = loop.run_in_executor(self._executor, func, *self.args) - result = await fut + func = functools.partial(self.func, *self.args, **self.kwargs) + if (inspect.isgeneratorfunction(self.func) + or inspect.isasyncgenfunction(self.func)): + self._executor = ProcessPoolExecutor(1, initializer=_gen_initializer, initargs=(func,)) + result = GeneratorWrapper(self) + else: + self._executor = ProcessPoolExecutor(1) + loop = asyncio.get_running_loop() + fut = loop.run_in_executor(self._executor, func) + result = await fut return result async def cancel(self): diff --git a/executor/engine/job/utils.py b/executor/engine/job/utils.py index 0eb935e..73e992a 100644 --- a/executor/engine/job/utils.py +++ b/executor/engine/job/utils.py @@ -1,4 +1,5 @@ import typing as T +import asyncio from datetime import datetime from ..utils import CheckAttrRange, ExecutorError @@ -34,3 +35,41 @@ def __init__(self, job: "Job", valid_status: T.List[JobStatusType]): super().__init__( f"Invalid state: {job} is in {job.status} state, " f"but should be in {valid_status} state.") + + +_T = T.TypeVar("_T") + +def _gen_initializer(gen_func, args=tuple(), kwargs={}): + global _generator + _generator = gen_func(*args, **kwargs) + + +def _gen_next(): + global _generator + return next(_generator) + + +def _gen_anext(): + global _generator + return asyncio.run(_generator.__anext__()) + + +class GeneratorWrapper(T.Generic[_T]): + """ + wrap a generator in executor pool + """ + def __init__(self, job: "Job"): + self._job = job + + def __iter__(self): + return self + + def __next__(self) -> _T: + return self._job._executor.submit(_gen_next).result() + + def __aiter__(self): + return self + + async def __anext__(self) -> _T: + fut = self._job._executor.submit(_gen_anext) + return (await asyncio.wrap_future(fut))