Skip to content
This repository was archived by the owner on Mar 26, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions examples/fanout_sync/test_fanout_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import pytest

from hatchet_sdk import Hatchet, Worker


@pytest.mark.parametrize("worker", ["fanout_sync"], indirect=True)
def test_run(hatchet: Hatchet, worker: Worker) -> None:
run = hatchet.admin.run_workflow("SyncFanoutParent", {"n": 2})
result = run.sync_result()
assert len(result["spawn"]["results"]) == 2
20 changes: 20 additions & 0 deletions examples/fanout_sync/trigger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import asyncio

from dotenv import load_dotenv

from hatchet_sdk import new_client


async def main() -> None:
load_dotenv()
hatchet = new_client()

hatchet.admin.run_workflow(
"SyncFanoutParent",
{"test": "test"},
options={"additional_metadata": {"hello": "moon"}},
)


if __name__ == "__main__":
asyncio.run(main())
55 changes: 55 additions & 0 deletions examples/fanout_sync/worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Any

from dotenv import load_dotenv

from hatchet_sdk import Context, Hatchet
from hatchet_sdk.workflow_run import WorkflowRunRef

load_dotenv()

hatchet = Hatchet(debug=True)


@hatchet.workflow(on_events=["parent:create"])
class SyncFanoutParent:
@hatchet.step(timeout="5m")
def spawn(self, context: Context) -> dict[str, Any]:
print("spawning child")

n = context.workflow_input().get("n", 5)

runs = context.spawn_workflows(
[
{
"workflow_name": "SyncFanoutChild",
"input": {"a": str(i)},
"key": f"child{i}",
"options": {"additional_metadata": {"hello": "earth"}},
}
for i in range(n)
]
)

results = [r.sync_result() for r in runs]

print(f"results {results}")

return {"results": results}


@hatchet.workflow(on_events=["child:create"])
class SyncFanoutChild:
@hatchet.step()
def process(self, context: Context) -> dict[str, str]:
return {"status": "success " + context.workflow_input()["a"]}


def main() -> None:
worker = hatchet.worker("sync-fanout-worker", max_runs=40)
worker.register_workflow(SyncFanoutParent())
worker.register_workflow(SyncFanoutChild())
worker.start()


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from hatchet_sdk import Hatchet, Worker
from hatchet_sdk.clients.admin import TriggerWorkflowOptions
from hatchet_sdk.clients.events import PushEventOptions
from hatchet_sdk.opentelemetry.instrumentor import HatchetInstrumentor
from hatchet_sdk.opentelemetry.instrumentor import (
HatchetInstrumentor,
create_traceparent,
inject_traceparent_into_metadata,
)

trace_provider = NoOpTracerProvider()

Expand All @@ -17,9 +21,7 @@


def create_additional_metadata() -> dict[str, str]:
return instrumentor.inject_traceparent_into_metadata(
{"hello": "world"}, instrumentor.create_traceparent()
)
return inject_traceparent_into_metadata({"hello": "world"})


def create_push_options() -> PushEventOptions:
Expand Down
10 changes: 6 additions & 4 deletions examples/opentelemetry_instrumentation/triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
from examples.opentelemetry_instrumentation.tracer import trace_provider
from hatchet_sdk.clients.admin import TriggerWorkflowOptions
from hatchet_sdk.clients.events import PushEventOptions
from hatchet_sdk.opentelemetry.instrumentor import HatchetInstrumentor
from hatchet_sdk.opentelemetry.instrumentor import (
HatchetInstrumentor,
create_traceparent,
inject_traceparent_into_metadata,
)

instrumentor = HatchetInstrumentor(tracer_provider=trace_provider)
tracer = trace_provider.get_tracer(__name__)


def create_additional_metadata() -> dict[str, str]:
return instrumentor.inject_traceparent_into_metadata(
{"hello": "world"}, instrumentor.create_traceparent()
)
return inject_traceparent_into_metadata({"hello": "world"})


def create_push_options() -> PushEventOptions:
Expand Down
10 changes: 7 additions & 3 deletions hatchet_sdk/clients/workflow_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ class PooledWorkflowRunListener:
interrupter: asyncio.Task = None

