Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 70 additions & 34 deletions cpp/src/arrow/compute/exec/hash_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ namespace compute {
using internal::RowEncoder;

class HashJoinBasicImpl : public HashJoinImpl {
private:
struct ThreadLocalState;

public:
Status InputReceived(size_t thread_index, int side, ExecBatch batch) override {
if (cancelled_) {
Expand Down Expand Up @@ -91,6 +94,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
local_states_.resize(num_threads);
for (size_t i = 0; i < local_states_.size(); ++i) {
local_states_[i].is_initialized = false;
local_states_[i].is_has_match_initialized = false;
}

has_hash_table_ = false;
Expand Down Expand Up @@ -150,23 +154,26 @@ class HashJoinBasicImpl : public HashJoinImpl {
int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle);
projected.values.resize(num_cols);

const int* to_input =
auto to_input =
schema_mgr_->proj_maps[side].map(projection_handle, HashJoinProjection::INPUT);
for (int icol = 0; icol < num_cols; ++icol) {
projected.values[icol] = batch.values[to_input[icol]];
projected.values[icol] = batch.values[to_input.get(icol)];
}

return encoder->EncodeAndAppend(projected);
}

void ProbeBatch_Lookup(const RowEncoder& exec_batch_keys,
void ProbeBatch_Lookup(ThreadLocalState* local_state, const RowEncoder& exec_batch_keys,
const std::vector<const uint8_t*>& non_null_bit_vectors,
const std::vector<int64_t>& non_null_bit_vector_offsets,
std::vector<int32_t>* output_match,
std::vector<int32_t>* output_no_match,
std::vector<int32_t>* output_match_left,
std::vector<int32_t>* output_match_right) {
ARROW_DCHECK(has_hash_table_);

InitHasMatchIfNeeded(local_state);

int num_cols = static_cast<int>(non_null_bit_vectors.size());
for (int32_t irow = 0; irow < exec_batch_keys.num_rows(); ++irow) {
// Apply null key filtering
Expand All @@ -191,7 +198,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
for (auto it = range.first; it != range.second; ++it) {
output_match_left->push_back(irow);
output_match_right->push_back(it->second);
has_match_[it->second] = 0xFF;
// Mark row in hash table as having a match
BitUtil::SetBit(local_state->has_match.data(), it->second);
has_match = true;
}
if (!has_match) {
Expand All @@ -215,46 +223,47 @@ class HashJoinBasicImpl : public HashJoinImpl {
ARROW_DCHECK((opt_right_payload == nullptr) ==
(schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) == 0));
result.values.resize(num_out_cols_left + num_out_cols_right);
const int* from_key = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
HashJoinProjection::KEY);
const int* from_payload = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
HashJoinProjection::PAYLOAD);
auto from_key = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
HashJoinProjection::KEY);
auto from_payload = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
HashJoinProjection::PAYLOAD);
for (int icol = 0; icol < num_out_cols_left; ++icol) {
bool is_from_key = (from_key[icol] != HashJoinSchema::kMissingField());
bool is_from_payload = (from_payload[icol] != HashJoinSchema::kMissingField());
bool is_from_key = (from_key.get(icol) != HashJoinSchema::kMissingField());
bool is_from_payload = (from_payload.get(icol) != HashJoinSchema::kMissingField());
ARROW_DCHECK(is_from_key != is_from_payload);
ARROW_DCHECK(!is_from_key ||
(opt_left_key &&
from_key[icol] < static_cast<int>(opt_left_key->values.size()) &&
from_key.get(icol) < static_cast<int>(opt_left_key->values.size()) &&
opt_left_key->length == batch_size_next));
ARROW_DCHECK(
!is_from_payload ||
(opt_left_payload &&
from_payload[icol] < static_cast<int>(opt_left_payload->values.size()) &&
from_payload.get(icol) < static_cast<int>(opt_left_payload->values.size()) &&
opt_left_payload->length == batch_size_next));
result.values[icol] = is_from_key ? opt_left_key->values[from_key[icol]]
: opt_left_payload->values[from_payload[icol]];
result.values[icol] = is_from_key
? opt_left_key->values[from_key.get(icol)]
: opt_left_payload->values[from_payload.get(icol)];
}
from_key = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT,
HashJoinProjection::KEY);
from_payload = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT,
HashJoinProjection::PAYLOAD);
for (int icol = 0; icol < num_out_cols_right; ++icol) {
bool is_from_key = (from_key[icol] != HashJoinSchema::kMissingField());
bool is_from_payload = (from_payload[icol] != HashJoinSchema::kMissingField());
bool is_from_key = (from_key.get(icol) != HashJoinSchema::kMissingField());
bool is_from_payload = (from_payload.get(icol) != HashJoinSchema::kMissingField());
ARROW_DCHECK(is_from_key != is_from_payload);
ARROW_DCHECK(!is_from_key ||
(opt_right_key &&
from_key[icol] < static_cast<int>(opt_right_key->values.size()) &&
from_key.get(icol) < static_cast<int>(opt_right_key->values.size()) &&
opt_right_key->length == batch_size_next));
ARROW_DCHECK(
!is_from_payload ||
(opt_right_payload &&
from_payload[icol] < static_cast<int>(opt_right_payload->values.size()) &&
from_payload.get(icol) < static_cast<int>(opt_right_payload->values.size()) &&
opt_right_payload->length == batch_size_next));
result.values[num_out_cols_left + icol] =
is_from_key ? opt_right_key->values[from_key[icol]]
: opt_right_payload->values[from_payload[icol]];
is_from_key ? opt_right_key->values[from_key.get(icol)]
: opt_right_payload->values[from_payload.get(icol)];
}

output_batch_callback_(std::move(result));
Expand Down Expand Up @@ -384,10 +393,10 @@ class HashJoinBasicImpl : public HashJoinImpl {
int num_key_cols = schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::KEY);
non_null_bit_vectors.resize(num_key_cols);
non_null_bit_vector_offsets.resize(num_key_cols);
const int* from_batch =
auto from_batch =
schema_mgr_->proj_maps[0].map(HashJoinProjection::KEY, HashJoinProjection::INPUT);
for (int i = 0; i < num_key_cols; ++i) {
int input_col_id = from_batch[i];
int input_col_id = from_batch.get(i);
const uint8_t* non_nulls = nullptr;
int64_t offset = 0;
if (batch[input_col_id].array()->buffers[0] != NULLPTR) {
Expand All @@ -398,7 +407,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
non_null_bit_vector_offsets[i] = offset;
}

ProbeBatch_Lookup(local_state.exec_batch_keys, non_null_bit_vectors,
ProbeBatch_Lookup(&local_state, local_state.exec_batch_keys, non_null_bit_vectors,
non_null_bit_vector_offsets, &local_state.match,
&local_state.no_match, &local_state.match_left,
&local_state.match_right);
Expand Down Expand Up @@ -446,11 +455,6 @@ class HashJoinBasicImpl : public HashJoinImpl {
hash_table_.insert(std::make_pair(hash_table_keys_.encoded_row(irow), irow));
}
}
if (!hash_table_empty_) {
int32_t num_rows = hash_table_keys_.num_rows();
has_match_.resize(num_rows);
memset(has_match_.data(), 0, num_rows);
}
}
return Status::OK();
}
Expand Down Expand Up @@ -563,9 +567,9 @@ class HashJoinBasicImpl : public HashJoinImpl {
id_right.clear();
bool use_left = false;

uint8_t match_search_value = (join_type_ == JoinType::RIGHT_SEMI) ? 0xFF : 0x00;
bool match_search_value = (join_type_ == JoinType::RIGHT_SEMI);
for (int32_t row_id = start_row_id; row_id < end_row_id; ++row_id) {
if (has_match_[row_id] == match_search_value) {
if (BitUtil::GetBit(has_match_.data(), row_id) == match_search_value) {
id_right.push_back(row_id);
}
}
Expand Down Expand Up @@ -607,16 +611,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
}

Status ScanHashTable(size_t thread_index) {
MergeHasMatch();
return scheduler_->StartTaskGroup(thread_index, task_group_scan_,
ScanHashTable_num_tasks());
}

bool QueueBatchIfNeeded(int side, ExecBatch batch) {
if (side == 0) {
if (has_hash_table_) {
return false;
}

std::lock_guard<std::mutex> lock(left_batches_mutex_);
if (has_hash_table_) {
return false;
Expand All @@ -636,6 +637,39 @@ class HashJoinBasicImpl : public HashJoinImpl {
return ScanHashTable(thread_index);
}

void InitHasMatchIfNeeded(ThreadLocalState* local_state) {
if (local_state->is_has_match_initialized) {
return;
}
if (!hash_table_empty_) {
int32_t num_rows = hash_table_keys_.num_rows();
local_state->has_match.resize(BitUtil::BytesForBits(num_rows));
memset(local_state->has_match.data(), 0, BitUtil::BytesForBits(num_rows));
}
local_state->is_has_match_initialized = true;
}

void MergeHasMatch() {
if (hash_table_empty_) {
return;
}

int32_t num_rows = hash_table_keys_.num_rows();
has_match_.resize(BitUtil::BytesForBits(num_rows));
memset(has_match_.data(), 0, BitUtil::BytesForBits(num_rows));

for (size_t tid = 0; tid < local_states_.size(); ++tid) {
if (!local_states_[tid].is_initialized) {
continue;
}
if (!local_states_[tid].is_has_match_initialized) {
continue;
}
arrow::internal::BitmapOr(has_match_.data(), 0, local_states_[tid].has_match.data(),
0, num_rows, 0, has_match_.data());
}
}

static constexpr int64_t hash_table_scan_unit_ = 32 * 1024;
static constexpr int64_t output_batch_size_ = 32 * 1024;

Expand Down Expand Up @@ -666,6 +700,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
std::vector<int32_t> no_match;
std::vector<int32_t> match_left;
std::vector<int32_t> match_right;
bool is_has_match_initialized;
std::vector<uint8_t> has_match;
};
std::vector<ThreadLocalState> local_states_;

Expand Down
6 changes: 3 additions & 3 deletions cpp/src/arrow/compute/exec/hash_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,9 @@ std::shared_ptr<Schema> HashJoinSchema::MakeOutputSchema(
for (int i = 0; i < left_size + right_size; ++i) {
bool is_left = (i < left_size);
int side = (is_left ? 0 : 1);
int input_field_id =
proj_maps[side].map(HashJoinProjection::OUTPUT,
HashJoinProjection::INPUT)[is_left ? i : i - left_size];
int input_field_id = proj_maps[side]
.map(HashJoinProjection::OUTPUT, HashJoinProjection::INPUT)
.get(is_left ? i : i - left_size);
const std::string& input_field_name =
proj_maps[side].field_name(HashJoinProjection::INPUT, input_field_id);
const std::shared_ptr<DataType>& input_data_type =
Expand Down
33 changes: 22 additions & 11 deletions cpp/src/arrow/compute/exec/hash_join_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -921,24 +921,31 @@ void HashJoinWithExecPlan(Random64Bit& rng, bool parallel,

ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));

Declaration join{"hashjoin", join_options};

// add left source
BatchesWithSchema l_batches = TableToBatches(rng, num_batches_l, l, "l_");
join.inputs.emplace_back(Declaration{
"source", SourceNodeOptions{l_batches.schema, l_batches.gen(parallel,
/*slow=*/false)}});
ASSERT_OK_AND_ASSIGN(
ExecNode * l_source,
MakeExecNode("source", plan.get(), {},
SourceNodeOptions{l_batches.schema, l_batches.gen(parallel,
/*slow=*/false)}));

// add right source
BatchesWithSchema r_batches = TableToBatches(rng, num_batches_r, r, "r_");
join.inputs.emplace_back(Declaration{
"source", SourceNodeOptions{r_batches.schema, r_batches.gen(parallel,
/*slow=*/false)}});
AsyncGenerator<util::optional<ExecBatch>> sink_gen;
ASSERT_OK_AND_ASSIGN(
ExecNode * r_source,
MakeExecNode("source", plan.get(), {},
SourceNodeOptions{r_batches.schema, r_batches.gen(parallel,
/*slow=*/false)}));

ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}})
.AddToPlan(plan.get()));
ASSERT_OK_AND_ASSIGN(ExecNode * join, MakeExecNode("hashjoin", plan.get(),
{l_source, r_source}, join_options));

AsyncGenerator<util::optional<ExecBatch>> sink_gen;
ASSERT_OK_AND_ASSIGN(
std::ignore, MakeExecNode("sink", plan.get(), {join}, SinkNodeOptions{&sink_gen}));

ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen));

ASSERT_OK_AND_ASSIGN(*output, TableFromExecBatches(output_schema, res));
}

Expand Down Expand Up @@ -1056,6 +1063,10 @@ TEST(HashJoin, Random) {
// print num_rows, batch_size, join_type, join_cmp
std::cout << join_type_name << " " << key_cmp_str << " ";
key_types.Print();
std::cout << " payload_l: ";
payload_types[0].Print();
std::cout << " payload_r: ";
payload_types[1].Print();
std::cout << " num_rows_l = " << num_rows_l << " num_rows_r = " << num_rows_r
<< " batch size = " << batch_size
<< " parallel = " << (parallel ? "true" : "false");
Expand Down
Loading