Skip to content
This repository was archived by the owner on Mar 26, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions examples/fanout_sync/test_fanout_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest

from hatchet_sdk import Hatchet, Worker


# requires scope module or higher for shared event loop
@pytest.mark.parametrize("worker", ["fanout_sync"], indirect=True)
def test_run(hatchet: Hatchet, worker: Worker) -> None:
run = hatchet.admin.run_workflow("SyncFanoutParent", {"n": 2})
result = run.sync_result()
assert len(result["spawn"]["results"]) == 2
20 changes: 20 additions & 0 deletions examples/fanout_sync/trigger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import asyncio

from dotenv import load_dotenv

from hatchet_sdk import new_client


async def main() -> None:
load_dotenv()
hatchet = new_client()

hatchet.admin.run_workflow(
"SyncFanoutParent",
{"test": "test"},
options={"additional_metadata": {"hello": "moon"}},
)


if __name__ == "__main__":
asyncio.run(main())
55 changes: 55 additions & 0 deletions examples/fanout_sync/worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Any

from dotenv import load_dotenv

from hatchet_sdk import Context, Hatchet
from hatchet_sdk.workflow_run import WorkflowRunRef

load_dotenv()

hatchet = Hatchet(debug=True)


@hatchet.workflow(on_events=["parent:create"])
class SyncFanoutParent:
@hatchet.step(timeout="5m")
def spawn(self, context: Context) -> dict[str, Any]:
print("spawning child")

n = context.workflow_input().get("n", 5)

runs = context.spawn_workflows(
[
{
"workflow_name": "SyncFanoutChild",
"input": {"a": str(i)},
"key": f"child{i}",
"options": {"additional_metadata": {"hello": "earth"}},
}
for i in range(n)
]
)

results = [r.sync_result() for r in runs]

print(f"results {results}")

return {"results": results}


@hatchet.workflow(on_events=["child:create"])
class SyncFanoutChild:
@hatchet.step()
def process(self, context: Context) -> dict[str, str]:
return {"status": "success " + context.workflow_input()["a"]}


def main() -> None:
worker = hatchet.worker("sync-fanout-worker", max_runs=40)
worker.register_workflow(SyncFanoutParent())
worker.register_workflow(SyncFanoutChild())
worker.start()


if __name__ == "__main__":
main()
10 changes: 7 additions & 3 deletions hatchet_sdk/clients/workflow_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ class PooledWorkflowRunListener:
interrupter: asyncio.Task = None

def __init__(self, config: ClientConfig):
try:
asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

conn = new_conn(config, True)
self.client = DispatcherStub(conn)
self.token = config.token
Expand Down Expand Up @@ -260,12 +266,10 @@ async def _retry_subscribe(self):
if self.curr_requester != 0:
self.requests.put_nowait(self.curr_requester)

listener = self.client.SubscribeToWorkflowRuns(
return self.client.SubscribeToWorkflowRuns(
self._request(),
metadata=get_metadata(self.token),
)

return listener
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.UNAVAILABLE:
retries = retries + 1
Expand Down
41 changes: 41 additions & 0 deletions hatchet_sdk/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,44 @@ def fetch_run_failures(self) -> list[dict[str, StrictStr]]:
for step_run in job_run.step_runs
if step_run.error and step_run.step
]

@tenacity_retry
def spawn_workflow(
self,
workflow_name: str,
input: dict[str, Any] = {},
key: str | None = None,
options: ChildTriggerWorkflowOptions | None = None,
) -> WorkflowRunRef:
worker_id = self.worker.id()
trigger_options = self._prepare_workflow_options(key, options, worker_id)

return self.admin_client.run_workflow(workflow_name, input, trigger_options)

@tenacity_retry
def spawn_workflows(
self, child_workflow_runs: list[ChildWorkflowRunDict]
) -> list[WorkflowRunRef]:

if len(child_workflow_runs) == 0:
raise Exception("no child workflows to spawn")

worker_id = self.worker.id()

bulk_trigger_workflow_runs: list[WorkflowRunDict] = []
for child_workflow_run in child_workflow_runs:
workflow_name = child_workflow_run["workflow_name"]
input = child_workflow_run["input"]

key = child_workflow_run.get("key")
options = child_workflow_run.get("options", {})

trigger_options = self._prepare_workflow_options(key, options, worker_id)

bulk_trigger_workflow_runs.append(
WorkflowRunDict(
workflow_name=workflow_name, input=input, options=trigger_options
)
)

return self.admin_client.run_workflows(bulk_trigger_workflow_runs)
4 changes: 2 additions & 2 deletions hatchet_sdk/utils/aio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(self) -> None:
self.loop = asyncio.new_event_loop()
self.thread = Thread(target=self.run_loop_in_thread, args=(self.loop,))

def __enter__(self) -> asyncio.AbstractEventLoop:
def __enter__(self, *a, **kw) -> asyncio.AbstractEventLoop:
"""
Starts the thread running the event loop when entering the context.

Expand All @@ -102,7 +102,7 @@ def __enter__(self) -> asyncio.AbstractEventLoop:
self.thread.start()
return self.loop

def __exit__(self) -> None:
def __exit__(self, *a, **kw) -> None:
"""
Stops the event loop and joins the thread when exiting the context.
"""
Expand Down
16 changes: 9 additions & 7 deletions hatchet_sdk/workflow_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,18 @@ def result(self) -> Coroutine:
return self.workflow_listener.result(self.workflow_run_id)

def sync_result(self) -> dict:
coro = self.workflow_listener.result(self.workflow_run_id)
loop = get_active_event_loop()

if loop is None:
with EventLoopThread() as loop:
coro = self.workflow_listener.result(self.workflow_run_id)
future = asyncio.run_coroutine_threadsafe(coro, loop)
return future.result()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(coro)
finally:
asyncio.set_event_loop(None)
else:
coro = self.workflow_listener.result(self.workflow_run_id)
future = asyncio.run_coroutine_threadsafe(coro, loop)
return future.result()
return loop.run_until_complete(coro)
Comment on lines -37 to +46
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure this is safe, but I think the only risk of it causing issues is if there's no event loop, which I can't imagine being super common. It's a little tough to test well though



T = TypeVar("T")
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "hatchet-sdk"
version = "0.46.1"
version = "0.47.0"
description = ""
authors = ["Alexander Belanger <alexander@hatchet.run>"]
readme = "README.md"
Expand Down Expand Up @@ -111,6 +111,7 @@ explicit_package_bases = true
api = "examples.api.api:main"
async = "examples.async.worker:main"
fanout = "examples.fanout.worker:main"
fanout_sync = "examples.fanout_sync.worker:main"
cancellation = "examples.cancellation.worker:main"
concurrency_limit = "examples.concurrency_limit.worker:main"
concurrency_limit_rr = "examples.concurrency_limit_rr.worker:main"
Expand Down