diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index ab0600dd1a8..5afb6a06c16 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -199,28 +199,6 @@ InMemoryDataset::InMemoryDataset(std::shared_ptr table) : Dataset(table->schema()), get_batches_(new TableRecordBatchGenerator(std::move(table))) {} -struct ReaderRecordBatchGenerator : InMemoryDataset::RecordBatchGenerator { - explicit ReaderRecordBatchGenerator(std::shared_ptr reader) - : reader_(std::move(reader)), consumed_(false) {} - - RecordBatchIterator Get() const final { - if (consumed_) { - return MakeErrorIterator>(Status::Invalid( - "RecordBatchReader-backed InMemoryDataset was already consumed")); - } - consumed_ = true; - auto reader = reader_; - return MakeFunctionIterator([reader] { return reader->Next(); }); - } - - std::shared_ptr reader_; - mutable bool consumed_; -}; - -InMemoryDataset::InMemoryDataset(std::shared_ptr reader) - : Dataset(reader->schema()), - get_batches_(new ReaderRecordBatchGenerator(std::move(reader))) {} - Result> InMemoryDataset::ReplaceSchema( std::shared_ptr schema) const { RETURN_NOT_OK(CheckProjectable(*schema_, *schema)); diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h index 40a60ffd48e..164b3ec17aa 100644 --- a/cpp/src/arrow/dataset/dataset.h +++ b/cpp/src/arrow/dataset/dataset.h @@ -206,7 +206,6 @@ class ARROW_DS_EXPORT InMemoryDataset : public Dataset { /// Convenience constructor taking a Table explicit InMemoryDataset(std::shared_ptr
table); - explicit InMemoryDataset(std::shared_ptr reader); std::string type_name() const override { return "in-memory"; } diff --git a/cpp/src/arrow/dataset/dataset_test.cc b/cpp/src/arrow/dataset/dataset_test.cc index 7aa0e1a2413..66d69c30c82 100644 --- a/cpp/src/arrow/dataset/dataset_test.cc +++ b/cpp/src/arrow/dataset/dataset_test.cc @@ -79,23 +79,6 @@ TEST_F(TestInMemoryDataset, ReplaceSchema) { .status()); } -TEST_F(TestInMemoryDataset, FromReader) { - constexpr int64_t kBatchSize = 1024; - constexpr int64_t kNumberBatches = 16; - - SetSchema({field("i32", int32()), field("f64", float64())}); - auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); - auto source_reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch); - auto target_reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch); - - auto dataset = std::make_shared(source_reader); - - AssertDatasetEquals(target_reader.get(), dataset.get()); - // Such datasets can only be scanned once - ASSERT_OK_AND_ASSIGN(auto fragments, dataset->GetFragments()); - ASSERT_RAISES(Invalid, fragments.Next()); -} - TEST_F(TestInMemoryDataset, GetFragments) { constexpr int64_t kBatchSize = 1024; constexpr int64_t kNumberBatches = 16; diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 41fa7ec5c77..3c2c7f65ee7 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -673,6 +673,64 @@ ScannerBuilder::ScannerBuilder(std::shared_ptr schema, DCHECK_OK(Filter(scan_options_->filter)); } +namespace { +class OneShotScanTask : public ScanTask { + public: + OneShotScanTask(RecordBatchIterator batch_it, std::shared_ptr options, + std::shared_ptr fragment) + : ScanTask(std::move(options), std::move(fragment)), + batch_it_(std::move(batch_it)) {} + Result Execute() override { + if (!batch_it_) return Status::Invalid("OneShotScanTask was already scanned"); + return std::move(batch_it_); + } + + private: + RecordBatchIterator batch_it_; +}; + +class OneShotFragment : public Fragment { + public: + OneShotFragment(std::shared_ptr schema, RecordBatchIterator batch_it) + : Fragment(compute::literal(true), std::move(schema)), + batch_it_(std::move(batch_it)) { + DCHECK_NE(physical_schema_, nullptr); + } + Status CheckConsumed() { + if (!batch_it_) return Status::Invalid("OneShotFragment was already scanned"); + return Status::OK(); + } + Result Scan(std::shared_ptr options) override { + RETURN_NOT_OK(CheckConsumed()); + ScanTaskVector tasks{std::make_shared( + std::move(batch_it_), std::move(options), shared_from_this())}; + return MakeVectorIterator(std::move(tasks)); + } + Result ScanBatchesAsync( + const std::shared_ptr& options) override { + RETURN_NOT_OK(CheckConsumed()); + return MakeBackgroundGenerator(std::move(batch_it_), options->io_context.executor()); + } + std::string type_name() const override { return "one-shot"; } + + protected: + Result> ReadPhysicalSchemaImpl() override { + return physical_schema_; + } + + RecordBatchIterator batch_it_; +}; +} // namespace + +std::shared_ptr ScannerBuilder::FromRecordBatchReader( + std::shared_ptr reader) { + auto batch_it = MakeIteratorFromReader(reader); + auto fragment = + std::make_shared(reader->schema(), std::move(batch_it)); + return std::make_shared(reader->schema(), std::move(fragment), + std::make_shared()); +} + const std::shared_ptr& ScannerBuilder::schema() const { return scan_options_->dataset_schema; } diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index 15bd27ab4f3..bbb79ee474f 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -312,6 +312,14 @@ class ARROW_DS_EXPORT ScannerBuilder { ScannerBuilder(std::shared_ptr schema, std::shared_ptr fragment, std::shared_ptr scan_options); + /// \brief Make a scanner from a record batch reader. + /// + /// The resulting scanner can be scanned only once. This is intended + /// to support writing data from streaming sources or other sources + /// that can be iterated only once. + static std::shared_ptr FromRecordBatchReader( + std::shared_ptr reader); + /// \brief Set the subset of columns to materialize. /// /// Columns which are not referenced may not be read from fragments. diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 17f4e079ae4..c44f28afb3b 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -507,6 +507,30 @@ TEST_P(TestScanner, Head) { AssertTablesEqual(*expected, *actual); } +TEST_P(TestScanner, FromReader) { + if (GetParam().use_async) { + GTEST_SKIP() << "Async scanner does not support construction from reader"; + } + auto batch_size = GetParam().items_per_batch; + auto num_batches = GetParam().num_batches; + + SetSchema({field("i32", int32()), field("f64", float64())}); + auto batch = ConstantArrayGenerator::Zeroes(batch_size, schema_); + auto source_reader = ConstantArrayGenerator::Repeat(num_batches, batch); + auto target_reader = ConstantArrayGenerator::Repeat(num_batches, batch); + + auto builder = ScannerBuilder::FromRecordBatchReader(source_reader); + ARROW_EXPECT_OK(builder->UseThreads(GetParam().use_threads)); + ASSERT_OK_AND_ASSIGN(auto scanner, builder->Finish()); + AssertScannerEquals(target_reader.get(), scanner.get()); + + // Such datasets can only be scanned once (but you can get fragments multiple times) + ASSERT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches()); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("OneShotFragment was already scanned"), + batch_it.Next()); +} + INSTANTIATE_TEST_SUITE_P(TestScannerThreading, TestScanner, ::testing::ValuesIn(TestScannerParams::Values())); diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index cf076f6536b..2a78a7fe712 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -298,6 +298,7 @@ cdef class Dataset(_Weakrefable): classes = { 'union': UnionDataset, 'filesystem': FileSystemDataset, + 'in-memory': InMemoryDataset, } class_ = classes.get(type_name, None) @@ -511,19 +512,10 @@ cdef class InMemoryDataset(Dataset): table = pa.Table.from_batches(batches, schema=schema) in_memory_dataset = make_shared[CInMemoryDataset]( pyarrow_unwrap_table(table)) - elif isinstance(source, pa.ipc.RecordBatchReader): - reader = source - in_memory_dataset = make_shared[CInMemoryDataset](reader.reader) - elif _is_iterable(source): - if schema is None: - raise ValueError('Must provide schema to construct in-memory ' - 'dataset from an iterable') - reader = pa.ipc.RecordBatchReader.from_batches(schema, source) - in_memory_dataset = make_shared[CInMemoryDataset](reader.reader) else: raise TypeError( - 'Expected a table, batch, iterable of tables/batches, or a ' - 'record batch reader instead of the given type: ' + + 'Expected a table, batch, or list of tables/batches ' + 'instead of the given type: ' + type(source).__name__ ) @@ -2751,6 +2743,46 @@ cdef class Scanner(_Weakrefable): scanner = GetResultValue(builder.get().Finish()) return Scanner.wrap(scanner) + @staticmethod + def from_batches(source, Schema schema=None, bint use_threads=True, + MemoryPool memory_pool=None, object columns=None, + Expression filter=None, + int batch_size=_DEFAULT_BATCH_SIZE, + FragmentScanOptions fragment_scan_options=None): + """Create a Scanner from an iterator of batches. + + This creates a scanner which can be used only once. It is + intended to support writing a dataset (which takes a scanner) + from a source which can be read only once (e.g. a + RecordBatchReader or generator). + """ + cdef: + shared_ptr[CScanOptions] options = make_shared[CScanOptions]() + shared_ptr[CScannerBuilder] builder + shared_ptr[CScanner] scanner + RecordBatchReader reader + if isinstance(source, pa.ipc.RecordBatchReader): + if schema: + raise ValueError('Cannot specify a schema when providing ' + 'a RecordBatchReader') + reader = source + elif _is_iterable(source): + if schema is None: + raise ValueError('Must provide schema to construct scanner ' + 'from an iterable') + reader = pa.ipc.RecordBatchReader.from_batches(schema, source) + else: + raise TypeError('Expected a RecordBatchReader or an iterable of ' + 'batches instead of the given type: ' + + type(source).__name__) + builder = CScannerBuilder.FromRecordBatchReader(reader.reader) + _populate_builder(builder, columns=columns, filter=filter, + batch_size=batch_size, use_threads=use_threads, + memory_pool=memory_pool, + fragment_scan_options=fragment_scan_options) + scanner = GetResultValue(builder.get().Finish()) + return Scanner.wrap(scanner) + @property def dataset_schema(self): """The schema with which batches will be read from fragments.""" diff --git a/python/pyarrow/dataset.py b/python/pyarrow/dataset.py index 97d08844f27..e80de1688e7 100644 --- a/python/pyarrow/dataset.py +++ b/python/pyarrow/dataset.py @@ -669,10 +669,7 @@ def dataset(source, schema=None, format=None, filesystem=None, 'of batches or tables. The given list contains the following ' 'types: {}'.format(type_names) ) - elif isinstance(source, (pa.RecordBatch, pa.ipc.RecordBatchReader, - pa.Table)): - return _in_memory_dataset(source, **kwargs) - elif _is_iterable(source): + elif isinstance(source, (pa.RecordBatch, pa.Table)): return _in_memory_dataset(source, **kwargs) else: raise TypeError( @@ -736,9 +733,12 @@ def write_dataset(data, base_dir, basename_template=None, format=None, if isinstance(data, (list, tuple)): schema = schema or data[0].schema data = InMemoryDataset(data, schema=schema) - elif isinstance(data, (pa.RecordBatch, pa.ipc.RecordBatchReader, - pa.Table)) or _is_iterable(data): + elif isinstance(data, (pa.RecordBatch, pa.Table)): + schema = schema or data.schema data = InMemoryDataset(data, schema=schema) + elif isinstance(data, pa.ipc.RecordBatchReader) or _is_iterable(data): + data = Scanner.from_batches(data, schema=schema) + schema = None elif not isinstance(data, (Dataset, Scanner)): raise ValueError( "Only Dataset, Scanner, Table/RecordBatch, RecordBatchReader, " diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index bff1a2bbb54..9105abcf7a0 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -121,6 +121,10 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: shared_ptr[CScanOptions] scan_options) CScannerBuilder(shared_ptr[CSchema], shared_ptr[CFragment], shared_ptr[CScanOptions] scan_options) + + @staticmethod + shared_ptr[CScannerBuilder] FromRecordBatchReader( + shared_ptr[CRecordBatchReader] reader) CStatus ProjectColumns "Project"(const vector[c_string]& columns) CStatus Project(vector[CExpression]& exprs, vector[c_string]& columns) CStatus Filter(CExpression filter) diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 8791c22f103..20a0a607c1d 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -1682,9 +1682,10 @@ def test_construct_from_invalid_sources_raise(multisourcefs): ds.dataset(None) expected = ( - "Must provide schema to construct in-memory dataset from an iterable" + "Expected a path-like, list of path-likes or a list of Datasets " + "instead of the given type: generator" ) - with pytest.raises(ValueError, match=expected): + with pytest.raises(TypeError, match=expected): ds.dataset((batch1 for _ in range(3))) expected = ( @@ -1717,49 +1718,32 @@ def test_construct_from_invalid_sources_raise(multisourcefs): def test_construct_in_memory(): batch = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"]) table = pa.Table.from_batches([batch]) - reader = pa.ipc.RecordBatchReader.from_batches(batch.schema, [batch]) - iterable = (batch for _ in range(1)) - for source in (batch, table, reader, [batch], [table]): - dataset = ds.dataset(source) - assert dataset.to_table() == table - - assert ds.dataset(iterable, schema=batch.schema).to_table().equals(table) assert ds.dataset([], schema=pa.schema([])).to_table() == pa.table([]) - # When constructed from batches/tables, should be reusable for source in (batch, table, [batch], [table]): dataset = ds.dataset(source) - assert len(list(dataset.get_fragments())) == 1 - assert len(list(dataset.get_fragments())) == 1 - assert dataset.to_table() == table assert dataset.to_table() == table + assert len(list(dataset.get_fragments())) == 1 assert next(dataset.get_fragments()).to_table() == table assert pa.Table.from_batches(list(dataset.to_batches())) == table + +def test_scan_iterator(): + batch = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"]) + table = pa.Table.from_batches([batch]) # When constructed from readers/iterators, should be one-shot - match = "InMemoryDataset was already consumed" - for factory in ( - lambda: pa.ipc.RecordBatchReader.from_batches( - batch.schema, [batch]), - lambda: (batch for _ in range(1)), + match = "OneShotFragment was already scanned" + for factory, schema in ( + (lambda: pa.ipc.RecordBatchReader.from_batches( + batch.schema, [batch]), None), + (lambda: (batch for _ in range(1)), batch.schema), ): - dataset = ds.dataset(factory(), schema=batch.schema) - # Getting fragments consumes the underlying iterator - fragments = list(dataset.get_fragments()) - assert len(fragments) == 1 - assert fragments[0].to_table() == table - with pytest.raises(pa.ArrowInvalid, match=match): - list(dataset.get_fragments()) - with pytest.raises(pa.ArrowInvalid, match=match): - dataset.to_table() - # Materializing consumes the underlying iterator - dataset = ds.dataset(factory(), schema=batch.schema) - assert dataset.to_table() == table - with pytest.raises(pa.ArrowInvalid, match=match): - list(dataset.get_fragments()) + # Scanning the fragment consumes the underlying iterator + scanner = ds.Scanner.from_batches(factory(), schema=schema) + assert scanner.to_table() == table with pytest.raises(pa.ArrowInvalid, match=match): - dataset.to_table() + scanner.to_table() @pytest.mark.parquet