diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index 2df34145cd9..60d9bd73073 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -95,6 +95,55 @@ Result InMemoryFragment::Scan(std::shared_ptr opt return MakeMapIterator(fn, std::move(batches_it)); } +Result InMemoryFragment::ScanBatchesAsync( + const std::shared_ptr& options) { + struct State { + State(std::shared_ptr fragment, int64_t batch_size) + : fragment(std::move(fragment)), + batch_index(0), + offset(0), + batch_size(batch_size) {} + + std::shared_ptr Next() { + const auto& next_parent = fragment->record_batches_[batch_index]; + if (offset < next_parent->num_rows()) { + auto next = next_parent->Slice(offset, batch_size); + offset += batch_size; + return next; + } + batch_index++; + offset = 0; + return nullptr; + } + + bool Finished() { return batch_index >= fragment->record_batches_.size(); } + + std::shared_ptr fragment; + std::size_t batch_index; + int64_t offset; + int64_t batch_size; + }; + + struct Generator { + Generator(std::shared_ptr fragment, int64_t batch_size) + : state(std::make_shared(std::move(fragment), batch_size)) {} + + Future> operator()() { + while (!state->Finished()) { + auto next = state->Next(); + if (next) { + return Future>::MakeFinished(std::move(next)); + } + } + return AsyncGeneratorEnd>(); + } + + std::shared_ptr state; + }; + return Generator(internal::checked_pointer_cast(shared_from_this()), + options->batch_size); +} + Dataset::Dataset(std::shared_ptr schema, Expression partition_expression) : schema_(std::move(schema)), partition_expression_(std::move(partition_expression)) {} @@ -189,11 +238,11 @@ Result InMemoryDataset::GetFragmentsImpl(Expression) { " which did not match InMemorySource's: ", *schema); } - RecordBatchVector batches{batch}; - return std::make_shared(std::move(batches)); + return std::make_shared(RecordBatchVector{std::move(batch)}); }; - return MakeMaybeMapIterator(std::move(create_fragment), get_batches_->Get()); + auto batches_it = get_batches_->Get(); + return MakeMaybeMapIterator(std::move(create_fragment), std::move(batches_it)); } Result> UnionDataset::Make(std::shared_ptr schema, diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h index 12c199dc210..c5c22d731fc 100644 --- a/cpp/src/arrow/dataset/dataset.h +++ b/cpp/src/arrow/dataset/dataset.h @@ -34,6 +34,8 @@ namespace arrow { namespace dataset { +using RecordBatchGenerator = std::function>()>; + /// \brief A granular piece of a Dataset, such as an individual file. /// /// A Fragment can be read/scanned separately from other fragments. It yields a @@ -64,6 +66,10 @@ class ARROW_DS_EXPORT Fragment : public std::enable_shared_from_this { /// To receive a record batch stream which is fully filtered and projected, use Scanner. virtual Result Scan(std::shared_ptr options) = 0; + /// An asynchronous version of Scan + virtual Result ScanBatchesAsync( + const std::shared_ptr& options) = 0; + virtual std::string type_name() const = 0; virtual std::string ToString() const { return type_name(); } @@ -113,6 +119,8 @@ class ARROW_DS_EXPORT InMemoryFragment : public Fragment { explicit InMemoryFragment(RecordBatchVector record_batches, Expression = literal(true)); Result Scan(std::shared_ptr options) override; + Result ScanBatchesAsync( + const std::shared_ptr& options) override; std::string type_name() const override { return "in-memory"; } diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index c3b4433b6de..bf4e17da4b7 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -102,13 +102,70 @@ Result> FileFormat::MakeFragment( std::move(partition_expression), std::move(physical_schema))); } +// TODO(ARROW-12355[CSV], ARROW-11772[IPC], ARROW-11843[Parquet]) The following +// implementation of ScanBatchesAsync is both ugly and terribly ineffecient. Each of the +// formats should provide their own efficient implementation. +Result FileFormat::ScanBatchesAsync( + const std::shared_ptr& scan_options, + const std::shared_ptr& file) { + ARROW_ASSIGN_OR_RAISE(auto scan_task_it, ScanFile(scan_options, file)); + struct State { + State(std::shared_ptr scan_options, ScanTaskIterator scan_task_it) + : scan_options(std::move(scan_options)), + scan_task_it(std::move(scan_task_it)), + current_rb_it(), + finished(false) {} + + std::shared_ptr scan_options; + ScanTaskIterator scan_task_it; + RecordBatchIterator current_rb_it; + bool finished; + }; + struct Generator { + Future> operator()() { + while (!state->finished) { + if (!state->current_rb_it) { + RETURN_NOT_OK(PumpScanTask()); + if (state->finished) { + return AsyncGeneratorEnd>(); + } + } + ARROW_ASSIGN_OR_RAISE(auto next_batch, state->current_rb_it.Next()); + if (IsIterationEnd(next_batch)) { + state->current_rb_it = RecordBatchIterator(); + } else { + return Future>::MakeFinished(next_batch); + } + } + return AsyncGeneratorEnd>(); + } + Status PumpScanTask() { + ARROW_ASSIGN_OR_RAISE(auto next_task, state->scan_task_it.Next()); + if (IsIterationEnd(next_task)) { + state->finished = true; + } else { + ARROW_ASSIGN_OR_RAISE(state->current_rb_it, next_task->Execute()); + } + return Status::OK(); + } + std::shared_ptr state; + }; + return Generator{std::make_shared(scan_options, std::move(scan_task_it))}; +} + Result> FileFragment::ReadPhysicalSchemaImpl() { return format_->Inspect(source_); } Result FileFragment::Scan(std::shared_ptr options) { auto self = std::dynamic_pointer_cast(shared_from_this()); - return format_->ScanFile(std::move(options), self); + return format_->ScanFile(options, self); +} + +Result FileFragment::ScanBatchesAsync( + const std::shared_ptr& options) { + auto self = std::dynamic_pointer_cast(shared_from_this()); + return format_->ScanBatchesAsync(options, self); } struct FileSystemDataset::FragmentSubtrees { diff --git a/cpp/src/arrow/dataset/file_base.h b/cpp/src/arrow/dataset/file_base.h index c4c70d65d2f..08359881a20 100644 --- a/cpp/src/arrow/dataset/file_base.h +++ b/cpp/src/arrow/dataset/file_base.h @@ -149,9 +149,13 @@ class ARROW_DS_EXPORT FileFormat : public std::enable_shared_from_this ScanFile( - std::shared_ptr options, + const std::shared_ptr& options, const std::shared_ptr& file) const = 0; + virtual Result ScanBatchesAsync( + const std::shared_ptr& options, + const std::shared_ptr& file); + /// \brief Open a fragment virtual Result> MakeFragment( FileSource source, Expression partition_expression, @@ -178,6 +182,8 @@ class ARROW_DS_EXPORT FileFormat : public std::enable_shared_from_this Scan(std::shared_ptr options) override; + Result ScanBatchesAsync( + const std::shared_ptr& options) override; std::string type_name() const override { return format_->type_name(); } std::string ToString() const override { return source_.path(); }; diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc index 8ba6505524c..a8274a545c4 100644 --- a/cpp/src/arrow/dataset/file_csv.cc +++ b/cpp/src/arrow/dataset/file_csv.cc @@ -191,11 +191,10 @@ Result> CsvFileFormat::Inspect(const FileSource& source) } Result CsvFileFormat::ScanFile( - std::shared_ptr options, + const std::shared_ptr& options, const std::shared_ptr& fragment) const { auto this_ = checked_pointer_cast(shared_from_this()); - auto task = - std::make_shared(std::move(this_), std::move(options), fragment); + auto task = std::make_shared(std::move(this_), options, fragment); return MakeVectorIterator>({std::move(task)}); } diff --git a/cpp/src/arrow/dataset/file_csv.h b/cpp/src/arrow/dataset/file_csv.h index 7232f37658c..9289c016afb 100644 --- a/cpp/src/arrow/dataset/file_csv.h +++ b/cpp/src/arrow/dataset/file_csv.h @@ -54,7 +54,7 @@ class ARROW_DS_EXPORT CsvFileFormat : public FileFormat { /// \brief Open a file for scanning Result ScanFile( - std::shared_ptr options, + const std::shared_ptr& options, const std::shared_ptr& fragment) const override; Result> MakeWriter( diff --git a/cpp/src/arrow/dataset/file_ipc.cc b/cpp/src/arrow/dataset/file_ipc.cc index 24ea6e36ff2..49893cde6d9 100644 --- a/cpp/src/arrow/dataset/file_ipc.cc +++ b/cpp/src/arrow/dataset/file_ipc.cc @@ -168,9 +168,9 @@ Result> IpcFileFormat::Inspect(const FileSource& source) } Result IpcFileFormat::ScanFile( - std::shared_ptr options, + const std::shared_ptr& options, const std::shared_ptr& fragment) const { - return IpcScanTaskIterator::Make(std::move(options), std::move(fragment)); + return IpcScanTaskIterator::Make(options, fragment); } // diff --git a/cpp/src/arrow/dataset/file_ipc.h b/cpp/src/arrow/dataset/file_ipc.h index aa3444eefa4..2c65078c754 100644 --- a/cpp/src/arrow/dataset/file_ipc.h +++ b/cpp/src/arrow/dataset/file_ipc.h @@ -53,7 +53,7 @@ class ARROW_DS_EXPORT IpcFileFormat : public FileFormat { /// \brief Open a file for scanning Result ScanFile( - std::shared_ptr options, + const std::shared_ptr& options, const std::shared_ptr& fragment) const override; Result> MakeWriter( diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc index 8caae949784..497e4128fdf 100644 --- a/cpp/src/arrow/dataset/file_parquet.cc +++ b/cpp/src/arrow/dataset/file_parquet.cc @@ -326,7 +326,7 @@ Result> ParquetFileFormat::GetReader } Result ParquetFileFormat::ScanFile( - std::shared_ptr options, + const std::shared_ptr& options, const std::shared_ptr& fragment) const { auto* parquet_fragment = checked_cast(fragment.get()); std::vector row_groups; diff --git a/cpp/src/arrow/dataset/file_parquet.h b/cpp/src/arrow/dataset/file_parquet.h index 734917e6384..790e89c24c2 100644 --- a/cpp/src/arrow/dataset/file_parquet.h +++ b/cpp/src/arrow/dataset/file_parquet.h @@ -96,7 +96,7 @@ class ARROW_DS_EXPORT ParquetFileFormat : public FileFormat { /// \brief Open a file for scanning Result ScanFile( - std::shared_ptr options, + const std::shared_ptr& options, const std::shared_ptr& file) const override; using FileFormat::MakeFragment; diff --git a/cpp/src/arrow/dataset/file_test.cc b/cpp/src/arrow/dataset/file_test.cc index c7ce5154d0a..dbddb5b385b 100644 --- a/cpp/src/arrow/dataset/file_test.cc +++ b/cpp/src/arrow/dataset/file_test.cc @@ -30,6 +30,7 @@ #include "arrow/filesystem/path_util.h" #include "arrow/filesystem/test_util.h" #include "arrow/status.h" +#include "arrow/testing/future_util.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/io_util.h" @@ -82,6 +83,51 @@ TEST(FileSource, BufferBased) { ASSERT_EQ(source1.buffer(), source3.buffer()); } +constexpr int kNumScanTasks = 2; +constexpr int kBatchesPerScanTask = 2; +constexpr int kRowsPerBatch = 1024; +class MockFileFormat : public FileFormat { + virtual std::string type_name() const { return "mock"; } + virtual bool Equals(const FileFormat& other) const { return false; } + virtual Result IsSupported(const FileSource& source) const { return true; } + virtual Result> Inspect(const FileSource& source) const { + return Status::NotImplemented("Not needed for test"); + } + virtual Result> MakeWriter( + std::shared_ptr destination, std::shared_ptr schema, + std::shared_ptr options) const { + return Status::NotImplemented("Not needed for test"); + } + virtual std::shared_ptr DefaultWriteOptions() { return nullptr; } + + virtual Result ScanFile( + const std::shared_ptr& options, + const std::shared_ptr& file) const { + auto sch = schema({field("i32", int32())}); + ScanTaskVector scan_tasks; + for (int i = 0; i < kNumScanTasks; i++) { + RecordBatchVector batches; + for (int j = 0; j < kBatchesPerScanTask; j++) { + batches.push_back(ConstantArrayGenerator::Zeroes(kRowsPerBatch, sch)); + } + scan_tasks.push_back(std::make_shared( + std::move(batches), std::make_shared(), nullptr)); + } + return MakeVectorIterator(std::move(scan_tasks)); + } +}; + +TEST(FileFormat, ScanAsync) { + MockFileFormat format; + auto scan_options = std::make_shared(); + ASSERT_OK_AND_ASSIGN(auto batch_gen, format.ScanBatchesAsync(scan_options, nullptr)); + ASSERT_FINISHES_OK_AND_ASSIGN(auto batches, CollectAsyncGenerator(batch_gen)); + ASSERT_EQ(kNumScanTasks * kBatchesPerScanTask, static_cast(batches.size())); + for (int i = 0; i < kNumScanTasks * kBatchesPerScanTask; i++) { + ASSERT_EQ(kRowsPerBatch, batches[i]->num_rows()); + } +} + TEST_F(TestFileSystemDataset, Basic) { MakeDataset({}); AssertFragmentsAreFromPath(*dataset_->GetFragments(), {}); diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index f7bd3c063e5..43c024768ea 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -40,6 +40,8 @@ namespace arrow { namespace dataset { +using FragmentGenerator = std::function>()>; + std::vector ScanOptions::MaterializedFields() const { std::vector fields; @@ -242,6 +244,31 @@ struct ScanBatchesState : public std::enable_shared_from_this size_t pop_cursor = 0; }; +class ARROW_DS_EXPORT SyncScanner : public Scanner { + public: + SyncScanner(std::shared_ptr dataset, std::shared_ptr scan_options) + : Scanner(std::move(scan_options)), dataset_(std::move(dataset)) {} + + SyncScanner(std::shared_ptr fragment, + std::shared_ptr scan_options) + : Scanner(std::move(scan_options)), fragment_(std::move(fragment)) {} + + Result ScanBatches() override; + Result Scan() override; + Status Scan(std::function visitor) override; + Result> ToTable() override; + + protected: + /// \brief GetFragments returns an iterator over all Fragments in this scan. + Result GetFragments(); + Future> ToTableInternal(internal::Executor* cpu_executor); + Result ScanInternal(); + + std::shared_ptr dataset_; + // TODO(ARROW-8065) remove fragment_ after a Dataset is constuctible from fragments + std::shared_ptr fragment_; +}; + Result SyncScanner::ScanBatches() { ARROW_ASSIGN_OR_RAISE(auto scan_task_it, ScanInternal()); auto task_group = scan_options_->TaskGroup(); @@ -311,6 +338,269 @@ Result ScanTaskIteratorFromRecordBatch( return fragment->Scan(std::move(options)); } +class ARROW_DS_EXPORT AsyncScanner : public Scanner, + public std::enable_shared_from_this { + public: + AsyncScanner(std::shared_ptr dataset, + std::shared_ptr scan_options) + : Scanner(std::move(scan_options)), dataset_(std::move(dataset)) {} + + Status Scan(std::function visitor) override; + Result ScanBatches() override; + Result ScanBatchesUnordered() override; + Result> ToTable() override; + + private: + Result ScanBatchesAsync(internal::Executor* executor); + Future<> VisitBatchesAsync(std::function visitor, + internal::Executor* executor); + Result ScanBatchesUnorderedAsync( + internal::Executor* executor); + Future> ToTableAsync(internal::Executor* executor); + + Result GetFragments() const; + + std::shared_ptr dataset_; +}; + +namespace { + +inline Result DoFilterAndProjectRecordBatchAsync( + const std::shared_ptr& scanner, const EnumeratedRecordBatch& in) { + ARROW_ASSIGN_OR_RAISE(Expression simplified_filter, + SimplifyWithGuarantee(scanner->options()->filter, + in.fragment.value->partition_expression())); + + compute::ExecContext exec_context{scanner->options()->pool}; + ARROW_ASSIGN_OR_RAISE( + Datum mask, ExecuteScalarExpression(simplified_filter, Datum(in.record_batch.value), + &exec_context)); + + Datum filtered; + if (mask.is_scalar()) { + const auto& mask_scalar = mask.scalar_as(); + if (mask_scalar.is_valid && mask_scalar.value) { + // filter matches entire table + filtered = in.record_batch.value; + } else { + // Filter matches nothing + filtered = in.record_batch.value->Slice(0, 0); + } + } else { + ARROW_ASSIGN_OR_RAISE( + filtered, compute::Filter(in.record_batch.value, mask, + compute::FilterOptions::Defaults(), &exec_context)); + } + + ARROW_ASSIGN_OR_RAISE(Expression simplified_projection, + SimplifyWithGuarantee(scanner->options()->projection, + in.fragment.value->partition_expression())); + ARROW_ASSIGN_OR_RAISE( + Datum projected, + ExecuteScalarExpression(simplified_projection, filtered, &exec_context)); + + DCHECK_EQ(projected.type()->id(), Type::STRUCT); + if (projected.shape() == ValueDescr::SCALAR) { + // Only virtual columns are projected. Broadcast to an array + ARROW_ASSIGN_OR_RAISE( + projected, + MakeArrayFromScalar(*projected.scalar(), filtered.record_batch()->num_rows(), + scanner->options()->pool)); + } + ARROW_ASSIGN_OR_RAISE(auto out, + RecordBatch::FromStructArray(projected.array_as())); + auto projected_batch = + out->ReplaceSchemaMetadata(in.record_batch.value->schema()->metadata()); + + return EnumeratedRecordBatch{ + {std::move(projected_batch), in.record_batch.index, in.record_batch.last}, + in.fragment}; +} + +inline EnumeratedRecordBatchGenerator FilterAndProjectRecordBatchAsync( + const std::shared_ptr& scanner, EnumeratedRecordBatchGenerator rbs) { + auto mapper = [scanner](const EnumeratedRecordBatch& in) { + return DoFilterAndProjectRecordBatchAsync(scanner, in); + }; + return MakeMappedGenerator(std::move(rbs), mapper); +} + +Result FragmentToBatches( + std::shared_ptr scanner, + const Enumerated>& fragment) { + ARROW_ASSIGN_OR_RAISE(auto batch_gen, + fragment.value->ScanBatchesAsync(scanner->options())); + auto enumerated_batch_gen = MakeEnumeratedGenerator(std::move(batch_gen)); + + auto combine_fn = + [fragment](const Enumerated>& record_batch) { + return EnumeratedRecordBatch{record_batch, fragment}; + }; + + auto combined_gen = MakeMappedGenerator(enumerated_batch_gen, + std::move(combine_fn)); + + return FilterAndProjectRecordBatchAsync(scanner, std::move(combined_gen)); +} + +Result> FragmentsToBatches( + std::shared_ptr scanner, FragmentGenerator fragment_gen) { + auto enumerated_fragment_gen = MakeEnumeratedGenerator(std::move(fragment_gen)); + return MakeMappedGenerator( + std::move(enumerated_fragment_gen), + [scanner](const Enumerated>& fragment) { + return FragmentToBatches(scanner, fragment); + }); +} + +} // namespace + +Result AsyncScanner::GetFragments() const { + // TODO(ARROW-8163): Async fragment scanning will return AsyncGenerator + // here. Current iterator based versions are all fast & sync so we will just ToVector + // it + ARROW_ASSIGN_OR_RAISE(auto fragments_it, dataset_->GetFragments(scan_options_->filter)); + ARROW_ASSIGN_OR_RAISE(auto fragments_vec, fragments_it.ToVector()); + return MakeVectorGenerator(std::move(fragments_vec)); +} + +Result AsyncScanner::ScanBatches() { + ARROW_ASSIGN_OR_RAISE(auto batches_gen, ScanBatchesAsync(internal::GetCpuThreadPool())); + return MakeGeneratorIterator(std::move(batches_gen)); +} + +Result AsyncScanner::ScanBatchesUnordered() { + ARROW_ASSIGN_OR_RAISE(auto batches_gen, + ScanBatchesUnorderedAsync(internal::GetCpuThreadPool())); + return MakeGeneratorIterator(std::move(batches_gen)); +} + +Result> AsyncScanner::ToTable() { + auto table_fut = ToTableAsync(internal::GetCpuThreadPool()); + return table_fut.result(); +} + +Result AsyncScanner::ScanBatchesUnorderedAsync( + internal::Executor* cpu_executor) { + auto self = shared_from_this(); + ARROW_ASSIGN_OR_RAISE(auto fragment_gen, GetFragments()); + ARROW_ASSIGN_OR_RAISE(auto batch_gen_gen, + FragmentsToBatches(self, std::move(fragment_gen))); + return MakeConcatenatedGenerator(std::move(batch_gen_gen)); +} + +Result AsyncScanner::ScanBatchesAsync( + internal::Executor* cpu_executor) { + ARROW_ASSIGN_OR_RAISE(auto unordered, ScanBatchesUnorderedAsync(cpu_executor)); + auto left_after_right = [](const EnumeratedRecordBatch& left, + const EnumeratedRecordBatch& right) { + // Before any comes first + if (left.fragment.value == nullptr) { + return false; + } + if (right.fragment.value == nullptr) { + return true; + } + // Compare batches if fragment is the same + if (left.fragment.index == right.fragment.index) { + return left.record_batch.index > right.record_batch.index; + } + // Otherwise compare fragment + return left.fragment.index > right.fragment.index; + }; + auto is_next = [](const EnumeratedRecordBatch& prev, + const EnumeratedRecordBatch& next) { + // Only true if next is the first batch + if (prev.fragment.value == nullptr) { + return next.fragment.index == 0 && next.record_batch.index == 0; + } + // If same fragment, compare batch index + if (prev.fragment.index == next.fragment.index) { + return next.record_batch.index == prev.record_batch.index + 1; + } + // Else only if next first batch of next fragment and prev is last batch of previous + return next.fragment.index == prev.fragment.index + 1 && prev.record_batch.last && + next.record_batch.index == 0; + }; + auto before_any = EnumeratedRecordBatch{{nullptr, -1, false}, {nullptr, -1, false}}; + auto sequenced = MakeSequencingGenerator(std::move(unordered), left_after_right, + is_next, before_any); + + auto unenumerate_fn = [](const EnumeratedRecordBatch& enumerated_batch) { + return TaggedRecordBatch{enumerated_batch.record_batch.value, + enumerated_batch.fragment.value}; + }; + return MakeMappedGenerator(std::move(sequenced), unenumerate_fn); +} + +struct AsyncTableAssemblyState { + /// Protecting mutating accesses to batches + std::mutex mutex{}; + std::vector batches{}; + + void Emplace(const EnumeratedRecordBatch& batch) { + std::lock_guard lock(mutex); + auto fragment_index = batch.fragment.index; + auto batch_index = batch.record_batch.index; + if (static_cast(batches.size()) <= fragment_index) { + batches.resize(fragment_index + 1); + } + if (static_cast(batches[fragment_index].size()) <= batch_index) { + batches[fragment_index].resize(batch_index + 1); + } + batches[fragment_index][batch_index] = batch.record_batch.value; + } + + RecordBatchVector Finish() { + RecordBatchVector all_batches; + for (auto& fragment_batches : batches) { + auto end = std::make_move_iterator(fragment_batches.end()); + for (auto it = std::make_move_iterator(fragment_batches.begin()); it != end; it++) { + all_batches.push_back(*it); + } + } + return all_batches; + } +}; + +Status AsyncScanner::Scan(std::function visitor) { + return internal::RunSynchronouslyVoid( + [this, &visitor](Executor* executor) { + return VisitBatchesAsync(visitor, executor); + }, + scan_options_->use_threads); +} + +Future<> AsyncScanner::VisitBatchesAsync(std::function visitor, + internal::Executor* executor) { + ARROW_ASSIGN_OR_RAISE(auto batches_gen, ScanBatchesAsync(executor)); + return VisitAsyncGenerator(std::move(batches_gen), visitor); +} + +Future> AsyncScanner::ToTableAsync( + internal::Executor* cpu_executor) { + auto scan_options = scan_options_; + ARROW_ASSIGN_OR_RAISE(auto positioned_batch_gen, + ScanBatchesUnorderedAsync(cpu_executor)); + /// Wraps the state in a shared_ptr to ensure that failing ScanTasks don't + /// invalidate concurrently running tasks when Finish() early returns + /// and the mutex/batches fail out of scope. + auto state = std::make_shared(); + + auto table_building_task = [state](const EnumeratedRecordBatch& batch) { + state->Emplace(batch); + return batch; + }; + + auto table_building_gen = MakeMappedGenerator( + positioned_batch_gen, table_building_task); + + return DiscardAllFromAsyncGenerator(table_building_gen) + .Then([state, scan_options](const detail::Empty&) { + return Table::FromRecordBatches(scan_options->projected_schema, state->Finish()); + }); +} + ScannerBuilder::ScannerBuilder(std::shared_ptr dataset) : ScannerBuilder(std::move(dataset), std::make_shared()) {} @@ -359,6 +649,11 @@ Status ScannerBuilder::UseThreads(bool use_threads) { return Status::OK(); } +Status ScannerBuilder::UseAsync(bool use_async) { + scan_options_->use_async = use_async; + return Status::OK(); +} + Status ScannerBuilder::BatchSize(int64_t batch_size) { if (batch_size <= 0) { return Status::Invalid("BatchSize must be greater than 0, got ", batch_size); @@ -388,8 +683,7 @@ Result> ScannerBuilder::Finish() { return std::make_shared(fragment_, scan_options_); } if (scan_options_->use_async) { - // TODO(ARROW-12289) - return Status::NotImplemented("The asynchronous scanner is not yet available"); + return std::make_shared(dataset_, scan_options_); } else { return std::make_shared(dataset_, scan_options_); } diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index 956fbbb2ee3..6315cf922d0 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -33,6 +33,7 @@ #include "arrow/io/interfaces.h" #include "arrow/memory_pool.h" #include "arrow/type_fwd.h" +#include "arrow/util/async_generator.h" #include "arrow/util/iterator.h" #include "arrow/util/thread_pool.h" #include "arrow/util/type_fwd.h" @@ -98,11 +99,6 @@ struct ARROW_DS_EXPORT ScanOptions { /// A pool from which materialized and scanned arrays will be allocated. MemoryPool* pool = arrow::default_memory_pool(); - /// Executor on which to run any CPU tasks - /// - /// Note: Will be ignored if use_threads is set to false - internal::Executor* cpu_executor = internal::GetCpuThreadPool(); - /// IOContext for any IO tasks /// /// Note: The IOContext executor will be ignored if use_threads is set to false @@ -166,13 +162,6 @@ class ARROW_DS_EXPORT ScanTask { std::shared_ptr fragment_; }; -template -struct Enumerated { - T value; - int index; - bool last; -}; - /// \brief Combines a record batch with the fragment that the record batch originated /// from /// @@ -305,34 +294,6 @@ class ARROW_DS_EXPORT Scanner { const std::shared_ptr scan_options_; }; -class ARROW_DS_EXPORT SyncScanner : public Scanner { - public: - SyncScanner(std::shared_ptr dataset, std::shared_ptr scan_options) - : Scanner(std::move(scan_options)), dataset_(std::move(dataset)) {} - - SyncScanner(std::shared_ptr fragment, - std::shared_ptr scan_options) - : Scanner(std::move(scan_options)), fragment_(std::move(fragment)) {} - - Result ScanBatches() override; - - Result Scan() override; - - Status Scan(std::function visitor) override; - - Result> ToTable() override; - - protected: - /// \brief GetFragments returns an iterator over all Fragments in this scan. - Result GetFragments(); - Future> ToTableInternal(internal::Executor* cpu_executor); - Result ScanInternal(); - - std::shared_ptr dataset_; - // TODO(ARROW-8065) remove fragment_ after a Dataset is constuctible from fragments - std::shared_ptr fragment_; -}; - /// \brief ScannerBuilder is a factory class to construct a Scanner. It is used /// to pass information, notably a potential filter expression and a subset of /// columns to materialize. @@ -386,6 +347,12 @@ class ARROW_DS_EXPORT ScannerBuilder { /// ThreadPool found in ScanOptions; Status UseThreads(bool use_threads = true); + /// \brief Indicate if the Scanner should run in experimental "async" mode + /// + /// This mode should have considerably better performance on high-latency or parallel + /// filesystems but is still experimental + Status UseAsync(bool use_async = true); + /// \brief Set the maximum number of rows per RecordBatch. /// /// \param[in] batch_size the maximum number of rows. diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index b4e374a7795..552102b3eda 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -18,9 +18,11 @@ #include "arrow/dataset/scanner.h" #include +#include #include +#include "arrow/compute/api.h" #include "arrow/compute/api_scalar.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/cast.h" @@ -39,33 +41,56 @@ using testing::IsEmpty; namespace arrow { namespace dataset { -constexpr int64_t kNumberChildDatasets = 2; -constexpr int64_t kNumberBatches = 16; -constexpr int64_t kBatchSize = 1024; +struct TestScannerParams { + bool use_async; + bool use_threads; + int num_child_datasets; + int num_batches; + int items_per_batch; + + static std::vector Values() { + std::vector values; + for (int sync = 0; sync < 2; sync++) { + for (int use_threads = 0; use_threads < 2; use_threads++) { + values.push_back( + {static_cast(sync), static_cast(use_threads), 1, 1, 1024}); + values.push_back( + {static_cast(sync), static_cast(use_threads), 2, 16, 1024}); + } + } + return values; + } +}; -class TestScanner : public DatasetFixtureMixin, - public ::testing::WithParamInterface { - protected: - bool UseThreads() { return GetParam(); } +std::ostream& operator<<(std::ostream& out, const TestScannerParams& params) { + out << (params.use_async ? "async-" : "sync-") + << (params.use_threads ? "threaded-" : "serial-") << params.num_child_datasets + << "d-" << params.num_batches << "b-" << params.items_per_batch << "i"; + return out; +} +class TestScanner : public DatasetFixtureMixinWithParam { + protected: std::shared_ptr MakeScanner(std::shared_ptr batch) { - std::vector> batches{static_cast(kNumberBatches), - batch}; + std::vector> batches{ + static_cast(GetParam().num_batches), batch}; - DatasetVector children{static_cast(kNumberChildDatasets), + DatasetVector children{static_cast(GetParam().num_child_datasets), std::make_shared(batch->schema(), batches)}; EXPECT_OK_AND_ASSIGN(auto dataset, UnionDataset::Make(batch->schema(), children)); ScannerBuilder builder(dataset, options_); - ARROW_EXPECT_OK(builder.UseThreads(UseThreads())); + ARROW_EXPECT_OK(builder.UseThreads(GetParam().use_threads)); + ARROW_EXPECT_OK(builder.UseAsync(GetParam().use_async)); EXPECT_OK_AND_ASSIGN(auto scanner, builder.Finish()); return scanner; } void AssertScannerEqualsRepetitionsOf( std::shared_ptr scanner, std::shared_ptr batch, - const int64_t total_batches = kNumberChildDatasets * kNumberBatches) { + const int64_t total_batches = GetParam().num_child_datasets * + GetParam().num_batches) { auto expected = ConstantArrayGenerator::Repeat(total_batches, batch); // Verifies that the unified BatchReader is equivalent to flattening all the @@ -75,7 +100,8 @@ class TestScanner : public DatasetFixtureMixin, void AssertScanBatchesEqualRepetitionsOf( std::shared_ptr scanner, std::shared_ptr batch, - const int64_t total_batches = kNumberChildDatasets * kNumberBatches) { + const int64_t total_batches = GetParam().num_child_datasets * + GetParam().num_batches) { auto expected = ConstantArrayGenerator::Repeat(total_batches, batch); AssertScanBatchesEquals(expected.get(), scanner.get()); @@ -83,38 +109,40 @@ class TestScanner : public DatasetFixtureMixin, void AssertScanBatchesUnorderedEqualRepetitionsOf( std::shared_ptr scanner, std::shared_ptr batch, - const int64_t total_batches = kNumberChildDatasets * kNumberBatches) { + const int64_t total_batches = GetParam().num_child_datasets * + GetParam().num_batches) { auto expected = ConstantArrayGenerator::Repeat(total_batches, batch); - AssertScanBatchesUnorderedEquals(expected.get(), scanner.get()); + AssertScanBatchesUnorderedEquals(expected.get(), scanner.get(), 1); } }; TEST_P(TestScanner, Scan) { SetSchema({field("i32", int32()), field("f64", float64())}); - auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); - AssertScannerEqualsRepetitionsOf(MakeScanner(batch), batch); + auto batch = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, schema_); + AssertScanBatchesUnorderedEqualRepetitionsOf(MakeScanner(batch), batch); } TEST_P(TestScanner, ScanBatches) { SetSchema({field("i32", int32()), field("f64", float64())}); - auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); + auto batch = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, schema_); AssertScanBatchesEqualRepetitionsOf(MakeScanner(batch), batch); } TEST_P(TestScanner, ScanBatchesUnordered) { SetSchema({field("i32", int32()), field("f64", float64())}); - auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); + auto batch = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, schema_); AssertScanBatchesUnorderedEqualRepetitionsOf(MakeScanner(batch), batch); } TEST_P(TestScanner, ScanWithCappedBatchSize) { SetSchema({field("i32", int32()), field("f64", float64())}); - auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); - options_->batch_size = kBatchSize / 2; - auto expected = batch->Slice(kBatchSize / 2); - AssertScannerEqualsRepetitionsOf(MakeScanner(batch), expected, - kNumberChildDatasets * kNumberBatches * 2); + auto batch = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, schema_); + options_->batch_size = GetParam().items_per_batch / 2; + auto expected = batch->Slice(GetParam().items_per_batch / 2); + AssertScanBatchesEqualRepetitionsOf( + MakeScanner(batch), expected, + GetParam().num_child_datasets * GetParam().num_batches * 2); } TEST_P(TestScanner, FilteredScan) { @@ -122,7 +150,8 @@ TEST_P(TestScanner, FilteredScan) { double value = 0.5; ASSERT_OK_AND_ASSIGN(auto f64, - ArrayFromBuilderVisitor(float64(), kBatchSize, kBatchSize / 2, + ArrayFromBuilderVisitor(float64(), GetParam().items_per_batch, + GetParam().items_per_batch / 2, [&](DoubleBuilder* builder) { builder->UnsafeAppend(value); builder->UnsafeAppend(-value); @@ -134,47 +163,58 @@ TEST_P(TestScanner, FilteredScan) { auto batch = RecordBatch::Make(schema_, f64->length(), {f64}); value = 0.5; - ASSERT_OK_AND_ASSIGN( - auto f64_filtered, - ArrayFromBuilderVisitor(float64(), kBatchSize / 2, [&](DoubleBuilder* builder) { - builder->UnsafeAppend(value); - value += 1.0; - })); + ASSERT_OK_AND_ASSIGN(auto f64_filtered, + ArrayFromBuilderVisitor(float64(), GetParam().items_per_batch / 2, + [&](DoubleBuilder* builder) { + builder->UnsafeAppend(value); + value += 1.0; + })); auto filtered_batch = RecordBatch::Make(schema_, f64_filtered->length(), {f64_filtered}); - AssertScannerEqualsRepetitionsOf(MakeScanner(batch), filtered_batch); + AssertScanBatchesEqualRepetitionsOf(MakeScanner(batch), filtered_batch); +} + +TEST_P(TestScanner, ProjectedScan) { + SetSchema({field("i32", int32()), field("f64", float64())}); + SetProjectedColumns({"i32"}); + auto batch_in = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, schema_); + auto batch_out = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, + schema({field("i32", int32())})); + AssertScanBatchesUnorderedEqualRepetitionsOf(MakeScanner(batch_in), batch_out); } TEST_P(TestScanner, MaterializeMissingColumn) { SetSchema({field("i32", int32()), field("f64", float64())}); - auto batch_missing_f64 = - ConstantArrayGenerator::Zeroes(kBatchSize, schema({field("i32", int32())})); + auto batch_missing_f64 = ConstantArrayGenerator::Zeroes( + GetParam().items_per_batch, schema({field("i32", int32())})); auto fragment_missing_f64 = std::make_shared( - RecordBatchVector{static_cast(kNumberChildDatasets * kNumberBatches), - batch_missing_f64}, + RecordBatchVector{ + static_cast(GetParam().num_child_datasets * GetParam().num_batches), + batch_missing_f64}, equal(field_ref("f64"), literal(2.5))); - ASSERT_OK_AND_ASSIGN(auto f64, ArrayFromBuilderVisitor(float64(), kBatchSize, - [&](DoubleBuilder* builder) { - builder->UnsafeAppend(2.5); - })); + ASSERT_OK_AND_ASSIGN(auto f64, + ArrayFromBuilderVisitor( + float64(), GetParam().items_per_batch, + [&](DoubleBuilder* builder) { builder->UnsafeAppend(2.5); })); auto batch_with_f64 = RecordBatch::Make(schema_, f64->length(), {batch_missing_f64->column(0), f64}); ScannerBuilder builder{schema_, fragment_missing_f64, options_}; ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish()); - AssertScannerEqualsRepetitionsOf(scanner, batch_with_f64); + AssertScanBatchesEqualRepetitionsOf(scanner, batch_with_f64); } TEST_P(TestScanner, ToTable) { SetSchema({field("i32", int32()), field("f64", float64())}); - auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); - std::vector> batches{kNumberBatches * kNumberChildDatasets, - batch}; + auto batch = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, schema_); + std::vector> batches{ + static_cast(GetParam().num_batches * GetParam().num_child_datasets), + batch}; ASSERT_OK_AND_ASSIGN(auto expected, Table::FromRecordBatches(batches)); @@ -189,7 +229,7 @@ TEST_P(TestScanner, ToTable) { TEST_P(TestScanner, ScanWithVisitor) { SetSchema({field("i32", int32()), field("f64", float64())}); - auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); + auto batch = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, schema_); auto scanner = MakeScanner(batch); ASSERT_OK(scanner->Scan([batch](TaggedRecordBatch scanned_batch) { AssertBatchesEqual(*batch, *scanned_batch.record_batch); @@ -198,21 +238,24 @@ TEST_P(TestScanner, ScanWithVisitor) { } TEST_P(TestScanner, TakeIndices) { + auto batch_size = GetParam().items_per_batch; + auto num_batches = GetParam().num_batches; + auto num_datasets = GetParam().num_child_datasets; SetSchema({field("i32", int32()), field("f64", float64())}); ArrayVector arrays(2); - ArrayFromVector(internal::Iota(kBatchSize), &arrays[0]); - ArrayFromVector(internal::Iota(static_cast(kBatchSize)), + ArrayFromVector(internal::Iota(batch_size), &arrays[0]); + ArrayFromVector(internal::Iota(static_cast(batch_size)), &arrays[1]); - auto batch = RecordBatch::Make(schema_, kBatchSize, arrays); + auto batch = RecordBatch::Make(schema_, batch_size, arrays); auto scanner = MakeScanner(batch); std::shared_ptr indices; { - ArrayFromVector(internal::Iota(kBatchSize), &indices); + ArrayFromVector(internal::Iota(batch_size), &indices); ASSERT_OK_AND_ASSIGN(auto taken, scanner->TakeRows(*indices)); ASSERT_OK_AND_ASSIGN(auto expected, Table::FromRecordBatches({batch})); - ASSERT_EQ(expected->num_rows(), kBatchSize); + ASSERT_EQ(expected->num_rows(), batch_size); AssertTablesEqual(*expected, *taken); } { @@ -223,16 +266,16 @@ TEST_P(TestScanner, TakeIndices) { ASSERT_EQ(expected.table()->num_rows(), 4); AssertTablesEqual(*expected.table(), *taken); } - { - ArrayFromVector({kBatchSize + 2, kBatchSize + 1}, &indices); + if (num_batches > 1) { + ArrayFromVector({batch_size + 2, batch_size + 1}, &indices); ASSERT_OK_AND_ASSIGN(auto table, scanner->ToTable()); ASSERT_OK_AND_ASSIGN(auto taken, scanner->TakeRows(*indices)); ASSERT_OK_AND_ASSIGN(auto expected, compute::Take(table, *indices)); ASSERT_EQ(expected.table()->num_rows(), 2); AssertTablesEqual(*expected.table(), *taken); } - { - ArrayFromVector({1, 3, 5, 7, kBatchSize + 1, 2 * kBatchSize + 2}, + if (num_batches > 1) { + ArrayFromVector({1, 3, 5, 7, batch_size + 1, 2 * batch_size + 2}, &indices); ASSERT_OK_AND_ASSIGN(auto taken, scanner->TakeRows(*indices)); ASSERT_OK_AND_ASSIGN(auto table, scanner->ToTable()); @@ -241,19 +284,23 @@ TEST_P(TestScanner, TakeIndices) { AssertTablesEqual(*expected.table(), *taken); } { - auto base = kNumberChildDatasets * kNumberBatches * kBatchSize; + auto base = num_datasets * num_batches * batch_size; ArrayFromVector({base + 1}, &indices); EXPECT_RAISES_WITH_MESSAGE_THAT( - IndexError, ::testing::HasSubstr("Some indices were out of bounds: 32769"), + IndexError, + ::testing::HasSubstr("Some indices were out of bounds: " + + std::to_string(base + 1)), scanner->TakeRows(*indices)); } { - auto base = kNumberChildDatasets * kNumberBatches * kBatchSize; + auto base = num_datasets * num_batches * batch_size; ArrayFromVector( {1, 2, base + 1, base + 2, base + 3, base + 4, base + 5, base + 6}, &indices); EXPECT_RAISES_WITH_MESSAGE_THAT( IndexError, - ::testing::HasSubstr("Some indices were out of bounds: 32769, 32770, 32771, ..."), + ::testing::HasSubstr( + "Some indices were out of bounds: " + std::to_string(base + 1) + ", " + + std::to_string(base + 2) + ", " + std::to_string(base + 3) + ", ..."), scanner->TakeRows(*indices)); } } @@ -276,11 +323,11 @@ class FailingFragment : public InMemoryFragment { TEST_P(TestScanner, ScanBatchesFailure) { SetSchema({field("i32", int32()), field("f64", float64())}); - auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); + auto batch = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, schema_); RecordBatchVector batches = {batch, batch, batch, batch}; ScannerBuilder builder(schema_, std::make_shared(batches), options_); - ASSERT_OK(builder.UseThreads(UseThreads())); + ASSERT_OK(builder.UseThreads(GetParam().use_threads)); ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish()); ASSERT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches()); @@ -302,8 +349,11 @@ TEST_P(TestScanner, ScanBatchesFailure) { } TEST_P(TestScanner, Head) { + auto batch_size = GetParam().items_per_batch; + auto num_batches = GetParam().num_batches; + auto num_datasets = GetParam().num_child_datasets; SetSchema({field("i32", int32()), field("f64", float64())}); - auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); + auto batch = ConstantArrayGenerator::Zeroes(batch_size, schema_); auto scanner = MakeScanner(batch); std::shared_ptr expected, actual; @@ -313,30 +363,32 @@ TEST_P(TestScanner, Head) { AssertTablesEqual(*expected, *actual); ASSERT_OK_AND_ASSIGN(expected, Table::FromRecordBatches(schema_, {batch})); - ASSERT_OK_AND_ASSIGN(actual, scanner->Head(kBatchSize)); + ASSERT_OK_AND_ASSIGN(actual, scanner->Head(batch_size)); AssertTablesEqual(*expected, *actual); ASSERT_OK_AND_ASSIGN(expected, Table::FromRecordBatches(schema_, {batch->Slice(0, 1)})); ASSERT_OK_AND_ASSIGN(actual, scanner->Head(1)); AssertTablesEqual(*expected, *actual); - ASSERT_OK_AND_ASSIGN(expected, - Table::FromRecordBatches(schema_, {batch, batch->Slice(0, 1)})); - ASSERT_OK_AND_ASSIGN(actual, scanner->Head(kBatchSize + 1)); - AssertTablesEqual(*expected, *actual); + if (num_batches > 1) { + ASSERT_OK_AND_ASSIGN(expected, + Table::FromRecordBatches(schema_, {batch, batch->Slice(0, 1)})); + ASSERT_OK_AND_ASSIGN(actual, scanner->Head(batch_size + 1)); + AssertTablesEqual(*expected, *actual); + } ASSERT_OK_AND_ASSIGN(expected, scanner->ToTable()); - ASSERT_OK_AND_ASSIGN(actual, - scanner->Head(kBatchSize * kNumberBatches * kNumberChildDatasets)); + ASSERT_OK_AND_ASSIGN(actual, scanner->Head(batch_size * num_batches * num_datasets)); AssertTablesEqual(*expected, *actual); ASSERT_OK_AND_ASSIGN(expected, scanner->ToTable()); - ASSERT_OK_AND_ASSIGN( - actual, scanner->Head(kBatchSize * kNumberBatches * kNumberChildDatasets + 100)); + ASSERT_OK_AND_ASSIGN(actual, + scanner->Head(batch_size * num_batches * num_datasets + 100)); AssertTablesEqual(*expected, *actual); } -INSTANTIATE_TEST_SUITE_P(TestScannerThreading, TestScanner, ::testing::Bool()); +INSTANTIATE_TEST_SUITE_P(TestScannerThreading, TestScanner, + ::testing::ValuesIn(TestScannerParams::Values())); class TestScannerBuilder : public ::testing::Test { void SetUp() override { diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index 1d1266de671..b94441e178a 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -216,23 +216,32 @@ class DatasetFixtureMixin : public ::testing::Test { } /// \brief Ensure that record batches found in reader are equals to the - /// record batches yielded by a scanner. Each fragment in the scanner is - /// expected to have a single batch. + /// record batches yielded by a scanner. void AssertScanBatchesUnorderedEquals(RecordBatchReader* expected, Scanner* scanner, + int expected_batches_per_fragment, bool ensure_drained = true) { ASSERT_OK_AND_ASSIGN(auto it, scanner->ScanBatchesUnordered()); int fragment_counter = 0; bool saw_last_fragment = false; - ARROW_EXPECT_OK(it.Visit([&](EnumeratedRecordBatch batch) -> Status { - EXPECT_EQ(0, batch.record_batch.index); - EXPECT_EQ(true, batch.record_batch.last); - EXPECT_EQ(fragment_counter++, batch.fragment.index); - EXPECT_FALSE(saw_last_fragment); + int batch_counter = 0; + auto visitor = [&](EnumeratedRecordBatch batch) -> Status { + if (batch_counter == 0) { + EXPECT_FALSE(saw_last_fragment); + } + EXPECT_EQ(batch_counter++, batch.record_batch.index); + auto last_batch = batch_counter == expected_batches_per_fragment; + EXPECT_EQ(last_batch, batch.record_batch.last); + EXPECT_EQ(fragment_counter, batch.fragment.index); + if (last_batch) { + fragment_counter++; + batch_counter = 0; + } saw_last_fragment = batch.fragment.last; AssertBatchEquals(expected, *batch.record_batch.value); return Status::OK(); - })); + }; + ARROW_EXPECT_OK(it.Visit(visitor)); if (ensure_drained) { EnsureRecordBatchReaderDrained(expected); @@ -265,10 +274,18 @@ class DatasetFixtureMixin : public ::testing::Test { ASSERT_OK_AND_ASSIGN(options_->filter, filter.Bind(*schema_)); } + void SetProjectedColumns(std::vector column_names) { + ASSERT_OK(SetProjection(options_.get(), std::move(column_names))); + } + std::shared_ptr schema_; std::shared_ptr options_; }; +template +class DatasetFixtureMixinWithParam : public DatasetFixtureMixin, + public ::testing::WithParamInterface

{}; + /// \brief A dummy FileFormat implementation class DummyFileFormat : public FileFormat { public: @@ -290,7 +307,7 @@ class DummyFileFormat : public FileFormat { /// \brief Open a file for scanning (always returns an empty iterator) Result ScanFile( - std::shared_ptr options, + const std::shared_ptr& options, const std::shared_ptr& fragment) const override { return MakeEmptyIterator>(); } @@ -330,7 +347,7 @@ class JSONRecordBatchFileFormat : public FileFormat { /// \brief Open a file for scanning Result ScanFile( - std::shared_ptr options, + const std::shared_ptr& options, const std::shared_ptr& fragment) const override { ARROW_ASSIGN_OR_RAISE(auto file, fragment->source().Open()); ARROW_ASSIGN_OR_RAISE(int64_t size, file->GetSize()); diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index f274478fd75..fd5d0d28e9d 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -110,7 +110,7 @@ Future<> VisitAsyncGenerator(AsyncGenerator generator, /// \brief Waits for an async generator to complete, discarding results. template Future<> DiscardAllFromAsyncGenerator(AsyncGenerator generator) { - std::function visitor = [](...) { return Status::OK(); }; + std::function visitor = [](const T&) { return Status::OK(); }; return VisitAsyncGenerator(generator, visitor); } @@ -280,6 +280,23 @@ AsyncGenerator MakeMappedGenerator(AsyncGenerator source_generator, return MappingGenerator(std::move(source_generator), std::move(map)); } +template +AsyncGenerator MakeMappedGenerator(AsyncGenerator source_generator, MapFunc map) { + struct MapCallback { + MapFunc map; + + Future operator()(const T& val) { return EnsureFuture(map(val)); } + + Future EnsureFuture(Result val) { + return Future::MakeFinished(std::move(val)); + } + Future EnsureFuture(V val) { return Future::MakeFinished(std::move(val)); } + Future EnsureFuture(Future val) { return val; } + }; + std::function(const T&)> map_fn = MapCallback{map}; + return MappingGenerator(std::move(source_generator), map_fn); +} + /// \see MakeSequencingGenerator template class SequencingGenerator {