diff --git a/state-manager/app/controller/get_graph_template.py b/state-manager/app/controller/get_graph_template.py new file mode 100644 index 00000000..7015bbd4 --- /dev/null +++ b/state-manager/app/controller/get_graph_template.py @@ -0,0 +1,48 @@ +from app.singletons.logs_manager import LogsManager +from app.models.graph_models import UpsertGraphTemplateResponse +from app.models.db.graph_template_model import GraphTemplate +from fastapi import HTTPException, status + +logger = LogsManager().get_logger() + + +async def get_graph_template(namespace_name: str, graph_name: str, x_exosphere_request_id: str) -> UpsertGraphTemplateResponse: + try: + graph_template = await GraphTemplate.find_one( + GraphTemplate.name == graph_name, + GraphTemplate.namespace == namespace_name + ) + + if not graph_template: + logger.error( + "Graph template not found", + graph_name=graph_name, + namespace_name=namespace_name, + x_exosphere_request_id=x_exosphere_request_id, + ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Graph template {graph_name} not found in namespace {namespace_name}") + + logger.info( + "Graph template retrieved", + graph_name=graph_name, + namespace_name=namespace_name, + x_exosphere_request_id=x_exosphere_request_id, + ) + + return UpsertGraphTemplateResponse( + nodes=graph_template.nodes, + validation_status=graph_template.validation_status, + validation_errors=graph_template.validation_errors, + secrets={secret_name: True for secret_name in graph_template.secrets.keys()}, + created_at=graph_template.created_at, + updated_at=graph_template.updated_at, + ) + except Exception as e: + logger.error( + "Error retrieving graph template", + error=e, + graph_name=graph_name, + namespace_name=namespace_name, + x_exosphere_request_id=x_exosphere_request_id, + ) + raise \ No newline at end of file diff --git a/state-manager/app/routes.py b/state-manager/app/routes.py index 659256d9..553f963b 100644 --- a/state-manager/app/routes.py +++ b/state-manager/app/routes.py @@ -20,6 +20,7 @@ from .models.graph_models import UpsertGraphTemplateRequest, UpsertGraphTemplateResponse from .controller.upsert_graph_template import upsert_graph_template as upsert_graph_template_controller +from .controller.get_graph_template import get_graph_template as get_graph_template_controller from .models.register_nodes_request import RegisterNodesRequestModel from .models.register_nodes_response import RegisterNodesResponseModel @@ -132,6 +133,25 @@ async def upsert_graph_template(namespace_name: str, graph_name: str, body: Upse return await upsert_graph_template_controller(namespace_name, graph_name, body, x_exosphere_request_id, background_tasks) +@router.get( + "/graph/{graph_name}", + response_model=UpsertGraphTemplateResponse, + status_code=status.HTTP_200_OK, + response_description="Graph template retrieved successfully", + tags=["graph"] +) +async def get_graph_template(namespace_name: str, graph_name: str, request: Request, api_key: str = Depends(check_api_key)): + x_exosphere_request_id = getattr(request.state, "x_exosphere_request_id", str(uuid4())) + + if api_key: + logger.info(f"API key is valid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + else: + logger.error(f"API key is invalid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") + + return await get_graph_template_controller(namespace_name, graph_name, x_exosphere_request_id) + + @router.put( "/nodes/", response_model=RegisterNodesResponseModel, diff --git a/state-manager/app/tasks/verify_graph.py b/state-manager/app/tasks/verify_graph.py index 3ca70d0e..73dcf5d6 100644 --- a/state-manager/app/tasks/verify_graph.py +++ b/state-manager/app/tasks/verify_graph.py @@ -3,6 +3,8 @@ from app.models.db.registered_node import RegisteredNode from app.singletons.logs_manager import LogsManager from beanie.operators import In +from json_schema_to_pydantic import create_model +from collections import deque logger = LogsManager().get_logger() @@ -16,26 +18,11 @@ async def verify_nodes_namespace(nodes: list[NodeTemplate], graph_namespace: str if node.namespace != graph_namespace and node.namespace != "exospherehost": errors.append(f"Node {node.identifier} has invalid namespace '{node.namespace}'. Must match graph namespace '{graph_namespace}' or use universal namespace 'exospherehost'") -async def verify_node_exists(nodes: list[NodeTemplate], graph_namespace: str, errors: list[str]): - graph_namespace_node_names = [ - node.node_name for node in nodes if node.namespace == graph_namespace - ] - graph_namespace_database_nodes = await RegisteredNode.find( - In(RegisteredNode.name, graph_namespace_node_names), - RegisteredNode.namespace == graph_namespace - ).to_list() - exospherehost_node_names = [ - node.node_name for node in nodes if node.namespace == "exospherehost" - ] - exospherehost_database_nodes = await RegisteredNode.find( - In(RegisteredNode.name, exospherehost_node_names), - RegisteredNode.namespace == "exospherehost" - ).to_list() - - template_nodes = set([(node.node_name, node.namespace) for node in nodes]) - database_nodes = set([(node.name, node.namespace) for node in graph_namespace_database_nodes + exospherehost_database_nodes]) +async def verify_node_exists(nodes: list[NodeTemplate], database_nodes: list[RegisteredNode], errors: list[str]): + template_nodes_set = set([(node.node_name, node.namespace) for node in nodes]) + database_nodes_set = set([(node.name, node.namespace) for node in database_nodes]) - nodes_not_found = template_nodes - database_nodes + nodes_not_found = template_nodes_set - database_nodes_set for node in nodes_not_found: errors.append(f"Node {node[0]} in namespace {node[1]} does not exist.") @@ -68,13 +55,169 @@ async def verify_node_identifiers(nodes: list[NodeTemplate], errors: list[str]): if next_node not in valid_identifiers: errors.append(f"Node {node.node_name} in namespace {node.namespace} has a next node {next_node} that does not exist in the graph") +async def verify_secrets(graph_template: GraphTemplate, database_nodes: list[RegisteredNode], errors: list[str]): + required_secrets_set = set() + + for node in database_nodes: + if node.secrets is None: + continue + for secret in node.secrets: + required_secrets_set.add(secret) + + present_secrets_set = set() + for secret_name in graph_template.secrets.keys(): + present_secrets_set.add(secret_name) + + missing_secrets_set = required_secrets_set - present_secrets_set + + for secret_name in missing_secrets_set: + errors.append(f"Secret {secret_name} is required but not present in the graph template") + + +async def get_database_nodes(nodes: list[NodeTemplate], graph_namespace: str): + graph_namespace_node_names = [ + node.node_name for node in nodes if node.namespace == graph_namespace + ] + graph_namespace_database_nodes = await RegisteredNode.find( + In(RegisteredNode.name, graph_namespace_node_names), + RegisteredNode.namespace == graph_namespace + ).to_list() + exospherehost_node_names = [ + node.node_name for node in nodes if node.namespace == "exospherehost" + ] + exospherehost_database_nodes = await RegisteredNode.find( + In(RegisteredNode.name, exospherehost_node_names), + RegisteredNode.namespace == "exospherehost" + ).to_list() + return graph_namespace_database_nodes + exospherehost_database_nodes + + +async def verify_inputs(graph_nodes: list[NodeTemplate], database_nodes: list[RegisteredNode], dependencies_graph: dict[str, set[str]], errors: list[str]): + look_up_table = {} + for node in graph_nodes: + look_up_table[node.identifier] = {"graph_node": node} + + for database_node in database_nodes: + if database_node.name == node.node_name and database_node.namespace == node.namespace: + look_up_table[node.identifier]["database_node"] = database_node + break + + for node in graph_nodes: + try: + model_class = create_model(look_up_table[node.identifier]["database_node"].inputs_schema) + + for field_name, field_info in model_class.model_fields.items(): + if field_info.annotation is not str: + errors.append(f"{node.node_name}.Inputs field '{field_name}' must be of type str, got {field_info.annotation}") + continue + + if field_name not in look_up_table[node.identifier]["graph_node"].inputs.keys(): + errors.append(f"{node.node_name}.Inputs field '{field_name}' not found in graph template") + continue + + # get ${{ identifier.outputs.field_name }} objects from the string + splits = look_up_table[node.identifier]["graph_node"].inputs[field_name].split("${{") + for split in splits[1:]: + if "}}" in split: + + identifier = None + field = None + + syntax_string = split.split("}}")[0].strip() + + if syntax_string.startswith("identifier.") and len(syntax_string.split(".")) == 3: + identifier = syntax_string.split(".")[1].strip() + field = syntax_string.split(".")[2].strip() + else: + errors.append(f"{node.node_name}.Inputs field '{field_name}' references field {syntax_string} which is not a valid output field") + continue + + if identifier is None or field is None: + errors.append(f"{node.node_name}.Inputs field '{field_name}' references field {syntax_string} which is not a valid output field") + continue + + if identifier not in dependencies_graph[node.identifier]: + errors.append(f"{node.node_name}.Inputs field '{field_name}' references node {identifier} which is not a dependency of {node.identifier}") + continue + + output_model_class = create_model(look_up_table[identifier]["database_node"].outputs_schema) + if field not in output_model_class.model_fields.keys(): + errors.append(f"{node.node_name}.Inputs field '{field_name}' references field {field} of node {identifier} which is not a valid output field") + continue + + except Exception as e: + errors.append(f"Error creating input model for node {node.identifier}: {str(e)}") + +async def build_dependencies_graph(graph_nodes: list[NodeTemplate]): + dependency_graph = {} + for node in graph_nodes: + dependency_graph[node.identifier] = set() + if node.next_nodes is None: + continue + for next_node in node.next_nodes: + dependency_graph[next_node].add(node.identifier) + dependency_graph[next_node] = dependency_graph[next_node] | dependency_graph[node.identifier] + return dependency_graph + +async def verify_topology(graph_nodes: list[NodeTemplate], errors: list[str]): + # verify that the graph is a tree + # verify that the graph is connected + dependencies = {} + identifier_to_node = {} + visited = {} + + for node in graph_nodes: + if node.identifier in dependencies.keys(): + errors.append(f"Multiple identifier {node.identifier} incorrect topology") + return + dependencies[node.identifier] = set() + identifier_to_node[node.identifier] = node + visited[node.identifier] = False + + # verify that there exists only one root node + for node in graph_nodes: + if node.next_nodes is None: + continue + for next_node in node.next_nodes: + dependencies[next_node].add(node.identifier) + + # verify that there exists only one root node + root_nodes = [node for node in graph_nodes if len(dependencies[node.identifier]) == 0] + if len(root_nodes) != 1: + errors.append(f"Graph has {len(root_nodes)} root nodes, expected 1") + return + + # verify that the graph is a tree + to_visit = deque([root_nodes[0].identifier]) + + while len(to_visit) > 0: + current_node = to_visit.popleft() + visited[current_node] = True + + if identifier_to_node[current_node].next_nodes is None: + continue + + for next_node in identifier_to_node[current_node].next_nodes: + if visited[next_node]: + errors.append(f"Graph is not a tree at {current_node} -> {next_node}") + else: + to_visit.append(next_node) + + for identifier, visited_value in visited.items(): + if not visited_value: + errors.append(f"Graph is not connected at {identifier}") + async def verify_graph(graph_template: GraphTemplate): try: errors = [] + database_nodes = await get_database_nodes(graph_template.nodes, graph_template.namespace) + await verify_nodes_names(graph_template.nodes, errors) await verify_nodes_namespace(graph_template.nodes, graph_template.namespace, errors) - await verify_node_exists(graph_template.nodes, graph_template.namespace, errors) + await verify_node_exists(graph_template.nodes, database_nodes, errors) await verify_node_identifiers(graph_template.nodes, errors) + await verify_secrets(graph_template, database_nodes, errors) + await verify_topology(graph_template.nodes, errors) if errors: graph_template.validation_status = GraphTemplateValidationStatus.INVALID @@ -82,6 +225,9 @@ async def verify_graph(graph_template: GraphTemplate): await graph_template.save() return + dependencies_graph = await build_dependencies_graph(graph_template.nodes) + await verify_inputs(graph_template.nodes, database_nodes, dependencies_graph, errors) + graph_template.validation_status = GraphTemplateValidationStatus.VALID graph_template.validation_errors = None await graph_template.save() diff --git a/state-manager/pyproject.toml b/state-manager/pyproject.toml index 2764246e..139d831c 100644 --- a/state-manager/pyproject.toml +++ b/state-manager/pyproject.toml @@ -8,6 +8,7 @@ dependencies = [ "beanie>=2.0.0", "cryptography>=45.0.5", "fastapi>=0.116.1", + "json-schema-to-pydantic>=0.4.1", "python-dotenv>=1.1.1", "structlog>=25.4.0", "uvicorn>=0.35.0", diff --git a/state-manager/uv.lock b/state-manager/uv.lock index d795a305..cb9711cd 100644 --- a/state-manager/uv.lock +++ b/state-manager/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.12" [[package]] @@ -171,6 +171,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, ] +[[package]] +name = "json-schema-to-pydantic" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f0/25/c9d8590a698a1cef53859b9a6ff32c79a758f16af4ab37118e4529503b2b/json_schema_to_pydantic-0.4.1.tar.gz", hash = "sha256:218df347563ce91d6214614310723db986e9de38f2bd0f683368a78fd0761a7a", size = 40975, upload-time = "2025-07-14T19:05:30.418Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/65/54ac92e3d1346ff21bb7e5b15078046fea552517c2d5d0184e5643074f36/json_schema_to_pydantic-0.4.1-py3-none-any.whl", hash = "sha256:83ecc23c4f44ad013974bd9dfef6475097ea130dc83872d0152f93a953f56564", size = 12969, upload-time = "2025-07-14T19:05:29.289Z" }, +] + [[package]] name = "lazy-model" version = "0.3.0" @@ -351,6 +363,7 @@ dependencies = [ { name = "beanie" }, { name = "cryptography" }, { name = "fastapi" }, + { name = "json-schema-to-pydantic" }, { name = "python-dotenv" }, { name = "structlog" }, { name = "uvicorn" }, @@ -366,6 +379,7 @@ requires-dist = [ { name = "beanie", specifier = ">=2.0.0" }, { name = "cryptography", specifier = ">=45.0.5" }, { name = "fastapi", specifier = ">=0.116.1" }, + { name = "json-schema-to-pydantic", specifier = ">=0.4.1" }, { name = "python-dotenv", specifier = ">=1.1.1" }, { name = "structlog", specifier = ">=25.4.0" }, { name = "uvicorn", specifier = ">=0.35.0" },