diff --git a/docs/docs/exosphere/signals.md b/docs/docs/exosphere/signals.md new file mode 100644 index 00000000..6796b455 --- /dev/null +++ b/docs/docs/exosphere/signals.md @@ -0,0 +1,131 @@ +# Signals + +!!! beta "Beta Feature" + Signals are currently available in beta. The API and functionality may change in future releases. + +Signals are a mechanism in Exosphere for controlling workflow execution flow and state management. They allow nodes to communicate with the state manager to perform specific actions like pruning states or requeuing them after a delay. + +## Overview + +Signals are implemented as exceptions that should be raised from within node execution. When a signal is raised, the runtime automatically handles the communication with the state manager to perform the requested action. + +## Available Signals + +### PruneSignal + +The `PruneSignal` is used to permanently remove a state from the workflow execution. This is typically used when a node determines that the current execution path should be terminated. + +#### Usage + +```python +from exospherehost import PruneSignal + +class MyNode(BaseNode): + class Inputs(BaseModel): + data: str + + class Outputs(BaseModel): + result: str + + async def execute(self, inputs: Inputs) -> Outputs: + if inputs.data == "invalid": + # Prune the state with optional data + raise PruneSignal({"reason": "invalid_data", "error": "Data validation failed"}) + + return self.Outputs(result="processed") +``` + +#### Parameters + +- `data` (dict[str, Any], optional): Additional data to include with the prune operation. Defaults to an empty dictionary. + +### ReQueueAfterSignal + +The `ReQueueAfterSignal` is used to requeue a state for execution after a specified time delay. This is useful for implementing retry logic, scheduled tasks, or rate limiting. + +#### Usage + +```python +from exospherehost import ReQueueAfterSignal +from datetime import timedelta + +class RetryNode(BaseNode): + class Inputs(BaseModel): + retry_count: int + data: str + + class Outputs(BaseModel): + result: str + + async def execute(self, inputs: Inputs) -> Outputs: + if inputs.retry_count < 3: + # Requeue after 5 minutes + raise ReQueueAfterSignal(timedelta(minutes=5)) + + return self.Outputs(result="completed") +``` + +#### Parameters + +- `delay` (timedelta): The amount of time to wait before requeuing the state. Must be greater than 0. + +## Important Notes + +1. **Do not catch signals**: Signals are designed to bubble up to the runtime for handling. Do not catch these exceptions in your node code. + +2. **Automatic handling**: The runtime automatically sends signals to the state manager when they are raised. + +3. **State lifecycle**: Signals affect the state's lifecycle in the state manager: + - `PruneSignal`: Sets state status to `PRUNED` + - `ReQueueAfterSignal`: Sets state status to `CREATED` and schedules requeue + +## Error Handling + +If signal sending fails (e.g., network issues), the runtime will log the error and continue processing other states. The failed signal will not be retried automatically. + +## Examples + +### Conditional Pruning + +```python +class ValidationNode(BaseNode): + class Inputs(BaseModel): + user_id: str + data: dict + + async def execute(self, inputs: Inputs) -> Outputs: + if not self._validate_user(inputs.user_id): + raise PruneSignal({ + "reason": "invalid_user", + "user_id": inputs.user_id, + "timestamp": datetime.now().isoformat() + }) + + return self.Outputs(validated=True) +``` + +### Polling + +```python +class PollingNode(BaseNode): + class Inputs(BaseModel): + job_id: str + + async def execute(self, inputs: Inputs) -> Outputs: + # Check if the job is complete + job_status = await self._check_job_status(inputs.job_id) + + if job_status == "completed": + result = await self._get_job_result(inputs.job_id) + return self.Outputs(result=result) + elif job_status == "failed": + # Job failed, prune the state + raise PruneSignal({ + "reason": "job_failed", + "job_id": inputs.job_id, + "poll_count": inputs.poll_count + }) + else: + # Job still running, poll again in 30 seconds + raise ReQueueAfterSignal(timedelta(seconds=30)) +``` \ No newline at end of file diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index df257a62..51df18d8 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -103,6 +103,7 @@ plugins: - exosphere/create-graph.md - exosphere/trigger-graph.md - exosphere/dashboard.md + - exosphere/signals.md - exosphere/architecture.md markdown_extensions: @@ -130,4 +131,5 @@ nav: - Create Graph: exosphere/create-graph.md - Trigger Graph: exosphere/trigger-graph.md - Dashboard: exosphere/dashboard.md + - Signals: exosphere/signals.md - Architecture: exosphere/architecture.md \ No newline at end of file diff --git a/python-sdk/exospherehost/__init__.py b/python-sdk/exospherehost/__init__.py index 777181b9..a16745df 100644 --- a/python-sdk/exospherehost/__init__.py +++ b/python-sdk/exospherehost/__init__.py @@ -38,7 +38,8 @@ async def execute(self, inputs: Inputs) -> Outputs: from .runtime import Runtime from .node.BaseNode import BaseNode from .statemanager import StateManager, TriggerState +from .signals import PruneSignal, ReQueueAfterSignal VERSION = __version__ -__all__ = ["Runtime", "BaseNode", "StateManager", "TriggerState", "VERSION"] +__all__ = ["Runtime", "BaseNode", "StateManager", "TriggerState", "VERSION", "PruneSignal", "ReQueueAfterSignal"] diff --git a/python-sdk/exospherehost/_version.py b/python-sdk/exospherehost/_version.py index a5ed84bc..9a836ee0 100644 --- a/python-sdk/exospherehost/_version.py +++ b/python-sdk/exospherehost/_version.py @@ -1 +1 @@ -version = "0.0.2b1" +version = "0.0.2b2" diff --git a/python-sdk/exospherehost/runtime.py b/python-sdk/exospherehost/runtime.py index 89786c58..de74e459 100644 --- a/python-sdk/exospherehost/runtime.py +++ b/python-sdk/exospherehost/runtime.py @@ -8,6 +8,7 @@ from pydantic import BaseModel from .node.BaseNode import BaseNode from aiohttp import ClientSession +from .signals import PruneSignal, ReQueueAfterSignal logger = logging.getLogger(__name__) @@ -159,6 +160,18 @@ def _get_secrets_endpoint(self, state_id: str): Construct the endpoint URL for getting secrets. """ return f"{self._state_manager_uri}/{str(self._state_manager_version)}/namespace/{self._namespace}/state/{state_id}/secrets" + + def _get_prune_endpoint(self, state_id: str): + """ + Construct the endpoint URL for pruning a state. + """ + return f"{self._state_manager_uri}/{str(self._state_manager_version)}/namespace/{self._namespace}/state/{state_id}/prune" + + def _get_requeue_after_endpoint(self, state_id: str): + """ + Construct the endpoint URL for requeuing a state after a timedelta. + """ + return f"{self._state_manager_uri}/{str(self._state_manager_version)}/namespace/{self._namespace}/state/{state_id}/re-enqueue-after" async def _register(self): """ @@ -395,6 +408,16 @@ async def _worker(self, idx: int): await self._notify_executed(state["state_id"], outputs) logger.info(f"Notified executed state {state['state_id']} for node {node.__name__ if node else "unknown"}") + + except PruneSignal as prune_signal: + logger.info(f"Pruning state {state['state_id']} for node {node.__name__ if node else "unknown"}") + await prune_signal.send(self._get_prune_endpoint(state["state_id"]), self._key) # type: ignore + logger.info(f"Pruned state {state['state_id']} for node {node.__name__ if node else "unknown"}") + + except ReQueueAfterSignal as requeue_signal: + logger.info(f"Requeuing state {state['state_id']} for node {node.__name__ if node else "unknown"} after {requeue_signal.delay}") + await requeue_signal.send(self._get_requeue_after_endpoint(state["state_id"]), self._key) # type: ignore + logger.info(f"Requeued state {state['state_id']} for node {node.__name__ if node else "unknown"} after {requeue_signal.delay}") except Exception as e: logger.error(f"Error executing state {state['state_id']} for node {node.__name__ if node else "unknown"}: {e}") diff --git a/python-sdk/exospherehost/signals.py b/python-sdk/exospherehost/signals.py new file mode 100644 index 00000000..c7072eb6 --- /dev/null +++ b/python-sdk/exospherehost/signals.py @@ -0,0 +1,71 @@ +from typing import Any +from aiohttp import ClientSession +from datetime import timedelta + +class PruneSignal(Exception): + """ + Exception used to signal that a prune operation should be performed. + + Args: + data (dict[str, Any], optional): Additional data to include with the signal. Defaults to {}. + + Note: + Do not catch this Exception, let it bubble up to Runtime for handling at StateManager. + """ + def __init__(self, data: dict[str, Any] = {}): + self.data = data + super().__init__(f"Prune signal received with data: {data} \n NOTE: Do not catch this Exception, let it bubble up to Runtime for handling at StateManager") + + async def send(self, endpoint: str, key: str): + """ + Sends the prune signal to the specified endpoint. + + Args: + endpoint (str): The URL to send the signal to. + key (str): The API key to include in the request headers. + + Raises: + Exception: If the HTTP request fails (status code != 200). + """ + async with ClientSession() as session: + async with session.post(endpoint, json=self.data, headers={"x-api-key": key}) as response: + if response.status != 200: + raise Exception(f"Failed to send prune signal to {endpoint}") + + +class ReQueueAfterSignal(Exception): + """ + Exception used to signal that a requeue operation should be performed after a specified timedelta. + + Args: + timedelta (timedelta): The amount of time to wait before requeuing. + + Note: + Do not catch this Exception, let it bubble up to Runtime for handling at StateManager. + """ + def __init__(self, delay: timedelta): + self.delay = delay + + if self.delay.total_seconds() <= 0: + raise Exception("Delay must be greater than 0") + + super().__init__(f"ReQueueAfter signal received with timedelta: {timedelta} \n NOTE: Do not catch this Exception, let it bubble up to Runtime for handling at StateManager") + + async def send(self, endpoint: str, key: str): + """ + Sends the requeue-after signal to the specified endpoint. + + Args: + endpoint (str): The URL to send the signal to. + key (str): The API key to include in the request headers. + + Raises: + Exception: If the HTTP request fails (status code != 200). + """ + body = { + "enqueue_after": int(self.delay.total_seconds() * 1000) + } + async with ClientSession() as session: + async with session.post(endpoint, json=body, headers={"x-api-key": key}) as response: + if response.status != 200: + raise Exception(f"Failed to send requeue after signal to {endpoint}") diff --git a/python-sdk/tests/test_package_init.py b/python-sdk/tests/test_package_init.py index 90bbabf7..d5406b52 100644 --- a/python-sdk/tests/test_package_init.py +++ b/python-sdk/tests/test_package_init.py @@ -15,7 +15,7 @@ def test_package_all_imports(): """Test that __all__ contains all expected exports.""" from exospherehost import __all__ - expected_exports = ["Runtime", "BaseNode", "StateManager", "TriggerState", "VERSION"] + expected_exports = ["Runtime", "BaseNode", "StateManager", "TriggerState", "VERSION", "PruneSignal", "ReQueueAfterSignal"] for export in expected_exports: assert export in __all__, f"{export} should be in __all__" diff --git a/python-sdk/tests/test_signals_and_runtime_functions.py b/python-sdk/tests/test_signals_and_runtime_functions.py new file mode 100644 index 00000000..e6a2222d --- /dev/null +++ b/python-sdk/tests/test_signals_and_runtime_functions.py @@ -0,0 +1,721 @@ +import pytest +import logging +from unittest.mock import AsyncMock, patch, MagicMock +from datetime import timedelta +from pydantic import BaseModel +from exospherehost.signals import PruneSignal, ReQueueAfterSignal +from exospherehost.runtime import Runtime, _setup_default_logging +from exospherehost.node.BaseNode import BaseNode + + +def create_mock_aiohttp_session(): + """Helper function to create a properly mocked aiohttp session.""" + mock_session = AsyncMock() + + # Create mock response objects + mock_post_response = MagicMock() + mock_get_response = MagicMock() + mock_put_response = MagicMock() + + # Create mock context managers for each method + mock_post_context = MagicMock() + mock_post_context.__aenter__ = AsyncMock(return_value=mock_post_response) + mock_post_context.__aexit__ = AsyncMock(return_value=None) + + mock_get_context = MagicMock() + mock_get_context.__aenter__ = AsyncMock(return_value=mock_get_response) + mock_get_context.__aexit__ = AsyncMock(return_value=None) + + mock_put_context = MagicMock() + mock_put_context.__aenter__ = AsyncMock(return_value=mock_put_response) + mock_put_context.__aexit__ = AsyncMock(return_value=None) + + # Set up the session methods to return the context managers using MagicMock + mock_session.post = MagicMock(return_value=mock_post_context) + mock_session.get = MagicMock(return_value=mock_get_context) + mock_session.put = MagicMock(return_value=mock_put_context) + + # Set up session context manager + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + + return mock_session, mock_post_response, mock_get_response, mock_put_response + + +class MockTestNode(BaseNode): + class Inputs(BaseModel): + name: str + + class Outputs(BaseModel): + message: str + + class Secrets(BaseModel): + api_key: str + + async def execute(self): + return self.Outputs(message=f"Hello {self.inputs.name}") # type: ignore + + +class TestPruneSignal: + """Test cases for PruneSignal exception class.""" + + def test_prune_signal_initialization_with_data(self): + """Test PruneSignal initialization with custom data.""" + data = {"reason": "test", "custom_field": "value"} + signal = PruneSignal(data) + + assert signal.data == data + assert "Prune signal received with data" in str(signal) + assert "Do not catch this Exception" in str(signal) + + def test_prune_signal_initialization_without_data(self): + """Test PruneSignal initialization without data (default empty dict).""" + signal = PruneSignal() + + assert signal.data == {} + assert "Prune signal received with data" in str(signal) + + def test_prune_signal_inheritance(self): + """Test that PruneSignal properly inherits from Exception.""" + signal = PruneSignal() + assert isinstance(signal, Exception) + + @pytest.mark.asyncio + async def test_prune_signal_send_success(self): + """Test successful sending of prune signal.""" + data = {"reason": "test_prune"} + signal = PruneSignal(data) + + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 200 + + with patch('exospherehost.signals.ClientSession', return_value=mock_session): + await signal.send("http://test-endpoint/prune", "test-api-key") + + # Verify the request was made correctly + mock_session.post.assert_called_once_with( + "http://test-endpoint/prune", + json=data, + headers={"x-api-key": "test-api-key"} + ) + + @pytest.mark.asyncio + async def test_prune_signal_send_failure(self): + """Test prune signal sending failure.""" + data = {"reason": "test_prune"} + signal = PruneSignal(data) + + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 500 + + with patch('exospherehost.signals.ClientSession', return_value=mock_session): + with pytest.raises(Exception, match="Failed to send prune signal"): + await signal.send("http://test-endpoint/prune", "test-api-key") + + +class TestReQueueAfterSignal: + """Test cases for ReQueueAfterSignal exception class.""" + + def test_requeue_signal_initialization(self): + """Test ReQueueAfterSignal initialization.""" + delta = timedelta(seconds=30) + signal = ReQueueAfterSignal(delta) + + assert signal.delay == delta + assert "ReQueueAfter signal received with timedelta" in str(signal) + assert "Do not catch this Exception" in str(signal) + + def test_requeue_signal_inheritance(self): + """Test that ReQueueAfterSignal properly inherits from Exception.""" + delta = timedelta(minutes=5) + signal = ReQueueAfterSignal(delta) + assert isinstance(signal, Exception) + + @pytest.mark.asyncio + async def test_requeue_signal_send_success(self): + """Test successful sending of requeue signal.""" + delta = timedelta(seconds=45) + signal = ReQueueAfterSignal(delta) + + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 200 + + with patch('exospherehost.signals.ClientSession', return_value=mock_session): + await signal.send("http://test-endpoint/requeue", "test-api-key") + + # Verify the request was made correctly + expected_body = {"enqueue_after": 45000} # 45 seconds * 1000 + mock_session.post.assert_called_once_with( + "http://test-endpoint/requeue", + json=expected_body, + headers={"x-api-key": "test-api-key"} + ) + + @pytest.mark.asyncio + async def test_requeue_signal_send_with_minutes(self): + """Test requeue signal sending with minutes in timedelta.""" + delta = timedelta(minutes=2, seconds=30) + signal = ReQueueAfterSignal(delta) + + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 200 + + with patch('exospherehost.signals.ClientSession', return_value=mock_session): + await signal.send("http://test-endpoint/requeue", "test-api-key") + + # Verify the request was made correctly + expected_body = {"enqueue_after": 150000} # (2*60 + 30) seconds * 1000 + mock_session.post.assert_called_once_with( + "http://test-endpoint/requeue", + json=expected_body, + headers={"x-api-key": "test-api-key"} + ) + + @pytest.mark.asyncio + async def test_requeue_signal_send_failure(self): + """Test requeue signal sending failure.""" + delta = timedelta(seconds=30) + signal = ReQueueAfterSignal(delta) + + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 400 + + with patch('exospherehost.signals.ClientSession', return_value=mock_session): + with pytest.raises(Exception, match="Failed to send requeue after signal"): + await signal.send("http://test-endpoint/requeue", "test-api-key") + + +class TestRuntimeSignalHandling: + """Test cases for Runtime signal handling functionality.""" + + def test_runtime_endpoint_construction(self): + """Test that runtime constructs correct endpoints for signal handling.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + # Test prune endpoint construction + prune_endpoint = runtime._get_prune_endpoint("test-state-id") + expected_prune = "http://test-state-manager/v0/namespace/test-namespace/state/test-state-id/prune" + assert prune_endpoint == expected_prune + + # Test requeue after endpoint construction + requeue_endpoint = runtime._get_requeue_after_endpoint("test-state-id") + expected_requeue = "http://test-state-manager/v0/namespace/test-namespace/state/test-state-id/re-enqueue-after" + assert requeue_endpoint == expected_requeue + + @pytest.mark.asyncio + async def test_signal_handling_direct(self): + """Test signal handling by directly calling signal.send() with runtime endpoints.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + # Test PruneSignal with runtime endpoint + prune_signal = PruneSignal({"reason": "direct_test"}) + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 200 + + with patch('exospherehost.signals.ClientSession', return_value=mock_session): + await prune_signal.send(runtime._get_prune_endpoint("test-state"), runtime._key) # type: ignore + + # Verify prune endpoint was called correctly + mock_session.post.assert_called_once_with( + runtime._get_prune_endpoint("test-state"), + json={"reason": "direct_test"}, + headers={"x-api-key": "test-key"} + ) + + @pytest.mark.asyncio + async def test_requeue_signal_handling_direct(self): + """Test requeue signal handling by directly calling signal.send() with runtime endpoints.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + # Test ReQueueAfterSignal with runtime endpoint + requeue_signal = ReQueueAfterSignal(timedelta(minutes=10)) + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 200 + + with patch('exospherehost.signals.ClientSession', return_value=mock_session): + await requeue_signal.send(runtime._get_requeue_after_endpoint("test-state"), runtime._key) # type: ignore + + # Verify requeue endpoint was called correctly + expected_body = {"enqueue_after": 600000} # 10 minutes * 60 * 1000 + mock_session.post.assert_called_once_with( + runtime._get_requeue_after_endpoint("test-state"), + json=expected_body, + headers={"x-api-key": "test-key"} + ) + + def test_need_secrets_function(self): + """Test the _need_secrets function.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + # Test with node that has secrets + assert runtime._need_secrets(MockTestNode) + + # Test with node that has no secrets + class MockNodeWithoutSecrets(BaseNode): + class Inputs(BaseModel): + name: str + class Outputs(BaseModel): + message: str + class Secrets(BaseModel): + pass + async def execute(self): + return self.Outputs(message="test") + + assert not runtime._need_secrets(MockNodeWithoutSecrets) + + @pytest.mark.asyncio + async def test_get_secrets_function(self): + """Test the _get_secrets function.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + # Mock successful secrets retrieval + mock_session, _, mock_get_response, _ = create_mock_aiohttp_session() + mock_get_response.status = 200 + mock_get_response.json = AsyncMock(return_value={"secrets": {"api_key": "test-secret"}}) + + with patch('exospherehost.runtime.ClientSession', return_value=mock_session): + secrets = await runtime._get_secrets("test-state-id") + + assert secrets == {"api_key": "test-secret"} + mock_session.get.assert_called_once_with( + runtime._get_secrets_endpoint("test-state-id"), + headers={"x-api-key": "test-key"} + ) + + @pytest.mark.asyncio + async def test_get_secrets_function_failure(self): + """Test the _get_secrets function when request fails.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + # Mock failed secrets retrieval + mock_session, _, mock_get_response, _ = create_mock_aiohttp_session() + mock_get_response.status = 404 + mock_get_response.json = AsyncMock(return_value={"error": "Not found"}) + + with patch('exospherehost.runtime.ClientSession', return_value=mock_session): + secrets = await runtime._get_secrets("test-state-id") + + assert secrets == {} + + @pytest.mark.asyncio + async def test_get_secrets_function_no_secrets_field(self): + """Test the _get_secrets function when response has no secrets field.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + # Mock response without secrets field + mock_session, _, mock_get_response, _ = create_mock_aiohttp_session() + mock_get_response.status = 200 + mock_get_response.json = AsyncMock(return_value={"data": "some other data"}) + + with patch('exospherehost.runtime.ClientSession', return_value=mock_session): + secrets = await runtime._get_secrets("test-state-id") + + assert secrets == {} + + +class TestRuntimeEndpointFunctions: + """Test cases for Runtime endpoint construction functions.""" + + def test_get_prune_endpoint(self): + """Test _get_prune_endpoint function.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + endpoint = runtime._get_prune_endpoint("state-123") + expected = "http://test-state-manager/v0/namespace/test-namespace/state/state-123/prune" + assert endpoint == expected + + def test_get_requeue_after_endpoint(self): + """Test _get_requeue_after_endpoint function.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + endpoint = runtime._get_requeue_after_endpoint("state-456") + expected = "http://test-state-manager/v0/namespace/test-namespace/state/state-456/re-enqueue-after" + assert endpoint == expected + + def test_get_prune_endpoint_with_custom_version(self): + """Test _get_prune_endpoint with custom state manager version.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key", + state_manage_version="v1" + ) + + endpoint = runtime._get_prune_endpoint("state-789") + expected = "http://test-state-manager/v1/namespace/test-namespace/state/state-789/prune" + assert endpoint == expected + + def test_get_requeue_after_endpoint_with_custom_version(self): + """Test _get_requeue_after_endpoint with custom state manager version.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key", + state_manage_version="v2" + ) + + endpoint = runtime._get_requeue_after_endpoint("state-101") + expected = "http://test-state-manager/v2/namespace/test-namespace/state/state-101/re-enqueue-after" + assert endpoint == expected + + +class TestSignalIntegration: + """Integration tests for signal handling in the runtime.""" + + @pytest.mark.asyncio + async def test_signal_exception_behavior(self): + """Test that signals are proper exceptions that can be raised and caught.""" + # Test PruneSignal + prune_signal = PruneSignal({"test": "data"}) + + with pytest.raises(PruneSignal) as exc_info: + raise prune_signal + + assert exc_info.value.data == {"test": "data"} + assert isinstance(exc_info.value, Exception) + + # Test ReQueueAfterSignal + requeue_signal = ReQueueAfterSignal(timedelta(seconds=30)) + + with pytest.raises(ReQueueAfterSignal) as exc_info: + raise requeue_signal + + assert exc_info.value.delay == timedelta(seconds=30) + assert isinstance(exc_info.value, Exception) + + @pytest.mark.asyncio + async def test_combined_signal_and_runtime_functionality(self): + """Test that signals work correctly with runtime endpoints.""" + runtime = Runtime( + namespace="production", + name="signal-runtime", + nodes=[MockTestNode], + state_manager_uri="https://api.exosphere.host", + key="prod-api-key", + state_manage_version="v1" + ) + + # Test PruneSignal with production-like endpoint + prune_signal = PruneSignal({"reason": "cleanup", "batch_id": "batch-123"}) + expected_prune_endpoint = "https://api.exosphere.host/v1/namespace/production/state/prod-state-456/prune" + actual_prune_endpoint = runtime._get_prune_endpoint("prod-state-456") + assert actual_prune_endpoint == expected_prune_endpoint + + # Test ReQueueAfterSignal with production-like endpoint + requeue_signal = ReQueueAfterSignal(timedelta(hours=2, minutes=30)) + expected_requeue_endpoint = "https://api.exosphere.host/v1/namespace/production/state/prod-state-789/re-enqueue-after" + actual_requeue_endpoint = runtime._get_requeue_after_endpoint("prod-state-789") + assert actual_requeue_endpoint == expected_requeue_endpoint + + # Test that signal data is preserved + assert prune_signal.data == {"reason": "cleanup", "batch_id": "batch-123"} + assert requeue_signal.delay == timedelta(hours=2, minutes=30) + + @pytest.mark.asyncio + async def test_signal_send_with_different_endpoints(self): + """Test signal sending with various endpoint configurations.""" + # Test with different URI formats + test_cases = [ + ("http://localhost:8080", "v0", "dev"), + ("https://api.production.com", "v2", "production"), + ("http://staging.internal:3000", "v1", "staging") + ] + + for uri, version, namespace in test_cases: + runtime = Runtime( + namespace=namespace, + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri=uri, + key="test-key", + state_manage_version=version + ) + + # Test prune endpoint construction + prune_endpoint = runtime._get_prune_endpoint("test-state") + expected_prune = f"{uri}/{version}/namespace/{namespace}/state/test-state/prune" + assert prune_endpoint == expected_prune + + # Test requeue endpoint construction + requeue_endpoint = runtime._get_requeue_after_endpoint("test-state") + expected_requeue = f"{uri}/{version}/namespace/{namespace}/state/test-state/re-enqueue-after" + assert requeue_endpoint == expected_requeue + + +class TestSignalEdgeCases: + """Test cases for signal edge cases and error conditions.""" + + def test_prune_signal_with_empty_data(self): + """Test PruneSignal with empty data.""" + signal = PruneSignal({}) + assert signal.data == {} + assert isinstance(signal, Exception) + + def test_prune_signal_with_complex_data(self): + """Test PruneSignal with complex nested data.""" + complex_data = { + "reason": "batch_cleanup", + "metadata": { + "batch_id": "batch-456", + "items": ["item1", "item2", "item3"], + "timestamp": "2023-12-01T10:00:00Z" + }, + "config": { + "force": True, + "notify_users": False + } + } + signal = PruneSignal(complex_data) + assert signal.data == complex_data + + def test_requeue_signal_with_zero_timedelta(self): + """Test ReQueueAfterSignal with zero timedelta.""" + with pytest.raises(Exception): + ReQueueAfterSignal(timedelta(seconds=0)) + + def test_requeue_signal_with_large_timedelta(self): + """Test ReQueueAfterSignal with large timedelta.""" + large_delta = timedelta(days=7, hours=12, minutes=30, seconds=45) + signal = ReQueueAfterSignal(large_delta) + assert signal.delay == large_delta + + @pytest.mark.asyncio + async def test_requeue_signal_timedelta_conversion(self): + """Test that ReQueueAfterSignal correctly converts timedelta to milliseconds.""" + test_cases = [ + (timedelta(seconds=1), 1000), + (timedelta(minutes=1), 60000), + (timedelta(hours=1), 3600000), + (timedelta(days=1), 86400000), + (timedelta(seconds=30, microseconds=500000), 30500), # 30.5 seconds + ] + + for delta, expected_ms in test_cases: + signal = ReQueueAfterSignal(delta) + + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 200 + + with patch('exospherehost.signals.ClientSession', return_value=mock_session): + await signal.send("http://test-endpoint", "test-key") + + # Verify correct milliseconds conversion + expected_body = {"enqueue_after": expected_ms} + mock_session.post.assert_called_with( + "http://test-endpoint", + json=expected_body, + headers={"x-api-key": "test-key"} + ) + + def test_signal_string_representations(self): + """Test string representations of signals.""" + prune_signal = PruneSignal({"test": "data"}) + prune_str = str(prune_signal) + assert "Prune signal received with data" in prune_str + assert "Do not catch this Exception" in prune_str + assert "{'test': 'data'}" in prune_str + + requeue_signal = ReQueueAfterSignal(timedelta(minutes=5)) + requeue_str = str(requeue_signal) + assert "ReQueueAfter signal received with timedelta" in requeue_str + assert "Do not catch this Exception" in requeue_str + +class TestRuntimeHelperFunctions: + """Test cases for Runtime helper functions.""" + + @pytest.mark.asyncio + async def test_notify_executed_function(self): + """Test the _notify_executed function.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + # Mock successful notification + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 200 + mock_post_response.json = AsyncMock(return_value={"status": "success"}) + + # Create test outputs + outputs = [MockTestNode.Outputs(message="output1"), MockTestNode.Outputs(message="output2")] + + with patch('exospherehost.runtime.ClientSession', return_value=mock_session): + await runtime._notify_executed("test-state-id", outputs) # type: ignore + + # Verify correct endpoint and payload + expected_body = {"outputs": [{"message": "output1"}, {"message": "output2"}]} + mock_session.post.assert_called_once_with( + runtime._get_executed_endpoint("test-state-id"), + json=expected_body, + headers={"x-api-key": "test-key"} + ) + + @pytest.mark.asyncio + async def test_notify_errored_function(self): + """Test the _notify_errored function.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + # Mock successful notification + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 200 + mock_post_response.json = AsyncMock(return_value={"status": "success"}) + + with patch('exospherehost.runtime.ClientSession', return_value=mock_session): + await runtime._notify_errored("test-state-id", "Test error message") + + # Verify correct endpoint and payload + expected_body = {"error": "Test error message"} + mock_session.post.assert_called_once_with( + runtime._get_errored_endpoint("test-state-id"), + json=expected_body, + headers={"x-api-key": "test-key"} + ) + + @pytest.mark.asyncio + async def test_notify_functions_with_failure(self): + """Test notification functions when HTTP requests fail.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + # Mock failed notification + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 500 + mock_post_response.json = AsyncMock(return_value={"error": "Internal server error"}) + + outputs = [MockTestNode.Outputs(message="test")] + + with patch('exospherehost.runtime.ClientSession', return_value=mock_session): + # These should not raise exceptions, just log errors + await runtime._notify_executed("test-state-id", outputs) # type: ignore + await runtime._notify_errored("test-state-id", "Test error") + + # Verify both endpoints were called despite failures + assert mock_session.post.call_count == 2 + + +class TestSetupDefaultLogging: + """Test cases for the _setup_default_logging function.""" + + def test_setup_default_logging_with_existing_handlers(self): + """Test that _setup_default_logging doesn't interfere with existing handlers.""" + # Create a logger with existing handlers + test_logger = logging.getLogger("test_logger") + handler = logging.StreamHandler() + test_logger.addHandler(handler) + + # Mock the root logger to have handlers + with patch('logging.getLogger') as mock_get_logger: + mock_root_logger = MagicMock() + mock_root_logger.handlers = [handler] + mock_get_logger.return_value = mock_root_logger + + # This should return early and not configure logging + _setup_default_logging() + + # Verify no basic config was called + mock_root_logger.basicConfig = MagicMock() + assert not mock_root_logger.basicConfig.called + + def test_setup_default_logging_with_disable_env_var(self): + """Test that _setup_default_logging respects the disable environment variable.""" + with patch.dict('os.environ', {'EXOSPHERE_DISABLE_DEFAULT_LOGGING': 'true'}), \ + patch('logging.getLogger') as mock_get_logger: + mock_root_logger = MagicMock() + mock_root_logger.handlers = [] + mock_get_logger.return_value = mock_root_logger + + _setup_default_logging() + + # Should return early due to env var + with patch('logging.basicConfig') as mock_basic_config: + _setup_default_logging() + assert not mock_basic_config.called + + def test_setup_default_logging_with_custom_log_level(self): + """Test that _setup_default_logging respects custom log level.""" + with patch.dict('os.environ', {'EXOSPHERE_LOG_LEVEL': 'DEBUG'}), \ + patch('logging.getLogger') as mock_get_logger, \ + patch('logging.basicConfig') as mock_basic_config: + + mock_root_logger = MagicMock() + mock_root_logger.handlers = [] + mock_get_logger.return_value = mock_root_logger + + _setup_default_logging() + + # Verify basicConfig was called with DEBUG level + mock_basic_config.assert_called_once() + call_args = mock_basic_config.call_args + assert call_args[1]['level'] == logging.DEBUG \ No newline at end of file diff --git a/state-manager/.coverage b/state-manager/.coverage new file mode 100644 index 00000000..c086d6eb Binary files /dev/null and b/state-manager/.coverage differ diff --git a/state-manager/app/controller/enqueue_states.py b/state-manager/app/controller/enqueue_states.py index b27a6bef..a5c36b52 100644 --- a/state-manager/app/controller/enqueue_states.py +++ b/state-manager/app/controller/enqueue_states.py @@ -1,4 +1,5 @@ import asyncio +import time from ..models.enqueue_request import EnqueueRequestModel from ..models.enqueue_response import EnqueueResponseModel, StateModel @@ -18,7 +19,8 @@ async def find_state(namespace_name: str, nodes: list[str]) -> State | None: "status": StateStatusEnum.CREATED, "node_name": { "$in": nodes - } + }, + "enqueue_after": {"$lte": int(time.time() * 1000)} }, { "$set": {"status": StateStatusEnum.QUEUED} diff --git a/state-manager/app/controller/prune_signal.py b/state-manager/app/controller/prune_signal.py new file mode 100644 index 00000000..122e93f0 --- /dev/null +++ b/state-manager/app/controller/prune_signal.py @@ -0,0 +1,32 @@ +from app.models.signal_models import PruneRequestModel, SignalResponseModel +from fastapi import HTTPException, status +from beanie import PydanticObjectId + +from app.models.db.state import State +from app.models.state_status_enum import StateStatusEnum +from app.singletons.logs_manager import LogsManager + +logger = LogsManager().get_logger() + +async def prune_signal(namespace_name: str, state_id: PydanticObjectId, body: PruneRequestModel, x_exosphere_request_id: str) -> SignalResponseModel: + + try: + logger.info(f"Received prune signal for state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + + state = await State.find_one(State.id == state_id) + + if not state: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="State not found") + + if state.status != StateStatusEnum.QUEUED: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="State is not queued") + + state.status = StateStatusEnum.PRUNED + state.data = body.data + await state.save() + + return SignalResponseModel(status=StateStatusEnum.PRUNED, enqueue_after=state.enqueue_after) + + except Exception as e: + logger.error(f"Error pruning state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id, error=e) + raise \ No newline at end of file diff --git a/state-manager/app/controller/re_queue_after_signal.py b/state-manager/app/controller/re_queue_after_signal.py new file mode 100644 index 00000000..009f1424 --- /dev/null +++ b/state-manager/app/controller/re_queue_after_signal.py @@ -0,0 +1,30 @@ +from app.models.signal_models import ReEnqueueAfterRequestModel, SignalResponseModel +from fastapi import HTTPException, status +from beanie import PydanticObjectId +import time + +from app.models.db.state import State +from app.models.state_status_enum import StateStatusEnum +from app.singletons.logs_manager import LogsManager + +logger = LogsManager().get_logger() + +async def re_queue_after_signal(namespace_name: str, state_id: PydanticObjectId, body: ReEnqueueAfterRequestModel, x_exosphere_request_id: str) -> SignalResponseModel: + + try: + logger.info(f"Received re-queue after signal for state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + + state = await State.find_one(State.id == state_id) + + if not state: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="State not found") + + state.status = StateStatusEnum.CREATED + state.enqueue_after = int(time.time() * 1000) + body.enqueue_after + await state.save() + + return SignalResponseModel(status=StateStatusEnum.CREATED, enqueue_after=state.enqueue_after) + + except Exception as e: + logger.error(f"Error re-queueing state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id, error=e) + raise \ No newline at end of file diff --git a/state-manager/app/models/db/state.py b/state-manager/app/models/db/state.py index 8d77d967..e37326b4 100644 --- a/state-manager/app/models/db/state.py +++ b/state-manager/app/models/db/state.py @@ -7,7 +7,7 @@ from typing import Any, Optional import hashlib import json - +import time class State(BaseDatabaseModel): node_name: str = Field(..., description="Name of the node of the state") @@ -18,10 +18,12 @@ class State(BaseDatabaseModel): status: StateStatusEnum = Field(..., description="Status of the state") inputs: dict[str, Any] = Field(..., description="Inputs of the state") outputs: dict[str, Any] = Field(..., description="Outputs of the state") + data: dict[str, Any] = Field(default_factory=dict, description="Data of the state (could be used to save pruned meta data)") error: Optional[str] = Field(None, description="Error message") parents: dict[str, PydanticObjectId] = Field(default_factory=dict, description="Parents of the state") does_unites: bool = Field(default=False, description="Whether this state unites other states") state_fingerprint: str = Field(default="", description="Fingerprint of the state") + enqueue_after: int = Field(default_factory=lambda: int(time.time() * 1000), gt=0, description="Unix time in milliseconds after which the state should be enqueued") @before_event([Insert, Replace, Save]) def _generate_fingerprint(self): @@ -65,5 +67,14 @@ class Settings: partialFilterExpression={ "does_unites": True } + ), + IndexModel( + [ + ("enqueue_after", 1), + ("status", 1), + ("namespace_name", 1), + ("node_name", 1), + ], + name="idx_enqueue_after" ) ] \ No newline at end of file diff --git a/state-manager/app/models/signal_models.py b/state-manager/app/models/signal_models.py new file mode 100644 index 00000000..40abe6f4 --- /dev/null +++ b/state-manager/app/models/signal_models.py @@ -0,0 +1,14 @@ +from pydantic import BaseModel, Field +from .state_status_enum import StateStatusEnum +from typing import Any + + +class SignalResponseModel(BaseModel): + enqueue_after: int = Field(..., description="Unix time in milliseconds after which the state should be re-enqueued") + status: StateStatusEnum = Field(..., description="Status of the state") + +class PruneRequestModel(BaseModel): + data: dict[str, Any] = Field(..., description="Data of the state") + +class ReEnqueueAfterRequestModel(BaseModel): + enqueue_after: int = Field(..., gt=0, description="Duration in milliseconds to delay the re-enqueuing of the state") \ No newline at end of file diff --git a/state-manager/app/models/state_status_enum.py b/state-manager/app/models/state_status_enum.py index 8da97002..7760176d 100644 --- a/state-manager/app/models/state_status_enum.py +++ b/state-manager/app/models/state_status_enum.py @@ -6,10 +6,8 @@ class StateStatusEnum(str, Enum): CREATED = 'CREATED' QUEUED = 'QUEUED' EXECUTED = 'EXECUTED' - NEXT_CREATED = 'NEXT_CREATED' - RETRY_CREATED = 'RETRY_CREATED' - TIMEDOUT = 'TIMEDOUT' ERRORED = 'ERRORED' CANCELLED = 'CANCELLED' SUCCESS = 'SUCCESS' NEXT_CREATED_ERROR = 'NEXT_CREATED_ERROR' + PRUNED = 'PRUNED' \ No newline at end of file diff --git a/state-manager/app/routes.py b/state-manager/app/routes.py index 9e956db2..ee219bb1 100644 --- a/state-manager/app/routes.py +++ b/state-manager/app/routes.py @@ -40,6 +40,14 @@ from .models.graph_structure_models import GraphStructureResponse from .controller.get_graph_structure import get_graph_structure +### signals +from .models.signal_models import SignalResponseModel +from .models.signal_models import PruneRequestModel +from .controller.prune_signal import prune_signal +from .models.signal_models import ReEnqueueAfterRequestModel +from .controller.re_queue_after_signal import re_queue_after_signal + + logger = LogsManager().get_logger() router = APIRouter(prefix="/v0/namespace/{namespace_name}") @@ -145,6 +153,44 @@ async def errored_state_route(namespace_name: str, state_id: str, body: ErroredR return await errored_state(namespace_name, PydanticObjectId(state_id), body, x_exosphere_request_id) +@router.post( + "/states/{state_id}/prune", + response_model=SignalResponseModel, + status_code=status.HTTP_200_OK, + response_description="State pruned successfully", + tags=["state"] +) +async def prune_state_route(namespace_name: str, state_id: str, body: PruneRequestModel, request: Request, api_key: str = Depends(check_api_key)): + x_exosphere_request_id = getattr(request.state, "x_exosphere_request_id", str(uuid4())) + + if api_key: + logger.info(f"API key is valid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + else: + logger.error(f"API key is invalid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") + + return await prune_signal(namespace_name, PydanticObjectId(state_id), body, x_exosphere_request_id) + + +@router.post( + "/states/{state_id}/re-enqueue-after", + response_model=SignalResponseModel, + status_code=status.HTTP_200_OK, + response_description="State re-enqueued successfully", + tags=["state"] +) +async def re_enqueue_after_state_route(namespace_name: str, state_id: str, body: ReEnqueueAfterRequestModel, request: Request, api_key: str = Depends(check_api_key)): + x_exosphere_request_id = getattr(request.state, "x_exosphere_request_id", str(uuid4())) + + if api_key: + logger.info(f"API key is valid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + else: + logger.error(f"API key is invalid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") + + return await re_queue_after_signal(namespace_name, PydanticObjectId(state_id), body, x_exosphere_request_id) + + @router.put( "/graph/{graph_name}", response_model=UpsertGraphTemplateResponse, diff --git a/state-manager/tests/unit/controller/test_prune_signal.py b/state-manager/tests/unit/controller/test_prune_signal.py new file mode 100644 index 00000000..1c36170a --- /dev/null +++ b/state-manager/tests/unit/controller/test_prune_signal.py @@ -0,0 +1,319 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from fastapi import HTTPException, status +from beanie import PydanticObjectId + +from app.controller.prune_signal import prune_signal +from app.models.signal_models import PruneRequestModel +from app.models.state_status_enum import StateStatusEnum + + +class TestPruneSignal: + """Test cases for prune_signal function""" + + @pytest.fixture + def mock_request_id(self): + return "test-request-id" + + @pytest.fixture + def mock_namespace(self): + return "test_namespace" + + @pytest.fixture + def mock_state_id(self): + return PydanticObjectId() + + @pytest.fixture + def mock_prune_request(self): + return PruneRequestModel( + data={"key": "value", "nested": {"data": "test"}} + ) + + @pytest.fixture + def mock_state_created(self): + state = MagicMock() + state.id = PydanticObjectId() + state.status = StateStatusEnum.QUEUED + state.enqueue_after = 1234567890 + return state + + @patch('app.controller.prune_signal.State') + async def test_prune_signal_success( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_prune_request, + mock_state_created, + mock_request_id + ): + """Test successful pruning of state""" + # Arrange + mock_state_created.save = AsyncMock() + mock_state_class.find_one = AsyncMock(return_value=mock_state_created) + + # Act + result = await prune_signal( + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ) + + # Assert + assert result.status == StateStatusEnum.PRUNED + assert result.enqueue_after == 1234567890 + assert mock_state_created.status == StateStatusEnum.PRUNED + assert mock_state_created.data == mock_prune_request.data + assert mock_state_created.save.call_count == 1 + assert mock_state_class.find_one.call_count == 1 + + @patch('app.controller.prune_signal.State') + async def test_prune_signal_state_not_found( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ): + """Test when state is not found""" + # Arrange + mock_state_class.find_one = AsyncMock(return_value=None) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await prune_signal( + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ) + + assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND + assert exc_info.value.detail == "State not found" + + @patch('app.controller.prune_signal.State') + async def test_prune_signal_invalid_status_created( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ): + """Test when state is in QUEUED status (invalid for pruning)""" + # Arrange + mock_state = MagicMock() + mock_state.status = StateStatusEnum.CREATED + mock_state_class.find_one = AsyncMock(return_value=mock_state) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await prune_signal( + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ) + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert exc_info.value.detail == "State is not queued" + + @patch('app.controller.prune_signal.State') + async def test_prune_signal_invalid_status_executed( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ): + """Test when state is in EXECUTED status (invalid for pruning)""" + # Arrange + mock_state = MagicMock() + mock_state.status = StateStatusEnum.EXECUTED + mock_state_class.find_one = AsyncMock(return_value=mock_state) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await prune_signal( + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ) + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert exc_info.value.detail == "State is not queued" + + @patch('app.controller.prune_signal.State') + async def test_prune_signal_invalid_status_errored( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ): + """Test when state is in ERRORED status (invalid for pruning)""" + # Arrange + mock_state = MagicMock() + mock_state.status = StateStatusEnum.ERRORED + mock_state_class.find_one = AsyncMock(return_value=mock_state) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await prune_signal( + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ) + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert exc_info.value.detail == "State is not queued" + + @patch('app.controller.prune_signal.State') + async def test_prune_signal_invalid_status_pruned( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ): + """Test when state is already in PRUNED status (invalid for pruning)""" + # Arrange + mock_state = MagicMock() + mock_state.status = StateStatusEnum.PRUNED + mock_state_class.find_one = AsyncMock(return_value=mock_state) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await prune_signal( + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ) + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert exc_info.value.detail == "State is not queued" + + @patch('app.controller.prune_signal.State') + async def test_prune_signal_database_error( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ): + """Test handling of database errors""" + # Arrange + mock_state_class.find_one = MagicMock(side_effect=Exception("Database error")) + + # Act & Assert + with pytest.raises(Exception) as exc_info: + await prune_signal( + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ) + + assert str(exc_info.value) == "Database error" + + @patch('app.controller.prune_signal.State') + async def test_prune_signal_save_error( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_prune_request, + mock_state_created, + mock_request_id + ): + """Test handling of save errors""" + # Arrange + mock_state_created.save = AsyncMock(side_effect=Exception("Save error")) + mock_state_class.find_one = AsyncMock(return_value=mock_state_created) + + # Act & Assert + with pytest.raises(Exception) as exc_info: + await prune_signal( + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ) + + assert str(exc_info.value) == "Save error" + + @patch('app.controller.prune_signal.State') + async def test_prune_signal_with_empty_data( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_state_created, + mock_request_id + ): + """Test pruning with empty data""" + # Arrange + prune_request = PruneRequestModel(data={}) + mock_state_created.save = AsyncMock() + mock_state_class.find_one = AsyncMock(return_value=mock_state_created) + + # Act + result = await prune_signal( + mock_namespace, + mock_state_id, + prune_request, + mock_request_id + ) + + # Assert + assert result.status == StateStatusEnum.PRUNED + assert mock_state_created.data == {} + assert mock_state_created.save.call_count == 1 + + @patch('app.controller.prune_signal.State') + async def test_prune_signal_with_complex_data( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_state_created, + mock_request_id + ): + """Test pruning with complex nested data""" + # Arrange + complex_data = { + "string": "test", + "number": 42, + "boolean": True, + "list": [1, 2, 3], + "nested": { + "object": { + "deep": "value" + } + } + } + prune_request = PruneRequestModel(data=complex_data) + mock_state_created.save = AsyncMock() + mock_state_class.find_one = AsyncMock(return_value=mock_state_created) + + # Act + result = await prune_signal( + mock_namespace, + mock_state_id, + prune_request, + mock_request_id + ) + + # Assert + assert result.status == StateStatusEnum.PRUNED + assert mock_state_created.data == complex_data + assert mock_state_created.save.call_count == 1 \ No newline at end of file diff --git a/state-manager/tests/unit/controller/test_re_queue_after_signal.py b/state-manager/tests/unit/controller/test_re_queue_after_signal.py new file mode 100644 index 00000000..48f41922 --- /dev/null +++ b/state-manager/tests/unit/controller/test_re_queue_after_signal.py @@ -0,0 +1,313 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from fastapi import HTTPException, status +from beanie import PydanticObjectId + +from app.controller.re_queue_after_signal import re_queue_after_signal +from app.models.signal_models import ReEnqueueAfterRequestModel +from app.models.state_status_enum import StateStatusEnum + + +class TestReQueueAfterSignal: + """Test cases for re_queue_after_signal function""" + + @pytest.fixture + def mock_request_id(self): + return "test-request-id" + + @pytest.fixture + def mock_namespace(self): + return "test_namespace" + + @pytest.fixture + def mock_state_id(self): + return PydanticObjectId() + + @pytest.fixture + def mock_re_enqueue_request(self): + return ReEnqueueAfterRequestModel( + enqueue_after=5000 # 5 seconds in milliseconds + ) + + @pytest.fixture + def mock_state_any_status(self): + state = MagicMock() + state.id = PydanticObjectId() + state.status = StateStatusEnum.QUEUED # Any status is valid for re-enqueue + state.enqueue_after = 1234567890 + return state + + @patch('app.controller.re_queue_after_signal.State') + @patch('app.controller.re_queue_after_signal.time') + async def test_re_queue_after_signal_success( + self, + mock_time, + mock_state_class, + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_state_any_status, + mock_request_id + ): + """Test successful re-enqueuing of state""" + # Arrange + mock_time.time.return_value = 1000.0 # Mock current time + mock_state_any_status.save = AsyncMock() + mock_state_class.find_one = AsyncMock(return_value=mock_state_any_status) + + # Act + result = await re_queue_after_signal( + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_request_id + ) + + # Assert + assert result.status == StateStatusEnum.CREATED + assert result.enqueue_after == 1005000 # 1000 * 1000 + 5000 + assert mock_state_any_status.status == StateStatusEnum.CREATED + assert mock_state_any_status.enqueue_after == 1005000 + assert mock_state_any_status.save.call_count == 1 + assert mock_state_class.find_one.call_count == 1 + + @patch('app.controller.re_queue_after_signal.State') + async def test_re_queue_after_signal_state_not_found( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_request_id + ): + """Test when state is not found""" + # Arrange + mock_state_class.find_one = AsyncMock(return_value=None) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await re_queue_after_signal( + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_request_id + ) + + assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND + assert exc_info.value.detail == "State not found" + + @patch('app.controller.re_queue_after_signal.State') + @patch('app.controller.re_queue_after_signal.time') + async def test_re_queue_after_signal_with_zero_delay( + self, + mock_time, + mock_state_class, + mock_namespace, + mock_state_id, + mock_state_any_status, + mock_request_id + ): + """Test re-enqueuing with zero delay""" + # Arrange + mock_time.time.return_value = 1000.0 + re_enqueue_request = ReEnqueueAfterRequestModel(enqueue_after=1) + mock_state_any_status.save = AsyncMock() + mock_state_class.find_one = AsyncMock(return_value=mock_state_any_status) + + # Act + result = await re_queue_after_signal( + mock_namespace, + mock_state_id, + re_enqueue_request, + mock_request_id + ) + + # Assert + assert result.status == StateStatusEnum.CREATED + assert result.enqueue_after == 1000001 # 1000 * 1000 + 0 + assert mock_state_any_status.enqueue_after == 1000001 + assert mock_state_any_status.save.call_count == 1 + + @patch('app.controller.re_queue_after_signal.State') + @patch('app.controller.re_queue_after_signal.time') + async def test_re_queue_after_signal_with_large_delay( + self, + mock_time, + mock_state_class, + mock_namespace, + mock_state_id, + mock_state_any_status, + mock_request_id + ): + """Test re-enqueuing with large delay""" + # Arrange + mock_time.time.return_value = 1000.0 + re_enqueue_request = ReEnqueueAfterRequestModel(enqueue_after=86400000) # 24 hours + mock_state_any_status.save = AsyncMock() + mock_state_class.find_one = AsyncMock(return_value=mock_state_any_status) + + # Act + result = await re_queue_after_signal( + mock_namespace, + mock_state_id, + re_enqueue_request, + mock_request_id + ) + + # Assert + assert result.status == StateStatusEnum.CREATED + assert result.enqueue_after == 87400000 # 1000 * 1000 + 86400000 + assert mock_state_any_status.enqueue_after == 87400000 + assert mock_state_any_status.save.call_count == 1 + + @patch('app.controller.re_queue_after_signal.State') + @patch('app.controller.re_queue_after_signal.time') + async def test_re_queue_after_signal_with_negative_delay( + self, + mock_time, + mock_state_class, + mock_namespace, + mock_state_id, + mock_state_any_status, + mock_request_id + ): + """Test re-enqueuing with negative delay (should still work)""" + # Arrange + + with pytest.raises(Exception): + ReEnqueueAfterRequestModel(enqueue_after=-5000) # Negative delay + + with pytest.raises(Exception): + ReEnqueueAfterRequestModel(enqueue_after=0) + + + @patch('app.controller.re_queue_after_signal.State') + async def test_re_queue_after_signal_database_error( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_request_id + ): + """Test handling of database errors""" + # Arrange + mock_state_class.find_one = MagicMock(side_effect=Exception("Database error")) + + # Act & Assert + with pytest.raises(Exception) as exc_info: + await re_queue_after_signal( + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_request_id + ) + + assert str(exc_info.value) == "Database error" + + @patch('app.controller.re_queue_after_signal.State') + @patch('app.controller.re_queue_after_signal.time') + async def test_re_queue_after_signal_save_error( + self, + mock_time, + mock_state_class, + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_state_any_status, + mock_request_id + ): + """Test handling of save errors""" + # Arrange + mock_time.time.return_value = 1000.0 + mock_state_any_status.save = AsyncMock(side_effect=Exception("Save error")) + mock_state_class.find_one = AsyncMock(return_value=mock_state_any_status) + + # Act & Assert + with pytest.raises(Exception) as exc_info: + await re_queue_after_signal( + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_request_id + ) + + assert str(exc_info.value) == "Save error" + + @patch('app.controller.re_queue_after_signal.State') + @patch('app.controller.re_queue_after_signal.time') + async def test_re_queue_after_signal_from_different_statuses( + self, + mock_time, + mock_state_class, + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_request_id + ): + """Test re-enqueuing from different initial statuses""" + # Arrange + mock_time.time.return_value = 1000.0 + + test_cases = [ + StateStatusEnum.CREATED, + StateStatusEnum.QUEUED, + StateStatusEnum.EXECUTED, + StateStatusEnum.ERRORED, + StateStatusEnum.CANCELLED, + StateStatusEnum.SUCCESS, + StateStatusEnum.NEXT_CREATED_ERROR, + StateStatusEnum.PRUNED + ] + + for initial_status in test_cases: + # Arrange for this test case + mock_state = MagicMock() + mock_state.id = PydanticObjectId() + mock_state.status = initial_status + mock_state.save = AsyncMock() + mock_state_class.find_one = AsyncMock(return_value=mock_state) + + # Act + result = await re_queue_after_signal( + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_request_id + ) + + # Assert + assert result.status == StateStatusEnum.CREATED + assert mock_state.status == StateStatusEnum.CREATED + assert mock_state.save.call_count == 1 + + @patch('app.controller.re_queue_after_signal.State') + @patch('app.controller.re_queue_after_signal.time') + async def test_re_queue_after_signal_time_precision( + self, + mock_time, + mock_state_class, + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_state_any_status, + mock_request_id + ): + """Test that time calculation is precise""" + # Arrange + mock_time.time.return_value = 1234.567 # Test with fractional seconds + mock_state_any_status.save = AsyncMock() + mock_state_class.find_one = AsyncMock(return_value=mock_state_any_status) + + # Act + result = await re_queue_after_signal( + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_request_id + ) + + # Assert + expected_enqueue_after = int(1234.567 * 1000) + 5000 + assert result.enqueue_after == expected_enqueue_after + assert mock_state_any_status.enqueue_after == expected_enqueue_after \ No newline at end of file diff --git a/state-manager/tests/unit/models/test_signal_models.py b/state-manager/tests/unit/models/test_signal_models.py new file mode 100644 index 00000000..4eea9141 --- /dev/null +++ b/state-manager/tests/unit/models/test_signal_models.py @@ -0,0 +1,273 @@ +import pytest +from pydantic import ValidationError + +from app.models.signal_models import PruneRequestModel, ReEnqueueAfterRequestModel, SignalResponseModel +from app.models.state_status_enum import StateStatusEnum + + +class TestPruneRequestModel: + """Test cases for PruneRequestModel""" + + def test_prune_request_model_valid_data(self): + """Test PruneRequestModel with valid data""" + # Arrange & Act + data = {"key": "value", "nested": {"data": "test"}} + model = PruneRequestModel(data=data) + + # Assert + assert model.data == data + + def test_prune_request_model_empty_data(self): + """Test PruneRequestModel with empty data""" + # Arrange & Act + data = {} + model = PruneRequestModel(data=data) + + # Assert + assert model.data == data + + def test_prune_request_model_complex_data(self): + """Test PruneRequestModel with complex nested data""" + # Arrange & Act + data = { + "string": "test", + "number": 42, + "boolean": True, + "list": [1, 2, 3], + "nested": { + "object": { + "deep": "value" + } + } + } + model = PruneRequestModel(data=data) + + # Assert + assert model.data == data + + def test_prune_request_model_missing_data(self): + """Test PruneRequestModel with missing data field""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + PruneRequestModel() # type: ignore + + assert "data" in str(exc_info.value) + + def test_prune_request_model_none_data(self): + """Test PruneRequestModel with None data""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + PruneRequestModel(data=None) # type: ignore + + assert "data" in str(exc_info.value) + + +class TestReEnqueueAfterRequestModel: + """Test cases for ReEnqueueAfterRequestModel""" + + def test_re_enqueue_after_request_model_valid_delay(self): + """Test ReEnqueueAfterRequestModel with valid delay""" + # Arrange & Act + delay = 5000 + model = ReEnqueueAfterRequestModel(enqueue_after=delay) + + # Assert + assert model.enqueue_after == delay + + def test_re_enqueue_after_request_model_zero_delay(self): + """Test ReEnqueueAfterRequestModel with zero delay""" + # Arrange & Act + with pytest.raises(Exception): + ReEnqueueAfterRequestModel(enqueue_after=0) + + def test_re_enqueue_after_request_model_negative_delay(self): + """Test ReEnqueueAfterRequestModel with negative delay""" + # Arrange & Act + with pytest.raises(Exception): + ReEnqueueAfterRequestModel(enqueue_after=-5000) + + def test_re_enqueue_after_request_model_large_delay(self): + """Test ReEnqueueAfterRequestModel with large delay""" + # Arrange & Act + delay = 86400000 # 24 hours + model = ReEnqueueAfterRequestModel(enqueue_after=delay) + + # Assert + assert model.enqueue_after == delay + + def test_re_enqueue_after_request_model_missing_enqueue_after(self): + """Test ReEnqueueAfterRequestModel with missing enqueue_after field""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + ReEnqueueAfterRequestModel() # type: ignore + + assert "enqueue_after" in str(exc_info.value) + + def test_re_enqueue_after_request_model_none_enqueue_after(self): + """Test ReEnqueueAfterRequestModel with None enqueue_after""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + ReEnqueueAfterRequestModel(enqueue_after=None) # type: ignore + + assert "enqueue_after" in str(exc_info.value) + + def test_re_enqueue_after_request_model_string_enqueue_after(self): + """Test ReEnqueueAfterRequestModel with string enqueue_after (should convert)""" + # Arrange & Act + delay = "5000" + model = ReEnqueueAfterRequestModel(enqueue_after=delay) # type: ignore + + # Assert + assert model.enqueue_after == 5000 + + def test_re_enqueue_after_request_model_float_enqueue_after(self): + """Test ReEnqueueAfterRequestModel with float enqueue_after (should convert)""" + # Arrange & Act + delay = 5000.0 + model = ReEnqueueAfterRequestModel(enqueue_after=delay) # type: ignore + + # Assert + assert model.enqueue_after == 5000 + + +class TestSignalResponseModel: + """Test cases for SignalResponseModel""" + + def test_signal_response_model_valid_data(self): + """Test SignalResponseModel with valid data""" + # Arrange & Act + enqueue_after = 1234567890 + status = StateStatusEnum.PRUNED + model = SignalResponseModel(enqueue_after=enqueue_after, status=status) + + # Assert + assert model.enqueue_after == enqueue_after + assert model.status == status + + def test_signal_response_model_created_status(self): + """Test SignalResponseModel with CREATED status""" + # Arrange & Act + enqueue_after = 1234567890 + status = StateStatusEnum.CREATED + model = SignalResponseModel(enqueue_after=enqueue_after, status=status) + + # Assert + assert model.enqueue_after == enqueue_after + assert model.status == status + + def test_signal_response_model_zero_enqueue_after(self): + """Test SignalResponseModel with zero enqueue_after""" + # Arrange & Act + enqueue_after = 0 + status = StateStatusEnum.PRUNED + model = SignalResponseModel(enqueue_after=enqueue_after, status=status) + + # Assert + assert model.enqueue_after == enqueue_after + assert model.status == status + + def test_signal_response_model_large_enqueue_after(self): + """Test SignalResponseModel with large enqueue_after""" + # Arrange & Act + enqueue_after = 9999999999999 + status = StateStatusEnum.CREATED + model = SignalResponseModel(enqueue_after=enqueue_after, status=status) + + # Assert + assert model.enqueue_after == enqueue_after + assert model.status == status + + def test_signal_response_model_missing_enqueue_after(self): + """Test SignalResponseModel with missing enqueue_after field""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + SignalResponseModel(status=StateStatusEnum.PRUNED) # type: ignore + + assert "enqueue_after" in str(exc_info.value) + + def test_signal_response_model_missing_status(self): + """Test SignalResponseModel with missing status field""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + SignalResponseModel(enqueue_after=1234567890) # type: ignore + + assert "status" in str(exc_info.value) + + def test_signal_response_model_none_enqueue_after(self): + """Test SignalResponseModel with None enqueue_after""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + SignalResponseModel(enqueue_after=None, status=StateStatusEnum.PRUNED) # type: ignore + + assert "enqueue_after" in str(exc_info.value) + + def test_signal_response_model_none_status(self): + """Test SignalResponseModel with None status""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + SignalResponseModel(enqueue_after=1234567890, status=None) # type: ignore + + assert "status" in str(exc_info.value) + + def test_signal_response_model_string_enqueue_after(self): + """Test SignalResponseModel with string enqueue_after (should convert)""" + # Arrange & Act + enqueue_after = "1234567890" + status = StateStatusEnum.PRUNED + model = SignalResponseModel(enqueue_after=enqueue_after, status=status) # type: ignore + + # Assert + assert model.enqueue_after == 1234567890 + assert model.status == status + + def test_signal_response_model_all_status_enum_values(self): + """Test SignalResponseModel with all possible status enum values""" + # Arrange + enqueue_after = 1234567890 + all_statuses = [ + StateStatusEnum.CREATED, + StateStatusEnum.QUEUED, + StateStatusEnum.EXECUTED, + StateStatusEnum.ERRORED, + StateStatusEnum.CANCELLED, + StateStatusEnum.SUCCESS, + StateStatusEnum.NEXT_CREATED_ERROR, + StateStatusEnum.PRUNED + ] + + for status in all_statuses: + # Act + model = SignalResponseModel(enqueue_after=enqueue_after, status=status) + + # Assert + assert model.enqueue_after == enqueue_after + assert model.status == status + + def test_signal_response_model_json_serialization(self): + """Test SignalResponseModel JSON serialization""" + # Arrange + enqueue_after = 1234567890 + status = StateStatusEnum.PRUNED + model = SignalResponseModel(enqueue_after=enqueue_after, status=status) + + # Act + json_data = model.model_dump() + + # Assert + assert json_data["enqueue_after"] == enqueue_after + assert json_data["status"] == status.value + + def test_signal_response_model_json_deserialization(self): + """Test SignalResponseModel JSON deserialization""" + # Arrange + json_data = { + "enqueue_after": 1234567890, + "status": "PRUNED" + } + + # Act + model = SignalResponseModel(**json_data) + + # Assert + assert model.enqueue_after == 1234567890 + assert model.status == StateStatusEnum.PRUNED \ No newline at end of file diff --git a/state-manager/tests/unit/test_routes.py b/state-manager/tests/unit/test_routes.py index 823c47d8..82cce20b 100644 --- a/state-manager/tests/unit/test_routes.py +++ b/state-manager/tests/unit/test_routes.py @@ -29,6 +29,8 @@ def test_router_has_correct_routes(self): assert any('/v0/namespace/{namespace_name}/graph/{graph_name}/states/create' in path for path in paths) assert any('/v0/namespace/{namespace_name}/states/{state_id}/executed' in path for path in paths) assert any('/v0/namespace/{namespace_name}/states/{state_id}/errored' in path for path in paths) + assert any('/v0/namespace/{namespace_name}/states/{state_id}/prune' in path for path in paths) + assert any('/v0/namespace/{namespace_name}/states/{state_id}/re-enqueue-after' in path for path in paths) # Graph template routes (there are two /graph/{graph_name} routes - GET and PUT) assert any('/v0/namespace/{namespace_name}/graph/{graph_name}' in path for path in paths) @@ -109,6 +111,84 @@ def test_create_request_model_validation(self): assert len(model.states) == 1 assert model.states[0].identifier == "node1" + def test_prune_request_model_validation(self): + """Test PruneRequestModel validation""" + from app.models.signal_models import PruneRequestModel + + # Test with valid data + valid_data = { + "data": {"key": "value", "nested": {"data": "test"}} + } + model = PruneRequestModel(**valid_data) + assert model.data == {"key": "value", "nested": {"data": "test"}} + + # Test with empty data + empty_data = {"data": {}} + model = PruneRequestModel(**empty_data) + assert model.data == {} + + # Test with complex data + complex_data = { + "data": { + "string": "test", + "number": 42, + "boolean": True, + "list": [1, 2, 3] + } + } + model = PruneRequestModel(**complex_data) + assert model.data["string"] == "test" + assert model.data["number"] == 42 + assert model.data["boolean"] is True + assert model.data["list"] == [1, 2, 3] + + def test_re_enqueue_after_request_model_validation(self): + """Test ReEnqueueAfterRequestModel validation""" + from app.models.signal_models import ReEnqueueAfterRequestModel + + # Test with valid data + valid_data = {"enqueue_after": 5000} + model = ReEnqueueAfterRequestModel(**valid_data) + assert model.enqueue_after == 5000 + + # Test with zero delay + zero_data = {"enqueue_after": 0} + with pytest.raises(Exception): + ReEnqueueAfterRequestModel(**zero_data) + + # Test with negative delay + negative_data = {"enqueue_after": -5000} + with pytest.raises(Exception): + ReEnqueueAfterRequestModel(**negative_data) + + # Test with large delay + large_data = {"enqueue_after": 86400000} + model = ReEnqueueAfterRequestModel(**large_data) + assert model.enqueue_after == 86400000 + + def test_signal_response_model_validation(self): + """Test SignalResponseModel validation""" + from app.models.signal_models import SignalResponseModel + from app.models.state_status_enum import StateStatusEnum + + # Test with valid data + valid_data = { + "enqueue_after": 1234567890, + "status": "PRUNED" + } + model = SignalResponseModel(**valid_data) + assert model.enqueue_after == 1234567890 + assert model.status == StateStatusEnum.PRUNED + + # Test with CREATED status + created_data = { + "enqueue_after": 1234567890, + "status": "CREATED" + } + model = SignalResponseModel(**created_data) + assert model.enqueue_after == 1234567890 + assert model.status == StateStatusEnum.CREATED + def test_executed_request_model_validation(self): """Test ExecutedRequestModel validation""" # Test with valid data @@ -331,7 +411,7 @@ async def test_enqueue_state_with_invalid_api_key(self, mock_enqueue_states, moc # Act & Assert with pytest.raises(HTTPException) as exc_info: - await enqueue_state("test_namespace", body, mock_request, None) + await enqueue_state("test_namespace", body, mock_request, None) # type: ignore assert exc_info.value.status_code == 401 assert exc_info.value.detail == "Invalid API key" @@ -385,7 +465,7 @@ async def test_trigger_graph_route_with_invalid_api_key(self, mock_trigger_graph # Act & Assert with pytest.raises(HTTPException) as exc_info: - await trigger_graph_route("test_namespace", "test_graph", body, mock_request, None) + await trigger_graph_route("test_namespace", "test_graph", body, mock_request, None) # type: ignore assert exc_info.value.status_code == 401 assert exc_info.value.detail == "Invalid API key" @@ -609,4 +689,156 @@ async def test_get_states_by_run_id_route_with_valid_api_key(self, mock_get_stat assert result.namespace == "test_namespace" assert result.run_id == "test_run" assert result.count == 1 - assert len(result.states) == 1 \ No newline at end of file + assert len(result.states) == 1 + + @patch('app.routes.prune_signal') + async def test_prune_state_route_with_valid_api_key(self, mock_prune_signal, mock_request): + """Test prune_state_route with valid API key""" + from app.routes import prune_state_route + from app.models.signal_models import PruneRequestModel, SignalResponseModel + from app.models.state_status_enum import StateStatusEnum + from beanie import PydanticObjectId + + # Arrange + state_id = "507f1f77bcf86cd799439011" + prune_request = PruneRequestModel(data={"key": "value"}) + expected_response = SignalResponseModel( + status=StateStatusEnum.PRUNED, + enqueue_after=1234567890 + ) + mock_prune_signal.return_value = expected_response + + # Act + result = await prune_state_route("test_namespace", state_id, prune_request, mock_request, "valid_key") + + # Assert + mock_prune_signal.assert_called_once_with("test_namespace", PydanticObjectId(state_id), prune_request, "test-request-id") + assert result == expected_response + + @patch('app.routes.prune_signal') + async def test_prune_state_route_with_invalid_api_key(self, mock_prune_signal, mock_request): + """Test prune_state_route with invalid API key""" + from app.routes import prune_state_route + from app.models.signal_models import PruneRequestModel + from fastapi import HTTPException, status + + # Arrange + state_id = "507f1f77bcf86cd799439011" + prune_request = PruneRequestModel(data={"key": "value"}) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await prune_state_route("test_namespace", state_id, prune_request, mock_request, None) # type: ignore + + assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED + assert exc_info.value.detail == "Invalid API key" + mock_prune_signal.assert_not_called() + + @patch('app.routes.re_queue_after_signal') + async def test_re_enqueue_after_state_route_with_valid_api_key(self, mock_re_queue_after_signal, mock_request): + """Test re_enqueue_after_state_route with valid API key""" + from app.routes import re_enqueue_after_state_route + from app.models.signal_models import ReEnqueueAfterRequestModel, SignalResponseModel + from app.models.state_status_enum import StateStatusEnum + from beanie import PydanticObjectId + + # Arrange + state_id = "507f1f77bcf86cd799439011" + re_enqueue_request = ReEnqueueAfterRequestModel(enqueue_after=5000) + expected_response = SignalResponseModel( + status=StateStatusEnum.CREATED, + enqueue_after=1234567890 + ) + mock_re_queue_after_signal.return_value = expected_response + + # Act + result = await re_enqueue_after_state_route("test_namespace", state_id, re_enqueue_request, mock_request, "valid_key") + + # Assert + mock_re_queue_after_signal.assert_called_once_with("test_namespace", PydanticObjectId(state_id), re_enqueue_request, "test-request-id") + assert result == expected_response + + @patch('app.routes.re_queue_after_signal') + async def test_re_enqueue_after_state_route_with_invalid_api_key(self, mock_re_queue_after_signal, mock_request): + """Test re_enqueue_after_state_route with invalid API key""" + from app.routes import re_enqueue_after_state_route + from app.models.signal_models import ReEnqueueAfterRequestModel + from fastapi import HTTPException, status + + # Arrange + state_id = "507f1f77bcf86cd799439011" + re_enqueue_request = ReEnqueueAfterRequestModel(enqueue_after=5000) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await re_enqueue_after_state_route("test_namespace", state_id, re_enqueue_request, mock_request, None) # type: ignore + + assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED + assert exc_info.value.detail == "Invalid API key" + mock_re_queue_after_signal.assert_not_called() + + @patch('app.routes.prune_signal') + async def test_prune_state_route_with_different_data(self, mock_prune_signal, mock_request): + """Test prune_state_route with different data payloads""" + from app.routes import prune_state_route + from app.models.signal_models import PruneRequestModel, SignalResponseModel + from app.models.state_status_enum import StateStatusEnum + from beanie import PydanticObjectId + + # Test cases with different data + test_cases = [ + {"simple": "value"}, + {"nested": {"data": "test"}}, + {"list": [1, 2, 3]}, + {"boolean": True, "number": 42}, + {} # Empty data + ] + + for test_data in test_cases: + # Arrange + state_id = "507f1f77bcf86cd799439011" + prune_request = PruneRequestModel(data=test_data) + expected_response = SignalResponseModel( + status=StateStatusEnum.PRUNED, + enqueue_after=1234567890 + ) + mock_prune_signal.return_value = expected_response + + # Act + result = await prune_state_route("test_namespace", state_id, prune_request, mock_request, "valid_key") + + # Assert + mock_prune_signal.assert_called_with("test_namespace", PydanticObjectId(state_id), prune_request, "test-request-id") + assert result == expected_response + + @patch('app.routes.re_queue_after_signal') + async def test_re_enqueue_after_state_route_with_different_delays(self, mock_re_queue_after_signal, mock_request): + """Test re_enqueue_after_state_route with different delay values""" + from app.routes import re_enqueue_after_state_route + from app.models.signal_models import ReEnqueueAfterRequestModel, SignalResponseModel + from app.models.state_status_enum import StateStatusEnum + from beanie import PydanticObjectId + + # Test cases with different delays + test_cases = [ + 1000, # 1 second + 60000, # 1 minute + 3600000 # 1 hour + ] + + for delay in test_cases: + # Arrange + state_id = "507f1f77bcf86cd799439011" + re_enqueue_request = ReEnqueueAfterRequestModel(enqueue_after=delay) + expected_response = SignalResponseModel( + status=StateStatusEnum.CREATED, + enqueue_after=1234567890 + ) + mock_re_queue_after_signal.return_value = expected_response + + # Act + result = await re_enqueue_after_state_route("test_namespace", state_id, re_enqueue_request, mock_request, "valid_key") + + # Assert + mock_re_queue_after_signal.assert_called_with("test_namespace", PydanticObjectId(state_id), re_enqueue_request, "test-request-id") + assert result == expected_response \ No newline at end of file