Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ if(ARROW_COMPUTE)
compute/exec/project_node.cc
compute/exec/sink_node.cc
compute/exec/source_node.cc
compute/exec/swiss_join.cc
compute/exec/task_util.cc
compute/exec/tpch_node.cc
compute/exec/union_node.cc
Expand Down Expand Up @@ -452,6 +453,7 @@ if(ARROW_COMPUTE)
append_avx2_src(compute/exec/key_encode_avx2.cc)
append_avx2_src(compute/exec/key_hash_avx2.cc)
append_avx2_src(compute/exec/key_map_avx2.cc)
append_avx2_src(compute/exec/swiss_join_avx2.cc)
append_avx2_src(compute/exec/util_avx2.cc)

list(APPEND ARROW_TESTING_SRCS compute/exec/test_util.cc)
Expand Down
102 changes: 47 additions & 55 deletions cpp/src/arrow/compute/exec/hash_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
}

Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution,
size_t num_threads, HashJoinSchema* schema_mgr,
size_t num_threads, const HashJoinProjectionMaps* proj_map_left,
const HashJoinProjectionMaps* proj_map_right,
std::vector<JoinKeyCmp> key_cmp, Expression filter,
OutputBatchCallback output_batch_callback,
FinishedCallback finished_callback,
Expand All @@ -98,7 +99,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
ctx_ = ctx;
join_type_ = join_type;
num_threads_ = num_threads;
schema_mgr_ = schema_mgr;
schema_[0] = proj_map_left;
schema_[1] = proj_map_right;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we get rid of schema manager?

key_cmp_ = std::move(key_cmp);
filter_ = std::move(filter);
output_batch_callback_ = std::move(output_batch_callback);
Expand Down Expand Up @@ -141,12 +143,11 @@ class HashJoinBasicImpl : public HashJoinImpl {
private:
void InitEncoder(int side, HashJoinProjection projection_handle, RowEncoder* encoder) {
std::vector<ValueDescr> data_types;
int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle);
int num_cols = schema_[side]->num_cols(projection_handle);
data_types.resize(num_cols);
for (int icol = 0; icol < num_cols; ++icol) {
data_types[icol] =
ValueDescr(schema_mgr_->proj_maps[side].data_type(projection_handle, icol),
ValueDescr::ARRAY);
data_types[icol] = ValueDescr(schema_[side]->data_type(projection_handle, icol),
ValueDescr::ARRAY);
}
encoder->Init(data_types, ctx_);
encoder->Clear();
Expand All @@ -157,8 +158,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
ThreadLocalState& local_state = local_states_[thread_index];
if (!local_state.is_initialized) {
InitEncoder(0, HashJoinProjection::KEY, &local_state.exec_batch_keys);
bool has_payload =
(schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0);
bool has_payload = (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0);
if (has_payload) {
InitEncoder(0, HashJoinProjection::PAYLOAD, &local_state.exec_batch_payloads);
}
Expand All @@ -170,11 +170,10 @@ class HashJoinBasicImpl : public HashJoinImpl {
Status EncodeBatch(int side, HashJoinProjection projection_handle, RowEncoder* encoder,
const ExecBatch& batch, ExecBatch* opt_projected_batch = nullptr) {
ExecBatch projected({}, batch.length);
int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle);
int num_cols = schema_[side]->num_cols(projection_handle);
projected.values.resize(num_cols);

auto to_input =
schema_mgr_->proj_maps[side].map(projection_handle, HashJoinProjection::INPUT);
auto to_input = schema_[side]->map(projection_handle, HashJoinProjection::INPUT);
for (int icol = 0; icol < num_cols; ++icol) {
projected.values[icol] = batch.values[to_input.get(icol)];
}
Expand Down Expand Up @@ -237,16 +236,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
ExecBatch* opt_left_payload, ExecBatch* opt_right_key,
ExecBatch* opt_right_payload) {
ExecBatch result({}, batch_size_next);
int num_out_cols_left =
schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT);
int num_out_cols_right =
schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT);
int num_out_cols_left = schema_[0]->num_cols(HashJoinProjection::OUTPUT);
int num_out_cols_right = schema_[1]->num_cols(HashJoinProjection::OUTPUT);

result.values.resize(num_out_cols_left + num_out_cols_right);
auto from_key = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
HashJoinProjection::KEY);
auto from_payload = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
HashJoinProjection::PAYLOAD);
auto from_key = schema_[0]->map(HashJoinProjection::OUTPUT, HashJoinProjection::KEY);
auto from_payload =
schema_[0]->map(HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD);
for (int icol = 0; icol < num_out_cols_left; ++icol) {
bool is_from_key = (from_key.get(icol) != HashJoinSchema::kMissingField());
bool is_from_payload = (from_payload.get(icol) != HashJoinSchema::kMissingField());
Expand All @@ -264,10 +260,9 @@ class HashJoinBasicImpl : public HashJoinImpl {
? opt_left_key->values[from_key.get(icol)]
: opt_left_payload->values[from_payload.get(icol)];
}
from_key = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT,
HashJoinProjection::KEY);
from_payload = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT,
HashJoinProjection::PAYLOAD);
from_key = schema_[1]->map(HashJoinProjection::OUTPUT, HashJoinProjection::KEY);
from_payload =
schema_[1]->map(HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD);
for (int icol = 0; icol < num_out_cols_right; ++icol) {
bool is_from_key = (from_key.get(icol) != HashJoinSchema::kMissingField());
bool is_from_payload = (from_payload.get(icol) != HashJoinSchema::kMissingField());
Expand All @@ -286,7 +281,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
: opt_right_payload->values[from_payload.get(icol)];
}

