diff --git a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/chunk_retrieval.py b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/chunk_retrieval.py index 3f42fcaa..e71e4455 100644 --- a/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/chunk_retrieval.py +++ b/graphrag_sdk/src/graphrag_sdk/retrieval/strategies/chunk_retrieval.py @@ -51,17 +51,25 @@ def _add(cid: str, text: str, source: str) -> None: except Exception as exc: logger.debug("Chunk vector search failed: %s", exc) - # Path C: MENTIONED_IN — 3 chunks per entity (batched UNWIND) + # Path C: MENTIONED_IN — top-3 chunks per entity, ranked by cosine + # distance to the query embedding. Hub entities (e.g. the main + # product name) can be MENTIONED_IN hundreds of chunks; the previous + # COLLECT(c)[..3] picked an arbitrary 3, almost never including the + # chunks most relevant to the current query. Ranking by cosine here + # surfaces the chunks closest to the query intent. eids_mention = [eid for eid, _ in entity_list[:15]] if eids_mention: try: result = await graph_store.query_raw( "UNWIND $eids AS eid " "MATCH (e:__Entity__ {id: eid})-[:MENTIONED_IN]->(c:Chunk) " + "WHERE c.embedding IS NOT NULL " + "WITH eid, c, vec.cosineDistance(c.embedding, vecf32($qv)) AS dist " + "ORDER BY eid, dist ASC " "WITH eid, COLLECT(c)[..3] AS chunks " "UNWIND chunks AS c " "RETURN eid, c.id AS id, c.text AS text", - {"eids": eids_mention}, + {"eids": eids_mention, "qv": query_vector}, ) for row in result.result_set: cid = row[1] diff --git a/graphrag_sdk/tests/test_multi_path_retrieval.py b/graphrag_sdk/tests/test_multi_path_retrieval.py index 53092946..b535730e 100644 --- a/graphrag_sdk/tests/test_multi_path_retrieval.py +++ b/graphrag_sdk/tests/test_multi_path_retrieval.py @@ -233,6 +233,55 @@ async def capture_query(cypher, params=None): ] assert len(twohop_chunk_queries) >= 1 + async def test_mentioned_in_ranks_chunks_by_cosine(self, mp_graph_store, mp_vector_store, mp_embedder, mp_llm): + """The MENTIONED_IN path must rank chunks per entity by cosine + distance to the query embedding — not pick an arbitrary 3. + + Hub entities can be MENTIONED_IN hundreds of chunks; arbitrary + selection almost never surfaces the chunks relevant to the + current query (regression: see PR referencing issue #258). + """ + # Seed entity discovery so Path C runs (mirrors the setup in + # ``test_mentioned_in_and_2hop_chunk_paths``). + mp_vector_store.search_relationships = AsyncMock(return_value=[ + {"src_name": "Alice", "type": "WORKS_AT", "tgt_name": "Acme", "fact": "engineer", "score": 0.9}, + ]) + + captured: list[tuple[str, dict]] = [] + + async def capture_query(cypher, params=None): + captured.append((cypher, params or {})) + result = MagicMock() + result.result_set = [] + return result + + mp_graph_store.query_raw = AsyncMock(side_effect=capture_query) + + s = MultiPathRetrieval( + graph_store=mp_graph_store, + vector_store=mp_vector_store, + embedder=mp_embedder, + llm=mp_llm, + ) + await s.search("Who is Alice?") + + # Find the direct MENTIONED_IN query (entity -> chunk, not 2-hop) + direct_mention = [ + (q, p) for q, p in captured + if "MENTIONED_IN" in q and "Chunk" in q and "neighbor" not in q.lower() + ] + assert direct_mention, "expected at least one direct MENTIONED_IN chunk query" + + cypher, params = direct_mention[0] + # The fix: rank by cosine distance to the query vector before + # COLLECT, so per-entity chunk selection is query-relevant. + assert "vec.cosineDistance" in cypher, ( + "MENTIONED_IN chunk query must rank by cosine distance to the " + "query vector (regression of issue #258)" + ) + assert "ORDER BY" in cypher, "expected ORDER BY to make COLLECT[..3] meaningful" + assert "qv" in params, "query vector must be passed as a parameter" + async def test_format_produces_sections(self, mp_graph_store, mp_vector_store, mp_embedder, mp_llm): """Output should include structured sections when data is available.""" mp_vector_store.search_relationships = AsyncMock(return_value=[