diff --git a/python-sdk/exospherehost/__init__.py b/python-sdk/exospherehost/__init__.py index a16745df..4f905001 100644 --- a/python-sdk/exospherehost/__init__.py +++ b/python-sdk/exospherehost/__init__.py @@ -37,9 +37,10 @@ async def execute(self, inputs: Inputs) -> Outputs: from ._version import version as __version__ from .runtime import Runtime from .node.BaseNode import BaseNode -from .statemanager import StateManager, TriggerState +from .statemanager import StateManager from .signals import PruneSignal, ReQueueAfterSignal +from .models import UnitesStrategyEnum, UnitesModel, GraphNodeModel, RetryStrategyEnum, RetryPolicyModel, StoreConfigModel VERSION = __version__ -__all__ = ["Runtime", "BaseNode", "StateManager", "TriggerState", "VERSION", "PruneSignal", "ReQueueAfterSignal"] +__all__ = ["Runtime", "BaseNode", "StateManager", "VERSION", "PruneSignal", "ReQueueAfterSignal", "UnitesStrategyEnum", "UnitesModel", "GraphNodeModel", "RetryStrategyEnum", "RetryPolicyModel", "StoreConfigModel"] diff --git a/python-sdk/exospherehost/_version.py b/python-sdk/exospherehost/_version.py index cc4ed116..8c58f5cc 100644 --- a/python-sdk/exospherehost/_version.py +++ b/python-sdk/exospherehost/_version.py @@ -1 +1 @@ -version = "0.0.2b4" +version = "0.0.2b5" diff --git a/python-sdk/exospherehost/models.py b/python-sdk/exospherehost/models.py new file mode 100644 index 00000000..6ab72e4c --- /dev/null +++ b/python-sdk/exospherehost/models.py @@ -0,0 +1,160 @@ +from pydantic import BaseModel, Field, field_validator +from typing import Any, Optional, List +from enum import Enum + + +class UnitesStrategyEnum(str, Enum): + ALL_SUCCESS = "ALL_SUCCESS" + ALL_DONE = "ALL_DONE" + + +class UnitesModel(BaseModel): + identifier: str = Field(..., description="Identifier of the node") + strategy: UnitesStrategyEnum = Field(default=UnitesStrategyEnum.ALL_SUCCESS, description="Strategy of the unites") + + +class GraphNodeModel(BaseModel): + node_name: str = Field(..., description="Name of the node") + namespace: str = Field(..., description="Namespace of the node") + identifier: str = Field(..., description="Identifier of the node") + inputs: dict[str, Any] = Field(..., description="Inputs of the node") + next_nodes: Optional[List[str]] = Field(None, description="Next nodes to execute") + unites: Optional[UnitesModel] = Field(None, description="Unites of the node") + + @field_validator('node_name') + @classmethod + def validate_node_name(cls, v: str) -> str: + trimmed_v = v.strip() + if trimmed_v == "" or trimmed_v is None: + raise ValueError("Node name cannot be empty") + return trimmed_v + + @field_validator('identifier') + @classmethod + def validate_identifier(cls, v: str) -> str: + trimmed_v = v.strip() + if trimmed_v == "" or trimmed_v is None: + raise ValueError("Node identifier cannot be empty") + elif trimmed_v == "store": + raise ValueError("Node identifier cannot be reserved word 'store'") + return trimmed_v + + @field_validator('next_nodes') + @classmethod + def validate_next_nodes(cls, v: Optional[List[str]]) -> Optional[List[str]]: + identifiers = set() + errors = [] + trimmed_v = [] + + if v is not None: + for next_node_identifier in v: + trimmed_next_node_identifier = next_node_identifier.strip() + + if trimmed_next_node_identifier == "" or trimmed_next_node_identifier is None: + errors.append("Next node identifier cannot be empty") + continue + + if trimmed_next_node_identifier in identifiers: + errors.append(f"Next node identifier {trimmed_next_node_identifier} is not unique") + continue + + identifiers.add(trimmed_next_node_identifier) + trimmed_v.append(trimmed_next_node_identifier) + if errors: + raise ValueError("\n".join(errors)) + return trimmed_v + + @field_validator('unites') + @classmethod + def validate_unites(cls, v: Optional[UnitesModel]) -> Optional[UnitesModel]: + trimmed_v = v + if v is not None: + trimmed_v = UnitesModel(identifier=v.identifier.strip(), strategy=v.strategy) + if trimmed_v.identifier == "" or trimmed_v.identifier is None: + raise ValueError("Unites identifier cannot be empty") + return trimmed_v + + +class RetryStrategyEnum(str, Enum): + EXPONENTIAL = "EXPONENTIAL" + EXPONENTIAL_FULL_JITTER = "EXPONENTIAL_FULL_JITTER" + EXPONENTIAL_EQUAL_JITTER = "EXPONENTIAL_EQUAL_JITTER" + + LINEAR = "LINEAR" + LINEAR_FULL_JITTER = "LINEAR_FULL_JITTER" + LINEAR_EQUAL_JITTER = "LINEAR_EQUAL_JITTER" + + FIXED = "FIXED" + FIXED_FULL_JITTER = "FIXED_FULL_JITTER" + FIXED_EQUAL_JITTER = "FIXED_EQUAL_JITTER" + + +class RetryPolicyModel(BaseModel): + max_retries: int = Field(default=3, description="The maximum number of retries", ge=0) + strategy: RetryStrategyEnum = Field(default=RetryStrategyEnum.EXPONENTIAL, description="The method of retry") + backoff_factor: int = Field(default=2000, description="The backoff factor in milliseconds (default: 2000 = 2 seconds)", gt=0) + exponent: int = Field(default=2, description="The exponent for the exponential retry strategy", gt=0) + max_delay: int | None = Field(default=None, description="The maximum delay in milliseconds (no default limit when None)", gt=0) + + +class StoreConfigModel(BaseModel): + required_keys: list[str] = Field(default_factory=list, description="Required keys of the store") + default_values: dict[str, str] = Field(default_factory=dict, description="Default values of the store") + + @field_validator("required_keys") + @classmethod + def validate_required_keys(cls, v: list[str]) -> list[str]: + errors = [] + keys = set() + trimmed_keys = [] + + for key in v: + trimmed_key = key.strip() if key is not None else "" + + if trimmed_key == "": + errors.append("Key cannot be empty or contain only whitespace") + continue + + if '.' in trimmed_key: + errors.append(f"Key '{trimmed_key}' cannot contain '.' character") + continue + + if trimmed_key in keys: + errors.append(f"Key '{trimmed_key}' is duplicated") + continue + + keys.add(trimmed_key) + trimmed_keys.append(trimmed_key) + + if len(errors) > 0: + raise ValueError("\n".join(errors)) + return trimmed_keys + + @field_validator("default_values") + @classmethod + def validate_default_values(cls, v: dict[str, str]) -> dict[str, str]: + errors = [] + keys = set() + normalized_dict = {} + + for key, value in v.items(): + trimmed_key = key.strip() if key is not None else "" + + if trimmed_key == "": + errors.append("Key cannot be empty or contain only whitespace") + continue + + if '.' in trimmed_key: + errors.append(f"Key '{trimmed_key}' cannot contain '.' character") + continue + + if trimmed_key in keys: + errors.append(f"Key '{trimmed_key}' is duplicated") + continue + + keys.add(trimmed_key) + normalized_dict[trimmed_key] = str(value) + + if len(errors) > 0: + raise ValueError("\n".join(errors)) + return normalized_dict \ No newline at end of file diff --git a/python-sdk/exospherehost/statemanager.py b/python-sdk/exospherehost/statemanager.py index 7940831c..1b029fa1 100644 --- a/python-sdk/exospherehost/statemanager.py +++ b/python-sdk/exospherehost/statemanager.py @@ -3,40 +3,7 @@ import asyncio import time -from typing import Any -from pydantic import BaseModel - - -class TriggerState(BaseModel): - """ - Represents a trigger state for graph execution. - - A trigger state contains an identifier and a set of input parameters that - will be passed to the graph when it is triggered for execution. - - Attributes: - identifier (str): A unique identifier for this trigger state. This is used - to distinguish between different trigger states and may be used by the - graph to determine how to process the trigger. - inputs (dict[str, str]): A dictionary of input parameters that will be - passed to the graph. The keys are parameter names and values are - parameter values, both as strings. - - Example: - ```python - # Create a trigger state with identifier and inputs - trigger_state = TriggerState( - identifier="user-login", - inputs={ - "user_id": "12345", - "session_token": "abc123def456", - "timestamp": "2024-01-15T10:30:00Z" - } - ) - ``` - """ - identifier: str - inputs: dict[str, str] +from .models import GraphNodeModel, RetryPolicyModel, StoreConfigModel class StateManager: @@ -67,7 +34,7 @@ def _get_upsert_graph_endpoint(self, graph_name: str): def _get_get_graph_endpoint(self, graph_name: str): return f"{self._state_manager_uri}/{self._state_manager_version}/namespace/{self._namespace}/graph/{graph_name}" - async def trigger(self, graph_name: str, inputs: dict[str, str] | None = None, store: dict[str, str] | None = None): + async def trigger(self, graph_name: str, inputs: dict[str, str] | None = None, store: dict[str, str] | None = None, start_delay: int = 0): """ Trigger execution of a graph. @@ -82,7 +49,8 @@ async def trigger(self, graph_name: str, inputs: dict[str, str] | None = None, s graph. Strings only. store (dict[str, str] | None): Optional key-value store that will be merged into the graph-level store before execution (beta). - + start_delay (int): Optional delay in milliseconds before the graph starts execution. + Returns: dict: JSON payload returned by the state-manager API. @@ -108,6 +76,7 @@ async def trigger(self, graph_name: str, inputs: dict[str, str] | None = None, s store = {} body = { + "start_delay": start_delay, "inputs": inputs, "store": store } @@ -156,7 +125,7 @@ async def get_graph(self, graph_name: str): raise Exception(f"Failed to get graph: {response.status} {await response.text()}") return await response.json() - async def upsert_graph(self, graph_name: str, graph_nodes: list[dict[str, Any]], secrets: dict[str, str], retry_policy: dict[str, Any] | None = None, store_config: dict[str, Any] | None = None, validation_timeout: int = 60, polling_interval: int = 1): + async def upsert_graph(self, graph_name: str, graph_nodes: list[GraphNodeModel], secrets: dict[str, str], retry_policy: RetryPolicyModel | None = None, store_config: StoreConfigModel | None = None, validation_timeout: int = 60, polling_interval: int = 1): """ Create or update a graph definition. @@ -169,10 +138,10 @@ async def upsert_graph(self, graph_name: str, graph_nodes: list[dict[str, Any]], Args: graph_name (str): Graph identifier. - graph_nodes (list[dict[str, Any]]): Graph node list. + graph_nodes (list[GraphNodeModel]): List of graph node models defining the workflow. secrets (dict[str, str]): Secrets available to all nodes. - retry_policy (dict[str, Any] | None): Optional per-node retry policy. - store_config (dict[str, Any] | None): Beta configuration for the + retry_policy (RetryPolicyModel | None): Optional per-node retry policy configuration. + store_config (StoreConfigModel | None): Beta configuration for the graph-level store (schema is subject to change). validation_timeout (int): Seconds to wait for validation (default 60). polling_interval (int): Polling interval in seconds (default 1). @@ -189,13 +158,13 @@ async def upsert_graph(self, graph_name: str, graph_nodes: list[dict[str, Any]], } body = { "secrets": secrets, - "nodes": graph_nodes + "nodes": [node.model_dump() for node in graph_nodes] } if retry_policy is not None: - body["retry_policy"] = retry_policy + body["retry_policy"] = retry_policy.model_dump() if store_config is not None: - body["store_config"] = store_config + body["store_config"] = store_config.model_dump() async with aiohttp.ClientSession() as session: async with session.put(endpoint, json=body, headers=headers) as response: # type: ignore diff --git a/python-sdk/tests/test_coverage_additions.py b/python-sdk/tests/test_coverage_additions.py index 04d97cd1..40dce760 100644 --- a/python-sdk/tests/test_coverage_additions.py +++ b/python-sdk/tests/test_coverage_additions.py @@ -43,7 +43,7 @@ async def test_statemanager_trigger_defaults(monkeypatch): # Verify it sent empty inputs/store when omitted mock_session.post.assert_called_once() _, kwargs = mock_session.post.call_args - assert kwargs["json"] == {"inputs": {}, "store": {}} + assert kwargs["json"] == {"inputs": {}, "store": {}, "start_delay": 0} class _DummyNode(BaseNode): diff --git a/python-sdk/tests/test_integration.py b/python-sdk/tests/test_integration.py index 33e87500..9c245331 100644 --- a/python-sdk/tests/test_integration.py +++ b/python-sdk/tests/test_integration.py @@ -2,7 +2,7 @@ import asyncio from unittest.mock import AsyncMock, patch, MagicMock from pydantic import BaseModel -from exospherehost import Runtime, BaseNode, StateManager, TriggerState +from exospherehost import Runtime, BaseNode, StateManager def create_mock_aiohttp_session(): @@ -205,8 +205,16 @@ async def test_state_manager_graph_lifecycle(self, mock_env_vars): sm = StateManager(namespace="test_namespace") # Test graph creation + from exospherehost.models import GraphNodeModel graph_nodes = [ - {"name": "IntegrationTestNode", "type": "test"} + GraphNodeModel( + node_name="IntegrationTestNode", + namespace="test_namespace", + identifier="IntegrationTestNode", + inputs={"type": "test"}, + next_nodes=None, + unites=None + ) ] secrets = {"api_key": "test_key", "database_url": "db://test"} @@ -214,12 +222,9 @@ async def test_state_manager_graph_lifecycle(self, mock_env_vars): assert result["validation_status"] == "VALID" # Test graph triggering - trigger_state = TriggerState( - identifier="test_trigger", - inputs={"user_id": "123", "action": "login"} - ) + trigger_state = {"identifier": "test_trigger", "inputs": {"user_id": "123", "action": "login"}} - trigger_result = await sm.trigger("test_graph", inputs=trigger_state.inputs) + trigger_result = await sm.trigger("test_graph", inputs=trigger_state["inputs"]) assert trigger_result == {"status": "triggered"} @@ -448,10 +453,10 @@ async def test_state_manager_error_propagation(self, mock_env_vars): mock_session_class.return_value = mock_session sm = StateManager(namespace="error_test") - trigger_state = TriggerState(identifier="test", inputs={"key": "value"}) + trigger_state = {"identifier": "test", "inputs": {"key": "value"}} with pytest.raises(Exception, match="Failed to trigger state: 404 Graph not found"): - await sm.trigger("nonexistent_graph", inputs=trigger_state.inputs) + await sm.trigger("nonexistent_graph", inputs=trigger_state["inputs"]) class TestConcurrencyIntegration: diff --git a/python-sdk/tests/test_models_and_statemanager_new.py b/python-sdk/tests/test_models_and_statemanager_new.py new file mode 100644 index 00000000..a35f70f4 --- /dev/null +++ b/python-sdk/tests/test_models_and_statemanager_new.py @@ -0,0 +1,187 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from exospherehost.models import ( + GraphNodeModel, + UnitesModel, + UnitesStrategyEnum, + StoreConfigModel, + RetryPolicyModel, +) +from exospherehost.statemanager import StateManager + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_mock_session_with_status(status: int, json_payload: dict): + """Create an aiohttp-like mock ClientSession returning the given status & payload.""" + mock_session = MagicMock() + mock_resp = MagicMock() + mock_resp.status = status + mock_resp.json = AsyncMock(return_value=json_payload) + + mock_ctx = MagicMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_resp) + mock_ctx.__aexit__ = AsyncMock(return_value=None) + + # route all verbs to the same context manager + mock_session.post.return_value = mock_ctx + mock_session.get.return_value = mock_ctx + mock_session.put.return_value = mock_ctx + + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + return mock_session, mock_resp + + +# --------------------------------------------------------------------------- +# GraphNodeModel & related validation +# --------------------------------------------------------------------------- + +def test_graph_node_model_trimming_and_defaults(): + model = GraphNodeModel( + node_name=" MyNode ", + namespace="ns", + identifier=" node1 ", + inputs={}, + next_nodes=[" next1 "], + unites=UnitesModel(identifier=" unite1 ") # strategy default should kick in + ) + + # Fields should be stripped + assert model.node_name == "MyNode" + assert model.identifier == "node1" + assert model.next_nodes == ["next1"] + assert model.unites is not None + assert model.unites.identifier == "unite1" + # Default enum value check + assert model.unites.strategy == UnitesStrategyEnum.ALL_SUCCESS + + +@pytest.mark.parametrize( + "field, kwargs, err_msg", + [ + ("node_name", {"node_name": " "}, "Node name cannot be empty"), + ("identifier", {"identifier": "store"}, "reserved word"), + ( + "next_nodes", + {"next_nodes": ["", "id2"]}, + "cannot be empty", + ), + ( + "next_nodes", + {"next_nodes": ["dup", "dup"]}, + "not unique", + ), + ( + "unites", + {"unites": UnitesModel(identifier=" ")}, + "Unites identifier cannot be empty", + ), + ], +) +def test_graph_node_model_invalid(field, kwargs, err_msg): + base_kwargs = dict( + node_name="n", + namespace="ns", + identifier="id1", + inputs={}, + next_nodes=None, + unites=None + ) + base_kwargs.update(kwargs) + with pytest.raises(ValueError) as e: + GraphNodeModel(**base_kwargs) # type: ignore + assert err_msg in str(e.value) + + +# --------------------------------------------------------------------------- +# StoreConfigModel validation +# --------------------------------------------------------------------------- + +def test_store_config_model_valid_and_normalises(): + cfg = StoreConfigModel( + required_keys=[" a ", "b"], + default_values={" c ": "1", "d": "2"}, + ) + # Keys should be trimmed and values stringified + assert cfg.required_keys == ["a", "b"] + assert cfg.default_values == {"c": "1", "d": "2"} + + +@pytest.mark.parametrize( + "kwargs, msg", + [ + ({"required_keys": ["a", "a"]}, "duplicated"), + ({"required_keys": ["a."]}, "cannot contain '.'"), + ({"required_keys": [" "]}, "cannot be empty"), + ({"default_values": {"k.k": "v"}}, "cannot contain '.'"), + ({"default_values": {"": "v"}}, "cannot be empty"), + ], +) +def test_store_config_model_invalid(kwargs, msg): + with pytest.raises(ValueError) as e: + StoreConfigModel(**kwargs) + assert msg in str(e.value) + + + + + +# --------------------------------------------------------------------------- +# RetryPolicyModel defaults (simple smoke test) +# --------------------------------------------------------------------------- + +def test_retry_policy_defaults(): + pol = RetryPolicyModel() + assert pol.max_retries == 3 + assert pol.backoff_factor == 2000 + # Ensure all enum values round-trip via model_dump + dumped = pol.model_dump() + assert dumped["strategy"] == pol.strategy + + +# --------------------------------------------------------------------------- +# StateManager – store_config / store handling logic +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_statemanager_upsert_includes_store_config(monkeypatch): + monkeypatch.setenv("EXOSPHERE_STATE_MANAGER_URI", "http://sm") + monkeypatch.setenv("EXOSPHERE_API_KEY", "k") + + sm = StateManager(namespace="ns") + + node = GraphNodeModel(node_name="n", namespace="ns", identifier="id1", inputs={}, next_nodes=None, unites=None) # type: ignore + store_cfg = StoreConfigModel(required_keys=["k1"], default_values={"k2": "v"}) + + # Mock ClientSession + mock_session, _ = _make_mock_session_with_status(201, {"validation_status": "VALID"}) + + with patch("exospherehost.statemanager.aiohttp.ClientSession", return_value=mock_session): + await sm.upsert_graph("g", [node], secrets={}, store_config=store_cfg) + + mock_session.put.assert_called_once() + _, kwargs = mock_session.put.call_args + # Ensure the store_config is present and exactly what model_dump produced + assert "store_config" in kwargs["json"] + assert kwargs["json"]["store_config"] == store_cfg.model_dump() + + +@pytest.mark.asyncio +async def test_statemanager_trigger_passes_store(monkeypatch): + monkeypatch.setenv("EXOSPHERE_STATE_MANAGER_URI", "http://sm") + monkeypatch.setenv("EXOSPHERE_API_KEY", "k") + + sm = StateManager(namespace="ns") + + mock_session, _ = _make_mock_session_with_status(200, {}) + + with patch("exospherehost.statemanager.aiohttp.ClientSession", return_value=mock_session): + await sm.trigger("g", inputs={"a": "1"}, store={"cursor": "0"}, start_delay=123) + + mock_session.post.assert_called_once() + _, kwargs = mock_session.post.call_args + assert kwargs["json"] == {"inputs": {"a": "1"}, "store": {"cursor": "0"}, "start_delay": 123} \ No newline at end of file diff --git a/python-sdk/tests/test_package_init.py b/python-sdk/tests/test_package_init.py index d5406b52..54dba15c 100644 --- a/python-sdk/tests/test_package_init.py +++ b/python-sdk/tests/test_package_init.py @@ -1,5 +1,5 @@ import pytest -from exospherehost import Runtime, BaseNode, StateManager, TriggerState, VERSION +from exospherehost import Runtime, BaseNode, StateManager, VERSION def test_package_imports(): @@ -7,7 +7,6 @@ def test_package_imports(): assert Runtime is not None assert BaseNode is not None assert StateManager is not None - assert TriggerState is not None assert VERSION is not None @@ -15,7 +14,7 @@ def test_package_all_imports(): """Test that __all__ contains all expected exports.""" from exospherehost import __all__ - expected_exports = ["Runtime", "BaseNode", "StateManager", "TriggerState", "VERSION", "PruneSignal", "ReQueueAfterSignal"] + expected_exports = ["Runtime", "BaseNode", "StateManager", "VERSION", "PruneSignal", "ReQueueAfterSignal", "UnitesStrategyEnum", "UnitesModel", "GraphNodeModel", "RetryStrategyEnum", "RetryPolicyModel", "StoreConfigModel"] for export in expected_exports: assert export in __all__, f"{export} should be in __all__" @@ -57,19 +56,6 @@ def test_state_manager_class_import(): assert StateManager is StateManagerDirect -def test_trigger_state_class_import(): - """Test that TriggerState class can be imported and instantiated.""" - from exospherehost.statemanager import TriggerState as TriggerStateDirect - - # Test that the imported TriggerState is the same as the one from the package - assert TriggerState is TriggerStateDirect - - # Test that TriggerState can be instantiated - state = TriggerState(identifier="test", inputs={"key": "value"}) - assert state.identifier == "test" - assert state.inputs == {"key": "value"} - - def test_version_import(): """Test that VERSION is properly imported and is a string.""" from exospherehost._version import version as version_direct @@ -121,7 +107,6 @@ def test_package_structure(): assert hasattr(exospherehost, 'Runtime') assert hasattr(exospherehost, 'BaseNode') assert hasattr(exospherehost, 'StateManager') - assert hasattr(exospherehost, 'TriggerState') assert hasattr(exospherehost, 'VERSION') assert hasattr(exospherehost, '__version__') assert hasattr(exospherehost, '__all__') diff --git a/python-sdk/tests/test_statemanager_comprehensive.py b/python-sdk/tests/test_statemanager_comprehensive.py index c8b1c7a2..42330e57 100644 --- a/python-sdk/tests/test_statemanager_comprehensive.py +++ b/python-sdk/tests/test_statemanager_comprehensive.py @@ -1,6 +1,7 @@ import pytest from unittest.mock import AsyncMock, patch, MagicMock -from exospherehost.statemanager import StateManager, TriggerState +from exospherehost.statemanager import StateManager +from exospherehost.models import GraphNodeModel def create_mock_aiohttp_session(): @@ -91,27 +92,6 @@ def test_get_get_graph_endpoint(self, state_manager_config): assert endpoint == expected -class TestTriggerState: - def test_trigger_state_creation(self): - state = TriggerState( - identifier="test_trigger", - inputs={"key1": "value1", "key2": "value2"} - ) - assert state.identifier == "test_trigger" - assert state.inputs == {"key1": "value1", "key2": "value2"} - - def test_trigger_state_model_dump(self): - state = TriggerState( - identifier="test_trigger", - inputs={"key": "value"} - ) - dumped = state.model_dump() - assert dumped == { - "identifier": "test_trigger", - "inputs": {"key": "value"} - } - - class TestStateManagerTrigger: @pytest.mark.asyncio async def test_trigger_single_state_success(self, state_manager_config): @@ -124,9 +104,9 @@ async def test_trigger_single_state_success(self, state_manager_config): mock_session_class.return_value = mock_session sm = StateManager(**state_manager_config) - state = TriggerState(identifier="test", inputs={"key": "value"}) + state = {"identifier": "test", "inputs": {"key": "value"}} - result = await sm.trigger("test_graph", inputs=state.inputs) + result = await sm.trigger("test_graph", inputs=state["inputs"]) assert result == {"status": "success"} @@ -142,11 +122,11 @@ async def test_trigger_multiple_states_success(self, state_manager_config): sm = StateManager(**state_manager_config) states = [ - TriggerState(identifier="test1", inputs={"key1": "value1"}), - TriggerState(identifier="test2", inputs={"key2": "value2"}) + {"identifier": "test1", "inputs": {"key1": "value1"}}, + {"identifier": "test2", "inputs": {"key2": "value2"}} ] - merged_inputs = {**states[0].inputs, **states[1].inputs} + merged_inputs = {**states[0]["inputs"], **states[1]["inputs"]} result = await sm.trigger("test_graph", inputs=merged_inputs) assert result == {"status": "success"} @@ -162,10 +142,10 @@ async def test_trigger_failure(self, state_manager_config): mock_session_class.return_value = mock_session sm = StateManager(**state_manager_config) - state = TriggerState(identifier="test", inputs={"key": "value"}) + state = {"identifier": "test", "inputs": {"key": "value"}} with pytest.raises(Exception, match="Failed to trigger state: 400 Bad request"): - await sm.trigger("test_graph", inputs=state.inputs) + await sm.trigger("test_graph", inputs=state["inputs"]) class TestStateManagerGetGraph: @@ -229,7 +209,14 @@ async def test_upsert_graph_success_201(self, state_manager_config): ] sm = StateManager(**state_manager_config) - graph_nodes = [{"name": "node1", "type": "test"}] + graph_nodes = [GraphNodeModel( + node_name="node1", + namespace="test_namespace", + identifier="node1", + inputs={"type": "test"}, + next_nodes=None, + unites=None + )] secrets = {"secret1": "value1"} result = await sm.upsert_graph("test_graph", graph_nodes, secrets) @@ -254,7 +241,14 @@ async def test_upsert_graph_success_200(self, state_manager_config): mock_session_class.return_value = mock_session sm = StateManager(**state_manager_config) - graph_nodes = [{"name": "node1", "type": "test"}] + graph_nodes = [GraphNodeModel( + node_name="node1", + namespace="test_namespace", + identifier="node1", + inputs={"type": "test"}, + next_nodes=None, + unites=None + )] secrets = {"secret1": "value1"} result = await sm.upsert_graph("test_graph", graph_nodes, secrets) @@ -274,7 +268,14 @@ async def test_upsert_graph_put_failure(self, state_manager_config): mock_session_class.return_value = mock_session sm = StateManager(**state_manager_config) - graph_nodes = [{"name": "node1", "type": "test"}] + graph_nodes = [GraphNodeModel( + node_name="node1", + namespace="test_namespace", + identifier="node1", + inputs={"type": "test"}, + next_nodes=None, + unites=None + )] secrets = {"secret1": "value1"} with pytest.raises(Exception, match="Failed to upsert graph: 500 Internal server error"): @@ -300,7 +301,14 @@ async def test_upsert_graph_validation_timeout(self, state_manager_config): mock_get_graph.return_value = {"validation_status": "PENDING"} sm = StateManager(**state_manager_config) - graph_nodes = [{"name": "node1", "type": "test"}] + graph_nodes = [GraphNodeModel( + node_name="node1", + namespace="test_namespace", + identifier="node1", + inputs={"type": "test"}, + next_nodes=None, + unites=None + )] secrets = {"secret1": "value1"} with pytest.raises(Exception, match="Graph validation check timed out after 1 seconds"): @@ -332,7 +340,14 @@ async def test_upsert_graph_validation_failed(self, state_manager_config): ] sm = StateManager(**state_manager_config) - graph_nodes = [{"name": "node1", "type": "test"}] + graph_nodes = [GraphNodeModel( + node_name="node1", + namespace="test_namespace", + identifier="node1", + inputs={"type": "test"}, + next_nodes=None, + unites=None + )] secrets = {"secret1": "value1"} with pytest.raises(Exception, match="Graph validation failed: INVALID and errors: \\[\"Node 'node1' not found\"\\]"): @@ -361,7 +376,14 @@ async def test_upsert_graph_custom_timeout_and_polling(self, state_manager_confi ] sm = StateManager(**state_manager_config) - graph_nodes = [{"name": "node1", "type": "test"}] + graph_nodes = [GraphNodeModel( + node_name="node1", + namespace="test_namespace", + identifier="node1", + inputs={"type": "test"}, + next_nodes=None, + unites=None + )] secrets = {"secret1": "value1"} result = await sm.upsert_graph( diff --git a/state-manager/app/controller/trigger_graph.py b/state-manager/app/controller/trigger_graph.py index a00823e5..01d1131b 100644 --- a/state-manager/app/controller/trigger_graph.py +++ b/state-manager/app/controller/trigger_graph.py @@ -9,7 +9,9 @@ from app.models.db.graph_template_model import GraphTemplate from app.models.node_template_model import NodeTemplate from app.models.dependent_string import DependentString + import uuid +import time logger = LogsManager().get_logger() @@ -96,6 +98,7 @@ async def trigger_graph(namespace_name: str, graph_name: str, body: TriggerGraph graph_name=graph_name, run_id=run_id, status=StateStatusEnum.CREATED, + enqueue_after=int(time.time() * 1000) + body.start_delay, inputs=inputs, outputs={}, error=None diff --git a/state-manager/app/models/trigger_model.py b/state-manager/app/models/trigger_model.py index a61ffceb..119e33d5 100644 --- a/state-manager/app/models/trigger_model.py +++ b/state-manager/app/models/trigger_model.py @@ -4,6 +4,7 @@ class TriggerGraphRequestModel(BaseModel): store: dict[str, str] = Field(default_factory=dict, description="Store for the runtime") inputs: dict[str, str] = Field(default_factory=dict, description="Inputs for the graph execution") + start_delay: int = Field(default=0, ge=0, description="Start delay in milliseconds") class TriggerGraphResponseModel(BaseModel): status: StateStatusEnum = Field(..., description="Status of the states")