diff --git a/coordinode/coordinode/__init__.py b/coordinode/coordinode/__init__.py index 93759cc..d36407f 100644 --- a/coordinode/coordinode/__init__.py +++ b/coordinode/coordinode/__init__.py @@ -22,7 +22,11 @@ AsyncCoordinodeClient, CoordinodeClient, EdgeResult, + EdgeTypeInfo, + LabelInfo, NodeResult, + PropertyDefinitionInfo, + TraverseResult, VectorResult, ) @@ -36,4 +40,8 @@ "NodeResult", "EdgeResult", "VectorResult", + "LabelInfo", + "EdgeTypeInfo", + "PropertyDefinitionInfo", + "TraverseResult", ] diff --git a/coordinode/coordinode/client.py b/coordinode/coordinode/client.py index 3d2fac5..b270dd1 100644 --- a/coordinode/coordinode/client.py +++ b/coordinode/coordinode/client.py @@ -85,6 +85,54 @@ def __repr__(self) -> str: return f"VectorResult(distance={self.distance:.4f}, node={self.node})" +class PropertyDefinitionInfo: + """A property definition from the schema (name, type, required, unique).""" + + def __init__(self, proto_def: Any) -> None: + self.name: str = proto_def.name + self.type: int = proto_def.type + self.required: bool = proto_def.required + self.unique: bool = proto_def.unique + + def __repr__(self) -> str: + return f"PropertyDefinitionInfo(name={self.name!r}, type={self.type}, required={self.required}, unique={self.unique})" + + +class LabelInfo: + """A node label returned from the schema registry.""" + + def __init__(self, proto_label: Any) -> None: + self.name: str = proto_label.name + self.version: int = proto_label.version + self.properties: list[PropertyDefinitionInfo] = [PropertyDefinitionInfo(p) for p in proto_label.properties] + + def __repr__(self) -> str: + return f"LabelInfo(name={self.name!r}, version={self.version}, properties={len(self.properties)})" + + +class EdgeTypeInfo: + """An edge type returned from the schema registry.""" + + def __init__(self, proto_edge_type: Any) -> None: + self.name: str = proto_edge_type.name + self.version: int = proto_edge_type.version + self.properties: list[PropertyDefinitionInfo] = [PropertyDefinitionInfo(p) for p in proto_edge_type.properties] + + def __repr__(self) -> str: + return f"EdgeTypeInfo(name={self.name!r}, version={self.version}, properties={len(self.properties)})" + + +class TraverseResult: + """Result of a graph traversal: reached nodes and traversed edges.""" + + def __init__(self, proto_response: Any) -> None: + self.nodes: list[NodeResult] = [NodeResult(n) for n in proto_response.nodes] + self.edges: list[EdgeResult] = [EdgeResult(e) for e in proto_response.edges] + + def __repr__(self) -> str: + return f"TraverseResult(nodes={len(self.nodes)}, edges={len(self.edges)})" + + # ── Async client ───────────────────────────────────────────────────────────── @@ -303,6 +351,72 @@ async def get_schema_text(self) -> str: return "\n".join(lines) + async def get_labels(self) -> list[LabelInfo]: + """Return all node labels defined in the schema.""" + from coordinode._proto.coordinode.v1.graph.schema_pb2 import ListLabelsRequest # type: ignore[import] + + resp = await self._schema_stub.ListLabels(ListLabelsRequest(), timeout=self._timeout) + return [LabelInfo(label) for label in resp.labels] + + async def get_edge_types(self) -> list[EdgeTypeInfo]: + """Return all edge types defined in the schema.""" + from coordinode._proto.coordinode.v1.graph.schema_pb2 import ListEdgeTypesRequest # type: ignore[import] + + resp = await self._schema_stub.ListEdgeTypes(ListEdgeTypesRequest(), timeout=self._timeout) + return [EdgeTypeInfo(et) for et in resp.edge_types] + + async def traverse( + self, + start_node_id: int, + edge_type: str, + direction: str = "outbound", + max_depth: int = 1, + ) -> TraverseResult: + """Traverse the graph from *start_node_id* following *edge_type* edges. + + Args: + start_node_id: ID of the node to start from. + edge_type: Edge type label to follow (e.g. ``"KNOWS"``). + direction: ``"outbound"`` (default), ``"inbound"``, or ``"both"``. + max_depth: Maximum hop count (default 1). + + Returns: + :class:`TraverseResult` with ``nodes`` and ``edges`` lists. + """ + # Validate pure string/int inputs before importing proto stubs — ensures ValueError + # is raised even when proto stubs have not been generated yet. + # Type guards come first so that wrong types raise ValueError, not AttributeError/TypeError. + if not isinstance(direction, str): + raise ValueError(f"direction must be a str, got {type(direction).__name__!r}.") + _valid_directions = {"outbound", "inbound", "both"} + key = direction.lower() + if key not in _valid_directions: + raise ValueError(f"Invalid direction {direction!r}. Must be one of: 'outbound', 'inbound', 'both'.") + # bool is a subclass of int in Python, so `isinstance(True, int)` is True — exclude it. + if not isinstance(max_depth, int) or isinstance(max_depth, bool) or max_depth < 1: + raise ValueError(f"max_depth must be an integer >= 1, got {max_depth!r}.") + + from coordinode._proto.coordinode.v1.graph.graph_pb2 import ( # type: ignore[import] + TraversalDirection, + TraverseRequest, + ) + + _direction_map = { + "outbound": TraversalDirection.TRAVERSAL_DIRECTION_OUTBOUND, + "inbound": TraversalDirection.TRAVERSAL_DIRECTION_INBOUND, + "both": TraversalDirection.TRAVERSAL_DIRECTION_BOTH, + } + direction_value = _direction_map[key] + + req = TraverseRequest( + start_node_id=start_node_id, + edge_type=edge_type, + direction=direction_value, + max_depth=max_depth, + ) + resp = await self._graph_stub.Traverse(req, timeout=self._timeout) + return TraverseResult(resp) + async def health(self) -> bool: from coordinode._proto.coordinode.v1.health.health_pb2 import ( # type: ignore[import] HealthCheckRequest, @@ -422,6 +536,24 @@ def create_edge( def get_schema_text(self) -> str: return self._run(self._async.get_schema_text()) + def get_labels(self) -> list[LabelInfo]: + """Return all node labels defined in the schema.""" + return self._run(self._async.get_labels()) + + def get_edge_types(self) -> list[EdgeTypeInfo]: + """Return all edge types defined in the schema.""" + return self._run(self._async.get_edge_types()) + + def traverse( + self, + start_node_id: int, + edge_type: str, + direction: str = "outbound", + max_depth: int = 1, + ) -> TraverseResult: + """Traverse the graph from *start_node_id* following *edge_type* edges.""" + return self._run(self._async.traverse(start_node_id, edge_type, direction, max_depth)) + def health(self) -> bool: return self._run(self._async.health()) diff --git a/tests/integration/test_sdk.py b/tests/integration/test_sdk.py index c26fbb3..fa14ce2 100644 --- a/tests/integration/test_sdk.py +++ b/tests/integration/test_sdk.py @@ -12,7 +12,7 @@ import pytest -from coordinode import AsyncCoordinodeClient, CoordinodeClient +from coordinode import AsyncCoordinodeClient, CoordinodeClient, EdgeTypeInfo, LabelInfo, TraverseResult ADDR = os.environ.get("COORDINODE_ADDR", "localhost:7080") @@ -208,6 +208,132 @@ def test_get_schema_text(client): client.cypher("MATCH (n:SchemaTestLabel {tag: $tag}) DETACH DELETE n", params={"tag": tag}) +# ── get_labels / get_edge_types / traverse ──────────────────────────────────── + + +def test_get_labels_returns_list(client): + """get_labels() returns a non-empty list of LabelInfo after data is present.""" + tag = uid() + label_name = f"GetLabelsTest{uid()}" + client.cypher(f"CREATE (n:{label_name} {{tag: $tag}})", params={"tag": tag}) + try: + labels = client.get_labels() + assert isinstance(labels, list) + assert len(labels) > 0 + assert all(isinstance(lbl, LabelInfo) for lbl in labels) + names = [lbl.name for lbl in labels] + assert label_name in names, f"{label_name} not in {names}" + finally: + client.cypher(f"MATCH (n:{label_name} {{tag: $tag}}) DETACH DELETE n", params={"tag": tag}) + + +def test_get_labels_has_property_definitions(client): + """LabelInfo.properties is a list (may be empty for schema-free labels).""" + tag = uid() + label_name = f"PropLabel{uid()}" + client.cypher(f"CREATE (n:{label_name} {{tag: $tag}})", params={"tag": tag}) + try: + labels = client.get_labels() + found = next((lbl for lbl in labels if lbl.name == label_name), None) + assert found is not None, f"{label_name} not returned by get_labels()" + # Intentionally only check the type — CoordiNode is schema-free and may return + # an empty properties list even when the node was created with properties. + assert isinstance(found.properties, list) + finally: + client.cypher(f"MATCH (n:{label_name} {{tag: $tag}}) DETACH DELETE n", params={"tag": tag}) + + +def test_get_edge_types_returns_list(client): + """get_edge_types() returns a non-empty list of EdgeTypeInfo after data is present.""" + tag = uid() + edge_type = f"GET_EDGE_TYPE_TEST_{uid()}".upper() + client.cypher( + f"CREATE (a:EdgeTypeTestNode {{tag: $tag}})-[:{edge_type}]->(b:EdgeTypeTestNode {{tag: $tag}})", + params={"tag": tag}, + ) + try: + edge_types = client.get_edge_types() + assert isinstance(edge_types, list) + assert len(edge_types) > 0 + assert all(isinstance(et, EdgeTypeInfo) for et in edge_types) + type_names = [et.name for et in edge_types] + assert edge_type in type_names, f"{edge_type} not in {type_names}" + finally: + client.cypher("MATCH (n:EdgeTypeTestNode {tag: $tag}) DETACH DELETE n", params={"tag": tag}) + + +def test_traverse_returns_neighbours(client): + """traverse() returns adjacent nodes reachable via the given edge type.""" + tag = uid() + client.cypher( + "CREATE (a:TraverseRPC {tag: $tag, role: 'hub'})-[:TRAVERSE_TEST]->(b:TraverseRPC {tag: $tag, role: 'leaf1'})", + params={"tag": tag}, + ) + try: + rows = client.cypher( + "MATCH (a:TraverseRPC {tag: $tag, role: 'hub'}) RETURN a AS node_id", + params={"tag": tag}, + ) + assert len(rows) >= 1, "hub node not found" + start_id = rows[0]["node_id"] + + # Fetch the leaf1 node ID so we can assert it specifically appears in the result. + leaf_rows = client.cypher( + "MATCH (b:TraverseRPC {tag: $tag, role: 'leaf1'}) RETURN b AS node_id", + params={"tag": tag}, + ) + assert len(leaf_rows) >= 1, "leaf1 node not found" + leaf1_id = leaf_rows[0]["node_id"] + + result = client.traverse(start_id, "TRAVERSE_TEST", direction="outbound", max_depth=1) + assert isinstance(result, TraverseResult) + assert len(result.nodes) >= 1, "traverse() returned no neighbour nodes" + node_ids = {n.id for n in result.nodes} + assert leaf1_id in node_ids, f"traverse() did not return the expected leaf1 node ({leaf1_id}); got: {node_ids}" + finally: + client.cypher("MATCH (n:TraverseRPC {tag: $tag}) DETACH DELETE n", params={"tag": tag}) + + +@pytest.mark.xfail( + strict=False, + raises=AssertionError, + # strict=False: XPASS is good news (server gained inbound support), not an error. + # strict=True would break CI exactly when the server improves, which is undesirable. + # The XPASS report in pytest output is the signal to remove this marker. + # raises=AssertionError: narrows xfail to the known failure mode (empty result set → + # assertion fails). Unexpected errors (gRPC RpcError, wrong enum, etc.) are NOT covered + # and will still propagate as CI failures. + reason="CoordiNode Traverse RPC does not yet support inbound direction — server returns empty result set", +) +def test_traverse_inbound_direction(client): + """traverse() with direction='inbound' reaches nodes that point TO start_id.""" + tag = uid() + client.cypher( + "CREATE (src:TraverseIn {tag: $tag})-[:INBOUND_TEST]->(dst:TraverseIn {tag: $tag})", + params={"tag": tag}, + ) + try: + # Capture both src and dst so that when the server gains inbound support + # (XPASS), the assertion verifies the *correct* node was returned, not just any node. + rows = client.cypher( + "MATCH (src:TraverseIn {tag: $tag})-[:INBOUND_TEST]->(dst:TraverseIn {tag: $tag}) " + "RETURN src AS src_id, dst AS dst_id", + params={"tag": tag}, + ) + assert len(rows) >= 1 + src_id = rows[0]["src_id"] + dst_id = rows[0]["dst_id"] + result = client.traverse(dst_id, "INBOUND_TEST", direction="inbound", max_depth=1) + assert isinstance(result, TraverseResult) + assert len(result.nodes) >= 1, "inbound traverse returned no nodes" + node_ids = {n.id for n in result.nodes} + assert src_id in node_ids, ( + f"inbound traverse did not return the expected source node ({src_id}); got: {node_ids}" + ) + finally: + client.cypher("MATCH (n:TraverseIn {tag: $tag}) DETACH DELETE n", params={"tag": tag}) + + # ── Hybrid search ───────────────────────────────────────────────────────────── diff --git a/tests/unit/test_schema_crud.py b/tests/unit/test_schema_crud.py new file mode 100644 index 0000000..3e6495c --- /dev/null +++ b/tests/unit/test_schema_crud.py @@ -0,0 +1,257 @@ +"""Unit tests for R-SDK3 additions: LabelInfo, EdgeTypeInfo, TraverseResult. + +All tests are mock-based — no proto stubs or running server required. +Pattern mirrors test_types.py: fake proto objects with the same attribute +interface that real generated messages provide. +""" + +import asyncio + +import pytest + +from coordinode.client import ( + AsyncCoordinodeClient, + EdgeResult, + EdgeTypeInfo, + LabelInfo, + NodeResult, + PropertyDefinitionInfo, + TraverseResult, +) + +# ── Fake proto stubs ───────────────────────────────────────────────────────── + + +class _FakePropDef: + """Matches proto PropertyDefinition shape.""" + + def __init__(self, name: str, type_: int, required: bool = False, unique: bool = False) -> None: + self.name = name + self.type = type_ + self.required = required + self.unique = unique + + +class _FakeLabel: + """Matches proto Label shape.""" + + def __init__(self, name: str, version: int = 1, properties=None) -> None: + self.name = name + self.version = version + self.properties = properties or [] + + +class _FakeEdgeType: + """Matches proto EdgeType shape.""" + + def __init__(self, name: str, version: int = 1, properties=None) -> None: + self.name = name + self.version = version + self.properties = properties or [] + + +class _FakeNode: + """Matches proto Node shape.""" + + def __init__(self, node_id: int, labels=None, properties=None) -> None: + self.node_id = node_id + self.labels = labels or [] + self.properties = properties or {} + + +class _FakeEdge: + """Matches proto Edge shape.""" + + def __init__(self, edge_id: int, edge_type: str, source: int, target: int, properties=None) -> None: + self.edge_id = edge_id + self.edge_type = edge_type + self.source_node_id = source + self.target_node_id = target + self.properties = properties or {} + + +class _FakeTraverseResponse: + """Matches proto TraverseResponse shape.""" + + def __init__(self, nodes=None, edges=None) -> None: + self.nodes = nodes or [] + self.edges = edges or [] + + +# ── PropertyDefinitionInfo ─────────────────────────────────────────────────── + + +class TestPropertyDefinitionInfo: + def test_fields_are_mapped(self): + # type=3 = PROPERTY_TYPE_STRING (int value from proto enum) + p = PropertyDefinitionInfo(_FakePropDef("name", 3, required=True, unique=False)) + assert p.name == "name" + assert p.type == 3 + assert p.required is True + assert p.unique is False + + def test_repr_contains_name(self): + p = PropertyDefinitionInfo(_FakePropDef("age", 1)) + assert "age" in repr(p) + + def test_optional_flags_default_false(self): + p = PropertyDefinitionInfo(_FakePropDef("x", 2)) + assert p.required is False + assert p.unique is False + + +# ── LabelInfo ──────────────────────────────────────────────────────────────── + + +class TestLabelInfo: + def test_empty_properties(self): + label = LabelInfo(_FakeLabel("Person", version=2)) + assert label.name == "Person" + assert label.version == 2 + assert label.properties == [] + + def test_properties_are_wrapped(self): + props = [_FakePropDef("name", 3), _FakePropDef("age", 1)] + label = LabelInfo(_FakeLabel("User", properties=props)) + assert len(label.properties) == 2 + assert all(isinstance(p, PropertyDefinitionInfo) for p in label.properties) + assert label.properties[0].name == "name" + assert label.properties[1].name == "age" + + def test_repr_contains_name(self): + label = LabelInfo(_FakeLabel("Movie")) + assert "Movie" in repr(label) + + def test_version_zero(self): + # Schema registry may return version=0 for newly created labels. + label = LabelInfo(_FakeLabel("Draft", version=0)) + assert label.version == 0 + + +# ── EdgeTypeInfo ───────────────────────────────────────────────────────────── + + +class TestEdgeTypeInfo: + PROPERTY_TYPE_TIMESTAMP = 6 + + def test_basic_fields(self): + et = EdgeTypeInfo(_FakeEdgeType("KNOWS", version=1)) + assert et.name == "KNOWS" + assert et.version == 1 + assert et.properties == [] + + def test_properties_are_wrapped(self): + props = [_FakePropDef("since", self.PROPERTY_TYPE_TIMESTAMP)] + et = EdgeTypeInfo(_FakeEdgeType("FOLLOWS", properties=props)) + assert len(et.properties) == 1 + assert et.properties[0].name == "since" + + def test_repr_contains_name(self): + et = EdgeTypeInfo(_FakeEdgeType("RATED")) + assert "RATED" in repr(et) + + +# ── TraverseResult ─────────────────────────────────────────────────────────── + + +class TestTraverseResult: + def test_empty_response(self): + result = TraverseResult(_FakeTraverseResponse()) + assert result.nodes == [] + assert result.edges == [] + + def test_nodes_are_wrapped_as_node_results(self): + nodes = [_FakeNode(1, ["Person"]), _FakeNode(2, ["Movie"])] + result = TraverseResult(_FakeTraverseResponse(nodes=nodes)) + assert len(result.nodes) == 2 + assert all(isinstance(n, NodeResult) for n in result.nodes) + assert result.nodes[0].id == 1 + assert result.nodes[1].id == 2 + + def test_edges_are_wrapped_as_edge_results(self): + edges = [_FakeEdge(10, "KNOWS", source=1, target=2)] + result = TraverseResult(_FakeTraverseResponse(edges=edges)) + assert len(result.edges) == 1 + assert isinstance(result.edges[0], EdgeResult) + assert result.edges[0].id == 10 + assert result.edges[0].source_id == 1 + assert result.edges[0].target_id == 2 + assert result.edges[0].type == "KNOWS" + + def test_mixed_nodes_and_edges(self): + nodes = [_FakeNode(1, ["A"]), _FakeNode(2, ["B"]), _FakeNode(3, ["C"])] + edges = [ + _FakeEdge(10, "REL", 1, 2), + _FakeEdge(11, "REL", 2, 3), + ] + result = TraverseResult(_FakeTraverseResponse(nodes=nodes, edges=edges)) + assert len(result.nodes) == 3 + assert len(result.edges) == 2 + + def test_repr_shows_counts(self): + nodes = [_FakeNode(1, [])] + result = TraverseResult(_FakeTraverseResponse(nodes=nodes)) + r = repr(result) + assert "nodes=1" in r + assert "edges=0" in r + + +# ── traverse() input validation ────────────────────────────────────────────── + + +class TestTraverseValidation: + """Unit tests for AsyncCoordinodeClient.traverse() input validation. + + Validation (direction and max_depth checks) runs before any RPC call, so no + running server is required — only the client object needs to be instantiated. + """ + + def test_invalid_direction_raises(self): + """traverse() raises ValueError for an unrecognised direction string.""" + + async def _inner() -> None: + client = AsyncCoordinodeClient("localhost:0") + with pytest.raises(ValueError, match="Invalid direction"): + await client.traverse(1, "KNOWS", direction="sideways") + + asyncio.run(_inner()) + + def test_max_depth_below_one_raises(self): + """traverse() raises ValueError when max_depth is less than 1.""" + + async def _inner() -> None: + client = AsyncCoordinodeClient("localhost:0") + with pytest.raises(ValueError, match="max_depth must be"): + await client.traverse(1, "KNOWS", max_depth=0) + + asyncio.run(_inner()) + + def test_direction_none_raises_value_error(self): + """traverse() raises ValueError (not AttributeError) when direction is None.""" + + async def _inner() -> None: + client = AsyncCoordinodeClient("localhost:0") + with pytest.raises(ValueError, match="direction must be a str"): + await client.traverse(1, "KNOWS", direction=None) # type: ignore[arg-type] + + asyncio.run(_inner()) + + def test_max_depth_string_raises_value_error(self): + """traverse() raises ValueError (not TypeError) when max_depth is a string.""" + + async def _inner() -> None: + client = AsyncCoordinodeClient("localhost:0") + with pytest.raises(ValueError, match="max_depth must be an integer"): + await client.traverse(1, "KNOWS", max_depth="2") # type: ignore[arg-type] + + asyncio.run(_inner()) + + def test_max_depth_bool_raises_value_error(self): + """traverse() raises ValueError for bool max_depth (bool is a subclass of int in Python).""" + + async def _inner() -> None: + client = AsyncCoordinodeClient("localhost:0") + with pytest.raises(ValueError, match="max_depth must be an integer"): + await client.traverse(1, "KNOWS", max_depth=True) # type: ignore[arg-type] + + asyncio.run(_inner())