output_batch_callback_(std::move(result));
output_batch_callback_(0, std::move(result));

// Update the counter of produced batches
//
Expand All @@ -312,13 +307,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
hash_table_keys_.Decode(match_right.size(), match_right.data()));

ExecBatch left_payload;
if (!schema_mgr_->LeftPayloadIsEmpty()) {
if (!schema_[0]->is_empty(HashJoinProjection::PAYLOAD)) {
ARROW_ASSIGN_OR_RAISE(left_payload, local_state.exec_batch_payloads.Decode(
match_left.size(), match_left.data()));
}

ExecBatch right_payload;
if (!schema_mgr_->RightPayloadIsEmpty()) {
if (!schema_[1]->is_empty(HashJoinProjection::PAYLOAD)) {
ARROW_ASSIGN_OR_RAISE(right_payload, hash_table_payloads_.Decode(
match_right.size(), match_right.data()));
}
Expand All @@ -338,14 +333,14 @@ class HashJoinBasicImpl : public HashJoinImpl {
}
};

SchemaProjectionMap left_to_key = schema_mgr_->proj_maps[0].map(
HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap left_to_pay = schema_mgr_->proj_maps[0].map(
HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
SchemaProjectionMap right_to_key = schema_mgr_->proj_maps[1].map(
HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap right_to_pay = schema_mgr_->proj_maps[1].map(
HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
SchemaProjectionMap left_to_key =
schema_[0]->map(HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap left_to_pay =
schema_[0]->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
SchemaProjectionMap right_to_key =
schema_[1]->map(HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap right_to_pay =
schema_[1]->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);

AppendFields(left_to_key, left_to_pay, left_key, left_payload);
AppendFields(right_to_key, right_to_pay, right_key, right_payload);
Expand Down Expand Up @@ -421,15 +416,14 @@ class HashJoinBasicImpl : public HashJoinImpl {

bool has_left =
(join_type_ != JoinType::RIGHT_SEMI && join_type_ != JoinType::RIGHT_ANTI &&
schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT) > 0);
schema_[0]->num_cols(HashJoinProjection::OUTPUT) > 0);
bool has_right =
(join_type_ != JoinType::LEFT_SEMI && join_type_ != JoinType::LEFT_ANTI &&
schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT) > 0);
schema_[1]->num_cols(HashJoinProjection::OUTPUT) > 0);
bool has_left_payload =
has_left && (schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0);
has_left && (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0);
bool has_right_payload =
has_right &&
(schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0);
has_right && (schema_[1]->num_cols(HashJoinProjection::PAYLOAD) > 0);

ThreadLocalState& local_state = local_states_[thread_index];
InitLocalStateIfNeeded(thread_index);
Expand All @@ -452,7 +446,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
ARROW_ASSIGN_OR_RAISE(right_key,
hash_table_keys_.Decode(batch_size_next, opt_right_ids));
// Post process build side keys that use dictionary
RETURN_NOT_OK(dict_build_.PostDecode(schema_mgr_->proj_maps[1], &right_key, ctx_));
RETURN_NOT_OK(dict_build_.PostDecode(*schema_[1], &right_key, ctx_));
}
if (has_right_payload) {
ARROW_ASSIGN_OR_RAISE(right_payload,
Expand Down Expand Up @@ -552,8 +546,7 @@ class HashJoinBasicImpl : public HashJoinImpl {

RETURN_NOT_OK(EncodeBatch(0, HashJoinProjection::KEY, &local_state.exec_batch_keys,
batch, &batch_key_for_lookups));
bool has_left_payload =
(schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0);
bool has_left_payload = (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0);
if (has_left_payload) {
local_state.exec_batch_payloads.Clear();
RETURN_NOT_OK(EncodeBatch(0, HashJoinProjection::PAYLOAD,
Expand All @@ -565,13 +558,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
local_state.match_left.clear();
local_state.match_right.clear();

bool use_key_batch_for_dicts = dict_probe_.BatchRemapNeeded(
thread_index, schema_mgr_->proj_maps[0], schema_mgr_->proj_maps[1], ctx_);
bool use_key_batch_for_dicts =
dict_probe_.BatchRemapNeeded(thread_index, *schema_[0], *schema_[1], ctx_);
RowEncoder* row_encoder_for_lookups = &local_state.exec_batch_keys;
if (use_key_batch_for_dicts) {
RETURN_NOT_OK(dict_probe_.EncodeBatch(
thread_index, schema_mgr_->proj_maps[0], schema_mgr_->proj_maps[1], dict_build_,
batch, &row_encoder_for_lookups, &batch_key_for_lookups, ctx_));
RETURN_NOT_OK(dict_probe_.EncodeBatch(thread_index, *schema_[0], *schema_[1],
dict_build_, batch, &row_encoder_for_lookups,
&batch_key_for_lookups, ctx_));
}

// Collect information about all nulls in key columns.
Expand Down Expand Up @@ -611,9 +604,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
if (batches.empty()) {
hash_table_empty_ = true;
} else {
dict_build_.InitEncoder(schema_mgr_->proj_maps[1], &hash_table_keys_, ctx_);
bool has_payload =
(schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0);
dict_build_.InitEncoder(*schema_[1], &hash_table_keys_, ctx_);
bool has_payload = (schema_[1]->num_cols(HashJoinProjection::PAYLOAD) > 0);
if (has_payload) {
InitEncoder(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_);
}
Expand All @@ -628,11 +620,11 @@ class HashJoinBasicImpl : public HashJoinImpl {
} else if (hash_table_empty_) {
hash_table_empty_ = false;

RETURN_NOT_OK(dict_build_.Init(schema_mgr_->proj_maps[1], &batch, ctx_));
RETURN_NOT_OK(dict_build_.Init(*schema_[1], &batch, ctx_));
}
int32_t num_rows_before = hash_table_keys_.num_rows();
RETURN_NOT_OK(dict_build_.EncodeBatch(thread_index, schema_mgr_->proj_maps[1],
batch, &hash_table_keys_, ctx_));
RETURN_NOT_OK(dict_build_.EncodeBatch(thread_index, *schema_[1], batch,
&hash_table_keys_, ctx_));
if (has_payload) {
RETURN_NOT_OK(
EncodeBatch(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_, batch));
Expand All @@ -645,7 +637,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
}

if (hash_table_empty_) {
RETURN_NOT_OK(dict_build_.Init(schema_mgr_->proj_maps[1], nullptr, ctx_));
RETURN_NOT_OK(dict_build_.Init(*schema_[1], nullptr, ctx_));
}

return Status::OK();
Expand Down Expand Up @@ -871,7 +863,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
ExecContext* ctx_;
JoinType join_type_;
size_t num_threads_;
HashJoinSchema* schema_mgr_;
const HashJoinProjectionMaps* schema_[2];
std::vector<JoinKeyCmp> key_cmp_;
Expression filter_;
std::unique_ptr<TaskScheduler> scheduler_;
Expand Down
10 changes: 8 additions & 2 deletions cpp/src/arrow/compute/exec/hash_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ class ARROW_EXPORT HashJoinSchema {
const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);

bool HasDictionaries() const;

bool HasLargeBinary() const;

Result<Expression> BindFilter(Expression filter, const Schema& left_schema,
const Schema& right_schema);
std::shared_ptr<Schema> MakeOutputSchema(const std::string& left_field_name_suffix,
Expand Down Expand Up @@ -98,12 +102,13 @@ class ARROW_EXPORT HashJoinSchema {

class HashJoinImpl {
public:
using OutputBatchCallback = std::function<void(ExecBatch)>;
using OutputBatchCallback = std::function<void(int64_t, ExecBatch)>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this new int64 parameter for? Doesn't it just get ignored later?

using FinishedCallback = std::function<void(int64_t)>;

virtual ~HashJoinImpl() = default;
virtual Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution,
size_t num_threads, HashJoinSchema* schema_mgr,
size_t num_threads, const HashJoinProjectionMaps* proj_map_left,
const HashJoinProjectionMaps* proj_map_right,
std::vector<JoinKeyCmp> key_cmp, Expression filter,
OutputBatchCallback output_batch_callback,
FinishedCallback finished_callback,
Expand All @@ -113,6 +118,7 @@ class HashJoinImpl {
virtual void Abort(TaskScheduler::AbortContinuationImpl pos_abort_callback) = 0;

static Result<std::unique_ptr<HashJoinImpl>> MakeBasic();
static Result<std::unique_ptr<HashJoinImpl>> MakeSwiss();

protected:
util::tracing::Span span_;
Expand Down
9 changes: 6 additions & 3 deletions cpp/src/arrow/compute/exec/hash_join_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class JoinBenchmark {
build_metadata["null_probability"] = std::to_string(settings.null_percentage);
build_metadata["min"] = std::to_string(min_build_value);
build_metadata["max"] = std::to_string(max_build_value);
build_metadata["min_length"] = "2";
build_metadata["max_length"] = "20";

std::unordered_map<std::string, std::string> probe_metadata;
probe_metadata["null_probability"] = std::to_string(settings.null_percentage);
Expand Down Expand Up @@ -124,7 +126,7 @@ class JoinBenchmark {
DCHECK_OK(schema_mgr_->Init(settings.join_type, *l_batches_.schema, left_keys,
*r_batches_.schema, right_keys, filter, "l_", "r_"));

join_ = *HashJoinImpl::MakeBasic();
join_ = *HashJoinImpl::MakeSwiss();

omp_set_num_threads(settings.num_threads);
auto schedule_callback = [](std::function<Status(size_t)> func) -> Status {
Expand All @@ -135,8 +137,9 @@ class JoinBenchmark {

DCHECK_OK(join_->Init(
ctx_.get(), settings.join_type, !is_parallel, settings.num_threads,
schema_mgr_.get(), {JoinKeyCmp::EQ}, std::move(filter), [](ExecBatch) {},
[](int64_t x) {}, schedule_callback));
&(schema_mgr_->proj_maps[0]), &(schema_mgr_->proj_maps[1]), {JoinKeyCmp::EQ},
std::move(filter), [](int64_t, ExecBatch) {}, [](int64_t x) {},
schedule_callback));
}

void RunJoin() {
Expand Down
Loading