From 4c8673437db3de3108b9431204b16fffc19b8514 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 20 Jul 2021 15:20:44 -0400 Subject: [PATCH 01/27] init --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/compute/exec/exec_plan.cc | 295 ++++++++++++++++++++++++ cpp/src/arrow/compute/exec/hash_join.cc | 25 ++ cpp/src/arrow/compute/exec/hash_join.h | 29 +++ 4 files changed, 350 insertions(+) create mode 100644 cpp/src/arrow/compute/exec/hash_join.cc create mode 100644 cpp/src/arrow/compute/exec/hash_join.h diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 308ee49972c..0145ec83472 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -411,6 +411,7 @@ if(ARROW_COMPUTE) compute/kernels/vector_replace.cc compute/kernels/vector_selection.cc compute/kernels/vector_sort.cc + compute/exec/hash_join.cc compute/exec/key_hash.cc compute/exec/key_map.cc compute/exec/key_compare.cc diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 5047c7a58d6..a7f568be58f 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -346,5 +346,300 @@ ExecFactoryRegistry* default_exec_factory_registry() { return &instance; } +struct HashSemiIndexJoinNode : ExecNode { + HashSemiIndexJoinNode(ExecNode* left_input, ExecNode* right_input, std::string label, + std::shared_ptr output_schema, ExecContext* ctx, + const std::vector&& index_field_ids) + : ExecNode(left_input->plan(), std::move(label), {left_input, right_input}, + {"hashsemiindexjoin"}, std::move(output_schema), /*num_outputs=*/1), + ctx_(ctx), + index_field_ids_(std::move(index_field_ids)) { + // num_input_batches_processed_.store(0); + // num_input_batches_total_.store(-1); + // num_output_batches_processed_.store(0); + output_started_.store(false); + } + + const char* kind_name() override { return "HashSemiIndexJoinNode"; } + + private: + struct ThreadLocalState; + + public: + Status InitLocalStateIfNeeded(ThreadLocalState* state) { + // Get input schema + auto input_schema = inputs_[0]->output_schema(); + + if (!state->grouper) { + // Build vector of key field data types + std::vector key_descrs(key_field_ids_.size()); + for (size_t i = 0; i < key_field_ids_.size(); ++i) { + auto key_field_id = key_field_ids_[i]; + key_descrs[i] = ValueDescr(input_schema->field(key_field_id)->type()); + } + + // Construct grouper + ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, ctx_)); + } + if (state->agg_states.empty()) { + // Build vector of aggregate source field data types + std::vector agg_src_descrs(agg_kernels_.size()); + for (size_t i = 0; i < agg_kernels_.size(); ++i) { + auto agg_src_field_id = agg_src_field_ids_[i]; + agg_src_descrs[i] = + ValueDescr(input_schema->field(agg_src_field_id)->type(), ValueDescr::ARRAY); + } + for (size_t i = 0; i < agg_kernels_.size(); ++i) { + ARROW_ASSIGN_OR_RAISE( + state->agg_states, + internal::InitKernels(agg_kernels_, ctx_, aggs_, agg_src_descrs)); + + ARROW_ASSIGN_OR_RAISE( + FieldVector agg_result_fields, + internal::ResolveKernels(aggs_, agg_kernels_, state->agg_states, ctx_, + agg_src_descrs)); + } + } + return Status::OK(); + } + + Status ProcessInputBatch(const ExecBatch& batch) { + SmallUniqueIdHolder id_holder(&local_state_id_assignment_); + int id = id_holder.get(); + ThreadLocalState* state = local_states_.get(id); + RETURN_NOT_OK(InitLocalStateIfNeeded(state)); + + // Create a batch with key columns + std::vector keys(key_field_ids_.size()); + for (size_t i = 0; i < key_field_ids_.size(); ++i) { + keys[i] = batch.values[key_field_ids_[i]]; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); + + // Create a batch with group ids + ARROW_ASSIGN_OR_RAISE(Datum id_batch, state->grouper->Consume(key_batch)); + + // Execute aggregate kernels + for (size_t i = 0; i < agg_kernels_.size(); ++i) { + KernelContext kernel_ctx{ctx_}; + kernel_ctx.SetState(state->agg_states[i].get()); + + ARROW_ASSIGN_OR_RAISE( + auto agg_batch, + ExecBatch::Make({batch.values[agg_src_field_ids_[i]], id_batch})); + + RETURN_NOT_OK(agg_kernels_[i]->resize(&kernel_ctx, state->grouper->num_groups())); + RETURN_NOT_OK(agg_kernels_[i]->consume(&kernel_ctx, agg_batch)); + } + + return Status::OK(); + } + + Status Merge() { + int num_local_states = local_state_id_assignment_.num_ids(); + ThreadLocalState* state0 = local_states_.get(0); + for (int i = 1; i < num_local_states; ++i) { + ThreadLocalState* state = local_states_.get(i); + ARROW_DCHECK(state); + ARROW_DCHECK(state->grouper); + ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques()); + ARROW_ASSIGN_OR_RAISE(Datum transposition, state0->grouper->Consume(other_keys)); + state->grouper.reset(); + + for (size_t i = 0; i < agg_kernels_.size(); ++i) { + KernelContext batch_ctx{ctx_}; + ARROW_DCHECK(state0->agg_states[i]); + batch_ctx.SetState(state0->agg_states[i].get()); + + RETURN_NOT_OK(agg_kernels_[i]->resize(&batch_ctx, state0->grouper->num_groups())); + RETURN_NOT_OK(agg_kernels_[i]->merge(&batch_ctx, std::move(*state->agg_states[i]), + *transposition.array())); + state->agg_states[i].reset(); + } + } + return Status::OK(); + } + + Status Finalize() { + out_data_.resize(agg_kernels_.size() + key_field_ids_.size()); + auto it = out_data_.begin(); + + ThreadLocalState* state = local_states_.get(0); + num_out_groups_ = state->grouper->num_groups(); + + // Aggregate fields come before key fields to match the behavior of GroupBy function + + for (size_t i = 0; i < agg_kernels_.size(); ++i) { + KernelContext batch_ctx{ctx_}; + batch_ctx.SetState(state->agg_states[i].get()); + Datum out; + RETURN_NOT_OK(agg_kernels_[i]->finalize(&batch_ctx, &out)); + *it++ = out.array(); + state->agg_states[i].reset(); + } + + ARROW_ASSIGN_OR_RAISE(ExecBatch out_keys, state->grouper->GetUniques()); + for (const auto& key : out_keys.values) { + *it++ = key.array(); + } + state->grouper.reset(); + + return Status::OK(); + } + + Status OutputNthBatch(int n) { + ARROW_DCHECK(output_started_.load()); + + // Check finished flag + if (finished_.is_finished()) { + return Status::OK(); + } + + // Slice arrays + int64_t batch_size = output_batch_size(); + int64_t batch_start = n * batch_size; + int64_t batch_length = std::min(batch_size, num_out_groups_ - batch_start); + std::vector output_slices(out_data_.size()); + for (size_t out_field_id = 0; out_field_id < out_data_.size(); ++out_field_id) { + output_slices[out_field_id] = + out_data_[out_field_id]->Slice(batch_start, batch_length); + } + + ARROW_ASSIGN_OR_RAISE(ExecBatch output_batch, ExecBatch::Make(output_slices)); + outputs_[0]->InputReceived(this, n, output_batch); + + uint32_t num_output_batches_processed = + 1 + num_output_batches_processed_.fetch_add(1); + if (num_output_batches_processed * batch_size >= num_out_groups_) { + finished_.MarkFinished(); + } + + return Status::OK(); + } + + Status OutputResult() { + bool expected = false; + if (!output_started_.compare_exchange_strong(expected, true)) { + return Status::OK(); + } + + RETURN_NOT_OK(Merge()); + RETURN_NOT_OK(Finalize()); + + int batch_size = output_batch_size(); + int num_result_batches = (num_out_groups_ + batch_size - 1) / batch_size; + outputs_[0]->InputFinished(this, num_result_batches); + + auto executor = arrow::internal::GetCpuThreadPool(); + for (int i = 0; i < num_result_batches; ++i) { + // Check finished flag + if (finished_.is_finished()) { + break; + } + + RETURN_NOT_OK(executor->Spawn([this, i]() { + Status status = OutputNthBatch(i); + if (!status.ok()) { + ErrorReceived(inputs_[0], status); + } + })); + } + + return Status::OK(); + } + + void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { + DCHECK_EQ(input, inputs_[0]); + + if (finished_.is_finished()) { + return; + } + + ARROW_DCHECK(num_input_batches_processed_.load() != num_input_batches_total_.load()); + + Status status = ProcessInputBatch(batch); + if (!status.ok()) { + ErrorReceived(input, status); + return; + } + + num_input_batches_processed_.fetch_add(1); + if (num_input_batches_processed_.load() == num_input_batches_total_.load()) { + status = OutputResult(); + if (!status.ok()) { + ErrorReceived(input, status); + return; + } + } + } + + void ErrorReceived(ExecNode* input, Status error) override { + DCHECK_EQ(input, inputs_[0]); + + outputs_[0]->ErrorReceived(this, std::move(error)); + StopProducing(); + } + + void InputFinished(ExecNode* input, int seq) override { + DCHECK_EQ(input, inputs_[0]); + + num_input_batches_total_.store(seq); + if (num_input_batches_processed_.load() == num_input_batches_total_.load()) { + Status status = OutputResult(); + + if (!status.ok()) { + ErrorReceived(input, status); + } + } + } + + Status StartProducing() override { + finished_ = Future<>::Make(); + return Status::OK(); + } + + void PauseProducing(ExecNode* output) override {} + + void ResumeProducing(ExecNode* output) override {} + + void StopProducing(ExecNode* output) override { + DCHECK_EQ(output, outputs_[0]); + inputs_[0]->StopProducing(this); + + finished_.MarkFinished(); + } + + void StopProducing() override { StopProducing(outputs_[0]); } + + Future<> finished() override { return finished_; } + + private: + int output_batch_size() const { + int result = static_cast(ctx_->exec_chunksize()); + if (result < 0) { + result = 32 * 1024; + } + return result; + } + + ExecContext* ctx_; + Future<> finished_ = Future<>::MakeFinished(); + + // std::atomic num_input_batches_processed_; + // std::atomic num_input_batches_total_; + // std::atomic num_output_batches_processed_; + + const std::vector index_field_ids_; + + struct ThreadLocalState { + std::unique_ptr grouper; + std::vector> agg_states; + }; + SharedSequenceOfObjects local_states_; + SmallUniqueIdAssignment local_state_id_assignment_; + uint32_t num_out_groups_{0}; + ArrayDataVector out_data_; + std::atomic output_started_; +}; } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc new file mode 100644 index 00000000000..10b676be892 --- /dev/null +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec/hash_join.h" + +namespace arrow { +namespace compute { + + +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/exec/hash_join.h b/cpp/src/arrow/compute/exec/hash_join.h new file mode 100644 index 00000000000..492cf0a0a49 --- /dev/null +++ b/cpp/src/arrow/compute/exec/hash_join.h @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace arrow { +namespace compute { + +enum JoinType { + LEFT_SEMI_JOIN, + RIGHT_SEMI_JOIN, + LEFT_ANTI_SEMI_JOIN, + RIGHT_ANTI_SEMI_JOIN +}; + +} // namespace compute +} // namespace arrow \ No newline at end of file From 29d4d1cd73eed21ac4fd9f4e30b9bae786bbce13 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 26 Jul 2021 18:11:01 -0400 Subject: [PATCH 02/27] adding Grouper::Find --- cpp/src/arrow/compute/api_aggregate.h | 5 + cpp/src/arrow/compute/exec/exec_plan.cc | 458 +++++++++--------- .../arrow/compute/kernels/hash_aggregate.cc | 77 ++- .../compute/kernels/hash_aggregate_test.cc | 22 + 4 files changed, 309 insertions(+), 253 deletions(-) diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h index 880424e97f8..87827889ca0 100644 --- a/cpp/src/arrow/compute/api_aggregate.h +++ b/cpp/src/arrow/compute/api_aggregate.h @@ -383,6 +383,11 @@ class ARROW_EXPORT Grouper { /// be as wide as necessary. virtual Result Consume(const ExecBatch& batch) = 0; + /// Finds/ queries the group IDs for the given ExecBatch for every index. Returns the + /// group IDs as an integer array. If a group ID not found, a UINT32_MAX will be + /// added to that index. This is a thread-safe lookup. + virtual Result Find(const ExecBatch& batch) = 0; + /// Get current unique keys. May be called multiple times. virtual Result GetUniques() = 0; diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index a7f568be58f..b102264331e 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -346,300 +346,278 @@ ExecFactoryRegistry* default_exec_factory_registry() { return &instance; } -struct HashSemiIndexJoinNode : ExecNode { +/*struct HashSemiIndexJoinNode : ExecNode { HashSemiIndexJoinNode(ExecNode* left_input, ExecNode* right_input, std::string label, std::shared_ptr output_schema, ExecContext* ctx, const std::vector&& index_field_ids) : ExecNode(left_input->plan(), std::move(label), {left_input, right_input}, - {"hashsemiindexjoin"}, std::move(output_schema), /*num_outputs=*/1), - ctx_(ctx), - index_field_ids_(std::move(index_field_ids)) { - // num_input_batches_processed_.store(0); - // num_input_batches_total_.store(-1); - // num_output_batches_processed_.store(0); - output_started_.store(false); - } - - const char* kind_name() override { return "HashSemiIndexJoinNode"; } - - private: - struct ThreadLocalState; - - public: - Status InitLocalStateIfNeeded(ThreadLocalState* state) { - // Get input schema - auto input_schema = inputs_[0]->output_schema(); + {"hashsemiindexjoin"}, std::move(output_schema), */ +/*num_outputs=*//*1), +ctx_(ctx), +num_build_batches_processed_(0), +num_build_batches_total_(-1), +num_probe_batches_processed_(0), +num_probe_batches_total_(-1), +num_output_batches_processed_(0), +index_field_ids_(std::move(index_field_ids)), +output_started_(false), +build_phase_finished_(false){} + +const char* kind_name() override { return "HashSemiIndexJoinNode"; } + +private: +struct ThreadLocalState; + +public: +Status InitLocalStateIfNeeded(ThreadLocalState* state) { +// Get input schema +auto input_schema = inputs_[0]->output_schema(); + +if (!state->grouper) { +// Build vector of key field data types +std::vector key_descrs(index_field_ids_.size()); +for (size_t i = 0; i < index_field_ids_.size(); ++i) { +auto key_field_id = index_field_ids_[i]; +key_descrs[i] = ValueDescr(input_schema->field(key_field_id)->type()); +} - if (!state->grouper) { - // Build vector of key field data types - std::vector key_descrs(key_field_ids_.size()); - for (size_t i = 0; i < key_field_ids_.size(); ++i) { - auto key_field_id = key_field_ids_[i]; - key_descrs[i] = ValueDescr(input_schema->field(key_field_id)->type()); - } +// Construct grouper +ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, ctx_)); +} - // Construct grouper - ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, ctx_)); - } - if (state->agg_states.empty()) { - // Build vector of aggregate source field data types - std::vector agg_src_descrs(agg_kernels_.size()); - for (size_t i = 0; i < agg_kernels_.size(); ++i) { - auto agg_src_field_id = agg_src_field_ids_[i]; - agg_src_descrs[i] = - ValueDescr(input_schema->field(agg_src_field_id)->type(), ValueDescr::ARRAY); - } - for (size_t i = 0; i < agg_kernels_.size(); ++i) { - ARROW_ASSIGN_OR_RAISE( - state->agg_states, - internal::InitKernels(agg_kernels_, ctx_, aggs_, agg_src_descrs)); - - ARROW_ASSIGN_OR_RAISE( - FieldVector agg_result_fields, - internal::ResolveKernels(aggs_, agg_kernels_, state->agg_states, ctx_, - agg_src_descrs)); - } - } - return Status::OK(); - } +return Status::OK(); +} - Status ProcessInputBatch(const ExecBatch& batch) { - SmallUniqueIdHolder id_holder(&local_state_id_assignment_); - int id = id_holder.get(); - ThreadLocalState* state = local_states_.get(id); - RETURN_NOT_OK(InitLocalStateIfNeeded(state)); +Status ProcessBuildSideBatch(const ExecBatch& batch) { +SmallUniqueIdHolder id_holder(&local_state_id_assignment_); +int id = id_holder.get(); +ThreadLocalState* state = local_states_.get(id); +RETURN_NOT_OK(InitLocalStateIfNeeded(state)); - // Create a batch with key columns - std::vector keys(key_field_ids_.size()); - for (size_t i = 0; i < key_field_ids_.size(); ++i) { - keys[i] = batch.values[key_field_ids_[i]]; - } - ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); +// Create a batch with key columns +std::vector keys(key_field_ids_.size()); +for (size_t i = 0; i < key_field_ids_.size(); ++i) { +keys[i] = batch.values[key_field_ids_[i]]; +} +ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); - // Create a batch with group ids - ARROW_ASSIGN_OR_RAISE(Datum id_batch, state->grouper->Consume(key_batch)); +// Create a batch with group ids +ARROW_ASSIGN_OR_RAISE(Datum id_batch, state->grouper->Consume(key_batch)); - // Execute aggregate kernels - for (size_t i = 0; i < agg_kernels_.size(); ++i) { - KernelContext kernel_ctx{ctx_}; - kernel_ctx.SetState(state->agg_states[i].get()); +// Execute aggregate kernels +for (size_t i = 0; i < agg_kernels_.size(); ++i) { +KernelContext kernel_ctx{ctx_}; +kernel_ctx.SetState(state->agg_states[i].get()); - ARROW_ASSIGN_OR_RAISE( - auto agg_batch, - ExecBatch::Make({batch.values[agg_src_field_ids_[i]], id_batch})); +ARROW_ASSIGN_OR_RAISE( +auto agg_batch, +ExecBatch::Make({batch.values[agg_src_field_ids_[i]], id_batch})); - RETURN_NOT_OK(agg_kernels_[i]->resize(&kernel_ctx, state->grouper->num_groups())); - RETURN_NOT_OK(agg_kernels_[i]->consume(&kernel_ctx, agg_batch)); - } +RETURN_NOT_OK(agg_kernels_[i]->resize(&kernel_ctx, state->grouper->num_groups())); +RETURN_NOT_OK(agg_kernels_[i]->consume(&kernel_ctx, agg_batch)); +} - return Status::OK(); - } +return Status::OK(); +} - Status Merge() { - int num_local_states = local_state_id_assignment_.num_ids(); - ThreadLocalState* state0 = local_states_.get(0); - for (int i = 1; i < num_local_states; ++i) { - ThreadLocalState* state = local_states_.get(i); - ARROW_DCHECK(state); - ARROW_DCHECK(state->grouper); - ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques()); - ARROW_ASSIGN_OR_RAISE(Datum transposition, state0->grouper->Consume(other_keys)); - state->grouper.reset(); - - for (size_t i = 0; i < agg_kernels_.size(); ++i) { - KernelContext batch_ctx{ctx_}; - ARROW_DCHECK(state0->agg_states[i]); - batch_ctx.SetState(state0->agg_states[i].get()); - - RETURN_NOT_OK(agg_kernels_[i]->resize(&batch_ctx, state0->grouper->num_groups())); - RETURN_NOT_OK(agg_kernels_[i]->merge(&batch_ctx, std::move(*state->agg_states[i]), - *transposition.array())); - state->agg_states[i].reset(); - } - } - return Status::OK(); - } +// merge all other groupers to grouper[0]. nothing needs to be done on the +// early_probe_batches, because when probing everyone +Status BuildSideMerge() { +int num_local_states = local_state_id_assignment_.num_ids(); +ThreadLocalState* state0 = local_states_.get(0); +for (int i = 1; i < num_local_states; ++i) { +ThreadLocalState* state = local_states_.get(i); +ARROW_DCHECK(state); +ARROW_DCHECK(state->grouper); +ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques()); +ARROW_ASSIGN_OR_RAISE(Datum _, state0->grouper->Consume(other_keys)); +state->grouper.reset(); +} +return Status::OK(); +} - Status Finalize() { - out_data_.resize(agg_kernels_.size() + key_field_ids_.size()); - auto it = out_data_.begin(); +Status Finalize() { +out_data_.resize(agg_kernels_.size() + key_field_ids_.size()); +auto it = out_data_.begin(); - ThreadLocalState* state = local_states_.get(0); - num_out_groups_ = state->grouper->num_groups(); +ThreadLocalState* state = local_states_.get(0); +num_out_groups_ = state->grouper->num_groups(); - // Aggregate fields come before key fields to match the behavior of GroupBy function +// Aggregate fields come before key fields to match the behavior of GroupBy function - for (size_t i = 0; i < agg_kernels_.size(); ++i) { - KernelContext batch_ctx{ctx_}; - batch_ctx.SetState(state->agg_states[i].get()); - Datum out; - RETURN_NOT_OK(agg_kernels_[i]->finalize(&batch_ctx, &out)); - *it++ = out.array(); - state->agg_states[i].reset(); - } +for (size_t i = 0; i < agg_kernels_.size(); ++i) { +KernelContext batch_ctx{ctx_}; +batch_ctx.SetState(state->agg_states[i].get()); +Datum out; +RETURN_NOT_OK(agg_kernels_[i]->finalize(&batch_ctx, &out)); +*it++ = out.array(); +state->agg_states[i].reset(); +} - ARROW_ASSIGN_OR_RAISE(ExecBatch out_keys, state->grouper->GetUniques()); - for (const auto& key : out_keys.values) { - *it++ = key.array(); - } - state->grouper.reset(); +ARROW_ASSIGN_OR_RAISE(ExecBatch out_keys, state->grouper->GetUniques()); +for (const auto& key : out_keys.values) { +*it++ = key.array(); +} +state->grouper.reset(); - return Status::OK(); - } +return Status::OK(); +} - Status OutputNthBatch(int n) { - ARROW_DCHECK(output_started_.load()); +Status OutputNthBatch(int n) { +ARROW_DCHECK(output_started_.load()); - // Check finished flag - if (finished_.is_finished()) { - return Status::OK(); - } +// Check finished flag +if (finished_.is_finished()) { +return Status::OK(); +} - // Slice arrays - int64_t batch_size = output_batch_size(); - int64_t batch_start = n * batch_size; - int64_t batch_length = std::min(batch_size, num_out_groups_ - batch_start); - std::vector output_slices(out_data_.size()); - for (size_t out_field_id = 0; out_field_id < out_data_.size(); ++out_field_id) { - output_slices[out_field_id] = - out_data_[out_field_id]->Slice(batch_start, batch_length); - } +// Slice arrays +int64_t batch_size = output_batch_size(); +int64_t batch_start = n * batch_size; +int64_t batch_length = std::min(batch_size, num_out_groups_ - batch_start); +std::vector output_slices(out_data_.size()); +for (size_t out_field_id = 0; out_field_id < out_data_.size(); ++out_field_id) { +output_slices[out_field_id] = +out_data_[out_field_id]->Slice(batch_start, batch_length); +} - ARROW_ASSIGN_OR_RAISE(ExecBatch output_batch, ExecBatch::Make(output_slices)); - outputs_[0]->InputReceived(this, n, output_batch); +ARROW_ASSIGN_OR_RAISE(ExecBatch output_batch, ExecBatch::Make(output_slices)); +outputs_[0]->InputReceived(this, n, output_batch); - uint32_t num_output_batches_processed = - 1 + num_output_batches_processed_.fetch_add(1); - if (num_output_batches_processed * batch_size >= num_out_groups_) { - finished_.MarkFinished(); - } +uint32_t num_output_batches_processed = +1 + num_output_batches_processed_.fetch_add(1); +if (num_output_batches_processed * batch_size >= num_out_groups_) { +finished_.MarkFinished(); +} - return Status::OK(); - } +return Status::OK(); +} - Status OutputResult() { - bool expected = false; - if (!output_started_.compare_exchange_strong(expected, true)) { - return Status::OK(); - } +Status OutputResult() { +bool expected = false; +if (!output_started_.compare_exchange_strong(expected, true)) { +return Status::OK(); +} - RETURN_NOT_OK(Merge()); - RETURN_NOT_OK(Finalize()); +RETURN_NOT_OK(BuildSideMerge()); +RETURN_NOT_OK(Finalize()); - int batch_size = output_batch_size(); - int num_result_batches = (num_out_groups_ + batch_size - 1) / batch_size; - outputs_[0]->InputFinished(this, num_result_batches); +int batch_size = output_batch_size(); +int num_result_batches = (num_out_groups_ + batch_size - 1) / batch_size; +outputs_[0]->InputFinished(this, num_result_batches); - auto executor = arrow::internal::GetCpuThreadPool(); - for (int i = 0; i < num_result_batches; ++i) { - // Check finished flag - if (finished_.is_finished()) { - break; - } +auto executor = arrow::internal::GetCpuThreadPool(); +for (int i = 0; i < num_result_batches; ++i) { +// Check finished flag +if (finished_.is_finished()) { +break; +} - RETURN_NOT_OK(executor->Spawn([this, i]() { - Status status = OutputNthBatch(i); - if (!status.ok()) { - ErrorReceived(inputs_[0], status); - } - })); - } +RETURN_NOT_OK(executor->Spawn([this, i]() { +Status status = OutputNthBatch(i); +if (!status.ok()) { +ErrorReceived(inputs_[0], status); +} +})); +} - return Status::OK(); - } +return Status::OK(); +} - void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { - DCHECK_EQ(input, inputs_[0]); +void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { +assert(input == inputs_[0] || input == inputs_[1]); - if (finished_.is_finished()) { - return; - } +if (finished_.is_finished()) { +return; +} - ARROW_DCHECK(num_input_batches_processed_.load() != num_input_batches_total_.load()); +ARROW_DCHECK(num_build_batches_processed_.load() != num_build_batches_total_.load()); - Status status = ProcessInputBatch(batch); - if (!status.ok()) { - ErrorReceived(input, status); - return; - } +Status status = ProcessBuildSideBatch(batch); +if (!status.ok()) { +ErrorReceived(input, status); +return; +} - num_input_batches_processed_.fetch_add(1); - if (num_input_batches_processed_.load() == num_input_batches_total_.load()) { - status = OutputResult(); - if (!status.ok()) { - ErrorReceived(input, status); - return; - } - } - } +num_build_batches_processed_.fetch_add(1); +if (num_build_batches_processed_.load() == num_build_batches_total_.load()) { +status = OutputResult(); +if (!status.ok()) { +ErrorReceived(input, status); +return; +} +} +} - void ErrorReceived(ExecNode* input, Status error) override { - DCHECK_EQ(input, inputs_[0]); +void ErrorReceived(ExecNode* input, Status error) override { +DCHECK_EQ(input, inputs_[0]); - outputs_[0]->ErrorReceived(this, std::move(error)); - StopProducing(); - } +outputs_[0]->ErrorReceived(this, std::move(error)); +StopProducing(); +} - void InputFinished(ExecNode* input, int seq) override { - DCHECK_EQ(input, inputs_[0]); +void InputFinished(ExecNode* input, int seq) override { +DCHECK_EQ(input, inputs_[0]); - num_input_batches_total_.store(seq); - if (num_input_batches_processed_.load() == num_input_batches_total_.load()) { - Status status = OutputResult(); +num_build_batches_total_.store(seq); +if (num_build_batches_processed_.load() == num_build_batches_total_.load()) { +Status status = OutputResult(); - if (!status.ok()) { - ErrorReceived(input, status); - } - } - } +if (!status.ok()) { +ErrorReceived(input, status); +} +} +} - Status StartProducing() override { - finished_ = Future<>::Make(); - return Status::OK(); - } +Status StartProducing() override { +finished_ = Future<>::Make(); +return Status::OK(); +} - void PauseProducing(ExecNode* output) override {} +void PauseProducing(ExecNode* output) override {} - void ResumeProducing(ExecNode* output) override {} +void ResumeProducing(ExecNode* output) override {} - void StopProducing(ExecNode* output) override { - DCHECK_EQ(output, outputs_[0]); - inputs_[0]->StopProducing(this); +void StopProducing(ExecNode* output) override { +DCHECK_EQ(output, outputs_[0]); +inputs_[0]->StopProducing(this); - finished_.MarkFinished(); - } +finished_.MarkFinished(); +} - void StopProducing() override { StopProducing(outputs_[0]); } +void StopProducing() override { StopProducing(outputs_[0]); } - Future<> finished() override { return finished_; } +Future<> finished() override { return finished_; } - private: - int output_batch_size() const { - int result = static_cast(ctx_->exec_chunksize()); - if (result < 0) { - result = 32 * 1024; - } - return result; - } +private: +int output_batch_size() const { +int result = static_cast(ctx_->exec_chunksize()); +if (result < 0) { +result = 32 * 1024; +} +return result; +} - ExecContext* ctx_; - Future<> finished_ = Future<>::MakeFinished(); +ExecContext* ctx_; +Future<> finished_ = Future<>::MakeFinished(); - // std::atomic num_input_batches_processed_; - // std::atomic num_input_batches_total_; - // std::atomic num_output_batches_processed_; +std::atomic num_build_batches_processed_; +std::atomic num_build_batches_total_; +std::atomic num_probe_batches_processed_; +std::atomic num_probe_batches_total_; +std::atomic num_output_batches_processed_; - const std::vector index_field_ids_; +const std::vector index_field_ids_; - struct ThreadLocalState { - std::unique_ptr grouper; - std::vector> agg_states; - }; - SharedSequenceOfObjects local_states_; - SmallUniqueIdAssignment local_state_id_assignment_; - uint32_t num_out_groups_{0}; - ArrayDataVector out_data_; - std::atomic output_started_; +struct ThreadLocalState { +std::unique_ptr grouper; +std::vector early_probe_batches{}; }; +SharedSequenceOfObjects local_states_; +SmallUniqueIdAssignment local_state_id_assignment_; +uint32_t num_out_groups_{0}; +ArrayDataVector out_data_; +std::atomic output_started_, build_phase_finished_; +};*/ } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 4fd6af9b190..c1a03ce221e 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#include + #include #include #include @@ -369,30 +371,43 @@ struct GrouperImpl : Grouper { return std::move(impl); } - Result Consume(const ExecBatch& batch) override { - std::vector offsets_batch(batch.length + 1); + Status PopulateKeyData(const ExecBatch& batch, std::vector* offsets_batch, + std::vector* key_bytes_batch, + std::vector* key_buf_ptrs) { + offsets_batch->resize(batch.length + 1); for (int i = 0; i < batch.num_values(); ++i) { - encoders_[i]->AddLength(*batch[i].array(), offsets_batch.data()); + encoders_[i]->AddLength(*batch[i].array(), offsets_batch->data()); } int32_t total_length = 0; for (int64_t i = 0; i < batch.length; ++i) { auto total_length_before = total_length; - total_length += offsets_batch[i]; - offsets_batch[i] = total_length_before; + total_length += offsets_batch->at(i); + offsets_batch->at(i) = total_length_before; } - offsets_batch[batch.length] = total_length; + offsets_batch->at(batch.length) = total_length; - std::vector key_bytes_batch(total_length); - std::vector key_buf_ptrs(batch.length); + key_bytes_batch->resize(total_length); + key_buf_ptrs->resize(batch.length); for (int64_t i = 0; i < batch.length; ++i) { - key_buf_ptrs[i] = key_bytes_batch.data() + offsets_batch[i]; + key_buf_ptrs->at(i) = key_bytes_batch->data() + offsets_batch->at(i); } for (int i = 0; i < batch.num_values(); ++i) { - RETURN_NOT_OK(encoders_[i]->Encode(*batch[i].array(), key_buf_ptrs.data())); + RETURN_NOT_OK(encoders_[i]->Encode(*batch[i].array(), key_buf_ptrs->data())); } + return Status::OK(); + } + + Result Consume(const ExecBatch& batch) override { + std::vector offsets_batch; + std::vector key_bytes_batch; + std::vector key_buf_ptrs; + + RETURN_NOT_OK( + PopulateKeyData(batch, &offsets_batch, &key_bytes_batch, &key_buf_ptrs)); + TypedBufferBuilder group_ids_batch(ctx_->memory_pool()); RETURN_NOT_OK(group_ids_batch.Resize(batch.length)); @@ -421,6 +436,36 @@ struct GrouperImpl : Grouper { return Datum(UInt32Array(batch.length, std::move(group_ids))); } + Result Find(const ExecBatch& batch) override { + std::vector offsets_batch; + std::vector key_bytes_batch; + std::vector key_buf_ptrs; + + RETURN_NOT_OK( + PopulateKeyData(batch, &offsets_batch, &key_bytes_batch, &key_buf_ptrs)); + + UInt32Builder group_ids_batch(ctx_->memory_pool()); + RETURN_NOT_OK(group_ids_batch.Resize(batch.length)); + + for (int64_t i = 0; i < batch.length; ++i) { + int32_t key_length = offsets_batch[i + 1] - offsets_batch[i]; + std::string key( + reinterpret_cast(key_bytes_batch.data() + offsets_batch[i]), + key_length); + + auto it = map_.find(key); + // no group ID was found, null will be emitted! + if (it == map_.end()) { + group_ids_batch.UnsafeAppendNull(); + } else { + group_ids_batch.UnsafeAppend(it->second); + } + } + + ARROW_ASSIGN_OR_RAISE(auto group_ids, group_ids_batch.Finish()); + return Datum(group_ids); + } + uint32_t num_groups() const override { return num_groups_; } Result GetUniques() override { @@ -622,6 +667,11 @@ struct GrouperFastImpl : Grouper { return Datum(UInt32Array(batch.length, std::move(group_ids))); } + Result Find(const ExecBatch& batch) override { + // todo impl this + return Result(); + } + uint32_t num_groups() const override { return static_cast(rows_.length()); } // Make sure padded buffers end up with the right logical size @@ -1929,9 +1979,10 @@ Result ResolveKernels( Result> Grouper::Make(const std::vector& descrs, ExecContext* ctx) { - if (GrouperFastImpl::CanUse(descrs)) { - return GrouperFastImpl::Make(descrs, ctx); - } + // TODO(niranda) re-enable this! + // if (GrouperFastImpl::CanUse(descrs)) { + // return GrouperFastImpl::Make(descrs, ctx); + // } return GrouperImpl::Make(descrs, ctx); } diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc index c69b51e71fc..b392f2db4bf 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc @@ -308,6 +308,16 @@ struct TestGrouper { AssertEquivalentIds(expected, ids); } + void ExpectFind(const std::string& key_json, const std::string& expected) { + ExpectFind(ExecBatchFromJSON(descrs_, key_json), ArrayFromJSON(uint32(), expected)); + } + void ExpectFind(const ExecBatch& key_batch, Datum expected) { + ASSERT_OK_AND_ASSIGN(Datum id_batch, grouper_->Find(key_batch)); + ValidateOutput(id_batch); + + AssertDatumsEqual(expected, id_batch); + } + void AssertEquivalentIds(const Datum& expected, const Datum& actual) { auto left = expected.make_array(); auto right = actual.make_array(); @@ -433,6 +443,18 @@ TEST(Grouper, NumericKey) { g.ExpectConsume("[[3], [27], [3], [27], [null], [81], [27], [81]]", "[0, 1, 0, 1, 3, 2, 1, 2]"); + + g.ExpectFind("[[3], [3]]", "[0, 0]"); + + g.ExpectFind("[[3], [3]]", "[0, 0]"); + + g.ExpectFind("[[27], [81]]", "[1, 2]"); + + g.ExpectFind("[[3], [27], [3], [27], [null], [81], [27], [81]]", + "[0, 1, 0, 1, 3, 2, 1, 2]"); + + g.ExpectFind("[[27], [3], [27], [null], [81], [1], [81]]", + "[1, 0, 1, 3, 2, null, 2]"); } } From 749927457286502b0b12fc636cb06e163a527659 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 28 Jul 2021 14:44:56 -0400 Subject: [PATCH 03/27] incomplete --- cpp/src/arrow/compute/exec/exec_plan.cc | 452 ++++++++++++------------ 1 file changed, 232 insertions(+), 220 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index b102264331e..9036052ac96 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -346,278 +346,290 @@ ExecFactoryRegistry* default_exec_factory_registry() { return &instance; } -/*struct HashSemiIndexJoinNode : ExecNode { - HashSemiIndexJoinNode(ExecNode* left_input, ExecNode* right_input, std::string label, +struct HashSemiIndexJoinNode : ExecNode { + HashSemiIndexJoinNode(ExecNode* build_input, ExecNode* probe_input, std::string label, std::shared_ptr output_schema, ExecContext* ctx, const std::vector&& index_field_ids) - : ExecNode(left_input->plan(), std::move(label), {left_input, right_input}, - {"hashsemiindexjoin"}, std::move(output_schema), */ -/*num_outputs=*//*1), -ctx_(ctx), -num_build_batches_processed_(0), -num_build_batches_total_(-1), -num_probe_batches_processed_(0), -num_probe_batches_total_(-1), -num_output_batches_processed_(0), -index_field_ids_(std::move(index_field_ids)), -output_started_(false), -build_phase_finished_(false){} - -const char* kind_name() override { return "HashSemiIndexJoinNode"; } - -private: -struct ThreadLocalState; - -public: -Status InitLocalStateIfNeeded(ThreadLocalState* state) { -// Get input schema -auto input_schema = inputs_[0]->output_schema(); - -if (!state->grouper) { -// Build vector of key field data types -std::vector key_descrs(index_field_ids_.size()); -for (size_t i = 0; i < index_field_ids_.size(); ++i) { -auto key_field_id = index_field_ids_[i]; -key_descrs[i] = ValueDescr(input_schema->field(key_field_id)->type()); -} + : ExecNode(build_input->plan(), std::move(label), {build_input, probe_input}, + {"hash_join_build", "hash_join_probe"}, std::move(output_schema), + /*num_outputs=*/1), + ctx_(ctx), + index_field_ids_(std::move(index_field_ids)) {} + + const char* kind_name() override { return "HashSemiIndexJoinNode"; } + + public: + Status InitLocalStateIfNeeded(ThreadLocalState* state) { + // Get input schema + auto input_schema = inputs_[0]->output_schema(); + + if (!state->grouper) { + // Build vector of key field data types + std::vector key_descrs(index_field_ids_.size()); + for (size_t i = 0; i < index_field_ids_.size(); ++i) { + auto key_field_id = index_field_ids_[i]; + key_descrs[i] = ValueDescr(input_schema->field(key_field_id)->type()); + } -// Construct grouper -ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, ctx_)); -} + // Construct grouper + ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, ctx_)); + } -return Status::OK(); -} + return Status::OK(); + } -Status ProcessBuildSideBatch(const ExecBatch& batch) { -SmallUniqueIdHolder id_holder(&local_state_id_assignment_); -int id = id_holder.get(); -ThreadLocalState* state = local_states_.get(id); -RETURN_NOT_OK(InitLocalStateIfNeeded(state)); + Status ProcessBuildSideBatch(const ExecBatch& batch) { + SmallUniqueIdHolder id_holder(&local_state_id_assignment_); + int id = id_holder.get(); + ThreadLocalState* state = local_states_.get(id); + RETURN_NOT_OK(InitLocalStateIfNeeded(state)); -// Create a batch with key columns -std::vector keys(key_field_ids_.size()); -for (size_t i = 0; i < key_field_ids_.size(); ++i) { -keys[i] = batch.values[key_field_ids_[i]]; -} -ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); + // Create a batch with key columns + std::vector keys(key_field_ids_.size()); + for (size_t i = 0; i < key_field_ids_.size(); ++i) { + keys[i] = batch.values[key_field_ids_[i]]; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); -// Create a batch with group ids -ARROW_ASSIGN_OR_RAISE(Datum id_batch, state->grouper->Consume(key_batch)); + // Create a batch with group ids + ARROW_ASSIGN_OR_RAISE(Datum id_batch, state->grouper->Consume(key_batch)); -// Execute aggregate kernels -for (size_t i = 0; i < agg_kernels_.size(); ++i) { -KernelContext kernel_ctx{ctx_}; -kernel_ctx.SetState(state->agg_states[i].get()); + // Execute aggregate kernels + for (size_t i = 0; i < agg_kernels_.size(); ++i) { + KernelContext kernel_ctx{ctx_}; + kernel_ctx.SetState(state->agg_states[i].get()); -ARROW_ASSIGN_OR_RAISE( -auto agg_batch, -ExecBatch::Make({batch.values[agg_src_field_ids_[i]], id_batch})); + ARROW_ASSIGN_OR_RAISE( + auto agg_batch, + ExecBatch::Make({batch.values[agg_src_field_ids_[i]], id_batch})); -RETURN_NOT_OK(agg_kernels_[i]->resize(&kernel_ctx, state->grouper->num_groups())); -RETURN_NOT_OK(agg_kernels_[i]->consume(&kernel_ctx, agg_batch)); -} + RETURN_NOT_OK(agg_kernels_[i]->resize(&kernel_ctx, state->grouper->num_groups())); + RETURN_NOT_OK(agg_kernels_[i]->consume(&kernel_ctx, agg_batch)); + } -return Status::OK(); -} + return Status::OK(); + } -// merge all other groupers to grouper[0]. nothing needs to be done on the -// early_probe_batches, because when probing everyone -Status BuildSideMerge() { -int num_local_states = local_state_id_assignment_.num_ids(); -ThreadLocalState* state0 = local_states_.get(0); -for (int i = 1; i < num_local_states; ++i) { -ThreadLocalState* state = local_states_.get(i); -ARROW_DCHECK(state); -ARROW_DCHECK(state->grouper); -ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques()); -ARROW_ASSIGN_OR_RAISE(Datum _, state0->grouper->Consume(other_keys)); -state->grouper.reset(); -} -return Status::OK(); -} + // merge all other groupers to grouper[0]. nothing needs to be done on the + // early_probe_batches, because when probing everyone + Status BuildSideMerge() { + int num_local_states = local_state_id_assignment_.num_ids(); + ThreadLocalState* state0 = local_states_.get(0); + for (int i = 1; i < num_local_states; ++i) { + ThreadLocalState* state = local_states_.get(i); + ARROW_DCHECK(state); + ARROW_DCHECK(state->grouper); + ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques()); + ARROW_ASSIGN_OR_RAISE(Datum _, state0->grouper->Consume(other_keys)); + state->grouper.reset(); + } + return Status::OK(); + } -Status Finalize() { -out_data_.resize(agg_kernels_.size() + key_field_ids_.size()); -auto it = out_data_.begin(); + Status Finalize() { + out_data_.resize(agg_kernels_.size() + key_field_ids_.size()); + auto it = out_data_.begin(); -ThreadLocalState* state = local_states_.get(0); -num_out_groups_ = state->grouper->num_groups(); + ThreadLocalState* state = local_states_.get(0); + num_out_groups_ = state->grouper->num_groups(); -// Aggregate fields come before key fields to match the behavior of GroupBy function + // Aggregate fields come before key fields to match the behavior of GroupBy function -for (size_t i = 0; i < agg_kernels_.size(); ++i) { -KernelContext batch_ctx{ctx_}; -batch_ctx.SetState(state->agg_states[i].get()); -Datum out; -RETURN_NOT_OK(agg_kernels_[i]->finalize(&batch_ctx, &out)); -*it++ = out.array(); -state->agg_states[i].reset(); -} + for (size_t i = 0; i < agg_kernels_.size(); ++i) { + KernelContext batch_ctx{ctx_}; + batch_ctx.SetState(state->agg_states[i].get()); + Datum out; + RETURN_NOT_OK(agg_kernels_[i]->finalize(&batch_ctx, &out)); + *it++ = out.array(); + state->agg_states[i].reset(); + } -ARROW_ASSIGN_OR_RAISE(ExecBatch out_keys, state->grouper->GetUniques()); -for (const auto& key : out_keys.values) { -*it++ = key.array(); -} -state->grouper.reset(); + ARROW_ASSIGN_OR_RAISE(ExecBatch out_keys, state->grouper->GetUniques()); + for (const auto& key : out_keys.values) { + *it++ = key.array(); + } + state->grouper.reset(); -return Status::OK(); -} + return Status::OK(); + } -Status OutputNthBatch(int n) { -ARROW_DCHECK(output_started_.load()); + Status OutputNthBatch(int n) { + ARROW_DCHECK(output_started_.load()); -// Check finished flag -if (finished_.is_finished()) { -return Status::OK(); -} + // Check finished flag + if (finished_.is_finished()) { + return Status::OK(); + } -// Slice arrays -int64_t batch_size = output_batch_size(); -int64_t batch_start = n * batch_size; -int64_t batch_length = std::min(batch_size, num_out_groups_ - batch_start); -std::vector output_slices(out_data_.size()); -for (size_t out_field_id = 0; out_field_id < out_data_.size(); ++out_field_id) { -output_slices[out_field_id] = -out_data_[out_field_id]->Slice(batch_start, batch_length); -} + // Slice arrays + int64_t batch_size = output_batch_size(); + int64_t batch_start = n * batch_size; + int64_t batch_length = std::min(batch_size, num_out_groups_ - batch_start); + std::vector output_slices(out_data_.size()); + for (size_t out_field_id = 0; out_field_id < out_data_.size(); ++out_field_id) { + output_slices[out_field_id] = + out_data_[out_field_id]->Slice(batch_start, batch_length); + } -ARROW_ASSIGN_OR_RAISE(ExecBatch output_batch, ExecBatch::Make(output_slices)); -outputs_[0]->InputReceived(this, n, output_batch); + ARROW_ASSIGN_OR_RAISE(ExecBatch output_batch, ExecBatch::Make(output_slices)); + outputs_[0]->InputReceived(this, n, output_batch); -uint32_t num_output_batches_processed = -1 + num_output_batches_processed_.fetch_add(1); -if (num_output_batches_processed * batch_size >= num_out_groups_) { -finished_.MarkFinished(); -} + uint32_t num_output_batches_processed = + 1 + num_output_batches_processed_.fetch_add(1); + if (num_output_batches_processed * batch_size >= num_out_groups_) { + finished_.MarkFinished(); + } -return Status::OK(); -} + return Status::OK(); + } -Status OutputResult() { -bool expected = false; -if (!output_started_.compare_exchange_strong(expected, true)) { -return Status::OK(); -} + Status OutputResult() { + bool expected = false; + if (!output_started_.compare_exchange_strong(expected, true)) { + return Status::OK(); + } -RETURN_NOT_OK(BuildSideMerge()); -RETURN_NOT_OK(Finalize()); + RETURN_NOT_OK(BuildSideMerge()); + RETURN_NOT_OK(Finalize()); -int batch_size = output_batch_size(); -int num_result_batches = (num_out_groups_ + batch_size - 1) / batch_size; -outputs_[0]->InputFinished(this, num_result_batches); + int batch_size = output_batch_size(); + int num_result_batches = (num_out_groups_ + batch_size - 1) / batch_size; + outputs_[0]->InputFinished(this, num_result_batches); -auto executor = arrow::internal::GetCpuThreadPool(); -for (int i = 0; i < num_result_batches; ++i) { -// Check finished flag -if (finished_.is_finished()) { -break; -} + auto executor = arrow::internal::GetCpuThreadPool(); + for (int i = 0; i < num_result_batches; ++i) { + // Check finished flag + if (finished_.is_finished()) { + break; + } -RETURN_NOT_OK(executor->Spawn([this, i]() { -Status status = OutputNthBatch(i); -if (!status.ok()) { -ErrorReceived(inputs_[0], status); -} -})); -} + RETURN_NOT_OK(executor->Spawn([this, i]() { + Status status = OutputNthBatch(i); + if (!status.ok()) { + ErrorReceived(inputs_[0], status); + } + })); + } -return Status::OK(); -} + return Status::OK(); + } -void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { -assert(input == inputs_[0] || input == inputs_[1]); + Status ProcessBuildBatch(const ExecBatch& batch) { + // if build side is still going on + return Status::OK(); + } -if (finished_.is_finished()) { -return; -} + Status ProcessCachedProbeBatches() { return Status::OK(); } -ARROW_DCHECK(num_build_batches_processed_.load() != num_build_batches_total_.load()); + Status ProcessProbeBatch(const ExecBatch& batch) { return Status::OK(); } -Status status = ProcessBuildSideBatch(batch); -if (!status.ok()) { -ErrorReceived(input, status); -return; -} + Status CacheProbeBatch(const ExecBatch& batch) { return Status::OK(); } -num_build_batches_processed_.fetch_add(1); -if (num_build_batches_processed_.load() == num_build_batches_total_.load()) { -status = OutputResult(); -if (!status.ok()) { -ErrorReceived(input, status); -return; -} -} -} + // If all build side batches received? continue streaming using probing + // else cache the batches in thread-local state + void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { + ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); -void ErrorReceived(ExecNode* input, Status error) override { -DCHECK_EQ(input, inputs_[0]); + size_t thread_index = get_thread_index_(); + ARROW_DCHECK(thread_index < local_states_.size()); -outputs_[0]->ErrorReceived(this, std::move(error)); -StopProducing(); -} + if (finished_.is_finished()) { + return; + } -void InputFinished(ExecNode* input, int seq) override { -DCHECK_EQ(input, inputs_[0]); + if (build_counter_.IsComplete()) { // build side complete! + ARROW_DCHECK(input != inputs_[0]); // if a build batch is received, error! + if (ErrorIfNotOk(ProcessProbeBatch(batch))) return; + } else { // build side is still processing! + if (input == inputs_[1]) { // if a probe batch is received, cache it! + if (ErrorIfNotOk(CacheProbeBatch(batch))) return; + } else { // else process build batch + if (ErrorIfNotOk(ProcessBuildSideBatch(batch))) return; + } + } -num_build_batches_total_.store(seq); -if (num_build_batches_processed_.load() == num_build_batches_total_.load()) { -Status status = OutputResult(); + ARROW_DCHECK(num_build_batches_processed_.load() != num_build_batches_total_.load()); -if (!status.ok()) { -ErrorReceived(input, status); -} -} -} + Status status = ProcessBuildSideBatch(batch); + if (!status.ok()) { + ErrorReceived(input, status); + return; + } -Status StartProducing() override { -finished_ = Future<>::Make(); -return Status::OK(); -} + num_build_batches_processed_.fetch_add(1); + if (num_build_batches_processed_.load() == num_build_batches_total_.load()) { + status = OutputResult(); + if (!status.ok()) { + ErrorReceived(input, status); + return; + } + } + } -void PauseProducing(ExecNode* output) override {} + void ErrorReceived(ExecNode* input, Status error) override { + DCHECK_EQ(input, inputs_[0]); -void ResumeProducing(ExecNode* output) override {} + outputs_[0]->ErrorReceived(this, std::move(error)); + StopProducing(); + } -void StopProducing(ExecNode* output) override { -DCHECK_EQ(output, outputs_[0]); -inputs_[0]->StopProducing(this); + void InputFinished(ExecNode* input, int seq) override { + DCHECK_EQ(input, inputs_[0]); -finished_.MarkFinished(); -} + num_build_batches_total_.store(seq); + if (num_build_batches_processed_.load() == num_build_batches_total_.load()) { + Status status = OutputResult(); -void StopProducing() override { StopProducing(outputs_[0]); } + if (!status.ok()) { + ErrorReceived(input, status); + } + } + } -Future<> finished() override { return finished_; } + Status StartProducing() override { + finished_ = Future<>::Make(); + return Status::OK(); + } -private: -int output_batch_size() const { -int result = static_cast(ctx_->exec_chunksize()); -if (result < 0) { -result = 32 * 1024; -} -return result; -} + void PauseProducing(ExecNode* output) override {} + + void ResumeProducing(ExecNode* output) override {} -ExecContext* ctx_; -Future<> finished_ = Future<>::MakeFinished(); + void StopProducing(ExecNode* output) override { + DCHECK_EQ(output, outputs_[0]); + inputs_[0]->StopProducing(this); -std::atomic num_build_batches_processed_; -std::atomic num_build_batches_total_; -std::atomic num_probe_batches_processed_; -std::atomic num_probe_batches_total_; -std::atomic num_output_batches_processed_; + finished_.MarkFinished(); + } + + void StopProducing() override { StopProducing(outputs_[0]); } + + Future<> finished() override { return finished_; } -const std::vector index_field_ids_; + private: + int output_batch_size() const { + int result = static_cast(ctx_->exec_chunksize()); + if (result < 0) { + result = 32 * 1024; + } + return result; + } -struct ThreadLocalState { -std::unique_ptr grouper; -std::vector early_probe_batches{}; + struct ThreadLocalState { + std::unique_ptr grouper; + std::vector early_probe_batches{}; + }; + + ExecContext* ctx_; + Future<> finished_ = Future<>::MakeFinished(); + + ThreadIndexer get_thread_index_; + const std::vector index_field_ids_; + + AtomicCounter build_counter_, probe_counter_, out_counter_; + std::vector local_states_; + ExecBatch out_data_; }; -SharedSequenceOfObjects local_states_; -SmallUniqueIdAssignment local_state_id_assignment_; -uint32_t num_out_groups_{0}; -ArrayDataVector out_data_; -std::atomic output_started_, build_phase_finished_; -};*/ + } // namespace compute } // namespace arrow From 1a6ae7a5202ee7e293abf0def73fdf6e69240f16 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 29 Jul 2021 15:29:35 -0400 Subject: [PATCH 04/27] mid way --- cpp/src/arrow/compute/api_aggregate.h | 2 +- cpp/src/arrow/compute/exec/exec_plan.cc | 244 +++++++++++------- .../arrow/compute/kernels/hash_aggregate.cc | 6 +- 3 files changed, 151 insertions(+), 101 deletions(-) diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h index 87827889ca0..2798cb7ed04 100644 --- a/cpp/src/arrow/compute/api_aggregate.h +++ b/cpp/src/arrow/compute/api_aggregate.h @@ -386,7 +386,7 @@ class ARROW_EXPORT Grouper { /// Finds/ queries the group IDs for the given ExecBatch for every index. Returns the /// group IDs as an integer array. If a group ID not found, a UINT32_MAX will be /// added to that index. This is a thread-safe lookup. - virtual Result Find(const ExecBatch& batch) = 0; + virtual Result Find(const ExecBatch& batch) const = 0; /// Get current unique keys. May be called multiple times. virtual Result GetUniques() = 0; diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 9036052ac96..f21d6e0a0b2 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -354,7 +354,11 @@ struct HashSemiIndexJoinNode : ExecNode { {"hash_join_build", "hash_join_probe"}, std::move(output_schema), /*num_outputs=*/1), ctx_(ctx), - index_field_ids_(std::move(index_field_ids)) {} + index_field_ids_(index_field_ids), + build_side_complete_(false) {} + + private: + struct ThreadLocalState; const char* kind_name() override { return "HashSemiIndexJoinNode"; } @@ -363,94 +367,92 @@ struct HashSemiIndexJoinNode : ExecNode { // Get input schema auto input_schema = inputs_[0]->output_schema(); - if (!state->grouper) { - // Build vector of key field data types - std::vector key_descrs(index_field_ids_.size()); - for (size_t i = 0; i < index_field_ids_.size(); ++i) { - auto key_field_id = index_field_ids_[i]; - key_descrs[i] = ValueDescr(input_schema->field(key_field_id)->type()); - } + if (state->grouper != nullptr) return Status::OK(); - // Construct grouper - ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, ctx_)); + // Build vector of key field data types + std::vector key_descrs(index_field_ids_.size()); + for (size_t i = 0; i < index_field_ids_.size(); ++i) { + auto idx_field_id = index_field_ids_[i]; + key_descrs[i] = ValueDescr(input_schema->field(idx_field_id)->type()); } + // Construct grouper + ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, ctx_)); + return Status::OK(); } - Status ProcessBuildSideBatch(const ExecBatch& batch) { - SmallUniqueIdHolder id_holder(&local_state_id_assignment_); - int id = id_holder.get(); - ThreadLocalState* state = local_states_.get(id); + // merge all other groupers to grouper[0]. nothing needs to be done on the + // cached_probe_batches, because when probing everyone + Status BuildSideMerge() { + ThreadLocalState* state0 = &local_states_[0]; + for (int i = 1; i < local_states_.size(); ++i) { + ThreadLocalState* state = &local_states_[i]; + ARROW_DCHECK(state); + ARROW_DCHECK(state->grouper); + ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques()); + ARROW_ASSIGN_OR_RAISE(Datum _, state0->grouper->Consume(other_keys)); + state->grouper.reset(); + } + return Status::OK(); + } + + // consumes a build batch and increments the build_batches count. if the build batches + // total reached at the end of consumption, all the local states will be merged, before + // incrementing the total batches + Status ConsumeBuildBatch(const size_t thread_index, ExecBatch batch) { + auto state = &local_states_[thread_index]; RETURN_NOT_OK(InitLocalStateIfNeeded(state)); // Create a batch with key columns - std::vector keys(key_field_ids_.size()); - for (size_t i = 0; i < key_field_ids_.size(); ++i) { - keys[i] = batch.values[key_field_ids_[i]]; + std::vector keys(index_field_ids_.size()); + for (size_t i = 0; i < index_field_ids_.size(); ++i) { + keys[i] = batch.values[index_field_ids_[i]]; } ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); // Create a batch with group ids ARROW_ASSIGN_OR_RAISE(Datum id_batch, state->grouper->Consume(key_batch)); - // Execute aggregate kernels - for (size_t i = 0; i < agg_kernels_.size(); ++i) { - KernelContext kernel_ctx{ctx_}; - kernel_ctx.SetState(state->agg_states[i].get()); + if (build_counter_.Increment()) { + // while incrementing, if the total is reached, merge all the groupers to 0'th one + RETURN_NOT_OK(BuildSideMerge()); - ARROW_ASSIGN_OR_RAISE( - auto agg_batch, - ExecBatch::Make({batch.values[agg_src_field_ids_[i]], id_batch})); + // enable flag that build side is completed + build_side_complete_.store(true); - RETURN_NOT_OK(agg_kernels_[i]->resize(&kernel_ctx, state->grouper->num_groups())); - RETURN_NOT_OK(agg_kernels_[i]->consume(&kernel_ctx, agg_batch)); + // since the build side is completed, consume cached probe batches + RETURN_NOT_OK(ConsumeCachedProbeBatches(thread_index)); } return Status::OK(); } - // merge all other groupers to grouper[0]. nothing needs to be done on the - // early_probe_batches, because when probing everyone - Status BuildSideMerge() { - int num_local_states = local_state_id_assignment_.num_ids(); - ThreadLocalState* state0 = local_states_.get(0); - for (int i = 1; i < num_local_states; ++i) { - ThreadLocalState* state = local_states_.get(i); - ARROW_DCHECK(state); - ARROW_DCHECK(state->grouper); - ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques()); - ARROW_ASSIGN_OR_RAISE(Datum _, state0->grouper->Consume(other_keys)); - state->grouper.reset(); - } - return Status::OK(); - } - Status Finalize() { - out_data_.resize(agg_kernels_.size() + key_field_ids_.size()); - auto it = out_data_.begin(); + ThreadLocalState* state = &local_states_[0]; - ThreadLocalState* state = local_states_.get(0); - num_out_groups_ = state->grouper->num_groups(); + ExecBatch out_data{{}, state->grouper->num_groups()}; + out_data.values.resize(agg_kernels_.size() + key_field_ids_.size()); // Aggregate fields come before key fields to match the behavior of GroupBy function - for (size_t i = 0; i < agg_kernels_.size(); ++i) { KernelContext batch_ctx{ctx_}; batch_ctx.SetState(state->agg_states[i].get()); - Datum out; - RETURN_NOT_OK(agg_kernels_[i]->finalize(&batch_ctx, &out)); - *it++ = out.array(); + RETURN_NOT_OK(agg_kernels_[i]->finalize(&batch_ctx, &out_data.values[i])); state->agg_states[i].reset(); } ARROW_ASSIGN_OR_RAISE(ExecBatch out_keys, state->grouper->GetUniques()); - for (const auto& key : out_keys.values) { - *it++ = key.array(); - } + std::move(out_keys.values.begin(), out_keys.values.end(), + out_data.values.begin() + agg_kernels_.size()); state->grouper.reset(); - return Status::OK(); + if (output_counter_.SetTotal( + static_cast(BitUtil::CeilDiv(out_data.length, output_batch_size())))) { + // this will be hit if out_data.length == 0 + finished_.MarkFinished(); + } + return out_data; } Status OutputNthBatch(int n) { @@ -489,7 +491,6 @@ struct HashSemiIndexJoinNode : ExecNode { return Status::OK(); } - RETURN_NOT_OK(BuildSideMerge()); RETURN_NOT_OK(Finalize()); int batch_size = output_batch_size(); @@ -514,16 +515,53 @@ struct HashSemiIndexJoinNode : ExecNode { return Status::OK(); } - Status ProcessBuildBatch(const ExecBatch& batch) { - // if build side is still going on + Status ConsumeCachedProbeBatches(const size_t thread_index) { + ThreadLocalState* state = &local_states_[thread_index]; + + // TODO (niranda) check if this is the best way to move batches + for (ExecBatch batch : state->cached_probe_batches) { + RETURN_NOT_OK(ConsumeProbeBatch(std::move(batch))); + } + state->cached_probe_batches.clear(); + return Status::OK(); } - Status ProcessCachedProbeBatches() { return Status::OK(); } + // consumes a probe batch and increment probe batches count. Probing would query the + // grouper[0] which have been merged with all others. + Status ConsumeProbeBatch(ExecBatch batch) { + auto* grouper = local_states_[0].grouper.get(); + + // Create a batch with key columns + std::vector keys(index_field_ids_.size()); + for (size_t i = 0; i < index_field_ids_.size(); ++i) { + keys[i] = batch.values[index_field_ids_[i]]; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); - Status ProcessProbeBatch(const ExecBatch& batch) { return Status::OK(); } + // Query the grouper with key_batch. If no match was found, returning group_ids would + // have null. + ARROW_ASSIGN_OR_RAISE(Datum group_ids, grouper->Find(key_batch)); + auto group_ids_data = *group_ids.array(); - Status CacheProbeBatch(const ExecBatch& batch) { return Status::OK(); } + auto filter_arr = + std::make_shared(group_ids_data.length, group_ids_data.buffers[0], + /*null_bitmap=*/nullptr, /*null_count=*/0, + /*offset=*/group_ids_data.offset); + Filter(); + + probe_counter_.Increment(); + return Status::OK(); + } + + Status CacheProbeBatch(const size_t thread_index, ExecBatch batch) { + ThreadLocalState* state = &local_states_[thread_index]; + state->cached_probe_batches.push_back(std::move(batch)); + return Status::OK(); + } + + inline bool IsBuildInput(ExecNode* input) { return input == inputs_[0]; } + inline bool IsProbeInput(ExecNode* input) { return input == inputs_[1]; } // If all build side batches received? continue streaming using probing // else cache the batches in thread-local state @@ -537,31 +575,16 @@ struct HashSemiIndexJoinNode : ExecNode { return; } - if (build_counter_.IsComplete()) { // build side complete! - ARROW_DCHECK(input != inputs_[0]); // if a build batch is received, error! - if (ErrorIfNotOk(ProcessProbeBatch(batch))) return; - } else { // build side is still processing! - if (input == inputs_[1]) { // if a probe batch is received, cache it! - if (ErrorIfNotOk(CacheProbeBatch(batch))) return; - } else { // else process build batch - if (ErrorIfNotOk(ProcessBuildSideBatch(batch))) return; - } - } - - ARROW_DCHECK(num_build_batches_processed_.load() != num_build_batches_total_.load()); + if (IsBuildInput(input)) { // build input batch is received + // if a build input is received when build side is completed, something's wrong! + ARROW_DCHECK(!build_side_complete_.load()); - Status status = ProcessBuildSideBatch(batch); - if (!status.ok()) { - ErrorReceived(input, status); - return; - } - - num_build_batches_processed_.fetch_add(1); - if (num_build_batches_processed_.load() == num_build_batches_total_.load()) { - status = OutputResult(); - if (!status.ok()) { - ErrorReceived(input, status); - return; + if (ErrorIfNotOk(ConsumeBuildBatch(thread_index, std::move(batch)))) return; + } else { // probe input batch is received + if (build_side_complete_.load()) { // build side done, continue with probing + if (ErrorIfNotOk(ConsumeProbeBatch(std::move(batch)))) return; + } else { // build side not completed. Cache this batch! + if (ErrorIfNotOk(CacheProbeBatch(thread_index, std::move(batch)))) return; } } } @@ -573,21 +596,31 @@ struct HashSemiIndexJoinNode : ExecNode { StopProducing(); } - void InputFinished(ExecNode* input, int seq) override { - DCHECK_EQ(input, inputs_[0]); + void InputFinished(ExecNode* input, int num_total) override { + // bail if StopProducing was called + if (finished_.is_finished()) return; - num_build_batches_total_.store(seq); - if (num_build_batches_processed_.load() == num_build_batches_total_.load()) { - Status status = OutputResult(); + ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); - if (!status.ok()) { - ErrorReceived(input, status); - } + // set total for build input + if (IsBuildInput(input) && build_counter_.SetTotal(num_total)) { + // only build side has completed! so process cached probe batches (of this thread) + ErrorIfNotOk(ConsumeCachedProbeBatches(get_thread_index_())); + return; } + + // set total for probe input. If it returns that probe side has completed, nothing to + // do, because probing inputs will be streamed to the output + probe_counter_.SetTotal(num_total); + + // output will be streamed from the probe side. So, they will have the same total. + out_counter_.SetTotal(num_total); } Status StartProducing() override { finished_ = Future<>::Make(); + + local_states_.resize(ThreadIndexer::Capacity()); return Status::OK(); } @@ -596,13 +629,23 @@ struct HashSemiIndexJoinNode : ExecNode { void ResumeProducing(ExecNode* output) override {} void StopProducing(ExecNode* output) override { - DCHECK_EQ(output, outputs_[0]); - inputs_[0]->StopProducing(this); + // DCHECK_EQ(output, outputs_[0]); - finished_.MarkFinished(); + if (build_counter_.Cancel() || probe_counter_.Cancel() || out_counter_.Cancel()) { + finished_.MarkFinished(); + } + + for (auto&& input : inputs_) { + input->StopProducing(this); + } } - void StopProducing() override { StopProducing(outputs_[0]); } + // TODO(niranda) couldn't there be multiple outputs for a Node? + void StopProducing() override { + for (auto&& output : outputs_) { + StopProducing(output); + } + } Future<> finished() override { return finished_; } @@ -617,7 +660,7 @@ struct HashSemiIndexJoinNode : ExecNode { struct ThreadLocalState { std::unique_ptr grouper; - std::vector early_probe_batches{}; + std::vector cached_probe_batches{}; }; ExecContext* ctx_; @@ -628,6 +671,13 @@ struct HashSemiIndexJoinNode : ExecNode { AtomicCounter build_counter_, probe_counter_, out_counter_; std::vector local_states_; + + // need a separate atomic bool to track if the build side complete. Can't use the flag + // inside the AtomicCounter, because we need to merge the build groupers once we receive + // all the build batches. So, while merging, we need to prevent probe batches, being + // consumed. + std::atomic build_side_complete_; + ExecBatch out_data_; }; diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index c1a03ce221e..09f832f2740 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -373,7 +373,7 @@ struct GrouperImpl : Grouper { Status PopulateKeyData(const ExecBatch& batch, std::vector* offsets_batch, std::vector* key_bytes_batch, - std::vector* key_buf_ptrs) { + std::vector* key_buf_ptrs) const { offsets_batch->resize(batch.length + 1); for (int i = 0; i < batch.num_values(); ++i) { encoders_[i]->AddLength(*batch[i].array(), offsets_batch->data()); @@ -436,7 +436,7 @@ struct GrouperImpl : Grouper { return Datum(UInt32Array(batch.length, std::move(group_ids))); } - Result Find(const ExecBatch& batch) override { + Result Find(const ExecBatch& batch) const override { std::vector offsets_batch; std::vector key_bytes_batch; std::vector key_buf_ptrs; @@ -667,7 +667,7 @@ struct GrouperFastImpl : Grouper { return Datum(UInt32Array(batch.length, std::move(group_ids))); } - Result Find(const ExecBatch& batch) override { + Result Find(const ExecBatch& batch) const override { // todo impl this return Result(); } From 5822c7b9ea499c81f40adec9611e3a45bd471e2b Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 29 Jul 2021 23:02:48 -0400 Subject: [PATCH 05/27] untested --- cpp/src/arrow/compute/exec/exec_plan.cc | 163 ++++++------------------ 1 file changed, 40 insertions(+), 123 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index f21d6e0a0b2..c5e715f0619 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -346,10 +346,10 @@ ExecFactoryRegistry* default_exec_factory_registry() { return &instance; } -struct HashSemiIndexJoinNode : ExecNode { - HashSemiIndexJoinNode(ExecNode* build_input, ExecNode* probe_input, std::string label, - std::shared_ptr output_schema, ExecContext* ctx, - const std::vector&& index_field_ids) +struct HashSemiJoinNode : ExecNode { + HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, std::string label, + std::shared_ptr output_schema, ExecContext* ctx, + const std::vector&& index_field_ids) : ExecNode(build_input->plan(), std::move(label), {build_input, probe_input}, {"hash_join_build", "hash_join_probe"}, std::move(output_schema), /*num_outputs=*/1), @@ -360,9 +360,9 @@ struct HashSemiIndexJoinNode : ExecNode { private: struct ThreadLocalState; - const char* kind_name() override { return "HashSemiIndexJoinNode"; } - public: + const char* kind_name() override { return "HashSemiJoinNode"; } + Status InitLocalStateIfNeeded(ThreadLocalState* state) { // Get input schema auto input_schema = inputs_[0]->output_schema(); @@ -428,99 +428,12 @@ struct HashSemiIndexJoinNode : ExecNode { return Status::OK(); } - Status Finalize() { - ThreadLocalState* state = &local_states_[0]; - - ExecBatch out_data{{}, state->grouper->num_groups()}; - out_data.values.resize(agg_kernels_.size() + key_field_ids_.size()); - - // Aggregate fields come before key fields to match the behavior of GroupBy function - for (size_t i = 0; i < agg_kernels_.size(); ++i) { - KernelContext batch_ctx{ctx_}; - batch_ctx.SetState(state->agg_states[i].get()); - RETURN_NOT_OK(agg_kernels_[i]->finalize(&batch_ctx, &out_data.values[i])); - state->agg_states[i].reset(); - } - - ARROW_ASSIGN_OR_RAISE(ExecBatch out_keys, state->grouper->GetUniques()); - std::move(out_keys.values.begin(), out_keys.values.end(), - out_data.values.begin() + agg_kernels_.size()); - state->grouper.reset(); - - if (output_counter_.SetTotal( - static_cast(BitUtil::CeilDiv(out_data.length, output_batch_size())))) { - // this will be hit if out_data.length == 0 - finished_.MarkFinished(); - } - return out_data; - } - - Status OutputNthBatch(int n) { - ARROW_DCHECK(output_started_.load()); - - // Check finished flag - if (finished_.is_finished()) { - return Status::OK(); - } - - // Slice arrays - int64_t batch_size = output_batch_size(); - int64_t batch_start = n * batch_size; - int64_t batch_length = std::min(batch_size, num_out_groups_ - batch_start); - std::vector output_slices(out_data_.size()); - for (size_t out_field_id = 0; out_field_id < out_data_.size(); ++out_field_id) { - output_slices[out_field_id] = - out_data_[out_field_id]->Slice(batch_start, batch_length); - } - - ARROW_ASSIGN_OR_RAISE(ExecBatch output_batch, ExecBatch::Make(output_slices)); - outputs_[0]->InputReceived(this, n, output_batch); - - uint32_t num_output_batches_processed = - 1 + num_output_batches_processed_.fetch_add(1); - if (num_output_batches_processed * batch_size >= num_out_groups_) { - finished_.MarkFinished(); - } - - return Status::OK(); - } - - Status OutputResult() { - bool expected = false; - if (!output_started_.compare_exchange_strong(expected, true)) { - return Status::OK(); - } - - RETURN_NOT_OK(Finalize()); - - int batch_size = output_batch_size(); - int num_result_batches = (num_out_groups_ + batch_size - 1) / batch_size; - outputs_[0]->InputFinished(this, num_result_batches); - - auto executor = arrow::internal::GetCpuThreadPool(); - for (int i = 0; i < num_result_batches; ++i) { - // Check finished flag - if (finished_.is_finished()) { - break; - } - - RETURN_NOT_OK(executor->Spawn([this, i]() { - Status status = OutputNthBatch(i); - if (!status.ok()) { - ErrorReceived(inputs_[0], status); - } - })); - } - - return Status::OK(); - } - Status ConsumeCachedProbeBatches(const size_t thread_index) { ThreadLocalState* state = &local_states_[thread_index]; // TODO (niranda) check if this is the best way to move batches - for (ExecBatch batch : state->cached_probe_batches) { - RETURN_NOT_OK(ConsumeProbeBatch(std::move(batch))); + for (auto cached : state->cached_probe_batches) { + RETURN_NOT_OK(ConsumeProbeBatch(cached.first, std::move(cached.second))); } state->cached_probe_batches.clear(); @@ -529,7 +442,7 @@ struct HashSemiIndexJoinNode : ExecNode { // consumes a probe batch and increment probe batches count. Probing would query the // grouper[0] which have been merged with all others. - Status ConsumeProbeBatch(ExecBatch batch) { + Status ConsumeProbeBatch(int seq, ExecBatch batch) { auto* grouper = local_states_[0].grouper.get(); // Create a batch with key columns @@ -544,24 +457,34 @@ struct HashSemiIndexJoinNode : ExecNode { ARROW_ASSIGN_OR_RAISE(Datum group_ids, grouper->Find(key_batch)); auto group_ids_data = *group_ids.array(); - auto filter_arr = - std::make_shared(group_ids_data.length, group_ids_data.buffers[0], - /*null_bitmap=*/nullptr, /*null_count=*/0, - /*offset=*/group_ids_data.offset); - Filter(); - - probe_counter_.Increment(); + if (group_ids_data.MayHaveNulls()) { // values need to be filtered + auto filter_arr = + std::make_shared(group_ids_data.length, group_ids_data.buffers[0], + /*null_bitmap=*/nullptr, /*null_count=*/0, + /*offset=*/group_ids_data.offset); + ARROW_ASSIGN_OR_RAISE(auto rec_batch, + batch.ToRecordBatch(output_schema_, ctx_->memory_pool())); + ARROW_ASSIGN_OR_RAISE( + auto filtered, + Filter(rec_batch, filter_arr, + /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); + auto out_batch = ExecBatch(*filtered.record_batch()); + outputs_[0]->InputReceived(this, seq, std::move(out_batch)); + } else { // all values are valid for output + outputs_[0]->InputReceived(this, seq, std::move(batch)); + } + + out_counter_.Increment(); return Status::OK(); } - Status CacheProbeBatch(const size_t thread_index, ExecBatch batch) { + Status CacheProbeBatch(const size_t thread_index, int seq_num, ExecBatch batch) { ThreadLocalState* state = &local_states_[thread_index]; - state->cached_probe_batches.push_back(std::move(batch)); + state->cached_probe_batches.emplace_back(seq_num, std::move(batch)); return Status::OK(); } inline bool IsBuildInput(ExecNode* input) { return input == inputs_[0]; } - inline bool IsProbeInput(ExecNode* input) { return input == inputs_[1]; } // If all build side batches received? continue streaming using probing // else cache the batches in thread-local state @@ -582,9 +505,9 @@ struct HashSemiIndexJoinNode : ExecNode { if (ErrorIfNotOk(ConsumeBuildBatch(thread_index, std::move(batch)))) return; } else { // probe input batch is received if (build_side_complete_.load()) { // build side done, continue with probing - if (ErrorIfNotOk(ConsumeProbeBatch(std::move(batch)))) return; + if (ErrorIfNotOk(ConsumeProbeBatch(seq, std::move(batch)))) return; } else { // build side not completed. Cache this batch! - if (ErrorIfNotOk(CacheProbeBatch(thread_index, std::move(batch)))) return; + if (ErrorIfNotOk(CacheProbeBatch(thread_index, seq, std::move(batch)))) return; } } } @@ -611,10 +534,13 @@ struct HashSemiIndexJoinNode : ExecNode { // set total for probe input. If it returns that probe side has completed, nothing to // do, because probing inputs will be streamed to the output - probe_counter_.SetTotal(num_total); + // probe_counter_.SetTotal(num_total); // output will be streamed from the probe side. So, they will have the same total. - out_counter_.SetTotal(num_total); + if (out_counter_.SetTotal(num_total)) { + // if out_counter has completed, the future is finished! + finished_.MarkFinished(); + } } Status StartProducing() override { @@ -629,9 +555,9 @@ struct HashSemiIndexJoinNode : ExecNode { void ResumeProducing(ExecNode* output) override {} void StopProducing(ExecNode* output) override { - // DCHECK_EQ(output, outputs_[0]); + DCHECK_EQ(output, outputs_[0]); - if (build_counter_.Cancel() || probe_counter_.Cancel() || out_counter_.Cancel()) { + if (build_counter_.Cancel() || /*probe_counter_.Cancel() ||*/ out_counter_.Cancel()) { finished_.MarkFinished(); } @@ -650,17 +576,10 @@ struct HashSemiIndexJoinNode : ExecNode { Future<> finished() override { return finished_; } private: - int output_batch_size() const { - int result = static_cast(ctx_->exec_chunksize()); - if (result < 0) { - result = 32 * 1024; - } - return result; - } struct ThreadLocalState { std::unique_ptr grouper; - std::vector cached_probe_batches{}; + std::vector> cached_probe_batches{}; }; ExecContext* ctx_; @@ -669,7 +588,7 @@ struct HashSemiIndexJoinNode : ExecNode { ThreadIndexer get_thread_index_; const std::vector index_field_ids_; - AtomicCounter build_counter_, probe_counter_, out_counter_; + AtomicCounter build_counter_, /*probe_counter_,*/ out_counter_; std::vector local_states_; // need a separate atomic bool to track if the build side complete. Can't use the flag @@ -677,8 +596,6 @@ struct HashSemiIndexJoinNode : ExecNode { // all the build batches. So, while merging, we need to prevent probe batches, being // consumed. std::atomic build_side_complete_; - - ExecBatch out_data_; }; } // namespace compute From caae9dba14d1b64cd317446652b6a2ff80e3983f Mon Sep 17 00:00:00 2001 From: niranda perera Date: Fri, 30 Jul 2021 15:46:38 -0400 Subject: [PATCH 06/27] code complete --- cpp/src/arrow/compute/exec/exec_plan.cc | 112 ++++++++++++++++++++---- cpp/src/arrow/compute/exec/exec_plan.h | 52 +++++++++++ 2 files changed, 145 insertions(+), 19 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index c5e715f0619..e6d316fbe1e 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -348,13 +348,14 @@ ExecFactoryRegistry* default_exec_factory_registry() { struct HashSemiJoinNode : ExecNode { HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, std::string label, - std::shared_ptr output_schema, ExecContext* ctx, - const std::vector&& index_field_ids) + ExecContext* ctx, const std::vector&& build_index_field_ids, + const std::vector&& probe_index_field_ids) : ExecNode(build_input->plan(), std::move(label), {build_input, probe_input}, - {"hash_join_build", "hash_join_probe"}, std::move(output_schema), + {"hash_join_build", "hash_join_probe"}, probe_input->output_schema(), /*num_outputs=*/1), ctx_(ctx), - index_field_ids_(index_field_ids), + build_index_field_ids_(build_index_field_ids), + probe_index_field_ids_(probe_index_field_ids), build_side_complete_(false) {} private: @@ -365,15 +366,15 @@ struct HashSemiJoinNode : ExecNode { Status InitLocalStateIfNeeded(ThreadLocalState* state) { // Get input schema - auto input_schema = inputs_[0]->output_schema(); + auto build_schema = inputs_[0]->output_schema(); if (state->grouper != nullptr) return Status::OK(); // Build vector of key field data types - std::vector key_descrs(index_field_ids_.size()); - for (size_t i = 0; i < index_field_ids_.size(); ++i) { - auto idx_field_id = index_field_ids_[i]; - key_descrs[i] = ValueDescr(input_schema->field(idx_field_id)->type()); + std::vector key_descrs(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + auto build_type = build_schema->field(build_index_field_ids_[i])->type(); + key_descrs[i] = ValueDescr(build_type); } // Construct grouper @@ -405,9 +406,9 @@ struct HashSemiJoinNode : ExecNode { RETURN_NOT_OK(InitLocalStateIfNeeded(state)); // Create a batch with key columns - std::vector keys(index_field_ids_.size()); - for (size_t i = 0; i < index_field_ids_.size(); ++i) { - keys[i] = batch.values[index_field_ids_[i]]; + std::vector keys(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + keys[i] = batch.values[build_index_field_ids_[i]]; } ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); @@ -446,9 +447,9 @@ struct HashSemiJoinNode : ExecNode { auto* grouper = local_states_[0].grouper.get(); // Create a batch with key columns - std::vector keys(index_field_ids_.size()); - for (size_t i = 0; i < index_field_ids_.size(); ++i) { - keys[i] = batch.values[index_field_ids_[i]]; + std::vector keys(probe_index_field_ids_.size()); + for (size_t i = 0; i < probe_index_field_ids_.size(); ++i) { + keys[i] = batch.values[probe_index_field_ids_[i]]; } ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); @@ -502,12 +503,12 @@ struct HashSemiJoinNode : ExecNode { // if a build input is received when build side is completed, something's wrong! ARROW_DCHECK(!build_side_complete_.load()); - if (ErrorIfNotOk(ConsumeBuildBatch(thread_index, std::move(batch)))) return; + ErrorIfNotOk(ConsumeBuildBatch(thread_index, std::move(batch))); } else { // probe input batch is received if (build_side_complete_.load()) { // build side done, continue with probing - if (ErrorIfNotOk(ConsumeProbeBatch(seq, std::move(batch)))) return; + ErrorIfNotOk(ConsumeProbeBatch(seq, std::move(batch))); } else { // build side not completed. Cache this batch! - if (ErrorIfNotOk(CacheProbeBatch(thread_index, seq, std::move(batch)))) return; + ErrorIfNotOk(CacheProbeBatch(thread_index, seq, std::move(batch))); } } } @@ -586,7 +587,7 @@ struct HashSemiJoinNode : ExecNode { Future<> finished_ = Future<>::MakeFinished(); ThreadIndexer get_thread_index_; - const std::vector index_field_ids_; + const std::vector build_index_field_ids_, probe_index_field_ids_; AtomicCounter build_counter_, /*probe_counter_,*/ out_counter_; std::vector local_states_; @@ -598,5 +599,78 @@ struct HashSemiJoinNode : ExecNode { std::atomic build_side_complete_; }; +Status ValidateJoinInputs(ExecNode* left_input, ExecNode* right_input, + const std::vector& left_keys, + const std::vector& right_keys) { + if (left_keys.size() != right_keys.size()) { + return Status::Invalid("left and right key sizes do not match"); + } + + const auto& l_schema = left_input->output_schema(); + const auto& r_schema = right_input->output_schema(); + for (size_t i = 0; i < left_keys.size(); i++) { + auto l_type = l_schema->GetFieldByName(left_keys[i])->type(); + auto r_type = r_schema->GetFieldByName(right_keys[i])->type(); + + if (!l_type->Equals(r_type)) { + return Status::Invalid("build and probe types do not match: " + l_type->ToString() + + "!=" + r_type->ToString()); + } + } + + return Status::OK(); +} + +Result> PopulateKeys(const Schema& schema, + const std::vector& keys) { + std::vector key_field_ids(keys.size()); + // Find input field indices for left key fields + for (size_t i = 0; i < keys.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto match, FieldRef(keys[i]).FindOne(schema)); + key_field_ids[i] = match[0]; + } + + return key_field_ids; +} + +Result MakeHashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, + std::string label, + const std::vector& build_keys, + const std::vector& probe_keys) { + RETURN_NOT_OK(ValidateJoinInputs(build_input, probe_input, build_keys, probe_keys)); + + auto build_schema = build_input->output_schema(); + auto probe_schema = probe_input->output_schema(); + + ARROW_ASSIGN_OR_RAISE(auto build_key_ids, PopulateKeys(*build_schema, build_keys)); + ARROW_ASSIGN_OR_RAISE(auto probe_key_ids, PopulateKeys(*probe_schema, probe_keys)); + + // output schema will be probe schema + auto ctx = build_input->plan()->exec_context(); + ExecPlan* plan = build_input->plan(); + + return plan->EmplaceNode(build_input, probe_input, std::move(label), + ctx, std::move(build_key_ids), + std::move(probe_key_ids)); +} + +Result MakeHashLeftSemiJoinNode(ExecNode* left_input, ExecNode* right_input, + std::string label, + const std::vector& left_keys, + const std::vector& right_keys) { + // left join--> build from right and probe from left + return MakeHashSemiJoinNode(right_input, left_input, std::move(label), right_keys, + left_keys); +} + +Result MakeHashRightSemiJoinNode(ExecNode* left_input, ExecNode* right_input, + std::string label, + const std::vector& left_keys, + const std::vector& right_keys) { + // right join--> build from left and probe from right + return MakeHashSemiJoinNode(left_input, right_input, std::move(label), left_keys, + right_keys); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 4a784ceb75b..3250a75633f 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -353,5 +353,57 @@ std::shared_ptr MakeGeneratorReader( std::shared_ptr, std::function>()>, MemoryPool*); +/// \brief Make a node which excludes some rows from batches passed through it +/// +/// The filter Expression will be evaluated against each batch which is pushed to +/// this node. Any rows for which the filter does not evaluate to `true` will be excluded +/// in the batch emitted by this node. +/// +/// If the filter is not already bound, it will be bound against the input's schema. +ARROW_EXPORT +Result MakeFilterNode(ExecNode* input, std::string label, Expression filter); + +/// \brief Make a node which executes expressions on input batches, producing new batches. +/// +/// Each expression will be evaluated against each batch which is pushed to +/// this node to produce a corresponding output column. +/// +/// If exprs are not already bound, they will be bound against the input's schema. +/// If names are not provided, the string representations of exprs will be used. +ARROW_EXPORT +Result MakeProjectNode(ExecNode* input, std::string label, + std::vector exprs, + std::vector names = {}); + +ARROW_EXPORT +Result MakeScalarAggregateNode(ExecNode* input, std::string label, + std::vector aggregates, + std::vector arguments, + std::vector out_field_names); + +/// \brief Make a node which groups input rows based on key fields and computes +/// aggregates for each group +ARROW_EXPORT +Result MakeGroupByNode(ExecNode* input, std::string label, + std::vector keys, + std::vector agg_srcs, + std::vector aggs); + +ARROW_EXPORT +Result GroupByUsingExecPlan(const std::vector& arguments, + const std::vector& keys, + const std::vector& aggregates, + bool use_threads, ExecContext* ctx); + +/// \brief +Result MakeHashLeftSemiJoinNode(ExecNode* left_input, ExecNode* right_input, + std::string label, + std::vector left_keys, + std::vector right_keys); +ARROW_EXPORT +Result MakeHashRightSemiJoinNode(ExecNode* left_input, ExecNode* right_input, + std::string label, + const std::vector& left_keys, + const std::vector& right_keys); } // namespace compute } // namespace arrow From 1d6ec313ac2bdadcaa8adbbbd093a08264cab85f Mon Sep 17 00:00:00 2001 From: niranda perera Date: Fri, 30 Jul 2021 16:04:53 -0400 Subject: [PATCH 07/27] adding test case dummy --- cpp/src/arrow/compute/exec/exec_plan.cc | 2 +- cpp/src/arrow/compute/exec/plan_test.cc | 28 +++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index e6d316fbe1e..47727a1e111 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -387,7 +387,7 @@ struct HashSemiJoinNode : ExecNode { // cached_probe_batches, because when probing everyone Status BuildSideMerge() { ThreadLocalState* state0 = &local_states_[0]; - for (int i = 1; i < local_states_.size(); ++i) { + for (size_t i = 1; i < local_states_.size(); ++i) { ThreadLocalState* state = &local_states_[i]; ARROW_DCHECK(state); ARROW_DCHECK(state->grouper); diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index f4d81ace040..6471dc292b8 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -730,5 +730,33 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) { })))); } +TEST(ExecPlanExecution, SourceHashLeftSemiJoin) { + // TODO (Niranda) add this! + /* for (bool parallel : {false, true}) { + SCOPED_TRACE(parallel ? "parallel/merged" : "serial"); + + auto input = MakeGroupableBatches(*/ + /*multiplicity=*/ /*parallel ? 100 : 1); + +ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + +ASSERT_OK_AND_ASSIGN(auto source, + MakeTestSourceNode(plan.get(), "source", input, + */ + /*parallel=*//*parallel, */ /*slow=*//*false)); +ASSERT_OK_AND_ASSIGN( +auto gby, MakeGroupByNode(source, "gby", */ + /*keys=*//*{"str"}, */ /*targets=*/ /*{"i32"}, +{{"hash_sum", nullptr}})); +auto sink_gen = MakeSinkNode(gby, "sink"); + +ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), +Finishes(ResultWith(UnorderedElementsAreArray({ExecBatchFromJSON( +{int64(), utf8()}, +parallel ? R"([[800, "alfa"], [1000, "beta"], [400, "gama"]])" +: R"([[8, "alfa"], [10, "beta"], [4, "gama"]])")})))); +}*/ +} + } // namespace compute } // namespace arrow From 16c62cf990f396349da5718887e3acc181deda82 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 3 Aug 2021 00:42:01 -0400 Subject: [PATCH 08/27] adding PR comments --- cpp/src/arrow/compute/api_aggregate.h | 4 +-- cpp/src/arrow/compute/exec/exec_plan.cc | 40 +++++++++++++++---------- cpp/src/arrow/compute/exec/exec_plan.h | 4 +-- 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h index 2798cb7ed04..2fa36b32b21 100644 --- a/cpp/src/arrow/compute/api_aggregate.h +++ b/cpp/src/arrow/compute/api_aggregate.h @@ -384,8 +384,8 @@ class ARROW_EXPORT Grouper { virtual Result Consume(const ExecBatch& batch) = 0; /// Finds/ queries the group IDs for the given ExecBatch for every index. Returns the - /// group IDs as an integer array. If a group ID not found, a UINT32_MAX will be - /// added to that index. This is a thread-safe lookup. + /// group IDs as an integer array. If a group ID not found, a null will be added to that + /// index. This is a thread-safe lookup. virtual Result Find(const ExecBatch& batch) const = 0; /// Get current unique keys. May be called multiple times. diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 47727a1e111..25fe4272100 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -356,7 +356,7 @@ struct HashSemiJoinNode : ExecNode { ctx_(ctx), build_index_field_ids_(build_index_field_ids), probe_index_field_ids_(probe_index_field_ids), - build_side_complete_(false) {} + hash_table_built_(false) {} private: struct ThreadLocalState; @@ -384,7 +384,8 @@ struct HashSemiJoinNode : ExecNode { } // merge all other groupers to grouper[0]. nothing needs to be done on the - // cached_probe_batches, because when probing everyone + // cached_probe_batches, because when probing everyone. Note: Only one thread + // should execute this out of the pool! Status BuildSideMerge() { ThreadLocalState* state0 = &local_states_[0]; for (size_t i = 1; i < local_states_.size(); ++i) { @@ -416,11 +417,13 @@ struct HashSemiJoinNode : ExecNode { ARROW_ASSIGN_OR_RAISE(Datum id_batch, state->grouper->Consume(key_batch)); if (build_counter_.Increment()) { + // only a single thread would come inside this if-block! + // while incrementing, if the total is reached, merge all the groupers to 0'th one RETURN_NOT_OK(BuildSideMerge()); // enable flag that build side is completed - build_side_complete_.store(true); + hash_table_built_.store(true); // since the build side is completed, consume cached probe batches RETURN_NOT_OK(ConsumeCachedProbeBatches(thread_index)); @@ -433,10 +436,12 @@ struct HashSemiJoinNode : ExecNode { ThreadLocalState* state = &local_states_[thread_index]; // TODO (niranda) check if this is the best way to move batches - for (auto cached : state->cached_probe_batches) { - RETURN_NOT_OK(ConsumeProbeBatch(cached.first, std::move(cached.second))); + if (!state->cached_probe_batches.empty()) { + for (auto cached : state->cached_probe_batches) { + RETURN_NOT_OK(ConsumeProbeBatch(cached.first, std::move(cached.second))); + } + state->cached_probe_batches.clear(); } - state->cached_probe_batches.clear(); return Status::OK(); } @@ -479,10 +484,9 @@ struct HashSemiJoinNode : ExecNode { return Status::OK(); } - Status CacheProbeBatch(const size_t thread_index, int seq_num, ExecBatch batch) { + void CacheProbeBatch(const size_t thread_index, int seq_num, ExecBatch batch) { ThreadLocalState* state = &local_states_[thread_index]; state->cached_probe_batches.emplace_back(seq_num, std::move(batch)); - return Status::OK(); } inline bool IsBuildInput(ExecNode* input) { return input == inputs_[0]; } @@ -501,14 +505,18 @@ struct HashSemiJoinNode : ExecNode { if (IsBuildInput(input)) { // build input batch is received // if a build input is received when build side is completed, something's wrong! - ARROW_DCHECK(!build_side_complete_.load()); + ARROW_DCHECK(!hash_table_built_.load()); ErrorIfNotOk(ConsumeBuildBatch(thread_index, std::move(batch))); - } else { // probe input batch is received - if (build_side_complete_.load()) { // build side done, continue with probing + } else { // probe input batch is received + if (hash_table_built_.load()) { // build side done, continue with probing + // consume cachedProbeBatches if available (for this thread) + ErrorIfNotOk(ConsumeCachedProbeBatches(thread_index)); + + // consume this probe batch ErrorIfNotOk(ConsumeProbeBatch(seq, std::move(batch))); } else { // build side not completed. Cache this batch! - ErrorIfNotOk(CacheProbeBatch(thread_index, seq, std::move(batch))); + CacheProbeBatch(thread_index, seq, std::move(batch)); } } } @@ -558,7 +566,9 @@ struct HashSemiJoinNode : ExecNode { void StopProducing(ExecNode* output) override { DCHECK_EQ(output, outputs_[0]); - if (build_counter_.Cancel() || /*probe_counter_.Cancel() ||*/ out_counter_.Cancel()) { + if (build_counter_.Cancel()) { + finished_.MarkFinished(); + } else if (out_counter_.Cancel()) { finished_.MarkFinished(); } @@ -589,14 +599,14 @@ struct HashSemiJoinNode : ExecNode { ThreadIndexer get_thread_index_; const std::vector build_index_field_ids_, probe_index_field_ids_; - AtomicCounter build_counter_, /*probe_counter_,*/ out_counter_; + AtomicCounter build_counter_, out_counter_; std::vector local_states_; // need a separate atomic bool to track if the build side complete. Can't use the flag // inside the AtomicCounter, because we need to merge the build groupers once we receive // all the build batches. So, while merging, we need to prevent probe batches, being // consumed. - std::atomic build_side_complete_; + std::atomic hash_table_built_; }; Status ValidateJoinInputs(ExecNode* left_input, ExecNode* right_input, diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 3250a75633f..2cc5d354fd3 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -398,8 +398,8 @@ Result GroupByUsingExecPlan(const std::vector& arguments, /// \brief Result MakeHashLeftSemiJoinNode(ExecNode* left_input, ExecNode* right_input, std::string label, - std::vector left_keys, - std::vector right_keys); + const std::vector& left_keys, + const std::vector& right_keys); ARROW_EXPORT Result MakeHashRightSemiJoinNode(ExecNode* left_input, ExecNode* right_input, std::string label, From e391dc6eab9b5b7b6fc8134bda78f414631613e6 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 3 Aug 2021 00:42:11 -0400 Subject: [PATCH 09/27] adding serial test case --- cpp/src/arrow/compute/exec/plan_test.cc | 75 ++++++++++++++++--------- 1 file changed, 50 insertions(+), 25 deletions(-) diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 6471dc292b8..2b13652727f 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -730,32 +730,57 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) { })))); } +void GenerateBatchesFromString(const std::shared_ptr& schema, + const std::vector& json_strings, + BatchesWithSchema* out_batches) { + std::vector descrs; + for (auto&& field : schema->fields()) { + descrs.emplace_back(field->type()); + } + + for (auto&& s : json_strings) { + out_batches->batches.push_back(ExecBatchFromJSON(descrs, s)); + } + + out_batches->schema = schema; +} + TEST(ExecPlanExecution, SourceHashLeftSemiJoin) { - // TODO (Niranda) add this! - /* for (bool parallel : {false, true}) { - SCOPED_TRACE(parallel ? "parallel/merged" : "serial"); - - auto input = MakeGroupableBatches(*/ - /*multiplicity=*/ /*parallel ? 100 : 1); - -ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - -ASSERT_OK_AND_ASSIGN(auto source, - MakeTestSourceNode(plan.get(), "source", input, - */ - /*parallel=*//*parallel, */ /*slow=*//*false)); -ASSERT_OK_AND_ASSIGN( -auto gby, MakeGroupByNode(source, "gby", */ - /*keys=*//*{"str"}, */ /*targets=*/ /*{"i32"}, -{{"hash_sum", nullptr}})); -auto sink_gen = MakeSinkNode(gby, "sink"); - -ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), -Finishes(ResultWith(UnorderedElementsAreArray({ExecBatchFromJSON( -{int64(), utf8()}, -parallel ? R"([[800, "alfa"], [1000, "beta"], [400, "gama"]])" -: R"([[8, "alfa"], [10, "beta"], [4, "gama"]])")})))); -}*/ + BatchesWithSchema l_batches, r_batches; + + GenerateBatchesFromString(schema({field("l_i32", int32()), field("l_str", utf8())}), + {R"([[0,"d"], [1,"b"]])", R"([[2,"d"], [3,"a"], [4,"a"]])", + R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, + &l_batches); + + GenerateBatchesFromString( + schema({field("r_str", utf8()), field("r_i32", int32())}), + {R"([["f", 0], ["b", 1], ["b", 2]])", R"([["c", 3], ["g", 4]])", R"([["e", 5]])"}, + &r_batches); + + SCOPED_TRACE("serial"); + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + + ASSERT_OK_AND_ASSIGN(auto l_source, + MakeTestSourceNode(plan.get(), "l_source", l_batches, + /*parallel=*/false, + /*slow=*/false)); + ASSERT_OK_AND_ASSIGN(auto r_source, + MakeTestSourceNode(plan.get(), "r_source", r_batches, + /*parallel=*/false, + /*slow=*/false)); + + ASSERT_OK_AND_ASSIGN( + auto semi_join, + MakeHashLeftSemiJoinNode(l_source, r_source, "l_semi_join", + /*left_keys=*/{"l_str"}, /*right_keys=*/{"r_str"})); + auto sink_gen = MakeSinkNode(semi_join, "sink"); + + ASSERT_THAT( + StartAndCollect(plan.get(), sink_gen), + Finishes(ResultWith(UnorderedElementsAreArray({ExecBatchFromJSON( + {int64(), utf8()}, R"([[1,"b"], [5,"b"], [6,"c"], [7,"e"], [8,"e"]])")})))); } } // namespace compute From 24990ac5485828f0f9bda4a823763b3ed91ab29c Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 3 Aug 2021 01:32:54 -0400 Subject: [PATCH 10/27] passing test --- cpp/src/arrow/compute/exec/exec_plan.cc | 49 +++++++++++++++++++++++-- cpp/src/arrow/compute/exec/plan_test.cc | 12 +++--- 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 25fe4272100..f776f3cfd8a 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -17,6 +17,9 @@ #include "arrow/compute/exec/exec_plan.h" +#include +#include +#include #include #include @@ -365,6 +368,9 @@ struct HashSemiJoinNode : ExecNode { const char* kind_name() override { return "HashSemiJoinNode"; } Status InitLocalStateIfNeeded(ThreadLocalState* state) { + std::cout << "init" + << "\n"; + // Get input schema auto build_schema = inputs_[0]->output_schema(); @@ -387,11 +393,16 @@ struct HashSemiJoinNode : ExecNode { // cached_probe_batches, because when probing everyone. Note: Only one thread // should execute this out of the pool! Status BuildSideMerge() { + std::cout << "build side merge" + << "\n"; + ThreadLocalState* state0 = &local_states_[0]; for (size_t i = 1; i < local_states_.size(); ++i) { ThreadLocalState* state = &local_states_[i]; ARROW_DCHECK(state); - ARROW_DCHECK(state->grouper); + if (!state->grouper) { + continue; + } ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques()); ARROW_ASSIGN_OR_RAISE(Datum _, state0->grouper->Consume(other_keys)); state->grouper.reset(); @@ -403,6 +414,8 @@ struct HashSemiJoinNode : ExecNode { // total reached at the end of consumption, all the local states will be merged, before // incrementing the total batches Status ConsumeBuildBatch(const size_t thread_index, ExecBatch batch) { + std::cout << "ConsumeBuildBatch " << thread_index << " " << batch.length << "\n"; + auto state = &local_states_[thread_index]; RETURN_NOT_OK(InitLocalStateIfNeeded(state)); @@ -433,6 +446,8 @@ struct HashSemiJoinNode : ExecNode { } Status ConsumeCachedProbeBatches(const size_t thread_index) { + std::cout << "ConsumeCachedProbeBatches " << thread_index << "\n"; + ThreadLocalState* state = &local_states_[thread_index]; // TODO (niranda) check if this is the best way to move batches @@ -449,6 +464,8 @@ struct HashSemiJoinNode : ExecNode { // consumes a probe batch and increment probe batches count. Probing would query the // grouper[0] which have been merged with all others. Status ConsumeProbeBatch(int seq, ExecBatch batch) { + std::cout << "ConsumeProbeBatch " << seq << "\n"; + auto* grouper = local_states_[0].grouper.get(); // Create a batch with key columns @@ -475,8 +492,10 @@ struct HashSemiJoinNode : ExecNode { Filter(rec_batch, filter_arr, /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); auto out_batch = ExecBatch(*filtered.record_batch()); + std::cout << "output " << seq << " " << out_batch.ToString() << "\n"; outputs_[0]->InputReceived(this, seq, std::move(out_batch)); } else { // all values are valid for output + std::cout << "output " << seq << " " << batch.ToString() << "\n"; outputs_[0]->InputReceived(this, seq, std::move(batch)); } @@ -494,6 +513,9 @@ struct HashSemiJoinNode : ExecNode { // If all build side batches received? continue streaming using probing // else cache the batches in thread-local state void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { + std::cout << "input received " << IsBuildInput(input) << " " << seq << " " + << batch.length << "\n"; + ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); size_t thread_index = get_thread_index_(); @@ -522,6 +544,7 @@ struct HashSemiJoinNode : ExecNode { } void ErrorReceived(ExecNode* input, Status error) override { + std::cout << "error received " << error.ToString() << "\n"; DCHECK_EQ(input, inputs_[0]); outputs_[0]->ErrorReceived(this, std::move(error)); @@ -529,15 +552,25 @@ struct HashSemiJoinNode : ExecNode { } void InputFinished(ExecNode* input, int num_total) override { + std::cout << "input finished " << IsBuildInput(input) << " " << num_total << "\n"; + // bail if StopProducing was called if (finished_.is_finished()) return; ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); + size_t thread_index = get_thread_index_(); + // set total for build input if (IsBuildInput(input) && build_counter_.SetTotal(num_total)) { + // while incrementing, if the total is reached, merge all the groupers to 0'th one + ErrorIfNotOk(BuildSideMerge()); + + // enable flag that build side is completed + hash_table_built_.store(true); + // only build side has completed! so process cached probe batches (of this thread) - ErrorIfNotOk(ConsumeCachedProbeBatches(get_thread_index_())); + ErrorIfNotOk(ConsumeCachedProbeBatches(thread_index)); return; } @@ -548,11 +581,14 @@ struct HashSemiJoinNode : ExecNode { // output will be streamed from the probe side. So, they will have the same total. if (out_counter_.SetTotal(num_total)) { // if out_counter has completed, the future is finished! + ErrorIfNotOk(ConsumeCachedProbeBatches(thread_index)); finished_.MarkFinished(); } + outputs_[0]->InputFinished(this, num_total); } Status StartProducing() override { + std::cout << "start prod \n"; finished_ = Future<>::Make(); local_states_.resize(ThreadIndexer::Capacity()); @@ -564,6 +600,8 @@ struct HashSemiJoinNode : ExecNode { void ResumeProducing(ExecNode* output) override {} void StopProducing(ExecNode* output) override { + std::cout << "stop prod from node\n"; + DCHECK_EQ(output, outputs_[0]); if (build_counter_.Cancel()) { @@ -579,15 +617,18 @@ struct HashSemiJoinNode : ExecNode { // TODO(niranda) couldn't there be multiple outputs for a Node? void StopProducing() override { + std::cout << "stop prod \n"; for (auto&& output : outputs_) { StopProducing(output); } } - Future<> finished() override { return finished_; } + Future<> finished() override { + std::cout << "finished? " << finished_.is_finished() << "\n"; + return finished_; + } private: - struct ThreadLocalState { std::unique_ptr grouper; std::vector> cached_probe_batches{}; diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 2b13652727f..a1a0be1f3c7 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -746,7 +746,7 @@ void GenerateBatchesFromString(const std::shared_ptr& schema, } TEST(ExecPlanExecution, SourceHashLeftSemiJoin) { - BatchesWithSchema l_batches, r_batches; + BatchesWithSchema l_batches, r_batches, exp_batches; GenerateBatchesFromString(schema({field("l_i32", int32()), field("l_str", utf8())}), {R"([[0,"d"], [1,"b"]])", R"([[2,"d"], [3,"a"], [4,"a"]])", @@ -777,10 +777,12 @@ TEST(ExecPlanExecution, SourceHashLeftSemiJoin) { /*left_keys=*/{"l_str"}, /*right_keys=*/{"r_str"})); auto sink_gen = MakeSinkNode(semi_join, "sink"); - ASSERT_THAT( - StartAndCollect(plan.get(), sink_gen), - Finishes(ResultWith(UnorderedElementsAreArray({ExecBatchFromJSON( - {int64(), utf8()}, R"([[1,"b"], [5,"b"], [6,"c"], [7,"e"], [8,"e"]])")})))); + GenerateBatchesFromString( + schema({field("l_i32", int32()), field("l_str", utf8())}), + {R"([[1,"b"]])", R"([])", R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, &exp_batches); + + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); } } // namespace compute From 50db0e2fee274b28cdbbd1298d3d423c391fc3d8 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 3 Aug 2021 16:52:46 -0400 Subject: [PATCH 11/27] refactoring files --- cpp/src/arrow/compute/exec/CMakeLists.txt | 1 + cpp/src/arrow/compute/exec/exec_plan.cc | 375 +--------------- cpp/src/arrow/compute/exec/exec_plan.h | 26 +- cpp/src/arrow/compute/exec/exec_utils.cc | 78 ++++ .../exec/{hash_join.h => exec_utils.h} | 47 +- cpp/src/arrow/compute/exec/hash_join.cc | 412 +++++++++++++++++- cpp/src/arrow/compute/exec/hash_join_test.cc | 88 ++++ cpp/src/arrow/compute/exec/plan_test.cc | 66 +-- cpp/src/arrow/compute/exec/test_util.cc | 75 ++++ cpp/src/arrow/compute/exec/test_util.h | 22 + 10 files changed, 740 insertions(+), 450 deletions(-) create mode 100644 cpp/src/arrow/compute/exec/exec_utils.cc rename cpp/src/arrow/compute/exec/{hash_join.h => exec_utils.h} (50%) create mode 100644 cpp/src/arrow/compute/exec/hash_join_test.cc diff --git a/cpp/src/arrow/compute/exec/CMakeLists.txt b/cpp/src/arrow/compute/exec/CMakeLists.txt index 2ed8b1c9480..281154e3518 100644 --- a/cpp/src/arrow/compute/exec/CMakeLists.txt +++ b/cpp/src/arrow/compute/exec/CMakeLists.txt @@ -25,5 +25,6 @@ add_arrow_compute_test(expression_test subtree_test.cc) add_arrow_compute_test(plan_test PREFIX "arrow-compute") +add_arrow_compute_test(hash_join_test PREFIX "arrow-compute") add_arrow_benchmark(expression_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index f776f3cfd8a..846961facf9 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -24,6 +24,7 @@ #include #include "arrow/compute/exec.h" +#include "arrow/compute/exec/exec_utils.h" #include "arrow/compute/exec/expression.h" #include "arrow/compute/exec_internal.h" #include "arrow/compute/registry.h" @@ -349,379 +350,5 @@ ExecFactoryRegistry* default_exec_factory_registry() { return &instance; } -struct HashSemiJoinNode : ExecNode { - HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, std::string label, - ExecContext* ctx, const std::vector&& build_index_field_ids, - const std::vector&& probe_index_field_ids) - : ExecNode(build_input->plan(), std::move(label), {build_input, probe_input}, - {"hash_join_build", "hash_join_probe"}, probe_input->output_schema(), - /*num_outputs=*/1), - ctx_(ctx), - build_index_field_ids_(build_index_field_ids), - probe_index_field_ids_(probe_index_field_ids), - hash_table_built_(false) {} - - private: - struct ThreadLocalState; - - public: - const char* kind_name() override { return "HashSemiJoinNode"; } - - Status InitLocalStateIfNeeded(ThreadLocalState* state) { - std::cout << "init" - << "\n"; - - // Get input schema - auto build_schema = inputs_[0]->output_schema(); - - if (state->grouper != nullptr) return Status::OK(); - - // Build vector of key field data types - std::vector key_descrs(build_index_field_ids_.size()); - for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { - auto build_type = build_schema->field(build_index_field_ids_[i])->type(); - key_descrs[i] = ValueDescr(build_type); - } - - // Construct grouper - ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, ctx_)); - - return Status::OK(); - } - - // merge all other groupers to grouper[0]. nothing needs to be done on the - // cached_probe_batches, because when probing everyone. Note: Only one thread - // should execute this out of the pool! - Status BuildSideMerge() { - std::cout << "build side merge" - << "\n"; - - ThreadLocalState* state0 = &local_states_[0]; - for (size_t i = 1; i < local_states_.size(); ++i) { - ThreadLocalState* state = &local_states_[i]; - ARROW_DCHECK(state); - if (!state->grouper) { - continue; - } - ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques()); - ARROW_ASSIGN_OR_RAISE(Datum _, state0->grouper->Consume(other_keys)); - state->grouper.reset(); - } - return Status::OK(); - } - - // consumes a build batch and increments the build_batches count. if the build batches - // total reached at the end of consumption, all the local states will be merged, before - // incrementing the total batches - Status ConsumeBuildBatch(const size_t thread_index, ExecBatch batch) { - std::cout << "ConsumeBuildBatch " << thread_index << " " << batch.length << "\n"; - - auto state = &local_states_[thread_index]; - RETURN_NOT_OK(InitLocalStateIfNeeded(state)); - - // Create a batch with key columns - std::vector keys(build_index_field_ids_.size()); - for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { - keys[i] = batch.values[build_index_field_ids_[i]]; - } - ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); - - // Create a batch with group ids - ARROW_ASSIGN_OR_RAISE(Datum id_batch, state->grouper->Consume(key_batch)); - - if (build_counter_.Increment()) { - // only a single thread would come inside this if-block! - - // while incrementing, if the total is reached, merge all the groupers to 0'th one - RETURN_NOT_OK(BuildSideMerge()); - - // enable flag that build side is completed - hash_table_built_.store(true); - - // since the build side is completed, consume cached probe batches - RETURN_NOT_OK(ConsumeCachedProbeBatches(thread_index)); - } - - return Status::OK(); - } - - Status ConsumeCachedProbeBatches(const size_t thread_index) { - std::cout << "ConsumeCachedProbeBatches " << thread_index << "\n"; - - ThreadLocalState* state = &local_states_[thread_index]; - - // TODO (niranda) check if this is the best way to move batches - if (!state->cached_probe_batches.empty()) { - for (auto cached : state->cached_probe_batches) { - RETURN_NOT_OK(ConsumeProbeBatch(cached.first, std::move(cached.second))); - } - state->cached_probe_batches.clear(); - } - - return Status::OK(); - } - - // consumes a probe batch and increment probe batches count. Probing would query the - // grouper[0] which have been merged with all others. - Status ConsumeProbeBatch(int seq, ExecBatch batch) { - std::cout << "ConsumeProbeBatch " << seq << "\n"; - - auto* grouper = local_states_[0].grouper.get(); - - // Create a batch with key columns - std::vector keys(probe_index_field_ids_.size()); - for (size_t i = 0; i < probe_index_field_ids_.size(); ++i) { - keys[i] = batch.values[probe_index_field_ids_[i]]; - } - ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); - - // Query the grouper with key_batch. If no match was found, returning group_ids would - // have null. - ARROW_ASSIGN_OR_RAISE(Datum group_ids, grouper->Find(key_batch)); - auto group_ids_data = *group_ids.array(); - - if (group_ids_data.MayHaveNulls()) { // values need to be filtered - auto filter_arr = - std::make_shared(group_ids_data.length, group_ids_data.buffers[0], - /*null_bitmap=*/nullptr, /*null_count=*/0, - /*offset=*/group_ids_data.offset); - ARROW_ASSIGN_OR_RAISE(auto rec_batch, - batch.ToRecordBatch(output_schema_, ctx_->memory_pool())); - ARROW_ASSIGN_OR_RAISE( - auto filtered, - Filter(rec_batch, filter_arr, - /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); - auto out_batch = ExecBatch(*filtered.record_batch()); - std::cout << "output " << seq << " " << out_batch.ToString() << "\n"; - outputs_[0]->InputReceived(this, seq, std::move(out_batch)); - } else { // all values are valid for output - std::cout << "output " << seq << " " << batch.ToString() << "\n"; - outputs_[0]->InputReceived(this, seq, std::move(batch)); - } - - out_counter_.Increment(); - return Status::OK(); - } - - void CacheProbeBatch(const size_t thread_index, int seq_num, ExecBatch batch) { - ThreadLocalState* state = &local_states_[thread_index]; - state->cached_probe_batches.emplace_back(seq_num, std::move(batch)); - } - - inline bool IsBuildInput(ExecNode* input) { return input == inputs_[0]; } - - // If all build side batches received? continue streaming using probing - // else cache the batches in thread-local state - void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { - std::cout << "input received " << IsBuildInput(input) << " " << seq << " " - << batch.length << "\n"; - - ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); - - size_t thread_index = get_thread_index_(); - ARROW_DCHECK(thread_index < local_states_.size()); - - if (finished_.is_finished()) { - return; - } - - if (IsBuildInput(input)) { // build input batch is received - // if a build input is received when build side is completed, something's wrong! - ARROW_DCHECK(!hash_table_built_.load()); - - ErrorIfNotOk(ConsumeBuildBatch(thread_index, std::move(batch))); - } else { // probe input batch is received - if (hash_table_built_.load()) { // build side done, continue with probing - // consume cachedProbeBatches if available (for this thread) - ErrorIfNotOk(ConsumeCachedProbeBatches(thread_index)); - - // consume this probe batch - ErrorIfNotOk(ConsumeProbeBatch(seq, std::move(batch))); - } else { // build side not completed. Cache this batch! - CacheProbeBatch(thread_index, seq, std::move(batch)); - } - } - } - - void ErrorReceived(ExecNode* input, Status error) override { - std::cout << "error received " << error.ToString() << "\n"; - DCHECK_EQ(input, inputs_[0]); - - outputs_[0]->ErrorReceived(this, std::move(error)); - StopProducing(); - } - - void InputFinished(ExecNode* input, int num_total) override { - std::cout << "input finished " << IsBuildInput(input) << " " << num_total << "\n"; - - // bail if StopProducing was called - if (finished_.is_finished()) return; - - ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); - - size_t thread_index = get_thread_index_(); - - // set total for build input - if (IsBuildInput(input) && build_counter_.SetTotal(num_total)) { - // while incrementing, if the total is reached, merge all the groupers to 0'th one - ErrorIfNotOk(BuildSideMerge()); - - // enable flag that build side is completed - hash_table_built_.store(true); - - // only build side has completed! so process cached probe batches (of this thread) - ErrorIfNotOk(ConsumeCachedProbeBatches(thread_index)); - return; - } - - // set total for probe input. If it returns that probe side has completed, nothing to - // do, because probing inputs will be streamed to the output - // probe_counter_.SetTotal(num_total); - - // output will be streamed from the probe side. So, they will have the same total. - if (out_counter_.SetTotal(num_total)) { - // if out_counter has completed, the future is finished! - ErrorIfNotOk(ConsumeCachedProbeBatches(thread_index)); - finished_.MarkFinished(); - } - outputs_[0]->InputFinished(this, num_total); - } - - Status StartProducing() override { - std::cout << "start prod \n"; - finished_ = Future<>::Make(); - - local_states_.resize(ThreadIndexer::Capacity()); - return Status::OK(); - } - - void PauseProducing(ExecNode* output) override {} - - void ResumeProducing(ExecNode* output) override {} - - void StopProducing(ExecNode* output) override { - std::cout << "stop prod from node\n"; - - DCHECK_EQ(output, outputs_[0]); - - if (build_counter_.Cancel()) { - finished_.MarkFinished(); - } else if (out_counter_.Cancel()) { - finished_.MarkFinished(); - } - - for (auto&& input : inputs_) { - input->StopProducing(this); - } - } - - // TODO(niranda) couldn't there be multiple outputs for a Node? - void StopProducing() override { - std::cout << "stop prod \n"; - for (auto&& output : outputs_) { - StopProducing(output); - } - } - - Future<> finished() override { - std::cout << "finished? " << finished_.is_finished() << "\n"; - return finished_; - } - - private: - struct ThreadLocalState { - std::unique_ptr grouper; - std::vector> cached_probe_batches{}; - }; - - ExecContext* ctx_; - Future<> finished_ = Future<>::MakeFinished(); - - ThreadIndexer get_thread_index_; - const std::vector build_index_field_ids_, probe_index_field_ids_; - - AtomicCounter build_counter_, out_counter_; - std::vector local_states_; - - // need a separate atomic bool to track if the build side complete. Can't use the flag - // inside the AtomicCounter, because we need to merge the build groupers once we receive - // all the build batches. So, while merging, we need to prevent probe batches, being - // consumed. - std::atomic hash_table_built_; -}; - -Status ValidateJoinInputs(ExecNode* left_input, ExecNode* right_input, - const std::vector& left_keys, - const std::vector& right_keys) { - if (left_keys.size() != right_keys.size()) { - return Status::Invalid("left and right key sizes do not match"); - } - - const auto& l_schema = left_input->output_schema(); - const auto& r_schema = right_input->output_schema(); - for (size_t i = 0; i < left_keys.size(); i++) { - auto l_type = l_schema->GetFieldByName(left_keys[i])->type(); - auto r_type = r_schema->GetFieldByName(right_keys[i])->type(); - - if (!l_type->Equals(r_type)) { - return Status::Invalid("build and probe types do not match: " + l_type->ToString() + - "!=" + r_type->ToString()); - } - } - - return Status::OK(); -} - -Result> PopulateKeys(const Schema& schema, - const std::vector& keys) { - std::vector key_field_ids(keys.size()); - // Find input field indices for left key fields - for (size_t i = 0; i < keys.size(); ++i) { - ARROW_ASSIGN_OR_RAISE(auto match, FieldRef(keys[i]).FindOne(schema)); - key_field_ids[i] = match[0]; - } - - return key_field_ids; -} - -Result MakeHashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, - std::string label, - const std::vector& build_keys, - const std::vector& probe_keys) { - RETURN_NOT_OK(ValidateJoinInputs(build_input, probe_input, build_keys, probe_keys)); - - auto build_schema = build_input->output_schema(); - auto probe_schema = probe_input->output_schema(); - - ARROW_ASSIGN_OR_RAISE(auto build_key_ids, PopulateKeys(*build_schema, build_keys)); - ARROW_ASSIGN_OR_RAISE(auto probe_key_ids, PopulateKeys(*probe_schema, probe_keys)); - - // output schema will be probe schema - auto ctx = build_input->plan()->exec_context(); - ExecPlan* plan = build_input->plan(); - - return plan->EmplaceNode(build_input, probe_input, std::move(label), - ctx, std::move(build_key_ids), - std::move(probe_key_ids)); -} - -Result MakeHashLeftSemiJoinNode(ExecNode* left_input, ExecNode* right_input, - std::string label, - const std::vector& left_keys, - const std::vector& right_keys) { - // left join--> build from right and probe from left - return MakeHashSemiJoinNode(right_input, left_input, std::move(label), right_keys, - left_keys); -} - -Result MakeHashRightSemiJoinNode(ExecNode* left_input, ExecNode* right_input, - std::string label, - const std::vector& left_keys, - const std::vector& right_keys) { - // right join--> build from left and probe from right - return MakeHashSemiJoinNode(left_input, right_input, std::move(label), left_keys, - right_keys); -} - } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 2cc5d354fd3..540c0cf3482 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -395,15 +395,23 @@ Result GroupByUsingExecPlan(const std::vector& arguments, const std::vector& aggregates, bool use_threads, ExecContext* ctx); -/// \brief -Result MakeHashLeftSemiJoinNode(ExecNode* left_input, ExecNode* right_input, - std::string label, - const std::vector& left_keys, - const std::vector& right_keys); +/// \brief Make a node which joins batches from two other nodes based on key fields +enum JoinType { + LEFT_SEMI, + RIGHT_SEMI, + LEFT_ANTI, + RIGHT_ANTI, + INNER, // Not Implemented + LEFT_OUTER, // Not Implemented + RIGHT_OUTER, // Not Implemented + FULL_OUTER // Not Implemented +}; + ARROW_EXPORT -Result MakeHashRightSemiJoinNode(ExecNode* left_input, ExecNode* right_input, - std::string label, - const std::vector& left_keys, - const std::vector& right_keys); +Result MakeHashJoinNode(JoinType join_type, ExecNode* left_input, + ExecNode* right_input, std::string label, + const std::vector& left_keys, + const std::vector& right_keys); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/exec_utils.cc b/cpp/src/arrow/compute/exec/exec_utils.cc new file mode 100644 index 00000000000..f1a96ac0812 --- /dev/null +++ b/cpp/src/arrow/compute/exec/exec_utils.cc @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec/exec_utils.h" + +#include + +namespace arrow { +namespace compute { + +size_t ThreadIndexer::operator()() { + auto id = std::this_thread::get_id(); + + std::unique_lock lock(mutex_); + const auto& id_index = *id_to_index_.emplace(id, id_to_index_.size()).first; + + return Check(id_index.second); +} + +size_t ThreadIndexer::Capacity() { + static size_t max_size = arrow::internal::ThreadPool::DefaultCapacity(); + return max_size; +} + +size_t ThreadIndexer::Check(size_t thread_index) { + DCHECK_LT(thread_index, Capacity()) + << "thread index " << thread_index << " is out of range [0, " << Capacity() << ")"; + + return thread_index; +} + +int AtomicCounter::count() const { return count_.load(); } + +util::optional AtomicCounter::total() const { + int total = total_.load(); + if (total == -1) return {}; + return total; +} + +bool AtomicCounter::Increment() { + DCHECK_NE(count_.load(), total_.load()); + int count = count_.fetch_add(1) + 1; + if (count != total_.load()) return false; + return DoneOnce(); +} + +// return true if the counter is complete +bool AtomicCounter::SetTotal(int total) { + total_.store(total); + if (count_.load() != total) return false; + return DoneOnce(); +} + +// return true if the counter has not already been completed +bool AtomicCounter::Cancel() { return DoneOnce(); } + +// ensure there is only one true return from Increment(), SetTotal(), or Cancel() +bool AtomicCounter::DoneOnce() { + bool expected = false; + return complete_.compare_exchange_strong(expected, true); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/hash_join.h b/cpp/src/arrow/compute/exec/exec_utils.h similarity index 50% rename from cpp/src/arrow/compute/exec/hash_join.h rename to cpp/src/arrow/compute/exec/exec_utils.h index 492cf0a0a49..c899e3b3058 100644 --- a/cpp/src/arrow/compute/exec/hash_join.h +++ b/cpp/src/arrow/compute/exec/exec_utils.h @@ -15,14 +15,51 @@ // specific language governing permissions and limitations // under the License. +#include +#include +#include + +#include "arrow/util/thread_pool.h" + namespace arrow { namespace compute { -enum JoinType { - LEFT_SEMI_JOIN, - RIGHT_SEMI_JOIN, - LEFT_ANTI_SEMI_JOIN, - RIGHT_ANTI_SEMI_JOIN +class ThreadIndexer { + public: + size_t operator()(); + + static size_t Capacity(); + + private: + static size_t Check(size_t thread_index); + + std::mutex mutex_; + std::unordered_map id_to_index_; +}; + +class AtomicCounter { + public: + AtomicCounter() = default; + + int count() const; + + util::optional total() const; + + // return true if the counter is complete + bool Increment(); + + // return true if the counter is complete + bool SetTotal(int total); + + // return true if the counter has not already been completed + bool Cancel(); + + private: + // ensure there is only one true return from Increment(), SetTotal(), or Cancel() + bool DoneOnce(); + + std::atomic count_{0}, total_{-1}; + std::atomic complete_{false}; }; } // namespace compute diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index 10b676be892..ba7a7646b6f 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -15,11 +15,421 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/compute/exec/hash_join.h" +#include +#include +#include + +#include + +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/exec_utils.h" namespace arrow { namespace compute { +struct HashSemiJoinNode : ExecNode { + HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, std::string label, + ExecContext* ctx, const std::vector&& build_index_field_ids, + const std::vector&& probe_index_field_ids) + : ExecNode(build_input->plan(), std::move(label), {build_input, probe_input}, + {"hash_join_build", "hash_join_probe"}, probe_input->output_schema(), + /*num_outputs=*/1), + ctx_(ctx), + build_index_field_ids_(build_index_field_ids), + probe_index_field_ids_(probe_index_field_ids), + hash_table_built_(false) {} + + private: + struct ThreadLocalState; + + public: + const char* kind_name() override { return "HashSemiJoinNode"; } + + Status InitLocalStateIfNeeded(ThreadLocalState* state) { + std::cout << "init" + << "\n"; + + // Get input schema + auto build_schema = inputs_[0]->output_schema(); + + if (state->grouper != nullptr) return Status::OK(); + + // Build vector of key field data types + std::vector key_descrs(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + auto build_type = build_schema->field(build_index_field_ids_[i])->type(); + key_descrs[i] = ValueDescr(build_type); + } + + // Construct grouper + ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, ctx_)); + + return Status::OK(); + } + + // merge all other groupers to grouper[0]. nothing needs to be done on the + // cached_probe_batches, because when probing everyone. Note: Only one thread + // should execute this out of the pool! + Status BuildSideMerge() { + std::cout << "build side merge" + << "\n"; + + ThreadLocalState* state0 = &local_states_[0]; + for (size_t i = 1; i < local_states_.size(); ++i) { + ThreadLocalState* state = &local_states_[i]; + ARROW_DCHECK(state); + if (!state->grouper) { + continue; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques()); + ARROW_ASSIGN_OR_RAISE(Datum _, state0->grouper->Consume(other_keys)); + state->grouper.reset(); + } + return Status::OK(); + } + + // consumes a build batch and increments the build_batches count. if the build batches + // total reached at the end of consumption, all the local states will be merged, before + // incrementing the total batches + Status ConsumeBuildBatch(const size_t thread_index, ExecBatch batch) { + std::cout << "ConsumeBuildBatch " << thread_index << " " << batch.length << "\n"; + + auto state = &local_states_[thread_index]; + RETURN_NOT_OK(InitLocalStateIfNeeded(state)); + + // Create a batch with key columns + std::vector keys(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + keys[i] = batch.values[build_index_field_ids_[i]]; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); + + // Create a batch with group ids + ARROW_ASSIGN_OR_RAISE(Datum id_batch, state->grouper->Consume(key_batch)); + + if (build_counter_.Increment()) { + // only a single thread would come inside this if-block! + + // while incrementing, if the total is reached, merge all the groupers to 0'th one + RETURN_NOT_OK(BuildSideMerge()); + + // enable flag that build side is completed + hash_table_built_.store(true); + + // since the build side is completed, consume cached probe batches + RETURN_NOT_OK(ConsumeCachedProbeBatches(thread_index)); + } + + return Status::OK(); + } + + Status ConsumeCachedProbeBatches(const size_t thread_index) { + std::cout << "ConsumeCachedProbeBatches " << thread_index << "\n"; + + ThreadLocalState* state = &local_states_[thread_index]; + + // TODO (niranda) check if this is the best way to move batches + if (!state->cached_probe_batches.empty()) { + for (auto cached : state->cached_probe_batches) { + RETURN_NOT_OK(ConsumeProbeBatch(cached.first, std::move(cached.second))); + } + state->cached_probe_batches.clear(); + } + + return Status::OK(); + } + + // consumes a probe batch and increment probe batches count. Probing would query the + // grouper[0] which have been merged with all others. + Status ConsumeProbeBatch(int seq, ExecBatch batch) { + std::cout << "ConsumeProbeBatch " << seq << "\n"; + + auto* grouper = local_states_[0].grouper.get(); + + // Create a batch with key columns + std::vector keys(probe_index_field_ids_.size()); + for (size_t i = 0; i < probe_index_field_ids_.size(); ++i) { + keys[i] = batch.values[probe_index_field_ids_[i]]; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); + + // Query the grouper with key_batch. If no match was found, returning group_ids would + // have null. + ARROW_ASSIGN_OR_RAISE(Datum group_ids, grouper->Find(key_batch)); + auto group_ids_data = *group_ids.array(); + + if (group_ids_data.MayHaveNulls()) { // values need to be filtered + auto filter_arr = + std::make_shared(group_ids_data.length, group_ids_data.buffers[0], + /*null_bitmap=*/nullptr, /*null_count=*/0, + /*offset=*/group_ids_data.offset); + ARROW_ASSIGN_OR_RAISE(auto rec_batch, + batch.ToRecordBatch(output_schema_, ctx_->memory_pool())); + ARROW_ASSIGN_OR_RAISE( + auto filtered, + Filter(rec_batch, filter_arr, + /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); + auto out_batch = ExecBatch(*filtered.record_batch()); + std::cout << "output " << seq << " " << out_batch.ToString() << "\n"; + outputs_[0]->InputReceived(this, seq, std::move(out_batch)); + } else { // all values are valid for output + std::cout << "output " << seq << " " << batch.ToString() << "\n"; + outputs_[0]->InputReceived(this, seq, std::move(batch)); + } + + out_counter_.Increment(); + return Status::OK(); + } + + void CacheProbeBatch(const size_t thread_index, int seq_num, ExecBatch batch) { + ThreadLocalState* state = &local_states_[thread_index]; + state->cached_probe_batches.emplace_back(seq_num, std::move(batch)); + } + + inline bool IsBuildInput(ExecNode* input) { return input == inputs_[0]; } + + // If all build side batches received? continue streaming using probing + // else cache the batches in thread-local state + void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { + std::cout << "input received " << IsBuildInput(input) << " " << seq << " " + << batch.length << "\n"; + + ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); + + size_t thread_index = get_thread_index_(); + ARROW_DCHECK(thread_index < local_states_.size()); + + if (finished_.is_finished()) { + return; + } + + if (IsBuildInput(input)) { // build input batch is received + // if a build input is received when build side is completed, something's wrong! + ARROW_DCHECK(!hash_table_built_.load()); + + ErrorIfNotOk(ConsumeBuildBatch(thread_index, std::move(batch))); + } else { // probe input batch is received + if (hash_table_built_.load()) { // build side done, continue with probing + // consume cachedProbeBatches if available (for this thread) + ErrorIfNotOk(ConsumeCachedProbeBatches(thread_index)); + + // consume this probe batch + ErrorIfNotOk(ConsumeProbeBatch(seq, std::move(batch))); + } else { // build side not completed. Cache this batch! + CacheProbeBatch(thread_index, seq, std::move(batch)); + } + } + } + + void ErrorReceived(ExecNode* input, Status error) override { + std::cout << "error received " << error.ToString() << "\n"; + DCHECK_EQ(input, inputs_[0]); + + outputs_[0]->ErrorReceived(this, std::move(error)); + StopProducing(); + } + + void InputFinished(ExecNode* input, int num_total) override { + std::cout << "input finished " << IsBuildInput(input) << " " << num_total << "\n"; + + // bail if StopProducing was called + if (finished_.is_finished()) return; + + ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); + + size_t thread_index = get_thread_index_(); + + // set total for build input + if (IsBuildInput(input) && build_counter_.SetTotal(num_total)) { + // while incrementing, if the total is reached, merge all the groupers to 0'th one + ErrorIfNotOk(BuildSideMerge()); + + // enable flag that build side is completed + hash_table_built_.store(true); + + // only build side has completed! so process cached probe batches (of this thread) + ErrorIfNotOk(ConsumeCachedProbeBatches(thread_index)); + return; + } + + // set total for probe input. If it returns that probe side has completed, nothing to + // do, because probing inputs will be streamed to the output + // probe_counter_.SetTotal(num_total); + + // output will be streamed from the probe side. So, they will have the same total. + if (out_counter_.SetTotal(num_total)) { + // if out_counter has completed, the future is finished! + ErrorIfNotOk(ConsumeCachedProbeBatches(thread_index)); + finished_.MarkFinished(); + } + outputs_[0]->InputFinished(this, num_total); + } + + Status StartProducing() override { + std::cout << "start prod \n"; + finished_ = Future<>::Make(); + + local_states_.resize(ThreadIndexer::Capacity()); + return Status::OK(); + } + + void PauseProducing(ExecNode* output) override {} + + void ResumeProducing(ExecNode* output) override {} + + void StopProducing(ExecNode* output) override { + std::cout << "stop prod from node\n"; + + DCHECK_EQ(output, outputs_[0]); + + if (build_counter_.Cancel()) { + finished_.MarkFinished(); + } else if (out_counter_.Cancel()) { + finished_.MarkFinished(); + } + + for (auto&& input : inputs_) { + input->StopProducing(this); + } + } + + // TODO(niranda) couldn't there be multiple outputs for a Node? + void StopProducing() override { + std::cout << "stop prod \n"; + for (auto&& output : outputs_) { + StopProducing(output); + } + } + + Future<> finished() override { + std::cout << "finished? " << finished_.is_finished() << "\n"; + return finished_; + } + + private: + struct ThreadLocalState { + std::unique_ptr grouper; + std::vector> cached_probe_batches{}; + }; + + ExecContext* ctx_; + Future<> finished_ = Future<>::MakeFinished(); + + ThreadIndexer get_thread_index_; + const std::vector build_index_field_ids_, probe_index_field_ids_; + + AtomicCounter build_counter_, out_counter_; + std::vector local_states_; + + // need a separate atomic bool to track if the build side complete. Can't use the flag + // inside the AtomicCounter, because we need to merge the build groupers once we receive + // all the build batches. So, while merging, we need to prevent probe batches, being + // consumed. + std::atomic hash_table_built_; +}; + +Status ValidateJoinInputs(ExecNode* left_input, ExecNode* right_input, + const std::vector& left_keys, + const std::vector& right_keys) { + if (left_keys.size() != right_keys.size()) { + return Status::Invalid("left and right key sizes do not match"); + } + + const auto& l_schema = left_input->output_schema(); + const auto& r_schema = right_input->output_schema(); + for (size_t i = 0; i < left_keys.size(); i++) { + auto l_type = l_schema->GetFieldByName(left_keys[i])->type(); + auto r_type = r_schema->GetFieldByName(right_keys[i])->type(); + + if (!l_type->Equals(r_type)) { + return Status::Invalid("build and probe types do not match: " + l_type->ToString() + + "!=" + r_type->ToString()); + } + } + + return Status::OK(); +} + +Result> PopulateKeys(const Schema& schema, + const std::vector& keys) { + std::vector key_field_ids(keys.size()); + // Find input field indices for left key fields + for (size_t i = 0; i < keys.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto match, FieldRef(keys[i]).FindOne(schema)); + key_field_ids[i] = match[0]; + } + + return key_field_ids; +} + +Result MakeHashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, + std::string label, + const std::vector& build_keys, + const std::vector& probe_keys) { + RETURN_NOT_OK(ValidateJoinInputs(build_input, probe_input, build_keys, probe_keys)); + + auto build_schema = build_input->output_schema(); + auto probe_schema = probe_input->output_schema(); + + ARROW_ASSIGN_OR_RAISE(auto build_key_ids, PopulateKeys(*build_schema, build_keys)); + ARROW_ASSIGN_OR_RAISE(auto probe_key_ids, PopulateKeys(*probe_schema, probe_keys)); + + // output schema will be probe schema + auto ctx = build_input->plan()->exec_context(); + ExecPlan* plan = build_input->plan(); + + return plan->EmplaceNode(build_input, probe_input, std::move(label), + ctx, std::move(build_key_ids), + std::move(probe_key_ids)); +} + +Result MakeHashLeftSemiJoinNode(ExecNode* left_input, ExecNode* right_input, + std::string label, + const std::vector& left_keys, + const std::vector& right_keys) { + // left join--> build from right and probe from left + return MakeHashSemiJoinNode(right_input, left_input, std::move(label), right_keys, + left_keys); +} + +Result MakeHashRightSemiJoinNode(ExecNode* left_input, ExecNode* right_input, + std::string label, + const std::vector& left_keys, + const std::vector& right_keys) { + // right join--> build from left and probe from right + return MakeHashSemiJoinNode(left_input, right_input, std::move(label), left_keys, + right_keys); +} + +static std::string JoinTypeToString[] = {"LEFT_SEMI", "RIGHT_SEMI", "LEFT_ANTI", + "RIGHT_ANTI", "INNER", "LEFT_OUTER", + "RIGHT_OUTER", "FULL_OUTER"}; + +Result MakeHashJoinNode(JoinType join_type, ExecNode* left_input, + ExecNode* right_input, std::string label, + const std::vector& left_keys, + const std::vector& right_keys) { + switch (join_type) { + case LEFT_SEMI: + // left join--> build from right and probe from left + return MakeHashSemiJoinNode(right_input, left_input, std::move(label), right_keys, + left_keys); + case RIGHT_SEMI: + // right join--> build from left and probe from right + return MakeHashSemiJoinNode(left_input, right_input, std::move(label), left_keys, + right_keys); + case LEFT_ANTI: + case RIGHT_ANTI: + case INNER: + case LEFT_OUTER: + case RIGHT_OUTER: + case FULL_OUTER: + return Status::NotImplemented(JoinTypeToString[join_type] + + " joins not implemented!"); + default: + return Status::Invalid("invalid join type"); + } +} } // namespace compute } // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/exec/hash_join_test.cc b/cpp/src/arrow/compute/exec/hash_join_test.cc new file mode 100644 index 00000000000..601e7fee8fa --- /dev/null +++ b/cpp/src/arrow/compute/exec/hash_join_test.cc @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/api.h" +#include "arrow/compute/exec/test_util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" + +using testing::UnorderedElementsAreArray; + +namespace arrow { +namespace compute { + +void GenerateBatchesFromString(const std::shared_ptr& schema, + const std::vector& json_strings, + BatchesWithSchema* out_batches) { + std::vector descrs; + for (auto&& field : schema->fields()) { + descrs.emplace_back(field->type()); + } + + for (auto&& s : json_strings) { + out_batches->batches.push_back(ExecBatchFromJSON(descrs, s)); + } + + out_batches->schema = schema; +} + +TEST(HashJoin, LeftSemi) { + auto l_schema = schema({field("l_i32", int32()), field("l_str", utf8())}); + auto r_schema = schema({field("r_str", utf8()), field("r_i32", int32())}); + BatchesWithSchema l_batches, r_batches, exp_batches; + + GenerateBatchesFromString(l_schema, + {R"([[0,"d"], [1,"b"]])", R"([[2,"d"], [3,"a"], [4,"a"]])", + R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, + &l_batches); + + GenerateBatchesFromString( + r_schema, + {R"([["f", 0], ["b", 1], ["b", 2]])", R"([["c", 3], ["g", 4]])", R"([["e", 5]])"}, + &r_batches); + + SCOPED_TRACE("serial"); + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + + ASSERT_OK_AND_ASSIGN(auto l_source, + MakeTestSourceNode(plan.get(), "l_source", l_batches, + /*parallel=*/false, + /*slow=*/false)); + ASSERT_OK_AND_ASSIGN(auto r_source, + MakeTestSourceNode(plan.get(), "r_source", r_batches, + /*parallel=*/false, + /*slow=*/false)); + + ASSERT_OK_AND_ASSIGN( + auto semi_join, + MakeHashJoinNode(JoinType::LEFT_SEMI, l_source, r_source, "l_semi_join", + /*left_keys=*/{"l_str"}, /*right_keys=*/{"r_str"})); + auto sink_gen = MakeSinkNode(semi_join, "sink"); + + GenerateBatchesFromString( + l_schema, {R"([[1,"b"]])", R"([])", R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, + &exp_batches); + + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); +} + +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index a1a0be1f3c7..d37fd2fca34 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -18,7 +18,6 @@ #include #include -#include #include "arrow/compute/exec.h" #include "arrow/compute/exec/exec_plan.h" @@ -27,17 +26,17 @@ #include "arrow/compute/exec/test_util.h" #include "arrow/compute/exec/util.h" #include "arrow/record_batch.h" +<<<<<<< HEAD #include "arrow/table.h" #include "arrow/testing/future_util.h" -#include "arrow/testing/gtest_util.h" +======= +>>>>>>> refactoring files + #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" -#include "arrow/testing/random.h" #include "arrow/util/async_generator.h" #include "arrow/util/logging.h" -#include "arrow/util/thread_pool.h" -#include "arrow/util/vector.h" -using testing::ElementsAre; + using testing::ElementsAre; using testing::ElementsAreArray; using testing::HasSubstr; using testing::Optional; @@ -730,60 +729,5 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) { })))); } -void GenerateBatchesFromString(const std::shared_ptr& schema, - const std::vector& json_strings, - BatchesWithSchema* out_batches) { - std::vector descrs; - for (auto&& field : schema->fields()) { - descrs.emplace_back(field->type()); - } - - for (auto&& s : json_strings) { - out_batches->batches.push_back(ExecBatchFromJSON(descrs, s)); - } - - out_batches->schema = schema; -} - -TEST(ExecPlanExecution, SourceHashLeftSemiJoin) { - BatchesWithSchema l_batches, r_batches, exp_batches; - - GenerateBatchesFromString(schema({field("l_i32", int32()), field("l_str", utf8())}), - {R"([[0,"d"], [1,"b"]])", R"([[2,"d"], [3,"a"], [4,"a"]])", - R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, - &l_batches); - - GenerateBatchesFromString( - schema({field("r_str", utf8()), field("r_i32", int32())}), - {R"([["f", 0], ["b", 1], ["b", 2]])", R"([["c", 3], ["g", 4]])", R"([["e", 5]])"}, - &r_batches); - - SCOPED_TRACE("serial"); - - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - - ASSERT_OK_AND_ASSIGN(auto l_source, - MakeTestSourceNode(plan.get(), "l_source", l_batches, - /*parallel=*/false, - /*slow=*/false)); - ASSERT_OK_AND_ASSIGN(auto r_source, - MakeTestSourceNode(plan.get(), "r_source", r_batches, - /*parallel=*/false, - /*slow=*/false)); - - ASSERT_OK_AND_ASSIGN( - auto semi_join, - MakeHashLeftSemiJoinNode(l_source, r_source, "l_semi_join", - /*left_keys=*/{"l_str"}, /*right_keys=*/{"r_str"})); - auto sink_gen = MakeSinkNode(semi_join, "sink"); - - GenerateBatchesFromString( - schema({field("l_i32", int32()), field("l_str", utf8())}), - {R"([[1,"b"]])", R"([])", R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, &exp_batches); - - ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), - Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); -} - } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index 49f089f3459..d1d06e773f0 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -35,6 +35,7 @@ #include "arrow/datum.h" #include "arrow/record_batch.h" #include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" #include "arrow/type.h" #include "arrow/util/async_generator.h" #include "arrow/util/iterator.h" @@ -155,5 +156,79 @@ ExecBatch ExecBatchFromJSON(const std::vector& descrs, return batch; } +Result MakeTestSourceNode(ExecPlan* plan, std::string label, + BatchesWithSchema batches_with_schema, bool parallel, + bool slow) { + DCHECK_GT(batches_with_schema.batches.size(), 0); + + auto opt_batches = ::arrow::internal::MapVector( + [](ExecBatch batch) { return util::make_optional(std::move(batch)); }, + std::move(batches_with_schema.batches)); + + AsyncGenerator> gen; + + if (parallel) { + // emulate batches completing initial decode-after-scan on a cpu thread + ARROW_ASSIGN_OR_RAISE( + gen, MakeBackgroundGenerator(MakeVectorIterator(std::move(opt_batches)), + ::arrow::internal::GetCpuThreadPool())); + + // ensure that callbacks are not executed immediately on a background thread + gen = MakeTransferredGenerator(std::move(gen), ::arrow::internal::GetCpuThreadPool()); + } else { + gen = MakeVectorGenerator(std::move(opt_batches)); + } + + if (slow) { + gen = MakeMappedGenerator(std::move(gen), [](const util::optional& batch) { + SleepABit(); + return batch; + }); + } + + return MakeSourceNode(plan, std::move(label), std::move(batches_with_schema.schema), + std::move(gen)); +} + +Future> StartAndCollect( + ExecPlan* plan, AsyncGenerator> gen) { + RETURN_NOT_OK(plan->Validate()); + RETURN_NOT_OK(plan->StartProducing()); + + auto collected_fut = CollectAsyncGenerator(gen); + + return AllComplete({plan->finished(), Future<>(collected_fut)}) + .Then([collected_fut]() -> Result> { + ARROW_ASSIGN_OR_RAISE(auto collected, collected_fut.result()); + return ::arrow::internal::MapVector( + [](util::optional batch) { return std::move(*batch); }, + std::move(collected)); + }); +} + +BatchesWithSchema MakeBasicBatches() { + BatchesWithSchema out; + out.batches = { + ExecBatchFromJSON({int32(), boolean()}, "[[null, true], [4, false]]"), + ExecBatchFromJSON({int32(), boolean()}, "[[5, null], [6, false], [7, false]]")}; + out.schema = schema({field("i32", int32()), field("bool", boolean())}); + return out; +} + +BatchesWithSchema MakeRandomBatches(const std::shared_ptr& schema, + int num_batches, int batch_size) { + BatchesWithSchema out; + + random::RandomArrayGenerator rng(42); + out.batches.resize(num_batches); + + for (int i = 0; i < num_batches; ++i) { + out.batches[i] = ExecBatch(*rng.BatchOf(schema->fields(), batch_size)); + // add a tag scalar to ensure the batches are unique + out.batches[i].values.emplace_back(i); + } + return out; +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index faa395bab78..3ef1333ea42 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -24,6 +24,7 @@ #include "arrow/compute/exec.h" #include "arrow/compute/exec/exec_plan.h" #include "arrow/testing/visibility.h" +#include "arrow/util/async_generator.h" #include "arrow/util/string_view.h" namespace arrow { @@ -41,5 +42,26 @@ ARROW_TESTING_EXPORT ExecBatch ExecBatchFromJSON(const std::vector& descrs, util::string_view json); +struct BatchesWithSchema { + std::vector batches; + std::shared_ptr schema; +}; + +ARROW_TESTING_EXPORT +Result MakeTestSourceNode(ExecPlan* plan, std::string label, + BatchesWithSchema batches_with_schema, bool parallel, + bool slow); + +ARROW_TESTING_EXPORT +Future> StartAndCollect( + ExecPlan* plan, AsyncGenerator> gen); + +ARROW_TESTING_EXPORT +BatchesWithSchema MakeBasicBatches(); + +ARROW_TESTING_EXPORT +BatchesWithSchema MakeRandomBatches(const std::shared_ptr& schema, + int num_batches = 10, int batch_size = 4); + } // namespace compute } // namespace arrow From c4e379e1193274122495037c02b42e047589e9a7 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 4 Aug 2021 16:23:19 -0400 Subject: [PATCH 12/27] adding right semi join test --- cpp/src/arrow/compute/exec/exec_utils.h | 2 +- cpp/src/arrow/compute/exec/hash_join.cc | 196 +++++++++++-------- cpp/src/arrow/compute/exec/hash_join_test.cc | 47 +++-- 3 files changed, 153 insertions(+), 92 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_utils.h b/cpp/src/arrow/compute/exec/exec_utils.h index c899e3b3058..65dd93150e1 100644 --- a/cpp/src/arrow/compute/exec/exec_utils.h +++ b/cpp/src/arrow/compute/exec/exec_utils.h @@ -63,4 +63,4 @@ class AtomicCounter { }; } // namespace compute -} // namespace arrow \ No newline at end of file +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index ba7a7646b6f..789a71973ee 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -37,6 +37,7 @@ struct HashSemiJoinNode : ExecNode { ctx_(ctx), build_index_field_ids_(build_index_field_ids), probe_index_field_ids_(probe_index_field_ids), + build_result_index(-1), hash_table_built_(false) {} private: @@ -46,8 +47,7 @@ struct HashSemiJoinNode : ExecNode { const char* kind_name() override { return "HashSemiJoinNode"; } Status InitLocalStateIfNeeded(ThreadLocalState* state) { - std::cout << "init" - << "\n"; + // std::cout << "init \n"; // Get input schema auto build_schema = inputs_[0]->output_schema(); @@ -67,32 +67,62 @@ struct HashSemiJoinNode : ExecNode { return Status::OK(); } - // merge all other groupers to grouper[0]. nothing needs to be done on the - // cached_probe_batches, because when probing everyone. Note: Only one thread - // should execute this out of the pool! - Status BuildSideMerge() { - std::cout << "build side merge" - << "\n"; + // Finds an appropriate index which could accumulate all build indices (i.e. the grouper + // which has the highest # of groups) + void CalculateBuildResultIndex() { + uint32_t curr_max = 0; + for (int i = 0; i < static_cast(local_states_.size()); i++) { + auto* state = &local_states_[i]; + ARROW_DCHECK(state); + if (state->grouper && curr_max < state->grouper->num_groups()) { + curr_max = state->grouper->num_groups(); + build_result_index = i; + } + } + ARROW_DCHECK(build_result_index > -1); + // std::cout << "build_result_index " << build_result_index << "\n"; + } - ThreadLocalState* state0 = &local_states_[0]; - for (size_t i = 1; i < local_states_.size(); ++i) { + // Performs the housekeeping work after the build-side is completed. Note: this method + // should be called ONLY ONCE! + Status BuildSideCompleted() { + // std::cout << "build side merge \n"; + + CalculateBuildResultIndex(); + + // merge every group into the build_result_index grouper + ThreadLocalState* result_state = &local_states_[build_result_index]; + for (int i = 0; i < static_cast(local_states_.size()); ++i) { ThreadLocalState* state = &local_states_[i]; ARROW_DCHECK(state); - if (!state->grouper) { + if (i == build_result_index || !state->grouper) { continue; } ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques()); - ARROW_ASSIGN_OR_RAISE(Datum _, state0->grouper->Consume(other_keys)); + + // TODO(niranda) replace with void consume method + ARROW_ASSIGN_OR_RAISE(Datum _, result_state->grouper->Consume(other_keys)); state->grouper.reset(); } + + // enable flag that build side is completed + hash_table_built_ = true; + + // since the build side is completed, consume cached probe batches + RETURN_NOT_OK(ConsumeCachedProbeBatches()); + return Status::OK(); } // consumes a build batch and increments the build_batches count. if the build batches // total reached at the end of consumption, all the local states will be merged, before // incrementing the total batches - Status ConsumeBuildBatch(const size_t thread_index, ExecBatch batch) { - std::cout << "ConsumeBuildBatch " << thread_index << " " << batch.length << "\n"; + Status ConsumeBuildBatch(ExecBatch batch) { + // std::cout << "ConsumeBuildBatch tid:" << thread_index << " len:" << batch.length + // << "\n"; + + size_t thread_index = get_thread_index_(); + ARROW_DCHECK(thread_index < local_states_.size()); auto state = &local_states_[thread_index]; RETURN_NOT_OK(InitLocalStateIfNeeded(state)); @@ -104,47 +134,51 @@ struct HashSemiJoinNode : ExecNode { } ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); - // Create a batch with group ids - ARROW_ASSIGN_OR_RAISE(Datum id_batch, state->grouper->Consume(key_batch)); + // Create a batch with group ids TODO(niranda) replace with void consume method + ARROW_ASSIGN_OR_RAISE(Datum _, state->grouper->Consume(key_batch)); if (build_counter_.Increment()) { - // only a single thread would come inside this if-block! - - // while incrementing, if the total is reached, merge all the groupers to 0'th one - RETURN_NOT_OK(BuildSideMerge()); - - // enable flag that build side is completed - hash_table_built_.store(true); - - // since the build side is completed, consume cached probe batches - RETURN_NOT_OK(ConsumeCachedProbeBatches(thread_index)); + // while incrementing, if the total is reached, call BuildSideCompleted + RETURN_NOT_OK(BuildSideCompleted()); } return Status::OK(); } - Status ConsumeCachedProbeBatches(const size_t thread_index) { - std::cout << "ConsumeCachedProbeBatches " << thread_index << "\n"; - - ThreadLocalState* state = &local_states_[thread_index]; - - // TODO (niranda) check if this is the best way to move batches - if (!state->cached_probe_batches.empty()) { - for (auto cached : state->cached_probe_batches) { - RETURN_NOT_OK(ConsumeProbeBatch(cached.first, std::move(cached.second))); + // consumes cached probe batches by invoking executor::Spawn. This should be called by a + // single thread. Note: this method should be called ONLY ONCE! + Status ConsumeCachedProbeBatches() { + // std::cout << "ConsumeCachedProbeBatches tid:" << thread_index + // << " len:" << cached_probe_batches.size() << "\n"; + + if (!cached_probe_batches.empty()) { + auto executor = ctx_->executor(); + for (auto&& cached : cached_probe_batches) { + if (executor) { + Status lambda_status; + RETURN_NOT_OK(executor->Spawn([&] { + lambda_status = ConsumeProbeBatch(cached.first, std::move(cached.second)); + })); + + // if the lambda execution failed internally, return status + RETURN_NOT_OK(lambda_status); + } else { + RETURN_NOT_OK(ConsumeProbeBatch(cached.first, std::move(cached.second))); + } } - state->cached_probe_batches.clear(); + // cached vector will be cleared. exec batches are expected to be moved to the + // lambdas + cached_probe_batches.clear(); } - return Status::OK(); } // consumes a probe batch and increment probe batches count. Probing would query the - // grouper[0] which have been merged with all others. + // grouper[build_result_index] which have been merged with all others. Status ConsumeProbeBatch(int seq, ExecBatch batch) { - std::cout << "ConsumeProbeBatch " << seq << "\n"; + // std::cout << "ConsumeProbeBatch seq:" << seq << "\n"; - auto* grouper = local_states_[0].grouper.get(); + auto& final_grouper = *local_states_[build_result_index].grouper; // Create a batch with key columns std::vector keys(probe_index_field_ids_.size()); @@ -155,7 +189,7 @@ struct HashSemiJoinNode : ExecNode { // Query the grouper with key_batch. If no match was found, returning group_ids would // have null. - ARROW_ASSIGN_OR_RAISE(Datum group_ids, grouper->Find(key_batch)); + ARROW_ASSIGN_OR_RAISE(Datum group_ids, final_grouper.Find(key_batch)); auto group_ids_data = *group_ids.array(); if (group_ids_data.MayHaveNulls()) { // values need to be filtered @@ -170,20 +204,25 @@ struct HashSemiJoinNode : ExecNode { Filter(rec_batch, filter_arr, /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); auto out_batch = ExecBatch(*filtered.record_batch()); - std::cout << "output " << seq << " " << out_batch.ToString() << "\n"; + // std::cout << "output seq:" << seq << " " << out_batch.length << "\n"; outputs_[0]->InputReceived(this, seq, std::move(out_batch)); } else { // all values are valid for output - std::cout << "output " << seq << " " << batch.ToString() << "\n"; + // std::cout << "output seq:" << seq << " " << batch.length << "\n"; outputs_[0]->InputReceived(this, seq, std::move(batch)); } - out_counter_.Increment(); + if (out_counter_.Increment()) { + finished_.MarkFinished(); + } return Status::OK(); } - void CacheProbeBatch(const size_t thread_index, int seq_num, ExecBatch batch) { - ThreadLocalState* state = &local_states_[thread_index]; - state->cached_probe_batches.emplace_back(seq_num, std::move(batch)); + // void CacheProbeBatch(const size_t thread_index, int seq_num, ExecBatch batch) { + void CacheProbeBatch(int seq_num, ExecBatch batch) { + // std::cout << "cache tid:" << thread_index << " seq:" << seq_num + // << " len:" << batch.length << "\n"; + std::lock_guard lck(cached_probe_batches_mutex); + cached_probe_batches.emplace_back(seq_num, std::move(batch)); } inline bool IsBuildInput(ExecNode* input) { return input == inputs_[0]; } @@ -191,38 +230,35 @@ struct HashSemiJoinNode : ExecNode { // If all build side batches received? continue streaming using probing // else cache the batches in thread-local state void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { - std::cout << "input received " << IsBuildInput(input) << " " << seq << " " - << batch.length << "\n"; + // std::cout << "input received input:" << (IsBuildInput(input) ? "b" : "p") + // << " seq:" << seq << " len:" << batch.length << "\n"; ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); - size_t thread_index = get_thread_index_(); - ARROW_DCHECK(thread_index < local_states_.size()); - if (finished_.is_finished()) { return; } if (IsBuildInput(input)) { // build input batch is received // if a build input is received when build side is completed, something's wrong! - ARROW_DCHECK(!hash_table_built_.load()); + ARROW_DCHECK(!hash_table_built_); - ErrorIfNotOk(ConsumeBuildBatch(thread_index, std::move(batch))); - } else { // probe input batch is received - if (hash_table_built_.load()) { // build side done, continue with probing + ErrorIfNotOk(ConsumeBuildBatch(std::move(batch))); + } else { // probe input batch is received + if (hash_table_built_) { // build side done, continue with probing // consume cachedProbeBatches if available (for this thread) - ErrorIfNotOk(ConsumeCachedProbeBatches(thread_index)); + ErrorIfNotOk(ConsumeCachedProbeBatches()); // consume this probe batch ErrorIfNotOk(ConsumeProbeBatch(seq, std::move(batch))); } else { // build side not completed. Cache this batch! - CacheProbeBatch(thread_index, seq, std::move(batch)); + CacheProbeBatch(seq, std::move(batch)); } } } void ErrorReceived(ExecNode* input, Status error) override { - std::cout << "error received " << error.ToString() << "\n"; + // std::cout << "error received " << error.ToString() << "\n"; DCHECK_EQ(input, inputs_[0]); outputs_[0]->ErrorReceived(this, std::move(error)); @@ -230,25 +266,18 @@ struct HashSemiJoinNode : ExecNode { } void InputFinished(ExecNode* input, int num_total) override { - std::cout << "input finished " << IsBuildInput(input) << " " << num_total << "\n"; + // std::cout << "input finished input:" << (IsBuildInput(input) ? "b" : "p") + // << " tot:" << num_total << "\n"; // bail if StopProducing was called if (finished_.is_finished()) return; ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); - size_t thread_index = get_thread_index_(); - // set total for build input if (IsBuildInput(input) && build_counter_.SetTotal(num_total)) { - // while incrementing, if the total is reached, merge all the groupers to 0'th one - ErrorIfNotOk(BuildSideMerge()); - - // enable flag that build side is completed - hash_table_built_.store(true); - - // only build side has completed! so process cached probe batches (of this thread) - ErrorIfNotOk(ConsumeCachedProbeBatches(thread_index)); + // while incrementing, if the total is reached, call BuildSideCompleted() + ErrorIfNotOk(BuildSideCompleted()); return; } @@ -259,14 +288,16 @@ struct HashSemiJoinNode : ExecNode { // output will be streamed from the probe side. So, they will have the same total. if (out_counter_.SetTotal(num_total)) { // if out_counter has completed, the future is finished! - ErrorIfNotOk(ConsumeCachedProbeBatches(thread_index)); + ErrorIfNotOk(ConsumeCachedProbeBatches()); + outputs_[0]->InputFinished(this, num_total); finished_.MarkFinished(); } outputs_[0]->InputFinished(this, num_total); + // std::cout << "output set:" << num_total << "\n"; } Status StartProducing() override { - std::cout << "start prod \n"; + // std::cout << "start prod \n"; finished_ = Future<>::Make(); local_states_.resize(ThreadIndexer::Capacity()); @@ -278,7 +309,7 @@ struct HashSemiJoinNode : ExecNode { void ResumeProducing(ExecNode* output) override {} void StopProducing(ExecNode* output) override { - std::cout << "stop prod from node\n"; + // std::cout << "stop prod from node\n"; DCHECK_EQ(output, outputs_[0]); @@ -295,21 +326,20 @@ struct HashSemiJoinNode : ExecNode { // TODO(niranda) couldn't there be multiple outputs for a Node? void StopProducing() override { - std::cout << "stop prod \n"; + // std::cout << "stop prod \n"; for (auto&& output : outputs_) { StopProducing(output); } } Future<> finished() override { - std::cout << "finished? " << finished_.is_finished() << "\n"; + // std::cout << "finished? " << finished_.is_finished() << "\n"; return finished_; } private: struct ThreadLocalState { std::unique_ptr grouper; - std::vector> cached_probe_batches{}; }; ExecContext* ctx_; @@ -321,11 +351,19 @@ struct HashSemiJoinNode : ExecNode { AtomicCounter build_counter_, out_counter_; std::vector local_states_; - // need a separate atomic bool to track if the build side complete. Can't use the flag + // we have no guarantee which threads would be coming from the build side. so, out of + // the thread local states, we need to find an appropriate index which could accumulate + // all build indices (ideally, the grouper which has the highest # of elems) + int32_t build_result_index; + + // need a separate bool to track if the build side complete. Can't use the flag // inside the AtomicCounter, because we need to merge the build groupers once we receive // all the build batches. So, while merging, we need to prevent probe batches, being // consumed. - std::atomic hash_table_built_; + bool hash_table_built_; + + std::mutex cached_probe_batches_mutex; + std::vector> cached_probe_batches{}; }; Status ValidateJoinInputs(ExecNode* left_input, ExecNode* right_input, @@ -432,4 +470,4 @@ Result MakeHashJoinNode(JoinType join_type, ExecNode* left_input, } } // namespace compute -} // namespace arrow \ No newline at end of file +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/hash_join_test.cc b/cpp/src/arrow/compute/exec/hash_join_test.cc index 601e7fee8fa..27b51c5cb30 100644 --- a/cpp/src/arrow/compute/exec/hash_join_test.cc +++ b/cpp/src/arrow/compute/exec/hash_join_test.cc @@ -29,7 +29,7 @@ namespace compute { void GenerateBatchesFromString(const std::shared_ptr& schema, const std::vector& json_strings, - BatchesWithSchema* out_batches) { + BatchesWithSchema* out_batches, int multiplicity = 1) { std::vector descrs; for (auto&& field : schema->fields()) { descrs.emplace_back(field->type()); @@ -39,10 +39,17 @@ void GenerateBatchesFromString(const std::shared_ptr& schema, out_batches->batches.push_back(ExecBatchFromJSON(descrs, s)); } + size_t batch_count = out_batches->batches.size(); + for (int repeat = 1; repeat < multiplicity; ++repeat) { + for (size_t i = 0; i < batch_count; ++i) { + out_batches->batches.push_back(out_batches->batches[i]); + } + } + out_batches->schema = schema; } -TEST(HashJoin, LeftSemi) { +void RunTest(JoinType type, bool parallel) { auto l_schema = schema({field("l_i32", int32()), field("l_str", utf8())}); auto r_schema = schema({field("r_str", utf8()), field("r_i32", int32())}); BatchesWithSchema l_batches, r_batches, exp_batches; @@ -50,12 +57,12 @@ TEST(HashJoin, LeftSemi) { GenerateBatchesFromString(l_schema, {R"([[0,"d"], [1,"b"]])", R"([[2,"d"], [3,"a"], [4,"a"]])", R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, - &l_batches); + &l_batches, /*multiplicity=*/parallel ? 100 : 1); GenerateBatchesFromString( r_schema, {R"([["f", 0], ["b", 1], ["b", 2]])", R"([["c", 3], ["g", 4]])", R"([["e", 5]])"}, - &r_batches); + &r_batches, /*multiplicity=*/parallel ? 100 : 1); SCOPED_TRACE("serial"); @@ -63,26 +70,42 @@ TEST(HashJoin, LeftSemi) { ASSERT_OK_AND_ASSIGN(auto l_source, MakeTestSourceNode(plan.get(), "l_source", l_batches, - /*parallel=*/false, + /*parallel=*/parallel, /*slow=*/false)); ASSERT_OK_AND_ASSIGN(auto r_source, MakeTestSourceNode(plan.get(), "r_source", r_batches, - /*parallel=*/false, + /*parallel=*/parallel, /*slow=*/false)); ASSERT_OK_AND_ASSIGN( auto semi_join, - MakeHashJoinNode(JoinType::LEFT_SEMI, l_source, r_source, "l_semi_join", + MakeHashJoinNode(type, l_source, r_source, "l_semi_join", /*left_keys=*/{"l_str"}, /*right_keys=*/{"r_str"})); auto sink_gen = MakeSinkNode(semi_join, "sink"); - GenerateBatchesFromString( - l_schema, {R"([[1,"b"]])", R"([])", R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, - &exp_batches); - + if (type == JoinType::LEFT_SEMI) { + GenerateBatchesFromString( + l_schema, {R"([[1,"b"]])", R"([])", R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, + &exp_batches, /*multiplicity=*/parallel ? 100 : 1); + } else if (type == JoinType::RIGHT_SEMI) { + GenerateBatchesFromString( + r_schema, {R"([["b", 1], ["b", 2]])", R"([["c", 3]])", R"([["e", 5]])"}, + &exp_batches, /*multiplicity=*/parallel ? 100 : 1); + } ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); } +class HashJoinTest : public testing::TestWithParam> {}; + +INSTANTIATE_TEST_SUITE_P(HashJoinTest, HashJoinTest, + ::testing::Combine(::testing::Values(JoinType::LEFT_SEMI, + JoinType::RIGHT_SEMI), + ::testing::Values(false, true))); + +TEST_P(HashJoinTest, TestSemiJoins) { + RunTest(std::get<0>(GetParam()), std::get<1>(GetParam())); +} + } // namespace compute -} // namespace arrow \ No newline at end of file +} // namespace arrow From 842204ec1065f1d7110c06bffce048464e27f7c9 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 4 Aug 2021 16:49:57 -0400 Subject: [PATCH 13/27] using log instead of cout --- cpp/src/arrow/compute/exec/hash_join.cc | 49 ++++++++++++------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index 789a71973ee..4a0d154f33d 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -19,10 +19,9 @@ #include #include -#include - #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/exec_utils.h" +#include "arrow/util/logging.h" namespace arrow { namespace compute { @@ -47,7 +46,7 @@ struct HashSemiJoinNode : ExecNode { const char* kind_name() override { return "HashSemiJoinNode"; } Status InitLocalStateIfNeeded(ThreadLocalState* state) { - // std::cout << "init \n"; + ARROW_LOG(DEBUG) << "init "; // Get input schema auto build_schema = inputs_[0]->output_schema(); @@ -80,13 +79,13 @@ struct HashSemiJoinNode : ExecNode { } } ARROW_DCHECK(build_result_index > -1); - // std::cout << "build_result_index " << build_result_index << "\n"; + ARROW_LOG(DEBUG) << "build_result_index " << build_result_index; } // Performs the housekeeping work after the build-side is completed. Note: this method // should be called ONLY ONCE! Status BuildSideCompleted() { - // std::cout << "build side merge \n"; + ARROW_LOG(DEBUG) << "build side merge"; CalculateBuildResultIndex(); @@ -118,12 +117,12 @@ struct HashSemiJoinNode : ExecNode { // total reached at the end of consumption, all the local states will be merged, before // incrementing the total batches Status ConsumeBuildBatch(ExecBatch batch) { - // std::cout << "ConsumeBuildBatch tid:" << thread_index << " len:" << batch.length - // << "\n"; - size_t thread_index = get_thread_index_(); ARROW_DCHECK(thread_index < local_states_.size()); + ARROW_LOG(DEBUG) << "ConsumeBuildBatch tid:" << thread_index + << " len:" << batch.length; + auto state = &local_states_[thread_index]; RETURN_NOT_OK(InitLocalStateIfNeeded(state)); @@ -148,8 +147,8 @@ struct HashSemiJoinNode : ExecNode { // consumes cached probe batches by invoking executor::Spawn. This should be called by a // single thread. Note: this method should be called ONLY ONCE! Status ConsumeCachedProbeBatches() { - // std::cout << "ConsumeCachedProbeBatches tid:" << thread_index - // << " len:" << cached_probe_batches.size() << "\n"; + ARROW_LOG(DEBUG) << "ConsumeCachedProbeBatches tid:" << get_thread_index_() + << " len:" << cached_probe_batches.size(); if (!cached_probe_batches.empty()) { auto executor = ctx_->executor(); @@ -176,7 +175,7 @@ struct HashSemiJoinNode : ExecNode { // consumes a probe batch and increment probe batches count. Probing would query the // grouper[build_result_index] which have been merged with all others. Status ConsumeProbeBatch(int seq, ExecBatch batch) { - // std::cout << "ConsumeProbeBatch seq:" << seq << "\n"; + ARROW_LOG(DEBUG) << "ConsumeProbeBatch seq:" << seq; auto& final_grouper = *local_states_[build_result_index].grouper; @@ -204,10 +203,10 @@ struct HashSemiJoinNode : ExecNode { Filter(rec_batch, filter_arr, /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); auto out_batch = ExecBatch(*filtered.record_batch()); - // std::cout << "output seq:" << seq << " " << out_batch.length << "\n"; + ARROW_LOG(DEBUG) << "output seq:" << seq << " " << out_batch.length; outputs_[0]->InputReceived(this, seq, std::move(out_batch)); } else { // all values are valid for output - // std::cout << "output seq:" << seq << " " << batch.length << "\n"; + ARROW_LOG(DEBUG) << "output seq:" << seq << " " << batch.length; outputs_[0]->InputReceived(this, seq, std::move(batch)); } @@ -219,8 +218,8 @@ struct HashSemiJoinNode : ExecNode { // void CacheProbeBatch(const size_t thread_index, int seq_num, ExecBatch batch) { void CacheProbeBatch(int seq_num, ExecBatch batch) { - // std::cout << "cache tid:" << thread_index << " seq:" << seq_num - // << " len:" << batch.length << "\n"; + ARROW_LOG(DEBUG) << "cache tid:" << get_thread_index_() << " seq:" << seq_num + << " len:" << batch.length; std::lock_guard lck(cached_probe_batches_mutex); cached_probe_batches.emplace_back(seq_num, std::move(batch)); } @@ -230,8 +229,8 @@ struct HashSemiJoinNode : ExecNode { // If all build side batches received? continue streaming using probing // else cache the batches in thread-local state void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { - // std::cout << "input received input:" << (IsBuildInput(input) ? "b" : "p") - // << " seq:" << seq << " len:" << batch.length << "\n"; + // //std::cout << "input received input:" << (IsBuildInput(input) ? "b" : "p") + // << " seq:" << seq << " len:" << batch.length ; ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); @@ -258,7 +257,7 @@ struct HashSemiJoinNode : ExecNode { } void ErrorReceived(ExecNode* input, Status error) override { - // std::cout << "error received " << error.ToString() << "\n"; + ARROW_LOG(DEBUG) << "error received " << error.ToString(); DCHECK_EQ(input, inputs_[0]); outputs_[0]->ErrorReceived(this, std::move(error)); @@ -266,8 +265,8 @@ struct HashSemiJoinNode : ExecNode { } void InputFinished(ExecNode* input, int num_total) override { - // std::cout << "input finished input:" << (IsBuildInput(input) ? "b" : "p") - // << " tot:" << num_total << "\n"; + ARROW_LOG(DEBUG) << "input finished input:" << (IsBuildInput(input) ? "b" : "p") + << " tot:" << num_total; // bail if StopProducing was called if (finished_.is_finished()) return; @@ -293,11 +292,11 @@ struct HashSemiJoinNode : ExecNode { finished_.MarkFinished(); } outputs_[0]->InputFinished(this, num_total); - // std::cout << "output set:" << num_total << "\n"; + ARROW_LOG(DEBUG) << "output set:" << num_total; } Status StartProducing() override { - // std::cout << "start prod \n"; + ARROW_LOG(DEBUG) << "start prod"; finished_ = Future<>::Make(); local_states_.resize(ThreadIndexer::Capacity()); @@ -309,7 +308,7 @@ struct HashSemiJoinNode : ExecNode { void ResumeProducing(ExecNode* output) override {} void StopProducing(ExecNode* output) override { - // std::cout << "stop prod from node\n"; + ARROW_LOG(DEBUG) << "stop prod from node"; DCHECK_EQ(output, outputs_[0]); @@ -326,14 +325,14 @@ struct HashSemiJoinNode : ExecNode { // TODO(niranda) couldn't there be multiple outputs for a Node? void StopProducing() override { - // std::cout << "stop prod \n"; + ARROW_LOG(DEBUG) << "stop prod "; for (auto&& output : outputs_) { StopProducing(output); } } Future<> finished() override { - // std::cout << "finished? " << finished_.is_finished() << "\n"; + ARROW_LOG(DEBUG) << "finished? " << finished_.is_finished(); return finished_; } From 2964cc2935b6a49a7d10ff3d354f069fb6890018 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 4 Aug 2021 16:55:33 -0400 Subject: [PATCH 14/27] minor changes --- cpp/src/arrow/compute/exec/hash_join.cc | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index 4a0d154f33d..47140e552b8 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -15,10 +15,8 @@ // specific language governing permissions and limitations // under the License. -#include -#include -#include - +#include "arrow/api.h" +#include "arrow/compute/api.h" #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/exec_utils.h" #include "arrow/util/logging.h" @@ -229,8 +227,8 @@ struct HashSemiJoinNode : ExecNode { // If all build side batches received? continue streaming using probing // else cache the batches in thread-local state void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { - // //std::cout << "input received input:" << (IsBuildInput(input) ? "b" : "p") - // << " seq:" << seq << " len:" << batch.length ; + ARROW_LOG(DEBUG) << "input received input:" << (IsBuildInput(input) ? "b" : "p") + << " seq:" << seq << " len:" << batch.length; ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); From 1fcbb1d5218da38432911d0365105c83e2124f62 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 4 Aug 2021 18:16:54 -0400 Subject: [PATCH 15/27] minor bug fix --- cpp/src/arrow/compute/exec/hash_join.cc | 79 +++++++++++++++++-------- 1 file changed, 53 insertions(+), 26 deletions(-) diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index 47140e552b8..a7fb60da43f 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -35,7 +35,8 @@ struct HashSemiJoinNode : ExecNode { build_index_field_ids_(build_index_field_ids), probe_index_field_ids_(probe_index_field_ids), build_result_index(-1), - hash_table_built_(false) {} + hash_table_built_(false), + cached_probe_batches_consumed(false) {} private: struct ThreadLocalState; @@ -44,7 +45,7 @@ struct HashSemiJoinNode : ExecNode { const char* kind_name() override { return "HashSemiJoinNode"; } Status InitLocalStateIfNeeded(ThreadLocalState* state) { - ARROW_LOG(DEBUG) << "init "; + ARROW_LOG(DEBUG) << "init state"; // Get input schema auto build_schema = inputs_[0]->output_schema(); @@ -80,11 +81,15 @@ struct HashSemiJoinNode : ExecNode { ARROW_LOG(DEBUG) << "build_result_index " << build_result_index; } - // Performs the housekeeping work after the build-side is completed. Note: this method - // should be called ONLY ONCE! + // Performs the housekeeping work after the build-side is completed. + // Note: this method is not thread safe, and hence should be guaranteed that it is + // not accessed concurrently! Status BuildSideCompleted() { ARROW_LOG(DEBUG) << "build side merge"; + // if the hash table has already been built, return + if (hash_table_built_) return Status::OK(); + CalculateBuildResultIndex(); // merge every group into the build_result_index grouper @@ -131,23 +136,28 @@ struct HashSemiJoinNode : ExecNode { } ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); - // Create a batch with group ids TODO(niranda) replace with void consume method + // Create a batch with group ids + // TODO(niranda) replace with void consume method ARROW_ASSIGN_OR_RAISE(Datum _, state->grouper->Consume(key_batch)); if (build_counter_.Increment()) { - // while incrementing, if the total is reached, call BuildSideCompleted + // only one thread would get inside this block! + // while incrementing, if the total is reached, call BuildSideCompleted. RETURN_NOT_OK(BuildSideCompleted()); } return Status::OK(); } - // consumes cached probe batches by invoking executor::Spawn. This should be called by a - // single thread. Note: this method should be called ONLY ONCE! + // consumes cached probe batches by invoking executor::Spawn. Status ConsumeCachedProbeBatches() { ARROW_LOG(DEBUG) << "ConsumeCachedProbeBatches tid:" << get_thread_index_() << " len:" << cached_probe_batches.size(); + // acquire the mutex to access cached_probe_batches, because while consuming, other + // batches should not be cached! + std::lock_guard lck(cached_probe_batches_mutex); + if (!cached_probe_batches.empty()) { auto executor = ctx_->executor(); for (auto&& cached : cached_probe_batches) { @@ -167,6 +177,9 @@ struct HashSemiJoinNode : ExecNode { // lambdas cached_probe_batches.clear(); } + + // set flag + cached_probe_batches_consumed = true; return Status::OK(); } @@ -214,12 +227,19 @@ struct HashSemiJoinNode : ExecNode { return Status::OK(); } - // void CacheProbeBatch(const size_t thread_index, int seq_num, ExecBatch batch) { - void CacheProbeBatch(int seq_num, ExecBatch batch) { + // Attempt to cache a probe batch. If it is not cached, return false. + // if cached_probe_batches_consumed is true, by the time a thread acquires + // cached_probe_batches_mutex, it should no longer be cached! instead, it can be + // directly consumed! + bool AttemptToCacheProbeBatch(int seq_num, ExecBatch* batch) { ARROW_LOG(DEBUG) << "cache tid:" << get_thread_index_() << " seq:" << seq_num - << " len:" << batch.length; + << " len:" << batch->length; std::lock_guard lck(cached_probe_batches_mutex); - cached_probe_batches.emplace_back(seq_num, std::move(batch)); + if (cached_probe_batches_consumed) { + return false; + } + cached_probe_batches.emplace_back(seq_num, std::move(*batch)); + return true; } inline bool IsBuildInput(ExecNode* input) { return input == inputs_[0]; } @@ -241,15 +261,18 @@ struct HashSemiJoinNode : ExecNode { ARROW_DCHECK(!hash_table_built_); ErrorIfNotOk(ConsumeBuildBatch(std::move(batch))); - } else { // probe input batch is received - if (hash_table_built_) { // build side done, continue with probing - // consume cachedProbeBatches if available (for this thread) - ErrorIfNotOk(ConsumeCachedProbeBatches()); + } else { // probe input batch is received + if (hash_table_built_) { + // build side done, continue with probing. when hash_table_built_ is set, it is + // guaranteed that some thread has already called the ConsumeCachedProbeBatches // consume this probe batch ErrorIfNotOk(ConsumeProbeBatch(seq, std::move(batch))); } else { // build side not completed. Cache this batch! - CacheProbeBatch(seq, std::move(batch)); + if (!AttemptToCacheProbeBatch(seq, &batch)) { + // if the cache attempt fails, consume the batch + ErrorIfNotOk(ConsumeProbeBatch(seq, std::move(batch))); + } } } } @@ -273,7 +296,8 @@ struct HashSemiJoinNode : ExecNode { // set total for build input if (IsBuildInput(input) && build_counter_.SetTotal(num_total)) { - // while incrementing, if the total is reached, call BuildSideCompleted() + // only one thread would get inside this block! + // while incrementing, if the total is reached, call BuildSideCompleted. ErrorIfNotOk(BuildSideCompleted()); return; } @@ -288,9 +312,9 @@ struct HashSemiJoinNode : ExecNode { ErrorIfNotOk(ConsumeCachedProbeBatches()); outputs_[0]->InputFinished(this, num_total); finished_.MarkFinished(); + } else { + outputs_[0]->InputFinished(this, num_total); } - outputs_[0]->InputFinished(this, num_total); - ARROW_LOG(DEBUG) << "output set:" << num_total; } Status StartProducing() override { @@ -324,9 +348,7 @@ struct HashSemiJoinNode : ExecNode { // TODO(niranda) couldn't there be multiple outputs for a Node? void StopProducing() override { ARROW_LOG(DEBUG) << "stop prod "; - for (auto&& output : outputs_) { - StopProducing(output); - } + outputs_[0]->StopProducing(); } Future<> finished() override { @@ -348,9 +370,9 @@ struct HashSemiJoinNode : ExecNode { AtomicCounter build_counter_, out_counter_; std::vector local_states_; - // we have no guarantee which threads would be coming from the build side. so, out of - // the thread local states, we need to find an appropriate index which could accumulate - // all build indices (ideally, the grouper which has the highest # of elems) + // There's no guarantee on which threads would be coming from the build side. so, out of + // the thread local states, an appropriate state needs to be chosen to accumulate + // all built results (ideally, the grouper which has the highest # of elems) int32_t build_result_index; // need a separate bool to track if the build side complete. Can't use the flag @@ -361,6 +383,11 @@ struct HashSemiJoinNode : ExecNode { std::mutex cached_probe_batches_mutex; std::vector> cached_probe_batches{}; + // a flag is required to indicate if the cached probe batches have already been + // consumed! if cached_probe_batches_consumed is true, by the time a thread aquires + // cached_probe_batches_mutex, it should no longer be cached! instead, it can be + // directly consumed! + bool cached_probe_batches_consumed; }; Status ValidateJoinInputs(ExecNode* left_input, ExecNode* right_input, From 614a9b918027dd16b6f74c9c157b52973b6d9433 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 4 Aug 2021 18:59:13 -0400 Subject: [PATCH 16/27] adding empty tests --- cpp/src/arrow/compute/exec/hash_join.cc | 7 +- cpp/src/arrow/compute/exec/hash_join_test.cc | 101 ++++++++++++++----- 2 files changed, 81 insertions(+), 27 deletions(-) diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index a7fb60da43f..8077099f17c 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -68,12 +68,13 @@ struct HashSemiJoinNode : ExecNode { // Finds an appropriate index which could accumulate all build indices (i.e. the grouper // which has the highest # of groups) void CalculateBuildResultIndex() { - uint32_t curr_max = 0; + int32_t curr_max = -1; for (int i = 0; i < static_cast(local_states_.size()); i++) { auto* state = &local_states_[i]; ARROW_DCHECK(state); - if (state->grouper && curr_max < state->grouper->num_groups()) { - curr_max = state->grouper->num_groups(); + if (state->grouper && + curr_max < static_cast(state->grouper->num_groups())) { + curr_max = static_cast(state->grouper->num_groups()); build_result_index = i; } } diff --git a/cpp/src/arrow/compute/exec/hash_join_test.cc b/cpp/src/arrow/compute/exec/hash_join_test.cc index 27b51c5cb30..a04eef60bc7 100644 --- a/cpp/src/arrow/compute/exec/hash_join_test.cc +++ b/cpp/src/arrow/compute/exec/hash_join_test.cc @@ -49,51 +49,100 @@ void GenerateBatchesFromString(const std::shared_ptr& schema, out_batches->schema = schema; } -void RunTest(JoinType type, bool parallel) { - auto l_schema = schema({field("l_i32", int32()), field("l_str", utf8())}); - auto r_schema = schema({field("r_str", utf8()), field("r_i32", int32())}); - BatchesWithSchema l_batches, r_batches, exp_batches; - - GenerateBatchesFromString(l_schema, - {R"([[0,"d"], [1,"b"]])", R"([[2,"d"], [3,"a"], [4,"a"]])", - R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, - &l_batches, /*multiplicity=*/parallel ? 100 : 1); - - GenerateBatchesFromString( - r_schema, - {R"([["f", 0], ["b", 1], ["b", 2]])", R"([["c", 3], ["g", 4]])", R"([["e", 5]])"}, - &r_batches, /*multiplicity=*/parallel ? 100 : 1); - +void CheckRunOutput(JoinType type, BatchesWithSchema l_batches, + BatchesWithSchema r_batches, + const std::vector& left_keys, + const std::vector& right_keys, + const BatchesWithSchema& exp_batches, bool parallel = false) { SCOPED_TRACE("serial"); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); ASSERT_OK_AND_ASSIGN(auto l_source, - MakeTestSourceNode(plan.get(), "l_source", l_batches, + MakeTestSourceNode(plan.get(), "l_source", std::move(l_batches), /*parallel=*/parallel, /*slow=*/false)); ASSERT_OK_AND_ASSIGN(auto r_source, - MakeTestSourceNode(plan.get(), "r_source", r_batches, + MakeTestSourceNode(plan.get(), "r_source", std::move(r_batches), /*parallel=*/parallel, /*slow=*/false)); ASSERT_OK_AND_ASSIGN( auto semi_join, - MakeHashJoinNode(type, l_source, r_source, "l_semi_join", - /*left_keys=*/{"l_str"}, /*right_keys=*/{"r_str"})); + MakeHashJoinNode(type, l_source, r_source, "hash_join", left_keys, right_keys)); auto sink_gen = MakeSinkNode(semi_join, "sink"); + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); +} + +void RunNonEmptyTest(JoinType type, bool parallel) { + auto l_schema = schema({field("l_i32", int32()), field("l_str", utf8())}); + auto r_schema = schema({field("r_str", utf8()), field("r_i32", int32())}); + BatchesWithSchema l_batches, r_batches, exp_batches; + + int multiplicity = parallel ? 100 : 1; + + GenerateBatchesFromString(l_schema, + {R"([[0,"d"], [1,"b"]])", R"([[2,"d"], [3,"a"], [4,"a"]])", + R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, + &l_batches, multiplicity); + + GenerateBatchesFromString( + r_schema, + {R"([["f", 0], ["b", 1], ["b", 2]])", R"([["c", 3], ["g", 4]])", R"([["e", 5]])"}, + &r_batches, multiplicity); + if (type == JoinType::LEFT_SEMI) { GenerateBatchesFromString( l_schema, {R"([[1,"b"]])", R"([])", R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, - &exp_batches, /*multiplicity=*/parallel ? 100 : 1); + &exp_batches, multiplicity); } else if (type == JoinType::RIGHT_SEMI) { GenerateBatchesFromString( r_schema, {R"([["b", 1], ["b", 2]])", R"([["c", 3]])", R"([["e", 5]])"}, - &exp_batches, /*multiplicity=*/parallel ? 100 : 1); + &exp_batches, multiplicity); } - ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), - Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); + + CheckRunOutput(type, std::move(l_batches), std::move(r_batches), + /*left_keys=*/{"l_str"}, /*right_keys=*/{"r_str"}, exp_batches, + parallel); +} + +void RunEmptyTest(JoinType type, bool parallel) { + auto l_schema = schema({field("l_i32", int32()), field("l_str", utf8())}); + auto r_schema = schema({field("r_str", utf8()), field("r_i32", int32())}); + BatchesWithSchema l_batches, r_batches, exp_batches; + + int multiplicity = parallel ? 100 : 1; + + if (type == JoinType::LEFT_SEMI) { + GenerateBatchesFromString(l_schema, {R"([])"}, &exp_batches, multiplicity); + } else if (type == JoinType::RIGHT_SEMI) { + GenerateBatchesFromString(r_schema, {R"([])"}, &exp_batches, multiplicity); + } + + // both empty + GenerateBatchesFromString(l_schema, {R"([])"}, &l_batches, multiplicity); + GenerateBatchesFromString(r_schema, {R"([])"}, &r_batches, multiplicity); + CheckRunOutput(type, std::move(l_batches), std::move(r_batches), + /*left_keys=*/{"l_str"}, /*right_keys=*/{"r_str"}, exp_batches, + parallel); + + // left empty + GenerateBatchesFromString(l_schema, {R"([])"}, &l_batches, multiplicity); + GenerateBatchesFromString(r_schema, {R"([["f", 0], ["b", 1], ["b", 2]])"}, &r_batches, + multiplicity); + CheckRunOutput(type, std::move(l_batches), std::move(r_batches), + /*left_keys=*/{"l_str"}, /*right_keys=*/{"r_str"}, exp_batches, + parallel); + + // right empty + GenerateBatchesFromString(l_schema, {R"([[0,"d"], [1,"b"]])"}, &l_batches, + multiplicity); + GenerateBatchesFromString(r_schema, {R"([])"}, &r_batches, multiplicity); + CheckRunOutput(type, std::move(l_batches), std::move(r_batches), + /*left_keys=*/{"l_str"}, /*right_keys=*/{"r_str"}, exp_batches, + parallel); } class HashJoinTest : public testing::TestWithParam> {}; @@ -104,7 +153,11 @@ INSTANTIATE_TEST_SUITE_P(HashJoinTest, HashJoinTest, ::testing::Values(false, true))); TEST_P(HashJoinTest, TestSemiJoins) { - RunTest(std::get<0>(GetParam()), std::get<1>(GetParam())); + RunNonEmptyTest(std::get<0>(GetParam()), std::get<1>(GetParam())); +} + +TEST_P(HashJoinTest, TestSemiJoinsLeftEmpty) { + RunEmptyTest(std::get<0>(GetParam()), std::get<1>(GetParam())); } } // namespace compute From 43115d3769d77ef884a47ff2929f8588112d84ad Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 4 Aug 2021 19:56:31 -0400 Subject: [PATCH 17/27] lint changes --- cpp/src/arrow/compute/exec/exec_utils.h | 2 ++ cpp/src/arrow/compute/exec/hash_join.cc | 10 +++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_utils.h b/cpp/src/arrow/compute/exec/exec_utils.h index 65dd93150e1..d6ecbda26b6 100644 --- a/cpp/src/arrow/compute/exec/exec_utils.h +++ b/cpp/src/arrow/compute/exec/exec_utils.h @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#pragma once + #include #include #include diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index 8077099f17c..f00344e13bb 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -464,14 +464,14 @@ Result MakeHashRightSemiJoinNode(ExecNode* left_input, ExecNode* righ right_keys); } -static std::string JoinTypeToString[] = {"LEFT_SEMI", "RIGHT_SEMI", "LEFT_ANTI", - "RIGHT_ANTI", "INNER", "LEFT_OUTER", - "RIGHT_OUTER", "FULL_OUTER"}; - Result MakeHashJoinNode(JoinType join_type, ExecNode* left_input, ExecNode* right_input, std::string label, const std::vector& left_keys, const std::vector& right_keys) { + static std::string join_type_string[] = {"LEFT_SEMI", "RIGHT_SEMI", "LEFT_ANTI", + "RIGHT_ANTI", "INNER", "LEFT_OUTER", + "RIGHT_OUTER", "FULL_OUTER"}; + switch (join_type) { case LEFT_SEMI: // left join--> build from right and probe from left @@ -487,7 +487,7 @@ Result MakeHashJoinNode(JoinType join_type, ExecNode* left_input, case LEFT_OUTER: case RIGHT_OUTER: case FULL_OUTER: - return Status::NotImplemented(JoinTypeToString[join_type] + + return Status::NotImplemented(join_type_string[join_type] + " joins not implemented!"); default: return Status::Invalid("invalid join type"); From 3fcfbb481015a967819005cce915d720b228161a Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 5 Aug 2021 10:39:38 -0400 Subject: [PATCH 18/27] fixing c++/cli mutex import --- cpp/src/arrow/compute/exec/exec_utils.cc | 4 ++-- cpp/src/arrow/compute/exec/exec_utils.h | 4 ++-- cpp/src/arrow/compute/exec/hash_join.cc | 4 +++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_utils.cc b/cpp/src/arrow/compute/exec/exec_utils.cc index f1a96ac0812..7026351e0b7 100644 --- a/cpp/src/arrow/compute/exec/exec_utils.cc +++ b/cpp/src/arrow/compute/exec/exec_utils.cc @@ -17,7 +17,7 @@ #include "arrow/compute/exec/exec_utils.h" -#include +#include "arrow/util/logging.h" namespace arrow { namespace compute { @@ -25,7 +25,7 @@ namespace compute { size_t ThreadIndexer::operator()() { auto id = std::this_thread::get_id(); - std::unique_lock lock(mutex_); + auto guard = mutex_.Lock(); // acquire the lock const auto& id_index = *id_to_index_.emplace(id, id_to_index_.size()).first; return Check(id_index.second); diff --git a/cpp/src/arrow/compute/exec/exec_utils.h b/cpp/src/arrow/compute/exec/exec_utils.h index d6ecbda26b6..93cd0775098 100644 --- a/cpp/src/arrow/compute/exec/exec_utils.h +++ b/cpp/src/arrow/compute/exec/exec_utils.h @@ -17,10 +17,10 @@ #pragma once -#include #include #include +#include "arrow/util/mutex.h" #include "arrow/util/thread_pool.h" namespace arrow { @@ -35,7 +35,7 @@ class ThreadIndexer { private: static size_t Check(size_t thread_index); - std::mutex mutex_; + util::Mutex mutex_; std::unordered_map id_to_index_; }; diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index f00344e13bb..9dd3a037dcd 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#include + #include "arrow/api.h" #include "arrow/compute/api.h" #include "arrow/compute/exec/exec_plan.h" @@ -159,7 +161,7 @@ struct HashSemiJoinNode : ExecNode { // batches should not be cached! std::lock_guard lck(cached_probe_batches_mutex); - if (!cached_probe_batches.empty()) { + if (!cached_probe_batches_consumed) { auto executor = ctx_->executor(); for (auto&& cached : cached_probe_batches) { if (executor) { From bbdd30a7253999d0e389f2d73c1f70d4d6a6f418 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 5 Aug 2021 16:32:49 -0400 Subject: [PATCH 19/27] adding anti-joins --- cpp/src/arrow/compute/exec/hash_join.cc | 115 +++++++++++------- cpp/src/arrow/compute/exec/hash_join_test.cc | 117 +++++++++++++------ 2 files changed, 156 insertions(+), 76 deletions(-) diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index 9dd3a037dcd..95db8f254cc 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -21,11 +21,13 @@ #include "arrow/compute/api.h" #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/exec_utils.h" +#include "arrow/util/bitmap_ops.h" #include "arrow/util/logging.h" namespace arrow { namespace compute { +template struct HashSemiJoinNode : ExecNode { HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, std::string label, ExecContext* ctx, const std::vector&& build_index_field_ids, @@ -186,6 +188,33 @@ struct HashSemiJoinNode : ExecNode { return Status::OK(); } + Status GenerateOutput(int seq, const ArrayData& group_ids_data, ExecBatch batch) { + if (group_ids_data.GetNullCount() == batch.length) { + // All NULLS! hence, there are no valid outputs! + ARROW_LOG(DEBUG) << "output seq:" << seq << " 0"; + outputs_[0]->InputReceived(this, seq, batch.Slice(0, 0)); + } else if (group_ids_data.MayHaveNulls()) { // values need to be filtered + auto filter_arr = + std::make_shared(group_ids_data.length, group_ids_data.buffers[0], + /*null_bitmap=*/nullptr, /*null_count=*/0, + /*offset=*/group_ids_data.offset); + ARROW_ASSIGN_OR_RAISE(auto rec_batch, + batch.ToRecordBatch(output_schema_, ctx_->memory_pool())); + ARROW_ASSIGN_OR_RAISE( + auto filtered, + Filter(rec_batch, filter_arr, + /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); + auto out_batch = ExecBatch(*filtered.record_batch()); + ARROW_LOG(DEBUG) << "output seq:" << seq << " " << out_batch.length; + outputs_[0]->InputReceived(this, seq, std::move(out_batch)); + } else { // all values are valid for output + ARROW_LOG(DEBUG) << "output seq:" << seq << " " << batch.length; + outputs_[0]->InputReceived(this, seq, std::move(batch)); + } + + return Status::OK(); + } + // consumes a probe batch and increment probe batches count. Probing would query the // grouper[build_result_index] which have been merged with all others. Status ConsumeProbeBatch(int seq, ExecBatch batch) { @@ -205,24 +234,7 @@ struct HashSemiJoinNode : ExecNode { ARROW_ASSIGN_OR_RAISE(Datum group_ids, final_grouper.Find(key_batch)); auto group_ids_data = *group_ids.array(); - if (group_ids_data.MayHaveNulls()) { // values need to be filtered - auto filter_arr = - std::make_shared(group_ids_data.length, group_ids_data.buffers[0], - /*null_bitmap=*/nullptr, /*null_count=*/0, - /*offset=*/group_ids_data.offset); - ARROW_ASSIGN_OR_RAISE(auto rec_batch, - batch.ToRecordBatch(output_schema_, ctx_->memory_pool())); - ARROW_ASSIGN_OR_RAISE( - auto filtered, - Filter(rec_batch, filter_arr, - /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); - auto out_batch = ExecBatch(*filtered.record_batch()); - ARROW_LOG(DEBUG) << "output seq:" << seq << " " << out_batch.length; - outputs_[0]->InputReceived(this, seq, std::move(out_batch)); - } else { // all values are valid for output - ARROW_LOG(DEBUG) << "output seq:" << seq << " " << batch.length; - outputs_[0]->InputReceived(this, seq, std::move(batch)); - } + RETURN_NOT_OK(GenerateOutput(seq, group_ids_data, std::move(batch))); if (out_counter_.Increment()) { finished_.MarkFinished(); @@ -393,6 +405,42 @@ struct HashSemiJoinNode : ExecNode { bool cached_probe_batches_consumed; }; +// template specialization for anti joins. For anti joins, group_ids_data needs to be +// inverted. Output will be taken for indices which are NULL +template <> +Status HashSemiJoinNode::GenerateOutput(int seq, const ArrayData& group_ids_data, + ExecBatch batch) { + if (group_ids_data.GetNullCount() == group_ids_data.length) { + // All NULLS! hence, all values are valid for output + ARROW_LOG(DEBUG) << "output seq:" << seq << " " << batch.length; + outputs_[0]->InputReceived(this, seq, std::move(batch)); + } else if (group_ids_data.MayHaveNulls()) { // values need to be filtered + // invert the validity buffer + arrow::internal::InvertBitmap( + group_ids_data.buffers[0]->data(), group_ids_data.offset, group_ids_data.length, + group_ids_data.buffers[0]->mutable_data(), group_ids_data.offset); + + auto filter_arr = + std::make_shared(group_ids_data.length, group_ids_data.buffers[0], + /*null_bitmap=*/nullptr, /*null_count=*/0, + /*offset=*/group_ids_data.offset); + ARROW_ASSIGN_OR_RAISE(auto rec_batch, + batch.ToRecordBatch(output_schema_, ctx_->memory_pool())); + ARROW_ASSIGN_OR_RAISE( + auto filtered, + Filter(rec_batch, filter_arr, + /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); + auto out_batch = ExecBatch(*filtered.record_batch()); + ARROW_LOG(DEBUG) << "output seq:" << seq << " " << out_batch.length; + outputs_[0]->InputReceived(this, seq, std::move(out_batch)); + } else { + // No NULLS! hence, there are no valid outputs! + ARROW_LOG(DEBUG) << "output seq:" << seq << " 0"; + outputs_[0]->InputReceived(this, seq, batch.Slice(0, 0)); + } + return Status::OK(); +} + Status ValidateJoinInputs(ExecNode* left_input, ExecNode* right_input, const std::vector& left_keys, const std::vector& right_keys) { @@ -427,6 +475,7 @@ Result> PopulateKeys(const Schema& schema, return key_field_ids; } +template Result MakeHashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, std::string label, const std::vector& build_keys, @@ -443,27 +492,9 @@ Result MakeHashSemiJoinNode(ExecNode* build_input, ExecNode* probe_in auto ctx = build_input->plan()->exec_context(); ExecPlan* plan = build_input->plan(); - return plan->EmplaceNode(build_input, probe_input, std::move(label), - ctx, std::move(build_key_ids), - std::move(probe_key_ids)); -} - -Result MakeHashLeftSemiJoinNode(ExecNode* left_input, ExecNode* right_input, - std::string label, - const std::vector& left_keys, - const std::vector& right_keys) { - // left join--> build from right and probe from left - return MakeHashSemiJoinNode(right_input, left_input, std::move(label), right_keys, - left_keys); -} - -Result MakeHashRightSemiJoinNode(ExecNode* left_input, ExecNode* right_input, - std::string label, - const std::vector& left_keys, - const std::vector& right_keys) { - // right join--> build from left and probe from right - return MakeHashSemiJoinNode(left_input, right_input, std::move(label), left_keys, - right_keys); + return plan->EmplaceNode>( + build_input, probe_input, std::move(label), ctx, std::move(build_key_ids), + std::move(probe_key_ids)); } Result MakeHashJoinNode(JoinType join_type, ExecNode* left_input, @@ -484,7 +515,13 @@ Result MakeHashJoinNode(JoinType join_type, ExecNode* left_input, return MakeHashSemiJoinNode(left_input, right_input, std::move(label), left_keys, right_keys); case LEFT_ANTI: + // left join--> build from right and probe from left + return MakeHashSemiJoinNode(right_input, left_input, std::move(label), + right_keys, left_keys); case RIGHT_ANTI: + // right join--> build from left and probe from right + return MakeHashSemiJoinNode(left_input, right_input, std::move(label), + left_keys, right_keys); case INNER: case LEFT_OUTER: case RIGHT_OUTER: diff --git a/cpp/src/arrow/compute/exec/hash_join_test.cc b/cpp/src/arrow/compute/exec/hash_join_test.cc index a04eef60bc7..816d6adaf7a 100644 --- a/cpp/src/arrow/compute/exec/hash_join_test.cc +++ b/cpp/src/arrow/compute/exec/hash_join_test.cc @@ -93,14 +93,32 @@ void RunNonEmptyTest(JoinType type, bool parallel) { {R"([["f", 0], ["b", 1], ["b", 2]])", R"([["c", 3], ["g", 4]])", R"([["e", 5]])"}, &r_batches, multiplicity); - if (type == JoinType::LEFT_SEMI) { - GenerateBatchesFromString( - l_schema, {R"([[1,"b"]])", R"([])", R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, - &exp_batches, multiplicity); - } else if (type == JoinType::RIGHT_SEMI) { - GenerateBatchesFromString( - r_schema, {R"([["b", 1], ["b", 2]])", R"([["c", 3]])", R"([["e", 5]])"}, - &exp_batches, multiplicity); + switch (type) { + case LEFT_SEMI: + GenerateBatchesFromString( + l_schema, {R"([[1,"b"]])", R"([])", R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, + &exp_batches, multiplicity); + break; + case RIGHT_SEMI: + GenerateBatchesFromString( + r_schema, {R"([["b", 1], ["b", 2]])", R"([["c", 3]])", R"([["e", 5]])"}, + &exp_batches, multiplicity); + break; + case LEFT_ANTI: + GenerateBatchesFromString( + l_schema, {R"([[0,"d"]])", R"([[2,"d"], [3,"a"], [4,"a"]])", R"([])"}, + &exp_batches, multiplicity); + break; + case RIGHT_ANTI: + GenerateBatchesFromString(r_schema, {R"([["f", 0]])", R"([["g", 4]])", R"([])"}, + &exp_batches, multiplicity); + break; + case INNER: + case LEFT_OUTER: + case RIGHT_OUTER: + case FULL_OUTER: + default: + FAIL() << "join type not implemented!"; } CheckRunOutput(type, std::move(l_batches), std::move(r_batches), @@ -111,46 +129,71 @@ void RunNonEmptyTest(JoinType type, bool parallel) { void RunEmptyTest(JoinType type, bool parallel) { auto l_schema = schema({field("l_i32", int32()), field("l_str", utf8())}); auto r_schema = schema({field("r_str", utf8()), field("r_i32", int32())}); - BatchesWithSchema l_batches, r_batches, exp_batches; int multiplicity = parallel ? 100 : 1; - if (type == JoinType::LEFT_SEMI) { - GenerateBatchesFromString(l_schema, {R"([])"}, &exp_batches, multiplicity); - } else if (type == JoinType::RIGHT_SEMI) { - GenerateBatchesFromString(r_schema, {R"([])"}, &exp_batches, multiplicity); - } + BatchesWithSchema l_empty, r_empty, l_n_empty, r_n_empty; - // both empty - GenerateBatchesFromString(l_schema, {R"([])"}, &l_batches, multiplicity); - GenerateBatchesFromString(r_schema, {R"([])"}, &r_batches, multiplicity); - CheckRunOutput(type, std::move(l_batches), std::move(r_batches), - /*left_keys=*/{"l_str"}, /*right_keys=*/{"r_str"}, exp_batches, - parallel); + GenerateBatchesFromString(l_schema, {R"([])"}, &l_empty, multiplicity); + GenerateBatchesFromString(r_schema, {R"([])"}, &r_empty, multiplicity); - // left empty - GenerateBatchesFromString(l_schema, {R"([])"}, &l_batches, multiplicity); - GenerateBatchesFromString(r_schema, {R"([["f", 0], ["b", 1], ["b", 2]])"}, &r_batches, + GenerateBatchesFromString(l_schema, {R"([[0,"d"], [1,"b"]])"}, &l_n_empty, multiplicity); - CheckRunOutput(type, std::move(l_batches), std::move(r_batches), - /*left_keys=*/{"l_str"}, /*right_keys=*/{"r_str"}, exp_batches, - parallel); - - // right empty - GenerateBatchesFromString(l_schema, {R"([[0,"d"], [1,"b"]])"}, &l_batches, + GenerateBatchesFromString(r_schema, {R"([["f", 0], ["b", 1], ["b", 2]])"}, &r_n_empty, multiplicity); - GenerateBatchesFromString(r_schema, {R"([])"}, &r_batches, multiplicity); - CheckRunOutput(type, std::move(l_batches), std::move(r_batches), - /*left_keys=*/{"l_str"}, /*right_keys=*/{"r_str"}, exp_batches, - parallel); + + std::vector l_keys{"l_str"}; + std::vector r_keys{"r_str"}; + + switch (type) { + case LEFT_SEMI: + // both empty + CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, l_empty, parallel); + // right empty + CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, l_empty, parallel); + // left empty + CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, l_empty, parallel); + break; + case RIGHT_SEMI: + // both empty + CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, r_empty, parallel); + // right empty + CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, r_empty, parallel); + // left empty + CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, r_empty, parallel); + break; + case LEFT_ANTI: + // both empty + CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, l_empty, parallel); + // right empty + CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, l_n_empty, parallel); + // left empty + CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, l_empty, parallel); + break; + case RIGHT_ANTI: + // both empty + CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, r_empty, parallel); + // right empty + CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, r_empty, parallel); + // left empty + CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, r_n_empty, parallel); + break; + case INNER: + case LEFT_OUTER: + case RIGHT_OUTER: + case FULL_OUTER: + default: + FAIL() << "join type not implemented!"; + } } class HashJoinTest : public testing::TestWithParam> {}; -INSTANTIATE_TEST_SUITE_P(HashJoinTest, HashJoinTest, - ::testing::Combine(::testing::Values(JoinType::LEFT_SEMI, - JoinType::RIGHT_SEMI), - ::testing::Values(false, true))); +INSTANTIATE_TEST_SUITE_P( + HashJoinTest, HashJoinTest, + ::testing::Combine(::testing::Values(JoinType::LEFT_SEMI, JoinType::RIGHT_SEMI, + JoinType::LEFT_ANTI, JoinType::RIGHT_ANTI), + ::testing::Values(false, true))); TEST_P(HashJoinTest, TestSemiJoins) { RunNonEmptyTest(std::get<0>(GetParam()), std::get<1>(GetParam())); From ac197856854e2a8016549392a56d1d8ce6fcb818 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 9 Aug 2021 14:36:07 -0400 Subject: [PATCH 20/27] attempting to solve the threading issue --- cpp/src/arrow/compute/exec/hash_join.cc | 123 ++++++++++++++++-------- 1 file changed, 84 insertions(+), 39 deletions(-) diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index 95db8f254cc..154a117a1fe 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#include + #include #include "arrow/api.h" @@ -27,6 +29,39 @@ namespace arrow { namespace compute { +namespace { + +struct FreeIndexFinder { + public: + explicit FreeIndexFinder(size_t max_indices) : indices_(max_indices) { + std::lock_guard lock(mutex_); + std::fill(indices_.begin(), indices_.end(), true); + ARROW_DCHECK_LE(indices_.size(), max_indices); + } + + /// return the first available index + size_t FindAvailableIndex() { + std::lock_guard lock(mutex_); + auto it = std::find(indices_.begin(), indices_.end(), true); + ARROW_DCHECK_NE(it, indices_.end()); + *it = false; + return std::distance(indices_.begin(), it); + } + + /// release the index + void ReleaseIndex(size_t idx) { + std::lock_guard lock(mutex_); + // check if the indices_[idx] == false + ARROW_DCHECK(idx < indices_.size() && !indices_.at(idx)); + indices_.at(idx) = true; + } + + private: + std::mutex mutex_; + std::vector indices_; +}; +} // namespace + template struct HashSemiJoinNode : ExecNode { HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, std::string label, @@ -36,6 +71,7 @@ struct HashSemiJoinNode : ExecNode { {"hash_join_build", "hash_join_probe"}, probe_input->output_schema(), /*num_outputs=*/1), ctx_(ctx), + free_index_finder(nullptr), build_index_field_ids_(build_index_field_ids), probe_index_field_ids_(probe_index_field_ids), build_result_index(-1), @@ -49,7 +85,7 @@ struct HashSemiJoinNode : ExecNode { const char* kind_name() override { return "HashSemiJoinNode"; } Status InitLocalStateIfNeeded(ThreadLocalState* state) { - ARROW_LOG(DEBUG) << "init state"; + ARROW_LOG(WARNING) << "init state"; // Get input schema auto build_schema = inputs_[0]->output_schema(); @@ -72,25 +108,26 @@ struct HashSemiJoinNode : ExecNode { // Finds an appropriate index which could accumulate all build indices (i.e. the grouper // which has the highest # of groups) void CalculateBuildResultIndex() { - int32_t curr_max = -1; - for (int i = 0; i < static_cast(local_states_.size()); i++) { - auto* state = &local_states_[i]; - ARROW_DCHECK(state); - if (state->grouper && - curr_max < static_cast(state->grouper->num_groups())) { - curr_max = static_cast(state->grouper->num_groups()); - build_result_index = i; - } - } - ARROW_DCHECK(build_result_index > -1); - ARROW_LOG(DEBUG) << "build_result_index " << build_result_index; + // int32_t curr_max = -1; + // for (int i = 0; i < static_cast(local_states_.size()); i++) { + // auto* state = &local_states_[i]; + // ARROW_DCHECK(state); + // if (state->grouper && + // curr_max < static_cast(state->grouper->num_groups())) { + // curr_max = static_cast(state->grouper->num_groups()); + // build_result_index = i; + // } + // } + // ARROW_DCHECK(build_result_index > -1); + // ARROW_LOG(WARNING) << "build_result_index " << build_result_index; + build_result_index = 0; } // Performs the housekeeping work after the build-side is completed. // Note: this method is not thread safe, and hence should be guaranteed that it is // not accessed concurrently! Status BuildSideCompleted() { - ARROW_LOG(DEBUG) << "build side merge"; + ARROW_LOG(WARNING) << "build side merge"; // if the hash table has already been built, return if (hash_table_built_) return Status::OK(); @@ -125,11 +162,13 @@ struct HashSemiJoinNode : ExecNode { // total reached at the end of consumption, all the local states will be merged, before // incrementing the total batches Status ConsumeBuildBatch(ExecBatch batch) { - size_t thread_index = get_thread_index_(); - ARROW_DCHECK(thread_index < local_states_.size()); + // size_t thread_index = get_thread_index_(); + // get a free index from the finder + int thread_index = free_index_finder->FindAvailableIndex(); + ARROW_DCHECK(static_cast(thread_index) < local_states_.size()); - ARROW_LOG(DEBUG) << "ConsumeBuildBatch tid:" << thread_index - << " len:" << batch.length; + ARROW_LOG(WARNING) << "ConsumeBuildBatch tid:" << thread_index + << " len:" << batch.length; auto state = &local_states_[thread_index]; RETURN_NOT_OK(InitLocalStateIfNeeded(state)); @@ -145,6 +184,8 @@ struct HashSemiJoinNode : ExecNode { // TODO(niranda) replace with void consume method ARROW_ASSIGN_OR_RAISE(Datum _, state->grouper->Consume(key_batch)); + free_index_finder->ReleaseIndex(thread_index); + if (build_counter_.Increment()) { // only one thread would get inside this block! // while incrementing, if the total is reached, call BuildSideCompleted. @@ -156,8 +197,8 @@ struct HashSemiJoinNode : ExecNode { // consumes cached probe batches by invoking executor::Spawn. Status ConsumeCachedProbeBatches() { - ARROW_LOG(DEBUG) << "ConsumeCachedProbeBatches tid:" << get_thread_index_() - << " len:" << cached_probe_batches.size(); + ARROW_LOG(WARNING) << "ConsumeCachedProbeBatches tid:" /*<< get_thread_index_()*/ + << " len:" << cached_probe_batches.size(); // acquire the mutex to access cached_probe_batches, because while consuming, other // batches should not be cached! @@ -191,7 +232,7 @@ struct HashSemiJoinNode : ExecNode { Status GenerateOutput(int seq, const ArrayData& group_ids_data, ExecBatch batch) { if (group_ids_data.GetNullCount() == batch.length) { // All NULLS! hence, there are no valid outputs! - ARROW_LOG(DEBUG) << "output seq:" << seq << " 0"; + ARROW_LOG(WARNING) << "output seq:" << seq << " 0"; outputs_[0]->InputReceived(this, seq, batch.Slice(0, 0)); } else if (group_ids_data.MayHaveNulls()) { // values need to be filtered auto filter_arr = @@ -205,10 +246,10 @@ struct HashSemiJoinNode : ExecNode { Filter(rec_batch, filter_arr, /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); auto out_batch = ExecBatch(*filtered.record_batch()); - ARROW_LOG(DEBUG) << "output seq:" << seq << " " << out_batch.length; + ARROW_LOG(WARNING) << "output seq:" << seq << " " << out_batch.length; outputs_[0]->InputReceived(this, seq, std::move(out_batch)); } else { // all values are valid for output - ARROW_LOG(DEBUG) << "output seq:" << seq << " " << batch.length; + ARROW_LOG(WARNING) << "output seq:" << seq << " " << batch.length; outputs_[0]->InputReceived(this, seq, std::move(batch)); } @@ -218,7 +259,7 @@ struct HashSemiJoinNode : ExecNode { // consumes a probe batch and increment probe batches count. Probing would query the // grouper[build_result_index] which have been merged with all others. Status ConsumeProbeBatch(int seq, ExecBatch batch) { - ARROW_LOG(DEBUG) << "ConsumeProbeBatch seq:" << seq; + ARROW_LOG(WARNING) << "ConsumeProbeBatch seq:" << seq; auto& final_grouper = *local_states_[build_result_index].grouper; @@ -247,8 +288,8 @@ struct HashSemiJoinNode : ExecNode { // cached_probe_batches_mutex, it should no longer be cached! instead, it can be // directly consumed! bool AttemptToCacheProbeBatch(int seq_num, ExecBatch* batch) { - ARROW_LOG(DEBUG) << "cache tid:" << get_thread_index_() << " seq:" << seq_num - << " len:" << batch->length; + ARROW_LOG(WARNING) << "cache tid:" /*<< get_thread_index_() */ << " seq:" << seq_num + << " len:" << batch->length; std::lock_guard lck(cached_probe_batches_mutex); if (cached_probe_batches_consumed) { return false; @@ -262,8 +303,8 @@ struct HashSemiJoinNode : ExecNode { // If all build side batches received? continue streaming using probing // else cache the batches in thread-local state void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { - ARROW_LOG(DEBUG) << "input received input:" << (IsBuildInput(input) ? "b" : "p") - << " seq:" << seq << " len:" << batch.length; + ARROW_LOG(WARNING) << "input received input:" << (IsBuildInput(input) ? "b" : "p") + << " seq:" << seq << " len:" << batch.length; ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); @@ -293,7 +334,7 @@ struct HashSemiJoinNode : ExecNode { } void ErrorReceived(ExecNode* input, Status error) override { - ARROW_LOG(DEBUG) << "error received " << error.ToString(); + ARROW_LOG(WARNING) << "error received " << error.ToString(); DCHECK_EQ(input, inputs_[0]); outputs_[0]->ErrorReceived(this, std::move(error)); @@ -301,8 +342,8 @@ struct HashSemiJoinNode : ExecNode { } void InputFinished(ExecNode* input, int num_total) override { - ARROW_LOG(DEBUG) << "input finished input:" << (IsBuildInput(input) ? "b" : "p") - << " tot:" << num_total; + ARROW_LOG(WARNING) << "input finished input:" << (IsBuildInput(input) ? "b" : "p") + << " tot:" << num_total; // bail if StopProducing was called if (finished_.is_finished()) return; @@ -333,10 +374,12 @@ struct HashSemiJoinNode : ExecNode { } Status StartProducing() override { - ARROW_LOG(DEBUG) << "start prod"; + ARROW_LOG(WARNING) << "start prod"; finished_ = Future<>::Make(); local_states_.resize(ThreadIndexer::Capacity()); + free_index_finder = + arrow::internal::make_unique(ThreadIndexer::Capacity()); return Status::OK(); } @@ -345,7 +388,7 @@ struct HashSemiJoinNode : ExecNode { void ResumeProducing(ExecNode* output) override {} void StopProducing(ExecNode* output) override { - ARROW_LOG(DEBUG) << "stop prod from node"; + ARROW_LOG(WARNING) << "stop prod from node"; DCHECK_EQ(output, outputs_[0]); @@ -362,12 +405,12 @@ struct HashSemiJoinNode : ExecNode { // TODO(niranda) couldn't there be multiple outputs for a Node? void StopProducing() override { - ARROW_LOG(DEBUG) << "stop prod "; + ARROW_LOG(WARNING) << "stop prod "; outputs_[0]->StopProducing(); } Future<> finished() override { - ARROW_LOG(DEBUG) << "finished? " << finished_.is_finished(); + ARROW_LOG(WARNING) << "finished? " << finished_.is_finished(); return finished_; } @@ -379,10 +422,12 @@ struct HashSemiJoinNode : ExecNode { ExecContext* ctx_; Future<> finished_ = Future<>::MakeFinished(); - ThreadIndexer get_thread_index_; + // ThreadIndexer get_thread_index_; + std::unique_ptr free_index_finder; const std::vector build_index_field_ids_, probe_index_field_ids_; AtomicCounter build_counter_, out_counter_; + std::vector local_states_; // There's no guarantee on which threads would be coming from the build side. so, out of @@ -412,7 +457,7 @@ Status HashSemiJoinNode::GenerateOutput(int seq, const ArrayData& group_id ExecBatch batch) { if (group_ids_data.GetNullCount() == group_ids_data.length) { // All NULLS! hence, all values are valid for output - ARROW_LOG(DEBUG) << "output seq:" << seq << " " << batch.length; + ARROW_LOG(WARNING) << "output seq:" << seq << " " << batch.length; outputs_[0]->InputReceived(this, seq, std::move(batch)); } else if (group_ids_data.MayHaveNulls()) { // values need to be filtered // invert the validity buffer @@ -431,11 +476,11 @@ Status HashSemiJoinNode::GenerateOutput(int seq, const ArrayData& group_id Filter(rec_batch, filter_arr, /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); auto out_batch = ExecBatch(*filtered.record_batch()); - ARROW_LOG(DEBUG) << "output seq:" << seq << " " << out_batch.length; + ARROW_LOG(WARNING) << "output seq:" << seq << " " << out_batch.length; outputs_[0]->InputReceived(this, seq, std::move(out_batch)); } else { // No NULLS! hence, there are no valid outputs! - ARROW_LOG(DEBUG) << "output seq:" << seq << " 0"; + ARROW_LOG(WARNING) << "output seq:" << seq << " 0"; outputs_[0]->InputReceived(this, seq, batch.Slice(0, 0)); } return Status::OK(); From c76df38abcf5646dd5d261a60d7ea9b6a2cda87e Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 9 Aug 2021 15:45:38 -0400 Subject: [PATCH 21/27] Revert "attempting to solve the threading issue" This reverts commit 0a3bcbf58a346b252884be0a4e06e302f86d17bb. --- cpp/src/arrow/compute/exec/hash_join.cc | 123 ++++++++---------------- 1 file changed, 39 insertions(+), 84 deletions(-) diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index 154a117a1fe..95db8f254cc 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -#include - #include #include "arrow/api.h" @@ -29,39 +27,6 @@ namespace arrow { namespace compute { -namespace { - -struct FreeIndexFinder { - public: - explicit FreeIndexFinder(size_t max_indices) : indices_(max_indices) { - std::lock_guard lock(mutex_); - std::fill(indices_.begin(), indices_.end(), true); - ARROW_DCHECK_LE(indices_.size(), max_indices); - } - - /// return the first available index - size_t FindAvailableIndex() { - std::lock_guard lock(mutex_); - auto it = std::find(indices_.begin(), indices_.end(), true); - ARROW_DCHECK_NE(it, indices_.end()); - *it = false; - return std::distance(indices_.begin(), it); - } - - /// release the index - void ReleaseIndex(size_t idx) { - std::lock_guard lock(mutex_); - // check if the indices_[idx] == false - ARROW_DCHECK(idx < indices_.size() && !indices_.at(idx)); - indices_.at(idx) = true; - } - - private: - std::mutex mutex_; - std::vector indices_; -}; -} // namespace - template struct HashSemiJoinNode : ExecNode { HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, std::string label, @@ -71,7 +36,6 @@ struct HashSemiJoinNode : ExecNode { {"hash_join_build", "hash_join_probe"}, probe_input->output_schema(), /*num_outputs=*/1), ctx_(ctx), - free_index_finder(nullptr), build_index_field_ids_(build_index_field_ids), probe_index_field_ids_(probe_index_field_ids), build_result_index(-1), @@ -85,7 +49,7 @@ struct HashSemiJoinNode : ExecNode { const char* kind_name() override { return "HashSemiJoinNode"; } Status InitLocalStateIfNeeded(ThreadLocalState* state) { - ARROW_LOG(WARNING) << "init state"; + ARROW_LOG(DEBUG) << "init state"; // Get input schema auto build_schema = inputs_[0]->output_schema(); @@ -108,26 +72,25 @@ struct HashSemiJoinNode : ExecNode { // Finds an appropriate index which could accumulate all build indices (i.e. the grouper // which has the highest # of groups) void CalculateBuildResultIndex() { - // int32_t curr_max = -1; - // for (int i = 0; i < static_cast(local_states_.size()); i++) { - // auto* state = &local_states_[i]; - // ARROW_DCHECK(state); - // if (state->grouper && - // curr_max < static_cast(state->grouper->num_groups())) { - // curr_max = static_cast(state->grouper->num_groups()); - // build_result_index = i; - // } - // } - // ARROW_DCHECK(build_result_index > -1); - // ARROW_LOG(WARNING) << "build_result_index " << build_result_index; - build_result_index = 0; + int32_t curr_max = -1; + for (int i = 0; i < static_cast(local_states_.size()); i++) { + auto* state = &local_states_[i]; + ARROW_DCHECK(state); + if (state->grouper && + curr_max < static_cast(state->grouper->num_groups())) { + curr_max = static_cast(state->grouper->num_groups()); + build_result_index = i; + } + } + ARROW_DCHECK(build_result_index > -1); + ARROW_LOG(DEBUG) << "build_result_index " << build_result_index; } // Performs the housekeeping work after the build-side is completed. // Note: this method is not thread safe, and hence should be guaranteed that it is // not accessed concurrently! Status BuildSideCompleted() { - ARROW_LOG(WARNING) << "build side merge"; + ARROW_LOG(DEBUG) << "build side merge"; // if the hash table has already been built, return if (hash_table_built_) return Status::OK(); @@ -162,13 +125,11 @@ struct HashSemiJoinNode : ExecNode { // total reached at the end of consumption, all the local states will be merged, before // incrementing the total batches Status ConsumeBuildBatch(ExecBatch batch) { - // size_t thread_index = get_thread_index_(); - // get a free index from the finder - int thread_index = free_index_finder->FindAvailableIndex(); - ARROW_DCHECK(static_cast(thread_index) < local_states_.size()); + size_t thread_index = get_thread_index_(); + ARROW_DCHECK(thread_index < local_states_.size()); - ARROW_LOG(WARNING) << "ConsumeBuildBatch tid:" << thread_index - << " len:" << batch.length; + ARROW_LOG(DEBUG) << "ConsumeBuildBatch tid:" << thread_index + << " len:" << batch.length; auto state = &local_states_[thread_index]; RETURN_NOT_OK(InitLocalStateIfNeeded(state)); @@ -184,8 +145,6 @@ struct HashSemiJoinNode : ExecNode { // TODO(niranda) replace with void consume method ARROW_ASSIGN_OR_RAISE(Datum _, state->grouper->Consume(key_batch)); - free_index_finder->ReleaseIndex(thread_index); - if (build_counter_.Increment()) { // only one thread would get inside this block! // while incrementing, if the total is reached, call BuildSideCompleted. @@ -197,8 +156,8 @@ struct HashSemiJoinNode : ExecNode { // consumes cached probe batches by invoking executor::Spawn. Status ConsumeCachedProbeBatches() { - ARROW_LOG(WARNING) << "ConsumeCachedProbeBatches tid:" /*<< get_thread_index_()*/ - << " len:" << cached_probe_batches.size(); + ARROW_LOG(DEBUG) << "ConsumeCachedProbeBatches tid:" << get_thread_index_() + << " len:" << cached_probe_batches.size(); // acquire the mutex to access cached_probe_batches, because while consuming, other // batches should not be cached! @@ -232,7 +191,7 @@ struct HashSemiJoinNode : ExecNode { Status GenerateOutput(int seq, const ArrayData& group_ids_data, ExecBatch batch) { if (group_ids_data.GetNullCount() == batch.length) { // All NULLS! hence, there are no valid outputs! - ARROW_LOG(WARNING) << "output seq:" << seq << " 0"; + ARROW_LOG(DEBUG) << "output seq:" << seq << " 0"; outputs_[0]->InputReceived(this, seq, batch.Slice(0, 0)); } else if (group_ids_data.MayHaveNulls()) { // values need to be filtered auto filter_arr = @@ -246,10 +205,10 @@ struct HashSemiJoinNode : ExecNode { Filter(rec_batch, filter_arr, /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); auto out_batch = ExecBatch(*filtered.record_batch()); - ARROW_LOG(WARNING) << "output seq:" << seq << " " << out_batch.length; + ARROW_LOG(DEBUG) << "output seq:" << seq << " " << out_batch.length; outputs_[0]->InputReceived(this, seq, std::move(out_batch)); } else { // all values are valid for output - ARROW_LOG(WARNING) << "output seq:" << seq << " " << batch.length; + ARROW_LOG(DEBUG) << "output seq:" << seq << " " << batch.length; outputs_[0]->InputReceived(this, seq, std::move(batch)); } @@ -259,7 +218,7 @@ struct HashSemiJoinNode : ExecNode { // consumes a probe batch and increment probe batches count. Probing would query the // grouper[build_result_index] which have been merged with all others. Status ConsumeProbeBatch(int seq, ExecBatch batch) { - ARROW_LOG(WARNING) << "ConsumeProbeBatch seq:" << seq; + ARROW_LOG(DEBUG) << "ConsumeProbeBatch seq:" << seq; auto& final_grouper = *local_states_[build_result_index].grouper; @@ -288,8 +247,8 @@ struct HashSemiJoinNode : ExecNode { // cached_probe_batches_mutex, it should no longer be cached! instead, it can be // directly consumed! bool AttemptToCacheProbeBatch(int seq_num, ExecBatch* batch) { - ARROW_LOG(WARNING) << "cache tid:" /*<< get_thread_index_() */ << " seq:" << seq_num - << " len:" << batch->length; + ARROW_LOG(DEBUG) << "cache tid:" << get_thread_index_() << " seq:" << seq_num + << " len:" << batch->length; std::lock_guard lck(cached_probe_batches_mutex); if (cached_probe_batches_consumed) { return false; @@ -303,8 +262,8 @@ struct HashSemiJoinNode : ExecNode { // If all build side batches received? continue streaming using probing // else cache the batches in thread-local state void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { - ARROW_LOG(WARNING) << "input received input:" << (IsBuildInput(input) ? "b" : "p") - << " seq:" << seq << " len:" << batch.length; + ARROW_LOG(DEBUG) << "input received input:" << (IsBuildInput(input) ? "b" : "p") + << " seq:" << seq << " len:" << batch.length; ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); @@ -334,7 +293,7 @@ struct HashSemiJoinNode : ExecNode { } void ErrorReceived(ExecNode* input, Status error) override { - ARROW_LOG(WARNING) << "error received " << error.ToString(); + ARROW_LOG(DEBUG) << "error received " << error.ToString(); DCHECK_EQ(input, inputs_[0]); outputs_[0]->ErrorReceived(this, std::move(error)); @@ -342,8 +301,8 @@ struct HashSemiJoinNode : ExecNode { } void InputFinished(ExecNode* input, int num_total) override { - ARROW_LOG(WARNING) << "input finished input:" << (IsBuildInput(input) ? "b" : "p") - << " tot:" << num_total; + ARROW_LOG(DEBUG) << "input finished input:" << (IsBuildInput(input) ? "b" : "p") + << " tot:" << num_total; // bail if StopProducing was called if (finished_.is_finished()) return; @@ -374,12 +333,10 @@ struct HashSemiJoinNode : ExecNode { } Status StartProducing() override { - ARROW_LOG(WARNING) << "start prod"; + ARROW_LOG(DEBUG) << "start prod"; finished_ = Future<>::Make(); local_states_.resize(ThreadIndexer::Capacity()); - free_index_finder = - arrow::internal::make_unique(ThreadIndexer::Capacity()); return Status::OK(); } @@ -388,7 +345,7 @@ struct HashSemiJoinNode : ExecNode { void ResumeProducing(ExecNode* output) override {} void StopProducing(ExecNode* output) override { - ARROW_LOG(WARNING) << "stop prod from node"; + ARROW_LOG(DEBUG) << "stop prod from node"; DCHECK_EQ(output, outputs_[0]); @@ -405,12 +362,12 @@ struct HashSemiJoinNode : ExecNode { // TODO(niranda) couldn't there be multiple outputs for a Node? void StopProducing() override { - ARROW_LOG(WARNING) << "stop prod "; + ARROW_LOG(DEBUG) << "stop prod "; outputs_[0]->StopProducing(); } Future<> finished() override { - ARROW_LOG(WARNING) << "finished? " << finished_.is_finished(); + ARROW_LOG(DEBUG) << "finished? " << finished_.is_finished(); return finished_; } @@ -422,12 +379,10 @@ struct HashSemiJoinNode : ExecNode { ExecContext* ctx_; Future<> finished_ = Future<>::MakeFinished(); - // ThreadIndexer get_thread_index_; - std::unique_ptr free_index_finder; + ThreadIndexer get_thread_index_; const std::vector build_index_field_ids_, probe_index_field_ids_; AtomicCounter build_counter_, out_counter_; - std::vector local_states_; // There's no guarantee on which threads would be coming from the build side. so, out of @@ -457,7 +412,7 @@ Status HashSemiJoinNode::GenerateOutput(int seq, const ArrayData& group_id ExecBatch batch) { if (group_ids_data.GetNullCount() == group_ids_data.length) { // All NULLS! hence, all values are valid for output - ARROW_LOG(WARNING) << "output seq:" << seq << " " << batch.length; + ARROW_LOG(DEBUG) << "output seq:" << seq << " " << batch.length; outputs_[0]->InputReceived(this, seq, std::move(batch)); } else if (group_ids_data.MayHaveNulls()) { // values need to be filtered // invert the validity buffer @@ -476,11 +431,11 @@ Status HashSemiJoinNode::GenerateOutput(int seq, const ArrayData& group_id Filter(rec_batch, filter_arr, /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); auto out_batch = ExecBatch(*filtered.record_batch()); - ARROW_LOG(WARNING) << "output seq:" << seq << " " << out_batch.length; + ARROW_LOG(DEBUG) << "output seq:" << seq << " " << out_batch.length; outputs_[0]->InputReceived(this, seq, std::move(out_batch)); } else { // No NULLS! hence, there are no valid outputs! - ARROW_LOG(WARNING) << "output seq:" << seq << " 0"; + ARROW_LOG(DEBUG) << "output seq:" << seq << " 0"; outputs_[0]->InputReceived(this, seq, batch.Slice(0, 0)); } return Status::OK(); From 5a89c174b0b687c0c4fa8727faa4192cf605214b Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 9 Aug 2021 13:15:20 -0400 Subject: [PATCH 22/27] ARROW-13482: [C++][Compute] Refactoring away from hard coded ExecNode factories to a registry - An extensible registry of exec node factories (`std::function(ExecPlan* plan, std::vector inputs, const ExecNodeOptions& options)>`) is provided - Hard coded factories like `compute::MakeSinkNode`, `dataset::MakeScanNode` are replaced by factories in the registry named "sink", "scan", etc - `arrow::compute::Declaration` is provided to represent an unconstructed set of `ExecNode`s, which can be validated and emplaced into an `ExecPlan` as a unit Closes #10793 from bkietz/exec-node-factory-registry Authored-by: Benjamin Kietzman Signed-off-by: Benjamin Kietzman --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/compute/exec/aggregate_node.cc | 6 +- cpp/src/arrow/compute/exec/exec_plan.h | 60 --- cpp/src/arrow/compute/exec/options.h | 3 + cpp/src/arrow/compute/exec/plan_test.cc | 25 + cpp/src/arrow/compute/exec/sink_node.cc | 82 +--- cpp/src/arrow/compute/exec/util.cc | 465 ++++++++++--------- 7 files changed, 269 insertions(+), 373 deletions(-) diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 0145ec83472..6293f1bcc05 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -372,6 +372,7 @@ if(ARROW_COMPUTE) compute/exec.cc compute/exec/aggregate_node.cc compute/exec/exec_plan.cc + compute/exec/exec_utils.cc compute/exec/expression.cc compute/exec/filter_node.cc compute/exec/project_node.cc diff --git a/cpp/src/arrow/compute/exec/aggregate_node.cc b/cpp/src/arrow/compute/exec/aggregate_node.cc index 131fd44ba87..0b416e9f47f 100644 --- a/cpp/src/arrow/compute/exec/aggregate_node.cc +++ b/cpp/src/arrow/compute/exec/aggregate_node.cc @@ -15,13 +15,12 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/compute/exec/exec_plan.h" - #include #include #include #include "arrow/compute/exec.h" +#include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/util.h" #include "arrow/compute/exec_internal.h" @@ -451,8 +450,7 @@ struct GroupByNode : ExecNode { // bail if StopProducing was called if (finished_.is_finished()) break; - auto plan = this->plan()->shared_from_this(); - RETURN_NOT_OK(executor->Spawn([plan, this, i] { OutputNthBatch(i); })); + RETURN_NOT_OK(executor->Spawn([this, i] { OutputNthBatch(i); })); } else { OutputNthBatch(i); } diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 540c0cf3482..4a784ceb75b 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -353,65 +353,5 @@ std::shared_ptr MakeGeneratorReader( std::shared_ptr, std::function>()>, MemoryPool*); -/// \brief Make a node which excludes some rows from batches passed through it -/// -/// The filter Expression will be evaluated against each batch which is pushed to -/// this node. Any rows for which the filter does not evaluate to `true` will be excluded -/// in the batch emitted by this node. -/// -/// If the filter is not already bound, it will be bound against the input's schema. -ARROW_EXPORT -Result MakeFilterNode(ExecNode* input, std::string label, Expression filter); - -/// \brief Make a node which executes expressions on input batches, producing new batches. -/// -/// Each expression will be evaluated against each batch which is pushed to -/// this node to produce a corresponding output column. -/// -/// If exprs are not already bound, they will be bound against the input's schema. -/// If names are not provided, the string representations of exprs will be used. -ARROW_EXPORT -Result MakeProjectNode(ExecNode* input, std::string label, - std::vector exprs, - std::vector names = {}); - -ARROW_EXPORT -Result MakeScalarAggregateNode(ExecNode* input, std::string label, - std::vector aggregates, - std::vector arguments, - std::vector out_field_names); - -/// \brief Make a node which groups input rows based on key fields and computes -/// aggregates for each group -ARROW_EXPORT -Result MakeGroupByNode(ExecNode* input, std::string label, - std::vector keys, - std::vector agg_srcs, - std::vector aggs); - -ARROW_EXPORT -Result GroupByUsingExecPlan(const std::vector& arguments, - const std::vector& keys, - const std::vector& aggregates, - bool use_threads, ExecContext* ctx); - -/// \brief Make a node which joins batches from two other nodes based on key fields -enum JoinType { - LEFT_SEMI, - RIGHT_SEMI, - LEFT_ANTI, - RIGHT_ANTI, - INNER, // Not Implemented - LEFT_OUTER, // Not Implemented - RIGHT_OUTER, // Not Implemented - FULL_OUTER // Not Implemented -}; - -ARROW_EXPORT -Result MakeHashJoinNode(JoinType join_type, ExecNode* left_input, - ExecNode* right_input, std::string label, - const std::vector& left_keys, - const std::vector& right_keys); - } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index acc79bdfdde..a54ab961814 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -112,6 +112,7 @@ class ARROW_EXPORT SinkNodeOptions : public ExecNodeOptions { std::function>()>* generator; }; +<<<<<<< HEAD /// \brief Make a node which sorts rows passed through it /// /// All batches pushed to this node will be accumulated, then sorted, by the given @@ -126,5 +127,7 @@ class ARROW_EXPORT OrderBySinkNodeOptions : public SinkNodeOptions { SortOptions sort_options; }; +======= +>>>>>>> ARROW-13482: [C++][Compute] Refactoring away from hard coded ExecNode factories to a registry } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index d37fd2fca34..5f50218677e 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -685,6 +685,7 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) { BatchesWithSchema scalar_data; scalar_data.batches = { +<<<<<<< HEAD ExecBatchFromJSON({ValueDescr::Scalar(int32()), ValueDescr::Scalar(boolean())}, "[[5, false], [5, false], [5, false]]"), ExecBatchFromJSON({int32(), boolean()}, "[[5, true], [6, false], [7, true]]")}; @@ -715,6 +716,30 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) { {"sink", SinkNodeOptions{&sink_gen}}, }) .AddToPlan(plan.get())); +======= + ExecBatchFromJSON({ValueDescr::Scalar(int32()), ValueDescr::Scalar(int32()), + ValueDescr::Scalar(int32())}, + "[[5, 5, 5], [5, 5, 5], [5, 5, 5]]"), + ExecBatchFromJSON({int32(), int32(), int32()}, + "[[5, 5, 5], [6, 6, 6], [7, 7, 7]]")}; + scalar_data.schema = + schema({field("a", int32()), field("b", int32()), field("c", int32())}); + + ASSERT_OK(Declaration::Sequence( + { + {"source", SourceNodeOptions{scalar_data.schema, + scalar_data.gen(/*parallel=*/false, + /*slow=*/false)}}, + {"aggregate", + AggregateNodeOptions{ + /*aggregates=*/{ + {"count", nullptr}, {"sum", nullptr}, {"mean", nullptr}}, + /*targets=*/{"a", "b", "c"}, + /*names=*/{"count(a)", "sum(b)", "mean(c)"}}}, + {"sink", SinkNodeOptions{&sink_gen}}, + }) + .AddToPlan(plan.get())); +>>>>>>> ARROW-13482: [C++][Compute] Refactoring away from hard coded ExecNode factories to a registry ASSERT_THAT( StartAndCollect(plan.get(), sink_gen), diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc index 4d9f82e582b..a8f21dbda1a 100644 --- a/cpp/src/arrow/compute/exec/sink_node.cc +++ b/cpp/src/arrow/compute/exec/sink_node.cc @@ -16,12 +16,11 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/compute/exec/exec_plan.h" - #include #include "arrow/compute/api_vector.h" #include "arrow/compute/exec.h" +#include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/expression.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/util.h" @@ -137,8 +136,8 @@ class SinkNode : public ExecNode { } } - protected: - virtual void Finish() { + private: + void Finish() { if (producer_.Close()) { finished_.MarkFinished(); } @@ -150,82 +149,7 @@ class SinkNode : public ExecNode { PushGenerator>::Producer producer_; }; -// A sink node that accumulates inputs, then sorts them before emitting them. -struct OrderBySinkNode final : public SinkNode { - OrderBySinkNode(ExecPlan* plan, std::vector inputs, SortOptions sort_options, - AsyncGenerator>* generator) - : SinkNode(plan, std::move(inputs), generator), - sort_options_(std::move(sort_options)) {} - - const char* kind_name() override { return "OrderBySinkNode"; } - - static Result Make(ExecPlan* plan, std::vector inputs, - const ExecNodeOptions& options) { - RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "OrderBySinkNode")); - - const auto& sink_options = checked_cast(options); - return plan->EmplaceNode( - plan, std::move(inputs), sink_options.sort_options, sink_options.generator); - } - - void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { - DCHECK_EQ(input, inputs_[0]); - - // Accumulate data - { - std::unique_lock lock(mutex_); - auto maybe_batch = batch.ToRecordBatch(inputs_[0]->output_schema(), - plan()->exec_context()->memory_pool()); - if (ErrorIfNotOk(maybe_batch.status())) return; - batches_.push_back(maybe_batch.MoveValueUnsafe()); - } - - if (input_counter_.Increment()) { - Finish(); - } - } - - protected: - Status DoFinish() { - Datum sorted; - { - std::unique_lock lock(mutex_); - ARROW_ASSIGN_OR_RAISE( - auto table, - Table::FromRecordBatches(inputs_[0]->output_schema(), std::move(batches_))); - ARROW_ASSIGN_OR_RAISE(auto indices, - SortIndices(table, sort_options_, plan()->exec_context())); - ARROW_ASSIGN_OR_RAISE(sorted, Take(table, indices, TakeOptions::NoBoundsCheck(), - plan()->exec_context())); - } - TableBatchReader reader(*sorted.table()); - while (true) { - std::shared_ptr batch; - RETURN_NOT_OK(reader.ReadNext(&batch)); - if (!batch) break; - bool did_push = producer_.Push(ExecBatch(*batch)); - if (!did_push) break; // producer_ was Closed already - } - return Status::OK(); - } - - void Finish() override { - Status st = DoFinish(); - if (ErrorIfNotOk(st)) { - producer_.Push(std::move(st)); - } - SinkNode::Finish(); - } - - private: - SortOptions sort_options_; - std::mutex mutex_; - std::vector> batches_; -}; - ExecFactoryRegistry::AddOnLoad kRegisterSink("sink", SinkNode::Make); -ExecFactoryRegistry::AddOnLoad kRegisterOrderBySink("order_by_sink", - OrderBySinkNode::Make); } // namespace } // namespace compute diff --git a/cpp/src/arrow/compute/exec/util.cc b/cpp/src/arrow/compute/exec/util.cc index aad6dc3d587..49bafc295fe 100644 --- a/cpp/src/arrow/compute/exec/util.cc +++ b/cpp/src/arrow/compute/exec/util.cc @@ -18,294 +18,299 @@ #include "arrow/compute/exec/util.h" #include "arrow/compute/exec/exec_plan.h" +<<<<<<< HEAD #include "arrow/table.h" -#include "arrow/util/bit_util.h" +======= +>>>>>>> ARROW-13482: [C++][Compute] Refactoring away from hard coded ExecNode factories to a registry + #include "arrow/util/bit_util.h" #include "arrow/util/bitmap_ops.h" #include "arrow/util/ubsan.h" -namespace arrow { + namespace arrow { -using BitUtil::CountTrailingZeros; + using BitUtil::CountTrailingZeros; -namespace util { + namespace util { -inline void BitUtil::bits_to_indexes_helper(uint64_t word, uint16_t base_index, - int* num_indexes, uint16_t* indexes) { - int n = *num_indexes; - while (word) { - indexes[n++] = base_index + static_cast(CountTrailingZeros(word)); - word &= word - 1; + inline void BitUtil::bits_to_indexes_helper(uint64_t word, uint16_t base_index, + int* num_indexes, uint16_t* indexes) { + int n = *num_indexes; + while (word) { + indexes[n++] = base_index + static_cast(CountTrailingZeros(word)); + word &= word - 1; + } + *num_indexes = n; } - *num_indexes = n; -} -inline void BitUtil::bits_filter_indexes_helper(uint64_t word, - const uint16_t* input_indexes, - int* num_indexes, uint16_t* indexes) { - int n = *num_indexes; - while (word) { - indexes[n++] = input_indexes[CountTrailingZeros(word)]; - word &= word - 1; + inline void BitUtil::bits_filter_indexes_helper(uint64_t word, + const uint16_t* input_indexes, + int* num_indexes, uint16_t* indexes) { + int n = *num_indexes; + while (word) { + indexes[n++] = input_indexes[CountTrailingZeros(word)]; + word &= word - 1; + } + *num_indexes = n; } - *num_indexes = n; -} -template -void BitUtil::bits_to_indexes_internal(int64_t hardware_flags, const int num_bits, - const uint8_t* bits, const uint16_t* input_indexes, - int* num_indexes, uint16_t* indexes) { - // 64 bits at a time - constexpr int unroll = 64; - int tail = num_bits % unroll; + template + void BitUtil::bits_to_indexes_internal(int64_t hardware_flags, const int num_bits, + const uint8_t* bits, + const uint16_t* input_indexes, int* num_indexes, + uint16_t* indexes) { + // 64 bits at a time + constexpr int unroll = 64; + int tail = num_bits % unroll; #if defined(ARROW_HAVE_AVX2) - if (hardware_flags & arrow::internal::CpuInfo::AVX2) { - if (filter_input_indexes) { - bits_filter_indexes_avx2(bit_to_search, num_bits - tail, bits, input_indexes, - num_indexes, indexes); + if (hardware_flags & arrow::internal::CpuInfo::AVX2) { + if (filter_input_indexes) { + bits_filter_indexes_avx2(bit_to_search, num_bits - tail, bits, input_indexes, + num_indexes, indexes); + } else { + bits_to_indexes_avx2(bit_to_search, num_bits - tail, bits, num_indexes, indexes); + } } else { - bits_to_indexes_avx2(bit_to_search, num_bits - tail, bits, num_indexes, indexes); +#endif + *num_indexes = 0; + for (int i = 0; i < num_bits / unroll; ++i) { + uint64_t word = util::SafeLoad(&reinterpret_cast(bits)[i]); + if (bit_to_search == 0) { + word = ~word; + } + if (filter_input_indexes) { + bits_filter_indexes_helper(word, input_indexes + i * 64, num_indexes, indexes); + } else { + bits_to_indexes_helper(word, i * 64, num_indexes, indexes); + } + } +#if defined(ARROW_HAVE_AVX2) } - } else { #endif - *num_indexes = 0; - for (int i = 0; i < num_bits / unroll; ++i) { - uint64_t word = util::SafeLoad(&reinterpret_cast(bits)[i]); + // Optionally process the last partial word with masking out bits outside range + if (tail) { + uint64_t word = + util::SafeLoad(&reinterpret_cast(bits)[num_bits / unroll]); if (bit_to_search == 0) { word = ~word; } + word &= ~0ULL >> (64 - tail); if (filter_input_indexes) { - bits_filter_indexes_helper(word, input_indexes + i * 64, num_indexes, indexes); + bits_filter_indexes_helper(word, input_indexes + num_bits - tail, num_indexes, + indexes); } else { - bits_to_indexes_helper(word, i * 64, num_indexes, indexes); + bits_to_indexes_helper(word, num_bits - tail, num_indexes, indexes); } } -#if defined(ARROW_HAVE_AVX2) } -#endif - // Optionally process the last partial word with masking out bits outside range - if (tail) { - uint64_t word = - util::SafeLoad(&reinterpret_cast(bits)[num_bits / unroll]); - if (bit_to_search == 0) { - word = ~word; + + void BitUtil::bits_to_indexes(int bit_to_search, int64_t hardware_flags, + const int num_bits, const uint8_t* bits, int* num_indexes, + uint16_t* indexes, int bit_offset) { + bits += bit_offset / 8; + bit_offset %= 8; + if (bit_offset != 0) { + int num_indexes_head = 0; + uint64_t bits_head = + util::SafeLoad(reinterpret_cast(bits)) >> bit_offset; + int bits_in_first_byte = std::min(num_bits, 8 - bit_offset); + bits_to_indexes(bit_to_search, hardware_flags, bits_in_first_byte, + reinterpret_cast(&bits_head), &num_indexes_head, + indexes); + int num_indexes_tail = 0; + if (num_bits > bits_in_first_byte) { + bits_to_indexes(bit_to_search, hardware_flags, num_bits - bits_in_first_byte, + bits + 1, &num_indexes_tail, indexes + num_indexes_head); + } + *num_indexes = num_indexes_head + num_indexes_tail; + return; } - word &= ~0ULL >> (64 - tail); - if (filter_input_indexes) { - bits_filter_indexes_helper(word, input_indexes + num_bits - tail, num_indexes, - indexes); + + if (bit_to_search == 0) { + bits_to_indexes_internal<0, false>(hardware_flags, num_bits, bits, nullptr, + num_indexes, indexes); } else { - bits_to_indexes_helper(word, num_bits - tail, num_indexes, indexes); + ARROW_DCHECK(bit_to_search == 1); + bits_to_indexes_internal<1, false>(hardware_flags, num_bits, bits, nullptr, + num_indexes, indexes); } } -} -void BitUtil::bits_to_indexes(int bit_to_search, int64_t hardware_flags, - const int num_bits, const uint8_t* bits, int* num_indexes, - uint16_t* indexes, int bit_offset) { - bits += bit_offset / 8; - bit_offset %= 8; - if (bit_offset != 0) { - int num_indexes_head = 0; - uint64_t bits_head = - util::SafeLoad(reinterpret_cast(bits)) >> bit_offset; - int bits_in_first_byte = std::min(num_bits, 8 - bit_offset); - bits_to_indexes(bit_to_search, hardware_flags, bits_in_first_byte, - reinterpret_cast(&bits_head), &num_indexes_head, - indexes); - int num_indexes_tail = 0; - if (num_bits > bits_in_first_byte) { - bits_to_indexes(bit_to_search, hardware_flags, num_bits - bits_in_first_byte, - bits + 1, &num_indexes_tail, indexes + num_indexes_head); + void BitUtil::bits_filter_indexes(int bit_to_search, int64_t hardware_flags, + const int num_bits, const uint8_t* bits, + const uint16_t* input_indexes, int* num_indexes, + uint16_t* indexes, int bit_offset) { + bits += bit_offset / 8; + bit_offset %= 8; + if (bit_offset != 0) { + int num_indexes_head = 0; + uint64_t bits_head = + util::SafeLoad(reinterpret_cast(bits)) >> bit_offset; + int bits_in_first_byte = std::min(num_bits, 8 - bit_offset); + bits_filter_indexes(bit_to_search, hardware_flags, bits_in_first_byte, + reinterpret_cast(&bits_head), input_indexes, + &num_indexes_head, indexes); + int num_indexes_tail = 0; + if (num_bits > bits_in_first_byte) { + bits_filter_indexes(bit_to_search, hardware_flags, num_bits - bits_in_first_byte, + bits + 1, input_indexes + bits_in_first_byte, + &num_indexes_tail, indexes + num_indexes_head); + } + *num_indexes = num_indexes_head + num_indexes_tail; + return; } - *num_indexes = num_indexes_head + num_indexes_tail; - return; - } - if (bit_to_search == 0) { - bits_to_indexes_internal<0, false>(hardware_flags, num_bits, bits, nullptr, - num_indexes, indexes); - } else { - ARROW_DCHECK(bit_to_search == 1); - bits_to_indexes_internal<1, false>(hardware_flags, num_bits, bits, nullptr, - num_indexes, indexes); - } -} - -void BitUtil::bits_filter_indexes(int bit_to_search, int64_t hardware_flags, - const int num_bits, const uint8_t* bits, - const uint16_t* input_indexes, int* num_indexes, - uint16_t* indexes, int bit_offset) { - bits += bit_offset / 8; - bit_offset %= 8; - if (bit_offset != 0) { - int num_indexes_head = 0; - uint64_t bits_head = - util::SafeLoad(reinterpret_cast(bits)) >> bit_offset; - int bits_in_first_byte = std::min(num_bits, 8 - bit_offset); - bits_filter_indexes(bit_to_search, hardware_flags, bits_in_first_byte, - reinterpret_cast(&bits_head), input_indexes, - &num_indexes_head, indexes); - int num_indexes_tail = 0; - if (num_bits > bits_in_first_byte) { - bits_filter_indexes(bit_to_search, hardware_flags, num_bits - bits_in_first_byte, - bits + 1, input_indexes + bits_in_first_byte, &num_indexes_tail, - indexes + num_indexes_head); + if (bit_to_search == 0) { + bits_to_indexes_internal<0, true>(hardware_flags, num_bits, bits, input_indexes, + num_indexes, indexes); + } else { + ARROW_DCHECK(bit_to_search == 1); + bits_to_indexes_internal<1, true>(hardware_flags, num_bits, bits, input_indexes, + num_indexes, indexes); } - *num_indexes = num_indexes_head + num_indexes_tail; - return; } - if (bit_to_search == 0) { - bits_to_indexes_internal<0, true>(hardware_flags, num_bits, bits, input_indexes, - num_indexes, indexes); - } else { - ARROW_DCHECK(bit_to_search == 1); - bits_to_indexes_internal<1, true>(hardware_flags, num_bits, bits, input_indexes, - num_indexes, indexes); + void BitUtil::bits_split_indexes(int64_t hardware_flags, const int num_bits, + const uint8_t* bits, int* num_indexes_bit0, + uint16_t* indexes_bit0, uint16_t* indexes_bit1, + int bit_offset) { + bits_to_indexes(0, hardware_flags, num_bits, bits, num_indexes_bit0, indexes_bit0, + bit_offset); + int num_indexes_bit1; + bits_to_indexes(1, hardware_flags, num_bits, bits, &num_indexes_bit1, indexes_bit1, + bit_offset); } -} - -void BitUtil::bits_split_indexes(int64_t hardware_flags, const int num_bits, - const uint8_t* bits, int* num_indexes_bit0, - uint16_t* indexes_bit0, uint16_t* indexes_bit1, - int bit_offset) { - bits_to_indexes(0, hardware_flags, num_bits, bits, num_indexes_bit0, indexes_bit0, - bit_offset); - int num_indexes_bit1; - bits_to_indexes(1, hardware_flags, num_bits, bits, &num_indexes_bit1, indexes_bit1, - bit_offset); -} -void BitUtil::bits_to_bytes(int64_t hardware_flags, const int num_bits, - const uint8_t* bits, uint8_t* bytes, int bit_offset) { - bits += bit_offset / 8; - bit_offset %= 8; - if (bit_offset != 0) { - uint64_t bits_head = - util::SafeLoad(reinterpret_cast(bits)) >> bit_offset; - int bits_in_first_byte = std::min(num_bits, 8 - bit_offset); - bits_to_bytes(hardware_flags, bits_in_first_byte, - reinterpret_cast(&bits_head), bytes); - if (num_bits > bits_in_first_byte) { - bits_to_bytes(hardware_flags, num_bits - bits_in_first_byte, bits + 1, - bytes + bits_in_first_byte); + void BitUtil::bits_to_bytes(int64_t hardware_flags, const int num_bits, + const uint8_t* bits, uint8_t* bytes, int bit_offset) { + bits += bit_offset / 8; + bit_offset %= 8; + if (bit_offset != 0) { + uint64_t bits_head = + util::SafeLoad(reinterpret_cast(bits)) >> bit_offset; + int bits_in_first_byte = std::min(num_bits, 8 - bit_offset); + bits_to_bytes(hardware_flags, bits_in_first_byte, + reinterpret_cast(&bits_head), bytes); + if (num_bits > bits_in_first_byte) { + bits_to_bytes(hardware_flags, num_bits - bits_in_first_byte, bits + 1, + bytes + bits_in_first_byte); + } + return; } - return; - } - int num_processed = 0; + int num_processed = 0; #if defined(ARROW_HAVE_AVX2) - if (hardware_flags & arrow::internal::CpuInfo::AVX2) { - // The function call below processes whole 32 bit chunks together. - num_processed = num_bits - (num_bits % 32); - bits_to_bytes_avx2(num_processed, bits, bytes); - } + if (hardware_flags & arrow::internal::CpuInfo::AVX2) { + // The function call below processes whole 32 bit chunks together. + num_processed = num_bits - (num_bits % 32); + bits_to_bytes_avx2(num_processed, bits, bytes); + } #endif - // Processing 8 bits at a time - constexpr int unroll = 8; - for (int i = num_processed / unroll; i < (num_bits + unroll - 1) / unroll; ++i) { - uint8_t bits_next = bits[i]; - // Clear the lowest bit and then make 8 copies of remaining 7 bits, each 7 bits apart - // from the previous. - uint64_t unpacked = static_cast(bits_next & 0xfe) * - ((1ULL << 7) | (1ULL << 14) | (1ULL << 21) | (1ULL << 28) | - (1ULL << 35) | (1ULL << 42) | (1ULL << 49)); - unpacked |= (bits_next & 1); - unpacked &= 0x0101010101010101ULL; - unpacked *= 255; - util::SafeStore(&reinterpret_cast(bytes)[i], unpacked); + // Processing 8 bits at a time + constexpr int unroll = 8; + for (int i = num_processed / unroll; i < (num_bits + unroll - 1) / unroll; ++i) { + uint8_t bits_next = bits[i]; + // Clear the lowest bit and then make 8 copies of remaining 7 bits, each 7 bits + // apart from the previous. + uint64_t unpacked = static_cast(bits_next & 0xfe) * + ((1ULL << 7) | (1ULL << 14) | (1ULL << 21) | (1ULL << 28) | + (1ULL << 35) | (1ULL << 42) | (1ULL << 49)); + unpacked |= (bits_next & 1); + unpacked &= 0x0101010101010101ULL; + unpacked *= 255; + util::SafeStore(&reinterpret_cast(bytes)[i], unpacked); + } } -} -void BitUtil::bytes_to_bits(int64_t hardware_flags, const int num_bits, - const uint8_t* bytes, uint8_t* bits, int bit_offset) { - bits += bit_offset / 8; - bit_offset %= 8; - if (bit_offset != 0) { - uint64_t bits_head; - int bits_in_first_byte = std::min(num_bits, 8 - bit_offset); - bytes_to_bits(hardware_flags, bits_in_first_byte, bytes, - reinterpret_cast(&bits_head)); - uint8_t mask = (1 << bit_offset) - 1; - *bits = static_cast((*bits & mask) | (bits_head << bit_offset)); + void BitUtil::bytes_to_bits(int64_t hardware_flags, const int num_bits, + const uint8_t* bytes, uint8_t* bits, int bit_offset) { + bits += bit_offset / 8; + bit_offset %= 8; + if (bit_offset != 0) { + uint64_t bits_head; + int bits_in_first_byte = std::min(num_bits, 8 - bit_offset); + bytes_to_bits(hardware_flags, bits_in_first_byte, bytes, + reinterpret_cast(&bits_head)); + uint8_t mask = (1 << bit_offset) - 1; + *bits = static_cast((*bits & mask) | (bits_head << bit_offset)); - if (num_bits > bits_in_first_byte) { - bytes_to_bits(hardware_flags, num_bits - bits_in_first_byte, - bytes + bits_in_first_byte, bits + 1); + if (num_bits > bits_in_first_byte) { + bytes_to_bits(hardware_flags, num_bits - bits_in_first_byte, + bytes + bits_in_first_byte, bits + 1); + } + return; } - return; - } - int num_processed = 0; + int num_processed = 0; #if defined(ARROW_HAVE_AVX2) - if (hardware_flags & arrow::internal::CpuInfo::AVX2) { - // The function call below processes whole 32 bit chunks together. - num_processed = num_bits - (num_bits % 32); - bytes_to_bits_avx2(num_processed, bytes, bits); - } + if (hardware_flags & arrow::internal::CpuInfo::AVX2) { + // The function call below processes whole 32 bit chunks together. + num_processed = num_bits - (num_bits % 32); + bytes_to_bits_avx2(num_processed, bytes, bits); + } #endif - // Process 8 bits at a time - constexpr int unroll = 8; - for (int i = num_processed / unroll; i < (num_bits + unroll - 1) / unroll; ++i) { - uint64_t bytes_next = util::SafeLoad(&reinterpret_cast(bytes)[i]); - bytes_next &= 0x0101010101010101ULL; - bytes_next |= (bytes_next >> 7); // Pairs of adjacent output bits in individual bytes - bytes_next |= (bytes_next >> 14); // 4 adjacent output bits in individual bytes - bytes_next |= (bytes_next >> 28); // All 8 output bits in the lowest byte - bits[i] = static_cast(bytes_next & 0xff); + // Process 8 bits at a time + constexpr int unroll = 8; + for (int i = num_processed / unroll; i < (num_bits + unroll - 1) / unroll; ++i) { + uint64_t bytes_next = util::SafeLoad(&reinterpret_cast(bytes)[i]); + bytes_next &= 0x0101010101010101ULL; + bytes_next |= + (bytes_next >> 7); // Pairs of adjacent output bits in individual bytes + bytes_next |= (bytes_next >> 14); // 4 adjacent output bits in individual bytes + bytes_next |= (bytes_next >> 28); // All 8 output bits in the lowest byte + bits[i] = static_cast(bytes_next & 0xff); + } } -} -bool BitUtil::are_all_bytes_zero(int64_t hardware_flags, const uint8_t* bytes, - uint32_t num_bytes) { + bool BitUtil::are_all_bytes_zero(int64_t hardware_flags, const uint8_t* bytes, + uint32_t num_bytes) { #if defined(ARROW_HAVE_AVX2) - if (hardware_flags & arrow::internal::CpuInfo::AVX2) { - return are_all_bytes_zero_avx2(bytes, num_bytes); - } + if (hardware_flags & arrow::internal::CpuInfo::AVX2) { + return are_all_bytes_zero_avx2(bytes, num_bytes); + } #endif - uint64_t result_or = 0; - uint32_t i; - for (i = 0; i < num_bytes / 8; ++i) { - uint64_t x = util::SafeLoad(&reinterpret_cast(bytes)[i]); - result_or |= x; - } - if (num_bytes % 8 > 0) { - uint64_t tail = 0; - result_or |= memcmp(bytes + i * 8, &tail, num_bytes % 8); + uint64_t result_or = 0; + uint32_t i; + for (i = 0; i < num_bytes / 8; ++i) { + uint64_t x = util::SafeLoad(&reinterpret_cast(bytes)[i]); + result_or |= x; + } + if (num_bytes % 8 > 0) { + uint64_t tail = 0; + result_or |= memcmp(bytes + i * 8, &tail, num_bytes % 8); + } + return result_or == 0; } - return result_or == 0; -} -} // namespace util + } // namespace util -namespace compute { + namespace compute { -Status ValidateExecNodeInputs(ExecPlan* plan, const std::vector& inputs, - int expected_num_inputs, const char* kind_name) { - if (static_cast(inputs.size()) != expected_num_inputs) { - return Status::Invalid(kind_name, " node requires ", expected_num_inputs, - " inputs but got ", inputs.size()); - } + Status ValidateExecNodeInputs(ExecPlan* plan, const std::vector& inputs, + int expected_num_inputs, const char* kind_name) { + if (static_cast(inputs.size()) != expected_num_inputs) { + return Status::Invalid(kind_name, " node requires ", expected_num_inputs, + " inputs but got ", inputs.size()); + } - for (auto input : inputs) { - if (input->plan() != plan) { - return Status::Invalid("Constructing a ", kind_name, - " node in a different plan from its input"); + for (auto input : inputs) { + if (input->plan() != plan) { + return Status::Invalid("Constructing a ", kind_name, + " node in a different plan from its input"); + } } - } - return Status::OK(); -} + return Status::OK(); + } -Result> TableFromExecBatches( - const std::shared_ptr& schema, const std::vector& exec_batches) { - RecordBatchVector batches; - for (const auto& batch : exec_batches) { - ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToRecordBatch(schema)); - batches.push_back(std::move(rb)); + Result> TableFromExecBatches( + const std::shared_ptr& schema, const std::vector& exec_batches) { + RecordBatchVector batches; + for (const auto& batch : exec_batches) { + ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToRecordBatch(schema)); + batches.push_back(std::move(rb)); + } + return Table::FromRecordBatches(schema, batches); } - return Table::FromRecordBatches(schema, batches); -} -} // namespace compute + } // namespace compute } // namespace arrow From 0e8795a9728ec2b4005cd710564e2680721059a9 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 10 Aug 2021 15:33:53 -0400 Subject: [PATCH 23/27] refactoring to new API --- cpp/src/arrow/compute/exec/aggregate_node.cc | 28 --- cpp/src/arrow/compute/exec/exec_utils.cc | 78 --------- cpp/src/arrow/compute/exec/exec_utils.h | 68 ------- cpp/src/arrow/compute/exec/hash_join.cc | 175 ++++++++++--------- cpp/src/arrow/compute/exec/options.h | 30 +++- cpp/src/arrow/compute/exec/plan_test.cc | 25 --- cpp/src/arrow/compute/exec/test_util.cc | 34 ---- cpp/src/arrow/compute/exec/test_util.h | 40 ++++- cpp/src/arrow/compute/exec/util.cc | 22 +++ cpp/src/arrow/compute/exec/util.h | 16 ++ 10 files changed, 195 insertions(+), 321 deletions(-) delete mode 100644 cpp/src/arrow/compute/exec/exec_utils.cc delete mode 100644 cpp/src/arrow/compute/exec/exec_utils.h diff --git a/cpp/src/arrow/compute/exec/aggregate_node.cc b/cpp/src/arrow/compute/exec/aggregate_node.cc index 0b416e9f47f..d371a6fe548 100644 --- a/cpp/src/arrow/compute/exec/aggregate_node.cc +++ b/cpp/src/arrow/compute/exec/aggregate_node.cc @@ -58,34 +58,6 @@ Result ResolveKernels( namespace { -class ThreadIndexer { - public: - size_t operator()() { - auto id = std::this_thread::get_id(); - - std::unique_lock lock(mutex_); - const auto& id_index = *id_to_index_.emplace(id, id_to_index_.size()).first; - - return Check(id_index.second); - } - - static size_t Capacity() { - static size_t max_size = arrow::internal::ThreadPool::DefaultCapacity(); - return max_size; - } - - private: - size_t Check(size_t thread_index) { - DCHECK_LT(thread_index, Capacity()) << "thread index " << thread_index - << " is out of range [0, " << Capacity() << ")"; - - return thread_index; - } - - std::mutex mutex_; - std::unordered_map id_to_index_; -}; - struct ScalarAggregateNode : ExecNode { ScalarAggregateNode(ExecPlan* plan, std::vector inputs, std::shared_ptr output_schema, diff --git a/cpp/src/arrow/compute/exec/exec_utils.cc b/cpp/src/arrow/compute/exec/exec_utils.cc deleted file mode 100644 index 7026351e0b7..00000000000 --- a/cpp/src/arrow/compute/exec/exec_utils.cc +++ /dev/null @@ -1,78 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/compute/exec/exec_utils.h" - -#include "arrow/util/logging.h" - -namespace arrow { -namespace compute { - -size_t ThreadIndexer::operator()() { - auto id = std::this_thread::get_id(); - - auto guard = mutex_.Lock(); // acquire the lock - const auto& id_index = *id_to_index_.emplace(id, id_to_index_.size()).first; - - return Check(id_index.second); -} - -size_t ThreadIndexer::Capacity() { - static size_t max_size = arrow::internal::ThreadPool::DefaultCapacity(); - return max_size; -} - -size_t ThreadIndexer::Check(size_t thread_index) { - DCHECK_LT(thread_index, Capacity()) - << "thread index " << thread_index << " is out of range [0, " << Capacity() << ")"; - - return thread_index; -} - -int AtomicCounter::count() const { return count_.load(); } - -util::optional AtomicCounter::total() const { - int total = total_.load(); - if (total == -1) return {}; - return total; -} - -bool AtomicCounter::Increment() { - DCHECK_NE(count_.load(), total_.load()); - int count = count_.fetch_add(1) + 1; - if (count != total_.load()) return false; - return DoneOnce(); -} - -// return true if the counter is complete -bool AtomicCounter::SetTotal(int total) { - total_.store(total); - if (count_.load() != total) return false; - return DoneOnce(); -} - -// return true if the counter has not already been completed -bool AtomicCounter::Cancel() { return DoneOnce(); } - -// ensure there is only one true return from Increment(), SetTotal(), or Cancel() -bool AtomicCounter::DoneOnce() { - bool expected = false; - return complete_.compare_exchange_strong(expected, true); -} - -} // namespace compute -} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/exec_utils.h b/cpp/src/arrow/compute/exec/exec_utils.h deleted file mode 100644 index 93cd0775098..00000000000 --- a/cpp/src/arrow/compute/exec/exec_utils.h +++ /dev/null @@ -1,68 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include -#include - -#include "arrow/util/mutex.h" -#include "arrow/util/thread_pool.h" - -namespace arrow { -namespace compute { - -class ThreadIndexer { - public: - size_t operator()(); - - static size_t Capacity(); - - private: - static size_t Check(size_t thread_index); - - util::Mutex mutex_; - std::unordered_map id_to_index_; -}; - -class AtomicCounter { - public: - AtomicCounter() = default; - - int count() const; - - util::optional total() const; - - // return true if the counter is complete - bool Increment(); - - // return true if the counter is complete - bool SetTotal(int total); - - // return true if the counter has not already been completed - bool Cancel(); - - private: - // ensure there is only one true return from Increment(), SetTotal(), or Cancel() - bool DoneOnce(); - - std::atomic count_{0}, total_{-1}; - std::atomic complete_{false}; -}; - -} // namespace compute -} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index 95db8f254cc..8bffe826dad 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -20,19 +20,60 @@ #include "arrow/api.h" #include "arrow/compute/api.h" #include "arrow/compute/exec/exec_plan.h" -#include "arrow/compute/exec/exec_utils.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/util.h" #include "arrow/util/bitmap_ops.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" #include "arrow/util/logging.h" +#include "arrow/util/thread_pool.h" namespace arrow { + +using internal::checked_cast; + namespace compute { +namespace { +Status ValidateJoinInputs(const std::shared_ptr& left_schema, + const std::shared_ptr& right_schema, + const std::vector& left_keys, + const std::vector& right_keys) { + if (left_keys.size() != right_keys.size()) { + return Status::Invalid("left and right key sizes do not match"); + } + + for (size_t i = 0; i < left_keys.size(); i++) { + auto l_type = left_schema->field(left_keys[i])->type(); + auto r_type = right_schema->field(right_keys[i])->type(); + + if (!l_type->Equals(r_type)) { + return Status::Invalid("build and probe types do not match: " + l_type->ToString() + + "!=" + r_type->ToString()); + } + } + + return Status::OK(); +} + +Result> PopulateKeys(const Schema& schema, + const std::vector& keys) { + std::vector key_field_ids(keys.size()); + // Find input field indices for left key fields + for (size_t i = 0; i < keys.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema)); + key_field_ids[i] = match[0]; + } + return key_field_ids; +} +} // namespace + template struct HashSemiJoinNode : ExecNode { - HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, std::string label, - ExecContext* ctx, const std::vector&& build_index_field_ids, + HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, ExecContext* ctx, + const std::vector&& build_index_field_ids, const std::vector&& probe_index_field_ids) - : ExecNode(build_input->plan(), std::move(label), {build_input, probe_input}, + : ExecNode(build_input->plan(), {build_input, probe_input}, {"hash_join_build", "hash_join_probe"}, probe_input->output_schema(), /*num_outputs=*/1), ctx_(ctx), @@ -441,97 +482,71 @@ Status HashSemiJoinNode::GenerateOutput(int seq, const ArrayData& group_id return Status::OK(); } -Status ValidateJoinInputs(ExecNode* left_input, ExecNode* right_input, - const std::vector& left_keys, - const std::vector& right_keys) { - if (left_keys.size() != right_keys.size()) { - return Status::Invalid("left and right key sizes do not match"); - } - - const auto& l_schema = left_input->output_schema(); - const auto& r_schema = right_input->output_schema(); - for (size_t i = 0; i < left_keys.size(); i++) { - auto l_type = l_schema->GetFieldByName(left_keys[i])->type(); - auto r_type = r_schema->GetFieldByName(right_keys[i])->type(); - - if (!l_type->Equals(r_type)) { - return Status::Invalid("build and probe types do not match: " + l_type->ToString() + - "!=" + r_type->ToString()); - } - } - - return Status::OK(); -} - -Result> PopulateKeys(const Schema& schema, - const std::vector& keys) { - std::vector key_field_ids(keys.size()); - // Find input field indices for left key fields - for (size_t i = 0; i < keys.size(); ++i) { - ARROW_ASSIGN_OR_RAISE(auto match, FieldRef(keys[i]).FindOne(schema)); - key_field_ids[i] = match[0]; - } - - return key_field_ids; -} - template Result MakeHashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, - std::string label, - const std::vector& build_keys, - const std::vector& probe_keys) { - RETURN_NOT_OK(ValidateJoinInputs(build_input, probe_input, build_keys, probe_keys)); - + const std::vector& build_keys, + const std::vector& probe_keys) { auto build_schema = build_input->output_schema(); auto probe_schema = probe_input->output_schema(); ARROW_ASSIGN_OR_RAISE(auto build_key_ids, PopulateKeys(*build_schema, build_keys)); ARROW_ASSIGN_OR_RAISE(auto probe_key_ids, PopulateKeys(*probe_schema, probe_keys)); + RETURN_NOT_OK( + ValidateJoinInputs(build_schema, probe_schema, build_key_ids, probe_key_ids)); + // output schema will be probe schema auto ctx = build_input->plan()->exec_context(); ExecPlan* plan = build_input->plan(); return plan->EmplaceNode>( - build_input, probe_input, std::move(label), ctx, std::move(build_key_ids), - std::move(probe_key_ids)); + build_input, probe_input, ctx, std::move(build_key_ids), std::move(probe_key_ids)); } -Result MakeHashJoinNode(JoinType join_type, ExecNode* left_input, - ExecNode* right_input, std::string label, - const std::vector& left_keys, - const std::vector& right_keys) { - static std::string join_type_string[] = {"LEFT_SEMI", "RIGHT_SEMI", "LEFT_ANTI", - "RIGHT_ANTI", "INNER", "LEFT_OUTER", - "RIGHT_OUTER", "FULL_OUTER"}; - - switch (join_type) { - case LEFT_SEMI: - // left join--> build from right and probe from left - return MakeHashSemiJoinNode(right_input, left_input, std::move(label), right_keys, - left_keys); - case RIGHT_SEMI: - // right join--> build from left and probe from right - return MakeHashSemiJoinNode(left_input, right_input, std::move(label), left_keys, - right_keys); - case LEFT_ANTI: - // left join--> build from right and probe from left - return MakeHashSemiJoinNode(right_input, left_input, std::move(label), - right_keys, left_keys); - case RIGHT_ANTI: - // right join--> build from left and probe from right - return MakeHashSemiJoinNode(left_input, right_input, std::move(label), - left_keys, right_keys); - case INNER: - case LEFT_OUTER: - case RIGHT_OUTER: - case FULL_OUTER: - return Status::NotImplemented(join_type_string[join_type] + - " joins not implemented!"); - default: - return Status::Invalid("invalid join type"); - } -} +ExecFactoryRegistry::AddOnLoad kRegisterHashJoin( + "hash_join", + [](ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) -> Result { + RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 2, "HashJoinNode")); + + const auto& join_options = checked_cast(options); + + static std::string join_type_string[] = {"LEFT_SEMI", "RIGHT_SEMI", "LEFT_ANTI", + "RIGHT_ANTI", "INNER", "LEFT_OUTER", + "RIGHT_OUTER", "FULL_OUTER"}; + + auto join_type = join_options.join_type; + + ExecNode* left_input = inputs[0]; + ExecNode* right_input = inputs[1]; + const auto& left_keys = join_options.left_keys; + const auto& right_keys = join_options.right_keys; + + switch (join_type) { + case LEFT_SEMI: + // left join--> build from right and probe from left + return MakeHashSemiJoinNode(right_input, left_input, right_keys, left_keys); + case RIGHT_SEMI: + // right join--> build from left and probe from right + return MakeHashSemiJoinNode(left_input, right_input, left_keys, right_keys); + case LEFT_ANTI: + // left join--> build from right and probe from left + return MakeHashSemiJoinNode(right_input, left_input, right_keys, + left_keys); + case RIGHT_ANTI: + // right join--> build from left and probe from right + return MakeHashSemiJoinNode(left_input, right_input, left_keys, + right_keys); + case INNER: + case LEFT_OUTER: + case RIGHT_OUTER: + case FULL_OUTER: + return Status::NotImplemented(join_type_string[join_type] + + " joins not implemented!"); + default: + return Status::Invalid("invalid join type"); + } + }); } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index a54ab961814..b103120b4d2 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -112,7 +112,6 @@ class ARROW_EXPORT SinkNodeOptions : public ExecNodeOptions { std::function>()>* generator; }; -<<<<<<< HEAD /// \brief Make a node which sorts rows passed through it /// /// All batches pushed to this node will be accumulated, then sorted, by the given @@ -127,7 +126,32 @@ class ARROW_EXPORT OrderBySinkNodeOptions : public SinkNodeOptions { SortOptions sort_options; }; -======= ->>>>>>> ARROW-13482: [C++][Compute] Refactoring away from hard coded ExecNode factories to a registry +enum JoinType { + LEFT_SEMI, + RIGHT_SEMI, + LEFT_ANTI, + RIGHT_ANTI, + INNER, // Not Implemented + LEFT_OUTER, // Not Implemented + RIGHT_OUTER, // Not Implemented + FULL_OUTER // Not Implemented +}; + +class ARROW_EXPORT JoinNodeOptions : public ExecNodeOptions { + public: + JoinNodeOptions(JoinType join_type, std::vector left_keys, + std::vector right_keys) + : join_type(join_type), + left_keys(std::move(left_keys)), + right_keys(std::move(right_keys)) {} + + // type of the join + JoinType join_type; + + // index keys of the join + std::vector left_keys; + std::vector right_keys; +}; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 5f50218677e..d37fd2fca34 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -685,7 +685,6 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) { BatchesWithSchema scalar_data; scalar_data.batches = { -<<<<<<< HEAD ExecBatchFromJSON({ValueDescr::Scalar(int32()), ValueDescr::Scalar(boolean())}, "[[5, false], [5, false], [5, false]]"), ExecBatchFromJSON({int32(), boolean()}, "[[5, true], [6, false], [7, true]]")}; @@ -716,30 +715,6 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) { {"sink", SinkNodeOptions{&sink_gen}}, }) .AddToPlan(plan.get())); -======= - ExecBatchFromJSON({ValueDescr::Scalar(int32()), ValueDescr::Scalar(int32()), - ValueDescr::Scalar(int32())}, - "[[5, 5, 5], [5, 5, 5], [5, 5, 5]]"), - ExecBatchFromJSON({int32(), int32(), int32()}, - "[[5, 5, 5], [6, 6, 6], [7, 7, 7]]")}; - scalar_data.schema = - schema({field("a", int32()), field("b", int32()), field("c", int32())}); - - ASSERT_OK(Declaration::Sequence( - { - {"source", SourceNodeOptions{scalar_data.schema, - scalar_data.gen(/*parallel=*/false, - /*slow=*/false)}}, - {"aggregate", - AggregateNodeOptions{ - /*aggregates=*/{ - {"count", nullptr}, {"sum", nullptr}, {"mean", nullptr}}, - /*targets=*/{"a", "b", "c"}, - /*names=*/{"count(a)", "sum(b)", "mean(c)"}}}, - {"sink", SinkNodeOptions{&sink_gen}}, - }) - .AddToPlan(plan.get())); ->>>>>>> ARROW-13482: [C++][Compute] Refactoring away from hard coded ExecNode factories to a registry ASSERT_THAT( StartAndCollect(plan.get(), sink_gen), diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index d1d06e773f0..91c993c101e 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -156,40 +156,6 @@ ExecBatch ExecBatchFromJSON(const std::vector& descrs, return batch; } -Result MakeTestSourceNode(ExecPlan* plan, std::string label, - BatchesWithSchema batches_with_schema, bool parallel, - bool slow) { - DCHECK_GT(batches_with_schema.batches.size(), 0); - - auto opt_batches = ::arrow::internal::MapVector( - [](ExecBatch batch) { return util::make_optional(std::move(batch)); }, - std::move(batches_with_schema.batches)); - - AsyncGenerator> gen; - - if (parallel) { - // emulate batches completing initial decode-after-scan on a cpu thread - ARROW_ASSIGN_OR_RAISE( - gen, MakeBackgroundGenerator(MakeVectorIterator(std::move(opt_batches)), - ::arrow::internal::GetCpuThreadPool())); - - // ensure that callbacks are not executed immediately on a background thread - gen = MakeTransferredGenerator(std::move(gen), ::arrow::internal::GetCpuThreadPool()); - } else { - gen = MakeVectorGenerator(std::move(opt_batches)); - } - - if (slow) { - gen = MakeMappedGenerator(std::move(gen), [](const util::optional& batch) { - SleepABit(); - return batch; - }); - } - - return MakeSourceNode(plan, std::move(label), std::move(batches_with_schema.schema), - std::move(gen)); -} - Future> StartAndCollect( ExecPlan* plan, AsyncGenerator> gen) { RETURN_NOT_OK(plan->Validate()); diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index 3ef1333ea42..55c971954ea 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -17,6 +17,9 @@ #pragma once +#include +#include + #include #include #include @@ -45,12 +48,39 @@ ExecBatch ExecBatchFromJSON(const std::vector& descrs, struct BatchesWithSchema { std::vector batches; std::shared_ptr schema; -}; -ARROW_TESTING_EXPORT -Result MakeTestSourceNode(ExecPlan* plan, std::string label, - BatchesWithSchema batches_with_schema, bool parallel, - bool slow); + AsyncGenerator> gen(bool parallel, bool slow) const { + DCHECK_GT(batches.size(), 0); + + auto opt_batches = ::arrow::internal::MapVector( + [](ExecBatch batch) { return util::make_optional(std::move(batch)); }, batches); + + AsyncGenerator> gen; + + if (parallel) { + // emulate batches completing initial decode-after-scan on a cpu thread + gen = MakeBackgroundGenerator(MakeVectorIterator(std::move(opt_batches)), + ::arrow::internal::GetCpuThreadPool()) + .ValueOrDie(); + + // ensure that callbacks are not executed immediately on a background thread + gen = + MakeTransferredGenerator(std::move(gen), ::arrow::internal::GetCpuThreadPool()); + } else { + gen = MakeVectorGenerator(std::move(opt_batches)); + } + + if (slow) { + gen = + MakeMappedGenerator(std::move(gen), [](const util::optional& batch) { + SleepABit(); + return batch; + }); + } + + return gen; + } +}; ARROW_TESTING_EXPORT Future> StartAndCollect( diff --git a/cpp/src/arrow/compute/exec/util.cc b/cpp/src/arrow/compute/exec/util.cc index 49bafc295fe..8eac7ddb6b7 100644 --- a/cpp/src/arrow/compute/exec/util.cc +++ b/cpp/src/arrow/compute/exec/util.cc @@ -24,6 +24,7 @@ >>>>>>> ARROW-13482: [C++][Compute] Refactoring away from hard coded ExecNode factories to a registry #include "arrow/util/bit_util.h" #include "arrow/util/bitmap_ops.h" +#include "arrow/util/thread_pool.h" #include "arrow/util/ubsan.h" namespace arrow { @@ -312,5 +313,26 @@ return Table::FromRecordBatches(schema, batches); } + size_t ThreadIndexer::operator()() { + auto id = std::this_thread::get_id(); + + auto guard = mutex_.Lock(); // acquire the lock + const auto& id_index = *id_to_index_.emplace(id, id_to_index_.size()).first; + + return Check(id_index.second); + } + + size_t ThreadIndexer::Capacity() { + static size_t max_size = arrow::internal::ThreadPool::DefaultCapacity(); + return max_size; + } + + size_t ThreadIndexer::Check(size_t thread_index) { + DCHECK_LT(thread_index, Capacity()) << "thread index " << thread_index + << " is out of range [0, " << Capacity() << ")"; + + return thread_index; + } + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/util.h b/cpp/src/arrow/compute/exec/util.h index 8bd6a3c5d62..24306329bd6 100644 --- a/cpp/src/arrow/compute/exec/util.h +++ b/cpp/src/arrow/compute/exec/util.h @@ -19,6 +19,8 @@ #include #include +#include +#include #include #include "arrow/buffer.h" @@ -29,6 +31,7 @@ #include "arrow/util/bit_util.h" #include "arrow/util/cpu_info.h" #include "arrow/util/logging.h" +#include "arrow/util/mutex.h" #include "arrow/util/optional.h" #if defined(__clang__) || defined(__GNUC__) @@ -233,5 +236,18 @@ class AtomicCounter { std::atomic complete_{false}; }; +class ThreadIndexer { + public: + size_t operator()(); + + static size_t Capacity(); + + private: + static size_t Check(size_t thread_index); + + util::Mutex mutex_; + std::unordered_map id_to_index_; +}; + } // namespace compute } // namespace arrow From 7977837b1aad01303192a53028f71c3dee53bbf4 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 10 Aug 2021 16:02:34 -0400 Subject: [PATCH 24/27] porting tests --- cpp/src/arrow/compute/exec/hash_join_test.cc | 43 +++++++++++--------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/cpp/src/arrow/compute/exec/hash_join_test.cc b/cpp/src/arrow/compute/exec/hash_join_test.cc index 816d6adaf7a..8a1035bf4d5 100644 --- a/cpp/src/arrow/compute/exec/hash_join_test.cc +++ b/cpp/src/arrow/compute/exec/hash_join_test.cc @@ -18,6 +18,7 @@ #include #include "arrow/api.h" +#include "arrow/compute/exec/options.h" #include "arrow/compute/exec/test_util.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" @@ -49,28 +50,30 @@ void GenerateBatchesFromString(const std::shared_ptr& schema, out_batches->schema = schema; } -void CheckRunOutput(JoinType type, BatchesWithSchema l_batches, - BatchesWithSchema r_batches, - const std::vector& left_keys, - const std::vector& right_keys, +void CheckRunOutput(JoinType type, const BatchesWithSchema& l_batches, + const BatchesWithSchema& r_batches, + const std::vector& left_keys, + const std::vector& right_keys, const BatchesWithSchema& exp_batches, bool parallel = false) { SCOPED_TRACE("serial"); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - ASSERT_OK_AND_ASSIGN(auto l_source, - MakeTestSourceNode(plan.get(), "l_source", std::move(l_batches), - /*parallel=*/parallel, - /*slow=*/false)); - ASSERT_OK_AND_ASSIGN(auto r_source, - MakeTestSourceNode(plan.get(), "r_source", std::move(r_batches), - /*parallel=*/parallel, - /*slow=*/false)); + JoinNodeOptions join_options{type, left_keys, right_keys}; + Declaration join{"hash_join", join_options}; - ASSERT_OK_AND_ASSIGN( - auto semi_join, - MakeHashJoinNode(type, l_source, r_source, "hash_join", left_keys, right_keys)); - auto sink_gen = MakeSinkNode(semi_join, "sink"); + // add left source + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{l_batches.schema, l_batches.gen(parallel, + /*slow=*/false)}}); + // add right source + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{r_batches.schema, r_batches.gen(parallel, + /*slow=*/false)}}); + AsyncGenerator> sink_gen; + + ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) + .AddToPlan(plan.get())); ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); @@ -121,8 +124,8 @@ void RunNonEmptyTest(JoinType type, bool parallel) { FAIL() << "join type not implemented!"; } - CheckRunOutput(type, std::move(l_batches), std::move(r_batches), - /*left_keys=*/{"l_str"}, /*right_keys=*/{"r_str"}, exp_batches, + CheckRunOutput(type, l_batches, r_batches, + /*left_keys=*/{{"l_str"}}, /*right_keys=*/{{"r_str"}}, exp_batches, parallel); } @@ -142,8 +145,8 @@ void RunEmptyTest(JoinType type, bool parallel) { GenerateBatchesFromString(r_schema, {R"([["f", 0], ["b", 1], ["b", 2]])"}, &r_n_empty, multiplicity); - std::vector l_keys{"l_str"}; - std::vector r_keys{"r_str"}; + std::vector l_keys{{"l_str"}}; + std::vector r_keys{{"r_str"}}; switch (type) { case LEFT_SEMI: From db5c72499107b97c22552e66d5a75664615d327c Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 11 Aug 2021 00:21:33 -0400 Subject: [PATCH 25/27] extending the test cases --- cpp/src/arrow/CMakeLists.txt | 2 +- cpp/src/arrow/compute/exec/CMakeLists.txt | 3 +- .../exec/{hash_join.cc => hash_join_node.cc} | 0 .../compute/exec/hash_join_node_benchmark.cc | 19 +++++++ ...sh_join_test.cc => hash_join_node_test.cc} | 55 ++++++++++++++++++- cpp/src/arrow/compute/exec/test_util.cc | 2 + 6 files changed, 78 insertions(+), 3 deletions(-) rename cpp/src/arrow/compute/exec/{hash_join.cc => hash_join_node.cc} (100%) create mode 100644 cpp/src/arrow/compute/exec/hash_join_node_benchmark.cc rename cpp/src/arrow/compute/exec/{hash_join_test.cc => hash_join_node_test.cc} (77%) diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 6293f1bcc05..513205a7c45 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -412,7 +412,7 @@ if(ARROW_COMPUTE) compute/kernels/vector_replace.cc compute/kernels/vector_selection.cc compute/kernels/vector_sort.cc - compute/exec/hash_join.cc + compute/exec/hash_join_node.cc compute/exec/key_hash.cc compute/exec/key_map.cc compute/exec/key_compare.cc diff --git a/cpp/src/arrow/compute/exec/CMakeLists.txt b/cpp/src/arrow/compute/exec/CMakeLists.txt index 281154e3518..030685c68b6 100644 --- a/cpp/src/arrow/compute/exec/CMakeLists.txt +++ b/cpp/src/arrow/compute/exec/CMakeLists.txt @@ -25,6 +25,7 @@ add_arrow_compute_test(expression_test subtree_test.cc) add_arrow_compute_test(plan_test PREFIX "arrow-compute") -add_arrow_compute_test(hash_join_test PREFIX "arrow-compute") +add_arrow_compute_test(hash_join_node_test PREFIX "arrow-compute") add_arrow_benchmark(expression_benchmark PREFIX "arrow-compute") +add_arrow_benchmark(hash_join_node_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc similarity index 100% rename from cpp/src/arrow/compute/exec/hash_join.cc rename to cpp/src/arrow/compute/exec/hash_join_node.cc diff --git a/cpp/src/arrow/compute/exec/hash_join_node_benchmark.cc b/cpp/src/arrow/compute/exec/hash_join_node_benchmark.cc new file mode 100644 index 00000000000..1f109ed3664 --- /dev/null +++ b/cpp/src/arrow/compute/exec/hash_join_node_benchmark.cc @@ -0,0 +1,19 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "benchmark/benchmark.h" + diff --git a/cpp/src/arrow/compute/exec/hash_join_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc similarity index 77% rename from cpp/src/arrow/compute/exec/hash_join_test.cc rename to cpp/src/arrow/compute/exec/hash_join_node_test.cc index 8a1035bf4d5..7e058a6864d 100644 --- a/cpp/src/arrow/compute/exec/hash_join_test.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc @@ -202,9 +202,62 @@ TEST_P(HashJoinTest, TestSemiJoins) { RunNonEmptyTest(std::get<0>(GetParam()), std::get<1>(GetParam())); } -TEST_P(HashJoinTest, TestSemiJoinsLeftEmpty) { +TEST_P(HashJoinTest, TestSemiJoinstEmpty) { RunEmptyTest(std::get<0>(GetParam()), std::get<1>(GetParam())); } +void TestJoinRandom(const std::shared_ptr& data_type, JoinType type, + bool parallel, int num_batches, int batch_size) { + auto l_schema = schema({field("l0", data_type), field("l1", data_type)}); + auto r_schema = schema({field("r0", data_type), field("r1", data_type)}); + + // generate data + auto l_batches = MakeRandomBatches(l_schema, num_batches, batch_size); + auto r_batches = MakeRandomBatches(r_schema, num_batches, batch_size); + + std::vector left_keys{{"l0"}}; + std::vector right_keys{{"r1"}}; + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + + JoinNodeOptions join_options{type, left_keys, right_keys}; + Declaration join{"hash_join", join_options}; + + // add left source + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{l_batches.schema, l_batches.gen(parallel, + /*slow=*/false)}}); + // add right source + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{r_batches.schema, r_batches.gen(parallel, + /*slow=*/false)}}); + AsyncGenerator> sink_gen; + + ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) + .AddToPlan(plan.get())); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); + + // TODO(niranda) add a verification step for res +} + +class HashJoinTestRand : public testing::TestWithParam< + std::tuple, JoinType, bool>> {}; + +static constexpr int kNumBatches = 1000; +static constexpr int kBatchSize = 100; + +INSTANTIATE_TEST_SUITE_P( + HashJoinTestRand, HashJoinTestRand, + ::testing::Combine(::testing::Values(int8(), int32(), int64(), float32(), float64()), + ::testing::Values(JoinType::LEFT_SEMI, JoinType::RIGHT_SEMI, + JoinType::LEFT_ANTI, JoinType::RIGHT_ANTI), + ::testing::Values(false, true))); + +TEST_P(HashJoinTestRand, TestingTypes) { + TestJoinRandom(std::get<0>(GetParam()), std::get<1>(GetParam()), + std::get<2>(GetParam()), kNumBatches, kBatchSize); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index 91c993c101e..46905f25ac8 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -193,6 +193,8 @@ BatchesWithSchema MakeRandomBatches(const std::shared_ptr& schema, // add a tag scalar to ensure the batches are unique out.batches[i].values.emplace_back(i); } + + out.schema = schema; return out; } From b250e99be3a805205259dc0472cf6b60166aa297 Mon Sep 17 00:00:00 2001 From: michalursa Date: Mon, 16 Aug 2021 23:10:53 -0700 Subject: [PATCH 26/27] Hash semi-join: debugging ThreadIndexer --- cpp/src/arrow/CMakeLists.txt | 1 - cpp/src/arrow/compute/exec/exec_plan.cc | 1 - cpp/src/arrow/compute/exec/hash_join_node.cc | 28 + .../compute/exec/hash_join_node_benchmark.cc | 1 - cpp/src/arrow/compute/exec/plan_test.cc | 88 +-- cpp/src/arrow/compute/exec/sink_node.cc | 79 ++- cpp/src/arrow/compute/exec/test_util.h | 2 +- cpp/src/arrow/compute/exec/util.cc | 507 +++++++++--------- 8 files changed, 359 insertions(+), 348 deletions(-) diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 513205a7c45..e329a1274fa 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -372,7 +372,6 @@ if(ARROW_COMPUTE) compute/exec.cc compute/exec/aggregate_node.cc compute/exec/exec_plan.cc - compute/exec/exec_utils.cc compute/exec/expression.cc compute/exec/filter_node.cc compute/exec/project_node.cc diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 846961facf9..6c0aee1baf6 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -24,7 +24,6 @@ #include #include "arrow/compute/exec.h" -#include "arrow/compute/exec/exec_utils.h" #include "arrow/compute/exec/expression.h" #include "arrow/compute/exec_internal.h" #include "arrow/compute/registry.h" diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc index 8bffe826dad..3a9dadcfa20 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node.cc @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +#include #include #include "arrow/api.h" @@ -166,7 +167,15 @@ struct HashSemiJoinNode : ExecNode { // total reached at the end of consumption, all the local states will be merged, before // incrementing the total batches Status ConsumeBuildBatch(ExecBatch batch) { + auto executor = ::arrow::internal::GetCpuThreadPool(); + std::cout << " executor->OwnsThisThread() = " << (executor->OwnsThisThread() ? 1 : 0) + << std::endl; + std::cout << "ConsumeBuildBatch: std::this_thread::get_id() = " << std::hex + << std::showbase << std::this_thread::get_id() << std::dec + << std::noshowbase; size_t thread_index = get_thread_index_(); + std::cout << " get_thread_index_() = " << thread_index; + std::cout << std::endl; ARROW_DCHECK(thread_index < local_states_.size()); ARROW_LOG(DEBUG) << "ConsumeBuildBatch tid:" << thread_index @@ -197,8 +206,16 @@ struct HashSemiJoinNode : ExecNode { // consumes cached probe batches by invoking executor::Spawn. Status ConsumeCachedProbeBatches() { + auto executor = ::arrow::internal::GetCpuThreadPool(); + std::cout << " executor->OwnsThisThread() = " << (executor->OwnsThisThread() ? 1 : 0) + << std::endl; + std::cout << "ConsumeCachedProbeBatches: std::this_thread::get_id() = " << std::hex + << std::showbase << std::this_thread::get_id() << std::dec + << std::noshowbase; ARROW_LOG(DEBUG) << "ConsumeCachedProbeBatches tid:" << get_thread_index_() << " len:" << cached_probe_batches.size(); + std::cout << " get_thread_index_() = " << get_thread_index_(); + std::cout << std::endl; // acquire the mutex to access cached_probe_batches, because while consuming, other // batches should not be cached! @@ -288,8 +305,16 @@ struct HashSemiJoinNode : ExecNode { // cached_probe_batches_mutex, it should no longer be cached! instead, it can be // directly consumed! bool AttemptToCacheProbeBatch(int seq_num, ExecBatch* batch) { + auto executor = ::arrow::internal::GetCpuThreadPool(); + std::cout << " executor->OwnsThisThread() = " << (executor->OwnsThisThread() ? 1 : 0) + << std::endl; + std::cout << "AttemptToCacheProbeBatch: std::this_thread::get_id() = " << std::hex + << std::showbase << std::this_thread::get_id() << std::dec + << std::noshowbase; ARROW_LOG(DEBUG) << "cache tid:" << get_thread_index_() << " seq:" << seq_num << " len:" << batch->length; + std::cout << " get_thread_index_() = " << get_thread_index_(); + std::cout << std::endl; std::lock_guard lck(cached_probe_batches_mutex); if (cached_probe_batches_consumed) { return false; @@ -303,6 +328,8 @@ struct HashSemiJoinNode : ExecNode { // If all build side batches received? continue streaming using probing // else cache the batches in thread-local state void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { + std::cout << "InputReceived " << seq << std::endl; + ARROW_LOG(DEBUG) << "input received input:" << (IsBuildInput(input) ? "b" : "p") << " seq:" << seq << " len:" << batch.length; @@ -374,6 +401,7 @@ struct HashSemiJoinNode : ExecNode { } Status StartProducing() override { + std::cout << "Start Producing" << std::endl; ARROW_LOG(DEBUG) << "start prod"; finished_ = Future<>::Make(); diff --git a/cpp/src/arrow/compute/exec/hash_join_node_benchmark.cc b/cpp/src/arrow/compute/exec/hash_join_node_benchmark.cc index 1f109ed3664..91daa7d5295 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node_benchmark.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node_benchmark.cc @@ -16,4 +16,3 @@ // under the License. #include "benchmark/benchmark.h" - diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index d37fd2fca34..c0feedad509 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -26,17 +26,14 @@ #include "arrow/compute/exec/test_util.h" #include "arrow/compute/exec/util.h" #include "arrow/record_batch.h" -<<<<<<< HEAD #include "arrow/table.h" #include "arrow/testing/future_util.h" -======= ->>>>>>> refactoring files - #include "arrow/testing/gtest_util.h" +#include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" #include "arrow/util/async_generator.h" #include "arrow/util/logging.h" - using testing::ElementsAre; +using testing::ElementsAre; using testing::ElementsAreArray; using testing::HasSubstr; using testing::Optional; @@ -197,87 +194,6 @@ TEST(ExecPlan, DummyStartProducingError) { ASSERT_THAT(t.stopped, ElementsAre("process2", "process3", "sink")); } -namespace { - -struct BatchesWithSchema { - std::vector batches; - std::shared_ptr schema; - - AsyncGenerator> gen(bool parallel, bool slow) const { - DCHECK_GT(batches.size(), 0); - - auto opt_batches = ::arrow::internal::MapVector( - [](ExecBatch batch) { return util::make_optional(std::move(batch)); }, batches); - - AsyncGenerator> gen; - - if (parallel) { - // emulate batches completing initial decode-after-scan on a cpu thread - gen = MakeBackgroundGenerator(MakeVectorIterator(std::move(opt_batches)), - ::arrow::internal::GetCpuThreadPool()) - .ValueOrDie(); - - // ensure that callbacks are not executed immediately on a background thread - gen = - MakeTransferredGenerator(std::move(gen), ::arrow::internal::GetCpuThreadPool()); - } else { - gen = MakeVectorGenerator(std::move(opt_batches)); - } - - if (slow) { - gen = - MakeMappedGenerator(std::move(gen), [](const util::optional& batch) { - SleepABit(); - return batch; - }); - } - - return gen; - } -}; - -Future> StartAndCollect( - ExecPlan* plan, AsyncGenerator> gen) { - RETURN_NOT_OK(plan->Validate()); - RETURN_NOT_OK(plan->StartProducing()); - - auto collected_fut = CollectAsyncGenerator(gen); - - return AllComplete({plan->finished(), Future<>(collected_fut)}) - .Then([collected_fut]() -> Result> { - ARROW_ASSIGN_OR_RAISE(auto collected, collected_fut.result()); - return ::arrow::internal::MapVector( - [](util::optional batch) { return std::move(*batch); }, - std::move(collected)); - }); -} - -BatchesWithSchema MakeBasicBatches() { - BatchesWithSchema out; - out.batches = { - ExecBatchFromJSON({int32(), boolean()}, "[[null, true], [4, false]]"), - ExecBatchFromJSON({int32(), boolean()}, "[[5, null], [6, false], [7, false]]")}; - out.schema = schema({field("i32", int32()), field("bool", boolean())}); - return out; -} - -BatchesWithSchema MakeRandomBatches(const std::shared_ptr& schema, - int num_batches = 10, int batch_size = 4) { - BatchesWithSchema out; - out.schema = schema; - - random::RandomArrayGenerator rng(42); - out.batches.resize(num_batches); - - for (int i = 0; i < num_batches; ++i) { - out.batches[i] = ExecBatch(*rng.BatchOf(schema->fields(), batch_size)); - // add a tag scalar to ensure the batches are unique - out.batches[i].values.emplace_back(i); - } - return out; -} -} // namespace - TEST(ExecPlanExecution, SourceSink) { for (bool slow : {false, true}) { SCOPED_TRACE(slow ? "slowed" : "unslowed"); diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc index a8f21dbda1a..bdbc0b0fa00 100644 --- a/cpp/src/arrow/compute/exec/sink_node.cc +++ b/cpp/src/arrow/compute/exec/sink_node.cc @@ -136,8 +136,8 @@ class SinkNode : public ExecNode { } } - private: - void Finish() { + protected: + virtual void Finish() { if (producer_.Close()) { finished_.MarkFinished(); } @@ -149,7 +149,82 @@ class SinkNode : public ExecNode { PushGenerator>::Producer producer_; }; +// A sink node that accumulates inputs, then sorts them before emitting them. +struct OrderBySinkNode final : public SinkNode { + OrderBySinkNode(ExecPlan* plan, std::vector inputs, SortOptions sort_options, + AsyncGenerator>* generator) + : SinkNode(plan, std::move(inputs), generator), + sort_options_(std::move(sort_options)) {} + + const char* kind_name() override { return "OrderBySinkNode"; } + + static Result Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) { + RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "OrderBySinkNode")); + + const auto& sink_options = checked_cast(options); + return plan->EmplaceNode( + plan, std::move(inputs), sink_options.sort_options, sink_options.generator); + } + + void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { + DCHECK_EQ(input, inputs_[0]); + + // Accumulate data + { + std::unique_lock lock(mutex_); + auto maybe_batch = batch.ToRecordBatch(inputs_[0]->output_schema(), + plan()->exec_context()->memory_pool()); + if (ErrorIfNotOk(maybe_batch.status())) return; + batches_.push_back(maybe_batch.MoveValueUnsafe()); + } + + if (input_counter_.Increment()) { + Finish(); + } + } + + protected: + Status DoFinish() { + Datum sorted; + { + std::unique_lock lock(mutex_); + ARROW_ASSIGN_OR_RAISE( + auto table, + Table::FromRecordBatches(inputs_[0]->output_schema(), std::move(batches_))); + ARROW_ASSIGN_OR_RAISE(auto indices, + SortIndices(table, sort_options_, plan()->exec_context())); + ARROW_ASSIGN_OR_RAISE(sorted, Take(table, indices, TakeOptions::NoBoundsCheck(), + plan()->exec_context())); + } + TableBatchReader reader(*sorted.table()); + while (true) { + std::shared_ptr batch; + RETURN_NOT_OK(reader.ReadNext(&batch)); + if (!batch) break; + bool did_push = producer_.Push(ExecBatch(*batch)); + if (!did_push) break; // producer_ was Closed already + } + return Status::OK(); + } + + void Finish() override { + Status st = DoFinish(); + if (ErrorIfNotOk(st)) { + producer_.Push(std::move(st)); + } + SinkNode::Finish(); + } + + private: + SortOptions sort_options_; + std::mutex mutex_; + std::vector> batches_; +}; + ExecFactoryRegistry::AddOnLoad kRegisterSink("sink", SinkNode::Make); +ExecFactoryRegistry::AddOnLoad kRegisterOrderBySink("order_by_sink", + OrderBySinkNode::Make); } // namespace } // namespace compute diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index 55c971954ea..e21dfd673ec 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -61,7 +61,7 @@ struct BatchesWithSchema { // emulate batches completing initial decode-after-scan on a cpu thread gen = MakeBackgroundGenerator(MakeVectorIterator(std::move(opt_batches)), ::arrow::internal::GetCpuThreadPool()) - .ValueOrDie(); + .ValueOrDie(); // ensure that callbacks are not executed immediately on a background thread gen = diff --git a/cpp/src/arrow/compute/exec/util.cc b/cpp/src/arrow/compute/exec/util.cc index 8eac7ddb6b7..57a21665b28 100644 --- a/cpp/src/arrow/compute/exec/util.cc +++ b/cpp/src/arrow/compute/exec/util.cc @@ -18,321 +18,316 @@ #include "arrow/compute/exec/util.h" #include "arrow/compute/exec/exec_plan.h" -<<<<<<< HEAD #include "arrow/table.h" -======= ->>>>>>> ARROW-13482: [C++][Compute] Refactoring away from hard coded ExecNode factories to a registry - #include "arrow/util/bit_util.h" +#include "arrow/util/bit_util.h" #include "arrow/util/bitmap_ops.h" #include "arrow/util/thread_pool.h" #include "arrow/util/ubsan.h" - namespace arrow { +namespace arrow { - using BitUtil::CountTrailingZeros; +using BitUtil::CountTrailingZeros; - namespace util { +namespace util { - inline void BitUtil::bits_to_indexes_helper(uint64_t word, uint16_t base_index, - int* num_indexes, uint16_t* indexes) { - int n = *num_indexes; - while (word) { - indexes[n++] = base_index + static_cast(CountTrailingZeros(word)); - word &= word - 1; - } - *num_indexes = n; +inline void BitUtil::bits_to_indexes_helper(uint64_t word, uint16_t base_index, + int* num_indexes, uint16_t* indexes) { + int n = *num_indexes; + while (word) { + indexes[n++] = base_index + static_cast(CountTrailingZeros(word)); + word &= word - 1; } - - inline void BitUtil::bits_filter_indexes_helper(uint64_t word, - const uint16_t* input_indexes, - int* num_indexes, uint16_t* indexes) { - int n = *num_indexes; - while (word) { - indexes[n++] = input_indexes[CountTrailingZeros(word)]; - word &= word - 1; - } - *num_indexes = n; + *num_indexes = n; +} + +inline void BitUtil::bits_filter_indexes_helper(uint64_t word, + const uint16_t* input_indexes, + int* num_indexes, uint16_t* indexes) { + int n = *num_indexes; + while (word) { + indexes[n++] = input_indexes[CountTrailingZeros(word)]; + word &= word - 1; } - - template - void BitUtil::bits_to_indexes_internal(int64_t hardware_flags, const int num_bits, - const uint8_t* bits, - const uint16_t* input_indexes, int* num_indexes, - uint16_t* indexes) { - // 64 bits at a time - constexpr int unroll = 64; - int tail = num_bits % unroll; + *num_indexes = n; +} + +template +void BitUtil::bits_to_indexes_internal(int64_t hardware_flags, const int num_bits, + const uint8_t* bits, const uint16_t* input_indexes, + int* num_indexes, uint16_t* indexes) { + // 64 bits at a time + constexpr int unroll = 64; + int tail = num_bits % unroll; #if defined(ARROW_HAVE_AVX2) - if (hardware_flags & arrow::internal::CpuInfo::AVX2) { - if (filter_input_indexes) { - bits_filter_indexes_avx2(bit_to_search, num_bits - tail, bits, input_indexes, - num_indexes, indexes); - } else { - bits_to_indexes_avx2(bit_to_search, num_bits - tail, bits, num_indexes, indexes); - } + if (hardware_flags & arrow::internal::CpuInfo::AVX2) { + if (filter_input_indexes) { + bits_filter_indexes_avx2(bit_to_search, num_bits - tail, bits, input_indexes, + num_indexes, indexes); } else { -#endif - *num_indexes = 0; - for (int i = 0; i < num_bits / unroll; ++i) { - uint64_t word = util::SafeLoad(&reinterpret_cast(bits)[i]); - if (bit_to_search == 0) { - word = ~word; - } - if (filter_input_indexes) { - bits_filter_indexes_helper(word, input_indexes + i * 64, num_indexes, indexes); - } else { - bits_to_indexes_helper(word, i * 64, num_indexes, indexes); - } - } -#if defined(ARROW_HAVE_AVX2) + bits_to_indexes_avx2(bit_to_search, num_bits - tail, bits, num_indexes, indexes); } + } else { #endif - // Optionally process the last partial word with masking out bits outside range - if (tail) { - uint64_t word = - util::SafeLoad(&reinterpret_cast(bits)[num_bits / unroll]); + *num_indexes = 0; + for (int i = 0; i < num_bits / unroll; ++i) { + uint64_t word = util::SafeLoad(&reinterpret_cast(bits)[i]); if (bit_to_search == 0) { word = ~word; } - word &= ~0ULL >> (64 - tail); if (filter_input_indexes) { - bits_filter_indexes_helper(word, input_indexes + num_bits - tail, num_indexes, - indexes); + bits_filter_indexes_helper(word, input_indexes + i * 64, num_indexes, indexes); } else { - bits_to_indexes_helper(word, num_bits - tail, num_indexes, indexes); + bits_to_indexes_helper(word, i * 64, num_indexes, indexes); } } +#if defined(ARROW_HAVE_AVX2) } - - void BitUtil::bits_to_indexes(int bit_to_search, int64_t hardware_flags, - const int num_bits, const uint8_t* bits, int* num_indexes, - uint16_t* indexes, int bit_offset) { - bits += bit_offset / 8; - bit_offset %= 8; - if (bit_offset != 0) { - int num_indexes_head = 0; - uint64_t bits_head = - util::SafeLoad(reinterpret_cast(bits)) >> bit_offset; - int bits_in_first_byte = std::min(num_bits, 8 - bit_offset); - bits_to_indexes(bit_to_search, hardware_flags, bits_in_first_byte, - reinterpret_cast(&bits_head), &num_indexes_head, - indexes); - int num_indexes_tail = 0; - if (num_bits > bits_in_first_byte) { - bits_to_indexes(bit_to_search, hardware_flags, num_bits - bits_in_first_byte, - bits + 1, &num_indexes_tail, indexes + num_indexes_head); - } - *num_indexes = num_indexes_head + num_indexes_tail; - return; - } - +#endif + // Optionally process the last partial word with masking out bits outside range + if (tail) { + uint64_t word = + util::SafeLoad(&reinterpret_cast(bits)[num_bits / unroll]); if (bit_to_search == 0) { - bits_to_indexes_internal<0, false>(hardware_flags, num_bits, bits, nullptr, - num_indexes, indexes); + word = ~word; + } + word &= ~0ULL >> (64 - tail); + if (filter_input_indexes) { + bits_filter_indexes_helper(word, input_indexes + num_bits - tail, num_indexes, + indexes); } else { - ARROW_DCHECK(bit_to_search == 1); - bits_to_indexes_internal<1, false>(hardware_flags, num_bits, bits, nullptr, - num_indexes, indexes); + bits_to_indexes_helper(word, num_bits - tail, num_indexes, indexes); } } - - void BitUtil::bits_filter_indexes(int bit_to_search, int64_t hardware_flags, - const int num_bits, const uint8_t* bits, - const uint16_t* input_indexes, int* num_indexes, - uint16_t* indexes, int bit_offset) { - bits += bit_offset / 8; - bit_offset %= 8; - if (bit_offset != 0) { - int num_indexes_head = 0; - uint64_t bits_head = - util::SafeLoad(reinterpret_cast(bits)) >> bit_offset; - int bits_in_first_byte = std::min(num_bits, 8 - bit_offset); - bits_filter_indexes(bit_to_search, hardware_flags, bits_in_first_byte, - reinterpret_cast(&bits_head), input_indexes, - &num_indexes_head, indexes); - int num_indexes_tail = 0; - if (num_bits > bits_in_first_byte) { - bits_filter_indexes(bit_to_search, hardware_flags, num_bits - bits_in_first_byte, - bits + 1, input_indexes + bits_in_first_byte, - &num_indexes_tail, indexes + num_indexes_head); - } - *num_indexes = num_indexes_head + num_indexes_tail; - return; +} + +void BitUtil::bits_to_indexes(int bit_to_search, int64_t hardware_flags, + const int num_bits, const uint8_t* bits, int* num_indexes, + uint16_t* indexes, int bit_offset) { + bits += bit_offset / 8; + bit_offset %= 8; + if (bit_offset != 0) { + int num_indexes_head = 0; + uint64_t bits_head = + util::SafeLoad(reinterpret_cast(bits)) >> bit_offset; + int bits_in_first_byte = std::min(num_bits, 8 - bit_offset); + bits_to_indexes(bit_to_search, hardware_flags, bits_in_first_byte, + reinterpret_cast(&bits_head), &num_indexes_head, + indexes); + int num_indexes_tail = 0; + if (num_bits > bits_in_first_byte) { + bits_to_indexes(bit_to_search, hardware_flags, num_bits - bits_in_first_byte, + bits + 1, &num_indexes_tail, indexes + num_indexes_head); } + *num_indexes = num_indexes_head + num_indexes_tail; + return; + } - if (bit_to_search == 0) { - bits_to_indexes_internal<0, true>(hardware_flags, num_bits, bits, input_indexes, - num_indexes, indexes); - } else { - ARROW_DCHECK(bit_to_search == 1); - bits_to_indexes_internal<1, true>(hardware_flags, num_bits, bits, input_indexes, - num_indexes, indexes); + if (bit_to_search == 0) { + bits_to_indexes_internal<0, false>(hardware_flags, num_bits, bits, nullptr, + num_indexes, indexes); + } else { + ARROW_DCHECK(bit_to_search == 1); + bits_to_indexes_internal<1, false>(hardware_flags, num_bits, bits, nullptr, + num_indexes, indexes); + } +} + +void BitUtil::bits_filter_indexes(int bit_to_search, int64_t hardware_flags, + const int num_bits, const uint8_t* bits, + const uint16_t* input_indexes, int* num_indexes, + uint16_t* indexes, int bit_offset) { + bits += bit_offset / 8; + bit_offset %= 8; + if (bit_offset != 0) { + int num_indexes_head = 0; + uint64_t bits_head = + util::SafeLoad(reinterpret_cast(bits)) >> bit_offset; + int bits_in_first_byte = std::min(num_bits, 8 - bit_offset); + bits_filter_indexes(bit_to_search, hardware_flags, bits_in_first_byte, + reinterpret_cast(&bits_head), input_indexes, + &num_indexes_head, indexes); + int num_indexes_tail = 0; + if (num_bits > bits_in_first_byte) { + bits_filter_indexes(bit_to_search, hardware_flags, num_bits - bits_in_first_byte, + bits + 1, input_indexes + bits_in_first_byte, &num_indexes_tail, + indexes + num_indexes_head); } + *num_indexes = num_indexes_head + num_indexes_tail; + return; } - void BitUtil::bits_split_indexes(int64_t hardware_flags, const int num_bits, - const uint8_t* bits, int* num_indexes_bit0, - uint16_t* indexes_bit0, uint16_t* indexes_bit1, - int bit_offset) { - bits_to_indexes(0, hardware_flags, num_bits, bits, num_indexes_bit0, indexes_bit0, - bit_offset); - int num_indexes_bit1; - bits_to_indexes(1, hardware_flags, num_bits, bits, &num_indexes_bit1, indexes_bit1, - bit_offset); + if (bit_to_search == 0) { + bits_to_indexes_internal<0, true>(hardware_flags, num_bits, bits, input_indexes, + num_indexes, indexes); + } else { + ARROW_DCHECK(bit_to_search == 1); + bits_to_indexes_internal<1, true>(hardware_flags, num_bits, bits, input_indexes, + num_indexes, indexes); } - - void BitUtil::bits_to_bytes(int64_t hardware_flags, const int num_bits, - const uint8_t* bits, uint8_t* bytes, int bit_offset) { - bits += bit_offset / 8; - bit_offset %= 8; - if (bit_offset != 0) { - uint64_t bits_head = - util::SafeLoad(reinterpret_cast(bits)) >> bit_offset; - int bits_in_first_byte = std::min(num_bits, 8 - bit_offset); - bits_to_bytes(hardware_flags, bits_in_first_byte, - reinterpret_cast(&bits_head), bytes); - if (num_bits > bits_in_first_byte) { - bits_to_bytes(hardware_flags, num_bits - bits_in_first_byte, bits + 1, - bytes + bits_in_first_byte); - } - return; +} + +void BitUtil::bits_split_indexes(int64_t hardware_flags, const int num_bits, + const uint8_t* bits, int* num_indexes_bit0, + uint16_t* indexes_bit0, uint16_t* indexes_bit1, + int bit_offset) { + bits_to_indexes(0, hardware_flags, num_bits, bits, num_indexes_bit0, indexes_bit0, + bit_offset); + int num_indexes_bit1; + bits_to_indexes(1, hardware_flags, num_bits, bits, &num_indexes_bit1, indexes_bit1, + bit_offset); +} + +void BitUtil::bits_to_bytes(int64_t hardware_flags, const int num_bits, + const uint8_t* bits, uint8_t* bytes, int bit_offset) { + bits += bit_offset / 8; + bit_offset %= 8; + if (bit_offset != 0) { + uint64_t bits_head = + util::SafeLoad(reinterpret_cast(bits)) >> bit_offset; + int bits_in_first_byte = std::min(num_bits, 8 - bit_offset); + bits_to_bytes(hardware_flags, bits_in_first_byte, + reinterpret_cast(&bits_head), bytes); + if (num_bits > bits_in_first_byte) { + bits_to_bytes(hardware_flags, num_bits - bits_in_first_byte, bits + 1, + bytes + bits_in_first_byte); } + return; + } - int num_processed = 0; + int num_processed = 0; #if defined(ARROW_HAVE_AVX2) - if (hardware_flags & arrow::internal::CpuInfo::AVX2) { - // The function call below processes whole 32 bit chunks together. - num_processed = num_bits - (num_bits % 32); - bits_to_bytes_avx2(num_processed, bits, bytes); - } + if (hardware_flags & arrow::internal::CpuInfo::AVX2) { + // The function call below processes whole 32 bit chunks together. + num_processed = num_bits - (num_bits % 32); + bits_to_bytes_avx2(num_processed, bits, bytes); + } #endif - // Processing 8 bits at a time - constexpr int unroll = 8; - for (int i = num_processed / unroll; i < (num_bits + unroll - 1) / unroll; ++i) { - uint8_t bits_next = bits[i]; - // Clear the lowest bit and then make 8 copies of remaining 7 bits, each 7 bits - // apart from the previous. - uint64_t unpacked = static_cast(bits_next & 0xfe) * - ((1ULL << 7) | (1ULL << 14) | (1ULL << 21) | (1ULL << 28) | - (1ULL << 35) | (1ULL << 42) | (1ULL << 49)); - unpacked |= (bits_next & 1); - unpacked &= 0x0101010101010101ULL; - unpacked *= 255; - util::SafeStore(&reinterpret_cast(bytes)[i], unpacked); - } + // Processing 8 bits at a time + constexpr int unroll = 8; + for (int i = num_processed / unroll; i < (num_bits + unroll - 1) / unroll; ++i) { + uint8_t bits_next = bits[i]; + // Clear the lowest bit and then make 8 copies of remaining 7 bits, each 7 bits + // apart from the previous. + uint64_t unpacked = static_cast(bits_next & 0xfe) * + ((1ULL << 7) | (1ULL << 14) | (1ULL << 21) | (1ULL << 28) | + (1ULL << 35) | (1ULL << 42) | (1ULL << 49)); + unpacked |= (bits_next & 1); + unpacked &= 0x0101010101010101ULL; + unpacked *= 255; + util::SafeStore(&reinterpret_cast(bytes)[i], unpacked); } - - void BitUtil::bytes_to_bits(int64_t hardware_flags, const int num_bits, - const uint8_t* bytes, uint8_t* bits, int bit_offset) { - bits += bit_offset / 8; - bit_offset %= 8; - if (bit_offset != 0) { - uint64_t bits_head; - int bits_in_first_byte = std::min(num_bits, 8 - bit_offset); - bytes_to_bits(hardware_flags, bits_in_first_byte, bytes, - reinterpret_cast(&bits_head)); - uint8_t mask = (1 << bit_offset) - 1; - *bits = static_cast((*bits & mask) | (bits_head << bit_offset)); - - if (num_bits > bits_in_first_byte) { - bytes_to_bits(hardware_flags, num_bits - bits_in_first_byte, - bytes + bits_in_first_byte, bits + 1); - } - return; +} + +void BitUtil::bytes_to_bits(int64_t hardware_flags, const int num_bits, + const uint8_t* bytes, uint8_t* bits, int bit_offset) { + bits += bit_offset / 8; + bit_offset %= 8; + if (bit_offset != 0) { + uint64_t bits_head; + int bits_in_first_byte = std::min(num_bits, 8 - bit_offset); + bytes_to_bits(hardware_flags, bits_in_first_byte, bytes, + reinterpret_cast(&bits_head)); + uint8_t mask = (1 << bit_offset) - 1; + *bits = static_cast((*bits & mask) | (bits_head << bit_offset)); + + if (num_bits > bits_in_first_byte) { + bytes_to_bits(hardware_flags, num_bits - bits_in_first_byte, + bytes + bits_in_first_byte, bits + 1); } + return; + } - int num_processed = 0; + int num_processed = 0; #if defined(ARROW_HAVE_AVX2) - if (hardware_flags & arrow::internal::CpuInfo::AVX2) { - // The function call below processes whole 32 bit chunks together. - num_processed = num_bits - (num_bits % 32); - bytes_to_bits_avx2(num_processed, bytes, bits); - } + if (hardware_flags & arrow::internal::CpuInfo::AVX2) { + // The function call below processes whole 32 bit chunks together. + num_processed = num_bits - (num_bits % 32); + bytes_to_bits_avx2(num_processed, bytes, bits); + } #endif - // Process 8 bits at a time - constexpr int unroll = 8; - for (int i = num_processed / unroll; i < (num_bits + unroll - 1) / unroll; ++i) { - uint64_t bytes_next = util::SafeLoad(&reinterpret_cast(bytes)[i]); - bytes_next &= 0x0101010101010101ULL; - bytes_next |= - (bytes_next >> 7); // Pairs of adjacent output bits in individual bytes - bytes_next |= (bytes_next >> 14); // 4 adjacent output bits in individual bytes - bytes_next |= (bytes_next >> 28); // All 8 output bits in the lowest byte - bits[i] = static_cast(bytes_next & 0xff); - } + // Process 8 bits at a time + constexpr int unroll = 8; + for (int i = num_processed / unroll; i < (num_bits + unroll - 1) / unroll; ++i) { + uint64_t bytes_next = util::SafeLoad(&reinterpret_cast(bytes)[i]); + bytes_next &= 0x0101010101010101ULL; + bytes_next |= (bytes_next >> 7); // Pairs of adjacent output bits in individual bytes + bytes_next |= (bytes_next >> 14); // 4 adjacent output bits in individual bytes + bytes_next |= (bytes_next >> 28); // All 8 output bits in the lowest byte + bits[i] = static_cast(bytes_next & 0xff); } +} - bool BitUtil::are_all_bytes_zero(int64_t hardware_flags, const uint8_t* bytes, - uint32_t num_bytes) { +bool BitUtil::are_all_bytes_zero(int64_t hardware_flags, const uint8_t* bytes, + uint32_t num_bytes) { #if defined(ARROW_HAVE_AVX2) - if (hardware_flags & arrow::internal::CpuInfo::AVX2) { - return are_all_bytes_zero_avx2(bytes, num_bytes); - } + if (hardware_flags & arrow::internal::CpuInfo::AVX2) { + return are_all_bytes_zero_avx2(bytes, num_bytes); + } #endif - uint64_t result_or = 0; - uint32_t i; - for (i = 0; i < num_bytes / 8; ++i) { - uint64_t x = util::SafeLoad(&reinterpret_cast(bytes)[i]); - result_or |= x; - } - if (num_bytes % 8 > 0) { - uint64_t tail = 0; - result_or |= memcmp(bytes + i * 8, &tail, num_bytes % 8); - } - return result_or == 0; + uint64_t result_or = 0; + uint32_t i; + for (i = 0; i < num_bytes / 8; ++i) { + uint64_t x = util::SafeLoad(&reinterpret_cast(bytes)[i]); + result_or |= x; + } + if (num_bytes % 8 > 0) { + uint64_t tail = 0; + result_or |= memcmp(bytes + i * 8, &tail, num_bytes % 8); } + return result_or == 0; +} - } // namespace util +} // namespace util - namespace compute { +namespace compute { - Status ValidateExecNodeInputs(ExecPlan* plan, const std::vector& inputs, - int expected_num_inputs, const char* kind_name) { - if (static_cast(inputs.size()) != expected_num_inputs) { - return Status::Invalid(kind_name, " node requires ", expected_num_inputs, - " inputs but got ", inputs.size()); - } +Status ValidateExecNodeInputs(ExecPlan* plan, const std::vector& inputs, + int expected_num_inputs, const char* kind_name) { + if (static_cast(inputs.size()) != expected_num_inputs) { + return Status::Invalid(kind_name, " node requires ", expected_num_inputs, + " inputs but got ", inputs.size()); + } - for (auto input : inputs) { - if (input->plan() != plan) { - return Status::Invalid("Constructing a ", kind_name, - " node in a different plan from its input"); - } + for (auto input : inputs) { + if (input->plan() != plan) { + return Status::Invalid("Constructing a ", kind_name, + " node in a different plan from its input"); } - - return Status::OK(); } - Result> TableFromExecBatches( - const std::shared_ptr& schema, const std::vector& exec_batches) { - RecordBatchVector batches; - for (const auto& batch : exec_batches) { - ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToRecordBatch(schema)); - batches.push_back(std::move(rb)); - } - return Table::FromRecordBatches(schema, batches); + return Status::OK(); +} + +Result> TableFromExecBatches( + const std::shared_ptr& schema, const std::vector& exec_batches) { + RecordBatchVector batches; + for (const auto& batch : exec_batches) { + ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToRecordBatch(schema)); + batches.push_back(std::move(rb)); } + return Table::FromRecordBatches(schema, batches); +} - size_t ThreadIndexer::operator()() { - auto id = std::this_thread::get_id(); +size_t ThreadIndexer::operator()() { + auto id = std::this_thread::get_id(); - auto guard = mutex_.Lock(); // acquire the lock - const auto& id_index = *id_to_index_.emplace(id, id_to_index_.size()).first; + auto guard = mutex_.Lock(); // acquire the lock + const auto& id_index = *id_to_index_.emplace(id, id_to_index_.size()).first; - return Check(id_index.second); - } + return Check(id_index.second); +} - size_t ThreadIndexer::Capacity() { - static size_t max_size = arrow::internal::ThreadPool::DefaultCapacity(); - return max_size; - } +size_t ThreadIndexer::Capacity() { + static size_t max_size = arrow::internal::ThreadPool::DefaultCapacity(); + return max_size; +} - size_t ThreadIndexer::Check(size_t thread_index) { - DCHECK_LT(thread_index, Capacity()) << "thread index " << thread_index - << " is out of range [0, " << Capacity() << ")"; +size_t ThreadIndexer::Check(size_t thread_index) { + DCHECK_LT(thread_index, Capacity()) + << "thread index " << thread_index << " is out of range [0, " << Capacity() << ")"; - return thread_index; - } + return thread_index; +} - } // namespace compute +} // namespace compute } // namespace arrow From ec09202cb28d885221a06bf3a9ed22b3d9fc9491 Mon Sep 17 00:00:00 2001 From: michalursa Date: Tue, 24 Aug 2021 00:28:50 -0700 Subject: [PATCH 27/27] Hash semi-join: fixing exec context for exec plan in unit tests --- cpp/src/arrow/compute/exec/hash_join_node_test.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc index 7e058a6864d..56483e1860d 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc @@ -57,7 +57,11 @@ void CheckRunOutput(JoinType type, const BatchesWithSchema& l_batches, const BatchesWithSchema& exp_batches, bool parallel = false) { SCOPED_TRACE("serial"); - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + ExecContext parallel_exec_context(default_memory_pool(), + ::arrow::internal::GetCpuThreadPool()); + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(parallel ? ¶llel_exec_context + : default_exec_context())); JoinNodeOptions join_options{type, left_keys, right_keys}; Declaration join{"hash_join", join_options};