From 2f8c651a201f0faef46b7ac9ae5fa6d66f37edc4 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 15 Apr 2021 11:23:27 -0400 Subject: [PATCH 1/4] ARROW-12231: [C++][Python][Dataset] Differentiate one-shot datasets --- cpp/src/arrow/dataset/dataset.cc | 97 +++++++++++++++----- cpp/src/arrow/dataset/dataset.h | 17 +++- cpp/src/arrow/dataset/dataset_test.cc | 21 ++++- python/pyarrow/_dataset.pyx | 8 +- python/pyarrow/includes/libarrow_dataset.pxd | 4 + python/pyarrow/tests/test_dataset.py | 13 ++- 6 files changed, 127 insertions(+), 33 deletions(-) diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index ab0600dd1a8..7ec71fdc99b 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -199,28 +199,6 @@ InMemoryDataset::InMemoryDataset(std::shared_ptr table) : Dataset(table->schema()), get_batches_(new TableRecordBatchGenerator(std::move(table))) {} -struct ReaderRecordBatchGenerator : InMemoryDataset::RecordBatchGenerator { - explicit ReaderRecordBatchGenerator(std::shared_ptr reader) - : reader_(std::move(reader)), consumed_(false) {} - - RecordBatchIterator Get() const final { - if (consumed_) { - return MakeErrorIterator>(Status::Invalid( - "RecordBatchReader-backed InMemoryDataset was already consumed")); - } - consumed_ = true; - auto reader = reader_; - return MakeFunctionIterator([reader] { return reader->Next(); }); - } - - std::shared_ptr reader_; - mutable bool consumed_; -}; - -InMemoryDataset::InMemoryDataset(std::shared_ptr reader) - : Dataset(reader->schema()), - get_batches_(new ReaderRecordBatchGenerator(std::move(reader))) {} - Result> InMemoryDataset::ReplaceSchema( std::shared_ptr schema) const { RETURN_NOT_OK(CheckProjectable(*schema_, *schema)); @@ -244,6 +222,81 @@ Result InMemoryDataset::GetFragmentsImpl(compute::Expression) return MakeMaybeMapIterator(std::move(create_fragment), std::move(batches_it)); } +struct ReaderRecordBatchGenerator : InMemoryDataset::RecordBatchGenerator { + explicit ReaderRecordBatchGenerator(std::shared_ptr reader) + : reader_(std::move(reader)), consumed_(false) {} + + RecordBatchIterator Get() const final { + if (consumed_) { + return MakeErrorIterator>( + Status::Invalid("OneShotDataset was already consumed")); + } + consumed_ = true; + auto reader = reader_; + return MakeFunctionIterator([reader] { return reader->Next(); }); + } + + std::shared_ptr reader_; + mutable bool consumed_; +}; + +OneShotDataset::OneShotDataset(std::shared_ptr reader) + : InMemoryDataset(reader->schema(), + std::make_shared(reader)) {} + +class ARROW_DS_EXPORT OneShotScanTask : public ScanTask { + public: + OneShotScanTask(RecordBatchIterator batch_it, std::shared_ptr options, + std::shared_ptr fragment) + : ScanTask(std::move(options), std::move(fragment)), + batch_it_(std::move(batch_it)) {} + Result Execute() override { + if (!batch_it_) return Status::Invalid("OneShotScanTask was already scanned"); + return std::move(batch_it_); + } + + private: + RecordBatchIterator batch_it_; +}; + +class ARROW_DS_EXPORT OneShotFragment : public Fragment { + public: + OneShotFragment(std::shared_ptr schema, RecordBatchIterator batch_it) + : Fragment(compute::literal(true), std::move(schema)), + batch_it_(std::move(batch_it)) { + DCHECK_NE(physical_schema_, nullptr); + } + Status CheckConsumed() { + if (!batch_it_) return Status::Invalid("OneShotFragment was already scanned"); + return Status::OK(); + } + Result Scan(std::shared_ptr options) override { + RETURN_NOT_OK(CheckConsumed()); + ScanTaskVector tasks{std::make_shared( + std::move(batch_it_), std::move(options), shared_from_this())}; + return MakeVectorIterator(std::move(tasks)); + } + Result ScanBatchesAsync( + const std::shared_ptr& options) override { + RETURN_NOT_OK(CheckConsumed()); + return MakeBackgroundGenerator(std::move(batch_it_), options->io_context.executor()); + } + std::string type_name() const override { return "one-shot"; } + + protected: + Result> ReadPhysicalSchemaImpl() override { + return physical_schema_; + } + + RecordBatchIterator batch_it_; +}; + +Result OneShotDataset::GetFragmentsImpl(compute::Expression predicate) { + FragmentVector fragments{ + std::make_shared(schema(), get_batches_->Get())}; + return MakeVectorIterator(std::move(fragments)); +} + Result> UnionDataset::Make(std::shared_ptr schema, DatasetVector children) { for (const auto& child : children) { diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h index 40a60ffd48e..958d94d728a 100644 --- a/cpp/src/arrow/dataset/dataset.h +++ b/cpp/src/arrow/dataset/dataset.h @@ -206,7 +206,6 @@ class ARROW_DS_EXPORT InMemoryDataset : public Dataset { /// Convenience constructor taking a Table explicit InMemoryDataset(std::shared_ptr
table); - explicit InMemoryDataset(std::shared_ptr reader); std::string type_name() const override { return "in-memory"; } @@ -219,6 +218,22 @@ class ARROW_DS_EXPORT InMemoryDataset : public Dataset { std::shared_ptr get_batches_; }; +/// \brief A Source which yields fragments wrapping a one-shot stream +/// of record batches. +/// +/// Unlike other datasets, this can be scanned only once. This is +/// intended to support writing data from streaming sources or other +/// sources that can be iterated only once. +class ARROW_DS_EXPORT OneShotDataset : public InMemoryDataset { + public: + /// Construct a dataset from a reader + explicit OneShotDataset(std::shared_ptr reader); + std::string type_name() const override { return "one-shot"; } + + protected: + Result GetFragmentsImpl(compute::Expression predicate) override; +}; + /// \brief A Dataset wrapping child Datasets. class ARROW_DS_EXPORT UnionDataset : public Dataset { public: diff --git a/cpp/src/arrow/dataset/dataset_test.cc b/cpp/src/arrow/dataset/dataset_test.cc index 7aa0e1a2413..5049bb09634 100644 --- a/cpp/src/arrow/dataset/dataset_test.cc +++ b/cpp/src/arrow/dataset/dataset_test.cc @@ -88,12 +88,27 @@ TEST_F(TestInMemoryDataset, FromReader) { auto source_reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch); auto target_reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch); - auto dataset = std::make_shared(source_reader); + auto dataset = std::make_shared(source_reader); AssertDatasetEquals(target_reader.get(), dataset.get()); - // Such datasets can only be scanned once + // Such datasets can only be scanned once (but you can get fragments multiple times) ASSERT_OK_AND_ASSIGN(auto fragments, dataset->GetFragments()); - ASSERT_RAISES(Invalid, fragments.Next()); + ASSERT_OK_AND_ASSIGN(auto fragment, fragments.Next()); + ASSERT_OK_AND_ASSIGN(auto scan_task_it, fragment->Scan(options_)); + ASSERT_OK_AND_ASSIGN(auto scan_task, scan_task_it.Next()); + ASSERT_OK_AND_ASSIGN(auto batch_it, scan_task->Execute()); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("OneShotDataset was already consumed"), + batch_it.Next()); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("OneShotScanTask was already scanned"), + scan_task->Execute()); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("OneShotFragment was already scanned"), + fragment->Scan(options_)); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("OneShotFragment was already scanned"), + fragment->ScanBatchesAsync(options_)); } TEST_F(TestInMemoryDataset, GetFragments) { diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index cf076f6536b..9414fa7861e 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -298,6 +298,8 @@ cdef class Dataset(_Weakrefable): classes = { 'union': UnionDataset, 'filesystem': FileSystemDataset, + 'in-memory': InMemoryDataset, + 'one-shot': InMemoryDataset, } class_ = classes.get(type_name, None) @@ -513,13 +515,15 @@ cdef class InMemoryDataset(Dataset): pyarrow_unwrap_table(table)) elif isinstance(source, pa.ipc.RecordBatchReader): reader = source - in_memory_dataset = make_shared[CInMemoryDataset](reader.reader) + in_memory_dataset = \ + make_shared[COneShotDataset](reader.reader) elif _is_iterable(source): if schema is None: raise ValueError('Must provide schema to construct in-memory ' 'dataset from an iterable') reader = pa.ipc.RecordBatchReader.from_batches(schema, source) - in_memory_dataset = make_shared[CInMemoryDataset](reader.reader) + in_memory_dataset = \ + make_shared[COneShotDataset](reader.reader) else: raise TypeError( 'Expected a table, batch, iterable of tables/batches, or a ' diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index bff1a2bbb54..f802c36d34f 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -151,6 +151,10 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: CInMemoryDataset(shared_ptr[CRecordBatchReader]) CInMemoryDataset(shared_ptr[CTable]) + cdef cppclass COneShotDataset "arrow::dataset::OneShotDataset"( + CInMemoryDataset): + COneShotDataset(shared_ptr[CRecordBatchReader]) + cdef cppclass CUnionDataset "arrow::dataset::UnionDataset"( CDataset): @staticmethod diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 8791c22f103..2e8a93a8ac2 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -1738,26 +1738,29 @@ def test_construct_in_memory(): assert pa.Table.from_batches(list(dataset.to_batches())) == table # When constructed from readers/iterators, should be one-shot - match = "InMemoryDataset was already consumed" + match = "OneShotDataset was already consumed" for factory in ( lambda: pa.ipc.RecordBatchReader.from_batches( batch.schema, [batch]), lambda: (batch for _ in range(1)), ): + # Scanning the fragment consumes the underlying iterator dataset = ds.dataset(factory(), schema=batch.schema) - # Getting fragments consumes the underlying iterator fragments = list(dataset.get_fragments()) assert len(fragments) == 1 assert fragments[0].to_table() == table + # But you can still get fragments + fragments = list(dataset.get_fragments()) with pytest.raises(pa.ArrowInvalid, match=match): - list(dataset.get_fragments()) + fragments[0].to_table() with pytest.raises(pa.ArrowInvalid, match=match): dataset.to_table() - # Materializing consumes the underlying iterator + # So does scanning the dataset dataset = ds.dataset(factory(), schema=batch.schema) assert dataset.to_table() == table + fragments = list(dataset.get_fragments()) with pytest.raises(pa.ArrowInvalid, match=match): - list(dataset.get_fragments()) + fragments[0].to_table() with pytest.raises(pa.ArrowInvalid, match=match): dataset.to_table() From c62ff07f385aaf0c7c149df2392d7c2d0d5e9e23 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 5 May 2021 14:40:12 -0400 Subject: [PATCH 2/4] ARROW-12231: [C++][Python][Dataset] Hide one-shotness in the scanner --- cpp/src/arrow/dataset/dataset.cc | 75 -------------------- cpp/src/arrow/dataset/dataset.h | 16 ----- cpp/src/arrow/dataset/dataset_test.cc | 32 --------- cpp/src/arrow/dataset/scanner.cc | 56 +++++++++++++++ cpp/src/arrow/dataset/scanner.h | 8 +++ cpp/src/arrow/dataset/scanner_test.cc | 24 +++++++ python/pyarrow/_dataset.pyx | 56 +++++++++++---- python/pyarrow/dataset.py | 12 ++-- python/pyarrow/includes/libarrow_dataset.pxd | 8 +-- python/pyarrow/tests/test_dataset.py | 51 +++++-------- 10 files changed, 156 insertions(+), 182 deletions(-) diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index 7ec71fdc99b..5afb6a06c16 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -222,81 +222,6 @@ Result InMemoryDataset::GetFragmentsImpl(compute::Expression) return MakeMaybeMapIterator(std::move(create_fragment), std::move(batches_it)); } -struct ReaderRecordBatchGenerator : InMemoryDataset::RecordBatchGenerator { - explicit ReaderRecordBatchGenerator(std::shared_ptr reader) - : reader_(std::move(reader)), consumed_(false) {} - - RecordBatchIterator Get() const final { - if (consumed_) { - return MakeErrorIterator>( - Status::Invalid("OneShotDataset was already consumed")); - } - consumed_ = true; - auto reader = reader_; - return MakeFunctionIterator([reader] { return reader->Next(); }); - } - - std::shared_ptr reader_; - mutable bool consumed_; -}; - -OneShotDataset::OneShotDataset(std::shared_ptr reader) - : InMemoryDataset(reader->schema(), - std::make_shared(reader)) {} - -class ARROW_DS_EXPORT OneShotScanTask : public ScanTask { - public: - OneShotScanTask(RecordBatchIterator batch_it, std::shared_ptr options, - std::shared_ptr fragment) - : ScanTask(std::move(options), std::move(fragment)), - batch_it_(std::move(batch_it)) {} - Result Execute() override { - if (!batch_it_) return Status::Invalid("OneShotScanTask was already scanned"); - return std::move(batch_it_); - } - - private: - RecordBatchIterator batch_it_; -}; - -class ARROW_DS_EXPORT OneShotFragment : public Fragment { - public: - OneShotFragment(std::shared_ptr schema, RecordBatchIterator batch_it) - : Fragment(compute::literal(true), std::move(schema)), - batch_it_(std::move(batch_it)) { - DCHECK_NE(physical_schema_, nullptr); - } - Status CheckConsumed() { - if (!batch_it_) return Status::Invalid("OneShotFragment was already scanned"); - return Status::OK(); - } - Result Scan(std::shared_ptr options) override { - RETURN_NOT_OK(CheckConsumed()); - ScanTaskVector tasks{std::make_shared( - std::move(batch_it_), std::move(options), shared_from_this())}; - return MakeVectorIterator(std::move(tasks)); - } - Result ScanBatchesAsync( - const std::shared_ptr& options) override { - RETURN_NOT_OK(CheckConsumed()); - return MakeBackgroundGenerator(std::move(batch_it_), options->io_context.executor()); - } - std::string type_name() const override { return "one-shot"; } - - protected: - Result> ReadPhysicalSchemaImpl() override { - return physical_schema_; - } - - RecordBatchIterator batch_it_; -}; - -Result OneShotDataset::GetFragmentsImpl(compute::Expression predicate) { - FragmentVector fragments{ - std::make_shared(schema(), get_batches_->Get())}; - return MakeVectorIterator(std::move(fragments)); -} - Result> UnionDataset::Make(std::shared_ptr schema, DatasetVector children) { for (const auto& child : children) { diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h index 958d94d728a..164b3ec17aa 100644 --- a/cpp/src/arrow/dataset/dataset.h +++ b/cpp/src/arrow/dataset/dataset.h @@ -218,22 +218,6 @@ class ARROW_DS_EXPORT InMemoryDataset : public Dataset { std::shared_ptr get_batches_; }; -/// \brief A Source which yields fragments wrapping a one-shot stream -/// of record batches. -/// -/// Unlike other datasets, this can be scanned only once. This is -/// intended to support writing data from streaming sources or other -/// sources that can be iterated only once. -class ARROW_DS_EXPORT OneShotDataset : public InMemoryDataset { - public: - /// Construct a dataset from a reader - explicit OneShotDataset(std::shared_ptr reader); - std::string type_name() const override { return "one-shot"; } - - protected: - Result GetFragmentsImpl(compute::Expression predicate) override; -}; - /// \brief A Dataset wrapping child Datasets. class ARROW_DS_EXPORT UnionDataset : public Dataset { public: diff --git a/cpp/src/arrow/dataset/dataset_test.cc b/cpp/src/arrow/dataset/dataset_test.cc index 5049bb09634..66d69c30c82 100644 --- a/cpp/src/arrow/dataset/dataset_test.cc +++ b/cpp/src/arrow/dataset/dataset_test.cc @@ -79,38 +79,6 @@ TEST_F(TestInMemoryDataset, ReplaceSchema) { .status()); } -TEST_F(TestInMemoryDataset, FromReader) { - constexpr int64_t kBatchSize = 1024; - constexpr int64_t kNumberBatches = 16; - - SetSchema({field("i32", int32()), field("f64", float64())}); - auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); - auto source_reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch); - auto target_reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch); - - auto dataset = std::make_shared(source_reader); - - AssertDatasetEquals(target_reader.get(), dataset.get()); - // Such datasets can only be scanned once (but you can get fragments multiple times) - ASSERT_OK_AND_ASSIGN(auto fragments, dataset->GetFragments()); - ASSERT_OK_AND_ASSIGN(auto fragment, fragments.Next()); - ASSERT_OK_AND_ASSIGN(auto scan_task_it, fragment->Scan(options_)); - ASSERT_OK_AND_ASSIGN(auto scan_task, scan_task_it.Next()); - ASSERT_OK_AND_ASSIGN(auto batch_it, scan_task->Execute()); - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, ::testing::HasSubstr("OneShotDataset was already consumed"), - batch_it.Next()); - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, ::testing::HasSubstr("OneShotScanTask was already scanned"), - scan_task->Execute()); - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, ::testing::HasSubstr("OneShotFragment was already scanned"), - fragment->Scan(options_)); - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, ::testing::HasSubstr("OneShotFragment was already scanned"), - fragment->ScanBatchesAsync(options_)); -} - TEST_F(TestInMemoryDataset, GetFragments) { constexpr int64_t kBatchSize = 1024; constexpr int64_t kNumberBatches = 16; diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 41fa7ec5c77..246ed4b0a01 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -673,6 +673,62 @@ ScannerBuilder::ScannerBuilder(std::shared_ptr schema, DCHECK_OK(Filter(scan_options_->filter)); } +class ARROW_DS_EXPORT OneShotScanTask : public ScanTask { + public: + OneShotScanTask(RecordBatchIterator batch_it, std::shared_ptr options, + std::shared_ptr fragment) + : ScanTask(std::move(options), std::move(fragment)), + batch_it_(std::move(batch_it)) {} + Result Execute() override { + if (!batch_it_) return Status::Invalid("OneShotScanTask was already scanned"); + return std::move(batch_it_); + } + + private: + RecordBatchIterator batch_it_; +}; + +class ARROW_DS_EXPORT OneShotFragment : public Fragment { + public: + OneShotFragment(std::shared_ptr schema, RecordBatchIterator batch_it) + : Fragment(compute::literal(true), std::move(schema)), + batch_it_(std::move(batch_it)) { + DCHECK_NE(physical_schema_, nullptr); + } + Status CheckConsumed() { + if (!batch_it_) return Status::Invalid("OneShotFragment was already scanned"); + return Status::OK(); + } + Result Scan(std::shared_ptr options) override { + RETURN_NOT_OK(CheckConsumed()); + ScanTaskVector tasks{std::make_shared( + std::move(batch_it_), std::move(options), shared_from_this())}; + return MakeVectorIterator(std::move(tasks)); + } + Result ScanBatchesAsync( + const std::shared_ptr& options) override { + RETURN_NOT_OK(CheckConsumed()); + return MakeBackgroundGenerator(std::move(batch_it_), options->io_context.executor()); + } + std::string type_name() const override { return "one-shot"; } + + protected: + Result> ReadPhysicalSchemaImpl() override { + return physical_schema_; + } + + RecordBatchIterator batch_it_; +}; + +std::shared_ptr ScannerBuilder::FromRecordBatchReader( + std::shared_ptr reader) { + auto batch_it = MakeFunctionIterator([reader] { return reader->Next(); }); + auto fragment = + std::make_shared(reader->schema(), std::move(batch_it)); + return std::make_shared(reader->schema(), std::move(fragment), + std::make_shared()); +} + const std::shared_ptr& ScannerBuilder::schema() const { return scan_options_->dataset_schema; } diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index 15bd27ab4f3..bbb79ee474f 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -312,6 +312,14 @@ class ARROW_DS_EXPORT ScannerBuilder { ScannerBuilder(std::shared_ptr schema, std::shared_ptr fragment, std::shared_ptr scan_options); + /// \brief Make a scanner from a record batch reader. + /// + /// The resulting scanner can be scanned only once. This is intended + /// to support writing data from streaming sources or other sources + /// that can be iterated only once. + static std::shared_ptr FromRecordBatchReader( + std::shared_ptr reader); + /// \brief Set the subset of columns to materialize. /// /// Columns which are not referenced may not be read from fragments. diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 17f4e079ae4..c44f28afb3b 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -507,6 +507,30 @@ TEST_P(TestScanner, Head) { AssertTablesEqual(*expected, *actual); } +TEST_P(TestScanner, FromReader) { + if (GetParam().use_async) { + GTEST_SKIP() << "Async scanner does not support construction from reader"; + } + auto batch_size = GetParam().items_per_batch; + auto num_batches = GetParam().num_batches; + + SetSchema({field("i32", int32()), field("f64", float64())}); + auto batch = ConstantArrayGenerator::Zeroes(batch_size, schema_); + auto source_reader = ConstantArrayGenerator::Repeat(num_batches, batch); + auto target_reader = ConstantArrayGenerator::Repeat(num_batches, batch); + + auto builder = ScannerBuilder::FromRecordBatchReader(source_reader); + ARROW_EXPECT_OK(builder->UseThreads(GetParam().use_threads)); + ASSERT_OK_AND_ASSIGN(auto scanner, builder->Finish()); + AssertScannerEquals(target_reader.get(), scanner.get()); + + // Such datasets can only be scanned once (but you can get fragments multiple times) + ASSERT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches()); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("OneShotFragment was already scanned"), + batch_it.Next()); +} + INSTANTIATE_TEST_SUITE_P(TestScannerThreading, TestScanner, ::testing::ValuesIn(TestScannerParams::Values())); diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 9414fa7861e..2a78a7fe712 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -299,7 +299,6 @@ cdef class Dataset(_Weakrefable): 'union': UnionDataset, 'filesystem': FileSystemDataset, 'in-memory': InMemoryDataset, - 'one-shot': InMemoryDataset, } class_ = classes.get(type_name, None) @@ -513,21 +512,10 @@ cdef class InMemoryDataset(Dataset): table = pa.Table.from_batches(batches, schema=schema) in_memory_dataset = make_shared[CInMemoryDataset]( pyarrow_unwrap_table(table)) - elif isinstance(source, pa.ipc.RecordBatchReader): - reader = source - in_memory_dataset = \ - make_shared[COneShotDataset](reader.reader) - elif _is_iterable(source): - if schema is None: - raise ValueError('Must provide schema to construct in-memory ' - 'dataset from an iterable') - reader = pa.ipc.RecordBatchReader.from_batches(schema, source) - in_memory_dataset = \ - make_shared[COneShotDataset](reader.reader) else: raise TypeError( - 'Expected a table, batch, iterable of tables/batches, or a ' - 'record batch reader instead of the given type: ' + + 'Expected a table, batch, or list of tables/batches ' + 'instead of the given type: ' + type(source).__name__ ) @@ -2755,6 +2743,46 @@ cdef class Scanner(_Weakrefable): scanner = GetResultValue(builder.get().Finish()) return Scanner.wrap(scanner) + @staticmethod + def from_batches(source, Schema schema=None, bint use_threads=True, + MemoryPool memory_pool=None, object columns=None, + Expression filter=None, + int batch_size=_DEFAULT_BATCH_SIZE, + FragmentScanOptions fragment_scan_options=None): + """Create a Scanner from an iterator of batches. + + This creates a scanner which can be used only once. It is + intended to support writing a dataset (which takes a scanner) + from a source which can be read only once (e.g. a + RecordBatchReader or generator). + """ + cdef: + shared_ptr[CScanOptions] options = make_shared[CScanOptions]() + shared_ptr[CScannerBuilder] builder + shared_ptr[CScanner] scanner + RecordBatchReader reader + if isinstance(source, pa.ipc.RecordBatchReader): + if schema: + raise ValueError('Cannot specify a schema when providing ' + 'a RecordBatchReader') + reader = source + elif _is_iterable(source): + if schema is None: + raise ValueError('Must provide schema to construct scanner ' + 'from an iterable') + reader = pa.ipc.RecordBatchReader.from_batches(schema, source) + else: + raise TypeError('Expected a RecordBatchReader or an iterable of ' + 'batches instead of the given type: ' + + type(source).__name__) + builder = CScannerBuilder.FromRecordBatchReader(reader.reader) + _populate_builder(builder, columns=columns, filter=filter, + batch_size=batch_size, use_threads=use_threads, + memory_pool=memory_pool, + fragment_scan_options=fragment_scan_options) + scanner = GetResultValue(builder.get().Finish()) + return Scanner.wrap(scanner) + @property def dataset_schema(self): """The schema with which batches will be read from fragments.""" diff --git a/python/pyarrow/dataset.py b/python/pyarrow/dataset.py index 97d08844f27..e80de1688e7 100644 --- a/python/pyarrow/dataset.py +++ b/python/pyarrow/dataset.py @@ -669,10 +669,7 @@ def dataset(source, schema=None, format=None, filesystem=None, 'of batches or tables. The given list contains the following ' 'types: {}'.format(type_names) ) - elif isinstance(source, (pa.RecordBatch, pa.ipc.RecordBatchReader, - pa.Table)): - return _in_memory_dataset(source, **kwargs) - elif _is_iterable(source): + elif isinstance(source, (pa.RecordBatch, pa.Table)): return _in_memory_dataset(source, **kwargs) else: raise TypeError( @@ -736,9 +733,12 @@ def write_dataset(data, base_dir, basename_template=None, format=None, if isinstance(data, (list, tuple)): schema = schema or data[0].schema data = InMemoryDataset(data, schema=schema) - elif isinstance(data, (pa.RecordBatch, pa.ipc.RecordBatchReader, - pa.Table)) or _is_iterable(data): + elif isinstance(data, (pa.RecordBatch, pa.Table)): + schema = schema or data.schema data = InMemoryDataset(data, schema=schema) + elif isinstance(data, pa.ipc.RecordBatchReader) or _is_iterable(data): + data = Scanner.from_batches(data, schema=schema) + schema = None elif not isinstance(data, (Dataset, Scanner)): raise ValueError( "Only Dataset, Scanner, Table/RecordBatch, RecordBatchReader, " diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index f802c36d34f..9105abcf7a0 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -121,6 +121,10 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: shared_ptr[CScanOptions] scan_options) CScannerBuilder(shared_ptr[CSchema], shared_ptr[CFragment], shared_ptr[CScanOptions] scan_options) + + @staticmethod + shared_ptr[CScannerBuilder] FromRecordBatchReader( + shared_ptr[CRecordBatchReader] reader) CStatus ProjectColumns "Project"(const vector[c_string]& columns) CStatus Project(vector[CExpression]& exprs, vector[c_string]& columns) CStatus Filter(CExpression filter) @@ -151,10 +155,6 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: CInMemoryDataset(shared_ptr[CRecordBatchReader]) CInMemoryDataset(shared_ptr[CTable]) - cdef cppclass COneShotDataset "arrow::dataset::OneShotDataset"( - CInMemoryDataset): - COneShotDataset(shared_ptr[CRecordBatchReader]) - cdef cppclass CUnionDataset "arrow::dataset::UnionDataset"( CDataset): @staticmethod diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 2e8a93a8ac2..20a0a607c1d 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -1682,9 +1682,10 @@ def test_construct_from_invalid_sources_raise(multisourcefs): ds.dataset(None) expected = ( - "Must provide schema to construct in-memory dataset from an iterable" + "Expected a path-like, list of path-likes or a list of Datasets " + "instead of the given type: generator" ) - with pytest.raises(ValueError, match=expected): + with pytest.raises(TypeError, match=expected): ds.dataset((batch1 for _ in range(3))) expected = ( @@ -1717,52 +1718,32 @@ def test_construct_from_invalid_sources_raise(multisourcefs): def test_construct_in_memory(): batch = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"]) table = pa.Table.from_batches([batch]) - reader = pa.ipc.RecordBatchReader.from_batches(batch.schema, [batch]) - iterable = (batch for _ in range(1)) - - for source in (batch, table, reader, [batch], [table]): - dataset = ds.dataset(source) - assert dataset.to_table() == table - assert ds.dataset(iterable, schema=batch.schema).to_table().equals(table) assert ds.dataset([], schema=pa.schema([])).to_table() == pa.table([]) - # When constructed from batches/tables, should be reusable for source in (batch, table, [batch], [table]): dataset = ds.dataset(source) - assert len(list(dataset.get_fragments())) == 1 - assert len(list(dataset.get_fragments())) == 1 - assert dataset.to_table() == table assert dataset.to_table() == table + assert len(list(dataset.get_fragments())) == 1 assert next(dataset.get_fragments()).to_table() == table assert pa.Table.from_batches(list(dataset.to_batches())) == table + +def test_scan_iterator(): + batch = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"]) + table = pa.Table.from_batches([batch]) # When constructed from readers/iterators, should be one-shot - match = "OneShotDataset was already consumed" - for factory in ( - lambda: pa.ipc.RecordBatchReader.from_batches( - batch.schema, [batch]), - lambda: (batch for _ in range(1)), + match = "OneShotFragment was already scanned" + for factory, schema in ( + (lambda: pa.ipc.RecordBatchReader.from_batches( + batch.schema, [batch]), None), + (lambda: (batch for _ in range(1)), batch.schema), ): # Scanning the fragment consumes the underlying iterator - dataset = ds.dataset(factory(), schema=batch.schema) - fragments = list(dataset.get_fragments()) - assert len(fragments) == 1 - assert fragments[0].to_table() == table - # But you can still get fragments - fragments = list(dataset.get_fragments()) - with pytest.raises(pa.ArrowInvalid, match=match): - fragments[0].to_table() - with pytest.raises(pa.ArrowInvalid, match=match): - dataset.to_table() - # So does scanning the dataset - dataset = ds.dataset(factory(), schema=batch.schema) - assert dataset.to_table() == table - fragments = list(dataset.get_fragments()) - with pytest.raises(pa.ArrowInvalid, match=match): - fragments[0].to_table() + scanner = ds.Scanner.from_batches(factory(), schema=schema) + assert scanner.to_table() == table with pytest.raises(pa.ArrowInvalid, match=match): - dataset.to_table() + scanner.to_table() @pytest.mark.parquet From 8804db785ef273f23b22e74804df954b2fb941d2 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 5 May 2021 21:01:14 -0400 Subject: [PATCH 3/4] ARROW-12231: [C++][Dataset] Address review feedback --- cpp/src/arrow/dataset/scanner.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 246ed4b0a01..6c5d9e61b21 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -673,6 +673,7 @@ ScannerBuilder::ScannerBuilder(std::shared_ptr schema, DCHECK_OK(Filter(scan_options_->filter)); } +namespace { class ARROW_DS_EXPORT OneShotScanTask : public ScanTask { public: OneShotScanTask(RecordBatchIterator batch_it, std::shared_ptr options, @@ -719,10 +720,11 @@ class ARROW_DS_EXPORT OneShotFragment : public Fragment { RecordBatchIterator batch_it_; }; +} // namespace std::shared_ptr ScannerBuilder::FromRecordBatchReader( std::shared_ptr reader) { - auto batch_it = MakeFunctionIterator([reader] { return reader->Next(); }); + auto batch_it = MakeIteratorFromReader(reader); auto fragment = std::make_shared(reader->schema(), std::move(batch_it)); return std::make_shared(reader->schema(), std::move(fragment), From f0162bc1e4969b9194d5838511d1ed0beaed8947 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 5 May 2021 21:15:39 -0400 Subject: [PATCH 4/4] ARROW-12231: [C++][Dataset] Make MSVC happy --- cpp/src/arrow/dataset/scanner.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 6c5d9e61b21..3c2c7f65ee7 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -674,7 +674,7 @@ ScannerBuilder::ScannerBuilder(std::shared_ptr schema, } namespace { -class ARROW_DS_EXPORT OneShotScanTask : public ScanTask { +class OneShotScanTask : public ScanTask { public: OneShotScanTask(RecordBatchIterator batch_it, std::shared_ptr options, std::shared_ptr fragment) @@ -689,7 +689,7 @@ class ARROW_DS_EXPORT OneShotScanTask : public ScanTask { RecordBatchIterator batch_it_; }; -class ARROW_DS_EXPORT OneShotFragment : public Fragment { +class OneShotFragment : public Fragment { public: OneShotFragment(std::shared_ptr schema, RecordBatchIterator batch_it) : Fragment(compute::literal(true), std::move(schema)),