From befab25002c0f1f4582ef91604f88fbfece6d7db Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Fri, 23 Jan 2026 11:48:20 -0300 Subject: [PATCH 1/7] write dataframe --- .../tables/parquet_table_provider.py | 76 +++++++++++++++++++ .../graphrag_storage/tables/table_provider.py | 52 +++++++++++++ packages/graphrag/graphrag/index/run/utils.py | 6 +- .../graphrag/graphrag/index/typing/context.py | 4 +- .../index/workflows/create_base_text_units.py | 5 +- packages/graphrag/graphrag/utils/storage.py | 15 +--- .../storage/test_parquet_table_provider.py | 71 +++++++++++++++++ 7 files changed, 212 insertions(+), 17 deletions(-) create mode 100644 packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py create mode 100644 packages/graphrag-storage/graphrag_storage/tables/table_provider.py create mode 100644 tests/unit/storage/test_parquet_table_provider.py diff --git a/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py new file mode 100644 index 0000000000..b71a6be467 --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parquet-based table provider implementation.""" + +from io import BytesIO +import logging +from typing import Any +import pandas as pd +from graphrag_storage.storage import Storage +from graphrag_storage.tables.table_provider import TableProvider + +logger = logging.getLogger(__name__) + +class ParquetTableProvider(TableProvider): + """Table provider that stores tables as Parquet files using an underlying Storage instance. + + This provider converts between pandas DataFrames and Parquet format, + storing the data through a Storage backend (file, blob, cosmos, etc.). + """ + + def __init__(self, storage: Storage, **kwargs) -> None: + """Initialize the Parquet table provider with an underlying storage instance. + + Args + ---- + storage: Storage + The storage instance to use for reading and writing Parquet files. + **kwargs: Any + Additional keyword arguments (currently unused). + """ + self._storage = storage + + async def read_dataframe(self, table_name: str) -> pd.DataFrame: + """Read a table from storage as a pandas DataFrame. + + Args + ---- + table_name: str + The name of the table to read. The file will be accessed as '{table_name}.parquet'. + + Returns + ------- + pd.DataFrame: + The table data loaded from the Parquet file. + + Raises + ------ + ValueError: + If the table file does not exist in storage. + Exception: + If there is an error reading or parsing the Parquet file. + """ + filename = f"{table_name}.parquet" + if not await self._storage.has(filename): + msg = f"Could not find {filename} in storage!" + raise ValueError(msg) + try: + logger.info("reading table from storage: %s", filename) + return pd.read_parquet(BytesIO(await self._storage.get(filename, as_bytes=True))) + except Exception: + logger.exception("error loading table from storage: %s", filename) + raise + + + async def write_dataframe(self, table_name: str, df: pd.DataFrame) -> None: + """Write a pandas DataFrame to storage as a Parquet file. + + Args + ---- + table_name: str + The name of the table to write. The file will be saved as '{table_name}.parquet'. + df: pd.DataFrame + The DataFrame to write to storage. + """ + await self._storage.set(f"{table_name}.parquet", df.to_parquet()) \ No newline at end of file diff --git a/packages/graphrag-storage/graphrag_storage/tables/table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/table_provider.py new file mode 100644 index 0000000000..056c07ed31 --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/tables/table_provider.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Abstract base class for table providers.""" + +from abc import ABC, abstractmethod +from typing import Any + +import pandas as pd + +class TableProvider(ABC): + """Provide a table-based storage interface with support for DataFrames and row dictionaries.""" + + @abstractmethod + def __init__(self, **kwargs: Any) -> None: + """Create a table provider instance. + + Args + ---- + **kwargs: Any + Keyword arguments for initialization, may include underlying Storage instance. + """ + + @abstractmethod + async def read_dataframe(self, table_name: str) -> pd.DataFrame: + """Read entire table as a pandas DataFrame. + + Args + ---- + table_name: str + The name of the table to read. + + Returns + ------- + pd.DataFrame: + The table data as a DataFrame. + """ + + @abstractmethod + async def write_dataframe(self, table_name: str, df: pd.DataFrame) -> None: + """Write entire table from a pandas DataFrame. + + Args + ---- + table_name: str + The name of the table to write. + df: pd.DataFrame + The DataFrame to write as a table. + """ + + + diff --git a/packages/graphrag/graphrag/index/run/utils.py b/packages/graphrag/graphrag/index/run/utils.py index be6914a6d6..93fcdfbded 100644 --- a/packages/graphrag/graphrag/index/run/utils.py +++ b/packages/graphrag/graphrag/index/run/utils.py @@ -5,7 +5,7 @@ from graphrag_cache import Cache from graphrag_cache.memory_cache import MemoryCache -from graphrag_storage import Storage, create_storage +from graphrag_storage import ParquetTableProvider, Storage, create_storage from graphrag_storage.memory_storage import MemoryStorage from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks @@ -27,9 +27,11 @@ def create_run_context( state: PipelineState | None = None, ) -> PipelineRunContext: """Create the run context for the pipeline.""" + output_storage = output_storage or MemoryStorage() return PipelineRunContext( input_storage=input_storage or MemoryStorage(), - output_storage=output_storage or MemoryStorage(), + output_storage=output_storage, + output_table_provider=ParquetTableProvider(storage=output_storage), previous_storage=previous_storage or MemoryStorage(), cache=cache or MemoryCache(), callbacks=callbacks or NoopWorkflowCallbacks(), diff --git a/packages/graphrag/graphrag/index/typing/context.py b/packages/graphrag/graphrag/index/typing/context.py index 95e7f898f9..1f25ec8f47 100644 --- a/packages/graphrag/graphrag/index/typing/context.py +++ b/packages/graphrag/graphrag/index/typing/context.py @@ -10,7 +10,7 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.index.typing.state import PipelineState from graphrag.index.typing.stats import PipelineRunStats -from graphrag_storage import Storage +from graphrag_storage import ParquetTableProvider, Storage @dataclass @@ -22,6 +22,8 @@ class PipelineRunContext: "Storage for input documents." output_storage: Storage "Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider." + output_table_provider: ParquetTableProvider + "Table provider for reading and writing output tables." previous_storage: Storage "Storage for previous pipeline run when running in update mode." cache: Cache diff --git a/packages/graphrag/graphrag/index/workflows/create_base_text_units.py b/packages/graphrag/graphrag/index/workflows/create_base_text_units.py index ec6abc2578..2d53fd8e6f 100644 --- a/packages/graphrag/graphrag/index/workflows/create_base_text_units.py +++ b/packages/graphrag/graphrag/index/workflows/create_base_text_units.py @@ -20,7 +20,6 @@ from graphrag.index.utils.hashing import gen_sha512_hash from graphrag.logger.progress import progress_ticker from graphrag.tokenizer.get_tokenizer import get_tokenizer -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -31,7 +30,7 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform base text_units.""" logger.info("Workflow started: create_base_text_units") - documents = await load_table_from_storage("documents", context.output_storage) + documents = await context.output_table_provider.read_dataframe("documents") tokenizer = get_tokenizer(encoding_model=config.chunking.encoding_model) chunker = create_chunker(config.chunking, tokenizer.encode, tokenizer.decode) @@ -43,7 +42,7 @@ async def run_workflow( prepend_metadata=config.chunking.prepend_metadata, ) - await write_table_to_storage(output, "text_units", context.output_storage) + await context.output_table_provider.write_dataframe("text_units", output) logger.info("Workflow completed: create_base_text_units") return WorkflowFunctionOutput(result=output) diff --git a/packages/graphrag/graphrag/utils/storage.py b/packages/graphrag/graphrag/utils/storage.py index 852d066091..e4d1566427 100644 --- a/packages/graphrag/graphrag/utils/storage.py +++ b/packages/graphrag/graphrag/utils/storage.py @@ -14,23 +14,16 @@ async def load_table_from_storage(name: str, storage: Storage) -> pd.DataFrame: """Load a parquet from the storage instance.""" - filename = f"{name}.parquet" - if not await storage.has(filename): - msg = f"Could not find {filename} in storage!" - raise ValueError(msg) - try: - logger.info("reading table from storage: %s", filename) - return pd.read_parquet(BytesIO(await storage.get(filename, as_bytes=True))) - except Exception: - logger.exception("error loading table from storage: %s", filename) - raise + table_provider = ParquetTableProvider(storage) + return await table_provider.read_dataframe(name) async def write_table_to_storage( table: pd.DataFrame, name: str, storage: Storage ) -> None: """Write a table to storage.""" - await storage.set(f"{name}.parquet", table.to_parquet()) + table_provider = ParquetTableProvider(storage) + await table_provider.write_dataframe(name, table) async def delete_table_from_storage(name: str, storage: Storage) -> None: diff --git a/tests/unit/storage/test_parquet_table_provider.py b/tests/unit/storage/test_parquet_table_provider.py new file mode 100644 index 0000000000..27b66228b9 --- /dev/null +++ b/tests/unit/storage/test_parquet_table_provider.py @@ -0,0 +1,71 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +import unittest +from io import BytesIO + +import pandas as pd +import pytest +from graphrag_storage import ParquetTableProvider, StorageConfig, StorageType, create_storage + + +class TestParquetTableProvider(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.storage = create_storage( + StorageConfig( + type=StorageType.Memory, + ) + ) + self.table_provider = ParquetTableProvider(storage=self.storage) + + async def asyncTearDown(self): + await self.storage.clear() + + async def test_write_and_read(self): + df = pd.DataFrame({ + "id": [1, 2, 3], + "name": ["Alice", "Bob", "Charlie"], + "age": [30, 25, 35] + }) + + await self.table_provider.write_dataframe("users", df) + result = await self.table_provider.read_dataframe("users") + + pd.testing.assert_frame_equal(result, df) + + async def test_read_nonexistent_table_raises_error(self): + with pytest.raises(ValueError, match="Could not find nonexistent.parquet in storage!"): + await self.table_provider.read_dataframe("nonexistent") + + async def test_empty_dataframe(self): + df = pd.DataFrame() + + await self.table_provider.write_dataframe("empty", df) + result = await self.table_provider.read_dataframe("empty") + + pd.testing.assert_frame_equal(result, df) + + async def test_dataframe_with_multiple_types(self): + df = pd.DataFrame({ + "int_col": [1, 2, 3], + "float_col": [1.1, 2.2, 3.3], + "str_col": ["a", "b", "c"], + "bool_col": [True, False, True] + }) + + await self.table_provider.write_dataframe("mixed", df) + result = await self.table_provider.read_dataframe("mixed") + + pd.testing.assert_frame_equal(result, df) + + async def test_storage_persistence(self): + df = pd.DataFrame({"x": [1, 2, 3]}) + + await self.table_provider.write_dataframe("test", df) + + assert await self.storage.has("test.parquet") + + parquet_bytes = await self.storage.get("test.parquet", as_bytes=True) + loaded_df = pd.read_parquet(BytesIO(parquet_bytes)) + + pd.testing.assert_frame_equal(loaded_df, df) From 5a0e5ea9fd3dafafd7abfeb7d167cd271a1e16f2 Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Fri, 23 Jan 2026 13:03:13 -0300 Subject: [PATCH 2/7] changed some workflows --- .../index/workflows/create_communities.py | 8 +++--- .../index/workflows/create_final_documents.py | 7 +++-- .../workflows/create_final_text_units.py | 27 +++++++------------ .../index/workflows/finalize_graph.py | 13 +++------ .../graphrag/index/workflows/prune_graph.py | 13 +++------ 5 files changed, 24 insertions(+), 44 deletions(-) diff --git a/packages/graphrag/graphrag/index/workflows/create_communities.py b/packages/graphrag/graphrag/index/workflows/create_communities.py index 4394593e99..f90889bac8 100644 --- a/packages/graphrag/graphrag/index/workflows/create_communities.py +++ b/packages/graphrag/graphrag/index/workflows/create_communities.py @@ -28,10 +28,8 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform final communities.""" logger.info("Workflow started: create_communities") - entities = await load_table_from_storage("entities", context.output_storage) - relationships = await load_table_from_storage( - "relationships", context.output_storage - ) + entities = await context.output_table_provider.read_dataframe("entities") + relationships = await context.output_table_provider.read_dataframe("relationships") max_cluster_size = config.cluster_graph.max_cluster_size use_lcc = config.cluster_graph.use_lcc @@ -45,7 +43,7 @@ async def run_workflow( seed=seed, ) - await write_table_to_storage(output, "communities", context.output_storage) + await context.output_table_provider.write_dataframe("communities", output) logger.info("Workflow completed: create_communities") return WorkflowFunctionOutput(result=output) diff --git a/packages/graphrag/graphrag/index/workflows/create_final_documents.py b/packages/graphrag/graphrag/index/workflows/create_final_documents.py index 554fbc4254..c799d1bb44 100644 --- a/packages/graphrag/graphrag/index/workflows/create_final_documents.py +++ b/packages/graphrag/graphrag/index/workflows/create_final_documents.py @@ -11,7 +11,6 @@ from graphrag.data_model.schemas import DOCUMENTS_FINAL_COLUMNS from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -22,12 +21,12 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform final documents.""" logger.info("Workflow started: create_final_documents") - documents = await load_table_from_storage("documents", context.output_storage) - text_units = await load_table_from_storage("text_units", context.output_storage) + documents = await context.output_table_provider.read_dataframe("documents") + text_units = await context.output_table_provider.read_dataframe("text_units") output = create_final_documents(documents, text_units) - await write_table_to_storage(output, "documents", context.output_storage) + await context.output_table_provider.write_dataframe("documents", output) logger.info("Workflow completed: create_final_documents") return WorkflowFunctionOutput(result=output) diff --git a/packages/graphrag/graphrag/index/workflows/create_final_text_units.py b/packages/graphrag/graphrag/index/workflows/create_final_text_units.py index c16e08bb7c..91b5885751 100644 --- a/packages/graphrag/graphrag/index/workflows/create_final_text_units.py +++ b/packages/graphrag/graphrag/index/workflows/create_final_text_units.py @@ -11,11 +11,6 @@ from graphrag.data_model.schemas import TEXT_UNITS_FINAL_COLUMNS from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import ( - load_table_from_storage, - storage_has_table, - write_table_to_storage, -) logger = logging.getLogger(__name__) @@ -26,18 +21,16 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform the text units.""" logger.info("Workflow started: create_final_text_units") - text_units = await load_table_from_storage("text_units", context.output_storage) - final_entities = await load_table_from_storage("entities", context.output_storage) - final_relationships = await load_table_from_storage( - "relationships", context.output_storage - ) + text_units = await context.output_table_provider.read_dataframe("text_units") + final_entities = await context.output_table_provider.read_dataframe("entities") + final_relationships = await context.output_table_provider.read_dataframe("relationships") + final_covariates = None - if config.extract_claims.enabled and await storage_has_table( - "covariates", context.output_storage - ): - final_covariates = await load_table_from_storage( - "covariates", context.output_storage - ) + if config.extract_claims.enabled: + try: + final_covariates = await context.output_table_provider.read_dataframe("covariates") + except Exception: + pass output = create_final_text_units( text_units, @@ -46,7 +39,7 @@ async def run_workflow( final_covariates, ) - await write_table_to_storage(output, "text_units", context.output_storage) + await context.output_table_provider.write_dataframe("text_units", output) logger.info("Workflow completed: create_final_text_units") return WorkflowFunctionOutput(result=output) diff --git a/packages/graphrag/graphrag/index/workflows/finalize_graph.py b/packages/graphrag/graphrag/index/workflows/finalize_graph.py index 49529aea3a..6edcb7ba93 100644 --- a/packages/graphrag/graphrag/index/workflows/finalize_graph.py +++ b/packages/graphrag/graphrag/index/workflows/finalize_graph.py @@ -14,7 +14,6 @@ from graphrag.index.operations.snapshot_graphml import snapshot_graphml from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -25,20 +24,16 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to create the base entity graph.""" logger.info("Workflow started: finalize_graph") - entities = await load_table_from_storage("entities", context.output_storage) - relationships = await load_table_from_storage( - "relationships", context.output_storage - ) + entities = await context.output_table_provider.read_dataframe("entities") + relationships = await context.output_table_provider.read_dataframe("relationships") final_entities, final_relationships = finalize_graph( entities, relationships, ) - await write_table_to_storage(final_entities, "entities", context.output_storage) - await write_table_to_storage( - final_relationships, "relationships", context.output_storage - ) + await context.output_table_provider.write_dataframe("entities", final_entities) + await context.output_table_provider.write_dataframe("relationships", final_relationships) if config.snapshots.graphml: graph = create_graph(final_relationships, edge_attr=["weight"]) diff --git a/packages/graphrag/graphrag/index/workflows/prune_graph.py b/packages/graphrag/graphrag/index/workflows/prune_graph.py index 5653eef49b..7a0fc73702 100644 --- a/packages/graphrag/graphrag/index/workflows/prune_graph.py +++ b/packages/graphrag/graphrag/index/workflows/prune_graph.py @@ -14,7 +14,6 @@ from graphrag.index.operations.prune_graph import prune_graph as prune_graph_operation from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -25,10 +24,8 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to create the base entity graph.""" logger.info("Workflow started: prune_graph") - entities = await load_table_from_storage("entities", context.output_storage) - relationships = await load_table_from_storage( - "relationships", context.output_storage - ) + entities = await context.output_table_provider.read_dataframe("entities") + relationships = await context.output_table_provider.read_dataframe("relationships") pruned_entities, pruned_relationships = prune_graph( entities, @@ -36,10 +33,8 @@ async def run_workflow( pruning_config=config.prune_graph, ) - await write_table_to_storage(pruned_entities, "entities", context.output_storage) - await write_table_to_storage( - pruned_relationships, "relationships", context.output_storage - ) + await context.output_table_provider.write_dataframe("entities", pruned_entities) + await context.output_table_provider.write_dataframe("relationships", pruned_relationships) logger.info("Workflow completed: prune_graph") return WorkflowFunctionOutput( From 277977e8764bd89388fe1cffe9dff16af906723b Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Mon, 26 Jan 2026 11:51:06 -0300 Subject: [PATCH 3/7] 1a --- .../tables/parquet_table_provider.py | 5 ++- .../graphrag_storage/tables/table_provider.py | 15 +++++++- .../workflows/create_community_reports.py | 23 ++++++------- .../create_community_reports_text.py | 10 +++--- .../index/workflows/extract_covariates.py | 5 ++- .../index/workflows/extract_graph_nlp.py | 7 ++-- .../workflows/generate_text_embeddings.py | 17 +++------- .../index/workflows/load_input_documents.py | 3 +- .../index/workflows/load_update_documents.py | 3 +- .../index/workflows/update_communities.py | 21 +++++++----- .../workflows/update_community_reports.py | 27 ++++++++------- .../index/workflows/update_covariates.py | 33 +++++++++--------- .../update_entities_relationships.py | 34 +++++++++---------- .../index/workflows/update_text_embeddings.py | 10 +++--- .../index/workflows/update_text_units.py | 12 ++++--- 15 files changed, 116 insertions(+), 109 deletions(-) diff --git a/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py index b71a6be467..9476202d48 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py +++ b/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py @@ -73,4 +73,7 @@ async def write_dataframe(self, table_name: str, df: pd.DataFrame) -> None: df: pd.DataFrame The DataFrame to write to storage. """ - await self._storage.set(f"{table_name}.parquet", df.to_parquet()) \ No newline at end of file + await self._storage.set(f"{table_name}.parquet", df.to_parquet()) + + async def has_dataframe(self, table_name): + return await self._storage.has(f"{table_name}.parquet") \ No newline at end of file diff --git a/packages/graphrag-storage/graphrag_storage/tables/table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/table_provider.py index 056c07ed31..06f5f760eb 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/table_provider.py +++ b/packages/graphrag-storage/graphrag_storage/tables/table_provider.py @@ -47,6 +47,19 @@ async def write_dataframe(self, table_name: str, df: pd.DataFrame) -> None: df: pd.DataFrame The DataFrame to write as a table. """ - + @abstractmethod + async def has_dataframe(self, table_name: str) -> bool: + """Check if a table exists in the provider. + + Args + ---- + table_name: str + The name of the table to check. + + Returns + ------- + bool: + True if the table exists, False otherwise. + """ diff --git a/packages/graphrag/graphrag/index/workflows/create_community_reports.py b/packages/graphrag/graphrag/index/workflows/create_community_reports.py index abfdeca45a..d96fb76b4b 100644 --- a/packages/graphrag/graphrag/index/workflows/create_community_reports.py +++ b/packages/graphrag/graphrag/index/workflows/create_community_reports.py @@ -30,11 +30,6 @@ ) from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import ( - load_table_from_storage, - storage_has_table, - write_table_to_storage, -) if TYPE_CHECKING: from graphrag_llm.completion import LLMCompletion @@ -48,14 +43,16 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform community reports.""" logger.info("Workflow started: create_community_reports") - edges = await load_table_from_storage("relationships", context.output_storage) - entities = await load_table_from_storage("entities", context.output_storage) - communities = await load_table_from_storage("communities", context.output_storage) + edges = await context.output_table_provider.read_dataframe("relationships") + entities = await context.output_table_provider.read_dataframe("entities") + communities = await context.output_table_provider.read_dataframe("communities") + claims = None - if config.extract_claims.enabled and await storage_has_table( - "covariates", context.output_storage - ): - claims = await load_table_from_storage("covariates", context.output_storage) + if config.extract_claims.enabled: + try: + claims = await context.output_table_provider.read_dataframe("covariates") + except Exception: + pass model_config = config.get_completion_model_config( config.community_reports.completion_model_id @@ -85,7 +82,7 @@ async def run_workflow( async_type=config.async_mode, ) - await write_table_to_storage(output, "community_reports", context.output_storage) + await context.output_table_provider.write_dataframe("community_reports", output) logger.info("Workflow completed: create_community_reports") return WorkflowFunctionOutput(result=output) diff --git a/packages/graphrag/graphrag/index/workflows/create_community_reports_text.py b/packages/graphrag/graphrag/index/workflows/create_community_reports_text.py index 8a6be96e68..52cb4b0f8e 100644 --- a/packages/graphrag/graphrag/index/workflows/create_community_reports_text.py +++ b/packages/graphrag/graphrag/index/workflows/create_community_reports_text.py @@ -29,7 +29,6 @@ ) from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage if TYPE_CHECKING: from graphrag_llm.completion import LLMCompletion @@ -43,10 +42,9 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform community reports.""" logger.info("Workflow started: create_community_reports_text") - entities = await load_table_from_storage("entities", context.output_storage) - communities = await load_table_from_storage("communities", context.output_storage) - - text_units = await load_table_from_storage("text_units", context.output_storage) + entities = await context.output_table_provider.read_dataframe("entities") + communities = await context.output_table_provider.read_dataframe("communities") + text_units = await context.output_table_provider.read_dataframe("text_units") model_config = config.get_completion_model_config( config.community_reports.completion_model_id @@ -75,7 +73,7 @@ async def run_workflow( async_type=config.async_mode, ) - await write_table_to_storage(output, "community_reports", context.output_storage) + await context.output_table_provider.write_dataframe("community_reports", output) logger.info("Workflow completed: create_community_reports_text") return WorkflowFunctionOutput(result=output) diff --git a/packages/graphrag/graphrag/index/workflows/extract_covariates.py b/packages/graphrag/graphrag/index/workflows/extract_covariates.py index 18b470a8b1..f27d8590d1 100644 --- a/packages/graphrag/graphrag/index/workflows/extract_covariates.py +++ b/packages/graphrag/graphrag/index/workflows/extract_covariates.py @@ -21,7 +21,6 @@ ) from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage if TYPE_CHECKING: from graphrag_llm.completion import LLMCompletion @@ -37,7 +36,7 @@ async def run_workflow( logger.info("Workflow started: extract_covariates") output = None if config.extract_claims.enabled: - text_units = await load_table_from_storage("text_units", context.output_storage) + text_units = await context.output_table_provider.read_dataframe("text_units") model_config = config.get_completion_model_config( config.extract_claims.completion_model_id @@ -64,7 +63,7 @@ async def run_workflow( async_type=config.async_mode, ) - await write_table_to_storage(output, "covariates", context.output_storage) + await context.output_table_provider.write_dataframe("covariates", output) logger.info("Workflow completed: extract_covariates") return WorkflowFunctionOutput(result=output) diff --git a/packages/graphrag/graphrag/index/workflows/extract_graph_nlp.py b/packages/graphrag/graphrag/index/workflows/extract_graph_nlp.py index 38810e5de4..c0cd069ac6 100644 --- a/packages/graphrag/graphrag/index/workflows/extract_graph_nlp.py +++ b/packages/graphrag/graphrag/index/workflows/extract_graph_nlp.py @@ -19,7 +19,6 @@ ) from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -30,7 +29,7 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to create the base entity graph.""" logger.info("Workflow started: extract_graph_nlp") - text_units = await load_table_from_storage("text_units", context.output_storage) + text_units = await context.output_table_provider.read_dataframe("text_units") text_analyzer_config = config.extract_graph_nlp.text_analyzer text_analyzer = create_noun_phrase_extractor(text_analyzer_config) @@ -44,8 +43,8 @@ async def run_workflow( async_type=config.extract_graph_nlp.async_mode, ) - await write_table_to_storage(entities, "entities", context.output_storage) - await write_table_to_storage(relationships, "relationships", context.output_storage) + await context.output_table_provider.write_dataframe("entities", entities) + await context.output_table_provider.write_dataframe("relationships", relationships) logger.info("Workflow completed: extract_graph_nlp") diff --git a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py index 16b726028e..feb4a1e445 100644 --- a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py +++ b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py @@ -25,10 +25,6 @@ from graphrag.index.operations.embed_text.embed_text import embed_text from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import ( - load_table_from_storage, - write_table_to_storage, -) if TYPE_CHECKING: from graphrag_llm.embedding import LLMEmbedding @@ -48,13 +44,11 @@ async def run_workflow( entities = None community_reports = None if text_unit_text_embedding in embedded_fields: - text_units = await load_table_from_storage("text_units", context.output_storage) + text_units = await context.output_table_provider.read_dataframe("text_units") if entity_description_embedding in embedded_fields: - entities = await load_table_from_storage("entities", context.output_storage) + entities = await context.output_table_provider.read_dataframe("entities") if community_full_content_embedding in embedded_fields: - community_reports = await load_table_from_storage( - "community_reports", context.output_storage - ) + community_reports = await context.output_table_provider.read_dataframe("community_reports") model_config = config.get_embedding_model_config( config.embed_text.embedding_model_id @@ -84,10 +78,9 @@ async def run_workflow( if config.snapshots.embeddings: for name, table in output.items(): - await write_table_to_storage( - table, + await context.output_table_provider.write_dataframe( f"embeddings.{name}", - context.output_storage, + table, ) logger.info("Workflow completed: generate_text_embeddings") diff --git a/packages/graphrag/graphrag/index/workflows/load_input_documents.py b/packages/graphrag/graphrag/index/workflows/load_input_documents.py index 0a5aa65454..ed7f83c8e2 100644 --- a/packages/graphrag/graphrag/index/workflows/load_input_documents.py +++ b/packages/graphrag/graphrag/index/workflows/load_input_documents.py @@ -11,7 +11,6 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import write_table_to_storage logger = logging.getLogger(__name__) @@ -33,7 +32,7 @@ async def run_workflow( logger.info("Final # of rows loaded: %s", len(output)) context.stats.num_documents = len(output) - await write_table_to_storage(output, "documents", context.output_storage) + await context.output_table_provider.write_dataframe("documents", output) return WorkflowFunctionOutput(result=output) diff --git a/packages/graphrag/graphrag/index/workflows/load_update_documents.py b/packages/graphrag/graphrag/index/workflows/load_update_documents.py index 1cab6cabfe..2116b337c5 100644 --- a/packages/graphrag/graphrag/index/workflows/load_update_documents.py +++ b/packages/graphrag/graphrag/index/workflows/load_update_documents.py @@ -14,7 +14,6 @@ from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.update.incremental_index import get_delta_docs -from graphrag.utils.storage import write_table_to_storage logger = logging.getLogger(__name__) @@ -37,7 +36,7 @@ async def run_workflow( logger.warning("No new update documents found.") return WorkflowFunctionOutput(result=None, stop=True) - await write_table_to_storage(output, "documents", context.output_storage) + await context.output_table_provider.write_dataframe("documents", output) return WorkflowFunctionOutput(result=output) diff --git a/packages/graphrag/graphrag/index/workflows/update_communities.py b/packages/graphrag/graphrag/index/workflows/update_communities.py index da4fdef147..3e6369fbdd 100644 --- a/packages/graphrag/graphrag/index/workflows/update_communities.py +++ b/packages/graphrag/graphrag/index/workflows/update_communities.py @@ -5,14 +5,13 @@ import logging -from graphrag_storage import Storage +from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.run.utils import get_update_storages from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.update.communities import _update_and_merge_communities -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -26,9 +25,13 @@ async def run_workflow( output_storage, previous_storage, delta_storage = get_update_storages( config, context.state["update_timestamp"] ) + + previous_table_provider = ParquetTableProvider(previous_storage) + delta_table_provider = ParquetTableProvider(delta_storage) + output_table_provider = ParquetTableProvider(output_storage) community_id_mapping = await _update_communities( - previous_storage, delta_storage, output_storage + previous_table_provider, delta_table_provider, output_table_provider ) context.state["incremental_update_community_id_mapping"] = community_id_mapping @@ -38,17 +41,17 @@ async def run_workflow( async def _update_communities( - previous_storage: Storage, - delta_storage: Storage, - output_storage: Storage, + previous_table_provider: ParquetTableProvider, + delta_table_provider: ParquetTableProvider, + output_table_provider: ParquetTableProvider, ) -> dict: """Update the communities output.""" - old_communities = await load_table_from_storage("communities", previous_storage) - delta_communities = await load_table_from_storage("communities", delta_storage) + old_communities = await previous_table_provider.read_dataframe("communities") + delta_communities = await delta_table_provider.read_dataframe("communities") merged_communities, community_id_mapping = _update_and_merge_communities( old_communities, delta_communities ) - await write_table_to_storage(merged_communities, "communities", output_storage) + await output_table_provider.write_dataframe("communities", merged_communities) return community_id_mapping diff --git a/packages/graphrag/graphrag/index/workflows/update_community_reports.py b/packages/graphrag/graphrag/index/workflows/update_community_reports.py index 790f9fc296..0929455c0a 100644 --- a/packages/graphrag/graphrag/index/workflows/update_community_reports.py +++ b/packages/graphrag/graphrag/index/workflows/update_community_reports.py @@ -6,14 +6,13 @@ import logging import pandas as pd -from graphrag_storage import Storage +from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.run.utils import get_update_storages from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.update.communities import _update_and_merge_community_reports -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -27,11 +26,15 @@ async def run_workflow( output_storage, previous_storage, delta_storage = get_update_storages( config, context.state["update_timestamp"] ) + + previous_table_provider = ParquetTableProvider(previous_storage) + delta_table_provider = ParquetTableProvider(delta_storage) + output_table_provider = ParquetTableProvider(output_storage) community_id_mapping = context.state["incremental_update_community_id_mapping"] merged_community_reports = await _update_community_reports( - previous_storage, delta_storage, output_storage, community_id_mapping + previous_table_provider, delta_table_provider, output_table_provider, community_id_mapping ) context.state["incremental_update_merged_community_reports"] = ( @@ -43,24 +46,24 @@ async def run_workflow( async def _update_community_reports( - previous_storage: Storage, - delta_storage: Storage, - output_storage: Storage, + previous_table_provider: ParquetTableProvider, + delta_table_provider: ParquetTableProvider, + output_table_provider: ParquetTableProvider, community_id_mapping: dict, ) -> pd.DataFrame: """Update the community reports output.""" - old_community_reports = await load_table_from_storage( - "community_reports", previous_storage + old_community_reports = await previous_table_provider.read_dataframe( + "community_reports" ) - delta_community_reports = await load_table_from_storage( - "community_reports", delta_storage + delta_community_reports = await delta_table_provider.read_dataframe( + "community_reports" ) merged_community_reports = _update_and_merge_community_reports( old_community_reports, delta_community_reports, community_id_mapping ) - await write_table_to_storage( - merged_community_reports, "community_reports", output_storage + await output_table_provider.write_dataframe( + "community_reports", merged_community_reports ) return merged_community_reports diff --git a/packages/graphrag/graphrag/index/workflows/update_covariates.py b/packages/graphrag/graphrag/index/workflows/update_covariates.py index 09f8b4053d..c2bdcc2b44 100644 --- a/packages/graphrag/graphrag/index/workflows/update_covariates.py +++ b/packages/graphrag/graphrag/index/workflows/update_covariates.py @@ -7,17 +7,12 @@ import numpy as np import pandas as pd -from graphrag_storage import Storage +from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.run.utils import get_update_storages from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import ( - load_table_from_storage, - storage_has_table, - write_table_to_storage, -) logger = logging.getLogger(__name__) @@ -31,28 +26,32 @@ async def run_workflow( output_storage, previous_storage, delta_storage = get_update_storages( config, context.state["update_timestamp"] ) - - if await storage_has_table( - "covariates", previous_storage - ) and await storage_has_table("covariates", delta_storage): + + previous_table_provider = ParquetTableProvider(previous_storage) + delta_table_provider = ParquetTableProvider(delta_storage) + output_table_provider = ParquetTableProvider(output_storage) + + if await previous_table_provider.has_dataframe( + "covariates" + ) and await delta_table_provider.has_dataframe("covariates"): logger.info("Updating Covariates") - await _update_covariates(previous_storage, delta_storage, output_storage) + await _update_covariates(previous_table_provider, delta_table_provider, output_table_provider) logger.info("Workflow completed: update_covariates") return WorkflowFunctionOutput(result=None) async def _update_covariates( - previous_storage: Storage, - delta_storage: Storage, - output_storage: Storage, + previous_table_provider: ParquetTableProvider, + delta_table_provider: ParquetTableProvider, + output_table_provider: ParquetTableProvider, ) -> None: """Update the covariates output.""" - old_covariates = await load_table_from_storage("covariates", previous_storage) - delta_covariates = await load_table_from_storage("covariates", delta_storage) + old_covariates = await previous_table_provider.read_dataframe("covariates") + delta_covariates = await delta_table_provider.read_dataframe("covariates") merged_covariates = _merge_covariates(old_covariates, delta_covariates) - await write_table_to_storage(merged_covariates, "covariates", output_storage) + await output_table_provider.write_dataframe("covariates", merged_covariates) def _merge_covariates( diff --git a/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py b/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py index 225c12d9b9..300a69ab9b 100644 --- a/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py +++ b/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py @@ -8,7 +8,7 @@ import pandas as pd from graphrag_cache import Cache from graphrag_llm.completion import create_completion -from graphrag_storage import Storage +from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider from graphrag.cache.cache_key_creator import cache_key_creator from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks @@ -19,7 +19,6 @@ from graphrag.index.update.entities import _group_and_resolve_entities from graphrag.index.update.relationships import _update_and_merge_relationships from graphrag.index.workflows.extract_graph import get_summarized_entities_relationships -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -33,15 +32,19 @@ async def run_workflow( output_storage, previous_storage, delta_storage = get_update_storages( config, context.state["update_timestamp"] ) + + previous_table_provider = ParquetTableProvider(previous_storage) + delta_table_provider = ParquetTableProvider(delta_storage) + output_table_provider = ParquetTableProvider(output_storage) ( merged_entities_df, merged_relationships_df, entity_id_mapping, ) = await _update_entities_and_relationships( - previous_storage, - delta_storage, - output_storage, + previous_table_provider, + delta_table_provider, + output_table_provider, config, context.cache, context.callbacks, @@ -56,24 +59,24 @@ async def run_workflow( async def _update_entities_and_relationships( - previous_storage: Storage, - delta_storage: Storage, - output_storage: Storage, + previous_table_provider: ParquetTableProvider, + delta_table_provider: ParquetTableProvider, + output_table_provider: ParquetTableProvider, config: GraphRagConfig, cache: Cache, callbacks: WorkflowCallbacks, ) -> tuple[pd.DataFrame, pd.DataFrame, dict]: """Update Final Entities and Relationships output.""" - old_entities = await load_table_from_storage("entities", previous_storage) - delta_entities = await load_table_from_storage("entities", delta_storage) + old_entities = await previous_table_provider.read_dataframe("entities") + delta_entities = await delta_table_provider.read_dataframe("entities") merged_entities_df, entity_id_mapping = _group_and_resolve_entities( old_entities, delta_entities ) # Update Relationships - old_relationships = await load_table_from_storage("relationships", previous_storage) - delta_relationships = await load_table_from_storage("relationships", delta_storage) + old_relationships = await previous_table_provider.read_dataframe("relationships") + delta_relationships = await delta_table_provider.read_dataframe("relationships") merged_relationships_df = _update_and_merge_relationships( old_relationships, delta_relationships, @@ -104,10 +107,7 @@ async def _update_entities_and_relationships( ) # Save the updated entities back to storage - await write_table_to_storage(merged_entities_df, "entities", output_storage) - - await write_table_to_storage( - merged_relationships_df, "relationships", output_storage - ) + await output_table_provider.write_dataframe("entities", merged_entities_df) + await output_table_provider.write_dataframe("relationships", merged_relationships_df) return merged_entities_df, merged_relationships_df, entity_id_mapping diff --git a/packages/graphrag/graphrag/index/workflows/update_text_embeddings.py b/packages/graphrag/graphrag/index/workflows/update_text_embeddings.py index 375bb69df4..1c3a1ca0ac 100644 --- a/packages/graphrag/graphrag/index/workflows/update_text_embeddings.py +++ b/packages/graphrag/graphrag/index/workflows/update_text_embeddings.py @@ -13,7 +13,7 @@ from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.workflows.generate_text_embeddings import generate_text_embeddings -from graphrag.utils.storage import write_table_to_storage +from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider logger = logging.getLogger(__name__) @@ -27,6 +27,8 @@ async def run_workflow( output_storage, _, _ = get_update_storages( config, context.state["update_timestamp"] ) + output_table_provider = ParquetTableProvider(output_storage) + merged_text_units = context.state["incremental_update_merged_text_units"] merged_entities_df = context.state["incremental_update_merged_entities"] merged_community_reports = context.state[ @@ -62,11 +64,7 @@ async def run_workflow( ) if config.snapshots.embeddings: for name, table in result.items(): - await write_table_to_storage( - table, - f"embeddings.{name}", - output_storage, - ) + await output_table_provider.write_dataframe(f"embeddings.{name}", table) logger.info("Workflow completed: update_text_embeddings") return WorkflowFunctionOutput(result=None) diff --git a/packages/graphrag/graphrag/index/workflows/update_text_units.py b/packages/graphrag/graphrag/index/workflows/update_text_units.py index c97f89ce7a..2d113709f5 100644 --- a/packages/graphrag/graphrag/index/workflows/update_text_units.py +++ b/packages/graphrag/graphrag/index/workflows/update_text_units.py @@ -8,12 +8,12 @@ import numpy as np import pandas as pd from graphrag_storage import Storage +from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.run.utils import get_update_storages from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -46,13 +46,17 @@ async def _update_text_units( entity_id_mapping: dict, ) -> pd.DataFrame: """Update the text units output.""" - old_text_units = await load_table_from_storage("text_units", previous_storage) - delta_text_units = await load_table_from_storage("text_units", delta_storage) + previous_table_provider = ParquetTableProvider(previous_storage) + delta_table_provider = ParquetTableProvider(delta_storage) + output_table_provider = ParquetTableProvider(output_storage) + + old_text_units = await previous_table_provider.read_dataframe("text_units") + delta_text_units = await delta_table_provider.read_dataframe("text_units") merged_text_units = _update_and_merge_text_units( old_text_units, delta_text_units, entity_id_mapping ) - await write_table_to_storage(merged_text_units, "text_units", output_storage) + await output_table_provider.write_dataframe("text_units", merged_text_units) return merged_text_units From 6c2e04f3fe38d100f4d4f93f9ffad7a116a354a1 Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Mon, 26 Jan 2026 19:38:18 -0300 Subject: [PATCH 4/7] add fixed files --- .../tables/parquet_table_provider.py | 53 ++++++++++++++----- .../graphrag_storage/tables/table_provider.py | 24 ++++++--- .../graphrag/index/run/run_pipeline.py | 39 +++++++++----- packages/graphrag/graphrag/index/run/utils.py | 28 ++++++---- .../graphrag/graphrag/index/typing/context.py | 8 +-- .../index/update/incremental_index.py | 29 +++++----- .../index/workflows/create_communities.py | 1 - .../workflows/create_community_reports.py | 12 ++--- .../workflows/create_final_text_units.py | 18 ++++--- .../graphrag/index/workflows/extract_graph.py | 15 +++--- .../index/workflows/finalize_graph.py | 4 +- .../workflows/generate_text_embeddings.py | 4 +- .../index/workflows/load_update_documents.py | 14 +++-- .../graphrag/index/workflows/prune_graph.py | 4 +- .../index/workflows/update_communities.py | 18 +++---- .../workflows/update_community_reports.py | 23 ++++---- .../index/workflows/update_covariates.py | 22 ++++---- .../update_entities_relationships.py | 22 ++++---- .../index/workflows/update_final_documents.py | 11 ++-- .../index/workflows/update_text_embeddings.py | 6 +-- .../index/workflows/update_text_units.py | 24 ++++----- .../storage/test_parquet_table_provider.py | 33 +++++++++--- 22 files changed, 246 insertions(+), 166 deletions(-) diff --git a/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py index 9476202d48..75805be23a 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py +++ b/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py @@ -3,25 +3,28 @@ """Parquet-based table provider implementation.""" -from io import BytesIO import logging -from typing import Any +import re +from io import BytesIO + import pandas as pd + from graphrag_storage.storage import Storage from graphrag_storage.tables.table_provider import TableProvider logger = logging.getLogger(__name__) + class ParquetTableProvider(TableProvider): """Table provider that stores tables as Parquet files using an underlying Storage instance. - + This provider converts between pandas DataFrames and Parquet format, storing the data through a Storage backend (file, blob, cosmos, etc.). """ def __init__(self, storage: Storage, **kwargs) -> None: """Initialize the Parquet table provider with an underlying storage instance. - + Args ---- storage: Storage @@ -33,17 +36,17 @@ def __init__(self, storage: Storage, **kwargs) -> None: async def read_dataframe(self, table_name: str) -> pd.DataFrame: """Read a table from storage as a pandas DataFrame. - + Args ---- table_name: str The name of the table to read. The file will be accessed as '{table_name}.parquet'. - + Returns ------- pd.DataFrame: The table data loaded from the Parquet file. - + Raises ------ ValueError: @@ -57,15 +60,16 @@ async def read_dataframe(self, table_name: str) -> pd.DataFrame: raise ValueError(msg) try: logger.info("reading table from storage: %s", filename) - return pd.read_parquet(BytesIO(await self._storage.get(filename, as_bytes=True))) + return pd.read_parquet( + BytesIO(await self._storage.get(filename, as_bytes=True)) + ) except Exception: logger.exception("error loading table from storage: %s", filename) raise - async def write_dataframe(self, table_name: str, df: pd.DataFrame) -> None: """Write a pandas DataFrame to storage as a Parquet file. - + Args ---- table_name: str @@ -75,5 +79,30 @@ async def write_dataframe(self, table_name: str, df: pd.DataFrame) -> None: """ await self._storage.set(f"{table_name}.parquet", df.to_parquet()) - async def has_dataframe(self, table_name): - return await self._storage.has(f"{table_name}.parquet") \ No newline at end of file + async def has_dataframe(self, table_name: str) -> bool: + """Check if a table exists in storage. + + Args + ---- + table_name: str + The name of the table to check. + + Returns + ------- + bool: + True if the table exists, False otherwise. + """ + return await self._storage.has(f"{table_name}.parquet") + + def find_tables(self) -> list[str]: + """Find all table names in storage. + + Returns + ------- + list[str]: + List of table names (without .parquet extension). + """ + return [ + file.replace(".parquet", "") + for file in self._storage.find(re.compile(r"\.parquet$")) + ] diff --git a/packages/graphrag-storage/graphrag_storage/tables/table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/table_provider.py index 06f5f760eb..0d48480892 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/table_provider.py +++ b/packages/graphrag-storage/graphrag_storage/tables/table_provider.py @@ -8,13 +8,14 @@ import pandas as pd + class TableProvider(ABC): """Provide a table-based storage interface with support for DataFrames and row dictionaries.""" @abstractmethod def __init__(self, **kwargs: Any) -> None: """Create a table provider instance. - + Args ---- **kwargs: Any @@ -24,22 +25,22 @@ def __init__(self, **kwargs: Any) -> None: @abstractmethod async def read_dataframe(self, table_name: str) -> pd.DataFrame: """Read entire table as a pandas DataFrame. - + Args ---- table_name: str The name of the table to read. - + Returns ------- pd.DataFrame: The table data as a DataFrame. """ - + @abstractmethod async def write_dataframe(self, table_name: str, df: pd.DataFrame) -> None: """Write entire table from a pandas DataFrame. - + Args ---- table_name: str @@ -47,19 +48,28 @@ async def write_dataframe(self, table_name: str, df: pd.DataFrame) -> None: df: pd.DataFrame The DataFrame to write as a table. """ + @abstractmethod async def has_dataframe(self, table_name: str) -> bool: """Check if a table exists in the provider. - + Args ---- table_name: str The name of the table to check. - + Returns ------- bool: True if the table exists, False otherwise. """ + @abstractmethod + def find_tables(self) -> list[str]: + """Find all table names in the provider. + Returns + ------- + list[str]: + List of table names (without file extensions). + """ diff --git a/packages/graphrag/graphrag/index/run/run_pipeline.py b/packages/graphrag/graphrag/index/run/run_pipeline.py index a4ce17582c..dadeb9bbf4 100644 --- a/packages/graphrag/graphrag/index/run/run_pipeline.py +++ b/packages/graphrag/graphrag/index/run/run_pipeline.py @@ -5,7 +5,6 @@ import json import logging -import re import time from collections.abc import AsyncIterable from dataclasses import asdict @@ -13,7 +12,9 @@ import pandas as pd from graphrag_cache import create_cache -from graphrag_storage import Storage, create_storage +from graphrag_storage import create_storage +from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider +from graphrag_storage.tables.table_provider import TableProvider from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig @@ -21,7 +22,7 @@ from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.pipeline import Pipeline from graphrag.index.typing.pipeline_run_result import PipelineRunResult -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage +from graphrag.utils.storage import write_table_to_storage logger = logging.getLogger(__name__) @@ -36,7 +37,9 @@ async def run_pipeline( ) -> AsyncIterable[PipelineRunResult]: """Run all workflows using a simplified pipeline.""" input_storage = create_storage(config.input_storage) - output_storage = create_storage(config.output_storage) + input_table_provider = ParquetTableProvider(input_storage) + + output_storage = create_storage(config.output) cache = create_cache(config.cache) # load existing state in case any workflows are stateful @@ -54,22 +57,28 @@ async def run_pipeline( update_timestamp = time.strftime("%Y%m%d-%H%M%S") timestamped_storage = update_storage.child(update_timestamp) delta_storage = timestamped_storage.child("delta") + delta_table_provider = ParquetTableProvider(delta_storage) # copy the previous output to a backup folder, so we can replace it with the update # we'll read from this later when we merge the old and new indexes previous_storage = timestamped_storage.child("previous") - await _copy_previous_output(output_storage, previous_storage) + previous_table_provider = ParquetTableProvider(previous_storage) + + output_table_provider = ParquetTableProvider(output_storage) + await _copy_previous_output(output_table_provider, previous_table_provider) state["update_timestamp"] = update_timestamp # if the user passes in a df directly, write directly to storage so we can skip finding/parsing later if input_documents is not None: - await write_table_to_storage(input_documents, "documents", delta_storage) + await delta_table_provider.write_dataframe("documents", input_documents) pipeline.remove("load_update_documents") context = create_run_context( input_storage=input_storage, + input_table_provider=input_table_provider, output_storage=delta_storage, - previous_storage=previous_storage, + output_table_provider=delta_table_provider, + previous_table_provider=previous_table_provider, cache=cache, callbacks=callbacks, state=state, @@ -85,7 +94,9 @@ async def run_pipeline( context = create_run_context( input_storage=input_storage, + input_table_provider=input_table_provider, output_storage=output_storage, + output_table_provider=ParquetTableProvider(storage=output_storage), cache=cache, callbacks=callbacks, state=state, @@ -156,10 +167,10 @@ async def _dump_json(context: PipelineRunContext) -> None: async def _copy_previous_output( - storage: Storage, - copy_storage: Storage, -): - for file in storage.find(re.compile(r"\.parquet$")): - base_name = file.replace(".parquet", "") - table = await load_table_from_storage(base_name, storage) - await write_table_to_storage(table, base_name, copy_storage) + output_table_provider: TableProvider, + previous_table_provider: TableProvider, +) -> None: + """Copy all parquet tables from output to previous storage for backup.""" + for table_name in output_table_provider.find_tables(): + table = await output_table_provider.read_dataframe(table_name) + await previous_table_provider.write_dataframe(table_name, table) diff --git a/packages/graphrag/graphrag/index/run/utils.py b/packages/graphrag/graphrag/index/run/utils.py index 93fcdfbded..95b0b399f2 100644 --- a/packages/graphrag/graphrag/index/run/utils.py +++ b/packages/graphrag/graphrag/index/run/utils.py @@ -19,20 +19,26 @@ def create_run_context( input_storage: Storage | None = None, + input_table_provider: ParquetTableProvider | None = None, output_storage: Storage | None = None, - previous_storage: Storage | None = None, + output_table_provider: ParquetTableProvider | None = None, + previous_table_provider: ParquetTableProvider | None = None, cache: Cache | None = None, callbacks: WorkflowCallbacks | None = None, stats: PipelineRunStats | None = None, state: PipelineState | None = None, ) -> PipelineRunContext: """Create the run context for the pipeline.""" + input_storage = input_storage or MemoryStorage() output_storage = output_storage or MemoryStorage() return PipelineRunContext( - input_storage=input_storage or MemoryStorage(), + input_storage=input_storage, + input_table_provider=input_table_provider + or ParquetTableProvider(storage=input_storage), output_storage=output_storage, - output_table_provider=ParquetTableProvider(storage=output_storage), - previous_storage=previous_storage or MemoryStorage(), + output_table_provider=output_table_provider + or ParquetTableProvider(storage=output_storage), + previous_table_provider=previous_table_provider, cache=cache or MemoryCache(), callbacks=callbacks or NoopWorkflowCallbacks(), stats=stats or PipelineRunStats(), @@ -50,14 +56,18 @@ def create_callback_chain( return manager -def get_update_storages( +def get_update_table_providers( config: GraphRagConfig, timestamp: str -) -> tuple[Storage, Storage, Storage]: - """Get storage objects for the update index run.""" - output_storage = create_storage(config.output_storage) +) -> tuple[ParquetTableProvider, ParquetTableProvider, ParquetTableProvider]: + """Get table providers for the update index run.""" + output_storage = create_storage(config.output) update_storage = create_storage(config.update_output_storage) timestamped_storage = update_storage.child(timestamp) delta_storage = timestamped_storage.child("delta") previous_storage = timestamped_storage.child("previous") - return output_storage, previous_storage, delta_storage + output_table_provider = ParquetTableProvider(output_storage) + previous_table_provider = ParquetTableProvider(previous_storage) + delta_table_provider = ParquetTableProvider(delta_storage) + + return output_table_provider, previous_table_provider, delta_table_provider diff --git a/packages/graphrag/graphrag/index/typing/context.py b/packages/graphrag/graphrag/index/typing/context.py index 1f25ec8f47..48c66e16c4 100644 --- a/packages/graphrag/graphrag/index/typing/context.py +++ b/packages/graphrag/graphrag/index/typing/context.py @@ -19,13 +19,15 @@ class PipelineRunContext: stats: PipelineRunStats input_storage: Storage - "Storage for input documents." + "Storage for reading input documents." + input_table_provider: ParquetTableProvider + "Table provider for reading input tables." output_storage: Storage "Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider." output_table_provider: ParquetTableProvider "Table provider for reading and writing output tables." - previous_storage: Storage - "Storage for previous pipeline run when running in update mode." + previous_table_provider: ParquetTableProvider | None + "Table provider for reading previous pipeline run when running in update mode." cache: Cache "Cache instance for reading previous LLM responses." callbacks: WorkflowCallbacks diff --git a/packages/graphrag/graphrag/index/update/incremental_index.py b/packages/graphrag/graphrag/index/update/incremental_index.py index 81f917e187..0e7eb34684 100644 --- a/packages/graphrag/graphrag/index/update/incremental_index.py +++ b/packages/graphrag/graphrag/index/update/incremental_index.py @@ -7,12 +7,7 @@ import numpy as np import pandas as pd -from graphrag_storage import Storage - -from graphrag.utils.storage import ( - load_table_from_storage, - write_table_to_storage, -) +from graphrag_storage.tables.table_provider import TableProvider @dataclass @@ -31,22 +26,24 @@ class InputDelta: deleted_inputs: pd.DataFrame -async def get_delta_docs(input_dataset: pd.DataFrame, storage: Storage) -> InputDelta: +async def get_delta_docs( + input_dataset: pd.DataFrame, table_provider: TableProvider +) -> InputDelta: """Get the delta between the input dataset and the final documents. Parameters ---------- input_dataset : pd.DataFrame The input dataset. - storage : Storage - The Pipeline storage. + table_provider : TableProvider + The table provider for reading previous documents. Returns ------- InputDelta The input delta. With new inputs and deleted inputs. """ - final_docs = await load_table_from_storage("documents", storage) + final_docs = await table_provider.read_dataframe("documents") # Select distinct title from final docs and from dataset previous_docs: list[str] = final_docs["title"].unique().tolist() @@ -63,19 +60,19 @@ async def get_delta_docs(input_dataset: pd.DataFrame, storage: Storage) -> Input async def concat_dataframes( name: str, - previous_storage: Storage, - delta_storage: Storage, - output_storage: Storage, + previous_table_provider: TableProvider, + delta_table_provider: TableProvider, + output_table_provider: TableProvider, ) -> pd.DataFrame: """Concatenate dataframes.""" - old_df = await load_table_from_storage(name, previous_storage) - delta_df = await load_table_from_storage(name, delta_storage) + old_df = await previous_table_provider.read_dataframe(name) + delta_df = await delta_table_provider.read_dataframe(name) # Merge the final documents initial_id = old_df["human_readable_id"].max() + 1 delta_df["human_readable_id"] = np.arange(initial_id, initial_id + len(delta_df)) final_df = pd.concat([old_df, delta_df], ignore_index=True, copy=False) - await write_table_to_storage(final_df, name, output_storage) + await output_table_provider.write_dataframe(name, final_df) return final_df diff --git a/packages/graphrag/graphrag/index/workflows/create_communities.py b/packages/graphrag/graphrag/index/workflows/create_communities.py index f90889bac8..7c3d7a6b33 100644 --- a/packages/graphrag/graphrag/index/workflows/create_communities.py +++ b/packages/graphrag/graphrag/index/workflows/create_communities.py @@ -17,7 +17,6 @@ from graphrag.index.operations.create_graph import create_graph from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) diff --git a/packages/graphrag/graphrag/index/workflows/create_community_reports.py b/packages/graphrag/graphrag/index/workflows/create_community_reports.py index d96fb76b4b..6f8b061a30 100644 --- a/packages/graphrag/graphrag/index/workflows/create_community_reports.py +++ b/packages/graphrag/graphrag/index/workflows/create_community_reports.py @@ -46,13 +46,13 @@ async def run_workflow( edges = await context.output_table_provider.read_dataframe("relationships") entities = await context.output_table_provider.read_dataframe("entities") communities = await context.output_table_provider.read_dataframe("communities") - + claims = None - if config.extract_claims.enabled: - try: - claims = await context.output_table_provider.read_dataframe("covariates") - except Exception: - pass + if ( + config.extract_claims.enabled + and await context.output_table_provider.has_dataframe("covariates") + ): + claims = await context.output_table_provider.read_dataframe("covariates") model_config = config.get_completion_model_config( config.community_reports.completion_model_id diff --git a/packages/graphrag/graphrag/index/workflows/create_final_text_units.py b/packages/graphrag/graphrag/index/workflows/create_final_text_units.py index 91b5885751..9c897b28f3 100644 --- a/packages/graphrag/graphrag/index/workflows/create_final_text_units.py +++ b/packages/graphrag/graphrag/index/workflows/create_final_text_units.py @@ -23,14 +23,18 @@ async def run_workflow( logger.info("Workflow started: create_final_text_units") text_units = await context.output_table_provider.read_dataframe("text_units") final_entities = await context.output_table_provider.read_dataframe("entities") - final_relationships = await context.output_table_provider.read_dataframe("relationships") - + final_relationships = await context.output_table_provider.read_dataframe( + "relationships" + ) + final_covariates = None - if config.extract_claims.enabled: - try: - final_covariates = await context.output_table_provider.read_dataframe("covariates") - except Exception: - pass + if ( + config.extract_claims.enabled + and await context.output_table_provider.has_dataframe("covariates") + ): + final_covariates = await context.output_table_provider.read_dataframe( + "covariates" + ) output = create_final_text_units( text_units, diff --git a/packages/graphrag/graphrag/index/workflows/extract_graph.py b/packages/graphrag/graphrag/index/workflows/extract_graph.py index 6d6520e401..237bbe16cc 100644 --- a/packages/graphrag/graphrag/index/workflows/extract_graph.py +++ b/packages/graphrag/graphrag/index/workflows/extract_graph.py @@ -21,7 +21,6 @@ ) from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage if TYPE_CHECKING: from graphrag_llm.completion import LLMCompletion @@ -35,7 +34,7 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to create the base entity graph.""" logger.info("Workflow started: extract_graph") - text_units = await load_table_from_storage("text_units", context.output_storage) + text_units = await context.output_table_provider.read_dataframe("text_units") extraction_model_config = config.get_completion_model_config( config.extract_graph.completion_model_id @@ -73,15 +72,15 @@ async def run_workflow( summarization_num_threads=config.concurrent_requests, ) - await write_table_to_storage(entities, "entities", context.output_storage) - await write_table_to_storage(relationships, "relationships", context.output_storage) + await context.output_table_provider.write_dataframe("entities", entities) + await context.output_table_provider.write_dataframe("relationships", relationships) if config.snapshots.raw_graph: - await write_table_to_storage( - raw_entities, "raw_entities", context.output_storage + await context.output_table_provider.write_dataframe( + "raw_entities", raw_entities ) - await write_table_to_storage( - raw_relationships, "raw_relationships", context.output_storage + await context.output_table_provider.write_dataframe( + "raw_relationships", raw_relationships ) logger.info("Workflow completed: extract_graph") diff --git a/packages/graphrag/graphrag/index/workflows/finalize_graph.py b/packages/graphrag/graphrag/index/workflows/finalize_graph.py index 6edcb7ba93..64029a8cb6 100644 --- a/packages/graphrag/graphrag/index/workflows/finalize_graph.py +++ b/packages/graphrag/graphrag/index/workflows/finalize_graph.py @@ -33,7 +33,9 @@ async def run_workflow( ) await context.output_table_provider.write_dataframe("entities", final_entities) - await context.output_table_provider.write_dataframe("relationships", final_relationships) + await context.output_table_provider.write_dataframe( + "relationships", final_relationships + ) if config.snapshots.graphml: graph = create_graph(final_relationships, edge_attr=["weight"]) diff --git a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py index feb4a1e445..c1e42969ee 100644 --- a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py +++ b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py @@ -48,7 +48,9 @@ async def run_workflow( if entity_description_embedding in embedded_fields: entities = await context.output_table_provider.read_dataframe("entities") if community_full_content_embedding in embedded_fields: - community_reports = await context.output_table_provider.read_dataframe("community_reports") + community_reports = await context.output_table_provider.read_dataframe( + "community_reports" + ) model_config = config.get_embedding_model_config( config.embed_text.embedding_model_id diff --git a/packages/graphrag/graphrag/index/workflows/load_update_documents.py b/packages/graphrag/graphrag/index/workflows/load_update_documents.py index 2116b337c5..3f4417d3e1 100644 --- a/packages/graphrag/graphrag/index/workflows/load_update_documents.py +++ b/packages/graphrag/graphrag/index/workflows/load_update_documents.py @@ -8,7 +8,7 @@ import pandas as pd from graphrag_input.input_reader import InputReader from graphrag_input.input_reader_factory import create_input_reader -from graphrag_storage import Storage +from graphrag_storage.tables.table_provider import TableProvider from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.typing.context import PipelineRunContext @@ -23,10 +23,14 @@ async def run_workflow( context: PipelineRunContext, ) -> WorkflowFunctionOutput: """Load and parse update-only input documents into a standard format.""" + if context.previous_table_provider is None: + msg = "previous_table_provider is required for update workflows" + raise ValueError(msg) + input_reader = create_input_reader(config.input, context.input_storage) output = await load_update_documents( input_reader, - context.previous_storage, + context.previous_table_provider, ) logger.info("Final # of update rows loaded: %s", len(output)) @@ -43,11 +47,11 @@ async def run_workflow( async def load_update_documents( input_reader: InputReader, - previous_storage: Storage, + previous_table_provider: TableProvider, ) -> pd.DataFrame: """Load and parse update-only input documents into a standard format.""" input_documents = pd.DataFrame(await input_reader.read_files()) - # previous storage is the output of the previous run + # previous table provider has the output of the previous run # we'll use this to diff the input from the prior - delta_documents = await get_delta_docs(input_documents, previous_storage) + delta_documents = await get_delta_docs(input_documents, previous_table_provider) return delta_documents.new_inputs diff --git a/packages/graphrag/graphrag/index/workflows/prune_graph.py b/packages/graphrag/graphrag/index/workflows/prune_graph.py index 7a0fc73702..483c9b18b3 100644 --- a/packages/graphrag/graphrag/index/workflows/prune_graph.py +++ b/packages/graphrag/graphrag/index/workflows/prune_graph.py @@ -34,7 +34,9 @@ async def run_workflow( ) await context.output_table_provider.write_dataframe("entities", pruned_entities) - await context.output_table_provider.write_dataframe("relationships", pruned_relationships) + await context.output_table_provider.write_dataframe( + "relationships", pruned_relationships + ) logger.info("Workflow completed: prune_graph") return WorkflowFunctionOutput( diff --git a/packages/graphrag/graphrag/index/workflows/update_communities.py b/packages/graphrag/graphrag/index/workflows/update_communities.py index 3e6369fbdd..7887706a86 100644 --- a/packages/graphrag/graphrag/index/workflows/update_communities.py +++ b/packages/graphrag/graphrag/index/workflows/update_communities.py @@ -5,10 +5,10 @@ import logging -from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider +from graphrag_storage.tables.table_provider import TableProvider from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.run.utils import get_update_storages +from graphrag.index.run.utils import get_update_table_providers from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.update.communities import _update_and_merge_communities @@ -22,13 +22,9 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """Update the communities from a incremental index run.""" logger.info("Workflow started: update_communities") - output_storage, previous_storage, delta_storage = get_update_storages( - config, context.state["update_timestamp"] + output_table_provider, previous_table_provider, delta_table_provider = ( + get_update_table_providers(config, context.state["update_timestamp"]) ) - - previous_table_provider = ParquetTableProvider(previous_storage) - delta_table_provider = ParquetTableProvider(delta_storage) - output_table_provider = ParquetTableProvider(output_storage) community_id_mapping = await _update_communities( previous_table_provider, delta_table_provider, output_table_provider @@ -41,9 +37,9 @@ async def run_workflow( async def _update_communities( - previous_table_provider: ParquetTableProvider, - delta_table_provider: ParquetTableProvider, - output_table_provider: ParquetTableProvider, + previous_table_provider: TableProvider, + delta_table_provider: TableProvider, + output_table_provider: TableProvider, ) -> dict: """Update the communities output.""" old_communities = await previous_table_provider.read_dataframe("communities") diff --git a/packages/graphrag/graphrag/index/workflows/update_community_reports.py b/packages/graphrag/graphrag/index/workflows/update_community_reports.py index 0929455c0a..9c9b0f2fec 100644 --- a/packages/graphrag/graphrag/index/workflows/update_community_reports.py +++ b/packages/graphrag/graphrag/index/workflows/update_community_reports.py @@ -6,10 +6,10 @@ import logging import pandas as pd -from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider +from graphrag_storage.tables.table_provider import TableProvider from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.run.utils import get_update_storages +from graphrag.index.run.utils import get_update_table_providers from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.update.communities import _update_and_merge_community_reports @@ -23,18 +23,17 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """Update the community reports from a incremental index run.""" logger.info("Workflow started: update_community_reports") - output_storage, previous_storage, delta_storage = get_update_storages( - config, context.state["update_timestamp"] + output_table_provider, previous_table_provider, delta_table_provider = ( + get_update_table_providers(config, context.state["update_timestamp"]) ) - - previous_table_provider = ParquetTableProvider(previous_storage) - delta_table_provider = ParquetTableProvider(delta_storage) - output_table_provider = ParquetTableProvider(output_storage) community_id_mapping = context.state["incremental_update_community_id_mapping"] merged_community_reports = await _update_community_reports( - previous_table_provider, delta_table_provider, output_table_provider, community_id_mapping + previous_table_provider, + delta_table_provider, + output_table_provider, + community_id_mapping, ) context.state["incremental_update_merged_community_reports"] = ( @@ -46,9 +45,9 @@ async def run_workflow( async def _update_community_reports( - previous_table_provider: ParquetTableProvider, - delta_table_provider: ParquetTableProvider, - output_table_provider: ParquetTableProvider, + previous_table_provider: TableProvider, + delta_table_provider: TableProvider, + output_table_provider: TableProvider, community_id_mapping: dict, ) -> pd.DataFrame: """Update the community reports output.""" diff --git a/packages/graphrag/graphrag/index/workflows/update_covariates.py b/packages/graphrag/graphrag/index/workflows/update_covariates.py index c2bdcc2b44..a2c1a834fb 100644 --- a/packages/graphrag/graphrag/index/workflows/update_covariates.py +++ b/packages/graphrag/graphrag/index/workflows/update_covariates.py @@ -7,10 +7,10 @@ import numpy as np import pandas as pd -from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider +from graphrag_storage.tables.table_provider import TableProvider from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.run.utils import get_update_storages +from graphrag.index.run.utils import get_update_table_providers from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput @@ -23,28 +23,26 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """Update the covariates from a incremental index run.""" logger.info("Workflow started: update_covariates") - output_storage, previous_storage, delta_storage = get_update_storages( - config, context.state["update_timestamp"] + output_table_provider, previous_table_provider, delta_table_provider = ( + get_update_table_providers(config, context.state["update_timestamp"]) ) - - previous_table_provider = ParquetTableProvider(previous_storage) - delta_table_provider = ParquetTableProvider(delta_storage) - output_table_provider = ParquetTableProvider(output_storage) if await previous_table_provider.has_dataframe( "covariates" ) and await delta_table_provider.has_dataframe("covariates"): logger.info("Updating Covariates") - await _update_covariates(previous_table_provider, delta_table_provider, output_table_provider) + await _update_covariates( + previous_table_provider, delta_table_provider, output_table_provider + ) logger.info("Workflow completed: update_covariates") return WorkflowFunctionOutput(result=None) async def _update_covariates( - previous_table_provider: ParquetTableProvider, - delta_table_provider: ParquetTableProvider, - output_table_provider: ParquetTableProvider, + previous_table_provider: TableProvider, + delta_table_provider: TableProvider, + output_table_provider: TableProvider, ) -> None: """Update the covariates output.""" old_covariates = await previous_table_provider.read_dataframe("covariates") diff --git a/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py b/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py index 300a69ab9b..c7d1bcc416 100644 --- a/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py +++ b/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py @@ -8,12 +8,12 @@ import pandas as pd from graphrag_cache import Cache from graphrag_llm.completion import create_completion -from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider +from graphrag_storage.tables.table_provider import TableProvider from graphrag.cache.cache_key_creator import cache_key_creator from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.run.utils import get_update_storages +from graphrag.index.run.utils import get_update_table_providers from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.update.entities import _group_and_resolve_entities @@ -29,13 +29,9 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """Update the entities and relationships from a incremental index run.""" logger.info("Workflow started: update_entities_relationships") - output_storage, previous_storage, delta_storage = get_update_storages( - config, context.state["update_timestamp"] + output_table_provider, previous_table_provider, delta_table_provider = ( + get_update_table_providers(config, context.state["update_timestamp"]) ) - - previous_table_provider = ParquetTableProvider(previous_storage) - delta_table_provider = ParquetTableProvider(delta_storage) - output_table_provider = ParquetTableProvider(output_storage) ( merged_entities_df, @@ -59,9 +55,9 @@ async def run_workflow( async def _update_entities_and_relationships( - previous_table_provider: ParquetTableProvider, - delta_table_provider: ParquetTableProvider, - output_table_provider: ParquetTableProvider, + previous_table_provider: TableProvider, + delta_table_provider: TableProvider, + output_table_provider: TableProvider, config: GraphRagConfig, cache: Cache, callbacks: WorkflowCallbacks, @@ -108,6 +104,8 @@ async def _update_entities_and_relationships( # Save the updated entities back to storage await output_table_provider.write_dataframe("entities", merged_entities_df) - await output_table_provider.write_dataframe("relationships", merged_relationships_df) + await output_table_provider.write_dataframe( + "relationships", merged_relationships_df + ) return merged_entities_df, merged_relationships_df, entity_id_mapping diff --git a/packages/graphrag/graphrag/index/workflows/update_final_documents.py b/packages/graphrag/graphrag/index/workflows/update_final_documents.py index b684beba94..7f473096d3 100644 --- a/packages/graphrag/graphrag/index/workflows/update_final_documents.py +++ b/packages/graphrag/graphrag/index/workflows/update_final_documents.py @@ -6,7 +6,7 @@ import logging from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.run.utils import get_update_storages +from graphrag.index.run.utils import get_update_table_providers from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.update.incremental_index import concat_dataframes @@ -20,12 +20,15 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """Update the documents from a incremental index run.""" logger.info("Workflow started: update_final_documents") - output_storage, previous_storage, delta_storage = get_update_storages( - config, context.state["update_timestamp"] + output_table_provider, previous_table_provider, delta_table_provider = ( + get_update_table_providers(config, context.state["update_timestamp"]) ) final_documents = await concat_dataframes( - "documents", previous_storage, delta_storage, output_storage + "documents", + previous_table_provider, + delta_table_provider, + output_table_provider, ) context.state["incremental_update_final_documents"] = final_documents diff --git a/packages/graphrag/graphrag/index/workflows/update_text_embeddings.py b/packages/graphrag/graphrag/index/workflows/update_text_embeddings.py index 1c3a1ca0ac..4a3cf1a673 100644 --- a/packages/graphrag/graphrag/index/workflows/update_text_embeddings.py +++ b/packages/graphrag/graphrag/index/workflows/update_text_embeddings.py @@ -9,11 +9,10 @@ from graphrag.cache.cache_key_creator import cache_key_creator from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.run.utils import get_update_storages +from graphrag.index.run.utils import get_update_table_providers from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.workflows.generate_text_embeddings import generate_text_embeddings -from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider logger = logging.getLogger(__name__) @@ -24,10 +23,9 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """Update the text embeddings from a incremental index run.""" logger.info("Workflow started: update_text_embeddings") - output_storage, _, _ = get_update_storages( + output_table_provider, _, _ = get_update_table_providers( config, context.state["update_timestamp"] ) - output_table_provider = ParquetTableProvider(output_storage) merged_text_units = context.state["incremental_update_merged_text_units"] merged_entities_df = context.state["incremental_update_merged_entities"] diff --git a/packages/graphrag/graphrag/index/workflows/update_text_units.py b/packages/graphrag/graphrag/index/workflows/update_text_units.py index 2d113709f5..02592b8aa4 100644 --- a/packages/graphrag/graphrag/index/workflows/update_text_units.py +++ b/packages/graphrag/graphrag/index/workflows/update_text_units.py @@ -7,11 +7,10 @@ import numpy as np import pandas as pd -from graphrag_storage import Storage -from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider +from graphrag_storage.tables.table_provider import TableProvider from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.run.utils import get_update_storages +from graphrag.index.run.utils import get_update_table_providers from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput @@ -24,13 +23,16 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """Update the text units from a incremental index run.""" logger.info("Workflow started: update_text_units") - output_storage, previous_storage, delta_storage = get_update_storages( - config, context.state["update_timestamp"] + output_table_provider, previous_table_provider, delta_table_provider = ( + get_update_table_providers(config, context.state["update_timestamp"]) ) entity_id_mapping = context.state["incremental_update_entity_id_mapping"] merged_text_units = await _update_text_units( - previous_storage, delta_storage, output_storage, entity_id_mapping + previous_table_provider, + delta_table_provider, + output_table_provider, + entity_id_mapping, ) context.state["incremental_update_merged_text_units"] = merged_text_units @@ -40,16 +42,12 @@ async def run_workflow( async def _update_text_units( - previous_storage: Storage, - delta_storage: Storage, - output_storage: Storage, + previous_table_provider: TableProvider, + delta_table_provider: TableProvider, + output_table_provider: TableProvider, entity_id_mapping: dict, ) -> pd.DataFrame: """Update the text units output.""" - previous_table_provider = ParquetTableProvider(previous_storage) - delta_table_provider = ParquetTableProvider(delta_storage) - output_table_provider = ParquetTableProvider(output_storage) - old_text_units = await previous_table_provider.read_dataframe("text_units") delta_text_units = await delta_table_provider.read_dataframe("text_units") merged_text_units = _update_and_merge_text_units( diff --git a/tests/unit/storage/test_parquet_table_provider.py b/tests/unit/storage/test_parquet_table_provider.py index 27b66228b9..e9aa20065d 100644 --- a/tests/unit/storage/test_parquet_table_provider.py +++ b/tests/unit/storage/test_parquet_table_provider.py @@ -6,7 +6,12 @@ import pandas as pd import pytest -from graphrag_storage import ParquetTableProvider, StorageConfig, StorageType, create_storage +from graphrag_storage import ( + ParquetTableProvider, + StorageConfig, + StorageType, + create_storage, +) class TestParquetTableProvider(unittest.IsolatedAsyncioTestCase): @@ -25,7 +30,7 @@ async def test_write_and_read(self): df = pd.DataFrame({ "id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"], - "age": [30, 25, 35] + "age": [30, 25, 35], }) await self.table_provider.write_dataframe("users", df) @@ -34,7 +39,9 @@ async def test_write_and_read(self): pd.testing.assert_frame_equal(result, df) async def test_read_nonexistent_table_raises_error(self): - with pytest.raises(ValueError, match="Could not find nonexistent.parquet in storage!"): + with pytest.raises( + ValueError, match=r"Could not find nonexistent\.parquet in storage!" + ): await self.table_provider.read_dataframe("nonexistent") async def test_empty_dataframe(self): @@ -50,7 +57,7 @@ async def test_dataframe_with_multiple_types(self): "int_col": [1, 2, 3], "float_col": [1.1, 2.2, 3.3], "str_col": ["a", "b", "c"], - "bool_col": [True, False, True] + "bool_col": [True, False, True], }) await self.table_provider.write_dataframe("mixed", df) @@ -62,10 +69,22 @@ async def test_storage_persistence(self): df = pd.DataFrame({"x": [1, 2, 3]}) await self.table_provider.write_dataframe("test", df) - + assert await self.storage.has("test.parquet") - + parquet_bytes = await self.storage.get("test.parquet", as_bytes=True) loaded_df = pd.read_parquet(BytesIO(parquet_bytes)) - + pd.testing.assert_frame_equal(loaded_df, df) + + async def test_has_dataframe(self): + df = pd.DataFrame({"a": [1, 2, 3]}) + + # Table doesn't exist yet + assert not await self.table_provider.has_dataframe("test_table") + + # Write the table + await self.table_provider.write_dataframe("test_table", df) + + # Now it exists + assert await self.table_provider.has_dataframe("test_table") From 0ccc8d6d128fe93a27afeb209c90a513fe0a836d Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Mon, 26 Jan 2026 19:47:25 -0300 Subject: [PATCH 5/7] add versioning --- .semversioner/next-release/minor-20260126224712110537.json | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 .semversioner/next-release/minor-20260126224712110537.json diff --git a/.semversioner/next-release/minor-20260126224712110537.json b/.semversioner/next-release/minor-20260126224712110537.json new file mode 100644 index 0000000000..3f83bfd77d --- /dev/null +++ b/.semversioner/next-release/minor-20260126224712110537.json @@ -0,0 +1,4 @@ +{ + "type": "minor", + "description": "Add TableProvider abstraction for table-based storage operations" +} From 501887124c77eff1d943fd049ec7caba1bbcb76a Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Tue, 27 Jan 2026 11:12:03 -0300 Subject: [PATCH 6/7] add patch and remove utility --- ...7.json => patch-20260127131016120694.json} | 2 +- .../index_migration_to_v1.ipynb | 39 ++++++------ .../index_migration_to_v2.ipynb | 61 +++++++++---------- .../index_migration_to_v3.ipynb | 12 ++-- .../graphrag_storage/__init__.py | 2 + .../graphrag_storage/tables/__init__.py | 8 +++ packages/graphrag/graphrag/cli/query.py | 13 ++-- .../graphrag/index/run/run_pipeline.py | 4 +- packages/graphrag/graphrag/index/run/utils.py | 3 +- .../graphrag/graphrag/index/typing/context.py | 3 +- packages/graphrag/graphrag/utils/storage.py | 36 ----------- .../storage/test_parquet_table_provider.py | 2 +- tests/verbs/test_create_base_text_units.py | 3 +- tests/verbs/test_create_communities.py | 3 +- tests/verbs/test_create_community_reports.py | 3 +- tests/verbs/test_create_final_documents.py | 3 +- tests/verbs/test_create_final_text_units.py | 3 +- tests/verbs/test_extract_covariates.py | 3 +- tests/verbs/test_extract_graph.py | 11 +--- tests/verbs/test_extract_graph_nlp.py | 11 +--- tests/verbs/test_finalize_graph.py | 15 ++--- tests/verbs/test_generate_text_embeddings.py | 5 +- tests/verbs/test_prune_graph.py | 7 +-- tests/verbs/util.py | 5 +- 24 files changed, 102 insertions(+), 155 deletions(-) rename .semversioner/next-release/{minor-20260126224712110537.json => patch-20260127131016120694.json} (82%) create mode 100644 packages/graphrag-storage/graphrag_storage/tables/__init__.py delete mode 100644 packages/graphrag/graphrag/utils/storage.py diff --git a/.semversioner/next-release/minor-20260126224712110537.json b/.semversioner/next-release/patch-20260127131016120694.json similarity index 82% rename from .semversioner/next-release/minor-20260126224712110537.json rename to .semversioner/next-release/patch-20260127131016120694.json index 3f83bfd77d..516466de8f 100644 --- a/.semversioner/next-release/minor-20260126224712110537.json +++ b/.semversioner/next-release/patch-20260127131016120694.json @@ -1,4 +1,4 @@ { - "type": "minor", + "type": "patch", "description": "Add TableProvider abstraction for table-based storage operations" } diff --git a/docs/examples_notebooks/index_migration_to_v1.ipynb b/docs/examples_notebooks/index_migration_to_v1.ipynb index c5b582d38d..3fd8f264bc 100644 --- a/docs/examples_notebooks/index_migration_to_v1.ipynb +++ b/docs/examples_notebooks/index_migration_to_v1.ipynb @@ -103,21 +103,22 @@ "source": [ "from uuid import uuid4\n", "\n", - "from graphrag.utils.storage import load_table_from_storage, write_table_to_storage\n", + "from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider\n", + "\n", + "# Create table provider from storage\n", + "table_provider = ParquetTableProvider(storage)\n", "\n", "# First we'll go through any parquet files that had model changes and update them\n", "# The new data model may have removed excess columns as well, but we will only make the minimal changes required for compatibility\n", "\n", - "final_documents = await load_table_from_storage(\"create_final_documents\", storage)\n", - "final_text_units = await load_table_from_storage(\"create_final_text_units\", storage)\n", - "final_entities = await load_table_from_storage(\"create_final_entities\", storage)\n", - "final_nodes = await load_table_from_storage(\"create_final_nodes\", storage)\n", - "final_relationships = await load_table_from_storage(\n", - " \"create_final_relationships\", storage\n", - ")\n", - "final_communities = await load_table_from_storage(\"create_final_communities\", storage)\n", - "final_community_reports = await load_table_from_storage(\n", - " \"create_final_community_reports\", storage\n", + "final_documents = await table_provider.read_dataframe(\"create_final_documents\")\n", + "final_text_units = await table_provider.read_dataframe(\"create_final_text_units\")\n", + "final_entities = await table_provider.read_dataframe(\"create_final_entities\")\n", + "final_nodes = await table_provider.read_dataframe(\"create_final_nodes\")\n", + "final_relationships = await table_provider.read_dataframe(\"create_final_relationships\")\n", + "final_communities = await table_provider.read_dataframe(\"create_final_communities\")\n", + "final_community_reports = await table_provider.read_dataframe(\n", + " \"create_final_community_reports\"\n", ")\n", "\n", "\n", @@ -187,14 +188,14 @@ " parent_df, on=\"community\", how=\"left\"\n", " )\n", "\n", - "await write_table_to_storage(final_documents, \"create_final_documents\", storage)\n", - "await write_table_to_storage(final_text_units, \"create_final_text_units\", storage)\n", - "await write_table_to_storage(final_entities, \"create_final_entities\", storage)\n", - "await write_table_to_storage(final_nodes, \"create_final_nodes\", storage)\n", - "await write_table_to_storage(final_relationships, \"create_final_relationships\", storage)\n", - "await write_table_to_storage(final_communities, \"create_final_communities\", storage)\n", - "await write_table_to_storage(\n", - " final_community_reports, \"create_final_community_reports\", storage\n", + "await table_provider.write_dataframe(\"create_final_documents\", final_documents)\n", + "await table_provider.write_dataframe(\"create_final_text_units\", final_text_units)\n", + "await table_provider.write_dataframe(\"create_final_entities\", final_entities)\n", + "await table_provider.write_dataframe(\"create_final_nodes\", final_nodes)\n", + "await table_provider.write_dataframe(\"create_final_relationships\", final_relationships)\n", + "await table_provider.write_dataframe(\"create_final_communities\", final_communities)\n", + "await table_provider.write_dataframe(\n", + " \"create_final_community_reports\", final_community_reports\n", ")" ] }, diff --git a/docs/examples_notebooks/index_migration_to_v2.ipynb b/docs/examples_notebooks/index_migration_to_v2.ipynb index 0681d1a0b2..c71e27d945 100644 --- a/docs/examples_notebooks/index_migration_to_v2.ipynb +++ b/docs/examples_notebooks/index_migration_to_v2.ipynb @@ -65,28 +65,25 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", - "from graphrag.utils.storage import (\n", - " delete_table_from_storage,\n", - " load_table_from_storage,\n", - " write_table_to_storage,\n", - ")\n", + "from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider\n", "\n", - "final_documents = await load_table_from_storage(\"create_final_documents\", storage)\n", - "final_text_units = await load_table_from_storage(\"create_final_text_units\", storage)\n", - "final_entities = await load_table_from_storage(\"create_final_entities\", storage)\n", - "final_covariates = await load_table_from_storage(\"create_final_covariates\", storage)\n", - "final_nodes = await load_table_from_storage(\"create_final_nodes\", storage)\n", - "final_relationships = await load_table_from_storage(\n", - " \"create_final_relationships\", storage\n", - ")\n", - "final_communities = await load_table_from_storage(\"create_final_communities\", storage)\n", - "final_community_reports = await load_table_from_storage(\n", - " \"create_final_community_reports\", storage\n", + "# Create table provider from storage\n", + "table_provider = ParquetTableProvider(storage)\n", + "\n", + "final_documents = await table_provider.read_dataframe(\"create_final_documents\")\n", + "final_text_units = await table_provider.read_dataframe(\"create_final_text_units\")\n", + "final_entities = await table_provider.read_dataframe(\"create_final_entities\")\n", + "final_covariates = await table_provider.read_dataframe(\"create_final_covariates\")\n", + "final_nodes = await table_provider.read_dataframe(\"create_final_nodes\")\n", + "final_relationships = await table_provider.read_dataframe(\"create_final_relationships\")\n", + "final_communities = await table_provider.read_dataframe(\"create_final_communities\")\n", + "final_community_reports = await table_provider.read_dataframe(\n", + " \"create_final_community_reports\"\n", ")\n", "\n", "# we've renamed document attributes as metadata\n", @@ -126,23 +123,23 @@ ")\n", "\n", "# we renamed all the output files for better clarity now that we don't have workflow naming constraints from DataShaper\n", - "await write_table_to_storage(final_documents, \"documents\", storage)\n", - "await write_table_to_storage(final_text_units, \"text_units\", storage)\n", - "await write_table_to_storage(final_entities, \"entities\", storage)\n", - "await write_table_to_storage(final_relationships, \"relationships\", storage)\n", - "await write_table_to_storage(final_covariates, \"covariates\", storage)\n", - "await write_table_to_storage(final_communities, \"communities\", storage)\n", - "await write_table_to_storage(final_community_reports, \"community_reports\", storage)\n", + "await table_provider.write_dataframe(\"documents\", final_documents)\n", + "await table_provider.write_dataframe(\"text_units\", final_text_units)\n", + "await table_provider.write_dataframe(\"entities\", final_entities)\n", + "await table_provider.write_dataframe(\"relationships\", final_relationships)\n", + "await table_provider.write_dataframe(\"covariates\", final_covariates)\n", + "await table_provider.write_dataframe(\"communities\", final_communities)\n", + "await table_provider.write_dataframe(\"community_reports\", final_community_reports)\n", "\n", "# delete all the old versions\n", - "await delete_table_from_storage(\"create_final_documents\", storage)\n", - "await delete_table_from_storage(\"create_final_text_units\", storage)\n", - "await delete_table_from_storage(\"create_final_entities\", storage)\n", - "await delete_table_from_storage(\"create_final_nodes\", storage)\n", - "await delete_table_from_storage(\"create_final_relationships\", storage)\n", - "await delete_table_from_storage(\"create_final_covariates\", storage)\n", - "await delete_table_from_storage(\"create_final_communities\", storage)\n", - "await delete_table_from_storage(\"create_final_community_reports\", storage)" + "await storage.delete(\"create_final_documents.parquet\")\n", + "await storage.delete(\"create_final_text_units.parquet\")\n", + "await storage.delete(\"create_final_entities.parquet\")\n", + "await storage.delete(\"create_final_nodes.parquet\")\n", + "await storage.delete(\"create_final_relationships.parquet\")\n", + "await storage.delete(\"create_final_covariates.parquet\")\n", + "await storage.delete(\"create_final_communities.parquet\")\n", + "await storage.delete(\"create_final_community_reports.parquet\")" ] } ], diff --git a/docs/examples_notebooks/index_migration_to_v3.ipynb b/docs/examples_notebooks/index_migration_to_v3.ipynb index a0e50be432..7f94dedee6 100644 --- a/docs/examples_notebooks/index_migration_to_v3.ipynb +++ b/docs/examples_notebooks/index_migration_to_v3.ipynb @@ -66,17 +66,17 @@ "metadata": {}, "outputs": [], "source": [ - "from graphrag.utils.storage import (\n", - " load_table_from_storage,\n", - " write_table_to_storage,\n", - ")\n", + "from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider\n", "\n", - "text_units = await load_table_from_storage(\"text_units\", storage)\n", + "# Create table provider from storage\n", + "table_provider = ParquetTableProvider(storage)\n", + "\n", + "text_units = await table_provider.read_dataframe(\"text_units\")\n", "\n", "text_units[\"document_id\"] = text_units[\"document_ids\"].apply(lambda ids: ids[0])\n", "remove_columns(text_units, [\"document_ids\"])\n", "\n", - "await write_table_to_storage(text_units, \"text_units\", storage)" + "await table_provider.write_dataframe(\"text_units\", text_units)" ] }, { diff --git a/packages/graphrag-storage/graphrag_storage/__init__.py b/packages/graphrag-storage/graphrag_storage/__init__.py index 2ae67be741..454842eecc 100644 --- a/packages/graphrag-storage/graphrag_storage/__init__.py +++ b/packages/graphrag-storage/graphrag_storage/__init__.py @@ -10,11 +10,13 @@ register_storage, ) from graphrag_storage.storage_type import StorageType +from graphrag_storage.tables import TableProvider __all__ = [ "Storage", "StorageConfig", "StorageType", + "TableProvider", "create_storage", "register_storage", ] diff --git a/packages/graphrag-storage/graphrag_storage/tables/__init__.py b/packages/graphrag-storage/graphrag_storage/tables/__init__.py new file mode 100644 index 0000000000..0210d935f3 --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/tables/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Table provider module for GraphRAG storage.""" + +from .table_provider import TableProvider + +__all__ = ["TableProvider"] diff --git a/packages/graphrag/graphrag/cli/query.py b/packages/graphrag/graphrag/cli/query.py index ae06a88c95..1606a3201b 100644 --- a/packages/graphrag/graphrag/cli/query.py +++ b/packages/graphrag/graphrag/cli/query.py @@ -9,12 +9,12 @@ from typing import TYPE_CHECKING, Any from graphrag_storage import create_storage +from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider import graphrag.api as api from graphrag.callbacks.noop_query_callbacks import NoopQueryCallbacks from graphrag.config.load_config import load_config from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.utils.storage import load_table_from_storage, storage_has_table if TYPE_CHECKING: import pandas as pd @@ -377,19 +377,18 @@ def _resolve_output_files( ) -> dict[str, Any]: """Read indexing output files to a dataframe dict.""" dataframe_dict = {} - storage_obj = create_storage(config.output_storage) + storage_obj = create_storage(config.output) + table_provider = ParquetTableProvider(storage_obj) for name in output_list: - df_value = asyncio.run(load_table_from_storage(name=name, storage=storage_obj)) + df_value = asyncio.run(table_provider.read_dataframe(name)) dataframe_dict[name] = df_value # for optional output files, set the dict entry to None instead of erroring out if it does not exist if optional_list: for optional_file in optional_list: - file_exists = asyncio.run(storage_has_table(optional_file, storage_obj)) + file_exists = asyncio.run(table_provider.has_dataframe(optional_file)) if file_exists: - df_value = asyncio.run( - load_table_from_storage(name=optional_file, storage=storage_obj) - ) + df_value = asyncio.run(table_provider.read_dataframe(optional_file)) dataframe_dict[optional_file] = df_value else: dataframe_dict[optional_file] = None diff --git a/packages/graphrag/graphrag/index/run/run_pipeline.py b/packages/graphrag/graphrag/index/run/run_pipeline.py index dadeb9bbf4..d55b2f3214 100644 --- a/packages/graphrag/graphrag/index/run/run_pipeline.py +++ b/packages/graphrag/graphrag/index/run/run_pipeline.py @@ -22,7 +22,6 @@ from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.pipeline import Pipeline from graphrag.index.typing.pipeline_run_result import PipelineRunResult -from graphrag.utils.storage import write_table_to_storage logger = logging.getLogger(__name__) @@ -89,7 +88,8 @@ async def run_pipeline( # if the user passes in a df directly, write directly to storage so we can skip finding/parsing later if input_documents is not None: - await write_table_to_storage(input_documents, "documents", output_storage) + output_table_provider = ParquetTableProvider(output_storage) + await output_table_provider.write_dataframe("documents", input_documents) pipeline.remove("load_input_documents") context = create_run_context( diff --git a/packages/graphrag/graphrag/index/run/utils.py b/packages/graphrag/graphrag/index/run/utils.py index 95b0b399f2..ad8b6ff11b 100644 --- a/packages/graphrag/graphrag/index/run/utils.py +++ b/packages/graphrag/graphrag/index/run/utils.py @@ -5,8 +5,9 @@ from graphrag_cache import Cache from graphrag_cache.memory_cache import MemoryCache -from graphrag_storage import ParquetTableProvider, Storage, create_storage +from graphrag_storage import Storage, create_storage from graphrag_storage.memory_storage import MemoryStorage +from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks diff --git a/packages/graphrag/graphrag/index/typing/context.py b/packages/graphrag/graphrag/index/typing/context.py index 48c66e16c4..27f9280d26 100644 --- a/packages/graphrag/graphrag/index/typing/context.py +++ b/packages/graphrag/graphrag/index/typing/context.py @@ -10,7 +10,8 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.index.typing.state import PipelineState from graphrag.index.typing.stats import PipelineRunStats -from graphrag_storage import ParquetTableProvider, Storage +from graphrag_storage import Storage +from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider @dataclass diff --git a/packages/graphrag/graphrag/utils/storage.py b/packages/graphrag/graphrag/utils/storage.py deleted file mode 100644 index e4d1566427..0000000000 --- a/packages/graphrag/graphrag/utils/storage.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Storage functions for the GraphRAG run module.""" - -import logging -from io import BytesIO - -import pandas as pd -from graphrag_storage import Storage - -logger = logging.getLogger(__name__) - - -async def load_table_from_storage(name: str, storage: Storage) -> pd.DataFrame: - """Load a parquet from the storage instance.""" - table_provider = ParquetTableProvider(storage) - return await table_provider.read_dataframe(name) - - -async def write_table_to_storage( - table: pd.DataFrame, name: str, storage: Storage -) -> None: - """Write a table to storage.""" - table_provider = ParquetTableProvider(storage) - await table_provider.write_dataframe(name, table) - - -async def delete_table_from_storage(name: str, storage: Storage) -> None: - """Delete a table to storage.""" - await storage.delete(f"{name}.parquet") - - -async def storage_has_table(name: str, storage: Storage) -> bool: - """Check if a table exists in storage.""" - return await storage.has(f"{name}.parquet") diff --git a/tests/unit/storage/test_parquet_table_provider.py b/tests/unit/storage/test_parquet_table_provider.py index e9aa20065d..781735224b 100644 --- a/tests/unit/storage/test_parquet_table_provider.py +++ b/tests/unit/storage/test_parquet_table_provider.py @@ -7,11 +7,11 @@ import pandas as pd import pytest from graphrag_storage import ( - ParquetTableProvider, StorageConfig, StorageType, create_storage, ) +from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider class TestParquetTableProvider(unittest.IsolatedAsyncioTestCase): diff --git a/tests/verbs/test_create_base_text_units.py b/tests/verbs/test_create_base_text_units.py index 34bad99dc7..b7ad0543ed 100644 --- a/tests/verbs/test_create_base_text_units.py +++ b/tests/verbs/test_create_base_text_units.py @@ -2,7 +2,6 @@ # Licensed under the MIT License from graphrag.index.workflows.create_base_text_units import run_workflow -from graphrag.utils.storage import load_table_from_storage from tests.unit.config.utils import get_default_graphrag_config @@ -23,7 +22,7 @@ async def test_create_base_text_units(): await run_workflow(config, context) - actual = await load_table_from_storage("text_units", context.output_storage) + actual = await context.output_table_provider.read_dataframe("text_units") print("EXPECTED") print(expected.columns) diff --git a/tests/verbs/test_create_communities.py b/tests/verbs/test_create_communities.py index d5505d7a31..072e878e2c 100644 --- a/tests/verbs/test_create_communities.py +++ b/tests/verbs/test_create_communities.py @@ -5,7 +5,6 @@ from graphrag.index.workflows.create_communities import ( run_workflow, ) -from graphrag.utils.storage import load_table_from_storage from tests.unit.config.utils import get_default_graphrag_config @@ -33,7 +32,7 @@ async def test_create_communities(): context, ) - actual = await load_table_from_storage("communities", context.output_storage) + actual = await context.output_table_provider.read_dataframe("communities") columns = list(expected.columns.values) # don't compare period since it is created with the current date each time diff --git a/tests/verbs/test_create_community_reports.py b/tests/verbs/test_create_community_reports.py index a36b6c7a66..68d8d1be9c 100644 --- a/tests/verbs/test_create_community_reports.py +++ b/tests/verbs/test_create_community_reports.py @@ -10,7 +10,6 @@ from graphrag.index.workflows.create_community_reports import ( run_workflow, ) -from graphrag.utils.storage import load_table_from_storage from tests.unit.config.utils import get_default_graphrag_config @@ -56,7 +55,7 @@ async def test_create_community_reports(): await run_workflow(config, context) - actual = await load_table_from_storage("community_reports", context.output_storage) + actual = await context.output_table_provider.read_dataframe("community_reports") assert len(actual.columns) == len(expected.columns) diff --git a/tests/verbs/test_create_final_documents.py b/tests/verbs/test_create_final_documents.py index 6031ccd09c..586ad5b31c 100644 --- a/tests/verbs/test_create_final_documents.py +++ b/tests/verbs/test_create_final_documents.py @@ -5,7 +5,6 @@ from graphrag.index.workflows.create_final_documents import ( run_workflow, ) -from graphrag.utils.storage import load_table_from_storage from tests.unit.config.utils import get_default_graphrag_config @@ -27,7 +26,7 @@ async def test_create_final_documents(): await run_workflow(config, context) - actual = await load_table_from_storage("documents", context.output_storage) + actual = await context.output_table_provider.read_dataframe("documents") compare_outputs(actual, expected) diff --git a/tests/verbs/test_create_final_text_units.py b/tests/verbs/test_create_final_text_units.py index 56c0e72a81..c97cba2bcd 100644 --- a/tests/verbs/test_create_final_text_units.py +++ b/tests/verbs/test_create_final_text_units.py @@ -5,7 +5,6 @@ from graphrag.index.workflows.create_final_text_units import ( run_workflow, ) -from graphrag.utils.storage import load_table_from_storage from tests.unit.config.utils import get_default_graphrag_config @@ -33,7 +32,7 @@ async def test_create_final_text_units(): await run_workflow(config, context) - actual = await load_table_from_storage("text_units", context.output_storage) + actual = await context.output_table_provider.read_dataframe("text_units") for column in TEXT_UNITS_FINAL_COLUMNS: assert column in actual.columns diff --git a/tests/verbs/test_extract_covariates.py b/tests/verbs/test_extract_covariates.py index 5a87c121b3..4cf3a79d77 100644 --- a/tests/verbs/test_extract_covariates.py +++ b/tests/verbs/test_extract_covariates.py @@ -5,7 +5,6 @@ from graphrag.index.workflows.extract_covariates import ( run_workflow, ) -from graphrag.utils.storage import load_table_from_storage from graphrag_llm.config import LLMProviderType from pandas.testing import assert_series_equal @@ -41,7 +40,7 @@ async def test_extract_covariates(): await run_workflow(config, context) - actual = await load_table_from_storage("covariates", context.output_storage) + actual = await context.output_table_provider.read_dataframe("covariates") for column in COVARIATES_FINAL_COLUMNS: assert column in actual.columns diff --git a/tests/verbs/test_extract_graph.py b/tests/verbs/test_extract_graph.py index b62bd9c77f..504baaac31 100644 --- a/tests/verbs/test_extract_graph.py +++ b/tests/verbs/test_extract_graph.py @@ -1,10 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.index.workflows.extract_graph import ( - run_workflow, -) -from graphrag.utils.storage import load_table_from_storage +from graphrag.index.workflows.extract_graph import run_workflow from tests.unit.config.utils import get_default_graphrag_config @@ -54,10 +51,8 @@ async def test_extract_graph(): await run_workflow(config, context) - nodes_actual = await load_table_from_storage("entities", context.output_storage) - edges_actual = await load_table_from_storage( - "relationships", context.output_storage - ) + nodes_actual = await context.output_table_provider.read_dataframe("entities") + edges_actual = await context.output_table_provider.read_dataframe("relationships") assert len(nodes_actual.columns) == 5 assert len(edges_actual.columns) == 5 diff --git a/tests/verbs/test_extract_graph_nlp.py b/tests/verbs/test_extract_graph_nlp.py index 9c758dda61..55ab376689 100644 --- a/tests/verbs/test_extract_graph_nlp.py +++ b/tests/verbs/test_extract_graph_nlp.py @@ -1,10 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.index.workflows.extract_graph_nlp import ( - run_workflow, -) -from graphrag.utils.storage import load_table_from_storage +from graphrag.index.workflows.extract_graph_nlp import run_workflow from tests.unit.config.utils import get_default_graphrag_config @@ -22,10 +19,8 @@ async def test_extract_graph_nlp(): await run_workflow(config, context) - nodes_actual = await load_table_from_storage("entities", context.output_storage) - edges_actual = await load_table_from_storage( - "relationships", context.output_storage - ) + nodes_actual = await context.output_table_provider.read_dataframe("entities") + edges_actual = await context.output_table_provider.read_dataframe("relationships") # this will be the raw count of entities and edges with no pruning # with NLP it is deterministic, so we can assert exact row counts diff --git a/tests/verbs/test_finalize_graph.py b/tests/verbs/test_finalize_graph.py index 055ec76768..72513a8293 100644 --- a/tests/verbs/test_finalize_graph.py +++ b/tests/verbs/test_finalize_graph.py @@ -5,10 +5,7 @@ ENTITIES_FINAL_COLUMNS, RELATIONSHIPS_FINAL_COLUMNS, ) -from graphrag.index.workflows.finalize_graph import ( - run_workflow, -) -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage +from graphrag.index.workflows.finalize_graph import run_workflow from tests.unit.config.utils import get_default_graphrag_config @@ -25,10 +22,8 @@ async def test_finalize_graph(): await run_workflow(config, context) - nodes_actual = await load_table_from_storage("entities", context.output_storage) - edges_actual = await load_table_from_storage( - "relationships", context.output_storage - ) + nodes_actual = await context.output_table_provider.read_dataframe("entities") + edges_actual = await context.output_table_provider.read_dataframe("relationships") for column in ENTITIES_FINAL_COLUMNS: assert column in nodes_actual.columns @@ -44,8 +39,8 @@ async def _prep_tables(): # edit the tables to eliminate final fields that wouldn't be on the inputs entities = load_test_table("entities") entities.drop(columns=["degree"], inplace=True) - await write_table_to_storage(entities, "entities", context.output_storage) + await context.output_table_provider.write_dataframe("entities", entities) relationships = load_test_table("relationships") relationships.drop(columns=["combined_degree"], inplace=True) - await write_table_to_storage(relationships, "relationships", context.output_storage) + await context.output_table_provider.write_dataframe("relationships", relationships) return context diff --git a/tests/verbs/test_generate_text_embeddings.py b/tests/verbs/test_generate_text_embeddings.py index 14fd163d87..f25ed52d34 100644 --- a/tests/verbs/test_generate_text_embeddings.py +++ b/tests/verbs/test_generate_text_embeddings.py @@ -7,7 +7,6 @@ from graphrag.index.workflows.generate_text_embeddings import ( run_workflow, ) -from graphrag.utils.storage import load_table_from_storage from tests.unit.config.utils import get_default_graphrag_config @@ -45,8 +44,8 @@ async def test_generate_text_embeddings(): assert f"embeddings.{field}.parquet" in parquet_files # entity description should always be here, let's assert its format - entity_description_embeddings = await load_table_from_storage( - "embeddings.entity_description", context.output_storage + entity_description_embeddings = await context.output_table_provider.read_dataframe( + "embeddings.entity_description" ) assert len(entity_description_embeddings.columns) == 2 diff --git a/tests/verbs/test_prune_graph.py b/tests/verbs/test_prune_graph.py index fa66f98bde..1df5cadebb 100644 --- a/tests/verbs/test_prune_graph.py +++ b/tests/verbs/test_prune_graph.py @@ -2,10 +2,7 @@ # Licensed under the MIT License from graphrag.config.models.prune_graph_config import PruneGraphConfig -from graphrag.index.workflows.prune_graph import ( - run_workflow, -) -from graphrag.utils.storage import load_table_from_storage +from graphrag.index.workflows.prune_graph import run_workflow from tests.unit.config.utils import get_default_graphrag_config @@ -26,6 +23,6 @@ async def test_prune_graph(): await run_workflow(config, context) - nodes_actual = await load_table_from_storage("entities", context.output_storage) + nodes_actual = await context.output_table_provider.read_dataframe("entities") assert len(nodes_actual) == 29 diff --git a/tests/verbs/util.py b/tests/verbs/util.py index 65b6b906b7..741e8e3b1a 100644 --- a/tests/verbs/util.py +++ b/tests/verbs/util.py @@ -4,7 +4,6 @@ import pandas as pd from graphrag.index.run.utils import create_run_context from graphrag.index.typing.context import PipelineRunContext -from graphrag.utils.storage import write_table_to_storage from pandas.testing import assert_series_equal pd.set_option("display.max_columns", None) @@ -17,12 +16,12 @@ async def create_test_context(storage: list[str] | None = None) -> PipelineRunCo # always set the input docs, but since our stored table is final, drop what wouldn't be in the original source input input = load_test_table("documents") input.drop(columns=["text_unit_ids"], inplace=True) - await write_table_to_storage(input, "documents", context.output_storage) + await context.output_table_provider.write_dataframe("documents", input) if storage: for name in storage: table = load_test_table(name) - await write_table_to_storage(table, name, context.output_storage) + await context.output_table_provider.write_dataframe(name, table) return context From 914fb637ed50c7600b7c44deee6e42a89c28ad94 Mon Sep 17 00:00:00 2001 From: Dayenne Souza Date: Tue, 27 Jan 2026 21:53:17 +0000 Subject: [PATCH 7/7] pr changes --- packages/graphrag/graphrag/cli/query.py | 2 +- packages/graphrag/graphrag/index/run/run_pipeline.py | 2 +- packages/graphrag/graphrag/index/run/utils.py | 2 +- packages/graphrag/graphrag/index/typing/context.py | 9 ++++----- tests/unit/storage/__init__.py | 2 ++ 5 files changed, 9 insertions(+), 8 deletions(-) create mode 100644 tests/unit/storage/__init__.py diff --git a/packages/graphrag/graphrag/cli/query.py b/packages/graphrag/graphrag/cli/query.py index 1606a3201b..1f808420d4 100644 --- a/packages/graphrag/graphrag/cli/query.py +++ b/packages/graphrag/graphrag/cli/query.py @@ -377,7 +377,7 @@ def _resolve_output_files( ) -> dict[str, Any]: """Read indexing output files to a dataframe dict.""" dataframe_dict = {} - storage_obj = create_storage(config.output) + storage_obj = create_storage(config.output_storage) table_provider = ParquetTableProvider(storage_obj) for name in output_list: df_value = asyncio.run(table_provider.read_dataframe(name)) diff --git a/packages/graphrag/graphrag/index/run/run_pipeline.py b/packages/graphrag/graphrag/index/run/run_pipeline.py index d55b2f3214..24ff39cc07 100644 --- a/packages/graphrag/graphrag/index/run/run_pipeline.py +++ b/packages/graphrag/graphrag/index/run/run_pipeline.py @@ -38,7 +38,7 @@ async def run_pipeline( input_storage = create_storage(config.input_storage) input_table_provider = ParquetTableProvider(input_storage) - output_storage = create_storage(config.output) + output_storage = create_storage(config.output_storage) cache = create_cache(config.cache) # load existing state in case any workflows are stateful diff --git a/packages/graphrag/graphrag/index/run/utils.py b/packages/graphrag/graphrag/index/run/utils.py index ad8b6ff11b..207e9561a0 100644 --- a/packages/graphrag/graphrag/index/run/utils.py +++ b/packages/graphrag/graphrag/index/run/utils.py @@ -61,7 +61,7 @@ def get_update_table_providers( config: GraphRagConfig, timestamp: str ) -> tuple[ParquetTableProvider, ParquetTableProvider, ParquetTableProvider]: """Get table providers for the update index run.""" - output_storage = create_storage(config.output) + output_storage = create_storage(config.output_storage) update_storage = create_storage(config.update_output_storage) timestamped_storage = update_storage.child(timestamp) delta_storage = timestamped_storage.child("delta") diff --git a/packages/graphrag/graphrag/index/typing/context.py b/packages/graphrag/graphrag/index/typing/context.py index 27f9280d26..f606218dd2 100644 --- a/packages/graphrag/graphrag/index/typing/context.py +++ b/packages/graphrag/graphrag/index/typing/context.py @@ -10,8 +10,7 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.index.typing.state import PipelineState from graphrag.index.typing.stats import PipelineRunStats -from graphrag_storage import Storage -from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider +from graphrag_storage import Storage, TableProvider @dataclass @@ -21,13 +20,13 @@ class PipelineRunContext: stats: PipelineRunStats input_storage: Storage "Storage for reading input documents." - input_table_provider: ParquetTableProvider + input_table_provider: TableProvider "Table provider for reading input tables." output_storage: Storage "Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider." - output_table_provider: ParquetTableProvider + output_table_provider: TableProvider "Table provider for reading and writing output tables." - previous_table_provider: ParquetTableProvider | None + previous_table_provider: TableProvider | None "Table provider for reading previous pipeline run when running in update mode." cache: Cache "Cache instance for reading previous LLM responses." diff --git a/tests/unit/storage/__init__.py b/tests/unit/storage/__init__.py new file mode 100644 index 0000000000..0a3e38adfb --- /dev/null +++ b/tests/unit/storage/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License