Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions cpp/src/arrow/compute/exec/hash_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -576,8 +576,7 @@ class HashJoinNode : public ExecNode {
{{"node.label", label()},
{"node.detail", ToString()},
{"node.kind", kind_name()}});
finished_ = Future<>::Make();
END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);
END_SPAN_ON_FUTURE_COMPLETION(span_, finished(), this);

bool use_sync_execution = !(plan_->exec_context()->executor());
size_t num_threads = use_sync_execution ? 1 : thread_indexer_.Capacity();
Expand Down Expand Up @@ -609,11 +608,11 @@ class HashJoinNode : public ExecNode {
for (auto&& input : inputs_) {
input->StopProducing(this);
}
impl_->Abort([this]() { finished_.MarkFinished(); });
impl_->Abort([this]() { ARROW_UNUSED(task_group_.End()); });
}
}

Future<> finished() override { return finished_; }
Future<> finished() override { return task_group_.OnFinished(); }

private:
void OutputBatchCallback(ExecBatch batch) {
Expand All @@ -624,14 +623,14 @@ class HashJoinNode : public ExecNode {
bool expected = false;
if (complete_.compare_exchange_strong(expected, true)) {
outputs_[0]->InputFinished(this, static_cast<int>(total_num_batches));
finished_.MarkFinished();
ARROW_UNUSED(task_group_.End());
}
}

Status ScheduleTaskCallback(std::function<Status(size_t)> func) {
auto executor = plan_->exec_context()->executor();
if (executor) {
RETURN_NOT_OK(executor->Spawn([this, func] {
ARROW_ASSIGN_OR_RAISE(auto task_fut, executor->Submit([this, func] {
size_t thread_index = thread_indexer_();
Status status = func(thread_index);
if (!status.ok()) {
Expand All @@ -640,6 +639,7 @@ class HashJoinNode : public ExecNode {
return;
}
}));
return task_group_.AddTask(task_fut);
} else {
// We should not get here in serial execution mode
ARROW_DCHECK(false);
Expand All @@ -656,6 +656,7 @@ class HashJoinNode : public ExecNode {
ThreadIndexer thread_indexer_;
std::unique_ptr<HashJoinSchema> schema_mgr_;
std::unique_ptr<HashJoinImpl> impl_;
util::AsyncTaskGroup task_group_;
};

namespace internal {
Expand Down
53 changes: 31 additions & 22 deletions cpp/src/arrow/compute/exec/hash_join_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -883,43 +883,49 @@ std::shared_ptr<Table> HashJoinSimple(
return Table::Make(schema, result, result[0]->length());
}

void HashJoinWithExecPlan(Random64Bit& rng, bool parallel,
const HashJoinNodeOptions& join_options,
const std::shared_ptr<Schema>& output_schema,
const std::vector<std::shared_ptr<Array>>& l,
const std::vector<std::shared_ptr<Array>>& r, int num_batches_l,
int num_batches_r, std::shared_ptr<Table>* output) {
Result<std::vector<ExecBatch>> HashJoinWithExecPlan(
Random64Bit& rng, bool parallel, const HashJoinNodeOptions& join_options,
const std::shared_ptr<Schema>& output_schema,
const std::vector<std::shared_ptr<Array>>& l,
const std::vector<std::shared_ptr<Array>>& r, int num_batches_l, int num_batches_r) {
auto exec_ctx = arrow::internal::make_unique<ExecContext>(
default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr);

ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));
ARROW_ASSIGN_OR_RAISE(auto plan, ExecPlan::Make(exec_ctx.get()));

// add left source
BatchesWithSchema l_batches = TableToBatches(rng, num_batches_l, l, "l_");
ASSERT_OK_AND_ASSIGN(
ARROW_ASSIGN_OR_RAISE(
ExecNode * l_source,
MakeExecNode("source", plan.get(), {},
SourceNodeOptions{l_batches.schema, l_batches.gen(parallel,
/*slow=*/false)}));
/*slow=*/true)}));

// add right source
BatchesWithSchema r_batches = TableToBatches(rng, num_batches_r, r, "r_");
ASSERT_OK_AND_ASSIGN(
ARROW_ASSIGN_OR_RAISE(
ExecNode * r_source,
MakeExecNode("source", plan.get(), {},
SourceNodeOptions{r_batches.schema, r_batches.gen(parallel,
/*slow=*/false)}));

ASSERT_OK_AND_ASSIGN(ExecNode * join, MakeExecNode("hashjoin", plan.get(),
{l_source, r_source}, join_options));
ARROW_ASSIGN_OR_RAISE(
ExecNode * join,
MakeExecNode("hashjoin", plan.get(), {l_source, r_source}, join_options));

AsyncGenerator<util::optional<ExecBatch>> sink_gen;
ASSERT_OK_AND_ASSIGN(
ARROW_ASSIGN_OR_RAISE(
std::ignore, MakeExecNode("sink", plan.get(), {join}, SinkNodeOptions{&sink_gen}));

ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen));

ASSERT_OK_AND_ASSIGN(*output, TableFromExecBatches(output_schema, res));
auto batches_fut = StartAndCollect(plan.get(), sink_gen);
if (!batches_fut.Wait(::arrow::kDefaultAssertFinishesWaitSeconds)) {
plan->StopProducing();
// If this second wait fails then there isn't much we can do. We will abort
// and probably get a segmentation fault.
plan->finished().Wait(::arrow::kDefaultAssertFinishesWaitSeconds);
return Status::Invalid("Plan did not finish in a reasonable amount of time");
}
return batches_fut.result();
}

TEST(HashJoin, Suffix) {
Expand Down Expand Up @@ -1161,12 +1167,15 @@ TEST(HashJoin, Random) {
}
std::shared_ptr<Schema> output_schema =
std::make_shared<Schema>(std::move(output_schema_fields));
std::shared_ptr<Table> output_rows_test;
HashJoinWithExecPlan(rng, parallel, join_options, output_schema,
shuffled_input_arrays[0], shuffled_input_arrays[1],
static_cast<int>(bit_util::CeilDiv(num_rows_l, batch_size)),
static_cast<int>(bit_util::CeilDiv(num_rows_r, batch_size)),
&output_rows_test);
ASSERT_OK_AND_ASSIGN(
auto batches, HashJoinWithExecPlan(
rng, parallel, join_options, output_schema,
shuffled_input_arrays[0], shuffled_input_arrays[1],
static_cast<int>(bit_util::CeilDiv(num_rows_l, batch_size)),
static_cast<int>(bit_util::CeilDiv(num_rows_r, batch_size))));

ASSERT_OK_AND_ASSIGN(auto output_rows_test,
TableFromExecBatches(output_schema, batches));

// Compare results
AssertTablesEqual(output_rows_ref, output_rows_test);
Expand Down