def __init__(self, config: ClientConfig):
try:
asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

conn = new_conn(config, True)
self.client = DispatcherStub(conn)
self.token = config.token
Expand Down Expand Up @@ -260,12 +266,10 @@ async def _retry_subscribe(self):
if self.curr_requester != 0:
self.requests.put_nowait(self.curr_requester)

listener = self.client.SubscribeToWorkflowRuns(
return self.client.SubscribeToWorkflowRuns(
self._request(),
metadata=get_metadata(self.token),
)

return listener
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.UNAVAILABLE:
retries = retries + 1
Expand Down
41 changes: 41 additions & 0 deletions hatchet_sdk/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,44 @@ def fetch_run_failures(self) -> list[dict[str, StrictStr]]:
for step_run in job_run.step_runs
if step_run.error and step_run.step
]

@tenacity_retry
def spawn_workflow(
self,
workflow_name: str,
input: dict[str, Any] = {},
key: str | None = None,
options: ChildTriggerWorkflowOptions | None = None,
) -> WorkflowRunRef:
worker_id = self.worker.id()
trigger_options = self._prepare_workflow_options(key, options, worker_id)

return self.admin_client.run_workflow(workflow_name, input, trigger_options)

@tenacity_retry
def spawn_workflows(
self, child_workflow_runs: list[ChildWorkflowRunDict]
) -> list[WorkflowRunRef]:

if len(child_workflow_runs) == 0:
raise Exception("no child workflows to spawn")

worker_id = self.worker.id()

bulk_trigger_workflow_runs: list[WorkflowRunDict] = []
for child_workflow_run in child_workflow_runs:
workflow_name = child_workflow_run["workflow_name"]
input = child_workflow_run["input"]

key = child_workflow_run.get("key")
options = child_workflow_run.get("options", {})

trigger_options = self._prepare_workflow_options(key, options, worker_id)

bulk_trigger_workflow_runs.append(
WorkflowRunDict(
workflow_name=workflow_name, input=input, options=trigger_options
)
)

return self.admin_client.run_workflows(bulk_trigger_workflow_runs)
11 changes: 11 additions & 0 deletions hatchet_sdk/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
worker_healthcheck_port: int | None = None,
worker_healthcheck_enabled: bool | None = None,
worker_preset_labels: dict[str, str] = {},
enable_force_kill_sync_threads: bool = False,
):
self.tenant_id = tenant_id
self.tls_config = tls_config
Expand All @@ -55,6 +56,7 @@ def __init__(
self.worker_healthcheck_port = worker_healthcheck_port
self.worker_healthcheck_enabled = worker_healthcheck_enabled
self.worker_preset_labels = worker_preset_labels
self.enable_force_kill_sync_threads = enable_force_kill_sync_threads

if not self.logInterceptor:
self.logInterceptor = getLogger()
Expand Down Expand Up @@ -174,6 +176,14 @@ def get_config_value(key, env_var):
"The `otel_exporter_otlp_*` fields are no longer supported as of SDK version `0.46.0`. Please see the documentation on OpenTelemetry at https://docs.hatchet.run/home/features/opentelemetry for more information on how to migrate to the new `HatchetInstrumentor`."
)

enable_force_kill_sync_threads = bool(
get_config_value(
"enable_force_kill_sync_threads",
"HATCHET_CLIENT_ENABLE_FORCE_KILL_SYNC_THREADS",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made this opt-in so we can document some potential risks

)
== "True"
or False
)
return ClientConfig(
tenant_id=tenant_id,
tls_config=tls_config,
Expand All @@ -188,6 +198,7 @@ def get_config_value(key, env_var):
worker_healthcheck_port=worker_healthcheck_port,
worker_healthcheck_enabled=worker_healthcheck_enabled,
worker_preset_labels=worker_preset_labels,
enable_force_kill_sync_threads=enable_force_kill_sync_threads,
)

def _load_tls_config(self, tls_data: Dict, host_port) -> ClientTLSConfig:
Expand Down
Loading