From 1fde9d6ad545f5ca8b3e7bc7fe69ce583e9f3e8a Mon Sep 17 00:00:00 2001 From: Hanwen Wu Date: Tue, 10 Sep 2024 06:53:38 -0400 Subject: [PATCH 01/12] wip v2 experimentation --- hatchet_sdk/runtime/__init__.py | 0 hatchet_sdk/runtime/admin.py | 20 ++ hatchet_sdk/runtime/registry.py | 19 ++ hatchet_sdk/v2/callable.py | 438 ++++++++++++++++++---------- hatchet_sdk/v2/hatchet.py | 371 +++++++++++------------ hatchet_sdk/worker/runner/runner.py | 55 ++-- hatchet_sdk/worker/worker.py | 33 ++- pyproject.toml | 3 + tests/v2/__init__.py | 0 tests/v2/test_traces.py | 25 ++ 10 files changed, 560 insertions(+), 404 deletions(-) create mode 100644 hatchet_sdk/runtime/__init__.py create mode 100644 hatchet_sdk/runtime/admin.py create mode 100644 hatchet_sdk/runtime/registry.py create mode 100644 tests/v2/__init__.py create mode 100644 tests/v2/test_traces.py diff --git a/hatchet_sdk/runtime/__init__.py b/hatchet_sdk/runtime/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/hatchet_sdk/runtime/admin.py b/hatchet_sdk/runtime/admin.py new file mode 100644 index 00000000..80ab83ad --- /dev/null +++ b/hatchet_sdk/runtime/admin.py @@ -0,0 +1,20 @@ + +# import hatchet_sdk.v2.callable as sdk +# import hatchet_sdk.clients.admin as client + +# from hatchet_sdk.contracts.workflows_pb2 import ( +# CreateStepRateLimit, +# CreateWorkflowJobOpts, +# CreateWorkflowStepOpts, +# CreateWorkflowVersionOpts, +# DesiredWorkerLabels, +# StickyStrategy, +# WorkflowConcurrencyOpts, +# WorkflowKind, +# ) + +# async def put_workflow(callable: sdk.HatchetCallable, client: client.AdminClient): +# options = callable._.options + +# kind: WorkflowKind = WorkflowKind.DURABLE if options.durable else WorkflowKind.FUNCTION + diff --git a/hatchet_sdk/runtime/registry.py b/hatchet_sdk/runtime/registry.py new file mode 100644 index 00000000..5dce3a7a --- /dev/null +++ b/hatchet_sdk/runtime/registry.py @@ -0,0 +1,19 @@ +from typing import Dict, List + + +class ActionRegistry: + + _registry: Dict[str, "HatchetCallable"] = dict() + + def register(self, callable: "HatchetCallable") -> str: + key = "{namespace}:{name}".format( + namespace=callable._.namespace, name=callable._.name + ) + self._registry[key] = callable + return key + + def list(self) -> List[str]: + return list(self._registry.keys()) + + +global_registry = ActionRegistry() diff --git a/hatchet_sdk/v2/callable.py b/hatchet_sdk/v2/callable.py index 0738c2f2..6e8bb5ac 100644 --- a/hatchet_sdk/v2/callable.py +++ b/hatchet_sdk/v2/callable.py @@ -1,7 +1,33 @@ +from __future__ import annotations + import asyncio -from typing import Callable, Dict, Generic, List, Optional, TypedDict, TypeVar, Union +import inspect +import json +from collections.abc import Awaitable, Callable +from contextvars import ContextVar, copy_context +from dataclasses import dataclass +from datetime import timedelta +from typing import ( + Any, + Dict, + ForwardRef, + Generic, + List, + Literal, + Optional, + ParamSpec, + TypedDict, + TypeVar, + Union, +) + +from google.protobuf.json_format import MessageToDict +from pydantic import BaseModel, ConfigDict, Field, computed_field +from pydantic.json_schema import SkipJsonSchema +from hatchet_sdk.clients.admin import TriggerWorkflowOptions from hatchet_sdk.context import Context +from hatchet_sdk.context.context import BaseContext, Context, ContextAioImpl from hatchet_sdk.contracts.workflows_pb2 import ( CreateStepRateLimit, CreateWorkflowJobOpts, @@ -15,153 +41,254 @@ from hatchet_sdk.labels import DesiredWorkerLabel from hatchet_sdk.logger import logger from hatchet_sdk.rate_limit import RateLimit +from hatchet_sdk.runtime import registry from hatchet_sdk.v2.concurrency import ConcurrencyFunction from hatchet_sdk.workflow_run import RunRef +# from typing import TYPE_CHECKING + +# if TYPE_CHECKING: +# from hatchet_sdk.v2.hatchet import Hatchet + + T = TypeVar("T") +P = ParamSpec("P") + +# TODO: according to Python, we should just use strings. +Options = ForwardRef("Options", is_class=True) +CallableMetadata = ForwardRef("CallableMetadata", is_class=True) -class HatchetCallable(Generic[T]): +class HatchetCallableBase(Generic[P, T]): + action_name: str + func: Callable[P, T] # note that T can be an Awaitable if func is a coroutine + _: CallableMetadata + def __init__( - self, - func: Callable[[Context], T], - durable: bool = False, - name: str = "", - auto_register: bool = True, - on_events: list | None = None, - on_crons: list | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy = None, - retries: int = 0, - rate_limits: List[RateLimit] | None = None, - concurrency: ConcurrencyFunction | None = None, - on_failure: Optional["HatchetCallable"] = None, - desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, - default_priority: int | None = None, + self, *, func: Callable[P, T], name: str, namespace: str, options: Options ): self.func = func + self._ = CallableMetadata( + name=name.lower() or str(func.__name__).lower(), + namespace=namespace, + options=options, + sourceloc=self.sourceloc, + ) + self.action_name = registry.global_registry.register(self) + + @property + def sourceloc(self) -> str: + try: + return "{}:{}".format( + inspect.getsourcefile(self.func), + inspect.getsourcelines(self.func)[1], + ) + except: + return "" + + # def __call__(self, context: Context) -> T: + # return self.func(context) + + # def with_namespace(self, namespace: str): + # if namespace is not None and namespace != "": + # self.function_namespace = namespace + # self.function_name = namespace + self.function_name + + def _to_workflow_proto(self) -> CreateWorkflowVersionOpts: + options = self._.options + + # if self.function_on_failure is not None: + # on_failure_job = CreateWorkflowJobOpts( + # name=self.function_name + "-on-failure", + # steps=[ + # self.function_on_failure.to_step(), + # ], + # ) + # # concurrency: WorkflowConcurrencyOpts | None = None + # if self.function_concurrency is not None: + # self.function_concurrency.set_namespace(self.function_namespace) + # concurrency = WorkflowConcurrencyOpts( + # action=self.function_concurrency.get_action_name(), + # max_runs=self.function_concurrency.max_runs, + # limit_strategy=self.function_concurrency.limit_strategy, + # ) - on_events = on_events or [] - on_crons = on_crons or [] - - limits = None - if rate_limits: - limits = [ - CreateStepRateLimit(key=rate_limit.key, units=rate_limit.units) - for rate_limit in rate_limits or [] - ] - - self.function_desired_worker_labels = {} - - for key, d in desired_worker_labels.items(): - value = d["value"] if "value" in d else None - self.function_desired_worker_labels[key] = DesiredWorkerLabels( + workflow = CreateWorkflowVersionOpts( + name=self._.name, + kind=WorkflowKind.DURABLE if options.durable else WorkflowKind.FUNCTION, + version=options.version, + event_triggers=options.on_events, + cron_triggers=options.on_crons, + schedule_timeout=options.schedule_timeout, + sticky=options.sticky, + on_failure_job=( + options.on_failure._to_job_proto() if options.on_failure else None + ), + concurrency=None, # TODO + jobs=[ + self._to_job_proto() + ], # Note that the failure job is also a HatchetCallable, and it should manage its own name. + default_priority=options.priority, + ) + return workflow + + def _to_job_proto(self) -> CreateWorkflowJobOpts: + job = CreateWorkflowJobOpts(name=self._.name, steps=[self._to_step_proto()]) + return job + + def _to_step_proto(self) -> CreateWorkflowStepOpts: + options = self._.options + step = CreateWorkflowStepOpts( + readable_id=self._.name, + action=self.action_name, + timeout=options.execution_timeout, + inputs="{}", # TODO: not sure that this is, we're defining a step, not running a step + parents=[], # this is a single step workflow, always empty + retries=options.retries, + rate_limits=options.ratelimits, + # worker_labels=self.function_desired_worker_labels, + ) + return step + + def _to_trigger_proto(self) -> Optional[TriggerWorkflowOptions]: + ctx = CallableContext.current() + if not ctx: + return None + trigger: TriggerWorkflowOptions = { + "parent_id": ctx.workflow_run_id, + "parent_step_run_id": ctx.step_run_id, + } + return trigger + + def _debug(self): + data = { + "action_name": self.action_name, + "func": repr(self.func), + "metadata": self._.model_dump(), + "def_proto": MessageToDict(self._to_workflow_proto()), + "call_proto": ( + MessageToDict(self._to_trigger_proto()) + if self._to_trigger_proto() + else None + ), + } + return data + + def _run(self, context: BaseContext): + raise NotImplementedError + + +class HatchetCallable(HatchetCallableBase[P, T]): + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + input = json.dumps({args: args, kwargs: kwargs}) + client = self._.options.hatchet + ref = client.admin.run( + self.action_name, input=input, options=self._to_trigger_proto() + ) + return asyncio.gather(ref.result()).result + + def _run(self, context: Context) -> T: + input = json.loads(context.workflow_input) + return self.func(*input.args, **input.kwargs) + + +class HatchetAwaitable(HatchetCallableBase[P, Awaitable[T]]): + async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + input = json.dumps({args: args, kwargs: kwargs}) + client = self._.options.hatchet + return (await client.admin.run(self.action_name, input)).result() + + async def _run(self, context: ContextAioImpl) -> T: + input = json.loads(context.workflow_input) + return await self.func(*input.args, **input.kwargs) + + +class Options(BaseModel): + # pydantic configuration + model_config = ConfigDict(arbitrary_types_allowed=True) + + hatchet: Any = Field( + default=None, exclude=True + ) # circular dependencies trying to import v2.hatchet.Hatchet + durable: bool = Field(default=False) + auto_register: bool = Field(default=True) + on_failure: Optional[HatchetCallableBase] = Field(default=None, exclude=True) + + # triggering options + on_events: List[str] = Field(default=[]) + on_crons: List[str] = Field(default=[]) + + # metadata + version: str = Field(default="") + + # timeout + execution_timeout: str = Field(default="60m", alias="timeout") + schedule_timeout: str = Field(default="5m") + + # execution + sticky: Optional[StickyStrategy] = Field(default=None) + retries: int = Field(default=0, ge=0) + ratelimits: List[RateLimit] = Field(default=[]) + priority: Optional[int] = Field(default=None, alias="default_priority", ge=1, le=3) + desired_worker_labels: Dict[str, DesiredWorkerLabel] = Field(default=dict()) + concurrency: Optional[ConcurrencyFunction] = Field(default=None) + + @computed_field + @property + def ratelimits_proto(self) -> List[CreateStepRateLimit]: + return [ + CreateStepRateLimit(key=limit.key, units=limit.units) + for limit in self.ratelimits + ] + + @computed_field + @property + def desired_worker_labels_proto(self) -> Dict[str, DesiredWorkerLabels]: + labels = dict() + for key, d in self.desired_worker_labels.items(): + value = d.get("value", None) + labels[key] = DesiredWorkerLabels( strValue=str(value) if not isinstance(value, int) else None, intValue=value if isinstance(value, int) else None, - required=d["required"] if "required" in d else None, - weight=d["weight"] if "weight" in d else None, - comparator=d["comparator"] if "comparator" in d else None, - ) - self.sticky = sticky - self.default_priority = default_priority - self.durable = durable - self.function_name = name.lower() or str(func.__name__).lower() - self.function_version = version - self.function_on_events = on_events - self.function_on_crons = on_crons - self.function_timeout = timeout - self.function_schedule_timeout = schedule_timeout - self.function_retries = retries - self.function_rate_limits = limits - self.function_concurrency = concurrency - self.function_on_failure = on_failure - self.function_namespace = "default" - self.function_auto_register = auto_register - - self.is_coroutine = False - - if asyncio.iscoroutinefunction(func): - self.is_coroutine = True - - def __call__(self, context: Context) -> T: - return self.func(context) - - def with_namespace(self, namespace: str): - if namespace is not None and namespace != "": - self.function_namespace = namespace - self.function_name = namespace + self.function_name - - def to_workflow_opts(self) -> CreateWorkflowVersionOpts: - kind: WorkflowKind = WorkflowKind.FUNCTION - - if self.durable: - kind = WorkflowKind.DURABLE - - on_failure_job: CreateWorkflowJobOpts | None = None - - if self.function_on_failure is not None: - on_failure_job = CreateWorkflowJobOpts( - name=self.function_name + "-on-failure", - steps=[ - self.function_on_failure.to_step(), - ], + required=d.get("required", None), + weight=d.get("weight", None), + comparator=d.get("comparator", None), ) + return labels - concurrency: WorkflowConcurrencyOpts | None = None - if self.function_concurrency is not None: - self.function_concurrency.set_namespace(self.function_namespace) - concurrency = WorkflowConcurrencyOpts( - action=self.function_concurrency.get_action_name(), - max_runs=self.function_concurrency.max_runs, - limit_strategy=self.function_concurrency.limit_strategy, - ) +class CallableMetadata(BaseModel): + name: str + namespace: str + sourceloc: str # source location of the callable + options: Options - validated_priority = ( - max(1, min(3, self.default_priority)) if self.default_priority else None - ) - if validated_priority != self.default_priority: - logger.warning( - "Warning: Default Priority Must be between 1 and 3 -- inclusively. Adjusted to be within the range." - ) - return CreateWorkflowVersionOpts( - name=self.function_name, - kind=kind, - version=self.function_version, - event_triggers=self.function_on_events, - cron_triggers=self.function_on_crons, - schedule_timeout=self.function_schedule_timeout, - sticky=self.sticky, - on_failure_job=on_failure_job, - concurrency=concurrency, - jobs=[ - CreateWorkflowJobOpts( - name=self.function_name, - steps=[ - self.to_step(), - ], - ) - ], - default_priority=validated_priority, - ) +# Context variable used for propagating hatchet context. +# The type of the variable is CallableContext. +_callable_cv = ContextVar("hatchet.callable") - def to_step(self) -> CreateWorkflowStepOpts: - return CreateWorkflowStepOpts( - readable_id=self.function_name, - action=self.get_action_name(), - timeout=self.function_timeout, - inputs="{}", - parents=[], - retries=self.function_retries, - rate_limits=self.function_rate_limits, - worker_labels=self.function_desired_worker_labels, - ) - def get_action_name(self) -> str: - return self.function_namespace + ":" + self.function_name +# The context object to be propagated between parent/child workflows. +class CallableContext(BaseModel): + # pydantic configuration + model_config = ConfigDict(arbitrary_types_allowed=True) + + caller: Optional["HatchetCallable[P,T]"] = None + workflow_run_id: str # caller's workflow run id + step_run_id: str # caller's step run id + + @staticmethod + def cv() -> ContextVar: + return _callable_cv + + @staticmethod + def current() -> Optional["CallableContext"]: + try: + cv: ContextVar = CallableContext.cv() + return cv.get() + except LookupError: + return None T = TypeVar("T") @@ -173,30 +300,33 @@ class TriggerOptions(TypedDict): class DurableContext(Context): - def run( - self, - function: Union[str, HatchetCallable[T]], - input: dict = {}, - key: str = None, - options: TriggerOptions = None, - ) -> "RunRef[T]": - worker_id = self.worker.id() - - workflow_name = function - - if not isinstance(function, str): - workflow_name = function.function_name - - # if ( - # options is not None - # and "sticky" in options - # and options["sticky"] == True - # and not self.worker.has_workflow(workflow_name) - # ): - # raise Exception( - # f"cannot run with sticky: workflow {workflow_name} is not registered on the worker" - # ) + pass + + +# def run( +# self, +# function: Union[str, HatchetCallable[T]], +# input: dict = {}, +# key: str = None, +# options: TriggerOptions = None, +# ) -> "RunRef[T]": +# worker_id = self.worker.id() + +# workflow_name = function + +# if not isinstance(function, str): +# workflow_name = function.function_name + +# # if ( +# # options is not None +# # and "sticky" in options +# # and options["sticky"] == True +# # and not self.worker.has_workflow(workflow_name) +# # ): +# # raise Exception( +# # f"cannot run with sticky: workflow {workflow_name} is not registered on the worker" +# # ) - trigger_options = self._prepare_workflow_options(key, options, worker_id) +# trigger_options = self._prepare_workflow_options(key, options, worker_id) - return self.admin_client.run(function, input, trigger_options) +# return self.admin_client.run(function, input, trigger_options) diff --git a/hatchet_sdk/v2/hatchet.py b/hatchet_sdk/v2/hatchet.py index 9c866ba8..f55a3fdb 100644 --- a/hatchet_sdk/v2/hatchet.py +++ b/hatchet_sdk/v2/hatchet.py @@ -1,222 +1,179 @@ -from typing import Callable, List, Optional, TypeVar +import functools +import inspect +from typing import Callable, List, Optional, ParamSpec, TypeVar +import hatchet_sdk.hatchet as v1 +import hatchet_sdk.v2.callable as v2_callable from hatchet_sdk.context import Context from hatchet_sdk.contracts.workflows_pb2 import ConcurrencyLimitStrategy, StickyStrategy -from hatchet_sdk.hatchet import Hatchet as HatchetV1 -from hatchet_sdk.hatchet import workflow + +# import Hatchet as HatchetV1 +# from hatchet_sdk.hatchet import workflow from hatchet_sdk.labels import DesiredWorkerLabel from hatchet_sdk.rate_limit import RateLimit -from hatchet_sdk.v2.callable import HatchetCallable -from hatchet_sdk.v2.concurrency import ConcurrencyFunction -from hatchet_sdk.worker.worker import register_on_worker from ..worker import Worker -T = TypeVar("T") +# from hatchet_sdk.v2.concurrency import ConcurrencyFunction +# from hatchet_sdk.worker.worker import register_on_worker -def function( - name: str = "", - auto_register: bool = True, - on_events: list | None = None, - on_crons: list | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy = None, - retries: int = 0, - rate_limits: List[RateLimit] | None = None, - desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, - concurrency: ConcurrencyFunction | None = None, - on_failure: Optional["HatchetCallable"] = None, - default_priority: int | None = None, -): - def inner(func: Callable[[Context], T]) -> HatchetCallable[T]: - return HatchetCallable( - func=func, - name=name, - auto_register=auto_register, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - sticky=sticky, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - concurrency=concurrency, - on_failure=on_failure, - default_priority=default_priority, - ) - - return inner - - -def durable( - name: str = "", - auto_register: bool = True, - on_events: list | None = None, - on_crons: list | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy = None, - retries: int = 0, - rate_limits: List[RateLimit] | None = None, - desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, - concurrency: ConcurrencyFunction | None = None, - on_failure: HatchetCallable | None = None, - default_priority: int | None = None, -): - def inner(func: HatchetCallable) -> HatchetCallable: - func.durable = True - - f = function( - name=name, - auto_register=auto_register, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - sticky=sticky, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - concurrency=concurrency, - on_failure=on_failure, - default_priority=default_priority, - ) - - resp = f(func) - - resp.durable = True - - return resp - - return inner - - -def concurrency( - name: str = "concurrency", - max_runs: int = 1, - limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, -): - def inner(func: Callable[[Context], str]) -> ConcurrencyFunction: - return ConcurrencyFunction(func, name, max_runs, limit_strategy) - - return inner - - -class Hatchet(HatchetV1): - dag = staticmethod(workflow) - concurrency = staticmethod(concurrency) - - functions: List[HatchetCallable] = [] +T = TypeVar("T") +P = ParamSpec("P") + + +# def durable( +# name: str = "", +# auto_register: bool = True, +# on_events: list | None = None, +# on_crons: list | None = None, +# version: str = "", +# timeout: str = "60m", +# schedule_timeout: str = "5m", +# sticky: StickyStrategy = None, +# retries: int = 0, +# rate_limits: List[RateLimit] | None = None, +# desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, +# concurrency: v2.concurrency.ConcurrencyFunction | None = None, +# on_failure: v2.callable.HatchetCallable | None = None, +# default_priority: int | None = None, +# ): +# def inner(func: v2.callable.HatchetCallable) -> v2.callable.HatchetCallable: +# func.durable = True + +# f = function( +# name=name, +# auto_register=auto_register, +# on_events=on_events, +# on_crons=on_crons, +# version=version, +# timeout=timeout, +# schedule_timeout=schedule_timeout, +# sticky=sticky, +# retries=retries, +# rate_limits=rate_limits, +# desired_worker_labels=desired_worker_labels, +# concurrency=concurrency, +# on_failure=on_failure, +# default_priority=default_priority, +# ) + +# resp = f(func) + +# resp.durable = True + +# return resp + +# return inner + + +# def concurrency( +# name: str = "concurrency", +# max_runs: int = 1, +# limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, +# ): +# def inner(func: Callable[[Context], str]) -> v2.concurrency.ConcurrencyFunction: +# return v2.concurrency.ConcurrencyFunction(func, name, max_runs, limit_strategy) + +# return inner + + +class Hatchet(v1.Hatchet): + dag = staticmethod(v1.workflow) + # concurrency = staticmethod(concurrency) + + functions: List[v2_callable.HatchetCallable] = [] def function( self, name: str = "", - auto_register: bool = True, - on_events: list | None = None, - on_crons: list | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - retries: int = 0, - rate_limits: List[RateLimit] | None = None, - desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, - concurrency: ConcurrencyFunction | None = None, - on_failure: Optional["HatchetCallable"] = None, - default_priority: int | None = None, - ): - resp = function( - name=name, - auto_register=auto_register, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - concurrency=concurrency, - on_failure=on_failure, - default_priority=default_priority, - ) - - def wrapper(func: Callable[[Context], T]) -> HatchetCallable[T]: - wrapped_resp = resp(func) - - if wrapped_resp.function_auto_register: - self.functions.append(wrapped_resp) - - wrapped_resp.with_namespace(self._client.config.namespace) - - return wrapped_resp - - return wrapper - - def durable( - self, - name: str = "", - auto_register: bool = True, - on_events: list | None = None, - on_crons: list | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy = None, - retries: int = 0, - rate_limits: List[RateLimit] | None = None, - desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, - concurrency: ConcurrencyFunction | None = None, - on_failure: Optional["HatchetCallable"] = None, - default_priority: int | None = None, - ) -> Callable[[HatchetCallable], HatchetCallable]: - resp = durable( - name=name, - auto_register=auto_register, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - sticky=sticky, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - concurrency=concurrency, - on_failure=on_failure, - default_priority=default_priority, - ) - - def wrapper(func: Callable[[Context], T]) -> HatchetCallable[T]: - wrapped_resp = resp(func) - - if wrapped_resp.function_auto_register: - self.functions.append(wrapped_resp) - - wrapped_resp.with_namespace(self._client.config.namespace) - - return wrapped_resp - - return wrapper - - def worker( - self, name: str, max_runs: int | None = None, labels: dict[str, str | int] = {} + namespace: str = "default", + options: v2_callable.Options = v2_callable.Options(), ): - worker = Worker( - name=name, - max_runs=max_runs, - labels=labels, - config=self._client.config, - debug=self._client.debug, - ) - - for func in self.functions: - register_on_worker(func, worker) - - return worker + options.hatchet = self + + def inner(func: Callable[P, T]) -> v2_callable.HatchetCallable[P, T]: + if inspect.iscoroutine(func): + callable = v2_callable.HatchetAwaitable( + func=func, + name=name, + namespace=namespace, + options=options, + ) + return functools.update_wrapper(callable, func) + elif inspect.isfunction(func): + callable = v2_callable.HatchetCallable( + func=func, + name=name, + namespace=namespace, + options=options, + ) + return functools.update_wrapper(callable, func) + else: + raise TypeError( + "the @function decorator can only be applied to functions (def) and async functions (async def)" + ) + + return inner + + # def durable( + # self, + # name: str = "", + # auto_register: bool = True, + # on_events: list | None = None, + # on_crons: list | None = None, + # version: str = "", + # timeout: str = "60m", + # schedule_timeout: str = "5m", + # sticky: StickyStrategy = None, + # retries: int = 0, + # rate_limits: List[RateLimit] | None = None, + # desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, + # concurrency: v2.concurrency.ConcurrencyFunction | None = None, + # on_failure: Optional["HatchetCallable"] = None, + # default_priority: int | None = None, + # ) -> Callable[[v2.callable.HatchetCallable], v2.callable.HatchetCallable]: + # resp = durable( + # name=name, + # auto_register=auto_register, + # on_events=on_events, + # on_crons=on_crons, + # version=version, + # timeout=timeout, + # schedule_timeout=schedule_timeout, + # sticky=sticky, + # retries=retries, + # rate_limits=rate_limits, + # desired_worker_labels=desired_worker_labels, + # concurrency=concurrency, + # on_failure=on_failure, + # default_priority=default_priority, + # ) + + # def wrapper(func: Callable[[Context], T]) -> v2.callable.HatchetCallable[T]: + # wrapped_resp = resp(func) + + # if wrapped_resp.function_auto_register: + # self.functions.append(wrapped_resp) + + # wrapped_resp.with_namespace(self._client.config.namespace) + + # return wrapped_resp + + # return wrapper + + # def worker( + # self, name: str, max_runs: int | None = None, labels: dict[str, str | int] = {} + # ): + # worker = Worker( + # name=name, + # max_runs=max_runs, + # labels=labels, + # config=self._client.config, + # debug=self._client.debug, + # ) + + # for func in self.functions: + # register_on_worker(func, worker) + + # return worker diff --git a/hatchet_sdk/worker/runner/runner.py b/hatchet_sdk/worker/runner/runner.py index d37da955..c87de198 100644 --- a/hatchet_sdk/worker/runner/runner.py +++ b/hatchet_sdk/worker/runner/runner.py @@ -33,7 +33,8 @@ ) from hatchet_sdk.loader import ClientConfig from hatchet_sdk.logger import logger -from hatchet_sdk.v2.callable import DurableContext + +# from hatchet_sdk.v2.callable import DurableContext from hatchet_sdk.worker.action_listener_process import ActionEvent wr: contextvars.ContextVar[str | None] = contextvars.ContextVar( @@ -326,32 +327,32 @@ async def handle_start_step_run(self, action: Action): # Find the corresponding action function from the registry action_func = self.action_registry.get(action_name) - context: Context | DurableContext - - if hasattr(action_func, "durable") and action_func.durable: - context = DurableContext( - action, - self.dispatcher_client, - self.admin_client, - self.client.event, - self.client.rest, - self.client.workflow_listener, - self.workflow_run_event_listener, - self.worker_context, - self.client.config.namespace, - ) - else: - context = Context( - action, - self.dispatcher_client, - self.admin_client, - self.client.event, - self.client.rest, - self.client.workflow_listener, - self.workflow_run_event_listener, - self.worker_context, - self.client.config.namespace, - ) + context: Context #| DurableContext + + # if hasattr(action_func, "durable") and action_func.durable: + # context = DurableContext( + # action, + # self.dispatcher_client, + # self.admin_client, + # self.client.event, + # self.client.rest, + # self.client.workflow_listener, + # self.workflow_run_event_listener, + # self.worker_context, + # self.client.config.namespace, + # ) + # else: + context = Context( + action, + self.dispatcher_client, + self.admin_client, + self.client.event, + self.client.rest, + self.client.workflow_listener, + self.workflow_run_event_listener, + self.worker_context, + self.client.config.namespace, + ) self.contexts[action.step_run_id] = context diff --git a/hatchet_sdk/worker/worker.py b/hatchet_sdk/worker/worker.py index 315f2f4a..b426d394 100644 --- a/hatchet_sdk/worker/worker.py +++ b/hatchet_sdk/worker/worker.py @@ -13,7 +13,8 @@ from hatchet_sdk.contracts.workflows_pb2 import CreateWorkflowVersionOpts from hatchet_sdk.loader import ClientConfig from hatchet_sdk.logger import logger -from hatchet_sdk.v2.callable import HatchetCallable + +# from hatchet_sdk.v2.callable import HatchetCallable from hatchet_sdk.worker.action_listener_process import worker_action_listener_process from hatchet_sdk.worker.runner.run_loop_manager import WorkerActionRunLoopManager from hatchet_sdk.workflow import WorkflowMeta @@ -62,8 +63,8 @@ def __post_init__(self): self.name = self.client.config.namespace + self.name self._setup_signal_handlers() - def register_function(self, action: str, func: HatchetCallable): - self.action_registry[action] = func + # def register_function(self, action: str, func: HatchetCallable): + # self.action_registry[action] = func def register_workflow_from_opts(self, name: str, opts: CreateWorkflowVersionOpts): try: @@ -285,20 +286,20 @@ def exit_forcefully(self): ) # Exit immediately TODO - should we exit with 1 here, there may be other workers to cleanup -def register_on_worker(callable: HatchetCallable, worker: Worker): - worker.register_function(callable.get_action_name(), callable) +# def register_on_worker(callable: HatchetCallable, worker: Worker): +# worker.register_function(callable.get_action_name(), callable) - if callable.function_on_failure is not None: - worker.register_function( - callable.function_on_failure.get_action_name(), callable.function_on_failure - ) +# if callable.function_on_failure is not None: +# worker.register_function( +# callable.function_on_failure.get_action_name(), callable.function_on_failure +# ) - if callable.function_concurrency is not None: - worker.register_function( - callable.function_concurrency.get_action_name(), - callable.function_concurrency, - ) +# if callable.function_concurrency is not None: +# worker.register_function( +# callable.function_concurrency.get_action_name(), +# callable.function_concurrency, +# ) - opts = callable.to_workflow_opts() +# opts = callable.to_workflow_opts() - worker.register_workflow_from_opts(opts.name, opts) +# worker.register_workflow_from_opts(opts.name, opts) diff --git a/pyproject.toml b/pyproject.toml index 4e94f09c..eb96e1bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,9 @@ known_third_party = [ "pyyaml", "urllib3", ] +skip = [ + "hatchet_sdk/contracts", +] [tool.poetry.scripts] api = "examples.api.api:main" diff --git a/tests/v2/__init__.py b/tests/v2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/v2/test_traces.py b/tests/v2/test_traces.py new file mode 100644 index 00000000..523ea677 --- /dev/null +++ b/tests/v2/test_traces.py @@ -0,0 +1,25 @@ +def get_client(): + import dotenv + + from hatchet_sdk.v2.hatchet import Hatchet + + dotenv.load_dotenv() + return Hatchet(debug=True) + + +hatchet = get_client() + + +@hatchet.function() +async def foo(a: int): + return bar(b=3) + + +@hatchet.function() +def bar(b: int): + return b + + +def test_trace(): + import json + print(json.dumps(foo._debug(), indent=2)) From f8ca09f5ef3a2cc2205d6ada24bca7cb72330b13 Mon Sep 17 00:00:00 2001 From: Hanwen Wu Date: Tue, 10 Sep 2024 22:34:46 -0400 Subject: [PATCH 02/12] hooking it up to the runner, pening more investigations of the triggering part --- hatchet_sdk/runtime/registry.py | 18 +++++----- hatchet_sdk/v2/callable.py | 16 +++++---- hatchet_sdk/v2/hatchet.py | 52 +++++++++++++++++------------ hatchet_sdk/worker/runner/runner.py | 3 ++ hatchet_sdk/worker/worker.py | 32 +++++++++--------- tests/v2/test_simple.py | 40 ++++++++++++++++++++++ tests/v2/test_traces.py | 25 -------------- 7 files changed, 108 insertions(+), 78 deletions(-) create mode 100644 tests/v2/test_simple.py delete mode 100644 tests/v2/test_traces.py diff --git a/hatchet_sdk/runtime/registry.py b/hatchet_sdk/runtime/registry.py index 5dce3a7a..a79e80ec 100644 --- a/hatchet_sdk/runtime/registry.py +++ b/hatchet_sdk/runtime/registry.py @@ -1,19 +1,19 @@ from typing import Dict, List +import hatchet_sdk.v2.callable as v2 + class ActionRegistry: + """A registry from action names (e.g. 'namespace:func') to Hatchet's callables. + + This is intended to be used per Hatchet client instance. + """ - _registry: Dict[str, "HatchetCallable"] = dict() + registry: Dict[str, v2.HatchetCallableBase] = dict() - def register(self, callable: "HatchetCallable") -> str: + def register(self, callable: v2.HatchetCallableBase) -> str: key = "{namespace}:{name}".format( namespace=callable._.namespace, name=callable._.name ) - self._registry[key] = callable + self.registry[key] = callable return key - - def list(self) -> List[str]: - return list(self._registry.keys()) - - -global_registry = ActionRegistry() diff --git a/hatchet_sdk/v2/callable.py b/hatchet_sdk/v2/callable.py index 6e8bb5ac..8cfed815 100644 --- a/hatchet_sdk/v2/callable.py +++ b/hatchet_sdk/v2/callable.py @@ -74,7 +74,6 @@ def __init__( options=options, sourceloc=self.sourceloc, ) - self.action_name = registry.global_registry.register(self) @property def sourceloc(self) -> str: @@ -151,6 +150,7 @@ def _to_step_proto(self) -> CreateWorkflowStepOpts: return step def _to_trigger_proto(self) -> Optional[TriggerWorkflowOptions]: + return None ctx = CallableContext.current() if not ctx: return None @@ -180,25 +180,29 @@ def _run(self, context: BaseContext): class HatchetCallable(HatchetCallableBase[P, T]): def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: - input = json.dumps({args: args, kwargs: kwargs}) + print(f"trigering {self.action_name}") + input = json.dumps({"args": args, "kwargs": kwargs}) client = self._.options.hatchet ref = client.admin.run( - self.action_name, input=input, options=self._to_trigger_proto() + self._.name, input=input, options=self._to_trigger_proto() ) - return asyncio.gather(ref.result()).result + return asyncio.run(ref.result()) def _run(self, context: Context) -> T: + print(f"running {self.action_name}") input = json.loads(context.workflow_input) return self.func(*input.args, **input.kwargs) class HatchetAwaitable(HatchetCallableBase[P, Awaitable[T]]): async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: - input = json.dumps({args: args, kwargs: kwargs}) + print(f"trigering {self.action_name}") + input = json.dumps({"args": args, "kwargs": kwargs}) client = self._.options.hatchet - return (await client.admin.run(self.action_name, input)).result() + return await client.admin.run(self._.name, input).result() async def _run(self, context: ContextAioImpl) -> T: + print(f"trigering {self.action_name}") input = json.loads(context.workflow_input) return await self.func(*input.args, **input.kwargs) diff --git a/hatchet_sdk/v2/hatchet.py b/hatchet_sdk/v2/hatchet.py index f55a3fdb..5f5777b2 100644 --- a/hatchet_sdk/v2/hatchet.py +++ b/hatchet_sdk/v2/hatchet.py @@ -1,8 +1,9 @@ import functools import inspect -from typing import Callable, List, Optional, ParamSpec, TypeVar +from typing import Callable, List, Optional, ParamSpec, TypeVar, Dict import hatchet_sdk.hatchet as v1 +import hatchet_sdk.runtime.registry as hatchet_registry import hatchet_sdk.v2.callable as v2_callable from hatchet_sdk.context import Context from hatchet_sdk.contracts.workflows_pb2 import ConcurrencyLimitStrategy, StickyStrategy @@ -15,7 +16,7 @@ from ..worker import Worker # from hatchet_sdk.v2.concurrency import ConcurrencyFunction -# from hatchet_sdk.worker.worker import register_on_worker +from hatchet_sdk.worker.worker import register_on_worker T = TypeVar("T") @@ -79,10 +80,10 @@ class Hatchet(v1.Hatchet): - dag = staticmethod(v1.workflow) + # dag = staticmethod(v1.workflow) # concurrency = staticmethod(concurrency) - functions: List[v2_callable.HatchetCallable] = [] + _registry: hatchet_registry.ActionRegistry = hatchet_registry.ActionRegistry() def function( self, @@ -93,14 +94,16 @@ def function( options.hatchet = self def inner(func: Callable[P, T]) -> v2_callable.HatchetCallable[P, T]: - if inspect.iscoroutine(func): + if inspect.iscoroutinefunction(func): callable = v2_callable.HatchetAwaitable( func=func, name=name, namespace=namespace, options=options, ) - return functools.update_wrapper(callable, func) + callable = functools.update_wrapper(callable, func) + callable.action_name = self._registry.register(callable) + return callable elif inspect.isfunction(func): callable = v2_callable.HatchetCallable( func=func, @@ -108,7 +111,9 @@ def inner(func: Callable[P, T]) -> v2_callable.HatchetCallable[P, T]: namespace=namespace, options=options, ) - return functools.update_wrapper(callable, func) + callable = functools.update_wrapper(callable, func) + callable.action_name = self._registry.register(callable) + return callable else: raise TypeError( "the @function decorator can only be applied to functions (def) and async functions (async def)" @@ -162,18 +167,21 @@ def inner(func: Callable[P, T]) -> v2_callable.HatchetCallable[P, T]: # return wrapper - # def worker( - # self, name: str, max_runs: int | None = None, labels: dict[str, str | int] = {} - # ): - # worker = Worker( - # name=name, - # max_runs=max_runs, - # labels=labels, - # config=self._client.config, - # debug=self._client.debug, - # ) - - # for func in self.functions: - # register_on_worker(func, worker) - - # return worker + def worker( + self, + name: str, + max_runs: Optional[int] = None, + labels: Dict[str, str | int] = {}, + ): + worker = Worker( + name=name, + max_runs=max_runs, + labels=labels, + config=self._client.config, + debug=self._client.debug, + ) + + for func in self._registry.registry.values(): + register_on_worker(func, worker) + + return worker diff --git a/hatchet_sdk/worker/runner/runner.py b/hatchet_sdk/worker/runner/runner.py index c87de198..54294837 100644 --- a/hatchet_sdk/worker/runner/runner.py +++ b/hatchet_sdk/worker/runner/runner.py @@ -279,6 +279,9 @@ async def async_wrapped_action_func( wr.set(context.workflow_run_id()) sr.set(context.step_run_id) + if hasattr(action_func, "_run"): + action_func = functools.partial(action_func._run, action_func) + try: if ( hasattr(action_func, "is_coroutine") and action_func.is_coroutine diff --git a/hatchet_sdk/worker/worker.py b/hatchet_sdk/worker/worker.py index b426d394..0160dcda 100644 --- a/hatchet_sdk/worker/worker.py +++ b/hatchet_sdk/worker/worker.py @@ -63,9 +63,10 @@ def __post_init__(self): self.name = self.client.config.namespace + self.name self._setup_signal_handlers() - # def register_function(self, action: str, func: HatchetCallable): - # self.action_registry[action] = func + def register_function(self, action: str, func): + self.action_registry[action] = func + # TODO: why do it on the worker, it seems unrelated. we should do that on the registry def register_workflow_from_opts(self, name: str, opts: CreateWorkflowVersionOpts): try: self.client.admin.put_workflow(opts.name, opts) @@ -286,20 +287,19 @@ def exit_forcefully(self): ) # Exit immediately TODO - should we exit with 1 here, there may be other workers to cleanup -# def register_on_worker(callable: HatchetCallable, worker: Worker): -# worker.register_function(callable.get_action_name(), callable) +def register_on_worker(callable, worker: Worker): + worker.register_function(callable.action_name, callable) -# if callable.function_on_failure is not None: -# worker.register_function( -# callable.function_on_failure.get_action_name(), callable.function_on_failure -# ) + # if callable.function_on_failure is not None: + # worker.register_function( + # callable.function_on_failure.action_name, callable.function_on_failure + # ) -# if callable.function_concurrency is not None: -# worker.register_function( -# callable.function_concurrency.get_action_name(), -# callable.function_concurrency, -# ) + # if callable.function_concurrency is not None: + # worker.register_function( + # callable.function_concurrency.action_name, + # callable.function_concurrency, + # ) -# opts = callable.to_workflow_opts() - -# worker.register_workflow_from_opts(opts.name, opts) + opts = callable._to_workflow_proto() + worker.register_workflow_from_opts(opts.name, opts) diff --git a/tests/v2/test_simple.py b/tests/v2/test_simple.py new file mode 100644 index 00000000..6fb1df4a --- /dev/null +++ b/tests/v2/test_simple.py @@ -0,0 +1,40 @@ +import asyncio +import pytest + + +def get_client(): + import dotenv + + from hatchet_sdk.v2.hatchet import Hatchet + + dotenv.load_dotenv() + return Hatchet(debug=True) + + +hatchet = get_client() + + +@hatchet.function() +async def foo(a: int): + print(f"in foo: a={a}") + return bar(b=3) + + +@hatchet.function() +def bar(b: int): + print(f"in bar: b={b}") + return b + + +# def test_trace(): +# import json + +# print(json.dumps(foo._debug(), indent=2)) + + +@pytest.mark.asyncio(scope="session") +async def test_run(): + worker = hatchet.worker("worker", max_runs=5) + c = foo(a=1) + worker.start() + print(await c) diff --git a/tests/v2/test_traces.py b/tests/v2/test_traces.py deleted file mode 100644 index 523ea677..00000000 --- a/tests/v2/test_traces.py +++ /dev/null @@ -1,25 +0,0 @@ -def get_client(): - import dotenv - - from hatchet_sdk.v2.hatchet import Hatchet - - dotenv.load_dotenv() - return Hatchet(debug=True) - - -hatchet = get_client() - - -@hatchet.function() -async def foo(a: int): - return bar(b=3) - - -@hatchet.function() -def bar(b: int): - return b - - -def test_trace(): - import json - print(json.dumps(foo._debug(), indent=2)) From 5491960581809a2f3d1ebaa2f9af17ec09c74929 Mon Sep 17 00:00:00 2001 From: Hanwen Wu Date: Thu, 12 Sep 2024 09:01:37 -0400 Subject: [PATCH 03/12] more draft work on the runtime --- hatchet_sdk/clients/admin.py | 21 +- hatchet_sdk/runtime/__init__.py | 0 hatchet_sdk/runtime/admin.py | 20 -- hatchet_sdk/runtime/registry.py | 19 -- hatchet_sdk/v2/callable.py | 209 ++++++++++---------- hatchet_sdk/v2/hatchet.py | 295 +++++++++++++++++------------ hatchet_sdk/v2/runtime/config.py | 1 + hatchet_sdk/v2/runtime/context.py | 114 +++++++++++ hatchet_sdk/v2/runtime/logging.py | 30 +++ hatchet_sdk/v2/runtime/registry.py | 31 +++ hatchet_sdk/v2/runtime/runner.py | 7 + hatchet_sdk/v2/runtime/worker.py | 195 +++++++++++++++++++ tests/v2/test_simple.py | 1 + tests/v2/test_worker.py | 28 +++ 14 files changed, 704 insertions(+), 267 deletions(-) delete mode 100644 hatchet_sdk/runtime/__init__.py delete mode 100644 hatchet_sdk/runtime/admin.py delete mode 100644 hatchet_sdk/runtime/registry.py create mode 100644 hatchet_sdk/v2/runtime/config.py create mode 100644 hatchet_sdk/v2/runtime/context.py create mode 100644 hatchet_sdk/v2/runtime/logging.py create mode 100644 hatchet_sdk/v2/runtime/registry.py create mode 100644 hatchet_sdk/v2/runtime/runner.py create mode 100644 hatchet_sdk/v2/runtime/worker.py create mode 100644 tests/v2/test_worker.py diff --git a/hatchet_sdk/clients/admin.py b/hatchet_sdk/clients/admin.py index 35d1715b..afabe047 100644 --- a/hatchet_sdk/clients/admin.py +++ b/hatchet_sdk/clients/admin.py @@ -339,13 +339,9 @@ def run_workflow( if self.namespace != "" and not workflow_name.startswith(self.namespace): workflow_name = f"{self.namespace}{workflow_name}" - request = self._prepare_workflow_request(workflow_name, input, options) - resp: TriggerWorkflowResponse = self.client.TriggerWorkflow( - request, - metadata=get_metadata(self.token), - ) + id = self.trigger_workflow(workflow_name, input, options) return WorkflowRunRef( - workflow_run_id=resp.workflow_run_id, + workflow_run_id=id, workflow_listener=self.pooled_workflow_listener, workflow_run_event_listener=self.listener_client, ) @@ -355,6 +351,19 @@ def run_workflow( raise ValueError(f"gRPC error: {e}") + def trigger_workflow( + self, + workflow_name: str, + input, + options: TriggerWorkflowOptions = None, + ) -> str: + request = self._prepare_workflow_request(workflow_name, input, options) + resp: TriggerWorkflowResponse = self.client.TriggerWorkflow( + request, + metadata=get_metadata(self.token), + ) + return resp.workflow_run_id + def run( self, function: Union[str, Callable[[Any], T]], diff --git a/hatchet_sdk/runtime/__init__.py b/hatchet_sdk/runtime/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/hatchet_sdk/runtime/admin.py b/hatchet_sdk/runtime/admin.py deleted file mode 100644 index 80ab83ad..00000000 --- a/hatchet_sdk/runtime/admin.py +++ /dev/null @@ -1,20 +0,0 @@ - -# import hatchet_sdk.v2.callable as sdk -# import hatchet_sdk.clients.admin as client - -# from hatchet_sdk.contracts.workflows_pb2 import ( -# CreateStepRateLimit, -# CreateWorkflowJobOpts, -# CreateWorkflowStepOpts, -# CreateWorkflowVersionOpts, -# DesiredWorkerLabels, -# StickyStrategy, -# WorkflowConcurrencyOpts, -# WorkflowKind, -# ) - -# async def put_workflow(callable: sdk.HatchetCallable, client: client.AdminClient): -# options = callable._.options - -# kind: WorkflowKind = WorkflowKind.DURABLE if options.durable else WorkflowKind.FUNCTION - diff --git a/hatchet_sdk/runtime/registry.py b/hatchet_sdk/runtime/registry.py deleted file mode 100644 index a79e80ec..00000000 --- a/hatchet_sdk/runtime/registry.py +++ /dev/null @@ -1,19 +0,0 @@ -from typing import Dict, List - -import hatchet_sdk.v2.callable as v2 - - -class ActionRegistry: - """A registry from action names (e.g. 'namespace:func') to Hatchet's callables. - - This is intended to be used per Hatchet client instance. - """ - - registry: Dict[str, v2.HatchetCallableBase] = dict() - - def register(self, callable: v2.HatchetCallableBase) -> str: - key = "{namespace}:{name}".format( - namespace=callable._.namespace, name=callable._.name - ) - self.registry[key] = callable - return key diff --git a/hatchet_sdk/v2/callable.py b/hatchet_sdk/v2/callable.py index 8cfed815..0565781e 100644 --- a/hatchet_sdk/v2/callable.py +++ b/hatchet_sdk/v2/callable.py @@ -4,9 +4,11 @@ import inspect import json from collections.abc import Awaitable, Callable -from contextvars import ContextVar, copy_context + +# from contextvars import ContextVar, copy_context from dataclasses import dataclass -from datetime import timedelta + +# from datetime import timedelta from typing import ( Any, Dict, @@ -25,6 +27,7 @@ from pydantic import BaseModel, ConfigDict, Field, computed_field from pydantic.json_schema import SkipJsonSchema +import hatchet_sdk.v2.hatchet as v2hatchet from hatchet_sdk.clients.admin import TriggerWorkflowOptions from hatchet_sdk.context import Context from hatchet_sdk.context.context import BaseContext, Context, ContextAioImpl @@ -41,8 +44,8 @@ from hatchet_sdk.labels import DesiredWorkerLabel from hatchet_sdk.logger import logger from hatchet_sdk.rate_limit import RateLimit -from hatchet_sdk.runtime import registry from hatchet_sdk.v2.concurrency import ConcurrencyFunction +from hatchet_sdk.v2.runtime import registry from hatchet_sdk.workflow_run import RunRef # from typing import TYPE_CHECKING @@ -54,47 +57,43 @@ T = TypeVar("T") P = ParamSpec("P") -# TODO: according to Python, we should just use strings. -Options = ForwardRef("Options", is_class=True) -CallableMetadata = ForwardRef("CallableMetadata", is_class=True) + +def _sourceloc(fn) -> str: + try: + return "{}:{}".format( + inspect.getsourcefile(fn), + inspect.getsourcelines(fn)[1], + ) + except: + return "" class HatchetCallableBase(Generic[P, T]): - action_name: str - func: Callable[P, T] # note that T can be an Awaitable if func is a coroutine - _: CallableMetadata def __init__( - self, *, func: Callable[P, T], name: str, namespace: str, options: Options + self, + *, + func: Callable[P, T], + name: str, + namespace: str, + client: v2hatchet.Hatchet, + options: Options, ): - self.func = func - self._ = CallableMetadata( + + self._hatchet = CallableMetadata( + # TODO: maybe use __qualname__ name=name.lower() or str(func.__name__).lower(), namespace=namespace, + sourceloc=_sourceloc(func), options=options, - sourceloc=self.sourceloc, + client=client, + func=func, + action=f"{namespace}:{name}", ) - - @property - def sourceloc(self) -> str: - try: - return "{}:{}".format( - inspect.getsourcefile(self.func), - inspect.getsourcelines(self.func)[1], - ) - except: - return "" - - # def __call__(self, context: Context) -> T: - # return self.func(context) - - # def with_namespace(self, namespace: str): - # if namespace is not None and namespace != "": - # self.function_namespace = namespace - # self.function_name = namespace + self.function_name + client.registry.add(key=self._hatchet.action, callable=self) def _to_workflow_proto(self) -> CreateWorkflowVersionOpts: - options = self._.options + options = self._hatchet.options # if self.function_on_failure is not None: # on_failure_job = CreateWorkflowJobOpts( @@ -113,7 +112,7 @@ def _to_workflow_proto(self) -> CreateWorkflowVersionOpts: # ) workflow = CreateWorkflowVersionOpts( - name=self._.name, + name=self._hatchet.name, kind=WorkflowKind.DURABLE if options.durable else WorkflowKind.FUNCTION, version=options.version, event_triggers=options.on_events, @@ -132,14 +131,16 @@ def _to_workflow_proto(self) -> CreateWorkflowVersionOpts: return workflow def _to_job_proto(self) -> CreateWorkflowJobOpts: - job = CreateWorkflowJobOpts(name=self._.name, steps=[self._to_step_proto()]) + job = CreateWorkflowJobOpts( + name=self._hatchet.name, steps=[self._to_step_proto()] + ) return job def _to_step_proto(self) -> CreateWorkflowStepOpts: - options = self._.options + options = self._hatchet.options step = CreateWorkflowStepOpts( - readable_id=self._.name, - action=self.action_name, + readable_id=self._hatchet.name, + action=self._hatchet.action, timeout=options.execution_timeout, inputs="{}", # TODO: not sure that this is, we're defining a step, not running a step parents=[], # this is a single step workflow, always empty @@ -162,9 +163,8 @@ def _to_trigger_proto(self) -> Optional[TriggerWorkflowOptions]: def _debug(self): data = { - "action_name": self.action_name, - "func": repr(self.func), - "metadata": self._.model_dump(), + "self": repr(self), + "metadata": self._hatchet._debug(), "def_proto": MessageToDict(self._to_workflow_proto()), "call_proto": ( MessageToDict(self._to_trigger_proto()) @@ -180,13 +180,14 @@ def _run(self, context: BaseContext): class HatchetCallable(HatchetCallableBase[P, T]): def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: - print(f"trigering {self.action_name}") + self._hatchet.client.logger.info(f"triggering {self._hatchet.action}") input = json.dumps({"args": args, "kwargs": kwargs}) - client = self._.options.hatchet - ref = client.admin.run( - self._.name, input=input, options=self._to_trigger_proto() + client = self._hatchet.client + ref = client.admin.trigger_workflow( + self._hatchet.name, input=input, options=self._to_trigger_proto() ) - return asyncio.run(ref.result()) + self._hatchet.client.logger.info(f"runid: {ref}") + return None def _run(self, context: Context) -> T: print(f"running {self.action_name}") @@ -211,9 +212,6 @@ class Options(BaseModel): # pydantic configuration model_config = ConfigDict(arbitrary_types_allowed=True) - hatchet: Any = Field( - default=None, exclude=True - ) # circular dependencies trying to import v2.hatchet.Hatchet durable: bool = Field(default=False) auto_register: bool = Field(default=True) on_failure: Optional[HatchetCallableBase] = Field(default=None, exclude=True) @@ -261,76 +259,93 @@ def desired_worker_labels_proto(self) -> Dict[str, DesiredWorkerLabels]: return labels -class CallableMetadata(BaseModel): +@dataclass +class CallableMetadata: + func: Callable[P, T] # the original function + name: str namespace: str + action: str sourceloc: str # source location of the callable + options: Options + client: v2hatchet.Hatchet + + def _debug(self): + return { + "func": repr(self.func), + "name": self.name, + "namespace": self.namespace, + "action": self.action, + "sourceloc": self.sourceloc, + "client": repr(self.client), + "options": self.options.model_dump(), + } -# Context variable used for propagating hatchet context. -# The type of the variable is CallableContext. -_callable_cv = ContextVar("hatchet.callable") +# # Context variable used for propagating hatchet context. +# # The type of the variable is CallableContext. +# _callable_cv = ContextVar("hatchet.callable") -# The context object to be propagated between parent/child workflows. -class CallableContext(BaseModel): - # pydantic configuration - model_config = ConfigDict(arbitrary_types_allowed=True) +# # The context object to be propagated between parent/child workflows. +# class CallableContext(BaseModel): +# # pydantic configuration +# model_config = ConfigDict(arbitrary_types_allowed=True) - caller: Optional["HatchetCallable[P,T]"] = None - workflow_run_id: str # caller's workflow run id - step_run_id: str # caller's step run id +# caller: Optional["HatchetCallable[P,T]"] = None +# workflow_run_id: str # caller's workflow run id +# step_run_id: str # caller's step run id - @staticmethod - def cv() -> ContextVar: - return _callable_cv +# @staticmethod +# def cv() -> ContextVar: +# return _callable_cv - @staticmethod - def current() -> Optional["CallableContext"]: - try: - cv: ContextVar = CallableContext.cv() - return cv.get() - except LookupError: - return None +# @staticmethod +# def current() -> Optional["CallableContext"]: +# try: +# cv: ContextVar = CallableContext.cv() +# return cv.get() +# except LookupError: +# return None -T = TypeVar("T") +# T = TypeVar("T") -class TriggerOptions(TypedDict): - additional_metadata: Dict[str, str] | None = None - sticky: bool | None = None +# class TriggerOptions(TypedDict): +# additional_metadata: Dict[str, str] | None = None +# sticky: bool | None = None -class DurableContext(Context): - pass +# class DurableContext(Context): +# pass -# def run( -# self, -# function: Union[str, HatchetCallable[T]], -# input: dict = {}, -# key: str = None, -# options: TriggerOptions = None, -# ) -> "RunRef[T]": -# worker_id = self.worker.id() +# # def run( +# # self, +# # function: Union[str, HatchetCallable[T]], +# # input: dict = {}, +# # key: str = None, +# # options: TriggerOptions = None, +# # ) -> "RunRef[T]": +# # worker_id = self.worker.id() -# workflow_name = function +# # workflow_name = function -# if not isinstance(function, str): -# workflow_name = function.function_name +# # if not isinstance(function, str): +# # workflow_name = function.function_name -# # if ( -# # options is not None -# # and "sticky" in options -# # and options["sticky"] == True -# # and not self.worker.has_workflow(workflow_name) -# # ): -# # raise Exception( -# # f"cannot run with sticky: workflow {workflow_name} is not registered on the worker" -# # ) +# # # if ( +# # # options is not None +# # # and "sticky" in options +# # # and options["sticky"] == True +# # # and not self.worker.has_workflow(workflow_name) +# # # ): +# # # raise Exception( +# # # f"cannot run with sticky: workflow {workflow_name} is not registered on the worker" +# # # ) -# trigger_options = self._prepare_workflow_options(key, options, worker_id) +# # trigger_options = self._prepare_workflow_options(key, options, worker_id) -# return self.admin_client.run(function, input, trigger_options) +# # return self.admin_client.run(function, input, trigger_options) diff --git a/hatchet_sdk/v2/hatchet.py b/hatchet_sdk/v2/hatchet.py index 5f5777b2..d21ba3a8 100644 --- a/hatchet_sdk/v2/hatchet.py +++ b/hatchet_sdk/v2/hatchet.py @@ -1,119 +1,99 @@ import functools import inspect -from typing import Callable, List, Optional, ParamSpec, TypeVar, Dict +from typing import Callable, Dict, List, Optional, ParamSpec, TypeVar import hatchet_sdk.hatchet as v1 -import hatchet_sdk.runtime.registry as hatchet_registry -import hatchet_sdk.v2.callable as v2_callable -from hatchet_sdk.context import Context -from hatchet_sdk.contracts.workflows_pb2 import ConcurrencyLimitStrategy, StickyStrategy +import hatchet_sdk.v2.runtime.config as config +import hatchet_sdk.v2.runtime.logging as logging +import hatchet_sdk.v2.runtime.registry as registry +import hatchet_sdk.v2.runtime.worker as worker +import hatchet_sdk.v2.callable as callable +import asyncio + + + +# import hatchet_sdk.runtime.registry as hatchet_registry +# import hatchet_sdk.v2.callable as v2_callable +# from hatchet_sdk.context import Context +# from hatchet_sdk.contracts.workflows_pb2 import ConcurrencyLimitStrategy, StickyStrategy # import Hatchet as HatchetV1 # from hatchet_sdk.hatchet import workflow -from hatchet_sdk.labels import DesiredWorkerLabel -from hatchet_sdk.rate_limit import RateLimit - -from ..worker import Worker +# from hatchet_sdk.labels import DesiredWorkerLabel +# from hatchet_sdk.rate_limit import RateLimit # from hatchet_sdk.v2.concurrency import ConcurrencyFunction -from hatchet_sdk.worker.worker import register_on_worker +# from hatchet_sdk.worker.worker import register_on_worker +# from ..worker import Worker T = TypeVar("T") P = ParamSpec("P") -# def durable( -# name: str = "", -# auto_register: bool = True, -# on_events: list | None = None, -# on_crons: list | None = None, -# version: str = "", -# timeout: str = "60m", -# schedule_timeout: str = "5m", -# sticky: StickyStrategy = None, -# retries: int = 0, -# rate_limits: List[RateLimit] | None = None, -# desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, -# concurrency: v2.concurrency.ConcurrencyFunction | None = None, -# on_failure: v2.callable.HatchetCallable | None = None, -# default_priority: int | None = None, -# ): -# def inner(func: v2.callable.HatchetCallable) -> v2.callable.HatchetCallable: -# func.durable = True - -# f = function( -# name=name, -# auto_register=auto_register, -# on_events=on_events, -# on_crons=on_crons, -# version=version, -# timeout=timeout, -# schedule_timeout=schedule_timeout, -# sticky=sticky, -# retries=retries, -# rate_limits=rate_limits, -# desired_worker_labels=desired_worker_labels, -# concurrency=concurrency, -# on_failure=on_failure, -# default_priority=default_priority, -# ) - -# resp = f(func) - -# resp.durable = True - -# return resp - -# return inner +class Hatchet: + def __init__( + self, + config: config.ClientConfig = config.ClientConfig(), + debug=False, + ): + # ensure a event loop is created before gRPC + try: + asyncio.get_event_loop() + finally: + pass + + self.registry = registry.ActionRegistry() + self.v1: v1.Hatchet = v1.Hatchet.from_environment( + defaults=config, + debug=debug, + ) -# def concurrency( -# name: str = "concurrency", -# max_runs: int = 1, -# limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, -# ): -# def inner(func: Callable[[Context], str]) -> v2.concurrency.ConcurrencyFunction: -# return v2.concurrency.ConcurrencyFunction(func, name, max_runs, limit_strategy) - -# return inner + @property + def admin(self): + return self.v1.admin + @property + def dispatcher(self): + return self.v1.dispatcher -class Hatchet(v1.Hatchet): - # dag = staticmethod(v1.workflow) - # concurrency = staticmethod(concurrency) + @property + def config(self): + return self.v1.config - _registry: hatchet_registry.ActionRegistry = hatchet_registry.ActionRegistry() + @property + def logger(self): + return logging.logger def function( self, name: str = "", namespace: str = "default", - options: v2_callable.Options = v2_callable.Options(), + options: "callable.Options" = callable.Options(), ): - options.hatchet = self - def inner(func: Callable[P, T]) -> v2_callable.HatchetCallable[P, T]: + def inner(func: Callable[P, T]) -> "callable.HatchetCallable[P, T]": if inspect.iscoroutinefunction(func): - callable = v2_callable.HatchetAwaitable( + wrapped = callable.HatchetAwaitable( func=func, name=name, namespace=namespace, + client=self, options=options, ) - callable = functools.update_wrapper(callable, func) - callable.action_name = self._registry.register(callable) - return callable + wrapped = functools.update_wrapper(wrapped, func) + return wrapped elif inspect.isfunction(func): - callable = v2_callable.HatchetCallable( + wrapped = callable.HatchetCallable( func=func, name=name, namespace=namespace, + client=self, options=options, ) - callable = functools.update_wrapper(callable, func) - callable.action_name = self._registry.register(callable) - return callable + wrapped = functools.update_wrapper(wrapped, func) + return wrapped else: raise TypeError( "the @function decorator can only be applied to functions (def) and async functions (async def)" @@ -121,8 +101,11 @@ def inner(func: Callable[P, T]) -> v2_callable.HatchetCallable[P, T]: return inner + def worker(self, options: worker.WorkerOptions) -> worker.Worker: + w = worker.Worker(client=self, options=options) + return w + # def durable( - # self, # name: str = "", # auto_register: bool = True, # on_events: list | None = None, @@ -135,53 +118,115 @@ def inner(func: Callable[P, T]) -> v2_callable.HatchetCallable[P, T]: # rate_limits: List[RateLimit] | None = None, # desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, # concurrency: v2.concurrency.ConcurrencyFunction | None = None, - # on_failure: Optional["HatchetCallable"] = None, + # on_failure: v2.callable.HatchetCallable | None = None, # default_priority: int | None = None, - # ) -> Callable[[v2.callable.HatchetCallable], v2.callable.HatchetCallable]: - # resp = durable( - # name=name, - # auto_register=auto_register, - # on_events=on_events, - # on_crons=on_crons, - # version=version, - # timeout=timeout, - # schedule_timeout=schedule_timeout, - # sticky=sticky, - # retries=retries, - # rate_limits=rate_limits, - # desired_worker_labels=desired_worker_labels, - # concurrency=concurrency, - # on_failure=on_failure, - # default_priority=default_priority, - # ) - - # def wrapper(func: Callable[[Context], T]) -> v2.callable.HatchetCallable[T]: - # wrapped_resp = resp(func) - - # if wrapped_resp.function_auto_register: - # self.functions.append(wrapped_resp) - - # wrapped_resp.with_namespace(self._client.config.namespace) - - # return wrapped_resp - - # return wrapper - - def worker( - self, - name: str, - max_runs: Optional[int] = None, - labels: Dict[str, str | int] = {}, - ): - worker = Worker( - name=name, - max_runs=max_runs, - labels=labels, - config=self._client.config, - debug=self._client.debug, - ) + # ): + # def inner(func: v2.callable.HatchetCallable) -> v2.callable.HatchetCallable: + # func.durable = True + + # f = function( + # name=name, + # auto_register=auto_register, + # on_events=on_events, + # on_crons=on_crons, + # version=version, + # timeout=timeout, + # schedule_timeout=schedule_timeout, + # sticky=sticky, + # retries=retries, + # rate_limits=rate_limits, + # desired_worker_labels=desired_worker_labels, + # concurrency=concurrency, + # on_failure=on_failure, + # default_priority=default_priority, + # ) + + # resp = f(func) + + # resp.durable = True + + # return resp + + # return inner + + # def concurrency( + # name: str = "concurrency", + # max_runs: int = 1, + # limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, + # ): + # def inner(func: Callable[[Context], str]) -> v2.concurrency.ConcurrencyFunction: + # return v2.concurrency.ConcurrencyFunction(func, name, max_runs, limit_strategy) + + # return inner + + # class OldHatchet(v1.Hatchet): + # # dag = staticmethod(v1.workflow) + # # concurrency = staticmethod(concurrency) + + # _registry: hatchet_registry.ActionRegistry = hatchet_registry.ActionRegistry() + + +# # def durable( +# # self, +# # name: str = "", +# # auto_register: bool = True, +# # on_events: list | None = None, +# # on_crons: list | None = None, +# # version: str = "", +# # timeout: str = "60m", +# # schedule_timeout: str = "5m", +# # sticky: StickyStrategy = None, +# # retries: int = 0, +# # rate_limits: List[RateLimit] | None = None, +# # desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, +# # concurrency: v2.concurrency.ConcurrencyFunction | None = None, +# # on_failure: Optional["HatchetCallable"] = None, +# # default_priority: int | None = None, +# # ) -> Callable[[v2.callable.HatchetCallable], v2.callable.HatchetCallable]: +# # resp = durable( +# # name=name, +# # auto_register=auto_register, +# # on_events=on_events, +# # on_crons=on_crons, +# # version=version, +# # timeout=timeout, +# # schedule_timeout=schedule_timeout, +# # sticky=sticky, +# # retries=retries, +# # rate_limits=rate_limits, +# # desired_worker_labels=desired_worker_labels, +# # concurrency=concurrency, +# # on_failure=on_failure, +# # default_priority=default_priority, +# # ) + +# # def wrapper(func: Callable[[Context], T]) -> v2.callable.HatchetCallable[T]: +# # wrapped_resp = resp(func) + +# # if wrapped_resp.function_auto_register: +# # self.functions.append(wrapped_resp) + +# # wrapped_resp.with_namespace(self._client.config.namespace) + +# # return wrapped_resp + +# # return wrapper + +# def worker( +# self, +# name: str, +# max_runs: Optional[int] = None, +# labels: Dict[str, str | int] = {}, +# ): +# worker = Worker( +# name=name, +# max_runs=max_runs, +# labels=labels, +# config=self._client.config, +# debug=self._client.debug, +# ) - for func in self._registry.registry.values(): - register_on_worker(func, worker) +# for func in self._registry.registry.values(): +# register_on_worker(func, worker) - return worker +# return worker diff --git a/hatchet_sdk/v2/runtime/config.py b/hatchet_sdk/v2/runtime/config.py new file mode 100644 index 00000000..37684432 --- /dev/null +++ b/hatchet_sdk/v2/runtime/config.py @@ -0,0 +1 @@ +from hatchet_sdk.loader import * diff --git a/hatchet_sdk/v2/runtime/context.py b/hatchet_sdk/v2/runtime/context.py new file mode 100644 index 00000000..6a55ea99 --- /dev/null +++ b/hatchet_sdk/v2/runtime/context.py @@ -0,0 +1,114 @@ +import asyncio +import copy +import os +import threading +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass +from typing import Optional + +import hatchet_sdk.v2.hatchet as hatchet + + +def _loopid() -> Optional[int]: + try: + return id(asyncio.get_running_loop()) + except: + return None + + +_ctxvar: ContextVar[Optional["BackgroundContext"]] = ContextVar( + "hatchet_background_context", default=None +) + + +@dataclass +class _RunInfo: + workflow_run_id: Optional[str] = None + step_run_id: Optional[str] = None + + namespace: str = "" + name: str = "" + + pid: int = os.getpid() + tid: int = threading.get_ident() + loopid: Optional[int] = _loopid() + + def copy(self): + return copy.deepcopy(self) + + +@dataclass +class BackgroundContext: + """Background context at function execution time.""" + + current: _RunInfo = _RunInfo() + parent: Optional[_RunInfo] = None + root: _RunInfo = current + + # The Hatchet client is a required property. + client: hatchet.Hatchet + + def set_workflow_run_id(self, id: str): + self.current.workflow_run_id = id + + def set_step_run_id(self, id: str): + self.current.step_run_id = id + + def copy(self): + ret = BackgroundContext( + client=self.client, + current=self.current.copy(), + parent=self.parent.copy() if self.parent else None, + root=self.root.copy(), + ) + return ret + + @staticmethod + def set(ctx: "BackgroundContext"): + global _ctxvar + _ctxvar.set(ctx) + + @staticmethod + def get() -> Optional["BackgroundContext"]: + global _ctxvar + return _ctxvar.get() + + +@contextmanager +def EnsureContext(client: Optional[hatchet.Hatchet] = None): + cleanup = False + ctx = BackgroundContext.get() + if ctx is None: + cleanup = True + assert client is not None + ctx = BackgroundContext(client=client) + BackgroundContext.set(ctx) + try: + yield ctx + finally: + if cleanup: + BackgroundContext.set(None) + + +@contextmanager +def WithContext(ctx: BackgroundContext): + prev = BackgroundContext.get() + BackgroundContext.set(ctx) + try: + yield ctx + finally: + BackgroundContext.set(prev) + + +@contextmanager +def EnterFunc(): + with EnsureContext() as current: + child = current.copy() + child.parent = current.current.copy() + child.current = _RunInfo() + with WithContext(child) as ctx: + try: + yield ctx + finally: + pass diff --git a/hatchet_sdk/v2/runtime/logging.py b/hatchet_sdk/v2/runtime/logging.py new file mode 100644 index 00000000..b6575142 --- /dev/null +++ b/hatchet_sdk/v2/runtime/logging.py @@ -0,0 +1,30 @@ +import hatchet_sdk.logger as v1 +import threading +import asyncio +import os + + +def _loopid(): + try: + return id(asyncio.get_running_loop()) + except: + return -1 + + +class HatchetLogger: + + def log(self, *args, **kwargs): + v1.logger.log(*args, **kwargs) + + def debug(self, *args, **kwargs): + v1.logger.debug(*args, **kwargs) + + def info(self, *args, **kwargs): + pid = str(os.getpid()) + tid = str(threading.get_ident()) + loopid = str(_loopid()) + v1.logger.info(f"{pid}, {tid}, {loopid}") + v1.logger.info(*args, **kwargs) + + +logger = HatchetLogger() diff --git a/hatchet_sdk/v2/runtime/registry.py b/hatchet_sdk/v2/runtime/registry.py new file mode 100644 index 00000000..3e21b314 --- /dev/null +++ b/hatchet_sdk/v2/runtime/registry.py @@ -0,0 +1,31 @@ +import sys +from typing import Dict + +import hatchet_sdk.v2.callable as callable +import hatchet_sdk.v2.hatchet as hatchet +import hatchet_sdk.v2.runtime.logging as logging + + +class ActionRegistry: + """A registry from action names (e.g. 'namespace:func') to Hatchet's callables. + + This is intended to be used per Hatchet client instance. + """ + + def __init__(self): + self.registry: Dict[str, callable.HatchetCallableBase] = dict() + + def add(self, key: str, callable: callable.HatchetCallableBase): + if key in self.registry: + raise KeyError(f"duplicated Hatchet callable: {key}") + self.registry[key] = callable + + def register_all(self, client: "hatchet.Hatchet"): + for callable in self.registry.values(): + proto = callable._to_workflow_proto() + try: + client.admin.put_workflow(proto.name, proto) + except Exception as e: + logging.logger.error(f"failed to register workflow: {proto.name}") + logging.logger.error(e) + sys.exit(1) diff --git a/hatchet_sdk/v2/runtime/runner.py b/hatchet_sdk/v2/runtime/runner.py new file mode 100644 index 00000000..c1583a64 --- /dev/null +++ b/hatchet_sdk/v2/runtime/runner.py @@ -0,0 +1,7 @@ + + + + + + +class BaseRunner \ No newline at end of file diff --git a/hatchet_sdk/v2/runtime/worker.py b/hatchet_sdk/v2/runtime/worker.py new file mode 100644 index 00000000..cc8ff179 --- /dev/null +++ b/hatchet_sdk/v2/runtime/worker.py @@ -0,0 +1,195 @@ +import asyncio +import time +from collections.abc import AsyncGenerator +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, List, Optional, Set +from concurrent.futures import ThreadPoolExecutor + +import grpc +from google.protobuf import timestamp_pb2 +from google.protobuf.json_format import MessageToDict, MessageToJson + +import hatchet_sdk.v2.hatchet as hatchet +from hatchet_sdk.contracts.dispatcher_pb2 import ( + ActionType, + AssignedAction, + HeartbeatRequest, + WorkerLabels, + WorkerListenRequest, + WorkerRegisterRequest, + WorkerRegisterResponse, + WorkerUnsubscribeRequest, +) +from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub + +import hatchet_sdk.connection as connection + + +@dataclass +class WorkerOptions: + name: str + actions: List[str] + slots: int = 5 + debug: bool = False + labels: Dict[str, str | int] = field(default_factory=dict) + heartbeat: int = 4 # heartbeat period in seconds + + @property + def labels_proto(self) -> Dict[str, WorkerLabels]: + ret = dict() + for k, v in self.labels.items(): + if isinstance(v, int): + ret[k] = WorkerLabels(intValue=v) + else: + ret[k] = WorkerLabels(strValue=str(v)) + return ret + + +class WorkerStatus(Enum): + UNKNOWN = 1 + REGISTERED = 2 + # STARTING = 2 + HEALTHY = 3 + UNHEALTHY = 4 + + +class _HeartBeater: + def __init__(self, worker: "Worker"): + self.worker = worker + self.last_heartbeat: int = -1 # unix epoch in seconds + self.missed = 0 + self.error = 0 + + async def heartbeat(self): + while not self.worker._shutdown: + now = int(time.time()) + proto = HeartbeatRequest( + workerId=self.worker.id, + heartbeatAt=timestamp_pb2.Timestamp(seconds=now), + ) + try: + resp = self.worker.client.dispatcher.client.Heartbeat( + proto, timeout=5, metadata=self.worker._grpc_metadata() + ) + self.worker.client.logger.info(f"heartbeat: {MessageToJson(resp)}") + except grpc.RpcError as e: + self.error += 1 + + if self.last_heartbeat < 0: + self.last_heartbeat = now + self.status = WorkerStatus.HEALTHY + else: + diff = proto.heartbeatAt.seconds - self.last_heartbeat + if diff > self.worker.options.heartbeat: + self.missed += 1 + + await asyncio.sleep(self.worker.options.heartbeat) + + +class _Listner: + def __init__(self, worker: "Worker"): + self.worker = worker + self.attempt = 0 + + conn = connection.new_conn(self.worker.client.config, aio=True) + self.stub = DispatcherStub(conn) + + async def listen(self) -> AsyncGenerator[AssignedAction]: + resp = None + try: + while not self.worker._shutdown: + proto = WorkerListenRequest(workerId=self.worker.id) + print(repr(asyncio.get_running_loop())) + resp = self.stub.ListenV2( + proto, timeout=5, metadata=self.worker._grpc_metadata() + ) + self.worker.client.logger.info("listening") + async for event in resp: + yield event + if self.worker._shutdown: + resp.cancel() + resp = None + break + resp = None + self.attempt += 1 + except Exception as e: + self.worker.client.logger.info(e) + raise e + finally: + if resp: + resp.cancel() + + +class Worker: + + def __init__( + self, + client: "hatchet.Hatchet", + options: WorkerOptions, + ): + self.options = options + self.client = client + self.status = WorkerStatus.UNKNOWN + self.id: Optional[str] = None + + self._shutdown = False # flag for shutting down + self._heartbeater = _HeartBeater(self) + self._heartbeater_task: Optional[asyncio.Task] = None + self._listener = _Listner(self) + self._listener_task: Optional[asyncio.Task] = None + + def _register(self) -> str: + resp: WorkerRegisterResponse = self.client.dispatcher.client.Register( + self._to_register_proto(), + timeout=30, + metadata=self._grpc_metadata(), + ) + self.client.logger.info(f"registered: {MessageToDict(resp)}") + self.id = resp.workerId + self.status = WorkerStatus.REGISTERED + return resp.workerId + + async def start(self): + self._register() + # self._heartbeat_task = asyncio.create_task( + # self._heartbeater.heartbeat(), name="heartbeat" + # ) + agen = self._listener.listen() + self._listener_task = asyncio.create_task(self._onevent(agen), name="listner") + # while True: + # if self._heartbeater.last_heartbeat > 0: + # return + # await asyncio.sleep(0.1) + + async def shutdown(self): + print("shutting down") + self._shutdown = True + # self._listener_task.cancel() + # self._heartbeat_task.cancel() + await asyncio.gather(self._heartbeat_task, self._listener_task) + + async def _onevent(self, agen: AsyncGenerator[AssignedAction]): + self.client.logger.info(repr(agen)) + try: + async for action in agen: + print(MessageToDict(action)) + except Exception as e: + print(e) + raise + finally: + pass + + def _grpc_metadata(self): + return [("authorization", f"bearer {self.client.config.token}")] + + def _to_register_proto(self) -> WorkerRegisterRequest: + options = self.options + proto = WorkerRegisterRequest( + workerName=options.name, + services=["default"], + actions=list(options.actions), + maxRuns=options.slots, + labels=options.labels_proto, + ) + return proto diff --git a/tests/v2/test_simple.py b/tests/v2/test_simple.py index 6fb1df4a..0db79a25 100644 --- a/tests/v2/test_simple.py +++ b/tests/v2/test_simple.py @@ -1,4 +1,5 @@ import asyncio + import pytest diff --git a/tests/v2/test_worker.py b/tests/v2/test_worker.py new file mode 100644 index 00000000..5fc0b27a --- /dev/null +++ b/tests/v2/test_worker.py @@ -0,0 +1,28 @@ +import asyncio + +import dotenv +import pytest + +from hatchet_sdk.v2.hatchet import Hatchet +from hatchet_sdk.v2.runtime.worker import WorkerOptions +import logging + +dotenv.load_dotenv() + +hatchet = Hatchet(debug=True) + +logging.getLogger("asyncio").setLevel(logging.DEBUG) + + +@hatchet.function() +def foo(): + pass + + +@pytest.mark.asyncio +async def test_worker(): + worker = hatchet.worker(WorkerOptions(name="worker", actions=["default:foo"])) + await worker.start() + foo() + await asyncio.sleep(10) + await worker.shutdown() From b5968eee5671ce941663939fb94008dd6bcdfac5 Mon Sep 17 00:00:00 2001 From: Hanwen Wu Date: Thu, 12 Sep 2024 20:55:07 -0400 Subject: [PATCH 04/12] hooking up the runner to actually run the functions --- hatchet_sdk/clients/admin.py | 3 + hatchet_sdk/v2/callable.py | 47 ++++---- hatchet_sdk/v2/hatchet.py | 17 +-- hatchet_sdk/v2/runtime/connection.py | 1 + hatchet_sdk/v2/runtime/context.py | 12 +- hatchet_sdk/v2/runtime/logging.py | 6 +- hatchet_sdk/v2/runtime/messages.py | 33 ++++++ hatchet_sdk/v2/runtime/registry.py | 4 +- hatchet_sdk/v2/runtime/runner.py | 113 ++++++++++++++++++- hatchet_sdk/v2/runtime/worker.py | 158 +++++++++++++++------------ tests/v2/test_worker.py | 10 +- 11 files changed, 291 insertions(+), 113 deletions(-) create mode 100644 hatchet_sdk/v2/runtime/connection.py create mode 100644 hatchet_sdk/v2/runtime/messages.py diff --git a/hatchet_sdk/clients/admin.py b/hatchet_sdk/clients/admin.py index afabe047..4dbe0959 100644 --- a/hatchet_sdk/clients/admin.py +++ b/hatchet_sdk/clients/admin.py @@ -4,6 +4,7 @@ import grpc from google.protobuf import timestamp_pb2 +from loguru import logger from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry from hatchet_sdk.clients.run_event_listener import new_listener @@ -358,6 +359,8 @@ def trigger_workflow( options: TriggerWorkflowOptions = None, ) -> str: request = self._prepare_workflow_request(workflow_name, input, options) + + logger.trace("trigger proto: {}", request) resp: TriggerWorkflowResponse = self.client.TriggerWorkflow( request, metadata=get_metadata(self.token), diff --git a/hatchet_sdk/v2/callable.py b/hatchet_sdk/v2/callable.py index 0565781e..512ef9db 100644 --- a/hatchet_sdk/v2/callable.py +++ b/hatchet_sdk/v2/callable.py @@ -24,10 +24,12 @@ ) from google.protobuf.json_format import MessageToDict +from loguru import logger from pydantic import BaseModel, ConfigDict, Field, computed_field from pydantic.json_schema import SkipJsonSchema import hatchet_sdk.v2.hatchet as v2hatchet +import hatchet_sdk.v2.runtime.context as context from hatchet_sdk.clients.admin import TriggerWorkflowOptions from hatchet_sdk.context import Context from hatchet_sdk.context.context import BaseContext, Context, ContextAioImpl @@ -69,7 +71,6 @@ def _sourceloc(fn) -> str: class HatchetCallableBase(Generic[P, T]): - def __init__( self, *, @@ -79,10 +80,10 @@ def __init__( client: v2hatchet.Hatchet, options: Options, ): - + # TODO: maybe use __qualname__ + name = name.lower() or func.__name__.lower() self._hatchet = CallableMetadata( - # TODO: maybe use __qualname__ - name=name.lower() or str(func.__name__).lower(), + name=name, namespace=namespace, sourceloc=_sourceloc(func), options=options, @@ -151,15 +152,12 @@ def _to_step_proto(self) -> CreateWorkflowStepOpts: return step def _to_trigger_proto(self) -> Optional[TriggerWorkflowOptions]: - return None - ctx = CallableContext.current() - if not ctx: - return None - trigger: TriggerWorkflowOptions = { - "parent_id": ctx.workflow_run_id, - "parent_step_run_id": ctx.step_run_id, - } - return trigger + with context.EnsureContext(self._hatchet.client) as ctx: + trigger: TriggerWorkflowOptions = { + "parent_id": ctx.parent.workflow_run_id if ctx.parent else None, + "parent_step_run_id": ctx.parent.step_run_id if ctx.parent else None, + } + return trigger def _debug(self): data = { @@ -174,25 +172,24 @@ def _debug(self): } return data - def _run(self, context: BaseContext): + def _run(self, ctx: BaseContext): raise NotImplementedError class HatchetCallable(HatchetCallableBase[P, T]): def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: - self._hatchet.client.logger.info(f"triggering {self._hatchet.action}") - input = json.dumps({"args": args, "kwargs": kwargs}) + logger.trace("triggering {}", self._to_trigger_proto()) + input = {"args": args, "kwargs": kwargs} client = self._hatchet.client ref = client.admin.trigger_workflow( self._hatchet.name, input=input, options=self._to_trigger_proto() ) - self._hatchet.client.logger.info(f"runid: {ref}") + logger.trace("runid: {}", ref) return None - def _run(self, context: Context) -> T: - print(f"running {self.action_name}") - input = json.loads(context.workflow_input) - return self.func(*input.args, **input.kwargs) + def _run(self, *args: P.args, **kwargs: P.kwargs) -> T: + print(f"running {self._hatchet.action}") + return self._hatchet.func(*args, **kwargs) class HatchetAwaitable(HatchetCallableBase[P, Awaitable[T]]): @@ -202,9 +199,9 @@ async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: client = self._.options.hatchet return await client.admin.run(self._.name, input).result() - async def _run(self, context: ContextAioImpl) -> T: + async def _run(self, ctx: ContextAioImpl) -> T: print(f"trigering {self.action_name}") - input = json.loads(context.workflow_input) + input = json.loads(ctx.workflow_input) return await self.func(*input.args, **input.kwargs) @@ -283,6 +280,10 @@ def _debug(self): } +class HatchetContextBase: + pass + + # # Context variable used for propagating hatchet context. # # The type of the variable is CallableContext. # _callable_cv = ContextVar("hatchet.callable") diff --git a/hatchet_sdk/v2/hatchet.py b/hatchet_sdk/v2/hatchet.py index d21ba3a8..42c7869c 100644 --- a/hatchet_sdk/v2/hatchet.py +++ b/hatchet_sdk/v2/hatchet.py @@ -1,16 +1,16 @@ +import asyncio import functools import inspect +import multiprocessing as mp from typing import Callable, Dict, List, Optional, ParamSpec, TypeVar import hatchet_sdk.hatchet as v1 +import hatchet_sdk.v2.callable as callable import hatchet_sdk.v2.runtime.config as config import hatchet_sdk.v2.runtime.logging as logging import hatchet_sdk.v2.runtime.registry as registry +import hatchet_sdk.v2.runtime.runner as runner import hatchet_sdk.v2.runtime.worker as worker -import hatchet_sdk.v2.callable as callable -import asyncio - - # import hatchet_sdk.runtime.registry as hatchet_registry # import hatchet_sdk.v2.callable as v2_callable @@ -32,7 +32,6 @@ class Hatchet: - def __init__( self, config: config.ClientConfig = config.ClientConfig(), @@ -49,6 +48,9 @@ def __init__( defaults=config, debug=debug, ) + self._q_action = mp.Queue() + self._q_event = mp.Queue() + self._runner = runner.BaseRunnerLoop(self, self._q_action, self._q_event) @property def admin(self): @@ -72,7 +74,6 @@ def function( namespace: str = "default", options: "callable.Options" = callable.Options(), ): - def inner(func: Callable[P, T]) -> "callable.HatchetCallable[P, T]": if inspect.iscoroutinefunction(func): wrapped = callable.HatchetAwaitable( @@ -102,7 +103,9 @@ def inner(func: Callable[P, T]) -> "callable.HatchetCallable[P, T]": return inner def worker(self, options: worker.WorkerOptions) -> worker.Worker: - w = worker.Worker(client=self, options=options) + w = worker.Worker( + client=self, options=options, inbound=self._q_event, outbound=self._q_action + ) return w # def durable( diff --git a/hatchet_sdk/v2/runtime/connection.py b/hatchet_sdk/v2/runtime/connection.py new file mode 100644 index 00000000..76331dc5 --- /dev/null +++ b/hatchet_sdk/v2/runtime/connection.py @@ -0,0 +1 @@ +from hatchet_sdk.connection import * diff --git a/hatchet_sdk/v2/runtime/context.py b/hatchet_sdk/v2/runtime/context.py index 6a55ea99..f16299d9 100644 --- a/hatchet_sdk/v2/runtime/context.py +++ b/hatchet_sdk/v2/runtime/context.py @@ -4,7 +4,7 @@ import threading from contextlib import contextmanager from contextvars import ContextVar -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional import hatchet_sdk.v2.hatchet as hatchet @@ -42,13 +42,13 @@ def copy(self): class BackgroundContext: """Background context at function execution time.""" - current: _RunInfo = _RunInfo() + # The Hatchet client is a required property. + client: "hatchet.Hatchet" + + current: _RunInfo = field(default_factory=_RunInfo) parent: Optional[_RunInfo] = None root: _RunInfo = current - # The Hatchet client is a required property. - client: hatchet.Hatchet - def set_workflow_run_id(self, id: str): self.current.workflow_run_id = id @@ -76,7 +76,7 @@ def get() -> Optional["BackgroundContext"]: @contextmanager -def EnsureContext(client: Optional[hatchet.Hatchet] = None): +def EnsureContext(client: Optional["hatchet.Hatchet"] = None): cleanup = False ctx = BackgroundContext.get() if ctx is None: diff --git a/hatchet_sdk/v2/runtime/logging.py b/hatchet_sdk/v2/runtime/logging.py index b6575142..d76da215 100644 --- a/hatchet_sdk/v2/runtime/logging.py +++ b/hatchet_sdk/v2/runtime/logging.py @@ -1,7 +1,8 @@ -import hatchet_sdk.logger as v1 -import threading import asyncio import os +import threading + +import hatchet_sdk.logger as v1 def _loopid(): @@ -12,7 +13,6 @@ def _loopid(): class HatchetLogger: - def log(self, *args, **kwargs): v1.logger.log(*args, **kwargs) diff --git a/hatchet_sdk/v2/runtime/messages.py b/hatchet_sdk/v2/runtime/messages.py new file mode 100644 index 00000000..df2eeeb0 --- /dev/null +++ b/hatchet_sdk/v2/runtime/messages.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Any + +from hatchet_sdk.contracts.dispatcher_pb2 import ( + GROUP_KEY_EVENT_TYPE_COMPLETED, + GROUP_KEY_EVENT_TYPE_FAILED, + GROUP_KEY_EVENT_TYPE_STARTED, + STEP_EVENT_TYPE_COMPLETED, + STEP_EVENT_TYPE_FAILED, + STEP_EVENT_TYPE_STARTED, + ActionType, + AssignedAction, + GroupKeyActionEventType, + StepActionEventType, + WorkflowRunEventType, +) +from hatchet_sdk.worker.action_listener_process import Action + + +class MessageType(Enum): + UNKNOWN = 0 + ACTION_RUN = 1 + ACTION_CANCEL = 2 + EVENT_STARTED = 3 + EVENT_FINISHED = 4 + + +@dataclass +class Message: + action: AssignedAction + type: MessageType + payload: Any = None diff --git a/hatchet_sdk/v2/runtime/registry.py b/hatchet_sdk/v2/runtime/registry.py index 3e21b314..248612f9 100644 --- a/hatchet_sdk/v2/runtime/registry.py +++ b/hatchet_sdk/v2/runtime/registry.py @@ -13,9 +13,9 @@ class ActionRegistry: """ def __init__(self): - self.registry: Dict[str, callable.HatchetCallableBase] = dict() + self.registry: Dict[str, "callable.HatchetCallableBase"] = dict() - def add(self, key: str, callable: callable.HatchetCallableBase): + def add(self, key: str, callable: "callable.HatchetCallableBase"): if key in self.registry: raise KeyError(f"duplicated Hatchet callable: {key}") self.registry[key] = callable diff --git a/hatchet_sdk/v2/runtime/runner.py b/hatchet_sdk/v2/runtime/runner.py index c1583a64..c3974488 100644 --- a/hatchet_sdk/v2/runtime/runner.py +++ b/hatchet_sdk/v2/runtime/runner.py @@ -1,7 +1,118 @@ +import asyncio +import json +import multiprocessing as mp +from typing import Any, Dict, Optional, Tuple +from loguru import logger +import hatchet_sdk.contracts.dispatcher_pb2 +import hatchet_sdk.v2.callable as callable +import hatchet_sdk.v2.hatchet as hatchet +import hatchet_sdk.v2.runtime.messages as messages +class _Runner: + def __init__( + self, + registry: Dict[str, "callable.HatchetCallableBase"], + msg: "messages.Message", + ): + self.registry = registry + logger.info(self.registry.keys()) + self.msg = msg + async def run(self) -> Tuple[Any, Exception]: + args = self.input["args"] + kwargs = self.input["kwargs"] + logger.trace("trying to run {} with input {}", self.action, self.input) + logger.trace(repr(self.fn)) + try: + if isinstance(self.fn, callable.HatchetCallable): + return await asyncio.to_thread(self.fn._run, *args, **kwargs), None + else: + return await self.fn._run(*args, **kwargs) + except Exception as e: + logger.exception(e) + return None, e -class BaseRunner \ No newline at end of file + @property + def action(self): + return self.msg.action.get("actionId") + + @property + def parent(self): + return self.msg.action.parent_workflow_run_id + + @property + def input(self): + return json.loads(self.msg.action.get("actionPayload")).get("input") + + @property + def fn(self): + return self.registry[self.action] + + +class BaseRunnerLoop: + def __init__( + self, + client: "hatchet.Hatchet", + inbound: mp.Queue, # inbound queue + outbound: mp.Queue, # outbound queue + ): + self.client = client + self.registry: Dict[str, "callable.HatchetCallableBase"] = ( + client.registry.registry + ) + self.inbound = inbound + self.outbound = outbound + + self.looptask: Optional[asyncio.Task] = None + + def start(self): + self.looptask = asyncio.create_task(self._runner_loop(), name="runner") + + async def shutdown(self): + self.looptask.cancel() + try: + await self.looptask + except asyncio.CancelledError: + logger.info("bye") + + async def _runner_loop(self): + while True: + msg: "messages.Message" = await asyncio.to_thread( + self._next_message_blocking + ) + match msg.type: + case messages.MessageType.ACTION_RUN: + await self._on_run(msg) + case messages.MessageType.ACTION_CANCEL: + await self._on_cancel(msg) + case _: + logger.debug(msg) + + async def _on_run(self, msg: "messages.Message"): + await asyncio.to_thread(self._emit_start_blocking) + result, e = await _Runner(self.registry, msg).run() + if e is None: + await asyncio._to_thread(self._emit_finish_blocking) + + async def _on_cancel(self, msg: "messages.Message"): + pass + + def _next_message_blocking(self) -> messages.Message: + return self.inbound.get() + + def _emit_start_blocking(self): + msg = messages.Message( + action=messages.Action(), + type=messages.MessageType.EVENT_FINISHED, + ) + self.outbound.put(msg) + + def _emit_finish_blocking(self): + msg = messages.Message( + action=messages.Action(), + type=messages.MessageType.EVENT_STARTED, + ) + self.outbound.put(msg) diff --git a/hatchet_sdk/v2/runtime/worker.py b/hatchet_sdk/v2/runtime/worker.py index cc8ff179..23d0a9db 100644 --- a/hatchet_sdk/v2/runtime/worker.py +++ b/hatchet_sdk/v2/runtime/worker.py @@ -1,16 +1,21 @@ import asyncio +import multiprocessing as mp import time from collections.abc import AsyncGenerator +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from enum import Enum from typing import Dict, List, Optional, Set -from concurrent.futures import ThreadPoolExecutor import grpc from google.protobuf import timestamp_pb2 from google.protobuf.json_format import MessageToDict, MessageToJson +from loguru import logger +import hatchet_sdk.contracts.dispatcher_pb2 import hatchet_sdk.v2.hatchet as hatchet +import hatchet_sdk.v2.runtime.connection as connection +import hatchet_sdk.v2.runtime.messages as messages from hatchet_sdk.contracts.dispatcher_pb2 import ( ActionType, AssignedAction, @@ -23,8 +28,6 @@ ) from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub -import hatchet_sdk.connection as connection - @dataclass class WorkerOptions: @@ -58,82 +61,94 @@ class _HeartBeater: def __init__(self, worker: "Worker"): self.worker = worker self.last_heartbeat: int = -1 # unix epoch in seconds + self.stub = DispatcherStub( + connection.new_conn(self.worker.client.config, aio=False) + ) self.missed = 0 self.error = 0 async def heartbeat(self): - while not self.worker._shutdown: - now = int(time.time()) - proto = HeartbeatRequest( - workerId=self.worker.id, - heartbeatAt=timestamp_pb2.Timestamp(seconds=now), - ) - try: - resp = self.worker.client.dispatcher.client.Heartbeat( - proto, timeout=5, metadata=self.worker._grpc_metadata() + try: + # It will exit the loop when a asyncio.CancelledError is raised + # by calling task.cancel() from outside. + while True: + now = int(time.time()) + proto = HeartbeatRequest( + workerId=self.worker.id, + heartbeatAt=timestamp_pb2.Timestamp(seconds=now), ) - self.worker.client.logger.info(f"heartbeat: {MessageToJson(resp)}") - except grpc.RpcError as e: - self.error += 1 + try: + resp = self.stub.Heartbeat( + proto, timeout=5, metadata=self.worker._grpc_metadata() + ) + logger.trace("heartbeat: {}", MessageToJson(resp)) + except grpc.RpcErrors: + self.error += 1 + + if self.last_heartbeat < 0: + self.last_heartbeat = now + self.status = WorkerStatus.HEALTHY + else: + diff = proto.heartbeatAt.seconds - self.last_heartbeat + if diff > self.worker.options.heartbeat: + self.missed += 1 + await asyncio.sleep(self.worker.options.heartbeat) - if self.last_heartbeat < 0: - self.last_heartbeat = now - self.status = WorkerStatus.HEALTHY - else: - diff = proto.heartbeatAt.seconds - self.last_heartbeat - if diff > self.worker.options.heartbeat: - self.missed += 1 - - await asyncio.sleep(self.worker.options.heartbeat) + finally: + logger.info("shutting down heartbeater") class _Listner: def __init__(self, worker: "Worker"): self.worker = worker self.attempt = 0 - - conn = connection.new_conn(self.worker.client.config, aio=True) - self.stub = DispatcherStub(conn) + self.stub = DispatcherStub( + connection.new_conn(self.worker.client.config, aio=True) + ) async def listen(self) -> AsyncGenerator[AssignedAction]: resp = None try: - while not self.worker._shutdown: + # It will exit the loop when asyncio.CancelledError is + # raised by calling task.cancel() from outside. + while True: proto = WorkerListenRequest(workerId=self.worker.id) - print(repr(asyncio.get_running_loop())) - resp = self.stub.ListenV2( - proto, timeout=5, metadata=self.worker._grpc_metadata() - ) - self.worker.client.logger.info("listening") - async for event in resp: - yield event - if self.worker._shutdown: - resp.cancel() - resp = None - break - resp = None - self.attempt += 1 - except Exception as e: - self.worker.client.logger.info(e) - raise e + try: + resp = self.stub.ListenV2( + proto, metadata=self.worker._grpc_metadata() + ) + logger.trace("listening") + async for event in resp: + yield event + + resp = None + self.attempt += 1 + except grpc.aio.AioRpcError as e: + logger.warning(e) + + # TODO: expotential backoff, retry limit, etc + finally: + logger.info("shutting down listener") if resp: resp.cancel() class Worker: - def __init__( self, client: "hatchet.Hatchet", + inbound: mp.Queue, + outbound: mp.Queue, options: WorkerOptions, ): self.options = options self.client = client self.status = WorkerStatus.UNKNOWN self.id: Optional[str] = None + self.inbound = inbound + self.outbound = outbound - self._shutdown = False # flag for shutting down self._heartbeater = _HeartBeater(self) self._heartbeater_task: Optional[asyncio.Task] = None self._listener = _Listner(self) @@ -145,40 +160,43 @@ def _register(self) -> str: timeout=30, metadata=self._grpc_metadata(), ) - self.client.logger.info(f"registered: {MessageToDict(resp)}") + logger.debug(f"worker registered: {MessageToDict(resp)}") self.id = resp.workerId self.status = WorkerStatus.REGISTERED return resp.workerId async def start(self): self._register() - # self._heartbeat_task = asyncio.create_task( - # self._heartbeater.heartbeat(), name="heartbeat" - # ) - agen = self._listener.listen() - self._listener_task = asyncio.create_task(self._onevent(agen), name="listner") - # while True: - # if self._heartbeater.last_heartbeat > 0: - # return - # await asyncio.sleep(0.1) + self._heartbeat_task = asyncio.create_task( + self._heartbeater.heartbeat(), name="heartbeat" + ) + self._listener_task = asyncio.create_task( + self._onevent(self._listener.listen()), name="listner" + ) + while True: + if self._heartbeater.last_heartbeat > 0: + return + await asyncio.sleep(0.1) async def shutdown(self): - print("shutting down") - self._shutdown = True - # self._listener_task.cancel() - # self._heartbeat_task.cancel() - await asyncio.gather(self._heartbeat_task, self._listener_task) + tg = asyncio.gather(self._heartbeat_task, self._listener_task) + tg.cancel() + try: + await tg + except asyncio.CancelledError: + logger.info("bye") async def _onevent(self, agen: AsyncGenerator[AssignedAction]): - self.client.logger.info(repr(agen)) - try: - async for action in agen: - print(MessageToDict(action)) - except Exception as e: - print(e) - raise - finally: - pass + async for action in agen: + if action.actionType == ActionType.START_STEP_RUN: + await asyncio.to_thread( + self.outbound.put, + messages.Message( + action=MessageToDict(action), + type=messages.MessageType.ACTION_RUN, + ), + ) + logger.trace(MessageToDict(action)) def _grpc_metadata(self): return [("authorization", f"bearer {self.client.config.token}")] diff --git a/tests/v2/test_worker.py b/tests/v2/test_worker.py index 5fc0b27a..4211e28f 100644 --- a/tests/v2/test_worker.py +++ b/tests/v2/test_worker.py @@ -1,11 +1,16 @@ import asyncio +import logging +import sys import dotenv import pytest +from loguru import logger from hatchet_sdk.v2.hatchet import Hatchet from hatchet_sdk.v2.runtime.worker import WorkerOptions -import logging + +logger.remove() +logger.add(sys.stdout, level="TRACE") dotenv.load_dotenv() @@ -16,6 +21,7 @@ @hatchet.function() def foo(): + print("HAHAHA") pass @@ -23,6 +29,8 @@ def foo(): async def test_worker(): worker = hatchet.worker(WorkerOptions(name="worker", actions=["default:foo"])) await worker.start() + hatchet._runner.start() foo() await asyncio.sleep(10) await worker.shutdown() + await hatchet._runner.shutdown() From 6d81998362895ea6a382347b382928c6871febb5 Mon Sep 17 00:00:00 2001 From: Hanwen Wu Date: Fri, 13 Sep 2024 20:48:44 -0400 Subject: [PATCH 05/12] a bit more work on propagating the events, and on runtime --- hatchet_sdk/v2/callable.py | 22 +++- hatchet_sdk/v2/hatchet.py | 11 +- hatchet_sdk/v2/runtime/context.py | 14 ++- hatchet_sdk/v2/runtime/messages.py | 38 +++++-- hatchet_sdk/v2/runtime/runner.py | 157 ++++++++++++++++++++--------- hatchet_sdk/v2/runtime/runtime.py | 30 ++++++ hatchet_sdk/v2/runtime/worker.py | 13 +-- tests/v2/test_worker.py | 1 + 8 files changed, 205 insertions(+), 81 deletions(-) create mode 100644 hatchet_sdk/v2/runtime/runtime.py diff --git a/hatchet_sdk/v2/callable.py b/hatchet_sdk/v2/callable.py index 512ef9db..49f8b6fa 100644 --- a/hatchet_sdk/v2/callable.py +++ b/hatchet_sdk/v2/callable.py @@ -43,8 +43,11 @@ WorkflowConcurrencyOpts, WorkflowKind, ) +from hatchet_sdk.contracts.dispatcher_pb2 import AssignedAction from hatchet_sdk.labels import DesiredWorkerLabel -from hatchet_sdk.logger import logger + +# from hatchet_sdk.logger import logger +from loguru import logger from hatchet_sdk.rate_limit import RateLimit from hatchet_sdk.v2.concurrency import ConcurrencyFunction from hatchet_sdk.v2.runtime import registry @@ -172,7 +175,8 @@ def _debug(self): } return data - def _run(self, ctx: BaseContext): + def _run(self, ctx: BaseContext) -> str: + # actually invokes the function, and serializing the output raise NotImplementedError @@ -187,9 +191,17 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: logger.trace("runid: {}", ref) return None - def _run(self, *args: P.args, **kwargs: P.kwargs) -> T: - print(f"running {self._hatchet.action}") - return self._hatchet.func(*args, **kwargs) + def _run(self, action: AssignedAction) -> str: + assert action.actionId == self._hatchet.action + logger.trace("invoking {}", action.actionId) + input = json.loads(action.actionPayload)["input"] + + with context.EnsureContext(self._hatchet.client) as ctx: + ctx.set_step_run_id(action.stepRunId) + ctx.set_workflow_run_id(action.workflowRunId) + ctx.set_parent_workflow_run_id(action.parent_workflow_run_id) + with context.WithContext(ctx): + return self._hatchet.func(*input["args"], **input["kwargs"]) class HatchetAwaitable(HatchetCallableBase[P, Awaitable[T]]): diff --git a/hatchet_sdk/v2/hatchet.py b/hatchet_sdk/v2/hatchet.py index 42c7869c..310dceb4 100644 --- a/hatchet_sdk/v2/hatchet.py +++ b/hatchet_sdk/v2/hatchet.py @@ -11,6 +11,7 @@ import hatchet_sdk.v2.runtime.registry as registry import hatchet_sdk.v2.runtime.runner as runner import hatchet_sdk.v2.runtime.worker as worker +import hatchet_sdk.v2.runtime.runtime as runtime # import hatchet_sdk.runtime.registry as hatchet_registry # import hatchet_sdk.v2.callable as v2_callable @@ -48,9 +49,6 @@ def __init__( defaults=config, debug=debug, ) - self._q_action = mp.Queue() - self._q_event = mp.Queue() - self._runner = runner.BaseRunnerLoop(self, self._q_action, self._q_event) @property def admin(self): @@ -102,11 +100,8 @@ def inner(func: Callable[P, T]) -> "callable.HatchetCallable[P, T]": return inner - def worker(self, options: worker.WorkerOptions) -> worker.Worker: - w = worker.Worker( - client=self, options=options, inbound=self._q_event, outbound=self._q_action - ) - return w + def worker(self, options: "worker.WorkerOptions") -> "runtime.Runtime": + return runtime.Runtime(self, options) # def durable( # name: str = "", diff --git a/hatchet_sdk/v2/runtime/context.py b/hatchet_sdk/v2/runtime/context.py index f16299d9..0ae12fd8 100644 --- a/hatchet_sdk/v2/runtime/context.py +++ b/hatchet_sdk/v2/runtime/context.py @@ -45,9 +45,9 @@ class BackgroundContext: # The Hatchet client is a required property. client: "hatchet.Hatchet" - current: _RunInfo = field(default_factory=_RunInfo) + current: _RunInfo + root: _RunInfo parent: Optional[_RunInfo] = None - root: _RunInfo = current def set_workflow_run_id(self, id: str): self.current.workflow_run_id = id @@ -55,6 +55,14 @@ def set_workflow_run_id(self, id: str): def set_step_run_id(self, id: str): self.current.step_run_id = id + def set_parent_workflow_run_id(self, id: str): + if self.parent is None: + self.parent = _RunInfo( + workflow_run_id=id, step_run_id=None, pid=None, tid=None, loopid=None + ) + else: + self.parent.workflow_run_id = id + def copy(self): ret = BackgroundContext( client=self.client, @@ -82,7 +90,7 @@ def EnsureContext(client: Optional["hatchet.Hatchet"] = None): if ctx is None: cleanup = True assert client is not None - ctx = BackgroundContext(client=client) + ctx = BackgroundContext(client=client, current=_RunInfo(), root=_RunInfo()) BackgroundContext.set(ctx) try: yield ctx diff --git a/hatchet_sdk/v2/runtime/messages.py b/hatchet_sdk/v2/runtime/messages.py index df2eeeb0..28992ab3 100644 --- a/hatchet_sdk/v2/runtime/messages.py +++ b/hatchet_sdk/v2/runtime/messages.py @@ -14,20 +14,42 @@ GroupKeyActionEventType, StepActionEventType, WorkflowRunEventType, + StepActionEvent, ) from hatchet_sdk.worker.action_listener_process import Action +from google.protobuf.json_format import ParseDict +from typing import Optional, Dict -class MessageType(Enum): +class MessageKind(Enum): UNKNOWN = 0 - ACTION_RUN = 1 - ACTION_CANCEL = 2 - EVENT_STARTED = 3 - EVENT_FINISHED = 4 + ACTION = 1 + STEP_EVENT = 2 @dataclass class Message: - action: AssignedAction - type: MessageType - payload: Any = None + """The runtime IPC message format. Note that it has to be trivially pickle-able.""" + + _action: Optional[Dict] = None + _step_event: Optional[Dict] = None + + @property + def kind(self) -> MessageKind: + if self._action is not None: + return MessageKind.ACTION + if self._step_event is not None: + return MessageKind.STEP_EVENT + return MessageKind.UNKNOWN + + @property + def action(self) -> AssignedAction: + assert self._action is not None + ret = AssignedAction() + return ParseDict(self._action, ret) + + @property + def step_event(self) -> StepActionEvent: + assert self._step_event is not None + ret = StepActionEvent() + return ParseDict(self._step_event, ret) diff --git a/hatchet_sdk/v2/runtime/runner.py b/hatchet_sdk/v2/runtime/runner.py index c3974488..8a02bb5f 100644 --- a/hatchet_sdk/v2/runtime/runner.py +++ b/hatchet_sdk/v2/runtime/runner.py @@ -5,10 +5,28 @@ from loguru import logger -import hatchet_sdk.contracts.dispatcher_pb2 +from hatchet_sdk.contracts.dispatcher_pb2 import ( + ActionType, + StepActionEvent, + StepActionEventType, +) import hatchet_sdk.v2.callable as callable import hatchet_sdk.v2.hatchet as hatchet import hatchet_sdk.v2.runtime.messages as messages +from google.protobuf.timestamp_pb2 import Timestamp +import time +from google.protobuf.json_format import MessageToDict +import traceback + + +def _timestamp(): + ns = time.time_ns() + return Timestamp(seconds=int(ns // 1e9), nanos=int(ns % 1e9)) + + +def _format_exc(e: Exception): + trace = "".join(traceback.format_exception(e)) + return "\n".join[str(e), trace] class _Runner: @@ -18,34 +36,25 @@ def __init__( msg: "messages.Message", ): self.registry = registry - logger.info(self.registry.keys()) + assert msg.kind == messages.MessageKind.ACTION self.msg = msg - async def run(self) -> Tuple[Any, Exception]: - args = self.input["args"] - kwargs = self.input["kwargs"] - logger.trace("trying to run {} with input {}", self.action, self.input) - logger.trace(repr(self.fn)) + async def run(self) -> Tuple[str, Exception]: + logger.trace("runner invoking: {}", repr(self.fn)) try: if isinstance(self.fn, callable.HatchetCallable): - return await asyncio.to_thread(self.fn._run, *args, **kwargs), None + return await asyncio.to_thread(self.fn._run, self.msg.action), None else: - return await self.fn._run(*args, **kwargs) + return await self.fn._run(self.msg.action), None + except asyncio.CancelledError: + raise except Exception as e: logger.exception(e) return None, e @property def action(self): - return self.msg.action.get("actionId") - - @property - def parent(self): - return self.msg.action.parent_workflow_run_id - - @property - def input(self): - return json.loads(self.msg.action.get("actionPayload")).get("input") + return self.msg.action.actionId @property def fn(self): @@ -63,56 +72,108 @@ def __init__( self.registry: Dict[str, "callable.HatchetCallableBase"] = ( client.registry.registry ) + self.worker_id:Optional[str] = None + self.inbound = inbound self.outbound = outbound self.looptask: Optional[asyncio.Task] = None + # a dict from StepRunId to its tasks + self.runners: Dict[str, asyncio.Task] = dict() + def start(self): - self.looptask = asyncio.create_task(self._runner_loop(), name="runner") + self.looptask = asyncio.create_task(self.loop(), name="runnerloop") async def shutdown(self): - self.looptask.cancel() + logger.info("shutting down runner loop") + t = asyncio.gather(*self.runners.values(), self.looptask) + self.outbound.close() + t.cancel() try: - await self.looptask + await t except asyncio.CancelledError: logger.info("bye") - async def _runner_loop(self): + async def loop(self): while True: - msg: "messages.Message" = await asyncio.to_thread( - self._next_message_blocking - ) - match msg.type: - case messages.MessageType.ACTION_RUN: - await self._on_run(msg) - case messages.MessageType.ACTION_CANCEL: - await self._on_cancel(msg) + msg: "messages.Message" = await self.next() + assert msg.kind == messages.MessageKind.ACTION + match msg.action.actionType: + case ActionType.START_STEP_RUN: + self.on_run(msg) + case ActionType.CANCEL_STEP_RUN: + self.on_cancel(msg) case _: logger.debug(msg) - async def _on_run(self, msg: "messages.Message"): - await asyncio.to_thread(self._emit_start_blocking) - result, e = await _Runner(self.registry, msg).run() - if e is None: - await asyncio._to_thread(self._emit_finish_blocking) + def on_run(self, msg: "messages.Message"): + async def task(): + try: + await self.emit_started(msg) + result, e = await _Runner(self.registry, msg).run() + if e is None: + await self.emit_finished(msg, result) + else: + await self.emit_failed(msg, _format_exc(e)) + finally: + del self.runners[msg.action.stepRunId] + + self.runners[msg.action.stepRunId] = asyncio.create_task( + task(), name=msg.action.stepRunId + ) - async def _on_cancel(self, msg: "messages.Message"): + def step_event(self, msg: "messages.Message", **kwargs) -> StepActionEvent: + base = StepActionEvent( + jobId=msg.action.jobId, + jobRunId=msg.action.jobRunId, + stepId=msg.action.stepId, + stepRunId=msg.action.stepRunId, + actionId=msg.action.actionId, + eventTimestamp=_timestamp(), + ) + base.MergeFrom(StepActionEvent(**kwargs)) + return MessageToDict(base) + + def on_cancel(self, msg: "messages.Message"): pass - def _next_message_blocking(self) -> messages.Message: - return self.inbound.get() + async def emit_started(self, msg: "messages.Message"): + await self.send( + messages.Message( + _step_event=self.step_event( + msg, eventType=StepActionEventType.STEP_EVENT_TYPE_STARTED + ) + ) + ) - def _emit_start_blocking(self): - msg = messages.Message( - action=messages.Action(), - type=messages.MessageType.EVENT_FINISHED, + async def emit_finished(self, msg: "messages.Message", payload: str): + await self.send( + messages.Message( + _step_event=self.step_event( + msg, + eventType=StepActionEventType.STEP_EVENT_TYPE_COMPLETED, + eventPayload=payload, + ) + ) ) - self.outbound.put(msg) - def _emit_finish_blocking(self): - msg = messages.Message( - action=messages.Action(), - type=messages.MessageType.EVENT_STARTED, + async def emit_failed(self, msg: "messages.Message", payload: str): + await self.send( + messages.Message( + _step_event=self.step_event( + msg, + eventType=StepActionEventType.STEP_EVENT_TYPE_FAILED, + eventPayload=payload, + ) + ) ) - self.outbound.put(msg) + + async def send(self, msg: "messages.Message"): + logger.trace("sending: {}", msg) + await asyncio.to_thread(self.outbound.put, msg) + + async def next(self) -> "messages.Message": + msg = await asyncio.to_thread(self.inbound.get) # raise EOFError if the queue is closed + logger.trace("recv: {}", msg) + return msg diff --git a/hatchet_sdk/v2/runtime/runtime.py b/hatchet_sdk/v2/runtime/runtime.py new file mode 100644 index 00000000..9dea07c1 --- /dev/null +++ b/hatchet_sdk/v2/runtime/runtime.py @@ -0,0 +1,30 @@ +import hatchet_sdk.v2.hatchet as hatchet +import hatchet_sdk.v2.runtime.worker as worker +import hatchet_sdk.v2.runtime.runner as runner +import multiprocessing as mp +import asyncio + + +class Runtime: + + def __init__(self, client: "hatchet.Hatchet", options: "worker.WorkerOptions"): + self.events = mp.Queue() + self.actions = mp.Queue() + self.worker = worker.Worker( + client=client, inbound=self.events, outbound=self.actions, options=options + ) + self.runner = runner.BaseRunnerLoop( + client=client, inbound=self.actions, outbound=self.events + ) + + async def start(self): + self.runner.start() + await self.worker.start() + self.runner.worker_id = self.worker.id + return self.worker.id + + async def shutdown(self): + await self.worker.shutdown() + await self.runner.shutdown() + self.actions.close() + self.events.close() diff --git a/hatchet_sdk/v2/runtime/worker.py b/hatchet_sdk/v2/runtime/worker.py index 23d0a9db..fa8870d3 100644 --- a/hatchet_sdk/v2/runtime/worker.py +++ b/hatchet_sdk/v2/runtime/worker.py @@ -179,8 +179,9 @@ async def start(self): await asyncio.sleep(0.1) async def shutdown(self): - tg = asyncio.gather(self._heartbeat_task, self._listener_task) + tg: asyncio.Future = asyncio.gather(self._heartbeat_task, self._listener_task) tg.cancel() + self.outbound.close() try: await tg except asyncio.CancelledError: @@ -188,14 +189,8 @@ async def shutdown(self): async def _onevent(self, agen: AsyncGenerator[AssignedAction]): async for action in agen: - if action.actionType == ActionType.START_STEP_RUN: - await asyncio.to_thread( - self.outbound.put, - messages.Message( - action=MessageToDict(action), - type=messages.MessageType.ACTION_RUN, - ), - ) + msg = messages.Message(_action=MessageToDict(action)) + await asyncio.to_thread(self.outbound.put, msg) logger.trace(MessageToDict(action)) def _grpc_metadata(self): diff --git a/tests/v2/test_worker.py b/tests/v2/test_worker.py index 4211e28f..b06a373c 100644 --- a/tests/v2/test_worker.py +++ b/tests/v2/test_worker.py @@ -34,3 +34,4 @@ async def test_worker(): await asyncio.sleep(10) await worker.shutdown() await hatchet._runner.shutdown() + return None From 4eb729065971d89d5e2f54f7930e6a57831c466c Mon Sep 17 00:00:00 2001 From: Hanwen Wu Date: Sun, 15 Sep 2024 13:26:08 -0400 Subject: [PATCH 06/12] continue working on context propagation --- hatchet_sdk/v2/callable.py | 114 ++++++++++++++++++++--------- hatchet_sdk/v2/hatchet.py | 2 +- hatchet_sdk/v2/runtime/context.py | 79 ++++++++++++-------- hatchet_sdk/v2/runtime/messages.py | 8 +- hatchet_sdk/v2/runtime/runner.py | 101 ++++++++++++------------- hatchet_sdk/v2/runtime/runtime.py | 20 +++-- hatchet_sdk/v2/runtime/worker.py | 99 ++++++++++++++++++------- tests/v2/test_worker.py | 15 ++-- 8 files changed, 275 insertions(+), 163 deletions(-) diff --git a/hatchet_sdk/v2/callable.py b/hatchet_sdk/v2/callable.py index 49f8b6fa..27fc9cea 100644 --- a/hatchet_sdk/v2/callable.py +++ b/hatchet_sdk/v2/callable.py @@ -6,7 +6,7 @@ from collections.abc import Awaitable, Callable # from contextvars import ContextVar, copy_context -from dataclasses import dataclass +from dataclasses import dataclass, asdict # from datetime import timedelta from typing import ( @@ -24,6 +24,8 @@ ) from google.protobuf.json_format import MessageToDict + +# from hatchet_sdk.logger import logger from loguru import logger from pydantic import BaseModel, ConfigDict, Field, computed_field from pydantic.json_schema import SkipJsonSchema @@ -33,6 +35,7 @@ from hatchet_sdk.clients.admin import TriggerWorkflowOptions from hatchet_sdk.context import Context from hatchet_sdk.context.context import BaseContext, Context, ContextAioImpl +from hatchet_sdk.contracts.dispatcher_pb2 import AssignedAction from hatchet_sdk.contracts.workflows_pb2 import ( CreateStepRateLimit, CreateWorkflowJobOpts, @@ -43,11 +46,7 @@ WorkflowConcurrencyOpts, WorkflowKind, ) -from hatchet_sdk.contracts.dispatcher_pb2 import AssignedAction from hatchet_sdk.labels import DesiredWorkerLabel - -# from hatchet_sdk.logger import logger -from loguru import logger from hatchet_sdk.rate_limit import RateLimit from hatchet_sdk.v2.concurrency import ConcurrencyFunction from hatchet_sdk.v2.runtime import registry @@ -154,26 +153,56 @@ def _to_step_proto(self) -> CreateWorkflowStepOpts: ) return step - def _to_trigger_proto(self) -> Optional[TriggerWorkflowOptions]: - with context.EnsureContext(self._hatchet.client) as ctx: - trigger: TriggerWorkflowOptions = { - "parent_id": ctx.parent.workflow_run_id if ctx.parent else None, - "parent_step_run_id": ctx.parent.step_run_id if ctx.parent else None, - } - return trigger - - def _debug(self): - data = { - "self": repr(self), - "metadata": self._hatchet._debug(), - "def_proto": MessageToDict(self._to_workflow_proto()), - "call_proto": ( - MessageToDict(self._to_trigger_proto()) - if self._to_trigger_proto() - else None + def _ctx_to_trigger_proto( + self, ctx: "context.BackgroundContext" + ) -> Optional[TriggerWorkflowOptions]: + # We are not in any valid Hatchet context. This means we're the root. + if ctx.current is None: + return None + + # Otherwise, the current context is the parent. + assert ctx.current is not None + trigger: TriggerWorkflowOptions = { + "parent_id": ctx.current.workflow_run_id, + "parent_step_run_id": ctx.current.step_run_id, + "additional_metadata": json.dumps( + {"_hatchet_background_context": ctx.asdict()} ), } - return data + return trigger + + def _ctx_from_action( + self, action: AssignedAction + ) -> Optional["context.BackgroundContext"]: + if not action.additional_metadata: + return None + + d: Optional[Dict] = None + try: + d = json.loads(action.additional_metadata) + except json.JSONDecodeError: + logger.warning("failed to decode additional metadata from assigned action") + return None + + if "_hatchet_background_context" not in d: + return None + + ctx = context.BackgroundContext.fromdict(d["_hatchet_background_context"]) + ctx.client = self._hatchet.client + return ctx + + # def _debug(self): + # data = { + # "self": repr(self), + # "metadata": self._hatchet._debug(), + # "def_proto": MessageToDict(self._to_workflow_proto()), + # "call_proto": ( + # MessageToDict(self._ctx_to_trigger_proto()) + # if self._to_trigger_proto() + # else None + # ), + # } + # return data def _run(self, ctx: BaseContext) -> str: # actually invokes the function, and serializing the output @@ -182,14 +211,17 @@ def _run(self, ctx: BaseContext) -> str: class HatchetCallable(HatchetCallableBase[P, T]): def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: - logger.trace("triggering {}", self._to_trigger_proto()) - input = {"args": args, "kwargs": kwargs} - client = self._hatchet.client - ref = client.admin.trigger_workflow( - self._hatchet.name, input=input, options=self._to_trigger_proto() - ) - logger.trace("runid: {}", ref) - return None + with context.EnsureContext(self._hatchet.client) as ctx: + trigger = self._ctx_to_trigger_proto(ctx) + input = {"args": args, "kwargs": kwargs} + client = self._hatchet.client + + logger.trace("triggering {}", trigger) + ref = client.admin.trigger_workflow( + self._hatchet.name, input=input, options=trigger + ) + logger.trace("runid: {}", ref) + return None def _run(self, action: AssignedAction) -> str: assert action.actionId == self._hatchet.action @@ -197,11 +229,21 @@ def _run(self, action: AssignedAction) -> str: input = json.loads(action.actionPayload)["input"] with context.EnsureContext(self._hatchet.client) as ctx: - ctx.set_step_run_id(action.stepRunId) - ctx.set_workflow_run_id(action.workflowRunId) - ctx.set_parent_workflow_run_id(action.parent_workflow_run_id) - with context.WithContext(ctx): - return self._hatchet.func(*input["args"], **input["kwargs"]) + assert ctx.current is None + + ctx: Optional["context.BackgroundContext"] = self._ctx_from_action(action) + if ctx is None: + info = context.RunInfo( + workflow_run_id=action.workflowRunId, + step_run_id=action.stepRunId, + name=self._hatchet.name, + namespace=self._hatchet.namespace, + ) + ctx = context.BackgroundContext( + client=self._hatchet.client, current=info, root=info + ) + with context.WithParentContext(ctx): + return self._hatchet.func(*input["args"], **input["kwargs"]) class HatchetAwaitable(HatchetCallableBase[P, Awaitable[T]]): diff --git a/hatchet_sdk/v2/hatchet.py b/hatchet_sdk/v2/hatchet.py index 310dceb4..15500848 100644 --- a/hatchet_sdk/v2/hatchet.py +++ b/hatchet_sdk/v2/hatchet.py @@ -10,8 +10,8 @@ import hatchet_sdk.v2.runtime.logging as logging import hatchet_sdk.v2.runtime.registry as registry import hatchet_sdk.v2.runtime.runner as runner -import hatchet_sdk.v2.runtime.worker as worker import hatchet_sdk.v2.runtime.runtime as runtime +import hatchet_sdk.v2.runtime.worker as worker # import hatchet_sdk.runtime.registry as hatchet_registry # import hatchet_sdk.v2.callable as v2_callable diff --git a/hatchet_sdk/v2/runtime/context.py b/hatchet_sdk/v2/runtime/context.py index 0ae12fd8..ed4bed56 100644 --- a/hatchet_sdk/v2/runtime/context.py +++ b/hatchet_sdk/v2/runtime/context.py @@ -4,11 +4,13 @@ import threading from contextlib import contextmanager from contextvars import ContextVar -from dataclasses import dataclass, field -from typing import Optional +from dataclasses import dataclass, asdict +from typing import Optional, Dict import hatchet_sdk.v2.hatchet as hatchet +from loguru import logger + def _loopid() -> Optional[int]: try: @@ -23,13 +25,14 @@ def _loopid() -> Optional[int]: @dataclass -class _RunInfo: +class RunInfo: workflow_run_id: Optional[str] = None step_run_id: Optional[str] = None namespace: str = "" name: str = "" + # TODO, pid/tid/loopid is not propagated to the engine, we are not able to restore them pid: int = os.getpid() tid: int = threading.get_ident() loopid: Optional[int] = _loopid() @@ -45,30 +48,38 @@ class BackgroundContext: # The Hatchet client is a required property. client: "hatchet.Hatchet" - current: _RunInfo - root: _RunInfo - parent: Optional[_RunInfo] = None - - def set_workflow_run_id(self, id: str): - self.current.workflow_run_id = id - - def set_step_run_id(self, id: str): - self.current.step_run_id = id + current: Optional[RunInfo] = None + root: Optional[RunInfo] = None + parent: Optional[RunInfo] = None + + def asdict(self): + """Return BackgroundContext as a serializable dict.""" + ret = dict() + if self.current: + ret["current"] = asdict(self.current) + if self.root: + ret["root"] = asdict(self.root) + if self.parent: + ret["parent"] = asdict(self.parent) + return ret - def set_parent_workflow_run_id(self, id: str): - if self.parent is None: - self.parent = _RunInfo( - workflow_run_id=id, step_run_id=None, pid=None, tid=None, loopid=None - ) - else: - self.parent.workflow_run_id = id + @staticmethod + def fromdict(d: Dict) -> "BackgroundContext": + ctx = BackgroundContext() + if "current" in d: + ctx.current = RunInfo(**(d["current"])) + if "root" in d: + ctx.root = RunInfo(**(d["root"])) + if "parent" in d: + ctx.parent = RunInfo(**(d["parent"])) + return ctx def copy(self): ret = BackgroundContext( client=self.client, - current=self.current.copy(), + current=self.current.copy() if self.current else None, parent=self.parent.copy() if self.parent else None, - root=self.root.copy(), + root=self.root.copy() if self.root else None, ) return ret @@ -90,9 +101,10 @@ def EnsureContext(client: Optional["hatchet.Hatchet"] = None): if ctx is None: cleanup = True assert client is not None - ctx = BackgroundContext(client=client, current=_RunInfo(), root=_RunInfo()) + ctx = BackgroundContext(client=client) BackgroundContext.set(ctx) try: + logger.trace("using context:\n{}", ctx) yield ctx finally: if cleanup: @@ -104,19 +116,22 @@ def WithContext(ctx: BackgroundContext): prev = BackgroundContext.get() BackgroundContext.set(ctx) try: + logger.trace("using context:\n{}", ctx) yield ctx finally: BackgroundContext.set(prev) @contextmanager -def EnterFunc(): - with EnsureContext() as current: - child = current.copy() - child.parent = current.current.copy() - child.current = _RunInfo() - with WithContext(child) as ctx: - try: - yield ctx - finally: - pass +def WithParentContext(ctx: BackgroundContext): + prev = BackgroundContext.get() + + child = ctx.copy() + child.parent = ctx.current.copy() + child.current = None + BackgroundContext.set(child) + try: + logger.trace("using context:\n{}", child) + yield child + finally: + BackgroundContext.set(prev) diff --git a/hatchet_sdk/v2/runtime/messages.py b/hatchet_sdk/v2/runtime/messages.py index 28992ab3..85ba9bb8 100644 --- a/hatchet_sdk/v2/runtime/messages.py +++ b/hatchet_sdk/v2/runtime/messages.py @@ -1,6 +1,8 @@ from dataclasses import dataclass from enum import Enum -from typing import Any +from typing import Any, Dict, Optional + +from google.protobuf.json_format import ParseDict from hatchet_sdk.contracts.dispatcher_pb2 import ( GROUP_KEY_EVENT_TYPE_COMPLETED, @@ -12,13 +14,11 @@ ActionType, AssignedAction, GroupKeyActionEventType, + StepActionEvent, StepActionEventType, WorkflowRunEventType, - StepActionEvent, ) from hatchet_sdk.worker.action_listener_process import Action -from google.protobuf.json_format import ParseDict -from typing import Optional, Dict class MessageKind(Enum): diff --git a/hatchet_sdk/v2/runtime/runner.py b/hatchet_sdk/v2/runtime/runner.py index 8a02bb5f..d0708eab 100644 --- a/hatchet_sdk/v2/runtime/runner.py +++ b/hatchet_sdk/v2/runtime/runner.py @@ -1,22 +1,23 @@ import asyncio import json import multiprocessing as mp +import time +import traceback from typing import Any, Dict, Optional, Tuple +from google.protobuf.json_format import MessageToDict +from google.protobuf.timestamp_pb2 import Timestamp from loguru import logger +import hatchet_sdk.v2.callable as callable +import hatchet_sdk.v2.hatchet as hatchet +import hatchet_sdk.v2.runtime.messages as messages from hatchet_sdk.contracts.dispatcher_pb2 import ( ActionType, + AssignedAction, StepActionEvent, StepActionEventType, ) -import hatchet_sdk.v2.callable as callable -import hatchet_sdk.v2.hatchet as hatchet -import hatchet_sdk.v2.runtime.messages as messages -from google.protobuf.timestamp_pb2 import Timestamp -import time -from google.protobuf.json_format import MessageToDict -import traceback def _timestamp(): @@ -26,53 +27,40 @@ def _timestamp(): def _format_exc(e: Exception): trace = "".join(traceback.format_exception(e)) - return "\n".join[str(e), trace] - - -class _Runner: - def __init__( - self, - registry: Dict[str, "callable.HatchetCallableBase"], - msg: "messages.Message", - ): - self.registry = registry - assert msg.kind == messages.MessageKind.ACTION - self.msg = msg - - async def run(self) -> Tuple[str, Exception]: - logger.trace("runner invoking: {}", repr(self.fn)) - try: - if isinstance(self.fn, callable.HatchetCallable): - return await asyncio.to_thread(self.fn._run, self.msg.action), None - else: - return await self.fn._run(self.msg.action), None - except asyncio.CancelledError: - raise - except Exception as e: - logger.exception(e) - return None, e - - @property - def action(self): - return self.msg.action.actionId - - @property - def fn(self): - return self.registry[self.action] + return "\n".join([str(e), trace]) + + +async def _invoke( + action: AssignedAction, registry: Dict[str, "callable.HatchetCallableBase"] +): + key = action.actionId + fn: "callable.HatchetCallableBase" = registry[key] # TODO + logger.trace("invoking: {}", repr(fn)) + try: + if isinstance(fn, callable.HatchetCallable): + return await asyncio.to_thread(fn._run, action), None + else: + return await fn._run(action), None + except asyncio.CancelledError: + raise + except Exception as e: + logger.exception(e) + return None, e class BaseRunnerLoop: def __init__( self, client: "hatchet.Hatchet", - inbound: mp.Queue, # inbound queue - outbound: mp.Queue, # outbound queue + inbound: mp.Queue, # inbound queue, not owned + outbound: mp.Queue, # outbound queue, not owned ): + logger.trace("init runner loop") self.client = client self.registry: Dict[str, "callable.HatchetCallableBase"] = ( client.registry.registry ) - self.worker_id:Optional[str] = None + self.worker_id: Optional[str] = None self.inbound = inbound self.outbound = outbound @@ -80,20 +68,20 @@ def __init__( self.looptask: Optional[asyncio.Task] = None # a dict from StepRunId to its tasks - self.runners: Dict[str, asyncio.Task] = dict() + self.tasks: Dict[str, asyncio.Task] = dict() def start(self): - self.looptask = asyncio.create_task(self.loop(), name="runnerloop") + logger.debug("runner loop started") + self.looptask = asyncio.create_task(self.loop(), name="runner loop") async def shutdown(self): - logger.info("shutting down runner loop") - t = asyncio.gather(*self.runners.values(), self.looptask) - self.outbound.close() + logger.trace("shutting down runner loop") + t = asyncio.gather(*self.tasks.values(), self.looptask) t.cancel() try: await t except asyncio.CancelledError: - logger.info("bye") + logger.debug("bye") async def loop(self): while True: @@ -109,17 +97,18 @@ async def loop(self): def on_run(self, msg: "messages.Message"): async def task(): + logger.trace("running {}", msg.action.stepRunId) try: await self.emit_started(msg) - result, e = await _Runner(self.registry, msg).run() + result, e = await _invoke(msg.action, self.registry) if e is None: await self.emit_finished(msg, result) else: await self.emit_failed(msg, _format_exc(e)) finally: - del self.runners[msg.action.stepRunId] + del self.tasks[msg.action.stepRunId] - self.runners[msg.action.stepRunId] = asyncio.create_task( + self.tasks[msg.action.stepRunId] = asyncio.create_task( task(), name=msg.action.stepRunId ) @@ -170,10 +159,12 @@ async def emit_failed(self, msg: "messages.Message", payload: str): ) async def send(self, msg: "messages.Message"): - logger.trace("sending: {}", msg) + logger.trace("send:\n{}", msg) await asyncio.to_thread(self.outbound.put, msg) async def next(self) -> "messages.Message": - msg = await asyncio.to_thread(self.inbound.get) # raise EOFError if the queue is closed - logger.trace("recv: {}", msg) + msg = await asyncio.to_thread( + self.inbound.get + ) # raise EOFError if the queue is closed + logger.trace("recv:\n{}", msg) return msg diff --git a/hatchet_sdk/v2/runtime/runtime.py b/hatchet_sdk/v2/runtime/runtime.py index 9dea07c1..73e9bf5d 100644 --- a/hatchet_sdk/v2/runtime/runtime.py +++ b/hatchet_sdk/v2/runtime/runtime.py @@ -1,13 +1,16 @@ +import asyncio +import multiprocessing as mp + +from loguru import logger + import hatchet_sdk.v2.hatchet as hatchet -import hatchet_sdk.v2.runtime.worker as worker import hatchet_sdk.v2.runtime.runner as runner -import multiprocessing as mp -import asyncio +import hatchet_sdk.v2.runtime.worker as worker class Runtime: - def __init__(self, client: "hatchet.Hatchet", options: "worker.WorkerOptions"): + logger.trace("init runtime") self.events = mp.Queue() self.actions = mp.Queue() self.worker = worker.Worker( @@ -18,13 +21,20 @@ def __init__(self, client: "hatchet.Hatchet", options: "worker.WorkerOptions"): ) async def start(self): + logger.trace("starting runtime") self.runner.start() await self.worker.start() self.runner.worker_id = self.worker.id + logger.debug("runtime started") return self.worker.id async def shutdown(self): + logger.trace("shutting down runtime") await self.worker.shutdown() - await self.runner.shutdown() self.actions.close() + self.actions.join_thread() + + await self.runner.shutdown() self.events.close() + self.events.join_thread() + logger.debug("bye") diff --git a/hatchet_sdk/v2/runtime/worker.py b/hatchet_sdk/v2/runtime/worker.py index fa8870d3..cdf2262e 100644 --- a/hatchet_sdk/v2/runtime/worker.py +++ b/hatchet_sdk/v2/runtime/worker.py @@ -1,5 +1,7 @@ import asyncio import multiprocessing as mp +import os +import threading import time from collections.abc import AsyncGenerator from concurrent.futures import ThreadPoolExecutor @@ -20,6 +22,7 @@ ActionType, AssignedAction, HeartbeatRequest, + StepActionEvent, WorkerLabels, WorkerListenRequest, WorkerRegisterRequest, @@ -59,6 +62,7 @@ class WorkerStatus(Enum): class _HeartBeater: def __init__(self, worker: "Worker"): + logger.debug("init heartbeater") self.worker = worker self.last_heartbeat: int = -1 # unix epoch in seconds self.stub = DispatcherStub( @@ -75,14 +79,15 @@ async def heartbeat(self): now = int(time.time()) proto = HeartbeatRequest( workerId=self.worker.id, - heartbeatAt=timestamp_pb2.Timestamp(seconds=now), + heartbeatAt=timestamp_pb2.Timestamp(seconds=now), # TODO ) try: - resp = self.stub.Heartbeat( + _ = self.stub.Heartbeat( proto, timeout=5, metadata=self.worker._grpc_metadata() ) - logger.trace("heartbeat: {}", MessageToJson(resp)) + logger.trace("heartbeat") except grpc.RpcErrors: + # TODO self.error += 1 if self.last_heartbeat < 0: @@ -95,18 +100,19 @@ async def heartbeat(self): await asyncio.sleep(self.worker.options.heartbeat) finally: - logger.info("shutting down heartbeater") + logger.debug("bye") -class _Listner: +class _ActionListner: def __init__(self, worker: "Worker"): + logger.debug("init action listener") self.worker = worker self.attempt = 0 self.stub = DispatcherStub( connection.new_conn(self.worker.client.config, aio=True) ) - async def listen(self) -> AsyncGenerator[AssignedAction]: + async def listen(self): resp = None try: # It will exit the loop when asyncio.CancelledError is @@ -117,9 +123,11 @@ async def listen(self) -> AsyncGenerator[AssignedAction]: resp = self.stub.ListenV2( proto, metadata=self.worker._grpc_metadata() ) - logger.trace("listening") + logger.trace("connection established") async for event in resp: - yield event + msg = messages.Message(_action=MessageToDict(event)) + logger.trace("assigned action:\n{}", msg) + await asyncio.to_thread(self.queue.put, msg) resp = None self.attempt += 1 @@ -129,9 +137,44 @@ async def listen(self) -> AsyncGenerator[AssignedAction]: # TODO: expotential backoff, retry limit, etc finally: - logger.info("shutting down listener") if resp: resp.cancel() + logger.debug("bye") + + @property + def queue(self): + return self.worker.outbound + + +class _EventListner: + def __init__(self, worker: "Worker"): + logger.debug("init event listener") + self.worker = worker + self.stub = DispatcherStub(connection.new_conn(self.worker.client.config)) + + async def listen(self): + try: + while True: + msg: "messages.Message" = await asyncio.to_thread(self.queue.get) + logger.trace("event:\n{}", msg) + assert msg.kind in [messages.MessageKind.STEP_EVENT] + match msg.kind: + case messages.MessageKind.STEP_EVENT: + await self.on_step_event(msg.step_event) + case _: + raise NotImplementedError(msg.kind) + finally: + logger.debug("bye") + + async def on_step_event(self, e: StepActionEvent): + # TODO: need retry + logger.trace("emit step action:\n{}", MessageToDict(e)) + resp = await asyncio.to_thread(self.stub.SendStepActionEvent, e, metadata=self.worker._grpc_metadata()) + logger.trace(resp) + + @property + def queue(self): + return self.worker.inbound class Worker: @@ -142,6 +185,7 @@ def __init__( outbound: mp.Queue, options: WorkerOptions, ): + logger.debug("init worker") self.options = options self.client = client self.status = WorkerStatus.UNKNOWN @@ -151,47 +195,52 @@ def __init__( self._heartbeater = _HeartBeater(self) self._heartbeater_task: Optional[asyncio.Task] = None - self._listener = _Listner(self) - self._listener_task: Optional[asyncio.Task] = None + self._action_listener = _ActionListner(self) + self._action_listener_task: Optional[asyncio.Task] = None + self._event_listner = _EventListner(self) + self._event_listner_task: Optional[asyncio.Task] = None def _register(self) -> str: + req = self._to_register_proto() + logger.trace("registering worker:\n{}", req) resp: WorkerRegisterResponse = self.client.dispatcher.client.Register( - self._to_register_proto(), + req, timeout=30, metadata=self._grpc_metadata(), ) - logger.debug(f"worker registered: {MessageToDict(resp)}") + logger.debug("worker registered:\n{}", MessageToDict(resp)) self.id = resp.workerId self.status = WorkerStatus.REGISTERED return resp.workerId async def start(self): + logger.trace("starting worker") self._register() self._heartbeat_task = asyncio.create_task( - self._heartbeater.heartbeat(), name="heartbeat" + self._heartbeater.heartbeat(), name="heartbeater" + ) + self._event_listner_task = asyncio.create_task( + self._event_listner.listen(), name="event_listener" ) - self._listener_task = asyncio.create_task( - self._onevent(self._listener.listen()), name="listner" + self._action_listener_task = asyncio.create_task( + self._action_listener.listen(), name="action_listener" ) while True: if self._heartbeater.last_heartbeat > 0: + logger.debug("worker started: {}", self.id) return await asyncio.sleep(0.1) async def shutdown(self): - tg: asyncio.Future = asyncio.gather(self._heartbeat_task, self._listener_task) + logger.trace("shutting down worker {}", self.id) + tg: asyncio.Future = asyncio.gather( + self._heartbeat_task, self._action_listener_task, self._event_listner_task + ) tg.cancel() - self.outbound.close() try: await tg except asyncio.CancelledError: - logger.info("bye") - - async def _onevent(self, agen: AsyncGenerator[AssignedAction]): - async for action in agen: - msg = messages.Message(_action=MessageToDict(action)) - await asyncio.to_thread(self.outbound.put, msg) - logger.trace(MessageToDict(action)) + logger.debug("bye") def _grpc_metadata(self): return [("authorization", f"bearer {self.client.config.token}")] diff --git a/tests/v2/test_worker.py b/tests/v2/test_worker.py index b06a373c..586dfe15 100644 --- a/tests/v2/test_worker.py +++ b/tests/v2/test_worker.py @@ -21,17 +21,22 @@ @hatchet.function() def foo(): - print("HAHAHA") - pass + print("Foo") + bar("from foo") + return "foo" + + +@hatchet.function() +def bar(x): + print(x) + return "bar" @pytest.mark.asyncio async def test_worker(): - worker = hatchet.worker(WorkerOptions(name="worker", actions=["default:foo"])) + worker = hatchet.worker(WorkerOptions(name="worker", actions=["default:foo", "default:bar"])) await worker.start() - hatchet._runner.start() foo() await asyncio.sleep(10) await worker.shutdown() - await hatchet._runner.shutdown() return None From 2c6d6b46476705cd719a2b4989037f5385cf46f4 Mon Sep 17 00:00:00 2001 From: Hanwen Wu Date: Sun, 15 Sep 2024 14:51:54 -0400 Subject: [PATCH 07/12] context propagation works now, working on result listening --- hatchet_sdk/v2/callable.py | 123 ++++++++++++++++++------------ hatchet_sdk/v2/hatchet.py | 7 +- hatchet_sdk/v2/runtime/context.py | 39 ++++++---- hatchet_sdk/v2/runtime/worker.py | 4 +- 4 files changed, 109 insertions(+), 64 deletions(-) diff --git a/hatchet_sdk/v2/callable.py b/hatchet_sdk/v2/callable.py index 27fc9cea..44846803 100644 --- a/hatchet_sdk/v2/callable.py +++ b/hatchet_sdk/v2/callable.py @@ -6,7 +6,7 @@ from collections.abc import Awaitable, Callable # from contextvars import ContextVar, copy_context -from dataclasses import dataclass, asdict +from dataclasses import asdict, dataclass, field # from datetime import timedelta from typing import ( @@ -43,6 +43,7 @@ CreateWorkflowVersionOpts, DesiredWorkerLabels, StickyStrategy, + TriggerWorkflowRequest, WorkflowConcurrencyOpts, WorkflowKind, ) @@ -72,6 +73,33 @@ def _sourceloc(fn) -> str: return "" +@dataclass +class _CallableInput: + args: List[Any] = field(default_factory=list) + kwargs: Dict[str, Any] = field(default_factory=dict) + + def dumps(self): + return json.dumps(asdict(self)) + + @staticmethod + def loads(s: str): + # NOTE: AssignedAction.actionPayload looks like the following + # '{"input": , "parents": {}, "overrides": {}, "user_data": {}, "triggered_by": "manual"}' + return _CallableInput(**(json.loads(s)["input"])) + + +@dataclass +class _CallableOutput(Generic[T]): + output: Optional[T] = None + + def dumps(self): + return json.dumps(asdict(self)) + + @staticmethod + def loads(s: str): + return _CallableOutput(**json.loads(s)) + + class HatchetCallableBase(Generic[P, T]): def __init__( self, @@ -97,23 +125,7 @@ def __init__( def _to_workflow_proto(self) -> CreateWorkflowVersionOpts: options = self._hatchet.options - - # if self.function_on_failure is not None: - # on_failure_job = CreateWorkflowJobOpts( - # name=self.function_name + "-on-failure", - # steps=[ - # self.function_on_failure.to_step(), - # ], - # ) - # # concurrency: WorkflowConcurrencyOpts | None = None - # if self.function_concurrency is not None: - # self.function_concurrency.set_namespace(self.function_namespace) - # concurrency = WorkflowConcurrencyOpts( - # action=self.function_concurrency.get_action_name(), - # max_runs=self.function_concurrency.max_runs, - # limit_strategy=self.function_concurrency.limit_strategy, - # ) - + # handle concurrency function and on failure function workflow = CreateWorkflowVersionOpts( name=self._hatchet.name, kind=WorkflowKind.DURABLE if options.durable else WorkflowKind.FUNCTION, @@ -153,25 +165,35 @@ def _to_step_proto(self) -> CreateWorkflowStepOpts: ) return step - def _ctx_to_trigger_proto( + def _encode_context( self, ctx: "context.BackgroundContext" - ) -> Optional[TriggerWorkflowOptions]: + ) -> TriggerWorkflowRequest: + trigger = TriggerWorkflowRequest( + additional_metadata=json.dumps( + {"_hatchet_background_context": ctx.asdict()} + ), + ) + # We are not in any valid Hatchet context. This means we're the root. if ctx.current is None: - return None + return trigger # Otherwise, the current context is the parent. assert ctx.current is not None - trigger: TriggerWorkflowOptions = { - "parent_id": ctx.current.workflow_run_id, - "parent_step_run_id": ctx.current.step_run_id, - "additional_metadata": json.dumps( - {"_hatchet_background_context": ctx.asdict()} - ), - } + trigger.parent_id = ctx.current.workflow_run_id + trigger.parent_step_run_id = ctx.current.step_run_id + trigger.child_index = 0 # TODO: what is this return trigger - def _ctx_from_action( + def _to_trigger_proto( + self, ctx: "context.BackgroundContext", inputs: _CallableInput + ) -> TriggerWorkflowRequest: + # NOTE: serialization error will be raised as TypeError + req = TriggerWorkflowRequest(name=self._hatchet.name, input=inputs.dumps()) + req.MergeFrom(self._encode_context(ctx)) + return req + + def _decode_context( self, action: AssignedAction ) -> Optional["context.BackgroundContext"]: if not action.additional_metadata: @@ -187,7 +209,9 @@ def _ctx_from_action( if "_hatchet_background_context" not in d: return None - ctx = context.BackgroundContext.fromdict(d["_hatchet_background_context"]) + ctx = context.BackgroundContext.fromdict( + client=self._hatchet.client, data=d["_hatchet_background_context"] + ) ctx.client = self._hatchet.client return ctx @@ -212,38 +236,43 @@ def _run(self, ctx: BaseContext) -> str: class HatchetCallable(HatchetCallableBase[P, T]): def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: with context.EnsureContext(self._hatchet.client) as ctx: - trigger = self._ctx_to_trigger_proto(ctx) - input = {"args": args, "kwargs": kwargs} + trigger = self._to_trigger_proto( + ctx, inputs=_CallableInput(args=args, kwargs=kwargs) + ) + logger.trace("triggering\n{}", MessageToDict(trigger)) client = self._hatchet.client - - logger.trace("triggering {}", trigger) - ref = client.admin.trigger_workflow( - self._hatchet.name, input=input, options=trigger + ref = client.admin.client.TriggerWorkflow( + trigger, metadata=self._hatchet.client._grpc_metadata() ) logger.trace("runid: {}", ref) + + # TODO: wait for the run. return None def _run(self, action: AssignedAction) -> str: assert action.actionId == self._hatchet.action - logger.trace("invoking {}", action.actionId) - input = json.loads(action.actionPayload)["input"] - + logger.trace("invoking:\n{}", MessageToDict(action)) with context.EnsureContext(self._hatchet.client) as ctx: assert ctx.current is None - ctx: Optional["context.BackgroundContext"] = self._ctx_from_action(action) - if ctx is None: - info = context.RunInfo( + parent: Optional["context.BackgroundContext"] = self._decode_context(action) + with context.WithParentContext(parent) as ctx: + assert ctx.current is None + ctx.current = context.RunInfo( workflow_run_id=action.workflowRunId, step_run_id=action.stepRunId, name=self._hatchet.name, namespace=self._hatchet.namespace, ) - ctx = context.BackgroundContext( - client=self._hatchet.client, current=info, root=info - ) - with context.WithParentContext(ctx): - return self._hatchet.func(*input["args"], **input["kwargs"]) + if ctx.root is None: + ctx.root = ctx.current.copy() + with context.WithContext(ctx): + inputs = _CallableInput.loads(action.actionPayload) + output = _CallableOutput( + output=self._hatchet.func(*inputs.args, **inputs.kwargs) + ) + logger.trace("output:\n{}", output) + return output.dumps() class HatchetAwaitable(HatchetCallableBase[P, Awaitable[T]]): diff --git a/hatchet_sdk/v2/hatchet.py b/hatchet_sdk/v2/hatchet.py index 15500848..8ae057b2 100644 --- a/hatchet_sdk/v2/hatchet.py +++ b/hatchet_sdk/v2/hatchet.py @@ -2,7 +2,7 @@ import functools import inspect import multiprocessing as mp -from typing import Callable, Dict, List, Optional, ParamSpec, TypeVar +from typing import Callable, Dict, Tuple, List, Optional, ParamSpec, TypeVar import hatchet_sdk.hatchet as v1 import hatchet_sdk.v2.callable as callable @@ -103,7 +103,10 @@ def inner(func: Callable[P, T]) -> "callable.HatchetCallable[P, T]": def worker(self, options: "worker.WorkerOptions") -> "runtime.Runtime": return runtime.Runtime(self, options) - # def durable( + def _grpc_metadata(self) -> List[Tuple]: + return [("authorization", f"bearer {self.config.token}")] + + # def durable(s # name: str = "", # auto_register: bool = True, # on_events: list | None = None, diff --git a/hatchet_sdk/v2/runtime/context.py b/hatchet_sdk/v2/runtime/context.py index ed4bed56..d2086140 100644 --- a/hatchet_sdk/v2/runtime/context.py +++ b/hatchet_sdk/v2/runtime/context.py @@ -4,13 +4,13 @@ import threading from contextlib import contextmanager from contextvars import ContextVar -from dataclasses import dataclass, asdict -from typing import Optional, Dict - -import hatchet_sdk.v2.hatchet as hatchet +from dataclasses import asdict, dataclass +from typing import Dict, Optional from loguru import logger +import hatchet_sdk.v2.hatchet as hatchet + def _loopid() -> Optional[int]: try: @@ -64,14 +64,14 @@ def asdict(self): return ret @staticmethod - def fromdict(d: Dict) -> "BackgroundContext": - ctx = BackgroundContext() - if "current" in d: - ctx.current = RunInfo(**(d["current"])) - if "root" in d: - ctx.root = RunInfo(**(d["root"])) - if "parent" in d: - ctx.parent = RunInfo(**(d["parent"])) + def fromdict(client: "hatchet.Hatchet", data: Dict) -> "BackgroundContext": + ctx = BackgroundContext(client=client) + if "current" in data: + ctx.current = RunInfo(**(data["current"])) + if "root" in data: + ctx.root = RunInfo(**(data["root"])) + if "parent" in data: + ctx.parent = RunInfo(**(data["parent"])) return ctx def copy(self): @@ -124,14 +124,25 @@ def WithContext(ctx: BackgroundContext): @contextmanager def WithParentContext(ctx: BackgroundContext): + """Use the given context as the parent. + + Note that this is to be used in the following pattern: + + with WithParentContext(parent) as ctx: + ctx.current = ... + with WithContext(ctx): + # code in the correct context here + + """ prev = BackgroundContext.get() + # NOTE: ctx.current could be None, which means there's no parent. + child = ctx.copy() - child.parent = ctx.current.copy() + child.parent = ctx.current.copy() if ctx.current else None child.current = None BackgroundContext.set(child) try: - logger.trace("using context:\n{}", child) yield child finally: BackgroundContext.set(prev) diff --git a/hatchet_sdk/v2/runtime/worker.py b/hatchet_sdk/v2/runtime/worker.py index cdf2262e..8a110e0d 100644 --- a/hatchet_sdk/v2/runtime/worker.py +++ b/hatchet_sdk/v2/runtime/worker.py @@ -169,7 +169,9 @@ async def listen(self): async def on_step_event(self, e: StepActionEvent): # TODO: need retry logger.trace("emit step action:\n{}", MessageToDict(e)) - resp = await asyncio.to_thread(self.stub.SendStepActionEvent, e, metadata=self.worker._grpc_metadata()) + resp = await asyncio.to_thread( + self.stub.SendStepActionEvent, e, metadata=self.worker._grpc_metadata() + ) logger.trace(resp) @property From d0ea7c609146ac23b077fe763c9df7fd8b641e3f Mon Sep 17 00:00:00 2001 From: Hanwen Wu Date: Tue, 17 Sep 2024 22:29:08 -0400 Subject: [PATCH 08/12] iterator adapters, more WIP for workflow run listener --- hatchet_sdk/v2/callable.py | 45 +++++-- hatchet_sdk/v2/hatchet.py | 9 +- hatchet_sdk/v2/runtime/connection.py | 48 ++++++- hatchet_sdk/v2/runtime/context.py | 22 ++-- hatchet_sdk/v2/runtime/listeners.py | 188 +++++++++++++++++++++++++++ hatchet_sdk/v2/runtime/messages.py | 22 ++++ hatchet_sdk/v2/runtime/utils.py | 44 +++++++ hatchet_sdk/v2/runtime/worker.py | 54 +++++++- hatchet_sdk/worker/worker.py | 6 +- tests/v2/test_listeners.py | 36 +++++ tests/v2/test_utils.py | 35 +++++ tests/v2/test_worker.py | 13 +- 12 files changed, 489 insertions(+), 33 deletions(-) create mode 100644 hatchet_sdk/v2/runtime/listeners.py create mode 100644 hatchet_sdk/v2/runtime/utils.py create mode 100644 tests/v2/test_listeners.py create mode 100644 tests/v2/test_utils.py diff --git a/hatchet_sdk/v2/callable.py b/hatchet_sdk/v2/callable.py index 44846803..7f220e47 100644 --- a/hatchet_sdk/v2/callable.py +++ b/hatchet_sdk/v2/callable.py @@ -3,7 +3,8 @@ import asyncio import inspect import json -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Iterator +from concurrent.futures.thread import ThreadPoolExecutor # from contextvars import ContextVar, copy_context from dataclasses import asdict, dataclass, field @@ -14,6 +15,7 @@ Dict, ForwardRef, Generic, + Iterable, List, Literal, Optional, @@ -35,7 +37,12 @@ from hatchet_sdk.clients.admin import TriggerWorkflowOptions from hatchet_sdk.context import Context from hatchet_sdk.context.context import BaseContext, Context, ContextAioImpl -from hatchet_sdk.contracts.dispatcher_pb2 import AssignedAction +from hatchet_sdk.contracts.dispatcher_pb2 import ( + AssignedAction, + SubscribeToWorkflowRunsRequest, + WorkflowRunEvent, + WorkflowRunEventType, +) from hatchet_sdk.contracts.workflows_pb2 import ( CreateStepRateLimit, CreateWorkflowJobOpts, @@ -44,6 +51,7 @@ DesiredWorkerLabels, StickyStrategy, TriggerWorkflowRequest, + TriggerWorkflowResponse, WorkflowConcurrencyOpts, WorkflowKind, ) @@ -125,7 +133,7 @@ def __init__( def _to_workflow_proto(self) -> CreateWorkflowVersionOpts: options = self._hatchet.options - # handle concurrency function and on failure function + # TODO: handle concurrency function and on failure function workflow = CreateWorkflowVersionOpts( name=self._hatchet.name, kind=WorkflowKind.DURABLE if options.durable else WorkflowKind.FUNCTION, @@ -228,7 +236,7 @@ def _decode_context( # } # return data - def _run(self, ctx: BaseContext) -> str: + def _run(self, action: AssignedAction) -> str: # actually invokes the function, and serializing the output raise NotImplementedError @@ -241,13 +249,34 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ) logger.trace("triggering\n{}", MessageToDict(trigger)) client = self._hatchet.client - ref = client.admin.client.TriggerWorkflow( + ref: TriggerWorkflowResponse = client.admin.client.TriggerWorkflow( trigger, metadata=self._hatchet.client._grpc_metadata() ) logger.trace("runid: {}", ref) - - # TODO: wait for the run. - return None + # TODO: look into timeouts for Future.result() + return self._hatchet.client.executor.submit( + self._result, ref.workflow_run_id + ).result() + + def _result(self, run_id: str) -> Optional[WorkflowRunEvent]: + # NOTE: SubscribeToWorkflowRuns is a stream-stream RPC. + request = SubscribeToWorkflowRunsRequest(workflowRunId=run_id) + stream: Iterator[WorkflowRunEvent] = ( + self._hatchet.client.dispatcher.client.SubscribeToWorkflowRuns( + iter((request,)), + metadata=self._hatchet.client._grpc_metadata(), + ) + ) + logger.debug(list(stream)) + for event in stream: + logger.trace("workflow run events:\n{}", MessageToDict(event)) + if event.eventType == WorkflowRunEventType.WORKFLOW_RUN_EVENT_TYPE_FINISHED: + return event.results + else: + logger.warning( + "unexpected workflow run events:\n{}", MessageToDict(event) + ) + break def _run(self, action: AssignedAction) -> str: assert action.actionId == self._hatchet.action diff --git a/hatchet_sdk/v2/hatchet.py b/hatchet_sdk/v2/hatchet.py index 8ae057b2..62bbca5d 100644 --- a/hatchet_sdk/v2/hatchet.py +++ b/hatchet_sdk/v2/hatchet.py @@ -2,7 +2,8 @@ import functools import inspect import multiprocessing as mp -from typing import Callable, Dict, Tuple, List, Optional, ParamSpec, TypeVar +from concurrent.futures import ThreadPoolExecutor +from typing import Callable, Dict, List, Optional, ParamSpec, Tuple, TypeVar import hatchet_sdk.hatchet as v1 import hatchet_sdk.v2.callable as callable @@ -12,6 +13,7 @@ import hatchet_sdk.v2.runtime.runner as runner import hatchet_sdk.v2.runtime.runtime as runtime import hatchet_sdk.v2.runtime.worker as worker +import hatchet_sdk.v2.runtime.context as context # import hatchet_sdk.runtime.registry as hatchet_registry # import hatchet_sdk.v2.callable as v2_callable @@ -28,6 +30,7 @@ # from ..worker import Worker + T = TypeVar("T") P = ParamSpec("P") @@ -37,6 +40,7 @@ def __init__( self, config: config.ClientConfig = config.ClientConfig(), debug=False, + executor: ThreadPoolExecutor = ThreadPoolExecutor(), ): # ensure a event loop is created before gRPC try: @@ -49,6 +53,9 @@ def __init__( defaults=config, debug=debug, ) + self.executor = executor + + context.ensure_background_context(client=self) @property def admin(self): diff --git a/hatchet_sdk/v2/runtime/connection.py b/hatchet_sdk/v2/runtime/connection.py index 76331dc5..cfd313b4 100644 --- a/hatchet_sdk/v2/runtime/connection.py +++ b/hatchet_sdk/v2/runtime/connection.py @@ -1 +1,47 @@ -from hatchet_sdk.connection import * +import hatchet_sdk.connection as v1 +import contextvars as cv +import grpc +import grpc.aio +from typing import Optional + +import hatchet_sdk.v2.runtime.context as context + + +_aio_channel_cv: cv.ContextVar[Optional[grpc.aio.Channel]] = cv.ContextVar( + "hatchet_background_aio_channel", default=None +) +_channel_cv: cv.ContextVar[Optional[grpc.Channel]] = cv.ContextVar( + "hatchet_background_channel", default=None +) + + +def ensure_background_channel() -> grpc.Channel: + ctx = context.ensure_background_context(client=None) + channel: grpc.Channel = _channel_cv.get() + if channel is None: + channel = v1.new_conn(ctx.client.config, aio=False) + _channel_cv.set(channel) + return channel + + +def ensure_background_achannel() -> grpc.aio.Channel: + ctx = context.ensure_background_context(client=None) + achannel: grpc.aio.Channel = _aio_channel_cv.get() + if achannel is None: + achannel = v1.new_conn(ctx.client.config, aio=True) + _aio_channel_cv.set(achannel) + return achannel + + +def reset_background_channel(): + c = _channel_cv.get() + if c is not None: + c.close() + _channel_cv.set(None) + + +async def reset_background_achannel(): + c: grpc.aio.Channel = _aio_channel_cv.get() + if c is not None: + await c.close() + _aio_channel_cv.set(None) diff --git a/hatchet_sdk/v2/runtime/context.py b/hatchet_sdk/v2/runtime/context.py index d2086140..2c35dacb 100644 --- a/hatchet_sdk/v2/runtime/context.py +++ b/hatchet_sdk/v2/runtime/context.py @@ -94,21 +94,15 @@ def get() -> Optional["BackgroundContext"]: return _ctxvar.get() -@contextmanager -def EnsureContext(client: Optional["hatchet.Hatchet"] = None): - cleanup = False +def ensure_background_context( + client: Optional["hatchet.Hatchet"] = None, +) -> BackgroundContext: ctx = BackgroundContext.get() if ctx is None: - cleanup = True assert client is not None ctx = BackgroundContext(client=client) BackgroundContext.set(ctx) - try: - logger.trace("using context:\n{}", ctx) - yield ctx - finally: - if cleanup: - BackgroundContext.set(None) + return ctx @contextmanager @@ -124,15 +118,15 @@ def WithContext(ctx: BackgroundContext): @contextmanager def WithParentContext(ctx: BackgroundContext): - """Use the given context as the parent. - + """Use the given context as the parent. + Note that this is to be used in the following pattern: - + with WithParentContext(parent) as ctx: ctx.current = ... with WithContext(ctx): # code in the correct context here - + """ prev = BackgroundContext.get() diff --git a/hatchet_sdk/v2/runtime/listeners.py b/hatchet_sdk/v2/runtime/listeners.py new file mode 100644 index 00000000..a6173a9a --- /dev/null +++ b/hatchet_sdk/v2/runtime/listeners.py @@ -0,0 +1,188 @@ +import asyncio +from asyncio.taskgroups import TaskGroup +import multiprocessing as mp +import os +import threading +import time +from collections.abc import AsyncGenerator, Generator, AsyncIterator +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, Generic, List, Optional, Set, TypeVar, Literal +from contextlib import suppress + + +import grpc +from google.protobuf import timestamp_pb2 +from google.protobuf.json_format import MessageToDict, MessageToJson +from loguru import logger + +import hatchet_sdk.contracts.dispatcher_pb2 +import hatchet_sdk.v2.hatchet as hatchet +import hatchet_sdk.v2.runtime.connection as connection +import hatchet_sdk.v2.runtime.context as context +import hatchet_sdk.v2.runtime.messages as messages +import hatchet_sdk.v2.runtime.worker as worker +from hatchet_sdk.contracts.dispatcher_pb2 import ( + ActionType, + AssignedAction, + HeartbeatRequest, + StepActionEvent, + SubscribeToWorkflowRunsRequest, + WorkerLabels, + WorkerListenRequest, + WorkerRegisterRequest, + WorkerRegisterResponse, + WorkerUnsubscribeRequest, + WorkflowRunEventType, + StepRunResult, + WorkflowRunEvent, +) +from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub + + +T = TypeVar("T") + + +# class _GrpcAioListnerBase(Generic[T]): +# def __init__(self): +# self.attempt = 0 +# self.interrupt = False + +# def channel(self): +# raise NotImplementedError() + +# def stub(self, channel): +# raise NotImplementedError + +# def request(self, stub): +# raise NotImplementedError() + +# def interrupt(self): +# self.interrupt = True + +# async def listen(self) -> AsyncGenerator[T]: +# while True: +# stub: DispatcherStub = self.stub() +# stub.ListenV2 + +# stream = None +# try: +# stream = self.request(stub) +# async for msg in stream: +# if not self.interrupt: +# yield msg +# except grpc.aio.AioRpcError as e: +# logger.warning(e) +# finally: +# if stream is not None: +# stream.cancel() +# self.interrupt = False +# self.attempt += 1 + +# def read(self) -> Generator[T]: +# stub = self.stub() +# stream = self.request(stub) + + + + +class WorkflowRunEventListener: + + @dataclass + class Sub: + id: str + run_id: str + future: asyncio.Future[List[StepRunResult]] + + def __hash__(self): + return hash(self.id) + + def __init__(self): + logger.trace("init workflow run event listener") + self._token = context.ensure_background_context(None).client.config.token + + # the set of active subscriptions + self._subs: Set[WorkflowRunEventListener.Sub] = set() + + # counter used for generating subscription ids + # not thread safe + self._counter = 0 + + # index from run id to subscriptions + self._by_run_id: Dict[str, WorkflowRunEventListener.Sub] = dict() + + # queue used for iterating requests + # must be created inside the loop + self._q_request: asyncio.Queue[SubscribeToWorkflowRunsRequest] = asyncio.Queue() + + self._events_agen: AsyncGenerator[WorkflowRunEvent] = self._events() + + async def loop(self): + await self._events_agen.aclose() + async for event in self._events_agen: + assert ( + event.eventType == WorkflowRunEventType.WORKFLOW_RUN_EVENT_TYPE_FINISHED + ) + self._by_run_id[event.workflowRunId].future.set_result(list(event.results)) + self._unsubscribe(event.workflowRunId) + + async def _events(self) -> AsyncGenerator[WorkflowRunEvent]: + + # keep trying until asyncio.CancelledError is raised into this coroutine + # TODO: handle retry, backoff, etc. + stub = DispatcherStub(channel=connection.ensure_background_achannel()) + agen = self._requests() + while True: + try: + stream: grpc.aio.StreamStreamCall[ + SubscribeToWorkflowRunsRequest, WorkflowRunEvent + ] = stub.SubscribeToWorkflowRuns( + agen, + metadata=[("authorization", f"bearer {self._token}")], + ) + logger.trace("stream established") + async for event in stream: + logger.trace(event) + yield event + + except grpc.aio.AioRpcError as e: + logger.exception(e) + pass + + await self._resubscribe() + + async def _requests(self) -> AsyncGenerator[SubscribeToWorkflowRunsRequest]: + while True: + req = await self._q_request.get() + logger.trace("client streaming req to server: {}", MessageToDict(req)) + yield req + + async def _resubscribe(self): + logger.trace("re-subscribing all") + async with asyncio.TaskGroup() as tg: + for id in self._by_run_id.keys(): + tg.create_task( + self._q_request.put( + SubscribeToWorkflowRunsRequest(workflowRunId=id) + ) + ) + + async def subscribe(self, run_id: str) -> "WorkflowRunEventListener.Sub": + if run_id in self._by_run_id: + return + logger.trace("subscribing: {}", run_id) + await self._q_request.put(SubscribeToWorkflowRunsRequest(workflowRunId=run_id)) + sub = self.Sub(id=self._counter, run_id=run_id, future=asyncio.Future()) + self._subs.add(sub) + self._by_run_id[run_id] = sub + self._counter += 1 + return sub + + def _unsubscribe(self, run_id: str): + logger.trace("unsubscribing {}", run_id) + sub = self._by_run_id.get(run_id, None) + if sub is None: + return + self._subs.remove(sub) + del self._by_run_id[run_id] diff --git a/hatchet_sdk/v2/runtime/messages.py b/hatchet_sdk/v2/runtime/messages.py index 85ba9bb8..d4e01ac9 100644 --- a/hatchet_sdk/v2/runtime/messages.py +++ b/hatchet_sdk/v2/runtime/messages.py @@ -16,6 +16,8 @@ GroupKeyActionEventType, StepActionEvent, StepActionEventType, + SubscribeToWorkflowRunsRequest, + WorkflowRunEvent, WorkflowRunEventType, ) from hatchet_sdk.worker.action_listener_process import Action @@ -25,6 +27,8 @@ class MessageKind(Enum): UNKNOWN = 0 ACTION = 1 STEP_EVENT = 2 + WORKFLOW_RUN_EVENT = 3 + SUBSCRIBE_TO_WORKFLOW_RUN = 4 @dataclass @@ -33,6 +37,8 @@ class Message: _action: Optional[Dict] = None _step_event: Optional[Dict] = None + _workflow_run_event: Optional[Dict] = None + _subscribe_to_workflow_run: Optional[Dict] = None @property def kind(self) -> MessageKind: @@ -40,6 +46,10 @@ def kind(self) -> MessageKind: return MessageKind.ACTION if self._step_event is not None: return MessageKind.STEP_EVENT + if self._workflow_run_event is not None: + return MessageKind.WORKFLOW_RUN_EVENT + if self._subscribe_to_workflow_run is not None: + return MessageKind.SUBSCRIBE_TO_WORKFLOW_RUN return MessageKind.UNKNOWN @property @@ -53,3 +63,15 @@ def step_event(self) -> StepActionEvent: assert self._step_event is not None ret = StepActionEvent() return ParseDict(self._step_event, ret) + + @property + def workflow_run_event(self) -> WorkflowRunEvent: + assert self._workflow_run_event is not None + ret = WorkflowRunEvent() + return ParseDict(self._workflow_run_event, ret) + + @property + def subscribe_to_workflow_run(self) -> SubscribeToWorkflowRunsRequest: + assert self._subscribe_to_workflow_run is not None + ret = SubscribeToWorkflowRunsRequest() + return ParseDict(self._subscribe_to_workflow_run, ret) diff --git a/hatchet_sdk/v2/runtime/utils.py b/hatchet_sdk/v2/runtime/utils.py new file mode 100644 index 00000000..1eea30a4 --- /dev/null +++ b/hatchet_sdk/v2/runtime/utils.py @@ -0,0 +1,44 @@ +from collections.abc import AsyncGenerator +import asyncio + +from typing import TypeVar + +from contextlib import suppress + +T = TypeVar("T") +I = TypeVar("I") + + +async def InterruptableAgen( + agen: AsyncGenerator[T], + interrupt: asyncio.Queue[I], + timeout: float, +) -> AsyncGenerator[T | I]: + + queue: asyncio.Queue[T | StopAsyncIteration] = asyncio.Queue() + + async def producer(): + async for item in agen: + await queue.put(item) + await queue.put(StopAsyncIteration()) + + producer_task = asyncio.create_task(producer()) + + while True: + with suppress(asyncio.TimeoutError): + item = await asyncio.wait_for(queue.get(), timeout=timeout) + # it is not timeout if we reach this line + if isinstance(item, StopAsyncIteration): + break + else: + yield item + + with suppress(asyncio.QueueEmpty): + v = interrupt.get_nowait() + # we are interrupted if we reach this line + yield v + break + + producer_task.cancel() + with suppress(asyncio.CancelledError): + await producer_task diff --git a/hatchet_sdk/v2/runtime/worker.py b/hatchet_sdk/v2/runtime/worker.py index 8a110e0d..ca5c0fc8 100644 --- a/hatchet_sdk/v2/runtime/worker.py +++ b/hatchet_sdk/v2/runtime/worker.py @@ -7,7 +7,7 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from enum import Enum -from typing import Dict, List, Optional, Set +from typing import Dict, Generic, List, Optional, Set, TypeVar import grpc from google.protobuf import timestamp_pb2 @@ -23,11 +23,13 @@ AssignedAction, HeartbeatRequest, StepActionEvent, + SubscribeToWorkflowRunsRequest, WorkerLabels, WorkerListenRequest, WorkerRegisterRequest, WorkerRegisterResponse, WorkerUnsubscribeRequest, + WorkflowRunEvent, ) from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub @@ -103,6 +105,56 @@ async def heartbeat(self): logger.debug("bye") +T = TypeVar("T") + + +class _GrpcAioListnerBase(Generic[T]): + def __init__(self): + self.attempt = 0 + self.interrupt = False + + def stub(self): + raise NotImplementedError() + + def request(self, stub): + raise NotImplementedError() + + def interrupt(self): + self.interrupt = True + + async def listen(self) -> AsyncGenerator[T]: + while True: + stub = self.stub() + stream = None + try: + stream = self.request(stub) + async for msg in stream: + if not self.interrupt: + yield msg + except grpc.aio.AioRpcError as e: + logger.warning(e) + finally: + if stream is not None: + stream.cancel() + self.interrupt = False + self.attempt += 1 + + +class _WorkflowRunListner(_GrpcAioListnerBase[WorkflowRunEvent]): + def __init__(self, worker: "Worker", run_id: str, stub: DispatcherStub): + super().__init__() + self._worker = worker + self._run_id = run_id + self._stub = stub + + def stub(self) -> DispatcherStub: + return self._stub + + def request(self, stub: DispatcherStub): + req = SubscribeToWorkflowRunsRequest(workflowRunId=self._run_id) + return stub.SubscribeToWorkflowRuns(req, metadata=self._worker._grpc_metadata()) + + class _ActionListner: def __init__(self, worker: "Worker"): logger.debug("init action listener") diff --git a/hatchet_sdk/worker/worker.py b/hatchet_sdk/worker/worker.py index 0160dcda..36c6b07d 100644 --- a/hatchet_sdk/worker/worker.py +++ b/hatchet_sdk/worker/worker.py @@ -120,7 +120,7 @@ def setup_loop(self, loop: asyncio.AbstractEventLoop = None): def start(self, options: WorkerStartOptions = WorkerStartOptions()): created_loop = self.setup_loop(options.loop) - f = asyncio.run_coroutine_threadsafe( + self.result_f = asyncio.run_coroutine_threadsafe( self.async_start(options, _from_start=True), self.loop ) # start the loop and wait until its closed @@ -129,7 +129,7 @@ def start(self, options: WorkerStartOptions = WorkerStartOptions()): if self.handle_kill: sys.exit(0) - return f + return self.result_f ## Start methods async def async_start( @@ -265,7 +265,7 @@ async def exit_gracefully(self): self.action_listener_process.kill() await self.close() - + # self.result_f.set_result("") if self.loop: self.loop.stop() diff --git a/tests/v2/test_listeners.py b/tests/v2/test_listeners.py new file mode 100644 index 00000000..713b2098 --- /dev/null +++ b/tests/v2/test_listeners.py @@ -0,0 +1,36 @@ +import asyncio +import logging +import sys + +import dotenv +import pytest +from loguru import logger + +from hatchet_sdk.v2.hatchet import Hatchet +from hatchet_sdk.v2.runtime.listeners import WorkflowRunEventListener + +logger.remove() +logger.add(sys.stdout, level="TRACE") + +dotenv.load_dotenv() + +hatchet = Hatchet(debug=True) + +logging.getLogger("asyncio").setLevel(logging.DEBUG) + + +async def interrupt(listener): + await asyncio.sleep(2) + logger.trace("interupt") + await listener._interrupt() + logger.trace("interrupted") + + +@pytest.mark.asyncio +async def test_listener_shutdown(): + listener = WorkflowRunEventListener() + task = asyncio.create_task(listener.loop()) + task2 = asyncio.create_task(interrupt(listener)) + sub = await listener.subscribe("bar-vj13ex/bar") + await sub.future + await task diff --git a/tests/v2/test_utils.py b/tests/v2/test_utils.py new file mode 100644 index 00000000..1585a503 --- /dev/null +++ b/tests/v2/test_utils.py @@ -0,0 +1,35 @@ +import asyncio +import logging + + +import pytest +from loguru import logger + +from hatchet_sdk.v2.runtime.utils import interuptable + + +logging.getLogger("asyncio").setLevel(logging.DEBUG) + + +async def producer(): + for i in range(10): + await asyncio.sleep(0.5) + logger.info("yielding {}", i) + yield i + + +async def consumer(agen): + async for item in agen: + logger.info("consuming: {}", item) + + +@pytest.mark.asyncio +async def test_interruptable_agen(): + + q = asyncio.Queue() + agen = interuptable(producer(), q, 1) + + async with asyncio.TaskGroup() as tg: + tg.create_task(consumer(agen)) + await asyncio.sleep(2) + await q.put({}) diff --git a/tests/v2/test_worker.py b/tests/v2/test_worker.py index 586dfe15..8667a8b1 100644 --- a/tests/v2/test_worker.py +++ b/tests/v2/test_worker.py @@ -21,22 +21,25 @@ @hatchet.function() def foo(): - print("Foo") - bar("from foo") + print("entering Foo") + print("result from bar: ", bar("from foo")) return "foo" @hatchet.function() def bar(x): - print(x) + print("entering Bar") + print("arguments for bar: ", x) return "bar" @pytest.mark.asyncio async def test_worker(): - worker = hatchet.worker(WorkerOptions(name="worker", actions=["default:foo", "default:bar"])) + worker = hatchet.worker( + WorkerOptions(name="worker", actions=["default:foo", "default:bar"]) + ) await worker.start() - foo() + print("result from foo: ", foo()) await asyncio.sleep(10) await worker.shutdown() return None From 9cbd0e34c7f5bd8412f51531cff96ef775b657f9 Mon Sep 17 00:00:00 2001 From: Hanwen Wu Date: Wed, 18 Sep 2024 14:38:11 -0400 Subject: [PATCH 09/12] wip --- hatchet_sdk/v2/runtime/listeners.py | 2 -- hatchet_sdk/v2/runtime/worker.py | 47 ----------------------------- 2 files changed, 49 deletions(-) diff --git a/hatchet_sdk/v2/runtime/listeners.py b/hatchet_sdk/v2/runtime/listeners.py index a6173a9a..d542f442 100644 --- a/hatchet_sdk/v2/runtime/listeners.py +++ b/hatchet_sdk/v2/runtime/listeners.py @@ -85,8 +85,6 @@ # stream = self.request(stub) - - class WorkflowRunEventListener: @dataclass diff --git a/hatchet_sdk/v2/runtime/worker.py b/hatchet_sdk/v2/runtime/worker.py index ca5c0fc8..1dd02826 100644 --- a/hatchet_sdk/v2/runtime/worker.py +++ b/hatchet_sdk/v2/runtime/worker.py @@ -108,53 +108,6 @@ async def heartbeat(self): T = TypeVar("T") -class _GrpcAioListnerBase(Generic[T]): - def __init__(self): - self.attempt = 0 - self.interrupt = False - - def stub(self): - raise NotImplementedError() - - def request(self, stub): - raise NotImplementedError() - - def interrupt(self): - self.interrupt = True - - async def listen(self) -> AsyncGenerator[T]: - while True: - stub = self.stub() - stream = None - try: - stream = self.request(stub) - async for msg in stream: - if not self.interrupt: - yield msg - except grpc.aio.AioRpcError as e: - logger.warning(e) - finally: - if stream is not None: - stream.cancel() - self.interrupt = False - self.attempt += 1 - - -class _WorkflowRunListner(_GrpcAioListnerBase[WorkflowRunEvent]): - def __init__(self, worker: "Worker", run_id: str, stub: DispatcherStub): - super().__init__() - self._worker = worker - self._run_id = run_id - self._stub = stub - - def stub(self) -> DispatcherStub: - return self._stub - - def request(self, stub: DispatcherStub): - req = SubscribeToWorkflowRunsRequest(workflowRunId=self._run_id) - return stub.SubscribeToWorkflowRuns(req, metadata=self._worker._grpc_metadata()) - - class _ActionListner: def __init__(self, worker: "Worker"): logger.debug("init action listener") From f8360a1feb9d7c4794145a3b3b370a7359aec867 Mon Sep 17 00:00:00 2001 From: Hanwen Wu Date: Wed, 18 Sep 2024 16:28:48 -0400 Subject: [PATCH 10/12] more changes to generator adaptors and workers. working on the main thread runtime to handle workflow run results --- hatchet_sdk/v2/runtime/listeners.py | 186 ++++++++++++++++------------ hatchet_sdk/v2/runtime/utils.py | 39 +++++- hatchet_sdk/v2/runtime/worker.py | 146 +++++++++------------- tests/v2/test_utils.py | 6 +- 4 files changed, 203 insertions(+), 174 deletions(-) diff --git a/hatchet_sdk/v2/runtime/listeners.py b/hatchet_sdk/v2/runtime/listeners.py index d542f442..43f71bdb 100644 --- a/hatchet_sdk/v2/runtime/listeners.py +++ b/hatchet_sdk/v2/runtime/listeners.py @@ -4,11 +4,11 @@ import os import threading import time -from collections.abc import AsyncGenerator, Generator, AsyncIterator +from collections.abc import AsyncGenerator, Generator, AsyncIterator, Callable from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from enum import Enum -from typing import Dict, Generic, List, Optional, Set, TypeVar, Literal +from typing import Dict, Generic, List, Optional, Set, TypeVar, Literal, Any from contextlib import suppress @@ -23,6 +23,7 @@ import hatchet_sdk.v2.runtime.context as context import hatchet_sdk.v2.runtime.messages as messages import hatchet_sdk.v2.runtime.worker as worker +import hatchet_sdk.v2.runtime.utils as utils from hatchet_sdk.contracts.dispatcher_pb2 import ( ActionType, AssignedAction, @@ -44,61 +45,19 @@ T = TypeVar("T") -# class _GrpcAioListnerBase(Generic[T]): -# def __init__(self): -# self.attempt = 0 -# self.interrupt = False - -# def channel(self): -# raise NotImplementedError() - -# def stub(self, channel): -# raise NotImplementedError - -# def request(self, stub): -# raise NotImplementedError() - -# def interrupt(self): -# self.interrupt = True - -# async def listen(self) -> AsyncGenerator[T]: -# while True: -# stub: DispatcherStub = self.stub() -# stub.ListenV2 - -# stream = None -# try: -# stream = self.request(stub) -# async for msg in stream: -# if not self.interrupt: -# yield msg -# except grpc.aio.AioRpcError as e: -# logger.warning(e) -# finally: -# if stream is not None: -# stream.cancel() -# self.interrupt = False -# self.attempt += 1 - -# def read(self) -> Generator[T]: -# stub = self.stub() -# stream = self.request(stub) - - class WorkflowRunEventListener: @dataclass class Sub: id: str run_id: str - future: asyncio.Future[List[StepRunResult]] + future: asyncio.Future[WorkflowRunEvent] def __hash__(self): return hash(self.id) def __init__(self): - logger.trace("init workflow run event listener") - self._token = context.ensure_background_context(None).client.config.token + logger.debug("init workflow run event listener") # the set of active subscriptions self._subs: Set[WorkflowRunEventListener.Sub] = set() @@ -114,47 +73,40 @@ def __init__(self): # must be created inside the loop self._q_request: asyncio.Queue[SubscribeToWorkflowRunsRequest] = asyncio.Queue() - self._events_agen: AsyncGenerator[WorkflowRunEvent] = self._events() - async def loop(self): - await self._events_agen.aclose() - async for event in self._events_agen: - assert ( - event.eventType == WorkflowRunEventType.WORKFLOW_RUN_EVENT_TYPE_FINISHED - ) - self._by_run_id[event.workflowRunId].future.set_result(list(event.results)) - self._unsubscribe(event.workflowRunId) + try: + agen = utils.ForeverAgen(self._events, exceptions=(grpc.aio.AioRpcError,)) + async for event in agen: + if isinstance(event, grpc.aio.AioRpcError): + logger.trace("encountered error, retrying: {}", event) + await self._resubscribe() + + else: + self._by_run_id[event.workflowRunId].future.set_result(event) + self._unsubscribe(event.workflowRunId) + finally: + logger.debug("bye") async def _events(self) -> AsyncGenerator[WorkflowRunEvent]: # keep trying until asyncio.CancelledError is raised into this coroutine # TODO: handle retry, backoff, etc. stub = DispatcherStub(channel=connection.ensure_background_achannel()) - agen = self._requests() - while True: - try: - stream: grpc.aio.StreamStreamCall[ - SubscribeToWorkflowRunsRequest, WorkflowRunEvent - ] = stub.SubscribeToWorkflowRuns( - agen, - metadata=[("authorization", f"bearer {self._token}")], - ) - logger.trace("stream established") - async for event in stream: - logger.trace(event) - yield event - - except grpc.aio.AioRpcError as e: - logger.exception(e) - pass - - await self._resubscribe() - - async def _requests(self) -> AsyncGenerator[SubscribeToWorkflowRunsRequest]: - while True: - req = await self._q_request.get() - logger.trace("client streaming req to server: {}", MessageToDict(req)) - yield req + requests = utils.QueueAgen(self._q_request) + + stream: grpc.aio.StreamStreamCall[ + SubscribeToWorkflowRunsRequest, WorkflowRunEvent + ] = stub.SubscribeToWorkflowRuns( + requests, + metadata=context.ensure_background_context().client._grpc_metadata(), + ) + logger.trace("stream established") + async for event in stream: + logger.trace("received workflow run event:\n{}", event) + assert ( + event.eventType == WorkflowRunEventType.WORKFLOW_RUN_EVENT_TYPE_FINISHED + ) + yield event async def _resubscribe(self): logger.trace("re-subscribing all") @@ -178,9 +130,81 @@ async def subscribe(self, run_id: str) -> "WorkflowRunEventListener.Sub": return sub def _unsubscribe(self, run_id: str): - logger.trace("unsubscribing {}", run_id) + logger.trace("unsubscribing: {}", run_id) sub = self._by_run_id.get(run_id, None) if sub is None: return self._subs.remove(sub) del self._by_run_id[run_id] + + +class AssignedActionListner: + def __init__(self, worker: "worker.Worker", interrupt: asyncio.Queue[T]): + logger.debug("init action listener") + self._worker = worker + self._interrupt = interrupt + + async def _action_stream(self) -> AsyncGenerator[AssignedAction]: + stub = DispatcherStub(connection.ensure_background_achannel()) + proto = WorkerListenRequest(workerId=self._worker.id) + resp = stub.ListenV2( + proto, + metadata=context.ensure_background_context(None).client._grpc_metadata(), + ) + logger.trace("connection established") + async for action in resp: + logger.trace("assigned action:\n{}", MessageToDict(action)) + yield action + + async def listen(self) -> AsyncGenerator[AssignedAction | grpc.aio.AioRpcError | T]: + try: + + def agen_factory(): + return utils.InterruptableAgen( + self._action_stream(), interrupt=self._interrupt, timeout=5 + ) + + agen = utils.ForeverAgen(agen_factory, exceptions=(grpc.aio.AioRpcError,)) + async for action in agen: + if isinstance(action, grpc.aio.AioRpcError): + logger.trace("encountered error, retrying: {}", action) + yield action + else: + yield action + finally: + logger.debug("bye") + + +class StepEventListener: + def __init__(self, inbound: asyncio.Queue["messages.Message"]): + logger.debug("init event listener") + self.inbound = inbound + self.stub = DispatcherStub(connection.ensure_background_channel()) + + async def _message_stream(self) -> AsyncGenerator["messages.Message"]: + while True: + msg: "messages.Message" = await asyncio.to_thread(self.inbound.get) + assert msg.kind in [messages.MessageKind.STEP_EVENT] + logger.trace("event:\n{}", msg) + yield msg + + async def listen(self): + try: + async for msg in self._message_stream(): + match msg.kind: + case messages.MessageKind.STEP_EVENT: + await self.on_step_event(msg.step_event) + case _: + raise NotImplementedError(msg.kind) + finally: + logger.debug("bye") + + async def on_step_event(self, e: StepActionEvent): + # TODO: need retry + logger.trace("emit step action:\n{}", MessageToDict(e)) + resp = await asyncio.to_thread( + self.stub.SendStepActionEvent, + e, + metadata=context.ensure_background_context().client._grpc_metadata(), + ) + logger.trace(resp) diff --git a/hatchet_sdk/v2/runtime/utils.py b/hatchet_sdk/v2/runtime/utils.py index 1eea30a4..4e3958cf 100644 --- a/hatchet_sdk/v2/runtime/utils.py +++ b/hatchet_sdk/v2/runtime/utils.py @@ -1,7 +1,11 @@ -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable import asyncio +import tenacity +import grpc +import multiprocessing as mp +import multiprocessing.queues as mpq +from typing import TypeVar, Tuple -from typing import TypeVar from contextlib import suppress @@ -42,3 +46,34 @@ async def producer(): producer_task.cancel() with suppress(asyncio.CancelledError): await producer_task + + +async def ForeverAgen( + agen_factory: Callable[[], AsyncGenerator[T]], exceptions: Tuple[Exception] +) -> AsyncGenerator[T | Exception]: + """Run a async generator forever until its cancelled. + + Args: + agen_factory: a callable that returns the async generator of type T + exceptions: a tuple of exceptions that should be suppressed and yielded. + Exceptions not listed here will be re-raised. + + Returns: + An async generator that yields T or yields the suppressed exceptions. + """ + while True: + agen = agen_factory() + try: + async for item in agen: + yield item + except Exception as e: + if isinstance(e, exceptions): + yield e + else: + raise + + +async def QueueAgen(inbound: asyncio.Queue[T] | mpq.Queue[T]) -> AsyncGenerator[T]: + while True: + item = await asyncio.to_thread(inbound.get) + yield item diff --git a/hatchet_sdk/v2/runtime/worker.py b/hatchet_sdk/v2/runtime/worker.py index 1dd02826..40295944 100644 --- a/hatchet_sdk/v2/runtime/worker.py +++ b/hatchet_sdk/v2/runtime/worker.py @@ -18,6 +18,11 @@ import hatchet_sdk.v2.hatchet as hatchet import hatchet_sdk.v2.runtime.connection as connection import hatchet_sdk.v2.runtime.messages as messages +import hatchet_sdk.v2.runtime.listeners as listeners +import hatchet_sdk.v2.runtime.context as context +import hatchet_sdk.v2.runtime.utils as utils + + from hatchet_sdk.contracts.dispatcher_pb2 import ( ActionType, AssignedAction, @@ -28,6 +33,7 @@ WorkerListenRequest, WorkerRegisterRequest, WorkerRegisterResponse, + StepRunResult, WorkerUnsubscribeRequest, WorkflowRunEvent, ) @@ -108,112 +114,48 @@ async def heartbeat(self): T = TypeVar("T") -class _ActionListner: - def __init__(self, worker: "Worker"): - logger.debug("init action listener") - self.worker = worker - self.attempt = 0 - self.stub = DispatcherStub( - connection.new_conn(self.worker.client.config, aio=True) - ) - - async def listen(self): - resp = None - try: - # It will exit the loop when asyncio.CancelledError is - # raised by calling task.cancel() from outside. - while True: - proto = WorkerListenRequest(workerId=self.worker.id) - try: - resp = self.stub.ListenV2( - proto, metadata=self.worker._grpc_metadata() - ) - logger.trace("connection established") - async for event in resp: - msg = messages.Message(_action=MessageToDict(event)) - logger.trace("assigned action:\n{}", msg) - await asyncio.to_thread(self.queue.put, msg) - - resp = None - self.attempt += 1 - except grpc.aio.AioRpcError as e: - logger.warning(e) - - # TODO: expotential backoff, retry limit, etc - - finally: - if resp: - resp.cancel() - logger.debug("bye") - - @property - def queue(self): - return self.worker.outbound - - -class _EventListner: - def __init__(self, worker: "Worker"): - logger.debug("init event listener") - self.worker = worker - self.stub = DispatcherStub(connection.new_conn(self.worker.client.config)) - - async def listen(self): - try: - while True: - msg: "messages.Message" = await asyncio.to_thread(self.queue.get) - logger.trace("event:\n{}", msg) - assert msg.kind in [messages.MessageKind.STEP_EVENT] - match msg.kind: - case messages.MessageKind.STEP_EVENT: - await self.on_step_event(msg.step_event) - case _: - raise NotImplementedError(msg.kind) - finally: - logger.debug("bye") - - async def on_step_event(self, e: StepActionEvent): - # TODO: need retry - logger.trace("emit step action:\n{}", MessageToDict(e)) - resp = await asyncio.to_thread( - self.stub.SendStepActionEvent, e, metadata=self.worker._grpc_metadata() - ) - logger.trace(resp) - - @property - def queue(self): - return self.worker.inbound - - class Worker: def __init__( self, client: "hatchet.Hatchet", - inbound: mp.Queue, - outbound: mp.Queue, + inbound: mp.Queue["messages.Message"], + outbound: mp.Queue["messages.Message"], options: WorkerOptions, ): logger.debug("init worker") + context.ensure_background_context(client=client) + self.options = options self.client = client - self.status = WorkerStatus.UNKNOWN self.id: Optional[str] = None self.inbound = inbound self.outbound = outbound + self.status = WorkerStatus.UNKNOWN self._heartbeater = _HeartBeater(self) self._heartbeater_task: Optional[asyncio.Task] = None - self._action_listener = _ActionListner(self) - self._action_listener_task: Optional[asyncio.Task] = None - self._event_listner = _EventListner(self) + + self._action_listener_interrupt: asyncio.Queue[StopAsyncIteration] = ( + asyncio.Queue() + ) + self._action_listener = listeners.AssignedActionListner( + worker=self, interrupt=self._action_listener_interrupt + ) + + self._event_listner_q: asyncio.Queue["messages.Message"] = asyncio.Queue() + self._event_listner = listeners.StepEventListener(self._event_listner_q) self._event_listner_task: Optional[asyncio.Task] = None + self._workflow_run_event_listener = listeners.WorkflowRunEventListener() + self._workflow_run_event_listener_task: Optional[asyncio.Task] = None + def _register(self) -> str: req = self._to_register_proto() logger.trace("registering worker:\n{}", req) resp: WorkerRegisterResponse = self.client.dispatcher.client.Register( req, timeout=30, - metadata=self._grpc_metadata(), + metadata=context.ensure_background_context().client._grpc_metadata(), ) logger.debug("worker registered:\n{}", MessageToDict(resp)) self.id = resp.workerId @@ -229,15 +171,44 @@ async def start(self): self._event_listner_task = asyncio.create_task( self._event_listner.listen(), name="event_listener" ) - self._action_listener_task = asyncio.create_task( - self._action_listener.listen(), name="action_listener" - ) while True: if self._heartbeater.last_heartbeat > 0: logger.debug("worker started: {}", self.id) return await asyncio.sleep(0.1) + async def server_message_loop(self): + async for action in self._action_listener.listen(): + if isinstance(action, StopAsyncIteration): + # interrupted, ignore + pass + elif isinstance(action, grpc.aio.AioRpcError): + # errored out, ignored + pass + else: + assert isinstance(action, AssignedAction) + msg = messages.Message(_action=MessageToDict(action)) + await asyncio.to_thread(self.outbound.put, msg) + + async def client_message_loop(self): + async for msg in utils.QueueAgen(self.inbound): + match msg.kind: + case messages.MessageKind.STEP_EVENT: + await asyncio.to_thread(self._event_listner_q.put, msg) + case messages.MessageKind.SUBSCRIBE_TO_WORKFLOW_RUN: + await self.on_workflow_run_subscription(msg) + + async def on_workflow_run_subscription(self, msg: "messages.Message"): + def callback(f: asyncio.Future[WorkflowRunEvent]): + self.outbound.put( + messages.Message(_workflow_run_event=MessageToDict(f.result())) + ) + + sub = await self._workflow_run_event_listener.subscribe( + msg.subscribe_to_workflow_run.workflowRunId + ) + sub.future.add_done_callback(callback) + async def shutdown(self): logger.trace("shutting down worker {}", self.id) tg: asyncio.Future = asyncio.gather( @@ -249,9 +220,6 @@ async def shutdown(self): except asyncio.CancelledError: logger.debug("bye") - def _grpc_metadata(self): - return [("authorization", f"bearer {self.client.config.token}")] - def _to_register_proto(self) -> WorkerRegisterRequest: options = self.options proto = WorkerRegisterRequest( diff --git a/tests/v2/test_utils.py b/tests/v2/test_utils.py index 1585a503..7629f87a 100644 --- a/tests/v2/test_utils.py +++ b/tests/v2/test_utils.py @@ -5,7 +5,7 @@ import pytest from loguru import logger -from hatchet_sdk.v2.runtime.utils import interuptable +from hatchet_sdk.v2.runtime.utils import InterruptableAgen, ForeverAgen logging.getLogger("asyncio").setLevel(logging.DEBUG) @@ -27,7 +27,9 @@ async def consumer(agen): async def test_interruptable_agen(): q = asyncio.Queue() - agen = interuptable(producer(), q, 1) + + agen_factory = lambda: InterruptableAgen(producer(), q, 1) + agen = ForeverAgen(agen_factory) async with asyncio.TaskGroup() as tg: tg.create_task(consumer(agen)) From 43076669cfe462605457dad3a2b461f8b505c5f1 Mon Sep 17 00:00:00 2001 From: Hanwen Wu Date: Sun, 22 Sep 2024 14:56:09 -0400 Subject: [PATCH 11/12] more changes --- hatchet_sdk/v2/callable.py | 79 ++++---- hatchet_sdk/v2/hatchet.py | 141 +-------------- hatchet_sdk/v2/runtime/connection.py | 6 +- hatchet_sdk/v2/runtime/future.py | 181 +++++++++++++++++++ hatchet_sdk/v2/runtime/listeners.py | 83 +++++++-- hatchet_sdk/v2/runtime/messages.py | 5 + hatchet_sdk/v2/runtime/runner.py | 19 +- hatchet_sdk/v2/runtime/runtime.py | 96 ++++++++-- hatchet_sdk/v2/runtime/utils.py | 78 +++++--- hatchet_sdk/v2/runtime/worker.py | 261 ++++++++++++++++++--------- tests/v2/test_broker.py | 61 +++++++ tests/v2/test_worker.py | 34 +++- 12 files changed, 716 insertions(+), 328 deletions(-) create mode 100644 hatchet_sdk/v2/runtime/future.py create mode 100644 tests/v2/test_broker.py diff --git a/hatchet_sdk/v2/callable.py b/hatchet_sdk/v2/callable.py index 7f220e47..d4380da3 100644 --- a/hatchet_sdk/v2/callable.py +++ b/hatchet_sdk/v2/callable.py @@ -1,5 +1,6 @@ -from __future__ import annotations +# from __future__ import annotations +import threading import asyncio import inspect import json @@ -34,11 +35,14 @@ import hatchet_sdk.v2.hatchet as v2hatchet import hatchet_sdk.v2.runtime.context as context +import hatchet_sdk.v2.runtime.messages as messages +import hatchet_sdk.v2.runtime.utils as utils from hatchet_sdk.clients.admin import TriggerWorkflowOptions from hatchet_sdk.context import Context from hatchet_sdk.context.context import BaseContext, Context, ContextAioImpl from hatchet_sdk.contracts.dispatcher_pb2 import ( AssignedAction, + StepRunResult, SubscribeToWorkflowRunsRequest, WorkflowRunEvent, WorkflowRunEventType, @@ -115,8 +119,8 @@ def __init__( func: Callable[P, T], name: str, namespace: str, - client: v2hatchet.Hatchet, - options: Options, + client: "v2hatchet.Hatchet", + options: "Options", ): # TODO: maybe use __qualname__ name = name.lower() or func.__name__.lower() @@ -236,6 +240,15 @@ def _decode_context( # } # return data + def _decode_output(self, result: WorkflowRunEvent): + steps = list(result.results) + assert len(steps) == 1 + step = steps[0] + if step.error: + raise RuntimeError(step.error) + else: + return _CallableOutput.loads(step.output).output + def _run(self, action: AssignedAction) -> str: # actually invokes the function, and serializing the output raise NotImplementedError @@ -243,46 +256,32 @@ def _run(self, action: AssignedAction) -> str: class HatchetCallable(HatchetCallableBase[P, T]): def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: - with context.EnsureContext(self._hatchet.client) as ctx: - trigger = self._to_trigger_proto( - ctx, inputs=_CallableInput(args=args, kwargs=kwargs) - ) - logger.trace("triggering\n{}", MessageToDict(trigger)) - client = self._hatchet.client - ref: TriggerWorkflowResponse = client.admin.client.TriggerWorkflow( - trigger, metadata=self._hatchet.client._grpc_metadata() - ) - logger.trace("runid: {}", ref) - # TODO: look into timeouts for Future.result() - return self._hatchet.client.executor.submit( - self._result, ref.workflow_run_id - ).result() - - def _result(self, run_id: str) -> Optional[WorkflowRunEvent]: - # NOTE: SubscribeToWorkflowRuns is a stream-stream RPC. - request = SubscribeToWorkflowRunsRequest(workflowRunId=run_id) - stream: Iterator[WorkflowRunEvent] = ( - self._hatchet.client.dispatcher.client.SubscribeToWorkflowRuns( - iter((request,)), - metadata=self._hatchet.client._grpc_metadata(), - ) + ctx = context.ensure_background_context() + trigger = self._to_trigger_proto( + ctx, inputs=_CallableInput(args=args, kwargs=kwargs) ) - logger.debug(list(stream)) - for event in stream: - logger.trace("workflow run events:\n{}", MessageToDict(event)) - if event.eventType == WorkflowRunEventType.WORKFLOW_RUN_EVENT_TYPE_FINISHED: - return event.results - else: - logger.warning( - "unexpected workflow run events:\n{}", MessageToDict(event) - ) - break + logger.trace( + "triggering on {}: {}", threading.get_ident(), MessageToDict(trigger) + ) + client = self._hatchet.client + ref: TriggerWorkflowResponse = client.admin.client.TriggerWorkflow( + trigger, metadata=self._hatchet.client._grpc_metadata() + ) + logger.trace("runid: {}", ref) + # TODO: look into timeouts for Future.result() + + sub = SubscribeToWorkflowRunsRequest(workflowRunId=ref.workflow_run_id) + wfre_future = self._hatchet.client._runtime.wfr_futures.submit(sub) + + return utils.MapFuture( + self._decode_output, wfre_future, self._hatchet.client.executor + ).result() def _run(self, action: AssignedAction) -> str: assert action.actionId == self._hatchet.action logger.trace("invoking:\n{}", MessageToDict(action)) - with context.EnsureContext(self._hatchet.client) as ctx: - assert ctx.current is None + ctx = context.ensure_background_context(client=self._hatchet.client) + assert ctx.current is None parent: Optional["context.BackgroundContext"] = self._decode_context(action) with context.WithParentContext(parent) as ctx: @@ -377,8 +376,8 @@ class CallableMetadata: action: str sourceloc: str # source location of the callable - options: Options - client: v2hatchet.Hatchet + options: "Options" + client: "v2hatchet.Hatchet" def _debug(self): return { diff --git a/hatchet_sdk/v2/hatchet.py b/hatchet_sdk/v2/hatchet.py index 62bbca5d..846c7602 100644 --- a/hatchet_sdk/v2/hatchet.py +++ b/hatchet_sdk/v2/hatchet.py @@ -3,17 +3,18 @@ import inspect import multiprocessing as mp from concurrent.futures import ThreadPoolExecutor +from contextlib import suppress from typing import Callable, Dict, List, Optional, ParamSpec, Tuple, TypeVar import hatchet_sdk.hatchet as v1 import hatchet_sdk.v2.callable as callable import hatchet_sdk.v2.runtime.config as config +import hatchet_sdk.v2.runtime.context as context import hatchet_sdk.v2.runtime.logging as logging import hatchet_sdk.v2.runtime.registry as registry import hatchet_sdk.v2.runtime.runner as runner import hatchet_sdk.v2.runtime.runtime as runtime import hatchet_sdk.v2.runtime.worker as worker -import hatchet_sdk.v2.runtime.context as context # import hatchet_sdk.runtime.registry as hatchet_registry # import hatchet_sdk.v2.callable as v2_callable @@ -43,10 +44,9 @@ def __init__( executor: ThreadPoolExecutor = ThreadPoolExecutor(), ): # ensure a event loop is created before gRPC - try: + + with suppress(RuntimeError): asyncio.get_event_loop() - finally: - pass self.registry = registry.ActionRegistry() self.v1: v1.Hatchet = v1.Hatchet.from_environment( @@ -55,6 +55,8 @@ def __init__( ) self.executor = executor + self._runtime: Optional["runtime.Runtime"] = None + context.ensure_background_context(client=self) @property @@ -107,134 +109,11 @@ def inner(func: Callable[P, T]) -> "callable.HatchetCallable[P, T]": return inner + # TODO: make it 1 worker : 1 client, which means moving the options to the initializer, and cache the result. def worker(self, options: "worker.WorkerOptions") -> "runtime.Runtime": - return runtime.Runtime(self, options) + if self._runtime is None: + self._runtime = runtime.Runtime(self, options) + return self._runtime def _grpc_metadata(self) -> List[Tuple]: return [("authorization", f"bearer {self.config.token}")] - - # def durable(s - # name: str = "", - # auto_register: bool = True, - # on_events: list | None = None, - # on_crons: list | None = None, - # version: str = "", - # timeout: str = "60m", - # schedule_timeout: str = "5m", - # sticky: StickyStrategy = None, - # retries: int = 0, - # rate_limits: List[RateLimit] | None = None, - # desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, - # concurrency: v2.concurrency.ConcurrencyFunction | None = None, - # on_failure: v2.callable.HatchetCallable | None = None, - # default_priority: int | None = None, - # ): - # def inner(func: v2.callable.HatchetCallable) -> v2.callable.HatchetCallable: - # func.durable = True - - # f = function( - # name=name, - # auto_register=auto_register, - # on_events=on_events, - # on_crons=on_crons, - # version=version, - # timeout=timeout, - # schedule_timeout=schedule_timeout, - # sticky=sticky, - # retries=retries, - # rate_limits=rate_limits, - # desired_worker_labels=desired_worker_labels, - # concurrency=concurrency, - # on_failure=on_failure, - # default_priority=default_priority, - # ) - - # resp = f(func) - - # resp.durable = True - - # return resp - - # return inner - - # def concurrency( - # name: str = "concurrency", - # max_runs: int = 1, - # limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, - # ): - # def inner(func: Callable[[Context], str]) -> v2.concurrency.ConcurrencyFunction: - # return v2.concurrency.ConcurrencyFunction(func, name, max_runs, limit_strategy) - - # return inner - - # class OldHatchet(v1.Hatchet): - # # dag = staticmethod(v1.workflow) - # # concurrency = staticmethod(concurrency) - - # _registry: hatchet_registry.ActionRegistry = hatchet_registry.ActionRegistry() - - -# # def durable( -# # self, -# # name: str = "", -# # auto_register: bool = True, -# # on_events: list | None = None, -# # on_crons: list | None = None, -# # version: str = "", -# # timeout: str = "60m", -# # schedule_timeout: str = "5m", -# # sticky: StickyStrategy = None, -# # retries: int = 0, -# # rate_limits: List[RateLimit] | None = None, -# # desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, -# # concurrency: v2.concurrency.ConcurrencyFunction | None = None, -# # on_failure: Optional["HatchetCallable"] = None, -# # default_priority: int | None = None, -# # ) -> Callable[[v2.callable.HatchetCallable], v2.callable.HatchetCallable]: -# # resp = durable( -# # name=name, -# # auto_register=auto_register, -# # on_events=on_events, -# # on_crons=on_crons, -# # version=version, -# # timeout=timeout, -# # schedule_timeout=schedule_timeout, -# # sticky=sticky, -# # retries=retries, -# # rate_limits=rate_limits, -# # desired_worker_labels=desired_worker_labels, -# # concurrency=concurrency, -# # on_failure=on_failure, -# # default_priority=default_priority, -# # ) - -# # def wrapper(func: Callable[[Context], T]) -> v2.callable.HatchetCallable[T]: -# # wrapped_resp = resp(func) - -# # if wrapped_resp.function_auto_register: -# # self.functions.append(wrapped_resp) - -# # wrapped_resp.with_namespace(self._client.config.namespace) - -# # return wrapped_resp - -# # return wrapper - -# def worker( -# self, -# name: str, -# max_runs: Optional[int] = None, -# labels: Dict[str, str | int] = {}, -# ): -# worker = Worker( -# name=name, -# max_runs=max_runs, -# labels=labels, -# config=self._client.config, -# debug=self._client.debug, -# ) - -# for func in self._registry.registry.values(): -# register_on_worker(func, worker) - -# return worker diff --git a/hatchet_sdk/v2/runtime/connection.py b/hatchet_sdk/v2/runtime/connection.py index cfd313b4..eca20b1c 100644 --- a/hatchet_sdk/v2/runtime/connection.py +++ b/hatchet_sdk/v2/runtime/connection.py @@ -1,12 +1,12 @@ -import hatchet_sdk.connection as v1 import contextvars as cv +from typing import Optional + import grpc import grpc.aio -from typing import Optional +import hatchet_sdk.connection as v1 import hatchet_sdk.v2.runtime.context as context - _aio_channel_cv: cv.ContextVar[Optional[grpc.aio.Channel]] = cv.ContextVar( "hatchet_background_aio_channel", default=None ) diff --git a/hatchet_sdk/v2/runtime/future.py b/hatchet_sdk/v2/runtime/future.py new file mode 100644 index 00000000..36f0a6d7 --- /dev/null +++ b/hatchet_sdk/v2/runtime/future.py @@ -0,0 +1,181 @@ +import asyncio +import multiprocessing as mp +import multiprocessing.queues as mpq +import queue +import threading +import time +from collections.abc import Callable, MutableSet +from concurrent.futures import Future, ThreadPoolExecutor +from contextlib import suppress +from typing import Dict, Generic, Optional, TypeAlias, TypeVar + +from google.protobuf.json_format import MessageToDict +from loguru import logger + +import hatchet_sdk.v2.runtime.messages as messages +import hatchet_sdk.v2.runtime.utils as utils +from hatchet_sdk.contracts.dispatcher_pb2 import ( + SubscribeToWorkflowRunsRequest, + WorkflowRunEvent, +) + +# TODO: use better generics for Python >= 3.12 +T = TypeVar("T") +RespT = TypeVar("RespT") +ReqT = TypeVar("ReqT") + + +_ThreadSafeQueue: TypeAlias = queue.Queue[T] | mpq.Queue[T] + + +class RequestResponseBroker(Generic[ReqT, RespT]): + def __init__( + self, + *, + inbound: _ThreadSafeQueue[RespT], + outbound: _ThreadSafeQueue[ReqT], + req_key: Callable[[ReqT], str], + resp_key: Callable[[RespT], str], + executor: ThreadPoolExecutor, + ): + """A broker that can send a request and returns a future for the response. + + The broker loop runs forever and quits upon asyncio.CancelledError. + + Args: + outbound: a thread-safe blocking queue to which the request should be forwarded to + inbound: a thread-safe blocking queue from which the responses will come + req_key: a function that computes the key of the request, which is used to match the responses + resp_key: a function that computes the key of the response, which is used to match the requests + executor: a thread pool for running any blocking code + """ + self._inbound = inbound + self._outbound = outbound + self._req_key = req_key + self._resp_key = resp_key + + # NOTE: this is used for running the polling tasks for results. + # The tasks we submit to the (any) executor should NOT wait indefinitely. + # We must provide it with a way to self-cancelling. + self._executor = executor + + # Used to signal to the tasks on the executor to quit + self._shutdown = False + + self._lock = threading.Lock() # lock for self._keys and self._futures + self._keys: MutableSet[str] = set() + self._futures: Dict[str, Optional[RespT]] = dict() + + self._akeys: MutableSet[str] = set() + self._afutures: Dict[str, asyncio.Future[RespT]] = dict() + + self.loop_task: Optional[asyncio.Task] = None + + def start(self): + logger.trace("starting broker on {}", threading.get_native_id()) + self.loop_task = asyncio.create_task(self.loop()) + return + + async def shutdown(self): + self.loop_task.cancel() + with suppress(asyncio.CancelledError): + await self.loop_task + + async def loop(self): + try: + async for resp in utils.QueueAgen(self._inbound): + logger.trace("broker got: {}", resp) + key = self._resp_key(resp) + + def update(): + with self._lock: + if key in self._futures: + self._futures[key] = resp + return True + return False + + if await asyncio.to_thread(update): + continue + + if key in self._afutures: + self._afutures[key].set_result(resp) + self._akeys.remove(key) + del self._afutures[key] + continue + + raise KeyError(f"key not found: {key}") + finally: + self._shutdown = True + + async def asubmit(self, req: ReqT) -> asyncio.Future[RespT]: + key = self._req_key(req) + assert key not in self._keys + + f = None + if key not in self._akeys: + self._afutures[key] = asyncio.Future() + f = self._afutures[key] + self._akeys.add(key) + await asyncio.to_thread(self._outbound.put, key) + + return f + + def submit(self, req: ReqT) -> Future[RespT]: + key = self._req_key(req) + + assert key not in self._akeys + + def poll(): + with self._lock: + if key not in self._keys: + self._futures[key] = None + self._keys.add(key) + self._outbound.put(req) + + resp = None + while resp is None and not self._shutdown: + while self._futures.get(key, None) is None: + time.sleep(1) + with self._lock: + resp = self._futures.get(key, None) + if resp is not None: + self._keys.remove(key) + del self._futures[key] + + return resp + + return self._executor.submit(poll) + + +class WorkflowRunFutures: + def __init__( + self, + broker: RequestResponseBroker["messages.Message", "messages.Message"], + ): + self._broker = broker + self._thread = None + + def start(self): + logger.trace("starting workflow run wrapper on {}", threading.get_native_id()) + self._thread = threading.Thread(target=asyncio.run, args=[self._broker.start()], name="workflow run event broker") + self._thread.start() + + async def shutdown(self): + del self._thread + + def submit(self, req: SubscribeToWorkflowRunsRequest) -> Future[WorkflowRunEvent]: + logger.trace("requesting workflow run result: {}", req) + f = self._broker.submit( + messages.Message(_subscribe_to_workflow_run=MessageToDict(req)) + ) + logger.trace("submitted") + return self._broker._executor.submit(lambda: f.result().workflow_run_event) + + async def asubmit( + self, req: SubscribeToWorkflowRunsRequest + ) -> asyncio.Future[WorkflowRunEvent]: + logger.trace("requesting workflow run result: {}", req) + f = await self._broker.asubmit(req) + event: asyncio.Future[WorkflowRunEvent] = asyncio.Future() + f.add_done_callback(lambda f: event.set_result(f.result().workflow_run_event)) + return event diff --git a/hatchet_sdk/v2/runtime/listeners.py b/hatchet_sdk/v2/runtime/listeners.py index 43f71bdb..7c3353b6 100644 --- a/hatchet_sdk/v2/runtime/listeners.py +++ b/hatchet_sdk/v2/runtime/listeners.py @@ -1,16 +1,15 @@ import asyncio -from asyncio.taskgroups import TaskGroup import multiprocessing as mp import os import threading import time -from collections.abc import AsyncGenerator, Generator, AsyncIterator, Callable +from asyncio.taskgroups import TaskGroup +from collections.abc import AsyncGenerator, AsyncIterator, Callable, Generator from concurrent.futures import ThreadPoolExecutor +from contextlib import suppress from dataclasses import dataclass, field from enum import Enum -from typing import Dict, Generic, List, Optional, Set, TypeVar, Literal, Any -from contextlib import suppress - +from typing import Any, Dict, Generic, List, Literal, Optional, Set, TypeVar import grpc from google.protobuf import timestamp_pb2 @@ -22,31 +21,29 @@ import hatchet_sdk.v2.runtime.connection as connection import hatchet_sdk.v2.runtime.context as context import hatchet_sdk.v2.runtime.messages as messages -import hatchet_sdk.v2.runtime.worker as worker import hatchet_sdk.v2.runtime.utils as utils +import hatchet_sdk.v2.runtime.worker as worker from hatchet_sdk.contracts.dispatcher_pb2 import ( ActionType, AssignedAction, HeartbeatRequest, StepActionEvent, + StepRunResult, SubscribeToWorkflowRunsRequest, WorkerLabels, WorkerListenRequest, WorkerRegisterRequest, WorkerRegisterResponse, WorkerUnsubscribeRequest, - WorkflowRunEventType, - StepRunResult, WorkflowRunEvent, + WorkflowRunEventType, ) from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub - T = TypeVar("T") class WorkflowRunEventListener: - @dataclass class Sub: id: str @@ -73,6 +70,21 @@ def __init__(self): # must be created inside the loop self._q_request: asyncio.Queue[SubscribeToWorkflowRunsRequest] = asyncio.Queue() + self._task = None + + def start(self): + self._task = asyncio.create_task( + self.loop(), name="workflow run event listener loop" + ) + logger.debug("started workflow run event listener") + + async def shutdown(self): + if self._task: + self._task.cancel() + with suppress(asyncio.CancelledError): + await self._task + self._task = None + async def loop(self): try: agen = utils.ForeverAgen(self._events, exceptions=(grpc.aio.AioRpcError,)) @@ -85,10 +97,9 @@ async def loop(self): self._by_run_id[event.workflowRunId].future.set_result(event) self._unsubscribe(event.workflowRunId) finally: - logger.debug("bye") + logger.debug("bye: workflow run event listner shuts down") async def _events(self) -> AsyncGenerator[WorkflowRunEvent]: - # keep trying until asyncio.CancelledError is raised into this coroutine # TODO: handle retry, backoff, etc. stub = DispatcherStub(channel=connection.ensure_background_achannel()) @@ -140,10 +151,25 @@ def _unsubscribe(self, run_id: str): class AssignedActionListner: def __init__(self, worker: "worker.Worker", interrupt: asyncio.Queue[T]): - logger.debug("init action listener") + logger.debug("init assigned action listener") self._worker = worker self._interrupt = interrupt + self._task = None + + def start( + self, async_on: Callable[[AssignedAction | grpc.aio.AioRpcError | T], Any] + ): + self._task = asyncio.create_task(self.loop(async_on)) + logger.debug("started assigned action listener") + + async def shutdown(self): + if self._task: + self._task.cancel() + with suppress(asyncio.CancelledError): + await self._task + self._task = None + async def _action_stream(self) -> AsyncGenerator[AssignedAction]: stub = DispatcherStub(connection.ensure_background_achannel()) proto = WorkerListenRequest(workerId=self._worker.id) @@ -172,7 +198,13 @@ def agen_factory(): else: yield action finally: - logger.debug("bye") + logger.debug("bye: assigned action listener") + + async def loop( + self, async_on: Callable[[AssignedAction | grpc.aio.AioRpcError | T], Any] + ): + async for event in self.listen(): + await async_on(event) class StepEventListener: @@ -181,9 +213,21 @@ def __init__(self, inbound: asyncio.Queue["messages.Message"]): self.inbound = inbound self.stub = DispatcherStub(connection.ensure_background_channel()) + self.task = None + + def start(self): + self.task = asyncio.create_task(self.listen()) + + async def shutdown(self): + if self.task: + self.task.cancel() + with suppress(asyncio.CancelledError): + await self.task + self.task = None + async def _message_stream(self) -> AsyncGenerator["messages.Message"]: while True: - msg: "messages.Message" = await asyncio.to_thread(self.inbound.get) + msg: "messages.Message" = await self.inbound.get() assert msg.kind in [messages.MessageKind.STEP_EVENT] logger.trace("event:\n{}", msg) yield msg @@ -193,13 +237,16 @@ async def listen(self): async for msg in self._message_stream(): match msg.kind: case messages.MessageKind.STEP_EVENT: - await self.on_step_event(msg.step_event) + await self._on_step_event(msg.step_event) case _: raise NotImplementedError(msg.kind) + except Exception as e: + logger.exception(e) + raise finally: - logger.debug("bye") + logger.debug("bye: step event listener") - async def on_step_event(self, e: StepActionEvent): + async def _on_step_event(self, e: StepActionEvent): # TODO: need retry logger.trace("emit step action:\n{}", MessageToDict(e)) resp = await asyncio.to_thread( diff --git a/hatchet_sdk/v2/runtime/messages.py b/hatchet_sdk/v2/runtime/messages.py index d4e01ac9..5aa4b8e8 100644 --- a/hatchet_sdk/v2/runtime/messages.py +++ b/hatchet_sdk/v2/runtime/messages.py @@ -29,6 +29,7 @@ class MessageKind(Enum): STEP_EVENT = 2 WORKFLOW_RUN_EVENT = 3 SUBSCRIBE_TO_WORKFLOW_RUN = 4 + WORKER_ID = 5 @dataclass @@ -40,6 +41,8 @@ class Message: _workflow_run_event: Optional[Dict] = None _subscribe_to_workflow_run: Optional[Dict] = None + worker_id: Optional[str] = None + @property def kind(self) -> MessageKind: if self._action is not None: @@ -50,6 +53,8 @@ def kind(self) -> MessageKind: return MessageKind.WORKFLOW_RUN_EVENT if self._subscribe_to_workflow_run is not None: return MessageKind.SUBSCRIBE_TO_WORKFLOW_RUN + if self.worker_id: + return MessageKind.WORKER_ID return MessageKind.UNKNOWN @property diff --git a/hatchet_sdk/v2/runtime/runner.py b/hatchet_sdk/v2/runtime/runner.py index d0708eab..9202cdf2 100644 --- a/hatchet_sdk/v2/runtime/runner.py +++ b/hatchet_sdk/v2/runtime/runner.py @@ -1,9 +1,12 @@ +import threading import asyncio import json import multiprocessing as mp +import multiprocessing.queues as mpq +import queue import time import traceback -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, TypeAlias, TypeVar from google.protobuf.json_format import MessageToDict from google.protobuf.timestamp_pb2 import Timestamp @@ -38,6 +41,7 @@ async def _invoke( logger.trace("invoking: {}", repr(fn)) try: if isinstance(fn, callable.HatchetCallable): + logger.trace("invoking {} on a separate thread", fn._hatchet.name) return await asyncio.to_thread(fn._run, action), None else: return await fn._run(action), None @@ -48,12 +52,17 @@ async def _invoke( return None, e -class BaseRunnerLoop: +# TODO: Use better generics for Python >= 3.12 +T = TypeVar("T") +_ThreadSafeQueue: TypeAlias = queue.Queue[T] | mpq.Queue[T] + + +class RunnerLoop: def __init__( self, client: "hatchet.Hatchet", - inbound: mp.Queue, # inbound queue, not owned - outbound: mp.Queue, # outbound queue, not owned + inbound: _ThreadSafeQueue["messages.Message"], # inbound queue, not owned + outbound: _ThreadSafeQueue["messages.Message"], # outbound queue, not owned ): logger.trace("init runner loop") self.client = client @@ -71,7 +80,7 @@ def __init__( self.tasks: Dict[str, asyncio.Task] = dict() def start(self): - logger.debug("runner loop started") + logger.debug("starting runner loop on {}", threading.get_ident()) self.looptask = asyncio.create_task(self.loop(), name="runner loop") async def shutdown(self): diff --git a/hatchet_sdk/v2/runtime/runtime.py b/hatchet_sdk/v2/runtime/runtime.py index 73e9bf5d..bf6d0739 100644 --- a/hatchet_sdk/v2/runtime/runtime.py +++ b/hatchet_sdk/v2/runtime/runtime.py @@ -1,40 +1,106 @@ import asyncio import multiprocessing as mp +import os +import queue +import sys +import threading +from concurrent.futures import ProcessPoolExecutor +from contextlib import suppress from loguru import logger +import hatchet_sdk.loader as loader import hatchet_sdk.v2.hatchet as hatchet +import hatchet_sdk.v2.runtime.future as future +import hatchet_sdk.v2.runtime.messages as messages import hatchet_sdk.v2.runtime.runner as runner +import hatchet_sdk.v2.runtime.utils as utils import hatchet_sdk.v2.runtime.worker as worker +from hatchet_sdk.contracts.dispatcher_pb2 import ( + SubscribeToWorkflowRunsRequest, + WorkflowRunEvent, +) class Runtime: def __init__(self, client: "hatchet.Hatchet", options: "worker.WorkerOptions"): logger.trace("init runtime") - self.events = mp.Queue() - self.actions = mp.Queue() - self.worker = worker.Worker( - client=client, inbound=self.events, outbound=self.actions, options=options + + self.client = client + self.process_pool = ProcessPoolExecutor() + + self.to_worker: mp.Queue["messages.Message"] = mp.Queue() + self.from_worker: mp.Queue["messages.Message"] = mp.Queue() + + self.to_runner = queue.Queue() + self.to_wfr_futures = queue.Queue() + + self.worker = worker.WorkerProcess( + config=client.config, + inbound=self.to_worker, + outbound=self.from_worker, + options=options, ) - self.runner = runner.BaseRunnerLoop( - client=client, inbound=self.actions, outbound=self.events + self.worker_id = None + + self.runner = runner.RunnerLoop( + client=client, inbound=self.to_runner, outbound=self.to_worker + ) + self.wfr_futures = future.WorkflowRunFutures( + # pool=client.executor, + broker=future.RequestResponseBroker( + inbound=self.to_wfr_futures, + outbound=self.to_worker, + req_key=lambda msg: msg.subscribe_to_workflow_run.workflowRunId, + resp_key=lambda msg: msg.workflow_run_event.workflowRunId, + executor=client.executor, + ), ) + self.loop_task = None + + async def loop(self): + async for msg in utils.QueueAgen(self.from_worker): + match msg.kind: + case messages.MessageKind.ACTION: + await asyncio.to_thread(self.to_runner.put, msg) + case messages.MessageKind.WORKFLOW_RUN_EVENT: + await asyncio.to_thread(self.to_wfr_futures.put, msg) + case messages.MessageKind.WORKER_ID: + self.runner.worker_id = msg.worker_id + self.worker_id = msg.worker_id + case _: + raise NotImplementedError + async def start(self): - logger.trace("starting runtime") + logger.trace("starting runtime on {}", threading.get_ident()) self.runner.start() - await self.worker.start() - self.runner.worker_id = self.worker.id + self.wfr_futures.start() + self.worker.start() + + self.client.executor.submit(asyncio.run, self.loop()) + + while self.worker_id is None: + await asyncio.sleep(1) logger.debug("runtime started") - return self.worker.id + return self.worker_id async def shutdown(self): logger.trace("shutting down runtime") - await self.worker.shutdown() - self.actions.close() - self.actions.join_thread() + + self.worker.shutdown() + # await self.worker.shutdown() + self.from_worker.close() + self.from_worker.join_thread() await self.runner.shutdown() - self.events.close() - self.events.join_thread() + self.to_worker.close() + self.to_worker.join_thread() + + await self.wfr_futures.shutdown() + + self.loop_task.cancel() + with suppress(asyncio.CancelledError): + await self.loop_task + logger.debug("bye") diff --git a/hatchet_sdk/v2/runtime/utils.py b/hatchet_sdk/v2/runtime/utils.py index 4e3958cf..e5d7e0a3 100644 --- a/hatchet_sdk/v2/runtime/utils.py +++ b/hatchet_sdk/v2/runtime/utils.py @@ -1,16 +1,18 @@ -from collections.abc import AsyncGenerator, Callable import asyncio -import tenacity -import grpc import multiprocessing as mp import multiprocessing.queues as mpq -from typing import TypeVar, Tuple - - +import queue +from collections.abc import AsyncGenerator, Callable +from concurrent.futures import Future, ThreadPoolExecutor from contextlib import suppress +from typing import Tuple, TypeVar + +import grpc +import tenacity T = TypeVar("T") I = TypeVar("I") +R = TypeVar("R") async def InterruptableAgen( @@ -18,7 +20,6 @@ async def InterruptableAgen( interrupt: asyncio.Queue[I], timeout: float, ) -> AsyncGenerator[T | I]: - queue: asyncio.Queue[T | StopAsyncIteration] = asyncio.Queue() async def producer(): @@ -26,25 +27,25 @@ async def producer(): await queue.put(item) await queue.put(StopAsyncIteration()) - producer_task = asyncio.create_task(producer()) - - while True: - with suppress(asyncio.TimeoutError): - item = await asyncio.wait_for(queue.get(), timeout=timeout) - # it is not timeout if we reach this line - if isinstance(item, StopAsyncIteration): + try: + producer_task = asyncio.create_task(producer()) + while True: + with suppress(asyncio.TimeoutError): + item = await asyncio.wait_for(queue.get(), timeout=timeout) + # it is not timeout if we reach this line + if isinstance(item, StopAsyncIteration): + break + else: + yield item + + with suppress(asyncio.QueueEmpty): + v = interrupt.get_nowait() + # we are interrupted if we reach this line + yield v break - else: - yield item - with suppress(asyncio.QueueEmpty): - v = interrupt.get_nowait() - # we are interrupted if we reach this line - yield v - break - - producer_task.cancel() - with suppress(asyncio.CancelledError): + finally: + producer_task.cancel() await producer_task @@ -73,7 +74,28 @@ async def ForeverAgen( raise -async def QueueAgen(inbound: asyncio.Queue[T] | mpq.Queue[T]) -> AsyncGenerator[T]: - while True: - item = await asyncio.to_thread(inbound.get) - yield item +async def QueueAgen( + inbound: queue.Queue[T] | asyncio.Queue[T] | mpq.Queue[T], +) -> AsyncGenerator[T]: + if isinstance(inbound, asyncio.Queue): + while True: + yield await inbound.get() + inbound.task_done() + elif isinstance(inbound, queue.Queue): + while True: + yield await asyncio.to_thread(inbound.get) + inbound.task_done() + elif isinstance(inbound, mpq.Queue): + while True: + yield await asyncio.to_thread(inbound.get) + else: + raise TypeError(f"unsupported queue type: {type(inbound)}") + + +def MapFuture( + fn: Callable[[T], R], fut: Future[T], pool: ThreadPoolExecutor +) -> Future[R]: + def task(fn: Callable[[T], R], fut: Future[T]): + return fn(fut.result()) + + return pool.submit(task, fn, fut) diff --git a/hatchet_sdk/v2/runtime/worker.py b/hatchet_sdk/v2/runtime/worker.py index 40295944..28d7278a 100644 --- a/hatchet_sdk/v2/runtime/worker.py +++ b/hatchet_sdk/v2/runtime/worker.py @@ -1,14 +1,18 @@ +import sys import asyncio import multiprocessing as mp +import multiprocessing.queues as mpq +import multiprocessing.synchronize as mps import os import threading import time from collections.abc import AsyncGenerator from concurrent.futures import ThreadPoolExecutor +from contextlib import suppress from dataclasses import dataclass, field from enum import Enum from typing import Dict, Generic, List, Optional, Set, TypeVar - +import logging import grpc from google.protobuf import timestamp_pb2 from google.protobuf.json_format import MessageToDict, MessageToJson @@ -16,24 +20,23 @@ import hatchet_sdk.contracts.dispatcher_pb2 import hatchet_sdk.v2.hatchet as hatchet +import hatchet_sdk.v2.runtime.config as config import hatchet_sdk.v2.runtime.connection as connection -import hatchet_sdk.v2.runtime.messages as messages -import hatchet_sdk.v2.runtime.listeners as listeners import hatchet_sdk.v2.runtime.context as context +import hatchet_sdk.v2.runtime.listeners as listeners +import hatchet_sdk.v2.runtime.messages as messages import hatchet_sdk.v2.runtime.utils as utils - - from hatchet_sdk.contracts.dispatcher_pb2 import ( ActionType, AssignedAction, HeartbeatRequest, StepActionEvent, + StepRunResult, SubscribeToWorkflowRunsRequest, WorkerLabels, WorkerListenRequest, WorkerRegisterRequest, WorkerRegisterResponse, - StepRunResult, WorkerUnsubscribeRequest, WorkflowRunEvent, ) @@ -60,29 +63,31 @@ def labels_proto(self) -> Dict[str, WorkerLabels]: return ret -class WorkerStatus(Enum): - UNKNOWN = 1 - REGISTERED = 2 - # STARTING = 2 - HEALTHY = 3 - UNHEALTHY = 4 - - -class _HeartBeater: +class HeartBeater: def __init__(self, worker: "Worker"): logger.debug("init heartbeater") self.worker = worker self.last_heartbeat: int = -1 # unix epoch in seconds - self.stub = DispatcherStub( - connection.new_conn(self.worker.client.config, aio=False) - ) + self.stub = DispatcherStub(connection.ensure_background_channel()) self.missed = 0 self.error = 0 + self.task = None + + async def start(self): + self.task = asyncio.create_task(self.heartbeat()) + while self.last_heartbeat < 0: + await asyncio.sleep(1) + + async def shutdown(self): + if self.task: + self.task.cancel() + with suppress(asyncio.CancelledError): + await self.task + self.task = None + async def heartbeat(self): try: - # It will exit the loop when a asyncio.CancelledError is raised - # by calling task.cancel() from outside. while True: now = int(time.time()) proto = HeartbeatRequest( @@ -91,21 +96,26 @@ async def heartbeat(self): ) try: _ = self.stub.Heartbeat( - proto, timeout=5, metadata=self.worker._grpc_metadata() + proto, + timeout=5, + metadata=context.ensure_background_context().client._grpc_metadata(), ) - logger.trace("heartbeat") - except grpc.RpcErrors: + logger.debug("heartbeat") + except grpc.RpcError as e: # TODO + logger.exception(e) self.error += 1 if self.last_heartbeat < 0: self.last_heartbeat = now - self.status = WorkerStatus.HEALTHY else: diff = proto.heartbeatAt.seconds - self.last_heartbeat if diff > self.worker.options.heartbeat: self.missed += 1 await asyncio.sleep(self.worker.options.heartbeat) + except Exception as e: + logger.exception(e) + raise finally: logger.debug("bye") @@ -117,89 +127,117 @@ async def heartbeat(self): class Worker: def __init__( self, - client: "hatchet.Hatchet", - inbound: mp.Queue["messages.Message"], - outbound: mp.Queue["messages.Message"], + *, options: WorkerOptions, + client: "hatchet.Hatchet", + inbound: mpq.Queue["messages.Message"], + outbound: mpq.Queue["messages.Message"], ): logger.debug("init worker") context.ensure_background_context(client=client) - self.options = options - self.client = client self.id: Optional[str] = None + + self.options = options self.inbound = inbound self.outbound = outbound - self.status = WorkerStatus.UNKNOWN - self._heartbeater = _HeartBeater(self) - self._heartbeater_task: Optional[asyncio.Task] = None + self._heartbeater = HeartBeater(self) self._action_listener_interrupt: asyncio.Queue[StopAsyncIteration] = ( asyncio.Queue() ) self._action_listener = listeners.AssignedActionListner( - worker=self, interrupt=self._action_listener_interrupt + worker=self, + interrupt=self._action_listener_interrupt, ) - self._event_listner_q: asyncio.Queue["messages.Message"] = asyncio.Queue() - self._event_listner = listeners.StepEventListener(self._event_listner_q) - self._event_listner_task: Optional[asyncio.Task] = None + self._to_event_listner: asyncio.Queue["messages.Message"] = asyncio.Queue() + self._event_listner = listeners.StepEventListener(self._to_event_listner) self._workflow_run_event_listener = listeners.WorkflowRunEventListener() - self._workflow_run_event_listener_task: Optional[asyncio.Task] = None + + self.main_loop_task = None def _register(self) -> str: req = self._to_register_proto() logger.trace("registering worker:\n{}", req) - resp: WorkerRegisterResponse = self.client.dispatcher.client.Register( + resp: ( + WorkerRegisterResponse + ) = context.ensure_background_context().client.dispatcher.client.Register( req, timeout=30, metadata=context.ensure_background_context().client._grpc_metadata(), ) logger.debug("worker registered:\n{}", MessageToDict(resp)) - self.id = resp.workerId - self.status = WorkerStatus.REGISTERED return resp.workerId - async def start(self): + async def start(self) -> str: logger.trace("starting worker") - self._register() - self._heartbeat_task = asyncio.create_task( - self._heartbeater.heartbeat(), name="heartbeater" - ) - self._event_listner_task = asyncio.create_task( - self._event_listner.listen(), name="event_listener" + self.id = self._register() + self._event_listner.start() + self._workflow_run_event_listener.start() + self._action_listener.start(async_on=self.on_assigned_action) + + self.main_loop_task = asyncio.create_task(self.loop()) + + await self._heartbeater.start() + await asyncio.to_thread(self.outbound.put, messages.Message(worker_id=self.id)) + logger.debug("worker started: {}", self.id) + return self.id + + async def shutdown(self): + logger.trace("shutting down worker {}", self.id) + + if self.main_loop_task: + self.main_loop_task.cancel() + with suppress(asyncio.CancelledError): + await self.main_loop_task + self.main_loop_task = None + + tg: asyncio.Future = asyncio.gather( + self._heartbeater.shutdown(), + self._event_listner.shutdown(), + self._action_listener.shutdown(), + self._workflow_run_event_listener.shutdown(), ) - while True: - if self._heartbeater.last_heartbeat > 0: - logger.debug("worker started: {}", self.id) - return - await asyncio.sleep(0.1) - - async def server_message_loop(self): - async for action in self._action_listener.listen(): - if isinstance(action, StopAsyncIteration): - # interrupted, ignore - pass - elif isinstance(action, grpc.aio.AioRpcError): - # errored out, ignored - pass - else: - assert isinstance(action, AssignedAction) - msg = messages.Message(_action=MessageToDict(action)) - await asyncio.to_thread(self.outbound.put, msg) - - async def client_message_loop(self): - async for msg in utils.QueueAgen(self.inbound): - match msg.kind: - case messages.MessageKind.STEP_EVENT: - await asyncio.to_thread(self._event_listner_q.put, msg) - case messages.MessageKind.SUBSCRIBE_TO_WORKFLOW_RUN: - await self.on_workflow_run_subscription(msg) + await tg + logger.debug("bye") + + async def loop(self): + try: + async for msg in utils.QueueAgen(self.inbound): + logger.trace("worker received msg: {}", msg) + match msg.kind: + case messages.MessageKind.STEP_EVENT: + await self._to_event_listner.put(msg) + case messages.MessageKind.SUBSCRIBE_TO_WORKFLOW_RUN: + await self.on_workflow_run_subscription(msg) + case _: + raise NotImplementedError + except Exception as e: + logger.exception(e) + raise + finally: + logger.trace("bye: worker") + + async def on_assigned_action( + self, action: StopAsyncIteration | grpc.aio.AioRpcError | AssignedAction + ): + if isinstance(action, StopAsyncIteration): + # interrupted, ignore + pass + elif isinstance(action, grpc.aio.AioRpcError): + # errored out, ignored + pass + else: + assert isinstance(action, AssignedAction) + msg = messages.Message(_action=MessageToDict(action)) + await asyncio.to_thread(self.outbound.put, msg) async def on_workflow_run_subscription(self, msg: "messages.Message"): def callback(f: asyncio.Future[WorkflowRunEvent]): + logger.trace("workflow run event future resolved") self.outbound.put( messages.Message(_workflow_run_event=MessageToDict(f.result())) ) @@ -209,17 +247,6 @@ def callback(f: asyncio.Future[WorkflowRunEvent]): ) sub.future.add_done_callback(callback) - async def shutdown(self): - logger.trace("shutting down worker {}", self.id) - tg: asyncio.Future = asyncio.gather( - self._heartbeat_task, self._action_listener_task, self._event_listner_task - ) - tg.cancel() - try: - await tg - except asyncio.CancelledError: - logger.debug("bye") - def _to_register_proto(self) -> WorkerRegisterRequest: options = self.options proto = WorkerRegisterRequest( @@ -230,3 +257,69 @@ def _to_register_proto(self) -> WorkerRegisterRequest: labels=options.labels_proto, ) return proto + + +def _worker_process( + config: "config.ClientConfig", + options: WorkerOptions, + inbound: mpq.Queue["messages.Message"], + outbound: mpq.Queue["messages.Message"], + shutdown: mps.Event, +): + client = hatchet.Hatchet(config=config, debug=True) + logger.remove() + logger.add(sys.stdout, level="TRACE") + + async def loop(): + worker = Worker( + client=client, + inbound=inbound, + outbound=outbound, + options=options, + ) + try: + id = await worker.start() + while not await asyncio.to_thread(shutdown.wait, 1): + pass + asyncio.current_task().cancel() + except Exception as e: + logger.exception(e) + raise + finally: + with suppress(asyncio.CancelledError): + await worker.shutdown() + logger.trace("worker process shuts down") + + asyncio.run(loop(), debug=True) + logger.trace("here") + + +class WorkerProcess: + def __init__( + self, + *, + config: "config.ClientConfig", + options: WorkerOptions, + inbound: mpq.Queue["messages.Message"], + outbound: mpq.Queue["messages.Message"], + ): + self.to_worker = inbound + self.shutdown_ev = mp.Event() + self.proc = mp.Process( + target=_worker_process, + kwargs={ + "config": config, + "options": options, + "inbound": inbound, + "outbound": outbound, + "shutdown": self.shutdown_ev, + }, + ) + + def start(self): + logger.debug("starting worker process") + self.proc.start() + + def shutdown(self): + self.shutdown_ev.set() + logger.debug("worker process shuts down") diff --git a/tests/v2/test_broker.py b/tests/v2/test_broker.py new file mode 100644 index 00000000..5e61d593 --- /dev/null +++ b/tests/v2/test_broker.py @@ -0,0 +1,61 @@ +import asyncio +import logging +import sys +import queue +import threading + +from concurrent.futures import ThreadPoolExecutor + +# import dotenv +import pytest +from loguru import logger + +# from hatchet_sdk.v2.hatchet import Hatchet +from hatchet_sdk.v2.runtime.broker import QueueToFutureBroker + +logger.remove() +logger.add(sys.stdout, level="TRACE") + +# dotenv.load_dotenv() + +# hatchet = Hatchet(debug=True) + +logging.getLogger("asyncio").setLevel(logging.DEBUG) + + +to_broker = queue.Queue() +to_server = queue.Queue() +exec = ThreadPoolExecutor() +broker = QueueToFutureBroker( + inbound=to_broker, + outbound=to_server, + req_key=lambda x: x, + resp_key=lambda x: x, + executor=exec, +) + + +def echo(p: queue.Queue, q: queue.Queue): + while True: + item = p.get() + logger.trace("echo {}", item) + q.put(item) + + +echo_f = exec.submit(echo, to_server, to_broker) + + +# def test_broker(): +# fut = exec.submit(asyncio.run, broker.loop()) +# f = broker.submit(1) +# print(f.result()) +# fut.cancel() + + +@pytest.mark.asyncio +async def test_broker_async(): + task = asyncio.create_task(broker.loop()) + f = await broker.asubmit(2) + print(await f) + task.cancel() + diff --git a/tests/v2/test_worker.py b/tests/v2/test_worker.py index 8667a8b1..4885bed2 100644 --- a/tests/v2/test_worker.py +++ b/tests/v2/test_worker.py @@ -1,16 +1,20 @@ +import time import asyncio import logging import sys - +import multiprocessing as mp import dotenv import pytest from loguru import logger from hatchet_sdk.v2.hatchet import Hatchet -from hatchet_sdk.v2.runtime.worker import WorkerOptions +from hatchet_sdk.v2.runtime.worker import WorkerOptions, WorkerProcess +from concurrent.futures import ThreadPoolExecutor logger.remove() -logger.add(sys.stdout, level="TRACE") +logger.add( + sys.stdout, level="TRACE" +) # , format="{level}\t|{module}:{function}:{line}[{process}:{thread}] - {message}") dotenv.load_dotenv() @@ -39,7 +43,29 @@ async def test_worker(): WorkerOptions(name="worker", actions=["default:foo", "default:bar"]) ) await worker.start() - print("result from foo: ", foo()) + print("result from foo: ", await asyncio.to_thread(foo)) await asyncio.sleep(10) await worker.shutdown() return None + + +# def test_worker_process(): +# to_worker = mp.Queue() +# from_worker = mp.Queue() +# p = WorkerProcess( +# config=hatchet.config, +# options=WorkerOptions(name="worker", actions=[]), +# inbound=to_worker, +# outbound=from_worker, +# ) + +# pool = ThreadPoolExecutor() +# id = pool.submit(from_worker.get) +# print(p.start()) +# print(id.result()) +# time.sleep(10) +# print("shutting down") +# p.shutdown() + +# to_worker.close() +# from_worker.close() From f854d1d8e5fb9d91027549f6a8193774547c4488 Mon Sep 17 00:00:00 2001 From: Hanwen Wu Date: Sun, 22 Sep 2024 22:33:46 -0400 Subject: [PATCH 12/12] now it is a fully working @function decorator --- hatchet_sdk/v2/callable.py | 323 +++++++++++---------------- hatchet_sdk/v2/hatchet.py | 53 ++--- hatchet_sdk/v2/runtime/connection.py | 14 +- hatchet_sdk/v2/runtime/context.py | 2 +- hatchet_sdk/v2/runtime/future.py | 91 +++++--- hatchet_sdk/v2/runtime/listeners.py | 164 ++++++++------ hatchet_sdk/v2/runtime/registry.py | 7 +- hatchet_sdk/v2/runtime/runner.py | 150 +++++++------ hatchet_sdk/v2/runtime/runtime.py | 120 ++++++---- hatchet_sdk/v2/runtime/utils.py | 21 +- hatchet_sdk/v2/runtime/worker.py | 151 +++++++------ tests/v2/test_worker.py | 15 +- 12 files changed, 581 insertions(+), 530 deletions(-) diff --git a/hatchet_sdk/v2/callable.py b/hatchet_sdk/v2/callable.py index d4380da3..2678e6c9 100644 --- a/hatchet_sdk/v2/callable.py +++ b/hatchet_sdk/v2/callable.py @@ -1,51 +1,22 @@ -# from __future__ import annotations - -import threading import asyncio import inspect import json -from collections.abc import Awaitable, Callable, Iterator -from concurrent.futures.thread import ThreadPoolExecutor - -# from contextvars import ContextVar, copy_context +from collections.abc import Awaitable, Callable +from concurrent.futures import Future from dataclasses import asdict, dataclass, field - -# from datetime import timedelta -from typing import ( - Any, - Dict, - ForwardRef, - Generic, - Iterable, - List, - Literal, - Optional, - ParamSpec, - TypedDict, - TypeVar, - Union, -) +from typing import Any, Dict, Generic, List, Optional, ParamSpec, Tuple, TypeVar from google.protobuf.json_format import MessageToDict - -# from hatchet_sdk.logger import logger from loguru import logger from pydantic import BaseModel, ConfigDict, Field, computed_field -from pydantic.json_schema import SkipJsonSchema -import hatchet_sdk.v2.hatchet as v2hatchet +import hatchet_sdk.v2.hatchet as hatchet import hatchet_sdk.v2.runtime.context as context -import hatchet_sdk.v2.runtime.messages as messages import hatchet_sdk.v2.runtime.utils as utils -from hatchet_sdk.clients.admin import TriggerWorkflowOptions -from hatchet_sdk.context import Context -from hatchet_sdk.context.context import BaseContext, Context, ContextAioImpl from hatchet_sdk.contracts.dispatcher_pb2 import ( AssignedAction, - StepRunResult, SubscribeToWorkflowRunsRequest, WorkflowRunEvent, - WorkflowRunEventType, ) from hatchet_sdk.contracts.workflows_pb2 import ( CreateStepRateLimit, @@ -56,20 +27,11 @@ StickyStrategy, TriggerWorkflowRequest, TriggerWorkflowResponse, - WorkflowConcurrencyOpts, WorkflowKind, ) from hatchet_sdk.labels import DesiredWorkerLabel from hatchet_sdk.rate_limit import RateLimit from hatchet_sdk.v2.concurrency import ConcurrencyFunction -from hatchet_sdk.v2.runtime import registry -from hatchet_sdk.workflow_run import RunRef - -# from typing import TYPE_CHECKING - -# if TYPE_CHECKING: -# from hatchet_sdk.v2.hatchet import Hatchet - T = TypeVar("T") P = ParamSpec("P") @@ -85,45 +47,56 @@ def _sourceloc(fn) -> str: return "" +# Note: this should be language independent, and useable by Go/Typescript, etc. @dataclass class _CallableInput: - args: List[Any] = field(default_factory=list) + """The input to a Hatchet callable.""" + + args: Tuple = field(default_factory=tuple) kwargs: Dict[str, Any] = field(default_factory=dict) - def dumps(self): + def dumps(self) -> str: return json.dumps(asdict(self)) @staticmethod - def loads(s: str): + def loads(s: str) -> "_CallableInput": # NOTE: AssignedAction.actionPayload looks like the following # '{"input": , "parents": {}, "overrides": {}, "user_data": {}, "triggered_by": "manual"}' return _CallableInput(**(json.loads(s)["input"])) +# Note: this should be language independent, and usable by Go/Typescript, etc. @dataclass class _CallableOutput(Generic[T]): - output: Optional[T] = None + """The output of a Hatchet callable.""" - def dumps(self): + output: T + + def dumps(self) -> str: return json.dumps(asdict(self)) @staticmethod - def loads(s: str): - return _CallableOutput(**json.loads(s)) + def loads(s: str) -> "_CallableOutput[T]": + ret = _CallableOutput(**json.loads(s)) + return ret class HatchetCallableBase(Generic[P, T]): + """Hatchet callable base.""" + def __init__( self, *, func: Callable[P, T], name: str, namespace: str, - client: "v2hatchet.Hatchet", + client: "hatchet.Hatchet", options: "Options", ): # TODO: maybe use __qualname__ name = name.lower() or func.__name__.lower() + + # hide everything under self._hatchet since the user has access to everything in HatchetCallableBase. self._hatchet = CallableMetadata( name=name, namespace=namespace, @@ -172,14 +145,15 @@ def _to_step_proto(self) -> CreateWorkflowStepOpts: inputs="{}", # TODO: not sure that this is, we're defining a step, not running a step parents=[], # this is a single step workflow, always empty retries=options.retries, - rate_limits=options.ratelimits, - # worker_labels=self.function_desired_worker_labels, + # rate_limits=options.ratelimits, # TODO + # worker_labels=self.function_desired_worker_labels, # TODO ) return step def _encode_context( self, ctx: "context.BackgroundContext" ) -> TriggerWorkflowRequest: + """Encode the given context into the trigger protobuf.""" trigger = TriggerWorkflowRequest( additional_metadata=json.dumps( {"_hatchet_background_context": ctx.asdict()} @@ -192,22 +166,15 @@ def _encode_context( # Otherwise, the current context is the parent. assert ctx.current is not None - trigger.parent_id = ctx.current.workflow_run_id - trigger.parent_step_run_id = ctx.current.step_run_id - trigger.child_index = 0 # TODO: what is this + trigger.parent_id = ctx.current.workflow_run_id or "" + trigger.parent_step_run_id = ctx.current.step_run_id or "" + trigger.child_index = 0 # TODO: this is no longer needed since the user has full control of how they wanna trigger the children return trigger - def _to_trigger_proto( - self, ctx: "context.BackgroundContext", inputs: _CallableInput - ) -> TriggerWorkflowRequest: - # NOTE: serialization error will be raised as TypeError - req = TriggerWorkflowRequest(name=self._hatchet.name, input=inputs.dumps()) - req.MergeFrom(self._encode_context(ctx)) - return req - def _decode_context( self, action: AssignedAction ) -> Optional["context.BackgroundContext"]: + """Reconstruct the background context using the assigned action protobuf.""" if not action.additional_metadata: return None @@ -218,6 +185,7 @@ def _decode_context( logger.warning("failed to decode additional metadata from assigned action") return None + assert isinstance(d, Dict) if "_hatchet_background_context" not in d: return None @@ -227,63 +195,54 @@ def _decode_context( ctx.client = self._hatchet.client return ctx - # def _debug(self): - # data = { - # "self": repr(self), - # "metadata": self._hatchet._debug(), - # "def_proto": MessageToDict(self._to_workflow_proto()), - # "call_proto": ( - # MessageToDict(self._ctx_to_trigger_proto()) - # if self._to_trigger_proto() - # else None - # ), - # } - # return data + def _to_trigger_proto( + self, ctx: "context.BackgroundContext", inputs: _CallableInput + ) -> TriggerWorkflowRequest: + # NOTE: serialization error will be raised as TypeError + req = TriggerWorkflowRequest(name=self._hatchet.name, input=inputs.dumps()) + req.MergeFrom(self._encode_context(ctx)) + return req + # TODO: the return type of decode output needs to be casted. + # For Callable[P, T] the return type is T. + # For Callable[P, Awaitable[T]], the return type is T. def _decode_output(self, result: WorkflowRunEvent): + """Decode the output from a WorkflowRunEvent. + + Note that the WorkflowRunEvent could be, in the future, encoded from a + different language, like Typescript or Go. + """ steps = list(result.results) - assert len(steps) == 1 + assert len(steps) == 1 # assumping single step workflows step = steps[0] if step.error: + # TODO: find a way to be more precise about the type of exception. + # right now everything is a RuntimeError. raise RuntimeError(step.error) else: - return _CallableOutput.loads(step.output).output + ret = _CallableOutput.loads(step.output).output + return ret - def _run(self, action: AssignedAction) -> str: - # actually invokes the function, and serializing the output - raise NotImplementedError - - -class HatchetCallable(HatchetCallableBase[P, T]): - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + def _trigger(self, *args: P.args, **kwargs: P.kwargs) -> TriggerWorkflowResponse: ctx = context.ensure_background_context() trigger = self._to_trigger_proto( ctx, inputs=_CallableInput(args=args, kwargs=kwargs) ) - logger.trace( - "triggering on {}: {}", threading.get_ident(), MessageToDict(trigger) - ) + logger.trace("triggering: {}", MessageToDict(trigger)) client = self._hatchet.client ref: TriggerWorkflowResponse = client.admin.client.TriggerWorkflow( trigger, metadata=self._hatchet.client._grpc_metadata() ) logger.trace("runid: {}", ref) - # TODO: look into timeouts for Future.result() - - sub = SubscribeToWorkflowRunsRequest(workflowRunId=ref.workflow_run_id) - wfre_future = self._hatchet.client._runtime.wfr_futures.submit(sub) - - return utils.MapFuture( - self._decode_output, wfre_future, self._hatchet.client.executor - ).result() + return ref - def _run(self, action: AssignedAction) -> str: - assert action.actionId == self._hatchet.action - logger.trace("invoking:\n{}", MessageToDict(action)) + def _make_ctx(self, action: AssignedAction) -> "context.BackgroundContext": ctx = context.ensure_background_context(client=self._hatchet.client) assert ctx.current is None - parent: Optional["context.BackgroundContext"] = self._decode_context(action) + parent = self._decode_context(action) or context.BackgroundContext( + client=self._hatchet.client + ) with context.WithParentContext(parent) as ctx: assert ctx.current is None ctx.current = context.RunInfo( @@ -294,29 +253,80 @@ def _run(self, action: AssignedAction) -> str: ) if ctx.root is None: ctx.root = ctx.current.copy() - with context.WithContext(ctx): - inputs = _CallableInput.loads(action.actionPayload) - output = _CallableOutput( - output=self._hatchet.func(*inputs.args, **inputs.kwargs) - ) - logger.trace("output:\n{}", output) - return output.dumps() + return ctx + + +class HatchetCallable(HatchetCallableBase[P, T]): + """A Hatchet callable wrapping a non-asyncio free function.""" + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Future[T]: + """Trigger a workflow run and returns the future. + + Note that it is important that we return a Future. We want the user + to trigger multiple calls and decide when to synchronize. Like, + + concurrent.futures.as_completed(wf1(), wf2(), wf3()) + """ + ref = self._trigger(*args, **kwargs) + + # now setup to wait for the result + sub = SubscribeToWorkflowRunsRequest(workflowRunId=ref.workflow_run_id) + + # TODO: expose a better interface on the Hatchet client for waiting on results. + wfre_future = self._hatchet.client.worker()._wfr_futures.submit(sub) + + fut: Future[T] = utils.MapFuture( + self._decode_output, wfre_future, self._hatchet.client.executor + ) + return fut + + def _run(self, action: AssignedAction) -> str: + """Executes the actual code and returns a serialized output.""" + + logger.trace("invoking: {}", MessageToDict(action)) + assert action.actionId == self._hatchet.action + + ctx = self._make_ctx(action) + with context.WithContext(ctx): + inputs = _CallableInput.loads(action.actionPayload) + output = _CallableOutput( + output=self._hatchet.func(*inputs.args, **inputs.kwargs) + ) + logger.trace("output: {}", output) + return output.dumps() class HatchetAwaitable(HatchetCallableBase[P, Awaitable[T]]): + """A Hatchet callable wrapping an asyncio free function.""" + async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: - print(f"trigering {self.action_name}") - input = json.dumps({"args": args, "kwargs": kwargs}) - client = self._.options.hatchet - return await client.admin.run(self._.name, input).result() + ref = self._trigger(*args, **kwargs) + + # now setup to wait for the result + sub = SubscribeToWorkflowRunsRequest(workflowRunId=ref.workflow_run_id) + + # TODO: expose a better interface on the Hatchet client for waiting on results. + wfre_future = await self._hatchet.client.worker()._wfr_futures.asubmit(sub) - async def _run(self, ctx: ContextAioImpl) -> T: - print(f"trigering {self.action_name}") - input = json.loads(ctx.workflow_input) - return await self.func(*input.args, **input.kwargs) + return self._decode_output(await wfre_future) + + async def _run(self, action: AssignedAction) -> str: + logger.trace("invoking: {}", MessageToDict(action)) + assert action.actionId == self._hatchet.action + + ctx = self._make_ctx(action) + with context.WithContext(ctx): + inputs = _CallableInput.loads(action.actionPayload) + output = _CallableOutput( + output=await self._hatchet.func(*inputs.args, **inputs.kwargs) + ) + logger.trace("output: {}", output) + return output.dumps() class Options(BaseModel): + """The options for a Hatchet function (aka workflow).""" + # pydantic configuration model_config = ConfigDict(arbitrary_types_allowed=True) @@ -354,21 +364,24 @@ def ratelimits_proto(self) -> List[CreateStepRateLimit]: @computed_field @property def desired_worker_labels_proto(self) -> Dict[str, DesiredWorkerLabels]: + # TODO: double check the default values labels = dict() for key, d in self.desired_worker_labels.items(): value = d.get("value", None) labels[key] = DesiredWorkerLabels( strValue=str(value) if not isinstance(value, int) else None, intValue=value if isinstance(value, int) else None, - required=d.get("required", None), - weight=d.get("weight", None), - comparator=d.get("comparator", None), + required=d.get("required") or False, + weight=d.get("weight") or 0, + comparator=str(d.get("comparator")) or None, ) return labels @dataclass -class CallableMetadata: +class CallableMetadata(Generic[P, T]): + """Metadata field for a decorated Hatchet workflow.""" + func: Callable[P, T] # the original function name: str @@ -377,7 +390,7 @@ class CallableMetadata: sourceloc: str # source location of the callable options: "Options" - client: "v2hatchet.Hatchet" + client: "hatchet.Hatchet" def _debug(self): return { @@ -389,75 +402,3 @@ def _debug(self): "client": repr(self.client), "options": self.options.model_dump(), } - - -class HatchetContextBase: - pass - - -# # Context variable used for propagating hatchet context. -# # The type of the variable is CallableContext. -# _callable_cv = ContextVar("hatchet.callable") - - -# # The context object to be propagated between parent/child workflows. -# class CallableContext(BaseModel): -# # pydantic configuration -# model_config = ConfigDict(arbitrary_types_allowed=True) - -# caller: Optional["HatchetCallable[P,T]"] = None -# workflow_run_id: str # caller's workflow run id -# step_run_id: str # caller's step run id - -# @staticmethod -# def cv() -> ContextVar: -# return _callable_cv - -# @staticmethod -# def current() -> Optional["CallableContext"]: -# try: -# cv: ContextVar = CallableContext.cv() -# return cv.get() -# except LookupError: -# return None - - -# T = TypeVar("T") - - -# class TriggerOptions(TypedDict): -# additional_metadata: Dict[str, str] | None = None -# sticky: bool | None = None - - -# class DurableContext(Context): -# pass - - -# # def run( -# # self, -# # function: Union[str, HatchetCallable[T]], -# # input: dict = {}, -# # key: str = None, -# # options: TriggerOptions = None, -# # ) -> "RunRef[T]": -# # worker_id = self.worker.id() - -# # workflow_name = function - -# # if not isinstance(function, str): -# # workflow_name = function.function_name - -# # # if ( -# # # options is not None -# # # and "sticky" in options -# # # and options["sticky"] == True -# # # and not self.worker.has_workflow(workflow_name) -# # # ): -# # # raise Exception( -# # # f"cannot run with sticky: workflow {workflow_name} is not registered on the worker" -# # # ) - -# # trigger_options = self._prepare_workflow_options(key, options, worker_id) - -# # return self.admin_client.run(function, input, trigger_options) diff --git a/hatchet_sdk/v2/hatchet.py b/hatchet_sdk/v2/hatchet.py index 846c7602..4998e490 100644 --- a/hatchet_sdk/v2/hatchet.py +++ b/hatchet_sdk/v2/hatchet.py @@ -1,10 +1,9 @@ import asyncio import functools import inspect -import multiprocessing as mp -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, Future from contextlib import suppress -from typing import Callable, Dict, List, Optional, ParamSpec, Tuple, TypeVar +from typing import Callable, List, Optional, ParamSpec, Tuple, TypeVar import hatchet_sdk.hatchet as v1 import hatchet_sdk.v2.callable as callable @@ -16,22 +15,6 @@ import hatchet_sdk.v2.runtime.runtime as runtime import hatchet_sdk.v2.runtime.worker as worker -# import hatchet_sdk.runtime.registry as hatchet_registry -# import hatchet_sdk.v2.callable as v2_callable -# from hatchet_sdk.context import Context -# from hatchet_sdk.contracts.workflows_pb2 import ConcurrencyLimitStrategy, StickyStrategy - -# import Hatchet as HatchetV1 -# from hatchet_sdk.hatchet import workflow -# from hatchet_sdk.labels import DesiredWorkerLabel -# from hatchet_sdk.rate_limit import RateLimit - -# from hatchet_sdk.v2.concurrency import ConcurrencyFunction -# from hatchet_sdk.worker.worker import register_on_worker - -# from ..worker import Worker - - T = TypeVar("T") P = ParamSpec("P") @@ -44,7 +27,6 @@ def __init__( executor: ThreadPoolExecutor = ThreadPoolExecutor(), ): # ensure a event loop is created before gRPC - with suppress(RuntimeError): asyncio.get_event_loop() @@ -71,27 +53,30 @@ def dispatcher(self): def config(self): return self.v1.config - @property - def logger(self): - return logging.logger - + # FIXME: consider separating this into @func and @afunc for better type hints. + # Right now, the type hint for the return type is (P -> T) | (P -> Future[T]) and this is because we + # don't statically know whether "func" is a def or an async def. def function( self, + *, name: str = "", namespace: str = "default", options: "callable.Options" = callable.Options(), ): - def inner(func: Callable[P, T]) -> "callable.HatchetCallable[P, T]": + # TODO: needs to detect and reject an already decorated free function. + # TODO: needs to detect and reject a classmethod/staticmethod. + def inner(func: Callable[P, T]): if inspect.iscoroutinefunction(func): - wrapped = callable.HatchetAwaitable( + wrapped = callable.HatchetAwaitable[P, T]( func=func, name=name, namespace=namespace, client=self, options=options, ) - wrapped = functools.update_wrapper(wrapped, func) - return wrapped + # TODO: investigate the type error here. + aret: Callable[P, T] = functools.update_wrapper(wrapped, func) + return aret elif inspect.isfunction(func): wrapped = callable.HatchetCallable( func=func, @@ -100,8 +85,8 @@ def inner(func: Callable[P, T]) -> "callable.HatchetCallable[P, T]": client=self, options=options, ) - wrapped = functools.update_wrapper(wrapped, func) - return wrapped + ret: Callable[P, Future[T]] = functools.update_wrapper(wrapped, func) + return ret else: raise TypeError( "the @function decorator can only be applied to functions (def) and async functions (async def)" @@ -110,9 +95,13 @@ def inner(func: Callable[P, T]) -> "callable.HatchetCallable[P, T]": return inner # TODO: make it 1 worker : 1 client, which means moving the options to the initializer, and cache the result. - def worker(self, options: "worker.WorkerOptions") -> "runtime.Runtime": + # TODO: rename it to runtime + def worker( + self, *, options: Optional["worker.WorkerOptions"] = None + ) -> "runtime.Runtime": if self._runtime is None: - self._runtime = runtime.Runtime(self, options) + assert options is not None + self._runtime = runtime.Runtime(client=self, options=options) return self._runtime def _grpc_metadata(self) -> List[Tuple]: diff --git a/hatchet_sdk/v2/runtime/connection.py b/hatchet_sdk/v2/runtime/connection.py index eca20b1c..2d6007b9 100644 --- a/hatchet_sdk/v2/runtime/connection.py +++ b/hatchet_sdk/v2/runtime/connection.py @@ -17,19 +17,23 @@ def ensure_background_channel() -> grpc.Channel: ctx = context.ensure_background_context(client=None) - channel: grpc.Channel = _channel_cv.get() + channel: Optional[grpc.Channel] = _channel_cv.get() if channel is None: - channel = v1.new_conn(ctx.client.config, aio=False) + # TODO: fix the typing of new_conn + channel = v1.new_conn(ctx.client.config, aio=False) # type: ignore _channel_cv.set(channel) + assert channel is not None return channel def ensure_background_achannel() -> grpc.aio.Channel: ctx = context.ensure_background_context(client=None) - achannel: grpc.aio.Channel = _aio_channel_cv.get() + achannel: Optional[grpc.aio.Channel] = _aio_channel_cv.get() if achannel is None: - achannel = v1.new_conn(ctx.client.config, aio=True) + # TODO: fix the typing of new_conn + achannel = v1.new_conn(ctx.client.config, aio=True) # type: ignore _aio_channel_cv.set(achannel) + assert achannel is not None return achannel @@ -41,7 +45,7 @@ def reset_background_channel(): async def reset_background_achannel(): - c: grpc.aio.Channel = _aio_channel_cv.get() + c: Optional[grpc.aio.Channel] = _aio_channel_cv.get() if c is not None: await c.close() _aio_channel_cv.set(None) diff --git a/hatchet_sdk/v2/runtime/context.py b/hatchet_sdk/v2/runtime/context.py index 2c35dacb..0475b820 100644 --- a/hatchet_sdk/v2/runtime/context.py +++ b/hatchet_sdk/v2/runtime/context.py @@ -84,7 +84,7 @@ def copy(self): return ret @staticmethod - def set(ctx: "BackgroundContext"): + def set(ctx: Optional["BackgroundContext"]): global _ctxvar _ctxvar.set(ctx) diff --git a/hatchet_sdk/v2/runtime/future.py b/hatchet_sdk/v2/runtime/future.py index 36f0a6d7..f83727ef 100644 --- a/hatchet_sdk/v2/runtime/future.py +++ b/hatchet_sdk/v2/runtime/future.py @@ -1,11 +1,10 @@ import asyncio -import multiprocessing as mp import multiprocessing.queues as mpq import queue import threading import time from collections.abc import Callable, MutableSet -from concurrent.futures import Future, ThreadPoolExecutor +from concurrent.futures import CancelledError, Future, ThreadPoolExecutor from contextlib import suppress from typing import Dict, Generic, Optional, TypeAlias, TypeVar @@ -38,9 +37,13 @@ def __init__( resp_key: Callable[[RespT], str], executor: ThreadPoolExecutor, ): - """A broker that can send a request and returns a future for the response. + """A broker that can send/forward a request and returns a future for the caller to wait upon. - The broker loop runs forever and quits upon asyncio.CancelledError. + This is to be used in the main process. The broker loop runs forever and quits upon asyncio.CancelledError. + The broker is essentially an adaptor from server-streams to either concurrent.futures.Future or asyncio.Future. + For the blocking case (i.e. concurrent.futures.Future), the broker uses polling. + + The class needs to be thread-safe for the concurrent.futures.Future case. Args: outbound: a thread-safe blocking queue to which the request should be forwarded to @@ -49,13 +52,14 @@ def __init__( resp_key: a function that computes the key of the response, which is used to match the requests executor: a thread pool for running any blocking code """ + logger.trace("init broker") self._inbound = inbound self._outbound = outbound self._req_key = req_key self._resp_key = resp_key # NOTE: this is used for running the polling tasks for results. - # The tasks we submit to the (any) executor should NOT wait indefinitely. + # The tasks we submit to the executor (or any executor) should NOT wait indefinitely. # We must provide it with a way to self-cancelling. self._executor = executor @@ -69,60 +73,81 @@ def __init__( self._akeys: MutableSet[str] = set() self._afutures: Dict[str, asyncio.Future[RespT]] = dict() - self.loop_task: Optional[asyncio.Task] = None + self._task: Optional[asyncio.Task] = None def start(self): - logger.trace("starting broker on {}", threading.get_native_id()) - self.loop_task = asyncio.create_task(self.loop()) + logger.trace("starting broker") + self._task = asyncio.create_task(self._loop()) return async def shutdown(self): - self.loop_task.cancel() - with suppress(asyncio.CancelledError): - await self.loop_task + if self._task: + self._task.cancel() + with suppress(asyncio.CancelledError): + await self._task + self._task = None + + async def _loop(self): + """The main broker loop. - async def loop(self): + The loop listens for any responses and resolves the corresponding futures. + """ + logger.trace("broker started") try: async for resp in utils.QueueAgen(self._inbound): logger.trace("broker got: {}", resp) key = self._resp_key(resp) + # if the response is for a concurrent.futures.Future, + # finds/resolves it and return True. def update(): with self._lock: if key in self._futures: self._futures[key] = resp return True + # NOTE: the clean up happens at submission time + # See self.submit() return False if await asyncio.to_thread(update): continue + # if the previous step didn't find a corresponding future, + # looks for the asyncio.Future instead. if key in self._afutures: self._afutures[key].set_result(resp) + + # clean up self._akeys.remove(key) del self._afutures[key] continue raise KeyError(f"key not found: {key}") finally: + logger.trace("broker shutting down") self._shutdown = True async def asubmit(self, req: ReqT) -> asyncio.Future[RespT]: + """Submits a request for an asyncio.Future.""" key = self._req_key(req) assert key not in self._keys - f = None - if key not in self._akeys: + f = self._afutures.get(key, None) + if f is None: self._afutures[key] = asyncio.Future() f = self._afutures[key] self._akeys.add(key) - await asyncio.to_thread(self._outbound.put, key) + # TODO: pyright can't figure out that both alternatives in the union type is individualy type-checked + await asyncio.to_thread(self._outbound.put, req) # type: ignore return f def submit(self, req: ReqT) -> Future[RespT]: - key = self._req_key(req) + """Submits a request for a concurrent.futures.Future. + The future may raise CancelledError if the broker is shutting down. + """ + key = self._req_key(req) assert key not in self._akeys def poll(): @@ -142,40 +167,54 @@ def poll(): self._keys.remove(key) del self._futures[key] + if self._shutdown: + logger.trace("broker polling task shutting down") + raise CancelledError("shutting down") + + assert resp is not None return resp return self._executor.submit(poll) class WorkflowRunFutures: + """A workflow run listener to be used in the main process. + + It is a high-level interface that wraps a RequestResponseBroker. + """ + def __init__( self, + *, + executor: ThreadPoolExecutor, broker: RequestResponseBroker["messages.Message", "messages.Message"], ): self._broker = broker - self._thread = None + self._executor = executor def start(self): - logger.trace("starting workflow run wrapper on {}", threading.get_native_id()) - self._thread = threading.Thread(target=asyncio.run, args=[self._broker.start()], name="workflow run event broker") - self._thread.start() + logger.trace("starting main-process workflow run listener") + self._broker.start() async def shutdown(self): - del self._thread + logger.trace("shutting down main-process workflow run listener") + await self._broker.shutdown() + logger.trace("bye: main-process workflow run listener") def submit(self, req: SubscribeToWorkflowRunsRequest) -> Future[WorkflowRunEvent]: - logger.trace("requesting workflow run result: {}", req) + logger.trace("requesting workflow run result: {}", MessageToDict(req)) f = self._broker.submit( messages.Message(_subscribe_to_workflow_run=MessageToDict(req)) ) - logger.trace("submitted") - return self._broker._executor.submit(lambda: f.result().workflow_run_event) + return self._executor.submit(lambda: f.result().workflow_run_event) async def asubmit( self, req: SubscribeToWorkflowRunsRequest ) -> asyncio.Future[WorkflowRunEvent]: - logger.trace("requesting workflow run result: {}", req) - f = await self._broker.asubmit(req) + logger.trace("requesting workflow run result: {}", MessageToDict(req)) + f = await self._broker.asubmit( + messages.Message(_subscribe_to_workflow_run=MessageToDict(req)) + ) event: asyncio.Future[WorkflowRunEvent] = asyncio.Future() f.add_done_callback(lambda f: event.set_result(f.result().workflow_run_event)) return event diff --git a/hatchet_sdk/v2/runtime/listeners.py b/hatchet_sdk/v2/runtime/listeners.py index 7c3353b6..d7bcadd5 100644 --- a/hatchet_sdk/v2/runtime/listeners.py +++ b/hatchet_sdk/v2/runtime/listeners.py @@ -1,52 +1,37 @@ import asyncio -import multiprocessing as mp -import os -import threading -import time -from asyncio.taskgroups import TaskGroup -from collections.abc import AsyncGenerator, AsyncIterator, Callable, Generator -from concurrent.futures import ThreadPoolExecutor +from collections.abc import AsyncGenerator, Callable from contextlib import suppress -from dataclasses import dataclass, field -from enum import Enum -from typing import Any, Dict, Generic, List, Literal, Optional, Set, TypeVar +from dataclasses import dataclass +from typing import Any, Dict, Generic, Set, TypeVar import grpc -from google.protobuf import timestamp_pb2 -from google.protobuf.json_format import MessageToDict, MessageToJson +from google.protobuf.json_format import MessageToDict from loguru import logger -import hatchet_sdk.contracts.dispatcher_pb2 -import hatchet_sdk.v2.hatchet as hatchet import hatchet_sdk.v2.runtime.connection as connection import hatchet_sdk.v2.runtime.context as context import hatchet_sdk.v2.runtime.messages as messages import hatchet_sdk.v2.runtime.utils as utils import hatchet_sdk.v2.runtime.worker as worker from hatchet_sdk.contracts.dispatcher_pb2 import ( - ActionType, AssignedAction, - HeartbeatRequest, StepActionEvent, - StepRunResult, SubscribeToWorkflowRunsRequest, - WorkerLabels, WorkerListenRequest, - WorkerRegisterRequest, - WorkerRegisterResponse, - WorkerUnsubscribeRequest, WorkflowRunEvent, WorkflowRunEventType, ) from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub -T = TypeVar("T") - class WorkflowRunEventListener: + """A multiplexing workflow run event listener. It should only be used in the sidecar process.""" + @dataclass class Sub: - id: str + """A subscription for a workflow run. This is only to be used in the sidecar process.""" + + id: str # TODO: the id is not used right now since one can only subscribe a run_id once. run_id: str future: asyncio.Future[WorkflowRunEvent] @@ -54,7 +39,7 @@ def __hash__(self): return hash(self.id) def __init__(self): - logger.debug("init workflow run event listener") + logger.trace("init workflow run event listener") # the set of active subscriptions self._subs: Set[WorkflowRunEventListener.Sub] = set() @@ -73,19 +58,26 @@ def __init__(self): self._task = None def start(self): + logger.trace("starting workflow run event listener") self._task = asyncio.create_task( - self.loop(), name="workflow run event listener loop" + self._loop(), name="workflow run event listener loop" ) - logger.debug("started workflow run event listener") async def shutdown(self): + logger.trace("shutting down workflow run event listener") if self._task: self._task.cancel() with suppress(asyncio.CancelledError): await self._task self._task = None - async def loop(self): + async def _loop(self): + """The main listener loop. + + The loop forwards subscription requests over the grpc stream to the server while giving + out a future to the caller. Then it listens for workflow run events and resolves the futures. + """ + logger.trace("started workflow run event listener") try: agen = utils.ForeverAgen(self._events, exceptions=(grpc.aio.AioRpcError,)) async for event in agen: @@ -97,9 +89,10 @@ async def loop(self): self._by_run_id[event.workflowRunId].future.set_result(event) self._unsubscribe(event.workflowRunId) finally: - logger.debug("bye: workflow run event listner shuts down") + logger.trace("bye: workflow run event listner shuts down") async def _events(self) -> AsyncGenerator[WorkflowRunEvent]: + """The async generator backed by server-streamed WorkflowRunEvents.""" # keep trying until asyncio.CancelledError is raised into this coroutine # TODO: handle retry, backoff, etc. stub = DispatcherStub(channel=connection.ensure_background_achannel()) @@ -113,7 +106,7 @@ async def _events(self) -> AsyncGenerator[WorkflowRunEvent]: ) logger.trace("stream established") async for event in stream: - logger.trace("received workflow run event:\n{}", event) + logger.trace("received workflow run event: {}", MessageToDict(event)) assert ( event.eventType == WorkflowRunEventType.WORKFLOW_RUN_EVENT_TYPE_FINISHED ) @@ -131,10 +124,10 @@ async def _resubscribe(self): async def subscribe(self, run_id: str) -> "WorkflowRunEventListener.Sub": if run_id in self._by_run_id: - return + return self._by_run_id[run_id] logger.trace("subscribing: {}", run_id) await self._q_request.put(SubscribeToWorkflowRunsRequest(workflowRunId=run_id)) - sub = self.Sub(id=self._counter, run_id=run_id, future=asyncio.Future()) + sub = self.Sub(id=str(self._counter), run_id=run_id, future=asyncio.Future()) self._subs.add(sub) self._by_run_id[run_id] = sub self._counter += 1 @@ -149,10 +142,20 @@ def _unsubscribe(self, run_id: str): del self._by_run_id[run_id] -class AssignedActionListner: - def __init__(self, worker: "worker.Worker", interrupt: asyncio.Queue[T]): - logger.debug("init assigned action listener") +# TODO: use better generics with Python >= 3.12 +T = TypeVar("T") + + +class AssignedActionListner(Generic[T]): + """An assigned action listener that runs a callback on every server-streamed assigned actions.""" + + def __init__(self, *, worker: "worker.Worker", interrupt: asyncio.Queue[T]): + logger.trace("init assigned action listener") + + # used to get the worker id, which is not immediately available. self._worker = worker + + # used to interrupt the action listener self._interrupt = interrupt self._task = None @@ -160,10 +163,16 @@ def __init__(self, worker: "worker.Worker", interrupt: asyncio.Queue[T]): def start( self, async_on: Callable[[AssignedAction | grpc.aio.AioRpcError | T], Any] ): - self._task = asyncio.create_task(self.loop(async_on)) - logger.debug("started assigned action listener") + """Starts the assigned action listener loop. + + Args: + async_on: the callback to be invoked when an assigned action is received. + """ + logger.trace("starting assigned action listener") + self._task = asyncio.create_task(self._loop(async_on)) async def shutdown(self): + logger.trace("shutting down assigned action listener") if self._task: self._task.cancel() with suppress(asyncio.CancelledError): @@ -171,6 +180,7 @@ async def shutdown(self): self._task = None async def _action_stream(self) -> AsyncGenerator[AssignedAction]: + """The async generator backed by the server-streamed assigend actions.""" stub = DispatcherStub(connection.ensure_background_achannel()) proto = WorkerListenRequest(workerId=self._worker.id) resp = stub.ListenV2( @@ -179,60 +189,70 @@ async def _action_stream(self) -> AsyncGenerator[AssignedAction]: ) logger.trace("connection established") async for action in resp: - logger.trace("assigned action:\n{}", MessageToDict(action)) + logger.trace("assigned action: {}", MessageToDict(action)) yield action - async def listen(self) -> AsyncGenerator[AssignedAction | grpc.aio.AioRpcError | T]: - try: + async def _listen( + self, + ) -> AsyncGenerator[AssignedAction | grpc.aio.AioRpcError | T]: + """The wrapped assigned action async generator that handles retries, etc.""" - def agen_factory(): - return utils.InterruptableAgen( - self._action_stream(), interrupt=self._interrupt, timeout=5 - ) + def agen_factory(): + return utils.InterruptableAgen( + self._action_stream(), interrupt=self._interrupt, timeout=5 + ) - agen = utils.ForeverAgen(agen_factory, exceptions=(grpc.aio.AioRpcError,)) - async for action in agen: - if isinstance(action, grpc.aio.AioRpcError): - logger.trace("encountered error, retrying: {}", action) - yield action - else: - yield action - finally: - logger.debug("bye: assigned action listener") + agen = utils.ForeverAgen(agen_factory, exceptions=(grpc.aio.AioRpcError,)) + async for action in agen: + if isinstance(action, grpc.aio.AioRpcError): + logger.trace("encountered error, retrying: {}", action) + yield action + else: + yield action - async def loop( + async def _loop( self, async_on: Callable[[AssignedAction | grpc.aio.AioRpcError | T], Any] ): - async for event in self.listen(): - await async_on(event) + """The main assigned action listener loop.""" + try: + logger.trace("started assigned action listener") + async for event in self._listen(): + await async_on(event) + finally: + logger.trace("bye: assigned action listener") class StepEventListener: - def __init__(self, inbound: asyncio.Queue["messages.Message"]): - logger.debug("init event listener") - self.inbound = inbound - self.stub = DispatcherStub(connection.ensure_background_channel()) + """A step event listener that forwards the step event from the main process to the server.""" - self.task = None + def __init__(self, *, inbound: asyncio.Queue["messages.Message"]): + logger.trace("init step event listener") + self._inbound = inbound + self._stub = DispatcherStub(connection.ensure_background_channel()) + self._task = None def start(self): - self.task = asyncio.create_task(self.listen()) + logger.trace("starting step event listener") + self._task = asyncio.create_task(self._listen()) async def shutdown(self): - if self.task: - self.task.cancel() + logger.trace("shutting down step event listener") + if self._task: + self._task.cancel() with suppress(asyncio.CancelledError): - await self.task - self.task = None + await self._task + self._task = None async def _message_stream(self) -> AsyncGenerator["messages.Message"]: while True: - msg: "messages.Message" = await self.inbound.get() + msg: "messages.Message" = await self._inbound.get() assert msg.kind in [messages.MessageKind.STEP_EVENT] - logger.trace("event:\n{}", msg) + logger.trace("event: {}", msg) yield msg - async def listen(self): + async def _listen(self): + """The main listener loop.""" + logger.trace("step event listener started") try: async for msg in self._message_stream(): match msg.kind: @@ -248,10 +268,10 @@ async def listen(self): async def _on_step_event(self, e: StepActionEvent): # TODO: need retry - logger.trace("emit step action:\n{}", MessageToDict(e)) + logger.trace("emit step action: {}", MessageToDict(e)) resp = await asyncio.to_thread( - self.stub.SendStepActionEvent, + self._stub.SendStepActionEvent, e, metadata=context.ensure_background_context().client._grpc_metadata(), ) - logger.trace(resp) + logger.trace("resp: {}", MessageToDict(resp)) diff --git a/hatchet_sdk/v2/runtime/registry.py b/hatchet_sdk/v2/runtime/registry.py index 248612f9..724fe68e 100644 --- a/hatchet_sdk/v2/runtime/registry.py +++ b/hatchet_sdk/v2/runtime/registry.py @@ -1,9 +1,10 @@ import sys from typing import Dict +from loguru import logger + import hatchet_sdk.v2.callable as callable import hatchet_sdk.v2.hatchet as hatchet -import hatchet_sdk.v2.runtime.logging as logging class ActionRegistry: @@ -26,6 +27,6 @@ def register_all(self, client: "hatchet.Hatchet"): try: client.admin.put_workflow(proto.name, proto) except Exception as e: - logging.logger.error(f"failed to register workflow: {proto.name}") - logging.logger.error(e) + logger.error("failed to register workflow: {}", proto.name) + logger.exception(e) sys.exit(1) diff --git a/hatchet_sdk/v2/runtime/runner.py b/hatchet_sdk/v2/runtime/runner.py index 9202cdf2..4beaa940 100644 --- a/hatchet_sdk/v2/runtime/runner.py +++ b/hatchet_sdk/v2/runtime/runner.py @@ -1,20 +1,19 @@ -import threading import asyncio -import json -import multiprocessing as mp import multiprocessing.queues as mpq import queue import time import traceback -from typing import Any, Dict, Optional, Tuple, TypeAlias, TypeVar +from contextlib import suppress +from typing import Dict, Optional, Tuple, TypeAlias, TypeVar from google.protobuf.json_format import MessageToDict from google.protobuf.timestamp_pb2 import Timestamp from loguru import logger import hatchet_sdk.v2.callable as callable -import hatchet_sdk.v2.hatchet as hatchet import hatchet_sdk.v2.runtime.messages as messages +import hatchet_sdk.v2.runtime.registry as registry +import hatchet_sdk.v2.runtime.utils as utils from hatchet_sdk.contracts.dispatcher_pb2 import ( ActionType, AssignedAction, @@ -35,16 +34,19 @@ def _format_exc(e: Exception): async def _invoke( action: AssignedAction, registry: Dict[str, "callable.HatchetCallableBase"] -): +) -> Tuple[str, None] | Tuple[None, Exception]: key = action.actionId - fn: "callable.HatchetCallableBase" = registry[key] # TODO + # TODO: handle cases when it's not registered more gracefully + fn: "callable.HatchetCallableBase" = registry[key] logger.trace("invoking: {}", repr(fn)) try: if isinstance(fn, callable.HatchetCallable): logger.trace("invoking {} on a separate thread", fn._hatchet.name) return await asyncio.to_thread(fn._run, action), None - else: + elif isinstance(fn, callable.HatchetAwaitable): return await fn._run(action), None + else: + raise NotImplementedError(f"unsupported callable case: {type(fn)}") except asyncio.CancelledError: raise except Exception as e: @@ -60,68 +62,76 @@ async def _invoke( class RunnerLoop: def __init__( self, - client: "hatchet.Hatchet", + *, + reg: "registry.ActionRegistry", inbound: _ThreadSafeQueue["messages.Message"], # inbound queue, not owned outbound: _ThreadSafeQueue["messages.Message"], # outbound queue, not owned ): logger.trace("init runner loop") - self.client = client - self.registry: Dict[str, "callable.HatchetCallableBase"] = ( - client.registry.registry - ) self.worker_id: Optional[str] = None - self.inbound = inbound - self.outbound = outbound - - self.looptask: Optional[asyncio.Task] = None + self._registry: Dict[str, "callable.HatchetCallableBase"] = reg.registry + self._inbound = inbound + self._outbound = outbound + self._loop_task: Optional[asyncio.Task] = None # a dict from StepRunId to its tasks - self.tasks: Dict[str, asyncio.Task] = dict() + self._tasks: Dict[str, asyncio.Task] = dict() def start(self): - logger.debug("starting runner loop on {}", threading.get_ident()) - self.looptask = asyncio.create_task(self.loop(), name="runner loop") + logger.trace("starting runner loop") + self._loop_task = asyncio.create_task(self._loop(), name="runner loop") async def shutdown(self): logger.trace("shutting down runner loop") - t = asyncio.gather(*self.tasks.values(), self.looptask) - t.cancel() - try: - await t - except asyncio.CancelledError: - logger.debug("bye") - - async def loop(self): - while True: - msg: "messages.Message" = await self.next() + # finishing all the tasks + t = asyncio.gather(*self._tasks.values()) + await t + + if self._loop_task is not None: + self._loop_task.cancel() + with suppress(asyncio.CancelledError): + await self._loop_task + self._loop_task = None + logger.trace("bye: runner loop") + + async def _loop(self): + """The main runner loop. + + It listens for actions from the sidecar process and executes them. + """ + async for msg in utils.QueueAgen(self._inbound): + logger.trace("received: {}", msg) assert msg.kind == messages.MessageKind.ACTION match msg.action.actionType: case ActionType.START_STEP_RUN: - self.on_run(msg) + self._on_run(msg) case ActionType.CANCEL_STEP_RUN: - self.on_cancel(msg) + self._on_cancel(msg) case _: - logger.debug(msg) + raise NotImplementedError(msg) - def on_run(self, msg: "messages.Message"): + def _on_run(self, msg: "messages.Message"): async def task(): logger.trace("running {}", msg.action.stepRunId) try: - await self.emit_started(msg) - result, e = await _invoke(msg.action, self.registry) + await self._emit_started(msg) + result, e = await _invoke(msg.action, self._registry) if e is None: - await self.emit_finished(msg, result) + assert result is not None + await self._emit_finished(msg, result) else: - await self.emit_failed(msg, _format_exc(e)) + assert result is None + await self._emit_failed(msg, _format_exc(e)) finally: - del self.tasks[msg.action.stepRunId] + del self._tasks[msg.action.stepRunId] - self.tasks[msg.action.stepRunId] = asyncio.create_task( + self._tasks[msg.action.stepRunId] = asyncio.create_task( task(), name=msg.action.stepRunId ) - def step_event(self, msg: "messages.Message", **kwargs) -> StepActionEvent: + def _step_event(self, msg: "messages.Message", **kwargs) -> StepActionEvent: + """Makes a StepActionEvent proto.""" base = StepActionEvent( jobId=msg.action.jobId, jobRunId=msg.action.jobRunId, @@ -131,49 +141,51 @@ def step_event(self, msg: "messages.Message", **kwargs) -> StepActionEvent: eventTimestamp=_timestamp(), ) base.MergeFrom(StepActionEvent(**kwargs)) - return MessageToDict(base) + return base - def on_cancel(self, msg: "messages.Message"): + def _on_cancel(self, msg: "messages.Message"): + # TODO pass - async def emit_started(self, msg: "messages.Message"): - await self.send( + async def _emit_started(self, msg: "messages.Message"): + await self._send( messages.Message( - _step_event=self.step_event( - msg, eventType=StepActionEventType.STEP_EVENT_TYPE_STARTED + _step_event=MessageToDict( + self._step_event( + msg, eventType=StepActionEventType.STEP_EVENT_TYPE_STARTED + ) ) ) ) - async def emit_finished(self, msg: "messages.Message", payload: str): - await self.send( + async def _emit_finished(self, msg: "messages.Message", payload: str): + await self._send( messages.Message( - _step_event=self.step_event( - msg, - eventType=StepActionEventType.STEP_EVENT_TYPE_COMPLETED, - eventPayload=payload, + _step_event=MessageToDict( + self._step_event( + msg, + eventType=StepActionEventType.STEP_EVENT_TYPE_COMPLETED, + eventPayload=payload, + ) ) ) ) - async def emit_failed(self, msg: "messages.Message", payload: str): - await self.send( + async def _emit_failed(self, msg: "messages.Message", payload: str): + await self._send( messages.Message( - _step_event=self.step_event( - msg, - eventType=StepActionEventType.STEP_EVENT_TYPE_FAILED, - eventPayload=payload, + _step_event=MessageToDict( + self._step_event( + msg, + eventType=StepActionEventType.STEP_EVENT_TYPE_FAILED, + eventPayload=payload, + ) ) ) ) - async def send(self, msg: "messages.Message"): - logger.trace("send:\n{}", msg) - await asyncio.to_thread(self.outbound.put, msg) - - async def next(self) -> "messages.Message": - msg = await asyncio.to_thread( - self.inbound.get - ) # raise EOFError if the queue is closed - logger.trace("recv:\n{}", msg) - return msg + async def _send(self, msg: "messages.Message"): + """Sends a message to the sidecar process.""" + logger.trace("send: {}", msg) + # TODO: pyright could not figure this out + await asyncio.to_thread(self._outbound.put, msg) # type: ignore diff --git a/hatchet_sdk/v2/runtime/runtime.py b/hatchet_sdk/v2/runtime/runtime.py index bf6d0739..72470b71 100644 --- a/hatchet_sdk/v2/runtime/runtime.py +++ b/hatchet_sdk/v2/runtime/runtime.py @@ -1,106 +1,130 @@ import asyncio import multiprocessing as mp -import os import queue -import sys import threading -from concurrent.futures import ProcessPoolExecutor +from concurrent.futures import CancelledError from contextlib import suppress from loguru import logger -import hatchet_sdk.loader as loader import hatchet_sdk.v2.hatchet as hatchet import hatchet_sdk.v2.runtime.future as future import hatchet_sdk.v2.runtime.messages as messages import hatchet_sdk.v2.runtime.runner as runner import hatchet_sdk.v2.runtime.utils as utils import hatchet_sdk.v2.runtime.worker as worker -from hatchet_sdk.contracts.dispatcher_pb2 import ( - SubscribeToWorkflowRunsRequest, - WorkflowRunEvent, -) class Runtime: - def __init__(self, client: "hatchet.Hatchet", options: "worker.WorkerOptions"): + """The Hatchet runtime. + + The runtime is managine the runner on the main process, the run event listener on the main process, + and the worker on the sidecar process, together with the queues among them. A Hatchet client should + only contain one Runtime object. The behavior will be undefined if there are multiple Runtime per + Hatchet client. + """ + + # TODO: rename WorkerOptions to RuntimeOptions. + def __init__(self, *, client: "hatchet.Hatchet", options: "worker.WorkerOptions"): logger.trace("init runtime") - self.client = client - self.process_pool = ProcessPoolExecutor() + self._client = client + self._executor = client.executor - self.to_worker: mp.Queue["messages.Message"] = mp.Queue() - self.from_worker: mp.Queue["messages.Message"] = mp.Queue() + # the main queues between the sidecar process and the main process + self._to_worker: mp.Queue["messages.Message"] = mp.Queue() + self._from_worker: mp.Queue["messages.Message"] = mp.Queue() - self.to_runner = queue.Queue() - self.to_wfr_futures = queue.Queue() + # the queue to the runner on the main process + self._to_runner = queue.Queue() - self.worker = worker.WorkerProcess( + # the queue to the workflow run event listener on the main process + self._to_wfr_futures = queue.Queue() + + # the worker on the sidecar process + self._worker = worker.WorkerProcess( config=client.config, - inbound=self.to_worker, - outbound=self.from_worker, + inbound=self._to_worker, + outbound=self._from_worker, options=options, ) self.worker_id = None - self.runner = runner.RunnerLoop( - client=client, inbound=self.to_runner, outbound=self.to_worker + # the runner on the main process + self._runner = runner.RunnerLoop( + reg=client.registry, + inbound=self._to_runner, + outbound=self._to_worker, ) - self.wfr_futures = future.WorkflowRunFutures( - # pool=client.executor, + + # the workflow run event listener on the main process + self._wfr_futures = future.WorkflowRunFutures( + executor=self._executor, broker=future.RequestResponseBroker( - inbound=self.to_wfr_futures, - outbound=self.to_worker, + inbound=self._to_wfr_futures, + outbound=self._to_worker, req_key=lambda msg: msg.subscribe_to_workflow_run.workflowRunId, resp_key=lambda msg: msg.workflow_run_event.workflowRunId, - executor=client.executor, + executor=self._executor, ), ) - self.loop_task = None + # the shutdown signal + self._shutdown = threading.Event() + self._loop_task = None - async def loop(self): - async for msg in utils.QueueAgen(self.from_worker): + async def _loop(self): + async for msg in utils.QueueAgen(self._from_worker): match msg.kind: case messages.MessageKind.ACTION: - await asyncio.to_thread(self.to_runner.put, msg) + await asyncio.to_thread(self._to_runner.put, msg) case messages.MessageKind.WORKFLOW_RUN_EVENT: - await asyncio.to_thread(self.to_wfr_futures.put, msg) + await asyncio.to_thread(self._to_wfr_futures.put, msg) case messages.MessageKind.WORKER_ID: - self.runner.worker_id = msg.worker_id + self._runner.worker_id = msg.worker_id self.worker_id = msg.worker_id case _: raise NotImplementedError + if self._shutdown.is_set(): + break + + logger.trace("bye: runtime") async def start(self): - logger.trace("starting runtime on {}", threading.get_ident()) - self.runner.start() - self.wfr_futures.start() - self.worker.start() + logger.debug("starting runtime") + + # NOTE: the order matters, we should start things in topological order + self._runner.start() + self._wfr_futures.start() - self.client.executor.submit(asyncio.run, self.loop()) + # schedule the runtime on a separate thread + self._loop_task = self._executor.submit(asyncio.run, self._loop()) + self._worker.start() while self.worker_id is None: await asyncio.sleep(1) + logger.debug("runtime started") return self.worker_id async def shutdown(self): logger.trace("shutting down runtime") - self.worker.shutdown() - # await self.worker.shutdown() - self.from_worker.close() - self.from_worker.join_thread() + # NOTE: the order matters, we should shut things down in topological order + self._worker.shutdown() + self._from_worker.close() + self._from_worker.join_thread() + + await self._runner.shutdown() + self._to_worker.close() + self._to_worker.join_thread() - await self.runner.shutdown() - self.to_worker.close() - self.to_worker.join_thread() + await self._wfr_futures.shutdown() - await self.wfr_futures.shutdown() + self._shutdown.set() - self.loop_task.cancel() - with suppress(asyncio.CancelledError): - await self.loop_task + if self._loop_task is not None: + with suppress(CancelledError): + self._loop_task.result(timeout=10) - logger.debug("bye") + logger.debug("bye: runtime") diff --git a/hatchet_sdk/v2/runtime/utils.py b/hatchet_sdk/v2/runtime/utils.py index e5d7e0a3..9112a0e6 100644 --- a/hatchet_sdk/v2/runtime/utils.py +++ b/hatchet_sdk/v2/runtime/utils.py @@ -1,14 +1,10 @@ import asyncio -import multiprocessing as mp import multiprocessing.queues as mpq import queue from collections.abc import AsyncGenerator, Callable from concurrent.futures import Future, ThreadPoolExecutor from contextlib import suppress -from typing import Tuple, TypeVar - -import grpc -import tenacity +from typing import Tuple, Type, TypeVar T = TypeVar("T") I = TypeVar("I") @@ -27,6 +23,7 @@ async def producer(): await queue.put(item) await queue.put(StopAsyncIteration()) + producer_task = None try: producer_task = asyncio.create_task(producer()) while True: @@ -45,13 +42,17 @@ async def producer(): break finally: - producer_task.cancel() - await producer_task + if producer_task: + producer_task.cancel() + await producer_task + + +E = TypeVar("E") async def ForeverAgen( - agen_factory: Callable[[], AsyncGenerator[T]], exceptions: Tuple[Exception] -) -> AsyncGenerator[T | Exception]: + agen_factory: Callable[[], AsyncGenerator[T]], exceptions: Tuple[Type[E]] +) -> AsyncGenerator[T | E]: """Run a async generator forever until its cancelled. Args: @@ -95,7 +96,7 @@ async def QueueAgen( def MapFuture( fn: Callable[[T], R], fut: Future[T], pool: ThreadPoolExecutor ) -> Future[R]: - def task(fn: Callable[[T], R], fut: Future[T]): + def task(fn: Callable[[T], R], fut: Future[T]) -> R: return fn(fut.result()) return pool.submit(task, fn, fut) diff --git a/hatchet_sdk/v2/runtime/worker.py b/hatchet_sdk/v2/runtime/worker.py index 28d7278a..b00365ea 100644 --- a/hatchet_sdk/v2/runtime/worker.py +++ b/hatchet_sdk/v2/runtime/worker.py @@ -1,24 +1,18 @@ -import sys import asyncio import multiprocessing as mp import multiprocessing.queues as mpq import multiprocessing.synchronize as mps -import os -import threading +import sys import time -from collections.abc import AsyncGenerator -from concurrent.futures import ThreadPoolExecutor from contextlib import suppress from dataclasses import dataclass, field -from enum import Enum -from typing import Dict, Generic, List, Optional, Set, TypeVar -import logging +from typing import Dict, List, Optional, TypeVar + import grpc from google.protobuf import timestamp_pb2 -from google.protobuf.json_format import MessageToDict, MessageToJson +from google.protobuf.json_format import MessageToDict from loguru import logger -import hatchet_sdk.contracts.dispatcher_pb2 import hatchet_sdk.v2.hatchet as hatchet import hatchet_sdk.v2.runtime.config as config import hatchet_sdk.v2.runtime.connection as connection @@ -27,24 +21,21 @@ import hatchet_sdk.v2.runtime.messages as messages import hatchet_sdk.v2.runtime.utils as utils from hatchet_sdk.contracts.dispatcher_pb2 import ( - ActionType, AssignedAction, HeartbeatRequest, - StepActionEvent, - StepRunResult, - SubscribeToWorkflowRunsRequest, WorkerLabels, - WorkerListenRequest, WorkerRegisterRequest, WorkerRegisterResponse, - WorkerUnsubscribeRequest, WorkflowRunEvent, ) from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub +# TODO: change it to RuntimeOptions @dataclass class WorkerOptions: + """Options for the runtime behavior of a Runtime.""" + name: str actions: List[str] slots: int = 5 @@ -65,37 +56,41 @@ def labels_proto(self) -> Dict[str, WorkerLabels]: class HeartBeater: def __init__(self, worker: "Worker"): - logger.debug("init heartbeater") - self.worker = worker + logger.trace("init heartbeater") + self._worker = worker # used to access worker id + self._stub = DispatcherStub(connection.ensure_background_channel()) + self.last_heartbeat: int = -1 # unix epoch in seconds - self.stub = DispatcherStub(connection.ensure_background_channel()) self.missed = 0 self.error = 0 - self.task = None + self._task = None async def start(self): - self.task = asyncio.create_task(self.heartbeat()) + logger.trace("starting heart beater") + self._task = asyncio.create_task(self._heartbeat()) while self.last_heartbeat < 0: await asyncio.sleep(1) async def shutdown(self): - if self.task: - self.task.cancel() + logger.trace("shutting down heart beater") + if self._task: + self._task.cancel() with suppress(asyncio.CancelledError): - await self.task - self.task = None + await self._task + self._task = None - async def heartbeat(self): + async def _heartbeat(self): + """The main heart beater loop.""" try: while True: now = int(time.time()) proto = HeartbeatRequest( - workerId=self.worker.id, + workerId=self._worker.id, heartbeatAt=timestamp_pb2.Timestamp(seconds=now), # TODO ) try: - _ = self.stub.Heartbeat( + _ = self._stub.Heartbeat( proto, timeout=5, metadata=context.ensure_background_context().client._grpc_metadata(), @@ -110,9 +105,9 @@ async def heartbeat(self): self.last_heartbeat = now else: diff = proto.heartbeatAt.seconds - self.last_heartbeat - if diff > self.worker.options.heartbeat: + if diff > self._worker.options.heartbeat: self.missed += 1 - await asyncio.sleep(self.worker.options.heartbeat) + await asyncio.sleep(self._worker.options.heartbeat) except Exception as e: logger.exception(e) raise @@ -125,6 +120,8 @@ async def heartbeat(self): class Worker: + """The main worker logic for the sidecar process.""" + def __init__( self, *, @@ -133,17 +130,20 @@ def __init__( inbound: mpq.Queue["messages.Message"], outbound: mpq.Queue["messages.Message"], ): - logger.debug("init worker") + logger.trace("init worker") context.ensure_background_context(client=client) self.id: Optional[str] = None - self.options = options - self.inbound = inbound - self.outbound = outbound + + # the main queues to/from the main process + self._inbound = inbound + self._outbound = outbound self._heartbeater = HeartBeater(self) + # used to interrupt the action listener + # TODO: need to hook this up to the heart beater so that the exceptions from heart beater can interrupt the action listener self._action_listener_interrupt: asyncio.Queue[StopAsyncIteration] = ( asyncio.Queue() ) @@ -152,48 +152,57 @@ def __init__( interrupt=self._action_listener_interrupt, ) + # the step event forwarder self._to_event_listner: asyncio.Queue["messages.Message"] = asyncio.Queue() - self._event_listner = listeners.StepEventListener(self._to_event_listner) + self._event_listner = listeners.StepEventListener( + inbound=self._to_event_listner + ) + # the workflow run listener self._workflow_run_event_listener = listeners.WorkflowRunEventListener() - self.main_loop_task = None + self._main_loop_task = None def _register(self) -> str: req = self._to_register_proto() - logger.trace("registering worker:\n{}", req) - resp: ( - WorkerRegisterResponse - ) = context.ensure_background_context().client.dispatcher.client.Register( - req, - timeout=30, - metadata=context.ensure_background_context().client._grpc_metadata(), + logger.trace("registering worker: {}", MessageToDict(req)) + resp: WorkerRegisterResponse = ( + context.ensure_background_context().client.dispatcher.client.Register( + req, + timeout=30, + metadata=context.ensure_background_context().client._grpc_metadata(), + ) ) - logger.debug("worker registered:\n{}", MessageToDict(resp)) + logger.debug("worker registered: {}", MessageToDict(resp)) return resp.workerId async def start(self) -> str: logger.trace("starting worker") self.id = self._register() + + # NOTE: order matters, we start them in topological order self._event_listner.start() self._workflow_run_event_listener.start() - self._action_listener.start(async_on=self.on_assigned_action) + self._action_listener.start(async_on=self._on_assigned_action) - self.main_loop_task = asyncio.create_task(self.loop()) + self._main_loop_task = asyncio.create_task(self._loop()) await self._heartbeater.start() - await asyncio.to_thread(self.outbound.put, messages.Message(worker_id=self.id)) + + # notify the worker id to the main process + await asyncio.to_thread(self._outbound.put, messages.Message(worker_id=self.id)) + logger.debug("worker started: {}", self.id) return self.id async def shutdown(self): logger.trace("shutting down worker {}", self.id) - if self.main_loop_task: - self.main_loop_task.cancel() + if self._main_loop_task: + self._main_loop_task.cancel() with suppress(asyncio.CancelledError): - await self.main_loop_task - self.main_loop_task = None + await self._main_loop_task + self._main_loop_task = None tg: asyncio.Future = asyncio.gather( self._heartbeater.shutdown(), @@ -202,17 +211,17 @@ async def shutdown(self): self._workflow_run_event_listener.shutdown(), ) await tg - logger.debug("bye") + logger.debug("bye: worker {}", self.id) - async def loop(self): + async def _loop(self): try: - async for msg in utils.QueueAgen(self.inbound): + async for msg in utils.QueueAgen(self._inbound): logger.trace("worker received msg: {}", msg) match msg.kind: case messages.MessageKind.STEP_EVENT: await self._to_event_listner.put(msg) case messages.MessageKind.SUBSCRIBE_TO_WORKFLOW_RUN: - await self.on_workflow_run_subscription(msg) + await self._on_workflow_run_subscription(msg) case _: raise NotImplementedError except Exception as e: @@ -221,7 +230,7 @@ async def loop(self): finally: logger.trace("bye: worker") - async def on_assigned_action( + async def _on_assigned_action( self, action: StopAsyncIteration | grpc.aio.AioRpcError | AssignedAction ): if isinstance(action, StopAsyncIteration): @@ -233,12 +242,12 @@ async def on_assigned_action( else: assert isinstance(action, AssignedAction) msg = messages.Message(_action=MessageToDict(action)) - await asyncio.to_thread(self.outbound.put, msg) + await asyncio.to_thread(self._outbound.put, msg) - async def on_workflow_run_subscription(self, msg: "messages.Message"): + async def _on_workflow_run_subscription(self, msg: "messages.Message"): def callback(f: asyncio.Future[WorkflowRunEvent]): logger.trace("workflow run event future resolved") - self.outbound.put( + self._outbound.put( messages.Message(_workflow_run_event=MessageToDict(f.result())) ) @@ -266,10 +275,18 @@ def _worker_process( outbound: mpq.Queue["messages.Message"], shutdown: mps.Event, ): + """The worker process logic. + + It has to be a top-level function since it needs to be pickled. + """ + # TODO: propagate options, debug, etc. client = hatchet.Hatchet(config=config, debug=True) + + # TODO: re-configure the loggers based on the options, etc. logger.remove() logger.add(sys.stdout, level="TRACE") + # FIXME: the loop is not exiting correctly. It hangs, instead. Investigate why. async def loop(): worker = Worker( client=client, @@ -278,10 +295,10 @@ async def loop(): options=options, ) try: - id = await worker.start() + _ = await worker.start() while not await asyncio.to_thread(shutdown.wait, 1): pass - asyncio.current_task().cancel() + # asyncio.current_task().cancel() except Exception as e: logger.exception(e) raise @@ -291,10 +308,12 @@ async def loop(): logger.trace("worker process shuts down") asyncio.run(loop(), debug=True) - logger.trace("here") + logger.trace("bye: worker process") class WorkerProcess: + """A wrapper to control the sidecar worker process.""" + def __init__( self, *, @@ -303,8 +322,8 @@ def __init__( inbound: mpq.Queue["messages.Message"], outbound: mpq.Queue["messages.Message"], ): - self.to_worker = inbound - self.shutdown_ev = mp.Event() + self._to_worker = inbound + self._shutdown_ev = mp.Event() self.proc = mp.Process( target=_worker_process, kwargs={ @@ -312,7 +331,7 @@ def __init__( "options": options, "inbound": inbound, "outbound": outbound, - "shutdown": self.shutdown_ev, + "shutdown": self._shutdown_ev, }, ) @@ -321,5 +340,5 @@ def start(self): self.proc.start() def shutdown(self): - self.shutdown_ev.set() + self._shutdown_ev.set() logger.debug("worker process shuts down") diff --git a/tests/v2/test_worker.py b/tests/v2/test_worker.py index 4885bed2..38f1f7eb 100644 --- a/tests/v2/test_worker.py +++ b/tests/v2/test_worker.py @@ -11,10 +11,10 @@ from hatchet_sdk.v2.runtime.worker import WorkerOptions, WorkerProcess from concurrent.futures import ThreadPoolExecutor -logger.remove() -logger.add( - sys.stdout, level="TRACE" -) # , format="{level}\t|{module}:{function}:{line}[{process}:{thread}] - {message}") +# logger.remove() +# logger.add( +# sys.stdout, level="TRACE" +# ) # , format="{level}\t|{module}:{function}:{line}[{process}:{thread}] - {message}") dotenv.load_dotenv() @@ -26,7 +26,7 @@ @hatchet.function() def foo(): print("entering Foo") - print("result from bar: ", bar("from foo")) + print("result from bar: ", bar("from foo").result()) return "foo" @@ -39,11 +39,12 @@ def bar(x): @pytest.mark.asyncio async def test_worker(): + worker = hatchet.worker( - WorkerOptions(name="worker", actions=["default:foo", "default:bar"]) + options=WorkerOptions(name="worker", actions=["default:foo", "default:bar"]) ) await worker.start() - print("result from foo: ", await asyncio.to_thread(foo)) + print("result from foo: ", await asyncio.to_thread(foo().result)) await asyncio.sleep(10) await worker.shutdown() return None