From dea9910fb18a3e667eabdf88f1e9483bc7f64033 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Mon, 13 Jan 2025 12:20:57 -0500 Subject: [PATCH 1/2] feat: overhaul client config --- hatchet_sdk/experimental/client.py | 7 +- hatchet_sdk/experimental/hatchet.py | 5 +- hatchet_sdk/experimental/loader.py | 322 ++++++---------------- hatchet_sdk/experimental/utils/tracing.py | 14 +- pyproject.toml | 13 + 5 files changed, 114 insertions(+), 247 deletions(-) diff --git a/hatchet_sdk/experimental/client.py b/hatchet_sdk/experimental/client.py index b513022b..6b1bb511 100644 --- a/hatchet_sdk/experimental/client.py +++ b/hatchet_sdk/experimental/client.py @@ -12,7 +12,7 @@ from .clients.dispatcher.dispatcher import DispatcherClient, new_dispatcher from .clients.events import EventClient, new_event from .clients.rest_client import RestApi -from .loader import ClientConfig, ConfigLoader +from .loader import ClientConfig class Client: @@ -37,11 +37,10 @@ def from_environment( loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - config: ClientConfig = ConfigLoader(".").load_client_config(defaults) for opt_function in opts_functions: - opt_function(config) + opt_function(defaults) - return cls.from_config(config, debug) + return cls.from_config(defaults, debug) @classmethod def from_config( diff --git a/hatchet_sdk/experimental/hatchet.py b/hatchet_sdk/experimental/hatchet.py index 8e121dab..712207f0 100644 --- a/hatchet_sdk/experimental/hatchet.py +++ b/hatchet_sdk/experimental/hatchet.py @@ -16,7 +16,7 @@ from hatchet_sdk.experimental.features.cron import CronClient from hatchet_sdk.experimental.features.scheduled import ScheduledClient from hatchet_sdk.experimental.labels import DesiredWorkerLabel -from hatchet_sdk.experimental.loader import ClientConfig, ConfigLoader +from hatchet_sdk.experimental.loader import ClientConfig from hatchet_sdk.experimental.rate_limit import RateLimit from hatchet_sdk.experimental.v2.callable import HatchetCallable @@ -190,8 +190,7 @@ class HatchetRest: rest: RestApi def __init__(self, config: ClientConfig = ClientConfig()): - _config: ClientConfig = ConfigLoader(".").load_client_config(config) - self.rest = RestApi(_config.server_url, _config.token, _config.tenant_id) + self.rest = RestApi(config.server_url, config.token, config.tenant_id) class Hatchet: diff --git a/hatchet_sdk/experimental/loader.py b/hatchet_sdk/experimental/loader.py index d754c2ae..9833b43a 100644 --- a/hatchet_sdk/experimental/loader.py +++ b/hatchet_sdk/experimental/loader.py @@ -1,246 +1,90 @@ -import json import os from logging import Logger, getLogger -from typing import Dict, Optional - -import yaml - -from .token import get_addresses_from_jwt, get_tenant_id_from_jwt - - -class ClientTLSConfig: - def __init__( - self, - tls_strategy: str, - cert_file: str, - key_file: str, - ca_file: str, - server_name: str, - ): - self.tls_strategy = tls_strategy - self.cert_file = cert_file - self.key_file = key_file - self.ca_file = ca_file - self.server_name = server_name - - -class ClientConfig: - logInterceptor: Logger - - def __init__( - self, - tenant_id: str = None, - tls_config: ClientTLSConfig = None, - token: str = None, - host_port: str = "localhost:7070", - server_url: str = "https://app.dev.hatchet-tools.com", - namespace: str = None, - listener_v2_timeout: int = None, - logger: Logger = None, - grpc_max_recv_message_length: int = 4 * 1024 * 1024, # 4MB - grpc_max_send_message_length: int = 4 * 1024 * 1024, # 4MB - otel_exporter_oltp_endpoint: str | None = None, - otel_service_name: str | None = None, - otel_exporter_oltp_headers: dict[str, str] | None = None, - otel_exporter_oltp_protocol: str | None = None, - worker_healthcheck_port: int | None = None, - worker_healthcheck_enabled: bool | None = None, - ): - self.tenant_id = tenant_id - self.tls_config = tls_config - self.host_port = host_port - self.token = token - self.server_url = server_url - self.namespace = "" - self.logInterceptor = logger - self.grpc_max_recv_message_length = grpc_max_recv_message_length - self.grpc_max_send_message_length = grpc_max_send_message_length - self.otel_exporter_oltp_endpoint = otel_exporter_oltp_endpoint - self.otel_service_name = otel_service_name - self.otel_exporter_oltp_headers = otel_exporter_oltp_headers - self.otel_exporter_oltp_protocol = otel_exporter_oltp_protocol - self.worker_healthcheck_port = worker_healthcheck_port - self.worker_healthcheck_enabled = worker_healthcheck_enabled - - if not self.logInterceptor: - self.logInterceptor = getLogger() - - # case on whether the namespace already has a trailing underscore - if namespace and not namespace.endswith("_"): - self.namespace = f"{namespace}_" - elif namespace: - self.namespace = namespace - - self.namespace = self.namespace.lower() - - self.listener_v2_timeout = listener_v2_timeout - - -class ConfigLoader: - def __init__(self, directory: str): - self.directory = directory - - def load_client_config(self, defaults: ClientConfig) -> ClientConfig: - config_file_path = os.path.join(self.directory, "client.yaml") - config_data: object = {"tls": {}} - - # determine if client.yaml exists - if os.path.exists(config_file_path): - with open(config_file_path, "r") as file: - config_data = yaml.safe_load(file) - - def get_config_value(key, env_var): - if key in config_data: - return config_data[key] - - if self._get_env_var(env_var) is not None: - return self._get_env_var(env_var) - - return getattr(defaults, key, None) - - namespace = get_config_value("namespace", "HATCHET_CLIENT_NAMESPACE") - - tenant_id = get_config_value("tenantId", "HATCHET_CLIENT_TENANT_ID") - token = get_config_value("token", "HATCHET_CLIENT_TOKEN") - listener_v2_timeout = get_config_value( - "listener_v2_timeout", "HATCHET_CLIENT_LISTENER_V2_TIMEOUT" - ) - listener_v2_timeout = int(listener_v2_timeout) if listener_v2_timeout else None - - if not token: - raise ValueError( - "Token must be set via HATCHET_CLIENT_TOKEN environment variable" - ) - - host_port = get_config_value("hostPort", "HATCHET_CLIENT_HOST_PORT") - server_url: str | None = None - - grpc_max_recv_message_length = get_config_value( - "grpc_max_recv_message_length", - "HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH", - ) - grpc_max_send_message_length = get_config_value( - "grpc_max_send_message_length", - "HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH", - ) - - if grpc_max_recv_message_length: - grpc_max_recv_message_length = int(grpc_max_recv_message_length) - - if grpc_max_send_message_length: - grpc_max_send_message_length = int(grpc_max_send_message_length) - - if not host_port: - # extract host and port from token - server_url, grpc_broadcast_address = get_addresses_from_jwt(token) - host_port = grpc_broadcast_address +from typing import cast + +from pydantic import BaseModel, ValidationError, ValidationInfo, field_validator + +from .token import get_tenant_id_from_jwt + + +class ClientTLSConfig(BaseModel): + tls_strategy: str + cert_file: str | None + key_file: str | None + ca_file: str | None + server_name: str + + +def _load_tls_config(host_port: str) -> ClientTLSConfig: + tls_strategy = os.getenv("HATCHET_CLIENT_TLS_STRATEGY", "tls") + cert_file = os.getenv("HATCHET_CLIENT_TLS_CERT_FILE") + key_file = os.getenv("HATCHET_CLIENT_TLS_KEY_FILE") + ca_file = os.getenv("HATCHET_CLIENT_TLS_ROOT_CA_FILE") + server_name = os.getenv("HATCHET_CLIENT_TLS_SERVER_NAME", host_port.split(":")[0]) + + return ClientTLSConfig(tls_strategy, cert_file, key_file, ca_file, server_name) + + +class ClientConfig(BaseModel): + token: str = os.getenv("HATCHET_CLIENT_TOKEN") + logger: Logger = getLogger() + tenant_id: str = os.getenv("HATCHET_CLIENT_TENANT_ID", "") + host_port: str = os.getenv("HATCHET_CLIENT_HOST_PORT", "localhost:7070") + tls_config: ClientTLSConfig = _load_tls_config(host_port) + server_url: str = "https://app.dev.hatchet-tools.com" + namespace: str = os.getenv("HATCHET_CLIENT_NAMESPACE", "") + listener_v2_timeout: int | None = ( + int(x) if (x := os.getenv("HATCHET_CLIENT_LISTENER_V2_TIMEOUT")) else None + ) + grpc_max_recv_message_length: int = int( + os.getenv("HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH", 4 * 1024 * 1024) + ) # 4MB + grpc_max_send_message_length: int = int( + os.getenv("HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH", 4 * 1024 * 1024) + ) # 4MB + otel_exporter_oltp_endpoint: str | None = os.getenv( + "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT" + ) + otel_service_name: str | None = os.getenv("HATCHET_CLIENT_OTEL_SERVICE_NAME") + otel_exporter_oltp_headers: str | None = ( + os.getenv("HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS"), + ) + otel_exporter_oltp_protocol: str | None = os.getenv( + "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_PROTOCOL" + ) + worker_healthcheck_port: int = int( + os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT", 8001) + ) + worker_healthcheck_enabled: bool = ( + os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED", "False") == "True" + ) + + @field_validator(mode="after") + @classmethod + def validate_namespace(cls, namespace: str | None) -> str: + if not namespace.endswith("_"): + namespace = f"{namespace}_" + + return namespace.lower() + + @field_validator(mode="after") + @classmethod + def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str: + token = cast(str | None, info.data.get("token")) if not tenant_id: - tenant_id = get_tenant_id_from_jwt(token) - - tls_config = self._load_tls_config(config_data["tls"], host_port) - - otel_exporter_oltp_endpoint = get_config_value( - "otel_exporter_oltp_endpoint", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT" - ) + if not token: + raise ValidationError( + "Token must be set before attempting to infer tenant ID" + ) - otel_service_name = get_config_value( - "otel_service_name", "HATCHET_CLIENT_OTEL_SERVICE_NAME" - ) + return get_tenant_id_from_jwt(token) - _oltp_headers = get_config_value( - "otel_exporter_oltp_headers", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS" - ) + return tenant_id - if _oltp_headers: - try: - otel_header_key, api_key = _oltp_headers.split("=", maxsplit=1) - otel_exporter_oltp_headers = {otel_header_key: api_key} - except ValueError: - raise ValueError( - "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS must be in the format `key=value`" - ) - else: - otel_exporter_oltp_headers = None - - otel_exporter_oltp_protocol = get_config_value( - "otel_exporter_oltp_protocol", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_PROTOCOL" - ) - - worker_healthcheck_port = int( - get_config_value( - "worker_healthcheck_port", "HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT" - ) - or 8001 - ) - - worker_healthcheck_enabled = ( - str( - get_config_value( - "worker_healthcheck_port", - "HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED", - ) - ) - == "True" - ) - - return ClientConfig( - tenant_id=tenant_id, - tls_config=tls_config, - token=token, - host_port=host_port, - server_url=server_url, - namespace=namespace, - listener_v2_timeout=listener_v2_timeout, - logger=defaults.logInterceptor, - grpc_max_recv_message_length=grpc_max_recv_message_length, - grpc_max_send_message_length=grpc_max_send_message_length, - otel_exporter_oltp_endpoint=otel_exporter_oltp_endpoint, - otel_service_name=otel_service_name, - otel_exporter_oltp_headers=otel_exporter_oltp_headers, - otel_exporter_oltp_protocol=otel_exporter_oltp_protocol, - worker_healthcheck_port=worker_healthcheck_port, - worker_healthcheck_enabled=worker_healthcheck_enabled, - ) - - def _load_tls_config(self, tls_data: Dict, host_port) -> ClientTLSConfig: - tls_strategy = ( - tls_data["tlsStrategy"] - if "tlsStrategy" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_STRATEGY") - ) - - if not tls_strategy: - tls_strategy = "tls" - - cert_file = ( - tls_data["tlsCertFile"] - if "tlsCertFile" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_CERT_FILE") - ) - key_file = ( - tls_data["tlsKeyFile"] - if "tlsKeyFile" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_KEY_FILE") - ) - ca_file = ( - tls_data["tlsRootCAFile"] - if "tlsRootCAFile" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_ROOT_CA_FILE") - ) - - server_name = ( - tls_data["tlsServerName"] - if "tlsServerName" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_SERVER_NAME") - ) - - # if server_name is not set, use the host from the host_port - if not server_name: - server_name = host_port.split(":")[0] - - return ClientTLSConfig(tls_strategy, cert_file, key_file, ca_file, server_name) - - @staticmethod - def _get_env_var(env_var: str, default: Optional[str] = None) -> str: - return os.environ.get(env_var, default) + ## TODO: Fix host port overrides here + ## Old code: + ## if not host_port: + ## ## extract host and port from token + ## server_url, grpc_broadcast_address = get_addresses_from_jwt(token) + ## host_port = grpc_broadcast_address diff --git a/hatchet_sdk/experimental/utils/tracing.py b/hatchet_sdk/experimental/utils/tracing.py index 59dc0774..52892979 100644 --- a/hatchet_sdk/experimental/utils/tracing.py +++ b/hatchet_sdk/experimental/utils/tracing.py @@ -16,6 +16,18 @@ OTEL_CARRIER_KEY = "__otel_carrier" +def parse_headers(headers: str | None) -> dict[str, str]: + if headers is None: + return {} + + try: + otel_header_key, api_key = headers.split("=", maxsplit=1) + + return {otel_header_key: api_key} + except ValueError: + raise ValueError("OTLP headers must be in the format `key=value`") + + @cache def create_tracer(config: ClientConfig) -> Tracer: ## TODO: Figure out how to specify protocol here @@ -27,7 +39,7 @@ def create_tracer(config: ClientConfig) -> Tracer: processor = BatchSpanProcessor( OTLPSpanExporter( endpoint=config.otel_exporter_oltp_endpoint, - headers=config.otel_exporter_oltp_headers, + headers=parse_headers(config.otel_exporter_oltp_headers), ), ) diff --git a/pyproject.toml b/pyproject.toml index 69380b67..e4c16240 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,19 @@ files = [ "hatchet_sdk/clients/rest/models/workflow_run.py", "hatchet_sdk/context/worker_context.py", "hatchet_sdk/clients/dispatcher/dispatcher.py", + "hatchet_sdk/experimental/hatchet.py", + "hatchet_sdk/experimental/worker/worker.py", + "hatchet_sdk/experimental/context/context.py", + "hatchet_sdk/experimental/worker/runner/runner.py", + "hatchet_sdk/experimental/workflow.py", + "hatchet_sdk/experimental/utils/serialization.py", + "hatchet_sdk/experimental/utils/tracing.py", + "hatchet_sdk/experimental/utils/types.py", + "hatchet_sdk/experimental/utils/backoff.py", + "hatchet_sdk/experimental/clients/rest/models/workflow_list.py", + "hatchet_sdk/experimental/clients/rest/models/workflow_run.py", + "hatchet_sdk/experimental/context/worker_context.py", + "hatchet_sdk/experimental/clients/dispatcher/dispatcher.py", ] follow_imports = "silent" disable_error_code = ["unused-coroutine"] From 9614f7a8da9a3504c886c7fe69187d3a640a06b6 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Mon, 13 Jan 2025 12:28:28 -0500 Subject: [PATCH 2/2] fix: couple type issues --- hatchet_sdk/experimental/loader.py | 45 ++++++++++++++++++++---------- pyproject.toml | 1 + 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/hatchet_sdk/experimental/loader.py b/hatchet_sdk/experimental/loader.py index 9833b43a..265aa4f0 100644 --- a/hatchet_sdk/experimental/loader.py +++ b/hatchet_sdk/experimental/loader.py @@ -16,25 +16,34 @@ class ClientTLSConfig(BaseModel): def _load_tls_config(host_port: str) -> ClientTLSConfig: - tls_strategy = os.getenv("HATCHET_CLIENT_TLS_STRATEGY", "tls") - cert_file = os.getenv("HATCHET_CLIENT_TLS_CERT_FILE") - key_file = os.getenv("HATCHET_CLIENT_TLS_KEY_FILE") - ca_file = os.getenv("HATCHET_CLIENT_TLS_ROOT_CA_FILE") - server_name = os.getenv("HATCHET_CLIENT_TLS_SERVER_NAME", host_port.split(":")[0]) + return ClientTLSConfig( + tls_strategy=os.getenv("HATCHET_CLIENT_TLS_STRATEGY", "tls"), + cert_file=os.getenv("HATCHET_CLIENT_TLS_CERT_FILE"), + key_file=os.getenv("HATCHET_CLIENT_TLS_KEY_FILE"), + ca_file=os.getenv("HATCHET_CLIENT_TLS_ROOT_CA_FILE"), + server_name=os.getenv( + "HATCHET_CLIENT_TLS_SERVER_NAME", host_port.split(":")[0] + ), + ) + + +def parse_listener_timeout(timeout: str | None) -> int | None: + if timeout is None: + return None - return ClientTLSConfig(tls_strategy, cert_file, key_file, ca_file, server_name) + return int(timeout) class ClientConfig(BaseModel): - token: str = os.getenv("HATCHET_CLIENT_TOKEN") + token: str = os.getenv("HATCHET_CLIENT_TOKEN", "") logger: Logger = getLogger() tenant_id: str = os.getenv("HATCHET_CLIENT_TENANT_ID", "") host_port: str = os.getenv("HATCHET_CLIENT_HOST_PORT", "localhost:7070") tls_config: ClientTLSConfig = _load_tls_config(host_port) server_url: str = "https://app.dev.hatchet-tools.com" namespace: str = os.getenv("HATCHET_CLIENT_NAMESPACE", "") - listener_v2_timeout: int | None = ( - int(x) if (x := os.getenv("HATCHET_CLIENT_LISTENER_V2_TIMEOUT")) else None + listener_v2_timeout: int | None = parse_listener_timeout( + os.getenv("HATCHET_CLIENT_LISTENER_V2_TIMEOUT") ) grpc_max_recv_message_length: int = int( os.getenv("HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH", 4 * 1024 * 1024) @@ -46,8 +55,8 @@ class ClientConfig(BaseModel): "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT" ) otel_service_name: str | None = os.getenv("HATCHET_CLIENT_OTEL_SERVICE_NAME") - otel_exporter_oltp_headers: str | None = ( - os.getenv("HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS"), + otel_exporter_oltp_headers: str | None = os.getenv( + "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS" ) otel_exporter_oltp_protocol: str | None = os.getenv( "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_PROTOCOL" @@ -59,15 +68,23 @@ class ClientConfig(BaseModel): os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED", "False") == "True" ) - @field_validator(mode="after") + @field_validator("token", mode="after") + @classmethod + def validate_token(cls, token: str) -> str: + if not token: + raise ValidationError("Token must be set") + + return token + + @field_validator("namespace", mode="after") @classmethod - def validate_namespace(cls, namespace: str | None) -> str: + def validate_namespace(cls, namespace: str) -> str: if not namespace.endswith("_"): namespace = f"{namespace}_" return namespace.lower() - @field_validator(mode="after") + @field_validator("tenant_id", mode="after") @classmethod def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str: token = cast(str | None, info.data.get("token")) diff --git a/pyproject.toml b/pyproject.toml index e4c16240..39261568 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,6 +106,7 @@ files = [ "hatchet_sdk/experimental/clients/rest/models/workflow_run.py", "hatchet_sdk/experimental/context/worker_context.py", "hatchet_sdk/experimental/clients/dispatcher/dispatcher.py", + "hatchet_sdk/experimental/loader.py", ] follow_imports = "silent" disable_error_code = ["unused-coroutine"]