Skip to content
132 changes: 79 additions & 53 deletions cpp/src/arrow/dataset/file_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,64 @@ class WriteQueue {
std::shared_ptr<Schema> schema_;
};

struct WriteState {
explicit WriteState(FileSystemDatasetWriteOptions write_options)
: write_options(std::move(write_options)) {}

FileSystemDatasetWriteOptions write_options;
util::Mutex mutex;
std::unordered_map<std::string, std::unique_ptr<WriteQueue>> queues;
};

Status WriteNextBatch(WriteState& state, const std::shared_ptr<ScanTask>& scan_task,
std::shared_ptr<RecordBatch> batch) {
ARROW_ASSIGN_OR_RAISE(auto groups, state.write_options.partitioning->Partition(batch));
batch.reset(); // drop to hopefully conserve memory

if (groups.batches.size() > static_cast<size_t>(state.write_options.max_partitions)) {
return Status::Invalid("Fragment would be written into ", groups.batches.size(),
" partitions. This exceeds the maximum of ",
state.write_options.max_partitions);
}

std::unordered_set<WriteQueue*> need_flushed;
for (size_t i = 0; i < groups.batches.size(); ++i) {
auto partition_expression = and_(std::move(groups.expressions[i]),
scan_task->fragment()->partition_expression());
auto batch = std::move(groups.batches[i]);

ARROW_ASSIGN_OR_RAISE(auto part,
state.write_options.partitioning->Format(partition_expression));

WriteQueue* queue;
{
// lookup the queue to which batch should be appended
auto queues_lock = state.mutex.Lock();

queue = internal::GetOrInsertGenerated(
&state.queues, std::move(part),
[&](const std::string& emplaced_part) {
// lookup in `queues` also failed,
// generate a new WriteQueue
size_t queue_index = state.queues.size() - 1;

return internal::make_unique<WriteQueue>(emplaced_part, queue_index,
batch->schema());
})
->second.get();
}

queue->Push(std::move(batch));
need_flushed.insert(queue);
}

// flush all touched WriteQueues
for (auto queue : need_flushed) {
RETURN_NOT_OK(queue->Flush(state.write_options));
}
return Status::OK();
}

} // namespace

Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_options,
Expand All @@ -382,6 +440,7 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio
ARROW_ASSIGN_OR_RAISE(auto fragment_it, scanner->GetFragments());
ARROW_ASSIGN_OR_RAISE(FragmentVector fragments, fragment_it.ToVector());
ScanTaskVector scan_tasks;
std::vector<Future<>> scan_futs;

