Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/exec/exec_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ class ARROW_EXPORT MapNode : public ExecNode {
protected:
void SubmitTask(std::function<Result<ExecBatch>(ExecBatch)> map_fn, ExecBatch batch);

void Finish(Status finish_st = Status::OK());
virtual void Finish(Status finish_st = Status::OK());

protected:
// Counter for the number of batches received
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/exec/sink_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ class ConsumingSinkNode : public ExecNode, public BackpressureControl {
}

protected:
virtual void Finish(const Status& finish_st) {
void Finish(const Status& finish_st) {
consumer_->Finish().AddCallback([this, finish_st](const Status& st) {
// Prefer the plan error over the consumer error
Status final_status = finish_st & st;
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/exec/source_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ struct TableSourceNode : public SourceNode {
static arrow::Status ValidateTableSourceNodeInput(const std::shared_ptr<Table> table,
const int64_t batch_size) {
if (table == nullptr) {
return Status::Invalid("TableSourceNode node requires table which is not null");
return Status::Invalid("TableSourceNode requires table which is not null");
}

if (batch_size <= 0) {
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/compute/exec/test_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,12 @@ ExecBatch ExecBatchFromJSON(const std::vector<ValueDescr>& descrs,
return batch;
}

Future<> StartAndFinish(ExecPlan* plan) {
RETURN_NOT_OK(plan->Validate());
RETURN_NOT_OK(plan->StartProducing());
return plan->finished();
}

Future<std::vector<ExecBatch>> StartAndCollect(
ExecPlan* plan, AsyncGenerator<util::optional<ExecBatch>> gen) {
RETURN_NOT_OK(plan->Validate());
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/compute/exec/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ struct BatchesWithSchema {
}
};

ARROW_TESTING_EXPORT
Future<> StartAndFinish(ExecPlan* plan);

ARROW_TESTING_EXPORT
Future<std::vector<ExecBatch>> StartAndCollect(
ExecPlan* plan, AsyncGenerator<util::optional<ExecBatch>> gen);
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/exec/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ namespace compute {
Status ValidateExecNodeInputs(ExecPlan* plan, const std::vector<ExecNode*>& inputs,
int expected_num_inputs, const char* kind_name) {
if (static_cast<int>(inputs.size()) != expected_num_inputs) {
return Status::Invalid(kind_name, " node requires ", expected_num_inputs,
return Status::Invalid(kind_name, " requires ", expected_num_inputs,
" inputs but got ", inputs.size());
}

Expand Down
10 changes: 10 additions & 0 deletions cpp/src/arrow/dataset/file_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,16 @@ class TeeNode : public compute::MapNode {

const char* kind_name() const override { return "TeeNode"; }

void Finish(Status finish_st) override {
dataset_writer_->Finish().AddCallback([this, finish_st](const Status& dw_status) {
// Need to wait for the task group to complete regardless of dw_status
task_group_.End().AddCallback(
[this, dw_status, finish_st](const Status& tg_status) {
finished_.MarkFinished(dw_status & finish_st & tg_status);
});
});
}

Result<compute::ExecBatch> DoTee(const compute::ExecBatch& batch) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<RecordBatch> record_batch,
batch.ToRecordBatch(output_schema()));
Expand Down
107 changes: 107 additions & 0 deletions cpp/src/arrow/dataset/file_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@
#include <cstdint>
#include <memory>
#include <string>
#include <tuple>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include "arrow/array/array_primitive.h"
#include "arrow/compute/exec/test_util.h"
#include "arrow/dataset/api.h"
#include "arrow/dataset/partition.h"
#include "arrow/dataset/plan.h"
#include "arrow/dataset/test_util.h"
#include "arrow/filesystem/path_util.h"
#include "arrow/filesystem/test_util.h"
Expand All @@ -34,6 +37,8 @@
#include "arrow/testing/gtest_util.h"
#include "arrow/util/io_util.h"

namespace cp = arrow::compute;

namespace arrow {

using internal::TemporaryDir;
Expand Down Expand Up @@ -342,5 +347,107 @@ TEST_F(TestFileSystemDataset, WriteProjected) {
}
}
}

class FileSystemWriteTest : public testing::TestWithParam<std::tuple<bool, bool>> {
using PlanFactory = std::function<std::vector<cp::Declaration>(
const FileSystemDatasetWriteOptions&,
std::function<Future<util::optional<cp::ExecBatch>>()>*)>;

protected:
bool IsParallel() { return std::get<0>(GetParam()); }
bool IsSlow() { return std::get<1>(GetParam()); }

FileSystemWriteTest() { dataset::internal::Initialize(); }

void TestDatasetWriteRoundTrip(PlanFactory plan_factory, bool has_output) {
// Runs in-memory data through the plan and then scans out the written
// data to ensure it matches the source data
auto format = std::make_shared<IpcFileFormat>();
auto fs = std::make_shared<fs::internal::MockFileSystem>(fs::kNoTime);
FileSystemDatasetWriteOptions write_options;
write_options.file_write_options = format->DefaultWriteOptions();
write_options.filesystem = fs;
write_options.base_dir = "root";
write_options.partitioning = std::make_shared<HivePartitioning>(schema({}));
write_options.basename_template = "{i}.feather";
const std::string kExpectedFilename = "root/0.feather";

cp::BatchesWithSchema source_data;
source_data.batches = {
cp::ExecBatchFromJSON({int32(), boolean()}, "[[null, true], [4, false]]"),
cp::ExecBatchFromJSON({int32(), boolean()},
"[[5, null], [6, false], [7, false]]")};
source_data.schema = schema({field("i32", int32()), field("bool", boolean())});

AsyncGenerator<util::optional<cp::ExecBatch>> sink_gen;

ASSERT_OK_AND_ASSIGN(auto plan, cp::ExecPlan::Make());
auto source_decl = cp::Declaration::Sequence(
{{"source", cp::SourceNodeOptions{source_data.schema,
source_data.gen(IsParallel(), IsSlow())}}});
auto declarations = plan_factory(write_options, &sink_gen);
declarations.insert(declarations.begin(), std::move(source_decl));
ASSERT_OK(cp::Declaration::Sequence(std::move(declarations)).AddToPlan(plan.get()));

if (has_output) {
ASSERT_FINISHES_OK_AND_ASSIGN(auto out_batches,
cp::StartAndCollect(plan.get(), sink_gen));
cp::AssertExecBatchesEqual(source_data.schema, source_data.batches, out_batches);
} else {
ASSERT_FINISHES_OK(cp::StartAndFinish(plan.get()));
}

// Read written dataset and make sure it matches
ASSERT_OK_AND_ASSIGN(auto dataset_factory, FileSystemDatasetFactory::Make(
fs, {kExpectedFilename}, format, {}));
ASSERT_OK_AND_ASSIGN(auto written_dataset, dataset_factory->Finish(FinishOptions{}));
AssertSchemaEqual(*source_data.schema, *written_dataset->schema());

ASSERT_OK_AND_ASSIGN(plan, cp::ExecPlan::Make());
ASSERT_OK_AND_ASSIGN(auto scanner_builder, written_dataset->NewScan());
ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish());
ASSERT_OK(cp::Declaration::Sequence(
{
{"scan", ScanNodeOptions{written_dataset, scanner->options()}},
{"sink", cp::SinkNodeOptions{&sink_gen}},
})
.AddToPlan(plan.get()));

ASSERT_FINISHES_OK_AND_ASSIGN(auto written_batches,
cp::StartAndCollect(plan.get(), sink_gen));
cp::AssertExecBatchesEqual(source_data.schema, source_data.batches, written_batches);
}
};

TEST_P(FileSystemWriteTest, Write) {
auto plan_factory =
[](const FileSystemDatasetWriteOptions& write_options,
std::function<Future<util::optional<cp::ExecBatch>>()>* sink_gen) {
return std::vector<cp::Declaration>{{"write", WriteNodeOptions{write_options}}};
};
TestDatasetWriteRoundTrip(plan_factory, /*has_output=*/false);
}

TEST_P(FileSystemWriteTest, TeeWrite) {
auto plan_factory =
[](const FileSystemDatasetWriteOptions& write_options,
std::function<Future<util::optional<cp::ExecBatch>>()>* sink_gen) {
return std::vector<cp::Declaration>{
{"tee", WriteNodeOptions{write_options}},
{"sink", cp::SinkNodeOptions{sink_gen}},
};
};
TestDatasetWriteRoundTrip(plan_factory, /*has_output=*/true);
}

INSTANTIATE_TEST_SUITE_P(
FileSystemWrite, FileSystemWriteTest,
testing::Combine(testing::Values(false, true), testing::Values(false, true)),
[](const testing::TestParamInfo<FileSystemWriteTest::ParamType>& info) {
std::string parallel_desc = std::get<0>(info.param) ? "parallel" : "serial";
std::string speed_desc = std::get<1>(info.param) ? "slow" : "fast";
return parallel_desc + "_" + speed_desc;
});

} // namespace dataset
} // namespace arrow