Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions cpp/src/arrow/dataset/dataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,28 +199,6 @@ InMemoryDataset::InMemoryDataset(std::shared_ptr<Table> table)
: Dataset(table->schema()),
get_batches_(new TableRecordBatchGenerator(std::move(table))) {}

struct ReaderRecordBatchGenerator : InMemoryDataset::RecordBatchGenerator {
explicit ReaderRecordBatchGenerator(std::shared_ptr<RecordBatchReader> reader)
: reader_(std::move(reader)), consumed_(false) {}

RecordBatchIterator Get() const final {
if (consumed_) {
return MakeErrorIterator<std::shared_ptr<RecordBatch>>(Status::Invalid(
"RecordBatchReader-backed InMemoryDataset was already consumed"));
}
consumed_ = true;
auto reader = reader_;
return MakeFunctionIterator([reader] { return reader->Next(); });
}

std::shared_ptr<RecordBatchReader> reader_;
mutable bool consumed_;
};

InMemoryDataset::InMemoryDataset(std::shared_ptr<RecordBatchReader> reader)
: Dataset(reader->schema()),
get_batches_(new ReaderRecordBatchGenerator(std::move(reader))) {}

Result<std::shared_ptr<Dataset>> InMemoryDataset::ReplaceSchema(
std::shared_ptr<Schema> schema) const {
RETURN_NOT_OK(CheckProjectable(*schema_, *schema));
Expand Down
1 change: 0 additions & 1 deletion cpp/src/arrow/dataset/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ class ARROW_DS_EXPORT InMemoryDataset : public Dataset {

/// Convenience constructor taking a Table
explicit InMemoryDataset(std::shared_ptr<Table> table);
explicit InMemoryDataset(std::shared_ptr<RecordBatchReader> reader);

std::string type_name() const override { return "in-memory"; }

Expand Down
17 changes: 0 additions & 17 deletions cpp/src/arrow/dataset/dataset_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,23 +79,6 @@ TEST_F(TestInMemoryDataset, ReplaceSchema) {
.status());
}

TEST_F(TestInMemoryDataset, FromReader) {
constexpr int64_t kBatchSize = 1024;
constexpr int64_t kNumberBatches = 16;

SetSchema({field("i32", int32()), field("f64", float64())});
auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
auto source_reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch);
auto target_reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch);

auto dataset = std::make_shared<InMemoryDataset>(source_reader);

AssertDatasetEquals(target_reader.get(), dataset.get());
// Such datasets can only be scanned once
ASSERT_OK_AND_ASSIGN(auto fragments, dataset->GetFragments());
ASSERT_RAISES(Invalid, fragments.Next());
}

