From 1f9b25a767c18f121e6517366bfef8f8cb8cbb13 Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Sat, 31 Jan 2026 20:02:19 -0500 Subject: [PATCH] Parallel Xarray record batch reading via blocks partition factory. --- Cargo.lock | 26 +++++----- src/lib.rs | 105 +++++++++++++++++++++++--------------- xarray_sql/reader.py | 59 ++++++++++++++++----- xarray_sql/reader_test.py | 46 ++++++++--------- xarray_sql/sql_test.py | 2 - 5 files changed, 145 insertions(+), 93 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a4bde02f..a07ddf39 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -176,7 +176,7 @@ dependencies = [ "chrono", "chrono-tz", "half", - "hashbrown 0.16.0", + "hashbrown 0.16.1", "num-complex", "num-integer", "num-traits", @@ -541,9 +541,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" [[package]] name = "bzip2" @@ -1751,9 +1751,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.16.0" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" [[package]] name = "heck" @@ -1917,12 +1917,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.12.0" +version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6717a8d2a5a929a1a2eb43a12812498ed141a0bcfb7e8f7844fbdbe4303bba9f" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "equivalent", - "hashbrown 0.16.0", + "hashbrown 0.16.1", ] [[package]] @@ -2261,7 +2261,7 @@ dependencies = [ "flate2", "futures", "half", - "hashbrown 0.16.0", + "hashbrown 0.16.1", "lz4_flex", "num-bigint", "num-integer", @@ -2948,9 +2948,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.16" +version = "0.7.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14307c986784f72ef81c89db7d9e28d6ac26d16213b109ea501696195e6e3ce5" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" dependencies = [ "bytes", "futures-core", @@ -3103,9 +3103,9 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "uuid" -version = "1.18.1" +version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" +checksum = "ee48d38b119b0cd71fe4141b30f5ba9c7c5d9f4e7a3a8b4a674e4b6ef789976f" dependencies = [ "getrandom 0.3.3", "js-sys", diff --git a/src/lib.rs b/src/lib.rs index d61e9186..0a432e89 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,14 +24,12 @@ //! incrementally. This enables processing of larger-than-memory datasets when combined //! with DataFusion's streaming execution. //! -//! ## Parallel Execution Note +//! ## Parallel Execution //! -//! When using DataFusion's parallel execution (multiple partitions), aggregation queries -//! without ORDER BY may return partial results due to how our stream interacts with -//! DataFusion's async runtime. To ensure complete results: -//! - Add ORDER BY to aggregation queries, or -//! - Use `SessionConfig().with_target_partitions(1)` for single-threaded execution -//! TODO(#106): Implement proper parallelism and partition handling. +//! Each xarray chunk becomes a separate partition, enabling parallel execution across +//! multiple cores. Due to a bug in DataFusion v51.0.0's `collect()` method, aggregation +//! queries should use `to_arrow_table()` instead to ensure complete results. +//! TODO(#107): Upgrading to the latest datafusion-python (52+) should fix this. use std::ffi::CString; use std::fmt::Debug; @@ -140,40 +138,44 @@ impl PartitionStream for PyArrowStreamPartition { } } -/// A lazy table provider that wraps a Python stream factory. +/// A lazy table provider that wraps Python stream factory functions. /// /// This class implements the `__datafusion_table_provider__` protocol, allowing /// it to be registered with DataFusion's `SessionContext.register_table()`. /// /// Data is NOT read until query execution time - this enables true lazy evaluation. -/// The factory function is called on each query execution to create a fresh stream, -/// allowing the same table to be queried multiple times. +/// Each partition has its own factory function that is called on query execution +/// to create a fresh stream, enabling true parallelism in DataFusion. +/// +/// # Note +/// +/// Due to a bug in DataFusion v51.0.0's `collect()` method, use `to_arrow_table()` +/// instead for aggregation queries to ensure complete results. /// /// # Example /// /// ```python /// from datafusion import SessionContext -/// from xarray_sql import LazyArrowStreamTable, XarrayRecordBatchReader -/// -/// # Create a factory that produces lazy readers -/// def make_reader(): -/// return XarrayRecordBatchReader(ds, chunks={'time': 240}) +/// from xarray_sql import LazyArrowStreamTable +/// import pyarrow as pa /// -/// # Get schema from a sample reader -/// sample = make_reader() -/// schema = sample.schema +/// # Create factories for each partition (chunk) +/// factories = [ +/// lambda: pa.RecordBatchReader.from_batches(schema, batches_chunk_0), +/// lambda: pa.RecordBatchReader.from_batches(schema, batches_chunk_1), +/// ] /// -/// # Wrap factory in lazy table - NO DATA LOADED -/// table = LazyArrowStreamTable(make_reader, schema) +/// # Wrap factories in lazy table - NO DATA LOADED +/// table = LazyArrowStreamTable(factories, schema) /// /// # Register with DataFusion - STILL NO DATA LOADED /// ctx = SessionContext() /// ctx.register_table("air", table) /// -/// # Data only loaded HERE during collect() -/// # Each query creates a fresh stream via the factory -/// result = ctx.sql("SELECT AVG(air) FROM air").collect() -/// result2 = ctx.sql("SELECT * FROM air LIMIT 10").collect() # Works! +/// # Data only loaded HERE during query execution +/// # Each partition runs in parallel with its own factory +/// # Use to_arrow_table() for aggregation queries +/// result = ctx.sql("SELECT AVG(air) FROM air").to_arrow_table() /// ``` #[pyclass(name = "LazyArrowStreamTable")] struct LazyArrowStreamTable { @@ -183,19 +185,21 @@ struct LazyArrowStreamTable { #[pymethods] impl LazyArrowStreamTable { - /// Create a new LazyArrowStreamTable from a stream factory function. + /// Create a new LazyArrowStreamTable from stream factory functions. /// /// Args: - /// stream_factory: A callable that returns a Python object implementing - /// the Arrow PyCapsule interface (`__arrow_c_stream__`). - /// Called on each query execution to create a fresh stream. - /// schema: A PyArrow Schema for the table. Required since the factory - /// hasn't been called yet. + /// stream_factories: A list of callables, each returning a Python object + /// implementing the Arrow PyCapsule interface (`__arrow_c_stream__`). + /// Each factory represents one partition, enabling parallel execution. + /// Called on each query execution to create fresh streams. + /// schema: A PyArrow Schema for the table. Required since the factories + /// haven't been called yet. /// /// Raises: /// TypeError: If the schema is not a valid PyArrow Schema. + /// ValueError: If stream_factories is empty. #[new] - fn new(stream_factory: &Bound<'_, PyAny>, schema: &Bound<'_, PyAny>) -> PyResult { + fn new(stream_factories: &Bound<'_, PyAny>, schema: &Bound<'_, PyAny>) -> PyResult { // Convert the PyArrow schema to Arrow schema use arrow::datatypes::Schema; use arrow::pyarrow::FromPyArrow; @@ -205,17 +209,34 @@ impl LazyArrowStreamTable { })?; let schema_ref = Arc::new(arrow_schema); - // Create the partition stream with the factory - let partition = - PyArrowStreamPartition::new(stream_factory.clone().unbind(), schema_ref.clone()); + // Extract factories from the Python list + let factories: Vec> = stream_factories.extract().map_err(|e| { + pyo3::exceptions::PyTypeError::new_err(format!( + "stream_factories must be a list of callables: {e}" + )) + })?; - // Create the StreamingTable - let table = - StreamingTable::try_new(schema_ref, vec![Arc::new(partition)]).map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!( - "Failed to create StreamingTable: {e}" - )) - })?; + if factories.is_empty() { + return Err(pyo3::exceptions::PyValueError::new_err( + "stream_factories must not be empty", + )); + } + + // Create one partition per factory + let partitions: Vec> = factories + .into_iter() + .map(|factory| { + Arc::new(PyArrowStreamPartition::new(factory, schema_ref.clone())) + as Arc + }) + .collect(); + + // Create the StreamingTable with multiple partitions + let table = StreamingTable::try_new(schema_ref, partitions).map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!( + "Failed to create StreamingTable: {e}" + )) + })?; Ok(Self { table: Arc::new(table), @@ -236,7 +257,7 @@ impl LazyArrowStreamTable { // Try to get the current tokio runtime handle (available when called from DataFusion context) let runtime = Handle::try_current().ok(); - // Create FFI wrapper (v49 API takes 3 arguments) + // Create FFI wrapper let ffi_provider = FFI_TableProvider::new( provider, false, // can_support_pushdown_filters runtime, diff --git a/xarray_sql/reader.py b/xarray_sql/reader.py index f89cd276..3b57559d 100644 --- a/xarray_sql/reader.py +++ b/xarray_sql/reader.py @@ -154,7 +154,7 @@ def read_xarray(ds: xr.Dataset, chunks: Chunks = None) -> pa.RecordBatchReader: """Pivots an Xarray Dataset into a PyArrow Table, partitioned by chunks. Args: - ds: An Xarray Dataset. All `data_vars` mush share the same dimensions. + ds: An Xarray Dataset. All `data_vars` must share the same dimensions. chunks: Xarray-like chunks. If not provided, will default to the Dataset's chunks. The product of the chunk sizes becomes the standard length of each dataframe partition. @@ -176,8 +176,24 @@ def read_xarray_table( """Create a lazy DataFusion table from an xarray Dataset. This is the simplest way to register xarray data with DataFusion. - Data is only read when queries are executed (during collect()), - not during registration. The table can be queried multiple times. + Data is only read when queries are executed, not during registration. + The table can be queried multiple times. + + Each chunk becomes a separate partition, enabling DataFusion's parallel + execution across multiple cores. + + Note: + Due to a bug in DataFusion v51.0.0's collect() method, use + `to_arrow_table()` instead of `collect()` for aggregation queries + to ensure complete results:: + + # Correct - use to_arrow_table() + result = ctx.sql('SELECT lat, AVG(temp) FROM t GROUP BY lat').to_arrow_table() + + # May return partial results with collect() + result = ctx.sql('SELECT lat, AVG(temp) FROM t GROUP BY lat').collect() + + This should be fixed when we upgrade datafusion-python to 52 (#107). Args: ds: An xarray Dataset. All data_vars must share the same dimensions. @@ -200,21 +216,38 @@ def read_xarray_table( >>> ctx = SessionContext() >>> ctx.register_table('air', table) >>> - >>> # Data is only read here, during collect() - >>> result = ctx.sql('SELECT AVG(air) FROM air').collect() + >>> # Data is only read here, during query execution + >>> result = ctx.sql('SELECT AVG(air) FROM air').to_arrow_table() >>> # Can query again - each query creates a fresh stream - >>> result2 = ctx.sql('SELECT * FROM air LIMIT 10').collect() + >>> result2 = ctx.sql('SELECT * FROM air LIMIT 10').to_arrow_table() """ from ._native import LazyArrowStreamTable # Get schema from dataset without creating a stream schema = _parse_schema(ds) - # Create a factory function that produces fresh RecordBatchReaders on each call - def make_stream() -> pa.RecordBatchReader: - stream = XarrayRecordBatchReader( - ds, chunks, _iteration_callback=_iteration_callback - ) - return pa.RecordBatchReader.from_stream(stream) + blocks = block_slices(ds, chunks) + + # Create a factory function for each block (partition) + # Each factory produces a RecordBatchReader for its specific chunk + def make_partition_factory( + block: Block, + ) -> t.Callable[[], pa.RecordBatchReader]: + """Create a factory function for a specific block/chunk.""" + + def make_stream() -> pa.RecordBatchReader: + # Call the iteration callback if provided (for testing) + if _iteration_callback is not None: + _iteration_callback(block) + + # Extract just this block from the dataset and convert to Arrow + df = pivot(ds.isel(block)) + batch = pa.RecordBatch.from_pandas(df, schema=schema) + return pa.RecordBatchReader.from_batches(schema, [batch]) + + return make_stream + + # Create one factory per block + factories = [make_partition_factory(block) for block in blocks] - return LazyArrowStreamTable(make_stream, schema) + return LazyArrowStreamTable(factories, schema) diff --git a/xarray_sql/reader_test.py b/xarray_sql/reader_test.py index 1527ba61..37304258 100644 --- a/xarray_sql/reader_test.py +++ b/xarray_sql/reader_test.py @@ -511,8 +511,8 @@ def test_batches_processed_incrementally(self, small_ds): tracker.batch_count == 4 ), f"Expected 4 batches, got {tracker.batch_count}" - def test_streaming_preserves_order(self, small_ds): - """Verify that streaming preserves the order of batches.""" + def test_all_partitions_processed(self, small_ds): + """Verify that all partitions are processed (order may vary with parallelism).""" blocks_seen = [] def track_order(block): @@ -529,17 +529,16 @@ def track_order(block): ctx.register_table("test_table", table) ctx.sql("SELECT * FROM test_table").collect() - # Should have 4 blocks + # Should have 4 blocks/partitions assert len(blocks_seen) == 4 - # Blocks should be in order (each slice should start after previous) - for i in range(1, len(blocks_seen)): - prev_end = blocks_seen[i - 1].stop - curr_start = blocks_seen[i].start - assert curr_start == prev_end, ( - f"Block {i} starts at {curr_start}, expected {prev_end}. " - f"Blocks are out of order!" - ) + # All blocks should be present (though order may vary due to parallelism) + # Extract start positions and verify they cover all expected ranges + starts = sorted([b.start for b in blocks_seen]) + expected_starts = [0, 25, 50, 75] + assert ( + starts == expected_starts + ), f"Expected partition starts {expected_starts}, got {starts}" def test_large_dataset_streams_correctly(self): """Test streaming with a larger dataset to verify memory behavior. @@ -733,9 +732,10 @@ def test_aggregation_with_many_batches(self): GROUP BY queries require processing all data, making them a good test for streaming behavior. - Note: ORDER BY is used to ensure deterministic results. Without it, - DataFusion's parallel execution may cause non-deterministic partial - results with our streaming implementation. + Note: We use to_arrow_table() instead of collect() due to a bug in + DataFusion v51.0.0 where collect() returns partial results for + parallel aggregation queries. + # TODO(#107): Upgrade to latest datafusion-python, which has the fix. """ np.random.seed(789) time_coord = pd.date_range("2020-01-01", periods=120, freq="h") @@ -751,7 +751,7 @@ def test_aggregation_with_many_batches(self): tracker = StreamingTracker() - # 12 batches + # 12 partitions (one per chunk) table = read_xarray_table( ds, chunks={"time": 10}, @@ -762,20 +762,19 @@ def test_aggregation_with_many_batches(self): ctx.register_table("test_table", table) # GROUP BY requires scanning all data - # ORDER BY ensures all partial aggregates are collected before returning - # TODO(#106): Fix the underlying partitioning issue. + # Use to_arrow_table() to avoid DataFusion collect() bug result = ctx.sql( - "SELECT lat, AVG(temperature) as avg_temp FROM test_table GROUP BY lat ORDER BY lat" - ).collect() + "SELECT lat, AVG(temperature) as avg_temp FROM test_table GROUP BY lat" + ).to_arrow_table() # Should have result for each lat value - df = result[0].to_pandas() + df = result.to_pandas() assert len(df) == 5, f"Expected 5 lat groups, got {len(df)}" - # All batches processed + # All partitions processed assert ( tracker.batch_count == 12 - ), f"Expected 12 batches, got {tracker.batch_count}" + ), f"Expected 12 partitions, got {tracker.batch_count}" class TestErrorPropagation: @@ -792,7 +791,8 @@ def failing_factory(): raise ValueError("Factory intentionally failed") schema = pa.schema([("value", pa.int64())]) - table = LazyArrowStreamTable(failing_factory, schema) + # API now requires a list of factories (one per partition) + table = LazyArrowStreamTable([failing_factory], schema) ctx = SessionContext() ctx.register_table("test_table", table) diff --git a/xarray_sql/sql_test.py b/xarray_sql/sql_test.py index 9374cb64..8feedeec 100644 --- a/xarray_sql/sql_test.py +++ b/xarray_sql/sql_test.py @@ -1,7 +1,5 @@ """SQL functionality tests for xarray-sql using pytest.""" -import numpy as np -import pandas as pd import pytest import xarray as xr