diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 70efe1d36af..ae3f355efb7 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -19,6 +19,7 @@ #include #include +#include #include "arrow/dataset/dataset.h" #include "arrow/dataset/dataset_internal.h" @@ -26,6 +27,8 @@ #include "arrow/dataset/scanner_internal.h" #include "arrow/table.h" #include "arrow/util/iterator.h" +#include "arrow/util/task_group.h" +#include "arrow/util/thread_pool.h" namespace arrow { namespace dataset { @@ -135,6 +138,11 @@ std::shared_ptr SchemaFromColumnNames( return std::make_shared(columns); } +Status ScannerBuilder::UseThreads(bool use_threads) { + scan_options_->use_threads = use_threads; + return Status::OK(); +} + Status ScannerBuilder::Finish(std::unique_ptr* out) const { if (has_projection_ && !project_columns_.empty()) { scan_options_->projector = std::make_shared( @@ -151,19 +159,54 @@ Status ScannerBuilder::Finish(std::unique_ptr* out) const { return Status::OK(); } -Status Scanner::ToTable(std::shared_ptr* out) { +using arrow::internal::TaskGroup; + +std::shared_ptr Scanner::TaskGroup() const { + return options_->use_threads ? TaskGroup::MakeThreaded(context_->thread_pool) + : TaskGroup::MakeSerial(); +} + +struct TableAggregator { + void Append(std::shared_ptr batch) { + std::lock_guard lock(m); + batches.emplace_back(std::move(batch)); + } + + Status Finish(std::shared_ptr
* out) { + return Table::FromRecordBatches(batches, out); + } + + std::mutex m; std::vector> batches; +}; + +struct ScanTaskPromise { + Status operator()() { + for (auto maybe_batch : task->Scan()) { + ARROW_ASSIGN_OR_RAISE(auto batch, std::move(maybe_batch)); + aggregator.Append(std::move(batch)); + } + + return Status::OK(); + } + + TableAggregator& aggregator; + std::shared_ptr task; +}; + +Status Scanner::ToTable(std::shared_ptr
* out) { + auto task_group = TaskGroup(); + + TableAggregator aggregator; + for (auto maybe_scan_task : Scan()) { + ARROW_ASSIGN_OR_RAISE(auto scan_task, std::move(maybe_scan_task)); + task_group->Append(ScanTaskPromise{aggregator, std::move(scan_task)}); + } - auto it_scantasks = Scan(); - RETURN_NOT_OK(it_scantasks.Visit([&batches](std::unique_ptr task) -> Status { - auto it = task->Scan(); - return it.Visit([&batches](std::shared_ptr batch) { - batches.push_back(batch); - return Status::OK(); - }); - })); + // Wait for all tasks to complete, or the first error. + RETURN_NOT_OK(task_group->Finish()); - return Table::FromRecordBatches(batches, out); + return aggregator.Finish(out); } } // namespace dataset diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index 3bd5e4fff6f..a2404ab0635 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -27,16 +27,22 @@ #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" #include "arrow/memory_pool.h" +#include "arrow/util/thread_pool.h" namespace arrow { class Table; +namespace internal { +class TaskGroup; +}; + namespace dataset { /// \brief Shared state for a Scan operation struct ARROW_DS_EXPORT ScanContext { MemoryPool* pool = arrow::default_memory_pool(); + internal::ThreadPool* thread_pool = arrow::internal::GetCpuThreadPool(); }; class RecordBatchProjector; @@ -47,6 +53,10 @@ class ARROW_DS_EXPORT ScanOptions { static std::shared_ptr Defaults(); + // Indicate if the Scanner should make use of the ThreadPool found in the + // ScanContext. + bool use_threads = false; + // Filter std::shared_ptr filter; // Evaluator for Filter @@ -109,18 +119,34 @@ Status ScanTaskIteratorFromRecordBatch(std::vector> /// yield scan_task class ARROW_DS_EXPORT Scanner { public: + Scanner(DataSourceVector sources, std::shared_ptr options, + std::shared_ptr context) + : sources_(std::move(sources)), + options_(std::move(options)), + context_(std::move(context)) {} + + virtual ~Scanner() = default; + /// \brief The Scan operator returns a stream of ScanTask. The caller is /// responsible to dispatch/schedule said tasks. Tasks should be safe to run /// in a concurrent fashion and outlive the iterator. virtual ScanTaskIterator Scan() = 0; - virtual ~Scanner() = default; - /// \brief Convert a Scanner into a Table. /// + /// \param[out] out output parameter + /// /// Use this convenience utility with care. This will serially materialize the /// Scan result in memory before creating the Table. Status ToTable(std::shared_ptr
* out); + + protected: + /// \brief Return a TaskGroup according to ScanContext thread rules. + std::shared_ptr TaskGroup() const; + + DataSourceVector sources_; + std::shared_ptr options_; + std::shared_ptr context_; }; /// \brief SimpleScanner is a trivial Scanner implementation that flattens @@ -141,16 +167,9 @@ class ARROW_DS_EXPORT SimpleScanner : public Scanner { SimpleScanner(std::vector> sources, std::shared_ptr options, std::shared_ptr context) - : sources_(std::move(sources)), - options_(std::move(options)), - context_(std::move(context)) {} + : Scanner(std::move(sources), std::move(options), std::move(context)) {} ScanTaskIterator Scan() override; - - private: - std::vector> sources_; - std::shared_ptr options_; - std::shared_ptr context_; }; /// \brief ScannerBuilder is a factory class to construct a Scanner. It is used @@ -188,6 +207,10 @@ class ARROW_DS_EXPORT ScannerBuilder { Status Filter(std::shared_ptr filter); Status Filter(const Expression& filter); + /// \brief Indicate if the Scanner should make use of the available + /// ThreadPool found in ScanContext; + Status UseThreads(bool use_threads = true); + /// \brief Return the constructed now-immutable Scanner object Status Finish(std::unique_ptr* out) const; diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 67b150d8ba7..164491058e3 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -125,8 +125,14 @@ TEST_F(TestSimpleScanner, ToTable) { auto scanner = std::make_shared(sources, options_, ctx_); std::shared_ptr
actual; + ASSERT_OK(scanner->ToTable(&actual)); + AssertTablesEqual(*expected, *actual); + // There is no guarantee on the ordering when using multiple threads, but + // since the RecordBatch is always the same it will pass. + options_->use_threads = true; + ASSERT_OK(scanner->ToTable(&actual)); AssertTablesEqual(*expected, *actual); }