TEST_F(TestInMemoryDataset, GetFragments) {
constexpr int64_t kBatchSize = 1024;
constexpr int64_t kNumberBatches = 16;
Expand Down
58 changes: 58 additions & 0 deletions cpp/src/arrow/dataset/scanner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,64 @@ ScannerBuilder::ScannerBuilder(std::shared_ptr<Schema> schema,
DCHECK_OK(Filter(scan_options_->filter));
}

namespace {
class OneShotScanTask : public ScanTask {
public:
OneShotScanTask(RecordBatchIterator batch_it, std::shared_ptr<ScanOptions> options,
std::shared_ptr<Fragment> fragment)
: ScanTask(std::move(options), std::move(fragment)),
batch_it_(std::move(batch_it)) {}
Result<RecordBatchIterator> Execute() override {
if (!batch_it_) return Status::Invalid("OneShotScanTask was already scanned");
return std::move(batch_it_);
}

private:
RecordBatchIterator batch_it_;
};

class OneShotFragment : public Fragment {
public:
OneShotFragment(std::shared_ptr<Schema> schema, RecordBatchIterator batch_it)
: Fragment(compute::literal(true), std::move(schema)),
batch_it_(std::move(batch_it)) {
DCHECK_NE(physical_schema_, nullptr);
}
Status CheckConsumed() {
if (!batch_it_) return Status::Invalid("OneShotFragment was already scanned");
return Status::OK();
}
Result<ScanTaskIterator> Scan(std::shared_ptr<ScanOptions> options) override {
RETURN_NOT_OK(CheckConsumed());
ScanTaskVector tasks{std::make_shared<OneShotScanTask>(
std::move(batch_it_), std::move(options), shared_from_this())};
return MakeVectorIterator(std::move(tasks));
}
Result<RecordBatchGenerator> ScanBatchesAsync(
const std::shared_ptr<ScanOptions>& options) override {
RETURN_NOT_OK(CheckConsumed());
return MakeBackgroundGenerator(std::move(batch_it_), options->io_context.executor());
}
std::string type_name() const override { return "one-shot"; }

protected:
Result<std::shared_ptr<Schema>> ReadPhysicalSchemaImpl() override {
return physical_schema_;
}

RecordBatchIterator batch_it_;
};
} // namespace

std::shared_ptr<ScannerBuilder> ScannerBuilder::FromRecordBatchReader(
std::shared_ptr<RecordBatchReader> reader) {
auto batch_it = MakeIteratorFromReader(reader);
auto fragment =
std::make_shared<OneShotFragment>(reader->schema(), std::move(batch_it));
return std::make_shared<ScannerBuilder>(reader->schema(), std::move(fragment),
std::make_shared<ScanOptions>());
}

const std::shared_ptr<Schema>& ScannerBuilder::schema() const {
return scan_options_->dataset_schema;
}
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/arrow/dataset/scanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,14 @@ class ARROW_DS_EXPORT ScannerBuilder {
ScannerBuilder(std::shared_ptr<Schema> schema, std::shared_ptr<Fragment> fragment,
std::shared_ptr<ScanOptions> scan_options);

/// \brief Make a scanner from a record batch reader.
///
/// The resulting scanner can be scanned only once. This is intended
/// to support writing data from streaming sources or other sources
/// that can be iterated only once.
static std::shared_ptr<ScannerBuilder> FromRecordBatchReader(
std::shared_ptr<RecordBatchReader> reader);

/// \brief Set the subset of columns to materialize.
///
/// Columns which are not referenced may not be read from fragments.
Expand Down
24 changes: 24 additions & 0 deletions cpp/src/arrow/dataset/scanner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,30 @@ TEST_P(TestScanner, Head) {
AssertTablesEqual(*expected, *actual);
}

TEST_P(TestScanner, FromReader) {
if (GetParam().use_async) {
GTEST_SKIP() << "Async scanner does not support construction from reader";
}
auto batch_size = GetParam().items_per_batch;
auto num_batches = GetParam().num_batches;

SetSchema({field("i32", int32()), field("f64", float64())});
auto batch = ConstantArrayGenerator::Zeroes(batch_size, schema_);
auto source_reader = ConstantArrayGenerator::Repeat(num_batches, batch);
auto target_reader = ConstantArrayGenerator::Repeat(num_batches, batch);

auto builder = ScannerBuilder::FromRecordBatchReader(source_reader);
ARROW_EXPECT_OK(builder->UseThreads(GetParam().use_threads));
ASSERT_OK_AND_ASSIGN(auto scanner, builder->Finish());
AssertScannerEquals(target_reader.get(), scanner.get());

// Such datasets can only be scanned once (but you can get fragments multiple times)
ASSERT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches());
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid, ::testing::HasSubstr("OneShotFragment was already scanned"),
batch_it.Next());
}

INSTANTIATE_TEST_SUITE_P(TestScannerThreading, TestScanner,
::testing::ValuesIn(TestScannerParams::Values()));

Expand Down
54 changes: 43 additions & 11 deletions python/pyarrow/_dataset.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ cdef class Dataset(_Weakrefable):
classes = {
'union': UnionDataset,
'filesystem': FileSystemDataset,
'in-memory': InMemoryDataset,
}

