diff --git a/state-manager/app/controller/create_states.py b/state-manager/app/controller/create_states.py index 442f0638..435825e7 100644 --- a/state-manager/app/controller/create_states.py +++ b/state-manager/app/controller/create_states.py @@ -1,24 +1,44 @@ +from fastapi import HTTPException + from app.singletons.logs_manager import LogsManager from app.models.create_models import CreateRequestModel, CreateResponseModel, ResponseStateModel from app.models.state_status_enum import StateStatusEnum from app.models.db.state import State +from app.models.db.graph_template_model import GraphTemplate +from app.models.node_template_model import NodeTemplate from beanie.operators import In from bson import ObjectId logger = LogsManager().get_logger() -async def create_states(namespace_name: str, body: CreateRequestModel, x_exosphere_request_id: str) -> CreateResponseModel: + +def get_node_template(graph_template: GraphTemplate, identifier: str) -> NodeTemplate: + node = graph_template.get_node_by_identifier(identifier) + if not node: + raise HTTPException(status_code=404, detail="Node template not found") + return node + + +async def create_states(namespace_name: str, graph_name: str, body: CreateRequestModel, x_exosphere_request_id: str) -> CreateResponseModel: try: states = [] logger.info(f"Creating states for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + + graph_template = await GraphTemplate.find_one(GraphTemplate.name == graph_name, GraphTemplate.namespace == namespace_name) + if not graph_template: + raise HTTPException(status_code=404, detail="Graph template not found") for state in body.states: + + node_template = get_node_template(graph_template, state.identifier) + states.append( State( - node_name=state.node_name, - namespace_name=namespace_name, - graph_name=state.graph_name, + identifier=state.identifier, + node_name=node_template.node_name, + namespace_name=node_template.namespace, + graph_name=graph_name, status=StateStatusEnum.CREATED, inputs=state.inputs, outputs={}, diff --git a/state-manager/app/controller/enqueue_states.py b/state-manager/app/controller/enqueue_states.py index e028d44a..e7379020 100644 --- a/state-manager/app/controller/enqueue_states.py +++ b/state-manager/app/controller/enqueue_states.py @@ -40,6 +40,7 @@ async def enqueue_states(namespace_name: str, body: EnqueueRequestModel, x_exosp StateModel( state_id=str(state.id), node_name=state.node_name, + identifier=state.identifier, inputs=state.inputs, created_at=state.created_at ) diff --git a/state-manager/app/models/create_models.py b/state-manager/app/models/create_models.py index ba672e0f..063dbf39 100644 --- a/state-manager/app/models/create_models.py +++ b/state-manager/app/models/create_models.py @@ -5,14 +5,14 @@ class RequestStateModel(BaseModel): - node_name: str = Field(..., description="Name of the node of the state") - graph_name: str = Field(..., description="Name of the graph template for this state") + identifier: str = Field(..., description="Unique identifier of the node template within the graph template") inputs: dict[str, Any] = Field(..., description="Inputs of the state") class ResponseStateModel(BaseModel): state_id: str = Field(..., description="ID of the state") node_name: str = Field(..., description="Name of the node of the state") + identifier: str = Field(..., description="Identifier of the node for which state is created") graph_name: str = Field(..., description="Name of the graph template for this state") inputs: dict[str, Any] = Field(..., description="Inputs of the state") created_at: datetime = Field(..., description="Date and time when the state was created") diff --git a/state-manager/app/models/db/graph_template_model.py b/state-manager/app/models/db/graph_template_model.py index 21f9c839..ee3d8d35 100644 --- a/state-manager/app/models/db/graph_template_model.py +++ b/state-manager/app/models/db/graph_template_model.py @@ -9,6 +9,7 @@ from typing import Dict from app.utils.encrypter import get_encrypter + class GraphTemplate(BaseDatabaseModel): name: str = Field(..., description="Name of the graph") namespace: str = Field(..., description="Namespace of the graph") @@ -26,6 +27,13 @@ class Settings: ) ] + def get_node_by_identifier(self, identifier: str) -> NodeTemplate | None: + """Get a node by its identifier using O(1) dictionary lookup.""" + for node in self.nodes: + if node.identifier == identifier: + return node + return None + @field_validator('secrets') @classmethod def validate_secrets(cls, v: Dict[str, str]) -> Dict[str, str]: diff --git a/state-manager/app/models/db/state.py b/state-manager/app/models/db/state.py index c0e687c5..d55df68a 100644 --- a/state-manager/app/models/db/state.py +++ b/state-manager/app/models/db/state.py @@ -8,6 +8,7 @@ class State(BaseDatabaseModel): node_name: str = Field(..., description="Name of the node of the state") namespace_name: str = Field(..., description="Name of the namespace of the state") + identifier: str = Field(..., description="Identifier of the node for which state is created") graph_name: str = Field(..., description="Name of the graph template for this state") status: StateStatusEnum = Field(..., description="Status of the state") inputs: dict[str, Any] = Field(..., description="Inputs of the state") diff --git a/state-manager/app/models/enqueue_response.py b/state-manager/app/models/enqueue_response.py index 65b46aaa..13150dc8 100644 --- a/state-manager/app/models/enqueue_response.py +++ b/state-manager/app/models/enqueue_response.py @@ -6,6 +6,7 @@ class StateModel(BaseModel): state_id: str = Field(..., description="ID of the state") node_name: str = Field(..., description="Name of the node of the state") + identifier: str = Field(..., description="Identifier of the node for which state is created") inputs: dict[str, Any] = Field(..., description="Inputs of the state") created_at: datetime = Field(..., description="Date and time when the state was created") diff --git a/state-manager/app/routes.py b/state-manager/app/routes.py index 70fbef13..f532687c 100644 --- a/state-manager/app/routes.py +++ b/state-manager/app/routes.py @@ -55,13 +55,13 @@ async def enqueue_state(namespace_name: str, body: EnqueueRequestModel, request: @router.post( - "/states/create", + "/graph/{graph_name}/states/create", response_model=CreateResponseModel, status_code=status.HTTP_200_OK, response_description="States created successfully", tags=["state"] ) -async def create_state(namespace_name: str, body: CreateRequestModel, request: Request, api_key: str = Depends(check_api_key)): +async def create_state(namespace_name: str, graph_name: str, body: CreateRequestModel, request: Request, api_key: str = Depends(check_api_key)): x_exosphere_request_id = getattr(request.state, "x_exosphere_request_id", str(uuid4())) @@ -71,7 +71,7 @@ async def create_state(namespace_name: str, body: CreateRequestModel, request: R logger.error(f"API key is invalid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") - return await create_states(namespace_name, body, x_exosphere_request_id) + return await create_states(namespace_name, graph_name, body, x_exosphere_request_id) @router.post(