Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ if(ARROW_COMPUTE)
compute/kernels/vector_replace.cc
compute/kernels/vector_selection.cc
compute/kernels/vector_sort.cc
compute/exec/hash_join.cc
compute/exec/key_hash.cc
compute/exec/key_map.cc
compute/exec/key_compare.cc
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/arrow/compute/api_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,11 @@ class ARROW_EXPORT Grouper {
/// be as wide as necessary.
virtual Result<Datum> Consume(const ExecBatch& batch) = 0;

/// Finds/ queries the group IDs for the given ExecBatch for every index. Returns the
/// group IDs as an integer array. If a group ID not found, a UINT32_MAX will be
/// added to that index. This is a thread-safe lookup.
virtual Result<Datum> Find(const ExecBatch& batch) = 0;

/// Get current unique keys. May be called multiple times.
virtual Result<ExecBatch> GetUniques() = 0;

Expand Down
273 changes: 273 additions & 0 deletions cpp/src/arrow/compute/exec/exec_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1291,5 +1291,278 @@ Result<Datum> GroupByUsingExecPlan(const std::vector<Datum>& arguments,
/*null_count=*/0);
}

/*struct HashSemiIndexJoinNode : ExecNode {
HashSemiIndexJoinNode(ExecNode* left_input, ExecNode* right_input, std::string label,
std::shared_ptr<Schema> output_schema, ExecContext* ctx,
const std::vector<int>&& index_field_ids)
: ExecNode(left_input->plan(), std::move(label), {left_input, right_input},
{"hashsemiindexjoin"}, std::move(output_schema), */
/*num_outputs=*//*1),
ctx_(ctx),
num_build_batches_processed_(0),
num_build_batches_total_(-1),
num_probe_batches_processed_(0),
num_probe_batches_total_(-1),
num_output_batches_processed_(0),
index_field_ids_(std::move(index_field_ids)),
output_started_(false),
build_phase_finished_(false){}

const char* kind_name() override { return "HashSemiIndexJoinNode"; }

private:
struct ThreadLocalState;

public:
Status InitLocalStateIfNeeded(ThreadLocalState* state) {
// Get input schema
auto input_schema = inputs_[0]->output_schema();

if (!state->grouper) {
// Build vector of key field data types
std::vector<ValueDescr> key_descrs(index_field_ids_.size());
for (size_t i = 0; i < index_field_ids_.size(); ++i) {
auto key_field_id = index_field_ids_[i];
key_descrs[i] = ValueDescr(input_schema->field(key_field_id)->type());
}

// Construct grouper
ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, ctx_));
}

return Status::OK();
}

Status ProcessBuildSideBatch(const ExecBatch& batch) {
SmallUniqueIdHolder id_holder(&local_state_id_assignment_);
int id = id_holder.get();
ThreadLocalState* state = local_states_.get(id);
RETURN_NOT_OK(InitLocalStateIfNeeded(state));

// Create a batch with key columns
std::vector<Datum> keys(key_field_ids_.size());
for (size_t i = 0; i < key_field_ids_.size(); ++i) {
keys[i] = batch.values[key_field_ids_[i]];
}
ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys));

// Create a batch with group ids
ARROW_ASSIGN_OR_RAISE(Datum id_batch, state->grouper->Consume(key_batch));

// Execute aggregate kernels
for (size_t i = 0; i < agg_kernels_.size(); ++i) {
KernelContext kernel_ctx{ctx_};
kernel_ctx.SetState(state->agg_states[i].get());

ARROW_ASSIGN_OR_RAISE(
auto agg_batch,
ExecBatch::Make({batch.values[agg_src_field_ids_[i]], id_batch}));

RETURN_NOT_OK(agg_kernels_[i]->resize(&kernel_ctx, state->grouper->num_groups()));
RETURN_NOT_OK(agg_kernels_[i]->consume(&kernel_ctx, agg_batch));
}

return Status::OK();
}

// merge all other groupers to grouper[0]. nothing needs to be done on the
// early_probe_batches, because when probing everyone
Status BuildSideMerge() {
int num_local_states = local_state_id_assignment_.num_ids();
ThreadLocalState* state0 = local_states_.get(0);
for (int i = 1; i < num_local_states; ++i) {
ThreadLocalState* state = local_states_.get(i);
ARROW_DCHECK(state);
ARROW_DCHECK(state->grouper);
ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, state->grouper->GetUniques());
ARROW_ASSIGN_OR_RAISE(Datum _, state0->grouper->Consume(other_keys));
state->grouper.reset();
}
return Status::OK();
}

Status Finalize() {
out_data_.resize(agg_kernels_.size() + key_field_ids_.size());
auto it = out_data_.begin();

ThreadLocalState* state = local_states_.get(0);
num_out_groups_ = state->grouper->num_groups();

// Aggregate fields come before key fields to match the behavior of GroupBy function

for (size_t i = 0; i < agg_kernels_.size(); ++i) {
KernelContext batch_ctx{ctx_};
batch_ctx.SetState(state->agg_states[i].get());
Datum out;
RETURN_NOT_OK(agg_kernels_[i]->finalize(&batch_ctx, &out));
*it++ = out.array();
state->agg_states[i].reset();
}

ARROW_ASSIGN_OR_RAISE(ExecBatch out_keys, state->grouper->GetUniques());
for (const auto& key : out_keys.values) {
*it++ = key.array();
}
state->grouper.reset();

return Status::OK();
}

Status OutputNthBatch(int n) {
ARROW_DCHECK(output_started_.load());

// Check finished flag
if (finished_.is_finished()) {
return Status::OK();
}

// Slice arrays
int64_t batch_size = output_batch_size();
int64_t batch_start = n * batch_size;
int64_t batch_length = std::min(batch_size, num_out_groups_ - batch_start);
std::vector<Datum> output_slices(out_data_.size());
for (size_t out_field_id = 0; out_field_id < out_data_.size(); ++out_field_id) {
output_slices[out_field_id] =
out_data_[out_field_id]->Slice(batch_start, batch_length);
}

ARROW_ASSIGN_OR_RAISE(ExecBatch output_batch, ExecBatch::Make(output_slices));
outputs_[0]->InputReceived(this, n, output_batch);

uint32_t num_output_batches_processed =
1 + num_output_batches_processed_.fetch_add(1);
if (num_output_batches_processed * batch_size >= num_out_groups_) {
finished_.MarkFinished();
}

return Status::OK();
}

Status OutputResult() {
bool expected = false;
if (!output_started_.compare_exchange_strong(expected, true)) {
return Status::OK();
}

RETURN_NOT_OK(BuildSideMerge());
RETURN_NOT_OK(Finalize());

int batch_size = output_batch_size();
int num_result_batches = (num_out_groups_ + batch_size - 1) / batch_size;
outputs_[0]->InputFinished(this, num_result_batches);

auto executor = arrow::internal::GetCpuThreadPool();
for (int i = 0; i < num_result_batches; ++i) {
// Check finished flag
if (finished_.is_finished()) {
break;
}

RETURN_NOT_OK(executor->Spawn([this, i]() {
Status status = OutputNthBatch(i);
if (!status.ok()) {
ErrorReceived(inputs_[0], status);
}
}));
}

return Status::OK();
}

void InputReceived(ExecNode* input, int seq, ExecBatch batch) override {
assert(input == inputs_[0] || input == inputs_[1]);

if (finished_.is_finished()) {
return;
}

ARROW_DCHECK(num_build_batches_processed_.load() != num_build_batches_total_.load());

Status status = ProcessBuildSideBatch(batch);
if (!status.ok()) {
ErrorReceived(input, status);
return;
}

num_build_batches_processed_.fetch_add(1);
if (num_build_batches_processed_.load() == num_build_batches_total_.load()) {
status = OutputResult();
if (!status.ok()) {
ErrorReceived(input, status);
return;
}
}
}

void ErrorReceived(ExecNode* input, Status error) override {
DCHECK_EQ(input, inputs_[0]);

outputs_[0]->ErrorReceived(this, std::move(error));
StopProducing();
}

void InputFinished(ExecNode* input, int seq) override {
DCHECK_EQ(input, inputs_[0]);

num_build_batches_total_.store(seq);
if (num_build_batches_processed_.load() == num_build_batches_total_.load()) {
Status status = OutputResult();

if (!status.ok()) {
ErrorReceived(input, status);
}
}
}

Status StartProducing() override {
finished_ = Future<>::Make();
return Status::OK();
}

void PauseProducing(ExecNode* output) override {}

void ResumeProducing(ExecNode* output) override {}

void StopProducing(ExecNode* output) override {
DCHECK_EQ(output, outputs_[0]);
inputs_[0]->StopProducing(this);

finished_.MarkFinished();
}

void StopProducing() override { StopProducing(outputs_[0]); }

Future<> finished() override { return finished_; }

private:
int output_batch_size() const {
int result = static_cast<int>(ctx_->exec_chunksize());
if (result < 0) {
result = 32 * 1024;
}
return result;
}

ExecContext* ctx_;
Future<> finished_ = Future<>::MakeFinished();

std::atomic<int> num_build_batches_processed_;
std::atomic<int> num_build_batches_total_;
std::atomic<int> num_probe_batches_processed_;
std::atomic<int> num_probe_batches_total_;
std::atomic<uint32_t> num_output_batches_processed_;

const std::vector<int> index_field_ids_;

struct ThreadLocalState {
std::unique_ptr<internal::Grouper> grouper;
std::vector<ExecBatch> early_probe_batches{};
};
SharedSequenceOfObjects<ThreadLocalState> local_states_;
SmallUniqueIdAssignment local_state_id_assignment_;
uint32_t num_out_groups_{0};
ArrayDataVector out_data_;
std::atomic<bool> output_started_, build_phase_finished_;
};*/
} // namespace compute
} // namespace arrow
25 changes: 25 additions & 0 deletions cpp/src/arrow/compute/exec/hash_join.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "arrow/compute/exec/hash_join.h"

namespace arrow {
namespace compute {


} // namespace compute
} // namespace arrow
29 changes: 29 additions & 0 deletions cpp/src/arrow/compute/exec/hash_join.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

namespace arrow {
namespace compute {

enum JoinType {
LEFT_SEMI_JOIN,
RIGHT_SEMI_JOIN,
LEFT_ANTI_SEMI_JOIN,
RIGHT_ANTI_SEMI_JOIN
};

} // namespace compute
} // namespace arrow
Loading