class_ = classes.get(type_name, None)
Expand Down Expand Up @@ -511,19 +512,10 @@ cdef class InMemoryDataset(Dataset):
table = pa.Table.from_batches(batches, schema=schema)
in_memory_dataset = make_shared[CInMemoryDataset](
pyarrow_unwrap_table(table))
elif isinstance(source, pa.ipc.RecordBatchReader):
reader = source
in_memory_dataset = make_shared[CInMemoryDataset](reader.reader)
elif _is_iterable(source):
if schema is None:
raise ValueError('Must provide schema to construct in-memory '
'dataset from an iterable')
reader = pa.ipc.RecordBatchReader.from_batches(schema, source)
in_memory_dataset = make_shared[CInMemoryDataset](reader.reader)
else:
raise TypeError(
'Expected a table, batch, iterable of tables/batches, or a '
'record batch reader instead of the given type: ' +
'Expected a table, batch, or list of tables/batches '
'instead of the given type: ' +
type(source).__name__
)

Expand Down Expand Up @@ -2751,6 +2743,46 @@ cdef class Scanner(_Weakrefable):
scanner = GetResultValue(builder.get().Finish())
return Scanner.wrap(scanner)

@staticmethod
def from_batches(source, Schema schema=None, bint use_threads=True,
MemoryPool memory_pool=None, object columns=None,
Expression filter=None,
int batch_size=_DEFAULT_BATCH_SIZE,
FragmentScanOptions fragment_scan_options=None):
"""Create a Scanner from an iterator of batches.

This creates a scanner which can be used only once. It is
intended to support writing a dataset (which takes a scanner)
from a source which can be read only once (e.g. a
RecordBatchReader or generator).
"""
cdef:
shared_ptr[CScanOptions] options = make_shared[CScanOptions]()
shared_ptr[CScannerBuilder] builder
shared_ptr[CScanner] scanner
RecordBatchReader reader
if isinstance(source, pa.ipc.RecordBatchReader):
if schema:
raise ValueError('Cannot specify a schema when providing '
'a RecordBatchReader')
reader = source
elif _is_iterable(source):
if schema is None:
raise ValueError('Must provide schema to construct scanner '
'from an iterable')
reader = pa.ipc.RecordBatchReader.from_batches(schema, source)
else:
raise TypeError('Expected a RecordBatchReader or an iterable of '
'batches instead of the given type: ' +
type(source).__name__)
builder = CScannerBuilder.FromRecordBatchReader(reader.reader)
_populate_builder(builder, columns=columns, filter=filter,
batch_size=batch_size, use_threads=use_threads,
memory_pool=memory_pool,
fragment_scan_options=fragment_scan_options)
scanner = GetResultValue(builder.get().Finish())
return Scanner.wrap(scanner)

