Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

105 changes: 63 additions & 42 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,12 @@
//! incrementally. This enables processing of larger-than-memory datasets when combined
//! with DataFusion's streaming execution.
//!
//! ## Parallel Execution Note
//! ## Parallel Execution
//!
//! When using DataFusion's parallel execution (multiple partitions), aggregation queries
//! without ORDER BY may return partial results due to how our stream interacts with
//! DataFusion's async runtime. To ensure complete results:
//! - Add ORDER BY to aggregation queries, or
//! - Use `SessionConfig().with_target_partitions(1)` for single-threaded execution
//! TODO(#106): Implement proper parallelism and partition handling.
//! Each xarray chunk becomes a separate partition, enabling parallel execution across
//! multiple cores. Due to a bug in DataFusion v51.0.0's `collect()` method, aggregation
//! queries should use `to_arrow_table()` instead to ensure complete results.
//! TODO(#107): Upgrading to the latest datafusion-python (52+) should fix this.

use std::ffi::CString;
use std::fmt::Debug;
Expand Down Expand Up @@ -140,40 +138,44 @@ impl PartitionStream for PyArrowStreamPartition {
}
}

/// A lazy table provider that wraps a Python stream factory.
/// A lazy table provider that wraps Python stream factory functions.
///
/// This class implements the `__datafusion_table_provider__` protocol, allowing
/// it to be registered with DataFusion's `SessionContext.register_table()`.
///
/// Data is NOT read until query execution time - this enables true lazy evaluation.
/// The factory function is called on each query execution to create a fresh stream,
/// allowing the same table to be queried multiple times.
/// Each partition has its own factory function that is called on query execution
/// to create a fresh stream, enabling true parallelism in DataFusion.
///
/// # Note
///
/// Due to a bug in DataFusion v51.0.0's `collect()` method, use `to_arrow_table()`
/// instead for aggregation queries to ensure complete results.
///
/// # Example
///
/// ```python
/// from datafusion import SessionContext
/// from xarray_sql import LazyArrowStreamTable, XarrayRecordBatchReader
///
/// # Create a factory that produces lazy readers
/// def make_reader():
/// return XarrayRecordBatchReader(ds, chunks={'time': 240})
/// from xarray_sql import LazyArrowStreamTable
/// import pyarrow as pa
///
/// # Get schema from a sample reader
/// sample = make_reader()
/// schema = sample.schema
/// # Create factories for each partition (chunk)
/// factories = [
/// lambda: pa.RecordBatchReader.from_batches(schema, batches_chunk_0),
/// lambda: pa.RecordBatchReader.from_batches(schema, batches_chunk_1),
/// ]
///
/// # Wrap factory in lazy table - NO DATA LOADED
/// table = LazyArrowStreamTable(make_reader, schema)
/// # Wrap factories in lazy table - NO DATA LOADED
/// table = LazyArrowStreamTable(factories, schema)
///
/// # Register with DataFusion - STILL NO DATA LOADED
/// ctx = SessionContext()
/// ctx.register_table("air", table)
///
/// # Data only loaded HERE during collect()
/// # Each query creates a fresh stream via the factory
/// result = ctx.sql("SELECT AVG(air) FROM air").collect()
/// result2 = ctx.sql("SELECT * FROM air LIMIT 10").collect() # Works!
/// # Data only loaded HERE during query execution
/// # Each partition runs in parallel with its own factory
/// # Use to_arrow_table() for aggregation queries
/// result = ctx.sql("SELECT AVG(air) FROM air").to_arrow_table()
/// ```
#[pyclass(name = "LazyArrowStreamTable")]
struct LazyArrowStreamTable {
Expand All @@ -183,19 +185,21 @@ struct LazyArrowStreamTable {

#[pymethods]
impl LazyArrowStreamTable {
/// Create a new LazyArrowStreamTable from a stream factory function.
/// Create a new LazyArrowStreamTable from stream factory functions.
///
/// Args:
/// stream_factory: A callable that returns a Python object implementing
/// the Arrow PyCapsule interface (`__arrow_c_stream__`).
/// Called on each query execution to create a fresh stream.
/// schema: A PyArrow Schema for the table. Required since the factory
/// hasn't been called yet.
/// stream_factories: A list of callables, each returning a Python object
/// implementing the Arrow PyCapsule interface (`__arrow_c_stream__`).
/// Each factory represents one partition, enabling parallel execution.
/// Called on each query execution to create fresh streams.
/// schema: A PyArrow Schema for the table. Required since the factories
/// haven't been called yet.
///
/// Raises:
/// TypeError: If the schema is not a valid PyArrow Schema.
/// ValueError: If stream_factories is empty.
#[new]
fn new(stream_factory: &Bound<'_, PyAny>, schema: &Bound<'_, PyAny>) -> PyResult<Self> {
fn new(stream_factories: &Bound<'_, PyAny>, schema: &Bound<'_, PyAny>) -> PyResult<Self> {
// Convert the PyArrow schema to Arrow schema
use arrow::datatypes::Schema;
use arrow::pyarrow::FromPyArrow;
Expand All @@ -205,17 +209,34 @@ impl LazyArrowStreamTable {
})?;
let schema_ref = Arc::new(arrow_schema);

// Create the partition stream with the factory
let partition =
PyArrowStreamPartition::new(stream_factory.clone().unbind(), schema_ref.clone());
// Extract factories from the Python list
let factories: Vec<Py<PyAny>> = stream_factories.extract().map_err(|e| {
pyo3::exceptions::PyTypeError::new_err(format!(
"stream_factories must be a list of callables: {e}"
))
})?;

// Create the StreamingTable
let table =
StreamingTable::try_new(schema_ref, vec![Arc::new(partition)]).map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"Failed to create StreamingTable: {e}"
))
})?;
if factories.is_empty() {
return Err(pyo3::exceptions::PyValueError::new_err(
"stream_factories must not be empty",
));
}

