diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc
index ab0600dd1a8..5afb6a06c16 100644
--- a/cpp/src/arrow/dataset/dataset.cc
+++ b/cpp/src/arrow/dataset/dataset.cc
@@ -199,28 +199,6 @@ InMemoryDataset::InMemoryDataset(std::shared_ptr
table)
: Dataset(table->schema()),
get_batches_(new TableRecordBatchGenerator(std::move(table))) {}
-struct ReaderRecordBatchGenerator : InMemoryDataset::RecordBatchGenerator {
- explicit ReaderRecordBatchGenerator(std::shared_ptr reader)
- : reader_(std::move(reader)), consumed_(false) {}
-
- RecordBatchIterator Get() const final {
- if (consumed_) {
- return MakeErrorIterator>(Status::Invalid(
- "RecordBatchReader-backed InMemoryDataset was already consumed"));
- }
- consumed_ = true;
- auto reader = reader_;
- return MakeFunctionIterator([reader] { return reader->Next(); });
- }
-
- std::shared_ptr reader_;
- mutable bool consumed_;
-};
-
-InMemoryDataset::InMemoryDataset(std::shared_ptr reader)
- : Dataset(reader->schema()),
- get_batches_(new ReaderRecordBatchGenerator(std::move(reader))) {}
-
Result> InMemoryDataset::ReplaceSchema(
std::shared_ptr schema) const {
RETURN_NOT_OK(CheckProjectable(*schema_, *schema));
diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h
index 40a60ffd48e..164b3ec17aa 100644
--- a/cpp/src/arrow/dataset/dataset.h
+++ b/cpp/src/arrow/dataset/dataset.h
@@ -206,7 +206,6 @@ class ARROW_DS_EXPORT InMemoryDataset : public Dataset {
/// Convenience constructor taking a Table
explicit InMemoryDataset(std::shared_ptr table);
- explicit InMemoryDataset(std::shared_ptr reader);
std::string type_name() const override { return "in-memory"; }
diff --git a/cpp/src/arrow/dataset/dataset_test.cc b/cpp/src/arrow/dataset/dataset_test.cc
index 7aa0e1a2413..66d69c30c82 100644
--- a/cpp/src/arrow/dataset/dataset_test.cc
+++ b/cpp/src/arrow/dataset/dataset_test.cc
@@ -79,23 +79,6 @@ TEST_F(TestInMemoryDataset, ReplaceSchema) {
.status());
}
-TEST_F(TestInMemoryDataset, FromReader) {
- constexpr int64_t kBatchSize = 1024;
- constexpr int64_t kNumberBatches = 16;
-
- SetSchema({field("i32", int32()), field("f64", float64())});
- auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
- auto source_reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch);
- auto target_reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch);
-
- auto dataset = std::make_shared(source_reader);
-
- AssertDatasetEquals(target_reader.get(), dataset.get());
- // Such datasets can only be scanned once
- ASSERT_OK_AND_ASSIGN(auto fragments, dataset->GetFragments());
- ASSERT_RAISES(Invalid, fragments.Next());
-}
-
TEST_F(TestInMemoryDataset, GetFragments) {
constexpr int64_t kBatchSize = 1024;
constexpr int64_t kNumberBatches = 16;
diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc
index 41fa7ec5c77..3c2c7f65ee7 100644
--- a/cpp/src/arrow/dataset/scanner.cc
+++ b/cpp/src/arrow/dataset/scanner.cc
@@ -673,6 +673,64 @@ ScannerBuilder::ScannerBuilder(std::shared_ptr schema,
DCHECK_OK(Filter(scan_options_->filter));
}
+namespace {
+class OneShotScanTask : public ScanTask {
+ public:
+ OneShotScanTask(RecordBatchIterator batch_it, std::shared_ptr options,
+ std::shared_ptr fragment)
+ : ScanTask(std::move(options), std::move(fragment)),
+ batch_it_(std::move(batch_it)) {}
+ Result Execute() override {
+ if (!batch_it_) return Status::Invalid("OneShotScanTask was already scanned");
+ return std::move(batch_it_);
+ }
+
+ private:
+ RecordBatchIterator batch_it_;
+};
+
+class OneShotFragment : public Fragment {
+ public:
+ OneShotFragment(std::shared_ptr schema, RecordBatchIterator batch_it)
+ : Fragment(compute::literal(true), std::move(schema)),
+ batch_it_(std::move(batch_it)) {
+ DCHECK_NE(physical_schema_, nullptr);
+ }
+ Status CheckConsumed() {
+ if (!batch_it_) return Status::Invalid("OneShotFragment was already scanned");
+ return Status::OK();
+ }
+ Result Scan(std::shared_ptr options) override {
+ RETURN_NOT_OK(CheckConsumed());
+ ScanTaskVector tasks{std::make_shared(
+ std::move(batch_it_), std::move(options), shared_from_this())};
+ return MakeVectorIterator(std::move(tasks));
+ }
+ Result ScanBatchesAsync(
+ const std::shared_ptr& options) override {
+ RETURN_NOT_OK(CheckConsumed());
+ return MakeBackgroundGenerator(std::move(batch_it_), options->io_context.executor());
+ }
+ std::string type_name() const override { return "one-shot"; }
+
+ protected:
+ Result> ReadPhysicalSchemaImpl() override {
+ return physical_schema_;
+ }
+
+ RecordBatchIterator batch_it_;
+};
+} // namespace
+
+std::shared_ptr ScannerBuilder::FromRecordBatchReader(
+ std::shared_ptr reader) {
+ auto batch_it = MakeIteratorFromReader(reader);
+ auto fragment =
+ std::make_shared(reader->schema(), std::move(batch_it));
+ return std::make_shared(reader->schema(), std::move(fragment),
+ std::make_shared());
+}
+
const std::shared_ptr& ScannerBuilder::schema() const {
return scan_options_->dataset_schema;
}
diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h
index 15bd27ab4f3..bbb79ee474f 100644
--- a/cpp/src/arrow/dataset/scanner.h
+++ b/cpp/src/arrow/dataset/scanner.h
@@ -312,6 +312,14 @@ class ARROW_DS_EXPORT ScannerBuilder {
ScannerBuilder(std::shared_ptr schema, std::shared_ptr fragment,
std::shared_ptr scan_options);
+ /// \brief Make a scanner from a record batch reader.
+ ///
+ /// The resulting scanner can be scanned only once. This is intended
+ /// to support writing data from streaming sources or other sources
+ /// that can be iterated only once.
+ static std::shared_ptr FromRecordBatchReader(
+ std::shared_ptr reader);
+
/// \brief Set the subset of columns to materialize.
///
/// Columns which are not referenced may not be read from fragments.
diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc
index 17f4e079ae4..c44f28afb3b 100644
--- a/cpp/src/arrow/dataset/scanner_test.cc
+++ b/cpp/src/arrow/dataset/scanner_test.cc
@@ -507,6 +507,30 @@ TEST_P(TestScanner, Head) {
AssertTablesEqual(*expected, *actual);
}
+TEST_P(TestScanner, FromReader) {
+ if (GetParam().use_async) {
+ GTEST_SKIP() << "Async scanner does not support construction from reader";
+ }
+ auto batch_size = GetParam().items_per_batch;
+ auto num_batches = GetParam().num_batches;
+
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(batch_size, schema_);
+ auto source_reader = ConstantArrayGenerator::Repeat(num_batches, batch);
+ auto target_reader = ConstantArrayGenerator::Repeat(num_batches, batch);
+
+ auto builder = ScannerBuilder::FromRecordBatchReader(source_reader);
+ ARROW_EXPECT_OK(builder->UseThreads(GetParam().use_threads));
+ ASSERT_OK_AND_ASSIGN(auto scanner, builder->Finish());
+ AssertScannerEquals(target_reader.get(), scanner.get());
+
+ // Such datasets can only be scanned once (but you can get fragments multiple times)
+ ASSERT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches());
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("OneShotFragment was already scanned"),
+ batch_it.Next());
+}
+
INSTANTIATE_TEST_SUITE_P(TestScannerThreading, TestScanner,
::testing::ValuesIn(TestScannerParams::Values()));
diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx
index cf076f6536b..2a78a7fe712 100644
--- a/python/pyarrow/_dataset.pyx
+++ b/python/pyarrow/_dataset.pyx
@@ -298,6 +298,7 @@ cdef class Dataset(_Weakrefable):
classes = {
'union': UnionDataset,
'filesystem': FileSystemDataset,
+ 'in-memory': InMemoryDataset,
}
class_ = classes.get(type_name, None)
@@ -511,19 +512,10 @@ cdef class InMemoryDataset(Dataset):
table = pa.Table.from_batches(batches, schema=schema)
in_memory_dataset = make_shared[CInMemoryDataset](
pyarrow_unwrap_table(table))
- elif isinstance(source, pa.ipc.RecordBatchReader):
- reader = source
- in_memory_dataset = make_shared[CInMemoryDataset](reader.reader)
- elif _is_iterable(source):
- if schema is None:
- raise ValueError('Must provide schema to construct in-memory '
- 'dataset from an iterable')
- reader = pa.ipc.RecordBatchReader.from_batches(schema, source)
- in_memory_dataset = make_shared[CInMemoryDataset](reader.reader)
else:
raise TypeError(
- 'Expected a table, batch, iterable of tables/batches, or a '
- 'record batch reader instead of the given type: ' +
+ 'Expected a table, batch, or list of tables/batches '
+ 'instead of the given type: ' +
type(source).__name__
)
@@ -2751,6 +2743,46 @@ cdef class Scanner(_Weakrefable):
scanner = GetResultValue(builder.get().Finish())
return Scanner.wrap(scanner)
+ @staticmethod
+ def from_batches(source, Schema schema=None, bint use_threads=True,
+ MemoryPool memory_pool=None, object columns=None,
+ Expression filter=None,
+ int batch_size=_DEFAULT_BATCH_SIZE,
+ FragmentScanOptions fragment_scan_options=None):
+ """Create a Scanner from an iterator of batches.
+
+ This creates a scanner which can be used only once. It is
+ intended to support writing a dataset (which takes a scanner)
+ from a source which can be read only once (e.g. a
+ RecordBatchReader or generator).
+ """
+ cdef:
+ shared_ptr[CScanOptions] options = make_shared[CScanOptions]()
+ shared_ptr[CScannerBuilder] builder
+ shared_ptr[CScanner] scanner
+ RecordBatchReader reader
+ if isinstance(source, pa.ipc.RecordBatchReader):
+ if schema:
+ raise ValueError('Cannot specify a schema when providing '
+ 'a RecordBatchReader')
+ reader = source
+ elif _is_iterable(source):
+ if schema is None:
+ raise ValueError('Must provide schema to construct scanner '
+ 'from an iterable')
+ reader = pa.ipc.RecordBatchReader.from_batches(schema, source)
+ else:
+ raise TypeError('Expected a RecordBatchReader or an iterable of '
+ 'batches instead of the given type: ' +
+ type(source).__name__)
+ builder = CScannerBuilder.FromRecordBatchReader(reader.reader)
+ _populate_builder(builder, columns=columns, filter=filter,
+ batch_size=batch_size, use_threads=use_threads,
+ memory_pool=memory_pool,
+ fragment_scan_options=fragment_scan_options)
+ scanner = GetResultValue(builder.get().Finish())
+ return Scanner.wrap(scanner)
+
@property
def dataset_schema(self):
"""The schema with which batches will be read from fragments."""
diff --git a/python/pyarrow/dataset.py b/python/pyarrow/dataset.py
index 97d08844f27..e80de1688e7 100644
--- a/python/pyarrow/dataset.py
+++ b/python/pyarrow/dataset.py
@@ -669,10 +669,7 @@ def dataset(source, schema=None, format=None, filesystem=None,
'of batches or tables. The given list contains the following '
'types: {}'.format(type_names)
)
- elif isinstance(source, (pa.RecordBatch, pa.ipc.RecordBatchReader,
- pa.Table)):
- return _in_memory_dataset(source, **kwargs)
- elif _is_iterable(source):
+ elif isinstance(source, (pa.RecordBatch, pa.Table)):
return _in_memory_dataset(source, **kwargs)
else:
raise TypeError(
@@ -736,9 +733,12 @@ def write_dataset(data, base_dir, basename_template=None, format=None,
if isinstance(data, (list, tuple)):
schema = schema or data[0].schema
data = InMemoryDataset(data, schema=schema)
- elif isinstance(data, (pa.RecordBatch, pa.ipc.RecordBatchReader,
- pa.Table)) or _is_iterable(data):
+ elif isinstance(data, (pa.RecordBatch, pa.Table)):
+ schema = schema or data.schema
data = InMemoryDataset(data, schema=schema)
+ elif isinstance(data, pa.ipc.RecordBatchReader) or _is_iterable(data):
+ data = Scanner.from_batches(data, schema=schema)
+ schema = None
elif not isinstance(data, (Dataset, Scanner)):
raise ValueError(
"Only Dataset, Scanner, Table/RecordBatch, RecordBatchReader, "
diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd
index bff1a2bbb54..9105abcf7a0 100644
--- a/python/pyarrow/includes/libarrow_dataset.pxd
+++ b/python/pyarrow/includes/libarrow_dataset.pxd
@@ -121,6 +121,10 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil:
shared_ptr[CScanOptions] scan_options)
CScannerBuilder(shared_ptr[CSchema], shared_ptr[CFragment],
shared_ptr[CScanOptions] scan_options)
+
+ @staticmethod
+ shared_ptr[CScannerBuilder] FromRecordBatchReader(
+ shared_ptr[CRecordBatchReader] reader)
CStatus ProjectColumns "Project"(const vector[c_string]& columns)
CStatus Project(vector[CExpression]& exprs, vector[c_string]& columns)
CStatus Filter(CExpression filter)
diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py
index 8791c22f103..20a0a607c1d 100644
--- a/python/pyarrow/tests/test_dataset.py
+++ b/python/pyarrow/tests/test_dataset.py
@@ -1682,9 +1682,10 @@ def test_construct_from_invalid_sources_raise(multisourcefs):
ds.dataset(None)
expected = (
- "Must provide schema to construct in-memory dataset from an iterable"
+ "Expected a path-like, list of path-likes or a list of Datasets "
+ "instead of the given type: generator"
)
- with pytest.raises(ValueError, match=expected):
+ with pytest.raises(TypeError, match=expected):
ds.dataset((batch1 for _ in range(3)))
expected = (
@@ -1717,49 +1718,32 @@ def test_construct_from_invalid_sources_raise(multisourcefs):
def test_construct_in_memory():
batch = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"])
table = pa.Table.from_batches([batch])
- reader = pa.ipc.RecordBatchReader.from_batches(batch.schema, [batch])
- iterable = (batch for _ in range(1))
- for source in (batch, table, reader, [batch], [table]):
- dataset = ds.dataset(source)
- assert dataset.to_table() == table
-
- assert ds.dataset(iterable, schema=batch.schema).to_table().equals(table)
assert ds.dataset([], schema=pa.schema([])).to_table() == pa.table([])
- # When constructed from batches/tables, should be reusable
for source in (batch, table, [batch], [table]):
dataset = ds.dataset(source)
- assert len(list(dataset.get_fragments())) == 1
- assert len(list(dataset.get_fragments())) == 1
- assert dataset.to_table() == table
assert dataset.to_table() == table
+ assert len(list(dataset.get_fragments())) == 1
assert next(dataset.get_fragments()).to_table() == table
assert pa.Table.from_batches(list(dataset.to_batches())) == table
+
+def test_scan_iterator():
+ batch = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"])
+ table = pa.Table.from_batches([batch])
# When constructed from readers/iterators, should be one-shot
- match = "InMemoryDataset was already consumed"
- for factory in (
- lambda: pa.ipc.RecordBatchReader.from_batches(
- batch.schema, [batch]),
- lambda: (batch for _ in range(1)),
+ match = "OneShotFragment was already scanned"
+ for factory, schema in (
+ (lambda: pa.ipc.RecordBatchReader.from_batches(
+ batch.schema, [batch]), None),
+ (lambda: (batch for _ in range(1)), batch.schema),
):
- dataset = ds.dataset(factory(), schema=batch.schema)
- # Getting fragments consumes the underlying iterator
- fragments = list(dataset.get_fragments())
- assert len(fragments) == 1
- assert fragments[0].to_table() == table
- with pytest.raises(pa.ArrowInvalid, match=match):
- list(dataset.get_fragments())
- with pytest.raises(pa.ArrowInvalid, match=match):
- dataset.to_table()
- # Materializing consumes the underlying iterator
- dataset = ds.dataset(factory(), schema=batch.schema)
- assert dataset.to_table() == table
- with pytest.raises(pa.ArrowInvalid, match=match):
- list(dataset.get_fragments())
+ # Scanning the fragment consumes the underlying iterator
+ scanner = ds.Scanner.from_batches(factory(), schema=schema)
+ assert scanner.to_table() == table
with pytest.raises(pa.ArrowInvalid, match=match):
- dataset.to_table()
+ scanner.to_table()
@pytest.mark.parquet