diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index e06fad9a1de..2344e4a8464 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -412,6 +412,7 @@ if(ARROW_COMPUTE) compute/kernels/vector_selection.cc compute/kernels/vector_sort.cc compute/exec/union_node.cc + compute/exec/hash_join_node.cc compute/exec/key_hash.cc compute/exec/key_map.cc compute/exec/key_compare.cc diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h index c8df81773d4..e4e8646de70 100644 --- a/cpp/src/arrow/compute/api_aggregate.h +++ b/cpp/src/arrow/compute/api_aggregate.h @@ -410,6 +410,11 @@ class ARROW_EXPORT Grouper { /// be as wide as necessary. virtual Result Consume(const ExecBatch& batch) = 0; + /// Finds/ queries the group IDs for the given ExecBatch for every index. Returns the + /// group IDs as an integer array. If a group ID not found, a null will be added to that + /// index. This is a thread-safe lookup. + virtual Result Find(const ExecBatch& batch) = 0; + /// Get current unique keys. May be called multiple times. virtual Result GetUniques() = 0; diff --git a/cpp/src/arrow/compute/exec/CMakeLists.txt b/cpp/src/arrow/compute/exec/CMakeLists.txt index f269ec33ed5..0d1c0f6df29 100644 --- a/cpp/src/arrow/compute/exec/CMakeLists.txt +++ b/cpp/src/arrow/compute/exec/CMakeLists.txt @@ -25,6 +25,7 @@ add_arrow_compute_test(expression_test subtree_test.cc) add_arrow_compute_test(plan_test PREFIX "arrow-compute") +add_arrow_compute_test(hash_join_node_test PREFIX "arrow-compute") add_arrow_compute_test(union_node_test PREFIX "arrow-compute") add_arrow_benchmark(expression_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index a53c6532591..5c64cf2fc30 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -356,6 +356,7 @@ void RegisterProjectNode(ExecFactoryRegistry*); void RegisterUnionNode(ExecFactoryRegistry*); void RegisterAggregateNode(ExecFactoryRegistry*); void RegisterSinkNode(ExecFactoryRegistry*); +void RegisterHashJoinNode(ExecFactoryRegistry*); } // namespace internal @@ -369,6 +370,7 @@ ExecFactoryRegistry* default_exec_factory_registry() { internal::RegisterUnionNode(this); internal::RegisterAggregateNode(this); internal::RegisterSinkNode(this); + internal::RegisterHashJoinNode(this); } Result GetFactory(const std::string& factory_name) override { diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc new file mode 100644 index 00000000000..9a867081f1b --- /dev/null +++ b/cpp/src/arrow/compute/exec/hash_join_node.cc @@ -0,0 +1,536 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/api.h" +#include "arrow/compute/api.h" +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/thread_pool.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { + +namespace { +Status ValidateJoinInputs(const std::shared_ptr& left_schema, + const std::shared_ptr& right_schema, + const std::vector& left_keys, + const std::vector& right_keys) { + if (left_keys.empty() || right_keys.empty()) { + return Status::Invalid("left and right key sizes can not be empty"); + } + + if (left_keys.size() != right_keys.size()) { + return Status::Invalid("left and right key sizes do not match"); + } + + for (size_t i = 0; i < left_keys.size(); i++) { + auto l_type = left_schema->field(left_keys[i])->type(); + auto r_type = right_schema->field(right_keys[i])->type(); + + if (!l_type->Equals(r_type)) { + return Status::Invalid("build and probe types do not match: " + l_type->ToString() + + "!=" + r_type->ToString()); + } + } + + return Status::OK(); +} + +Result> PopulateKeys(const Schema& schema, + const std::vector& keys) { + std::vector key_field_ids(keys.size()); + // Find input field indices for key fields + for (size_t i = 0; i < keys.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema)); + key_field_ids[i] = match[0]; + } + return key_field_ids; +} + +Result MakeEmptyExecBatch(const std::shared_ptr& schema, + MemoryPool* pool) { + std::vector values; + values.reserve(schema->num_fields()); + + for (const auto& field : schema->fields()) { + ARROW_ASSIGN_OR_RAISE(auto arr, MakeArrayOfNull(field->type(), 0, pool)); + values.emplace_back(arr); + } + + return ExecBatch{std::move(values), 0}; +} + +} // namespace + +template +class HashSemiJoinNode : public ExecNode { + private: + struct ThreadLocalState; + + public: + HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, ExecContext* ctx, + const std::vector&& build_index_field_ids, + const std::vector&& probe_index_field_ids) + : ExecNode(build_input->plan(), {build_input, probe_input}, + {"hash_join_build", "hash_join_probe"}, probe_input->output_schema(), + /*num_outputs=*/1), + ctx_(ctx), + build_index_field_ids_(build_index_field_ids), + probe_index_field_ids_(probe_index_field_ids), + build_result_index(-1), + hash_table_built_(false), + cached_probe_batches_consumed(false) {} + + const char* kind_name() const override { return "HashSemiJoinNode"; } + + // If all build side batches received, continue streaming using probing + // else cache the batches in thread-local state + void InputReceived(ExecNode* input, ExecBatch batch) override { + ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); + + if (finished_.is_finished()) { + return; + } + + if (IsBuildInput(input)) { // build input batch is received + // if a build input is received when build side is completed, something's wrong! + ARROW_DCHECK(!hash_table_built_); + + ErrorIfNotOk(ConsumeBuildBatch(std::move(batch))); + } else { // probe input batch is received + if (hash_table_built_) { + // build side done, continue with probing. when hash_table_built_ is set, it is + // guaranteed that some thread has already called the ConsumeCachedProbeBatches + + // consume this probe batch + ErrorIfNotOk(ConsumeProbeBatch(std::move(batch))); + } else { // build side not completed. Cache this batch! + if (!AttemptToCacheProbeBatch(&batch)) { + // if the cache attempt fails, consume the batch + ErrorIfNotOk(ConsumeProbeBatch(std::move(batch))); + } + } + } + } + + void ErrorReceived(ExecNode* input, Status error) override { + DCHECK_EQ(input, inputs_[0]); + + outputs_[0]->ErrorReceived(this, std::move(error)); + StopProducing(); + } + + void InputFinished(ExecNode* input, int num_total) override { + // bail if StopProducing was called + if (finished_.is_finished()) return; + + ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]); + + // set total for build input + if (IsBuildInput(input) && build_counter_.SetTotal(num_total)) { + // only one thread would get inside this block! + // while incrementing, if the total is reached, call BuildSideCompleted. + ErrorIfNotOk(BuildSideCompleted()); + return; + } + + // output will be streamed from the probe side. So, they will have the same total. + if (out_counter_.SetTotal(num_total)) { + // if out_counter has completed, the future is finished! + ErrorIfNotOk(ConsumeCachedProbeBatches()); + outputs_[0]->InputFinished(this, num_total); + finished_.MarkFinished(); + } else { + outputs_[0]->InputFinished(this, num_total); + } + } + + Status StartProducing() override { + finished_ = Future<>::Make(); + + local_states_.resize(ThreadIndexer::Capacity()); + return Status::OK(); + } + + void PauseProducing(ExecNode* output) override {} + + void ResumeProducing(ExecNode* output) override {} + + void StopProducing(ExecNode* output) override { + DCHECK_EQ(output, outputs_[0]); + + if (build_counter_.Cancel()) { + finished_.MarkFinished(); + } else if (out_counter_.Cancel()) { + finished_.MarkFinished(); + } + + for (auto&& input : inputs_) { + input->StopProducing(this); + } + } + + void StopProducing() override { outputs_[0]->StopProducing(); } + + Future<> finished() override { return finished_; } + + private: + Status InitLocalStateIfNeeded(ThreadLocalState* state) { + // Get input schema + auto build_schema = inputs_[0]->output_schema(); + + if (state->grouper != nullptr) return Status::OK(); + + // Build vector of key field data types + std::vector key_descrs(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + auto build_type = build_schema->field(build_index_field_ids_[i])->type(); + key_descrs[i] = ValueDescr(build_type); + } + + // Construct grouper + ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, ctx_)); + + return Status::OK(); + } + + // Finds an appropriate index which could accumulate all build indices (i.e. the grouper + // which has the highest # of groups) + void CalculateBuildResultIndex() { + int32_t curr_max = -1; + for (int i = 0; i < static_cast(local_states_.size()); i++) { + auto* state = &local_states_[i]; + ARROW_DCHECK(state); + if (state->grouper && + curr_max < static_cast(state->grouper->num_groups())) { + curr_max = static_cast(state->grouper->num_groups()); + build_result_index = i; + } + } + DCHECK_GT(build_result_index, -1); + } + + // Performs the housekeeping work after the build-side is completed. + // Note: this method is not thread safe, and hence should be guaranteed that it is + // not accessed concurrently! + Status BuildSideCompleted() { + // if the hash table has already been built, return + if (hash_table_built_) return Status::OK(); + + CalculateBuildResultIndex(); + + // merge every group into the build_result_index grouper + ThreadLocalState* result_state = &local_states_[build_result_index]; + for (int i = 0; i < static_cast(local_states_.size()); ++i) { + ThreadLocalState* state = &local_states_[i]; + ARROW_DCHECK(state); + if (i == build_result_index || !state->grouper) { + continue; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques()); + + // TODO(niranda) replace with void consume method + ARROW_ASSIGN_OR_RAISE(Datum _, result_state->grouper->Consume(other_keys)); + state->grouper.reset(); + } + + // enable flag that build side is completed + hash_table_built_ = true; + + // since the build side is completed, consume cached probe batches + RETURN_NOT_OK(ConsumeCachedProbeBatches()); + + return Status::OK(); + } + + // consumes a build batch and increments the build_batches count. if the build batches + // total reached at the end of consumption, all the local states will be merged, before + // incrementing the total batches + Status ConsumeBuildBatch(ExecBatch batch) { + size_t thread_index = get_thread_index_(); + ARROW_DCHECK(thread_index < local_states_.size()); + + auto state = &local_states_[thread_index]; + RETURN_NOT_OK(InitLocalStateIfNeeded(state)); + + // Create a batch with key columns + std::vector keys(build_index_field_ids_.size()); + for (size_t i = 0; i < build_index_field_ids_.size(); ++i) { + keys[i] = batch.values[build_index_field_ids_[i]]; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); + + // Create a batch with group ids + // TODO(niranda) replace with void consume method + ARROW_ASSIGN_OR_RAISE(Datum _, state->grouper->Consume(key_batch)); + + if (build_counter_.Increment()) { + // only one thread would get inside this block! + // while incrementing, if the total is reached, call BuildSideCompleted. + RETURN_NOT_OK(BuildSideCompleted()); + } + + return Status::OK(); + } + + // consumes cached probe batches by invoking executor::Spawn. + Status ConsumeCachedProbeBatches() { + // acquire the mutex to access cached_probe_batches, because while consuming, other + // batches should not be cached! + std::lock_guard lck(cached_probe_batches_mutex); + + if (!cached_probe_batches_consumed) { + auto executor = ctx_->executor(); + + while (!cached_probe_batches.empty()) { + ExecBatch cached = std::move(cached_probe_batches.back()); + cached_probe_batches.pop_back(); + + if (executor) { + RETURN_NOT_OK(executor->Spawn( + // since cached will be going out-of-scope, it needs to be copied into the + // capture list + [&, cached]() mutable { + // since batch consumption is done asynchronously, a failed status would + // have to be propagated then and there! + ErrorIfNotOk(ConsumeProbeBatch(std::move(cached))); + })); + } else { + RETURN_NOT_OK(ConsumeProbeBatch(std::move(cached))); + } + } + } + + // set flag + cached_probe_batches_consumed = true; + return Status::OK(); + } + + Status GenerateOutput(const ArrayData& group_ids_data, ExecBatch batch) { + if (group_ids_data.GetNullCount() == batch.length) { + // All NULLS! hence, there are no valid outputs! + ARROW_ASSIGN_OR_RAISE(auto empty_batch, + MakeEmptyExecBatch(output_schema_, ctx_->memory_pool())); + outputs_[0]->InputReceived(this, std::move(empty_batch)); + } else if (group_ids_data.MayHaveNulls()) { // values need to be filtered + auto filter_arr = + std::make_shared(group_ids_data.length, group_ids_data.buffers[0], + /*null_bitmap=*/nullptr, /*null_count=*/0, + /*offset=*/group_ids_data.offset); + ARROW_ASSIGN_OR_RAISE(auto rec_batch, + batch.ToRecordBatch(output_schema_, ctx_->memory_pool())); + ARROW_ASSIGN_OR_RAISE( + auto filtered, + Filter(rec_batch, filter_arr, + /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); + auto out_batch = ExecBatch(*filtered.record_batch()); + outputs_[0]->InputReceived(this, std::move(out_batch)); + } else { // all values are valid for output + outputs_[0]->InputReceived(this, std::move(batch)); + } + + return Status::OK(); + } + + // consumes a probe batch and increment probe batches count. Probing would query the + // grouper[build_result_index] which have been merged with all others. + Status ConsumeProbeBatch(ExecBatch batch) { + auto& final_grouper = *local_states_[build_result_index].grouper; + + // Create a batch with key columns + std::vector keys(probe_index_field_ids_.size()); + for (size_t i = 0; i < probe_index_field_ids_.size(); ++i) { + keys[i] = batch.values[probe_index_field_ids_[i]]; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); + + // Query the grouper with key_batch. If no match was found, returning group_ids would + // have null. + ARROW_ASSIGN_OR_RAISE(Datum group_ids, final_grouper.Find(key_batch)); + auto group_ids_data = *group_ids.array(); + + RETURN_NOT_OK(GenerateOutput(group_ids_data, std::move(batch))); + + if (out_counter_.Increment()) { + finished_.MarkFinished(); + } + return Status::OK(); + } + + // Attempt to cache a probe batch. If it is not cached, return false. + // if cached_probe_batches_consumed is true, by the time a thread acquires + // cached_probe_batches_mutex, it should no longer be cached! instead, it can be + // directly consumed! + bool AttemptToCacheProbeBatch(ExecBatch* batch) { + std::lock_guard lck(cached_probe_batches_mutex); + if (cached_probe_batches_consumed) { + return false; + } + cached_probe_batches.push_back(std::move(*batch)); + return true; + } + + bool IsBuildInput(ExecNode* input) { return input == inputs_[0]; } + + struct ThreadLocalState { + std::unique_ptr grouper; + }; + + ExecContext* ctx_; + Future<> finished_ = Future<>::MakeFinished(); + + ThreadIndexer get_thread_index_; + const std::vector build_index_field_ids_, probe_index_field_ids_; + + AtomicCounter build_counter_, out_counter_; + std::vector local_states_; + + // There's no guarantee on which threads would be coming from the build side. so, out of + // the thread local states, an appropriate state needs to be chosen to accumulate + // all built results (ideally, the grouper which has the highest # of elems) + int32_t build_result_index; + + // need a separate bool to track if the build side complete. Can't use the flag + // inside the AtomicCounter, because we need to merge the build groupers once we receive + // all the build batches. So, while merging, we need to prevent probe batches, being + // consumed. + bool hash_table_built_; + + std::mutex cached_probe_batches_mutex; + std::vector cached_probe_batches{}; + // a flag is required to indicate if the cached probe batches have already been + // consumed! if cached_probe_batches_consumed is true, by the time a thread aquires + // cached_probe_batches_mutex, it should no longer be cached! instead, it can be + // directly consumed! + bool cached_probe_batches_consumed; +}; + +// template specialization for anti joins. For anti joins, group_ids_data needs to be +// inverted. Output will be taken for indices which are NULL +template <> +Status HashSemiJoinNode::GenerateOutput(const ArrayData& group_ids_data, + ExecBatch batch) { + if (group_ids_data.GetNullCount() == group_ids_data.length) { + // All NULLS! hence, all values are valid for output + outputs_[0]->InputReceived(this, std::move(batch)); + } else if (group_ids_data.MayHaveNulls()) { // values need to be filtered + // invert the validity buffer + arrow::internal::InvertBitmap( + group_ids_data.buffers[0]->data(), group_ids_data.offset, group_ids_data.length, + group_ids_data.buffers[0]->mutable_data(), group_ids_data.offset); + + auto filter_arr = + std::make_shared(group_ids_data.length, group_ids_data.buffers[0], + /*null_bitmap=*/nullptr, /*null_count=*/0, + /*offset=*/group_ids_data.offset); + ARROW_ASSIGN_OR_RAISE(auto rec_batch, + batch.ToRecordBatch(output_schema_, ctx_->memory_pool())); + ARROW_ASSIGN_OR_RAISE( + auto filtered, + Filter(rec_batch, filter_arr, + /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_)); + auto out_batch = ExecBatch(*filtered.record_batch()); + outputs_[0]->InputReceived(this, std::move(out_batch)); + } else { + // No NULLS! hence, there are no valid outputs! + ARROW_ASSIGN_OR_RAISE(auto empty_batch, + MakeEmptyExecBatch(output_schema_, ctx_->memory_pool())); + outputs_[0]->InputReceived(this, std::move(empty_batch)); + } + return Status::OK(); +} + +template +Result MakeHashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, + const std::vector& build_keys, + const std::vector& probe_keys) { + auto build_schema = build_input->output_schema(); + auto probe_schema = probe_input->output_schema(); + + ARROW_ASSIGN_OR_RAISE(auto build_key_ids, PopulateKeys(*build_schema, build_keys)); + ARROW_ASSIGN_OR_RAISE(auto probe_key_ids, PopulateKeys(*probe_schema, probe_keys)); + + RETURN_NOT_OK( + ValidateJoinInputs(build_schema, probe_schema, build_key_ids, probe_key_ids)); + + // output schema will be probe schema + auto ctx = build_input->plan()->exec_context(); + ExecPlan* plan = build_input->plan(); + + return plan->EmplaceNode>( + build_input, probe_input, ctx, std::move(build_key_ids), std::move(probe_key_ids)); +} + +Result MakeHashJoinNode(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) { + RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 2, "HashJoinNode")); + + const auto& join_options = checked_cast(options); + + static std::string join_type_string[] = {"LEFT_SEMI", "RIGHT_SEMI", "LEFT_ANTI", + "RIGHT_ANTI", "INNER", "LEFT_OUTER", + "RIGHT_OUTER", "FULL_OUTER"}; + + auto join_type = join_options.join_type; + + ExecNode* left_input = inputs[0]; + ExecNode* right_input = inputs[1]; + const auto& left_keys = join_options.left_keys; + const auto& right_keys = join_options.right_keys; + + switch (join_type) { + case LEFT_SEMI: + // left join--> build from right and probe from left + return MakeHashSemiJoinNode(right_input, left_input, right_keys, left_keys); + case RIGHT_SEMI: + // right join--> build from left and probe from right + return MakeHashSemiJoinNode(left_input, right_input, left_keys, right_keys); + case LEFT_ANTI: + // left join--> build from right and probe from left + return MakeHashSemiJoinNode(right_input, left_input, right_keys, left_keys); + case RIGHT_ANTI: + // right join--> build from left and probe from right + return MakeHashSemiJoinNode(left_input, right_input, left_keys, right_keys); + case INNER: + case LEFT_OUTER: + case RIGHT_OUTER: + case FULL_OUTER: + return Status::NotImplemented(join_type_string[join_type] + + " joins not implemented!"); + default: + return Status::Invalid("invalid join type"); + } +} + +namespace internal { +void RegisterHashJoinNode(ExecFactoryRegistry* registry) { + DCHECK_OK(registry->AddFactory("hash_join", MakeHashJoinNode)); +} +} // namespace internal + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc new file mode 100644 index 00000000000..5efb14012d9 --- /dev/null +++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc @@ -0,0 +1,411 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include + +#include "arrow/api.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/test_util.h" +#include "arrow/compute/kernels/test_util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/thread_pool.h" + +using testing::UnorderedElementsAreArray; + +namespace arrow { +namespace compute { + +BatchesWithSchema GenerateBatchesFromString( + const std::shared_ptr& schema, + const std::vector& json_strings, int multiplicity = 1) { + BatchesWithSchema out_batches{{}, schema}; + + std::vector descrs; + for (auto&& field : schema->fields()) { + descrs.emplace_back(field->type()); + } + + for (auto&& s : json_strings) { + out_batches.batches.push_back(ExecBatchFromJSON(descrs, s)); + } + + size_t batch_count = out_batches.batches.size(); + for (int repeat = 1; repeat < multiplicity; ++repeat) { + for (size_t i = 0; i < batch_count; ++i) { + out_batches.batches.push_back(out_batches.batches[i]); + } + } + + return out_batches; +} + +void CheckRunOutput(JoinType type, const BatchesWithSchema& l_batches, + const BatchesWithSchema& r_batches, + const std::vector& left_keys, + const std::vector& right_keys, + const BatchesWithSchema& exp_batches, bool parallel = false) { + auto exec_ctx = arrow::internal::make_unique( + default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); + + JoinNodeOptions join_options{type, left_keys, right_keys}; + Declaration join{"hash_join", join_options}; + + // add left source + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{l_batches.schema, l_batches.gen(parallel, + /*slow=*/false)}}); + // add right source + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{r_batches.schema, r_batches.gen(parallel, + /*slow=*/false)}}); + AsyncGenerator> sink_gen; + + ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) + .AddToPlan(plan.get())); + + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); +} + +void RunNonEmptyTest(JoinType type, bool parallel) { + auto l_schema = schema({field("l_i32", int32()), field("l_str", utf8())}); + auto r_schema = schema({field("r_str", utf8()), field("r_i32", int32())}); + BatchesWithSchema l_batches, r_batches, exp_batches; + + int multiplicity = parallel ? 100 : 1; + + l_batches = GenerateBatchesFromString( + l_schema, + {R"([[0,"d"], [1,"b"]])", R"([[2,"d"], [3,"a"], [4,"a"]])", + R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, + multiplicity); + + r_batches = GenerateBatchesFromString( + r_schema, + {R"([["f", 0], ["b", 1], ["b", 2]])", R"([["c", 3], ["g", 4]])", R"([["e", 5]])"}, + multiplicity); + + switch (type) { + case LEFT_SEMI: + exp_batches = GenerateBatchesFromString( + l_schema, {R"([[1,"b"]])", R"([])", R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, + multiplicity); + break; + case RIGHT_SEMI: + exp_batches = GenerateBatchesFromString( + r_schema, {R"([["b", 1], ["b", 2]])", R"([["c", 3]])", R"([["e", 5]])"}, + multiplicity); + break; + case LEFT_ANTI: + exp_batches = GenerateBatchesFromString( + l_schema, {R"([[0,"d"]])", R"([[2,"d"], [3,"a"], [4,"a"]])", R"([])"}, + multiplicity); + break; + case RIGHT_ANTI: + exp_batches = GenerateBatchesFromString( + r_schema, {R"([["f", 0]])", R"([["g", 4]])", R"([])"}, multiplicity); + break; + case INNER: + case LEFT_OUTER: + case RIGHT_OUTER: + case FULL_OUTER: + default: + FAIL() << "join type not implemented!"; + } + + CheckRunOutput(type, l_batches, r_batches, + /*left_keys=*/{{"l_str"}}, /*right_keys=*/{{"r_str"}}, exp_batches, + parallel); +} + +void RunEmptyTest(JoinType type, bool parallel) { + auto l_schema = schema({field("l_i32", int32()), field("l_str", utf8())}); + auto r_schema = schema({field("r_str", utf8()), field("r_i32", int32())}); + + int multiplicity = parallel ? 100 : 1; + + BatchesWithSchema l_empty, r_empty, l_n_empty, r_n_empty; + + l_empty = GenerateBatchesFromString(l_schema, {R"([])"}, multiplicity); + r_empty = GenerateBatchesFromString(r_schema, {R"([])"}, multiplicity); + + l_n_empty = + GenerateBatchesFromString(l_schema, {R"([[0,"d"], [1,"b"]])"}, multiplicity); + r_n_empty = GenerateBatchesFromString(r_schema, {R"([["f", 0], ["b", 1], ["b", 2]])"}, + multiplicity); + + std::vector l_keys{{"l_str"}}; + std::vector r_keys{{"r_str"}}; + + switch (type) { + case LEFT_SEMI: + // both empty + CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, l_empty, parallel); + // right empty + CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, l_empty, parallel); + // left empty + CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, l_empty, parallel); + break; + case RIGHT_SEMI: + // both empty + CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, r_empty, parallel); + // right empty + CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, r_empty, parallel); + // left empty + CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, r_empty, parallel); + break; + case LEFT_ANTI: + // both empty + CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, l_empty, parallel); + // right empty + CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, l_n_empty, parallel); + // left empty + CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, l_empty, parallel); + break; + case RIGHT_ANTI: + // both empty + CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, r_empty, parallel); + // right empty + CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, r_empty, parallel); + // left empty + CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, r_n_empty, parallel); + break; + case INNER: + case LEFT_OUTER: + case RIGHT_OUTER: + case FULL_OUTER: + default: + FAIL() << "join type not implemented!"; + } +} + +class HashJoinTest : public testing::TestWithParam> {}; + +INSTANTIATE_TEST_SUITE_P( + HashJoinTest, HashJoinTest, + ::testing::Combine(::testing::Values(JoinType::LEFT_SEMI, JoinType::RIGHT_SEMI, + JoinType::LEFT_ANTI, JoinType::RIGHT_ANTI), + ::testing::Values(false, true))); + +TEST_P(HashJoinTest, TestSemiJoins) { + RunNonEmptyTest(std::get<0>(GetParam()), std::get<1>(GetParam())); +} + +TEST_P(HashJoinTest, TestSemiJoinstEmpty) { + RunEmptyTest(std::get<0>(GetParam()), std::get<1>(GetParam())); +} + +template +static Status SimpleVerifySemiJoinOutputImpl(int index_col, + const std::shared_ptr& schema, + const std::vector& build_batches, + const std::vector& probe_batches, + const std::vector& output_batches, + bool anti_join = false) { + // populate hash set + std::unordered_set hash_set; + bool has_null = false; + for (auto&& b : build_batches) { + const std::shared_ptr& arr = b[index_col].array(); + VisitArrayDataInline( + *arr, [&](C_TYPE val) { hash_set.insert(val); }, [&]() { has_null = true; }); + } + + // probe hash set + RecordBatchVector exp_batches; + exp_batches.reserve(probe_batches.size()); + for (auto&& b : probe_batches) { + const std::shared_ptr& arr = b[index_col].array(); + + BooleanBuilder builder; + RETURN_NOT_OK(builder.Reserve(arr->length)); + VisitArrayDataInline( + *arr, + [&](C_TYPE val) { + auto res = hash_set.find(val); + // setting anti_join, would invert res != hash_set.end() + builder.UnsafeAppend(anti_join != (res != hash_set.end())); + }, + [&]() { builder.UnsafeAppend(anti_join != has_null); }); + + ARROW_ASSIGN_OR_RAISE(auto filter, builder.Finish()); + + ARROW_ASSIGN_OR_RAISE(auto rec_batch, b.ToRecordBatch(schema)); + ARROW_ASSIGN_OR_RAISE(auto filtered, Filter(rec_batch, filter)); + + exp_batches.push_back(filtered.record_batch()); + } + + ARROW_ASSIGN_OR_RAISE(auto exp_table, Table::FromRecordBatches(exp_batches)); + std::vector sort_keys; + for (auto&& f : schema->fields()) { + sort_keys.emplace_back(f->name()); + } + ARROW_ASSIGN_OR_RAISE(auto exp_table_sort_ids, + SortIndices(exp_table, SortOptions(sort_keys))); + ARROW_ASSIGN_OR_RAISE(auto exp_table_sorted, Take(exp_table, exp_table_sort_ids)); + + // create a table from output batches + RecordBatchVector output_rbs; + for (auto&& b : output_batches) { + ARROW_ASSIGN_OR_RAISE(auto rb, b.ToRecordBatch(schema)); + output_rbs.push_back(std::move(rb)); + } + + ARROW_ASSIGN_OR_RAISE(auto out_table, Table::FromRecordBatches(output_rbs)); + ARROW_ASSIGN_OR_RAISE(auto out_table_sort_ids, + SortIndices(exp_table, SortOptions(sort_keys))); + ARROW_ASSIGN_OR_RAISE(auto out_table_sorted, Take(exp_table, exp_table_sort_ids)); + + AssertTablesEqual(*exp_table_sorted.table(), *out_table_sorted.table(), + /*same_chunk_layout=*/false, /*flatten=*/true); + + return Status::OK(); +} + +template +struct SimpleVerifySemiJoinOutput {}; + +template +struct SimpleVerifySemiJoinOutput> { + static Status Verify(int index_col, const std::shared_ptr& schema, + const std::vector& build_batches, + const std::vector& probe_batches, + const std::vector& output_batches, + bool anti_join = false) { + return SimpleVerifySemiJoinOutputImpl::CType>( + index_col, schema, build_batches, probe_batches, output_batches, anti_join); + } +}; + +template +struct SimpleVerifySemiJoinOutput> { + static Status Verify(int index_col, const std::shared_ptr& schema, + const std::vector& build_batches, + const std::vector& probe_batches, + const std::vector& output_batches, + bool anti_join = false) { + return SimpleVerifySemiJoinOutputImpl( + index_col, schema, build_batches, probe_batches, output_batches, anti_join); + } +}; + +template +void TestSemiJoinRandom(JoinType type, bool parallel, int num_batches, int batch_size) { + auto data_type = default_type_instance(); + auto l_schema = schema({field("l0", data_type), field("l1", data_type)}); + auto r_schema = schema({field("r0", data_type), field("r1", data_type)}); + + // generate data + auto l_batches = MakeRandomBatches(l_schema, num_batches, batch_size); + auto r_batches = MakeRandomBatches(r_schema, num_batches, batch_size); + + std::vector left_keys{{"l0"}}; + std::vector right_keys{{"r1"}}; + + auto exec_ctx = arrow::internal::make_unique( + default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); + + JoinNodeOptions join_options{type, left_keys, right_keys}; + Declaration join{"hash_join", join_options}; + + // add left source + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{l_batches.schema, l_batches.gen(parallel, + /*slow=*/false)}}); + // add right source + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{r_batches.schema, r_batches.gen(parallel, + /*slow=*/false)}}); + AsyncGenerator> sink_gen; + + ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) + .AddToPlan(plan.get())); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); + + // verification step for res + switch (type) { + case LEFT_SEMI: + ASSERT_OK(SimpleVerifySemiJoinOutput::Verify( + 0, l_schema, r_batches.batches, l_batches.batches, res)); + return; + case RIGHT_SEMI: + ASSERT_OK(SimpleVerifySemiJoinOutput::Verify( + 0, r_schema, l_batches.batches, r_batches.batches, res)); + return; + case LEFT_ANTI: + ASSERT_OK(SimpleVerifySemiJoinOutput::Verify( + 0, l_schema, r_batches.batches, l_batches.batches, res, true)); + return; + case RIGHT_ANTI: + ASSERT_OK(SimpleVerifySemiJoinOutput::Verify( + 0, l_schema, r_batches.batches, l_batches.batches, res, true)); + return; + default: + FAIL() << "Unsupported join type"; + } +} + +static constexpr int kNumBatches = 100; +static constexpr int kBatchSize = 10; + +using TestingTypes = ::testing::Types; + +template +class HashJoinTestRand : public testing::Test {}; + +TYPED_TEST_SUITE(HashJoinTestRand, TestingTypes); + +TYPED_TEST(HashJoinTestRand, LeftSemiJoin) { + for (bool parallel : {false, true}) { + TestSemiJoinRandom(JoinType::LEFT_SEMI, parallel, kNumBatches, kBatchSize); + } +} + +TYPED_TEST(HashJoinTestRand, RightSemiJoin) { + for (bool parallel : {false, true}) { + TestSemiJoinRandom(JoinType::RIGHT_SEMI, parallel, kNumBatches, + kBatchSize); + } +} + +TYPED_TEST(HashJoinTestRand, LeftAntiJoin) { + for (bool parallel : {false, true}) { + TestSemiJoinRandom(JoinType::LEFT_ANTI, parallel, kNumBatches, kBatchSize); + } +} + +TYPED_TEST(HashJoinTestRand, RightAntiJoin) { + for (bool parallel : {false, true}) { + TestSemiJoinRandom(JoinType::RIGHT_ANTI, parallel, kNumBatches, + kBatchSize); + } +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index acc79bdfdde..e936d07e5ae 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -112,6 +112,33 @@ class ARROW_EXPORT SinkNodeOptions : public ExecNodeOptions { std::function>()>* generator; }; +enum JoinType { + LEFT_SEMI, + RIGHT_SEMI, + LEFT_ANTI, + RIGHT_ANTI, + INNER, // Not Implemented + LEFT_OUTER, // Not Implemented + RIGHT_OUTER, // Not Implemented + FULL_OUTER // Not Implemented +}; + +class ARROW_EXPORT JoinNodeOptions : public ExecNodeOptions { + public: + JoinNodeOptions(JoinType join_type, std::vector left_keys, + std::vector right_keys) + : join_type(join_type), + left_keys(std::move(left_keys)), + right_keys(std::move(right_keys)) {} + + // type of the join + JoinType join_type; + + // index keys of the join + std::vector left_keys; + std::vector right_keys; +}; + /// \brief Make a node which sorts rows passed through it /// /// All batches pushed to this node will be accumulated, then sorted, by the given diff --git a/cpp/src/arrow/compute/exec/util.cc b/cpp/src/arrow/compute/exec/util.cc index e2fe61a63c6..020431ebf97 100644 --- a/cpp/src/arrow/compute/exec/util.cc +++ b/cpp/src/arrow/compute/exec/util.cc @@ -21,6 +21,8 @@ #include "arrow/table.h" #include "arrow/util/bit_util.h" #include "arrow/util/bitmap_ops.h" +#include "arrow/util/io_util.h" +#include "arrow/util/thread_pool.h" #include "arrow/util/ubsan.h" namespace arrow { @@ -301,6 +303,27 @@ Status ValidateExecNodeInputs(ExecPlan* plan, const std::vector& inpu return Status::OK(); } +size_t ThreadIndexer::operator()() { + auto id = arrow::internal::GetThreadId(); + + auto guard = mutex_.Lock(); // acquire the lock + const auto& id_index = *id_to_index_.emplace(id, id_to_index_.size()).first; + + return Check(id_index.second); +} + +size_t ThreadIndexer::Capacity() { + static size_t max_size = arrow::internal::ThreadPool::DefaultCapacity(); + return max_size; +} + +size_t ThreadIndexer::Check(size_t thread_index) { + DCHECK_LT(thread_index, Capacity()) + << "thread index " << thread_index << " is out of range [0, " << Capacity() << ")"; + + return thread_index; +} + Result> TableFromExecBatches( const std::shared_ptr& schema, const std::vector& exec_batches) { RecordBatchVector batches; diff --git a/cpp/src/arrow/compute/exec/util.h b/cpp/src/arrow/compute/exec/util.h index 63f3315f7e0..d54feea2e99 100644 --- a/cpp/src/arrow/compute/exec/util.h +++ b/cpp/src/arrow/compute/exec/util.h @@ -19,6 +19,7 @@ #include #include +#include #include #include "arrow/buffer.h" @@ -29,6 +30,7 @@ #include "arrow/util/bit_util.h" #include "arrow/util/cpu_info.h" #include "arrow/util/logging.h" +#include "arrow/util/mutex.h" #include "arrow/util/optional.h" #if defined(__clang__) || defined(__GNUC__) @@ -242,5 +244,18 @@ class AtomicCounter { std::atomic complete_{false}; }; +class ThreadIndexer { + public: + size_t operator()(); + + static size_t Capacity(); + + private: + static size_t Check(size_t thread_index); + + util::Mutex mutex_; + std::unordered_map id_to_index_; +}; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 6e7c074573d..900cdd2e585 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#include + #include #include #include @@ -40,6 +42,7 @@ #include "arrow/compute/kernels/util_internal.h" #include "arrow/record_batch.h" #include "arrow/util/bit_run_reader.h" +#include "arrow/util/bitmap.h" #include "arrow/util/bitmap_ops.h" #include "arrow/util/bitmap_writer.h" #include "arrow/util/checked_cast.h" @@ -443,30 +446,43 @@ struct GrouperImpl : Grouper { return std::move(impl); } - Result Consume(const ExecBatch& batch) override { - std::vector offsets_batch(batch.length + 1); + Status PopulateKeyData(const ExecBatch& batch, std::vector* offsets_batch, + std::vector* key_bytes_batch, + std::vector* key_buf_ptrs) const { + offsets_batch->resize(batch.length + 1); for (int i = 0; i < batch.num_values(); ++i) { - encoders_[i]->AddLength(batch[i], batch.length, offsets_batch.data()); + encoders_[i]->AddLength(batch[i], batch.length, offsets_batch->data()); } int32_t total_length = 0; for (int64_t i = 0; i < batch.length; ++i) { auto total_length_before = total_length; - total_length += offsets_batch[i]; - offsets_batch[i] = total_length_before; + total_length += (*offsets_batch)[i]; + (*offsets_batch)[i] = total_length_before; } - offsets_batch[batch.length] = total_length; + (*offsets_batch)[batch.length] = total_length; - std::vector key_bytes_batch(total_length); - std::vector key_buf_ptrs(batch.length); + key_bytes_batch->resize(total_length); + key_buf_ptrs->resize(batch.length); for (int64_t i = 0; i < batch.length; ++i) { - key_buf_ptrs[i] = key_bytes_batch.data() + offsets_batch[i]; + key_buf_ptrs->at(i) = key_bytes_batch->data() + (*offsets_batch)[i]; } for (int i = 0; i < batch.num_values(); ++i) { - RETURN_NOT_OK(encoders_[i]->Encode(batch[i], batch.length, key_buf_ptrs.data())); + RETURN_NOT_OK(encoders_[i]->Encode(batch[i], batch.length, key_buf_ptrs->data())); } + return Status::OK(); + } + + Result Consume(const ExecBatch& batch) override { + std::vector offsets_batch; + std::vector key_bytes_batch; + std::vector key_buf_ptrs; + + RETURN_NOT_OK( + PopulateKeyData(batch, &offsets_batch, &key_bytes_batch, &key_buf_ptrs)); + TypedBufferBuilder group_ids_batch(ctx_->memory_pool()); RETURN_NOT_OK(group_ids_batch.Resize(batch.length)); @@ -495,6 +511,36 @@ struct GrouperImpl : Grouper { return Datum(UInt32Array(batch.length, std::move(group_ids))); } + Result Find(const ExecBatch& batch) override { + std::vector offsets_batch; + std::vector key_bytes_batch; + std::vector key_buf_ptrs; + + RETURN_NOT_OK( + PopulateKeyData(batch, &offsets_batch, &key_bytes_batch, &key_buf_ptrs)); + + UInt32Builder group_ids_batch(ctx_->memory_pool()); + RETURN_NOT_OK(group_ids_batch.Resize(batch.length)); + + for (int64_t i = 0; i < batch.length; ++i) { + int32_t key_length = offsets_batch[i + 1] - offsets_batch[i]; + std::string key( + reinterpret_cast(key_bytes_batch.data() + offsets_batch[i]), + key_length); + + auto it = map_.find(key); + // no group ID was found, null will be emitted! + if (it == map_.end()) { + group_ids_batch.UnsafeAppendNull(); + } else { + group_ids_batch.UnsafeAppend(it->second); + } + } + + ARROW_ASSIGN_OR_RAISE(auto group_ids, group_ids_batch.Finish()); + return Datum(group_ids); + } + uint32_t num_groups() const override { return num_groups_; } Result GetUniques() override { @@ -632,6 +678,7 @@ struct GrouperFastImpl : Grouper { return ConsumeImpl(batch); } + template Result ConsumeImpl(const ExecBatch& batch) { int64_t num_rows = batch.length; int num_columns = batch.num_values(); @@ -656,6 +703,15 @@ struct GrouperFastImpl : Grouper { ARROW_ASSIGN_OR_RAISE( group_ids, AllocateBuffer(sizeof(uint32_t) * num_rows, ctx_->memory_pool())); + std::shared_ptr group_ids_validity; + if (Find) { + ARROW_ASSIGN_OR_RAISE( + group_ids_validity, + AllocateBitmap(sizeof(uint32_t) * num_rows, ctx_->memory_pool())); + } else { + group_ids_validity = nullptr; + } + for (int icol = 0; icol < num_columns; ++icol) { const uint8_t* non_nulls = nullptr; if (batch[icol].array()->buffers[0] != NULLPTR) { @@ -701,16 +757,23 @@ struct GrouperFastImpl : Grouper { match_bitvector.mutable_data(), local_slots.mutable_data(), reinterpret_cast(group_ids->mutable_data()) + start_row); } - auto ids = util::TempVectorHolder(&temp_stack_, batch_size_next); - int num_ids; - util::BitUtil::bits_to_indexes(0, encode_ctx_.hardware_flags, batch_size_next, - match_bitvector.mutable_data(), &num_ids, - ids.mutable_data()); - - RETURN_NOT_OK(map_.map_new_keys( - num_ids, ids.mutable_data(), minibatch_hashes_.data(), - reinterpret_cast(group_ids->mutable_data()) + start_row)); + if (Find) { + // In find mode, don't insert the new keys. just copy the match bitvector to + // the group_ids_validity buffer. Valid group_ids are already populated. + arrow::internal::CopyBitmap(match_bitvector.mutable_data(), 0, batch_size_next, + group_ids_validity->mutable_data(), start_row); + } else { + auto ids = util::TempVectorHolder(&temp_stack_, batch_size_next); + int num_ids; + util::BitUtil::bits_to_indexes(0, encode_ctx_.hardware_flags, batch_size_next, + match_bitvector.mutable_data(), &num_ids, + ids.mutable_data()); + + RETURN_NOT_OK(map_.map_new_keys( + num_ids, ids.mutable_data(), minibatch_hashes_.data(), + reinterpret_cast(group_ids->mutable_data()) + start_row)); + } start_row += batch_size_next; if (minibatch_size_ * 2 <= minibatch_size_max_) { @@ -718,7 +781,27 @@ struct GrouperFastImpl : Grouper { } } - return Datum(UInt32Array(batch.length, std::move(group_ids))); + return Datum( + UInt32Array(batch.length, std::move(group_ids), std::move(group_ids_validity))); + } + + Result Find(const ExecBatch& batch) override { + // ARROW-14027: broadcast scalar arguments for now + for (int i = 0; i < batch.num_values(); i++) { + if (batch.values[i].is_scalar()) { + ExecBatch expanded = batch; + for (int j = i; j < expanded.num_values(); j++) { + if (expanded.values[j].is_scalar()) { + ARROW_ASSIGN_OR_RAISE( + expanded.values[j], + MakeArrayFromScalar(*expanded.values[j].scalar(), expanded.length, + ctx_->memory_pool())); + } + } + return ConsumeImpl(expanded); + } + } + return ConsumeImpl(batch); } uint32_t num_groups() const override { return static_cast(rows_.length()); } diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc index 412290aa777..5a42d169524 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc @@ -320,6 +320,16 @@ struct TestGrouper { AssertEquivalentIds(expected, ids); } + void ExpectFind(const std::string& key_json, const std::string& expected) { + ExpectFind(ExecBatchFromJSON(descrs_, key_json), ArrayFromJSON(uint32(), expected)); + } + void ExpectFind(const ExecBatch& key_batch, Datum expected) { + ASSERT_OK_AND_ASSIGN(Datum id_batch, grouper_->Find(key_batch)); + ValidateOutput(id_batch); + + AssertDatumsEqual(expected, id_batch); + } + void ExpectUniques(const ExecBatch& uniques) { EXPECT_THAT(grouper_->GetUniques(), ResultWith(Eq(uniques))); } @@ -459,6 +469,19 @@ TEST(Grouper, NumericKey) { g.ExpectConsume("[[3], [27], [3], [27], [null], [81], [27], [81]]", "[0, 1, 0, 1, 3, 2, 1, 2]"); + + g.ExpectFind("[[3], [3]]", "[0, 0]"); + + g.ExpectFind("[[3], [3]]", "[0, 0]"); + + g.ExpectFind("[[27], [81]]", "[1, 2]"); + + g.ExpectFind("[[3], [27], [3], [27], [null], [81], [27], [81]]", + "[0, 1, 0, 1, 3, 2, 1, 2]"); + + g.ExpectFind("[[27], [3], [27], [null], [81], [1], [81]]", + "[1, 0, 1, 3, 2, null, 2]"); + g.ExpectUniques("[[3], [27], [81], [null]]"); } }