From c846a0131bdffdd2404db4933ec50542f862616c Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Fri, 31 Jan 2025 15:48:40 -0500 Subject: [PATCH 1/6] feat: initial pass at hatchet function --- examples/simple/worker.py | 17 +-- hatchet_sdk/v2/hatchet.py | 297 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 304 insertions(+), 10 deletions(-) create mode 100644 hatchet_sdk/v2/hatchet.py diff --git a/examples/simple/worker.py b/examples/simple/worker.py index c6aab0ad..f886bce0 100644 --- a/examples/simple/worker.py +++ b/examples/simple/worker.py @@ -3,20 +3,17 @@ hatchet = Hatchet(debug=True) -class MyWorkflow(BaseWorkflow): - @hatchet.step(timeout="11s", retries=3) - def step1(self, context: Context) -> dict[str, str]: - print("executed step1") - return { - "step1": "step1", - } +@hatchet.function(timeout="11s") +def step1(context: Context) -> dict[str, str]: + print("executed step1") + return { + "step1": "step1", + } def main() -> None: - wf = MyWorkflow() - worker = hatchet.worker("test-worker", max_runs=1) - worker.register_workflow(wf) + worker.register_function(step1) worker.start() diff --git a/hatchet_sdk/v2/hatchet.py b/hatchet_sdk/v2/hatchet.py new file mode 100644 index 00000000..ba50c5c0 --- /dev/null +++ b/hatchet_sdk/v2/hatchet.py @@ -0,0 +1,297 @@ +import asyncio +import logging +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Optional, + ParamSpec, + Type, + TypeVar, + cast, +) + +from hatchet_sdk.client import Client, new_client, new_client_raw +from hatchet_sdk.clients.admin import AdminClient +from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient +from hatchet_sdk.clients.events import EventClient +from hatchet_sdk.clients.rest_client import RestApi +from hatchet_sdk.clients.run_event_listener import RunEventListenerClient +from hatchet_sdk.context.context import Context +from hatchet_sdk.contracts.workflows_pb2 import DesiredWorkerLabels +from hatchet_sdk.features.cron import CronClient +from hatchet_sdk.features.scheduled import ScheduledClient +from hatchet_sdk.labels import DesiredWorkerLabel +from hatchet_sdk.loader import ClientConfig +from hatchet_sdk.logger import logger +from hatchet_sdk.rate_limit import RateLimit +from hatchet_sdk.v2.workflows import ( + BaseWorkflowImpl, + ConcurrencyExpression, + EmptyModel, + Step, + StepType, + StickyStrategy, + TWorkflowInput, + WorkflowConfig, + WorkflowDeclaration, +) + +if TYPE_CHECKING: + from hatchet_sdk.worker.worker import Worker + +P = ParamSpec("P") +R = TypeVar("R") + + +def transform_desired_worker_label(d: DesiredWorkerLabel) -> DesiredWorkerLabels: + value = d.value + return DesiredWorkerLabels( + strValue=value if not isinstance(value, int) else None, + intValue=value if isinstance(value, int) else None, + required=d.required, + weight=d.weight, + comparator=d.comparator, # type: ignore[arg-type] + ) + + +class Hatchet: + """ + Main client for interacting with the Hatchet SDK. + + This class provides access to various client interfaces and utility methods + for working with Hatchet workers, workflows, and steps. + + Attributes: + cron (CronClient): Interface for cron trigger operations. + + admin (AdminClient): Interface for administrative operations. + dispatcher (DispatcherClient): Interface for dispatching operations. + event (EventClient): Interface for event-related operations. + rest (RestApi): Interface for REST API operations. + """ + + _client: Client + cron: CronClient + scheduled: ScheduledClient + + @classmethod + def from_environment( + cls, defaults: ClientConfig = ClientConfig(), **kwargs: Any + ) -> "Hatchet": + return cls(client=new_client(defaults), **kwargs) + + @classmethod + def from_config(cls, config: ClientConfig, **kwargs: Any) -> "Hatchet": + return cls(client=new_client_raw(config), **kwargs) + + def __init__( + self, + debug: bool = False, + client: Optional[Client] = None, + config: ClientConfig = ClientConfig(), + ): + """ + Initialize a new Hatchet instance. + + Args: + debug (bool, optional): Enable debug logging. Defaults to False. + client (Optional[Client], optional): A pre-configured Client instance. Defaults to None. + config (ClientConfig, optional): Configuration for creating a new Client. Defaults to ClientConfig(). + """ + if client is not None: + self._client = client + else: + self._client = new_client(config, debug) + + if debug: + logger.setLevel(logging.DEBUG) + + self.cron = CronClient(self._client) + self.scheduled = ScheduledClient(self._client) + + @property + def admin(self) -> AdminClient: + return self._client.admin + + @property + def dispatcher(self) -> DispatcherClient: + return self._client.dispatcher + + @property + def event(self) -> EventClient: + return self._client.event + + @property + def rest(self) -> RestApi: + return self._client.rest + + @property + def listener(self) -> RunEventListenerClient: + return self._client.listener + + @property + def config(self) -> ClientConfig: + return self._client.config + + @property + def tenant_id(self) -> str: + return self._client.config.tenant_id + + def step( + self, + name: str = "", + timeout: str = "60m", + parents: list[str] = [], + retries: int = 0, + rate_limits: list[RateLimit] = [], + desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, + backoff_factor: float | None = None, + backoff_max_seconds: int | None = None, + ) -> Callable[[Callable[[Any, Context], Any]], Step[R]]: + def inner(func: Callable[[Any, Context], R]) -> Step[R]: + return Step( + fn=func, + type=StepType.DEFAULT, + name=name.lower() or str(func.__name__).lower(), + timeout=timeout, + parents=parents, + retries=retries, + rate_limits=[r for rate_limit in rate_limits if (r := rate_limit._req)], + desired_worker_labels={ + key: transform_desired_worker_label(d) + for key, d in desired_worker_labels.items() + }, + backoff_factor=backoff_factor, + backoff_max_seconds=backoff_max_seconds, + ) + + return inner + + def on_failure_step( + self, + name: str = "", + timeout: str = "60m", + parents: list[str] = [], + retries: int = 0, + rate_limits: list[RateLimit] = [], + desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, + backoff_factor: float | None = None, + backoff_max_seconds: int | None = None, + ) -> Callable[[Callable[[Any, Context], Any]], Step[R]]: + def inner(func: Callable[[Any, Context], R]) -> Step[R]: + return Step( + fn=func, + type=StepType.ON_FAILURE, + name=name.lower() or str(func.__name__).lower(), + timeout=timeout, + parents=parents, + retries=retries, + rate_limits=[r for rate_limit in rate_limits if (r := rate_limit._req)], + desired_worker_labels={ + key: transform_desired_worker_label(d) + for key, d in desired_worker_labels.items() + }, + backoff_factor=backoff_factor, + backoff_max_seconds=backoff_max_seconds, + ) + + return inner + + def function( + self, + name: str = "", + on_events: list[str] = [], + on_crons: list[str] = [], + version: str = "", + timeout: str = "60m", + schedule_timeout: str = "5m", + sticky: StickyStrategy | None = None, + default_priority: int = 1, + concurrency: ConcurrencyExpression | None = None, + input_validator: Type[TWorkflowInput] | None = None, + ) -> Callable[[Callable[[Context], R]], BaseWorkflowImpl]: + declaration = WorkflowDeclaration[TWorkflowInput]( + WorkflowConfig( + name=name, + on_events=on_events, + on_crons=on_crons, + version=version, + timeout=timeout, + schedule_timeout=schedule_timeout, + sticky=sticky, + default_priority=default_priority, + concurrency=concurrency, + input_validator=input_validator + or cast(Type[TWorkflowInput], EmptyModel), + ), + self, + ) + + def inner(func: Callable[[Context], R]) -> BaseWorkflowImpl: + class Workflow(BaseWorkflowImpl): + config = declaration.config + + @self.step( + name=name, + timeout=timeout, + retries=0, + rate_limits=[], + backoff_factor=None, + backoff_max_seconds=None, + ) + def fn(self, context: Context) -> R: + return func(context) + + return Workflow() + + return inner + + def worker( + self, name: str, max_runs: int | None = None, labels: dict[str, str | int] = {} + ) -> "Worker": + from hatchet_sdk.worker.worker import Worker + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + return Worker( + name=name, + max_runs=max_runs, + labels=labels, + config=self._client.config, + debug=self._client.debug, + owned_loop=loop is None, + ) + + def declare_workflow( + self, + name: str = "", + on_events: list[str] = [], + on_crons: list[str] = [], + version: str = "", + timeout: str = "60m", + schedule_timeout: str = "5m", + sticky: StickyStrategy | None = None, + default_priority: int = 1, + concurrency: ConcurrencyExpression | None = None, + input_validator: Type[TWorkflowInput] | None = None, + ) -> WorkflowDeclaration[TWorkflowInput]: + return WorkflowDeclaration[TWorkflowInput]( + WorkflowConfig( + name=name, + on_events=on_events, + on_crons=on_crons, + version=version, + timeout=timeout, + schedule_timeout=schedule_timeout, + sticky=sticky, + default_priority=default_priority, + concurrency=concurrency, + input_validator=input_validator + or cast(Type[TWorkflowInput], EmptyModel), + ), + self, + ) From 5fb9244517e85b6fcaf14b392527fb86cc842a49 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Fri, 31 Jan 2025 15:55:02 -0500 Subject: [PATCH 2/6] fix: loader --- hatchet_sdk/loader.py | 150 +++++++++++++++++++++++++++++++++++------- 1 file changed, 125 insertions(+), 25 deletions(-) diff --git a/hatchet_sdk/loader.py b/hatchet_sdk/loader.py index a4252c22..ac58befb 100644 --- a/hatchet_sdk/loader.py +++ b/hatchet_sdk/loader.py @@ -1,4 +1,5 @@ import json +import os from logging import Logger, getLogger from pydantic import Field, field_validator, model_validator @@ -7,19 +8,36 @@ from hatchet_sdk.token import get_addresses_from_jwt, get_tenant_id_from_jwt -def create_settings_config(env_prefix: str) -> SettingsConfigDict: - return SettingsConfigDict( - env_prefix=env_prefix, - env_file=(".env", ".env.hatchet", ".env.dev", ".env.local"), - extra="ignore", - ) +class ClientTLSConfig(BaseModel): + tls_strategy: str + cert_file: str | None + key_file: str | None + ca_file: str | None + server_name: str -class ClientTLSConfig(BaseSettings): - model_config = create_settings_config( - env_prefix="HATCHET_CLIENT_TLS_", +def _load_tls_config(host_port: str | None = None) -> ClientTLSConfig: + server_name = os.getenv("HATCHET_CLIENT_TLS_SERVER_NAME") + + if not server_name and host_port: + server_name = host_port.split(":")[0] + + if not server_name: + server_name = "localhost" + + return ClientTLSConfig( + tls_strategy=os.getenv("HATCHET_CLIENT_TLS_STRATEGY", "tls"), + cert_file=os.getenv("HATCHET_CLIENT_TLS_CERT_FILE"), + key_file=os.getenv("HATCHET_CLIENT_TLS_KEY_FILE"), + ca_file=os.getenv("HATCHET_CLIENT_TLS_ROOT_CA_FILE"), + server_name=server_name, ) + +def parse_listener_timeout(timeout: str | None) -> int | None: + if timeout is None: + return None + strategy: str = "tls" cert_file: str | None = None key_file: str | None = None @@ -39,16 +57,19 @@ class HealthcheckConfig(BaseSettings): DEFAULT_HOST_PORT = "localhost:7070" -class ClientConfig(BaseSettings): - model_config = create_settings_config( - env_prefix="HATCHET_CLIENT_", - ) +class ClientConfig(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, validate_default=True) - token: str = "" + token: str = os.getenv("HATCHET_CLIENT_TOKEN", "") logger: Logger = getLogger() + tenant_id: str = os.getenv("HATCHET_CLIENT_TENANT_ID", "") + + ## IMPORTANT: Order matters here. The validators run in the order that the + ## fields are defined in the model. So, we need to make sure that the + ## host_port is set before we try to load the tls_config and server_url + host_port: str = os.getenv("HATCHET_CLIENT_HOST_PORT", DEFAULT_HOST_PORT) + tls_config: ClientTLSConfig = _load_tls_config() - tenant_id: str = "" - host_port: str = DEFAULT_HOST_PORT server_url: str = "https://app.dev.hatchet-tools.com" namespace: str = "" @@ -59,19 +80,36 @@ class ClientConfig(BaseSettings): grpc_max_recv_message_length: int = Field( default=4 * 1024 * 1024, description="4MB default" ) - grpc_max_send_message_length: int = Field( - default=4 * 1024 * 1024, description="4MB default" + grpc_max_recv_message_length: int = int( + os.getenv("HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH", 4 * 1024 * 1024) + ) # 4MB + grpc_max_send_message_length: int = int( + os.getenv("HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH", 4 * 1024 * 1024) + ) # 4MB + otel_exporter_oltp_endpoint: str | None = os.getenv( + "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT" + ) + otel_service_name: str | None = os.getenv("HATCHET_CLIENT_OTEL_SERVICE_NAME") + otel_exporter_oltp_headers: str | None = os.getenv( + "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS" + ) + otel_exporter_oltp_protocol: str | None = os.getenv( + "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_PROTOCOL" + ) + worker_healthcheck_port: int = int( + os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT", 8001) + ) + worker_healthcheck_enabled: bool = ( + os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED", "False") == "True" ) - worker_preset_labels: dict[str, str] = Field(default_factory=dict) - - @model_validator(mode="after") - def validate_token_and_tenant(self) -> "ClientConfig": - if not self.token: + @field_validator("token", mode="after") + @classmethod + def validate_token(cls, token: str) -> str: + if not token: raise ValueError("Token must be set") - if not self.tenant_id: - self.tenant_id = get_tenant_id_from_jwt(self.token) + return token return self @@ -108,9 +146,71 @@ def validate_listener_timeout(cls, value: int | None | str) -> int | None: def validate_namespace(cls, namespace: str) -> str: if not namespace: return "" + if not namespace.endswith("_"): namespace = f"{namespace}_" + return namespace.lower() + @field_validator("tenant_id", mode="after") + @classmethod + def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str: + token = cast(str | None, info.data.get("token")) + + if not tenant_id: + if not token: + raise ValueError("Either the token or tenant_id must be set") + + return get_tenant_id_from_jwt(token) + + return tenant_id + + @field_validator("host_port", mode="after") + @classmethod + def validate_host_port(cls, host_port: str, info: ValidationInfo) -> str: + if host_port and host_port != DEFAULT_HOST_PORT: + return host_port + + token = cast(str, info.data.get("token")) + + if not token: + raise ValueError("Token must be set") + + _, grpc_broadcast_address = get_addresses_from_jwt(token) + + return grpc_broadcast_address + + @field_validator("server_url", mode="after") + @classmethod + def validate_server_url(cls, server_url: str, info: ValidationInfo) -> str: + ## IMPORTANT: Order matters here. The validators run in the order that the + ## fields are defined in the model. So, we need to make sure that the + ## host_port is set before we try to load the server_url + host_port = cast(str, info.data.get("host_port")) + + if host_port and host_port != DEFAULT_HOST_PORT: + return host_port + + token = cast(str, info.data.get("token")) + + if not token: + raise ValueError("Token must be set") + + _server_url, _ = get_addresses_from_jwt(token) + + return _server_url + + @field_validator("tls_config", mode="after") + @classmethod + def validate_tls_config( + cls, tls_config: ClientTLSConfig, info: ValidationInfo + ) -> ClientTLSConfig: + ## IMPORTANT: Order matters here. This validator runs in the order + ## that the fields are defined in the model. So, we need to make sure + ## that the host_port is set before we try to load the tls_config + host_port = cast(str, info.data.get("host_port")) + + return _load_tls_config(host_port) + def __hash__(self) -> int: return hash(json.dumps(self.model_dump(), default=str)) From 957ef3f3bf406e96b90c6c251c2bd06c09b8b7fa Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Fri, 31 Jan 2025 16:01:58 -0500 Subject: [PATCH 3/6] fix: move declaration inside of `inner` decorator --- examples/simple/worker.py | 3 +++ hatchet_sdk/v2/hatchet.py | 36 ++++++++++++++++++------------------ 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/examples/simple/worker.py b/examples/simple/worker.py index f886bce0..277e9784 100644 --- a/examples/simple/worker.py +++ b/examples/simple/worker.py @@ -1,5 +1,8 @@ from hatchet_sdk import BaseWorkflow, Context, Hatchet +from hatchet_sdk import Context +from hatchet_sdk.v2 import Hatchet + hatchet = Hatchet(debug=True) diff --git a/hatchet_sdk/v2/hatchet.py b/hatchet_sdk/v2/hatchet.py index ba50c5c0..1538ad46 100644 --- a/hatchet_sdk/v2/hatchet.py +++ b/hatchet_sdk/v2/hatchet.py @@ -211,29 +211,29 @@ def function( concurrency: ConcurrencyExpression | None = None, input_validator: Type[TWorkflowInput] | None = None, ) -> Callable[[Callable[[Context], R]], BaseWorkflowImpl]: - declaration = WorkflowDeclaration[TWorkflowInput]( - WorkflowConfig( - name=name, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - sticky=sticky, - default_priority=default_priority, - concurrency=concurrency, - input_validator=input_validator - or cast(Type[TWorkflowInput], EmptyModel), - ), - self, - ) - def inner(func: Callable[[Context], R]) -> BaseWorkflowImpl: + declaration = WorkflowDeclaration[TWorkflowInput]( + WorkflowConfig( + name=name or func.__name__, + on_events=on_events, + on_crons=on_crons, + version=version, + timeout=timeout, + schedule_timeout=schedule_timeout, + sticky=sticky, + default_priority=default_priority, + concurrency=concurrency, + input_validator=input_validator + or cast(Type[TWorkflowInput], EmptyModel), + ), + self, + ) + class Workflow(BaseWorkflowImpl): config = declaration.config @self.step( - name=name, + name=declaration.config.name, timeout=timeout, retries=0, rate_limits=[], From 661a40e58c049921af22a75a6367a32f06615f2b Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Fri, 31 Jan 2025 17:32:56 -0500 Subject: [PATCH 4/6] feat: registration for standalone functions --- examples/simple/worker.py | 11 ++-- hatchet_sdk/v2/hatchet.py | 114 ++++++++++++++++++++++++----------- hatchet_sdk/worker/worker.py | 30 ++++++++- 3 files changed, 115 insertions(+), 40 deletions(-) diff --git a/examples/simple/worker.py b/examples/simple/worker.py index 277e9784..7b05a021 100644 --- a/examples/simple/worker.py +++ b/examples/simple/worker.py @@ -6,12 +6,13 @@ hatchet = Hatchet(debug=True) -@hatchet.function(timeout="11s") +@hatchet.function() def step1(context: Context) -> dict[str, str]: - print("executed step1") - return { - "step1": "step1", - } + message = "Hello from Hatchet!" + + context.log(message) + + return {"message": message} def main() -> None: diff --git a/hatchet_sdk/v2/hatchet.py b/hatchet_sdk/v2/hatchet.py index 1538ad46..601e9908 100644 --- a/hatchet_sdk/v2/hatchet.py +++ b/hatchet_sdk/v2/hatchet.py @@ -4,10 +4,12 @@ TYPE_CHECKING, Any, Callable, + Generic, Optional, ParamSpec, Type, TypeVar, + Union, cast, ) @@ -55,6 +57,56 @@ def transform_desired_worker_label(d: DesiredWorkerLabel) -> DesiredWorkerLabels ) +class Function(Generic[R, TWorkflowInput]): + def __init__( + self, + fn: Callable[[Context], R], + hatchet: "Hatchet", + name: str = "", + on_events: list[str] = [], + on_crons: list[str] = [], + version: str = "", + timeout: str = "60m", + schedule_timeout: str = "5m", + sticky: StickyStrategy | None = None, + retries: int = 0, + rate_limits: list[RateLimit] = [], + desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, + concurrency: ConcurrencyExpression | None = None, + on_failure: Union["Function[R]", None] = None, + default_priority: int = 1, + input_validator: Type[TWorkflowInput] | None = None, + backoff_factor: float | None = None, + backoff_max_seconds: int | None = None, + ) -> None: + def func(_: Any, context: Context) -> R: + return fn(context) + + self.hatchet = hatchet + self.step: Step[R] = hatchet.step( + name=name or fn.__name__, + timeout=timeout, + retries=retries, + rate_limits=rate_limits, + desired_worker_labels=desired_worker_labels, + backoff_factor=backoff_factor, + backoff_max_seconds=backoff_max_seconds, + )(func) + self.on_failure_step = on_failure + self.workflow_config = WorkflowConfig( + name=name or fn.__name__, + on_events=on_events, + on_crons=on_crons, + version=version, + timeout=timeout, + schedule_timeout=schedule_timeout, + sticky=sticky, + default_priority=default_priority, + concurrency=concurrency, + input_validator=input_validator or cast(Type[TWorkflowInput], EmptyModel), + ) + + class Hatchet: """ Main client for interacting with the Hatchet SDK. @@ -207,44 +259,38 @@ def function( timeout: str = "60m", schedule_timeout: str = "5m", sticky: StickyStrategy | None = None, - default_priority: int = 1, + retries: int = 0, + rate_limits: list[RateLimit] = [], + desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, concurrency: ConcurrencyExpression | None = None, + on_failure: Union["Function[Any]", None] = None, + default_priority: int = 1, input_validator: Type[TWorkflowInput] | None = None, - ) -> Callable[[Callable[[Context], R]], BaseWorkflowImpl]: - def inner(func: Callable[[Context], R]) -> BaseWorkflowImpl: - declaration = WorkflowDeclaration[TWorkflowInput]( - WorkflowConfig( - name=name or func.__name__, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - sticky=sticky, - default_priority=default_priority, - concurrency=concurrency, - input_validator=input_validator - or cast(Type[TWorkflowInput], EmptyModel), - ), - self, + backoff_factor: float | None = None, + backoff_max_seconds: int | None = None, + ) -> Callable[[Callable[[Context], R]], Function[R, TWorkflowInput]]: + def inner(func: Callable[[Context], R]) -> Function[R, TWorkflowInput]: + return Function[R, TWorkflowInput]( + func, + hatchet=self, + name=name, + on_events=on_events, + on_crons=on_crons, + version=version, + timeout=timeout, + schedule_timeout=schedule_timeout, + sticky=sticky, + retries=retries, + rate_limits=rate_limits, + desired_worker_labels=desired_worker_labels, + concurrency=concurrency, + on_failure=on_failure, + default_priority=default_priority, + input_validator=input_validator, + backoff_factor=backoff_factor, + backoff_max_seconds=backoff_max_seconds, ) - class Workflow(BaseWorkflowImpl): - config = declaration.config - - @self.step( - name=declaration.config.name, - timeout=timeout, - retries=0, - rate_limits=[], - backoff_factor=None, - backoff_max_seconds=None, - ) - def fn(self, context: Context) -> R: - return func(context) - - return Workflow() - return inner def worker( diff --git a/hatchet_sdk/worker/worker.py b/hatchet_sdk/worker/worker.py index 28fd1f27..091b3651 100644 --- a/hatchet_sdk/worker/worker.py +++ b/hatchet_sdk/worker/worker.py @@ -108,7 +108,33 @@ def register_workflow_from_opts( logger.error(e) sys.exit(1) - def register_workflow(self, workflow: Union["BaseWorkflow", Any]) -> None: + def register_function(self, function: "Function[Any]") -> None: + from hatchet_sdk.workflow import BaseWorkflow + + declaration = function.hatchet.declare_workflow( + **function.workflow_config.model_dump() + ) + + class Workflow(BaseWorkflow): + config = declaration.config + + @property + def default_steps(self) -> list[Step[Any]]: + return [function.step] + + @property + def on_failure_steps(self) -> list[Step[Any]]: + if not function.on_failure_step: + return [] + + step = function.on_failure_step.step + step.type = StepType.ON_FAILURE + + return [step] + + self.register_workflow(Workflow()) + + def register_workflow(self, workflow: Union["BaseWorkflowImpl", Any]) -> None: namespace = self.client.config.namespace try: @@ -120,7 +146,9 @@ def register_workflow(self, workflow: Union["BaseWorkflow", Any]) -> None: logger.error(e) sys.exit(1) + print(workflow.steps) for step in workflow.steps: + print(step) action_name = workflow.create_action_name(namespace, step) self.action_registry[action_name] = step return_type = get_type_hints(step.fn).get("return") From a4eef4e220c66e3a47efeeebfd13fc681426a1f8 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Fri, 31 Jan 2025 17:33:27 -0500 Subject: [PATCH 5/6] fix: rm print cruft --- hatchet_sdk/worker/worker.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/hatchet_sdk/worker/worker.py b/hatchet_sdk/worker/worker.py index 091b3651..057bd33e 100644 --- a/hatchet_sdk/worker/worker.py +++ b/hatchet_sdk/worker/worker.py @@ -146,9 +146,7 @@ def register_workflow(self, workflow: Union["BaseWorkflowImpl", Any]) -> None: logger.error(e) sys.exit(1) - print(workflow.steps) for step in workflow.steps: - print(step) action_name = workflow.create_action_name(namespace, step) self.action_registry[action_name] = step return_type = get_type_hints(step.fn).get("return") From 0a4b7b30cdf9c489c3101fcf44772d5e0b29842f Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Tue, 11 Feb 2025 21:58:09 -0500 Subject: [PATCH 6/6] fix: cleanup rebase issues --- examples/simple/worker.py | 5 +- hatchet_sdk/hatchet.py | 95 +++++++++- hatchet_sdk/loader.py | 150 +++------------ hatchet_sdk/v2/hatchet.py | 343 ----------------------------------- hatchet_sdk/worker/worker.py | 6 +- 5 files changed, 123 insertions(+), 476 deletions(-) delete mode 100644 hatchet_sdk/v2/hatchet.py diff --git a/examples/simple/worker.py b/examples/simple/worker.py index 7b05a021..90a069ad 100644 --- a/examples/simple/worker.py +++ b/examples/simple/worker.py @@ -1,7 +1,4 @@ -from hatchet_sdk import BaseWorkflow, Context, Hatchet - -from hatchet_sdk import Context -from hatchet_sdk.v2 import Hatchet +from hatchet_sdk import Context, Hatchet hatchet = Hatchet(debug=True) diff --git a/hatchet_sdk/hatchet.py b/hatchet_sdk/hatchet.py index fca00583..72357523 100644 --- a/hatchet_sdk/hatchet.py +++ b/hatchet_sdk/hatchet.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar, cast +from typing import TYPE_CHECKING, Any, Callable, Generic, Type, TypeVar, Union, cast from hatchet_sdk.client import Client, new_client, new_client_raw from hatchet_sdk.clients.admin import AdminClient @@ -44,6 +44,56 @@ def transform_desired_worker_label(d: DesiredWorkerLabel) -> DesiredWorkerLabels ) +class Function(Generic[R, TWorkflowInput]): + def __init__( + self, + fn: Callable[[Context], R], + hatchet: "Hatchet", + name: str = "", + on_events: list[str] = [], + on_crons: list[str] = [], + version: str = "", + timeout: str = "60m", + schedule_timeout: str = "5m", + sticky: StickyStrategy | None = None, + retries: int = 0, + rate_limits: list[RateLimit] = [], + desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, + concurrency: ConcurrencyExpression | None = None, + on_failure: Union["Function[R]", None] = None, + default_priority: int = 1, + input_validator: Type[TWorkflowInput] | None = None, + backoff_factor: float | None = None, + backoff_max_seconds: int | None = None, + ) -> None: + def func(_: Any, context: Context) -> R: + return fn(context) + + self.hatchet = hatchet + self.step: Step[R] = hatchet.step( + name=name or fn.__name__, + timeout=timeout, + retries=retries, + rate_limits=rate_limits, + desired_worker_labels=desired_worker_labels, + backoff_factor=backoff_factor, + backoff_max_seconds=backoff_max_seconds, + )(func) + self.on_failure_step = on_failure + self.workflow_config = WorkflowConfig( + name=name or fn.__name__, + on_events=on_events, + on_crons=on_crons, + version=version, + timeout=timeout, + schedule_timeout=schedule_timeout, + sticky=sticky, + default_priority=default_priority, + concurrency=concurrency, + input_validator=input_validator or cast(Type[TWorkflowInput], EmptyModel), + ) + + class Hatchet: """ Main client for interacting with the Hatchet SDK. @@ -187,6 +237,49 @@ def inner(func: Callable[[Any, Context], R]) -> Step[R]: return inner + def function( + self, + name: str = "", + on_events: list[str] = [], + on_crons: list[str] = [], + version: str = "", + timeout: str = "60m", + schedule_timeout: str = "5m", + sticky: StickyStrategy | None = None, + retries: int = 0, + rate_limits: list[RateLimit] = [], + desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, + concurrency: ConcurrencyExpression | None = None, + on_failure: Union["Function[Any]", None] = None, + default_priority: int = 1, + input_validator: Type[TWorkflowInput] | None = None, + backoff_factor: float | None = None, + backoff_max_seconds: int | None = None, + ) -> Callable[[Callable[[Context], R]], Function[R, TWorkflowInput]]: + def inner(func: Callable[[Context], R]) -> Function[R, TWorkflowInput]: + return Function[R, TWorkflowInput]( + func, + hatchet=self, + name=name, + on_events=on_events, + on_crons=on_crons, + version=version, + timeout=timeout, + schedule_timeout=schedule_timeout, + sticky=sticky, + retries=retries, + rate_limits=rate_limits, + desired_worker_labels=desired_worker_labels, + concurrency=concurrency, + on_failure=on_failure, + default_priority=default_priority, + input_validator=input_validator, + backoff_factor=backoff_factor, + backoff_max_seconds=backoff_max_seconds, + ) + + return inner + def worker( self, name: str, max_runs: int | None = None, labels: dict[str, str | int] = {} ) -> "Worker": diff --git a/hatchet_sdk/loader.py b/hatchet_sdk/loader.py index ac58befb..a4252c22 100644 --- a/hatchet_sdk/loader.py +++ b/hatchet_sdk/loader.py @@ -1,5 +1,4 @@ import json -import os from logging import Logger, getLogger from pydantic import Field, field_validator, model_validator @@ -8,35 +7,18 @@ from hatchet_sdk.token import get_addresses_from_jwt, get_tenant_id_from_jwt -class ClientTLSConfig(BaseModel): - tls_strategy: str - cert_file: str | None - key_file: str | None - ca_file: str | None - server_name: str - - -def _load_tls_config(host_port: str | None = None) -> ClientTLSConfig: - server_name = os.getenv("HATCHET_CLIENT_TLS_SERVER_NAME") - - if not server_name and host_port: - server_name = host_port.split(":")[0] - - if not server_name: - server_name = "localhost" - - return ClientTLSConfig( - tls_strategy=os.getenv("HATCHET_CLIENT_TLS_STRATEGY", "tls"), - cert_file=os.getenv("HATCHET_CLIENT_TLS_CERT_FILE"), - key_file=os.getenv("HATCHET_CLIENT_TLS_KEY_FILE"), - ca_file=os.getenv("HATCHET_CLIENT_TLS_ROOT_CA_FILE"), - server_name=server_name, +def create_settings_config(env_prefix: str) -> SettingsConfigDict: + return SettingsConfigDict( + env_prefix=env_prefix, + env_file=(".env", ".env.hatchet", ".env.dev", ".env.local"), + extra="ignore", ) -def parse_listener_timeout(timeout: str | None) -> int | None: - if timeout is None: - return None +class ClientTLSConfig(BaseSettings): + model_config = create_settings_config( + env_prefix="HATCHET_CLIENT_TLS_", + ) strategy: str = "tls" cert_file: str | None = None @@ -57,19 +39,16 @@ class HealthcheckConfig(BaseSettings): DEFAULT_HOST_PORT = "localhost:7070" -class ClientConfig(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True, validate_default=True) +class ClientConfig(BaseSettings): + model_config = create_settings_config( + env_prefix="HATCHET_CLIENT_", + ) - token: str = os.getenv("HATCHET_CLIENT_TOKEN", "") + token: str = "" logger: Logger = getLogger() - tenant_id: str = os.getenv("HATCHET_CLIENT_TENANT_ID", "") - - ## IMPORTANT: Order matters here. The validators run in the order that the - ## fields are defined in the model. So, we need to make sure that the - ## host_port is set before we try to load the tls_config and server_url - host_port: str = os.getenv("HATCHET_CLIENT_HOST_PORT", DEFAULT_HOST_PORT) - tls_config: ClientTLSConfig = _load_tls_config() + tenant_id: str = "" + host_port: str = DEFAULT_HOST_PORT server_url: str = "https://app.dev.hatchet-tools.com" namespace: str = "" @@ -80,36 +59,19 @@ class ClientConfig(BaseModel): grpc_max_recv_message_length: int = Field( default=4 * 1024 * 1024, description="4MB default" ) - grpc_max_recv_message_length: int = int( - os.getenv("HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH", 4 * 1024 * 1024) - ) # 4MB - grpc_max_send_message_length: int = int( - os.getenv("HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH", 4 * 1024 * 1024) - ) # 4MB - otel_exporter_oltp_endpoint: str | None = os.getenv( - "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT" - ) - otel_service_name: str | None = os.getenv("HATCHET_CLIENT_OTEL_SERVICE_NAME") - otel_exporter_oltp_headers: str | None = os.getenv( - "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS" - ) - otel_exporter_oltp_protocol: str | None = os.getenv( - "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_PROTOCOL" - ) - worker_healthcheck_port: int = int( - os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT", 8001) - ) - worker_healthcheck_enabled: bool = ( - os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED", "False") == "True" + grpc_max_send_message_length: int = Field( + default=4 * 1024 * 1024, description="4MB default" ) - @field_validator("token", mode="after") - @classmethod - def validate_token(cls, token: str) -> str: - if not token: + worker_preset_labels: dict[str, str] = Field(default_factory=dict) + + @model_validator(mode="after") + def validate_token_and_tenant(self) -> "ClientConfig": + if not self.token: raise ValueError("Token must be set") - return token + if not self.tenant_id: + self.tenant_id = get_tenant_id_from_jwt(self.token) return self @@ -146,71 +108,9 @@ def validate_listener_timeout(cls, value: int | None | str) -> int | None: def validate_namespace(cls, namespace: str) -> str: if not namespace: return "" - if not namespace.endswith("_"): namespace = f"{namespace}_" - return namespace.lower() - @field_validator("tenant_id", mode="after") - @classmethod - def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str: - token = cast(str | None, info.data.get("token")) - - if not tenant_id: - if not token: - raise ValueError("Either the token or tenant_id must be set") - - return get_tenant_id_from_jwt(token) - - return tenant_id - - @field_validator("host_port", mode="after") - @classmethod - def validate_host_port(cls, host_port: str, info: ValidationInfo) -> str: - if host_port and host_port != DEFAULT_HOST_PORT: - return host_port - - token = cast(str, info.data.get("token")) - - if not token: - raise ValueError("Token must be set") - - _, grpc_broadcast_address = get_addresses_from_jwt(token) - - return grpc_broadcast_address - - @field_validator("server_url", mode="after") - @classmethod - def validate_server_url(cls, server_url: str, info: ValidationInfo) -> str: - ## IMPORTANT: Order matters here. The validators run in the order that the - ## fields are defined in the model. So, we need to make sure that the - ## host_port is set before we try to load the server_url - host_port = cast(str, info.data.get("host_port")) - - if host_port and host_port != DEFAULT_HOST_PORT: - return host_port - - token = cast(str, info.data.get("token")) - - if not token: - raise ValueError("Token must be set") - - _server_url, _ = get_addresses_from_jwt(token) - - return _server_url - - @field_validator("tls_config", mode="after") - @classmethod - def validate_tls_config( - cls, tls_config: ClientTLSConfig, info: ValidationInfo - ) -> ClientTLSConfig: - ## IMPORTANT: Order matters here. This validator runs in the order - ## that the fields are defined in the model. So, we need to make sure - ## that the host_port is set before we try to load the tls_config - host_port = cast(str, info.data.get("host_port")) - - return _load_tls_config(host_port) - def __hash__(self) -> int: return hash(json.dumps(self.model_dump(), default=str)) diff --git a/hatchet_sdk/v2/hatchet.py b/hatchet_sdk/v2/hatchet.py deleted file mode 100644 index 601e9908..00000000 --- a/hatchet_sdk/v2/hatchet.py +++ /dev/null @@ -1,343 +0,0 @@ -import asyncio -import logging -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Generic, - Optional, - ParamSpec, - Type, - TypeVar, - Union, - cast, -) - -from hatchet_sdk.client import Client, new_client, new_client_raw -from hatchet_sdk.clients.admin import AdminClient -from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient -from hatchet_sdk.clients.events import EventClient -from hatchet_sdk.clients.rest_client import RestApi -from hatchet_sdk.clients.run_event_listener import RunEventListenerClient -from hatchet_sdk.context.context import Context -from hatchet_sdk.contracts.workflows_pb2 import DesiredWorkerLabels -from hatchet_sdk.features.cron import CronClient -from hatchet_sdk.features.scheduled import ScheduledClient -from hatchet_sdk.labels import DesiredWorkerLabel -from hatchet_sdk.loader import ClientConfig -from hatchet_sdk.logger import logger -from hatchet_sdk.rate_limit import RateLimit -from hatchet_sdk.v2.workflows import ( - BaseWorkflowImpl, - ConcurrencyExpression, - EmptyModel, - Step, - StepType, - StickyStrategy, - TWorkflowInput, - WorkflowConfig, - WorkflowDeclaration, -) - -if TYPE_CHECKING: - from hatchet_sdk.worker.worker import Worker - -P = ParamSpec("P") -R = TypeVar("R") - - -def transform_desired_worker_label(d: DesiredWorkerLabel) -> DesiredWorkerLabels: - value = d.value - return DesiredWorkerLabels( - strValue=value if not isinstance(value, int) else None, - intValue=value if isinstance(value, int) else None, - required=d.required, - weight=d.weight, - comparator=d.comparator, # type: ignore[arg-type] - ) - - -class Function(Generic[R, TWorkflowInput]): - def __init__( - self, - fn: Callable[[Context], R], - hatchet: "Hatchet", - name: str = "", - on_events: list[str] = [], - on_crons: list[str] = [], - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy | None = None, - retries: int = 0, - rate_limits: list[RateLimit] = [], - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - concurrency: ConcurrencyExpression | None = None, - on_failure: Union["Function[R]", None] = None, - default_priority: int = 1, - input_validator: Type[TWorkflowInput] | None = None, - backoff_factor: float | None = None, - backoff_max_seconds: int | None = None, - ) -> None: - def func(_: Any, context: Context) -> R: - return fn(context) - - self.hatchet = hatchet - self.step: Step[R] = hatchet.step( - name=name or fn.__name__, - timeout=timeout, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - backoff_factor=backoff_factor, - backoff_max_seconds=backoff_max_seconds, - )(func) - self.on_failure_step = on_failure - self.workflow_config = WorkflowConfig( - name=name or fn.__name__, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - sticky=sticky, - default_priority=default_priority, - concurrency=concurrency, - input_validator=input_validator or cast(Type[TWorkflowInput], EmptyModel), - ) - - -class Hatchet: - """ - Main client for interacting with the Hatchet SDK. - - This class provides access to various client interfaces and utility methods - for working with Hatchet workers, workflows, and steps. - - Attributes: - cron (CronClient): Interface for cron trigger operations. - - admin (AdminClient): Interface for administrative operations. - dispatcher (DispatcherClient): Interface for dispatching operations. - event (EventClient): Interface for event-related operations. - rest (RestApi): Interface for REST API operations. - """ - - _client: Client - cron: CronClient - scheduled: ScheduledClient - - @classmethod - def from_environment( - cls, defaults: ClientConfig = ClientConfig(), **kwargs: Any - ) -> "Hatchet": - return cls(client=new_client(defaults), **kwargs) - - @classmethod - def from_config(cls, config: ClientConfig, **kwargs: Any) -> "Hatchet": - return cls(client=new_client_raw(config), **kwargs) - - def __init__( - self, - debug: bool = False, - client: Optional[Client] = None, - config: ClientConfig = ClientConfig(), - ): - """ - Initialize a new Hatchet instance. - - Args: - debug (bool, optional): Enable debug logging. Defaults to False. - client (Optional[Client], optional): A pre-configured Client instance. Defaults to None. - config (ClientConfig, optional): Configuration for creating a new Client. Defaults to ClientConfig(). - """ - if client is not None: - self._client = client - else: - self._client = new_client(config, debug) - - if debug: - logger.setLevel(logging.DEBUG) - - self.cron = CronClient(self._client) - self.scheduled = ScheduledClient(self._client) - - @property - def admin(self) -> AdminClient: - return self._client.admin - - @property - def dispatcher(self) -> DispatcherClient: - return self._client.dispatcher - - @property - def event(self) -> EventClient: - return self._client.event - - @property - def rest(self) -> RestApi: - return self._client.rest - - @property - def listener(self) -> RunEventListenerClient: - return self._client.listener - - @property - def config(self) -> ClientConfig: - return self._client.config - - @property - def tenant_id(self) -> str: - return self._client.config.tenant_id - - def step( - self, - name: str = "", - timeout: str = "60m", - parents: list[str] = [], - retries: int = 0, - rate_limits: list[RateLimit] = [], - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - backoff_factor: float | None = None, - backoff_max_seconds: int | None = None, - ) -> Callable[[Callable[[Any, Context], Any]], Step[R]]: - def inner(func: Callable[[Any, Context], R]) -> Step[R]: - return Step( - fn=func, - type=StepType.DEFAULT, - name=name.lower() or str(func.__name__).lower(), - timeout=timeout, - parents=parents, - retries=retries, - rate_limits=[r for rate_limit in rate_limits if (r := rate_limit._req)], - desired_worker_labels={ - key: transform_desired_worker_label(d) - for key, d in desired_worker_labels.items() - }, - backoff_factor=backoff_factor, - backoff_max_seconds=backoff_max_seconds, - ) - - return inner - - def on_failure_step( - self, - name: str = "", - timeout: str = "60m", - parents: list[str] = [], - retries: int = 0, - rate_limits: list[RateLimit] = [], - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - backoff_factor: float | None = None, - backoff_max_seconds: int | None = None, - ) -> Callable[[Callable[[Any, Context], Any]], Step[R]]: - def inner(func: Callable[[Any, Context], R]) -> Step[R]: - return Step( - fn=func, - type=StepType.ON_FAILURE, - name=name.lower() or str(func.__name__).lower(), - timeout=timeout, - parents=parents, - retries=retries, - rate_limits=[r for rate_limit in rate_limits if (r := rate_limit._req)], - desired_worker_labels={ - key: transform_desired_worker_label(d) - for key, d in desired_worker_labels.items() - }, - backoff_factor=backoff_factor, - backoff_max_seconds=backoff_max_seconds, - ) - - return inner - - def function( - self, - name: str = "", - on_events: list[str] = [], - on_crons: list[str] = [], - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy | None = None, - retries: int = 0, - rate_limits: list[RateLimit] = [], - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - concurrency: ConcurrencyExpression | None = None, - on_failure: Union["Function[Any]", None] = None, - default_priority: int = 1, - input_validator: Type[TWorkflowInput] | None = None, - backoff_factor: float | None = None, - backoff_max_seconds: int | None = None, - ) -> Callable[[Callable[[Context], R]], Function[R, TWorkflowInput]]: - def inner(func: Callable[[Context], R]) -> Function[R, TWorkflowInput]: - return Function[R, TWorkflowInput]( - func, - hatchet=self, - name=name, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - sticky=sticky, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - concurrency=concurrency, - on_failure=on_failure, - default_priority=default_priority, - input_validator=input_validator, - backoff_factor=backoff_factor, - backoff_max_seconds=backoff_max_seconds, - ) - - return inner - - def worker( - self, name: str, max_runs: int | None = None, labels: dict[str, str | int] = {} - ) -> "Worker": - from hatchet_sdk.worker.worker import Worker - - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = None - - return Worker( - name=name, - max_runs=max_runs, - labels=labels, - config=self._client.config, - debug=self._client.debug, - owned_loop=loop is None, - ) - - def declare_workflow( - self, - name: str = "", - on_events: list[str] = [], - on_crons: list[str] = [], - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy | None = None, - default_priority: int = 1, - concurrency: ConcurrencyExpression | None = None, - input_validator: Type[TWorkflowInput] | None = None, - ) -> WorkflowDeclaration[TWorkflowInput]: - return WorkflowDeclaration[TWorkflowInput]( - WorkflowConfig( - name=name, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - sticky=sticky, - default_priority=default_priority, - concurrency=concurrency, - input_validator=input_validator - or cast(Type[TWorkflowInput], EmptyModel), - ), - self, - ) diff --git a/hatchet_sdk/worker/worker.py b/hatchet_sdk/worker/worker.py index 057bd33e..3be3fb05 100644 --- a/hatchet_sdk/worker/worker.py +++ b/hatchet_sdk/worker/worker.py @@ -31,10 +31,10 @@ STOP_LOOP_TYPE, WorkerActionRunLoopManager, ) -from hatchet_sdk.workflow import Step +from hatchet_sdk.workflow import BaseWorkflow, Step, StepType if TYPE_CHECKING: - from hatchet_sdk.workflow import BaseWorkflow + from hatchet_sdk.hatchet import Function T = TypeVar("T") @@ -134,7 +134,7 @@ def on_failure_steps(self) -> list[Step[Any]]: self.register_workflow(Workflow()) - def register_workflow(self, workflow: Union["BaseWorkflowImpl", Any]) -> None: + def register_workflow(self, workflow: Union["BaseWorkflow", Any]) -> None: namespace = self.client.config.namespace try: