Skip to content
Closed
29 changes: 27 additions & 2 deletions cpp/src/arrow/dataset/dataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,11 @@ InMemoryFragment::InMemoryFragment(std::shared_ptr<Schema> schema,

InMemoryFragment::InMemoryFragment(RecordBatchVector record_batches,
Expression partition_expression)
: InMemoryFragment(record_batches.empty() ? schema({}) : record_batches[0]->schema(),
std::move(record_batches), std::move(partition_expression)) {}
: Fragment(std::move(partition_expression), /*schema=*/nullptr),
record_batches_(std::move(record_batches)) {
// Order of argument evaluation is undefined, so compute physical_schema here
physical_schema_ = record_batches_.empty() ? schema({}) : record_batches_[0]->schema();
}

Result<ScanTaskIterator> InMemoryFragment::Scan(std::shared_ptr<ScanOptions> options) {
// Make an explicit copy of record_batches_ to ensure Scan can be called
Expand Down Expand Up @@ -148,6 +151,28 @@ InMemoryDataset::InMemoryDataset(std::shared_ptr<Table> table)
: Dataset(table->schema()),
get_batches_(new TableRecordBatchGenerator(std::move(table))) {}

struct ReaderRecordBatchGenerator : InMemoryDataset::RecordBatchGenerator {
explicit ReaderRecordBatchGenerator(std::shared_ptr<RecordBatchReader> reader)
: reader_(std::move(reader)), consumed_(false) {}

RecordBatchIterator Get() const final {
if (consumed_) {
return MakeErrorIterator<std::shared_ptr<RecordBatch>>(Status::Invalid(
"RecordBatchReader-backed InMemoryDataset was already consumed"));
}
consumed_ = true;
auto reader = reader_;
return MakeFunctionIterator([reader] { return reader->Next(); });
}

std::shared_ptr<RecordBatchReader> reader_;
mutable bool consumed_;
};

InMemoryDataset::InMemoryDataset(std::shared_ptr<RecordBatchReader> reader)
: Dataset(reader->schema()),
get_batches_(new ReaderRecordBatchGenerator(std::move(reader))) {}

Result<std::shared_ptr<Dataset>> InMemoryDataset::ReplaceSchema(
std::shared_ptr<Schema> schema) const {
RETURN_NOT_OK(CheckProjectable(*schema_, *schema));
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/dataset/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ class ARROW_DS_EXPORT InMemoryDataset : public Dataset {
InMemoryDataset(std::shared_ptr<Schema> schema, RecordBatchVector batches);

explicit InMemoryDataset(std::shared_ptr<Table> table);
explicit InMemoryDataset(std::shared_ptr<RecordBatchReader> reader);

std::string type_name() const override { return "in-memory"; }

Expand Down
31 changes: 31 additions & 0 deletions cpp/src/arrow/dataset/dataset_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,23 @@ TEST_F(TestInMemoryDataset, ReplaceSchema) {
.status());
}

TEST_F(TestInMemoryDataset, FromReader) {
constexpr int64_t kBatchSize = 1024;
constexpr int64_t kNumberBatches = 16;

SetSchema({field("i32", int32()), field("f64", float64())});
auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
auto source_reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch);
auto target_reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch);

auto dataset = std::make_shared<InMemoryDataset>(source_reader);

AssertDatasetEquals(target_reader.get(), dataset.get());
// Such datasets can only be scanned once
ASSERT_OK_AND_ASSIGN(auto fragments, dataset->GetFragments());
ASSERT_RAISES(Invalid, fragments.Next());
}

TEST_F(TestInMemoryDataset, GetFragments) {
constexpr int64_t kBatchSize = 1024;
constexpr int64_t kNumberBatches = 16;
Expand All @@ -93,6 +110,20 @@ TEST_F(TestInMemoryDataset, GetFragments) {
AssertDatasetEquals(reader.get(), dataset.get());
}

TEST_F(TestInMemoryDataset, InMemoryFragment) {
constexpr int64_t kBatchSize = 1024;

SetSchema({field("i32", int32()), field("f64", float64())});
auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
RecordBatchVector batches{batch};

// Regression test: previously this constructor relied on undefined behavior (order of
// evaluation of arguments) leading to fragments being constructed with empty schemas
auto fragment = std::make_shared<InMemoryFragment>(batches);
ASSERT_OK_AND_ASSIGN(auto schema, fragment->ReadPhysicalSchema());
AssertSchemaEqual(batch->schema(), schema);
}

class TestUnionDataset : public DatasetFixtureMixin {};

TEST_F(TestUnionDataset, ReplaceSchema) {
Expand Down
100 changes: 80 additions & 20 deletions python/pyarrow/_dataset.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ import os

import pyarrow as pa
from pyarrow.lib cimport *
from pyarrow.lib import frombytes, tobytes
from pyarrow.lib import ArrowTypeError, frombytes, tobytes
from pyarrow.includes.libarrow_dataset cimport *
from pyarrow._fs cimport FileSystem, FileInfo, FileSelector
from pyarrow._csv cimport ConvertOptions, ParseOptions, ReadOptions
from pyarrow.util import _is_path_like, _stringify_path
from pyarrow.util import _is_iterable, _is_path_like, _stringify_path

from pyarrow._parquet cimport (
_create_writer_properties, _create_arrow_writer_properties,
Expand Down Expand Up @@ -441,6 +441,76 @@ cdef class Dataset(_Weakrefable):
return pyarrow_wrap_schema(self.dataset.schema())


cdef class InMemoryDataset(Dataset):
"""A Dataset wrapping in-memory data.

Parameters
----------
source
The data for this dataset. Can be a RecordBatch, Table, list of
RecordBatch/Table, iterable of RecordBatch, or a RecordBatchReader.
If an iterable is provided, the schema must also be provided.
schema : Schema, optional
Only required if passing an iterable as the source.
"""

cdef:
CInMemoryDataset* in_memory_dataset

def __init__(self, source, Schema schema=None):
cdef:
RecordBatchReader reader
shared_ptr[CInMemoryDataset] in_memory_dataset

if isinstance(source, (pa.RecordBatch, pa.Table)):
source = [source]

if isinstance(source, (list, tuple)):
batches = []
for item in source:
if isinstance(item, pa.RecordBatch):
batches.append(item)
elif isinstance(item, pa.Table):
batches.extend(item.to_batches())
else:
raise TypeError(
'Expected a list of tables or batches. The given list '
'contains a ' + type(item).__name__)
if schema is None:
schema = item.schema
elif not schema.equals(item.schema):
raise ArrowTypeError(
f'Item has schema\n{item.schema}\nwhich does not '
f'match expected schema\n{schema}')
if not batches and schema is None:
raise ValueError('Must provide schema to construct in-memory '
'dataset from an empty list')
table = pa.Table.from_batches(batches, schema=schema)
in_memory_dataset = make_shared[CInMemoryDataset](
pyarrow_unwrap_table(table))
elif isinstance(source, pa.ipc.RecordBatchReader):
reader = source
in_memory_dataset = make_shared[CInMemoryDataset](reader.reader)
elif _is_iterable(source):
if schema is None:
raise ValueError('Must provide schema to construct in-memory '
'dataset from an iterable')
reader = pa.ipc.RecordBatchReader.from_batches(schema, source)
in_memory_dataset = make_shared[CInMemoryDataset](reader.reader)
else:
raise TypeError(
'Expected a table, batch, iterable of tables/batches, or a '
'record batch reader instead of the given type: ' +
type(source).__name__
)

self.init(<shared_ptr[CDataset]> in_memory_dataset)

cdef void init(self, const shared_ptr[CDataset]& sp):
Dataset.init(self, sp)
self.in_memory_dataset = <CInMemoryDataset*> sp.get()


cdef class UnionDataset(Dataset):
"""A Dataset wrapping child datasets.

Expand Down Expand Up @@ -2600,10 +2670,14 @@ def _get_partition_keys(Expression partition_expression):


def _filesystemdataset_write(
data not None, object base_dir not None, str basename_template not None,
Schema schema not None, FileSystem filesystem not None,
Dataset data not None,
object base_dir not None,
str basename_template not None,
Schema schema not None,
FileSystem filesystem not None,
Partitioning partitioning not None,
FileWriteOptions file_options not None, bint use_threads,
FileWriteOptions file_options not None,
bint use_threads,
int max_partitions,
):
"""
Expand All @@ -2621,21 +2695,7 @@ def _filesystemdataset_write(
c_options.max_partitions = max_partitions
c_options.basename_template = tobytes(basename_template)

if isinstance(data, Dataset):
scanner = data._scanner(use_threads=use_threads)
else:
# data is list of batches/tables
for table in data:
if isinstance(table, Table):
for batch in table.to_batches():
c_batches.push_back((<RecordBatch> batch).sp_batch)
else:
c_batches.push_back((<RecordBatch> table).sp_batch)

data = Fragment.wrap(shared_ptr[CFragment](
new CInMemoryFragment(move(c_batches), _true.unwrap())))

scanner = Scanner.from_fragment(data, schema, use_threads=use_threads)
scanner = data._scanner(use_threads=use_threads)

c_scanner = (<Scanner> scanner).unwrap()
with nogil:
Expand Down
52 changes: 39 additions & 13 deletions python/pyarrow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""Dataset is currently unstable. APIs subject to change without notice."""

import pyarrow as pa
from pyarrow.util import _stringify_path, _is_path_like
from pyarrow.util import _is_iterable, _stringify_path, _is_path_like

from pyarrow._dataset import ( # noqa
CsvFileFormat,
Expand All @@ -37,6 +37,7 @@
HivePartitioning,
IpcFileFormat,
IpcFileWriteOptions,
InMemoryDataset,
ParquetDatasetFactory,
ParquetFactoryOptions,
ParquetFileFormat,
Expand Down Expand Up @@ -408,6 +409,13 @@ def _filesystem_dataset(source, schema=None, filesystem=None,
return factory.finish(schema)


def _in_memory_dataset(source, schema=None, **kwargs):
if any(v is not None for v in kwargs.values()):
raise ValueError(
"For in-memory datasets, you cannot pass any additional arguments")
return InMemoryDataset(source, schema)


def _union_dataset(children, schema=None, **kwargs):
if any(v is not None for v in kwargs.values()):
raise ValueError(
Expand Down Expand Up @@ -508,7 +516,8 @@ def dataset(source, schema=None, format=None, filesystem=None,

Parameters
----------
source : path, list of paths, dataset, list of datasets or URI
source : path, list of paths, dataset, list of datasets, (list of) batches
or tables, iterable of batches, RecordBatchReader, or URI
Path pointing to a single file:
Open a FileSystemDataset from a single file.
Path pointing to a directory:
Expand All @@ -524,6 +533,11 @@ def dataset(source, schema=None, format=None, filesystem=None,
A nested UnionDataset gets constructed, it allows arbitrary
composition of other datasets.
Note that additional keyword arguments are not allowed.
(List of) batches or tables, iterable of batches, or RecordBatchReader:
Create an InMemoryDataset. If an iterable or empty list is given,
a schema must also be given. If an iterable or RecordBatchReader
is given, the resulting dataset can only be scanned once; further
attempts will raise an error.
schema : Schema, optional
Optionally provide the Schema for the Dataset, in which case it will
not be inferred from the source.
Expand Down Expand Up @@ -636,21 +650,29 @@ def dataset(source, schema=None, format=None, filesystem=None,
selector_ignore_prefixes=ignore_prefixes
)

# TODO(kszucs): support InMemoryDataset for a table input
if _is_path_like(source):
return _filesystem_dataset(source, **kwargs)
elif isinstance(source, (tuple, list)):
if all(_is_path_like(elem) for elem in source):
return _filesystem_dataset(source, **kwargs)
elif all(isinstance(elem, Dataset) for elem in source):
return _union_dataset(source, **kwargs)
elif all(isinstance(elem, (pa.RecordBatch, pa.Table))
for elem in source):
return _in_memory_dataset(source, **kwargs)
else:
unique_types = set(type(elem).__name__ for elem in source)
type_names = ', '.join('{}'.format(t) for t in unique_types)
raise TypeError(
'Expected a list of path-like or dataset objects. The given '
'list contains the following types: {}'.format(type_names)
'Expected a list of path-like or dataset objects, or a list '
'of batches or tables. The given list contains the following '
'types: {}'.format(type_names)
)
elif isinstance(source, (pa.RecordBatch, pa.ipc.RecordBatchReader,
pa.Table)):
return _in_memory_dataset(source, **kwargs)
elif _is_iterable(source):
return _in_memory_dataset(source, **kwargs)
else:
raise TypeError(
'Expected a path-like, list of path-likes or a list of Datasets '
Expand All @@ -676,9 +698,11 @@ def write_dataset(data, base_dir, basename_template=None, format=None,

Parameters
----------
data : Dataset, Table/RecordBatch, or list of Table/RecordBatch
data : Dataset, Table/RecordBatch, RecordBatchReader, list of
Table/RecordBatch, or iterable of RecordBatch
The data to write. This can be a Dataset instance or
in-memory Arrow data.
in-memory Arrow data. If an iterable is given, the schema must
also be given.
base_dir : str
The root directory where to write the dataset.
basename_template : str, optional
Expand Down Expand Up @@ -710,15 +734,17 @@ def write_dataset(data, base_dir, basename_template=None, format=None,

if isinstance(data, Dataset):
schema = schema or data.schema
elif isinstance(data, (pa.Table, pa.RecordBatch)):
schema = schema or data.schema
data = [data]
elif isinstance(data, list):
elif isinstance(data, (list, tuple)):
schema = schema or data[0].schema
data = InMemoryDataset(data, schema=schema)
elif isinstance(data, (pa.RecordBatch, pa.ipc.RecordBatchReader,
pa.Table)) or _is_iterable(data):
data = InMemoryDataset(data, schema=schema)
schema = schema or data.schema
else:
raise ValueError(
"Only Dataset, Table/RecordBatch or a list of Table/RecordBatch "
"objects are supported."
"Only Dataset, Table/RecordBatch, RecordBatchReader, a list "
"of Tables/RecordBatches, or iterable of batches are supported."
)

if format is None and isinstance(data, FileSystemDataset):
Expand Down
5 changes: 5 additions & 0 deletions python/pyarrow/includes/libarrow_dataset.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil:

CResult[shared_ptr[CScannerBuilder]] NewScan()

cdef cppclass CInMemoryDataset "arrow::dataset::InMemoryDataset"(
CDataset):
CInMemoryDataset(shared_ptr[CRecordBatchReader])
CInMemoryDataset(shared_ptr[CTable])

cdef cppclass CUnionDataset "arrow::dataset::UnionDataset"(
CDataset):
@staticmethod
Expand Down
Loading