From eb58ec0c432bb2fc68dca1d9f8fe8e13f6770a93 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Thu, 14 Apr 2022 14:20:35 -1000 Subject: [PATCH] ARROW-14911: Under certain circumstances it was possible for the hash join node to mark itself finished too early when task scheduler tasks were still winding down and attempts to access its own state would fail as the node was deleted. --- cpp/src/arrow/compute/exec/hash_join_node.cc | 13 ++--- .../arrow/compute/exec/hash_join_node_test.cc | 53 +++++++++++-------- 2 files changed, 38 insertions(+), 28 deletions(-) diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc index 93e54c6400e..060d40a078c 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node.cc @@ -576,8 +576,7 @@ class HashJoinNode : public ExecNode { {{"node.label", label()}, {"node.detail", ToString()}, {"node.kind", kind_name()}}); - finished_ = Future<>::Make(); - END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this); + END_SPAN_ON_FUTURE_COMPLETION(span_, finished(), this); bool use_sync_execution = !(plan_->exec_context()->executor()); size_t num_threads = use_sync_execution ? 1 : thread_indexer_.Capacity(); @@ -609,11 +608,11 @@ class HashJoinNode : public ExecNode { for (auto&& input : inputs_) { input->StopProducing(this); } - impl_->Abort([this]() { finished_.MarkFinished(); }); + impl_->Abort([this]() { ARROW_UNUSED(task_group_.End()); }); } } - Future<> finished() override { return finished_; } + Future<> finished() override { return task_group_.OnFinished(); } private: void OutputBatchCallback(ExecBatch batch) { @@ -624,14 +623,14 @@ class HashJoinNode : public ExecNode { bool expected = false; if (complete_.compare_exchange_strong(expected, true)) { outputs_[0]->InputFinished(this, static_cast(total_num_batches)); - finished_.MarkFinished(); + ARROW_UNUSED(task_group_.End()); } } Status ScheduleTaskCallback(std::function func) { auto executor = plan_->exec_context()->executor(); if (executor) { - RETURN_NOT_OK(executor->Spawn([this, func] { + ARROW_ASSIGN_OR_RAISE(auto task_fut, executor->Submit([this, func] { size_t thread_index = thread_indexer_(); Status status = func(thread_index); if (!status.ok()) { @@ -640,6 +639,7 @@ class HashJoinNode : public ExecNode { return; } })); + return task_group_.AddTask(task_fut); } else { // We should not get here in serial execution mode ARROW_DCHECK(false); @@ -656,6 +656,7 @@ class HashJoinNode : public ExecNode { ThreadIndexer thread_indexer_; std::unique_ptr schema_mgr_; std::unique_ptr impl_; + util::AsyncTaskGroup task_group_; }; namespace internal { diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc index 96469a78ab2..03a49e3378b 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc @@ -883,43 +883,49 @@ std::shared_ptr HashJoinSimple( return Table::Make(schema, result, result[0]->length()); } -void HashJoinWithExecPlan(Random64Bit& rng, bool parallel, - const HashJoinNodeOptions& join_options, - const std::shared_ptr& output_schema, - const std::vector>& l, - const std::vector>& r, int num_batches_l, - int num_batches_r, std::shared_ptr
* output) { +Result> HashJoinWithExecPlan( + Random64Bit& rng, bool parallel, const HashJoinNodeOptions& join_options, + const std::shared_ptr& output_schema, + const std::vector>& l, + const std::vector>& r, int num_batches_l, int num_batches_r) { auto exec_ctx = arrow::internal::make_unique( default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); + ARROW_ASSIGN_OR_RAISE(auto plan, ExecPlan::Make(exec_ctx.get())); // add left source BatchesWithSchema l_batches = TableToBatches(rng, num_batches_l, l, "l_"); - ASSERT_OK_AND_ASSIGN( + ARROW_ASSIGN_OR_RAISE( ExecNode * l_source, MakeExecNode("source", plan.get(), {}, SourceNodeOptions{l_batches.schema, l_batches.gen(parallel, - /*slow=*/false)})); + /*slow=*/true)})); // add right source BatchesWithSchema r_batches = TableToBatches(rng, num_batches_r, r, "r_"); - ASSERT_OK_AND_ASSIGN( + ARROW_ASSIGN_OR_RAISE( ExecNode * r_source, MakeExecNode("source", plan.get(), {}, SourceNodeOptions{r_batches.schema, r_batches.gen(parallel, /*slow=*/false)})); - ASSERT_OK_AND_ASSIGN(ExecNode * join, MakeExecNode("hashjoin", plan.get(), - {l_source, r_source}, join_options)); + ARROW_ASSIGN_OR_RAISE( + ExecNode * join, + MakeExecNode("hashjoin", plan.get(), {l_source, r_source}, join_options)); AsyncGenerator> sink_gen; - ASSERT_OK_AND_ASSIGN( + ARROW_ASSIGN_OR_RAISE( std::ignore, MakeExecNode("sink", plan.get(), {join}, SinkNodeOptions{&sink_gen})); - ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); - - ASSERT_OK_AND_ASSIGN(*output, TableFromExecBatches(output_schema, res)); + auto batches_fut = StartAndCollect(plan.get(), sink_gen); + if (!batches_fut.Wait(::arrow::kDefaultAssertFinishesWaitSeconds)) { + plan->StopProducing(); + // If this second wait fails then there isn't much we can do. We will abort + // and probably get a segmentation fault. + plan->finished().Wait(::arrow::kDefaultAssertFinishesWaitSeconds); + return Status::Invalid("Plan did not finish in a reasonable amount of time"); + } + return batches_fut.result(); } TEST(HashJoin, Suffix) { @@ -1161,12 +1167,15 @@ TEST(HashJoin, Random) { } std::shared_ptr output_schema = std::make_shared(std::move(output_schema_fields)); - std::shared_ptr
output_rows_test; - HashJoinWithExecPlan(rng, parallel, join_options, output_schema, - shuffled_input_arrays[0], shuffled_input_arrays[1], - static_cast(bit_util::CeilDiv(num_rows_l, batch_size)), - static_cast(bit_util::CeilDiv(num_rows_r, batch_size)), - &output_rows_test); + ASSERT_OK_AND_ASSIGN( + auto batches, HashJoinWithExecPlan( + rng, parallel, join_options, output_schema, + shuffled_input_arrays[0], shuffled_input_arrays[1], + static_cast(bit_util::CeilDiv(num_rows_l, batch_size)), + static_cast(bit_util::CeilDiv(num_rows_r, batch_size)))); + + ASSERT_OK_AND_ASSIGN(auto output_rows_test, + TableFromExecBatches(output_schema, batches)); // Compare results AssertTablesEqual(output_rows_ref, output_rows_test);