diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 308ee49972c..e329a1274fa 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -411,6 +411,7 @@ if(ARROW_COMPUTE) compute/kernels/vector_replace.cc compute/kernels/vector_selection.cc compute/kernels/vector_sort.cc + compute/exec/hash_join_node.cc compute/exec/key_hash.cc compute/exec/key_map.cc compute/exec/key_compare.cc diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h index 880424e97f8..2fa36b32b21 100644 --- a/cpp/src/arrow/compute/api_aggregate.h +++ b/cpp/src/arrow/compute/api_aggregate.h @@ -383,6 +383,11 @@ class ARROW_EXPORT Grouper { /// be as wide as necessary. virtual Result Consume(const ExecBatch& batch) = 0; + /// Finds/ queries the group IDs for the given ExecBatch for every index. Returns the + /// group IDs as an integer array. If a group ID not found, a null will be added to that + /// index. This is a thread-safe lookup. + virtual Result Find(const ExecBatch& batch) const = 0; + /// Get current unique keys. May be called multiple times. virtual Result GetUniques() = 0; diff --git a/cpp/src/arrow/compute/exec/CMakeLists.txt b/cpp/src/arrow/compute/exec/CMakeLists.txt index 2ed8b1c9480..030685c68b6 100644 --- a/cpp/src/arrow/compute/exec/CMakeLists.txt +++ b/cpp/src/arrow/compute/exec/CMakeLists.txt @@ -25,5 +25,7 @@ add_arrow_compute_test(expression_test subtree_test.cc) add_arrow_compute_test(plan_test PREFIX "arrow-compute") +add_arrow_compute_test(hash_join_node_test PREFIX "arrow-compute") add_arrow_benchmark(expression_benchmark PREFIX "arrow-compute") +add_arrow_benchmark(hash_join_node_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/exec/aggregate_node.cc b/cpp/src/arrow/compute/exec/aggregate_node.cc index 260fd8e52fd..327f6eeec62 100644 --- a/cpp/src/arrow/compute/exec/aggregate_node.cc +++ b/cpp/src/arrow/compute/exec/aggregate_node.cc @@ -15,13 +15,12 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/compute/exec/exec_plan.h" - #include #include #include #include "arrow/compute/exec.h" +#include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/util.h" #include "arrow/compute/exec_internal.h" @@ -59,34 +58,6 @@ Result ResolveKernels( namespace { -class ThreadIndexer { - public: - size_t operator()() { - auto id = std::this_thread::get_id(); - - std::unique_lock lock(mutex_); - const auto& id_index = *id_to_index_.emplace(id, id_to_index_.size()).first; - - return Check(id_index.second); - } - - static size_t Capacity() { - static size_t max_size = arrow::internal::ThreadPool::DefaultCapacity(); - return max_size; - } - - private: - size_t Check(size_t thread_index) { - DCHECK_LT(thread_index, Capacity()) << "thread index " << thread_index - << " is out of range [0, " << Capacity() << ")"; - - return thread_index; - } - - std::mutex mutex_; - std::unordered_map id_to_index_; -}; - class ScalarAggregateNode : public ExecNode { public: ScalarAggregateNode(ExecPlan* plan, std::vector inputs, @@ -461,8 +432,7 @@ class GroupByNode : public ExecNode { // bail if StopProducing was called if (finished_.is_finished()) break; - auto plan = this->plan()->shared_from_this(); - RETURN_NOT_OK(executor->Spawn([plan, this, i] { OutputNthBatch(i); })); + RETURN_NOT_OK(executor->Spawn([this, i] { OutputNthBatch(i); })); } else { OutputNthBatch(i); } diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 5047c7a58d6..6c0aee1baf6 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -17,6 +17,9 @@ #include "arrow/compute/exec/exec_plan.h" +#include +#include +#include #include #include diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc new file mode 100644 index 00000000000..3a9dadcfa20 --- /dev/null +++ b/cpp/src/arrow/compute/exec/hash_join_node.cc @@ -0,0 +1,580 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/api.h" +#include "arrow/compute/api.h" +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/thread_pool.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { + +namespace { +Status ValidateJoinInputs(const std::shared_ptr& left_schema, + const std::shared_ptr& right_schema, + const std::vector& left_keys, + const std::vector& right_keys) { + if (left_keys.size() != right_keys.size()) { + return Status::Invalid("left and right key sizes do not match"); + } + + for (size_t i = 0; i < left_keys.size(); i++) { + auto l_type = left_schema->field(left_keys[i])->type(); + auto r_type = right_schema->field(right_keys[i])->type(); + + if (!l_type->Equals(r_type)) { + return Status::Invalid("build and probe types do not match: " + l_type->ToString() + + "!=" + r_type->ToString()); + } + } + + return Status::OK(); +} + +Result> PopulateKeys(const Schema& schema, + const std::vector& keys) { + std::vector key_field_ids(keys.size()); + // Find input field indices for left key fields + for (size_t i = 0; i < keys.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema)); + key_field_ids[i] = match[0]; + } + return key_field_ids; +} +} // namespace + +template +struct HashSemiJoinNode : ExecNode { + HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, ExecContext* ctx, + const std::vector&& build_index_field_ids, + const std::vector&& probe_index_field_ids) + : ExecNode(build_input->plan(), {build_input, probe_input}, + {"hash_join_build", "hash_join_probe"}, probe_input->output_schema(), + /*num_outputs=*/1), + ctx_(ctx), + build_index_field_ids_(build_index_field_ids), + probe_index_field_ids_(probe_index_field_ids), + build_result_index(-1), + hash_table_built_(false), + cached_probe_batches_consumed(false) {} + + private: + struct ThreadLocalState; + + public: + const char* kind_name() override { return "HashSemiJoinNode"; } + + Status InitLocalStateIfNeeded(ThreadLocalState* state) { + ARROW_LOG(DEBUG) << "init state"; + + // Get input schema + auto build_schema = inputs_[0]->output_schema(); + + if (state->grouper != nullptr) return Status::OK(); + + // Build vector of key field data types + std::vector key_descrs(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + auto build_type = build_schema->field(build_index_field_ids_[i])->type(); + key_descrs[i] = ValueDescr(build_type); + } + + // Construct grouper + ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, ctx_)); + + return Status::OK(); + } + + // Finds an appropriate index which could accumulate all build indices (i.e. the grouper + // which has the highest # of groups) + void CalculateBuildResultIndex() { + int32_t curr_max = -1; + for (int i = 0; i < static_cast(local_states_.size()); i++) { + auto* state = &local_states_[i]; + ARROW_DCHECK(state); + if (state->grouper && + curr_max < static_cast(state->grouper->num_groups())) { + curr_max = static_cast(state->grouper->num_groups()); + build_result_index = i; + } + } + ARROW_DCHECK(build_result_index > -1); + ARROW_LOG(DEBUG) << "build_result_index " << build_result_index; + } + + // Performs the housekeeping work after the build-side is completed. + // Note: this method is not thread safe, and hence should be guaranteed that it is + // not accessed concurrently! + Status BuildSideCompleted() { + ARROW_LOG(DEBUG) << "build side merge"; + + // if the hash table has already been built, return + if (hash_table_built_) return Status::OK(); + + CalculateBuildResultIndex(); + + // merge every group into the build_result_index grouper + ThreadLocalState* result_state = &local_states_[build_result_index]; + for (int i = 0; i < static_cast(local_states_.size()); ++i) { + ThreadLocalState* state = &local_states_[i]; + ARROW_DCHECK(state); + if (i == build_result_index || !state->grouper) { + continue; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques()); + + // TODO(niranda) replace with void consume method + ARROW_ASSIGN_OR_RAISE(Datum _, result_state->grouper->Consume(other_keys)); + state->grouper.reset(); + } + + // enable flag that build side is completed + hash_table_built_ = true; + + // since the build side is completed, consume cached probe batches + RETURN_NOT_OK(ConsumeCachedProbeBatches()); + + return Status::OK(); + } + + // consumes a build batch and increments the build_batches count. if the build batches + // total reached at the end of consumption, all the local states will be merged, before + // incrementing the total batches + Status ConsumeBuildBatch(ExecBatch batch) { + auto executor = ::arrow::internal::GetCpuThreadPool(); + std::cout << " executor->OwnsThisThread() = " << (executor->OwnsThisThread() ? 1 : 0) + << std::endl; + std::cout << "ConsumeBuildBatch: std::this_thread::get_id() = " << std::hex + << std::showbase << std::this_thread::get_id() << std::dec + << std::noshowbase; + size_t thread_index = get_thread_index_(); + std::cout << " get_thread_index_() = " << thread_index; + std::cout << std::endl; + ARROW_DCHECK(thread_index < local_states_.size()); + + ARROW_LOG(DEBUG) << "ConsumeBuildBatch tid:" << thread_index + << " len:" << batch.length; + + auto state = &local_states_[thread_index]; + RETURN_NOT_OK(InitLocalStateIfNeeded(state)); + + // Create a batch with key columns + std::vector keys(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + keys[i] = batch.values[build_index_field_ids_[i]]; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); + + // Create a batch with group ids + // TODO(niranda) replace with void consume method + ARROW_ASSIGN_OR_RAISE(Datum _, state->grouper->Consume(key_batch)); + + if (build_counter_.Increment()) { + // only one thread would get inside this block! + // while incrementing, if the total is reached, call BuildSideCompleted. + RETURN_NOT_OK(BuildSideCompleted()); + } + + return Status::OK(); + } + + // consumes cached probe batches by invoking executor::Spawn. + Status ConsumeCachedProbeBatches() { + auto executor = ::arrow::internal::GetCpuThreadPool(); + std::cout << " executor->OwnsThisThread() = " << (executor->OwnsThisThread() ? 1 : 0) + << std::endl; + std::cout << "ConsumeCachedProbeBatches: std::this_thread::get_id() = " << std::hex + << std::showbase << std::this_thread::get_id() << std::dec + << std::noshowbase; + ARROW_LOG(DEBUG) << "ConsumeCachedProbeBatches tid:" << get_thread_index_() + << " len:" << cached_probe_batches.size(); + std::cout << " get_thread_index_() = " << get_thread_index_(); + std::cout << std::endl; + + // acquire the mutex to access cached_probe_batches, because while consuming, other + // batches should not be cached! + std::lock_guard lck(cached_probe_batches_mutex); + + if (!cached_probe_batches_consumed) { + auto executor = ctx_->executor(); + for (auto&& cached : cached_probe_batches) { + if (executor) { + Status lambda_status; + RETURN_NOT_OK(executor->Spawn([&] { + lambda_status = ConsumeProbeBatch(cached.first, std::move(cached.second)); + })); + + // if the lambda execution failed internally, return status + RETURN_NOT_OK(lambda_status); + } else { + RETURN_NOT_OK(ConsumeProbeBatch(cached.first, std::move(cached.second))); + } + } + // cached vector will be cleared. exec batches are expected to be moved to the + // lambdas + cached_probe_batches.clear(); + } + + // set flag + cached_probe_batches_consumed = true; + return Status::OK(); + } + + Status GenerateOutput(int seq, const ArrayData& group_ids_data, ExecBatch batch) { + if (group_ids_data.GetNullCount() == batch.length) { + // All NULLS! hence, there are no valid outputs! + ARROW_LOG(DEBUG) << "output seq:" << seq << " 0"; + outputs_[0]->InputReceived(this, seq, batch.Slice(0, 0)); + } else if (group_ids_data.MayHaveNulls()) { // values need to be filtered + auto filter_arr = + std::make_shared(group_ids_data.length, group_ids_data.buffers[0], + /*null_bitmap=*/nullptr, /*null_count=*/0, + /*offset=*/group_ids_data.offset); + ARROW_ASSIGN_OR_RAISE(auto rec_batch, + batch.ToRecordBatch(output_schema_, ctx_->memory_pool())); + ARROW_ASSIGN_OR_RAISE( + auto filtered, + Filter(rec_batch, filter_arr, + /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); + auto out_batch = ExecBatch(*filtered.record_batch()); + ARROW_LOG(DEBUG) << "output seq:" << seq << " " << out_batch.length; + outputs_[0]->InputReceived(this, seq, std::move(out_batch)); + } else { // all values are valid for output + ARROW_LOG(DEBUG) << "output seq:" << seq << " " << batch.length; + outputs_[0]->InputReceived(this, seq, std::move(batch)); + } + + return Status::OK(); + } + + // consumes a probe batch and increment probe batches count. Probing would query the + // grouper[build_result_index] which have been merged with all others. + Status ConsumeProbeBatch(int seq, ExecBatch batch) { + ARROW_LOG(DEBUG) << "ConsumeProbeBatch seq:" << seq; + + auto& final_grouper = *local_states_[build_result_index].grouper; + + // Create a batch with key columns + std::vector keys(probe_index_field_ids_.size()); + for (size_t i = 0; i < probe_index_field_ids_.size(); ++i) { + keys[i] = batch.values[probe_index_field_ids_[i]]; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); + + // Query the grouper with key_batch. If no match was found, returning group_ids would + // have null. + ARROW_ASSIGN_OR_RAISE(Datum group_ids, final_grouper.Find(key_batch)); + auto group_ids_data = *group_ids.array(); + + RETURN_NOT_OK(GenerateOutput(seq, group_ids_data, std::move(batch))); + + if (out_counter_.Increment()) { + finished_.MarkFinished(); + } + return Status::OK(); + } + + // Attempt to cache a probe batch. If it is not cached, return false. + // if cached_probe_batches_consumed is true, by the time a thread acquires + // cached_probe_batches_mutex, it should no longer be cached! instead, it can be + // directly consumed! + bool AttemptToCacheProbeBatch(int seq_num, ExecBatch* batch) { + auto executor = ::arrow::internal::GetCpuThreadPool(); + std::cout << " executor->OwnsThisThread() = " << (executor->OwnsThisThread() ? 1 : 0) + << std::endl; + std::cout << "AttemptToCacheProbeBatch: std::this_thread::get_id() = " << std::hex + << std::showbase << std::this_thread::get_id() << std::dec + << std::noshowbase; + ARROW_LOG(DEBUG) << "cache tid:" << get_thread_index_() << " seq:" << seq_num + << " len:" << batch->length; + std::cout << " get_thread_index_() = " << get_thread_index_(); + std::cout << std::endl; + std::lock_guard lck(cached_probe_batches_mutex); + if (cached_probe_batches_consumed) { + return false; + } + cached_probe_batches.emplace_back(seq_num, std::move(*batch)); + return true; + } + + inline bool IsBuildInput(ExecNode* input) { return input == inputs_[0]; } + + // If all build side batches received? continue streaming using probing + // else cache the batches in thread-local state + void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { + std::cout << "InputReceived " << seq << std::endl; + + ARROW_LOG(DEBUG) << "input received input:" << (IsBuildInput(input) ? "b" : "p") + << " seq:" << seq << " len:" << batch.length; + + ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); + + if (finished_.is_finished()) { + return; + } + + if (IsBuildInput(input)) { // build input batch is received + // if a build input is received when build side is completed, something's wrong! + ARROW_DCHECK(!hash_table_built_); + + ErrorIfNotOk(ConsumeBuildBatch(std::move(batch))); + } else { // probe input batch is received + if (hash_table_built_) { + // build side done, continue with probing. when hash_table_built_ is set, it is + // guaranteed that some thread has already called the ConsumeCachedProbeBatches + + // consume this probe batch + ErrorIfNotOk(ConsumeProbeBatch(seq, std::move(batch))); + } else { // build side not completed. Cache this batch! + if (!AttemptToCacheProbeBatch(seq, &batch)) { + // if the cache attempt fails, consume the batch + ErrorIfNotOk(ConsumeProbeBatch(seq, std::move(batch))); + } + } + } + } + + void ErrorReceived(ExecNode* input, Status error) override { + ARROW_LOG(DEBUG) << "error received " << error.ToString(); + DCHECK_EQ(input, inputs_[0]); + + outputs_[0]->ErrorReceived(this, std::move(error)); + StopProducing(); + } + + void InputFinished(ExecNode* input, int num_total) override { + ARROW_LOG(DEBUG) << "input finished input:" << (IsBuildInput(input) ? "b" : "p") + << " tot:" << num_total; + + // bail if StopProducing was called + if (finished_.is_finished()) return; + + ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); + + // set total for build input + if (IsBuildInput(input) && build_counter_.SetTotal(num_total)) { + // only one thread would get inside this block! + // while incrementing, if the total is reached, call BuildSideCompleted. + ErrorIfNotOk(BuildSideCompleted()); + return; + } + + // set total for probe input. If it returns that probe side has completed, nothing to + // do, because probing inputs will be streamed to the output + // probe_counter_.SetTotal(num_total); + + // output will be streamed from the probe side. So, they will have the same total. + if (out_counter_.SetTotal(num_total)) { + // if out_counter has completed, the future is finished! + ErrorIfNotOk(ConsumeCachedProbeBatches()); + outputs_[0]->InputFinished(this, num_total); + finished_.MarkFinished(); + } else { + outputs_[0]->InputFinished(this, num_total); + } + } + + Status StartProducing() override { + std::cout << "Start Producing" << std::endl; + ARROW_LOG(DEBUG) << "start prod"; + finished_ = Future<>::Make(); + + local_states_.resize(ThreadIndexer::Capacity()); + return Status::OK(); + } + + void PauseProducing(ExecNode* output) override {} + + void ResumeProducing(ExecNode* output) override {} + + void StopProducing(ExecNode* output) override { + ARROW_LOG(DEBUG) << "stop prod from node"; + + DCHECK_EQ(output, outputs_[0]); + + if (build_counter_.Cancel()) { + finished_.MarkFinished(); + } else if (out_counter_.Cancel()) { + finished_.MarkFinished(); + } + + for (auto&& input : inputs_) { + input->StopProducing(this); + } + } + + // TODO(niranda) couldn't there be multiple outputs for a Node? + void StopProducing() override { + ARROW_LOG(DEBUG) << "stop prod "; + outputs_[0]->StopProducing(); + } + + Future<> finished() override { + ARROW_LOG(DEBUG) << "finished? " << finished_.is_finished(); + return finished_; + } + + private: + struct ThreadLocalState { + std::unique_ptr grouper; + }; + + ExecContext* ctx_; + Future<> finished_ = Future<>::MakeFinished(); + + ThreadIndexer get_thread_index_; + const std::vector build_index_field_ids_, probe_index_field_ids_; + + AtomicCounter build_counter_, out_counter_; + std::vector local_states_; + + // There's no guarantee on which threads would be coming from the build side. so, out of + // the thread local states, an appropriate state needs to be chosen to accumulate + // all built results (ideally, the grouper which has the highest # of elems) + int32_t build_result_index; + + // need a separate bool to track if the build side complete. Can't use the flag + // inside the AtomicCounter, because we need to merge the build groupers once we receive + // all the build batches. So, while merging, we need to prevent probe batches, being + // consumed. + bool hash_table_built_; + + std::mutex cached_probe_batches_mutex; + std::vector> cached_probe_batches{}; + // a flag is required to indicate if the cached probe batches have already been + // consumed! if cached_probe_batches_consumed is true, by the time a thread aquires + // cached_probe_batches_mutex, it should no longer be cached! instead, it can be + // directly consumed! + bool cached_probe_batches_consumed; +}; + +// template specialization for anti joins. For anti joins, group_ids_data needs to be +// inverted. Output will be taken for indices which are NULL +template <> +Status HashSemiJoinNode::GenerateOutput(int seq, const ArrayData& group_ids_data, + ExecBatch batch) { + if (group_ids_data.GetNullCount() == group_ids_data.length) { + // All NULLS! hence, all values are valid for output + ARROW_LOG(DEBUG) << "output seq:" << seq << " " << batch.length; + outputs_[0]->InputReceived(this, seq, std::move(batch)); + } else if (group_ids_data.MayHaveNulls()) { // values need to be filtered + // invert the validity buffer + arrow::internal::InvertBitmap( + group_ids_data.buffers[0]->data(), group_ids_data.offset, group_ids_data.length, + group_ids_data.buffers[0]->mutable_data(), group_ids_data.offset); + + auto filter_arr = + std::make_shared(group_ids_data.length, group_ids_data.buffers[0], + /*null_bitmap=*/nullptr, /*null_count=*/0, + /*offset=*/group_ids_data.offset); + ARROW_ASSIGN_OR_RAISE(auto rec_batch, + batch.ToRecordBatch(output_schema_, ctx_->memory_pool())); + ARROW_ASSIGN_OR_RAISE( + auto filtered, + Filter(rec_batch, filter_arr, + /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); + auto out_batch = ExecBatch(*filtered.record_batch()); + ARROW_LOG(DEBUG) << "output seq:" << seq << " " << out_batch.length; + outputs_[0]->InputReceived(this, seq, std::move(out_batch)); + } else { + // No NULLS! hence, there are no valid outputs! + ARROW_LOG(DEBUG) << "output seq:" << seq << " 0"; + outputs_[0]->InputReceived(this, seq, batch.Slice(0, 0)); + } + return Status::OK(); +} + +template +Result MakeHashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, + const std::vector& build_keys, + const std::vector& probe_keys) { + auto build_schema = build_input->output_schema(); + auto probe_schema = probe_input->output_schema(); + + ARROW_ASSIGN_OR_RAISE(auto build_key_ids, PopulateKeys(*build_schema, build_keys)); + ARROW_ASSIGN_OR_RAISE(auto probe_key_ids, PopulateKeys(*probe_schema, probe_keys)); + + RETURN_NOT_OK( + ValidateJoinInputs(build_schema, probe_schema, build_key_ids, probe_key_ids)); + + // output schema will be probe schema + auto ctx = build_input->plan()->exec_context(); + ExecPlan* plan = build_input->plan(); + + return plan->EmplaceNode>( + build_input, probe_input, ctx, std::move(build_key_ids), std::move(probe_key_ids)); +} + +ExecFactoryRegistry::AddOnLoad kRegisterHashJoin( + "hash_join", + [](ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) -> Result { + RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 2, "HashJoinNode")); + + const auto& join_options = checked_cast(options); + + static std::string join_type_string[] = {"LEFT_SEMI", "RIGHT_SEMI", "LEFT_ANTI", + "RIGHT_ANTI", "INNER", "LEFT_OUTER", + "RIGHT_OUTER", "FULL_OUTER"}; + + auto join_type = join_options.join_type; + + ExecNode* left_input = inputs[0]; + ExecNode* right_input = inputs[1]; + const auto& left_keys = join_options.left_keys; + const auto& right_keys = join_options.right_keys; + + switch (join_type) { + case LEFT_SEMI: + // left join--> build from right and probe from left + return MakeHashSemiJoinNode(right_input, left_input, right_keys, left_keys); + case RIGHT_SEMI: + // right join--> build from left and probe from right + return MakeHashSemiJoinNode(left_input, right_input, left_keys, right_keys); + case LEFT_ANTI: + // left join--> build from right and probe from left + return MakeHashSemiJoinNode(right_input, left_input, right_keys, + left_keys); + case RIGHT_ANTI: + // right join--> build from left and probe from right + return MakeHashSemiJoinNode(left_input, right_input, left_keys, + right_keys); + case INNER: + case LEFT_OUTER: + case RIGHT_OUTER: + case FULL_OUTER: + return Status::NotImplemented(join_type_string[join_type] + + " joins not implemented!"); + default: + return Status::Invalid("invalid join type"); + } + }); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/hash_join_node_benchmark.cc b/cpp/src/arrow/compute/exec/hash_join_node_benchmark.cc new file mode 100644 index 00000000000..91daa7d5295 --- /dev/null +++ b/cpp/src/arrow/compute/exec/hash_join_node_benchmark.cc @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "benchmark/benchmark.h" diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc new file mode 100644 index 00000000000..56483e1860d --- /dev/null +++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc @@ -0,0 +1,267 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/api.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/test_util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" + +using testing::UnorderedElementsAreArray; + +namespace arrow { +namespace compute { + +void GenerateBatchesFromString(const std::shared_ptr& schema, + const std::vector& json_strings, + BatchesWithSchema* out_batches, int multiplicity = 1) { + std::vector descrs; + for (auto&& field : schema->fields()) { + descrs.emplace_back(field->type()); + } + + for (auto&& s : json_strings) { + out_batches->batches.push_back(ExecBatchFromJSON(descrs, s)); + } + + size_t batch_count = out_batches->batches.size(); + for (int repeat = 1; repeat < multiplicity; ++repeat) { + for (size_t i = 0; i < batch_count; ++i) { + out_batches->batches.push_back(out_batches->batches[i]); + } + } + + out_batches->schema = schema; +} + +void CheckRunOutput(JoinType type, const BatchesWithSchema& l_batches, + const BatchesWithSchema& r_batches, + const std::vector& left_keys, + const std::vector& right_keys, + const BatchesWithSchema& exp_batches, bool parallel = false) { + SCOPED_TRACE("serial"); + + ExecContext parallel_exec_context(default_memory_pool(), + ::arrow::internal::GetCpuThreadPool()); + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(parallel ? ¶llel_exec_context + : default_exec_context())); + + JoinNodeOptions join_options{type, left_keys, right_keys}; + Declaration join{"hash_join", join_options}; + + // add left source + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{l_batches.schema, l_batches.gen(parallel, + /*slow=*/false)}}); + // add right source + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{r_batches.schema, r_batches.gen(parallel, + /*slow=*/false)}}); + AsyncGenerator> sink_gen; + + ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) + .AddToPlan(plan.get())); + + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); +} + +void RunNonEmptyTest(JoinType type, bool parallel) { + auto l_schema = schema({field("l_i32", int32()), field("l_str", utf8())}); + auto r_schema = schema({field("r_str", utf8()), field("r_i32", int32())}); + BatchesWithSchema l_batches, r_batches, exp_batches; + + int multiplicity = parallel ? 100 : 1; + + GenerateBatchesFromString(l_schema, + {R"([[0,"d"], [1,"b"]])", R"([[2,"d"], [3,"a"], [4,"a"]])", + R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, + &l_batches, multiplicity); + + GenerateBatchesFromString( + r_schema, + {R"([["f", 0], ["b", 1], ["b", 2]])", R"([["c", 3], ["g", 4]])", R"([["e", 5]])"}, + &r_batches, multiplicity); + + switch (type) { + case LEFT_SEMI: + GenerateBatchesFromString( + l_schema, {R"([[1,"b"]])", R"([])", R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, + &exp_batches, multiplicity); + break; + case RIGHT_SEMI: + GenerateBatchesFromString( + r_schema, {R"([["b", 1], ["b", 2]])", R"([["c", 3]])", R"([["e", 5]])"}, + &exp_batches, multiplicity); + break; + case LEFT_ANTI: + GenerateBatchesFromString( + l_schema, {R"([[0,"d"]])", R"([[2,"d"], [3,"a"], [4,"a"]])", R"([])"}, + &exp_batches, multiplicity); + break; + case RIGHT_ANTI: + GenerateBatchesFromString(r_schema, {R"([["f", 0]])", R"([["g", 4]])", R"([])"}, + &exp_batches, multiplicity); + break; + case INNER: + case LEFT_OUTER: + case RIGHT_OUTER: + case FULL_OUTER: + default: + FAIL() << "join type not implemented!"; + } + + CheckRunOutput(type, l_batches, r_batches, + /*left_keys=*/{{"l_str"}}, /*right_keys=*/{{"r_str"}}, exp_batches, + parallel); +} + +void RunEmptyTest(JoinType type, bool parallel) { + auto l_schema = schema({field("l_i32", int32()), field("l_str", utf8())}); + auto r_schema = schema({field("r_str", utf8()), field("r_i32", int32())}); + + int multiplicity = parallel ? 100 : 1; + + BatchesWithSchema l_empty, r_empty, l_n_empty, r_n_empty; + + GenerateBatchesFromString(l_schema, {R"([])"}, &l_empty, multiplicity); + GenerateBatchesFromString(r_schema, {R"([])"}, &r_empty, multiplicity); + + GenerateBatchesFromString(l_schema, {R"([[0,"d"], [1,"b"]])"}, &l_n_empty, + multiplicity); + GenerateBatchesFromString(r_schema, {R"([["f", 0], ["b", 1], ["b", 2]])"}, &r_n_empty, + multiplicity); + + std::vector l_keys{{"l_str"}}; + std::vector r_keys{{"r_str"}}; + + switch (type) { + case LEFT_SEMI: + // both empty + CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, l_empty, parallel); + // right empty + CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, l_empty, parallel); + // left empty + CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, l_empty, parallel); + break; + case RIGHT_SEMI: + // both empty + CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, r_empty, parallel); + // right empty + CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, r_empty, parallel); + // left empty + CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, r_empty, parallel); + break; + case LEFT_ANTI: + // both empty + CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, l_empty, parallel); + // right empty + CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, l_n_empty, parallel); + // left empty + CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, l_empty, parallel); + break; + case RIGHT_ANTI: + // both empty + CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, r_empty, parallel); + // right empty + CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, r_empty, parallel); + // left empty + CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, r_n_empty, parallel); + break; + case INNER: + case LEFT_OUTER: + case RIGHT_OUTER: + case FULL_OUTER: + default: + FAIL() << "join type not implemented!"; + } +} + +class HashJoinTest : public testing::TestWithParam> {}; + +INSTANTIATE_TEST_SUITE_P( + HashJoinTest, HashJoinTest, + ::testing::Combine(::testing::Values(JoinType::LEFT_SEMI, JoinType::RIGHT_SEMI, + JoinType::LEFT_ANTI, JoinType::RIGHT_ANTI), + ::testing::Values(false, true))); + +TEST_P(HashJoinTest, TestSemiJoins) { + RunNonEmptyTest(std::get<0>(GetParam()), std::get<1>(GetParam())); +} + +TEST_P(HashJoinTest, TestSemiJoinstEmpty) { + RunEmptyTest(std::get<0>(GetParam()), std::get<1>(GetParam())); +} + +void TestJoinRandom(const std::shared_ptr& data_type, JoinType type, + bool parallel, int num_batches, int batch_size) { + auto l_schema = schema({field("l0", data_type), field("l1", data_type)}); + auto r_schema = schema({field("r0", data_type), field("r1", data_type)}); + + // generate data + auto l_batches = MakeRandomBatches(l_schema, num_batches, batch_size); + auto r_batches = MakeRandomBatches(r_schema, num_batches, batch_size); + + std::vector left_keys{{"l0"}}; + std::vector right_keys{{"r1"}}; + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + + JoinNodeOptions join_options{type, left_keys, right_keys}; + Declaration join{"hash_join", join_options}; + + // add left source + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{l_batches.schema, l_batches.gen(parallel, + /*slow=*/false)}}); + // add right source + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{r_batches.schema, r_batches.gen(parallel, + /*slow=*/false)}}); + AsyncGenerator> sink_gen; + + ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) + .AddToPlan(plan.get())); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); + + // TODO(niranda) add a verification step for res +} + +class HashJoinTestRand : public testing::TestWithParam< + std::tuple, JoinType, bool>> {}; + +static constexpr int kNumBatches = 1000; +static constexpr int kBatchSize = 100; + +INSTANTIATE_TEST_SUITE_P( + HashJoinTestRand, HashJoinTestRand, + ::testing::Combine(::testing::Values(int8(), int32(), int64(), float32(), float64()), + ::testing::Values(JoinType::LEFT_SEMI, JoinType::RIGHT_SEMI, + JoinType::LEFT_ANTI, JoinType::RIGHT_ANTI), + ::testing::Values(false, true))); + +TEST_P(HashJoinTestRand, TestingTypes) { + TestJoinRandom(std::get<0>(GetParam()), std::get<1>(GetParam()), + std::get<2>(GetParam()), kNumBatches, kBatchSize); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index acc79bdfdde..b103120b4d2 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -126,5 +126,32 @@ class ARROW_EXPORT OrderBySinkNodeOptions : public SinkNodeOptions { SortOptions sort_options; }; +enum JoinType { + LEFT_SEMI, + RIGHT_SEMI, + LEFT_ANTI, + RIGHT_ANTI, + INNER, // Not Implemented + LEFT_OUTER, // Not Implemented + RIGHT_OUTER, // Not Implemented + FULL_OUTER // Not Implemented +}; + +class ARROW_EXPORT JoinNodeOptions : public ExecNodeOptions { + public: + JoinNodeOptions(JoinType join_type, std::vector left_keys, + std::vector right_keys) + : join_type(join_type), + left_keys(std::move(left_keys)), + right_keys(std::move(right_keys)) {} + + // type of the join + JoinType join_type; + + // index keys of the join + std::vector left_keys; + std::vector right_keys; +}; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 96bdf746dc8..0c7e7cb39e4 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -18,7 +18,6 @@ #include #include -#include #include "arrow/compute/exec.h" #include "arrow/compute/exec/exec_plan.h" @@ -31,11 +30,8 @@ #include "arrow/testing/future_util.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" -#include "arrow/testing/random.h" #include "arrow/util/async_generator.h" #include "arrow/util/logging.h" -#include "arrow/util/thread_pool.h" -#include "arrow/util/vector.h" using testing::ElementsAre; using testing::ElementsAreArray; @@ -198,87 +194,6 @@ TEST(ExecPlan, DummyStartProducingError) { ASSERT_THAT(t.stopped, ElementsAre("process2", "process3", "sink")); } -namespace { - -struct BatchesWithSchema { - std::vector batches; - std::shared_ptr schema; - - AsyncGenerator> gen(bool parallel, bool slow) const { - DCHECK_GT(batches.size(), 0); - - auto opt_batches = ::arrow::internal::MapVector( - [](ExecBatch batch) { return util::make_optional(std::move(batch)); }, batches); - - AsyncGenerator> gen; - - if (parallel) { - // emulate batches completing initial decode-after-scan on a cpu thread - gen = MakeBackgroundGenerator(MakeVectorIterator(std::move(opt_batches)), - ::arrow::internal::GetCpuThreadPool()) - .ValueOrDie(); - - // ensure that callbacks are not executed immediately on a background thread - gen = - MakeTransferredGenerator(std::move(gen), ::arrow::internal::GetCpuThreadPool()); - } else { - gen = MakeVectorGenerator(std::move(opt_batches)); - } - - if (slow) { - gen = - MakeMappedGenerator(std::move(gen), [](const util::optional& batch) { - SleepABit(); - return batch; - }); - } - - return gen; - } -}; - -Future> StartAndCollect( - ExecPlan* plan, AsyncGenerator> gen) { - RETURN_NOT_OK(plan->Validate()); - RETURN_NOT_OK(plan->StartProducing()); - - auto collected_fut = CollectAsyncGenerator(gen); - - return AllComplete({plan->finished(), Future<>(collected_fut)}) - .Then([collected_fut]() -> Result> { - ARROW_ASSIGN_OR_RAISE(auto collected, collected_fut.result()); - return ::arrow::internal::MapVector( - [](util::optional batch) { return std::move(*batch); }, - std::move(collected)); - }); -} - -BatchesWithSchema MakeBasicBatches() { - BatchesWithSchema out; - out.batches = { - ExecBatchFromJSON({int32(), boolean()}, "[[null, true], [4, false]]"), - ExecBatchFromJSON({int32(), boolean()}, "[[5, null], [6, false], [7, false]]")}; - out.schema = schema({field("i32", int32()), field("bool", boolean())}); - return out; -} - -BatchesWithSchema MakeRandomBatches(const std::shared_ptr& schema, - int num_batches = 10, int batch_size = 4) { - BatchesWithSchema out; - out.schema = schema; - - random::RandomArrayGenerator rng(42); - out.batches.resize(num_batches); - - for (int i = 0; i < num_batches; ++i) { - out.batches[i] = ExecBatch(*rng.BatchOf(schema->fields(), batch_size)); - // add a tag scalar to ensure the batches are unique - out.batches[i].values.emplace_back(i); - } - return out; -} -} // namespace - TEST(ExecPlanExecution, SourceSink) { for (bool slow : {false, true}) { SCOPED_TRACE(slow ? "slowed" : "unslowed"); diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc index 4d9f82e582b..bdbc0b0fa00 100644 --- a/cpp/src/arrow/compute/exec/sink_node.cc +++ b/cpp/src/arrow/compute/exec/sink_node.cc @@ -16,12 +16,11 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/compute/exec/exec_plan.h" - #include #include "arrow/compute/api_vector.h" #include "arrow/compute/exec.h" +#include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/expression.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/util.h" diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index 49f089f3459..46905f25ac8 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -35,6 +35,7 @@ #include "arrow/datum.h" #include "arrow/record_batch.h" #include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" #include "arrow/type.h" #include "arrow/util/async_generator.h" #include "arrow/util/iterator.h" @@ -155,5 +156,47 @@ ExecBatch ExecBatchFromJSON(const std::vector& descrs, return batch; } +Future> StartAndCollect( + ExecPlan* plan, AsyncGenerator> gen) { + RETURN_NOT_OK(plan->Validate()); + RETURN_NOT_OK(plan->StartProducing()); + + auto collected_fut = CollectAsyncGenerator(gen); + + return AllComplete({plan->finished(), Future<>(collected_fut)}) + .Then([collected_fut]() -> Result> { + ARROW_ASSIGN_OR_RAISE(auto collected, collected_fut.result()); + return ::arrow::internal::MapVector( + [](util::optional batch) { return std::move(*batch); }, + std::move(collected)); + }); +} + +BatchesWithSchema MakeBasicBatches() { + BatchesWithSchema out; + out.batches = { + ExecBatchFromJSON({int32(), boolean()}, "[[null, true], [4, false]]"), + ExecBatchFromJSON({int32(), boolean()}, "[[5, null], [6, false], [7, false]]")}; + out.schema = schema({field("i32", int32()), field("bool", boolean())}); + return out; +} + +BatchesWithSchema MakeRandomBatches(const std::shared_ptr& schema, + int num_batches, int batch_size) { + BatchesWithSchema out; + + random::RandomArrayGenerator rng(42); + out.batches.resize(num_batches); + + for (int i = 0; i < num_batches; ++i) { + out.batches[i] = ExecBatch(*rng.BatchOf(schema->fields(), batch_size)); + // add a tag scalar to ensure the batches are unique + out.batches[i].values.emplace_back(i); + } + + out.schema = schema; + return out; +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index faa395bab78..e21dfd673ec 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -17,6 +17,9 @@ #pragma once +#include +#include + #include #include #include @@ -24,6 +27,7 @@ #include "arrow/compute/exec.h" #include "arrow/compute/exec/exec_plan.h" #include "arrow/testing/visibility.h" +#include "arrow/util/async_generator.h" #include "arrow/util/string_view.h" namespace arrow { @@ -41,5 +45,53 @@ ARROW_TESTING_EXPORT ExecBatch ExecBatchFromJSON(const std::vector& descrs, util::string_view json); +struct BatchesWithSchema { + std::vector batches; + std::shared_ptr schema; + + AsyncGenerator> gen(bool parallel, bool slow) const { + DCHECK_GT(batches.size(), 0); + + auto opt_batches = ::arrow::internal::MapVector( + [](ExecBatch batch) { return util::make_optional(std::move(batch)); }, batches); + + AsyncGenerator> gen; + + if (parallel) { + // emulate batches completing initial decode-after-scan on a cpu thread + gen = MakeBackgroundGenerator(MakeVectorIterator(std::move(opt_batches)), + ::arrow::internal::GetCpuThreadPool()) + .ValueOrDie(); + + // ensure that callbacks are not executed immediately on a background thread + gen = + MakeTransferredGenerator(std::move(gen), ::arrow::internal::GetCpuThreadPool()); + } else { + gen = MakeVectorGenerator(std::move(opt_batches)); + } + + if (slow) { + gen = + MakeMappedGenerator(std::move(gen), [](const util::optional& batch) { + SleepABit(); + return batch; + }); + } + + return gen; + } +}; + +ARROW_TESTING_EXPORT +Future> StartAndCollect( + ExecPlan* plan, AsyncGenerator> gen); + +ARROW_TESTING_EXPORT +BatchesWithSchema MakeBasicBatches(); + +ARROW_TESTING_EXPORT +BatchesWithSchema MakeRandomBatches(const std::shared_ptr& schema, + int num_batches = 10, int batch_size = 4); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/util.cc b/cpp/src/arrow/compute/exec/util.cc index aad6dc3d587..57a21665b28 100644 --- a/cpp/src/arrow/compute/exec/util.cc +++ b/cpp/src/arrow/compute/exec/util.cc @@ -21,6 +21,7 @@ #include "arrow/table.h" #include "arrow/util/bit_util.h" #include "arrow/util/bitmap_ops.h" +#include "arrow/util/thread_pool.h" #include "arrow/util/ubsan.h" namespace arrow { @@ -205,8 +206,8 @@ void BitUtil::bits_to_bytes(int64_t hardware_flags, const int num_bits, constexpr int unroll = 8; for (int i = num_processed / unroll; i < (num_bits + unroll - 1) / unroll; ++i) { uint8_t bits_next = bits[i]; - // Clear the lowest bit and then make 8 copies of remaining 7 bits, each 7 bits apart - // from the previous. + // Clear the lowest bit and then make 8 copies of remaining 7 bits, each 7 bits + // apart from the previous. uint64_t unpacked = static_cast(bits_next & 0xfe) * ((1ULL << 7) | (1ULL << 14) | (1ULL << 21) | (1ULL << 28) | (1ULL << 35) | (1ULL << 42) | (1ULL << 49)); @@ -307,5 +308,26 @@ Result> TableFromExecBatches( return Table::FromRecordBatches(schema, batches); } +size_t ThreadIndexer::operator()() { + auto id = std::this_thread::get_id(); + + auto guard = mutex_.Lock(); // acquire the lock + const auto& id_index = *id_to_index_.emplace(id, id_to_index_.size()).first; + + return Check(id_index.second); +} + +size_t ThreadIndexer::Capacity() { + static size_t max_size = arrow::internal::ThreadPool::DefaultCapacity(); + return max_size; +} + +size_t ThreadIndexer::Check(size_t thread_index) { + DCHECK_LT(thread_index, Capacity()) + << "thread index " << thread_index << " is out of range [0, " << Capacity() << ")"; + + return thread_index; +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/util.h b/cpp/src/arrow/compute/exec/util.h index 8bd6a3c5d62..24306329bd6 100644 --- a/cpp/src/arrow/compute/exec/util.h +++ b/cpp/src/arrow/compute/exec/util.h @@ -19,6 +19,8 @@ #include #include +#include +#include #include #include "arrow/buffer.h" @@ -29,6 +31,7 @@ #include "arrow/util/bit_util.h" #include "arrow/util/cpu_info.h" #include "arrow/util/logging.h" +#include "arrow/util/mutex.h" #include "arrow/util/optional.h" #if defined(__clang__) || defined(__GNUC__) @@ -233,5 +236,18 @@ class AtomicCounter { std::atomic complete_{false}; }; +class ThreadIndexer { + public: + size_t operator()(); + + static size_t Capacity(); + + private: + static size_t Check(size_t thread_index); + + util::Mutex mutex_; + std::unordered_map id_to_index_; +}; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 4fd6af9b190..09f832f2740 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#include + #include #include #include @@ -369,30 +371,43 @@ struct GrouperImpl : Grouper { return std::move(impl); } - Result Consume(const ExecBatch& batch) override { - std::vector offsets_batch(batch.length + 1); + Status PopulateKeyData(const ExecBatch& batch, std::vector* offsets_batch, + std::vector* key_bytes_batch, + std::vector* key_buf_ptrs) const { + offsets_batch->resize(batch.length + 1); for (int i = 0; i < batch.num_values(); ++i) { - encoders_[i]->AddLength(*batch[i].array(), offsets_batch.data()); + encoders_[i]->AddLength(*batch[i].array(), offsets_batch->data()); } int32_t total_length = 0; for (int64_t i = 0; i < batch.length; ++i) { auto total_length_before = total_length; - total_length += offsets_batch[i]; - offsets_batch[i] = total_length_before; + total_length += offsets_batch->at(i); + offsets_batch->at(i) = total_length_before; } - offsets_batch[batch.length] = total_length; + offsets_batch->at(batch.length) = total_length; - std::vector key_bytes_batch(total_length); - std::vector key_buf_ptrs(batch.length); + key_bytes_batch->resize(total_length); + key_buf_ptrs->resize(batch.length); for (int64_t i = 0; i < batch.length; ++i) { - key_buf_ptrs[i] = key_bytes_batch.data() + offsets_batch[i]; + key_buf_ptrs->at(i) = key_bytes_batch->data() + offsets_batch->at(i); } for (int i = 0; i < batch.num_values(); ++i) { - RETURN_NOT_OK(encoders_[i]->Encode(*batch[i].array(), key_buf_ptrs.data())); + RETURN_NOT_OK(encoders_[i]->Encode(*batch[i].array(), key_buf_ptrs->data())); } + return Status::OK(); + } + + Result Consume(const ExecBatch& batch) override { + std::vector offsets_batch; + std::vector key_bytes_batch; + std::vector key_buf_ptrs; + + RETURN_NOT_OK( + PopulateKeyData(batch, &offsets_batch, &key_bytes_batch, &key_buf_ptrs)); + TypedBufferBuilder group_ids_batch(ctx_->memory_pool()); RETURN_NOT_OK(group_ids_batch.Resize(batch.length)); @@ -421,6 +436,36 @@ struct GrouperImpl : Grouper { return Datum(UInt32Array(batch.length, std::move(group_ids))); } + Result Find(const ExecBatch& batch) const override { + std::vector offsets_batch; + std::vector key_bytes_batch; + std::vector key_buf_ptrs; + + RETURN_NOT_OK( + PopulateKeyData(batch, &offsets_batch, &key_bytes_batch, &key_buf_ptrs)); + + UInt32Builder group_ids_batch(ctx_->memory_pool()); + RETURN_NOT_OK(group_ids_batch.Resize(batch.length)); + + for (int64_t i = 0; i < batch.length; ++i) { + int32_t key_length = offsets_batch[i + 1] - offsets_batch[i]; + std::string key( + reinterpret_cast(key_bytes_batch.data() + offsets_batch[i]), + key_length); + + auto it = map_.find(key); + // no group ID was found, null will be emitted! + if (it == map_.end()) { + group_ids_batch.UnsafeAppendNull(); + } else { + group_ids_batch.UnsafeAppend(it->second); + } + } + + ARROW_ASSIGN_OR_RAISE(auto group_ids, group_ids_batch.Finish()); + return Datum(group_ids); + } + uint32_t num_groups() const override { return num_groups_; } Result GetUniques() override { @@ -622,6 +667,11 @@ struct GrouperFastImpl : Grouper { return Datum(UInt32Array(batch.length, std::move(group_ids))); } + Result Find(const ExecBatch& batch) const override { + // todo impl this + return Result(); + } + uint32_t num_groups() const override { return static_cast(rows_.length()); } // Make sure padded buffers end up with the right logical size @@ -1929,9 +1979,10 @@ Result ResolveKernels( Result> Grouper::Make(const std::vector& descrs, ExecContext* ctx) { - if (GrouperFastImpl::CanUse(descrs)) { - return GrouperFastImpl::Make(descrs, ctx); - } + // TODO(niranda) re-enable this! + // if (GrouperFastImpl::CanUse(descrs)) { + // return GrouperFastImpl::Make(descrs, ctx); + // } return GrouperImpl::Make(descrs, ctx); } diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc index c69b51e71fc..b392f2db4bf 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc @@ -308,6 +308,16 @@ struct TestGrouper { AssertEquivalentIds(expected, ids); } + void ExpectFind(const std::string& key_json, const std::string& expected) { + ExpectFind(ExecBatchFromJSON(descrs_, key_json), ArrayFromJSON(uint32(), expected)); + } + void ExpectFind(const ExecBatch& key_batch, Datum expected) { + ASSERT_OK_AND_ASSIGN(Datum id_batch, grouper_->Find(key_batch)); + ValidateOutput(id_batch); + + AssertDatumsEqual(expected, id_batch); + } + void AssertEquivalentIds(const Datum& expected, const Datum& actual) { auto left = expected.make_array(); auto right = actual.make_array(); @@ -433,6 +443,18 @@ TEST(Grouper, NumericKey) { g.ExpectConsume("[[3], [27], [3], [27], [null], [81], [27], [81]]", "[0, 1, 0, 1, 3, 2, 1, 2]"); + + g.ExpectFind("[[3], [3]]", "[0, 0]"); + + g.ExpectFind("[[3], [3]]", "[0, 0]"); + + g.ExpectFind("[[27], [81]]", "[1, 2]"); + + g.ExpectFind("[[3], [27], [3], [27], [null], [81], [27], [81]]", + "[0, 1, 0, 1, 3, 2, 1, 2]"); + + g.ExpectFind("[[27], [3], [27], [null], [81], [1], [81]]", + "[1, 0, 1, 3, 2, null, 2]"); } }