Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions state-manager/app/controller/get_graph_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from app.singletons.logs_manager import LogsManager
from app.models.graph_models import UpsertGraphTemplateResponse
from app.models.db.graph_template_model import GraphTemplate
from fastapi import HTTPException, status

logger = LogsManager().get_logger()


async def get_graph_template(namespace_name: str, graph_name: str, x_exosphere_request_id: str) -> UpsertGraphTemplateResponse:
try:
graph_template = await GraphTemplate.find_one(
GraphTemplate.name == graph_name,
GraphTemplate.namespace == namespace_name
)

if not graph_template:
logger.error(
"Graph template not found",
graph_name=graph_name,
namespace_name=namespace_name,
x_exosphere_request_id=x_exosphere_request_id,
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Graph template {graph_name} not found in namespace {namespace_name}")

logger.info(
"Graph template retrieved",
graph_name=graph_name,
namespace_name=namespace_name,
x_exosphere_request_id=x_exosphere_request_id,
)

return UpsertGraphTemplateResponse(
nodes=graph_template.nodes,
validation_status=graph_template.validation_status,
validation_errors=graph_template.validation_errors,
secrets={secret_name: True for secret_name in graph_template.secrets.keys()},
created_at=graph_template.created_at,
updated_at=graph_template.updated_at,
)
except Exception as e:
logger.error(
"Error retrieving graph template",
error=e,
graph_name=graph_name,
namespace_name=namespace_name,
x_exosphere_request_id=x_exosphere_request_id,
)
raise
20 changes: 20 additions & 0 deletions state-manager/app/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from .models.graph_models import UpsertGraphTemplateRequest, UpsertGraphTemplateResponse
from .controller.upsert_graph_template import upsert_graph_template as upsert_graph_template_controller
from .controller.get_graph_template import get_graph_template as get_graph_template_controller

from .models.register_nodes_request import RegisterNodesRequestModel
from .models.register_nodes_response import RegisterNodesResponseModel
Expand Down Expand Up @@ -132,6 +133,25 @@ async def upsert_graph_template(namespace_name: str, graph_name: str, body: Upse
return await upsert_graph_template_controller(namespace_name, graph_name, body, x_exosphere_request_id, background_tasks)


@router.get(
"/graph/{graph_name}",
response_model=UpsertGraphTemplateResponse,
status_code=status.HTTP_200_OK,
response_description="Graph template retrieved successfully",
tags=["graph"]
)
async def get_graph_template(namespace_name: str, graph_name: str, request: Request, api_key: str = Depends(check_api_key)):
x_exosphere_request_id = getattr(request.state, "x_exosphere_request_id", str(uuid4()))

if api_key:
logger.info(f"API key is valid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id)
else:
logger.error(f"API key is invalid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key")

return await get_graph_template_controller(namespace_name, graph_name, x_exosphere_request_id)


@router.put(
"/nodes/",
response_model=RegisterNodesResponseModel,
Expand Down
186 changes: 166 additions & 20 deletions state-manager/app/tasks/verify_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from app.models.db.registered_node import RegisteredNode
from app.singletons.logs_manager import LogsManager
from beanie.operators import In
from json_schema_to_pydantic import create_model
from collections import deque

logger = LogsManager().get_logger()

Expand All @@ -16,26 +18,11 @@ async def verify_nodes_namespace(nodes: list[NodeTemplate], graph_namespace: str
if node.namespace != graph_namespace and node.namespace != "exospherehost":
errors.append(f"Node {node.identifier} has invalid namespace '{node.namespace}'. Must match graph namespace '{graph_namespace}' or use universal namespace 'exospherehost'")

async def verify_node_exists(nodes: list[NodeTemplate], graph_namespace: str, errors: list[str]):
graph_namespace_node_names = [
node.node_name for node in nodes if node.namespace == graph_namespace
]
graph_namespace_database_nodes = await RegisteredNode.find(
In(RegisteredNode.name, graph_namespace_node_names),
RegisteredNode.namespace == graph_namespace
).to_list()
exospherehost_node_names = [
node.node_name for node in nodes if node.namespace == "exospherehost"
]
exospherehost_database_nodes = await RegisteredNode.find(
In(RegisteredNode.name, exospherehost_node_names),
RegisteredNode.namespace == "exospherehost"
).to_list()

template_nodes = set([(node.node_name, node.namespace) for node in nodes])
database_nodes = set([(node.name, node.namespace) for node in graph_namespace_database_nodes + exospherehost_database_nodes])
async def verify_node_exists(nodes: list[NodeTemplate], database_nodes: list[RegisteredNode], errors: list[str]):
template_nodes_set = set([(node.node_name, node.namespace) for node in nodes])
database_nodes_set = set([(node.name, node.namespace) for node in database_nodes])

nodes_not_found = template_nodes - database_nodes
nodes_not_found = template_nodes_set - database_nodes_set

for node in nodes_not_found:
errors.append(f"Node {node[0]} in namespace {node[1]} does not exist.")
Expand Down Expand Up @@ -68,20 +55,179 @@ async def verify_node_identifiers(nodes: list[NodeTemplate], errors: list[str]):
if next_node not in valid_identifiers:
errors.append(f"Node {node.node_name} in namespace {node.namespace} has a next node {next_node} that does not exist in the graph")

async def verify_secrets(graph_template: GraphTemplate, database_nodes: list[RegisteredNode], errors: list[str]):
required_secrets_set = set()

for node in database_nodes:
if node.secrets is None:
continue
for secret in node.secrets:
required_secrets_set.add(secret)

present_secrets_set = set()
for secret_name in graph_template.secrets.keys():
present_secrets_set.add(secret_name)

missing_secrets_set = required_secrets_set - present_secrets_set

for secret_name in missing_secrets_set:
errors.append(f"Secret {secret_name} is required but not present in the graph template")


async def get_database_nodes(nodes: list[NodeTemplate], graph_namespace: str):
graph_namespace_node_names = [
node.node_name for node in nodes if node.namespace == graph_namespace
]
graph_namespace_database_nodes = await RegisteredNode.find(
In(RegisteredNode.name, graph_namespace_node_names),
RegisteredNode.namespace == graph_namespace
).to_list()
exospherehost_node_names = [
node.node_name for node in nodes if node.namespace == "exospherehost"
]
exospherehost_database_nodes = await RegisteredNode.find(
In(RegisteredNode.name, exospherehost_node_names),
RegisteredNode.namespace == "exospherehost"
).to_list()
return graph_namespace_database_nodes + exospherehost_database_nodes


async def verify_inputs(graph_nodes: list[NodeTemplate], database_nodes: list[RegisteredNode], dependencies_graph: dict[str, set[str]], errors: list[str]):
look_up_table = {}
for node in graph_nodes:
look_up_table[node.identifier] = {"graph_node": node}

for database_node in database_nodes:
if database_node.name == node.node_name and database_node.namespace == node.namespace:
look_up_table[node.identifier]["database_node"] = database_node
break

for node in graph_nodes:
try:
model_class = create_model(look_up_table[node.identifier]["database_node"].inputs_schema)

for field_name, field_info in model_class.model_fields.items():
if field_info.annotation is not str:
errors.append(f"{node.node_name}.Inputs field '{field_name}' must be of type str, got {field_info.annotation}")
continue

if field_name not in look_up_table[node.identifier]["graph_node"].inputs.keys():
errors.append(f"{node.node_name}.Inputs field '{field_name}' not found in graph template")
continue

# get ${{ identifier.outputs.field_name }} objects from the string
splits = look_up_table[node.identifier]["graph_node"].inputs[field_name].split("${{")
for split in splits[1:]:
if "}}" in split:

identifier = None
field = None

syntax_string = split.split("}}")[0].strip()

if syntax_string.startswith("identifier.") and len(syntax_string.split(".")) == 3:
identifier = syntax_string.split(".")[1].strip()
field = syntax_string.split(".")[2].strip()
else:
errors.append(f"{node.node_name}.Inputs field '{field_name}' references field {syntax_string} which is not a valid output field")
continue

if identifier is None or field is None:
errors.append(f"{node.node_name}.Inputs field '{field_name}' references field {syntax_string} which is not a valid output field")
continue

if identifier not in dependencies_graph[node.identifier]:
errors.append(f"{node.node_name}.Inputs field '{field_name}' references node {identifier} which is not a dependency of {node.identifier}")
continue

output_model_class = create_model(look_up_table[identifier]["database_node"].outputs_schema)
if field not in output_model_class.model_fields.keys():
errors.append(f"{node.node_name}.Inputs field '{field_name}' references field {field} of node {identifier} which is not a valid output field")
continue

except Exception as e:
errors.append(f"Error creating input model for node {node.identifier}: {str(e)}")

async def build_dependencies_graph(graph_nodes: list[NodeTemplate]):
dependency_graph = {}
for node in graph_nodes:
dependency_graph[node.identifier] = set()
if node.next_nodes is None:
continue
for next_node in node.next_nodes:
dependency_graph[next_node].add(node.identifier)
dependency_graph[next_node] = dependency_graph[next_node] | dependency_graph[node.identifier]
return dependency_graph

async def verify_topology(graph_nodes: list[NodeTemplate], errors: list[str]):
# verify that the graph is a tree
# verify that the graph is connected
dependencies = {}
identifier_to_node = {}
visited = {}

for node in graph_nodes:
if node.identifier in dependencies.keys():
errors.append(f"Multiple identifier {node.identifier} incorrect topology")
return
dependencies[node.identifier] = set()
identifier_to_node[node.identifier] = node
visited[node.identifier] = False

# verify that there exists only one root node
for node in graph_nodes:
if node.next_nodes is None:
continue
for next_node in node.next_nodes:
dependencies[next_node].add(node.identifier)

# verify that there exists only one root node
root_nodes = [node for node in graph_nodes if len(dependencies[node.identifier]) == 0]
if len(root_nodes) != 1:
errors.append(f"Graph has {len(root_nodes)} root nodes, expected 1")
return

# verify that the graph is a tree
to_visit = deque([root_nodes[0].identifier])

while len(to_visit) > 0:
current_node = to_visit.popleft()
visited[current_node] = True

if identifier_to_node[current_node].next_nodes is None:
continue

for next_node in identifier_to_node[current_node].next_nodes:
if visited[next_node]:
errors.append(f"Graph is not a tree at {current_node} -> {next_node}")
else:
to_visit.append(next_node)

for identifier, visited_value in visited.items():
if not visited_value:
errors.append(f"Graph is not connected at {identifier}")

async def verify_graph(graph_template: GraphTemplate):
try:
errors = []
database_nodes = await get_database_nodes(graph_template.nodes, graph_template.namespace)

await verify_nodes_names(graph_template.nodes, errors)
await verify_nodes_namespace(graph_template.nodes, graph_template.namespace, errors)
await verify_node_exists(graph_template.nodes, graph_template.namespace, errors)
await verify_node_exists(graph_template.nodes, database_nodes, errors)
await verify_node_identifiers(graph_template.nodes, errors)
await verify_secrets(graph_template, database_nodes, errors)
await verify_topology(graph_template.nodes, errors)

if errors:
graph_template.validation_status = GraphTemplateValidationStatus.INVALID
graph_template.validation_errors = errors
await graph_template.save()
return

dependencies_graph = await build_dependencies_graph(graph_template.nodes)
await verify_inputs(graph_template.nodes, database_nodes, dependencies_graph, errors)

graph_template.validation_status = GraphTemplateValidationStatus.VALID
graph_template.validation_errors = None
await graph_template.save()
Expand Down
1 change: 1 addition & 0 deletions state-manager/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies = [
"beanie>=2.0.0",
"cryptography>=45.0.5",
"fastapi>=0.116.1",
"json-schema-to-pydantic>=0.4.1",
"python-dotenv>=1.1.1",
"structlog>=25.4.0",
"uvicorn>=0.35.0",
Expand Down
16 changes: 15 additions & 1 deletion state-manager/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.