diff --git a/state-manager/app/controller/executed_state.py b/state-manager/app/controller/executed_state.py index 7712bc81..b07fe167 100644 --- a/state-manager/app/controller/executed_state.py +++ b/state-manager/app/controller/executed_state.py @@ -1,14 +1,15 @@ from app.models.executed_models import ExecutedRequestModel, ExecutedResponseModel from bson import ObjectId -from fastapi import HTTPException, status +from fastapi import HTTPException, status, BackgroundTasks from app.models.db.state import State from app.models.state_status_enum import StateStatusEnum from app.singletons.logs_manager import LogsManager +from app.tasks.create_next_state import create_next_state logger = LogsManager().get_logger() -async def executed_state(namespace_name: str, state_id: ObjectId, body: ExecutedRequestModel, x_exosphere_request_id: str) -> ExecutedResponseModel: +async def executed_state(namespace_name: str, state_id: ObjectId, body: ExecutedRequestModel, x_exosphere_request_id: str, background_tasks: BackgroundTasks) -> ExecutedResponseModel: try: logger.info(f"Executed state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) @@ -20,9 +21,36 @@ async def executed_state(namespace_name: str, state_id: ObjectId, body: Executed if state.status != StateStatusEnum.QUEUED: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="State is not queued") - await State.find_one(State.id == state_id).set( - {"status": StateStatusEnum.EXECUTED, "outputs": body.outputs} - ) + if len(body.outputs) == 0: + await State.find_one(State.id == state_id).set( + {"status": StateStatusEnum.EXECUTED, "outputs": {}, "parents": {**state.parents, state.identifier: ObjectId(state.id)}} + ) + + background_tasks.add_task(create_next_state, state) + + else: + await State.find_one(State.id == state_id).set( + {"status": StateStatusEnum.EXECUTED, "outputs": body.outputs[0], "parents": {**state.parents, state.identifier: ObjectId(state.id)}} + ) + background_tasks.add_task(create_next_state, state) + + for output in body.outputs[1:]: + new_state = State( + node_name=state.node_name, + namespace_name=state.namespace_name, + identifier=state.identifier, + graph_name=state.graph_name, + status=StateStatusEnum.CREATED, + inputs=state.inputs, + outputs=output, + error=None, + parents={ + **state.parents, + state.identifier: ObjectId(state.id) + } + ) + await new_state.save() + background_tasks.add_task(create_next_state, new_state) return ExecutedResponseModel(status=StateStatusEnum.EXECUTED) diff --git a/state-manager/app/models/db/state.py b/state-manager/app/models/db/state.py index d55df68a..6989e1a5 100644 --- a/state-manager/app/models/db/state.py +++ b/state-manager/app/models/db/state.py @@ -1,3 +1,4 @@ +from bson import ObjectId from .base import BaseDatabaseModel from ..state_status_enum import StateStatusEnum from pydantic import Field @@ -13,4 +14,5 @@ class State(BaseDatabaseModel): status: StateStatusEnum = Field(..., description="Status of the state") inputs: dict[str, Any] = Field(..., description="Inputs of the state") outputs: dict[str, Any] = Field(..., description="Outputs of the state") - error: Optional[str] = Field(None, description="Error message") \ No newline at end of file + error: Optional[str] = Field(None, description="Error message") + parents: dict[str, ObjectId] = Field(default_factory=dict, description="Parents of the state") \ No newline at end of file diff --git a/state-manager/app/routes.py b/state-manager/app/routes.py index 553f963b..3b9f0e0b 100644 --- a/state-manager/app/routes.py +++ b/state-manager/app/routes.py @@ -81,7 +81,7 @@ async def create_state(namespace_name: str, graph_name: str, body: CreateRequest response_description="State executed successfully", tags=["state"] ) -async def executed_state_route(namespace_name: str, state_id: str, body: ExecutedRequestModel, request: Request, api_key: str = Depends(check_api_key)): +async def executed_state_route(namespace_name: str, state_id: str, body: ExecutedRequestModel, request: Request, background_tasks: BackgroundTasks, api_key: str = Depends(check_api_key)): x_exosphere_request_id = getattr(request.state, "x_exosphere_request_id", str(uuid4())) @@ -91,7 +91,7 @@ async def executed_state_route(namespace_name: str, state_id: str, body: Execute logger.error(f"API key is invalid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") - return await executed_state(namespace_name, ObjectId(state_id), body, x_exosphere_request_id) + return await executed_state(namespace_name, ObjectId(state_id), body, x_exosphere_request_id, background_tasks) @router.post( diff --git a/state-manager/app/tasks/create_next_state.py b/state-manager/app/tasks/create_next_state.py new file mode 100644 index 00000000..268e818f --- /dev/null +++ b/state-manager/app/tasks/create_next_state.py @@ -0,0 +1,124 @@ +import asyncio +import time + +from bson import ObjectId + +from app.models.db.state import State +from app.models.db.graph_template_model import GraphTemplate +from app.models.graph_template_validation_status import GraphTemplateValidationStatus +from app.models.db.registered_node import RegisteredNode +from app.models.state_status_enum import StateStatusEnum + +from json_schema_to_pydantic import create_model + +async def create_next_state(state: State): + graph_template = None + + try: + start_time = time.time() + timeout_seconds = 300 # 5 minutes + + while True: + graph_template = await GraphTemplate.find_one(GraphTemplate.name == state.graph_name, GraphTemplate.namespace == state.namespace_name) + if not graph_template: + raise Exception(f"Graph template {state.graph_name} not found") + if graph_template.validation_status == GraphTemplateValidationStatus.VALID: + break + + # Check if we've exceeded the timeout + if time.time() - start_time > timeout_seconds: + raise Exception(f"Timeout waiting for graph template {state.graph_name} to become valid after {timeout_seconds} seconds") + + await asyncio.sleep(1) + + node_template = graph_template.get_node_by_identifier(state.identifier) + if not node_template: + raise Exception(f"Node template {state.identifier} not found") + + next_node_identifier = node_template.next_nodes + if not next_node_identifier: + raise Exception(f"Node template {state.identifier} has no next nodes") + + cache_states = {} + + for identifier in next_node_identifier: + next_node_template = graph_template.get_node_by_identifier(identifier) + if not next_node_template: + continue + + registered_node = await RegisteredNode.find_one(RegisteredNode.name == next_node_template.node_name, RegisteredNode.namespace == next_node_template.namespace) + + if not registered_node: + raise Exception(f"Registered node {next_node_template.node_name} not found") + + next_node_input_model = create_model(registered_node.inputs_schema) + next_node_input_data = {} + + for field_name, _ in next_node_input_model.model_fields.items(): + temporary_input = next_node_template.inputs[field_name] + splits = temporary_input.split("${{") + + if len(splits) == 0: + next_node_input_data[field_name] = temporary_input + continue + + constructed_string = "" + for split in splits: + if "}}" in split: + placeholder_content = split.split("}}")[0] + parts = [p.strip() for p in placeholder_content.split('.')] + + if len(parts) != 3 or parts[1] != 'outputs': + raise Exception(f"Invalid input placeholder format: '{placeholder_content}' for field {field_name}") + + input_identifier = parts[0] + input_field = parts[2] + + parent_id = state.parents.get(input_identifier) + + if not parent_id: + raise Exception(f"Parent identifier '{input_identifier}' not found in state parents.") + + if parent_id not in cache_states: + dependent_state = await State.get(ObjectId(parent_id)) + if not dependent_state: + raise Exception(f"Dependent state {input_identifier} not found") + cache_states[parent_id] = dependent_state + else: + dependent_state = cache_states[parent_id] + + if input_field not in dependent_state.outputs: + raise Exception(f"Input field {input_field} not found in dependent state {input_identifier}") + + constructed_string += dependent_state.outputs[input_field] + split.split("}}")[1] + + else: + constructed_string += split + + next_node_input_data[field_name] = constructed_string + + new_state = State( + node_name=next_node_template.node_name, + namespace_name=next_node_template.namespace, + identifier=next_node_template.identifier, + graph_name=state.graph_name, + status=StateStatusEnum.CREATED, + inputs=next_node_input_data, + outputs={}, + error=None, + parents={ + **state.parents, + next_node_template.identifier: ObjectId(state.id) + } + ) + + await new_state.save() + + state.status = StateStatusEnum.SUCCESS + await state.save() + + except Exception as e: + state.status = StateStatusEnum.ERRORED + state.error = str(e) + await state.save() + return