@property
def dataset_schema(self):
"""The schema with which batches will be read from fragments."""
Expand Down
12 changes: 6 additions & 6 deletions python/pyarrow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,10 +669,7 @@ def dataset(source, schema=None, format=None, filesystem=None,
'of batches or tables. The given list contains the following '
'types: {}'.format(type_names)
)
elif isinstance(source, (pa.RecordBatch, pa.ipc.RecordBatchReader,
pa.Table)):
return _in_memory_dataset(source, **kwargs)
elif _is_iterable(source):
elif isinstance(source, (pa.RecordBatch, pa.Table)):
return _in_memory_dataset(source, **kwargs)
else:
raise TypeError(
Expand Down Expand Up @@ -736,9 +733,12 @@ def write_dataset(data, base_dir, basename_template=None, format=None,
if isinstance(data, (list, tuple)):
schema = schema or data[0].schema
data = InMemoryDataset(data, schema=schema)
elif isinstance(data, (pa.RecordBatch, pa.ipc.RecordBatchReader,
pa.Table)) or _is_iterable(data):
elif isinstance(data, (pa.RecordBatch, pa.Table)):
schema = schema or data.schema
data = InMemoryDataset(data, schema=schema)
elif isinstance(data, pa.ipc.RecordBatchReader) or _is_iterable(data):
data = Scanner.from_batches(data, schema=schema)
schema = None
elif not isinstance(data, (Dataset, Scanner)):
raise ValueError(
"Only Dataset, Scanner, Table/RecordBatch, RecordBatchReader, "
Expand Down
4 changes: 4 additions & 0 deletions python/pyarrow/includes/libarrow_dataset.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil:
shared_ptr[CScanOptions] scan_options)
CScannerBuilder(shared_ptr[CSchema], shared_ptr[CFragment],
shared_ptr[CScanOptions] scan_options)

@staticmethod
shared_ptr[CScannerBuilder] FromRecordBatchReader(
shared_ptr[CRecordBatchReader] reader)
CStatus ProjectColumns "Project"(const vector[c_string]& columns)
CStatus Project(vector[CExpression]& exprs, vector[c_string]& columns)
CStatus Filter(CExpression filter)
Expand Down
50 changes: 17 additions & 33 deletions python/pyarrow/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1682,9 +1682,10 @@ def test_construct_from_invalid_sources_raise(multisourcefs):
ds.dataset(None)

expected = (
"Must provide schema to construct in-memory dataset from an iterable"
"Expected a path-like, list of path-likes or a list of Datasets "
"instead of the given type: generator"
)
with pytest.raises(ValueError, match=expected):
with pytest.raises(TypeError, match=expected):
ds.dataset((batch1 for _ in range(3)))

expected = (
Expand Down Expand Up @@ -1717,49 +1718,32 @@ def test_construct_from_invalid_sources_raise(multisourcefs):
def test_construct_in_memory():
batch = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"])
table = pa.Table.from_batches([batch])
reader = pa.ipc.RecordBatchReader.from_batches(batch.schema, [batch])
iterable = (batch for _ in range(1))

for source in (batch, table, reader, [batch], [table]):
dataset = ds.dataset(source)
assert dataset.to_table() == table

assert ds.dataset(iterable, schema=batch.schema).to_table().equals(table)
assert ds.dataset([], schema=pa.schema([])).to_table() == pa.table([])

# When constructed from batches/tables, should be reusable
for source in (batch, table, [batch], [table]):
dataset = ds.dataset(source)
assert len(list(dataset.get_fragments())) == 1
assert len(list(dataset.get_fragments())) == 1
assert dataset.to_table() == table
assert dataset.to_table() == table
assert len(list(dataset.get_fragments())) == 1
assert next(dataset.get_fragments()).to_table() == table
assert pa.Table.from_batches(list(dataset.to_batches())) == table


def test_scan_iterator():
batch = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"])
table = pa.Table.from_batches([batch])
# When constructed from readers/iterators, should be one-shot
match = "InMemoryDataset was already consumed"
for factory in (
lambda: pa.ipc.RecordBatchReader.from_batches(
batch.schema, [batch]),
lambda: (batch for _ in range(1)),
match = "OneShotFragment was already scanned"
for factory, schema in (
(lambda: pa.ipc.RecordBatchReader.from_batches(
batch.schema, [batch]), None),
(lambda: (batch for _ in range(1)), batch.schema),
):
dataset = ds.dataset(factory(), schema=batch.schema)
# Getting fragments consumes the underlying iterator
fragments = list(dataset.get_fragments())
assert len(fragments) == 1
assert fragments[0].to_table() == table
with pytest.raises(pa.ArrowInvalid, match=match):
list(dataset.get_fragments())
with pytest.raises(pa.ArrowInvalid, match=match):
dataset.to_table()
# Materializing consumes the underlying iterator
dataset = ds.dataset(factory(), schema=batch.schema)
assert dataset.to_table() == table
with pytest.raises(pa.ArrowInvalid, match=match):
list(dataset.get_fragments())
# Scanning the fragment consumes the underlying iterator
scanner = ds.Scanner.from_batches(factory(), schema=schema)
assert scanner.to_table() == table
with pytest.raises(pa.ArrowInvalid, match=match):
dataset.to_table()
scanner.to_table()


@pytest.mark.parquet
Expand Down