for (const auto& fragment : fragments) {
auto options = std::make_shared<ScanOptions>(*scanner->options());
Expand All @@ -399,68 +458,35 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio
// to a WriteQueue which flushes batches into that partition's output file. In principle
// any thread could produce a batch for any partition, so each task alternates between
// pushing batches and flushing them to disk.
util::Mutex queues_mutex;
std::unordered_map<std::string, std::unique_ptr<WriteQueue>> queues;
WriteState state(write_options);

for (const auto& scan_task : scan_tasks) {
task_group->Append([&, scan_task] {
ARROW_ASSIGN_OR_RAISE(auto batches, scan_task->Execute());

for (auto maybe_batch : batches) {
ARROW_ASSIGN_OR_RAISE(auto batch, maybe_batch);
ARROW_ASSIGN_OR_RAISE(auto groups, write_options.partitioning->Partition(batch));
batch.reset(); // drop to hopefully conserve memory

if (groups.batches.size() > static_cast<size_t>(write_options.max_partitions)) {
return Status::Invalid("Fragment would be written into ", groups.batches.size(),
" partitions. This exceeds the maximum of ",
write_options.max_partitions);
}

std::unordered_set<WriteQueue*> need_flushed;
for (size_t i = 0; i < groups.batches.size(); ++i) {
auto partition_expression = and_(std::move(groups.expressions[i]),
scan_task->fragment()->partition_expression());
auto batch = std::move(groups.batches[i]);

ARROW_ASSIGN_OR_RAISE(auto part,
write_options.partitioning->Format(partition_expression));

WriteQueue* queue;
{
// lookup the queue to which batch should be appended
auto queues_lock = queues_mutex.Lock();

queue = internal::GetOrInsertGenerated(
&queues, std::move(part),
[&](const std::string& emplaced_part) {
// lookup in `queues` also failed,
// generate a new WriteQueue
size_t queue_index = queues.size() - 1;

return internal::make_unique<WriteQueue>(
emplaced_part, queue_index, batch->schema());
})
->second.get();
}

queue->Push(std::move(batch));
need_flushed.insert(queue);
}
if (scan_task->supports_async()) {
ARROW_ASSIGN_OR_RAISE(auto batches_gen, scan_task->ExecuteAsync());
std::function<Status(std::shared_ptr<RecordBatch> batch)> batch_visitor =
[&, scan_task](std::shared_ptr<RecordBatch> batch) {
return WriteNextBatch(state, scan_task, std::move(batch));
};
scan_futs.push_back(VisitAsyncGenerator(batches_gen, batch_visitor));
} else {
task_group->Append([&, scan_task] {
ARROW_ASSIGN_OR_RAISE(auto batches, scan_task->Execute());

// flush all touched WriteQueues
for (auto queue : need_flushed) {
RETURN_NOT_OK(queue->Flush(write_options));
for (auto maybe_batch : batches) {
ARROW_ASSIGN_OR_RAISE(auto batch, maybe_batch);
RETURN_NOT_OK(WriteNextBatch(state, scan_task, std::move(batch)));
}
}

return Status::OK();
});
return Status::OK();
});
}
}
RETURN_NOT_OK(task_group->Finish());
auto scan_futs_all_done = AllComplete(scan_futs);
RETURN_NOT_OK(scan_futs_all_done.status());

task_group = scanner->options()->TaskGroup();
for (const auto& part_queue : queues) {
for (const auto& part_queue : state.queues) {
task_group->Append([&] { return part_queue.second->writer()->Finish(); });
}
return task_group->Finish();
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/dataset/file_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ class ARROW_DS_EXPORT FileWriter {

Status Write(RecordBatchReader* batches);

Status Finish();
virtual Status Finish();

const std::shared_ptr<FileFormat>& format() const { return options_->format(); }
const std::shared_ptr<Schema>& schema() const { return schema_; }
Expand Down
66 changes: 45 additions & 21 deletions cpp/src/arrow/dataset/file_csv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "arrow/io/compressed.h"
#include "arrow/result.h"
#include "arrow/type.h"
#include "arrow/util/async_generator.h"
#include "arrow/util/iterator.h"
#include "arrow/util/logging.h"

Expand All @@ -42,6 +43,7 @@ namespace dataset {

using internal::checked_cast;
using internal::checked_pointer_cast;
using RecordBatchGenerator = AsyncGenerator<std::shared_ptr<RecordBatch>>;

Result<std::unordered_set<std::string>> GetColumnNames(
const csv::ParseOptions& parse_options, util::string_view first_block,
Expand Down Expand Up @@ -110,35 +112,47 @@ static inline Result<csv::ReadOptions> GetReadOptions(
return read_options;
}

static inline Result<std::shared_ptr<csv::StreamingReader>> OpenReader(
static inline Future<std::shared_ptr<csv::StreamingReader>> OpenReaderAsync(
const FileSource& source, const CsvFileFormat& format,
const std::shared_ptr<ScanOptions>& scan_options = nullptr,
MemoryPool* pool = default_memory_pool()) {
ARROW_ASSIGN_OR_RAISE(auto reader_options, GetReadOptions(format, scan_options));

util::string_view first_block;
ARROW_ASSIGN_OR_RAISE(auto input, source.OpenCompressed());
ARROW_ASSIGN_OR_RAISE(
input, io::BufferedInputStream::Create(reader_options.block_size,
default_memory_pool(), std::move(input)));
ARROW_ASSIGN_OR_RAISE(first_block, input->Peek(reader_options.block_size));

const auto& parse_options = format.parse_options;
auto convert_options = csv::ConvertOptions::Defaults();
if (scan_options != nullptr) {
ARROW_ASSIGN_OR_RAISE(convert_options,
GetConvertOptions(format, scan_options, first_block, pool));
}

auto maybe_reader =
csv::StreamingReader::Make(io::IOContext(pool), std::move(input), reader_options,
parse_options, convert_options);
if (!maybe_reader.ok()) {
return maybe_reader.status().WithMessage("Could not open CSV input source '",
source.path(), "': ", maybe_reader.status());
}
auto peek_fut = DeferNotOk(input->io_context().executor()->Submit(
[input, reader_options] { return input->Peek(reader_options.block_size); }));

return peek_fut.Then([=](const util::string_view& first_block)
-> Future<std::shared_ptr<csv::StreamingReader>> {
const auto& parse_options = format.parse_options;
auto convert_options = csv::ConvertOptions::Defaults();
if (scan_options != nullptr) {
ARROW_ASSIGN_OR_RAISE(convert_options,
GetConvertOptions(format, scan_options, first_block, pool));
}

return csv::StreamingReader::MakeAsync(io::default_io_context(), std::move(input),
reader_options, parse_options, convert_options)
.Then(
[](const std::shared_ptr<csv::StreamingReader>& maybe_reader)
-> Result<std::shared_ptr<csv::StreamingReader>> { return maybe_reader; },
[source](const Status& err) -> Result<std::shared_ptr<csv::StreamingReader>> {
return err.WithMessage("Could not open CSV input source '", source.path(),
"': ", err);
});
});
}

return std::move(maybe_reader).ValueOrDie();
static inline Result<std::shared_ptr<csv::StreamingReader>> OpenReader(
const FileSource& source, const CsvFileFormat& format,
const std::shared_ptr<ScanOptions>& scan_options = nullptr,
MemoryPool* pool = default_memory_pool()) {
auto open_reader_fut = OpenReaderAsync(source, format, scan_options, pool);
return open_reader_fut.result();
}

/// \brief A ScanTask backed by an Csv file.
Expand All @@ -152,9 +166,19 @@ class CsvScanTask : public ScanTask {
source_(fragment->source()) {}

Result<RecordBatchIterator> Execute() override {
ARROW_ASSIGN_OR_RAISE(auto reader,
OpenReader(source_, *format_, options(), options()->pool));
return IteratorFromReader(std::move(reader));
ARROW_ASSIGN_OR_RAISE(auto gen, ExecuteAsync());
return MakeGeneratorIterator(std::move(gen));
}

bool supports_async() const override { return true; }

Result<RecordBatchGenerator> ExecuteAsync() override {
auto reader_fut = OpenReaderAsync(source_, *format_, options(), options()->pool);
auto generator_fut = reader_fut.Then(
[](const std::shared_ptr<csv::StreamingReader>& reader) -> RecordBatchGenerator {
return [reader]() { return reader->ReadNextAsync(); };
});
return MakeFromFuture(generator_fut);
}

private:
Expand Down
29 changes: 29 additions & 0 deletions cpp/src/arrow/dataset/file_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,35 @@ TEST_F(TestFileSystemDataset, FragmentPartitions) {
});
}

class TestFilesystemDatasetNestedParallelism : public NestedParallelismMixin {};

TEST_F(TestFilesystemDatasetNestedParallelism, Write) {
constexpr int NUM_BATCHES = 32;
RecordBatchVector batches;
for (int i = 0; i < NUM_BATCHES; i++) {
batches.push_back(ConstantArrayGenerator::Zeroes(/*size=*/1, schema_));
}
auto dataset = std::make_shared<NestedParallelismDataset>(schema_, std::move(batches));
ScannerBuilder builder{dataset, options_};
ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish());

ASSERT_OK_AND_ASSIGN(auto output_dir, TemporaryDir::Make("nested-parallel-dataset"));

auto format = std::make_shared<DiscardingRowCountingFormat>();
auto rows_written = std::make_shared<std::atomic<int>>(0);
std::shared_ptr<FileWriteOptions> file_write_options =
std::make_shared<DiscardingRowCountingFileWriteOptions>(rows_written);
FileSystemDatasetWriteOptions dataset_write_options;
dataset_write_options.file_write_options = file_write_options;
dataset_write_options.basename_template = "{i}";
dataset_write_options.partitioning = std::make_shared<HivePartitioning>(schema({}));
dataset_write_options.base_dir = output_dir->path().ToString();
dataset_write_options.filesystem = std::make_shared<fs::LocalFileSystem>();

ASSERT_OK(FileSystemDataset::Write(dataset_write_options, scanner));
ASSERT_EQ(NUM_BATCHES, rows_written->load());
}

// Tests of subtree pruning

struct TestPathTree {
Expand Down
32 changes: 25 additions & 7 deletions cpp/src/arrow/dataset/scanner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ Result<RecordBatchIterator> InMemoryScanTask::Execute() {
return MakeVectorIterator(record_batches_);
}

Result<RecordBatchGenerator> ScanTask::ExecuteAsync() {
return Status::NotImplemented("Async is not implemented for this scan task yet");
}

bool ScanTask::supports_async() const { return false; }

Result<FragmentIterator> Scanner::GetFragments() {
if (fragment_ != nullptr) {
return MakeVectorIterator(FragmentVector{fragment_});
Expand Down Expand Up @@ -203,19 +209,31 @@ Result<std::shared_ptr<Table>> Scanner::ToTable() {
auto state = std::make_shared<TableAssemblyState>();

size_t scan_task_id = 0;
std::vector<Future<>> scan_futures;
for (auto maybe_scan_task : scan_task_it) {
ARROW_ASSIGN_OR_RAISE(auto scan_task, maybe_scan_task);

auto id = scan_task_id++;
task_group->Append([state, id, scan_task] {
ARROW_ASSIGN_OR_RAISE(auto batch_it, scan_task->Execute());
ARROW_ASSIGN_OR_RAISE(auto local, batch_it.ToVector());
state->Emplace(std::move(local), id);
return Status::OK();
});
if (scan_task->supports_async()) {
ARROW_ASSIGN_OR_RAISE(auto scan_gen, scan_task->ExecuteAsync());
auto scan_fut = CollectAsyncGenerator(std::move(scan_gen))
.Then([state, id](const RecordBatchVector& rbs) {
state->Emplace(rbs, id);
});
scan_futures.push_back(std::move(scan_fut));
} else {
task_group->Append([state, id, scan_task] {
ARROW_ASSIGN_OR_RAISE(auto batch_it, scan_task->Execute());
ARROW_ASSIGN_OR_RAISE(auto local, batch_it.ToVector());
state->Emplace(std::move(local), id);
return Status::OK();
});
}
}
// Wait for all async tasks to complete, or the first error
RETURN_NOT_OK(AllComplete(scan_futures).status());

// Wait for all tasks to complete, or the first error.
// Wait for all sync tasks to complete, or the first error.
RETURN_NOT_OK(task_group->Finish());

return Table::FromRecordBatches(scan_options_->projected_schema,
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/dataset/scanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@
#include "arrow/dataset/visibility.h"
#include "arrow/memory_pool.h"
#include "arrow/type_fwd.h"
#include "arrow/util/async_generator.h"
#include "arrow/util/type_fwd.h"

namespace arrow {
using RecordBatchGenerator = AsyncGenerator<std::shared_ptr<RecordBatch>>;
namespace dataset {

constexpr int64_t kDefaultBatchSize = 1 << 20;
Expand Down Expand Up @@ -101,6 +103,8 @@ class ARROW_DS_EXPORT ScanTask {
/// resulting from the Scan. Execution semantics are encapsulated in the
/// particular ScanTask implementation
virtual Result<RecordBatchIterator> Execute() = 0;
virtual Result<RecordBatchGenerator> ExecuteAsync();
virtual bool supports_async() const;

virtual ~ScanTask() = default;

Expand Down
Loading