From 2f8c651a201f0faef46b7ac9ae5fa6d66f37edc4 Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 15 Apr 2021 11:23:27 -0400
Subject: [PATCH 1/4] ARROW-12231: [C++][Python][Dataset] Differentiate
one-shot datasets
---
cpp/src/arrow/dataset/dataset.cc | 97 +++++++++++++++-----
cpp/src/arrow/dataset/dataset.h | 17 +++-
cpp/src/arrow/dataset/dataset_test.cc | 21 ++++-
python/pyarrow/_dataset.pyx | 8 +-
python/pyarrow/includes/libarrow_dataset.pxd | 4 +
python/pyarrow/tests/test_dataset.py | 13 ++-
6 files changed, 127 insertions(+), 33 deletions(-)
diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc
index ab0600dd1a8..7ec71fdc99b 100644
--- a/cpp/src/arrow/dataset/dataset.cc
+++ b/cpp/src/arrow/dataset/dataset.cc
@@ -199,28 +199,6 @@ InMemoryDataset::InMemoryDataset(std::shared_ptr table)
: Dataset(table->schema()),
get_batches_(new TableRecordBatchGenerator(std::move(table))) {}
-struct ReaderRecordBatchGenerator : InMemoryDataset::RecordBatchGenerator {
- explicit ReaderRecordBatchGenerator(std::shared_ptr reader)
- : reader_(std::move(reader)), consumed_(false) {}
-
- RecordBatchIterator Get() const final {
- if (consumed_) {
- return MakeErrorIterator>(Status::Invalid(
- "RecordBatchReader-backed InMemoryDataset was already consumed"));
- }
- consumed_ = true;
- auto reader = reader_;
- return MakeFunctionIterator([reader] { return reader->Next(); });
- }
-
- std::shared_ptr reader_;
- mutable bool consumed_;
-};
-
-InMemoryDataset::InMemoryDataset(std::shared_ptr reader)
- : Dataset(reader->schema()),
- get_batches_(new ReaderRecordBatchGenerator(std::move(reader))) {}
-
Result> InMemoryDataset::ReplaceSchema(
std::shared_ptr schema) const {
RETURN_NOT_OK(CheckProjectable(*schema_, *schema));
@@ -244,6 +222,81 @@ Result InMemoryDataset::GetFragmentsImpl(compute::Expression)
return MakeMaybeMapIterator(std::move(create_fragment), std::move(batches_it));
}
+struct ReaderRecordBatchGenerator : InMemoryDataset::RecordBatchGenerator {
+ explicit ReaderRecordBatchGenerator(std::shared_ptr reader)
+ : reader_(std::move(reader)), consumed_(false) {}
+
+ RecordBatchIterator Get() const final {
+ if (consumed_) {
+ return MakeErrorIterator>(
+ Status::Invalid("OneShotDataset was already consumed"));
+ }
+ consumed_ = true;
+ auto reader = reader_;
+ return MakeFunctionIterator([reader] { return reader->Next(); });
+ }
+
+ std::shared_ptr reader_;
+ mutable bool consumed_;
+};
+
+OneShotDataset::OneShotDataset(std::shared_ptr reader)
+ : InMemoryDataset(reader->schema(),
+ std::make_shared(reader)) {}
+
+class ARROW_DS_EXPORT OneShotScanTask : public ScanTask {
+ public:
+ OneShotScanTask(RecordBatchIterator batch_it, std::shared_ptr options,
+ std::shared_ptr fragment)
+ : ScanTask(std::move(options), std::move(fragment)),
+ batch_it_(std::move(batch_it)) {}
+ Result Execute() override {
+ if (!batch_it_) return Status::Invalid("OneShotScanTask was already scanned");
+ return std::move(batch_it_);
+ }
+
+ private:
+ RecordBatchIterator batch_it_;
+};
+
+class ARROW_DS_EXPORT OneShotFragment : public Fragment {
+ public:
+ OneShotFragment(std::shared_ptr schema, RecordBatchIterator batch_it)
+ : Fragment(compute::literal(true), std::move(schema)),
+ batch_it_(std::move(batch_it)) {
+ DCHECK_NE(physical_schema_, nullptr);
+ }
+ Status CheckConsumed() {
+ if (!batch_it_) return Status::Invalid("OneShotFragment was already scanned");
+ return Status::OK();
+ }
+ Result Scan(std::shared_ptr options) override {
+ RETURN_NOT_OK(CheckConsumed());
+ ScanTaskVector tasks{std::make_shared(
+ std::move(batch_it_), std::move(options), shared_from_this())};
+ return MakeVectorIterator(std::move(tasks));
+ }
+ Result ScanBatchesAsync(
+ const std::shared_ptr& options) override {
+ RETURN_NOT_OK(CheckConsumed());
+ return MakeBackgroundGenerator(std::move(batch_it_), options->io_context.executor());
+ }
+ std::string type_name() const override { return "one-shot"; }
+
+ protected:
+ Result> ReadPhysicalSchemaImpl() override {
+ return physical_schema_;
+ }
+
+ RecordBatchIterator batch_it_;
+};
+
+Result OneShotDataset::GetFragmentsImpl(compute::Expression predicate) {
+ FragmentVector fragments{
+ std::make_shared(schema(), get_batches_->Get())};
+ return MakeVectorIterator(std::move(fragments));
+}
+
Result> UnionDataset::Make(std::shared_ptr schema,
DatasetVector children) {
for (const auto& child : children) {
diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h
index 40a60ffd48e..958d94d728a 100644
--- a/cpp/src/arrow/dataset/dataset.h
+++ b/cpp/src/arrow/dataset/dataset.h
@@ -206,7 +206,6 @@ class ARROW_DS_EXPORT InMemoryDataset : public Dataset {
/// Convenience constructor taking a Table
explicit InMemoryDataset(std::shared_ptr table);
- explicit InMemoryDataset(std::shared_ptr reader);
std::string type_name() const override { return "in-memory"; }
@@ -219,6 +218,22 @@ class ARROW_DS_EXPORT InMemoryDataset : public Dataset {
std::shared_ptr get_batches_;
};
+/// \brief A Source which yields fragments wrapping a one-shot stream
+/// of record batches.
+///
+/// Unlike other datasets, this can be scanned only once. This is
+/// intended to support writing data from streaming sources or other
+/// sources that can be iterated only once.
+class ARROW_DS_EXPORT OneShotDataset : public InMemoryDataset {
+ public:
+ /// Construct a dataset from a reader
+ explicit OneShotDataset(std::shared_ptr reader);
+ std::string type_name() const override { return "one-shot"; }
+
+ protected:
+ Result GetFragmentsImpl(compute::Expression predicate) override;
+};
+
/// \brief A Dataset wrapping child Datasets.
class ARROW_DS_EXPORT UnionDataset : public Dataset {
public:
diff --git a/cpp/src/arrow/dataset/dataset_test.cc b/cpp/src/arrow/dataset/dataset_test.cc
index 7aa0e1a2413..5049bb09634 100644
--- a/cpp/src/arrow/dataset/dataset_test.cc
+++ b/cpp/src/arrow/dataset/dataset_test.cc
@@ -88,12 +88,27 @@ TEST_F(TestInMemoryDataset, FromReader) {
auto source_reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch);
auto target_reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch);
- auto dataset = std::make_shared(source_reader);
+ auto dataset = std::make_shared(source_reader);
AssertDatasetEquals(target_reader.get(), dataset.get());
- // Such datasets can only be scanned once
+ // Such datasets can only be scanned once (but you can get fragments multiple times)
ASSERT_OK_AND_ASSIGN(auto fragments, dataset->GetFragments());
- ASSERT_RAISES(Invalid, fragments.Next());
+ ASSERT_OK_AND_ASSIGN(auto fragment, fragments.Next());
+ ASSERT_OK_AND_ASSIGN(auto scan_task_it, fragment->Scan(options_));
+ ASSERT_OK_AND_ASSIGN(auto scan_task, scan_task_it.Next());
+ ASSERT_OK_AND_ASSIGN(auto batch_it, scan_task->Execute());
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("OneShotDataset was already consumed"),
+ batch_it.Next());
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("OneShotScanTask was already scanned"),
+ scan_task->Execute());
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("OneShotFragment was already scanned"),
+ fragment->Scan(options_));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("OneShotFragment was already scanned"),
+ fragment->ScanBatchesAsync(options_));
}
TEST_F(TestInMemoryDataset, GetFragments) {
diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx
index cf076f6536b..9414fa7861e 100644
--- a/python/pyarrow/_dataset.pyx
+++ b/python/pyarrow/_dataset.pyx
@@ -298,6 +298,8 @@ cdef class Dataset(_Weakrefable):
classes = {
'union': UnionDataset,
'filesystem': FileSystemDataset,
+ 'in-memory': InMemoryDataset,
+ 'one-shot': InMemoryDataset,
}
class_ = classes.get(type_name, None)
@@ -513,13 +515,15 @@ cdef class InMemoryDataset(Dataset):
pyarrow_unwrap_table(table))
elif isinstance(source, pa.ipc.RecordBatchReader):
reader = source
- in_memory_dataset = make_shared[CInMemoryDataset](reader.reader)
+ in_memory_dataset = \
+ make_shared[COneShotDataset](reader.reader)
elif _is_iterable(source):
if schema is None:
raise ValueError('Must provide schema to construct in-memory '
'dataset from an iterable')
reader = pa.ipc.RecordBatchReader.from_batches(schema, source)
- in_memory_dataset = make_shared[CInMemoryDataset](reader.reader)
+ in_memory_dataset = \
+ make_shared[COneShotDataset](reader.reader)
else:
raise TypeError(
'Expected a table, batch, iterable of tables/batches, or a '
diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd
index bff1a2bbb54..f802c36d34f 100644
--- a/python/pyarrow/includes/libarrow_dataset.pxd
+++ b/python/pyarrow/includes/libarrow_dataset.pxd
@@ -151,6 +151,10 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil:
CInMemoryDataset(shared_ptr[CRecordBatchReader])
CInMemoryDataset(shared_ptr[CTable])
+ cdef cppclass COneShotDataset "arrow::dataset::OneShotDataset"(
+ CInMemoryDataset):
+ COneShotDataset(shared_ptr[CRecordBatchReader])
+
cdef cppclass CUnionDataset "arrow::dataset::UnionDataset"(
CDataset):
@staticmethod
diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py
index 8791c22f103..2e8a93a8ac2 100644
--- a/python/pyarrow/tests/test_dataset.py
+++ b/python/pyarrow/tests/test_dataset.py
@@ -1738,26 +1738,29 @@ def test_construct_in_memory():
assert pa.Table.from_batches(list(dataset.to_batches())) == table
# When constructed from readers/iterators, should be one-shot
- match = "InMemoryDataset was already consumed"
+ match = "OneShotDataset was already consumed"
for factory in (
lambda: pa.ipc.RecordBatchReader.from_batches(
batch.schema, [batch]),
lambda: (batch for _ in range(1)),
):
+ # Scanning the fragment consumes the underlying iterator
dataset = ds.dataset(factory(), schema=batch.schema)
- # Getting fragments consumes the underlying iterator
fragments = list(dataset.get_fragments())
assert len(fragments) == 1
assert fragments[0].to_table() == table
+ # But you can still get fragments
+ fragments = list(dataset.get_fragments())
with pytest.raises(pa.ArrowInvalid, match=match):
- list(dataset.get_fragments())
+ fragments[0].to_table()
with pytest.raises(pa.ArrowInvalid, match=match):
dataset.to_table()
- # Materializing consumes the underlying iterator
+ # So does scanning the dataset
dataset = ds.dataset(factory(), schema=batch.schema)
assert dataset.to_table() == table
+ fragments = list(dataset.get_fragments())
with pytest.raises(pa.ArrowInvalid, match=match):
- list(dataset.get_fragments())
+ fragments[0].to_table()
with pytest.raises(pa.ArrowInvalid, match=match):
dataset.to_table()
From c62ff07f385aaf0c7c149df2392d7c2d0d5e9e23 Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 5 May 2021 14:40:12 -0400
Subject: [PATCH 2/4] ARROW-12231: [C++][Python][Dataset] Hide one-shotness in
the scanner
---
cpp/src/arrow/dataset/dataset.cc | 75 --------------------
cpp/src/arrow/dataset/dataset.h | 16 -----
cpp/src/arrow/dataset/dataset_test.cc | 32 ---------
cpp/src/arrow/dataset/scanner.cc | 56 +++++++++++++++
cpp/src/arrow/dataset/scanner.h | 8 +++
cpp/src/arrow/dataset/scanner_test.cc | 24 +++++++
python/pyarrow/_dataset.pyx | 56 +++++++++++----
python/pyarrow/dataset.py | 12 ++--
python/pyarrow/includes/libarrow_dataset.pxd | 8 +--
python/pyarrow/tests/test_dataset.py | 51 +++++--------
10 files changed, 156 insertions(+), 182 deletions(-)
diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc
index 7ec71fdc99b..5afb6a06c16 100644
--- a/cpp/src/arrow/dataset/dataset.cc
+++ b/cpp/src/arrow/dataset/dataset.cc
@@ -222,81 +222,6 @@ Result InMemoryDataset::GetFragmentsImpl(compute::Expression)
return MakeMaybeMapIterator(std::move(create_fragment), std::move(batches_it));
}
-struct ReaderRecordBatchGenerator : InMemoryDataset::RecordBatchGenerator {
- explicit ReaderRecordBatchGenerator(std::shared_ptr reader)
- : reader_(std::move(reader)), consumed_(false) {}
-
- RecordBatchIterator Get() const final {
- if (consumed_) {
- return MakeErrorIterator>(
- Status::Invalid("OneShotDataset was already consumed"));
- }
- consumed_ = true;
- auto reader = reader_;
- return MakeFunctionIterator([reader] { return reader->Next(); });
- }
-
- std::shared_ptr reader_;
- mutable bool consumed_;
-};
-
-OneShotDataset::OneShotDataset(std::shared_ptr reader)
- : InMemoryDataset(reader->schema(),
- std::make_shared(reader)) {}
-
-class ARROW_DS_EXPORT OneShotScanTask : public ScanTask {
- public:
- OneShotScanTask(RecordBatchIterator batch_it, std::shared_ptr options,
- std::shared_ptr fragment)
- : ScanTask(std::move(options), std::move(fragment)),
- batch_it_(std::move(batch_it)) {}
- Result Execute() override {
- if (!batch_it_) return Status::Invalid("OneShotScanTask was already scanned");
- return std::move(batch_it_);
- }
-
- private:
- RecordBatchIterator batch_it_;
-};
-
-class ARROW_DS_EXPORT OneShotFragment : public Fragment {
- public:
- OneShotFragment(std::shared_ptr schema, RecordBatchIterator batch_it)
- : Fragment(compute::literal(true), std::move(schema)),
- batch_it_(std::move(batch_it)) {
- DCHECK_NE(physical_schema_, nullptr);
- }
- Status CheckConsumed() {
- if (!batch_it_) return Status::Invalid("OneShotFragment was already scanned");
- return Status::OK();
- }
- Result Scan(std::shared_ptr options) override {
- RETURN_NOT_OK(CheckConsumed());
- ScanTaskVector tasks{std::make_shared(
- std::move(batch_it_), std::move(options), shared_from_this())};
- return MakeVectorIterator(std::move(tasks));
- }
- Result ScanBatchesAsync(
- const std::shared_ptr& options) override {
- RETURN_NOT_OK(CheckConsumed());
- return MakeBackgroundGenerator(std::move(batch_it_), options->io_context.executor());
- }
- std::string type_name() const override { return "one-shot"; }
-
- protected:
- Result> ReadPhysicalSchemaImpl() override {
- return physical_schema_;
- }
-
- RecordBatchIterator batch_it_;
-};
-
-Result OneShotDataset::GetFragmentsImpl(compute::Expression predicate) {
- FragmentVector fragments{
- std::make_shared(schema(), get_batches_->Get())};
- return MakeVectorIterator(std::move(fragments));
-}
-
Result> UnionDataset::Make(std::shared_ptr schema,
DatasetVector children) {
for (const auto& child : children) {
diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h
index 958d94d728a..164b3ec17aa 100644
--- a/cpp/src/arrow/dataset/dataset.h
+++ b/cpp/src/arrow/dataset/dataset.h
@@ -218,22 +218,6 @@ class ARROW_DS_EXPORT InMemoryDataset : public Dataset {
std::shared_ptr get_batches_;
};
-/// \brief A Source which yields fragments wrapping a one-shot stream
-/// of record batches.
-///
-/// Unlike other datasets, this can be scanned only once. This is
-/// intended to support writing data from streaming sources or other
-/// sources that can be iterated only once.
-class ARROW_DS_EXPORT OneShotDataset : public InMemoryDataset {
- public:
- /// Construct a dataset from a reader
- explicit OneShotDataset(std::shared_ptr reader);
- std::string type_name() const override { return "one-shot"; }
-
- protected:
- Result GetFragmentsImpl(compute::Expression predicate) override;
-};
-
/// \brief A Dataset wrapping child Datasets.
class ARROW_DS_EXPORT UnionDataset : public Dataset {
public:
diff --git a/cpp/src/arrow/dataset/dataset_test.cc b/cpp/src/arrow/dataset/dataset_test.cc
index 5049bb09634..66d69c30c82 100644
--- a/cpp/src/arrow/dataset/dataset_test.cc
+++ b/cpp/src/arrow/dataset/dataset_test.cc
@@ -79,38 +79,6 @@ TEST_F(TestInMemoryDataset, ReplaceSchema) {
.status());
}
-TEST_F(TestInMemoryDataset, FromReader) {
- constexpr int64_t kBatchSize = 1024;
- constexpr int64_t kNumberBatches = 16;
-
- SetSchema({field("i32", int32()), field("f64", float64())});
- auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
- auto source_reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch);
- auto target_reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch);
-
- auto dataset = std::make_shared(source_reader);
-
- AssertDatasetEquals(target_reader.get(), dataset.get());
- // Such datasets can only be scanned once (but you can get fragments multiple times)
- ASSERT_OK_AND_ASSIGN(auto fragments, dataset->GetFragments());
- ASSERT_OK_AND_ASSIGN(auto fragment, fragments.Next());
- ASSERT_OK_AND_ASSIGN(auto scan_task_it, fragment->Scan(options_));
- ASSERT_OK_AND_ASSIGN(auto scan_task, scan_task_it.Next());
- ASSERT_OK_AND_ASSIGN(auto batch_it, scan_task->Execute());
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- Invalid, ::testing::HasSubstr("OneShotDataset was already consumed"),
- batch_it.Next());
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- Invalid, ::testing::HasSubstr("OneShotScanTask was already scanned"),
- scan_task->Execute());
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- Invalid, ::testing::HasSubstr("OneShotFragment was already scanned"),
- fragment->Scan(options_));
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- Invalid, ::testing::HasSubstr("OneShotFragment was already scanned"),
- fragment->ScanBatchesAsync(options_));
-}
-
TEST_F(TestInMemoryDataset, GetFragments) {
constexpr int64_t kBatchSize = 1024;
constexpr int64_t kNumberBatches = 16;
diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc
index 41fa7ec5c77..246ed4b0a01 100644
--- a/cpp/src/arrow/dataset/scanner.cc
+++ b/cpp/src/arrow/dataset/scanner.cc
@@ -673,6 +673,62 @@ ScannerBuilder::ScannerBuilder(std::shared_ptr schema,
DCHECK_OK(Filter(scan_options_->filter));
}
+class ARROW_DS_EXPORT OneShotScanTask : public ScanTask {
+ public:
+ OneShotScanTask(RecordBatchIterator batch_it, std::shared_ptr options,
+ std::shared_ptr fragment)
+ : ScanTask(std::move(options), std::move(fragment)),
+ batch_it_(std::move(batch_it)) {}
+ Result Execute() override {
+ if (!batch_it_) return Status::Invalid("OneShotScanTask was already scanned");
+ return std::move(batch_it_);
+ }
+
+ private:
+ RecordBatchIterator batch_it_;
+};
+
+class ARROW_DS_EXPORT OneShotFragment : public Fragment {
+ public:
+ OneShotFragment(std::shared_ptr schema, RecordBatchIterator batch_it)
+ : Fragment(compute::literal(true), std::move(schema)),
+ batch_it_(std::move(batch_it)) {
+ DCHECK_NE(physical_schema_, nullptr);
+ }
+ Status CheckConsumed() {
+ if (!batch_it_) return Status::Invalid("OneShotFragment was already scanned");
+ return Status::OK();
+ }
+ Result Scan(std::shared_ptr options) override {
+ RETURN_NOT_OK(CheckConsumed());
+ ScanTaskVector tasks{std::make_shared(
+ std::move(batch_it_), std::move(options), shared_from_this())};
+ return MakeVectorIterator(std::move(tasks));
+ }
+ Result ScanBatchesAsync(
+ const std::shared_ptr& options) override {
+ RETURN_NOT_OK(CheckConsumed());
+ return MakeBackgroundGenerator(std::move(batch_it_), options->io_context.executor());
+ }
+ std::string type_name() const override { return "one-shot"; }
+
+ protected:
+ Result> ReadPhysicalSchemaImpl() override {
+ return physical_schema_;
+ }
+
+ RecordBatchIterator batch_it_;
+};
+
+std::shared_ptr ScannerBuilder::FromRecordBatchReader(
+ std::shared_ptr reader) {
+ auto batch_it = MakeFunctionIterator([reader] { return reader->Next(); });
+ auto fragment =
+ std::make_shared(reader->schema(), std::move(batch_it));
+ return std::make_shared(reader->schema(), std::move(fragment),
+ std::make_shared());
+}
+
const std::shared_ptr& ScannerBuilder::schema() const {
return scan_options_->dataset_schema;
}
diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h
index 15bd27ab4f3..bbb79ee474f 100644
--- a/cpp/src/arrow/dataset/scanner.h
+++ b/cpp/src/arrow/dataset/scanner.h
@@ -312,6 +312,14 @@ class ARROW_DS_EXPORT ScannerBuilder {
ScannerBuilder(std::shared_ptr schema, std::shared_ptr fragment,
std::shared_ptr scan_options);
+ /// \brief Make a scanner from a record batch reader.
+ ///
+ /// The resulting scanner can be scanned only once. This is intended
+ /// to support writing data from streaming sources or other sources
+ /// that can be iterated only once.
+ static std::shared_ptr FromRecordBatchReader(
+ std::shared_ptr reader);
+
/// \brief Set the subset of columns to materialize.
///
/// Columns which are not referenced may not be read from fragments.
diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc
index 17f4e079ae4..c44f28afb3b 100644
--- a/cpp/src/arrow/dataset/scanner_test.cc
+++ b/cpp/src/arrow/dataset/scanner_test.cc
@@ -507,6 +507,30 @@ TEST_P(TestScanner, Head) {
AssertTablesEqual(*expected, *actual);
}
+TEST_P(TestScanner, FromReader) {
+ if (GetParam().use_async) {
+ GTEST_SKIP() << "Async scanner does not support construction from reader";
+ }
+ auto batch_size = GetParam().items_per_batch;
+ auto num_batches = GetParam().num_batches;
+
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(batch_size, schema_);
+ auto source_reader = ConstantArrayGenerator::Repeat(num_batches, batch);
+ auto target_reader = ConstantArrayGenerator::Repeat(num_batches, batch);
+
+ auto builder = ScannerBuilder::FromRecordBatchReader(source_reader);
+ ARROW_EXPECT_OK(builder->UseThreads(GetParam().use_threads));
+ ASSERT_OK_AND_ASSIGN(auto scanner, builder->Finish());
+ AssertScannerEquals(target_reader.get(), scanner.get());
+
+ // Such datasets can only be scanned once (but you can get fragments multiple times)
+ ASSERT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches());
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("OneShotFragment was already scanned"),
+ batch_it.Next());
+}
+
INSTANTIATE_TEST_SUITE_P(TestScannerThreading, TestScanner,
::testing::ValuesIn(TestScannerParams::Values()));
diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx
index 9414fa7861e..2a78a7fe712 100644
--- a/python/pyarrow/_dataset.pyx
+++ b/python/pyarrow/_dataset.pyx
@@ -299,7 +299,6 @@ cdef class Dataset(_Weakrefable):
'union': UnionDataset,
'filesystem': FileSystemDataset,
'in-memory': InMemoryDataset,
- 'one-shot': InMemoryDataset,
}
class_ = classes.get(type_name, None)
@@ -513,21 +512,10 @@ cdef class InMemoryDataset(Dataset):
table = pa.Table.from_batches(batches, schema=schema)
in_memory_dataset = make_shared[CInMemoryDataset](
pyarrow_unwrap_table(table))
- elif isinstance(source, pa.ipc.RecordBatchReader):
- reader = source
- in_memory_dataset = \
- make_shared[COneShotDataset](reader.reader)
- elif _is_iterable(source):
- if schema is None:
- raise ValueError('Must provide schema to construct in-memory '
- 'dataset from an iterable')
- reader = pa.ipc.RecordBatchReader.from_batches(schema, source)
- in_memory_dataset = \
- make_shared[COneShotDataset](reader.reader)
else:
raise TypeError(
- 'Expected a table, batch, iterable of tables/batches, or a '
- 'record batch reader instead of the given type: ' +
+ 'Expected a table, batch, or list of tables/batches '
+ 'instead of the given type: ' +
type(source).__name__
)
@@ -2755,6 +2743,46 @@ cdef class Scanner(_Weakrefable):
scanner = GetResultValue(builder.get().Finish())
return Scanner.wrap(scanner)
+ @staticmethod
+ def from_batches(source, Schema schema=None, bint use_threads=True,
+ MemoryPool memory_pool=None, object columns=None,
+ Expression filter=None,
+ int batch_size=_DEFAULT_BATCH_SIZE,
+ FragmentScanOptions fragment_scan_options=None):
+ """Create a Scanner from an iterator of batches.
+
+ This creates a scanner which can be used only once. It is
+ intended to support writing a dataset (which takes a scanner)
+ from a source which can be read only once (e.g. a
+ RecordBatchReader or generator).
+ """
+ cdef:
+ shared_ptr[CScanOptions] options = make_shared[CScanOptions]()
+ shared_ptr[CScannerBuilder] builder
+ shared_ptr[CScanner] scanner
+ RecordBatchReader reader
+ if isinstance(source, pa.ipc.RecordBatchReader):
+ if schema:
+ raise ValueError('Cannot specify a schema when providing '
+ 'a RecordBatchReader')
+ reader = source
+ elif _is_iterable(source):
+ if schema is None:
+ raise ValueError('Must provide schema to construct scanner '
+ 'from an iterable')
+ reader = pa.ipc.RecordBatchReader.from_batches(schema, source)
+ else:
+ raise TypeError('Expected a RecordBatchReader or an iterable of '
+ 'batches instead of the given type: ' +
+ type(source).__name__)
+ builder = CScannerBuilder.FromRecordBatchReader(reader.reader)
+ _populate_builder(builder, columns=columns, filter=filter,
+ batch_size=batch_size, use_threads=use_threads,
+ memory_pool=memory_pool,
+ fragment_scan_options=fragment_scan_options)
+ scanner = GetResultValue(builder.get().Finish())
+ return Scanner.wrap(scanner)
+
@property
def dataset_schema(self):
"""The schema with which batches will be read from fragments."""
diff --git a/python/pyarrow/dataset.py b/python/pyarrow/dataset.py
index 97d08844f27..e80de1688e7 100644
--- a/python/pyarrow/dataset.py
+++ b/python/pyarrow/dataset.py
@@ -669,10 +669,7 @@ def dataset(source, schema=None, format=None, filesystem=None,
'of batches or tables. The given list contains the following '
'types: {}'.format(type_names)
)
- elif isinstance(source, (pa.RecordBatch, pa.ipc.RecordBatchReader,
- pa.Table)):
- return _in_memory_dataset(source, **kwargs)
- elif _is_iterable(source):
+ elif isinstance(source, (pa.RecordBatch, pa.Table)):
return _in_memory_dataset(source, **kwargs)
else:
raise TypeError(
@@ -736,9 +733,12 @@ def write_dataset(data, base_dir, basename_template=None, format=None,
if isinstance(data, (list, tuple)):
schema = schema or data[0].schema
data = InMemoryDataset(data, schema=schema)
- elif isinstance(data, (pa.RecordBatch, pa.ipc.RecordBatchReader,
- pa.Table)) or _is_iterable(data):
+ elif isinstance(data, (pa.RecordBatch, pa.Table)):
+ schema = schema or data.schema
data = InMemoryDataset(data, schema=schema)
+ elif isinstance(data, pa.ipc.RecordBatchReader) or _is_iterable(data):
+ data = Scanner.from_batches(data, schema=schema)
+ schema = None
elif not isinstance(data, (Dataset, Scanner)):
raise ValueError(
"Only Dataset, Scanner, Table/RecordBatch, RecordBatchReader, "
diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd
index f802c36d34f..9105abcf7a0 100644
--- a/python/pyarrow/includes/libarrow_dataset.pxd
+++ b/python/pyarrow/includes/libarrow_dataset.pxd
@@ -121,6 +121,10 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil:
shared_ptr[CScanOptions] scan_options)
CScannerBuilder(shared_ptr[CSchema], shared_ptr[CFragment],
shared_ptr[CScanOptions] scan_options)
+
+ @staticmethod
+ shared_ptr[CScannerBuilder] FromRecordBatchReader(
+ shared_ptr[CRecordBatchReader] reader)
CStatus ProjectColumns "Project"(const vector[c_string]& columns)
CStatus Project(vector[CExpression]& exprs, vector[c_string]& columns)
CStatus Filter(CExpression filter)
@@ -151,10 +155,6 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil:
CInMemoryDataset(shared_ptr[CRecordBatchReader])
CInMemoryDataset(shared_ptr[CTable])
- cdef cppclass COneShotDataset "arrow::dataset::OneShotDataset"(
- CInMemoryDataset):
- COneShotDataset(shared_ptr[CRecordBatchReader])
-
cdef cppclass CUnionDataset "arrow::dataset::UnionDataset"(
CDataset):
@staticmethod
diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py
index 2e8a93a8ac2..20a0a607c1d 100644
--- a/python/pyarrow/tests/test_dataset.py
+++ b/python/pyarrow/tests/test_dataset.py
@@ -1682,9 +1682,10 @@ def test_construct_from_invalid_sources_raise(multisourcefs):
ds.dataset(None)
expected = (
- "Must provide schema to construct in-memory dataset from an iterable"
+ "Expected a path-like, list of path-likes or a list of Datasets "
+ "instead of the given type: generator"
)
- with pytest.raises(ValueError, match=expected):
+ with pytest.raises(TypeError, match=expected):
ds.dataset((batch1 for _ in range(3)))
expected = (
@@ -1717,52 +1718,32 @@ def test_construct_from_invalid_sources_raise(multisourcefs):
def test_construct_in_memory():
batch = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"])
table = pa.Table.from_batches([batch])
- reader = pa.ipc.RecordBatchReader.from_batches(batch.schema, [batch])
- iterable = (batch for _ in range(1))
-
- for source in (batch, table, reader, [batch], [table]):
- dataset = ds.dataset(source)
- assert dataset.to_table() == table
- assert ds.dataset(iterable, schema=batch.schema).to_table().equals(table)
assert ds.dataset([], schema=pa.schema([])).to_table() == pa.table([])
- # When constructed from batches/tables, should be reusable
for source in (batch, table, [batch], [table]):
dataset = ds.dataset(source)
- assert len(list(dataset.get_fragments())) == 1
- assert len(list(dataset.get_fragments())) == 1
- assert dataset.to_table() == table
assert dataset.to_table() == table
+ assert len(list(dataset.get_fragments())) == 1
assert next(dataset.get_fragments()).to_table() == table
assert pa.Table.from_batches(list(dataset.to_batches())) == table
+
+def test_scan_iterator():
+ batch = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"])
+ table = pa.Table.from_batches([batch])
# When constructed from readers/iterators, should be one-shot
- match = "OneShotDataset was already consumed"
- for factory in (
- lambda: pa.ipc.RecordBatchReader.from_batches(
- batch.schema, [batch]),
- lambda: (batch for _ in range(1)),
+ match = "OneShotFragment was already scanned"
+ for factory, schema in (
+ (lambda: pa.ipc.RecordBatchReader.from_batches(
+ batch.schema, [batch]), None),
+ (lambda: (batch for _ in range(1)), batch.schema),
):
# Scanning the fragment consumes the underlying iterator
- dataset = ds.dataset(factory(), schema=batch.schema)
- fragments = list(dataset.get_fragments())
- assert len(fragments) == 1
- assert fragments[0].to_table() == table
- # But you can still get fragments
- fragments = list(dataset.get_fragments())
- with pytest.raises(pa.ArrowInvalid, match=match):
- fragments[0].to_table()
- with pytest.raises(pa.ArrowInvalid, match=match):
- dataset.to_table()
- # So does scanning the dataset
- dataset = ds.dataset(factory(), schema=batch.schema)
- assert dataset.to_table() == table
- fragments = list(dataset.get_fragments())
- with pytest.raises(pa.ArrowInvalid, match=match):
- fragments[0].to_table()
+ scanner = ds.Scanner.from_batches(factory(), schema=schema)
+ assert scanner.to_table() == table
with pytest.raises(pa.ArrowInvalid, match=match):
- dataset.to_table()
+ scanner.to_table()
@pytest.mark.parquet
From 8804db785ef273f23b22e74804df954b2fb941d2 Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 5 May 2021 21:01:14 -0400
Subject: [PATCH 3/4] ARROW-12231: [C++][Dataset] Address review feedback
---
cpp/src/arrow/dataset/scanner.cc | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc
index 246ed4b0a01..6c5d9e61b21 100644
--- a/cpp/src/arrow/dataset/scanner.cc
+++ b/cpp/src/arrow/dataset/scanner.cc
@@ -673,6 +673,7 @@ ScannerBuilder::ScannerBuilder(std::shared_ptr schema,
DCHECK_OK(Filter(scan_options_->filter));
}
+namespace {
class ARROW_DS_EXPORT OneShotScanTask : public ScanTask {
public:
OneShotScanTask(RecordBatchIterator batch_it, std::shared_ptr options,
@@ -719,10 +720,11 @@ class ARROW_DS_EXPORT OneShotFragment : public Fragment {
RecordBatchIterator batch_it_;
};
+} // namespace
std::shared_ptr ScannerBuilder::FromRecordBatchReader(
std::shared_ptr reader) {
- auto batch_it = MakeFunctionIterator([reader] { return reader->Next(); });
+ auto batch_it = MakeIteratorFromReader(reader);
auto fragment =
std::make_shared(reader->schema(), std::move(batch_it));
return std::make_shared(reader->schema(), std::move(fragment),
From f0162bc1e4969b9194d5838511d1ed0beaed8947 Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 5 May 2021 21:15:39 -0400
Subject: [PATCH 4/4] ARROW-12231: [C++][Dataset] Make MSVC happy
---
cpp/src/arrow/dataset/scanner.cc | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc
index 6c5d9e61b21..3c2c7f65ee7 100644
--- a/cpp/src/arrow/dataset/scanner.cc
+++ b/cpp/src/arrow/dataset/scanner.cc
@@ -674,7 +674,7 @@ ScannerBuilder::ScannerBuilder(std::shared_ptr schema,
}
namespace {
-class ARROW_DS_EXPORT OneShotScanTask : public ScanTask {
+class OneShotScanTask : public ScanTask {
public:
OneShotScanTask(RecordBatchIterator batch_it, std::shared_ptr options,
std::shared_ptr fragment)
@@ -689,7 +689,7 @@ class ARROW_DS_EXPORT OneShotScanTask : public ScanTask {
RecordBatchIterator batch_it_;
};
-class ARROW_DS_EXPORT OneShotFragment : public Fragment {
+class OneShotFragment : public Fragment {
public:
OneShotFragment(std::shared_ptr schema, RecordBatchIterator batch_it)
: Fragment(compute::literal(true), std::move(schema)),