diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index df155784924..2df34145cd9 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -66,8 +66,11 @@ InMemoryFragment::InMemoryFragment(std::shared_ptr schema, InMemoryFragment::InMemoryFragment(RecordBatchVector record_batches, Expression partition_expression) - : InMemoryFragment(record_batches.empty() ? schema({}) : record_batches[0]->schema(), - std::move(record_batches), std::move(partition_expression)) {} + : Fragment(std::move(partition_expression), /*schema=*/nullptr), + record_batches_(std::move(record_batches)) { + // Order of argument evaluation is undefined, so compute physical_schema here + physical_schema_ = record_batches_.empty() ? schema({}) : record_batches_[0]->schema(); +} Result InMemoryFragment::Scan(std::shared_ptr options) { // Make an explicit copy of record_batches_ to ensure Scan can be called @@ -148,6 +151,28 @@ InMemoryDataset::InMemoryDataset(std::shared_ptr table) : Dataset(table->schema()), get_batches_(new TableRecordBatchGenerator(std::move(table))) {} +struct ReaderRecordBatchGenerator : InMemoryDataset::RecordBatchGenerator { + explicit ReaderRecordBatchGenerator(std::shared_ptr reader) + : reader_(std::move(reader)), consumed_(false) {} + + RecordBatchIterator Get() const final { + if (consumed_) { + return MakeErrorIterator>(Status::Invalid( + "RecordBatchReader-backed InMemoryDataset was already consumed")); + } + consumed_ = true; + auto reader = reader_; + return MakeFunctionIterator([reader] { return reader->Next(); }); + } + + std::shared_ptr reader_; + mutable bool consumed_; +}; + +InMemoryDataset::InMemoryDataset(std::shared_ptr reader) + : Dataset(reader->schema()), + get_batches_(new ReaderRecordBatchGenerator(std::move(reader))) {} + Result> InMemoryDataset::ReplaceSchema( std::shared_ptr schema) const { RETURN_NOT_OK(CheckProjectable(*schema_, *schema)); diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h index a28b79840d6..6be83059fc1 100644 --- a/cpp/src/arrow/dataset/dataset.h +++ b/cpp/src/arrow/dataset/dataset.h @@ -184,6 +184,7 @@ class ARROW_DS_EXPORT InMemoryDataset : public Dataset { InMemoryDataset(std::shared_ptr schema, RecordBatchVector batches); explicit InMemoryDataset(std::shared_ptr
table); + explicit InMemoryDataset(std::shared_ptr reader); std::string type_name() const override { return "in-memory"; } diff --git a/cpp/src/arrow/dataset/dataset_test.cc b/cpp/src/arrow/dataset/dataset_test.cc index a3603558924..1db96b8b5c3 100644 --- a/cpp/src/arrow/dataset/dataset_test.cc +++ b/cpp/src/arrow/dataset/dataset_test.cc @@ -79,6 +79,23 @@ TEST_F(TestInMemoryDataset, ReplaceSchema) { .status()); } +TEST_F(TestInMemoryDataset, FromReader) { + constexpr int64_t kBatchSize = 1024; + constexpr int64_t kNumberBatches = 16; + + SetSchema({field("i32", int32()), field("f64", float64())}); + auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); + auto source_reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch); + auto target_reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch); + + auto dataset = std::make_shared(source_reader); + + AssertDatasetEquals(target_reader.get(), dataset.get()); + // Such datasets can only be scanned once + ASSERT_OK_AND_ASSIGN(auto fragments, dataset->GetFragments()); + ASSERT_RAISES(Invalid, fragments.Next()); +} + TEST_F(TestInMemoryDataset, GetFragments) { constexpr int64_t kBatchSize = 1024; constexpr int64_t kNumberBatches = 16; @@ -93,6 +110,20 @@ TEST_F(TestInMemoryDataset, GetFragments) { AssertDatasetEquals(reader.get(), dataset.get()); } +TEST_F(TestInMemoryDataset, InMemoryFragment) { + constexpr int64_t kBatchSize = 1024; + + SetSchema({field("i32", int32()), field("f64", float64())}); + auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); + RecordBatchVector batches{batch}; + + // Regression test: previously this constructor relied on undefined behavior (order of + // evaluation of arguments) leading to fragments being constructed with empty schemas + auto fragment = std::make_shared(batches); + ASSERT_OK_AND_ASSIGN(auto schema, fragment->ReadPhysicalSchema()); + AssertSchemaEqual(batch->schema(), schema); +} + class TestUnionDataset : public DatasetFixtureMixin {}; TEST_F(TestUnionDataset, ReplaceSchema) { diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 12ddcee5343..387471185a1 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -26,11 +26,11 @@ import os import pyarrow as pa from pyarrow.lib cimport * -from pyarrow.lib import frombytes, tobytes +from pyarrow.lib import ArrowTypeError, frombytes, tobytes from pyarrow.includes.libarrow_dataset cimport * from pyarrow._fs cimport FileSystem, FileInfo, FileSelector from pyarrow._csv cimport ConvertOptions, ParseOptions, ReadOptions -from pyarrow.util import _is_path_like, _stringify_path +from pyarrow.util import _is_iterable, _is_path_like, _stringify_path from pyarrow._parquet cimport ( _create_writer_properties, _create_arrow_writer_properties, @@ -441,6 +441,76 @@ cdef class Dataset(_Weakrefable): return pyarrow_wrap_schema(self.dataset.schema()) +cdef class InMemoryDataset(Dataset): + """A Dataset wrapping in-memory data. + + Parameters + ---------- + source + The data for this dataset. Can be a RecordBatch, Table, list of + RecordBatch/Table, iterable of RecordBatch, or a RecordBatchReader. + If an iterable is provided, the schema must also be provided. + schema : Schema, optional + Only required if passing an iterable as the source. + """ + + cdef: + CInMemoryDataset* in_memory_dataset + + def __init__(self, source, Schema schema=None): + cdef: + RecordBatchReader reader + shared_ptr[CInMemoryDataset] in_memory_dataset + + if isinstance(source, (pa.RecordBatch, pa.Table)): + source = [source] + + if isinstance(source, (list, tuple)): + batches = [] + for item in source: + if isinstance(item, pa.RecordBatch): + batches.append(item) + elif isinstance(item, pa.Table): + batches.extend(item.to_batches()) + else: + raise TypeError( + 'Expected a list of tables or batches. The given list ' + 'contains a ' + type(item).__name__) + if schema is None: + schema = item.schema + elif not schema.equals(item.schema): + raise ArrowTypeError( + f'Item has schema\n{item.schema}\nwhich does not ' + f'match expected schema\n{schema}') + if not batches and schema is None: + raise ValueError('Must provide schema to construct in-memory ' + 'dataset from an empty list') + table = pa.Table.from_batches(batches, schema=schema) + in_memory_dataset = make_shared[CInMemoryDataset]( + pyarrow_unwrap_table(table)) + elif isinstance(source, pa.ipc.RecordBatchReader): + reader = source + in_memory_dataset = make_shared[CInMemoryDataset](reader.reader) + elif _is_iterable(source): + if schema is None: + raise ValueError('Must provide schema to construct in-memory ' + 'dataset from an iterable') + reader = pa.ipc.RecordBatchReader.from_batches(schema, source) + in_memory_dataset = make_shared[CInMemoryDataset](reader.reader) + else: + raise TypeError( + 'Expected a table, batch, iterable of tables/batches, or a ' + 'record batch reader instead of the given type: ' + + type(source).__name__ + ) + + self.init( in_memory_dataset) + + cdef void init(self, const shared_ptr[CDataset]& sp): + Dataset.init(self, sp) + self.in_memory_dataset = sp.get() + + cdef class UnionDataset(Dataset): """A Dataset wrapping child datasets. @@ -2600,10 +2670,14 @@ def _get_partition_keys(Expression partition_expression): def _filesystemdataset_write( - data not None, object base_dir not None, str basename_template not None, - Schema schema not None, FileSystem filesystem not None, + Dataset data not None, + object base_dir not None, + str basename_template not None, + Schema schema not None, + FileSystem filesystem not None, Partitioning partitioning not None, - FileWriteOptions file_options not None, bint use_threads, + FileWriteOptions file_options not None, + bint use_threads, int max_partitions, ): """ @@ -2621,21 +2695,7 @@ def _filesystemdataset_write( c_options.max_partitions = max_partitions c_options.basename_template = tobytes(basename_template) - if isinstance(data, Dataset): - scanner = data._scanner(use_threads=use_threads) - else: - # data is list of batches/tables - for table in data: - if isinstance(table, Table): - for batch in table.to_batches(): - c_batches.push_back(( batch).sp_batch) - else: - c_batches.push_back(( table).sp_batch) - - data = Fragment.wrap(shared_ptr[CFragment]( - new CInMemoryFragment(move(c_batches), _true.unwrap()))) - - scanner = Scanner.from_fragment(data, schema, use_threads=use_threads) + scanner = data._scanner(use_threads=use_threads) c_scanner = ( scanner).unwrap() with nogil: diff --git a/python/pyarrow/dataset.py b/python/pyarrow/dataset.py index 195d414b047..0c65070d872 100644 --- a/python/pyarrow/dataset.py +++ b/python/pyarrow/dataset.py @@ -18,7 +18,7 @@ """Dataset is currently unstable. APIs subject to change without notice.""" import pyarrow as pa -from pyarrow.util import _stringify_path, _is_path_like +from pyarrow.util import _is_iterable, _stringify_path, _is_path_like from pyarrow._dataset import ( # noqa CsvFileFormat, @@ -37,6 +37,7 @@ HivePartitioning, IpcFileFormat, IpcFileWriteOptions, + InMemoryDataset, ParquetDatasetFactory, ParquetFactoryOptions, ParquetFileFormat, @@ -408,6 +409,13 @@ def _filesystem_dataset(source, schema=None, filesystem=None, return factory.finish(schema) +def _in_memory_dataset(source, schema=None, **kwargs): + if any(v is not None for v in kwargs.values()): + raise ValueError( + "For in-memory datasets, you cannot pass any additional arguments") + return InMemoryDataset(source, schema) + + def _union_dataset(children, schema=None, **kwargs): if any(v is not None for v in kwargs.values()): raise ValueError( @@ -508,7 +516,8 @@ def dataset(source, schema=None, format=None, filesystem=None, Parameters ---------- - source : path, list of paths, dataset, list of datasets or URI + source : path, list of paths, dataset, list of datasets, (list of) batches + or tables, iterable of batches, RecordBatchReader, or URI Path pointing to a single file: Open a FileSystemDataset from a single file. Path pointing to a directory: @@ -524,6 +533,11 @@ def dataset(source, schema=None, format=None, filesystem=None, A nested UnionDataset gets constructed, it allows arbitrary composition of other datasets. Note that additional keyword arguments are not allowed. + (List of) batches or tables, iterable of batches, or RecordBatchReader: + Create an InMemoryDataset. If an iterable or empty list is given, + a schema must also be given. If an iterable or RecordBatchReader + is given, the resulting dataset can only be scanned once; further + attempts will raise an error. schema : Schema, optional Optionally provide the Schema for the Dataset, in which case it will not be inferred from the source. @@ -636,7 +650,6 @@ def dataset(source, schema=None, format=None, filesystem=None, selector_ignore_prefixes=ignore_prefixes ) - # TODO(kszucs): support InMemoryDataset for a table input if _is_path_like(source): return _filesystem_dataset(source, **kwargs) elif isinstance(source, (tuple, list)): @@ -644,13 +657,22 @@ def dataset(source, schema=None, format=None, filesystem=None, return _filesystem_dataset(source, **kwargs) elif all(isinstance(elem, Dataset) for elem in source): return _union_dataset(source, **kwargs) + elif all(isinstance(elem, (pa.RecordBatch, pa.Table)) + for elem in source): + return _in_memory_dataset(source, **kwargs) else: unique_types = set(type(elem).__name__ for elem in source) type_names = ', '.join('{}'.format(t) for t in unique_types) raise TypeError( - 'Expected a list of path-like or dataset objects. The given ' - 'list contains the following types: {}'.format(type_names) + 'Expected a list of path-like or dataset objects, or a list ' + 'of batches or tables. The given list contains the following ' + 'types: {}'.format(type_names) ) + elif isinstance(source, (pa.RecordBatch, pa.ipc.RecordBatchReader, + pa.Table)): + return _in_memory_dataset(source, **kwargs) + elif _is_iterable(source): + return _in_memory_dataset(source, **kwargs) else: raise TypeError( 'Expected a path-like, list of path-likes or a list of Datasets ' @@ -676,9 +698,11 @@ def write_dataset(data, base_dir, basename_template=None, format=None, Parameters ---------- - data : Dataset, Table/RecordBatch, or list of Table/RecordBatch + data : Dataset, Table/RecordBatch, RecordBatchReader, list of + Table/RecordBatch, or iterable of RecordBatch The data to write. This can be a Dataset instance or - in-memory Arrow data. + in-memory Arrow data. If an iterable is given, the schema must + also be given. base_dir : str The root directory where to write the dataset. basename_template : str, optional @@ -710,15 +734,17 @@ def write_dataset(data, base_dir, basename_template=None, format=None, if isinstance(data, Dataset): schema = schema or data.schema - elif isinstance(data, (pa.Table, pa.RecordBatch)): - schema = schema or data.schema - data = [data] - elif isinstance(data, list): + elif isinstance(data, (list, tuple)): schema = schema or data[0].schema + data = InMemoryDataset(data, schema=schema) + elif isinstance(data, (pa.RecordBatch, pa.ipc.RecordBatchReader, + pa.Table)) or _is_iterable(data): + data = InMemoryDataset(data, schema=schema) + schema = schema or data.schema else: raise ValueError( - "Only Dataset, Table/RecordBatch or a list of Table/RecordBatch " - "objects are supported." + "Only Dataset, Table/RecordBatch, RecordBatchReader, a list " + "of Tables/RecordBatches, or iterable of batches are supported." ) if format is None and isinstance(data, FileSystemDataset): diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index f7f2a142001..db2e73acdff 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -124,6 +124,11 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: CResult[shared_ptr[CScannerBuilder]] NewScan() + cdef cppclass CInMemoryDataset "arrow::dataset::InMemoryDataset"( + CDataset): + CInMemoryDataset(shared_ptr[CRecordBatchReader]) + CInMemoryDataset(shared_ptr[CTable]) + cdef cppclass CUnionDataset "arrow::dataset::UnionDataset"( CDataset): @staticmethod diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 36cff9958f9..a7dd1520168 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -1622,13 +1622,16 @@ def test_construct_from_invalid_sources_raise(multisourcefs): fs.FileSelector('/schema'), format=ds.ParquetFileFormat() ) + batch1 = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"]) + batch2 = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["b"]) with pytest.raises(TypeError, match='Expected.*FileSystemDatasetFactory'): ds.dataset([child1, child2]) expected = ( - "Expected a list of path-like or dataset objects. The given list " - "contains the following types: int" + "Expected a list of path-like or dataset objects, or a list " + "of batches or tables. The given list contains the following " + "types: int" ) with pytest.raises(TypeError, match=expected): ds.dataset([1, 2, 3]) @@ -1640,6 +1643,85 @@ def test_construct_from_invalid_sources_raise(multisourcefs): with pytest.raises(TypeError, match=expected): ds.dataset(None) + expected = ( + "Must provide schema to construct in-memory dataset from an iterable" + ) + with pytest.raises(ValueError, match=expected): + ds.dataset((batch1 for _ in range(3))) + + expected = ( + "Must provide schema to construct in-memory dataset from an empty list" + ) + with pytest.raises(ValueError, match=expected): + ds.InMemoryDataset([]) + + expected = ( + "Item has schema\nb: int64\nwhich does not match expected schema\n" + "a: int64" + ) + with pytest.raises(TypeError, match=expected): + ds.dataset([batch1, batch2]) + + expected = ( + "Expected a list of path-like or dataset objects, or a list of " + "batches or tables. The given list contains the following types:" + ) + with pytest.raises(TypeError, match=expected): + ds.dataset([batch1, 0]) + + expected = ( + "Expected a list of tables or batches. The given list contains a int" + ) + with pytest.raises(TypeError, match=expected): + ds.InMemoryDataset([batch1, 0]) + + +def test_construct_in_memory(): + batch = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"]) + table = pa.Table.from_batches([batch]) + reader = pa.ipc.RecordBatchReader.from_batches(batch.schema, [batch]) + iterable = (batch for _ in range(1)) + + for source in (batch, table, reader, [batch], [table]): + dataset = ds.dataset(source) + assert dataset.to_table() == table + + assert ds.dataset(iterable, schema=batch.schema).to_table().equals(table) + assert ds.dataset([], schema=pa.schema([])).to_table() == pa.table([]) + + # When constructed from batches/tables, should be reusable + for source in (batch, table, [batch], [table]): + dataset = ds.dataset(source) + assert len(list(dataset.get_fragments())) == 1 + assert len(list(dataset.get_fragments())) == 1 + assert dataset.to_table() == table + assert dataset.to_table() == table + assert next(dataset.get_fragments()).to_table() == table + + # When constructed from readers/iterators, should be one-shot + match = "InMemoryDataset was already consumed" + for factory in ( + lambda: pa.ipc.RecordBatchReader.from_batches( + batch.schema, [batch]), + lambda: (batch for _ in range(1)), + ): + dataset = ds.dataset(factory(), schema=batch.schema) + # Getting fragments consumes the underlying iterator + fragments = list(dataset.get_fragments()) + assert len(fragments) == 1 + assert fragments[0].to_table() == table + with pytest.raises(pa.ArrowInvalid, match=match): + list(dataset.get_fragments()) + with pytest.raises(pa.ArrowInvalid, match=match): + dataset.to_table() + # Materializing consumes the underlying iterator + dataset = ds.dataset(factory(), schema=batch.schema) + assert dataset.to_table() == table + with pytest.raises(pa.ArrowInvalid, match=match): + list(dataset.get_fragments()) + with pytest.raises(pa.ArrowInvalid, match=match): + dataset.to_table() + @pytest.mark.parquet def test_open_dataset_partitioned_directory(tempdir): @@ -2856,6 +2938,28 @@ def test_write_table_multiple_fragments(tempdir): ) +def test_write_iterable(tempdir): + table = pa.table([ + pa.array(range(20)), pa.array(np.random.randn(20)), + pa.array(np.repeat(['a', 'b'], 10)) + ], names=["f1", "f2", "part"]) + + base_dir = tempdir / 'inmemory_iterable' + ds.write_dataset((batch for batch in table.to_batches()), base_dir, + schema=table.schema, + basename_template='dat_{i}.arrow', format="feather") + result = ds.dataset(base_dir, format="ipc").to_table() + assert result.equals(table) + + base_dir = tempdir / 'inmemory_reader' + reader = pa.ipc.RecordBatchReader.from_batches(table.schema, + table.to_batches()) + ds.write_dataset(reader, base_dir, + basename_template='dat_{i}.arrow', format="feather") + result = ds.dataset(base_dir, format="ipc").to_table() + assert result.equals(table) + + def test_write_table_partitioned_dict(tempdir): # ensure writing table partitioned on a dictionary column works without # specifying the dictionary values explicitly diff --git a/python/pyarrow/util.py b/python/pyarrow/util.py index e91294a3a1b..446e6733351 100644 --- a/python/pyarrow/util.py +++ b/python/pyarrow/util.py @@ -62,6 +62,14 @@ def __instancecheck__(self, other): return _DeprecatedMeta(old_name, (new_class,), {}) +def _is_iterable(obj): + try: + iter(obj) + return True + except TypeError: + return False + + def _is_path_like(path): # PEP519 filesystem path protocol is available from python 3.6, so pathlib # doesn't implement __fspath__ for earlier versions