Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 53 additions & 10 deletions cpp/src/arrow/dataset/scanner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@

#include <algorithm>
#include <memory>
#include <mutex>

#include "arrow/dataset/dataset.h"
#include "arrow/dataset/dataset_internal.h"
#include "arrow/dataset/filter.h"
#include "arrow/dataset/scanner_internal.h"
#include "arrow/table.h"
#include "arrow/util/iterator.h"
#include "arrow/util/task_group.h"
#include "arrow/util/thread_pool.h"

namespace arrow {
namespace dataset {
Expand Down Expand Up @@ -135,6 +138,11 @@ std::shared_ptr<Schema> SchemaFromColumnNames(
return std::make_shared<Schema>(columns);
}

Status ScannerBuilder::UseThreads(bool use_threads) {
scan_options_->use_threads = use_threads;
return Status::OK();
}

Status ScannerBuilder::Finish(std::unique_ptr<Scanner>* out) const {
if (has_projection_ && !project_columns_.empty()) {
scan_options_->projector = std::make_shared<RecordBatchProjector>(
Expand All @@ -151,19 +159,54 @@ Status ScannerBuilder::Finish(std::unique_ptr<Scanner>* out) const {
return Status::OK();
}

Status Scanner::ToTable(std::shared_ptr<Table>* out) {
using arrow::internal::TaskGroup;

std::shared_ptr<TaskGroup> Scanner::TaskGroup() const {
return options_->use_threads ? TaskGroup::MakeThreaded(context_->thread_pool)
: TaskGroup::MakeSerial();
}

struct TableAggregator {
void Append(std::shared_ptr<RecordBatch> batch) {
std::lock_guard<std::mutex> lock(m);
batches.emplace_back(std::move(batch));
}

Status Finish(std::shared_ptr<Table>* out) {
return Table::FromRecordBatches(batches, out);
}

std::mutex m;
std::vector<std::shared_ptr<RecordBatch>> batches;
};

struct ScanTaskPromise {
Status operator()() {
for (auto maybe_batch : task->Scan()) {
ARROW_ASSIGN_OR_RAISE(auto batch, std::move(maybe_batch));
aggregator.Append(std::move(batch));
}

return Status::OK();
}

TableAggregator& aggregator;
std::shared_ptr<ScanTask> task;
};

Status Scanner::ToTable(std::shared_ptr<Table>* out) {
auto task_group = TaskGroup();

TableAggregator aggregator;
for (auto maybe_scan_task : Scan()) {
ARROW_ASSIGN_OR_RAISE(auto scan_task, std::move(maybe_scan_task));
task_group->Append(ScanTaskPromise{aggregator, std::move(scan_task)});
}

auto it_scantasks = Scan();
RETURN_NOT_OK(it_scantasks.Visit([&batches](std::unique_ptr<ScanTask> task) -> Status {
auto it = task->Scan();
return it.Visit([&batches](std::shared_ptr<RecordBatch> batch) {
batches.push_back(batch);
return Status::OK();
});
}));
// Wait for all tasks to complete, or the first error.
RETURN_NOT_OK(task_group->Finish());

return Table::FromRecordBatches(batches, out);
return aggregator.Finish(out);
}

} // namespace dataset
Expand Down
43 changes: 33 additions & 10 deletions cpp/src/arrow/dataset/scanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,22 @@
#include "arrow/dataset/type_fwd.h"
#include "arrow/dataset/visibility.h"
#include "arrow/memory_pool.h"
#include "arrow/util/thread_pool.h"

namespace arrow {

class Table;

namespace internal {
class TaskGroup;
};

namespace dataset {

/// \brief Shared state for a Scan operation
struct ARROW_DS_EXPORT ScanContext {
MemoryPool* pool = arrow::default_memory_pool();
internal::ThreadPool* thread_pool = arrow::internal::GetCpuThreadPool();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's time to expose a common ResourceContext class that has a MemoryPool and a ThreadPool?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be a task group instead of a thread pool. Then users can pass a serial task group to signal single threaded operation

};

class RecordBatchProjector;
Expand All @@ -47,6 +53,10 @@ class ARROW_DS_EXPORT ScanOptions {

static std::shared_ptr<ScanOptions> Defaults();

// Indicate if the Scanner should make use of the ThreadPool found in the
// ScanContext.
bool use_threads = false;

// Filter
std::shared_ptr<Expression> filter;
// Evaluator for Filter
Expand Down Expand Up @@ -109,18 +119,34 @@ Status ScanTaskIteratorFromRecordBatch(std::vector<std::shared_ptr<RecordBatch>>
/// yield scan_task
class ARROW_DS_EXPORT Scanner {
public:
Scanner(DataSourceVector sources, std::shared_ptr<ScanOptions> options,
std::shared_ptr<ScanContext> context)
: sources_(std::move(sources)),
options_(std::move(options)),
context_(std::move(context)) {}

virtual ~Scanner() = default;

/// \brief The Scan operator returns a stream of ScanTask. The caller is
/// responsible to dispatch/schedule said tasks. Tasks should be safe to run
/// in a concurrent fashion and outlive the iterator.
virtual ScanTaskIterator Scan() = 0;

virtual ~Scanner() = default;

/// \brief Convert a Scanner into a Table.
///
/// \param[out] out output parameter
///
/// Use this convenience utility with care. This will serially materialize the
/// Scan result in memory before creating the Table.
Status ToTable(std::shared_ptr<Table>* out);

protected:
/// \brief Return a TaskGroup according to ScanContext thread rules.
std::shared_ptr<internal::TaskGroup> TaskGroup() const;

DataSourceVector sources_;
std::shared_ptr<ScanOptions> options_;
std::shared_ptr<ScanContext> context_;
};

/// \brief SimpleScanner is a trivial Scanner implementation that flattens
Expand All @@ -141,16 +167,9 @@ class ARROW_DS_EXPORT SimpleScanner : public Scanner {
SimpleScanner(std::vector<std::shared_ptr<DataSource>> sources,
std::shared_ptr<ScanOptions> options,
std::shared_ptr<ScanContext> context)
: sources_(std::move(sources)),
options_(std::move(options)),
context_(std::move(context)) {}
: Scanner(std::move(sources), std::move(options), std::move(context)) {}

ScanTaskIterator Scan() override;

private:
std::vector<std::shared_ptr<DataSource>> sources_;
std::shared_ptr<ScanOptions> options_;
std::shared_ptr<ScanContext> context_;
};

/// \brief ScannerBuilder is a factory class to construct a Scanner. It is used
Expand Down Expand Up @@ -188,6 +207,10 @@ class ARROW_DS_EXPORT ScannerBuilder {
Status Filter(std::shared_ptr<Expression> filter);
Status Filter(const Expression& filter);

/// \brief Indicate if the Scanner should make use of the available
/// ThreadPool found in ScanContext;
Status UseThreads(bool use_threads = true);

/// \brief Return the constructed now-immutable Scanner object
Status Finish(std::unique_ptr<Scanner>* out) const;

Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/dataset/scanner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,14 @@ TEST_F(TestSimpleScanner, ToTable) {

auto scanner = std::make_shared<SimpleScanner>(sources, options_, ctx_);
std::shared_ptr<Table> actual;

ASSERT_OK(scanner->ToTable(&actual));
AssertTablesEqual(*expected, *actual);

// There is no guarantee on the ordering when using multiple threads, but
// since the RecordBatch is always the same it will pass.
options_->use_threads = true;
ASSERT_OK(scanner->ToTable(&actual));
AssertTablesEqual(*expected, *actual);
}

Expand Down