Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ if(ARROW_COMPUTE)
compute/exec/key_compare.cc
compute/exec/key_encode.cc
compute/exec/util.cc
compute/exec/hash_join_dict.cc
compute/exec/hash_join.cc
compute/exec/hash_join_node.cc
compute/exec/task_util.cc)
Expand Down
103 changes: 80 additions & 23 deletions cpp/src/arrow/compute/exec/hash_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <unordered_map>
#include <vector>

#include "arrow/compute/exec/hash_join_dict.h"
#include "arrow/compute/exec/task_util.h"
#include "arrow/compute/kernels/row_encoder.h"

Expand Down Expand Up @@ -96,6 +97,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
local_states_[i].is_initialized = false;
local_states_[i].is_has_match_initialized = false;
}
dict_probe_.Init(num_threads);

has_hash_table_ = false;
num_batches_produced_.store(0);
Expand Down Expand Up @@ -144,12 +146,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
if (has_payload) {
InitEncoder(0, HashJoinProjection::PAYLOAD, &local_state.exec_batch_payloads);
}

local_state.is_initialized = true;
}
}

Status EncodeBatch(int side, HashJoinProjection projection_handle, RowEncoder* encoder,
const ExecBatch& batch) {
const ExecBatch& batch, ExecBatch* opt_projected_batch = nullptr) {
ExecBatch projected({}, batch.length);
int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle);
projected.values.resize(num_cols);
Expand All @@ -160,6 +163,10 @@ class HashJoinBasicImpl : public HashJoinImpl {
projected.values[icol] = batch.values[to_input.get(icol)];
}

if (opt_projected_batch) {
*opt_projected_batch = projected;
}

return encoder->EncodeAndAppend(projected);
}

