diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc index 93e54c6400e..060d40a078c 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node.cc @@ -576,8 +576,7 @@ class HashJoinNode : public ExecNode { {{"node.label", label()}, {"node.detail", ToString()}, {"node.kind", kind_name()}}); - finished_ = Future<>::Make(); - END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this); + END_SPAN_ON_FUTURE_COMPLETION(span_, finished(), this); bool use_sync_execution = !(plan_->exec_context()->executor()); size_t num_threads = use_sync_execution ? 1 : thread_indexer_.Capacity(); @@ -609,11 +608,11 @@ class HashJoinNode : public ExecNode { for (auto&& input : inputs_) { input->StopProducing(this); } - impl_->Abort([this]() { finished_.MarkFinished(); }); + impl_->Abort([this]() { ARROW_UNUSED(task_group_.End()); }); } } - Future<> finished() override { return finished_; } + Future<> finished() override { return task_group_.OnFinished(); } private: void OutputBatchCallback(ExecBatch batch) { @@ -624,14 +623,14 @@ class HashJoinNode : public ExecNode { bool expected = false; if (complete_.compare_exchange_strong(expected, true)) { outputs_[0]->InputFinished(this, static_cast(total_num_batches)); - finished_.MarkFinished(); + ARROW_UNUSED(task_group_.End()); } } Status ScheduleTaskCallback(std::function func) { auto executor = plan_->exec_context()->executor(); if (executor) { - RETURN_NOT_OK(executor->Spawn([this, func] { + ARROW_ASSIGN_OR_RAISE(auto task_fut, executor->Submit([this, func] { size_t thread_index = thread_indexer_(); Status status = func(thread_index); if (!status.ok()) { @@ -640,6 +639,7 @@ class HashJoinNode : public ExecNode { return; } })); + return task_group_.AddTask(task_fut); } else { // We should not get here in serial execution mode ARROW_DCHECK(false); @@ -656,6 +656,7 @@ class HashJoinNode : public ExecNode { ThreadIndexer thread_indexer_; std::unique_ptr schema_mgr_; std::unique_ptr impl_; + util::AsyncTaskGroup task_group_; }; namespace internal { diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc index 96469a78ab2..03a49e3378b 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc @@ -883,43 +883,49 @@ std::shared_ptr HashJoinSimple( return Table::Make(schema, result, result[0]->length()); } -void HashJoinWithExecPlan(Random64Bit& rng, bool parallel, - const HashJoinNodeOptions& join_options, - const std::shared_ptr& output_schema, - const std::vector>& l, - const std::vector>& r, int num_batches_l, - int num_batches_r, std::shared_ptr
* output) { +Result> HashJoinWithExecPlan( + Random64Bit& rng, bool parallel, const HashJoinNodeOptions& join_options, + const std::shared_ptr& output_schema, + const std::vector>& l, + const std::vector>& r, int num_batches_l, int num_batches_r) { auto exec_ctx = arrow::internal::make_unique( default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); + ARROW_ASSIGN_OR_RAISE(auto plan, ExecPlan::Make(exec_ctx.get())); // add left source BatchesWithSchema l_batches = TableToBatches(rng, num_batches_l, l, "l_"); - ASSERT_OK_AND_ASSIGN( + ARROW_ASSIGN_OR_RAISE( ExecNode * l_source, MakeExecNode("source", plan.get(), {}, SourceNodeOptions{l_batches.schema, l_batches.gen(parallel, - /*slow=*/false)})); + /*slow=*/true)})); // add right source BatchesWithSchema r_batches = TableToBatches(rng, num_batches_r, r, "r_"); - ASSERT_OK_AND_ASSIGN( + ARROW_ASSIGN_OR_RAISE( ExecNode * r_source, MakeExecNode("source", plan.get(), {}, SourceNodeOptions{r_batches.schema, r_batches.gen(parallel, /*slow=*/false)})); - ASSERT_OK_AND_ASSIGN(ExecNode * join, MakeExecNode("hashjoin", plan.get(), - {l_source, r_source}, join_options)); + ARROW_ASSIGN_OR_RAISE( + ExecNode * join, + MakeExecNode("hashjoin", plan.get(), {l_source, r_source}, join_options)); AsyncGenerator> sink_gen; - ASSERT_OK_AND_ASSIGN( + ARROW_ASSIGN_OR_RAISE( std::ignore, MakeExecNode("sink", plan.get(), {join}, SinkNodeOptions{&sink_gen})); - ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); - - ASSERT_OK_AND_ASSIGN(*output, TableFromExecBatches(output_schema, res)); + auto batches_fut = StartAndCollect(plan.get(), sink_gen); + if (!batches_fut.Wait(::arrow::kDefaultAssertFinishesWaitSeconds)) { + plan->StopProducing(); + // If this second wait fails then there isn't much we can do. We will abort + // and probably get a segmentation fault. + plan->finished().Wait(::arrow::kDefaultAssertFinishesWaitSeconds); + return Status::Invalid("Plan did not finish in a reasonable amount of time"); + } + return batches_fut.result(); } TEST(HashJoin, Suffix) { @@ -1161,12 +1167,15 @@ TEST(HashJoin, Random) { } std::shared_ptr output_schema = std::make_shared(std::move(output_schema_fields)); - std::shared_ptr
output_rows_test; - HashJoinWithExecPlan(rng, parallel, join_options, output_schema, - shuffled_input_arrays[0], shuffled_input_arrays[1], - static_cast(bit_util::CeilDiv(num_rows_l, batch_size)), - static_cast(bit_util::CeilDiv(num_rows_r, batch_size)), - &output_rows_test); + ASSERT_OK_AND_ASSIGN( + auto batches, HashJoinWithExecPlan( + rng, parallel, join_options, output_schema, + shuffled_input_arrays[0], shuffled_input_arrays[1], + static_cast(bit_util::CeilDiv(num_rows_l, batch_size)), + static_cast(bit_util::CeilDiv(num_rows_r, batch_size)))); + + ASSERT_OK_AND_ASSIGN(auto output_rows_test, + TableFromExecBatches(output_schema, batches)); // Compare results AssertTablesEqual(output_rows_ref, output_rows_test);