diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index cb6e91bd40e..bea1ddb28fe 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -406,6 +406,7 @@ if(ARROW_COMPUTE) compute/kernels/vector_replace.cc compute/kernels/vector_selection.cc compute/kernels/vector_sort.cc + compute/exec/hash_join.cc compute/exec/key_hash.cc compute/exec/key_map.cc compute/exec/key_compare.cc diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h index d66d4f1517c..195e76fe9f1 100644 --- a/cpp/src/arrow/compute/api_aggregate.h +++ b/cpp/src/arrow/compute/api_aggregate.h @@ -349,6 +349,11 @@ class ARROW_EXPORT Grouper { /// be as wide as necessary. virtual Result Consume(const ExecBatch& batch) = 0; + /// Finds/ queries the group IDs for the given ExecBatch for every index. Returns the + /// group IDs as an integer array. If a group ID not found, a UINT32_MAX will be + /// added to that index. This is a thread-safe lookup. + virtual Result Find(const ExecBatch& batch) = 0; + /// Get current unique keys. May be called multiple times. virtual Result GetUniques() = 0; diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 20c8c347cc1..f7d62c63673 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -1291,5 +1291,278 @@ Result GroupByUsingExecPlan(const std::vector& arguments, /*null_count=*/0); } +/*struct HashSemiIndexJoinNode : ExecNode { + HashSemiIndexJoinNode(ExecNode* left_input, ExecNode* right_input, std::string label, + std::shared_ptr output_schema, ExecContext* ctx, + const std::vector&& index_field_ids) + : ExecNode(left_input->plan(), std::move(label), {left_input, right_input}, + {"hashsemiindexjoin"}, std::move(output_schema), */ +/*num_outputs=*//*1), +ctx_(ctx), +num_build_batches_processed_(0), +num_build_batches_total_(-1), +num_probe_batches_processed_(0), +num_probe_batches_total_(-1), +num_output_batches_processed_(0), +index_field_ids_(std::move(index_field_ids)), +output_started_(false), +build_phase_finished_(false){} + +const char* kind_name() override { return "HashSemiIndexJoinNode"; } + +private: +struct ThreadLocalState; + +public: +Status InitLocalStateIfNeeded(ThreadLocalState* state) { +// Get input schema +auto input_schema = inputs_[0]->output_schema(); + +if (!state->grouper) { +// Build vector of key field data types +std::vector key_descrs(index_field_ids_.size()); +for (size_t i = 0; i < index_field_ids_.size(); ++i) { +auto key_field_id = index_field_ids_[i]; +key_descrs[i] = ValueDescr(input_schema->field(key_field_id)->type()); +} + +// Construct grouper +ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, ctx_)); +} + +return Status::OK(); +} + +Status ProcessBuildSideBatch(const ExecBatch& batch) { +SmallUniqueIdHolder id_holder(&local_state_id_assignment_); +int id = id_holder.get(); +ThreadLocalState* state = local_states_.get(id); +RETURN_NOT_OK(InitLocalStateIfNeeded(state)); + +// Create a batch with key columns +std::vector keys(key_field_ids_.size()); +for (size_t i = 0; i < key_field_ids_.size(); ++i) { +keys[i] = batch.values[key_field_ids_[i]]; +} +ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys)); + +// Create a batch with group ids +ARROW_ASSIGN_OR_RAISE(Datum id_batch, state->grouper->Consume(key_batch)); + +// Execute aggregate kernels +for (size_t i = 0; i < agg_kernels_.size(); ++i) { +KernelContext kernel_ctx{ctx_}; +kernel_ctx.SetState(state->agg_states[i].get()); + +ARROW_ASSIGN_OR_RAISE( +auto agg_batch, +ExecBatch::Make({batch.values[agg_src_field_ids_[i]], id_batch})); + +RETURN_NOT_OK(agg_kernels_[i]->resize(&kernel_ctx, state->grouper->num_groups())); +RETURN_NOT_OK(agg_kernels_[i]->consume(&kernel_ctx, agg_batch)); +} + +return Status::OK(); +} + +// merge all other groupers to grouper[0]. nothing needs to be done on the +// early_probe_batches, because when probing everyone +Status BuildSideMerge() { +int num_local_states = local_state_id_assignment_.num_ids(); +ThreadLocalState* state0 = local_states_.get(0); +for (int i = 1; i < num_local_states; ++i) { +ThreadLocalState* state = local_states_.get(i); +ARROW_DCHECK(state); +ARROW_DCHECK(state->grouper); +ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques()); +ARROW_ASSIGN_OR_RAISE(Datum _, state0->grouper->Consume(other_keys)); +state->grouper.reset(); +} +return Status::OK(); +} + +Status Finalize() { +out_data_.resize(agg_kernels_.size() + key_field_ids_.size()); +auto it = out_data_.begin(); + +ThreadLocalState* state = local_states_.get(0); +num_out_groups_ = state->grouper->num_groups(); + +// Aggregate fields come before key fields to match the behavior of GroupBy function + +for (size_t i = 0; i < agg_kernels_.size(); ++i) { +KernelContext batch_ctx{ctx_}; +batch_ctx.SetState(state->agg_states[i].get()); +Datum out; +RETURN_NOT_OK(agg_kernels_[i]->finalize(&batch_ctx, &out)); +*it++ = out.array(); +state->agg_states[i].reset(); +} + +ARROW_ASSIGN_OR_RAISE(ExecBatch out_keys, state->grouper->GetUniques()); +for (const auto& key : out_keys.values) { +*it++ = key.array(); +} +state->grouper.reset(); + +return Status::OK(); +} + +Status OutputNthBatch(int n) { +ARROW_DCHECK(output_started_.load()); + +// Check finished flag +if (finished_.is_finished()) { +return Status::OK(); +} + +// Slice arrays +int64_t batch_size = output_batch_size(); +int64_t batch_start = n * batch_size; +int64_t batch_length = std::min(batch_size, num_out_groups_ - batch_start); +std::vector output_slices(out_data_.size()); +for (size_t out_field_id = 0; out_field_id < out_data_.size(); ++out_field_id) { +output_slices[out_field_id] = +out_data_[out_field_id]->Slice(batch_start, batch_length); +} + +ARROW_ASSIGN_OR_RAISE(ExecBatch output_batch, ExecBatch::Make(output_slices)); +outputs_[0]->InputReceived(this, n, output_batch); + +uint32_t num_output_batches_processed = +1 + num_output_batches_processed_.fetch_add(1); +if (num_output_batches_processed * batch_size >= num_out_groups_) { +finished_.MarkFinished(); +} + +return Status::OK(); +} + +Status OutputResult() { +bool expected = false; +if (!output_started_.compare_exchange_strong(expected, true)) { +return Status::OK(); +} + +RETURN_NOT_OK(BuildSideMerge()); +RETURN_NOT_OK(Finalize()); + +int batch_size = output_batch_size(); +int num_result_batches = (num_out_groups_ + batch_size - 1) / batch_size; +outputs_[0]->InputFinished(this, num_result_batches); + +auto executor = arrow::internal::GetCpuThreadPool(); +for (int i = 0; i < num_result_batches; ++i) { +// Check finished flag +if (finished_.is_finished()) { +break; +} + +RETURN_NOT_OK(executor->Spawn([this, i]() { +Status status = OutputNthBatch(i); +if (!status.ok()) { +ErrorReceived(inputs_[0], status); +} +})); +} + +return Status::OK(); +} + +void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { +assert(input == inputs_[0] || input == inputs_[1]); + +if (finished_.is_finished()) { +return; +} + +ARROW_DCHECK(num_build_batches_processed_.load() != num_build_batches_total_.load()); + +Status status = ProcessBuildSideBatch(batch); +if (!status.ok()) { +ErrorReceived(input, status); +return; +} + +num_build_batches_processed_.fetch_add(1); +if (num_build_batches_processed_.load() == num_build_batches_total_.load()) { +status = OutputResult(); +if (!status.ok()) { +ErrorReceived(input, status); +return; +} +} +} + +void ErrorReceived(ExecNode* input, Status error) override { +DCHECK_EQ(input, inputs_[0]); + +outputs_[0]->ErrorReceived(this, std::move(error)); +StopProducing(); +} + +void InputFinished(ExecNode* input, int seq) override { +DCHECK_EQ(input, inputs_[0]); + +num_build_batches_total_.store(seq); +if (num_build_batches_processed_.load() == num_build_batches_total_.load()) { +Status status = OutputResult(); + +if (!status.ok()) { +ErrorReceived(input, status); +} +} +} + +Status StartProducing() override { +finished_ = Future<>::Make(); +return Status::OK(); +} + +void PauseProducing(ExecNode* output) override {} + +void ResumeProducing(ExecNode* output) override {} + +void StopProducing(ExecNode* output) override { +DCHECK_EQ(output, outputs_[0]); +inputs_[0]->StopProducing(this); + +finished_.MarkFinished(); +} + +void StopProducing() override { StopProducing(outputs_[0]); } + +Future<> finished() override { return finished_; } + +private: +int output_batch_size() const { +int result = static_cast(ctx_->exec_chunksize()); +if (result < 0) { +result = 32 * 1024; +} +return result; +} + +ExecContext* ctx_; +Future<> finished_ = Future<>::MakeFinished(); + +std::atomic num_build_batches_processed_; +std::atomic num_build_batches_total_; +std::atomic num_probe_batches_processed_; +std::atomic num_probe_batches_total_; +std::atomic num_output_batches_processed_; + +const std::vector index_field_ids_; + +struct ThreadLocalState { +std::unique_ptr grouper; +std::vector early_probe_batches{}; +}; +SharedSequenceOfObjects local_states_; +SmallUniqueIdAssignment local_state_id_assignment_; +uint32_t num_out_groups_{0}; +ArrayDataVector out_data_; +std::atomic output_started_, build_phase_finished_; +};*/ } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc new file mode 100644 index 00000000000..10b676be892 --- /dev/null +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec/hash_join.h" + +namespace arrow { +namespace compute { + + +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/exec/hash_join.h b/cpp/src/arrow/compute/exec/hash_join.h new file mode 100644 index 00000000000..492cf0a0a49 --- /dev/null +++ b/cpp/src/arrow/compute/exec/hash_join.h @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace arrow { +namespace compute { + +enum JoinType { + LEFT_SEMI_JOIN, + RIGHT_SEMI_JOIN, + LEFT_ANTI_SEMI_JOIN, + RIGHT_ANTI_SEMI_JOIN +}; + +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 3e4b401bae9..30675d8e9fe 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#include + #include #include #include @@ -364,30 +366,43 @@ struct GrouperImpl : Grouper { return std::move(impl); } - Result Consume(const ExecBatch& batch) override { - std::vector offsets_batch(batch.length + 1); + Status PopulateKeyData(const ExecBatch& batch, std::vector* offsets_batch, + std::vector* key_bytes_batch, + std::vector* key_buf_ptrs) { + offsets_batch->resize(batch.length + 1); for (int i = 0; i < batch.num_values(); ++i) { - encoders_[i]->AddLength(*batch[i].array(), offsets_batch.data()); + encoders_[i]->AddLength(*batch[i].array(), offsets_batch->data()); } int32_t total_length = 0; for (int64_t i = 0; i < batch.length; ++i) { auto total_length_before = total_length; - total_length += offsets_batch[i]; - offsets_batch[i] = total_length_before; + total_length += offsets_batch->at(i); + offsets_batch->at(i) = total_length_before; } - offsets_batch[batch.length] = total_length; + offsets_batch->at(batch.length) = total_length; - std::vector key_bytes_batch(total_length); - std::vector key_buf_ptrs(batch.length); + key_bytes_batch->resize(total_length); + key_buf_ptrs->resize(batch.length); for (int64_t i = 0; i < batch.length; ++i) { - key_buf_ptrs[i] = key_bytes_batch.data() + offsets_batch[i]; + key_buf_ptrs->at(i) = key_bytes_batch->data() + offsets_batch->at(i); } for (int i = 0; i < batch.num_values(); ++i) { - RETURN_NOT_OK(encoders_[i]->Encode(*batch[i].array(), key_buf_ptrs.data())); + RETURN_NOT_OK(encoders_[i]->Encode(*batch[i].array(), key_buf_ptrs->data())); } + return Status::OK(); + } + + Result Consume(const ExecBatch& batch) override { + std::vector offsets_batch; + std::vector key_bytes_batch; + std::vector key_buf_ptrs; + + RETURN_NOT_OK( + PopulateKeyData(batch, &offsets_batch, &key_bytes_batch, &key_buf_ptrs)); + TypedBufferBuilder group_ids_batch(ctx_->memory_pool()); RETURN_NOT_OK(group_ids_batch.Resize(batch.length)); @@ -416,6 +431,36 @@ struct GrouperImpl : Grouper { return Datum(UInt32Array(batch.length, std::move(group_ids))); } + Result Find(const ExecBatch& batch) override { + std::vector offsets_batch; + std::vector key_bytes_batch; + std::vector key_buf_ptrs; + + RETURN_NOT_OK( + PopulateKeyData(batch, &offsets_batch, &key_bytes_batch, &key_buf_ptrs)); + + UInt32Builder group_ids_batch(ctx_->memory_pool()); + RETURN_NOT_OK(group_ids_batch.Resize(batch.length)); + + for (int64_t i = 0; i < batch.length; ++i) { + int32_t key_length = offsets_batch[i + 1] - offsets_batch[i]; + std::string key( + reinterpret_cast(key_bytes_batch.data() + offsets_batch[i]), + key_length); + + auto it = map_.find(key); + // no group ID was found, null will be emitted! + if (it == map_.end()) { + group_ids_batch.UnsafeAppendNull(); + } else { + group_ids_batch.UnsafeAppend(it->second); + } + } + + ARROW_ASSIGN_OR_RAISE(auto group_ids, group_ids_batch.Finish()); + return Datum(group_ids); + } + uint32_t num_groups() const override { return num_groups_; } Result GetUniques() override { @@ -617,6 +662,11 @@ struct GrouperFastImpl : Grouper { return Datum(UInt32Array(batch.length, std::move(group_ids))); } + Result Find(const ExecBatch& batch) override { + // todo impl this + return Result(); + } + uint32_t num_groups() const override { return static_cast(rows_.length()); } // Make sure padded buffers end up with the right logical size @@ -1342,9 +1392,10 @@ Result ResolveKernels( Result> Grouper::Make(const std::vector& descrs, ExecContext* ctx) { - if (GrouperFastImpl::CanUse(descrs)) { - return GrouperFastImpl::Make(descrs, ctx); - } + // TODO(niranda) re-enable this! + // if (GrouperFastImpl::CanUse(descrs)) { + // return GrouperFastImpl::Make(descrs, ctx); + // } return GrouperImpl::Make(descrs, ctx); } diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc index 46c7716abce..8374c333742 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc @@ -216,6 +216,16 @@ struct TestGrouper { AssertEquivalentIds(expected, ids); } + void ExpectFind(const std::string& key_json, const std::string& expected) { + ExpectFind(ExecBatchFromJSON(descrs_, key_json), ArrayFromJSON(uint32(), expected)); + } + void ExpectFind(const ExecBatch& key_batch, Datum expected) { + ASSERT_OK_AND_ASSIGN(Datum id_batch, grouper_->Find(key_batch)); + ValidateOutput(id_batch); + + AssertDatumsEqual(expected, id_batch); + } + void AssertEquivalentIds(const Datum& expected, const Datum& actual) { auto left = expected.make_array(); auto right = actual.make_array(); @@ -341,6 +351,18 @@ TEST(Grouper, NumericKey) { g.ExpectConsume("[[3], [27], [3], [27], [null], [81], [27], [81]]", "[0, 1, 0, 1, 3, 2, 1, 2]"); + + g.ExpectFind("[[3], [3]]", "[0, 0]"); + + g.ExpectFind("[[3], [3]]", "[0, 0]"); + + g.ExpectFind("[[27], [81]]", "[1, 2]"); + + g.ExpectFind("[[3], [27], [3], [27], [null], [81], [27], [81]]", + "[0, 1, 0, 1, 3, 2, 1, 2]"); + + g.ExpectFind("[[27], [3], [27], [null], [81], [1], [81]]", + "[1, 0, 1, 3, 2, null, 2]"); } }