Expand All @@ -170,6 +177,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
std::vector<int32_t>* output_no_match,
std::vector<int32_t>* output_match_left,
std::vector<int32_t>* output_match_right) {
InitHasMatchIfNeeded(local_state);

ARROW_DCHECK(has_hash_table_);

InitHasMatchIfNeeded(local_state);
Expand Down Expand Up @@ -311,6 +320,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
ARROW_DCHECK(opt_right_ids);
ARROW_ASSIGN_OR_RAISE(right_key,
hash_table_keys_.Decode(batch_size_next, opt_right_ids));
// Post process build side keys that use dictionary
RETURN_NOT_OK(dict_build_.PostDecode(schema_mgr_->proj_maps[1], &right_key, ctx_));
}
if (has_right_payload) {
ARROW_ASSIGN_OR_RAISE(right_payload,
Expand Down Expand Up @@ -368,13 +379,48 @@ class HashJoinBasicImpl : public HashJoinImpl {
return Status::OK();
}

void NullInfoFromBatch(const ExecBatch& batch,
std::vector<const uint8_t*>* nn_bit_vectors,
std::vector<int64_t>* nn_offsets,
std::vector<uint8_t>* nn_bit_vector_all_nulls) {
int num_cols = static_cast<int>(batch.values.size());
nn_bit_vectors->resize(num_cols);
nn_offsets->resize(num_cols);
nn_bit_vector_all_nulls->clear();
for (int64_t i = 0; i < num_cols; ++i) {
const uint8_t* nn = nullptr;
int64_t offset = 0;
if (batch[i].is_array()) {
if (batch[i].array()->buffers[0] != NULLPTR) {
nn = batch[i].array()->buffers[0]->data();
offset = batch[i].array()->offset;
}
} else {
ARROW_DCHECK(batch[i].is_scalar());
if (!batch[i].scalar_as<arrow::internal::PrimitiveScalarBase>().is_valid) {
if (nn_bit_vector_all_nulls->empty()) {
nn_bit_vector_all_nulls->resize(BitUtil::BytesForBits(batch.length));
memset(nn_bit_vector_all_nulls->data(), 0,
BitUtil::BytesForBits(batch.length));
}
nn = nn_bit_vector_all_nulls->data();
}
}
(*nn_bit_vectors)[i] = nn;
(*nn_offsets)[i] = offset;
}
}

Status ProbeBatch(size_t thread_index, const ExecBatch& batch) {
ThreadLocalState& local_state = local_states_[thread_index];
InitLocalStateIfNeeded(thread_index);

local_state.exec_batch_keys.Clear();
RETURN_NOT_OK(
EncodeBatch(0, HashJoinProjection::KEY, &local_state.exec_batch_keys, batch));

ExecBatch batch_key_for_lookups;

RETURN_NOT_OK(EncodeBatch(0, HashJoinProjection::KEY, &local_state.exec_batch_keys,
batch, &batch_key_for_lookups));
bool has_left_payload =
(schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0);
if (has_left_payload) {
Expand All @@ -388,26 +434,24 @@ class HashJoinBasicImpl : public HashJoinImpl {
local_state.match_left.clear();
local_state.match_right.clear();

bool use_key_batch_for_dicts = dict_probe_.BatchRemapNeeded(
thread_index, schema_mgr_->proj_maps[0], schema_mgr_->proj_maps[1], ctx_);
RowEncoder* row_encoder_for_lookups = &local_state.exec_batch_keys;
if (use_key_batch_for_dicts) {
RETURN_NOT_OK(dict_probe_.EncodeBatch(
thread_index, schema_mgr_->proj_maps[0], schema_mgr_->proj_maps[1], dict_build_,
batch, &row_encoder_for_lookups, &batch_key_for_lookups, ctx_));
}

// Collect information about all nulls in key columns.
//
std::vector<const uint8_t*> non_null_bit_vectors;
std::vector<int64_t> non_null_bit_vector_offsets;
int num_key_cols = schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::KEY);
non_null_bit_vectors.resize(num_key_cols);
non_null_bit_vector_offsets.resize(num_key_cols);
auto from_batch =
schema_mgr_->proj_maps[0].map(HashJoinProjection::KEY, HashJoinProjection::INPUT);
for (int i = 0; i < num_key_cols; ++i) {
int input_col_id = from_batch.get(i);
const uint8_t* non_nulls = nullptr;
int64_t offset = 0;
if (batch[input_col_id].array()->buffers[0] != NULLPTR) {
non_nulls = batch[input_col_id].array()->buffers[0]->data();
offset = batch[input_col_id].array()->offset;
}
non_null_bit_vectors[i] = non_nulls;
non_null_bit_vector_offsets[i] = offset;
}
std::vector<uint8_t> all_nulls;
NullInfoFromBatch(batch_key_for_lookups, &non_null_bit_vectors,
&non_null_bit_vector_offsets, &all_nulls);

ProbeBatch_Lookup(&local_state, local_state.exec_batch_keys, non_null_bit_vectors,
ProbeBatch_Lookup(&local_state, *row_encoder_for_lookups, non_null_bit_vectors,
non_null_bit_vector_offsets, &local_state.match,
&local_state.no_match, &local_state.match_left,
&local_state.match_right);
Expand All @@ -427,7 +471,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
if (batches.empty()) {
hash_table_empty_ = true;
} else {
InitEncoder(1, HashJoinProjection::KEY, &hash_table_keys_);
dict_build_.InitEncoder(schema_mgr_->proj_maps[1], &hash_table_keys_, ctx_);
bool has_payload =
(schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0);
if (has_payload) {
Expand All @@ -441,11 +485,14 @@ class HashJoinBasicImpl : public HashJoinImpl {
const ExecBatch& batch = batches[ibatch];
if (batch.length == 0) {
continue;
} else {
} else if (hash_table_empty_) {
hash_table_empty_ = false;

RETURN_NOT_OK(dict_build_.Init(schema_mgr_->proj_maps[1], &batch, ctx_));
}
int32_t num_rows_before = hash_table_keys_.num_rows();
RETURN_NOT_OK(EncodeBatch(1, HashJoinProjection::KEY, &hash_table_keys_, batch));
RETURN_NOT_OK(dict_build_.EncodeBatch(thread_index, schema_mgr_->proj_maps[1],
batch, &hash_table_keys_, ctx_));
if (has_payload) {
RETURN_NOT_OK(
EncodeBatch(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_, batch));
Expand All @@ -456,6 +503,11 @@ class HashJoinBasicImpl : public HashJoinImpl {
}
}
}

if (hash_table_empty_) {
RETURN_NOT_OK(dict_build_.Init(schema_mgr_->proj_maps[1], nullptr, ctx_));
}

return Status::OK();
}

Expand Down Expand Up @@ -713,6 +765,11 @@ class HashJoinBasicImpl : public HashJoinImpl {
std::vector<uint8_t> has_match_;
bool hash_table_empty_;

// Dictionary handling
//
HashJoinDictBuildMulti dict_build_;
HashJoinDictProbeMulti dict_probe_;

std::vector<ExecBatch> left_batches_;
bool has_hash_table_;
std::mutex left_batches_mutex_;
Expand Down
4 changes: 0 additions & 4 deletions cpp/src/arrow/compute/exec/hash_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@
namespace arrow {
namespace compute {

// Identifiers for all different row schemas that are used in a join
//
enum class HashJoinProjection : int { INPUT = 0, KEY = 1, PAYLOAD = 2, OUTPUT = 3 };

class ARROW_EXPORT HashJoinSchema {
public:
Status Init(JoinType join_type, const Schema& left_schema,
Expand Down
Loading