Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ if(ARROW_COMPUTE)
compute/exec/project_node.cc
compute/exec/sink_node.cc
compute/exec/source_node.cc
compute/exec/swiss_join.cc
compute/exec/task_util.cc
compute/exec/tpch_node.cc
compute/exec/union_node.cc
Expand Down Expand Up @@ -459,6 +460,7 @@ if(ARROW_COMPUTE)
append_avx2_src(compute/exec/bloom_filter_avx2.cc)
append_avx2_src(compute/exec/key_hash_avx2.cc)
append_avx2_src(compute/exec/key_map_avx2.cc)
append_avx2_src(compute/exec/swiss_join_avx2.cc)
append_avx2_src(compute/exec/util_avx2.cc)
append_avx2_src(compute/row/compare_internal_avx2.cc)
append_avx2_src(compute/row/encode_internal_avx2.cc)
Expand Down
116 changes: 57 additions & 59 deletions cpp/src/arrow/compute/exec/hash_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,21 @@ class HashJoinBasicImpl : public HashJoinImpl {

public:
Status Init(ExecContext* ctx, JoinType join_type, size_t num_threads,
HashJoinSchema* schema_mgr, std::vector<JoinKeyCmp> key_cmp,
Expression filter, OutputBatchCallback output_batch_callback,
const HashJoinProjectionMaps* proj_map_left,
const HashJoinProjectionMaps* proj_map_right,
std::vector<JoinKeyCmp> key_cmp, Expression filter,
OutputBatchCallback output_batch_callback,
FinishedCallback finished_callback, TaskScheduler* scheduler) override {
START_COMPUTE_SPAN(span_, "HashJoinBasicImpl",
{{"detail", filter.ToString()},
{"join.kind", ToString(join_type)},
{"join.kind", arrow::compute::ToString(join_type)},
{"join.threads", static_cast<uint32_t>(num_threads)}});

num_threads_ = num_threads;
ctx_ = ctx;
join_type_ = join_type;
schema_mgr_ = schema_mgr;
schema_[0] = proj_map_left;
schema_[1] = proj_map_right;
key_cmp_ = std::move(key_cmp);
filter_ = std::move(filter);
output_batch_callback_ = std::move(output_batch_callback);
Expand Down Expand Up @@ -82,13 +85,15 @@ class HashJoinBasicImpl : public HashJoinImpl {
scheduler_->Abort(std::move(pos_abort_callback));
}

std::string ToString() const override { return "HashJoinBasicImpl"; }

private:
void InitEncoder(int side, HashJoinProjection projection_handle, RowEncoder* encoder) {
std::vector<TypeHolder> data_types;
int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle);
int num_cols = schema_[side]->num_cols(projection_handle);
data_types.resize(num_cols);
for (int icol = 0; icol < num_cols; ++icol) {
data_types[icol] = schema_mgr_->proj_maps[side].data_type(projection_handle, icol);
data_types[icol] = schema_[side]->data_type(projection_handle, icol);
}
encoder->Init(data_types, ctx_);
encoder->Clear();
Expand All @@ -99,8 +104,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
ThreadLocalState& local_state = local_states_[thread_index];
if (!local_state.is_initialized) {
InitEncoder(0, HashJoinProjection::KEY, &local_state.exec_batch_keys);
bool has_payload =
(schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0);
bool has_payload = (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0);
if (has_payload) {
InitEncoder(0, HashJoinProjection::PAYLOAD, &local_state.exec_batch_payloads);
}
Expand All @@ -112,11 +116,10 @@ class HashJoinBasicImpl : public HashJoinImpl {
Status EncodeBatch(int side, HashJoinProjection projection_handle, RowEncoder* encoder,
const ExecBatch& batch, ExecBatch* opt_projected_batch = nullptr) {
ExecBatch projected({}, batch.length);
int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle);
int num_cols = schema_[side]->num_cols(projection_handle);
projected.values.resize(num_cols);

auto to_input =
schema_mgr_->proj_maps[side].map(projection_handle, HashJoinProjection::INPUT);
auto to_input = schema_[side]->map(projection_handle, HashJoinProjection::INPUT);
for (int icol = 0; icol < num_cols; ++icol) {
projected.values[icol] = batch.values[to_input.get(icol)];
}
Expand Down Expand Up @@ -179,19 +182,17 @@ class HashJoinBasicImpl : public HashJoinImpl {
ExecBatch* opt_left_payload, ExecBatch* opt_right_key,
ExecBatch* opt_right_payload) {
ExecBatch result({}, batch_size_next);
int num_out_cols_left =
schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT);
int num_out_cols_right =
schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT);
int num_out_cols_left = schema_[0]->num_cols(HashJoinProjection::OUTPUT);
int num_out_cols_right = schema_[1]->num_cols(HashJoinProjection::OUTPUT);

result.values.resize(num_out_cols_left + num_out_cols_right);
auto from_key = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
HashJoinProjection::KEY);
auto from_payload = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
HashJoinProjection::PAYLOAD);
auto from_key = schema_[0]->map(HashJoinProjection::OUTPUT, HashJoinProjection::KEY);
auto from_payload =
schema_[0]->map(HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD);
for (int icol = 0; icol < num_out_cols_left; ++icol) {
bool is_from_key = (from_key.get(icol) != HashJoinSchema::kMissingField());
bool is_from_payload = (from_payload.get(icol) != HashJoinSchema::kMissingField());
bool is_from_key = (from_key.get(icol) != HashJoinProjectionMaps::kMissingField);
bool is_from_payload =
(from_payload.get(icol) != HashJoinProjectionMaps::kMissingField);
ARROW_DCHECK(is_from_key != is_from_payload);
ARROW_DCHECK(!is_from_key ||
(opt_left_key &&
Expand All @@ -206,13 +207,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
? opt_left_key->values[from_key.get(icol)]
: opt_left_payload->values[from_payload.get(icol)];
}
from_key = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT,
HashJoinProjection::KEY);
from_payload = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT,
HashJoinProjection::PAYLOAD);
from_key = schema_[1]->map(HashJoinProjection::OUTPUT, HashJoinProjection::KEY);
from_payload =
schema_[1]->map(HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD);
for (int icol = 0; icol < num_out_cols_right; ++icol) {
bool is_from_key = (from_key.get(icol) != HashJoinSchema::kMissingField());
bool is_from_payload = (from_payload.get(icol) != HashJoinSchema::kMissingField());
bool is_from_key = (from_key.get(icol) != HashJoinProjectionMaps::kMissingField);
bool is_from_payload =
(from_payload.get(icol) != HashJoinProjectionMaps::kMissingField);
ARROW_DCHECK(is_from_key != is_from_payload);
ARROW_DCHECK(!is_from_key ||
(opt_right_key &&
Expand All @@ -228,7 +229,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
: opt_right_payload->values[from_payload.get(icol)];
}

output_batch_callback_(std::move(result));
output_batch_callback_(0, std::move(result));

// Update the counter of produced batches
//
Expand All @@ -254,13 +255,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
hash_table_keys_.Decode(match_right.size(), match_right.data()));

ExecBatch left_payload;
if (!schema_mgr_->LeftPayloadIsEmpty()) {
if (!schema_[0]->is_empty(HashJoinProjection::PAYLOAD)) {
ARROW_ASSIGN_OR_RAISE(left_payload, local_state.exec_batch_payloads.Decode(
match_left.size(), match_left.data()));
}

ExecBatch right_payload;
if (!schema_mgr_->RightPayloadIsEmpty()) {
if (!schema_[1]->is_empty(HashJoinProjection::PAYLOAD)) {
ARROW_ASSIGN_OR_RAISE(right_payload, hash_table_payloads_.Decode(
match_right.size(), match_right.data()));
}
Expand All @@ -280,14 +281,14 @@ class HashJoinBasicImpl : public HashJoinImpl {
}
};

SchemaProjectionMap left_to_key = schema_mgr_->proj_maps[0].map(
HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap left_to_pay = schema_mgr_->proj_maps[0].map(
HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
SchemaProjectionMap right_to_key = schema_mgr_->proj_maps[1].map(
HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap right_to_pay = schema_mgr_->proj_maps[1].map(
HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
SchemaProjectionMap left_to_key =
schema_[0]->map(HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap left_to_pay =
schema_[0]->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
SchemaProjectionMap right_to_key =
schema_[1]->map(HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap right_to_pay =
schema_[1]->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);

AppendFields(left_to_key, left_to_pay, left_key, left_payload);
AppendFields(right_to_key, right_to_pay, right_key, right_payload);
Expand Down Expand Up @@ -363,15 +364,14 @@ class HashJoinBasicImpl : public HashJoinImpl {

bool has_left =
(join_type_ != JoinType::RIGHT_SEMI && join_type_ != JoinType::RIGHT_ANTI &&
schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT) > 0);
schema_[0]->num_cols(HashJoinProjection::OUTPUT) > 0);
bool has_right =
(join_type_ != JoinType::LEFT_SEMI && join_type_ != JoinType::LEFT_ANTI &&
schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT) > 0);
schema_[1]->num_cols(HashJoinProjection::OUTPUT) > 0);
bool has_left_payload =
has_left && (schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0);
has_left && (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0);
bool has_right_payload =
has_right &&
(schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0);
has_right && (schema_[1]->num_cols(HashJoinProjection::PAYLOAD) > 0);

ThreadLocalState& local_state = local_states_[thread_index];
RETURN_NOT_OK(InitLocalStateIfNeeded(thread_index));
Expand All @@ -394,7 +394,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
ARROW_ASSIGN_OR_RAISE(right_key,
hash_table_keys_.Decode(batch_size_next, opt_right_ids));
// Post process build side keys that use dictionary
RETURN_NOT_OK(dict_build_.PostDecode(schema_mgr_->proj_maps[1], &right_key, ctx_));
RETURN_NOT_OK(dict_build_.PostDecode(*schema_[1], &right_key, ctx_));
}
if (has_right_payload) {
ARROW_ASSIGN_OR_RAISE(right_payload,
Expand Down Expand Up @@ -494,8 +494,7 @@ class HashJoinBasicImpl : public HashJoinImpl {

RETURN_NOT_OK(EncodeBatch(0, HashJoinProjection::KEY, &local_state.exec_batch_keys,
batch, &batch_key_for_lookups));
bool has_left_payload =
(schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0);
bool has_left_payload = (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0);
if (has_left_payload) {
local_state.exec_batch_payloads.Clear();
RETURN_NOT_OK(EncodeBatch(0, HashJoinProjection::PAYLOAD,
Expand All @@ -507,13 +506,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
local_state.match_left.clear();
local_state.match_right.clear();

bool use_key_batch_for_dicts = dict_probe_.BatchRemapNeeded(
thread_index, schema_mgr_->proj_maps[0], schema_mgr_->proj_maps[1], ctx_);
bool use_key_batch_for_dicts =
dict_probe_.BatchRemapNeeded(thread_index, *schema_[0], *schema_[1], ctx_);
RowEncoder* row_encoder_for_lookups = &local_state.exec_batch_keys;
if (use_key_batch_for_dicts) {
RETURN_NOT_OK(dict_probe_.EncodeBatch(
thread_index, schema_mgr_->proj_maps[0], schema_mgr_->proj_maps[1], dict_build_,
batch, &row_encoder_for_lookups, &batch_key_for_lookups, ctx_));
RETURN_NOT_OK(dict_probe_.EncodeBatch(thread_index, *schema_[0], *schema_[1],
dict_build_, batch, &row_encoder_for_lookups,
&batch_key_for_lookups, ctx_));
}

// Collect information about all nulls in key columns.
Expand Down Expand Up @@ -561,9 +560,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
if (batches.empty()) {
hash_table_empty_ = true;
} else {
dict_build_.InitEncoder(schema_mgr_->proj_maps[1], &hash_table_keys_, ctx_);
bool has_payload =
(schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0);
dict_build_.InitEncoder(*schema_[1], &hash_table_keys_, ctx_);
bool has_payload = (schema_[1]->num_cols(HashJoinProjection::PAYLOAD) > 0);
if (has_payload) {
InitEncoder(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_);
}
Expand All @@ -578,11 +576,11 @@ class HashJoinBasicImpl : public HashJoinImpl {
} else if (hash_table_empty_) {
hash_table_empty_ = false;

RETURN_NOT_OK(dict_build_.Init(schema_mgr_->proj_maps[1], &batch, ctx_));
RETURN_NOT_OK(dict_build_.Init(*schema_[1], &batch, ctx_));
}
int32_t num_rows_before = hash_table_keys_.num_rows();
RETURN_NOT_OK(dict_build_.EncodeBatch(thread_index, schema_mgr_->proj_maps[1],
batch, &hash_table_keys_, ctx_));
RETURN_NOT_OK(dict_build_.EncodeBatch(thread_index, *schema_[1], batch,
&hash_table_keys_, ctx_));
if (has_payload) {
RETURN_NOT_OK(
EncodeBatch(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_, batch));
Expand All @@ -595,7 +593,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
}

if (hash_table_empty_) {
RETURN_NOT_OK(dict_build_.Init(schema_mgr_->proj_maps[1], nullptr, ctx_));
RETURN_NOT_OK(dict_build_.Init(*schema_[1], nullptr, ctx_));
}

return Status::OK();
Expand Down Expand Up @@ -740,7 +738,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
ExecContext* ctx_;
JoinType join_type_;
size_t num_threads_;
HashJoinSchema* schema_mgr_;
const HashJoinProjectionMaps* schema_[2];
std::vector<JoinKeyCmp> key_cmp_;
Expression filter_;
TaskScheduler* scheduler_;
Expand Down
75 changes: 7 additions & 68 deletions cpp/src/arrow/compute/exec/hash_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,90 +36,29 @@ namespace compute {

using arrow::util::AccumulationQueue;

class ARROW_EXPORT HashJoinSchema {
public:
Status Init(JoinType join_type, const Schema& left_schema,
const std::vector<FieldRef>& left_keys, const Schema& right_schema,
const std::vector<FieldRef>& right_keys, const Expression& filter,
const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);

Status Init(JoinType join_type, const Schema& left_schema,
const std::vector<FieldRef>& left_keys,
const std::vector<FieldRef>& left_output, const Schema& right_schema,
const std::vector<FieldRef>& right_keys,
const std::vector<FieldRef>& right_output, const Expression& filter,
const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);

static Status ValidateSchemas(JoinType join_type, const Schema& left_schema,
const std::vector<FieldRef>& left_keys,
const std::vector<FieldRef>& left_output,
const Schema& right_schema,
const std::vector<FieldRef>& right_keys,
const std::vector<FieldRef>& right_output,
const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);

Result<Expression> BindFilter(Expression filter, const Schema& left_schema,
const Schema& right_schema, ExecContext* exec_context);
std::shared_ptr<Schema> MakeOutputSchema(const std::string& left_field_name_suffix,
const std::string& right_field_name_suffix);

bool LeftPayloadIsEmpty() { return PayloadIsEmpty(0); }

bool RightPayloadIsEmpty() { return PayloadIsEmpty(1); }

static int kMissingField() {
return SchemaProjectionMaps<HashJoinProjection>::kMissingField;
}

SchemaProjectionMaps<HashJoinProjection> proj_maps[2];

private:
static bool IsTypeSupported(const DataType& type);

Status CollectFilterColumns(std::vector<FieldRef>& left_filter,
std::vector<FieldRef>& right_filter,
const Expression& filter, const Schema& left_schema,
const Schema& right_schema);

Expression RewriteFilterToUseFilterSchema(int right_filter_offset,
const SchemaProjectionMap& left_to_filter,
const SchemaProjectionMap& right_to_filter,
const Expression& filter);

bool PayloadIsEmpty(int side) {
ARROW_DCHECK(side == 0 || side == 1);
return proj_maps[side].num_cols(HashJoinProjection::PAYLOAD) == 0;
}

static Result<std::vector<FieldRef>> ComputePayload(const Schema& schema,
const std::vector<FieldRef>& output,
const std::vector<FieldRef>& filter,
const std::vector<FieldRef>& key);
};

class HashJoinImpl {
public:
using OutputBatchCallback = std::function<void(ExecBatch)>;
using OutputBatchCallback = std::function<void(int64_t, ExecBatch)>;
using BuildFinishedCallback = std::function<Status(size_t)>;
using ProbeFinishedCallback = std::function<Status(size_t)>;
using FinishedCallback = std::function<void(int64_t)>;

virtual ~HashJoinImpl() = default;
virtual Status Init(ExecContext* ctx, JoinType join_type, size_t num_threads,
HashJoinSchema* schema_mgr, std::vector<JoinKeyCmp> key_cmp,
Expression filter, OutputBatchCallback output_batch_callback,
const HashJoinProjectionMaps* proj_map_left,
const HashJoinProjectionMaps* proj_map_right,
std::vector<JoinKeyCmp> key_cmp, Expression filter,
OutputBatchCallback output_batch_callback,
FinishedCallback finished_callback, TaskScheduler* scheduler) = 0;

virtual Status BuildHashTable(size_t thread_index, AccumulationQueue batches,
BuildFinishedCallback on_finished) = 0;
virtual Status ProbeSingleBatch(size_t thread_index, ExecBatch batch) = 0;
virtual Status ProbingFinished(size_t thread_index) = 0;
virtual void Abort(TaskScheduler::AbortContinuationImpl pos_abort_callback) = 0;
virtual std::string ToString() const = 0;

static Result<std::unique_ptr<HashJoinImpl>> MakeBasic();
static Result<std::unique_ptr<HashJoinImpl>> MakeSwiss();

protected:
util::tracing::Span span_;
Expand Down
Loading