diff --git a/examples/fanout_sync/test_fanout_sync.py b/examples/fanout_sync/test_fanout_sync.py new file mode 100644 index 00000000..c0ef04b7 --- /dev/null +++ b/examples/fanout_sync/test_fanout_sync.py @@ -0,0 +1,11 @@ +import pytest + +from hatchet_sdk import Hatchet, Worker + + +# requires scope module or higher for shared event loop +@pytest.mark.parametrize("worker", ["fanout_sync"], indirect=True) +def test_run(hatchet: Hatchet, worker: Worker) -> None: + run = hatchet.admin.run_workflow("SyncFanoutParent", {"n": 2}) + result = run.sync_result() + assert len(result["spawn"]["results"]) == 2 diff --git a/examples/fanout_sync/trigger.py b/examples/fanout_sync/trigger.py new file mode 100644 index 00000000..d5ac99b8 --- /dev/null +++ b/examples/fanout_sync/trigger.py @@ -0,0 +1,20 @@ +import asyncio + +from dotenv import load_dotenv + +from hatchet_sdk import new_client + + +async def main() -> None: + load_dotenv() + hatchet = new_client() + + hatchet.admin.run_workflow( + "SyncFanoutParent", + {"test": "test"}, + options={"additional_metadata": {"hello": "moon"}}, + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/fanout_sync/worker.py b/examples/fanout_sync/worker.py new file mode 100644 index 00000000..a2020844 --- /dev/null +++ b/examples/fanout_sync/worker.py @@ -0,0 +1,55 @@ +from typing import Any + +from dotenv import load_dotenv + +from hatchet_sdk import Context, Hatchet +from hatchet_sdk.workflow_run import WorkflowRunRef + +load_dotenv() + +hatchet = Hatchet(debug=True) + + +@hatchet.workflow(on_events=["parent:create"]) +class SyncFanoutParent: + @hatchet.step(timeout="5m") + def spawn(self, context: Context) -> dict[str, Any]: + print("spawning child") + + n = context.workflow_input().get("n", 5) + + runs = context.spawn_workflows( + [ + { + "workflow_name": "SyncFanoutChild", + "input": {"a": str(i)}, + "key": f"child{i}", + "options": {"additional_metadata": {"hello": "earth"}}, + } + for i in range(n) + ] + ) + + results = [r.sync_result() for r in runs] + + print(f"results {results}") + + return {"results": results} + + +@hatchet.workflow(on_events=["child:create"]) +class SyncFanoutChild: + @hatchet.step() + def process(self, context: Context) -> dict[str, str]: + return {"status": "success " + context.workflow_input()["a"]} + + +def main() -> None: + worker = hatchet.worker("sync-fanout-worker", max_runs=40) + worker.register_workflow(SyncFanoutParent()) + worker.register_workflow(SyncFanoutChild()) + worker.start() + + +if __name__ == "__main__": + main() diff --git a/hatchet_sdk/clients/workflow_listener.py b/hatchet_sdk/clients/workflow_listener.py index b1131587..8bf71a3c 100644 --- a/hatchet_sdk/clients/workflow_listener.py +++ b/hatchet_sdk/clients/workflow_listener.py @@ -75,6 +75,12 @@ class PooledWorkflowRunListener: interrupter: asyncio.Task = None def __init__(self, config: ClientConfig): + try: + asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + conn = new_conn(config, True) self.client = DispatcherStub(conn) self.token = config.token @@ -260,12 +266,10 @@ async def _retry_subscribe(self): if self.curr_requester != 0: self.requests.put_nowait(self.curr_requester) - listener = self.client.SubscribeToWorkflowRuns( + return self.client.SubscribeToWorkflowRuns( self._request(), metadata=get_metadata(self.token), ) - - return listener except grpc.RpcError as e: if e.code() == grpc.StatusCode.UNAVAILABLE: retries = retries + 1 diff --git a/hatchet_sdk/context/context.py b/hatchet_sdk/context/context.py index f20acd66..2584949b 100644 --- a/hatchet_sdk/context/context.py +++ b/hatchet_sdk/context/context.py @@ -403,3 +403,44 @@ def fetch_run_failures(self) -> list[dict[str, StrictStr]]: for step_run in job_run.step_runs if step_run.error and step_run.step ] + + @tenacity_retry + def spawn_workflow( + self, + workflow_name: str, + input: dict[str, Any] = {}, + key: str | None = None, + options: ChildTriggerWorkflowOptions | None = None, + ) -> WorkflowRunRef: + worker_id = self.worker.id() + trigger_options = self._prepare_workflow_options(key, options, worker_id) + + return self.admin_client.run_workflow(workflow_name, input, trigger_options) + + @tenacity_retry + def spawn_workflows( + self, child_workflow_runs: list[ChildWorkflowRunDict] + ) -> list[WorkflowRunRef]: + + if len(child_workflow_runs) == 0: + raise Exception("no child workflows to spawn") + + worker_id = self.worker.id() + + bulk_trigger_workflow_runs: list[WorkflowRunDict] = [] + for child_workflow_run in child_workflow_runs: + workflow_name = child_workflow_run["workflow_name"] + input = child_workflow_run["input"] + + key = child_workflow_run.get("key") + options = child_workflow_run.get("options", {}) + + trigger_options = self._prepare_workflow_options(key, options, worker_id) + + bulk_trigger_workflow_runs.append( + WorkflowRunDict( + workflow_name=workflow_name, input=input, options=trigger_options + ) + ) + + return self.admin_client.run_workflows(bulk_trigger_workflow_runs) diff --git a/hatchet_sdk/utils/aio_utils.py b/hatchet_sdk/utils/aio_utils.py index 3f7ac3f3..459205f1 100644 --- a/hatchet_sdk/utils/aio_utils.py +++ b/hatchet_sdk/utils/aio_utils.py @@ -92,7 +92,7 @@ def __init__(self) -> None: self.loop = asyncio.new_event_loop() self.thread = Thread(target=self.run_loop_in_thread, args=(self.loop,)) - def __enter__(self) -> asyncio.AbstractEventLoop: + def __enter__(self, *a, **kw) -> asyncio.AbstractEventLoop: """ Starts the thread running the event loop when entering the context. @@ -102,7 +102,7 @@ def __enter__(self) -> asyncio.AbstractEventLoop: self.thread.start() return self.loop - def __exit__(self) -> None: + def __exit__(self, *a, **kw) -> None: """ Stops the event loop and joins the thread when exiting the context. """ diff --git a/hatchet_sdk/workflow_run.py b/hatchet_sdk/workflow_run.py index 51a23821..064f6741 100644 --- a/hatchet_sdk/workflow_run.py +++ b/hatchet_sdk/workflow_run.py @@ -32,16 +32,18 @@ def result(self) -> Coroutine: return self.workflow_listener.result(self.workflow_run_id) def sync_result(self) -> dict: + coro = self.workflow_listener.result(self.workflow_run_id) loop = get_active_event_loop() + if loop is None: - with EventLoopThread() as loop: - coro = self.workflow_listener.result(self.workflow_run_id) - future = asyncio.run_coroutine_threadsafe(coro, loop) - return future.result() + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(coro) + finally: + asyncio.set_event_loop(None) else: - coro = self.workflow_listener.result(self.workflow_run_id) - future = asyncio.run_coroutine_threadsafe(coro, loop) - return future.result() + return loop.run_until_complete(coro) T = TypeVar("T") diff --git a/pyproject.toml b/pyproject.toml index 29cc6578..7fbdc77a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "hatchet-sdk" -version = "0.46.1" +version = "0.47.0" description = "" authors = ["Alexander Belanger "] readme = "README.md" @@ -111,6 +111,7 @@ explicit_package_bases = true api = "examples.api.api:main" async = "examples.async.worker:main" fanout = "examples.fanout.worker:main" +fanout_sync = "examples.fanout_sync.worker:main" cancellation = "examples.cancellation.worker:main" concurrency_limit = "examples.concurrency_limit.worker:main" concurrency_limit_rr = "examples.concurrency_limit_rr.worker:main"