From 194f1adb47b0714ddf42adbe325d954372dcb773 Mon Sep 17 00:00:00 2001 From: Bryce Mecum Date: Fri, 6 Jun 2025 16:29:27 +0000 Subject: [PATCH] Enable creating InMemoryDataset from RBR --- python/pyarrow/_dataset.pyx | 6 +++--- python/pyarrow/dataset.py | 2 +- python/pyarrow/tests/test_dataset.py | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 9e5edee5742..478c6b3f7c1 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -1011,7 +1011,7 @@ cdef class InMemoryDataset(Dataset): if isinstance(source, (pa.RecordBatch, pa.Table)): source = [source] - if isinstance(source, (list, tuple)): + if isinstance(source, (list, tuple, pa.RecordBatchReader)): batches = [] for item in source: if isinstance(item, pa.RecordBatch): @@ -1036,8 +1036,8 @@ cdef class InMemoryDataset(Dataset): pyarrow_unwrap_table(table)) else: raise TypeError( - 'Expected a table, batch, or list of tables/batches ' - 'instead of the given type: ' + + 'Expected a Table, RecordBatch, list of Table/RecordBatch, ' + 'or RecordBatchReader instead of the given type: ' + type(source).__name__ ) diff --git a/python/pyarrow/dataset.py b/python/pyarrow/dataset.py index 26602c1e175..ef4f7288723 100644 --- a/python/pyarrow/dataset.py +++ b/python/pyarrow/dataset.py @@ -804,7 +804,7 @@ def dataset(source, schema=None, format=None, filesystem=None, 'of batches or tables. The given list contains the following ' f'types: {type_names}' ) - elif isinstance(source, (pa.RecordBatch, pa.Table)): + elif isinstance(source, (pa.RecordBatch, pa.Table, pa.RecordBatchReader)): return _in_memory_dataset(source, **kwargs) else: raise TypeError( diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 4af0f914eb6..c17e038713a 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -2558,13 +2558,14 @@ def test_construct_from_invalid_sources_raise(multisourcefs): def test_construct_in_memory(dataset_reader): batch = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"]) + reader = pa.RecordBatchReader.from_batches(batch.schema, [batch]) table = pa.Table.from_batches([batch]) dataset_table = ds.dataset([], format='ipc', schema=pa.schema([]) ).to_table() assert dataset_table == pa.table([]) - for source in (batch, table, [batch], [table]): + for source in (batch, table, [batch], [table], reader): dataset = ds.dataset(source) assert dataset_reader.to_table(dataset) == table assert len(list(dataset.get_fragments())) == 1