From c9d18d7febdce087117b8293fd7e6205be71cfb4 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Sat, 30 Aug 2025 13:42:37 +0530 Subject: [PATCH 01/19] Add prune and re-enqueue signal functionality - Introduced new routes for pruning states and re-enqueuing states after a specified time. - Added corresponding controller functions to handle the logic for pruning and re-enqueuing states, including validation and error handling. - Created new signal models for request and response structures related to pruning and re-enqueuing operations. - Updated the State model to include a new field for enqueue_after, enhancing state management capabilities. - Enhanced logging for better traceability of operations related to state management. --- .../app/controller/enqueue_states.py | 4 +- state-manager/app/controller/prune_signal.py | 32 +++++++++++++ .../app/controller/re_queue_after_singal.py | 32 +++++++++++++ state-manager/app/models/db/state.py | 28 ++++++++++- state-manager/app/models/signal_models.py | 14 ++++++ state-manager/app/models/state_status_enum.py | 4 +- state-manager/app/routes.py | 46 +++++++++++++++++++ 7 files changed, 155 insertions(+), 5 deletions(-) create mode 100644 state-manager/app/controller/prune_signal.py create mode 100644 state-manager/app/controller/re_queue_after_singal.py create mode 100644 state-manager/app/models/signal_models.py diff --git a/state-manager/app/controller/enqueue_states.py b/state-manager/app/controller/enqueue_states.py index b27a6bef..a5c36b52 100644 --- a/state-manager/app/controller/enqueue_states.py +++ b/state-manager/app/controller/enqueue_states.py @@ -1,4 +1,5 @@ import asyncio +import time from ..models.enqueue_request import EnqueueRequestModel from ..models.enqueue_response import EnqueueResponseModel, StateModel @@ -18,7 +19,8 @@ async def find_state(namespace_name: str, nodes: list[str]) -> State | None: "status": StateStatusEnum.CREATED, "node_name": { "$in": nodes - } + }, + "enqueue_after": {"$lte": int(time.time() * 1000)} }, { "$set": {"status": StateStatusEnum.QUEUED} diff --git a/state-manager/app/controller/prune_signal.py b/state-manager/app/controller/prune_signal.py new file mode 100644 index 00000000..50a2dd10 --- /dev/null +++ b/state-manager/app/controller/prune_signal.py @@ -0,0 +1,32 @@ +from app.models.signal_models import PruneRequestModel, SignalResponseModel +from fastapi import HTTPException, status +from beanie import PydanticObjectId + +from app.models.db.state import State +from app.models.state_status_enum import StateStatusEnum +from app.singletons.logs_manager import LogsManager + +logger = LogsManager().get_logger() + +async def prune_signal(namespace_name: str, state_id: PydanticObjectId, body: PruneRequestModel, x_exosphere_request_id: str) -> SignalResponseModel: + + try: + logger.info(f"Recieved prune signal for state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + + state = await State.find_one(State.id == state_id) + + if not state: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="State not found") + + if state.status != StateStatusEnum.CREATED: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="State is not created") + + state.status = StateStatusEnum.PRUNED + state.data = body.data + await state.save() + + return SignalResponseModel(status=StateStatusEnum.PRUNED, enqueue_after=state.enqueue_after) + + except Exception as e: + logger.error(f"Error pruning state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id, error=e) + raise \ No newline at end of file diff --git a/state-manager/app/controller/re_queue_after_singal.py b/state-manager/app/controller/re_queue_after_singal.py new file mode 100644 index 00000000..f2af585a --- /dev/null +++ b/state-manager/app/controller/re_queue_after_singal.py @@ -0,0 +1,32 @@ +from app.models.signal_models import ReEnqueueAfterRequestModel, SignalResponseModel +from fastapi import HTTPException, status +from beanie import PydanticObjectId + +from app.models.db.state import State +from app.models.state_status_enum import StateStatusEnum +from app.singletons.logs_manager import LogsManager + +logger = LogsManager().get_logger() + +async def re_queue_after_signal(namespace_name: str, state_id: PydanticObjectId, body: ReEnqueueAfterRequestModel, x_exosphere_request_id: str) -> SignalResponseModel: + + try: + logger.info(f"Recieved re-queue after signal for state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + + state = await State.find_one(State.id == state_id) + + if not state: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="State not found") + + if state.status != StateStatusEnum.CREATED: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="State is not created") + + state.status = StateStatusEnum.CREATED + state.enqueue_after = state.enqueue_after + body.enqueue_after + await state.save() + + return SignalResponseModel(status=StateStatusEnum.CREATED, enqueue_after=state.enqueue_after) + + except Exception as e: + logger.error(f"Error re-queueing state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id, error=e) + raise \ No newline at end of file diff --git a/state-manager/app/models/db/state.py b/state-manager/app/models/db/state.py index 8d77d967..b24d2c5a 100644 --- a/state-manager/app/models/db/state.py +++ b/state-manager/app/models/db/state.py @@ -7,7 +7,7 @@ from typing import Any, Optional import hashlib import json - +import time class State(BaseDatabaseModel): node_name: str = Field(..., description="Name of the node of the state") @@ -18,10 +18,12 @@ class State(BaseDatabaseModel): status: StateStatusEnum = Field(..., description="Status of the state") inputs: dict[str, Any] = Field(..., description="Inputs of the state") outputs: dict[str, Any] = Field(..., description="Outputs of the state") + data: dict[str, Any] = Field(default_factory=dict, description="Data of the state") error: Optional[str] = Field(None, description="Error message") parents: dict[str, PydanticObjectId] = Field(default_factory=dict, description="Parents of the state") does_unites: bool = Field(default=False, description="Whether this state unites other states") state_fingerprint: str = Field(default="", description="Fingerprint of the state") + enqueue_after: int = Field(default_factory=lambda: int(time.time() * 1000), description="Unix time in milliseconds after which the state should be enqueued") @before_event([Insert, Replace, Save]) def _generate_fingerprint(self): @@ -65,5 +67,29 @@ class Settings: partialFilterExpression={ "does_unites": True } + ), + IndexModel( + [ + ("enqueue_after", 1) + ], + name="idx_enqueue_after" + ), + IndexModel( + [ + ("status", 1) + ], + name="idx_status" + ), + IndexModel( + [ + ("namespace_name", 1), + ], + name="idx_namespace_name" + ), + IndexModel( + [ + ("node_name", 1), + ], + name="idx_node_name" ) ] \ No newline at end of file diff --git a/state-manager/app/models/signal_models.py b/state-manager/app/models/signal_models.py new file mode 100644 index 00000000..491797a3 --- /dev/null +++ b/state-manager/app/models/signal_models.py @@ -0,0 +1,14 @@ +from pydantic import BaseModel, Field +from .state_status_enum import StateStatusEnum +from typing import Any + + +class SignalResponseModel(BaseModel): + enqueue_after: int = Field(..., description="Unix time in milliseconds after which the state should be re-enqueued") + status: StateStatusEnum = Field(..., description="Status of the state") + +class PruneRequestModel(BaseModel): + data: dict[str, Any] = Field(..., description="Data of the state") + +class ReEnqueueAfterRequestModel(BaseModel): + enqueue_after: int = Field(..., description="Unix time in milliseconds after which the state should be re-enqueued") \ No newline at end of file diff --git a/state-manager/app/models/state_status_enum.py b/state-manager/app/models/state_status_enum.py index 8da97002..7760176d 100644 --- a/state-manager/app/models/state_status_enum.py +++ b/state-manager/app/models/state_status_enum.py @@ -6,10 +6,8 @@ class StateStatusEnum(str, Enum): CREATED = 'CREATED' QUEUED = 'QUEUED' EXECUTED = 'EXECUTED' - NEXT_CREATED = 'NEXT_CREATED' - RETRY_CREATED = 'RETRY_CREATED' - TIMEDOUT = 'TIMEDOUT' ERRORED = 'ERRORED' CANCELLED = 'CANCELLED' SUCCESS = 'SUCCESS' NEXT_CREATED_ERROR = 'NEXT_CREATED_ERROR' + PRUNED = 'PRUNED' \ No newline at end of file diff --git a/state-manager/app/routes.py b/state-manager/app/routes.py index 9e956db2..9b160280 100644 --- a/state-manager/app/routes.py +++ b/state-manager/app/routes.py @@ -40,6 +40,14 @@ from .models.graph_structure_models import GraphStructureResponse from .controller.get_graph_structure import get_graph_structure +### singnals +from .models.signal_models import SignalResponseModel +from .models.signal_models import PruneRequestModel +from .controller.prune_signal import prune_signal +from .models.signal_models import ReEnqueueAfterRequestModel +from .controller.re_queue_after_singal import re_queue_after_signal + + logger = LogsManager().get_logger() router = APIRouter(prefix="/v0/namespace/{namespace_name}") @@ -145,6 +153,44 @@ async def errored_state_route(namespace_name: str, state_id: str, body: ErroredR return await errored_state(namespace_name, PydanticObjectId(state_id), body, x_exosphere_request_id) +@router.post( + "/states/{state_id}/prune", + response_model=SignalResponseModel, + status_code=status.HTTP_200_OK, + response_description="State skipped successfully", + tags=["state"] +) +async def prune_state_route(namespace_name: str, state_id: str, body: PruneRequestModel, request: Request, api_key: str = Depends(check_api_key)): + x_exosphere_request_id = getattr(request.state, "x_exosphere_request_id", str(uuid4())) + + if api_key: + logger.info(f"API key is valid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + else: + logger.error(f"API key is invalid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") + + return await prune_signal(namespace_name, PydanticObjectId(state_id), body, x_exosphere_request_id) + + +@router.post( + "/states/{state_id}/re-enqueue-after", + response_model=SignalResponseModel, + status_code=status.HTTP_200_OK, + response_description="State re-enqueued successfully", + tags=["state"] +) +async def re_enqueue_after_state_route(namespace_name: str, state_id: str, body: ReEnqueueAfterRequestModel, request: Request, api_key: str = Depends(check_api_key)): + x_exosphere_request_id = getattr(request.state, "x_exosphere_request_id", str(uuid4())) + + if api_key: + logger.info(f"API key is valid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + else: + logger.error(f"API key is invalid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") + + return await re_queue_after_signal(namespace_name, PydanticObjectId(state_id), body, x_exosphere_request_id) + + @router.put( "/graph/{graph_name}", response_model=UpsertGraphTemplateResponse, From e6367aff49495f8748fdc1bea724924d83724f90 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Sat, 30 Aug 2025 13:47:33 +0530 Subject: [PATCH 02/19] Fix typos in logging messages and comments - Corrected the spelling of "Recieved" to "Received" in logging statements within the prune_signal and re_queue_after_signal functions for improved clarity. - Fixed a typo in the comment header from "singnals" to "signals" in the routes.py file, enhancing code readability. --- state-manager/app/controller/prune_signal.py | 2 +- state-manager/app/controller/re_queue_after_singal.py | 2 +- state-manager/app/routes.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/state-manager/app/controller/prune_signal.py b/state-manager/app/controller/prune_signal.py index 50a2dd10..9c37d39a 100644 --- a/state-manager/app/controller/prune_signal.py +++ b/state-manager/app/controller/prune_signal.py @@ -11,7 +11,7 @@ async def prune_signal(namespace_name: str, state_id: PydanticObjectId, body: PruneRequestModel, x_exosphere_request_id: str) -> SignalResponseModel: try: - logger.info(f"Recieved prune signal for state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + logger.info(f"Received prune signal for state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) state = await State.find_one(State.id == state_id) diff --git a/state-manager/app/controller/re_queue_after_singal.py b/state-manager/app/controller/re_queue_after_singal.py index f2af585a..c67db7f0 100644 --- a/state-manager/app/controller/re_queue_after_singal.py +++ b/state-manager/app/controller/re_queue_after_singal.py @@ -11,7 +11,7 @@ async def re_queue_after_signal(namespace_name: str, state_id: PydanticObjectId, body: ReEnqueueAfterRequestModel, x_exosphere_request_id: str) -> SignalResponseModel: try: - logger.info(f"Recieved re-queue after signal for state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + logger.info(f"Received re-queue after signal for state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) state = await State.find_one(State.id == state_id) diff --git a/state-manager/app/routes.py b/state-manager/app/routes.py index 9b160280..01af92a7 100644 --- a/state-manager/app/routes.py +++ b/state-manager/app/routes.py @@ -40,7 +40,7 @@ from .models.graph_structure_models import GraphStructureResponse from .controller.get_graph_structure import get_graph_structure -### singnals +### signals from .models.signal_models import SignalResponseModel from .models.signal_models import PruneRequestModel from .controller.prune_signal import prune_signal From bd0a89800db9f8e36bed63f4fb237c635c2470f0 Mon Sep 17 00:00:00 2001 From: Nivedit Jain Date: Sat, 30 Aug 2025 13:48:48 +0530 Subject: [PATCH 03/19] Update state-manager/app/models/signal_models.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- state-manager/app/models/signal_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/state-manager/app/models/signal_models.py b/state-manager/app/models/signal_models.py index 491797a3..3e86d53a 100644 --- a/state-manager/app/models/signal_models.py +++ b/state-manager/app/models/signal_models.py @@ -11,4 +11,4 @@ class PruneRequestModel(BaseModel): data: dict[str, Any] = Field(..., description="Data of the state") class ReEnqueueAfterRequestModel(BaseModel): - enqueue_after: int = Field(..., description="Unix time in milliseconds after which the state should be re-enqueued") \ No newline at end of file + enqueue_after: int = Field(..., description="Duration in milliseconds to delay the re-enqueuing of the state") \ No newline at end of file From fc9c14c0f380cd83e73daa1b93a1207c03f98aa6 Mon Sep 17 00:00:00 2001 From: Nivedit Jain Date: Sat, 30 Aug 2025 13:52:02 +0530 Subject: [PATCH 04/19] Update state-manager/app/controller/re_queue_after_singal.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- state-manager/app/controller/re_queue_after_singal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/state-manager/app/controller/re_queue_after_singal.py b/state-manager/app/controller/re_queue_after_singal.py index c67db7f0..5b7f43a6 100644 --- a/state-manager/app/controller/re_queue_after_singal.py +++ b/state-manager/app/controller/re_queue_after_singal.py @@ -22,7 +22,7 @@ async def re_queue_after_signal(namespace_name: str, state_id: PydanticObjectId, raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="State is not created") state.status = StateStatusEnum.CREATED - state.enqueue_after = state.enqueue_after + body.enqueue_after + state.enqueue_after = int(time.time() * 1000) + body.enqueue_after await state.save() return SignalResponseModel(status=StateStatusEnum.CREATED, enqueue_after=state.enqueue_after) From 4cc63feeafb018f0647f7b00b81fae03f271004d Mon Sep 17 00:00:00 2001 From: Nivedit Jain Date: Sat, 30 Aug 2025 13:52:29 +0530 Subject: [PATCH 05/19] Update state-manager/app/controller/re_queue_after_singal.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- state-manager/app/controller/re_queue_after_singal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/state-manager/app/controller/re_queue_after_singal.py b/state-manager/app/controller/re_queue_after_singal.py index 5b7f43a6..33412ac1 100644 --- a/state-manager/app/controller/re_queue_after_singal.py +++ b/state-manager/app/controller/re_queue_after_singal.py @@ -13,7 +13,7 @@ async def re_queue_after_signal(namespace_name: str, state_id: PydanticObjectId, try: logger.info(f"Received re-queue after signal for state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) - state = await State.find_one(State.id == state_id) + state = await State.find_one(State.id == state_id, State.namespace_name == namespace_name) if not state: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="State not found") From b6127ffbd6a6e37caccf2aabe98429fa57abece1 Mon Sep 17 00:00:00 2001 From: Nivedit Jain Date: Sat, 30 Aug 2025 13:52:35 +0530 Subject: [PATCH 06/19] Update state-manager/app/controller/prune_signal.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- state-manager/app/controller/prune_signal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/state-manager/app/controller/prune_signal.py b/state-manager/app/controller/prune_signal.py index 9c37d39a..14835290 100644 --- a/state-manager/app/controller/prune_signal.py +++ b/state-manager/app/controller/prune_signal.py @@ -13,7 +13,7 @@ async def prune_signal(namespace_name: str, state_id: PydanticObjectId, body: Pr try: logger.info(f"Received prune signal for state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) - state = await State.find_one(State.id == state_id) + state = await State.find_one(State.id == state_id, State.namespace_name == namespace_name) if not state: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="State not found") From 87870b04c6407f5b0e151d77da779adb1a89e57b Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Sat, 30 Aug 2025 13:54:05 +0530 Subject: [PATCH 07/19] Refactor re-queue after signal functionality and update state model - Corrected the import path for the re_queue_after_signal controller in routes.py. - Introduced a new controller function in re_queue_after_signal.py to handle re-queuing of states, including error handling and logging. - Updated the State model to consolidate indexing for enqueue_after and status fields, improving database query performance. --- ...ter_singal.py => re_queue_after_signal.py} | 5 +---- state-manager/app/models/db/state.py | 21 +++---------------- state-manager/app/routes.py | 2 +- 3 files changed, 5 insertions(+), 23 deletions(-) rename state-manager/app/controller/{re_queue_after_singal.py => re_queue_after_signal.py} (88%) diff --git a/state-manager/app/controller/re_queue_after_singal.py b/state-manager/app/controller/re_queue_after_signal.py similarity index 88% rename from state-manager/app/controller/re_queue_after_singal.py rename to state-manager/app/controller/re_queue_after_signal.py index c67db7f0..984fe96d 100644 --- a/state-manager/app/controller/re_queue_after_singal.py +++ b/state-manager/app/controller/re_queue_after_signal.py @@ -17,10 +17,7 @@ async def re_queue_after_signal(namespace_name: str, state_id: PydanticObjectId, if not state: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="State not found") - - if state.status != StateStatusEnum.CREATED: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="State is not created") - + state.status = StateStatusEnum.CREATED state.enqueue_after = state.enqueue_after + body.enqueue_after await state.save() diff --git a/state-manager/app/models/db/state.py b/state-manager/app/models/db/state.py index b24d2c5a..6651b9c0 100644 --- a/state-manager/app/models/db/state.py +++ b/state-manager/app/models/db/state.py @@ -70,26 +70,11 @@ class Settings: ), IndexModel( [ - ("enqueue_after", 1) - ], - name="idx_enqueue_after" - ), - IndexModel( - [ - ("status", 1) - ], - name="idx_status" - ), - IndexModel( - [ + ("enqueue_after", 1), + ("status", 1), ("namespace_name", 1), - ], - name="idx_namespace_name" - ), - IndexModel( - [ ("node_name", 1), ], - name="idx_node_name" + name="idx_enqueue_after" ) ] \ No newline at end of file diff --git a/state-manager/app/routes.py b/state-manager/app/routes.py index 01af92a7..a108c77c 100644 --- a/state-manager/app/routes.py +++ b/state-manager/app/routes.py @@ -45,7 +45,7 @@ from .models.signal_models import PruneRequestModel from .controller.prune_signal import prune_signal from .models.signal_models import ReEnqueueAfterRequestModel -from .controller.re_queue_after_singal import re_queue_after_signal +from .controller.re_queue_after_signal import re_queue_after_signal logger = LogsManager().get_logger() From c689a8eda5e0d2d68e4eb2a23571f30f48bd70d6 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Sat, 30 Aug 2025 13:55:22 +0530 Subject: [PATCH 08/19] Added import for time module in re_queue_after_signal.py to support time-related functionality. --- state-manager/app/controller/re_queue_after_signal.py | 1 + 1 file changed, 1 insertion(+) diff --git a/state-manager/app/controller/re_queue_after_signal.py b/state-manager/app/controller/re_queue_after_signal.py index 8931bf6c..57eb2256 100644 --- a/state-manager/app/controller/re_queue_after_signal.py +++ b/state-manager/app/controller/re_queue_after_signal.py @@ -1,6 +1,7 @@ from app.models.signal_models import ReEnqueueAfterRequestModel, SignalResponseModel from fastapi import HTTPException, status from beanie import PydanticObjectId +import time from app.models.db.state import State from app.models.state_status_enum import StateStatusEnum From 55a47544830f5ca5c0a1efa5cc552ce8e6f292e8 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Sat, 30 Aug 2025 14:09:11 +0530 Subject: [PATCH 09/19] Add unit tests for prune and re-enqueue signal functionality - Introduced comprehensive tests for the new prune and re-enqueue signal routes, validating request models and response handling. - Implemented model validation tests for PruneRequestModel and ReEnqueueAfterRequestModel, ensuring correct data handling. - Added controller tests for prune_signal and re_queue_after_signal functions, covering success and error scenarios. - Enhanced SignalResponseModel tests to verify correct serialization and deserialization of responses. - Improved overall test coverage for state management operations, ensuring robustness and reliability. --- .../unit/controller/test_prune_signal.py | 319 +++++++++++++++++ .../controller/test_re_queue_after_signal.py | 325 ++++++++++++++++++ .../tests/unit/models/test_signal_models.py | 279 +++++++++++++++ state-manager/tests/unit/test_routes.py | 240 ++++++++++++- 4 files changed, 1160 insertions(+), 3 deletions(-) create mode 100644 state-manager/tests/unit/controller/test_prune_signal.py create mode 100644 state-manager/tests/unit/controller/test_re_queue_after_signal.py create mode 100644 state-manager/tests/unit/models/test_signal_models.py diff --git a/state-manager/tests/unit/controller/test_prune_signal.py b/state-manager/tests/unit/controller/test_prune_signal.py new file mode 100644 index 00000000..17011637 --- /dev/null +++ b/state-manager/tests/unit/controller/test_prune_signal.py @@ -0,0 +1,319 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from fastapi import HTTPException, status +from beanie import PydanticObjectId + +from app.controller.prune_signal import prune_signal +from app.models.signal_models import PruneRequestModel, SignalResponseModel +from app.models.state_status_enum import StateStatusEnum + + +class TestPruneSignal: + """Test cases for prune_signal function""" + + @pytest.fixture + def mock_request_id(self): + return "test-request-id" + + @pytest.fixture + def mock_namespace(self): + return "test_namespace" + + @pytest.fixture + def mock_state_id(self): + return PydanticObjectId() + + @pytest.fixture + def mock_prune_request(self): + return PruneRequestModel( + data={"key": "value", "nested": {"data": "test"}} + ) + + @pytest.fixture + def mock_state_created(self): + state = MagicMock() + state.id = PydanticObjectId() + state.status = StateStatusEnum.CREATED + state.enqueue_after = 1234567890 + return state + + @patch('app.controller.prune_signal.State') + async def test_prune_signal_success( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_prune_request, + mock_state_created, + mock_request_id + ): + """Test successful pruning of state""" + # Arrange + mock_state_created.save = AsyncMock() + mock_state_class.find_one = AsyncMock(return_value=mock_state_created) + + # Act + result = await prune_signal( + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ) + + # Assert + assert result.status == StateStatusEnum.PRUNED + assert result.enqueue_after == 1234567890 + assert mock_state_created.status == StateStatusEnum.PRUNED + assert mock_state_created.data == mock_prune_request.data + assert mock_state_created.save.call_count == 1 + assert mock_state_class.find_one.call_count == 1 + + @patch('app.controller.prune_signal.State') + async def test_prune_signal_state_not_found( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ): + """Test when state is not found""" + # Arrange + mock_state_class.find_one = AsyncMock(return_value=None) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await prune_signal( + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ) + + assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND + assert exc_info.value.detail == "State not found" + + @patch('app.controller.prune_signal.State') + async def test_prune_signal_invalid_status_queued( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ): + """Test when state is in QUEUED status (invalid for pruning)""" + # Arrange + mock_state = MagicMock() + mock_state.status = StateStatusEnum.QUEUED + mock_state_class.find_one = AsyncMock(return_value=mock_state) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await prune_signal( + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ) + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert exc_info.value.detail == "State is not created" + + @patch('app.controller.prune_signal.State') + async def test_prune_signal_invalid_status_executed( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ): + """Test when state is in EXECUTED status (invalid for pruning)""" + # Arrange + mock_state = MagicMock() + mock_state.status = StateStatusEnum.EXECUTED + mock_state_class.find_one = AsyncMock(return_value=mock_state) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await prune_signal( + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ) + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert exc_info.value.detail == "State is not created" + + @patch('app.controller.prune_signal.State') + async def test_prune_signal_invalid_status_errored( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ): + """Test when state is in ERRORED status (invalid for pruning)""" + # Arrange + mock_state = MagicMock() + mock_state.status = StateStatusEnum.ERRORED + mock_state_class.find_one = AsyncMock(return_value=mock_state) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await prune_signal( + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ) + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert exc_info.value.detail == "State is not created" + + @patch('app.controller.prune_signal.State') + async def test_prune_signal_invalid_status_pruned( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ): + """Test when state is already in PRUNED status (invalid for pruning)""" + # Arrange + mock_state = MagicMock() + mock_state.status = StateStatusEnum.PRUNED + mock_state_class.find_one = AsyncMock(return_value=mock_state) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await prune_signal( + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ) + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert exc_info.value.detail == "State is not created" + + @patch('app.controller.prune_signal.State') + async def test_prune_signal_database_error( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ): + """Test handling of database errors""" + # Arrange + mock_state_class.find_one = MagicMock(side_effect=Exception("Database error")) + + # Act & Assert + with pytest.raises(Exception) as exc_info: + await prune_signal( + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ) + + assert str(exc_info.value) == "Database error" + + @patch('app.controller.prune_signal.State') + async def test_prune_signal_save_error( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_prune_request, + mock_state_created, + mock_request_id + ): + """Test handling of save errors""" + # Arrange + mock_state_created.save = AsyncMock(side_effect=Exception("Save error")) + mock_state_class.find_one = AsyncMock(return_value=mock_state_created) + + # Act & Assert + with pytest.raises(Exception) as exc_info: + await prune_signal( + mock_namespace, + mock_state_id, + mock_prune_request, + mock_request_id + ) + + assert str(exc_info.value) == "Save error" + + @patch('app.controller.prune_signal.State') + async def test_prune_signal_with_empty_data( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_state_created, + mock_request_id + ): + """Test pruning with empty data""" + # Arrange + prune_request = PruneRequestModel(data={}) + mock_state_created.save = AsyncMock() + mock_state_class.find_one = AsyncMock(return_value=mock_state_created) + + # Act + result = await prune_signal( + mock_namespace, + mock_state_id, + prune_request, + mock_request_id + ) + + # Assert + assert result.status == StateStatusEnum.PRUNED + assert mock_state_created.data == {} + assert mock_state_created.save.call_count == 1 + + @patch('app.controller.prune_signal.State') + async def test_prune_signal_with_complex_data( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_state_created, + mock_request_id + ): + """Test pruning with complex nested data""" + # Arrange + complex_data = { + "string": "test", + "number": 42, + "boolean": True, + "list": [1, 2, 3], + "nested": { + "object": { + "deep": "value" + } + } + } + prune_request = PruneRequestModel(data=complex_data) + mock_state_created.save = AsyncMock() + mock_state_class.find_one = AsyncMock(return_value=mock_state_created) + + # Act + result = await prune_signal( + mock_namespace, + mock_state_id, + prune_request, + mock_request_id + ) + + # Assert + assert result.status == StateStatusEnum.PRUNED + assert mock_state_created.data == complex_data + assert mock_state_created.save.call_count == 1 \ No newline at end of file diff --git a/state-manager/tests/unit/controller/test_re_queue_after_signal.py b/state-manager/tests/unit/controller/test_re_queue_after_signal.py new file mode 100644 index 00000000..fb9f6c84 --- /dev/null +++ b/state-manager/tests/unit/controller/test_re_queue_after_signal.py @@ -0,0 +1,325 @@ +import pytest +import time +from unittest.mock import AsyncMock, MagicMock, patch +from fastapi import HTTPException, status +from beanie import PydanticObjectId + +from app.controller.re_queue_after_signal import re_queue_after_signal +from app.models.signal_models import ReEnqueueAfterRequestModel, SignalResponseModel +from app.models.state_status_enum import StateStatusEnum + + +class TestReQueueAfterSignal: + """Test cases for re_queue_after_signal function""" + + @pytest.fixture + def mock_request_id(self): + return "test-request-id" + + @pytest.fixture + def mock_namespace(self): + return "test_namespace" + + @pytest.fixture + def mock_state_id(self): + return PydanticObjectId() + + @pytest.fixture + def mock_re_enqueue_request(self): + return ReEnqueueAfterRequestModel( + enqueue_after=5000 # 5 seconds in milliseconds + ) + + @pytest.fixture + def mock_state_any_status(self): + state = MagicMock() + state.id = PydanticObjectId() + state.status = StateStatusEnum.QUEUED # Any status is valid for re-enqueue + state.enqueue_after = 1234567890 + return state + + @patch('app.controller.re_queue_after_signal.State') + @patch('app.controller.re_queue_after_signal.time') + async def test_re_queue_after_signal_success( + self, + mock_time, + mock_state_class, + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_state_any_status, + mock_request_id + ): + """Test successful re-enqueuing of state""" + # Arrange + mock_time.time.return_value = 1000.0 # Mock current time + mock_state_any_status.save = AsyncMock() + mock_state_class.find_one = AsyncMock(return_value=mock_state_any_status) + + # Act + result = await re_queue_after_signal( + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_request_id + ) + + # Assert + assert result.status == StateStatusEnum.CREATED + assert result.enqueue_after == 1005000 # 1000 * 1000 + 5000 + assert mock_state_any_status.status == StateStatusEnum.CREATED + assert mock_state_any_status.enqueue_after == 1005000 + assert mock_state_any_status.save.call_count == 1 + assert mock_state_class.find_one.call_count == 1 + + @patch('app.controller.re_queue_after_signal.State') + async def test_re_queue_after_signal_state_not_found( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_request_id + ): + """Test when state is not found""" + # Arrange + mock_state_class.find_one = AsyncMock(return_value=None) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await re_queue_after_signal( + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_request_id + ) + + assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND + assert exc_info.value.detail == "State not found" + + @patch('app.controller.re_queue_after_signal.State') + @patch('app.controller.re_queue_after_signal.time') + async def test_re_queue_after_signal_with_zero_delay( + self, + mock_time, + mock_state_class, + mock_namespace, + mock_state_id, + mock_state_any_status, + mock_request_id + ): + """Test re-enqueuing with zero delay""" + # Arrange + mock_time.time.return_value = 1000.0 + re_enqueue_request = ReEnqueueAfterRequestModel(enqueue_after=0) + mock_state_any_status.save = AsyncMock() + mock_state_class.find_one = AsyncMock(return_value=mock_state_any_status) + + # Act + result = await re_queue_after_signal( + mock_namespace, + mock_state_id, + re_enqueue_request, + mock_request_id + ) + + # Assert + assert result.status == StateStatusEnum.CREATED + assert result.enqueue_after == 1000000 # 1000 * 1000 + 0 + assert mock_state_any_status.enqueue_after == 1000000 + assert mock_state_any_status.save.call_count == 1 + + @patch('app.controller.re_queue_after_signal.State') + @patch('app.controller.re_queue_after_signal.time') + async def test_re_queue_after_signal_with_large_delay( + self, + mock_time, + mock_state_class, + mock_namespace, + mock_state_id, + mock_state_any_status, + mock_request_id + ): + """Test re-enqueuing with large delay""" + # Arrange + mock_time.time.return_value = 1000.0 + re_enqueue_request = ReEnqueueAfterRequestModel(enqueue_after=86400000) # 24 hours + mock_state_any_status.save = AsyncMock() + mock_state_class.find_one = AsyncMock(return_value=mock_state_any_status) + + # Act + result = await re_queue_after_signal( + mock_namespace, + mock_state_id, + re_enqueue_request, + mock_request_id + ) + + # Assert + assert result.status == StateStatusEnum.CREATED + assert result.enqueue_after == 87400000 # 1000 * 1000 + 86400000 + assert mock_state_any_status.enqueue_after == 87400000 + assert mock_state_any_status.save.call_count == 1 + + @patch('app.controller.re_queue_after_signal.State') + @patch('app.controller.re_queue_after_signal.time') + async def test_re_queue_after_signal_with_negative_delay( + self, + mock_time, + mock_state_class, + mock_namespace, + mock_state_id, + mock_state_any_status, + mock_request_id + ): + """Test re-enqueuing with negative delay (should still work)""" + # Arrange + mock_time.time.return_value = 1000.0 + re_enqueue_request = ReEnqueueAfterRequestModel(enqueue_after=-5000) # Negative delay + mock_state_any_status.save = AsyncMock() + mock_state_class.find_one = AsyncMock(return_value=mock_state_any_status) + + # Act + result = await re_queue_after_signal( + mock_namespace, + mock_state_id, + re_enqueue_request, + mock_request_id + ) + + # Assert + assert result.status == StateStatusEnum.CREATED + assert result.enqueue_after == 995000 # 1000 * 1000 + (-5000) + assert mock_state_any_status.enqueue_after == 995000 + assert mock_state_any_status.save.call_count == 1 + + @patch('app.controller.re_queue_after_signal.State') + async def test_re_queue_after_signal_database_error( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_request_id + ): + """Test handling of database errors""" + # Arrange + mock_state_class.find_one = MagicMock(side_effect=Exception("Database error")) + + # Act & Assert + with pytest.raises(Exception) as exc_info: + await re_queue_after_signal( + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_request_id + ) + + assert str(exc_info.value) == "Database error" + + @patch('app.controller.re_queue_after_signal.State') + @patch('app.controller.re_queue_after_signal.time') + async def test_re_queue_after_signal_save_error( + self, + mock_time, + mock_state_class, + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_state_any_status, + mock_request_id + ): + """Test handling of save errors""" + # Arrange + mock_time.time.return_value = 1000.0 + mock_state_any_status.save = AsyncMock(side_effect=Exception("Save error")) + mock_state_class.find_one = AsyncMock(return_value=mock_state_any_status) + + # Act & Assert + with pytest.raises(Exception) as exc_info: + await re_queue_after_signal( + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_request_id + ) + + assert str(exc_info.value) == "Save error" + + @patch('app.controller.re_queue_after_signal.State') + @patch('app.controller.re_queue_after_signal.time') + async def test_re_queue_after_signal_from_different_statuses( + self, + mock_time, + mock_state_class, + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_request_id + ): + """Test re-enqueuing from different initial statuses""" + # Arrange + mock_time.time.return_value = 1000.0 + + test_cases = [ + StateStatusEnum.CREATED, + StateStatusEnum.QUEUED, + StateStatusEnum.EXECUTED, + StateStatusEnum.ERRORED, + StateStatusEnum.CANCELLED, + StateStatusEnum.SUCCESS, + StateStatusEnum.NEXT_CREATED_ERROR, + StateStatusEnum.PRUNED + ] + + for initial_status in test_cases: + # Arrange for this test case + mock_state = MagicMock() + mock_state.id = PydanticObjectId() + mock_state.status = initial_status + mock_state.save = AsyncMock() + mock_state_class.find_one = AsyncMock(return_value=mock_state) + + # Act + result = await re_queue_after_signal( + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_request_id + ) + + # Assert + assert result.status == StateStatusEnum.CREATED + assert mock_state.status == StateStatusEnum.CREATED + assert mock_state.save.call_count == 1 + + @patch('app.controller.re_queue_after_signal.State') + @patch('app.controller.re_queue_after_signal.time') + async def test_re_queue_after_signal_time_precision( + self, + mock_time, + mock_state_class, + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_state_any_status, + mock_request_id + ): + """Test that time calculation is precise""" + # Arrange + mock_time.time.return_value = 1234.567 # Test with fractional seconds + mock_state_any_status.save = AsyncMock() + mock_state_class.find_one = AsyncMock(return_value=mock_state_any_status) + + # Act + result = await re_queue_after_signal( + mock_namespace, + mock_state_id, + mock_re_enqueue_request, + mock_request_id + ) + + # Assert + expected_enqueue_after = int(1234.567 * 1000) + 5000 + assert result.enqueue_after == expected_enqueue_after + assert mock_state_any_status.enqueue_after == expected_enqueue_after \ No newline at end of file diff --git a/state-manager/tests/unit/models/test_signal_models.py b/state-manager/tests/unit/models/test_signal_models.py new file mode 100644 index 00000000..fd14636d --- /dev/null +++ b/state-manager/tests/unit/models/test_signal_models.py @@ -0,0 +1,279 @@ +import pytest +from pydantic import ValidationError + +from app.models.signal_models import PruneRequestModel, ReEnqueueAfterRequestModel, SignalResponseModel +from app.models.state_status_enum import StateStatusEnum + + +class TestPruneRequestModel: + """Test cases for PruneRequestModel""" + + def test_prune_request_model_valid_data(self): + """Test PruneRequestModel with valid data""" + # Arrange & Act + data = {"key": "value", "nested": {"data": "test"}} + model = PruneRequestModel(data=data) + + # Assert + assert model.data == data + + def test_prune_request_model_empty_data(self): + """Test PruneRequestModel with empty data""" + # Arrange & Act + data = {} + model = PruneRequestModel(data=data) + + # Assert + assert model.data == data + + def test_prune_request_model_complex_data(self): + """Test PruneRequestModel with complex nested data""" + # Arrange & Act + data = { + "string": "test", + "number": 42, + "boolean": True, + "list": [1, 2, 3], + "nested": { + "object": { + "deep": "value" + } + } + } + model = PruneRequestModel(data=data) + + # Assert + assert model.data == data + + def test_prune_request_model_missing_data(self): + """Test PruneRequestModel with missing data field""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + PruneRequestModel() # type: ignore + + assert "data" in str(exc_info.value) + + def test_prune_request_model_none_data(self): + """Test PruneRequestModel with None data""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + PruneRequestModel(data=None) # type: ignore + + assert "data" in str(exc_info.value) + + +class TestReEnqueueAfterRequestModel: + """Test cases for ReEnqueueAfterRequestModel""" + + def test_re_enqueue_after_request_model_valid_delay(self): + """Test ReEnqueueAfterRequestModel with valid delay""" + # Arrange & Act + delay = 5000 + model = ReEnqueueAfterRequestModel(enqueue_after=delay) + + # Assert + assert model.enqueue_after == delay + + def test_re_enqueue_after_request_model_zero_delay(self): + """Test ReEnqueueAfterRequestModel with zero delay""" + # Arrange & Act + delay = 0 + model = ReEnqueueAfterRequestModel(enqueue_after=delay) + + # Assert + assert model.enqueue_after == delay + + def test_re_enqueue_after_request_model_negative_delay(self): + """Test ReEnqueueAfterRequestModel with negative delay""" + # Arrange & Act + delay = -5000 + model = ReEnqueueAfterRequestModel(enqueue_after=delay) + + # Assert + assert model.enqueue_after == delay + + def test_re_enqueue_after_request_model_large_delay(self): + """Test ReEnqueueAfterRequestModel with large delay""" + # Arrange & Act + delay = 86400000 # 24 hours + model = ReEnqueueAfterRequestModel(enqueue_after=delay) + + # Assert + assert model.enqueue_after == delay + + def test_re_enqueue_after_request_model_missing_enqueue_after(self): + """Test ReEnqueueAfterRequestModel with missing enqueue_after field""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + ReEnqueueAfterRequestModel() # type: ignore + + assert "enqueue_after" in str(exc_info.value) + + def test_re_enqueue_after_request_model_none_enqueue_after(self): + """Test ReEnqueueAfterRequestModel with None enqueue_after""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + ReEnqueueAfterRequestModel(enqueue_after=None) # type: ignore + + assert "enqueue_after" in str(exc_info.value) + + def test_re_enqueue_after_request_model_string_enqueue_after(self): + """Test ReEnqueueAfterRequestModel with string enqueue_after (should convert)""" + # Arrange & Act + delay = "5000" + model = ReEnqueueAfterRequestModel(enqueue_after=delay) # type: ignore + + # Assert + assert model.enqueue_after == 5000 + + def test_re_enqueue_after_request_model_float_enqueue_after(self): + """Test ReEnqueueAfterRequestModel with float enqueue_after (should convert)""" + # Arrange & Act + delay = 5000.0 + model = ReEnqueueAfterRequestModel(enqueue_after=delay) # type: ignore + + # Assert + assert model.enqueue_after == 5000 + + +class TestSignalResponseModel: + """Test cases for SignalResponseModel""" + + def test_signal_response_model_valid_data(self): + """Test SignalResponseModel with valid data""" + # Arrange & Act + enqueue_after = 1234567890 + status = StateStatusEnum.PRUNED + model = SignalResponseModel(enqueue_after=enqueue_after, status=status) + + # Assert + assert model.enqueue_after == enqueue_after + assert model.status == status + + def test_signal_response_model_created_status(self): + """Test SignalResponseModel with CREATED status""" + # Arrange & Act + enqueue_after = 1234567890 + status = StateStatusEnum.CREATED + model = SignalResponseModel(enqueue_after=enqueue_after, status=status) + + # Assert + assert model.enqueue_after == enqueue_after + assert model.status == status + + def test_signal_response_model_zero_enqueue_after(self): + """Test SignalResponseModel with zero enqueue_after""" + # Arrange & Act + enqueue_after = 0 + status = StateStatusEnum.PRUNED + model = SignalResponseModel(enqueue_after=enqueue_after, status=status) + + # Assert + assert model.enqueue_after == enqueue_after + assert model.status == status + + def test_signal_response_model_large_enqueue_after(self): + """Test SignalResponseModel with large enqueue_after""" + # Arrange & Act + enqueue_after = 9999999999999 + status = StateStatusEnum.CREATED + model = SignalResponseModel(enqueue_after=enqueue_after, status=status) + + # Assert + assert model.enqueue_after == enqueue_after + assert model.status == status + + def test_signal_response_model_missing_enqueue_after(self): + """Test SignalResponseModel with missing enqueue_after field""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + SignalResponseModel(status=StateStatusEnum.PRUNED) # type: ignore + + assert "enqueue_after" in str(exc_info.value) + + def test_signal_response_model_missing_status(self): + """Test SignalResponseModel with missing status field""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + SignalResponseModel(enqueue_after=1234567890) # type: ignore + + assert "status" in str(exc_info.value) + + def test_signal_response_model_none_enqueue_after(self): + """Test SignalResponseModel with None enqueue_after""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + SignalResponseModel(enqueue_after=None, status=StateStatusEnum.PRUNED) # type: ignore + + assert "enqueue_after" in str(exc_info.value) + + def test_signal_response_model_none_status(self): + """Test SignalResponseModel with None status""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + SignalResponseModel(enqueue_after=1234567890, status=None) # type: ignore + + assert "status" in str(exc_info.value) + + def test_signal_response_model_string_enqueue_after(self): + """Test SignalResponseModel with string enqueue_after (should convert)""" + # Arrange & Act + enqueue_after = "1234567890" + status = StateStatusEnum.PRUNED + model = SignalResponseModel(enqueue_after=enqueue_after, status=status) # type: ignore + + # Assert + assert model.enqueue_after == 1234567890 + assert model.status == status + + def test_signal_response_model_all_status_enum_values(self): + """Test SignalResponseModel with all possible status enum values""" + # Arrange + enqueue_after = 1234567890 + all_statuses = [ + StateStatusEnum.CREATED, + StateStatusEnum.QUEUED, + StateStatusEnum.EXECUTED, + StateStatusEnum.ERRORED, + StateStatusEnum.CANCELLED, + StateStatusEnum.SUCCESS, + StateStatusEnum.NEXT_CREATED_ERROR, + StateStatusEnum.PRUNED + ] + + for status in all_statuses: + # Act + model = SignalResponseModel(enqueue_after=enqueue_after, status=status) + + # Assert + assert model.enqueue_after == enqueue_after + assert model.status == status + + def test_signal_response_model_json_serialization(self): + """Test SignalResponseModel JSON serialization""" + # Arrange + enqueue_after = 1234567890 + status = StateStatusEnum.PRUNED + model = SignalResponseModel(enqueue_after=enqueue_after, status=status) + + # Act + json_data = model.model_dump() + + # Assert + assert json_data["enqueue_after"] == enqueue_after + assert json_data["status"] == status.value + + def test_signal_response_model_json_deserialization(self): + """Test SignalResponseModel JSON deserialization""" + # Arrange + json_data = { + "enqueue_after": 1234567890, + "status": "PRUNED" + } + + # Act + model = SignalResponseModel(**json_data) + + # Assert + assert model.enqueue_after == 1234567890 + assert model.status == StateStatusEnum.PRUNED \ No newline at end of file diff --git a/state-manager/tests/unit/test_routes.py b/state-manager/tests/unit/test_routes.py index 823c47d8..92c9ca28 100644 --- a/state-manager/tests/unit/test_routes.py +++ b/state-manager/tests/unit/test_routes.py @@ -29,6 +29,8 @@ def test_router_has_correct_routes(self): assert any('/v0/namespace/{namespace_name}/graph/{graph_name}/states/create' in path for path in paths) assert any('/v0/namespace/{namespace_name}/states/{state_id}/executed' in path for path in paths) assert any('/v0/namespace/{namespace_name}/states/{state_id}/errored' in path for path in paths) + assert any('/v0/namespace/{namespace_name}/states/{state_id}/prune' in path for path in paths) + assert any('/v0/namespace/{namespace_name}/states/{state_id}/re-enqueue-after' in path for path in paths) # Graph template routes (there are two /graph/{graph_name} routes - GET and PUT) assert any('/v0/namespace/{namespace_name}/graph/{graph_name}' in path for path in paths) @@ -109,6 +111,84 @@ def test_create_request_model_validation(self): assert len(model.states) == 1 assert model.states[0].identifier == "node1" + def test_prune_request_model_validation(self): + """Test PruneRequestModel validation""" + from app.models.signal_models import PruneRequestModel + + # Test with valid data + valid_data = { + "data": {"key": "value", "nested": {"data": "test"}} + } + model = PruneRequestModel(**valid_data) + assert model.data == {"key": "value", "nested": {"data": "test"}} + + # Test with empty data + empty_data = {"data": {}} + model = PruneRequestModel(**empty_data) + assert model.data == {} + + # Test with complex data + complex_data = { + "data": { + "string": "test", + "number": 42, + "boolean": True, + "list": [1, 2, 3] + } + } + model = PruneRequestModel(**complex_data) + assert model.data["string"] == "test" + assert model.data["number"] == 42 + assert model.data["boolean"] is True + assert model.data["list"] == [1, 2, 3] + + def test_re_enqueue_after_request_model_validation(self): + """Test ReEnqueueAfterRequestModel validation""" + from app.models.signal_models import ReEnqueueAfterRequestModel + + # Test with valid data + valid_data = {"enqueue_after": 5000} + model = ReEnqueueAfterRequestModel(**valid_data) + assert model.enqueue_after == 5000 + + # Test with zero delay + zero_data = {"enqueue_after": 0} + model = ReEnqueueAfterRequestModel(**zero_data) + assert model.enqueue_after == 0 + + # Test with negative delay + negative_data = {"enqueue_after": -5000} + model = ReEnqueueAfterRequestModel(**negative_data) + assert model.enqueue_after == -5000 + + # Test with large delay + large_data = {"enqueue_after": 86400000} + model = ReEnqueueAfterRequestModel(**large_data) + assert model.enqueue_after == 86400000 + + def test_signal_response_model_validation(self): + """Test SignalResponseModel validation""" + from app.models.signal_models import SignalResponseModel + from app.models.state_status_enum import StateStatusEnum + + # Test with valid data + valid_data = { + "enqueue_after": 1234567890, + "status": "PRUNED" + } + model = SignalResponseModel(**valid_data) + assert model.enqueue_after == 1234567890 + assert model.status == StateStatusEnum.PRUNED + + # Test with CREATED status + created_data = { + "enqueue_after": 1234567890, + "status": "CREATED" + } + model = SignalResponseModel(**created_data) + assert model.enqueue_after == 1234567890 + assert model.status == StateStatusEnum.CREATED + def test_executed_request_model_validation(self): """Test ExecutedRequestModel validation""" # Test with valid data @@ -331,7 +411,7 @@ async def test_enqueue_state_with_invalid_api_key(self, mock_enqueue_states, moc # Act & Assert with pytest.raises(HTTPException) as exc_info: - await enqueue_state("test_namespace", body, mock_request, None) + await enqueue_state("test_namespace", body, mock_request, None) # type: ignore assert exc_info.value.status_code == 401 assert exc_info.value.detail == "Invalid API key" @@ -385,7 +465,7 @@ async def test_trigger_graph_route_with_invalid_api_key(self, mock_trigger_graph # Act & Assert with pytest.raises(HTTPException) as exc_info: - await trigger_graph_route("test_namespace", "test_graph", body, mock_request, None) + await trigger_graph_route("test_namespace", "test_graph", body, mock_request, None) # type: ignore assert exc_info.value.status_code == 401 assert exc_info.value.detail == "Invalid API key" @@ -609,4 +689,158 @@ async def test_get_states_by_run_id_route_with_valid_api_key(self, mock_get_stat assert result.namespace == "test_namespace" assert result.run_id == "test_run" assert result.count == 1 - assert len(result.states) == 1 \ No newline at end of file + assert len(result.states) == 1 + + @patch('app.routes.prune_signal') + async def test_prune_state_route_with_valid_api_key(self, mock_prune_signal, mock_request): + """Test prune_state_route with valid API key""" + from app.routes import prune_state_route + from app.models.signal_models import PruneRequestModel, SignalResponseModel + from app.models.state_status_enum import StateStatusEnum + from beanie import PydanticObjectId + + # Arrange + state_id = "507f1f77bcf86cd799439011" + prune_request = PruneRequestModel(data={"key": "value"}) + expected_response = SignalResponseModel( + status=StateStatusEnum.PRUNED, + enqueue_after=1234567890 + ) + mock_prune_signal.return_value = expected_response + + # Act + result = await prune_state_route("test_namespace", state_id, prune_request, mock_request, "valid_key") + + # Assert + mock_prune_signal.assert_called_once_with("test_namespace", PydanticObjectId(state_id), prune_request, "test-request-id") + assert result == expected_response + + @patch('app.routes.prune_signal') + async def test_prune_state_route_with_invalid_api_key(self, mock_prune_signal, mock_request): + """Test prune_state_route with invalid API key""" + from app.routes import prune_state_route + from app.models.signal_models import PruneRequestModel + from fastapi import HTTPException, status + + # Arrange + state_id = "507f1f77bcf86cd799439011" + prune_request = PruneRequestModel(data={"key": "value"}) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await prune_state_route("test_namespace", state_id, prune_request, mock_request, None) # type: ignore + + assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED + assert exc_info.value.detail == "Invalid API key" + mock_prune_signal.assert_not_called() + + @patch('app.routes.re_queue_after_signal') + async def test_re_enqueue_after_state_route_with_valid_api_key(self, mock_re_queue_after_signal, mock_request): + """Test re_enqueue_after_state_route with valid API key""" + from app.routes import re_enqueue_after_state_route + from app.models.signal_models import ReEnqueueAfterRequestModel, SignalResponseModel + from app.models.state_status_enum import StateStatusEnum + from beanie import PydanticObjectId + + # Arrange + state_id = "507f1f77bcf86cd799439011" + re_enqueue_request = ReEnqueueAfterRequestModel(enqueue_after=5000) + expected_response = SignalResponseModel( + status=StateStatusEnum.CREATED, + enqueue_after=1234567890 + ) + mock_re_queue_after_signal.return_value = expected_response + + # Act + result = await re_enqueue_after_state_route("test_namespace", state_id, re_enqueue_request, mock_request, "valid_key") + + # Assert + mock_re_queue_after_signal.assert_called_once_with("test_namespace", PydanticObjectId(state_id), re_enqueue_request, "test-request-id") + assert result == expected_response + + @patch('app.routes.re_queue_after_signal') + async def test_re_enqueue_after_state_route_with_invalid_api_key(self, mock_re_queue_after_signal, mock_request): + """Test re_enqueue_after_state_route with invalid API key""" + from app.routes import re_enqueue_after_state_route + from app.models.signal_models import ReEnqueueAfterRequestModel + from fastapi import HTTPException, status + + # Arrange + state_id = "507f1f77bcf86cd799439011" + re_enqueue_request = ReEnqueueAfterRequestModel(enqueue_after=5000) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await re_enqueue_after_state_route("test_namespace", state_id, re_enqueue_request, mock_request, None) # type: ignore + + assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED + assert exc_info.value.detail == "Invalid API key" + mock_re_queue_after_signal.assert_not_called() + + @patch('app.routes.prune_signal') + async def test_prune_state_route_with_different_data(self, mock_prune_signal, mock_request): + """Test prune_state_route with different data payloads""" + from app.routes import prune_state_route + from app.models.signal_models import PruneRequestModel, SignalResponseModel + from app.models.state_status_enum import StateStatusEnum + from beanie import PydanticObjectId + + # Test cases with different data + test_cases = [ + {"simple": "value"}, + {"nested": {"data": "test"}}, + {"list": [1, 2, 3]}, + {"boolean": True, "number": 42}, + {} # Empty data + ] + + for test_data in test_cases: + # Arrange + state_id = "507f1f77bcf86cd799439011" + prune_request = PruneRequestModel(data=test_data) + expected_response = SignalResponseModel( + status=StateStatusEnum.PRUNED, + enqueue_after=1234567890 + ) + mock_prune_signal.return_value = expected_response + + # Act + result = await prune_state_route("test_namespace", state_id, prune_request, mock_request, "valid_key") + + # Assert + mock_prune_signal.assert_called_with("test_namespace", PydanticObjectId(state_id), prune_request, "test-request-id") + assert result == expected_response + + @patch('app.routes.re_queue_after_signal') + async def test_re_enqueue_after_state_route_with_different_delays(self, mock_re_queue_after_signal, mock_request): + """Test re_enqueue_after_state_route with different delay values""" + from app.routes import re_enqueue_after_state_route + from app.models.signal_models import ReEnqueueAfterRequestModel, SignalResponseModel + from app.models.state_status_enum import StateStatusEnum + from beanie import PydanticObjectId + + # Test cases with different delays + test_cases = [ + 0, # No delay + 1000, # 1 second + 60000, # 1 minute + 3600000, # 1 hour + -5000 # Negative delay + ] + + for delay in test_cases: + # Arrange + state_id = "507f1f77bcf86cd799439011" + re_enqueue_request = ReEnqueueAfterRequestModel(enqueue_after=delay) + expected_response = SignalResponseModel( + status=StateStatusEnum.CREATED, + enqueue_after=1234567890 + ) + mock_re_queue_after_signal.return_value = expected_response + + # Act + result = await re_enqueue_after_state_route("test_namespace", state_id, re_enqueue_request, mock_request, "valid_key") + + # Assert + mock_re_queue_after_signal.assert_called_with("test_namespace", PydanticObjectId(state_id), re_enqueue_request, "test-request-id") + assert result == expected_response \ No newline at end of file From e5d61d89b7a11d0ff30bf5d8879c23da79c9de91 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Sat, 30 Aug 2025 14:09:56 +0530 Subject: [PATCH 10/19] Refactor test imports for prune and re-enqueue signal unit tests - Removed unused SignalResponseModel import from test files for prune_signal and re_queue_after_signal. - Streamlined test dependencies to improve clarity and maintainability. --- state-manager/tests/unit/controller/test_prune_signal.py | 2 +- .../tests/unit/controller/test_re_queue_after_signal.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/state-manager/tests/unit/controller/test_prune_signal.py b/state-manager/tests/unit/controller/test_prune_signal.py index 17011637..66349b26 100644 --- a/state-manager/tests/unit/controller/test_prune_signal.py +++ b/state-manager/tests/unit/controller/test_prune_signal.py @@ -4,7 +4,7 @@ from beanie import PydanticObjectId from app.controller.prune_signal import prune_signal -from app.models.signal_models import PruneRequestModel, SignalResponseModel +from app.models.signal_models import PruneRequestModel from app.models.state_status_enum import StateStatusEnum diff --git a/state-manager/tests/unit/controller/test_re_queue_after_signal.py b/state-manager/tests/unit/controller/test_re_queue_after_signal.py index fb9f6c84..fc437b58 100644 --- a/state-manager/tests/unit/controller/test_re_queue_after_signal.py +++ b/state-manager/tests/unit/controller/test_re_queue_after_signal.py @@ -1,11 +1,10 @@ import pytest -import time from unittest.mock import AsyncMock, MagicMock, patch from fastapi import HTTPException, status from beanie import PydanticObjectId from app.controller.re_queue_after_signal import re_queue_after_signal -from app.models.signal_models import ReEnqueueAfterRequestModel, SignalResponseModel +from app.models.signal_models import ReEnqueueAfterRequestModel from app.models.state_status_enum import StateStatusEnum From d84b6abe6cde274da7e62316baef3a62a273bf68 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Sat, 30 Aug 2025 15:06:34 +0530 Subject: [PATCH 11/19] Implement prune and requeue signal functionality - Added new signals: PruneSingal and ReQueueAfterSingal to handle state management operations. - Updated Runtime class to include methods for constructing endpoints for pruning and requeuing states. - Enhanced error handling in the Runtime class to manage new signal exceptions. - Bumped version to 0.0.2b2 to reflect the addition of new features. --- python-sdk/exospherehost/__init__.py | 3 +- python-sdk/exospherehost/_version.py | 2 +- python-sdk/exospherehost/runtime.py | 23 ++++++++++ python-sdk/exospherehost/signals.py | 67 ++++++++++++++++++++++++++++ 4 files changed, 93 insertions(+), 2 deletions(-) create mode 100644 python-sdk/exospherehost/signals.py diff --git a/python-sdk/exospherehost/__init__.py b/python-sdk/exospherehost/__init__.py index 777181b9..29ef3bab 100644 --- a/python-sdk/exospherehost/__init__.py +++ b/python-sdk/exospherehost/__init__.py @@ -38,7 +38,8 @@ async def execute(self, inputs: Inputs) -> Outputs: from .runtime import Runtime from .node.BaseNode import BaseNode from .statemanager import StateManager, TriggerState +from .signals import PruneSingal, ReQueueAfterSingal VERSION = __version__ -__all__ = ["Runtime", "BaseNode", "StateManager", "TriggerState", "VERSION"] +__all__ = ["Runtime", "BaseNode", "StateManager", "TriggerState", "VERSION", "PruneSingal", "ReQueueAfterSingal"] diff --git a/python-sdk/exospherehost/_version.py b/python-sdk/exospherehost/_version.py index a5ed84bc..9a836ee0 100644 --- a/python-sdk/exospherehost/_version.py +++ b/python-sdk/exospherehost/_version.py @@ -1 +1 @@ -version = "0.0.2b1" +version = "0.0.2b2" diff --git a/python-sdk/exospherehost/runtime.py b/python-sdk/exospherehost/runtime.py index 89786c58..94c74b7e 100644 --- a/python-sdk/exospherehost/runtime.py +++ b/python-sdk/exospherehost/runtime.py @@ -8,6 +8,7 @@ from pydantic import BaseModel from .node.BaseNode import BaseNode from aiohttp import ClientSession +from .signals import PruneSingal, ReQueueAfterSingal logger = logging.getLogger(__name__) @@ -159,6 +160,18 @@ def _get_secrets_endpoint(self, state_id: str): Construct the endpoint URL for getting secrets. """ return f"{self._state_manager_uri}/{str(self._state_manager_version)}/namespace/{self._namespace}/state/{state_id}/secrets" + + def _get_prune_endpoint(self, state_id: str): + """ + Construct the endpoint URL for pruning a state. + """ + return f"{self._state_manager_uri}/{str(self._state_manager_version)}/namespace/{self._namespace}/state/{state_id}/prune" + + def _get_requeue_after_endpoint(self, state_id: str): + """ + Construct the endpoint URL for requeuing a state after a timedelta. + """ + return f"{self._state_manager_uri}/{str(self._state_manager_version)}/namespace/{self._namespace}/state/{state_id}/re-enqueue-after" async def _register(self): """ @@ -395,6 +408,16 @@ async def _worker(self, idx: int): await self._notify_executed(state["state_id"], outputs) logger.info(f"Notified executed state {state['state_id']} for node {node.__name__ if node else "unknown"}") + + except PruneSingal as prune_signal: + logger.info(f"Pruning state {state['state_id']} for node {node.__name__ if node else "unknown"}") + await prune_signal.send(self._get_prune_endpoint(state["state_id"]), self._key) # type: ignore + logger.info(f"Pruned state {state['state_id']} for node {node.__name__ if node else "unknown"}") + + except ReQueueAfterSingal as requeue_signal: + logger.info(f"Requeuing state {state['state_id']} for node {node.__name__ if node else "unknown"} after {requeue_signal.timedelta}") + await requeue_signal.send(self._get_requeue_after_endpoint(state["state_id"]), self._key) # type: ignore + logger.info(f"Requeued state {state['state_id']} for node {node.__name__ if node else "unknown"} after {requeue_signal.timedelta}") except Exception as e: logger.error(f"Error executing state {state['state_id']} for node {node.__name__ if node else "unknown"}: {e}") diff --git a/python-sdk/exospherehost/signals.py b/python-sdk/exospherehost/signals.py new file mode 100644 index 00000000..44f21fa4 --- /dev/null +++ b/python-sdk/exospherehost/signals.py @@ -0,0 +1,67 @@ +from typing import Any +from aiohttp import ClientSession +from datetime import timedelta + +class PruneSingal(Exception): + """ + Exception used to signal that a prune operation should be performed. + + Args: + data (dict[str, Any], optional): Additional data to include with the signal. Defaults to {}. + + Note: + Do not catch this Exception, let it bubble up to Runtime for handling at StateManager. + """ + def __init__(self, data: dict[str, Any] = {}): + self.data = data + super().__init__(f"Prune signal received with data: {data} \n NOTE: Do not catch this Exception, let it bubble up to Runtime for handling at StateManager") + + async def send(self, endpoint: str, key: str): + """ + Sends the prune signal to the specified endpoint. + + Args: + endpoint (str): The URL to send the signal to. + key (str): The API key to include in the request headers. + + Raises: + Exception: If the HTTP request fails (status code != 200). + """ + async with ClientSession() as session: + async with session.post(endpoint, json=self.data, headers={"x-api-key": key}) as response: + if response.status != 200: + raise Exception(f"Failed to send prune signal to {endpoint}") + + +class ReQueueAfterSingal(Exception): + """ + Exception used to signal that a requeue operation should be performed after a specified timedelta. + + Args: + timedelta (timedelta): The amount of time to wait before requeuing. + + Note: + Do not catch this Exception, let it bubble up to Runtime for handling at StateManager. + """ + def __init__(self, timedelta: timedelta): + self.timedelta = timedelta + super().__init__(f"ReQueueAfter signal received with timedelta: {timedelta} \n NOTE: Do not catch this Exception, let it bubble up to Runtime for handling at StateManager") + + async def send(self, endpoint: str, key: str): + """ + Sends the requeue-after signal to the specified endpoint. + + Args: + endpoint (str): The URL to send the signal to. + key (str): The API key to include in the request headers. + + Raises: + Exception: If the HTTP request fails (status code != 200). + """ + body = { + "enqueue_after": int(self.timedelta.total_seconds() * 1000) + } + async with ClientSession() as session: + async with session.post(endpoint, json=body, headers={"x-api-key": key}) as response: + if response.status != 200: + raise Exception(f"Failed to send requeue after signal to {endpoint}") From ced39b9a5d906a07ae2dcbf06c287ddae3bbefb8 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Sat, 30 Aug 2025 15:39:18 +0530 Subject: [PATCH 12/19] Add tests for PruneSingal and ReQueueAfterSingal functionality - Updated test_package_init.py to include new expected exports for PruneSingal and ReQueueAfterSingal. - Created test_signals_and_runtime_functions.py to implement comprehensive unit tests for PruneSingal and ReQueueAfterSingal, covering initialization, sending, and error handling. - Enhanced test coverage for signal handling in the Runtime class, ensuring robust integration with state management operations. --- python-sdk/tests/test_package_init.py | 2 +- .../test_signals_and_runtime_functions.py | 723 ++++++++++++++++++ 2 files changed, 724 insertions(+), 1 deletion(-) create mode 100644 python-sdk/tests/test_signals_and_runtime_functions.py diff --git a/python-sdk/tests/test_package_init.py b/python-sdk/tests/test_package_init.py index 90bbabf7..926ef43e 100644 --- a/python-sdk/tests/test_package_init.py +++ b/python-sdk/tests/test_package_init.py @@ -15,7 +15,7 @@ def test_package_all_imports(): """Test that __all__ contains all expected exports.""" from exospherehost import __all__ - expected_exports = ["Runtime", "BaseNode", "StateManager", "TriggerState", "VERSION"] + expected_exports = ["Runtime", "BaseNode", "StateManager", "TriggerState", "VERSION", "PruneSingal", "ReQueueAfterSingal"] for export in expected_exports: assert export in __all__, f"{export} should be in __all__" diff --git a/python-sdk/tests/test_signals_and_runtime_functions.py b/python-sdk/tests/test_signals_and_runtime_functions.py new file mode 100644 index 00000000..cbea3289 --- /dev/null +++ b/python-sdk/tests/test_signals_and_runtime_functions.py @@ -0,0 +1,723 @@ +import pytest +import logging +from unittest.mock import AsyncMock, patch, MagicMock +from datetime import timedelta +from pydantic import BaseModel +from exospherehost.signals import PruneSingal, ReQueueAfterSingal +from exospherehost.runtime import Runtime, _setup_default_logging +from exospherehost.node.BaseNode import BaseNode + + +def create_mock_aiohttp_session(): + """Helper function to create a properly mocked aiohttp session.""" + mock_session = AsyncMock() + + # Create mock response objects + mock_post_response = MagicMock() + mock_get_response = MagicMock() + mock_put_response = MagicMock() + + # Create mock context managers for each method + mock_post_context = MagicMock() + mock_post_context.__aenter__ = AsyncMock(return_value=mock_post_response) + mock_post_context.__aexit__ = AsyncMock(return_value=None) + + mock_get_context = MagicMock() + mock_get_context.__aenter__ = AsyncMock(return_value=mock_get_response) + mock_get_context.__aexit__ = AsyncMock(return_value=None) + + mock_put_context = MagicMock() + mock_put_context.__aenter__ = AsyncMock(return_value=mock_put_response) + mock_put_context.__aexit__ = AsyncMock(return_value=None) + + # Set up the session methods to return the context managers using MagicMock + mock_session.post = MagicMock(return_value=mock_post_context) + mock_session.get = MagicMock(return_value=mock_get_context) + mock_session.put = MagicMock(return_value=mock_put_context) + + # Set up session context manager + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + + return mock_session, mock_post_response, mock_get_response, mock_put_response + + +class MockTestNode(BaseNode): + class Inputs(BaseModel): + name: str + + class Outputs(BaseModel): + message: str + + class Secrets(BaseModel): + api_key: str + + async def execute(self): + return self.Outputs(message=f"Hello {self.inputs.name}") # type: ignore + + +class TestPruneSingal: + """Test cases for PruneSingal exception class.""" + + def test_prune_signal_initialization_with_data(self): + """Test PruneSingal initialization with custom data.""" + data = {"reason": "test", "custom_field": "value"} + signal = PruneSingal(data) + + assert signal.data == data + assert "Prune signal received with data" in str(signal) + assert "Do not catch this Exception" in str(signal) + + def test_prune_signal_initialization_without_data(self): + """Test PruneSingal initialization without data (default empty dict).""" + signal = PruneSingal() + + assert signal.data == {} + assert "Prune signal received with data" in str(signal) + + def test_prune_signal_inheritance(self): + """Test that PruneSingal properly inherits from Exception.""" + signal = PruneSingal() + assert isinstance(signal, Exception) + + @pytest.mark.asyncio + async def test_prune_signal_send_success(self): + """Test successful sending of prune signal.""" + data = {"reason": "test_prune"} + signal = PruneSingal(data) + + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 200 + + with patch('exospherehost.signals.ClientSession', return_value=mock_session): + await signal.send("http://test-endpoint/prune", "test-api-key") + + # Verify the request was made correctly + mock_session.post.assert_called_once_with( + "http://test-endpoint/prune", + json=data, + headers={"x-api-key": "test-api-key"} + ) + + @pytest.mark.asyncio + async def test_prune_signal_send_failure(self): + """Test prune signal sending failure.""" + data = {"reason": "test_prune"} + signal = PruneSingal(data) + + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 500 + + with patch('exospherehost.signals.ClientSession', return_value=mock_session): + with pytest.raises(Exception, match="Failed to send prune signal"): + await signal.send("http://test-endpoint/prune", "test-api-key") + + +class TestReQueueAfterSingal: + """Test cases for ReQueueAfterSingal exception class.""" + + def test_requeue_signal_initialization(self): + """Test ReQueueAfterSingal initialization.""" + delta = timedelta(seconds=30) + signal = ReQueueAfterSingal(delta) + + assert signal.timedelta == delta + assert "ReQueueAfter signal received with timedelta" in str(signal) + assert "Do not catch this Exception" in str(signal) + + def test_requeue_signal_inheritance(self): + """Test that ReQueueAfterSingal properly inherits from Exception.""" + delta = timedelta(minutes=5) + signal = ReQueueAfterSingal(delta) + assert isinstance(signal, Exception) + + @pytest.mark.asyncio + async def test_requeue_signal_send_success(self): + """Test successful sending of requeue signal.""" + delta = timedelta(seconds=45) + signal = ReQueueAfterSingal(delta) + + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 200 + + with patch('exospherehost.signals.ClientSession', return_value=mock_session): + await signal.send("http://test-endpoint/requeue", "test-api-key") + + # Verify the request was made correctly + expected_body = {"enqueue_after": 45000} # 45 seconds * 1000 + mock_session.post.assert_called_once_with( + "http://test-endpoint/requeue", + json=expected_body, + headers={"x-api-key": "test-api-key"} + ) + + @pytest.mark.asyncio + async def test_requeue_signal_send_with_minutes(self): + """Test requeue signal sending with minutes in timedelta.""" + delta = timedelta(minutes=2, seconds=30) + signal = ReQueueAfterSingal(delta) + + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 200 + + with patch('exospherehost.signals.ClientSession', return_value=mock_session): + await signal.send("http://test-endpoint/requeue", "test-api-key") + + # Verify the request was made correctly + expected_body = {"enqueue_after": 150000} # (2*60 + 30) seconds * 1000 + mock_session.post.assert_called_once_with( + "http://test-endpoint/requeue", + json=expected_body, + headers={"x-api-key": "test-api-key"} + ) + + @pytest.mark.asyncio + async def test_requeue_signal_send_failure(self): + """Test requeue signal sending failure.""" + delta = timedelta(seconds=30) + signal = ReQueueAfterSingal(delta) + + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 400 + + with patch('exospherehost.signals.ClientSession', return_value=mock_session): + with pytest.raises(Exception, match="Failed to send requeue after signal"): + await signal.send("http://test-endpoint/requeue", "test-api-key") + + +class TestRuntimeSignalHandling: + """Test cases for Runtime signal handling functionality.""" + + def test_runtime_endpoint_construction(self): + """Test that runtime constructs correct endpoints for signal handling.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + # Test prune endpoint construction + prune_endpoint = runtime._get_prune_endpoint("test-state-id") + expected_prune = "http://test-state-manager/v0/namespace/test-namespace/state/test-state-id/prune" + assert prune_endpoint == expected_prune + + # Test requeue after endpoint construction + requeue_endpoint = runtime._get_requeue_after_endpoint("test-state-id") + expected_requeue = "http://test-state-manager/v0/namespace/test-namespace/state/test-state-id/re-enqueue-after" + assert requeue_endpoint == expected_requeue + + @pytest.mark.asyncio + async def test_signal_handling_direct(self): + """Test signal handling by directly calling signal.send() with runtime endpoints.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + # Test PruneSingal with runtime endpoint + prune_signal = PruneSingal({"reason": "direct_test"}) + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 200 + + with patch('exospherehost.signals.ClientSession', return_value=mock_session): + await prune_signal.send(runtime._get_prune_endpoint("test-state"), runtime._key) + + # Verify prune endpoint was called correctly + mock_session.post.assert_called_once_with( + runtime._get_prune_endpoint("test-state"), + json={"reason": "direct_test"}, + headers={"x-api-key": "test-key"} + ) + + @pytest.mark.asyncio + async def test_requeue_signal_handling_direct(self): + """Test requeue signal handling by directly calling signal.send() with runtime endpoints.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + # Test ReQueueAfterSingal with runtime endpoint + requeue_signal = ReQueueAfterSingal(timedelta(minutes=10)) + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 200 + + with patch('exospherehost.signals.ClientSession', return_value=mock_session): + await requeue_signal.send(runtime._get_requeue_after_endpoint("test-state"), runtime._key) + + # Verify requeue endpoint was called correctly + expected_body = {"enqueue_after": 600000} # 10 minutes * 60 * 1000 + mock_session.post.assert_called_once_with( + runtime._get_requeue_after_endpoint("test-state"), + json=expected_body, + headers={"x-api-key": "test-key"} + ) + + def test_need_secrets_function(self): + """Test the _need_secrets function.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + # Test with node that has secrets + assert runtime._need_secrets(MockTestNode) + + # Test with node that has no secrets + class MockNodeWithoutSecrets(BaseNode): + class Inputs(BaseModel): + name: str + class Outputs(BaseModel): + message: str + class Secrets(BaseModel): + pass + async def execute(self): + return self.Outputs(message="test") + + assert not runtime._need_secrets(MockNodeWithoutSecrets) + + @pytest.mark.asyncio + async def test_get_secrets_function(self): + """Test the _get_secrets function.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + # Mock successful secrets retrieval + mock_session, _, mock_get_response, _ = create_mock_aiohttp_session() + mock_get_response.status = 200 + mock_get_response.json = AsyncMock(return_value={"secrets": {"api_key": "test-secret"}}) + + with patch('exospherehost.runtime.ClientSession', return_value=mock_session): + secrets = await runtime._get_secrets("test-state-id") + + assert secrets == {"api_key": "test-secret"} + mock_session.get.assert_called_once_with( + runtime._get_secrets_endpoint("test-state-id"), + headers={"x-api-key": "test-key"} + ) + + @pytest.mark.asyncio + async def test_get_secrets_function_failure(self): + """Test the _get_secrets function when request fails.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + # Mock failed secrets retrieval + mock_session, _, mock_get_response, _ = create_mock_aiohttp_session() + mock_get_response.status = 404 + mock_get_response.json = AsyncMock(return_value={"error": "Not found"}) + + with patch('exospherehost.runtime.ClientSession', return_value=mock_session): + secrets = await runtime._get_secrets("test-state-id") + + assert secrets == {} + + @pytest.mark.asyncio + async def test_get_secrets_function_no_secrets_field(self): + """Test the _get_secrets function when response has no secrets field.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + # Mock response without secrets field + mock_session, _, mock_get_response, _ = create_mock_aiohttp_session() + mock_get_response.status = 200 + mock_get_response.json = AsyncMock(return_value={"data": "some other data"}) + + with patch('exospherehost.runtime.ClientSession', return_value=mock_session): + secrets = await runtime._get_secrets("test-state-id") + + assert secrets == {} + + +class TestRuntimeEndpointFunctions: + """Test cases for Runtime endpoint construction functions.""" + + def test_get_prune_endpoint(self): + """Test _get_prune_endpoint function.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + endpoint = runtime._get_prune_endpoint("state-123") + expected = "http://test-state-manager/v0/namespace/test-namespace/state/state-123/prune" + assert endpoint == expected + + def test_get_requeue_after_endpoint(self): + """Test _get_requeue_after_endpoint function.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + endpoint = runtime._get_requeue_after_endpoint("state-456") + expected = "http://test-state-manager/v0/namespace/test-namespace/state/state-456/re-enqueue-after" + assert endpoint == expected + + def test_get_prune_endpoint_with_custom_version(self): + """Test _get_prune_endpoint with custom state manager version.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key", + state_manage_version="v1" + ) + + endpoint = runtime._get_prune_endpoint("state-789") + expected = "http://test-state-manager/v1/namespace/test-namespace/state/state-789/prune" + assert endpoint == expected + + def test_get_requeue_after_endpoint_with_custom_version(self): + """Test _get_requeue_after_endpoint with custom state manager version.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key", + state_manage_version="v2" + ) + + endpoint = runtime._get_requeue_after_endpoint("state-101") + expected = "http://test-state-manager/v2/namespace/test-namespace/state/state-101/re-enqueue-after" + assert endpoint == expected + + +class TestSignalIntegration: + """Integration tests for signal handling in the runtime.""" + + @pytest.mark.asyncio + async def test_signal_exception_behavior(self): + """Test that signals are proper exceptions that can be raised and caught.""" + # Test PruneSingal + prune_signal = PruneSingal({"test": "data"}) + + with pytest.raises(PruneSingal) as exc_info: + raise prune_signal + + assert exc_info.value.data == {"test": "data"} + assert isinstance(exc_info.value, Exception) + + # Test ReQueueAfterSingal + requeue_signal = ReQueueAfterSingal(timedelta(seconds=30)) + + with pytest.raises(ReQueueAfterSingal) as exc_info: + raise requeue_signal + + assert exc_info.value.timedelta == timedelta(seconds=30) + assert isinstance(exc_info.value, Exception) + + @pytest.mark.asyncio + async def test_combined_signal_and_runtime_functionality(self): + """Test that signals work correctly with runtime endpoints.""" + runtime = Runtime( + namespace="production", + name="signal-runtime", + nodes=[MockTestNode], + state_manager_uri="https://api.exosphere.host", + key="prod-api-key", + state_manage_version="v1" + ) + + # Test PruneSingal with production-like endpoint + prune_signal = PruneSingal({"reason": "cleanup", "batch_id": "batch-123"}) + expected_prune_endpoint = "https://api.exosphere.host/v1/namespace/production/state/prod-state-456/prune" + actual_prune_endpoint = runtime._get_prune_endpoint("prod-state-456") + assert actual_prune_endpoint == expected_prune_endpoint + + # Test ReQueueAfterSingal with production-like endpoint + requeue_signal = ReQueueAfterSingal(timedelta(hours=2, minutes=30)) + expected_requeue_endpoint = "https://api.exosphere.host/v1/namespace/production/state/prod-state-789/re-enqueue-after" + actual_requeue_endpoint = runtime._get_requeue_after_endpoint("prod-state-789") + assert actual_requeue_endpoint == expected_requeue_endpoint + + # Test that signal data is preserved + assert prune_signal.data == {"reason": "cleanup", "batch_id": "batch-123"} + assert requeue_signal.timedelta == timedelta(hours=2, minutes=30) + + @pytest.mark.asyncio + async def test_signal_send_with_different_endpoints(self): + """Test signal sending with various endpoint configurations.""" + # Test with different URI formats + test_cases = [ + ("http://localhost:8080", "v0", "dev"), + ("https://api.production.com", "v2", "production"), + ("http://staging.internal:3000", "v1", "staging") + ] + + for uri, version, namespace in test_cases: + runtime = Runtime( + namespace=namespace, + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri=uri, + key="test-key", + state_manage_version=version + ) + + # Test prune endpoint construction + prune_endpoint = runtime._get_prune_endpoint("test-state") + expected_prune = f"{uri}/{version}/namespace/{namespace}/state/test-state/prune" + assert prune_endpoint == expected_prune + + # Test requeue endpoint construction + requeue_endpoint = runtime._get_requeue_after_endpoint("test-state") + expected_requeue = f"{uri}/{version}/namespace/{namespace}/state/test-state/re-enqueue-after" + assert requeue_endpoint == expected_requeue + + +class TestSignalEdgeCases: + """Test cases for signal edge cases and error conditions.""" + + def test_prune_signal_with_empty_data(self): + """Test PruneSingal with empty data.""" + signal = PruneSingal({}) + assert signal.data == {} + assert isinstance(signal, Exception) + + def test_prune_signal_with_complex_data(self): + """Test PruneSingal with complex nested data.""" + complex_data = { + "reason": "batch_cleanup", + "metadata": { + "batch_id": "batch-456", + "items": ["item1", "item2", "item3"], + "timestamp": "2023-12-01T10:00:00Z" + }, + "config": { + "force": True, + "notify_users": False + } + } + signal = PruneSingal(complex_data) + assert signal.data == complex_data + + def test_requeue_signal_with_zero_timedelta(self): + """Test ReQueueAfterSingal with zero timedelta.""" + signal = ReQueueAfterSingal(timedelta(seconds=0)) + assert signal.timedelta == timedelta(seconds=0) + + def test_requeue_signal_with_large_timedelta(self): + """Test ReQueueAfterSingal with large timedelta.""" + large_delta = timedelta(days=7, hours=12, minutes=30, seconds=45) + signal = ReQueueAfterSingal(large_delta) + assert signal.timedelta == large_delta + + @pytest.mark.asyncio + async def test_requeue_signal_timedelta_conversion(self): + """Test that ReQueueAfterSingal correctly converts timedelta to milliseconds.""" + test_cases = [ + (timedelta(seconds=1), 1000), + (timedelta(minutes=1), 60000), + (timedelta(hours=1), 3600000), + (timedelta(days=1), 86400000), + (timedelta(seconds=30, microseconds=500000), 30500), # 30.5 seconds + ] + + for delta, expected_ms in test_cases: + signal = ReQueueAfterSingal(delta) + + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 200 + + with patch('exospherehost.signals.ClientSession', return_value=mock_session): + await signal.send("http://test-endpoint", "test-key") + + # Verify correct milliseconds conversion + expected_body = {"enqueue_after": expected_ms} + mock_session.post.assert_called_with( + "http://test-endpoint", + json=expected_body, + headers={"x-api-key": "test-key"} + ) + + def test_signal_string_representations(self): + """Test string representations of signals.""" + prune_signal = PruneSingal({"test": "data"}) + prune_str = str(prune_signal) + assert "Prune signal received with data" in prune_str + assert "Do not catch this Exception" in prune_str + assert "{'test': 'data'}" in prune_str + + requeue_signal = ReQueueAfterSingal(timedelta(minutes=5)) + requeue_str = str(requeue_signal) + assert "ReQueueAfter signal received with timedelta" in requeue_str + assert "Do not catch this Exception" in requeue_str + assert "0:05:00" in requeue_str + + +class TestRuntimeHelperFunctions: + """Test cases for Runtime helper functions.""" + + @pytest.mark.asyncio + async def test_notify_executed_function(self): + """Test the _notify_executed function.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + # Mock successful notification + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 200 + mock_post_response.json = AsyncMock(return_value={"status": "success"}) + + # Create test outputs + outputs = [MockTestNode.Outputs(message="output1"), MockTestNode.Outputs(message="output2")] + + with patch('exospherehost.runtime.ClientSession', return_value=mock_session): + await runtime._notify_executed("test-state-id", outputs) + + # Verify correct endpoint and payload + expected_body = {"outputs": [{"message": "output1"}, {"message": "output2"}]} + mock_session.post.assert_called_once_with( + runtime._get_executed_endpoint("test-state-id"), + json=expected_body, + headers={"x-api-key": "test-key"} + ) + + @pytest.mark.asyncio + async def test_notify_errored_function(self): + """Test the _notify_errored function.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + # Mock successful notification + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 200 + mock_post_response.json = AsyncMock(return_value={"status": "success"}) + + with patch('exospherehost.runtime.ClientSession', return_value=mock_session): + await runtime._notify_errored("test-state-id", "Test error message") + + # Verify correct endpoint and payload + expected_body = {"error": "Test error message"} + mock_session.post.assert_called_once_with( + runtime._get_errored_endpoint("test-state-id"), + json=expected_body, + headers={"x-api-key": "test-key"} + ) + + @pytest.mark.asyncio + async def test_notify_functions_with_failure(self): + """Test notification functions when HTTP requests fail.""" + runtime = Runtime( + namespace="test-namespace", + name="test-runtime", + nodes=[MockTestNode], + state_manager_uri="http://test-state-manager", + key="test-key" + ) + + # Mock failed notification + mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() + mock_post_response.status = 500 + mock_post_response.json = AsyncMock(return_value={"error": "Internal server error"}) + + outputs = [MockTestNode.Outputs(message="test")] + + with patch('exospherehost.runtime.ClientSession', return_value=mock_session): + # These should not raise exceptions, just log errors + await runtime._notify_executed("test-state-id", outputs) + await runtime._notify_errored("test-state-id", "Test error") + + # Verify both endpoints were called despite failures + assert mock_session.post.call_count == 2 + + +class TestSetupDefaultLogging: + """Test cases for the _setup_default_logging function.""" + + def test_setup_default_logging_with_existing_handlers(self): + """Test that _setup_default_logging doesn't interfere with existing handlers.""" + # Create a logger with existing handlers + test_logger = logging.getLogger("test_logger") + handler = logging.StreamHandler() + test_logger.addHandler(handler) + + # Mock the root logger to have handlers + with patch('logging.getLogger') as mock_get_logger: + mock_root_logger = MagicMock() + mock_root_logger.handlers = [handler] + mock_get_logger.return_value = mock_root_logger + + # This should return early and not configure logging + _setup_default_logging() + + # Verify no basic config was called + mock_root_logger.basicConfig = MagicMock() + assert not mock_root_logger.basicConfig.called + + def test_setup_default_logging_with_disable_env_var(self): + """Test that _setup_default_logging respects the disable environment variable.""" + with patch.dict('os.environ', {'EXOSPHERE_DISABLE_DEFAULT_LOGGING': 'true'}), \ + patch('logging.getLogger') as mock_get_logger: + mock_root_logger = MagicMock() + mock_root_logger.handlers = [] + mock_get_logger.return_value = mock_root_logger + + _setup_default_logging() + + # Should return early due to env var + with patch('logging.basicConfig') as mock_basic_config: + _setup_default_logging() + assert not mock_basic_config.called + + def test_setup_default_logging_with_custom_log_level(self): + """Test that _setup_default_logging respects custom log level.""" + with patch.dict('os.environ', {'EXOSPHERE_LOG_LEVEL': 'DEBUG'}), \ + patch('logging.getLogger') as mock_get_logger, \ + patch('logging.basicConfig') as mock_basic_config: + + mock_root_logger = MagicMock() + mock_root_logger.handlers = [] + mock_get_logger.return_value = mock_root_logger + + _setup_default_logging() + + # Verify basicConfig was called with DEBUG level + mock_basic_config.assert_called_once() + call_args = mock_basic_config.call_args + assert call_args[1]['level'] == logging.DEBUG \ No newline at end of file From b018b03104af72e321295ed5d9f5c42ba62c9d77 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Sat, 30 Aug 2025 16:20:23 +0530 Subject: [PATCH 13/19] Fix signal naming inconsistencies and enhance signal functionality - Renamed `PruneSingal` and `ReQueueAfterSingal` to `PruneSignal` and `ReQueueAfterSignal` across the codebase for consistency. - Updated exception handling in the `Runtime` class to reflect the new signal names. - Adjusted test cases to use the corrected signal names, ensuring comprehensive coverage for the updated functionality. - Enhanced model validation in `ReEnqueueAfterRequestModel` and `State` to enforce positive duration constraints for re-enqueueing operations. --- python-sdk/exospherehost/__init__.py | 4 +- python-sdk/exospherehost/runtime.py | 10 +- python-sdk/exospherehost/signals.py | 14 ++- .../test_signals_and_runtime_functions.py | 110 +++++++++--------- state-manager/app/models/db/state.py | 4 +- state-manager/app/models/signal_models.py | 2 +- state-manager/app/routes.py | 2 +- 7 files changed, 75 insertions(+), 71 deletions(-) diff --git a/python-sdk/exospherehost/__init__.py b/python-sdk/exospherehost/__init__.py index 29ef3bab..a16745df 100644 --- a/python-sdk/exospherehost/__init__.py +++ b/python-sdk/exospherehost/__init__.py @@ -38,8 +38,8 @@ async def execute(self, inputs: Inputs) -> Outputs: from .runtime import Runtime from .node.BaseNode import BaseNode from .statemanager import StateManager, TriggerState -from .signals import PruneSingal, ReQueueAfterSingal +from .signals import PruneSignal, ReQueueAfterSignal VERSION = __version__ -__all__ = ["Runtime", "BaseNode", "StateManager", "TriggerState", "VERSION", "PruneSingal", "ReQueueAfterSingal"] +__all__ = ["Runtime", "BaseNode", "StateManager", "TriggerState", "VERSION", "PruneSignal", "ReQueueAfterSignal"] diff --git a/python-sdk/exospherehost/runtime.py b/python-sdk/exospherehost/runtime.py index 94c74b7e..de74e459 100644 --- a/python-sdk/exospherehost/runtime.py +++ b/python-sdk/exospherehost/runtime.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from .node.BaseNode import BaseNode from aiohttp import ClientSession -from .signals import PruneSingal, ReQueueAfterSingal +from .signals import PruneSignal, ReQueueAfterSignal logger = logging.getLogger(__name__) @@ -409,15 +409,15 @@ async def _worker(self, idx: int): await self._notify_executed(state["state_id"], outputs) logger.info(f"Notified executed state {state['state_id']} for node {node.__name__ if node else "unknown"}") - except PruneSingal as prune_signal: + except PruneSignal as prune_signal: logger.info(f"Pruning state {state['state_id']} for node {node.__name__ if node else "unknown"}") await prune_signal.send(self._get_prune_endpoint(state["state_id"]), self._key) # type: ignore logger.info(f"Pruned state {state['state_id']} for node {node.__name__ if node else "unknown"}") - except ReQueueAfterSingal as requeue_signal: - logger.info(f"Requeuing state {state['state_id']} for node {node.__name__ if node else "unknown"} after {requeue_signal.timedelta}") + except ReQueueAfterSignal as requeue_signal: + logger.info(f"Requeuing state {state['state_id']} for node {node.__name__ if node else "unknown"} after {requeue_signal.delay}") await requeue_signal.send(self._get_requeue_after_endpoint(state["state_id"]), self._key) # type: ignore - logger.info(f"Requeued state {state['state_id']} for node {node.__name__ if node else "unknown"} after {requeue_signal.timedelta}") + logger.info(f"Requeued state {state['state_id']} for node {node.__name__ if node else "unknown"} after {requeue_signal.delay}") except Exception as e: logger.error(f"Error executing state {state['state_id']} for node {node.__name__ if node else "unknown"}: {e}") diff --git a/python-sdk/exospherehost/signals.py b/python-sdk/exospherehost/signals.py index 44f21fa4..c7072eb6 100644 --- a/python-sdk/exospherehost/signals.py +++ b/python-sdk/exospherehost/signals.py @@ -2,7 +2,7 @@ from aiohttp import ClientSession from datetime import timedelta -class PruneSingal(Exception): +class PruneSignal(Exception): """ Exception used to signal that a prune operation should be performed. @@ -33,7 +33,7 @@ async def send(self, endpoint: str, key: str): raise Exception(f"Failed to send prune signal to {endpoint}") -class ReQueueAfterSingal(Exception): +class ReQueueAfterSignal(Exception): """ Exception used to signal that a requeue operation should be performed after a specified timedelta. @@ -43,8 +43,12 @@ class ReQueueAfterSingal(Exception): Note: Do not catch this Exception, let it bubble up to Runtime for handling at StateManager. """ - def __init__(self, timedelta: timedelta): - self.timedelta = timedelta + def __init__(self, delay: timedelta): + self.delay = delay + + if self.delay.total_seconds() <= 0: + raise Exception("Delay must be greater than 0") + super().__init__(f"ReQueueAfter signal received with timedelta: {timedelta} \n NOTE: Do not catch this Exception, let it bubble up to Runtime for handling at StateManager") async def send(self, endpoint: str, key: str): @@ -59,7 +63,7 @@ async def send(self, endpoint: str, key: str): Exception: If the HTTP request fails (status code != 200). """ body = { - "enqueue_after": int(self.timedelta.total_seconds() * 1000) + "enqueue_after": int(self.delay.total_seconds() * 1000) } async with ClientSession() as session: async with session.post(endpoint, json=body, headers={"x-api-key": key}) as response: diff --git a/python-sdk/tests/test_signals_and_runtime_functions.py b/python-sdk/tests/test_signals_and_runtime_functions.py index cbea3289..cf12168c 100644 --- a/python-sdk/tests/test_signals_and_runtime_functions.py +++ b/python-sdk/tests/test_signals_and_runtime_functions.py @@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, patch, MagicMock from datetime import timedelta from pydantic import BaseModel -from exospherehost.signals import PruneSingal, ReQueueAfterSingal +from exospherehost.signals import PruneSignal, ReQueueAfterSignal from exospherehost.runtime import Runtime, _setup_default_logging from exospherehost.node.BaseNode import BaseNode @@ -56,35 +56,35 @@ async def execute(self): return self.Outputs(message=f"Hello {self.inputs.name}") # type: ignore -class TestPruneSingal: - """Test cases for PruneSingal exception class.""" +class TestPruneSignal: + """Test cases for PruneSignal exception class.""" def test_prune_signal_initialization_with_data(self): - """Test PruneSingal initialization with custom data.""" + """Test PruneSignal initialization with custom data.""" data = {"reason": "test", "custom_field": "value"} - signal = PruneSingal(data) + signal = PruneSignal(data) assert signal.data == data assert "Prune signal received with data" in str(signal) assert "Do not catch this Exception" in str(signal) def test_prune_signal_initialization_without_data(self): - """Test PruneSingal initialization without data (default empty dict).""" - signal = PruneSingal() + """Test PruneSignal initialization without data (default empty dict).""" + signal = PruneSignal() assert signal.data == {} assert "Prune signal received with data" in str(signal) def test_prune_signal_inheritance(self): - """Test that PruneSingal properly inherits from Exception.""" - signal = PruneSingal() + """Test that PruneSignal properly inherits from Exception.""" + signal = PruneSignal() assert isinstance(signal, Exception) @pytest.mark.asyncio async def test_prune_signal_send_success(self): """Test successful sending of prune signal.""" data = {"reason": "test_prune"} - signal = PruneSingal(data) + signal = PruneSignal(data) mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() mock_post_response.status = 200 @@ -103,7 +103,7 @@ async def test_prune_signal_send_success(self): async def test_prune_signal_send_failure(self): """Test prune signal sending failure.""" data = {"reason": "test_prune"} - signal = PruneSingal(data) + signal = PruneSignal(data) mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() mock_post_response.status = 500 @@ -113,29 +113,29 @@ async def test_prune_signal_send_failure(self): await signal.send("http://test-endpoint/prune", "test-api-key") -class TestReQueueAfterSingal: - """Test cases for ReQueueAfterSingal exception class.""" +class TestReQueueAfterSignal: + """Test cases for ReQueueAfterSignal exception class.""" def test_requeue_signal_initialization(self): - """Test ReQueueAfterSingal initialization.""" + """Test ReQueueAfterSignal initialization.""" delta = timedelta(seconds=30) - signal = ReQueueAfterSingal(delta) + signal = ReQueueAfterSignal(delta) - assert signal.timedelta == delta + assert signal.delay == delta assert "ReQueueAfter signal received with timedelta" in str(signal) assert "Do not catch this Exception" in str(signal) def test_requeue_signal_inheritance(self): - """Test that ReQueueAfterSingal properly inherits from Exception.""" + """Test that ReQueueAfterSignal properly inherits from Exception.""" delta = timedelta(minutes=5) - signal = ReQueueAfterSingal(delta) + signal = ReQueueAfterSignal(delta) assert isinstance(signal, Exception) @pytest.mark.asyncio async def test_requeue_signal_send_success(self): """Test successful sending of requeue signal.""" delta = timedelta(seconds=45) - signal = ReQueueAfterSingal(delta) + signal = ReQueueAfterSignal(delta) mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() mock_post_response.status = 200 @@ -155,7 +155,7 @@ async def test_requeue_signal_send_success(self): async def test_requeue_signal_send_with_minutes(self): """Test requeue signal sending with minutes in timedelta.""" delta = timedelta(minutes=2, seconds=30) - signal = ReQueueAfterSingal(delta) + signal = ReQueueAfterSignal(delta) mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() mock_post_response.status = 200 @@ -175,7 +175,7 @@ async def test_requeue_signal_send_with_minutes(self): async def test_requeue_signal_send_failure(self): """Test requeue signal sending failure.""" delta = timedelta(seconds=30) - signal = ReQueueAfterSingal(delta) + signal = ReQueueAfterSignal(delta) mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() mock_post_response.status = 400 @@ -219,13 +219,13 @@ async def test_signal_handling_direct(self): key="test-key" ) - # Test PruneSingal with runtime endpoint - prune_signal = PruneSingal({"reason": "direct_test"}) + # Test PruneSignal with runtime endpoint + prune_signal = PruneSignal({"reason": "direct_test"}) mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() mock_post_response.status = 200 with patch('exospherehost.signals.ClientSession', return_value=mock_session): - await prune_signal.send(runtime._get_prune_endpoint("test-state"), runtime._key) + await prune_signal.send(runtime._get_prune_endpoint("test-state"), runtime._key) # type: ignore # Verify prune endpoint was called correctly mock_session.post.assert_called_once_with( @@ -245,13 +245,13 @@ async def test_requeue_signal_handling_direct(self): key="test-key" ) - # Test ReQueueAfterSingal with runtime endpoint - requeue_signal = ReQueueAfterSingal(timedelta(minutes=10)) + # Test ReQueueAfterSignal with runtime endpoint + requeue_signal = ReQueueAfterSignal(timedelta(minutes=10)) mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() mock_post_response.status = 200 with patch('exospherehost.signals.ClientSession', return_value=mock_session): - await requeue_signal.send(runtime._get_requeue_after_endpoint("test-state"), runtime._key) + await requeue_signal.send(runtime._get_requeue_after_endpoint("test-state"), runtime._key) # type: ignore # Verify requeue endpoint was called correctly expected_body = {"enqueue_after": 600000} # 10 minutes * 60 * 1000 @@ -423,22 +423,22 @@ class TestSignalIntegration: @pytest.mark.asyncio async def test_signal_exception_behavior(self): """Test that signals are proper exceptions that can be raised and caught.""" - # Test PruneSingal - prune_signal = PruneSingal({"test": "data"}) + # Test PruneSignal + prune_signal = PruneSignal({"test": "data"}) - with pytest.raises(PruneSingal) as exc_info: + with pytest.raises(PruneSignal) as exc_info: raise prune_signal assert exc_info.value.data == {"test": "data"} assert isinstance(exc_info.value, Exception) - # Test ReQueueAfterSingal - requeue_signal = ReQueueAfterSingal(timedelta(seconds=30)) + # Test ReQueueAfterSignal + requeue_signal = ReQueueAfterSignal(timedelta(seconds=30)) - with pytest.raises(ReQueueAfterSingal) as exc_info: + with pytest.raises(ReQueueAfterSignal) as exc_info: raise requeue_signal - assert exc_info.value.timedelta == timedelta(seconds=30) + assert exc_info.value.delay == timedelta(seconds=30) assert isinstance(exc_info.value, Exception) @pytest.mark.asyncio @@ -453,21 +453,21 @@ async def test_combined_signal_and_runtime_functionality(self): state_manage_version="v1" ) - # Test PruneSingal with production-like endpoint - prune_signal = PruneSingal({"reason": "cleanup", "batch_id": "batch-123"}) + # Test PruneSignal with production-like endpoint + prune_signal = PruneSignal({"reason": "cleanup", "batch_id": "batch-123"}) expected_prune_endpoint = "https://api.exosphere.host/v1/namespace/production/state/prod-state-456/prune" actual_prune_endpoint = runtime._get_prune_endpoint("prod-state-456") assert actual_prune_endpoint == expected_prune_endpoint - # Test ReQueueAfterSingal with production-like endpoint - requeue_signal = ReQueueAfterSingal(timedelta(hours=2, minutes=30)) + # Test ReQueueAfterSignal with production-like endpoint + requeue_signal = ReQueueAfterSignal(timedelta(hours=2, minutes=30)) expected_requeue_endpoint = "https://api.exosphere.host/v1/namespace/production/state/prod-state-789/re-enqueue-after" actual_requeue_endpoint = runtime._get_requeue_after_endpoint("prod-state-789") assert actual_requeue_endpoint == expected_requeue_endpoint # Test that signal data is preserved assert prune_signal.data == {"reason": "cleanup", "batch_id": "batch-123"} - assert requeue_signal.timedelta == timedelta(hours=2, minutes=30) + assert requeue_signal.delay == timedelta(hours=2, minutes=30) @pytest.mark.asyncio async def test_signal_send_with_different_endpoints(self): @@ -504,13 +504,13 @@ class TestSignalEdgeCases: """Test cases for signal edge cases and error conditions.""" def test_prune_signal_with_empty_data(self): - """Test PruneSingal with empty data.""" - signal = PruneSingal({}) + """Test PruneSignal with empty data.""" + signal = PruneSignal({}) assert signal.data == {} assert isinstance(signal, Exception) def test_prune_signal_with_complex_data(self): - """Test PruneSingal with complex nested data.""" + """Test PruneSignal with complex nested data.""" complex_data = { "reason": "batch_cleanup", "metadata": { @@ -523,23 +523,23 @@ def test_prune_signal_with_complex_data(self): "notify_users": False } } - signal = PruneSingal(complex_data) + signal = PruneSignal(complex_data) assert signal.data == complex_data def test_requeue_signal_with_zero_timedelta(self): - """Test ReQueueAfterSingal with zero timedelta.""" - signal = ReQueueAfterSingal(timedelta(seconds=0)) - assert signal.timedelta == timedelta(seconds=0) + """Test ReQueueAfterSignal with zero timedelta.""" + signal = ReQueueAfterSignal(timedelta(seconds=0)) + assert signal.delay == timedelta(seconds=0) def test_requeue_signal_with_large_timedelta(self): - """Test ReQueueAfterSingal with large timedelta.""" + """Test ReQueueAfterSignal with large timedelta.""" large_delta = timedelta(days=7, hours=12, minutes=30, seconds=45) - signal = ReQueueAfterSingal(large_delta) - assert signal.timedelta == large_delta + signal = ReQueueAfterSignal(large_delta) + assert signal.delay == large_delta @pytest.mark.asyncio async def test_requeue_signal_timedelta_conversion(self): - """Test that ReQueueAfterSingal correctly converts timedelta to milliseconds.""" + """Test that ReQueueAfterSignal correctly converts timedelta to milliseconds.""" test_cases = [ (timedelta(seconds=1), 1000), (timedelta(minutes=1), 60000), @@ -549,7 +549,7 @@ async def test_requeue_signal_timedelta_conversion(self): ] for delta, expected_ms in test_cases: - signal = ReQueueAfterSingal(delta) + signal = ReQueueAfterSignal(delta) mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() mock_post_response.status = 200 @@ -567,13 +567,13 @@ async def test_requeue_signal_timedelta_conversion(self): def test_signal_string_representations(self): """Test string representations of signals.""" - prune_signal = PruneSingal({"test": "data"}) + prune_signal = PruneSignal({"test": "data"}) prune_str = str(prune_signal) assert "Prune signal received with data" in prune_str assert "Do not catch this Exception" in prune_str assert "{'test': 'data'}" in prune_str - requeue_signal = ReQueueAfterSingal(timedelta(minutes=5)) + requeue_signal = ReQueueAfterSignal(timedelta(minutes=5)) requeue_str = str(requeue_signal) assert "ReQueueAfter signal received with timedelta" in requeue_str assert "Do not catch this Exception" in requeue_str @@ -603,7 +603,7 @@ async def test_notify_executed_function(self): outputs = [MockTestNode.Outputs(message="output1"), MockTestNode.Outputs(message="output2")] with patch('exospherehost.runtime.ClientSession', return_value=mock_session): - await runtime._notify_executed("test-state-id", outputs) + await runtime._notify_executed("test-state-id", outputs) # type: ignore # Verify correct endpoint and payload expected_body = {"outputs": [{"message": "output1"}, {"message": "output2"}]} @@ -660,7 +660,7 @@ async def test_notify_functions_with_failure(self): with patch('exospherehost.runtime.ClientSession', return_value=mock_session): # These should not raise exceptions, just log errors - await runtime._notify_executed("test-state-id", outputs) + await runtime._notify_executed("test-state-id", outputs) # type: ignore await runtime._notify_errored("test-state-id", "Test error") # Verify both endpoints were called despite failures diff --git a/state-manager/app/models/db/state.py b/state-manager/app/models/db/state.py index 6651b9c0..e37326b4 100644 --- a/state-manager/app/models/db/state.py +++ b/state-manager/app/models/db/state.py @@ -18,12 +18,12 @@ class State(BaseDatabaseModel): status: StateStatusEnum = Field(..., description="Status of the state") inputs: dict[str, Any] = Field(..., description="Inputs of the state") outputs: dict[str, Any] = Field(..., description="Outputs of the state") - data: dict[str, Any] = Field(default_factory=dict, description="Data of the state") + data: dict[str, Any] = Field(default_factory=dict, description="Data of the state (could be used to save pruned meta data)") error: Optional[str] = Field(None, description="Error message") parents: dict[str, PydanticObjectId] = Field(default_factory=dict, description="Parents of the state") does_unites: bool = Field(default=False, description="Whether this state unites other states") state_fingerprint: str = Field(default="", description="Fingerprint of the state") - enqueue_after: int = Field(default_factory=lambda: int(time.time() * 1000), description="Unix time in milliseconds after which the state should be enqueued") + enqueue_after: int = Field(default_factory=lambda: int(time.time() * 1000), gt=0, description="Unix time in milliseconds after which the state should be enqueued") @before_event([Insert, Replace, Save]) def _generate_fingerprint(self): diff --git a/state-manager/app/models/signal_models.py b/state-manager/app/models/signal_models.py index 3e86d53a..40abe6f4 100644 --- a/state-manager/app/models/signal_models.py +++ b/state-manager/app/models/signal_models.py @@ -11,4 +11,4 @@ class PruneRequestModel(BaseModel): data: dict[str, Any] = Field(..., description="Data of the state") class ReEnqueueAfterRequestModel(BaseModel): - enqueue_after: int = Field(..., description="Duration in milliseconds to delay the re-enqueuing of the state") \ No newline at end of file + enqueue_after: int = Field(..., gt=0, description="Duration in milliseconds to delay the re-enqueuing of the state") \ No newline at end of file diff --git a/state-manager/app/routes.py b/state-manager/app/routes.py index a108c77c..ee219bb1 100644 --- a/state-manager/app/routes.py +++ b/state-manager/app/routes.py @@ -157,7 +157,7 @@ async def errored_state_route(namespace_name: str, state_id: str, body: ErroredR "/states/{state_id}/prune", response_model=SignalResponseModel, status_code=status.HTTP_200_OK, - response_description="State skipped successfully", + response_description="State pruned successfully", tags=["state"] ) async def prune_state_route(namespace_name: str, state_id: str, body: PruneRequestModel, request: Request, api_key: str = Depends(check_api_key)): From 5d66297250d6c62068f238235d5a6b96037cda79 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Sat, 30 Aug 2025 16:37:22 +0530 Subject: [PATCH 14/19] Correct signal naming in tests and enhance exception handling - Updated expected exports in test_package_init.py to reflect the correct signal name `PruneSignal` and `ReQueueAfterSignal`. - Modified test_requeue_signal_with_zero_timedelta in test_signals_and_runtime_functions.py to raise an exception when initialized with zero timedelta, improving error handling for the ReQueueAfterSignal class. - Removed unnecessary assertions in the test to streamline the test logic. --- python-sdk/tests/test_package_init.py | 2 +- python-sdk/tests/test_signals_and_runtime_functions.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/python-sdk/tests/test_package_init.py b/python-sdk/tests/test_package_init.py index 926ef43e..d5406b52 100644 --- a/python-sdk/tests/test_package_init.py +++ b/python-sdk/tests/test_package_init.py @@ -15,7 +15,7 @@ def test_package_all_imports(): """Test that __all__ contains all expected exports.""" from exospherehost import __all__ - expected_exports = ["Runtime", "BaseNode", "StateManager", "TriggerState", "VERSION", "PruneSingal", "ReQueueAfterSingal"] + expected_exports = ["Runtime", "BaseNode", "StateManager", "TriggerState", "VERSION", "PruneSignal", "ReQueueAfterSignal"] for export in expected_exports: assert export in __all__, f"{export} should be in __all__" diff --git a/python-sdk/tests/test_signals_and_runtime_functions.py b/python-sdk/tests/test_signals_and_runtime_functions.py index cf12168c..e6a2222d 100644 --- a/python-sdk/tests/test_signals_and_runtime_functions.py +++ b/python-sdk/tests/test_signals_and_runtime_functions.py @@ -528,8 +528,8 @@ def test_prune_signal_with_complex_data(self): def test_requeue_signal_with_zero_timedelta(self): """Test ReQueueAfterSignal with zero timedelta.""" - signal = ReQueueAfterSignal(timedelta(seconds=0)) - assert signal.delay == timedelta(seconds=0) + with pytest.raises(Exception): + ReQueueAfterSignal(timedelta(seconds=0)) def test_requeue_signal_with_large_timedelta(self): """Test ReQueueAfterSignal with large timedelta.""" @@ -577,8 +577,6 @@ def test_signal_string_representations(self): requeue_str = str(requeue_signal) assert "ReQueueAfter signal received with timedelta" in requeue_str assert "Do not catch this Exception" in requeue_str - assert "0:05:00" in requeue_str - class TestRuntimeHelperFunctions: """Test cases for Runtime helper functions.""" From ef84f0c683a4d85ffd24b4db10e1c66bd0300ce1 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Sat, 30 Aug 2025 16:37:30 +0530 Subject: [PATCH 15/19] Enhance validation in ReEnqueueAfterRequestModel tests - Updated tests for ReEnqueueAfterRequestModel to raise exceptions for zero and negative delay values, ensuring stricter validation. - Adjusted test cases in related signal tests to reflect the new validation rules, improving overall test reliability and coverage. --- .../controller/test_re_queue_after_signal.py | 29 ++++++------------- .../tests/unit/models/test_signal_models.py | 16 ++++------ state-manager/tests/unit/test_routes.py | 12 ++++---- 3 files changed, 19 insertions(+), 38 deletions(-) diff --git a/state-manager/tests/unit/controller/test_re_queue_after_signal.py b/state-manager/tests/unit/controller/test_re_queue_after_signal.py index fc437b58..48f41922 100644 --- a/state-manager/tests/unit/controller/test_re_queue_after_signal.py +++ b/state-manager/tests/unit/controller/test_re_queue_after_signal.py @@ -110,7 +110,7 @@ async def test_re_queue_after_signal_with_zero_delay( """Test re-enqueuing with zero delay""" # Arrange mock_time.time.return_value = 1000.0 - re_enqueue_request = ReEnqueueAfterRequestModel(enqueue_after=0) + re_enqueue_request = ReEnqueueAfterRequestModel(enqueue_after=1) mock_state_any_status.save = AsyncMock() mock_state_class.find_one = AsyncMock(return_value=mock_state_any_status) @@ -124,8 +124,8 @@ async def test_re_queue_after_signal_with_zero_delay( # Assert assert result.status == StateStatusEnum.CREATED - assert result.enqueue_after == 1000000 # 1000 * 1000 + 0 - assert mock_state_any_status.enqueue_after == 1000000 + assert result.enqueue_after == 1000001 # 1000 * 1000 + 0 + assert mock_state_any_status.enqueue_after == 1000001 assert mock_state_any_status.save.call_count == 1 @patch('app.controller.re_queue_after_signal.State') @@ -173,24 +173,13 @@ async def test_re_queue_after_signal_with_negative_delay( ): """Test re-enqueuing with negative delay (should still work)""" # Arrange - mock_time.time.return_value = 1000.0 - re_enqueue_request = ReEnqueueAfterRequestModel(enqueue_after=-5000) # Negative delay - mock_state_any_status.save = AsyncMock() - mock_state_class.find_one = AsyncMock(return_value=mock_state_any_status) - - # Act - result = await re_queue_after_signal( - mock_namespace, - mock_state_id, - re_enqueue_request, - mock_request_id - ) + + with pytest.raises(Exception): + ReEnqueueAfterRequestModel(enqueue_after=-5000) # Negative delay - # Assert - assert result.status == StateStatusEnum.CREATED - assert result.enqueue_after == 995000 # 1000 * 1000 + (-5000) - assert mock_state_any_status.enqueue_after == 995000 - assert mock_state_any_status.save.call_count == 1 + with pytest.raises(Exception): + ReEnqueueAfterRequestModel(enqueue_after=0) + @patch('app.controller.re_queue_after_signal.State') async def test_re_queue_after_signal_database_error( diff --git a/state-manager/tests/unit/models/test_signal_models.py b/state-manager/tests/unit/models/test_signal_models.py index fd14636d..4eea9141 100644 --- a/state-manager/tests/unit/models/test_signal_models.py +++ b/state-manager/tests/unit/models/test_signal_models.py @@ -77,21 +77,15 @@ def test_re_enqueue_after_request_model_valid_delay(self): def test_re_enqueue_after_request_model_zero_delay(self): """Test ReEnqueueAfterRequestModel with zero delay""" # Arrange & Act - delay = 0 - model = ReEnqueueAfterRequestModel(enqueue_after=delay) - - # Assert - assert model.enqueue_after == delay + with pytest.raises(Exception): + ReEnqueueAfterRequestModel(enqueue_after=0) def test_re_enqueue_after_request_model_negative_delay(self): """Test ReEnqueueAfterRequestModel with negative delay""" # Arrange & Act - delay = -5000 - model = ReEnqueueAfterRequestModel(enqueue_after=delay) - - # Assert - assert model.enqueue_after == delay - + with pytest.raises(Exception): + ReEnqueueAfterRequestModel(enqueue_after=-5000) + def test_re_enqueue_after_request_model_large_delay(self): """Test ReEnqueueAfterRequestModel with large delay""" # Arrange & Act diff --git a/state-manager/tests/unit/test_routes.py b/state-manager/tests/unit/test_routes.py index 92c9ca28..82cce20b 100644 --- a/state-manager/tests/unit/test_routes.py +++ b/state-manager/tests/unit/test_routes.py @@ -153,13 +153,13 @@ def test_re_enqueue_after_request_model_validation(self): # Test with zero delay zero_data = {"enqueue_after": 0} - model = ReEnqueueAfterRequestModel(**zero_data) - assert model.enqueue_after == 0 + with pytest.raises(Exception): + ReEnqueueAfterRequestModel(**zero_data) # Test with negative delay negative_data = {"enqueue_after": -5000} - model = ReEnqueueAfterRequestModel(**negative_data) - assert model.enqueue_after == -5000 + with pytest.raises(Exception): + ReEnqueueAfterRequestModel(**negative_data) # Test with large delay large_data = {"enqueue_after": 86400000} @@ -821,11 +821,9 @@ async def test_re_enqueue_after_state_route_with_different_delays(self, mock_re_ # Test cases with different delays test_cases = [ - 0, # No delay 1000, # 1 second 60000, # 1 minute - 3600000, # 1 hour - -5000 # Negative delay + 3600000 # 1 hour ] for delay in test_cases: From 74f51a65bafca01bc274f6ecba6be9a56e641c78 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Sat, 30 Aug 2025 16:48:58 +0530 Subject: [PATCH 16/19] Add Signals documentation and update navigation in mkdocs.yml - Introduced a new documentation file for Signals, detailing their functionality, usage, and examples for `PruneSignal` and `ReQueueAfterSignal`. - Updated mkdocs.yml to include the new Signals documentation in the navigation structure, enhancing accessibility for users. --- docs/docs/exosphere/signals.md | 131 +++++++++++++++++++++++++++++++++ docs/mkdocs.yml | 2 + 2 files changed, 133 insertions(+) create mode 100644 docs/docs/exosphere/signals.md diff --git a/docs/docs/exosphere/signals.md b/docs/docs/exosphere/signals.md new file mode 100644 index 00000000..6796b455 --- /dev/null +++ b/docs/docs/exosphere/signals.md @@ -0,0 +1,131 @@ +# Signals + +!!! beta "Beta Feature" + Signals are currently available in beta. The API and functionality may change in future releases. + +Signals are a mechanism in Exosphere for controlling workflow execution flow and state management. They allow nodes to communicate with the state manager to perform specific actions like pruning states or requeuing them after a delay. + +## Overview + +Signals are implemented as exceptions that should be raised from within node execution. When a signal is raised, the runtime automatically handles the communication with the state manager to perform the requested action. + +## Available Signals + +### PruneSignal + +The `PruneSignal` is used to permanently remove a state from the workflow execution. This is typically used when a node determines that the current execution path should be terminated. + +#### Usage + +```python +from exospherehost import PruneSignal + +class MyNode(BaseNode): + class Inputs(BaseModel): + data: str + + class Outputs(BaseModel): + result: str + + async def execute(self, inputs: Inputs) -> Outputs: + if inputs.data == "invalid": + # Prune the state with optional data + raise PruneSignal({"reason": "invalid_data", "error": "Data validation failed"}) + + return self.Outputs(result="processed") +``` + +#### Parameters + +- `data` (dict[str, Any], optional): Additional data to include with the prune operation. Defaults to an empty dictionary. + +### ReQueueAfterSignal + +The `ReQueueAfterSignal` is used to requeue a state for execution after a specified time delay. This is useful for implementing retry logic, scheduled tasks, or rate limiting. + +#### Usage + +```python +from exospherehost import ReQueueAfterSignal +from datetime import timedelta + +class RetryNode(BaseNode): + class Inputs(BaseModel): + retry_count: int + data: str + + class Outputs(BaseModel): + result: str + + async def execute(self, inputs: Inputs) -> Outputs: + if inputs.retry_count < 3: + # Requeue after 5 minutes + raise ReQueueAfterSignal(timedelta(minutes=5)) + + return self.Outputs(result="completed") +``` + +#### Parameters + +- `delay` (timedelta): The amount of time to wait before requeuing the state. Must be greater than 0. + +## Important Notes + +1. **Do not catch signals**: Signals are designed to bubble up to the runtime for handling. Do not catch these exceptions in your node code. + +2. **Automatic handling**: The runtime automatically sends signals to the state manager when they are raised. + +3. **State lifecycle**: Signals affect the state's lifecycle in the state manager: + - `PruneSignal`: Sets state status to `PRUNED` + - `ReQueueAfterSignal`: Sets state status to `CREATED` and schedules requeue + +## Error Handling + +If signal sending fails (e.g., network issues), the runtime will log the error and continue processing other states. The failed signal will not be retried automatically. + +## Examples + +### Conditional Pruning + +```python +class ValidationNode(BaseNode): + class Inputs(BaseModel): + user_id: str + data: dict + + async def execute(self, inputs: Inputs) -> Outputs: + if not self._validate_user(inputs.user_id): + raise PruneSignal({ + "reason": "invalid_user", + "user_id": inputs.user_id, + "timestamp": datetime.now().isoformat() + }) + + return self.Outputs(validated=True) +``` + +### Polling + +```python +class PollingNode(BaseNode): + class Inputs(BaseModel): + job_id: str + + async def execute(self, inputs: Inputs) -> Outputs: + # Check if the job is complete + job_status = await self._check_job_status(inputs.job_id) + + if job_status == "completed": + result = await self._get_job_result(inputs.job_id) + return self.Outputs(result=result) + elif job_status == "failed": + # Job failed, prune the state + raise PruneSignal({ + "reason": "job_failed", + "job_id": inputs.job_id, + "poll_count": inputs.poll_count + }) + else: + # Job still running, poll again in 30 seconds + raise ReQueueAfterSignal(timedelta(seconds=30)) +``` \ No newline at end of file diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index df257a62..51df18d8 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -103,6 +103,7 @@ plugins: - exosphere/create-graph.md - exosphere/trigger-graph.md - exosphere/dashboard.md + - exosphere/signals.md - exosphere/architecture.md markdown_extensions: @@ -130,4 +131,5 @@ nav: - Create Graph: exosphere/create-graph.md - Trigger Graph: exosphere/trigger-graph.md - Dashboard: exosphere/dashboard.md + - Signals: exosphere/signals.md - Architecture: exosphere/architecture.md \ No newline at end of file From c13b53043fed85255d5f2f199d4036546c773751 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Sat, 30 Aug 2025 16:55:14 +0530 Subject: [PATCH 17/19] Update prune_signal status check to validate against QUEUED state - Modified the status check in the prune_signal function to ensure that the state must be QUEUED before proceeding with pruning, enhancing the validation logic and error handling for state management operations. --- state-manager/app/controller/prune_signal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/state-manager/app/controller/prune_signal.py b/state-manager/app/controller/prune_signal.py index 14835290..74460c13 100644 --- a/state-manager/app/controller/prune_signal.py +++ b/state-manager/app/controller/prune_signal.py @@ -18,8 +18,8 @@ async def prune_signal(namespace_name: str, state_id: PydanticObjectId, body: Pr if not state: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="State not found") - if state.status != StateStatusEnum.CREATED: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="State is not created") + if state.status != StateStatusEnum.QUEUED: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="State is not queued") state.status = StateStatusEnum.PRUNED state.data = body.data From 49e2eb4e1b184f330f7e16a1505e9f955114b593 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Sat, 30 Aug 2025 17:00:06 +0530 Subject: [PATCH 18/19] fixed all failing tests --- state-manager/.coverage | Bin 0 -> 69632 bytes .../tests/unit/controller/test_prune_signal.py | 14 +++++++------- 2 files changed, 7 insertions(+), 7 deletions(-) create mode 100644 state-manager/.coverage diff --git a/state-manager/.coverage b/state-manager/.coverage new file mode 100644 index 0000000000000000000000000000000000000000..c086d6ebbc568c908ef8936e383e094905537430 GIT binary patch literal 69632 zcmeHQ3v?S-nbt_7w`Q&%u@&2~Ex%$re#DLwJ2t^358Daryf!)Qp>Zo@d2FkZB{4H< z?Dn)qO(8v$o^lpWTV7pYdti6N7EaGX%cBXf<+TS2O?fOSNy`?JP-wcOG<~J@{`by| zG?JxuMulYwSI540B+dQi|G)d+|K9&z-I;xR$0M>3n2w8cAvw@zDX}n&KjH@~ohpos@(r+|j{j(AD7+yGeO`6pTzk zDOs2Sh=drK3yBMXgTg{5IYGUeEX>Qw2LK{OW@7k^K$}{;Js=9xf+)nof}}{%7MW`I zv>#sdIQ`}2%n^a;6cWSm=gravKFFxDV;X?Mad8SuKPiS{;aNfI4204ehT|aTysT&| z#^=_`@jx^ZBWlK^h#ZN>0>ZoyPRhbmUIjZbY1#|O(`y4*@5TzfTPdktOUF*8aGsLt z>1yF=5H1$iXw4MFmxjVb=2&uW5-7heIhVkd10|FT%ke@b3ze1QsA+ER36q zgmnnVuNTD7jL@xqtGI=~v>}1%^tv7Xk@`ZWlQw~1Z~D-as&Y8}H8spFZem7iZK;_v z%Y?L=%;QG2ujYi1loT~b@1#glFMMP+_;DePOAb?Mt3c<;t$#_yG_KV*V^Td5m zMqULV^Iqc~&T1#8L-XDfm!t@EtGL^ehO|!ut+OEsdm*VcFd2_Wg-}fC8=%Av39cED z0z`)8z0wsK3~2#l%v|3}WYRhp?B@4Fqtn_K9L&vyVpD2=g}Y$|6?{K#ocKD56M=$( zR%nzGG|?PP4OIOWdn^sVI{}-e_A$7UamB5WuGpVTfTJ z3Q2*mD8QG{m$tXs-pCAaZ3TOAilbx~9?GTa8K=LagGnXT#;Ew2pah-zTu2jJeQV7! zlAJFy_OcrsE{t~cc7nOn(ShG8SHR9qsliEw1??<#tmA3tKuQL^YJorcq6|<5C`HWIu8kbQ3fajlmW^BWq>k38K4YM z1}FoR0m=YnfHLq`$$-ylt3}!`JbFqPTTPYv8-RiC-k$D04`^jY=Pc+Pdf~59K`Jz5 zfHFWCpbSt3Ce_bKl^-mAUUo|B%BdG>hR?r*y1-RoTET%UG@UDeJX zJ8yCJIez1~+p*8#vp-^&?al1t?CtER?FHLW+ch@c`mpr|>jvg&<_NQoL6%=w?n;YQ z;9&CrYe_k`4o!sPQ^JHWAD0rdu(xtHF3A&;47)M9U>gZ+*b*l~iNr)wjzpyi*isQb z7?cFqt|7xlkoEmGOWY`e2-|A}2}*6*02g)}Ej9;}MkqkwAli^5Dqk_9ms19OmqP$>!6 zSBT-py8y2Ct-_Hy0JQC$fL37?O{Z&bQR4bA zyDQDGtCvSb(!?3{1CU{_sfiO)+_$kf(w$TVz!Gk|5*Teqc8=TwX+Zo#73anJKS~LE-Kn;z9ztM!5c_QWIvc6xUsSv;?kh zwpyozga9|0ApSrE8bFF22ypyHSyasfcYNs{OR-six{?Q|UXxJSzBmZ`)FZIV8+LUQ z|B%ELwUYx}!~Rqi5nNUdH2DJa?lZxA25mA#hHS%L5XP|8Wr}OfakB@?8y?kW%Im6r z%ncBR9j^*OuCH<$6yQD0F&kaoW>u zv6>M^2c#3GpRrjaqis_u2eEg_fE7{(J%rwt;d-C)_&&K%x!!lcZ%EVd)#*`Q{sG#EA@WV<8gn>-RJs@W545!uhaW0 z+qE`i|28|ueZ+E-pX6WSzv!BB)mU%z9;Uw?p24H_?$4)2Y`|9yMeR7+9Vh2sBp+t`%1C`6sb z_L{2Wyzzg}9yZlvSPosoDHQ*&-NnK!JB&6XwfeQ8^YZw=dncQ!eamp#(5h<(o2oF3 zrqecCK=6K@#qw4t{_ikn5c%W(HDhe5v>5$z;{WzywJQ|=x9w(A4aEpnApUPH7Oz13 zzuFXW+JHfc*Y(p?Eu(A-6{BZ%{NG%R4^;+p=}||Msg|R&^}O+aqbW3%pn~y#gDEGG zS2lh8A220c4&1W%zuuISP*u$v|JQA2Q{{%O;SJ*d+PB*lP2%XfiB+bIQdKA&w5JWK z#uP6)meWw$2)}xSO>u_x(FDqk|Eo-~tCuGy{$E*4;}_C#dwQ?YYE+vkuQNQ4Dc7JvT8_q+}FDV0*0m=YnfHFWCpbSt3C z47y~2Kl-8!PzERilmW^BWq>k38K4YM1}FoR0m=Yn;4);uX5-v={eQ^!v<1D4{(w%S zpP}!dhtZeNzo1W{d(aVdGung3&;V*f0aS)O{OkOA{tW+1{>S`x`3L#?`OosV^7DL@ z-_LL7hxi`8k+0@?-pXC%p5cDQ{fzq#_b~VG+@NwG-_ETRO^0eaPR^@m#K3;u-eaG|9I9lMC zZ<+QrLtfo{t=01Ov!8gz)dcU#nugCJUnBm!@!H|z!;6cIrvctoHk@91{_Ij^0FJu@ zC*1En^V$^;EDaod=%GtL{M{n&smJHkD_5$+$92O?H~SocS~%WP`}*&6hrfI52Ydhd8%G=t7N5W}r8dZ$ZLeH2{PCtAEnKxU zc52XJ#a~$&JHxu&JcHN&hxn)ky^8*bo<(QSljy(DW9UE8f1t0!Jiz_vb1)O|QS=Y! zc62Mc30;rkXcmRg)o33YMk38K4YYmJHx% zl^Gf$+2A0_Hf|)@zyQhm`$@K81IgB}Cs|(~$=0nSS#K}NdU{BDY|R>ywYQV3t&L=@tt4B$nq)04Bx`OaSyL0q8XHO0&_J?4fMoUcB&(|MY5GEN#^&Htg@116%`~~v4UjfUBBUx!F$x2E{h7ieko@5+H zGM|rRUN6Z!9+J7;By+h)<|H!%i~}YDfO5N?WGqWEn~h{vD+?0xl|Pvd`T4wM1P0A+wOKpCJ6PzERilmW^B zWq>k38MwR|Ap8IM)A;xQ7vT;7zei`$f1^{d2jFq^DEdBn1RY0Tg?#{DL`TtQ(WlWT z(A}^X;0|;fIt=#)SU@uD2e=Li=-uc5x)SY$Jpp5A3wk@+i26`B>k38K4YM1}FpncMRBu;1>6` zL7Z&F$pB9Jak2p?>v7VDlXW=h#YqoN*5afaCtWz{#7PHE*5ITaCv7-s#mQ=%wBV!} zCrvnM#7P5A0ywG1NgYmVak2_0H8`opNfl03;>3@WN}N>SWCc#jaZ-kpQk<0F1mT3o z35OFOPP{nr;KYp+7fzfwap1&`6BZ{noLF%}?*DJS(NfFox1jHPp0FPFd>p+GZSw4O zFSvf|JnUTWxW~cSllB(&C+r7p7j2)n9%J@f`Utao?@Z$GB4OM?D_* zx7>ZM&p7rw&iFdLzp`CxL-ud8W86n9C;3UZC%_k7Q?458jot&6yS+7UY5}I6qNzMQ zH9RYX4+bUjSZjF3`M?HS-0;In#duN{Bq-HyLzZoZixqgLVxR4xafnbP2G?DGfK3e= zFBF-YiVBB9q99EqW3!1D1sg|O!MS(olHyD}9MqFp^ z!J=tlqzT>QGkbw*7Zz*CwqRH~^)bKUxQJ_OsMsgHClit&%E9c1(*odHcFgig!@{W5 zFA6geNfyK)er&F!f1rA|sb`-Sg3}(Y+*RzOm9qr|H9L#tO%#G4O;QMkrm@0OWF{7h z0$_*z+m;ctyeXxFirN6O#?-T2b3c5vILZ?V?bPP<={mLW3gp34*olW>XD@ zwbKbk0JDt@n{8{cc=+M=@rJ`f+ERcGAt*-@YAw2n0j7yUfI&9 zAY9P$V}SPn@Cr3v-P#xY@DCsM7M2XZSo zx^7~XyU#LaRv}?t2t$t!-4?jyC^pKPVx4CiN_)b0wfkF^QPTp6VqDbAQ)P->oe>Et zk|xed7){-2TAY~T#$f=FfboI=1o=%he$|NxX&kMSsw|pk6-x1MAW(&=2af|o(~+48 zNs#48YzEi(iek#^46mFkWg>=&RmPw~EG#Z0plgKdmzgnprMT|urCh1ycZRR7R>zc( z5Mon`KM;WikfP+BY|3vI2X}lyS(r;iL$d1l0192jrh3gnW&7gb^`U5FDuf5Aq$5k> zisDWAQ%%*$i5zJ11;BA;c+a3shDbB;LBR43vr2;h)S9E$ROFJmysqj!{G?@vSrN$f zRZfFKEI9{ExXlgIiL2CHxFQ*VxJ)%8RUw@=PKfnR8rDVvtBndW_@s$_+Tk_NVx-b(#un?s)m}2~>1?Ki(e(v7VXXZ&$qm6& z7;14sn2$?|SwR$L@w5=B3sgY)$fXmd-JO84Anb=`@%w|deWv1hLL5xf^A(si=9t<@2p0Y>G1+ zV5_2NbDL>ls!aJI4JAD(`6r*;EJmDslX9kdI#n(GB$KI0)=YV=?W${AP}_C*_x}Ty zyDaE?aDV@|!LRebz=!w-?w8z$;YocZz7xK~z7FqG-g~@Pd#gPsJskZZo%+t&f zW*>trzktiV$!qffYe_kcPG{up%GblsJmN+X((C_QHo%1q&!c3RjfLi5Dm@g&GwJK# zXC()XUU^Pg|G&Bqt~+SBkZwi&&Da0;tOM+pVz3LX|8MODL~k*Ox*4m@Jpf|3j`NO| zcCQ79CZlrbCe8}2|L^PuzAZxv2kTG`$SXcb1$blPSM2o}1# zqe$Kgt^aRb1AxX2LYZmK9oCGt14wBR`sJ+u?`|tnyF%;#yIKLTp$NeWtpD#?T_j$C z_5aZpGsLM@n`ixhq#1Bf5qf5?|8Hpm7_&Z98Ph1TfrXAY0*2w1qqB7^O7_AKMr$yG zrV>;Gf5fK+bO8r~oj-2^{( zTr`KH>n28OKn%l1sVbCSv`-t<_G&Y{=vYoektqu>rfpS##2M8`6DW86e{7{0cJ=b) ztpAVt0m!h|)WlgnL>{RG9Me3fanhYs1;7}#KfPYL>;J0fa5o5t@+pgujB!$*CbT7FV4IEzmo%8 z!~RsYU~pME(BunfpPS&<(1v->psDNs!(I@^u+?RXYu;hA2g(~B)n>};3~$U05QZJE z3ZYD>W=}S4bpepsQJK0pt-u1E!_FcI(P=a8fbxcW@;95l8?^(P;bx&j%US>5!UB%r zF^UdHr`ebdFbvm|&ei3v|8KFH5k?236Q&=2%r6;jo4Rs1{phaBG1DqXD-%JE?A49vd_>aga5^ll`ld O%^Ly8>||3m#Qy@5H2k#y literal 0 HcmV?d00001 diff --git a/state-manager/tests/unit/controller/test_prune_signal.py b/state-manager/tests/unit/controller/test_prune_signal.py index 66349b26..1c36170a 100644 --- a/state-manager/tests/unit/controller/test_prune_signal.py +++ b/state-manager/tests/unit/controller/test_prune_signal.py @@ -33,7 +33,7 @@ def mock_prune_request(self): def mock_state_created(self): state = MagicMock() state.id = PydanticObjectId() - state.status = StateStatusEnum.CREATED + state.status = StateStatusEnum.QUEUED state.enqueue_after = 1234567890 return state @@ -94,7 +94,7 @@ async def test_prune_signal_state_not_found( assert exc_info.value.detail == "State not found" @patch('app.controller.prune_signal.State') - async def test_prune_signal_invalid_status_queued( + async def test_prune_signal_invalid_status_created( self, mock_state_class, mock_namespace, @@ -105,7 +105,7 @@ async def test_prune_signal_invalid_status_queued( """Test when state is in QUEUED status (invalid for pruning)""" # Arrange mock_state = MagicMock() - mock_state.status = StateStatusEnum.QUEUED + mock_state.status = StateStatusEnum.CREATED mock_state_class.find_one = AsyncMock(return_value=mock_state) # Act & Assert @@ -118,7 +118,7 @@ async def test_prune_signal_invalid_status_queued( ) assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST - assert exc_info.value.detail == "State is not created" + assert exc_info.value.detail == "State is not queued" @patch('app.controller.prune_signal.State') async def test_prune_signal_invalid_status_executed( @@ -145,7 +145,7 @@ async def test_prune_signal_invalid_status_executed( ) assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST - assert exc_info.value.detail == "State is not created" + assert exc_info.value.detail == "State is not queued" @patch('app.controller.prune_signal.State') async def test_prune_signal_invalid_status_errored( @@ -172,7 +172,7 @@ async def test_prune_signal_invalid_status_errored( ) assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST - assert exc_info.value.detail == "State is not created" + assert exc_info.value.detail == "State is not queued" @patch('app.controller.prune_signal.State') async def test_prune_signal_invalid_status_pruned( @@ -199,7 +199,7 @@ async def test_prune_signal_invalid_status_pruned( ) assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST - assert exc_info.value.detail == "State is not created" + assert exc_info.value.detail == "State is not queued" @patch('app.controller.prune_signal.State') async def test_prune_signal_database_error( From ea760a1382323d8cc311d2ed3a9744fd7823189f Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Sat, 30 Aug 2025 17:06:55 +0530 Subject: [PATCH 19/19] namespace check would be added as a seprate unit later to take care of permissions --- state-manager/app/controller/prune_signal.py | 2 +- state-manager/app/controller/re_queue_after_signal.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/state-manager/app/controller/prune_signal.py b/state-manager/app/controller/prune_signal.py index 74460c13..122e93f0 100644 --- a/state-manager/app/controller/prune_signal.py +++ b/state-manager/app/controller/prune_signal.py @@ -13,7 +13,7 @@ async def prune_signal(namespace_name: str, state_id: PydanticObjectId, body: Pr try: logger.info(f"Received prune signal for state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) - state = await State.find_one(State.id == state_id, State.namespace_name == namespace_name) + state = await State.find_one(State.id == state_id) if not state: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="State not found") diff --git a/state-manager/app/controller/re_queue_after_signal.py b/state-manager/app/controller/re_queue_after_signal.py index 57eb2256..009f1424 100644 --- a/state-manager/app/controller/re_queue_after_signal.py +++ b/state-manager/app/controller/re_queue_after_signal.py @@ -14,7 +14,7 @@ async def re_queue_after_signal(namespace_name: str, state_id: PydanticObjectId, try: logger.info(f"Received re-queue after signal for state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) - state = await State.find_one(State.id == state_id, State.namespace_name == namespace_name) + state = await State.find_one(State.id == state_id) if not state: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="State not found")