// Create one partition per factory
let partitions: Vec<Arc<dyn PartitionStream>> = factories
.into_iter()
.map(|factory| {
Arc::new(PyArrowStreamPartition::new(factory, schema_ref.clone()))
as Arc<dyn PartitionStream>
})
.collect();

// Create the StreamingTable with multiple partitions
let table = StreamingTable::try_new(schema_ref, partitions).map_err(|e| {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep looking at the docs I think this makes sense, this creates a TableProvider with all your partitions, which implements scan, which produces an ExecutionPlan (without you explicitly building the ExecutionPlan).

On thing though, the code isn't in this PR, but in your PartitionStream implementation for PyArrowStreamPartition, inside the try_stream!, you are calling Python::attach, which basically acquires the GIL, but the GIL can only be acquired by one thread at a time. I think it's essentially acting as a mutex_guard. So I think, and I'm not sure about this but you can probably test on a decent sized zarr store and check the core usage, that while you are indeed creating partitions that will run in parallel with this, in practice you might only have one partition reading data at a time, with the other ones waiting.

Now, even if it is the case, I think what you have here, with partitions, would have a big performance advantage over a single partition, because if you create multiple partitions from the start, downstream execution plans can leverage that. For example, say you have 4 partitions, say partition 1 reads some data, streams it, then partition 2 reads data, streams it, sequentially, but while partition 2 is reading data, whatever operations you have downstream can start working on the data streamed from partition 1. So you might have a "slow" start, but as data is being read, the query could start leveraging the partitions (assuming you have more operations after just reading the data).

On that note though, looking at the PartitionStream implementation, I think right now, one partition would acquire the GIL, read all its batches, and only then let another partition acquire the GIL. It would probably be better to allow partitions to be "interleaved", i.e. partition 1 reads one batch, partition 2 one batch, ..., back to partition 1 for its second batch, and so on. I think you can accomplish that by using allow_threads, https://pyo3.rs/v0.9.2/parallelism. Like even if just a little bit of rust code runs within an allow_threads, I think it might allow a different partition to acquire the GIL.

Okay done with my overly long ramblings! I'm not 100% about all of the above, but I think that if you have some decent sized data to test this on easily, it might be worth just trying with allow_threads, see if it impacts performance (I think you'd have to include some compute heavy operations that can run on individual batches to see the difference though).

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'll look into your GIL/parallelism enhancement in a future PR! Thank you so much for the suggestion!

I will add an acceptance test to see if the GIL acquisition is really slowing us down. From my experience so far, we acquire the GIL just to schedule the partitions, but since they are arrow streams ("send"-like), the processing happens in parallel in DF.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, yeah I'm not sure anymore then... in any case let me know what you find out if you test it out!

pyo3::exceptions::PyRuntimeError::new_err(format!(
"Failed to create StreamingTable: {e}"
))
})?;

