diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index e06fad9a1de..5d01376796b 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -411,12 +411,16 @@ if(ARROW_COMPUTE) compute/kernels/vector_replace.cc compute/kernels/vector_selection.cc compute/kernels/vector_sort.cc + compute/kernels/row_encoder.cc compute/exec/union_node.cc compute/exec/key_hash.cc compute/exec/key_map.cc compute/exec/key_compare.cc compute/exec/key_encode.cc - compute/exec/util.cc) + compute/exec/util.cc + compute/exec/hash_join.cc + compute/exec/hash_join_node.cc + compute/exec/task_util.cc) append_avx2_src(compute/kernels/aggregate_basic_avx2.cc) append_avx512_src(compute/kernels/aggregate_basic_avx512.cc) diff --git a/cpp/src/arrow/compute/exec/CMakeLists.txt b/cpp/src/arrow/compute/exec/CMakeLists.txt index f269ec33ed5..ccc36c093e8 100644 --- a/cpp/src/arrow/compute/exec/CMakeLists.txt +++ b/cpp/src/arrow/compute/exec/CMakeLists.txt @@ -25,6 +25,9 @@ add_arrow_compute_test(expression_test subtree_test.cc) add_arrow_compute_test(plan_test PREFIX "arrow-compute") +add_arrow_compute_test(hash_join_node_test PREFIX "arrow-compute") add_arrow_compute_test(union_node_test PREFIX "arrow-compute") +add_arrow_compute_test(util_test PREFIX "arrow-compute") + add_arrow_benchmark(expression_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index a53c6532591..5c64cf2fc30 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -356,6 +356,7 @@ void RegisterProjectNode(ExecFactoryRegistry*); void RegisterUnionNode(ExecFactoryRegistry*); void RegisterAggregateNode(ExecFactoryRegistry*); void RegisterSinkNode(ExecFactoryRegistry*); +void RegisterHashJoinNode(ExecFactoryRegistry*); } // namespace internal @@ -369,6 +370,7 @@ ExecFactoryRegistry* default_exec_factory_registry() { internal::RegisterUnionNode(this); internal::RegisterAggregateNode(this); internal::RegisterSinkNode(this); + internal::RegisterHashJoinNode(this); } Result GetFactory(const std::string& factory_name) override { diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc new file mode 100644 index 00000000000..9500beb666a --- /dev/null +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -0,0 +1,702 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec/hash_join.h" + +#include +#include +#include +#include +#include +#include + +#include "arrow/compute/exec/task_util.h" +#include "arrow/compute/kernels/row_encoder.h" + +namespace arrow { +namespace compute { + +using internal::RowEncoder; + +class HashJoinBasicImpl : public HashJoinImpl { + public: + Status InputReceived(size_t thread_index, int side, ExecBatch batch) override { + if (cancelled_) { + return Status::Cancelled("Hash join cancelled"); + } + if (QueueBatchIfNeeded(side, batch)) { + return Status::OK(); + } else { + ARROW_DCHECK(side == 0); + return ProbeBatch(thread_index, batch); + } + } + + Status InputFinished(size_t thread_index, int side) override { + if (cancelled_) { + return Status::Cancelled("Hash join cancelled"); + } + if (side == 0) { + bool proceed; + { + std::lock_guard lock(finished_mutex_); + proceed = !left_side_finished_ && left_queue_finished_; + left_side_finished_ = true; + } + if (proceed) { + RETURN_NOT_OK(OnLeftSideAndQueueFinished(thread_index)); + } + } else { + bool proceed; + { + std::lock_guard lock(finished_mutex_); + proceed = !right_side_finished_; + right_side_finished_ = true; + } + if (proceed) { + RETURN_NOT_OK(OnRightSideFinished(thread_index)); + } + } + return Status::OK(); + } + + Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution, + size_t num_threads, HashJoinSchema* schema_mgr, + std::vector key_cmp, OutputBatchCallback output_batch_callback, + FinishedCallback finished_callback, + TaskScheduler::ScheduleImpl schedule_task_callback) override { + num_threads = std::max(num_threads, static_cast(1)); + + ctx_ = ctx; + join_type_ = join_type; + num_threads_ = num_threads; + schema_mgr_ = schema_mgr; + key_cmp_ = std::move(key_cmp); + output_batch_callback_ = std::move(output_batch_callback); + finished_callback_ = std::move(finished_callback); + local_states_.resize(num_threads); + for (size_t i = 0; i < local_states_.size(); ++i) { + local_states_[i].is_initialized = false; + } + + has_hash_table_ = false; + num_batches_produced_.store(0); + cancelled_ = false; + right_side_finished_ = false; + left_side_finished_ = false; + left_queue_finished_ = false; + + scheduler_ = TaskScheduler::Make(); + RegisterBuildHashTable(); + RegisterProbeQueuedBatches(); + RegisterScanHashTable(); + scheduler_->RegisterEnd(); + RETURN_NOT_OK(scheduler_->StartScheduling( + 0 /*thread index*/, std::move(schedule_task_callback), + static_cast(2 * num_threads) /*concurrent tasks*/, use_sync_execution)); + + return Status::OK(); + } + + void Abort(TaskScheduler::AbortContinuationImpl pos_abort_callback) override { + cancelled_ = true; + scheduler_->Abort(std::move(pos_abort_callback)); + } + + private: + void InitEncoder(int side, HashJoinProjection projection_handle, RowEncoder* encoder) { + std::vector data_types; + int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle); + data_types.resize(num_cols); + for (int icol = 0; icol < num_cols; ++icol) { + data_types[icol] = + ValueDescr(schema_mgr_->proj_maps[side].data_type(projection_handle, icol), + ValueDescr::ARRAY); + } + encoder->Init(data_types, ctx_); + encoder->Clear(); + } + + void InitLocalStateIfNeeded(size_t thread_index) { + ThreadLocalState& local_state = local_states_[thread_index]; + if (!local_state.is_initialized) { + InitEncoder(0, HashJoinProjection::KEY, &local_state.exec_batch_keys); + bool has_payload = + (schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0); + if (has_payload) { + InitEncoder(0, HashJoinProjection::PAYLOAD, &local_state.exec_batch_payloads); + } + local_state.is_initialized = true; + } + } + + Status EncodeBatch(int side, HashJoinProjection projection_handle, RowEncoder* encoder, + const ExecBatch& batch) { + ExecBatch projected({}, batch.length); + int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle); + projected.values.resize(num_cols); + + const int* to_input = + schema_mgr_->proj_maps[side].map(projection_handle, HashJoinProjection::INPUT); + for (int icol = 0; icol < num_cols; ++icol) { + projected.values[icol] = batch.values[to_input[icol]]; + } + + return encoder->EncodeAndAppend(projected); + } + + void ProbeBatch_Lookup(const RowEncoder& exec_batch_keys, + const std::vector& non_null_bit_vectors, + const std::vector& non_null_bit_vector_offsets, + std::vector* output_match, + std::vector* output_no_match, + std::vector* output_match_left, + std::vector* output_match_right) { + ARROW_DCHECK(has_hash_table_); + int num_cols = static_cast(non_null_bit_vectors.size()); + for (int32_t irow = 0; irow < exec_batch_keys.num_rows(); ++irow) { + // Apply null key filtering + bool no_match = hash_table_empty_; + for (int icol = 0; icol < num_cols; ++icol) { + bool is_null = non_null_bit_vectors[icol] && + !BitUtil::GetBit(non_null_bit_vectors[icol], + non_null_bit_vector_offsets[icol] + irow); + if (key_cmp_[icol] == JoinKeyCmp::EQ && is_null) { + no_match = true; + break; + } + } + if (no_match) { + output_no_match->push_back(irow); + continue; + } + // Get all matches from hash table + bool has_match = false; + + auto range = hash_table_.equal_range(exec_batch_keys.encoded_row(irow)); + for (auto it = range.first; it != range.second; ++it) { + output_match_left->push_back(irow); + output_match_right->push_back(it->second); + has_match_[it->second] = 0xFF; + has_match = true; + } + if (!has_match) { + output_no_match->push_back(irow); + } else { + output_match->push_back(irow); + } + } + } + + void ProbeBatch_OutputOne(int64_t batch_size_next, ExecBatch* opt_left_key, + ExecBatch* opt_left_payload, ExecBatch* opt_right_key, + ExecBatch* opt_right_payload) { + ExecBatch result({}, batch_size_next); + int num_out_cols_left = + schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT); + int num_out_cols_right = + schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT); + ARROW_DCHECK((opt_left_payload == nullptr) == + (schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) == 0)); + ARROW_DCHECK((opt_right_payload == nullptr) == + (schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) == 0)); + result.values.resize(num_out_cols_left + num_out_cols_right); + const int* from_key = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT, + HashJoinProjection::KEY); + const int* from_payload = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT, + HashJoinProjection::PAYLOAD); + for (int icol = 0; icol < num_out_cols_left; ++icol) { + bool is_from_key = (from_key[icol] != HashJoinSchema::kMissingField()); + bool is_from_payload = (from_payload[icol] != HashJoinSchema::kMissingField()); + ARROW_DCHECK(is_from_key != is_from_payload); + ARROW_DCHECK(!is_from_key || + (opt_left_key && + from_key[icol] < static_cast(opt_left_key->values.size()) && + opt_left_key->length == batch_size_next)); + ARROW_DCHECK( + !is_from_payload || + (opt_left_payload && + from_payload[icol] < static_cast(opt_left_payload->values.size()) && + opt_left_payload->length == batch_size_next)); + result.values[icol] = is_from_key ? opt_left_key->values[from_key[icol]] + : opt_left_payload->values[from_payload[icol]]; + } + from_key = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT, + HashJoinProjection::KEY); + from_payload = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT, + HashJoinProjection::PAYLOAD); + for (int icol = 0; icol < num_out_cols_right; ++icol) { + bool is_from_key = (from_key[icol] != HashJoinSchema::kMissingField()); + bool is_from_payload = (from_payload[icol] != HashJoinSchema::kMissingField()); + ARROW_DCHECK(is_from_key != is_from_payload); + ARROW_DCHECK(!is_from_key || + (opt_right_key && + from_key[icol] < static_cast(opt_right_key->values.size()) && + opt_right_key->length == batch_size_next)); + ARROW_DCHECK( + !is_from_payload || + (opt_right_payload && + from_payload[icol] < static_cast(opt_right_payload->values.size()) && + opt_right_payload->length == batch_size_next)); + result.values[num_out_cols_left + icol] = + is_from_key ? opt_right_key->values[from_key[icol]] + : opt_right_payload->values[from_payload[icol]]; + } + + output_batch_callback_(std::move(result)); + + // Update the counter of produced batches + // + num_batches_produced_++; + } + + Status ProbeBatch_OutputOne(size_t thread_index, int64_t batch_size_next, + const int32_t* opt_left_ids, const int32_t* opt_right_ids) { + if (batch_size_next == 0 || (!opt_left_ids && !opt_right_ids)) { + return Status::OK(); + } + + bool has_left = + (join_type_ != JoinType::RIGHT_SEMI && join_type_ != JoinType::RIGHT_ANTI && + schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT) > 0); + bool has_right = + (join_type_ != JoinType::LEFT_SEMI && join_type_ != JoinType::LEFT_ANTI && + schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT) > 0); + bool has_left_payload = + has_left && (schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0); + bool has_right_payload = + has_right && + (schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0); + + ThreadLocalState& local_state = local_states_[thread_index]; + InitLocalStateIfNeeded(thread_index); + + ExecBatch left_key; + ExecBatch left_payload; + ExecBatch right_key; + ExecBatch right_payload; + if (has_left) { + ARROW_DCHECK(opt_left_ids); + ARROW_ASSIGN_OR_RAISE( + left_key, local_state.exec_batch_keys.Decode(batch_size_next, opt_left_ids)); + } + if (has_left_payload) { + ARROW_ASSIGN_OR_RAISE(left_payload, local_state.exec_batch_payloads.Decode( + batch_size_next, opt_left_ids)); + } + if (has_right) { + ARROW_DCHECK(opt_right_ids); + ARROW_ASSIGN_OR_RAISE(right_key, + hash_table_keys_.Decode(batch_size_next, opt_right_ids)); + } + if (has_right_payload) { + ARROW_ASSIGN_OR_RAISE(right_payload, + hash_table_payloads_.Decode(batch_size_next, opt_right_ids)); + } + + ProbeBatch_OutputOne(batch_size_next, has_left ? &left_key : nullptr, + has_left_payload ? &left_payload : nullptr, + has_right ? &right_key : nullptr, + has_right_payload ? &right_payload : nullptr); + + return Status::OK(); + } + + Status ProbeBatch_OutputAll(size_t thread_index, const RowEncoder& exec_batch_keys, + const RowEncoder& exec_batch_payloads, + const std::vector& match, + const std::vector& no_match, + std::vector& match_left, + std::vector& match_right) { + if (join_type_ == JoinType::RIGHT_SEMI || join_type_ == JoinType::RIGHT_ANTI) { + // Nothing to output + return Status::OK(); + } + + if (join_type_ == JoinType::LEFT_ANTI || join_type_ == JoinType::LEFT_SEMI) { + const std::vector& out_ids = + (join_type_ == JoinType::LEFT_SEMI) ? match : no_match; + + for (size_t start = 0; start < out_ids.size(); start += output_batch_size_) { + int64_t batch_size_next = std::min(static_cast(out_ids.size() - start), + static_cast(output_batch_size_)); + RETURN_NOT_OK(ProbeBatch_OutputOne(thread_index, batch_size_next, + out_ids.data() + start, nullptr)); + } + } else { + if (join_type_ == JoinType::LEFT_OUTER || join_type_ == JoinType::FULL_OUTER) { + for (size_t i = 0; i < no_match.size(); ++i) { + match_left.push_back(no_match[i]); + match_right.push_back(RowEncoder::kRowIdForNulls()); + } + } + + ARROW_DCHECK(match_left.size() == match_right.size()); + + for (size_t start = 0; start < match_left.size(); start += output_batch_size_) { + int64_t batch_size_next = + std::min(static_cast(match_left.size() - start), + static_cast(output_batch_size_)); + RETURN_NOT_OK(ProbeBatch_OutputOne(thread_index, batch_size_next, + match_left.data() + start, + match_right.data() + start)); + } + } + return Status::OK(); + } + + Status ProbeBatch(size_t thread_index, const ExecBatch& batch) { + ThreadLocalState& local_state = local_states_[thread_index]; + InitLocalStateIfNeeded(thread_index); + + local_state.exec_batch_keys.Clear(); + RETURN_NOT_OK( + EncodeBatch(0, HashJoinProjection::KEY, &local_state.exec_batch_keys, batch)); + bool has_left_payload = + (schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0); + if (has_left_payload) { + local_state.exec_batch_payloads.Clear(); + RETURN_NOT_OK(EncodeBatch(0, HashJoinProjection::PAYLOAD, + &local_state.exec_batch_payloads, batch)); + } + + local_state.match.clear(); + local_state.no_match.clear(); + local_state.match_left.clear(); + local_state.match_right.clear(); + + std::vector non_null_bit_vectors; + std::vector non_null_bit_vector_offsets; + int num_key_cols = schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::KEY); + non_null_bit_vectors.resize(num_key_cols); + non_null_bit_vector_offsets.resize(num_key_cols); + const int* from_batch = + schema_mgr_->proj_maps[0].map(HashJoinProjection::KEY, HashJoinProjection::INPUT); + for (int i = 0; i < num_key_cols; ++i) { + int input_col_id = from_batch[i]; + const uint8_t* non_nulls = nullptr; + int64_t offset = 0; + if (batch[input_col_id].array()->buffers[0] != NULLPTR) { + non_nulls = batch[input_col_id].array()->buffers[0]->data(); + offset = batch[input_col_id].array()->offset; + } + non_null_bit_vectors[i] = non_nulls; + non_null_bit_vector_offsets[i] = offset; + } + + ProbeBatch_Lookup(local_state.exec_batch_keys, non_null_bit_vectors, + non_null_bit_vector_offsets, &local_state.match, + &local_state.no_match, &local_state.match_left, + &local_state.match_right); + + RETURN_NOT_OK(ProbeBatch_OutputAll(thread_index, local_state.exec_batch_keys, + local_state.exec_batch_payloads, local_state.match, + local_state.no_match, local_state.match_left, + local_state.match_right)); + + return Status::OK(); + } + + int64_t BuildHashTable_num_tasks() { return 1; } + + Status BuildHashTable_exec_task(size_t thread_index, int64_t /*task_id*/) { + const std::vector& batches = right_batches_; + if (batches.empty()) { + hash_table_empty_ = true; + } else { + InitEncoder(1, HashJoinProjection::KEY, &hash_table_keys_); + bool has_payload = + (schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0); + if (has_payload) { + InitEncoder(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_); + } + hash_table_empty_ = true; + for (size_t ibatch = 0; ibatch < batches.size(); ++ibatch) { + if (cancelled_) { + return Status::Cancelled("Hash join cancelled"); + } + const ExecBatch& batch = batches[ibatch]; + if (batch.length == 0) { + continue; + } else { + hash_table_empty_ = false; + } + int32_t num_rows_before = hash_table_keys_.num_rows(); + RETURN_NOT_OK(EncodeBatch(1, HashJoinProjection::KEY, &hash_table_keys_, batch)); + if (has_payload) { + RETURN_NOT_OK( + EncodeBatch(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_, batch)); + } + int32_t num_rows_after = hash_table_keys_.num_rows(); + for (int32_t irow = num_rows_before; irow < num_rows_after; ++irow) { + hash_table_.insert(std::make_pair(hash_table_keys_.encoded_row(irow), irow)); + } + } + if (!hash_table_empty_) { + int32_t num_rows = hash_table_keys_.num_rows(); + has_match_.resize(num_rows); + memset(has_match_.data(), 0, num_rows); + } + } + return Status::OK(); + } + + Status BuildHashTable_on_finished(size_t thread_index) { + if (cancelled_) { + return Status::Cancelled("Hash join cancelled"); + } + + { + std::lock_guard lock(left_batches_mutex_); + has_hash_table_ = true; + } + + right_batches_.clear(); + + RETURN_NOT_OK(ProbeQueuedBatches(thread_index)); + + return Status::OK(); + } + + void RegisterBuildHashTable() { + task_group_build_ = scheduler_->RegisterTaskGroup( + [this](size_t thread_index, int64_t task_id) -> Status { + return BuildHashTable_exec_task(thread_index, task_id); + }, + [this](size_t thread_index) -> Status { + return BuildHashTable_on_finished(thread_index); + }); + } + + Status BuildHashTable(size_t thread_index) { + return scheduler_->StartTaskGroup(thread_index, task_group_build_, + BuildHashTable_num_tasks()); + } + + int64_t ProbeQueuedBatches_num_tasks() { + return static_cast(left_batches_.size()); + } + + Status ProbeQueuedBatches_exec_task(size_t thread_index, int64_t task_id) { + if (cancelled_) { + return Status::Cancelled("Hash join cancelled"); + } + return ProbeBatch(thread_index, std::move(left_batches_[task_id])); + } + + Status ProbeQueuedBatches_on_finished(size_t thread_index) { + if (cancelled_) { + return Status::Cancelled("Hash join cancelled"); + } + + left_batches_.clear(); + + bool proceed; + { + std::lock_guard lock(finished_mutex_); + proceed = left_side_finished_ && !left_queue_finished_; + left_queue_finished_ = true; + } + if (proceed) { + RETURN_NOT_OK(OnLeftSideAndQueueFinished(thread_index)); + } + + return Status::OK(); + } + + void RegisterProbeQueuedBatches() { + task_group_queued_ = scheduler_->RegisterTaskGroup( + [this](size_t thread_index, int64_t task_id) -> Status { + return ProbeQueuedBatches_exec_task(thread_index, task_id); + }, + [this](size_t thread_index) -> Status { + return ProbeQueuedBatches_on_finished(thread_index); + }); + } + + Status ProbeQueuedBatches(size_t thread_index) { + return scheduler_->StartTaskGroup(thread_index, task_group_queued_, + ProbeQueuedBatches_num_tasks()); + } + + int64_t ScanHashTable_num_tasks() { + if (!has_hash_table_ || hash_table_empty_) { + return 0; + } + if (join_type_ != JoinType::RIGHT_SEMI && join_type_ != JoinType::RIGHT_ANTI && + join_type_ != JoinType::RIGHT_OUTER && join_type_ != JoinType::FULL_OUTER) { + return 0; + } + return BitUtil::CeilDiv(hash_table_keys_.num_rows(), hash_table_scan_unit_); + } + + Status ScanHashTable_exec_task(size_t thread_index, int64_t task_id) { + if (cancelled_) { + return Status::Cancelled("Hash join cancelled"); + } + + int32_t start_row_id = static_cast(hash_table_scan_unit_ * task_id); + int32_t end_row_id = + static_cast(std::min(static_cast(hash_table_keys_.num_rows()), + hash_table_scan_unit_ * (task_id + 1))); + + ThreadLocalState& local_state = local_states_[thread_index]; + InitLocalStateIfNeeded(thread_index); + + std::vector& id_left = local_state.no_match; + std::vector& id_right = local_state.match; + id_left.clear(); + id_right.clear(); + bool use_left = false; + + uint8_t match_search_value = (join_type_ == JoinType::RIGHT_SEMI) ? 0xFF : 0x00; + for (int32_t row_id = start_row_id; row_id < end_row_id; ++row_id) { + if (has_match_[row_id] == match_search_value) { + id_right.push_back(row_id); + } + } + + if (id_right.empty()) { + return Status::OK(); + } + + if (join_type_ != JoinType::RIGHT_SEMI && join_type_ != JoinType::RIGHT_ANTI) { + use_left = true; + id_left.resize(id_right.size()); + for (size_t i = 0; i < id_left.size(); ++i) { + id_left[i] = RowEncoder::kRowIdForNulls(); + } + } + + RETURN_NOT_OK( + ProbeBatch_OutputOne(thread_index, static_cast(id_right.size()), + use_left ? id_left.data() : nullptr, id_right.data())); + return Status::OK(); + } + + Status ScanHashTable_on_finished(size_t thread_index) { + if (cancelled_) { + return Status::Cancelled("Hash join cancelled"); + } + finished_callback_(num_batches_produced_.load()); + return Status::OK(); + } + + void RegisterScanHashTable() { + task_group_scan_ = scheduler_->RegisterTaskGroup( + [this](size_t thread_index, int64_t task_id) -> Status { + return ScanHashTable_exec_task(thread_index, task_id); + }, + [this](size_t thread_index) -> Status { + return ScanHashTable_on_finished(thread_index); + }); + } + + Status ScanHashTable(size_t thread_index) { + return scheduler_->StartTaskGroup(thread_index, task_group_scan_, + ScanHashTable_num_tasks()); + } + + bool QueueBatchIfNeeded(int side, ExecBatch batch) { + if (side == 0) { + if (has_hash_table_) { + return false; + } + + std::lock_guard lock(left_batches_mutex_); + if (has_hash_table_) { + return false; + } + left_batches_.emplace_back(std::move(batch)); + return true; + } else { + std::lock_guard lock(right_batches_mutex_); + right_batches_.emplace_back(std::move(batch)); + return true; + } + } + + Status OnRightSideFinished(size_t thread_index) { return BuildHashTable(thread_index); } + + Status OnLeftSideAndQueueFinished(size_t thread_index) { + return ScanHashTable(thread_index); + } + + static constexpr int64_t hash_table_scan_unit_ = 32 * 1024; + static constexpr int64_t output_batch_size_ = 32 * 1024; + + // Metadata + // + ExecContext* ctx_; + JoinType join_type_; + size_t num_threads_; + HashJoinSchema* schema_mgr_; + std::vector key_cmp_; + std::unique_ptr scheduler_; + int task_group_build_; + int task_group_queued_; + int task_group_scan_; + + // Callbacks + // + OutputBatchCallback output_batch_callback_; + FinishedCallback finished_callback_; + + // Thread local runtime state + // + struct ThreadLocalState { + bool is_initialized; + RowEncoder exec_batch_keys; + RowEncoder exec_batch_payloads; + std::vector match; + std::vector no_match; + std::vector match_left; + std::vector match_right; + }; + std::vector local_states_; + + // Shared runtime state + // + RowEncoder hash_table_keys_; + RowEncoder hash_table_payloads_; + std::unordered_multimap hash_table_; + std::vector has_match_; + bool hash_table_empty_; + + std::vector left_batches_; + bool has_hash_table_; + std::mutex left_batches_mutex_; + + std::vector right_batches_; + std::mutex right_batches_mutex_; + + std::atomic num_batches_produced_; + bool cancelled_; + + bool right_side_finished_; + bool left_side_finished_; + bool left_queue_finished_; + std::mutex finished_mutex_; +}; + +Result> HashJoinImpl::MakeBasic() { + std::unique_ptr impl{new HashJoinBasicImpl()}; + return std::move(impl); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/hash_join.h b/cpp/src/arrow/compute/exec/hash_join.h new file mode 100644 index 00000000000..a2312e09653 --- /dev/null +++ b/cpp/src/arrow/compute/exec/hash_join.h @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/schema_util.h" +#include "arrow/compute/exec/task_util.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" + +namespace arrow { +namespace compute { + +// Identifiers for all different row schemas that are used in a join +// +enum class HashJoinProjection : int { INPUT = 0, KEY = 1, PAYLOAD = 2, OUTPUT = 3 }; + +class ARROW_EXPORT HashJoinSchema { + public: + Status Init(JoinType join_type, const Schema& left_schema, + const std::vector& left_keys, const Schema& right_schema, + const std::vector& right_keys, + const std::string& left_field_name_prefix, + const std::string& right_field_name_prefix); + + Status Init(JoinType join_type, const Schema& left_schema, + const std::vector& left_keys, + const std::vector& left_output, const Schema& right_schema, + const std::vector& right_keys, + const std::vector& right_output, + const std::string& left_field_name_prefix, + const std::string& right_field_name_prefix); + + static Status ValidateSchemas(JoinType join_type, const Schema& left_schema, + const std::vector& left_keys, + const std::vector& left_output, + const Schema& right_schema, + const std::vector& right_keys, + const std::vector& right_output, + const std::string& left_field_name_prefix, + const std::string& right_field_name_prefix); + + std::shared_ptr MakeOutputSchema(const std::string& left_field_name_prefix, + const std::string& right_field_name_prefix); + + static int kMissingField() { + return SchemaProjectionMaps::kMissingField; + } + + SchemaProjectionMaps proj_maps[2]; + + private: + static Result> VectorDiff(const Schema& schema, + const std::vector& a, + const std::vector& b); +}; + +class HashJoinImpl { + public: + using OutputBatchCallback = std::function; + using FinishedCallback = std::function; + + virtual ~HashJoinImpl() = default; + virtual Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution, + size_t num_threads, HashJoinSchema* schema_mgr, + std::vector key_cmp, + OutputBatchCallback output_batch_callback, + FinishedCallback finished_callback, + TaskScheduler::ScheduleImpl schedule_task_callback) = 0; + virtual Status InputReceived(size_t thread_index, int side, ExecBatch batch) = 0; + virtual Status InputFinished(size_t thread_index, int side) = 0; + virtual void Abort(TaskScheduler::AbortContinuationImpl pos_abort_callback) = 0; + + static Result> MakeBasic(); +}; + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc new file mode 100644 index 00000000000..ff87aa47fae --- /dev/null +++ b/cpp/src/arrow/compute/exec/hash_join_node.cc @@ -0,0 +1,466 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/hash_join.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/schema_util.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/thread_pool.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { + +Result> HashJoinSchema::VectorDiff(const Schema& schema, + const std::vector& a, + const std::vector& b) { + std::unordered_set b_paths; + for (size_t i = 0; i < b.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto match, b[i].FindOne(schema)); + b_paths.insert(match[0]); + } + + std::vector result; + + for (size_t i = 0; i < a.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto match, a[i].FindOne(schema)); + bool is_found = (b_paths.find(match[0]) != b_paths.end()); + if (!is_found) { + result.push_back(a[i]); + } + } + + return result; +} + +Status HashJoinSchema::Init(JoinType join_type, const Schema& left_schema, + const std::vector& left_keys, + const Schema& right_schema, + const std::vector& right_keys, + const std::string& left_field_name_prefix, + const std::string& right_field_name_prefix) { + std::vector left_output; + if (join_type != JoinType::RIGHT_SEMI && join_type != JoinType::RIGHT_ANTI) { + const FieldVector& left_fields = left_schema.fields(); + left_output.resize(left_fields.size()); + for (size_t i = 0; i < left_fields.size(); ++i) { + left_output[i] = FieldRef(static_cast(i)); + } + } + // Repeat the same for the right side + std::vector right_output; + if (join_type != JoinType::LEFT_SEMI && join_type != JoinType::LEFT_ANTI) { + const FieldVector& right_fields = right_schema.fields(); + right_output.resize(right_fields.size()); + for (size_t i = 0; i < right_fields.size(); ++i) { + right_output[i] = FieldRef(static_cast(i)); + } + } + return Init(join_type, left_schema, left_keys, left_output, right_schema, right_keys, + right_output, left_field_name_prefix, right_field_name_prefix); +} + +Status HashJoinSchema::Init(JoinType join_type, const Schema& left_schema, + const std::vector& left_keys, + const std::vector& left_output, + const Schema& right_schema, + const std::vector& right_keys, + const std::vector& right_output, + const std::string& left_field_name_prefix, + const std::string& right_field_name_prefix) { + RETURN_NOT_OK(ValidateSchemas(join_type, left_schema, left_keys, left_output, + right_schema, right_keys, right_output, + left_field_name_prefix, right_field_name_prefix)); + + std::vector handles; + std::vector*> field_refs; + + handles.push_back(HashJoinProjection::KEY); + field_refs.push_back(&left_keys); + ARROW_ASSIGN_OR_RAISE(auto left_payload, + VectorDiff(left_schema, left_output, left_keys)); + handles.push_back(HashJoinProjection::PAYLOAD); + field_refs.push_back(&left_payload); + handles.push_back(HashJoinProjection::OUTPUT); + field_refs.push_back(&left_output); + + RETURN_NOT_OK( + proj_maps[0].Init(HashJoinProjection::INPUT, left_schema, handles, field_refs)); + + handles.clear(); + field_refs.clear(); + + handles.push_back(HashJoinProjection::KEY); + field_refs.push_back(&right_keys); + ARROW_ASSIGN_OR_RAISE(auto right_payload, + VectorDiff(right_schema, right_output, right_keys)); + handles.push_back(HashJoinProjection::PAYLOAD); + field_refs.push_back(&right_payload); + handles.push_back(HashJoinProjection::OUTPUT); + field_refs.push_back(&right_output); + + RETURN_NOT_OK( + proj_maps[1].Init(HashJoinProjection::INPUT, right_schema, handles, field_refs)); + + return Status::OK(); +} + +Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_schema, + const std::vector& left_keys, + const std::vector& left_output, + const Schema& right_schema, + const std::vector& right_keys, + const std::vector& right_output, + const std::string& left_field_name_prefix, + const std::string& right_field_name_prefix) { + // Checks for key fields: + // 1. Key field refs must match exactly one input field + // 2. Same number of key fields on left and right + // 3. At least one key field + // 4. Equal data types for corresponding key fields + // 5. Dictionary type is not supported in a key field + // 6. Some other data types may not be allowed in a key field + // + if (left_keys.size() != right_keys.size()) { + return Status::Invalid("Different number of key fields on left (", left_keys.size(), + ") and right (", right_keys.size(), ") side of the join"); + } + if (left_keys.size() < 1) { + return Status::Invalid("Join key cannot be empty"); + } + for (size_t i = 0; i < left_keys.size() + right_keys.size(); ++i) { + bool left_side = i < left_keys.size(); + const FieldRef& field_ref = + left_side ? left_keys[i] : right_keys[i - left_keys.size()]; + Result result = field_ref.FindOne(left_side ? left_schema : right_schema); + if (!result.ok()) { + return Status::Invalid("No match or multiple matches for key field reference ", + field_ref.ToString(), left_side ? " on left " : " on right ", + "side of the join"); + } + const FieldPath& match = result.ValueUnsafe(); + const std::shared_ptr& type = + (left_side ? left_schema.fields() : right_schema.fields())[match[0]]->type(); + if (type->id() == Type::DICTIONARY) { + return Status::Invalid( + "Dictionary type support for join key is not yet implemented, key field " + "reference: ", + field_ref.ToString(), left_side ? " on left " : " on right ", + "side of the join"); + } + if ((type->id() != Type::BOOL && !is_fixed_width(type->id()) && + !is_binary_like(type->id())) || + is_large_binary_like(type->id())) { + return Status::Invalid("Data type ", type->ToString(), + " is not supported in join key field"); + } + } + for (size_t i = 0; i < left_keys.size(); ++i) { + const FieldRef& left_ref = left_keys[i]; + const FieldRef& right_ref = right_keys[i]; + int left_id = left_ref.FindOne(left_schema).ValueUnsafe()[0]; + int right_id = right_ref.FindOne(right_schema).ValueUnsafe()[0]; + const std::shared_ptr& left_type = left_schema.fields()[left_id]->type(); + const std::shared_ptr& right_type = right_schema.fields()[right_id]->type(); + if (!left_type->Equals(right_type)) { + return Status::Invalid("Mismatched data types for corresponding join field keys: ", + left_ref.ToString(), " of type ", left_type->ToString(), + " and ", right_ref.ToString(), " of type ", + right_type->ToString()); + } + } + + // Check for output fields: + // 1. Output field refs must match exactly one input field + // 2. At least one output field + // 3. Dictionary type is not supported in an output field + // 4. Left semi/anti join (right semi/anti join) must not output fields from right + // (left) + // 5. No name collisions in output fields after adding (potentially empty) + // prefixes to left and right output + // + if (left_output.empty() && right_output.empty()) { + return Status::Invalid("Join must output at least one field"); + } + if (join_type == JoinType::LEFT_SEMI || join_type == JoinType::LEFT_ANTI) { + if (!right_output.empty()) { + return Status::Invalid( + join_type == JoinType::LEFT_SEMI ? "Left semi join " : "Left anti-semi join ", + "may not output fields from right side"); + } + } + if (join_type == JoinType::RIGHT_SEMI || join_type == JoinType::RIGHT_ANTI) { + if (!left_output.empty()) { + return Status::Invalid(join_type == JoinType::RIGHT_SEMI ? "Right semi join " + : "Right anti-semi join ", + "may not output fields from left side"); + } + } + for (size_t i = 0; i < left_output.size() + right_output.size(); ++i) { + bool left_side = i < left_output.size(); + const FieldRef& field_ref = + left_side ? left_output[i] : right_output[i - left_output.size()]; + Result result = field_ref.FindOne(left_side ? left_schema : right_schema); + if (!result.ok()) { + return Status::Invalid("No match or multiple matches for output field reference ", + field_ref.ToString(), left_side ? " on left " : " on right ", + "side of the join"); + } + const FieldPath& match = result.ValueUnsafe(); + const std::shared_ptr& type = + (left_side ? left_schema.fields() : right_schema.fields())[match[0]]->type(); + if (type->id() == Type::DICTIONARY) { + return Status::Invalid( + "Dictionary type support for join output field is not yet implemented, output " + "field reference: ", + field_ref.ToString(), left_side ? " on left " : " on right ", + "side of the join"); + } + } + return Status::OK(); +} + +std::shared_ptr HashJoinSchema::MakeOutputSchema( + const std::string& left_field_name_prefix, + const std::string& right_field_name_prefix) { + std::vector> fields; + int left_size = proj_maps[0].num_cols(HashJoinProjection::OUTPUT); + int right_size = proj_maps[1].num_cols(HashJoinProjection::OUTPUT); + fields.resize(left_size + right_size); + + for (int i = 0; i < left_size + right_size; ++i) { + bool is_left = (i < left_size); + int side = (is_left ? 0 : 1); + int input_field_id = + proj_maps[side].map(HashJoinProjection::OUTPUT, + HashJoinProjection::INPUT)[is_left ? i : i - left_size]; + const std::string& input_field_name = + proj_maps[side].field_name(HashJoinProjection::INPUT, input_field_id); + const std::shared_ptr& input_data_type = + proj_maps[side].data_type(HashJoinProjection::INPUT, input_field_id); + + std::string output_field_name = + (is_left ? left_field_name_prefix : right_field_name_prefix) + input_field_name; + + // All fields coming out of join are marked as nullable. + fields[i] = + std::make_shared(output_field_name, input_data_type, true /*nullable*/); + } + return std::make_shared(std::move(fields)); +} + +class HashJoinNode : public ExecNode { + public: + HashJoinNode(ExecPlan* plan, NodeVector inputs, const HashJoinNodeOptions& join_options, + std::shared_ptr output_schema, + std::unique_ptr schema_mgr, + std::unique_ptr impl) + : ExecNode(plan, inputs, {"left", "right"}, + /*output_schema=*/std::move(output_schema), + /*num_outputs=*/1), + join_type_(join_options.join_type), + key_cmp_(join_options.key_cmp), + schema_mgr_(std::move(schema_mgr)), + impl_(std::move(impl)) { + complete_.store(false); + } + + static Result Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) { + // Number of input exec nodes must be 2 + RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 2, "HashJoinNode")); + + std::unique_ptr schema_mgr = + ::arrow::internal::make_unique(); + + const auto& join_options = checked_cast(options); + + // This will also validate input schemas + if (join_options.output_all) { + RETURN_NOT_OK(schema_mgr->Init( + join_options.join_type, *(inputs[0]->output_schema()), join_options.left_keys, + *(inputs[1]->output_schema()), join_options.right_keys, + join_options.output_prefix_for_left, join_options.output_prefix_for_right)); + } else { + RETURN_NOT_OK(schema_mgr->Init( + join_options.join_type, *(inputs[0]->output_schema()), join_options.left_keys, + join_options.left_output, *(inputs[1]->output_schema()), + join_options.right_keys, join_options.right_output, + join_options.output_prefix_for_left, join_options.output_prefix_for_right)); + } + + // Generate output schema + std::shared_ptr output_schema = schema_mgr->MakeOutputSchema( + join_options.output_prefix_for_left, join_options.output_prefix_for_right); + + // Create hash join implementation object + ARROW_ASSIGN_OR_RAISE(std::unique_ptr impl, HashJoinImpl::MakeBasic()); + + return plan->EmplaceNode(plan, inputs, join_options, + std::move(output_schema), + std::move(schema_mgr), std::move(impl)); + } + + const char* kind_name() const override { return "HashJoinNode"; } + + void InputReceived(ExecNode* input, ExecBatch batch) override { + ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); + + if (complete_.load()) { + return; + } + + size_t thread_index = thread_indexer_(); + int side = (input == inputs_[0]) ? 0 : 1; + { + Status status = impl_->InputReceived(thread_index, side, std::move(batch)); + if (!status.ok()) { + StopProducing(); + ErrorIfNotOk(status); + return; + } + } + if (batch_count_[side].Increment()) { + Status status = impl_->InputFinished(thread_index, side); + if (!status.ok()) { + StopProducing(); + ErrorIfNotOk(status); + return; + } + } + } + + void ErrorReceived(ExecNode* input, Status error) override { + DCHECK_EQ(input, inputs_[0]); + StopProducing(); + outputs_[0]->ErrorReceived(this, std::move(error)); + } + + void InputFinished(ExecNode* input, int total_batches) override { + ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); + + size_t thread_index = thread_indexer_(); + int side = (input == inputs_[0]) ? 0 : 1; + + if (batch_count_[side].SetTotal(total_batches)) { + Status status = impl_->InputFinished(thread_index, side); + if (!status.ok()) { + StopProducing(); + ErrorIfNotOk(status); + return; + } + } + } + + Status StartProducing() override { + finished_ = Future<>::Make(); + + bool use_sync_execution = !(plan_->exec_context()->executor()); + size_t num_threads = use_sync_execution ? 1 : thread_indexer_.Capacity(); + + RETURN_NOT_OK(impl_->Init( + plan_->exec_context(), join_type_, use_sync_execution, num_threads, + schema_mgr_.get(), key_cmp_, + [this](ExecBatch batch) { this->OutputBatchCallback(batch); }, + [this](int64_t total_num_batches) { this->FinishedCallback(total_num_batches); }, + [this](std::function func) -> Status { + return this->ScheduleTaskCallback(std::move(func)); + })); + return Status::OK(); + } + + void PauseProducing(ExecNode* output) override {} + + void ResumeProducing(ExecNode* output) override {} + + void StopProducing(ExecNode* output) override { + DCHECK_EQ(output, outputs_[0]); + StopProducing(); + } + + void StopProducing() override { + bool expected = false; + if (complete_.compare_exchange_strong(expected, true)) { + for (auto&& input : inputs_) { + input->StopProducing(this); + } + impl_->Abort([this]() { finished_.MarkFinished(); }); + } + } + + Future<> finished() override { return finished_; } + + private: + void OutputBatchCallback(ExecBatch batch) { + outputs_[0]->InputReceived(this, std::move(batch)); + } + + void FinishedCallback(int64_t total_num_batches) { + bool expected = false; + if (complete_.compare_exchange_strong(expected, true)) { + outputs_[0]->InputFinished(this, static_cast(total_num_batches)); + finished_.MarkFinished(); + } + } + + Status ScheduleTaskCallback(std::function func) { + auto executor = plan_->exec_context()->executor(); + if (executor) { + RETURN_NOT_OK(executor->Spawn([this, func] { + size_t thread_index = thread_indexer_(); + Status status = func(thread_index); + if (!status.ok()) { + StopProducing(); + ErrorIfNotOk(status); + return; + } + })); + } else { + // We should not get here in serial execution mode + ARROW_DCHECK(false); + } + return Status::OK(); + } + + private: + AtomicCounter batch_count_[2]; + std::atomic complete_; + Future<> finished_ = Future<>::MakeFinished(); + JoinType join_type_; + std::vector key_cmp_; + ThreadIndexer thread_indexer_; + std::unique_ptr schema_mgr_; + std::unique_ptr impl_; +}; + +namespace internal { + +void RegisterHashJoinNode(ExecFactoryRegistry* registry) { + DCHECK_OK(registry->AddFactory("hashjoin", HashJoinNode::Make)); +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc new file mode 100644 index 00000000000..4c1954d8ab2 --- /dev/null +++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc @@ -0,0 +1,1106 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include +#include + +#include "arrow/api.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/test_util.h" +#include "arrow/compute/exec/util.h" +#include "arrow/compute/kernels/row_encoder.h" +#include "arrow/compute/kernels/test_util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" +#include "arrow/testing/random.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/pcg_random.h" +#include "arrow/util/thread_pool.h" + +using testing::UnorderedElementsAreArray; + +namespace arrow { +namespace compute { + +BatchesWithSchema GenerateBatchesFromString( + const std::shared_ptr& schema, + const std::vector& json_strings, int multiplicity = 1) { + BatchesWithSchema out_batches{{}, schema}; + + std::vector descrs; + for (auto&& field : schema->fields()) { + descrs.emplace_back(field->type()); + } + + for (auto&& s : json_strings) { + out_batches.batches.push_back(ExecBatchFromJSON(descrs, s)); + } + + size_t batch_count = out_batches.batches.size(); + for (int repeat = 1; repeat < multiplicity; ++repeat) { + for (size_t i = 0; i < batch_count; ++i) { + out_batches.batches.push_back(out_batches.batches[i]); + } + } + + return out_batches; +} + +void CheckRunOutput(JoinType type, const BatchesWithSchema& l_batches, + const BatchesWithSchema& r_batches, + const std::vector& left_keys, + const std::vector& right_keys, + const BatchesWithSchema& exp_batches, bool parallel = false) { + auto exec_ctx = arrow::internal::make_unique( + default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); + + HashJoinNodeOptions join_options{type, left_keys, right_keys}; + Declaration join{"hashjoin", join_options}; + + // add left source + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{l_batches.schema, l_batches.gen(parallel, + /*slow=*/false)}}); + // add right source + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{r_batches.schema, r_batches.gen(parallel, + /*slow=*/false)}}); + AsyncGenerator> sink_gen; + + ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) + .AddToPlan(plan.get())); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); + + ASSERT_OK_AND_ASSIGN(auto exp_table, + TableFromExecBatches(exp_batches.schema, exp_batches.batches)); + + ASSERT_OK_AND_ASSIGN(auto out_table, TableFromExecBatches(exp_batches.schema, res)); + + if (exp_table->num_rows() == 0) { + ASSERT_EQ(exp_table->num_rows(), out_table->num_rows()); + } else { + std::vector sort_keys; + for (auto&& f : exp_batches.schema->fields()) { + sort_keys.emplace_back(f->name()); + } + ASSERT_OK_AND_ASSIGN(auto exp_table_sort_ids, + SortIndices(exp_table, SortOptions(sort_keys))); + ASSERT_OK_AND_ASSIGN(auto exp_table_sorted, Take(exp_table, exp_table_sort_ids)); + ASSERT_OK_AND_ASSIGN(auto out_table_sort_ids, + SortIndices(out_table, SortOptions(sort_keys))); + ASSERT_OK_AND_ASSIGN(auto out_table_sorted, Take(out_table, out_table_sort_ids)); + + AssertTablesEqual(*exp_table_sorted.table(), *out_table_sorted.table(), + /*same_chunk_layout=*/false, /*flatten=*/true); + } +} + +void RunNonEmptyTest(JoinType type, bool parallel) { + auto l_schema = schema({field("l_i32", int32()), field("l_str", utf8())}); + auto r_schema = schema({field("r_str", utf8()), field("r_i32", int32())}); + BatchesWithSchema l_batches, r_batches, exp_batches; + + int multiplicity = parallel ? 100 : 1; + + l_batches = GenerateBatchesFromString( + l_schema, + {R"([[0,"d"], [1,"b"]])", R"([[2,"d"], [3,"a"], [4,"a"]])", + R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, + multiplicity); + + r_batches = GenerateBatchesFromString( + r_schema, + {R"([["f", 0], ["b", 1], ["b", 2]])", R"([["c", 3], ["g", 4]])", R"([["e", 5]])"}, + multiplicity); + + switch (type) { + case JoinType::LEFT_SEMI: + exp_batches = GenerateBatchesFromString( + l_schema, {R"([[1,"b"]])", R"([])", R"([[5,"b"], [6,"c"], [7,"e"], [8,"e"]])"}, + multiplicity); + break; + case JoinType::RIGHT_SEMI: + exp_batches = GenerateBatchesFromString( + r_schema, {R"([["b", 1], ["b", 2]])", R"([["c", 3]])", R"([["e", 5]])"}, + multiplicity); + break; + case JoinType::LEFT_ANTI: + exp_batches = GenerateBatchesFromString( + l_schema, {R"([[0,"d"]])", R"([[2,"d"], [3,"a"], [4,"a"]])", R"([])"}, + multiplicity); + break; + case JoinType::RIGHT_ANTI: + exp_batches = GenerateBatchesFromString( + r_schema, {R"([["f", 0]])", R"([["g", 4]])", R"([])"}, multiplicity); + break; + case JoinType::INNER: + case JoinType::LEFT_OUTER: + case JoinType::RIGHT_OUTER: + case JoinType::FULL_OUTER: + default: + FAIL() << "join type not implemented!"; + } + + CheckRunOutput(type, l_batches, r_batches, + /*left_keys=*/{{"l_str"}}, /*right_keys=*/{{"r_str"}}, exp_batches, + parallel); +} + +void RunEmptyTest(JoinType type, bool parallel) { + auto l_schema = schema({field("l_i32", int32()), field("l_str", utf8())}); + auto r_schema = schema({field("r_str", utf8()), field("r_i32", int32())}); + + int multiplicity = parallel ? 100 : 1; + + BatchesWithSchema l_empty, r_empty, l_n_empty, r_n_empty; + + l_empty = GenerateBatchesFromString(l_schema, {R"([])"}, multiplicity); + r_empty = GenerateBatchesFromString(r_schema, {R"([])"}, multiplicity); + + l_n_empty = + GenerateBatchesFromString(l_schema, {R"([[0,"d"], [1,"b"]])"}, multiplicity); + r_n_empty = GenerateBatchesFromString(r_schema, {R"([["f", 0], ["b", 1], ["b", 2]])"}, + multiplicity); + + std::vector l_keys{{"l_str"}}; + std::vector r_keys{{"r_str"}}; + + switch (type) { + case JoinType::LEFT_SEMI: + // both empty + CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, l_empty, parallel); + // right empty + CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, l_empty, parallel); + // left empty + CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, l_empty, parallel); + break; + case JoinType::RIGHT_SEMI: + // both empty + CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, r_empty, parallel); + // right empty + CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, r_empty, parallel); + // left empty + CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, r_empty, parallel); + break; + case JoinType::LEFT_ANTI: + // both empty + CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, l_empty, parallel); + // right empty + CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, l_n_empty, parallel); + // left empty + CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, l_empty, parallel); + break; + case JoinType::RIGHT_ANTI: + // both empty + CheckRunOutput(type, l_empty, r_empty, l_keys, r_keys, r_empty, parallel); + // right empty + CheckRunOutput(type, l_n_empty, r_empty, l_keys, r_keys, r_empty, parallel); + // left empty + CheckRunOutput(type, l_empty, r_n_empty, l_keys, r_keys, r_n_empty, parallel); + break; + case JoinType::INNER: + case JoinType::LEFT_OUTER: + case JoinType::RIGHT_OUTER: + case JoinType::FULL_OUTER: + default: + FAIL() << "join type not implemented!"; + } +} + +class HashJoinTest : public testing::TestWithParam> {}; + +INSTANTIATE_TEST_SUITE_P( + HashJoinTest, HashJoinTest, + ::testing::Combine(::testing::Values(JoinType::LEFT_SEMI, JoinType::RIGHT_SEMI, + JoinType::LEFT_ANTI, JoinType::RIGHT_ANTI), + ::testing::Values(false, true))); + +TEST_P(HashJoinTest, TestSemiJoins) { + RunNonEmptyTest(std::get<0>(GetParam()), std::get<1>(GetParam())); +} + +TEST_P(HashJoinTest, TestSemiJoinsEmpty) { + RunEmptyTest(std::get<0>(GetParam()), std::get<1>(GetParam())); +} + +class Random64Bit { + public: + explicit Random64Bit(random::SeedType seed) : rng_(seed) {} + uint64_t next() { return dist_(rng_); } + template + inline T from_range(const T& min_val, const T& max_val) { + return static_cast(min_val + (next() % (max_val - min_val + 1))); + } + + private: + random::pcg32_fast rng_; + std::uniform_int_distribution dist_; +}; + +struct RandomDataTypeConstraints { + int64_t data_type_enabled_mask; + // Null related + double min_null_probability; + double max_null_probability; + // Binary related + int min_binary_length; + int max_binary_length; + // String related + int min_string_length; + int max_string_length; + + void Default() { + data_type_enabled_mask = kInt1 | kInt2 | kInt4 | kInt8 | kBool | kBinary | kString; + min_null_probability = 0.0; + max_null_probability = 0.2; + min_binary_length = 1; + max_binary_length = 40; + min_string_length = 0; + max_string_length = 40; + } + + void OnlyInt(int int_size, bool allow_nulls) { + Default(); + data_type_enabled_mask = + int_size == 8 ? kInt8 : int_size == 4 ? kInt4 : int_size == 2 ? kInt2 : kInt1; + if (!allow_nulls) { + max_null_probability = 0.0; + } + } + + void OnlyString(bool allow_nulls) { + Default(); + data_type_enabled_mask = kString; + if (!allow_nulls) { + max_null_probability = 0.0; + } + } + + // Data type mask constants + static constexpr int64_t kInt1 = 1; + static constexpr int64_t kInt2 = 2; + static constexpr int64_t kInt4 = 4; + static constexpr int64_t kInt8 = 8; + static constexpr int64_t kBool = 16; + static constexpr int64_t kBinary = 32; + static constexpr int64_t kString = 64; +}; + +struct RandomDataType { + double null_probability; + bool is_fixed_length; + int fixed_length; + int min_string_length; + int max_string_length; + + static RandomDataType Random(Random64Bit& rng, + const RandomDataTypeConstraints& constraints) { + RandomDataType result; + if ((constraints.data_type_enabled_mask & constraints.kString) != 0) { + if (constraints.data_type_enabled_mask != constraints.kString) { + // Both string and fixed length types enabled + // 50% chance of string + result.is_fixed_length = ((rng.next() % 2) == 0); + } else { + result.is_fixed_length = false; + } + } else { + result.is_fixed_length = true; + } + if (constraints.max_null_probability > 0.0) { + // 25% chance of no nulls + // Uniform distribution of null probability from min to max + result.null_probability = ((rng.next() % 4) == 0) + ? 0.0 + : static_cast(rng.next() % 1025) / 1024.0 * + (constraints.max_null_probability - + constraints.min_null_probability) + + constraints.min_null_probability; + } else { + result.null_probability = 0.0; + } + // Pick data type for fixed length + if (result.is_fixed_length) { + int log_type; + for (;;) { + log_type = rng.next() % 6; + if (constraints.data_type_enabled_mask & (1ULL << log_type)) { + break; + } + } + if ((1ULL << log_type) == constraints.kBinary) { + for (;;) { + result.fixed_length = rng.from_range(constraints.min_binary_length, + constraints.max_binary_length); + if (result.fixed_length != 1 && result.fixed_length != 2 && + result.fixed_length != 4 && result.fixed_length != 8) { + break; + } + } + } else { + result.fixed_length = + ((1ULL << log_type) == constraints.kBool) ? 0 : (1ULL << log_type); + } + } else { + // Pick parameters for string + result.min_string_length = + rng.from_range(constraints.min_string_length, constraints.max_string_length); + result.max_string_length = + rng.from_range(constraints.min_string_length, constraints.max_string_length); + if (result.min_string_length > result.max_string_length) { + std::swap(result.min_string_length, result.max_string_length); + } + } + return result; + } +}; + +struct RandomDataTypeVector { + std::vector data_types; + + void AddRandom(Random64Bit& rng, const RandomDataTypeConstraints& constraints) { + data_types.push_back(RandomDataType::Random(rng, constraints)); + } + + void Print() { + for (size_t i = 0; i < data_types.size(); ++i) { + if (!data_types[i].is_fixed_length) { + std::cout << "str[" << data_types[i].min_string_length << ".." + << data_types[i].max_string_length << "]"; + SCOPED_TRACE("str[" + std::to_string(data_types[i].min_string_length) + ".." + + std::to_string(data_types[i].max_string_length) + "]"); + } else { + std::cout << "int[" << data_types[i].fixed_length << "]"; + SCOPED_TRACE("int[" + std::to_string(data_types[i].fixed_length) + "]"); + } + } + std::cout << std::endl; + } +}; + +std::vector> GenRandomRecords( + Random64Bit& rng, const std::vector& data_types, int num_rows) { + std::vector> result; + random::RandomArrayGenerator rag(static_cast(rng.next())); + for (size_t i = 0; i < data_types.size(); ++i) { + if (data_types[i].is_fixed_length) { + switch (data_types[i].fixed_length) { + case 0: + result.push_back(rag.Boolean(num_rows, 0.5, data_types[i].null_probability)); + break; + case 1: + result.push_back(rag.UInt8(num_rows, std::numeric_limits::min(), + std::numeric_limits::max(), + data_types[i].null_probability)); + break; + case 2: + result.push_back(rag.UInt16(num_rows, std::numeric_limits::min(), + std::numeric_limits::max(), + data_types[i].null_probability)); + break; + case 4: + result.push_back(rag.UInt32(num_rows, std::numeric_limits::min(), + std::numeric_limits::max(), + data_types[i].null_probability)); + break; + case 8: + result.push_back(rag.UInt64(num_rows, std::numeric_limits::min(), + std::numeric_limits::max(), + data_types[i].null_probability)); + break; + default: + result.push_back(rag.FixedSizeBinary(num_rows, data_types[i].fixed_length, + data_types[i].null_probability)); + break; + } + } else { + result.push_back(rag.String(num_rows, data_types[i].min_string_length, + data_types[i].max_string_length, + data_types[i].null_probability)); + } + } + return result; +} + +// Index < 0 means appending null values to all columns. +// +void TakeUsingVector(ExecContext* ctx, const std::vector>& input, + const std::vector indices, + std::vector>* result) { + ASSERT_OK_AND_ASSIGN( + std::shared_ptr buf, + AllocateBuffer(indices.size() * sizeof(int32_t), ctx->memory_pool())); + int32_t* buf_indices = reinterpret_cast(buf->mutable_data()); + bool has_null_rows = false; + for (size_t i = 0; i < indices.size(); ++i) { + if (indices[i] < 0) { + buf_indices[i] = 0; + has_null_rows = true; + } else { + buf_indices[i] = indices[i]; + } + } + std::shared_ptr indices_array = MakeArray(ArrayData::Make( + int32(), indices.size(), {nullptr, std::move(buf)}, /*null_count=*/0)); + + result->resize(input.size()); + for (size_t i = 0; i < result->size(); ++i) { + ASSERT_OK_AND_ASSIGN(Datum new_array, Take(input[i], indices_array)); + (*result)[i] = new_array.make_array(); + } + if (has_null_rows) { + for (size_t i = 0; i < result->size(); ++i) { + if ((*result)[i]->data()->buffers[0] == NULLPTR) { + ASSERT_OK_AND_ASSIGN(std::shared_ptr null_buf, + AllocateBitmap(indices.size(), ctx->memory_pool())); + uint8_t* non_nulls = null_buf->mutable_data(); + memset(non_nulls, 0xFF, BitUtil::BytesForBits(indices.size())); + if ((*result)[i]->data()->buffers.size() == 2) { + (*result)[i] = MakeArray( + ArrayData::Make((*result)[i]->type(), indices.size(), + {std::move(null_buf), (*result)[i]->data()->buffers[1]})); + } else { + (*result)[i] = MakeArray( + ArrayData::Make((*result)[i]->type(), indices.size(), + {std::move(null_buf), (*result)[i]->data()->buffers[1], + (*result)[i]->data()->buffers[2]})); + } + } + (*result)[i]->data()->SetNullCount(kUnknownNullCount); + } + for (size_t i = 0; i < indices.size(); ++i) { + if (indices[i] < 0) { + for (size_t col = 0; col < result->size(); ++col) { + uint8_t* non_nulls = (*result)[col]->data()->buffers[0]->mutable_data(); + BitUtil::ClearBit(non_nulls, i); + } + } + } + } +} + +// Generate random arrays given list of data type descriptions and null probabilities. +// Make sure that all generated records are unique. +// The actual number of generated records may be lower than desired because duplicates +// will be removed without replacement. +// +std::vector> GenRandomUniqueRecords( + Random64Bit& rng, const RandomDataTypeVector& data_types, int num_desired, + int* num_actual) { + std::vector> result = + GenRandomRecords(rng, data_types.data_types, num_desired); + + ExecContext* ctx = default_exec_context(); + std::vector val_descrs; + for (size_t i = 0; i < result.size(); ++i) { + val_descrs.push_back(ValueDescr(result[i]->type(), ValueDescr::ARRAY)); + } + internal::RowEncoder encoder; + encoder.Init(val_descrs, ctx); + ExecBatch batch({}, num_desired); + batch.values.resize(result.size()); + for (size_t i = 0; i < result.size(); ++i) { + batch.values[i] = result[i]; + } + Status status = encoder.EncodeAndAppend(batch); + ARROW_DCHECK(status.ok()); + + std::unordered_map uniques; + std::vector ids; + for (int i = 0; i < num_desired; ++i) { + if (uniques.find(encoder.encoded_row(i)) == uniques.end()) { + uniques.insert(std::make_pair(encoder.encoded_row(i), i)); + ids.push_back(i); + } + } + *num_actual = static_cast(uniques.size()); + + std::vector> output; + TakeUsingVector(ctx, result, ids, &output); + return output; +} + +std::vector NullInKey(const std::vector& cmp, + const std::vector>& key) { + ARROW_DCHECK(cmp.size() <= key.size()); + ARROW_DCHECK(key.size() > 0); + std::vector result; + result.resize(key[0]->length()); + for (size_t i = 0; i < result.size(); ++i) { + result[i] = false; + } + for (size_t i = 0; i < cmp.size(); ++i) { + if (cmp[i] != JoinKeyCmp::EQ) { + continue; + } + if (key[i]->data()->buffers[0] == NULLPTR) { + continue; + } + const uint8_t* nulls = key[i]->data()->buffers[0]->data(); + if (!nulls) { + continue; + } + for (size_t j = 0; j < result.size(); ++j) { + if (!BitUtil::GetBit(nulls, j)) { + result[j] = true; + } + } + } + return result; +} + +void GenRandomJoinTables(ExecContext* ctx, Random64Bit& rng, int num_rows_l, + int num_rows_r, int num_keys_common, int num_keys_left, + int num_keys_right, const RandomDataTypeVector& key_types, + const RandomDataTypeVector& payload_left_types, + const RandomDataTypeVector& payload_right_types, + std::vector* key_id_l, std::vector* key_id_r, + std::vector>* left, + std::vector>* right) { + // Generate random keys dictionary + // + int num_keys_desired = num_keys_left + num_keys_right - num_keys_common; + int num_keys_actual = 0; + std::vector> keys = + GenRandomUniqueRecords(rng, key_types, num_keys_desired, &num_keys_actual); + + // There will be three dictionary id ranges: + // - common keys [0..num_keys_common-1] + // - keys on right that are not on left [num_keys_common..num_keys_right-1] + // - keys on left that are not on right [num_keys_right..num_keys_actual-1] + // + num_keys_common = static_cast(static_cast(num_keys_common) * + num_keys_actual / num_keys_desired); + num_keys_right = static_cast(static_cast(num_keys_right) * + num_keys_actual / num_keys_desired); + ARROW_DCHECK(num_keys_right >= num_keys_common); + num_keys_left = num_keys_actual - num_keys_right + num_keys_common; + if (num_keys_left == 0) { + ARROW_DCHECK(num_keys_common == 0 && num_keys_right > 0); + ++num_keys_left; + ++num_keys_common; + } + if (num_keys_right == 0) { + ARROW_DCHECK(num_keys_common == 0 && num_keys_left > 0); + ++num_keys_right; + ++num_keys_common; + } + ARROW_DCHECK(num_keys_left >= num_keys_common); + ARROW_DCHECK(num_keys_left + num_keys_right - num_keys_common == num_keys_actual); + + key_id_l->resize(num_rows_l); + for (int i = 0; i < num_rows_l; ++i) { + (*key_id_l)[i] = rng.from_range(0, num_keys_left - 1); + if ((*key_id_l)[i] >= num_keys_common) { + (*key_id_l)[i] += num_keys_right - num_keys_common; + } + } + + key_id_r->resize(num_rows_r); + for (int i = 0; i < num_rows_r; ++i) { + (*key_id_r)[i] = rng.from_range(0, num_keys_right - 1); + } + + std::vector> key_l; + std::vector> key_r; + TakeUsingVector(ctx, keys, *key_id_l, &key_l); + TakeUsingVector(ctx, keys, *key_id_r, &key_r); + std::vector> payload_l = + GenRandomRecords(rng, payload_left_types.data_types, num_rows_l); + std::vector> payload_r = + GenRandomRecords(rng, payload_right_types.data_types, num_rows_r); + + left->resize(key_l.size() + payload_l.size()); + for (size_t i = 0; i < key_l.size(); ++i) { + (*left)[i] = key_l[i]; + } + for (size_t i = 0; i < payload_l.size(); ++i) { + (*left)[key_l.size() + i] = payload_l[i]; + } + right->resize(key_r.size() + payload_r.size()); + for (size_t i = 0; i < key_r.size(); ++i) { + (*right)[i] = key_r[i]; + } + for (size_t i = 0; i < payload_r.size(); ++i) { + (*right)[key_r.size() + i] = payload_r[i]; + } +} + +std::vector> ConstructJoinOutputFromRowIds( + ExecContext* ctx, const std::vector& row_ids_l, + const std::vector& row_ids_r, const std::vector>& l, + const std::vector>& r, + const std::vector& shuffle_output_l, const std::vector& shuffle_output_r) { + std::vector> full_output_l; + std::vector> full_output_r; + TakeUsingVector(ctx, l, row_ids_l, &full_output_l); + TakeUsingVector(ctx, r, row_ids_r, &full_output_r); + std::vector> result; + result.resize(shuffle_output_l.size() + shuffle_output_r.size()); + for (size_t i = 0; i < shuffle_output_l.size(); ++i) { + result[i] = full_output_l[shuffle_output_l[i]]; + } + for (size_t i = 0; i < shuffle_output_r.size(); ++i) { + result[shuffle_output_l.size() + i] = full_output_r[shuffle_output_r[i]]; + } + return result; +} + +BatchesWithSchema TableToBatches(Random64Bit& rng, int num_batches, + const std::vector>& table, + const std::string& column_name_prefix) { + BatchesWithSchema result; + + std::vector> fields; + fields.resize(table.size()); + for (size_t i = 0; i < table.size(); ++i) { + fields[i] = std::make_shared(column_name_prefix + std::to_string(i), + table[i]->type(), true); + } + result.schema = std::make_shared(std::move(fields)); + + int64_t length = table[0]->length(); + num_batches = std::min(num_batches, static_cast(length)); + + std::vector batch_offsets; + batch_offsets.push_back(0); + batch_offsets.push_back(length); + std::unordered_set batch_offset_set; + for (int i = 0; i < num_batches - 1; ++i) { + for (;;) { + int64_t offset = rng.from_range(static_cast(1), length - 1); + if (batch_offset_set.find(offset) == batch_offset_set.end()) { + batch_offset_set.insert(offset); + batch_offsets.push_back(offset); + break; + } + } + } + std::sort(batch_offsets.begin(), batch_offsets.end()); + + for (int i = 0; i < num_batches; ++i) { + int64_t batch_offset = batch_offsets[i]; + int64_t batch_length = batch_offsets[i + 1] - batch_offsets[i]; + ExecBatch batch({}, batch_length); + batch.values.resize(table.size()); + for (size_t col = 0; col < table.size(); ++col) { + batch.values[col] = table[col]->data()->Slice(batch_offset, batch_length); + } + result.batches.push_back(batch); + } + + return result; +} + +// -1 in result means outputting all corresponding fields as nulls +// +void HashJoinSimpleInt(JoinType join_type, const std::vector& l, + const std::vector& null_in_key_l, + const std::vector& r, + const std::vector& null_in_key_r, + std::vector* result_l, std::vector* result_r, + int64_t output_length_limit, bool* length_limit_reached) { + *length_limit_reached = false; + + bool switch_sides = false; + switch (join_type) { + case JoinType::RIGHT_SEMI: + join_type = JoinType::LEFT_SEMI; + switch_sides = true; + break; + case JoinType::RIGHT_ANTI: + join_type = JoinType::LEFT_ANTI; + switch_sides = true; + break; + case JoinType::RIGHT_OUTER: + join_type = JoinType::LEFT_OUTER; + switch_sides = true; + break; + default: + break; + } + const std::vector& build = switch_sides ? l : r; + const std::vector& probe = switch_sides ? r : l; + const std::vector& null_in_key_build = + switch_sides ? null_in_key_l : null_in_key_r; + const std::vector& null_in_key_probe = + switch_sides ? null_in_key_r : null_in_key_l; + std::vector* result_build = switch_sides ? result_l : result_r; + std::vector* result_probe = switch_sides ? result_r : result_l; + + std::unordered_multimap map_build; + for (size_t i = 0; i < build.size(); ++i) { + map_build.insert(std::make_pair(build[i], i)); + } + std::vector match_build; + match_build.resize(build.size()); + for (size_t i = 0; i < build.size(); ++i) { + match_build[i] = false; + } + + for (int32_t i = 0; i < static_cast(probe.size()); ++i) { + std::vector match_probe; + if (!null_in_key_probe[i]) { + auto range = map_build.equal_range(probe[i]); + for (auto it = range.first; it != range.second; ++it) { + if (!null_in_key_build[it->second]) { + match_probe.push_back(static_cast(it->second)); + match_build[it->second] = true; + } + } + } + switch (join_type) { + case JoinType::LEFT_SEMI: + if (!match_probe.empty()) { + result_probe->push_back(i); + result_build->push_back(-1); + } + break; + case JoinType::LEFT_ANTI: + if (match_probe.empty()) { + result_probe->push_back(i); + result_build->push_back(-1); + } + break; + case JoinType::INNER: + for (size_t j = 0; j < match_probe.size(); ++j) { + result_probe->push_back(i); + result_build->push_back(match_probe[j]); + } + break; + case JoinType::LEFT_OUTER: + case JoinType::FULL_OUTER: + if (match_probe.empty()) { + result_probe->push_back(i); + result_build->push_back(-1); + } else { + for (size_t j = 0; j < match_probe.size(); ++j) { + result_probe->push_back(i); + result_build->push_back(match_probe[j]); + } + } + break; + default: + ARROW_DCHECK(false); + break; + } + + if (static_cast(result_probe->size()) >= output_length_limit) { + *length_limit_reached = true; + return; + } + } + + if (join_type == JoinType::FULL_OUTER) { + for (int32_t i = 0; i < static_cast(build.size()); ++i) { + if (!match_build[i]) { + result_probe->push_back(-1); + result_build->push_back(i); + } + } + } +} + +std::vector GenShuffle(Random64Bit& rng, int length) { + std::vector shuffle(length); + std::iota(shuffle.begin(), shuffle.end(), 0); + for (int i = 0; i < length * 2; ++i) { + int from = rng.from_range(0, length - 1); + int to = rng.from_range(0, length - 1); + if (from != to) { + std::swap(shuffle[from], shuffle[to]); + } + } + return shuffle; +} + +void GenJoinFieldRefs(Random64Bit& rng, int num_key_fields, bool no_output, + const std::vector>& original_input, + const std::string& field_name_prefix, + std::vector>* new_input, + std::vector* keys, std::vector* output, + std::vector* output_field_ids) { + // Permute input + std::vector shuffle = GenShuffle(rng, static_cast(original_input.size())); + new_input->resize(original_input.size()); + for (size_t i = 0; i < original_input.size(); ++i) { + (*new_input)[i] = original_input[shuffle[i]]; + } + + // Compute key field refs + keys->resize(num_key_fields); + for (size_t i = 0; i < shuffle.size(); ++i) { + if (shuffle[i] < num_key_fields) { + bool use_by_name_ref = (rng.from_range(0, 1) == 0); + if (use_by_name_ref) { + (*keys)[shuffle[i]] = FieldRef(field_name_prefix + std::to_string(i)); + } else { + (*keys)[shuffle[i]] = FieldRef(static_cast(i)); + } + } + } + + // Compute output field refs + if (!no_output) { + int num_output = rng.from_range(1, static_cast(original_input.size() + 1)); + output_field_ids->resize(num_output); + output->resize(num_output); + for (int i = 0; i < num_output; ++i) { + int col_id = rng.from_range(0, static_cast(original_input.size() - 1)); + (*output_field_ids)[i] = col_id; + (*output)[i] = (rng.from_range(0, 1) == 0) + ? FieldRef(field_name_prefix + std::to_string(col_id)) + : FieldRef(col_id); + } + } +} + +std::shared_ptr HashJoinSimple( + ExecContext* ctx, JoinType join_type, const std::vector& cmp, + int num_key_fields, const std::vector& key_id_l, + const std::vector& key_id_r, + const std::vector>& original_l, + const std::vector>& original_r, + const std::vector>& l, + const std::vector>& r, const std::vector& output_ids_l, + const std::vector& output_ids_r, int64_t output_length_limit, + bool* length_limit_reached) { + std::vector> key_l(num_key_fields); + std::vector> key_r(num_key_fields); + for (int i = 0; i < num_key_fields; ++i) { + key_l[i] = original_l[i]; + key_r[i] = original_r[i]; + } + std::vector null_key_l = NullInKey(cmp, key_l); + std::vector null_key_r = NullInKey(cmp, key_r); + + std::vector row_ids_l; + std::vector row_ids_r; + HashJoinSimpleInt(join_type, key_id_l, null_key_l, key_id_r, null_key_r, &row_ids_l, + &row_ids_r, output_length_limit, length_limit_reached); + + std::vector> result = ConstructJoinOutputFromRowIds( + ctx, row_ids_l, row_ids_r, l, r, output_ids_l, output_ids_r); + + std::vector> fields(result.size()); + for (size_t i = 0; i < result.size(); ++i) { + fields[i] = std::make_shared("a" + std::to_string(i), result[i]->type(), true); + } + std::shared_ptr schema = std::make_shared(std::move(fields)); + return Table::Make(schema, result, result[0]->length()); +} + +void HashJoinWithExecPlan(Random64Bit& rng, bool parallel, + const HashJoinNodeOptions& join_options, + const std::shared_ptr& output_schema, + const std::vector>& l, + const std::vector>& r, int num_batches_l, + int num_batches_r, std::shared_ptr
* output) { + auto exec_ctx = arrow::internal::make_unique( + default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); + + Declaration join{"hashjoin", join_options}; + + // add left source + BatchesWithSchema l_batches = TableToBatches(rng, num_batches_l, l, "l_"); + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{l_batches.schema, l_batches.gen(parallel, + /*slow=*/false)}}); + // add right source + BatchesWithSchema r_batches = TableToBatches(rng, num_batches_r, r, "r_"); + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{r_batches.schema, r_batches.gen(parallel, + /*slow=*/false)}}); + AsyncGenerator> sink_gen; + + ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) + .AddToPlan(plan.get())); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); + ASSERT_OK_AND_ASSIGN(*output, TableFromExecBatches(output_schema, res)); +} + +TEST(HashJoin, Random) { + Random64Bit rng(42); + + int num_tests = 100; + for (int test_id = 0; test_id < num_tests; ++test_id) { + bool parallel = (rng.from_range(0, 1) == 1); + auto exec_ctx = arrow::internal::make_unique( + default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); + + // Constraints + RandomDataTypeConstraints type_constraints; + type_constraints.Default(); + // type_constraints.OnlyInt(1, true); + constexpr int max_num_key_fields = 3; + constexpr int max_num_payload_fields = 3; + const char* join_type_names[] = {"LEFT_SEMI", "RIGHT_SEMI", "LEFT_ANTI", + "RIGHT_ANTI", "INNER", "LEFT_OUTER", + "RIGHT_OUTER", "FULL_OUTER"}; + std::vector join_type_options{JoinType::LEFT_SEMI, JoinType::RIGHT_SEMI, + JoinType::LEFT_ANTI, JoinType::RIGHT_ANTI, + JoinType::INNER, JoinType::LEFT_OUTER, + JoinType::RIGHT_OUTER, JoinType::FULL_OUTER}; + constexpr int join_type_mask = 0xFF; + // for INNER join only: + // constexpr int join_type_mask = 0x10; + std::vector key_cmp_options{JoinKeyCmp::EQ, JoinKeyCmp::IS}; + constexpr int key_cmp_mask = 0x03; + // for EQ only: + // constexpr int key_cmp_mask = 0x01; + constexpr int min_num_rows = 1; + const int max_num_rows = parallel ? 20000 : 2000; + constexpr int min_batch_size = 10; + constexpr int max_batch_size = 100; + + // Generate list of key field data types + int num_key_fields = rng.from_range(1, max_num_key_fields); + RandomDataTypeVector key_types; + for (int i = 0; i < num_key_fields; ++i) { + key_types.AddRandom(rng, type_constraints); + } + + // Generate lists of payload data types + int num_payload_fields[2]; + RandomDataTypeVector payload_types[2]; + for (int i = 0; i < 2; ++i) { + num_payload_fields[i] = rng.from_range(0, max_num_payload_fields); + for (int j = 0; j < num_payload_fields[i]; ++j) { + payload_types[i].AddRandom(rng, type_constraints); + } + } + + // Generate join type and comparison functions + std::vector key_cmp(num_key_fields); + std::string key_cmp_str; + for (int i = 0; i < num_key_fields; ++i) { + for (;;) { + int pos = rng.from_range(0, 1); + if ((key_cmp_mask & (1 << pos)) > 0) { + key_cmp[i] = key_cmp_options[pos]; + if (i > 0) { + key_cmp_str += "_"; + } + key_cmp_str += key_cmp[i] == JoinKeyCmp::EQ ? "EQ" : "IS"; + break; + } + } + } + JoinType join_type; + std::string join_type_name; + for (;;) { + int pos = rng.from_range(0, 7); + if ((join_type_mask & (1 << pos)) > 0) { + join_type = join_type_options[pos]; + join_type_name = join_type_names[pos]; + break; + } + } + + // Generate input records + int num_rows_l = rng.from_range(min_num_rows, max_num_rows); + int num_rows_r = rng.from_range(min_num_rows, max_num_rows); + int num_rows = std::min(num_rows_l, num_rows_r); + int batch_size = rng.from_range(min_batch_size, max_batch_size); + int num_keys = rng.from_range(std::max(1, num_rows / 10), num_rows); + int num_keys_r = rng.from_range(std::max(1, num_keys / 2), num_keys); + int num_keys_common = rng.from_range(std::max(1, num_keys_r / 2), num_keys_r); + int num_keys_l = num_keys_common + (num_keys - num_keys_r); + std::vector key_id_vectors[2]; + std::vector> input_arrays[2]; + GenRandomJoinTables(exec_ctx.get(), rng, num_rows_l, num_rows_r, num_keys_common, + num_keys_l, num_keys_r, key_types, payload_types[0], + payload_types[1], &(key_id_vectors[0]), &(key_id_vectors[1]), + &(input_arrays[0]), &(input_arrays[1])); + std::vector> shuffled_input_arrays[2]; + std::vector key_fields[2]; + std::vector output_fields[2]; + std::vector output_field_ids[2]; + for (int i = 0; i < 2; ++i) { + bool no_output = false; + if (i == 0) { + no_output = + join_type == JoinType::RIGHT_SEMI || join_type == JoinType::RIGHT_ANTI; + } else { + no_output = join_type == JoinType::LEFT_SEMI || join_type == JoinType::LEFT_ANTI; + } + GenJoinFieldRefs(rng, num_key_fields, no_output, input_arrays[i], + std::string((i == 0) ? "l_" : "r_"), &(shuffled_input_arrays[i]), + &(key_fields[i]), &(output_fields[i]), &(output_field_ids[i])); + } + + // Print test case parameters + // print num_rows, batch_size, join_type, join_cmp + std::cout << join_type_name << " " << key_cmp_str << " "; + key_types.Print(); + std::cout << " num_rows_l = " << num_rows_l << " num_rows_r = " << num_rows_r + << " batch size = " << batch_size + << " parallel = " << (parallel ? "true" : "false"); + std::cout << std::endl; + + // Run reference join implementation + std::vector null_in_key_vectors[2]; + for (int i = 0; i < 2; ++i) { + null_in_key_vectors[i] = NullInKey(key_cmp, input_arrays[i]); + } + int64_t output_length_limit = 100000; + bool length_limit_reached = false; + std::shared_ptr
output_rows_ref = HashJoinSimple( + exec_ctx.get(), join_type, key_cmp, num_key_fields, key_id_vectors[0], + key_id_vectors[1], input_arrays[0], input_arrays[1], shuffled_input_arrays[0], + shuffled_input_arrays[1], output_field_ids[0], output_field_ids[1], + output_length_limit, &length_limit_reached); + if (length_limit_reached) { + continue; + } + + // Run tested join implementation + HashJoinNodeOptions join_options{join_type, key_fields[0], key_fields[1], + output_fields[0], output_fields[1], key_cmp}; + std::vector> output_schema_fields; + for (int i = 0; i < 2; ++i) { + for (size_t col = 0; col < output_fields[i].size(); ++col) { + output_schema_fields.push_back(std::make_shared( + std::string((i == 0) ? "l_" : "r_") + std::to_string(col), + shuffled_input_arrays[i][output_field_ids[i][col]]->type(), true)); + } + } + std::shared_ptr output_schema = + std::make_shared(std::move(output_schema_fields)); + std::shared_ptr
output_rows_test; + HashJoinWithExecPlan(rng, parallel, join_options, output_schema, + shuffled_input_arrays[0], shuffled_input_arrays[1], + static_cast(BitUtil::CeilDiv(num_rows_l, batch_size)), + static_cast(BitUtil::CeilDiv(num_rows_r, batch_size)), + &output_rows_test); + + // Compare results + AssertTablesEqual(output_rows_ref, output_rows_test); + } +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index acc79bdfdde..3dd88210e78 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -126,5 +126,98 @@ class ARROW_EXPORT OrderBySinkNodeOptions : public SinkNodeOptions { SortOptions sort_options; }; +enum class JoinType { + LEFT_SEMI, + RIGHT_SEMI, + LEFT_ANTI, + RIGHT_ANTI, + INNER, + LEFT_OUTER, + RIGHT_OUTER, + FULL_OUTER +}; + +enum class JoinKeyCmp { EQ, IS }; + +/// \brief Make a node which implements join operation using hash join strategy. +class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { + public: + static constexpr const char* default_output_prefix_for_left = ""; + static constexpr const char* default_output_prefix_for_right = ""; + HashJoinNodeOptions( + JoinType in_join_type, std::vector in_left_keys, + std::vector in_right_keys, + std::string output_prefix_for_left = default_output_prefix_for_left, + std::string output_prefix_for_right = default_output_prefix_for_right) + : join_type(in_join_type), + left_keys(std::move(in_left_keys)), + right_keys(std::move(in_right_keys)), + output_all(true), + output_prefix_for_left(std::move(output_prefix_for_left)), + output_prefix_for_right(std::move(output_prefix_for_right)) { + key_cmp.resize(left_keys.size()); + for (size_t i = 0; i < left_keys.size(); ++i) { + key_cmp[i] = JoinKeyCmp::EQ; + } + } + HashJoinNodeOptions( + JoinType join_type, std::vector left_keys, + std::vector right_keys, std::vector left_output, + std::vector right_output, + std::string output_prefix_for_left = default_output_prefix_for_left, + std::string output_prefix_for_right = default_output_prefix_for_right) + : join_type(join_type), + left_keys(std::move(left_keys)), + right_keys(std::move(right_keys)), + output_all(false), + left_output(std::move(left_output)), + right_output(std::move(right_output)), + output_prefix_for_left(std::move(output_prefix_for_left)), + output_prefix_for_right(std::move(output_prefix_for_right)) { + key_cmp.resize(left_keys.size()); + for (size_t i = 0; i < left_keys.size(); ++i) { + key_cmp[i] = JoinKeyCmp::EQ; + } + } + HashJoinNodeOptions( + JoinType join_type, std::vector left_keys, + std::vector right_keys, std::vector left_output, + std::vector right_output, std::vector key_cmp, + std::string output_prefix_for_left = default_output_prefix_for_left, + std::string output_prefix_for_right = default_output_prefix_for_right) + : join_type(join_type), + left_keys(std::move(left_keys)), + right_keys(std::move(right_keys)), + output_all(false), + left_output(std::move(left_output)), + right_output(std::move(right_output)), + key_cmp(std::move(key_cmp)), + output_prefix_for_left(std::move(output_prefix_for_left)), + output_prefix_for_right(std::move(output_prefix_for_right)) {} + + // type of join (inner, left, semi...) + JoinType join_type; + // key fields from left input + std::vector left_keys; + // key fields from right input + std::vector right_keys; + // if set all valid fields from both left and right input will be output + // (and field ref vectors for output fields will be ignored) + bool output_all; + // output fields passed from left input + std::vector left_output; + // output fields passed from right input + std::vector right_output; + // key comparison function (determines whether a null key is equal another null key or + // not) + std::vector key_cmp; + // prefix added to names of output fields coming from left input (used to distinguish, + // if necessary, between fields of the same name in left and right input and can be left + // empty if there are no name collisions) + std::string output_prefix_for_left; + // prefix added to names of output fields coming from right input + std::string output_prefix_for_right; +}; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 1f88e0185d1..8b1c2794bb3 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -34,6 +34,7 @@ #include "arrow/testing/random.h" #include "arrow/util/async_generator.h" #include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" #include "arrow/util/thread_pool.h" #include "arrow/util/vector.h" @@ -804,5 +805,120 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) { })))); } +TEST(ExecPlanExecution, SelfInnerHashJoinSink) { + for (bool parallel : {false, true}) { + SCOPED_TRACE(parallel ? "parallel/merged" : "serial"); + + auto input = MakeGroupableBatches(); + + auto exec_ctx = arrow::internal::make_unique( + default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); + AsyncGenerator> sink_gen; + + ExecNode* left_source; + ExecNode* right_source; + for (auto source : {&left_source, &right_source}) { + ASSERT_OK_AND_ASSIGN( + *source, MakeExecNode("source", plan.get(), {}, + SourceNodeOptions{input.schema, + input.gen(parallel, /*slow=*/false)})); + } + ASSERT_OK_AND_ASSIGN( + auto left_filter, + MakeExecNode("filter", plan.get(), {left_source}, + FilterNodeOptions{greater_equal(field_ref("i32"), literal(-1))})); + ASSERT_OK_AND_ASSIGN( + auto right_filter, + MakeExecNode("filter", plan.get(), {right_source}, + FilterNodeOptions{less_equal(field_ref("i32"), literal(2))})); + + // left side: [3, "alfa"], [3, "alfa"], [12, "alfa"], [3, "beta"], [7, "beta"], + // [-1, "gama"], [5, "gama"] + // right side: [-2, "alfa"], [-8, "alfa"], [-1, "gama"] + + HashJoinNodeOptions join_opts{JoinType::INNER, + /*left_keys=*/{"str"}, + /*right_keys=*/{"str"}, "l_", "r_"}; + + ASSERT_OK_AND_ASSIGN( + auto hashjoin, + MakeExecNode("hashjoin", plan.get(), {left_filter, right_filter}, join_opts)); + + ASSERT_OK_AND_ASSIGN(std::ignore, MakeExecNode("sink", plan.get(), {hashjoin}, + SinkNodeOptions{&sink_gen})); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto result, StartAndCollect(plan.get(), sink_gen)); + + std::vector expected = { + ExecBatchFromJSON({int32(), utf8(), int32(), utf8()}, R"([ + [3, "alfa", -2, "alfa"], [3, "alfa", -8, "alfa"], + [3, "alfa", -2, "alfa"], [3, "alfa", -8, "alfa"], + [12, "alfa", -2, "alfa"], [12, "alfa", -8, "alfa"], + [-1, "gama", -1, "gama"], [5, "gama", -1, "gama"]])")}; + + AssertExecBatchesEqual(hashjoin->output_schema(), result, expected); + } +} + +TEST(ExecPlanExecution, SelfOuterHashJoinSink) { + for (bool parallel : {false, true}) { + SCOPED_TRACE(parallel ? "parallel/merged" : "serial"); + + auto input = MakeGroupableBatches(); + + auto exec_ctx = arrow::internal::make_unique( + default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); + AsyncGenerator> sink_gen; + + ExecNode* left_source; + ExecNode* right_source; + for (auto source : {&left_source, &right_source}) { + ASSERT_OK_AND_ASSIGN( + *source, MakeExecNode("source", plan.get(), {}, + SourceNodeOptions{input.schema, + input.gen(parallel, /*slow=*/false)})); + } + ASSERT_OK_AND_ASSIGN( + auto left_filter, + MakeExecNode("filter", plan.get(), {left_source}, + FilterNodeOptions{greater_equal(field_ref("i32"), literal(-1))})); + ASSERT_OK_AND_ASSIGN( + auto right_filter, + MakeExecNode("filter", plan.get(), {right_source}, + FilterNodeOptions{less_equal(field_ref("i32"), literal(2))})); + + // left side: [3, "alfa"], [3, "alfa"], [12, "alfa"], [3, "beta"], [7, "beta"], + // [-1, "gama"], [5, "gama"] + // right side: [-2, "alfa"], [-8, "alfa"], [-1, "gama"] + + HashJoinNodeOptions join_opts{JoinType::FULL_OUTER, + /*left_keys=*/{"str"}, + /*right_keys=*/{"str"}, "l_", "r_"}; + + ASSERT_OK_AND_ASSIGN( + auto hashjoin, + MakeExecNode("hashjoin", plan.get(), {left_filter, right_filter}, join_opts)); + + ASSERT_OK_AND_ASSIGN(std::ignore, MakeExecNode("sink", plan.get(), {hashjoin}, + SinkNodeOptions{&sink_gen})); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto result, StartAndCollect(plan.get(), sink_gen)); + + std::vector expected = { + ExecBatchFromJSON({int32(), utf8(), int32(), utf8()}, R"([ + [3, "alfa", -2, "alfa"], [3, "alfa", -8, "alfa"], + [3, "alfa", -2, "alfa"], [3, "alfa", -8, "alfa"], + [12, "alfa", -2, "alfa"], [12, "alfa", -8, "alfa"], + [3, "beta", null, null], [7, "beta", null, null], + [-1, "gama", -1, "gama"], [5, "gama", -1, "gama"]])")}; + + AssertExecBatchesEqual(hashjoin->output_schema(), result, expected); + } +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/schema_util.h b/cpp/src/arrow/compute/exec/schema_util.h new file mode 100644 index 00000000000..046120714f7 --- /dev/null +++ b/cpp/src/arrow/compute/exec/schema_util.h @@ -0,0 +1,210 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/compute/exec/key_encode.h" // for KeyColumnMetadata +#include "arrow/type.h" // for DataType, FieldRef, Field and Schema +#include "arrow/util/mutex.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { + +/// Helper class for managing different projections of the same row schema. +/// Used to efficiently map any field in one projection to a corresponding field in +/// another projection. +/// Materialized mappings are generated lazily at the time of the first access. +/// Thread-safe apart from initialization. +template +class SchemaProjectionMaps { + public: + static constexpr int kMissingField = -1; + + Status Init(ProjectionIdEnum full_schema_handle, const Schema& schema, + const std::vector& projection_handles, + const std::vector*>& projections) { + ARROW_DCHECK(projection_handles.size() == projections.size()); + RegisterSchema(full_schema_handle, schema); + for (size_t i = 0; i < projections.size(); ++i) { + ARROW_RETURN_NOT_OK( + RegisterProjectedSchema(projection_handles[i], *(projections[i]), schema)); + } + RegisterEnd(); + return Status::OK(); + } + + int num_cols(ProjectionIdEnum schema_handle) const { + int id = schema_id(schema_handle); + return static_cast(schemas_[id].second.size()); + } + + const KeyEncoder::KeyColumnMetadata& column_metadata(ProjectionIdEnum schema_handle, + int field_id) const { + return field(schema_handle, field_id).column_metadata; + } + + const std::string& field_name(ProjectionIdEnum schema_handle, int field_id) const { + return field(schema_handle, field_id).field_name; + } + + const std::shared_ptr& data_type(ProjectionIdEnum schema_handle, + int field_id) const { + return field(schema_handle, field_id).data_type; + } + + const int* map(ProjectionIdEnum from, ProjectionIdEnum to) { + int id_from = schema_id(from); + int id_to = schema_id(to); + int num_schemas = static_cast(schemas_.size()); + int pos = id_from * num_schemas + id_to; + const int* ptr = mapping_ptrs_[pos]; + if (!ptr) { + auto guard = mutex_.Lock(); // acquire the lock + if (!ptr) { + GenerateMap(id_from, id_to); + } + ptr = mapping_ptrs_[pos]; + } + return ptr; + } + + protected: + struct FieldInfo { + int field_path; + std::string field_name; + std::shared_ptr data_type; + KeyEncoder::KeyColumnMetadata column_metadata; + }; + + void RegisterSchema(ProjectionIdEnum handle, const Schema& schema) { + std::vector out_fields; + const FieldVector& in_fields = schema.fields(); + out_fields.resize(in_fields.size()); + for (size_t i = 0; i < in_fields.size(); ++i) { + const std::string& name = in_fields[i]->name(); + const std::shared_ptr& type = in_fields[i]->type(); + out_fields[i].field_path = static_cast(i); + out_fields[i].field_name = name; + out_fields[i].data_type = type; + out_fields[i].column_metadata = ColumnMetadataFromDataType(type); + } + schemas_.push_back(std::make_pair(handle, out_fields)); + } + + Status RegisterProjectedSchema(ProjectionIdEnum handle, + const std::vector& selected_fields, + const Schema& full_schema) { + std::vector out_fields; + const FieldVector& in_fields = full_schema.fields(); + out_fields.resize(selected_fields.size()); + for (size_t i = 0; i < selected_fields.size(); ++i) { + // All fields must be found in schema without ambiguity + ARROW_ASSIGN_OR_RAISE(auto match, selected_fields[i].FindOne(full_schema)); + const std::string& name = in_fields[match[0]]->name(); + const std::shared_ptr& type = in_fields[match[0]]->type(); + out_fields[i].field_path = match[0]; + out_fields[i].field_name = name; + out_fields[i].data_type = type; + out_fields[i].column_metadata = ColumnMetadataFromDataType(type); + } + schemas_.push_back(std::make_pair(handle, out_fields)); + return Status::OK(); + } + + void RegisterEnd() { + size_t size = schemas_.size(); + mapping_ptrs_.resize(size * size); + mapping_bufs_.resize(size * size); + } + + KeyEncoder::KeyColumnMetadata ColumnMetadataFromDataType( + const std::shared_ptr& type) { + if (type->id() == Type::DICTIONARY) { + auto bit_width = checked_cast(*type).bit_width(); + ARROW_DCHECK(bit_width % 8 == 0); + return KeyEncoder::KeyColumnMetadata(true, bit_width / 8); + } else if (type->id() == Type::BOOL) { + return KeyEncoder::KeyColumnMetadata(true, 0); + } else if (is_fixed_width(type->id())) { + return KeyEncoder::KeyColumnMetadata( + true, checked_cast(*type).bit_width() / 8); + } else if (is_binary_like(type->id())) { + return KeyEncoder::KeyColumnMetadata(false, sizeof(uint32_t)); + } else { + ARROW_DCHECK(false); + return KeyEncoder::KeyColumnMetadata(true, 0); + } + } + + int schema_id(ProjectionIdEnum schema_handle) const { + for (size_t i = 0; i < schemas_.size(); ++i) { + if (schemas_[i].first == schema_handle) { + return static_cast(i); + } + } + // We should never get here + ARROW_DCHECK(false); + return -1; + } + + const FieldInfo& field(ProjectionIdEnum schema_handle, int field_id) const { + int id = schema_id(schema_handle); + const std::vector& field_infos = schemas_[id].second; + return field_infos[field_id]; + } + + void GenerateMap(int id_from, int id_to) { + int num_schemas = static_cast(schemas_.size()); + int pos = id_from * num_schemas + id_to; + + int num_cols_from = static_cast(schemas_[id_from].second.size()); + int num_cols_to = static_cast(schemas_[id_to].second.size()); + mapping_bufs_[pos].resize(num_cols_from); + const std::vector& fields_from = schemas_[id_from].second; + const std::vector& fields_to = schemas_[id_to].second; + for (int i = 0; i < num_cols_from; ++i) { + int field_id = kMissingField; + for (int j = 0; j < num_cols_to; ++j) { + if (fields_from[i].field_path == fields_to[j].field_path) { + field_id = j; + // If there are multiple matches for the same input field, + // it will be mapped to the first match. + break; + } + } + mapping_bufs_[pos][i] = field_id; + } + mapping_ptrs_[pos] = mapping_bufs_[pos].data(); + } + + std::vector mapping_ptrs_; + std::vector> mapping_bufs_; + // vector used as a mapping from ProjectionIdEnum to fields + std::vector>> schemas_; + util::Mutex mutex_; +}; + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/task_util.cc b/cpp/src/arrow/compute/exec/task_util.cc new file mode 100644 index 00000000000..5693400ae91 --- /dev/null +++ b/cpp/src/arrow/compute/exec/task_util.cc @@ -0,0 +1,406 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec/task_util.h" + +#include +#include + +#include "arrow/util/logging.h" + +namespace arrow { +namespace compute { + +class TaskSchedulerImpl : public TaskScheduler { + public: + TaskSchedulerImpl(); + int RegisterTaskGroup(TaskImpl task_impl, TaskGroupContinuationImpl cont_impl) override; + void RegisterEnd() override; + Status StartTaskGroup(size_t thread_id, int group_id, int64_t total_num_tasks) override; + Status ExecuteMore(size_t thread_id, int num_tasks_to_execute, + bool execute_all) override; + Status StartScheduling(size_t thread_id, ScheduleImpl schedule_impl, + int num_concurrent_tasks, bool use_sync_execution) override; + void Abort(AbortContinuationImpl impl) override; + + private: + // Task group state transitions progress one way. + // Seeing an old version of the state by a thread is a valid situation. + // + enum class TaskGroupState : int { + NOT_READY, + READY, + ALL_TASKS_STARTED, + ALL_TASKS_FINISHED + }; + + struct TaskGroup { + TaskGroup(TaskImpl task_impl, TaskGroupContinuationImpl cont_impl) + : task_impl_(std::move(task_impl)), + cont_impl_(std::move(cont_impl)), + state_(TaskGroupState::NOT_READY), + num_tasks_present_(0) { + num_tasks_started_.value.store(0); + num_tasks_finished_.value.store(0); + } + TaskGroup(const TaskGroup& src) + : task_impl_(src.task_impl_), + cont_impl_(src.cont_impl_), + state_(TaskGroupState::NOT_READY), + num_tasks_present_(0) { + ARROW_DCHECK(src.state_ == TaskGroupState::NOT_READY); + num_tasks_started_.value.store(0); + num_tasks_finished_.value.store(0); + } + TaskImpl task_impl_; + TaskGroupContinuationImpl cont_impl_; + + TaskGroupState state_; + int64_t num_tasks_present_; + + AtomicWithPadding num_tasks_started_; + AtomicWithPadding num_tasks_finished_; + }; + + std::vector> PickTasks(int num_tasks, int start_task_group = 0); + Status ExecuteTask(size_t thread_id, int group_id, int64_t task_id, + bool* task_group_finished); + bool PostExecuteTask(size_t thread_id, int group_id); + Status OnTaskGroupFinished(size_t thread_id, int group_id, + bool* all_task_groups_finished); + Status ScheduleMore(size_t thread_id, int num_tasks_finished = 0); + + bool use_sync_execution_; + int num_concurrent_tasks_; + ScheduleImpl schedule_impl_; + AbortContinuationImpl abort_cont_impl_; + + std::vector task_groups_; + bool aborted_; + bool register_finished_; + std::mutex mutex_; // Mutex protecting task_groups_ (state_ and num_tasks_present_ + // fields), aborted_ flag and register_finished_ flag + + AtomicWithPadding num_tasks_to_schedule_; +}; + +TaskSchedulerImpl::TaskSchedulerImpl() + : use_sync_execution_(false), + num_concurrent_tasks_(0), + aborted_(false), + register_finished_(false) { + num_tasks_to_schedule_.value.store(0); +} + +int TaskSchedulerImpl::RegisterTaskGroup(TaskImpl task_impl, + TaskGroupContinuationImpl cont_impl) { + int result = static_cast(task_groups_.size()); + task_groups_.emplace_back(std::move(task_impl), std::move(cont_impl)); + return result; +} + +void TaskSchedulerImpl::RegisterEnd() { + std::lock_guard lock(mutex_); + + register_finished_ = true; +} + +Status TaskSchedulerImpl::StartTaskGroup(size_t thread_id, int group_id, + int64_t total_num_tasks) { + ARROW_DCHECK(group_id >= 0 && group_id < static_cast(task_groups_.size())); + TaskGroup& task_group = task_groups_[group_id]; + + bool aborted = false; + bool all_tasks_finished = false; + { + std::lock_guard lock(mutex_); + + aborted = aborted_; + + if (task_group.state_ == TaskGroupState::NOT_READY) { + task_group.num_tasks_present_ = total_num_tasks; + if (total_num_tasks == 0) { + task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED; + all_tasks_finished = true; + } + task_group.state_ = TaskGroupState::READY; + } + } + + if (!aborted && all_tasks_finished) { + bool all_task_groups_finished = false; + RETURN_NOT_OK(OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished)); + if (all_task_groups_finished) { + return Status::OK(); + } + } + + if (!aborted) { + return ScheduleMore(thread_id); + } else { + return Status::Cancelled("Scheduler cancelled"); + } +} + +std::vector> TaskSchedulerImpl::PickTasks(int num_tasks, + int start_task_group) { + std::vector> result; + for (size_t i = 0; i < task_groups_.size(); ++i) { + int task_group_id = static_cast((start_task_group + i) % (task_groups_.size())); + TaskGroup& task_group = task_groups_[task_group_id]; + + if (task_group.state_ != TaskGroupState::READY) { + continue; + } + + int num_tasks_remaining = num_tasks - static_cast(result.size()); + int64_t start_task = + task_group.num_tasks_started_.value.fetch_add(num_tasks_remaining); + if (start_task >= task_group.num_tasks_present_) { + continue; + } + + int num_tasks_current_group = num_tasks_remaining; + if (start_task + num_tasks_current_group >= task_group.num_tasks_present_) { + { + std::lock_guard lock(mutex_); + if (task_group.state_ == TaskGroupState::READY) { + task_group.state_ = TaskGroupState::ALL_TASKS_STARTED; + } + } + num_tasks_current_group = + static_cast(task_group.num_tasks_present_ - start_task); + } + + for (int64_t task_id = start_task; task_id < start_task + num_tasks_current_group; + ++task_id) { + result.push_back(std::make_pair(task_group_id, task_id)); + } + + if (static_cast(result.size()) == num_tasks) { + break; + } + } + + return result; +} + +Status TaskSchedulerImpl::ExecuteTask(size_t thread_id, int group_id, int64_t task_id, + bool* task_group_finished) { + if (!aborted_) { + RETURN_NOT_OK(task_groups_[group_id].task_impl_(thread_id, task_id)); + } + *task_group_finished = PostExecuteTask(thread_id, group_id); + return Status::OK(); +} + +bool TaskSchedulerImpl::PostExecuteTask(size_t thread_id, int group_id) { + int64_t total = task_groups_[group_id].num_tasks_present_; + int64_t prev_finished = task_groups_[group_id].num_tasks_finished_.value.fetch_add(1); + bool all_tasks_finished = (prev_finished + 1 == total); + return all_tasks_finished; +} + +Status TaskSchedulerImpl::OnTaskGroupFinished(size_t thread_id, int group_id, + bool* all_task_groups_finished) { + bool aborted = false; + { + std::lock_guard lock(mutex_); + + aborted = aborted_; + TaskGroup& task_group = task_groups_[group_id]; + task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED; + *all_task_groups_finished = true; + for (size_t i = 0; i < task_groups_.size(); ++i) { + if (task_groups_[i].state_ != TaskGroupState::ALL_TASKS_FINISHED) { + *all_task_groups_finished = false; + break; + } + } + } + + if (aborted && *all_task_groups_finished) { + abort_cont_impl_(); + return Status::Cancelled("Scheduler cancelled"); + } + if (!aborted) { + RETURN_NOT_OK(task_groups_[group_id].cont_impl_(thread_id)); + } + return Status::OK(); +} + +Status TaskSchedulerImpl::ExecuteMore(size_t thread_id, int num_tasks_to_execute, + bool execute_all) { + num_tasks_to_execute = std::max(1, num_tasks_to_execute); + + int last_id = 0; + for (;;) { + if (aborted_) { + return Status::Cancelled("Scheduler cancelled"); + } + + // Pick next bundle of tasks + const auto& tasks = PickTasks(num_tasks_to_execute, last_id); + if (tasks.empty()) { + break; + } + last_id = tasks.back().first; + + // Execute picked tasks immediately + for (size_t i = 0; i < tasks.size(); ++i) { + int group_id = tasks[i].first; + int64_t task_id = tasks[i].second; + bool task_group_finished = false; + Status status = ExecuteTask(thread_id, group_id, task_id, &task_group_finished); + if (!status.ok()) { + // Mark the remaining picked tasks as finished + for (size_t j = i + 1; j < tasks.size(); ++j) { + if (PostExecuteTask(thread_id, tasks[j].first)) { + bool all_task_groups_finished = false; + RETURN_NOT_OK( + OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished)); + if (all_task_groups_finished) { + return Status::OK(); + } + } + } + return status; + } else { + if (task_group_finished) { + bool all_task_groups_finished = false; + RETURN_NOT_OK( + OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished)); + if (all_task_groups_finished) { + return Status::OK(); + } + } + } + } + + if (!execute_all) { + num_tasks_to_execute -= static_cast(tasks.size()); + if (num_tasks_to_execute == 0) { + break; + } + } + } + + return Status::OK(); +} + +Status TaskSchedulerImpl::StartScheduling(size_t thread_id, ScheduleImpl schedule_impl, + int num_concurrent_tasks, + bool use_sync_execution) { + schedule_impl_ = std::move(schedule_impl); + use_sync_execution_ = use_sync_execution; + num_concurrent_tasks_ = num_concurrent_tasks; + num_tasks_to_schedule_.value += num_concurrent_tasks; + return ScheduleMore(thread_id); +} + +Status TaskSchedulerImpl::ScheduleMore(size_t thread_id, int num_tasks_finished) { + if (aborted_) { + return Status::Cancelled("Scheduler cancelled"); + } + + ARROW_DCHECK(register_finished_); + + if (use_sync_execution_) { + return ExecuteMore(thread_id, 1, true); + } + + int num_new_tasks = num_tasks_finished; + for (;;) { + int expected = num_tasks_to_schedule_.value.load(); + if (num_tasks_to_schedule_.value.compare_exchange_strong(expected, 0)) { + num_new_tasks += expected; + break; + } + } + if (num_new_tasks == 0) { + return Status::OK(); + } + + const auto& tasks = PickTasks(num_new_tasks); + if (static_cast(tasks.size()) < num_new_tasks) { + num_tasks_to_schedule_.value += num_new_tasks - static_cast(tasks.size()); + } + + for (size_t i = 0; i < tasks.size(); ++i) { + int group_id = tasks[i].first; + int64_t task_id = tasks[i].second; + RETURN_NOT_OK(schedule_impl_([this, group_id, task_id](size_t thread_id) -> Status { + RETURN_NOT_OK(ScheduleMore(thread_id, 1)); + + bool task_group_finished = false; + RETURN_NOT_OK(ExecuteTask(thread_id, group_id, task_id, &task_group_finished)); + + if (task_group_finished) { + bool all_task_groups_finished = false; + return OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished); + } + + return Status::OK(); + })); + } + + return Status::OK(); +} + +void TaskSchedulerImpl::Abort(AbortContinuationImpl impl) { + bool all_finished = true; + { + std::lock_guard lock(mutex_); + aborted_ = true; + abort_cont_impl_ = std::move(impl); + if (register_finished_) { + for (size_t i = 0; i < task_groups_.size(); ++i) { + TaskGroup& task_group = task_groups_[i]; + if (task_group.state_ == TaskGroupState::NOT_READY) { + task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED; + } else if (task_group.state_ == TaskGroupState::READY) { + int64_t expected = task_group.num_tasks_started_.value.load(); + for (;;) { + if (task_group.num_tasks_started_.value.compare_exchange_strong( + expected, task_group.num_tasks_present_)) { + break; + } + } + int64_t before_add = task_group.num_tasks_finished_.value.fetch_add( + task_group.num_tasks_present_ - expected); + if (before_add >= expected) { + task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED; + } else { + all_finished = false; + task_group.state_ = TaskGroupState::ALL_TASKS_STARTED; + } + } + } + } + } + if (all_finished) { + abort_cont_impl_(); + } +} + +std::unique_ptr TaskScheduler::Make() { + std::unique_ptr impl{new TaskSchedulerImpl()}; + return std::move(impl); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/task_util.h b/cpp/src/arrow/compute/exec/task_util.h new file mode 100644 index 00000000000..44540d255df --- /dev/null +++ b/cpp/src/arrow/compute/exec/task_util.h @@ -0,0 +1,100 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/status.h" +#include "arrow/util/logging.h" + +namespace arrow { +namespace compute { + +// Atomic value surrounded by padding bytes to avoid cache line invalidation +// whenever it is modified by a concurrent thread on a different CPU core. +// +template +class AtomicWithPadding { + private: + static constexpr int kCacheLineSize = 64; + uint8_t padding_before[kCacheLineSize]; + + public: + std::atomic value; + + private: + uint8_t padding_after[kCacheLineSize]; +}; + +// Used for asynchronous execution of operations that can be broken into +// a fixed number of symmetric tasks that can be executed concurrently. +// +// Implements priorities between multiple such operations, called task groups. +// +// Allows to specify the maximum number of in-flight tasks at any moment. +// +// Also allows for executing next pending tasks immediately using a caller thread. +// +class TaskScheduler { + public: + using TaskImpl = std::function; + using TaskGroupContinuationImpl = std::function; + using ScheduleImpl = std::function; + using AbortContinuationImpl = std::function; + + virtual ~TaskScheduler() = default; + + // Order in which task groups are registered represents priorities of their tasks + // (the first group has the highest priority). + // + // Returns task group identifier that is used to request operations on the task group. + virtual int RegisterTaskGroup(TaskImpl task_impl, + TaskGroupContinuationImpl cont_impl) = 0; + + virtual void RegisterEnd() = 0; + + // total_num_tasks may be zero, in which case task group continuation will be executed + // immediately + virtual Status StartTaskGroup(size_t thread_id, int group_id, + int64_t total_num_tasks) = 0; + + // Execute given number of tasks immediately using caller thread + virtual Status ExecuteMore(size_t thread_id, int num_tasks_to_execute, + bool execute_all) = 0; + + // Begin scheduling tasks using provided callback and + // the limit on the number of in-flight tasks at any moment. + // + // Scheduling will continue as long as there are waiting tasks. + // + // It will automatically resume whenever new task group gets started. + virtual Status StartScheduling(size_t thread_id, ScheduleImpl schedule_impl, + int num_concurrent_tasks, bool use_sync_execution) = 0; + + // Abort scheduling and execution. + // Used in case of being notified about unrecoverable error for the entire query. + virtual void Abort(AbortContinuationImpl impl) = 0; + + static std::unique_ptr Make(); +}; + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index 7062d0260ad..03fe96df966 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -30,10 +30,13 @@ #include #include +#include "arrow/compute/api_vector.h" #include "arrow/compute/exec.h" #include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/util.h" #include "arrow/datum.h" #include "arrow/record_batch.h" +#include "arrow/table.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" #include "arrow/type.h" @@ -198,5 +201,37 @@ BatchesWithSchema MakeRandomBatches(const std::shared_ptr& schema, return out; } +Result> SortTableOnAllFields(const std::shared_ptr
& tab) { + std::vector sort_keys; + for (auto&& f : tab->schema()->fields()) { + sort_keys.emplace_back(f->name()); + } + ARROW_ASSIGN_OR_RAISE(auto sort_ids, SortIndices(tab, SortOptions(sort_keys))); + ARROW_ASSIGN_OR_RAISE(auto tab_sorted, Take(tab, sort_ids)); + return tab_sorted.table(); +} + +void AssertTablesEqual(const std::shared_ptr
& exp, + const std::shared_ptr
& act) { + ASSERT_EQ(exp->num_columns(), act->num_columns()); + if (exp->num_rows() == 0) { + ASSERT_EQ(exp->num_rows(), act->num_rows()); + } else { + ASSERT_OK_AND_ASSIGN(auto exp_sorted, SortTableOnAllFields(exp)); + ASSERT_OK_AND_ASSIGN(auto act_sorted, SortTableOnAllFields(act)); + + AssertTablesEqual(*exp_sorted, *act_sorted, + /*same_chunk_layout=*/false, /*flatten=*/true); + } +} + +void AssertExecBatchesEqual(const std::shared_ptr& schema, + const std::vector& exp, + const std::vector& act) { + ASSERT_OK_AND_ASSIGN(auto exp_tab, TableFromExecBatches(schema, exp)); + ASSERT_OK_AND_ASSIGN(auto act_tab, TableFromExecBatches(schema, act)); + AssertTablesEqual(exp_tab, act_tab); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index e21dfd673ec..2ee140a5348 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -93,5 +93,17 @@ ARROW_TESTING_EXPORT BatchesWithSchema MakeRandomBatches(const std::shared_ptr& schema, int num_batches = 10, int batch_size = 4); +ARROW_TESTING_EXPORT +Result> SortTableOnAllFields(const std::shared_ptr
& tab); + +ARROW_TESTING_EXPORT +void AssertTablesEqual(const std::shared_ptr
& exp, + const std::shared_ptr
& act); + +ARROW_TESTING_EXPORT +void AssertExecBatchesEqual(const std::shared_ptr& schema, + const std::vector& exp, + const std::vector& act); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/util.cc b/cpp/src/arrow/compute/exec/util.cc index e2fe61a63c6..64060d44564 100644 --- a/cpp/src/arrow/compute/exec/util.cc +++ b/cpp/src/arrow/compute/exec/util.cc @@ -311,5 +311,26 @@ Result> TableFromExecBatches( return Table::FromRecordBatches(schema, batches); } +size_t ThreadIndexer::operator()() { + auto id = std::this_thread::get_id(); + + auto guard = mutex_.Lock(); // acquire the lock + const auto& id_index = *id_to_index_.emplace(id, id_to_index_.size()).first; + + return Check(id_index.second); +} + +size_t ThreadIndexer::Capacity() { + static size_t max_size = arrow::internal::ThreadPool::DefaultCapacity() + 1; + return max_size; +} + +size_t ThreadIndexer::Check(size_t thread_index) { + DCHECK_LT(thread_index, Capacity()) + << "thread index " << thread_index << " is out of range [0, " << Capacity() << ")"; + + return thread_index; +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/util.h b/cpp/src/arrow/compute/exec/util.h index 63f3315f7e0..ed89bece6a3 100644 --- a/cpp/src/arrow/compute/exec/util.h +++ b/cpp/src/arrow/compute/exec/util.h @@ -19,6 +19,8 @@ #include #include +#include +#include #include #include "arrow/buffer.h" @@ -29,7 +31,9 @@ #include "arrow/util/bit_util.h" #include "arrow/util/cpu_info.h" #include "arrow/util/logging.h" +#include "arrow/util/mutex.h" #include "arrow/util/optional.h" +#include "arrow/util/thread_pool.h" #if defined(__clang__) || defined(__GNUC__) #define BYTESWAP(x) __builtin_bswap64(x) @@ -242,5 +246,18 @@ class AtomicCounter { std::atomic complete_{false}; }; +class ThreadIndexer { + public: + size_t operator()(); + + static size_t Capacity(); + + private: + static size_t Check(size_t thread_index); + + util::Mutex mutex_; + std::unordered_map id_to_index_; +}; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/util_test.cc b/cpp/src/arrow/compute/exec/util_test.cc new file mode 100644 index 00000000000..9659fb2e9de --- /dev/null +++ b/cpp/src/arrow/compute/exec/util_test.cc @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec/hash_join.h" +#include "arrow/compute/exec/schema_util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" + +using testing::Eq; + +namespace arrow { +namespace compute { + +const char* kLeftPrefix = "left."; +const char* kRightPrefix = "right."; + +TEST(FieldMap, Trivial) { + HashJoinSchema schema_mgr; + + auto left = schema({field("i32", int32())}); + auto right = schema({field("i32", int32())}); + + ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32"}, *right, {"i32"}, kLeftPrefix, + kRightPrefix)); + + auto output = schema_mgr.MakeOutputSchema(kLeftPrefix, kRightPrefix); + EXPECT_THAT(*output, Eq(Schema({ + field("left.i32", int32()), + field("right.i32", int32()), + }))); + + auto i = + schema_mgr.proj_maps[0].map(HashJoinProjection::INPUT, HashJoinProjection::OUTPUT); + EXPECT_EQ(i[0], 0); +} + +TEST(FieldMap, TrivialDuplicates) { + HashJoinSchema schema_mgr; + + auto left = schema({field("i32", int32())}); + auto right = schema({field("i32", int32())}); + + ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32"}, *right, {"i32"}, "", "")); + + auto output = schema_mgr.MakeOutputSchema("", ""); + EXPECT_THAT(*output, Eq(Schema({ + field("i32", int32()), + field("i32", int32()), + }))); + + auto i = + schema_mgr.proj_maps[0].map(HashJoinProjection::INPUT, HashJoinProjection::OUTPUT); + EXPECT_EQ(i[0], 0); +} + +TEST(FieldMap, SingleKeyField) { + HashJoinSchema schema_mgr; + + auto left = schema({field("i32", int32()), field("str", utf8())}); + auto right = schema({field("f32", float32()), field("i32", int32())}); + + ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32"}, *right, {"i32"}, kLeftPrefix, + kRightPrefix)); + + EXPECT_EQ(schema_mgr.proj_maps[0].num_cols(HashJoinProjection::INPUT), 2); + EXPECT_EQ(schema_mgr.proj_maps[1].num_cols(HashJoinProjection::INPUT), 2); + EXPECT_EQ(schema_mgr.proj_maps[0].num_cols(HashJoinProjection::KEY), 1); + EXPECT_EQ(schema_mgr.proj_maps[1].num_cols(HashJoinProjection::KEY), 1); + EXPECT_EQ(schema_mgr.proj_maps[0].num_cols(HashJoinProjection::OUTPUT), 2); + EXPECT_EQ(schema_mgr.proj_maps[1].num_cols(HashJoinProjection::OUTPUT), 2); + + auto output = schema_mgr.MakeOutputSchema(kLeftPrefix, kRightPrefix); + EXPECT_THAT(*output, Eq(Schema({ + field("left.i32", int32()), + field("left.str", utf8()), + field("right.f32", float32()), + field("right.i32", int32()), + }))); + + auto i = + schema_mgr.proj_maps[0].map(HashJoinProjection::INPUT, HashJoinProjection::OUTPUT); + EXPECT_EQ(i[0], 0); +} + +TEST(FieldMap, TwoKeyFields) { + HashJoinSchema schema_mgr; + + auto left = schema({ + field("i32", int32()), + field("str", utf8()), + field("bool", boolean()), + }); + auto right = schema({ + field("i32", int32()), + field("str", utf8()), + field("f32", float32()), + field("f64", float64()), + }); + + ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32", "str"}, *right, + {"i32", "str"}, kLeftPrefix, kRightPrefix)); + + auto output = schema_mgr.MakeOutputSchema(kLeftPrefix, kRightPrefix); + EXPECT_THAT(*output, Eq(Schema({ + field("left.i32", int32()), + field("left.str", utf8()), + field("left.bool", boolean()), + + field("right.i32", int32()), + field("right.str", utf8()), + field("right.f32", float32()), + field("right.f64", float64()), + }))); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 6e7c074573d..73c8f9d26c0 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -37,6 +37,7 @@ #include "arrow/compute/kernels/aggregate_internal.h" #include "arrow/compute/kernels/aggregate_var_std_internal.h" #include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/row_encoder.h" #include "arrow/compute/kernels/util_internal.h" #include "arrow/record_batch.h" #include "arrow/util/bit_run_reader.h" @@ -60,344 +61,6 @@ namespace compute { namespace internal { namespace { -struct KeyEncoder { - // the first byte of an encoded key is used to indicate nullity - static constexpr bool kExtraByteForNull = true; - - static constexpr uint8_t kNullByte = 1; - static constexpr uint8_t kValidByte = 0; - - virtual ~KeyEncoder() = default; - - virtual void AddLength(const Datum&, int64_t batch_length, int32_t* lengths) = 0; - - virtual Status Encode(const Datum&, int64_t batch_length, uint8_t** encoded_bytes) = 0; - - virtual Result> Decode(uint8_t** encoded_bytes, - int32_t length, MemoryPool*) = 0; - - // extract the null bitmap from the leading nullity bytes of encoded keys - static Status DecodeNulls(MemoryPool* pool, int32_t length, uint8_t** encoded_bytes, - std::shared_ptr* null_bitmap, int32_t* null_count) { - // first count nulls to determine if a null bitmap is necessary - *null_count = 0; - for (int32_t i = 0; i < length; ++i) { - *null_count += (encoded_bytes[i][0] == kNullByte); - } - - if (*null_count > 0) { - ARROW_ASSIGN_OR_RAISE(*null_bitmap, AllocateBitmap(length, pool)); - uint8_t* validity = (*null_bitmap)->mutable_data(); - - FirstTimeBitmapWriter writer(validity, 0, length); - for (int32_t i = 0; i < length; ++i) { - if (encoded_bytes[i][0] == kValidByte) { - writer.Set(); - } else { - writer.Clear(); - } - writer.Next(); - encoded_bytes[i] += 1; - } - writer.Finish(); - } else { - for (int32_t i = 0; i < length; ++i) { - encoded_bytes[i] += 1; - } - } - return Status ::OK(); - } -}; - -struct BooleanKeyEncoder : KeyEncoder { - static constexpr int kByteWidth = 1; - - void AddLength(const Datum& data, int64_t batch_length, int32_t* lengths) override { - for (int64_t i = 0; i < batch_length; ++i) { - lengths[i] += kByteWidth + kExtraByteForNull; - } - } - - Status Encode(const Datum& data, int64_t batch_length, - uint8_t** encoded_bytes) override { - if (data.is_array()) { - VisitArrayDataInline( - *data.array(), - [&](bool value) { - auto& encoded_ptr = *encoded_bytes++; - *encoded_ptr++ = kValidByte; - *encoded_ptr++ = value; - }, - [&] { - auto& encoded_ptr = *encoded_bytes++; - *encoded_ptr++ = kNullByte; - *encoded_ptr++ = 0; - }); - } else { - const auto& scalar = data.scalar_as(); - bool value = scalar.is_valid && scalar.value; - for (int64_t i = 0; i < batch_length; i++) { - auto& encoded_ptr = *encoded_bytes++; - *encoded_ptr++ = kValidByte; - *encoded_ptr++ = value; - } - } - return Status::OK(); - } - - Result> Decode(uint8_t** encoded_bytes, int32_t length, - MemoryPool* pool) override { - std::shared_ptr null_buf; - int32_t null_count; - RETURN_NOT_OK(DecodeNulls(pool, length, encoded_bytes, &null_buf, &null_count)); - - ARROW_ASSIGN_OR_RAISE(auto key_buf, AllocateBitmap(length, pool)); - - uint8_t* raw_output = key_buf->mutable_data(); - for (int32_t i = 0; i < length; ++i) { - auto& encoded_ptr = encoded_bytes[i]; - BitUtil::SetBitTo(raw_output, i, encoded_ptr[0] != 0); - encoded_ptr += 1; - } - - return ArrayData::Make(boolean(), length, {std::move(null_buf), std::move(key_buf)}, - null_count); - } -}; - -struct FixedWidthKeyEncoder : KeyEncoder { - explicit FixedWidthKeyEncoder(std::shared_ptr type) - : type_(std::move(type)), - byte_width_(checked_cast(*type_).bit_width() / 8) {} - - void AddLength(const Datum& data, int64_t batch_length, int32_t* lengths) override { - for (int64_t i = 0; i < batch_length; ++i) { - lengths[i] += byte_width_ + kExtraByteForNull; - } - } - - Status Encode(const Datum& data, int64_t batch_length, - uint8_t** encoded_bytes) override { - if (data.is_array()) { - const auto& arr = *data.array(); - ArrayData viewed(fixed_size_binary(byte_width_), arr.length, arr.buffers, - arr.null_count, arr.offset); - - VisitArrayDataInline( - viewed, - [&](util::string_view bytes) { - auto& encoded_ptr = *encoded_bytes++; - *encoded_ptr++ = kValidByte; - memcpy(encoded_ptr, bytes.data(), byte_width_); - encoded_ptr += byte_width_; - }, - [&] { - auto& encoded_ptr = *encoded_bytes++; - *encoded_ptr++ = kNullByte; - memset(encoded_ptr, 0, byte_width_); - encoded_ptr += byte_width_; - }); - } else { - const auto& scalar = data.scalar_as(); - if (scalar.is_valid) { - const util::string_view data = scalar.view(); - DCHECK_EQ(data.size(), static_cast(byte_width_)); - for (int64_t i = 0; i < batch_length; i++) { - auto& encoded_ptr = *encoded_bytes++; - *encoded_ptr++ = kValidByte; - memcpy(encoded_ptr, data.data(), data.size()); - encoded_ptr += byte_width_; - } - } else { - for (int64_t i = 0; i < batch_length; i++) { - auto& encoded_ptr = *encoded_bytes++; - *encoded_ptr++ = kNullByte; - memset(encoded_ptr, 0, byte_width_); - encoded_ptr += byte_width_; - } - } - } - return Status::OK(); - } - - Result> Decode(uint8_t** encoded_bytes, int32_t length, - MemoryPool* pool) override { - std::shared_ptr null_buf; - int32_t null_count; - RETURN_NOT_OK(DecodeNulls(pool, length, encoded_bytes, &null_buf, &null_count)); - - ARROW_ASSIGN_OR_RAISE(auto key_buf, AllocateBuffer(length * byte_width_, pool)); - - uint8_t* raw_output = key_buf->mutable_data(); - for (int32_t i = 0; i < length; ++i) { - auto& encoded_ptr = encoded_bytes[i]; - std::memcpy(raw_output, encoded_ptr, byte_width_); - encoded_ptr += byte_width_; - raw_output += byte_width_; - } - - return ArrayData::Make(type_, length, {std::move(null_buf), std::move(key_buf)}, - null_count); - } - - std::shared_ptr type_; - int byte_width_; -}; - -struct DictionaryKeyEncoder : FixedWidthKeyEncoder { - DictionaryKeyEncoder(std::shared_ptr type, MemoryPool* pool) - : FixedWidthKeyEncoder(std::move(type)), pool_(pool) {} - - Status Encode(const Datum& data, int64_t batch_length, - uint8_t** encoded_bytes) override { - auto dict = data.is_array() ? MakeArray(data.array()->dictionary) - : data.scalar_as().value.dictionary; - if (dictionary_) { - if (!dictionary_->Equals(dict)) { - // TODO(bkietz) unify if necessary. For now, just error if any batch's dictionary - // differs from the first we saw for this key - return Status::NotImplemented("Unifying differing dictionaries"); - } - } else { - dictionary_ = std::move(dict); - } - if (data.is_array()) { - return FixedWidthKeyEncoder::Encode(data, batch_length, encoded_bytes); - } - return FixedWidthKeyEncoder::Encode(data.scalar_as().value.index, - batch_length, encoded_bytes); - } - - Result> Decode(uint8_t** encoded_bytes, int32_t length, - MemoryPool* pool) override { - ARROW_ASSIGN_OR_RAISE(auto data, - FixedWidthKeyEncoder::Decode(encoded_bytes, length, pool)); - - if (dictionary_) { - data->dictionary = dictionary_->data(); - } else { - ARROW_ASSIGN_OR_RAISE(auto dict, MakeArrayOfNull(type_, 0)); - data->dictionary = dict->data(); - } - - data->type = type_; - return data; - } - - MemoryPool* pool_; - std::shared_ptr dictionary_; -}; - -template -struct VarLengthKeyEncoder : KeyEncoder { - using Offset = typename T::offset_type; - - void AddLength(const Datum& data, int64_t batch_length, int32_t* lengths) override { - if (data.is_array()) { - int64_t i = 0; - VisitArrayDataInline( - *data.array(), - [&](util::string_view bytes) { - lengths[i++] += - kExtraByteForNull + sizeof(Offset) + static_cast(bytes.size()); - }, - [&] { lengths[i++] += kExtraByteForNull + sizeof(Offset); }); - } else { - const Scalar& scalar = *data.scalar(); - const int32_t buffer_size = - scalar.is_valid ? static_cast(UnboxScalar::Unbox(scalar).size()) - : 0; - for (int64_t i = 0; i < batch_length; i++) { - lengths[i] += kExtraByteForNull + sizeof(Offset) + buffer_size; - } - } - } - - Status Encode(const Datum& data, int64_t batch_length, - uint8_t** encoded_bytes) override { - if (data.is_array()) { - VisitArrayDataInline( - *data.array(), - [&](util::string_view bytes) { - auto& encoded_ptr = *encoded_bytes++; - *encoded_ptr++ = kValidByte; - util::SafeStore(encoded_ptr, static_cast(bytes.size())); - encoded_ptr += sizeof(Offset); - memcpy(encoded_ptr, bytes.data(), bytes.size()); - encoded_ptr += bytes.size(); - }, - [&] { - auto& encoded_ptr = *encoded_bytes++; - *encoded_ptr++ = kNullByte; - util::SafeStore(encoded_ptr, static_cast(0)); - encoded_ptr += sizeof(Offset); - }); - } else { - const auto& scalar = data.scalar_as(); - const auto& bytes = *scalar.value; - if (scalar.is_valid) { - for (int64_t i = 0; i < batch_length; i++) { - auto& encoded_ptr = *encoded_bytes++; - *encoded_ptr++ = kValidByte; - util::SafeStore(encoded_ptr, static_cast(bytes.size())); - encoded_ptr += sizeof(Offset); - memcpy(encoded_ptr, bytes.data(), bytes.size()); - encoded_ptr += bytes.size(); - } - } else { - for (int64_t i = 0; i < batch_length; i++) { - auto& encoded_ptr = *encoded_bytes++; - *encoded_ptr++ = kNullByte; - util::SafeStore(encoded_ptr, static_cast(0)); - encoded_ptr += sizeof(Offset); - } - } - } - return Status::OK(); - } - - Result> Decode(uint8_t** encoded_bytes, int32_t length, - MemoryPool* pool) override { - std::shared_ptr null_buf; - int32_t null_count; - RETURN_NOT_OK(DecodeNulls(pool, length, encoded_bytes, &null_buf, &null_count)); - - Offset length_sum = 0; - for (int32_t i = 0; i < length; ++i) { - length_sum += util::SafeLoadAs(encoded_bytes[i]); - } - - ARROW_ASSIGN_OR_RAISE(auto offset_buf, - AllocateBuffer(sizeof(Offset) * (1 + length), pool)); - ARROW_ASSIGN_OR_RAISE(auto key_buf, AllocateBuffer(length_sum)); - - auto raw_offsets = reinterpret_cast(offset_buf->mutable_data()); - auto raw_keys = key_buf->mutable_data(); - - Offset current_offset = 0; - for (int32_t i = 0; i < length; ++i) { - raw_offsets[i] = current_offset; - - auto key_length = util::SafeLoadAs(encoded_bytes[i]); - encoded_bytes[i] += sizeof(Offset); - - memcpy(raw_keys + current_offset, encoded_bytes[i], key_length); - encoded_bytes[i] += key_length; - - current_offset += key_length; - } - raw_offsets[length] = current_offset; - - return ArrayData::Make( - type_, length, {std::move(null_buf), std::move(offset_buf), std::move(key_buf)}, - null_count); - } - - explicit VarLengthKeyEncoder(std::shared_ptr type) : type_(std::move(type)) {} - - std::shared_ptr type_; -}; - struct GrouperImpl : Grouper { static Result> Make(const std::vector& keys, ExecContext* ctx) { diff --git a/cpp/src/arrow/compute/kernels/row_encoder.cc b/cpp/src/arrow/compute/kernels/row_encoder.cc new file mode 100644 index 00000000000..4f61022a481 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/row_encoder.cc @@ -0,0 +1,357 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/kernels/row_encoder.h" + +#include "arrow/util/bitmap_writer.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" + +namespace arrow { + +using internal::FirstTimeBitmapWriter; + +namespace compute { +namespace internal { + +// extract the null bitmap from the leading nullity bytes of encoded keys +Status KeyEncoder::DecodeNulls(MemoryPool* pool, int32_t length, uint8_t** encoded_bytes, + std::shared_ptr* null_bitmap, + int32_t* null_count) { + // first count nulls to determine if a null bitmap is necessary + *null_count = 0; + for (int32_t i = 0; i < length; ++i) { + *null_count += (encoded_bytes[i][0] == kNullByte); + } + + if (*null_count > 0) { + ARROW_ASSIGN_OR_RAISE(*null_bitmap, AllocateBitmap(length, pool)); + uint8_t* validity = (*null_bitmap)->mutable_data(); + + FirstTimeBitmapWriter writer(validity, 0, length); + for (int32_t i = 0; i < length; ++i) { + if (encoded_bytes[i][0] == kValidByte) { + writer.Set(); + } else { + writer.Clear(); + } + writer.Next(); + encoded_bytes[i] += 1; + } + writer.Finish(); + } else { + for (int32_t i = 0; i < length; ++i) { + encoded_bytes[i] += 1; + } + } + return Status ::OK(); +} + +void BooleanKeyEncoder::AddLength(const Datum& data, int64_t batch_length, + int32_t* lengths) { + for (int64_t i = 0; i < batch_length; ++i) { + lengths[i] += kByteWidth + kExtraByteForNull; + } +} + +void BooleanKeyEncoder::AddLengthNull(int32_t* length) { + *length += kByteWidth + kExtraByteForNull; +} + +Status BooleanKeyEncoder::Encode(const Datum& data, int64_t batch_length, + uint8_t** encoded_bytes) { + if (data.is_array()) { + VisitArrayDataInline( + *data.array(), + [&](bool value) { + auto& encoded_ptr = *encoded_bytes++; + *encoded_ptr++ = kValidByte; + *encoded_ptr++ = value; + }, + [&] { + auto& encoded_ptr = *encoded_bytes++; + *encoded_ptr++ = kNullByte; + *encoded_ptr++ = 0; + }); + } else { + const auto& scalar = data.scalar_as(); + bool value = scalar.is_valid && scalar.value; + for (int64_t i = 0; i < batch_length; i++) { + auto& encoded_ptr = *encoded_bytes++; + *encoded_ptr++ = kValidByte; + *encoded_ptr++ = value; + } + } + return Status::OK(); +} + +void BooleanKeyEncoder::EncodeNull(uint8_t** encoded_bytes) { + auto& encoded_ptr = *encoded_bytes; + *encoded_ptr++ = kNullByte; + *encoded_ptr++ = 0; +} + +Result> BooleanKeyEncoder::Decode(uint8_t** encoded_bytes, + int32_t length, + MemoryPool* pool) { + std::shared_ptr null_buf; + int32_t null_count; + RETURN_NOT_OK(DecodeNulls(pool, length, encoded_bytes, &null_buf, &null_count)); + + ARROW_ASSIGN_OR_RAISE(auto key_buf, AllocateBitmap(length, pool)); + + uint8_t* raw_output = key_buf->mutable_data(); + for (int32_t i = 0; i < length; ++i) { + auto& encoded_ptr = encoded_bytes[i]; + BitUtil::SetBitTo(raw_output, i, encoded_ptr[0] != 0); + encoded_ptr += 1; + } + + return ArrayData::Make(boolean(), length, {std::move(null_buf), std::move(key_buf)}, + null_count); +} + +void FixedWidthKeyEncoder::AddLength(const Datum& data, int64_t batch_length, + int32_t* lengths) { + for (int64_t i = 0; i < batch_length; ++i) { + lengths[i] += byte_width_ + kExtraByteForNull; + } +} + +void FixedWidthKeyEncoder::AddLengthNull(int32_t* length) { + *length += byte_width_ + kExtraByteForNull; +} + +Status FixedWidthKeyEncoder::Encode(const Datum& data, int64_t batch_length, + uint8_t** encoded_bytes) { + if (data.is_array()) { + const auto& arr = *data.array(); + ArrayData viewed(fixed_size_binary(byte_width_), arr.length, arr.buffers, + arr.null_count, arr.offset); + + VisitArrayDataInline( + viewed, + [&](util::string_view bytes) { + auto& encoded_ptr = *encoded_bytes++; + *encoded_ptr++ = kValidByte; + memcpy(encoded_ptr, bytes.data(), byte_width_); + encoded_ptr += byte_width_; + }, + [&] { + auto& encoded_ptr = *encoded_bytes++; + *encoded_ptr++ = kNullByte; + memset(encoded_ptr, 0, byte_width_); + encoded_ptr += byte_width_; + }); + } else { + const auto& scalar = data.scalar_as(); + if (scalar.is_valid) { + const util::string_view data = scalar.view(); + DCHECK_EQ(data.size(), static_cast(byte_width_)); + for (int64_t i = 0; i < batch_length; i++) { + auto& encoded_ptr = *encoded_bytes++; + *encoded_ptr++ = kValidByte; + memcpy(encoded_ptr, data.data(), data.size()); + encoded_ptr += byte_width_; + } + } else { + for (int64_t i = 0; i < batch_length; i++) { + auto& encoded_ptr = *encoded_bytes++; + *encoded_ptr++ = kNullByte; + memset(encoded_ptr, 0, byte_width_); + encoded_ptr += byte_width_; + } + } + } + return Status::OK(); +} + +void FixedWidthKeyEncoder::EncodeNull(uint8_t** encoded_bytes) { + auto& encoded_ptr = *encoded_bytes; + *encoded_ptr++ = kNullByte; + memset(encoded_ptr, 0, byte_width_); + encoded_ptr += byte_width_; +} + +Result> FixedWidthKeyEncoder::Decode(uint8_t** encoded_bytes, + int32_t length, + MemoryPool* pool) { + std::shared_ptr null_buf; + int32_t null_count; + RETURN_NOT_OK(DecodeNulls(pool, length, encoded_bytes, &null_buf, &null_count)); + + ARROW_ASSIGN_OR_RAISE(auto key_buf, AllocateBuffer(length * byte_width_, pool)); + + uint8_t* raw_output = key_buf->mutable_data(); + for (int32_t i = 0; i < length; ++i) { + auto& encoded_ptr = encoded_bytes[i]; + std::memcpy(raw_output, encoded_ptr, byte_width_); + encoded_ptr += byte_width_; + raw_output += byte_width_; + } + + return ArrayData::Make(type_, length, {std::move(null_buf), std::move(key_buf)}, + null_count); +} + +Status DictionaryKeyEncoder::Encode(const Datum& data, int64_t batch_length, + uint8_t** encoded_bytes) { + auto dict = data.is_array() ? MakeArray(data.array()->dictionary) + : data.scalar_as().value.dictionary; + if (dictionary_) { + if (!dictionary_->Equals(dict)) { + // TODO(bkietz) unify if necessary. For now, just error if any batch's dictionary + // differs from the first we saw for this key + return Status::NotImplemented("Unifying differing dictionaries"); + } + } else { + dictionary_ = std::move(dict); + } + if (data.is_array()) { + return FixedWidthKeyEncoder::Encode(data, batch_length, encoded_bytes); + } + return FixedWidthKeyEncoder::Encode(data.scalar_as().value.index, + batch_length, encoded_bytes); +} + +Result> DictionaryKeyEncoder::Decode(uint8_t** encoded_bytes, + int32_t length, + MemoryPool* pool) { + ARROW_ASSIGN_OR_RAISE(auto data, + FixedWidthKeyEncoder::Decode(encoded_bytes, length, pool)); + + if (dictionary_) { + data->dictionary = dictionary_->data(); + } else { + ARROW_ASSIGN_OR_RAISE(auto dict, MakeArrayOfNull(type_, 0)); + data->dictionary = dict->data(); + } + + data->type = type_; + return data; +} + +void RowEncoder::Init(const std::vector& column_types, ExecContext* ctx) { + ctx_ = ctx; + encoders_.resize(column_types.size()); + + for (size_t i = 0; i < column_types.size(); ++i) { + const auto& column_type = column_types[i].type; + + if (column_type->id() == Type::BOOL) { + encoders_[i] = std::make_shared(); + continue; + } + + if (column_type->id() == Type::DICTIONARY) { + encoders_[i] = + std::make_shared(column_type, ctx->memory_pool()); + continue; + } + + if (is_fixed_width(column_type->id())) { + encoders_[i] = std::make_shared(column_type); + continue; + } + + if (is_binary_like(column_type->id())) { + encoders_[i] = std::make_shared>(column_type); + continue; + } + + if (is_large_binary_like(column_type->id())) { + encoders_[i] = std::make_shared>(column_type); + continue; + } + + // We should not get here + ARROW_DCHECK(false); + } + + int32_t total_length = 0; + for (size_t i = 0; i < column_types.size(); ++i) { + encoders_[i]->AddLengthNull(&total_length); + } + encoded_nulls_.resize(total_length); + uint8_t* buf_ptr = encoded_nulls_.data(); + for (size_t i = 0; i < column_types.size(); ++i) { + encoders_[i]->EncodeNull(&buf_ptr); + } +} + +void RowEncoder::Clear() { + offsets_.clear(); + bytes_.clear(); +} + +Status RowEncoder::EncodeAndAppend(const ExecBatch& batch) { + if (offsets_.empty()) { + offsets_.resize(1); + offsets_[0] = 0; + } + size_t length_before = offsets_.size() - 1; + offsets_.resize(length_before + batch.length + 1); + for (int64_t i = 0; i < batch.length; ++i) { + offsets_[length_before + 1 + i] = 0; + } + + for (int i = 0; i < batch.num_values(); ++i) { + encoders_[i]->AddLength(batch[i], batch.length, offsets_.data() + length_before + 1); + } + + int32_t total_length = offsets_[length_before]; + for (int64_t i = 0; i < batch.length; ++i) { + total_length += offsets_[length_before + 1 + i]; + offsets_[length_before + 1 + i] = total_length; + } + + bytes_.resize(total_length); + std::vector buf_ptrs(batch.length); + for (int64_t i = 0; i < batch.length; ++i) { + buf_ptrs[i] = bytes_.data() + offsets_[length_before + i]; + } + + for (int i = 0; i < batch.num_values(); ++i) { + RETURN_NOT_OK(encoders_[i]->Encode(batch[i], batch.length, buf_ptrs.data())); + } + + return Status::OK(); +} + +Result RowEncoder::Decode(int64_t num_rows, const int32_t* row_ids) { + ExecBatch out({}, num_rows); + + std::vector buf_ptrs(num_rows); + for (int64_t i = 0; i < num_rows; ++i) { + buf_ptrs[i] = (row_ids[i] == kRowIdForNulls()) ? encoded_nulls_.data() + : bytes_.data() + offsets_[row_ids[i]]; + } + + out.values.resize(encoders_.size()); + for (size_t i = 0; i < encoders_.size(); ++i) { + ARROW_ASSIGN_OR_RAISE( + out.values[i], + encoders_[i]->Decode(buf_ptrs.data(), static_cast(num_rows), + ctx_->memory_pool())); + } + + return out; +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/row_encoder.h b/cpp/src/arrow/compute/kernels/row_encoder.h new file mode 100644 index 00000000000..49356c5e9fc --- /dev/null +++ b/cpp/src/arrow/compute/kernels/row_encoder.h @@ -0,0 +1,263 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/compute/exec.h" +#include "arrow/compute/kernels/codegen_internal.h" +#include "arrow/visitor_inline.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { +namespace internal { + +struct KeyEncoder { + // the first byte of an encoded key is used to indicate nullity + static constexpr bool kExtraByteForNull = true; + + static constexpr uint8_t kNullByte = 1; + static constexpr uint8_t kValidByte = 0; + + virtual ~KeyEncoder() = default; + + virtual void AddLength(const Datum&, int64_t batch_length, int32_t* lengths) = 0; + + virtual void AddLengthNull(int32_t* length) = 0; + + virtual Status Encode(const Datum&, int64_t batch_length, uint8_t** encoded_bytes) = 0; + + virtual void EncodeNull(uint8_t** encoded_bytes) = 0; + + virtual Result> Decode(uint8_t** encoded_bytes, + int32_t length, MemoryPool*) = 0; + + // extract the null bitmap from the leading nullity bytes of encoded keys + static Status DecodeNulls(MemoryPool* pool, int32_t length, uint8_t** encoded_bytes, + std::shared_ptr* null_bitmap, int32_t* null_count); +}; + +struct BooleanKeyEncoder : KeyEncoder { + static constexpr int kByteWidth = 1; + + void AddLength(const Datum& data, int64_t batch_length, int32_t* lengths) override; + + void AddLengthNull(int32_t* length) override; + + Status Encode(const Datum& data, int64_t batch_length, + uint8_t** encoded_bytes) override; + + void EncodeNull(uint8_t** encoded_bytes) override; + + Result> Decode(uint8_t** encoded_bytes, int32_t length, + MemoryPool* pool) override; +}; + +struct FixedWidthKeyEncoder : KeyEncoder { + explicit FixedWidthKeyEncoder(std::shared_ptr type) + : type_(std::move(type)), + byte_width_(checked_cast(*type_).bit_width() / 8) {} + + void AddLength(const Datum& data, int64_t batch_length, int32_t* lengths) override; + + void AddLengthNull(int32_t* length) override; + + Status Encode(const Datum& data, int64_t batch_length, + uint8_t** encoded_bytes) override; + + void EncodeNull(uint8_t** encoded_bytes) override; + + Result> Decode(uint8_t** encoded_bytes, int32_t length, + MemoryPool* pool) override; + + std::shared_ptr type_; + int byte_width_; +}; + +struct DictionaryKeyEncoder : FixedWidthKeyEncoder { + DictionaryKeyEncoder(std::shared_ptr type, MemoryPool* pool) + : FixedWidthKeyEncoder(std::move(type)), pool_(pool) {} + + Status Encode(const Datum& data, int64_t batch_length, + uint8_t** encoded_bytes) override; + + Result> Decode(uint8_t** encoded_bytes, int32_t length, + MemoryPool* pool) override; + + MemoryPool* pool_; + std::shared_ptr dictionary_; +}; + +template +struct VarLengthKeyEncoder : KeyEncoder { + using Offset = typename T::offset_type; + + void AddLength(const Datum& data, int64_t batch_length, int32_t* lengths) override { + if (data.is_array()) { + int64_t i = 0; + VisitArrayDataInline( + *data.array(), + [&](util::string_view bytes) { + lengths[i++] += + kExtraByteForNull + sizeof(Offset) + static_cast(bytes.size()); + }, + [&] { lengths[i++] += kExtraByteForNull + sizeof(Offset); }); + } else { + const Scalar& scalar = *data.scalar(); + const int32_t buffer_size = + scalar.is_valid ? static_cast(UnboxScalar::Unbox(scalar).size()) + : 0; + for (int64_t i = 0; i < batch_length; i++) { + lengths[i] += kExtraByteForNull + sizeof(Offset) + buffer_size; + } + } + } + + void AddLengthNull(int32_t* length) override { + *length += kExtraByteForNull + sizeof(Offset); + } + + Status Encode(const Datum& data, int64_t batch_length, + uint8_t** encoded_bytes) override { + if (data.is_array()) { + VisitArrayDataInline( + *data.array(), + [&](util::string_view bytes) { + auto& encoded_ptr = *encoded_bytes++; + *encoded_ptr++ = kValidByte; + util::SafeStore(encoded_ptr, static_cast(bytes.size())); + encoded_ptr += sizeof(Offset); + memcpy(encoded_ptr, bytes.data(), bytes.size()); + encoded_ptr += bytes.size(); + }, + [&] { + auto& encoded_ptr = *encoded_bytes++; + *encoded_ptr++ = kNullByte; + util::SafeStore(encoded_ptr, static_cast(0)); + encoded_ptr += sizeof(Offset); + }); + } else { + const auto& scalar = data.scalar_as(); + const auto& bytes = *scalar.value; + if (scalar.is_valid) { + for (int64_t i = 0; i < batch_length; i++) { + auto& encoded_ptr = *encoded_bytes++; + *encoded_ptr++ = kValidByte; + util::SafeStore(encoded_ptr, static_cast(bytes.size())); + encoded_ptr += sizeof(Offset); + memcpy(encoded_ptr, bytes.data(), bytes.size()); + encoded_ptr += bytes.size(); + } + } else { + for (int64_t i = 0; i < batch_length; i++) { + auto& encoded_ptr = *encoded_bytes++; + *encoded_ptr++ = kNullByte; + util::SafeStore(encoded_ptr, static_cast(0)); + encoded_ptr += sizeof(Offset); + } + } + } + return Status::OK(); + } + + void EncodeNull(uint8_t** encoded_bytes) override { + auto& encoded_ptr = *encoded_bytes; + *encoded_ptr++ = kNullByte; + util::SafeStore(encoded_ptr, static_cast(0)); + encoded_ptr += sizeof(Offset); + } + + Result> Decode(uint8_t** encoded_bytes, int32_t length, + MemoryPool* pool) override { + std::shared_ptr null_buf; + int32_t null_count; + ARROW_RETURN_NOT_OK(DecodeNulls(pool, length, encoded_bytes, &null_buf, &null_count)); + + Offset length_sum = 0; + for (int32_t i = 0; i < length; ++i) { + length_sum += util::SafeLoadAs(encoded_bytes[i]); + } + + ARROW_ASSIGN_OR_RAISE(auto offset_buf, + AllocateBuffer(sizeof(Offset) * (1 + length), pool)); + ARROW_ASSIGN_OR_RAISE(auto key_buf, AllocateBuffer(length_sum)); + + auto raw_offsets = reinterpret_cast(offset_buf->mutable_data()); + auto raw_keys = key_buf->mutable_data(); + + Offset current_offset = 0; + for (int32_t i = 0; i < length; ++i) { + raw_offsets[i] = current_offset; + + auto key_length = util::SafeLoadAs(encoded_bytes[i]); + encoded_bytes[i] += sizeof(Offset); + + memcpy(raw_keys + current_offset, encoded_bytes[i], key_length); + encoded_bytes[i] += key_length; + + current_offset += key_length; + } + raw_offsets[length] = current_offset; + + return ArrayData::Make( + type_, length, {std::move(null_buf), std::move(offset_buf), std::move(key_buf)}, + null_count); + } + + explicit VarLengthKeyEncoder(std::shared_ptr type) : type_(std::move(type)) {} + + std::shared_ptr type_; +}; + +class ARROW_EXPORT RowEncoder { + public: + static constexpr int kRowIdForNulls() { return -1; } + + void Init(const std::vector& column_types, ExecContext* ctx); + void Clear(); + Status EncodeAndAppend(const ExecBatch& batch); + Result Decode(int64_t num_rows, const int32_t* row_ids); + + inline std::string encoded_row(int32_t i) const { + if (i == kRowIdForNulls()) { + return std::string(reinterpret_cast(encoded_nulls_.data()), + encoded_nulls_.size()); + } + int32_t row_length = offsets_[i + 1] - offsets_[i]; + return std::string(reinterpret_cast(bytes_.data() + offsets_[i]), + row_length); + } + + int32_t num_rows() const { + return offsets_.size() == 0 ? 0 : static_cast(offsets_.size() - 1); + } + + private: + ExecContext* ctx_; + std::vector> encoders_; + std::vector offsets_; + std::vector bytes_; + std::vector encoded_nulls_; +}; + +} // namespace internal +} // namespace compute +} // namespace arrow