Ok(Self {
table: Arc::new(table),
Expand All @@ -236,7 +257,7 @@ impl LazyArrowStreamTable {
// Try to get the current tokio runtime handle (available when called from DataFusion context)
let runtime = Handle::try_current().ok();

// Create FFI wrapper (v49 API takes 3 arguments)
// Create FFI wrapper
let ffi_provider = FFI_TableProvider::new(
provider, false, // can_support_pushdown_filters
runtime,
Expand Down
59 changes: 46 additions & 13 deletions xarray_sql/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def read_xarray(ds: xr.Dataset, chunks: Chunks = None) -> pa.RecordBatchReader:
"""Pivots an Xarray Dataset into a PyArrow Table, partitioned by chunks.

Args:
ds: An Xarray Dataset. All `data_vars` mush share the same dimensions.
ds: An Xarray Dataset. All `data_vars` must share the same dimensions.
chunks: Xarray-like chunks. If not provided, will default to the Dataset's
chunks. The product of the chunk sizes becomes the standard length of each
dataframe partition.
Expand All @@ -176,8 +176,24 @@ def read_xarray_table(
"""Create a lazy DataFusion table from an xarray Dataset.

This is the simplest way to register xarray data with DataFusion.
Data is only read when queries are executed (during collect()),
not during registration. The table can be queried multiple times.
Data is only read when queries are executed, not during registration.
The table can be queried multiple times.

Each chunk becomes a separate partition, enabling DataFusion's parallel
execution across multiple cores.

Note:
Due to a bug in DataFusion v51.0.0's collect() method, use
`to_arrow_table()` instead of `collect()` for aggregation queries
to ensure complete results::

# Correct - use to_arrow_table()
result = ctx.sql('SELECT lat, AVG(temp) FROM t GROUP BY lat').to_arrow_table()

# May return partial results with collect()
result = ctx.sql('SELECT lat, AVG(temp) FROM t GROUP BY lat').collect()

This should be fixed when we upgrade datafusion-python to 52 (#107).

Args:
ds: An xarray Dataset. All data_vars must share the same dimensions.
Expand All @@ -200,21 +216,38 @@ def read_xarray_table(
>>> ctx = SessionContext()
>>> ctx.register_table('air', table)
>>>
>>> # Data is only read here, during collect()
>>> result = ctx.sql('SELECT AVG(air) FROM air').collect()
>>> # Data is only read here, during query execution
>>> result = ctx.sql('SELECT AVG(air) FROM air').to_arrow_table()
>>> # Can query again - each query creates a fresh stream
>>> result2 = ctx.sql('SELECT * FROM air LIMIT 10').collect()
>>> result2 = ctx.sql('SELECT * FROM air LIMIT 10').to_arrow_table()
"""
from ._native import LazyArrowStreamTable

# Get schema from dataset without creating a stream
schema = _parse_schema(ds)

# Create a factory function that produces fresh RecordBatchReaders on each call
def make_stream() -> pa.RecordBatchReader:
stream = XarrayRecordBatchReader(
ds, chunks, _iteration_callback=_iteration_callback
)
return pa.RecordBatchReader.from_stream(stream)
blocks = block_slices(ds, chunks)

# Create a factory function for each block (partition)
# Each factory produces a RecordBatchReader for its specific chunk
def make_partition_factory(
block: Block,
) -> t.Callable[[], pa.RecordBatchReader]:
"""Create a factory function for a specific block/chunk."""

def make_stream() -> pa.RecordBatchReader:
# Call the iteration callback if provided (for testing)
if _iteration_callback is not None:
_iteration_callback(block)

# Extract just this block from the dataset and convert to Arrow
df = pivot(ds.isel(block))
batch = pa.RecordBatch.from_pandas(df, schema=schema)
return pa.RecordBatchReader.from_batches(schema, [batch])

return make_stream

# Create one factory per block
factories = [make_partition_factory(block) for block in blocks]

return LazyArrowStreamTable(make_stream, schema)
return LazyArrowStreamTable(factories, schema)
Loading
Loading