diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index ac07ccf4b7d..93dd1297bd7 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -404,6 +404,7 @@ if(ARROW_COMPUTE) compute/exec/project_node.cc compute/exec/sink_node.cc compute/exec/source_node.cc + compute/exec/swiss_join.cc compute/exec/task_util.cc compute/exec/tpch_node.cc compute/exec/union_node.cc @@ -459,6 +460,7 @@ if(ARROW_COMPUTE) append_avx2_src(compute/exec/bloom_filter_avx2.cc) append_avx2_src(compute/exec/key_hash_avx2.cc) append_avx2_src(compute/exec/key_map_avx2.cc) + append_avx2_src(compute/exec/swiss_join_avx2.cc) append_avx2_src(compute/exec/util_avx2.cc) append_avx2_src(compute/row/compare_internal_avx2.cc) append_avx2_src(compute/row/encode_internal_avx2.cc) diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index a376fb5f57b..e821979ae10 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -41,18 +41,21 @@ class HashJoinBasicImpl : public HashJoinImpl { public: Status Init(ExecContext* ctx, JoinType join_type, size_t num_threads, - HashJoinSchema* schema_mgr, std::vector key_cmp, - Expression filter, OutputBatchCallback output_batch_callback, + const HashJoinProjectionMaps* proj_map_left, + const HashJoinProjectionMaps* proj_map_right, + std::vector key_cmp, Expression filter, + OutputBatchCallback output_batch_callback, FinishedCallback finished_callback, TaskScheduler* scheduler) override { START_COMPUTE_SPAN(span_, "HashJoinBasicImpl", {{"detail", filter.ToString()}, - {"join.kind", ToString(join_type)}, + {"join.kind", arrow::compute::ToString(join_type)}, {"join.threads", static_cast(num_threads)}}); num_threads_ = num_threads; ctx_ = ctx; join_type_ = join_type; - schema_mgr_ = schema_mgr; + schema_[0] = proj_map_left; + schema_[1] = proj_map_right; key_cmp_ = std::move(key_cmp); filter_ = std::move(filter); output_batch_callback_ = std::move(output_batch_callback); @@ -82,13 +85,15 @@ class HashJoinBasicImpl : public HashJoinImpl { scheduler_->Abort(std::move(pos_abort_callback)); } + std::string ToString() const override { return "HashJoinBasicImpl"; } + private: void InitEncoder(int side, HashJoinProjection projection_handle, RowEncoder* encoder) { std::vector data_types; - int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle); + int num_cols = schema_[side]->num_cols(projection_handle); data_types.resize(num_cols); for (int icol = 0; icol < num_cols; ++icol) { - data_types[icol] = schema_mgr_->proj_maps[side].data_type(projection_handle, icol); + data_types[icol] = schema_[side]->data_type(projection_handle, icol); } encoder->Init(data_types, ctx_); encoder->Clear(); @@ -99,8 +104,7 @@ class HashJoinBasicImpl : public HashJoinImpl { ThreadLocalState& local_state = local_states_[thread_index]; if (!local_state.is_initialized) { InitEncoder(0, HashJoinProjection::KEY, &local_state.exec_batch_keys); - bool has_payload = - (schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0); + bool has_payload = (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0); if (has_payload) { InitEncoder(0, HashJoinProjection::PAYLOAD, &local_state.exec_batch_payloads); } @@ -112,11 +116,10 @@ class HashJoinBasicImpl : public HashJoinImpl { Status EncodeBatch(int side, HashJoinProjection projection_handle, RowEncoder* encoder, const ExecBatch& batch, ExecBatch* opt_projected_batch = nullptr) { ExecBatch projected({}, batch.length); - int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle); + int num_cols = schema_[side]->num_cols(projection_handle); projected.values.resize(num_cols); - auto to_input = - schema_mgr_->proj_maps[side].map(projection_handle, HashJoinProjection::INPUT); + auto to_input = schema_[side]->map(projection_handle, HashJoinProjection::INPUT); for (int icol = 0; icol < num_cols; ++icol) { projected.values[icol] = batch.values[to_input.get(icol)]; } @@ -179,19 +182,17 @@ class HashJoinBasicImpl : public HashJoinImpl { ExecBatch* opt_left_payload, ExecBatch* opt_right_key, ExecBatch* opt_right_payload) { ExecBatch result({}, batch_size_next); - int num_out_cols_left = - schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT); - int num_out_cols_right = - schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT); + int num_out_cols_left = schema_[0]->num_cols(HashJoinProjection::OUTPUT); + int num_out_cols_right = schema_[1]->num_cols(HashJoinProjection::OUTPUT); result.values.resize(num_out_cols_left + num_out_cols_right); - auto from_key = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT, - HashJoinProjection::KEY); - auto from_payload = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT, - HashJoinProjection::PAYLOAD); + auto from_key = schema_[0]->map(HashJoinProjection::OUTPUT, HashJoinProjection::KEY); + auto from_payload = + schema_[0]->map(HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD); for (int icol = 0; icol < num_out_cols_left; ++icol) { - bool is_from_key = (from_key.get(icol) != HashJoinSchema::kMissingField()); - bool is_from_payload = (from_payload.get(icol) != HashJoinSchema::kMissingField()); + bool is_from_key = (from_key.get(icol) != HashJoinProjectionMaps::kMissingField); + bool is_from_payload = + (from_payload.get(icol) != HashJoinProjectionMaps::kMissingField); ARROW_DCHECK(is_from_key != is_from_payload); ARROW_DCHECK(!is_from_key || (opt_left_key && @@ -206,13 +207,13 @@ class HashJoinBasicImpl : public HashJoinImpl { ? opt_left_key->values[from_key.get(icol)] : opt_left_payload->values[from_payload.get(icol)]; } - from_key = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT, - HashJoinProjection::KEY); - from_payload = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT, - HashJoinProjection::PAYLOAD); + from_key = schema_[1]->map(HashJoinProjection::OUTPUT, HashJoinProjection::KEY); + from_payload = + schema_[1]->map(HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD); for (int icol = 0; icol < num_out_cols_right; ++icol) { - bool is_from_key = (from_key.get(icol) != HashJoinSchema::kMissingField()); - bool is_from_payload = (from_payload.get(icol) != HashJoinSchema::kMissingField()); + bool is_from_key = (from_key.get(icol) != HashJoinProjectionMaps::kMissingField); + bool is_from_payload = + (from_payload.get(icol) != HashJoinProjectionMaps::kMissingField); ARROW_DCHECK(is_from_key != is_from_payload); ARROW_DCHECK(!is_from_key || (opt_right_key && @@ -228,7 +229,7 @@ class HashJoinBasicImpl : public HashJoinImpl { : opt_right_payload->values[from_payload.get(icol)]; } - output_batch_callback_(std::move(result)); + output_batch_callback_(0, std::move(result)); // Update the counter of produced batches // @@ -254,13 +255,13 @@ class HashJoinBasicImpl : public HashJoinImpl { hash_table_keys_.Decode(match_right.size(), match_right.data())); ExecBatch left_payload; - if (!schema_mgr_->LeftPayloadIsEmpty()) { + if (!schema_[0]->is_empty(HashJoinProjection::PAYLOAD)) { ARROW_ASSIGN_OR_RAISE(left_payload, local_state.exec_batch_payloads.Decode( match_left.size(), match_left.data())); } ExecBatch right_payload; - if (!schema_mgr_->RightPayloadIsEmpty()) { + if (!schema_[1]->is_empty(HashJoinProjection::PAYLOAD)) { ARROW_ASSIGN_OR_RAISE(right_payload, hash_table_payloads_.Decode( match_right.size(), match_right.data())); } @@ -280,14 +281,14 @@ class HashJoinBasicImpl : public HashJoinImpl { } }; - SchemaProjectionMap left_to_key = schema_mgr_->proj_maps[0].map( - HashJoinProjection::FILTER, HashJoinProjection::KEY); - SchemaProjectionMap left_to_pay = schema_mgr_->proj_maps[0].map( - HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD); - SchemaProjectionMap right_to_key = schema_mgr_->proj_maps[1].map( - HashJoinProjection::FILTER, HashJoinProjection::KEY); - SchemaProjectionMap right_to_pay = schema_mgr_->proj_maps[1].map( - HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD); + SchemaProjectionMap left_to_key = + schema_[0]->map(HashJoinProjection::FILTER, HashJoinProjection::KEY); + SchemaProjectionMap left_to_pay = + schema_[0]->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD); + SchemaProjectionMap right_to_key = + schema_[1]->map(HashJoinProjection::FILTER, HashJoinProjection::KEY); + SchemaProjectionMap right_to_pay = + schema_[1]->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD); AppendFields(left_to_key, left_to_pay, left_key, left_payload); AppendFields(right_to_key, right_to_pay, right_key, right_payload); @@ -363,15 +364,14 @@ class HashJoinBasicImpl : public HashJoinImpl { bool has_left = (join_type_ != JoinType::RIGHT_SEMI && join_type_ != JoinType::RIGHT_ANTI && - schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT) > 0); + schema_[0]->num_cols(HashJoinProjection::OUTPUT) > 0); bool has_right = (join_type_ != JoinType::LEFT_SEMI && join_type_ != JoinType::LEFT_ANTI && - schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT) > 0); + schema_[1]->num_cols(HashJoinProjection::OUTPUT) > 0); bool has_left_payload = - has_left && (schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0); + has_left && (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0); bool has_right_payload = - has_right && - (schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0); + has_right && (schema_[1]->num_cols(HashJoinProjection::PAYLOAD) > 0); ThreadLocalState& local_state = local_states_[thread_index]; RETURN_NOT_OK(InitLocalStateIfNeeded(thread_index)); @@ -394,7 +394,7 @@ class HashJoinBasicImpl : public HashJoinImpl { ARROW_ASSIGN_OR_RAISE(right_key, hash_table_keys_.Decode(batch_size_next, opt_right_ids)); // Post process build side keys that use dictionary - RETURN_NOT_OK(dict_build_.PostDecode(schema_mgr_->proj_maps[1], &right_key, ctx_)); + RETURN_NOT_OK(dict_build_.PostDecode(*schema_[1], &right_key, ctx_)); } if (has_right_payload) { ARROW_ASSIGN_OR_RAISE(right_payload, @@ -494,8 +494,7 @@ class HashJoinBasicImpl : public HashJoinImpl { RETURN_NOT_OK(EncodeBatch(0, HashJoinProjection::KEY, &local_state.exec_batch_keys, batch, &batch_key_for_lookups)); - bool has_left_payload = - (schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0); + bool has_left_payload = (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0); if (has_left_payload) { local_state.exec_batch_payloads.Clear(); RETURN_NOT_OK(EncodeBatch(0, HashJoinProjection::PAYLOAD, @@ -507,13 +506,13 @@ class HashJoinBasicImpl : public HashJoinImpl { local_state.match_left.clear(); local_state.match_right.clear(); - bool use_key_batch_for_dicts = dict_probe_.BatchRemapNeeded( - thread_index, schema_mgr_->proj_maps[0], schema_mgr_->proj_maps[1], ctx_); + bool use_key_batch_for_dicts = + dict_probe_.BatchRemapNeeded(thread_index, *schema_[0], *schema_[1], ctx_); RowEncoder* row_encoder_for_lookups = &local_state.exec_batch_keys; if (use_key_batch_for_dicts) { - RETURN_NOT_OK(dict_probe_.EncodeBatch( - thread_index, schema_mgr_->proj_maps[0], schema_mgr_->proj_maps[1], dict_build_, - batch, &row_encoder_for_lookups, &batch_key_for_lookups, ctx_)); + RETURN_NOT_OK(dict_probe_.EncodeBatch(thread_index, *schema_[0], *schema_[1], + dict_build_, batch, &row_encoder_for_lookups, + &batch_key_for_lookups, ctx_)); } // Collect information about all nulls in key columns. @@ -561,9 +560,8 @@ class HashJoinBasicImpl : public HashJoinImpl { if (batches.empty()) { hash_table_empty_ = true; } else { - dict_build_.InitEncoder(schema_mgr_->proj_maps[1], &hash_table_keys_, ctx_); - bool has_payload = - (schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0); + dict_build_.InitEncoder(*schema_[1], &hash_table_keys_, ctx_); + bool has_payload = (schema_[1]->num_cols(HashJoinProjection::PAYLOAD) > 0); if (has_payload) { InitEncoder(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_); } @@ -578,11 +576,11 @@ class HashJoinBasicImpl : public HashJoinImpl { } else if (hash_table_empty_) { hash_table_empty_ = false; - RETURN_NOT_OK(dict_build_.Init(schema_mgr_->proj_maps[1], &batch, ctx_)); + RETURN_NOT_OK(dict_build_.Init(*schema_[1], &batch, ctx_)); } int32_t num_rows_before = hash_table_keys_.num_rows(); - RETURN_NOT_OK(dict_build_.EncodeBatch(thread_index, schema_mgr_->proj_maps[1], - batch, &hash_table_keys_, ctx_)); + RETURN_NOT_OK(dict_build_.EncodeBatch(thread_index, *schema_[1], batch, + &hash_table_keys_, ctx_)); if (has_payload) { RETURN_NOT_OK( EncodeBatch(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_, batch)); @@ -595,7 +593,7 @@ class HashJoinBasicImpl : public HashJoinImpl { } if (hash_table_empty_) { - RETURN_NOT_OK(dict_build_.Init(schema_mgr_->proj_maps[1], nullptr, ctx_)); + RETURN_NOT_OK(dict_build_.Init(*schema_[1], nullptr, ctx_)); } return Status::OK(); @@ -740,7 +738,7 @@ class HashJoinBasicImpl : public HashJoinImpl { ExecContext* ctx_; JoinType join_type_; size_t num_threads_; - HashJoinSchema* schema_mgr_; + const HashJoinProjectionMaps* schema_[2]; std::vector key_cmp_; Expression filter_; TaskScheduler* scheduler_; diff --git a/cpp/src/arrow/compute/exec/hash_join.h b/cpp/src/arrow/compute/exec/hash_join.h index 97bdf166a01..19add7d4405 100644 --- a/cpp/src/arrow/compute/exec/hash_join.h +++ b/cpp/src/arrow/compute/exec/hash_join.h @@ -36,81 +36,18 @@ namespace compute { using arrow::util::AccumulationQueue; -class ARROW_EXPORT HashJoinSchema { - public: - Status Init(JoinType join_type, const Schema& left_schema, - const std::vector& left_keys, const Schema& right_schema, - const std::vector& right_keys, const Expression& filter, - const std::string& left_field_name_prefix, - const std::string& right_field_name_prefix); - - Status Init(JoinType join_type, const Schema& left_schema, - const std::vector& left_keys, - const std::vector& left_output, const Schema& right_schema, - const std::vector& right_keys, - const std::vector& right_output, const Expression& filter, - const std::string& left_field_name_prefix, - const std::string& right_field_name_prefix); - - static Status ValidateSchemas(JoinType join_type, const Schema& left_schema, - const std::vector& left_keys, - const std::vector& left_output, - const Schema& right_schema, - const std::vector& right_keys, - const std::vector& right_output, - const std::string& left_field_name_prefix, - const std::string& right_field_name_prefix); - - Result BindFilter(Expression filter, const Schema& left_schema, - const Schema& right_schema, ExecContext* exec_context); - std::shared_ptr MakeOutputSchema(const std::string& left_field_name_suffix, - const std::string& right_field_name_suffix); - - bool LeftPayloadIsEmpty() { return PayloadIsEmpty(0); } - - bool RightPayloadIsEmpty() { return PayloadIsEmpty(1); } - - static int kMissingField() { - return SchemaProjectionMaps::kMissingField; - } - - SchemaProjectionMaps proj_maps[2]; - - private: - static bool IsTypeSupported(const DataType& type); - - Status CollectFilterColumns(std::vector& left_filter, - std::vector& right_filter, - const Expression& filter, const Schema& left_schema, - const Schema& right_schema); - - Expression RewriteFilterToUseFilterSchema(int right_filter_offset, - const SchemaProjectionMap& left_to_filter, - const SchemaProjectionMap& right_to_filter, - const Expression& filter); - - bool PayloadIsEmpty(int side) { - ARROW_DCHECK(side == 0 || side == 1); - return proj_maps[side].num_cols(HashJoinProjection::PAYLOAD) == 0; - } - - static Result> ComputePayload(const Schema& schema, - const std::vector& output, - const std::vector& filter, - const std::vector& key); -}; - class HashJoinImpl { public: - using OutputBatchCallback = std::function; + using OutputBatchCallback = std::function; using BuildFinishedCallback = std::function; - using ProbeFinishedCallback = std::function; using FinishedCallback = std::function; virtual ~HashJoinImpl() = default; virtual Status Init(ExecContext* ctx, JoinType join_type, size_t num_threads, - HashJoinSchema* schema_mgr, std::vector key_cmp, - Expression filter, OutputBatchCallback output_batch_callback, + const HashJoinProjectionMaps* proj_map_left, + const HashJoinProjectionMaps* proj_map_right, + std::vector key_cmp, Expression filter, + OutputBatchCallback output_batch_callback, FinishedCallback finished_callback, TaskScheduler* scheduler) = 0; virtual Status BuildHashTable(size_t thread_index, AccumulationQueue batches, @@ -118,8 +55,10 @@ class HashJoinImpl { virtual Status ProbeSingleBatch(size_t thread_index, ExecBatch batch) = 0; virtual Status ProbingFinished(size_t thread_index) = 0; virtual void Abort(TaskScheduler::AbortContinuationImpl pos_abort_callback) = 0; + virtual std::string ToString() const = 0; static Result> MakeBasic(); + static Result> MakeSwiss(); protected: util::tracing::Span span_; diff --git a/cpp/src/arrow/compute/exec/hash_join_benchmark.cc b/cpp/src/arrow/compute/exec/hash_join_benchmark.cc index 0786071f997..97badb84233 100644 --- a/cpp/src/arrow/compute/exec/hash_join_benchmark.cc +++ b/cpp/src/arrow/compute/exec/hash_join_benchmark.cc @@ -38,6 +38,9 @@ namespace compute { struct BenchmarkSettings { int num_threads = 1; JoinType join_type = JoinType::INNER; + // Change to 'true' to benchmark alternative, non-default and less optimized version of + // a hash join node implementation. + bool use_basic_implementation = false; int batch_size = 1024; int num_build_batches = 32; int num_probe_batches = 32 * 16; @@ -78,6 +81,8 @@ class JoinBenchmark { build_metadata["null_probability"] = std::to_string(settings.null_percentage); build_metadata["min"] = std::to_string(min_build_value); build_metadata["max"] = std::to_string(max_build_value); + build_metadata["min_length"] = "2"; + build_metadata["max_length"] = "20"; std::unordered_map probe_metadata; probe_metadata["null_probability"] = std::to_string(settings.null_percentage); @@ -132,7 +137,11 @@ class JoinBenchmark { left_keys, *r_batches_with_schema.schema, right_keys, filter, "l_", "r_")); - join_ = *HashJoinImpl::MakeBasic(); + if (settings.use_basic_implementation) { + join_ = *HashJoinImpl::MakeBasic(); + } else { + join_ = *HashJoinImpl::MakeSwiss(); + } omp_set_num_threads(settings.num_threads); auto schedule_callback = [](std::function func) -> Status { @@ -143,9 +152,9 @@ class JoinBenchmark { scheduler_ = TaskScheduler::Make(); DCHECK_OK(join_->Init( - ctx_.get(), settings.join_type, settings.num_threads, schema_mgr_.get(), - std::move(key_cmp), std::move(filter), [](ExecBatch) {}, [](int64_t x) {}, - scheduler_.get())); + ctx_.get(), settings.join_type, settings.num_threads, + &(schema_mgr_->proj_maps[0]), &(schema_mgr_->proj_maps[1]), std::move(key_cmp), + std::move(filter), [](ExecBatch) {}, [](int64_t x) {}, scheduler_.get())); task_group_probe_ = scheduler_->RegisterTaskGroup( [this](size_t thread_index, int64_t task_id) -> Status { diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc index baa82671259..73df78b46e8 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node.cc @@ -22,6 +22,7 @@ #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/hash_join.h" #include "arrow/compute/exec/hash_join_dict.h" +#include "arrow/compute/exec/hash_join_node.h" #include "arrow/compute/exec/key_hash.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/schema_util.h" @@ -560,7 +561,8 @@ struct BloomFilterPushdownContext { } } ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(std::move(keys))); - RETURN_NOT_OK(Hashing32::HashBatch(key_batch, hashes.data(), + std::vector temp_column_arrays; + RETURN_NOT_OK(Hashing32::HashBatch(key_batch, hashes.data(), temp_column_arrays, ctx_->cpu_info()->hardware_flags(), stack, 0, key_batch.length)); @@ -654,6 +656,33 @@ struct BloomFilterPushdownContext { FilterFinishedCallback on_finished_; } eval_; }; +bool HashJoinSchema::HasDictionaries() const { + for (int side = 0; side <= 1; ++side) { + for (int icol = 0; icol < proj_maps[side].num_cols(HashJoinProjection::INPUT); + ++icol) { + const std::shared_ptr& column_type = + proj_maps[side].data_type(HashJoinProjection::INPUT, icol); + if (column_type->id() == Type::DICTIONARY) { + return true; + } + } + } + return false; +} + +bool HashJoinSchema::HasLargeBinary() const { + for (int side = 0; side <= 1; ++side) { + for (int icol = 0; icol < proj_maps[side].num_cols(HashJoinProjection::INPUT); + ++icol) { + const std::shared_ptr& column_type = + proj_maps[side].data_type(HashJoinProjection::INPUT, icol); + if (is_large_binary_like(column_type->id())) { + return true; + } + } + } + return false; +} class HashJoinNode : public ExecNode { public: @@ -708,8 +737,26 @@ class HashJoinNode : public ExecNode { // Generate output schema std::shared_ptr output_schema = schema_mgr->MakeOutputSchema( join_options.output_suffix_for_left, join_options.output_suffix_for_right); + // Create hash join implementation object - ARROW_ASSIGN_OR_RAISE(std::unique_ptr impl, HashJoinImpl::MakeBasic()); + // SwissJoin does not support: + // a) 64-bit string offsets + // b) residual predicates + // c) dictionaries + // + bool use_swiss_join; +#if ARROW_LITTLE_ENDIAN + use_swiss_join = (filter == literal(true)) && !schema_mgr->HasDictionaries() && + !schema_mgr->HasLargeBinary(); +#else + use_swiss_join = false; +#endif + std::unique_ptr impl; + if (use_swiss_join) { + ARROW_ASSIGN_OR_RAISE(impl, HashJoinImpl::MakeSwiss()); + } else { + ARROW_ASSIGN_OR_RAISE(impl, HashJoinImpl::MakeBasic()); + } return plan->EmplaceNode( plan, inputs, join_options, std::move(output_schema), std::move(schema_mgr), @@ -907,8 +954,11 @@ class HashJoinNode : public ExecNode { disable_bloom_filter_, use_sync_execution); RETURN_NOT_OK(impl_->Init( - plan_->exec_context(), join_type_, num_threads, schema_mgr_.get(), key_cmp_, - filter_, [this](ExecBatch batch) { this->OutputBatchCallback(batch); }, + plan_->exec_context(), join_type_, num_threads, &(schema_mgr_->proj_maps[0]), + &(schema_mgr_->proj_maps[1]), key_cmp_, filter_, + [this](int64_t /*ignored*/, ExecBatch batch) { + this->OutputBatchCallback(batch); + }, [this](int64_t total_num_batches) { this->FinishedCallback(total_num_batches); }, scheduler_.get())); @@ -968,6 +1018,11 @@ class HashJoinNode : public ExecNode { Future<> finished() override { return task_group_.OnFinished(); } + protected: + std::string ToStringExtra(int indent = 0) const override { + return "implementation=" + impl_->ToString(); + } + private: void OutputBatchCallback(ExecBatch batch) { outputs_[0]->InputReceived(this, std::move(batch)); @@ -1124,8 +1179,11 @@ Status BloomFilterPushdownContext::BuildBloomFilter_exec_task(size_t thread_inde for (int64_t i = 0; i < key_batch.length; i += util::MiniBatch::kMiniBatchLength) { int64_t length = std::min(static_cast(key_batch.length - i), static_cast(util::MiniBatch::kMiniBatchLength)); - RETURN_NOT_OK(Hashing32::HashBatch( - key_batch, hashes, ctx_->cpu_info()->hardware_flags(), stack, i, length)); + + std::vector temp_column_arrays; + RETURN_NOT_OK(Hashing32::HashBatch(key_batch, hashes, temp_column_arrays, + ctx_->cpu_info()->hardware_flags(), stack, i, + length)); RETURN_NOT_OK(build_.builder_->PushNextBatch(thread_index, length, hashes)); } return Status::OK(); diff --git a/cpp/src/arrow/compute/exec/hash_join_node.h b/cpp/src/arrow/compute/exec/hash_join_node.h new file mode 100644 index 00000000000..8dc7ea0b8bf --- /dev/null +++ b/cpp/src/arrow/compute/exec/hash_join_node.h @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/schema_util.h" +#include "arrow/result.h" +#include "arrow/status.h" + +namespace arrow { +namespace compute { + +class ARROW_EXPORT HashJoinSchema { + public: + Status Init(JoinType join_type, const Schema& left_schema, + const std::vector& left_keys, const Schema& right_schema, + const std::vector& right_keys, const Expression& filter, + const std::string& left_field_name_prefix, + const std::string& right_field_name_prefix); + + Status Init(JoinType join_type, const Schema& left_schema, + const std::vector& left_keys, + const std::vector& left_output, const Schema& right_schema, + const std::vector& right_keys, + const std::vector& right_output, const Expression& filter, + const std::string& left_field_name_prefix, + const std::string& right_field_name_prefix); + + static Status ValidateSchemas(JoinType join_type, const Schema& left_schema, + const std::vector& left_keys, + const std::vector& left_output, + const Schema& right_schema, + const std::vector& right_keys, + const std::vector& right_output, + const std::string& left_field_name_prefix, + const std::string& right_field_name_prefix); + + bool HasDictionaries() const; + + bool HasLargeBinary() const; + + Result BindFilter(Expression filter, const Schema& left_schema, + const Schema& right_schema, ExecContext* exec_context); + std::shared_ptr MakeOutputSchema(const std::string& left_field_name_suffix, + const std::string& right_field_name_suffix); + + bool LeftPayloadIsEmpty() { return PayloadIsEmpty(0); } + + bool RightPayloadIsEmpty() { return PayloadIsEmpty(1); } + + static int kMissingField() { + return SchemaProjectionMaps::kMissingField; + } + + SchemaProjectionMaps proj_maps[2]; + + private: + static bool IsTypeSupported(const DataType& type); + + Status CollectFilterColumns(std::vector& left_filter, + std::vector& right_filter, + const Expression& filter, const Schema& left_schema, + const Schema& right_schema); + + Expression RewriteFilterToUseFilterSchema(int right_filter_offset, + const SchemaProjectionMap& left_to_filter, + const SchemaProjectionMap& right_to_filter, + const Expression& filter); + + bool PayloadIsEmpty(int side) { + ARROW_DCHECK(side == 0 || side == 1); + return proj_maps[side].num_cols(HashJoinProjection::PAYLOAD) == 0; + } + + static Result> ComputePayload(const Schema& schema, + const std::vector& output, + const std::vector& filter, + const std::vector& key); +}; + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/key_hash.cc b/cpp/src/arrow/compute/exec/key_hash.cc index 5a5d524c404..3f495bc9e60 100644 --- a/cpp/src/arrow/compute/exec/key_hash.cc +++ b/cpp/src/arrow/compute/exec/key_hash.cc @@ -458,9 +458,9 @@ void Hashing32::HashMultiColumn(const std::vector& cols, } Status Hashing32::HashBatch(const ExecBatch& key_batch, uint32_t* hashes, + std::vector& column_arrays, int64_t hardware_flags, util::TempVectorStack* temp_stack, int64_t offset, int64_t length) { - std::vector column_arrays; RETURN_NOT_OK(ColumnArraysFromExecBatch(key_batch, offset, length, &column_arrays)); LightContext ctx; @@ -890,9 +890,9 @@ void Hashing64::HashMultiColumn(const std::vector& cols, } Status Hashing64::HashBatch(const ExecBatch& key_batch, uint64_t* hashes, + std::vector& column_arrays, int64_t hardware_flags, util::TempVectorStack* temp_stack, int64_t offset, int64_t length) { - std::vector column_arrays; RETURN_NOT_OK(ColumnArraysFromExecBatch(key_batch, offset, length, &column_arrays)); LightContext ctx; diff --git a/cpp/src/arrow/compute/exec/key_hash.h b/cpp/src/arrow/compute/exec/key_hash.h index f8af7988387..68197973e02 100644 --- a/cpp/src/arrow/compute/exec/key_hash.h +++ b/cpp/src/arrow/compute/exec/key_hash.h @@ -49,6 +49,7 @@ class ARROW_EXPORT Hashing32 { uint32_t* out_hash); static Status HashBatch(const ExecBatch& key_batch, uint32_t* hashes, + std::vector& column_arrays, int64_t hardware_flags, util::TempVectorStack* temp_stack, int64_t offset, int64_t length); @@ -161,6 +162,7 @@ class ARROW_EXPORT Hashing64 { uint64_t* hashes); static Status HashBatch(const ExecBatch& key_batch, uint64_t* hashes, + std::vector& column_arrays, int64_t hardware_flags, util::TempVectorStack* temp_stack, int64_t offset, int64_t length); diff --git a/cpp/src/arrow/compute/exec/key_map.cc b/cpp/src/arrow/compute/exec/key_map.cc index fe5ed98bb3e..a61184e4ca9 100644 --- a/cpp/src/arrow/compute/exec/key_map.cc +++ b/cpp/src/arrow/compute/exec/key_map.cc @@ -42,8 +42,8 @@ constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL; // b) first empty slot is encountered, // c) we reach the end of the block. // -// Optionally an index of the first slot to start the search from can be specified. -// In this case slots before it will be ignored. +// Optionally an index of the first slot to start the search from can be specified. In +// this case slots before it will be ignored. // template inline void SwissTable::search_block(uint64_t block, int stamp, int start_slot, @@ -88,29 +88,12 @@ inline void SwissTable::search_block(uint64_t block, int stamp, int start_slot, // We get 0 if there are no matches *out_match_found = (matches == 0 ? 0 : 1); - // Now if we or with the highest bits of the block and scan zero bits in reverse, - // we get 8x slot index that we were looking for. - // This formula works in all three cases a), b) and c). + // Now if we or with the highest bits of the block and scan zero bits in reverse, we get + // 8x slot index that we were looking for. This formula works in all three cases a), b) + // and c). *out_slot = static_cast(CountLeadingZeros(matches | block_high_bits) >> 3); } -inline uint64_t SwissTable::extract_group_id(const uint8_t* block_ptr, int slot, - uint64_t group_id_mask) const { - // Group id values for all 8 slots in the block are bit-packed and follow the status - // bytes. We assume here that the number of bits is rounded up to 8, 16, 32 or 64. In - // that case we can extract group id using aligned 64-bit word access. - int num_group_id_bits = static_cast(ARROW_POPCOUNT64(group_id_mask)); - ARROW_DCHECK(num_group_id_bits == 8 || num_group_id_bits == 16 || - num_group_id_bits == 32 || num_group_id_bits == 64); - - int bit_offset = slot * num_group_id_bits; - const uint64_t* group_id_bytes = - reinterpret_cast(block_ptr) + 1 + (bit_offset >> 6); - uint64_t group_id = (*group_id_bytes >> (bit_offset & 63)) & group_id_mask; - - return group_id; -} - template void SwissTable::extract_group_ids_imp(const int num_keys, const uint16_t* selection, const uint32_t* hashes, const uint8_t* local_slots, @@ -147,14 +130,16 @@ void SwissTable::extract_group_ids(const int num_keys, const uint16_t* optional_ ARROW_DCHECK(num_group_id_bits == 8 || num_group_id_bits == 16 || num_group_id_bits == 32); + int num_processed = 0; + // Optimistically use simplified lookup involving only a start block to find // a single group id candidate for every input. #if defined(ARROW_HAVE_AVX2) int num_group_id_bytes = num_group_id_bits / 8; if ((hardware_flags_ & arrow::internal::CpuInfo::AVX2) && !optional_selection) { - extract_group_ids_avx2(num_keys, hashes, local_slots, out_group_ids, sizeof(uint64_t), - 8 + 8 * num_group_id_bytes, num_group_id_bytes); - return; + num_processed = extract_group_ids_avx2(num_keys, hashes, local_slots, out_group_ids, + sizeof(uint64_t), 8 + 8 * num_group_id_bytes, + num_group_id_bytes); } #endif switch (num_group_id_bits) { @@ -163,8 +148,9 @@ void SwissTable::extract_group_ids(const int num_keys, const uint16_t* optional_ extract_group_ids_imp(num_keys, optional_selection, hashes, local_slots, out_group_ids, 8, 16); } else { - extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, - out_group_ids, 8, 16); + extract_group_ids_imp( + num_keys - num_processed, nullptr, hashes + num_processed, + local_slots + num_processed, out_group_ids + num_processed, 8, 16); } break; case 16: @@ -172,8 +158,9 @@ void SwissTable::extract_group_ids(const int num_keys, const uint16_t* optional_ extract_group_ids_imp(num_keys, optional_selection, hashes, local_slots, out_group_ids, 4, 12); } else { - extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, - out_group_ids, 4, 12); + extract_group_ids_imp( + num_keys - num_processed, nullptr, hashes + num_processed, + local_slots + num_processed, out_group_ids + num_processed, 4, 12); } break; case 32: @@ -181,8 +168,9 @@ void SwissTable::extract_group_ids(const int num_keys, const uint16_t* optional_ extract_group_ids_imp(num_keys, optional_selection, hashes, local_slots, out_group_ids, 2, 10); } else { - extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, - out_group_ids, 2, 10); + extract_group_ids_imp( + num_keys - num_processed, nullptr, hashes + num_processed, + local_slots + num_processed, out_group_ids + num_processed, 2, 10); } break; default: @@ -312,24 +300,21 @@ void SwissTable::early_filter(const int num_keys, const uint32_t* hashes, uint8_t* out_local_slots) const { // Optimistically use simplified lookup involving only a start block to find // a single group id candidate for every input. + int num_processed = 0; #if defined(ARROW_HAVE_AVX2) if (hardware_flags_ & arrow::internal::CpuInfo::AVX2) { if (log_blocks_ <= 4) { - int tail = num_keys % 32; - int delta = num_keys - tail; - early_filter_imp_avx2_x32(num_keys - tail, hashes, out_match_bitvector, - out_local_slots); - early_filter_imp_avx2_x8(tail, hashes + delta, out_match_bitvector + delta / 8, - out_local_slots + delta); - } else { - early_filter_imp_avx2_x8(num_keys, hashes, out_match_bitvector, out_local_slots); + num_processed = early_filter_imp_avx2_x32(num_keys, hashes, out_match_bitvector, + out_local_slots); } - } else { -#endif - early_filter_imp(num_keys, hashes, out_match_bitvector, out_local_slots); -#if defined(ARROW_HAVE_AVX2) + num_processed += early_filter_imp_avx2_x8( + num_keys - num_processed, hashes + num_processed, + out_match_bitvector + num_processed / 8, out_local_slots + num_processed); } #endif + early_filter_imp(num_keys - num_processed, hashes + num_processed, + out_match_bitvector + num_processed / 8, + out_local_slots + num_processed); } // Input selection may be: @@ -348,10 +333,16 @@ void SwissTable::run_comparisons(const int num_keys, const uint16_t* optional_selection_ids, const uint8_t* optional_selection_bitvector, const uint32_t* groupids, int* out_num_not_equal, - uint16_t* out_not_equal_selection) const { + uint16_t* out_not_equal_selection, + const EqualImpl& equal_impl, void* callback_ctx) const { ARROW_DCHECK(optional_selection_ids || optional_selection_bitvector); ARROW_DCHECK(!optional_selection_ids || !optional_selection_bitvector); + if (num_keys == 0) { + *out_num_not_equal = 0; + return; + } + if (!optional_selection_ids && optional_selection_bitvector) { // Count rows with matches (based on stamp comparison) // and decide based on their percentage whether to call dense or sparse comparison @@ -368,21 +359,22 @@ void SwissTable::run_comparisons(const int num_keys, if (num_inserted_ > 0 && num_matches > 0 && num_matches > 3 * num_keys / 4) { uint32_t out_num; - equal_impl_(num_keys, nullptr, groupids, &out_num, out_not_equal_selection); + equal_impl(num_keys, nullptr, groupids, &out_num, out_not_equal_selection, + callback_ctx); *out_num_not_equal = static_cast(out_num); } else { util::bit_util::bits_to_indexes(1, hardware_flags_, num_keys, optional_selection_bitvector, out_num_not_equal, out_not_equal_selection); uint32_t out_num; - equal_impl_(*out_num_not_equal, out_not_equal_selection, groupids, &out_num, - out_not_equal_selection); + equal_impl(*out_num_not_equal, out_not_equal_selection, groupids, &out_num, + out_not_equal_selection, callback_ctx); *out_num_not_equal = static_cast(out_num); } } else { uint32_t out_num; - equal_impl_(num_keys, optional_selection_ids, groupids, &out_num, - out_not_equal_selection); + equal_impl(num_keys, optional_selection_ids, groupids, &out_num, + out_not_equal_selection, callback_ctx); *out_num_not_equal = static_cast(out_num); } } @@ -432,35 +424,6 @@ bool SwissTable::find_next_stamp_match(const uint32_t hash, const uint32_t in_sl return match_found; } -void SwissTable::insert_into_empty_slot(uint32_t slot_id, uint32_t hash, - uint32_t group_id) { - const uint64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); - - // We assume here that the number of bits is rounded up to 8, 16, 32 or 64. - // In that case we can insert group id value using aligned 64-bit word access. - ARROW_DCHECK(num_groupid_bits == 8 || num_groupid_bits == 16 || - num_groupid_bits == 32 || num_groupid_bits == 64); - - const uint64_t num_block_bytes = (8 + num_groupid_bits); - constexpr uint64_t stamp_mask = 0x7f; - - int start_slot = (slot_id & 7); - int stamp = - static_cast((hash >> (bits_hash_ - log_blocks_ - bits_stamp_)) & stamp_mask); - uint64_t block_id = slot_id >> 3; - uint8_t* blockbase = blocks_ + num_block_bytes * block_id; - - blockbase[7 - start_slot] = static_cast(stamp); - int groupid_bit_offset = static_cast(start_slot * num_groupid_bits); - - // Block status bytes should start at an address aligned to 8 bytes - ARROW_DCHECK((reinterpret_cast(blockbase) & 7) == 0); - uint64_t* ptr = reinterpret_cast(blockbase) + 1 + (groupid_bit_offset >> 6); - *ptr |= (static_cast(group_id) << (groupid_bit_offset & 63)); - - hashes_[slot_id] = hash; -} - // Find method is the continuation of processing from early_filter. // Its input consists of hash values and the output of early_filter. // It updates match bit-vector, clearing it from any false positives @@ -471,7 +434,8 @@ void SwissTable::insert_into_empty_slot(uint32_t slot_id, uint32_t hash, // void SwissTable::find(const int num_keys, const uint32_t* hashes, uint8_t* inout_match_bitvector, const uint8_t* local_slots, - uint32_t* out_group_ids) const { + uint32_t* out_group_ids, util::TempVectorStack* temp_stack, + const EqualImpl& equal_impl, void* callback_ctx) const { // Temporary selection vector. // It will hold ids of keys for which we do not know yet // if they have a match in hash table or not. @@ -481,12 +445,12 @@ void SwissTable::find(const int num_keys, const uint32_t* hashes, // to array of ids. // ARROW_DCHECK(num_keys <= (1 << log_minibatch_)); - auto ids_buf = util::TempVectorHolder(temp_stack_, num_keys); + auto ids_buf = util::TempVectorHolder(temp_stack, num_keys); uint16_t* ids = ids_buf.mutable_data(); int num_ids; - int64_t num_matches = - arrow::internal::CountSetBits(inout_match_bitvector, /*offset=*/0, num_keys); + int64_t num_matches = arrow::internal::CountSetBits(inout_match_bitvector, + /*offset=*/0, num_keys); // If there is a high density of selected input rows // (majority of them are present in the selection), @@ -498,19 +462,20 @@ void SwissTable::find(const int num_keys, const uint32_t* hashes, if (visit_all) { extract_group_ids(num_keys, nullptr, hashes, local_slots, out_group_ids); run_comparisons(num_keys, nullptr, inout_match_bitvector, out_group_ids, &num_ids, - ids); + ids, equal_impl, callback_ctx); } else { util::bit_util::bits_to_indexes(1, hardware_flags_, num_keys, inout_match_bitvector, &num_ids, ids); extract_group_ids(num_ids, ids, hashes, local_slots, out_group_ids); - run_comparisons(num_ids, ids, nullptr, out_group_ids, &num_ids, ids); + run_comparisons(num_ids, ids, nullptr, out_group_ids, &num_ids, ids, equal_impl, + callback_ctx); } if (num_ids == 0) { return; } - auto slot_ids_buf = util::TempVectorHolder(temp_stack_, num_keys); + auto slot_ids_buf = util::TempVectorHolder(temp_stack, num_keys); uint32_t* slot_ids = slot_ids_buf.mutable_data(); init_slot_ids(num_ids, ids, hashes, local_slots, inout_match_bitvector, slot_ids); @@ -531,9 +496,10 @@ void SwissTable::find(const int num_keys, const uint32_t* hashes, } } - run_comparisons(num_ids, ids, nullptr, out_group_ids, &num_ids, ids); + run_comparisons(num_ids, ids, nullptr, out_group_ids, &num_ids, ids, equal_impl, + callback_ctx); } -} // namespace compute +} // Slow processing of input keys in the most generic case. // Handles inserting new keys. @@ -545,11 +511,11 @@ void SwissTable::find(const int num_keys, const uint32_t* hashes, // Update selection vector to reflect which items have been processed. // Ids in selection vector do not have to be sorted. // -Status SwissTable::map_new_keys_helper(const uint32_t* hashes, - uint32_t* inout_num_selected, - uint16_t* inout_selection, bool* out_need_resize, - uint32_t* out_group_ids, - uint32_t* inout_next_slot_ids) { +Status SwissTable::map_new_keys_helper( + const uint32_t* hashes, uint32_t* inout_num_selected, uint16_t* inout_selection, + bool* out_need_resize, uint32_t* out_group_ids, uint32_t* inout_next_slot_ids, + util::TempVectorStack* temp_stack, const EqualImpl& equal_impl, + const AppendImpl& append_impl, void* callback_ctx) { auto num_groups_limit = num_groups_for_resize(); ARROW_DCHECK(num_inserted_ < num_groups_limit); @@ -560,7 +526,7 @@ Status SwissTable::map_new_keys_helper(const uint32_t* hashes, size_t num_bytes_for_bits = (*inout_num_selected + 7) / 8 + sizeof(uint64_t); auto match_bitvector_buf = util::TempVectorHolder( - temp_stack_, static_cast(num_bytes_for_bits)); + temp_stack, static_cast(num_bytes_for_bits)); uint8_t* match_bitvector = match_bitvector_buf.mutable_data(); memset(match_bitvector, 0xff, num_bytes_for_bits); @@ -580,11 +546,12 @@ Status SwissTable::map_new_keys_helper(const uint32_t* hashes, // out_group_ids[id] = num_inserted_ + num_inserted_new; insert_into_empty_slot(inout_next_slot_ids[id], hashes[id], out_group_ids[id]); + hashes_[inout_next_slot_ids[id]] = hashes[id]; ::arrow::bit_util::ClearBit(match_bitvector, num_processed); ++num_inserted_new; - // We need to break processing and have the caller of this function - // resize hash table if we reach the limit of the number of groups present. + // We need to break processing and have the caller of this function resize hash + // table if we reach the limit of the number of groups present. // if (num_inserted_ + num_inserted_new == num_groups_limit) { ++num_processed; @@ -594,7 +561,7 @@ Status SwissTable::map_new_keys_helper(const uint32_t* hashes, } auto temp_ids_buffer = - util::TempVectorHolder(temp_stack_, *inout_num_selected); + util::TempVectorHolder(temp_stack, *inout_num_selected); uint16_t* temp_ids = temp_ids_buffer.mutable_data(); int num_temp_ids = 0; @@ -603,16 +570,18 @@ Status SwissTable::map_new_keys_helper(const uint32_t* hashes, util::bit_util::bits_filter_indexes(0, hardware_flags_, num_processed, match_bitvector, inout_selection, &num_temp_ids, temp_ids); ARROW_DCHECK(static_cast(num_inserted_new) == num_temp_ids); - RETURN_NOT_OK(append_impl_(num_inserted_new, temp_ids)); + RETURN_NOT_OK(append_impl(num_inserted_new, temp_ids, callback_ctx)); num_inserted_ += num_inserted_new; // Evaluate comparisons and append ids of rows that failed it to the non-match set. util::bit_util::bits_filter_indexes(1, hardware_flags_, num_processed, match_bitvector, inout_selection, &num_temp_ids, temp_ids); - run_comparisons(num_temp_ids, temp_ids, nullptr, out_group_ids, &num_temp_ids, - temp_ids); + run_comparisons(num_temp_ids, temp_ids, nullptr, out_group_ids, &num_temp_ids, temp_ids, + equal_impl, callback_ctx); - memcpy(inout_selection, temp_ids, sizeof(uint16_t) * num_temp_ids); + if (num_temp_ids > 0) { + memcpy(inout_selection, temp_ids, sizeof(uint16_t) * num_temp_ids); + } // Append ids of any unprocessed entries if we aborted processing due to the need // to resize. if (num_processed < *inout_num_selected) { @@ -629,7 +598,9 @@ Status SwissTable::map_new_keys_helper(const uint32_t* hashes, // this set). // Status SwissTable::map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t* hashes, - uint32_t* group_ids) { + uint32_t* group_ids, util::TempVectorStack* temp_stack, + const EqualImpl& equal_impl, + const AppendImpl& append_impl, void* callback_ctx) { if (num_ids == 0) { return Status::OK(); } @@ -645,7 +616,7 @@ Status SwissTable::map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t* ARROW_DCHECK(static_cast(max_id + 1) <= (1 << log_minibatch_)); // Allocate temporary buffers for slot ids and intialize them - auto slot_ids_buf = util::TempVectorHolder(temp_stack_, max_id + 1); + auto slot_ids_buf = util::TempVectorHolder(temp_stack, max_id + 1); uint32_t* slot_ids = slot_ids_buf.mutable_data(); init_slot_ids_for_new_keys(num_ids, ids, hashes, slot_ids); @@ -658,7 +629,8 @@ Status SwissTable::map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t* // bigger hash table. bool out_of_capacity; RETURN_NOT_OK(map_new_keys_helper(hashes, &num_ids, ids, &out_of_capacity, group_ids, - slot_ids)); + slot_ids, temp_stack, equal_impl, append_impl, + callback_ctx)); if (out_of_capacity) { RETURN_NOT_OK(grow_double()); // Reset start slot ids for still unprocessed input keys. @@ -803,17 +775,13 @@ Status SwissTable::grow_double() { return Status::OK(); } -Status SwissTable::init(int64_t hardware_flags, MemoryPool* pool, - util::TempVectorStack* temp_stack, int log_minibatch, - EqualImpl equal_impl, AppendImpl append_impl) { +Status SwissTable::init(int64_t hardware_flags, MemoryPool* pool, int log_blocks, + bool no_hash_array) { hardware_flags_ = hardware_flags; pool_ = pool; - temp_stack_ = temp_stack; - log_minibatch_ = log_minibatch; - equal_impl_ = equal_impl; - append_impl_ = append_impl; + log_minibatch_ = util::MiniBatch::kLogMiniBatchLength; - log_blocks_ = 0; + log_blocks_ = log_blocks; int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); num_inserted_ = 0; @@ -829,12 +797,16 @@ Status SwissTable::init(int64_t hardware_flags, MemoryPool* pool, util::SafeStore(blocks_ + i * block_bytes, kHighBitOfEachByte); } - uint64_t num_slots = 1ULL << (log_blocks_ + 3); - const uint64_t hash_size = sizeof(uint32_t); - const uint64_t hash_bytes = hash_size * num_slots + padding_; - uint8_t* hashes8; - RETURN_NOT_OK(pool_->Allocate(hash_bytes, &hashes8)); - hashes_ = reinterpret_cast(hashes8); + if (no_hash_array) { + hashes_ = nullptr; + } else { + uint64_t num_slots = 1ULL << (log_blocks_ + 3); + const uint64_t hash_size = sizeof(uint32_t); + const uint64_t hash_bytes = hash_size * num_slots + padding_; + uint8_t* hashes8; + RETURN_NOT_OK(pool_->Allocate(hash_bytes, &hashes8)); + hashes_ = reinterpret_cast(hashes8); + } return Status::OK(); } diff --git a/cpp/src/arrow/compute/exec/key_map.h b/cpp/src/arrow/compute/exec/key_map.h index 12c1e393c4a..cc630e0b1c3 100644 --- a/cpp/src/arrow/compute/exec/key_map.h +++ b/cpp/src/arrow/compute/exec/key_map.h @@ -27,7 +27,17 @@ namespace arrow { namespace compute { +// SwissTable is a variant of a hash table implementation. +// This implementation is vectorized, that is: main interface methods take arrays of input +// values and output arrays of result values. +// +// A detailed explanation of this data structure (including concepts such as blocks, +// slots, stamps) and operations provided by this class is given in the document: +// arrow/compute/exec/doc/key_map.md. +// class SwissTable { + friend class SwissTableMerge; + public: SwissTable() = default; ~SwissTable() { cleanup(); } @@ -35,11 +45,12 @@ class SwissTable { using EqualImpl = std::function; - using AppendImpl = std::function; + uint16_t* out_selection_mismatch, void* callback_ctx)>; + using AppendImpl = + std::function; - Status init(int64_t hardware_flags, MemoryPool* pool, util::TempVectorStack* temp_stack, - int log_minibatch, EqualImpl equal_impl, AppendImpl append_impl); + Status init(int64_t hardware_flags, MemoryPool* pool, int log_blocks = 0, + bool no_hash_array = false); void cleanup(); @@ -47,10 +58,22 @@ class SwissTable { uint8_t* out_match_bitvector, uint8_t* out_local_slots) const; void find(const int num_keys, const uint32_t* hashes, uint8_t* inout_match_bitvector, - const uint8_t* local_slots, uint32_t* out_group_ids) const; + const uint8_t* local_slots, uint32_t* out_group_ids, + util::TempVectorStack* temp_stack, const EqualImpl& equal_impl, + void* callback_ctx) const; Status map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t* hashes, - uint32_t* group_ids); + uint32_t* group_ids, util::TempVectorStack* temp_stack, + const EqualImpl& equal_impl, const AppendImpl& append_impl, + void* callback_ctx); + + int minibatch_size() const { return 1 << log_minibatch_; } + + int64_t num_inserted() const { return num_inserted_; } + + int64_t hardware_flags() const { return hardware_flags_; } + + MemoryPool* pool() const { return pool_; } private: // Lookup helpers @@ -116,21 +139,22 @@ class SwissTable { void early_filter_imp(const int num_keys, const uint32_t* hashes, uint8_t* out_match_bitvector, uint8_t* out_local_slots) const; #if defined(ARROW_HAVE_AVX2) - void early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* hashes, + int early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* hashes, + uint8_t* out_match_bitvector, + uint8_t* out_local_slots) const; + int early_filter_imp_avx2_x32(const int num_hashes, const uint32_t* hashes, uint8_t* out_match_bitvector, uint8_t* out_local_slots) const; - void early_filter_imp_avx2_x32(const int num_hashes, const uint32_t* hashes, - uint8_t* out_match_bitvector, - uint8_t* out_local_slots) const; - void extract_group_ids_avx2(const int num_keys, const uint32_t* hashes, - const uint8_t* local_slots, uint32_t* out_group_ids, - int byte_offset, int byte_multiplier, int byte_size) const; + int extract_group_ids_avx2(const int num_keys, const uint32_t* hashes, + const uint8_t* local_slots, uint32_t* out_group_ids, + int byte_offset, int byte_multiplier, int byte_size) const; #endif void run_comparisons(const int num_keys, const uint16_t* optional_selection_ids, const uint8_t* optional_selection_bitvector, const uint32_t* groupids, int* out_num_not_equal, - uint16_t* out_not_equal_selection) const; + uint16_t* out_not_equal_selection, const EqualImpl& equal_impl, + void* callback_ctx) const; inline bool find_next_stamp_match(const uint32_t hash, const uint32_t in_slot_id, uint32_t* out_slot_id, uint32_t* out_group_id) const; @@ -145,7 +169,10 @@ class SwissTable { // Status map_new_keys_helper(const uint32_t* hashes, uint32_t* inout_num_selected, uint16_t* inout_selection, bool* out_need_resize, - uint32_t* out_group_ids, uint32_t* out_next_slot_ids); + uint32_t* out_group_ids, uint32_t* out_next_slot_ids, + util::TempVectorStack* temp_stack, + const EqualImpl& equal_impl, const AppendImpl& append_impl, + void* callback_ctx); // Resize small hash tables when 50% full (up to 8KB). // Resize large hash tables when 75% full. @@ -198,11 +225,51 @@ class SwissTable { int64_t hardware_flags_; MemoryPool* pool_; - util::TempVectorStack* temp_stack_; - - EqualImpl equal_impl_; - AppendImpl append_impl_; }; +uint64_t SwissTable::extract_group_id(const uint8_t* block_ptr, int slot, + uint64_t group_id_mask) const { + // Group id values for all 8 slots in the block are bit-packed and follow the status + // bytes. We assume here that the number of bits is rounded up to 8, 16, 32 or 64. In + // that case we can extract group id using aligned 64-bit word access. + int num_group_id_bits = static_cast(ARROW_POPCOUNT64(group_id_mask)); + ARROW_DCHECK(num_group_id_bits == 8 || num_group_id_bits == 16 || + num_group_id_bits == 32 || num_group_id_bits == 64); + + int bit_offset = slot * num_group_id_bits; + const uint64_t* group_id_bytes = + reinterpret_cast(block_ptr) + 1 + (bit_offset >> 6); + uint64_t group_id = (*group_id_bytes >> (bit_offset & 63)) & group_id_mask; + + return group_id; +} + +void SwissTable::insert_into_empty_slot(uint32_t slot_id, uint32_t hash, + uint32_t group_id) { + const uint64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); + + // We assume here that the number of bits is rounded up to 8, 16, 32 or 64. + // In that case we can insert group id value using aligned 64-bit word access. + ARROW_DCHECK(num_groupid_bits == 8 || num_groupid_bits == 16 || + num_groupid_bits == 32 || num_groupid_bits == 64); + + const uint64_t num_block_bytes = (8 + num_groupid_bits); + constexpr uint64_t stamp_mask = 0x7f; + + int start_slot = (slot_id & 7); + int stamp = + static_cast((hash >> (bits_hash_ - log_blocks_ - bits_stamp_)) & stamp_mask); + uint64_t block_id = slot_id >> 3; + uint8_t* blockbase = blocks_ + num_block_bytes * block_id; + + blockbase[7 - start_slot] = static_cast(stamp); + int groupid_bit_offset = static_cast(start_slot * num_groupid_bits); + + // Block status bytes should start at an address aligned to 8 bytes + ARROW_DCHECK((reinterpret_cast(blockbase) & 7) == 0); + uint64_t* ptr = reinterpret_cast(blockbase) + 1 + (groupid_bit_offset >> 6); + *ptr |= (static_cast(group_id) << (groupid_bit_offset & 63)); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/key_map_avx2.cc b/cpp/src/arrow/compute/exec/key_map_avx2.cc index 2fca6bf6c10..4c77f3af237 100644 --- a/cpp/src/arrow/compute/exec/key_map_avx2.cc +++ b/cpp/src/arrow/compute/exec/key_map_avx2.cc @@ -24,21 +24,15 @@ namespace compute { #if defined(ARROW_HAVE_AVX2) -// Why it is OK to round up number of rows internally: -// All of the buffers: hashes, out_match_bitvector, out_group_ids, out_next_slot_ids -// are temporary buffers of group id mapping. -// Temporary buffers are buffers that live only within the boundaries of a single -// minibatch. Temporary buffers add 64B at the end, so that SIMD code does not have to -// worry about reading and writing outside of the end of the buffer up to 64B. If the -// hashes array contains garbage after the last element, it cannot cause computation to -// fail, since any random data is a valid hash for the purpose of lookup. +// This is more or less translation of equivalent scalar code, adjusted for a +// different instruction set (e.g. missing leading zero count instruction). // -// This is more or less translation of equivalent scalar code, adjusted for a different -// instruction set (e.g. missing leading zero count instruction). +// Returns the number of hashes actually processed, which may be less than +// requested due to alignment required by SIMD. // -void SwissTable::early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* hashes, - uint8_t* out_match_bitvector, - uint8_t* out_local_slots) const { +int SwissTable::early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* hashes, + uint8_t* out_match_bitvector, + uint8_t* out_local_slots) const { // Number of inputs processed together in a loop constexpr int unroll = 8; @@ -46,8 +40,7 @@ void SwissTable::early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* const __m256i* vhash_ptr = reinterpret_cast(hashes); const __m256i vstamp_mask = _mm256_set1_epi32((1 << bits_stamp_) - 1); - // TODO: explain why it is ok to process hashes outside of buffer boundaries - for (int i = 0; i < ((num_hashes + unroll - 1) / unroll); ++i) { + for (int i = 0; i < num_hashes / unroll; ++i) { constexpr uint64_t kEachByteIs8 = 0x0808080808080808ULL; constexpr uint64_t kByteSequenceOfPowersOf2 = 0x8040201008040201ULL; @@ -139,6 +132,8 @@ void SwissTable::early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* out_match_bitvector[i] = _pext_u32(_mm256_movemask_epi8(vmatch_found), 0x11111111); // 0b00010001 repeated 4x } + + return num_hashes - (num_hashes % unroll); } // Take a set of 16 64-bit elements, @@ -173,8 +168,8 @@ inline void split_bytes_avx2(__m256i word0, __m256i word1, __m256i word2, __m256 // k4, o4, l4, p4, ... k7, o7, l7, p7} __m256i byte01 = _mm256_unpacklo_epi32( - a, b); // {a0, e0, b0, f0, i0, m0, j0, n0, a1, e1, b1, f1, i1, m1, j1, n1, c0, g0, - // d0, h0, k0, o0, l0, p0, ...} + a, b); // {a0, e0, b0, f0, i0, m0, j0, n0, a1, e1, b1, f1, i1, m1, j1, n1, + // c0, g0, d0, h0, k0, o0, l0, p0, ...} __m256i shuffle_const = _mm256_setr_epi8(0, 2, 8, 10, 1, 3, 9, 11, 4, 6, 12, 14, 5, 7, 13, 15, 0, 2, 8, 10, 1, 3, 9, 11, 4, 6, 12, 14, 5, 7, 13, 15); @@ -206,9 +201,13 @@ inline void split_bytes_avx2(__m256i word0, __m256i word1, __m256i word2, __m256 // using a different method. // TODO: Explain the idea behind storing arrays in SIMD registers. // Explain why it is faster with SIMD than using memory loads. -void SwissTable::early_filter_imp_avx2_x32(const int num_hashes, const uint32_t* hashes, - uint8_t* out_match_bitvector, - uint8_t* out_local_slots) const { +// +// Returns the number of hashes actually processed, which may be less than +// requested due to alignment required by SIMD. +// +int SwissTable::early_filter_imp_avx2_x32(const int num_hashes, const uint32_t* hashes, + uint8_t* out_match_bitvector, + uint8_t* out_local_slots) const { constexpr int unroll = 32; // There is a limit on the number of input blocks, @@ -366,12 +365,14 @@ void SwissTable::early_filter_imp_avx2_x32(const int num_hashes, const uint32_t* reinterpret_cast(out_match_bitvector)[i] = _mm256_movemask_epi8(vmatch_found); } + + return num_hashes - (num_hashes % unroll); } -void SwissTable::extract_group_ids_avx2(const int num_keys, const uint32_t* hashes, - const uint8_t* local_slots, - uint32_t* out_group_ids, int byte_offset, - int byte_multiplier, int byte_size) const { +int SwissTable::extract_group_ids_avx2(const int num_keys, const uint32_t* hashes, + const uint8_t* local_slots, + uint32_t* out_group_ids, int byte_offset, + int byte_multiplier, int byte_size) const { ARROW_DCHECK(byte_size == 1 || byte_size == 2 || byte_size == 4); uint32_t mask = byte_size == 1 ? 0xFF : byte_size == 2 ? 0xFFFF : 0xFFFFFFFF; auto elements = reinterpret_cast(blocks_ + byte_offset); @@ -380,7 +381,7 @@ void SwissTable::extract_group_ids_avx2(const int num_keys, const uint32_t* hash ARROW_DCHECK(byte_size == 1 && byte_offset == 8 && byte_multiplier == 16); __m256i block_group_ids = _mm256_set1_epi64x(reinterpret_cast(blocks_)[1]); - for (int i = 0; i < (num_keys + unroll - 1) / unroll; ++i) { + for (int i = 0; i < num_keys / unroll; ++i) { __m256i local_slot = _mm256_set1_epi64x(reinterpret_cast(local_slots)[i]); __m256i group_id = _mm256_shuffle_epi8(block_group_ids, local_slot); @@ -390,7 +391,7 @@ void SwissTable::extract_group_ids_avx2(const int num_keys, const uint32_t* hash _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + i, group_id); } } else { - for (int i = 0; i < (num_keys + unroll - 1) / unroll; ++i) { + for (int i = 0; i < num_keys / unroll; ++i) { __m256i hash = _mm256_loadu_si256(reinterpret_cast(hashes) + i); __m256i local_slot = _mm256_set1_epi64x(reinterpret_cast(local_slots)[i]); @@ -406,6 +407,7 @@ void SwissTable::extract_group_ids_avx2(const int num_keys, const uint32_t* hash _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + i, group_id); } } + return num_keys - (num_keys % unroll); } #endif diff --git a/cpp/src/arrow/compute/exec/partition_util.h b/cpp/src/arrow/compute/exec/partition_util.h index 07fb91f2f1a..b3f302511a7 100644 --- a/cpp/src/arrow/compute/exec/partition_util.h +++ b/cpp/src/arrow/compute/exec/partition_util.h @@ -118,6 +118,54 @@ class PartitionLocks { /// \brief Release a partition so that other threads can work on it void ReleasePartitionLock(int prtn_id); + // Executes (synchronously and using current thread) the same operation on a set of + // multiple partitions. Tries to minimize partition locking overhead by randomizing and + // adjusting order in which partitions are processed. + // + // PROCESS_PRTN_FN is a callback which will be executed for each partition after + // acquiring the lock for that partition. It gets partition id as an argument. + // IS_PRTN_EMPTY_FN is a callback which filters out (when returning true) partitions + // with specific ids from processing. + // + template + Status ForEachPartition(size_t thread_id, + /*scratch space buffer with space for one element per partition; + dirty in and dirty out*/ + int* temp_unprocessed_prtns, IS_PRTN_EMPTY_FN is_prtn_empty_fn, + PROCESS_PRTN_FN process_prtn_fn) { + int num_unprocessed_partitions = 0; + for (int i = 0; i < num_prtns_; ++i) { + bool is_prtn_empty = is_prtn_empty_fn(i); + if (!is_prtn_empty) { + temp_unprocessed_prtns[num_unprocessed_partitions++] = i; + } + } + while (num_unprocessed_partitions > 0) { + int locked_prtn_id; + int locked_prtn_id_pos; + AcquirePartitionLock(thread_id, num_unprocessed_partitions, temp_unprocessed_prtns, + /*limit_retries=*/false, /*max_retries=*/-1, &locked_prtn_id, + &locked_prtn_id_pos); + { + class AutoReleaseLock { + public: + AutoReleaseLock(PartitionLocks* locks, int prtn_id) + : locks(locks), prtn_id(prtn_id) {} + ~AutoReleaseLock() { locks->ReleasePartitionLock(prtn_id); } + PartitionLocks* locks; + int prtn_id; + } auto_release_lock(this, locked_prtn_id); + ARROW_RETURN_NOT_OK(process_prtn_fn(locked_prtn_id)); + } + if (locked_prtn_id_pos < num_unprocessed_partitions - 1) { + temp_unprocessed_prtns[locked_prtn_id_pos] = + temp_unprocessed_prtns[num_unprocessed_partitions - 1]; + } + --num_unprocessed_partitions; + } + return Status::OK(); + } + private: std::atomic* lock_ptr(int prtn_id); int random_int(size_t thread_id, int num_values); diff --git a/cpp/src/arrow/compute/exec/schema_util.h b/cpp/src/arrow/compute/exec/schema_util.h index 91b7e6cfc6e..f2b14aa5450 100644 --- a/cpp/src/arrow/compute/exec/schema_util.h +++ b/cpp/src/arrow/compute/exec/schema_util.h @@ -24,7 +24,6 @@ #include "arrow/compute/light_array.h" // for KeyColumnMetadata #include "arrow/type.h" // for DataType, FieldRef, Field and Schema -#include "arrow/util/mutex.h" namespace arrow { @@ -79,16 +78,28 @@ class SchemaProjectionMaps { int num_cols(ProjectionIdEnum schema_handle) const { int id = schema_id(schema_handle); - return static_cast(schemas_[id].second.size()); + return static_cast(schemas_[id].second.data_types.size()); + } + + bool is_empty(ProjectionIdEnum schema_handle) const { + return num_cols(schema_handle) == 0; } const std::string& field_name(ProjectionIdEnum schema_handle, int field_id) const { - return field(schema_handle, field_id).field_name; + int id = schema_id(schema_handle); + return schemas_[id].second.field_names[field_id]; } const std::shared_ptr& data_type(ProjectionIdEnum schema_handle, int field_id) const { - return field(schema_handle, field_id).data_type; + int id = schema_id(schema_handle); + return schemas_[id].second.data_types[field_id]; + } + + const std::vector>& data_types( + ProjectionIdEnum schema_handle) const { + int id = schema_id(schema_handle); + return schemas_[id].second.data_types; } SchemaProjectionMap map(ProjectionIdEnum from, ProjectionIdEnum to) const { @@ -102,22 +113,24 @@ class SchemaProjectionMaps { } protected: - struct FieldInfo { - int field_path; - std::string field_name; - std::shared_ptr data_type; + struct FieldInfos { + std::vector field_paths; + std::vector field_names; + std::vector> data_types; }; Status RegisterSchema(ProjectionIdEnum handle, const Schema& schema) { - std::vector out_fields; + FieldInfos out_fields; const FieldVector& in_fields = schema.fields(); - out_fields.resize(in_fields.size()); + out_fields.field_paths.resize(in_fields.size()); + out_fields.field_names.resize(in_fields.size()); + out_fields.data_types.resize(in_fields.size()); for (size_t i = 0; i < in_fields.size(); ++i) { const std::string& name = in_fields[i]->name(); const std::shared_ptr& type = in_fields[i]->type(); - out_fields[i].field_path = static_cast(i); - out_fields[i].field_name = name; - out_fields[i].data_type = type; + out_fields.field_paths[i] = static_cast(i); + out_fields.field_names[i] = name; + out_fields.data_types[i] = type; } schemas_.push_back(std::make_pair(handle, out_fields)); return Status::OK(); @@ -126,17 +139,19 @@ class SchemaProjectionMaps { Status RegisterProjectedSchema(ProjectionIdEnum handle, const std::vector& selected_fields, const Schema& full_schema) { - std::vector out_fields; + FieldInfos out_fields; const FieldVector& in_fields = full_schema.fields(); - out_fields.resize(selected_fields.size()); + out_fields.field_paths.resize(selected_fields.size()); + out_fields.field_names.resize(selected_fields.size()); + out_fields.data_types.resize(selected_fields.size()); for (size_t i = 0; i < selected_fields.size(); ++i) { // All fields must be found in schema without ambiguity ARROW_ASSIGN_OR_RAISE(auto match, selected_fields[i].FindOne(full_schema)); const std::string& name = in_fields[match[0]]->name(); const std::shared_ptr& type = in_fields[match[0]]->type(); - out_fields[i].field_path = match[0]; - out_fields[i].field_name = name; - out_fields[i].data_type = type; + out_fields.field_paths[i] = match[0]; + out_fields.field_names[i] = name; + out_fields.data_types[i] = type; } schemas_.push_back(std::make_pair(handle, out_fields)); return Status::OK(); @@ -163,15 +178,9 @@ class SchemaProjectionMaps { return -1; } - const FieldInfo& field(ProjectionIdEnum schema_handle, int field_id) const { - int id = schema_id(schema_handle); - const std::vector& field_infos = schemas_[id].second; - return field_infos[field_id]; - } - void GenerateMapForProjection(int id_proj, int id_base) { - int num_cols_proj = static_cast(schemas_[id_proj].second.size()); - int num_cols_base = static_cast(schemas_[id_base].second.size()); + int num_cols_proj = static_cast(schemas_[id_proj].second.data_types.size()); + int num_cols_base = static_cast(schemas_[id_base].second.data_types.size()); std::vector& mapping = mappings_[id_proj]; std::vector& inverse_mapping = inverse_mappings_[id_proj]; @@ -183,15 +192,15 @@ class SchemaProjectionMaps { mapping[i] = inverse_mapping[i] = i; } } else { - const std::vector& fields_proj = schemas_[id_proj].second; - const std::vector& fields_base = schemas_[id_base].second; + const FieldInfos& fields_proj = schemas_[id_proj].second; + const FieldInfos& fields_base = schemas_[id_base].second; for (int i = 0; i < num_cols_base; ++i) { inverse_mapping[i] = SchemaProjectionMap::kMissingField; } for (int i = 0; i < num_cols_proj; ++i) { int field_id = SchemaProjectionMap::kMissingField; for (int j = 0; j < num_cols_base; ++j) { - if (fields_proj[i].field_path == fields_base[j].field_path) { + if (fields_proj.field_paths[i] == fields_base.field_paths[j]) { field_id = j; // If there are multiple matches for the same input field, // it will be mapped to the first match. @@ -206,10 +215,12 @@ class SchemaProjectionMaps { } // vector used as a mapping from ProjectionIdEnum to fields - std::vector>> schemas_; + std::vector> schemas_; std::vector> mappings_; std::vector> inverse_mappings_; }; +using HashJoinProjectionMaps = SchemaProjectionMaps; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/swiss_join.cc b/cpp/src/arrow/compute/exec/swiss_join.cc new file mode 100644 index 00000000000..5d70e01b1e3 --- /dev/null +++ b/cpp/src/arrow/compute/exec/swiss_join.cc @@ -0,0 +1,2526 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec/swiss_join.h" +#include +#include // std::upper_bound +#include +#include +#include +#include "arrow/array/util.h" // MakeArrayFromScalar +#include "arrow/compute/exec/hash_join.h" +#include "arrow/compute/exec/key_hash.h" +#include "arrow/compute/exec/util.h" +#include "arrow/compute/kernels/row_encoder.h" +#include "arrow/compute/row/compare_internal.h" +#include "arrow/compute/row/encode_internal.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/tracing_internal.h" + +namespace arrow { +namespace compute { + +int RowArrayAccessor::VarbinaryColumnId(const RowTableMetadata& row_metadata, + int column_id) { + ARROW_DCHECK(row_metadata.num_cols() > static_cast(column_id)); + ARROW_DCHECK(!row_metadata.is_fixed_length); + ARROW_DCHECK(!row_metadata.column_metadatas[column_id].is_fixed_length); + + int varbinary_column_id = 0; + for (int i = 0; i < column_id; ++i) { + if (!row_metadata.column_metadatas[i].is_fixed_length) { + ++varbinary_column_id; + } + } + return varbinary_column_id; +} + +int RowArrayAccessor::NumRowsToSkip(const RowTableImpl& rows, int column_id, int num_rows, + const uint32_t* row_ids, int num_tail_bytes_to_skip) { + uint32_t num_bytes_skipped = 0; + int num_rows_left = num_rows; + + bool is_fixed_length_column = + rows.metadata().column_metadatas[column_id].is_fixed_length; + + if (!is_fixed_length_column) { + // Varying length column + // + int varbinary_column_id = VarbinaryColumnId(rows.metadata(), column_id); + + while (num_rows_left > 0 && + num_bytes_skipped < static_cast(num_tail_bytes_to_skip)) { + // Find the pointer to the last requested row + // + uint32_t last_row_id = row_ids[num_rows_left - 1]; + const uint8_t* row_ptr = rows.data(2) + rows.offsets()[last_row_id]; + + // Find the length of the requested varying length field in that row + // + uint32_t field_offset_within_row, field_length; + if (varbinary_column_id == 0) { + rows.metadata().first_varbinary_offset_and_length( + row_ptr, &field_offset_within_row, &field_length); + } else { + rows.metadata().nth_varbinary_offset_and_length( + row_ptr, varbinary_column_id, &field_offset_within_row, &field_length); + } + + num_bytes_skipped += field_length; + --num_rows_left; + } + } else { + // Fixed length column + // + uint32_t field_length = rows.metadata().column_metadatas[column_id].fixed_length; + uint32_t num_bytes_skipped = 0; + while (num_rows_left > 0 && + num_bytes_skipped < static_cast(num_tail_bytes_to_skip)) { + num_bytes_skipped += field_length; + --num_rows_left; + } + } + + return num_rows - num_rows_left; +} + +template +void RowArrayAccessor::Visit(const RowTableImpl& rows, int column_id, int num_rows, + const uint32_t* row_ids, PROCESS_VALUE_FN process_value_fn) { + bool is_fixed_length_column = + rows.metadata().column_metadatas[column_id].is_fixed_length; + + // There are 4 cases, each requiring different steps: + // 1. Varying length column that is the first varying length column in a row + // 2. Varying length column that is not the first varying length column in a + // row + // 3. Fixed length column in a fixed length row + // 4. Fixed length column in a varying length row + + if (!is_fixed_length_column) { + int varbinary_column_id = VarbinaryColumnId(rows.metadata(), column_id); + const uint8_t* row_ptr_base = rows.data(2); + const uint32_t* row_offsets = rows.offsets(); + uint32_t field_offset_within_row, field_length; + + if (varbinary_column_id == 0) { + // Case 1: This is the first varbinary column + // + for (int i = 0; i < num_rows; ++i) { + uint32_t row_id = row_ids[i]; + const uint8_t* row_ptr = row_ptr_base + row_offsets[row_id]; + rows.metadata().first_varbinary_offset_and_length( + row_ptr, &field_offset_within_row, &field_length); + process_value_fn(i, row_ptr + field_offset_within_row, field_length); + } + } else { + // Case 2: This is second or later varbinary column + // + for (int i = 0; i < num_rows; ++i) { + uint32_t row_id = row_ids[i]; + const uint8_t* row_ptr = row_ptr_base + row_offsets[row_id]; + rows.metadata().nth_varbinary_offset_and_length( + row_ptr, varbinary_column_id, &field_offset_within_row, &field_length); + process_value_fn(i, row_ptr + field_offset_within_row, field_length); + } + } + } + + if (is_fixed_length_column) { + uint32_t field_offset_within_row = rows.metadata().encoded_field_offset( + rows.metadata().pos_after_encoding(column_id)); + uint32_t field_length = rows.metadata().column_metadatas[column_id].fixed_length; + // Bit column is encoded as a single byte + // + if (field_length == 0) { + field_length = 1; + } + uint32_t row_length = rows.metadata().fixed_length; + + bool is_fixed_length_row = rows.metadata().is_fixed_length; + if (is_fixed_length_row) { + // Case 3: This is a fixed length column in a fixed length row + // + const uint8_t* row_ptr_base = rows.data(1) + field_offset_within_row; + for (int i = 0; i < num_rows; ++i) { + uint32_t row_id = row_ids[i]; + const uint8_t* row_ptr = row_ptr_base + row_length * row_id; + process_value_fn(i, row_ptr, field_length); + } + } else { + // Case 4: This is a fixed length column in a varying length row + // + const uint8_t* row_ptr_base = rows.data(2) + field_offset_within_row; + const uint32_t* row_offsets = rows.offsets(); + for (int i = 0; i < num_rows; ++i) { + uint32_t row_id = row_ids[i]; + const uint8_t* row_ptr = row_ptr_base + row_offsets[row_id]; + process_value_fn(i, row_ptr, field_length); + } + } + } +} + +template +void RowArrayAccessor::VisitNulls(const RowTableImpl& rows, int column_id, int num_rows, + const uint32_t* row_ids, + PROCESS_VALUE_FN process_value_fn) { + const uint8_t* null_masks = rows.null_masks(); + uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row; + uint32_t pos_after_encoding = rows.metadata().pos_after_encoding(column_id); + for (int i = 0; i < num_rows; ++i) { + uint32_t row_id = row_ids[i]; + int64_t bit_id = row_id * null_mask_num_bytes * 8 + pos_after_encoding; + process_value_fn(i, bit_util::GetBit(null_masks, bit_id) ? 0xff : 0); + } +} + +Status RowArray::InitIfNeeded(MemoryPool* pool, const RowTableMetadata& row_metadata) { + if (is_initialized_) { + return Status::OK(); + } + encoder_.Init(row_metadata.column_metadatas, sizeof(uint64_t), sizeof(uint64_t)); + RETURN_NOT_OK(rows_temp_.Init(pool, row_metadata)); + RETURN_NOT_OK(rows_.Init(pool, row_metadata)); + is_initialized_ = true; + return Status::OK(); +} + +Status RowArray::InitIfNeeded(MemoryPool* pool, const ExecBatch& batch) { + if (is_initialized_) { + return Status::OK(); + } + std::vector column_metadatas; + RETURN_NOT_OK(ColumnMetadatasFromExecBatch(batch, &column_metadatas)); + RowTableMetadata row_metadata; + row_metadata.FromColumnMetadataVector(column_metadatas, sizeof(uint64_t), + sizeof(uint64_t)); + + return InitIfNeeded(pool, row_metadata); +} + +Status RowArray::AppendBatchSelection(MemoryPool* pool, const ExecBatch& batch, + int begin_row_id, int end_row_id, int num_row_ids, + const uint16_t* row_ids, + std::vector& temp_column_arrays) { + RETURN_NOT_OK(InitIfNeeded(pool, batch)); + RETURN_NOT_OK(ColumnArraysFromExecBatch(batch, begin_row_id, end_row_id - begin_row_id, + &temp_column_arrays)); + encoder_.PrepareEncodeSelected( + /*start_row=*/0, end_row_id - begin_row_id, temp_column_arrays); + RETURN_NOT_OK(encoder_.EncodeSelected(&rows_temp_, num_row_ids, row_ids)); + RETURN_NOT_OK(rows_.AppendSelectionFrom(rows_temp_, num_row_ids, nullptr)); + return Status::OK(); +} + +void RowArray::Compare(const ExecBatch& batch, int begin_row_id, int end_row_id, + int num_selected, const uint16_t* batch_selection_maybe_null, + const uint32_t* array_row_ids, uint32_t* out_num_not_equal, + uint16_t* out_not_equal_selection, int64_t hardware_flags, + util::TempVectorStack* temp_stack, + std::vector& temp_column_arrays, + uint8_t* out_match_bitvector_maybe_null) { + Status status = ColumnArraysFromExecBatch( + batch, begin_row_id, end_row_id - begin_row_id, &temp_column_arrays); + ARROW_DCHECK(status.ok()); + + LightContext ctx; + ctx.hardware_flags = hardware_flags; + ctx.stack = temp_stack; + KeyCompare::CompareColumnsToRows( + num_selected, batch_selection_maybe_null, array_row_ids, &ctx, out_num_not_equal, + out_not_equal_selection, temp_column_arrays, rows_, + /*are_cols_in_encoding_order=*/false, out_match_bitvector_maybe_null); +} + +Status RowArray::DecodeSelected(ResizableArrayData* output, int column_id, + int num_rows_to_append, const uint32_t* row_ids, + MemoryPool* pool) const { + int num_rows_before = output->num_rows(); + RETURN_NOT_OK(output->ResizeFixedLengthBuffers(num_rows_before + num_rows_to_append)); + + // Both input (KeyRowArray) and output (ResizableArrayData) have buffers with + // extra bytes added at the end to avoid buffer overruns when using wide load + // instructions. + // + + ARROW_ASSIGN_OR_RAISE(KeyColumnMetadata column_metadata, output->column_metadata()); + + if (column_metadata.is_fixed_length) { + uint32_t fixed_length = column_metadata.fixed_length; + switch (fixed_length) { + case 0: + RowArrayAccessor::Visit(rows_, column_id, num_rows_to_append, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + bit_util::SetBitTo(output->mutable_data(1), + num_rows_before + i, *ptr != 0); + }); + break; + case 1: + RowArrayAccessor::Visit(rows_, column_id, num_rows_to_append, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + output->mutable_data(1)[num_rows_before + i] = *ptr; + }); + break; + case 2: + RowArrayAccessor::Visit( + rows_, column_id, num_rows_to_append, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + reinterpret_cast(output->mutable_data(1))[num_rows_before + i] = + *reinterpret_cast(ptr); + }); + break; + case 4: + RowArrayAccessor::Visit( + rows_, column_id, num_rows_to_append, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + reinterpret_cast(output->mutable_data(1))[num_rows_before + i] = + *reinterpret_cast(ptr); + }); + break; + case 8: + RowArrayAccessor::Visit( + rows_, column_id, num_rows_to_append, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + reinterpret_cast(output->mutable_data(1))[num_rows_before + i] = + *reinterpret_cast(ptr); + }); + break; + default: + RowArrayAccessor::Visit( + rows_, column_id, num_rows_to_append, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + uint64_t* dst = reinterpret_cast( + output->mutable_data(1) + num_bytes * (num_rows_before + i)); + const uint64_t* src = reinterpret_cast(ptr); + for (uint32_t word_id = 0; + word_id < bit_util::CeilDiv(num_bytes, sizeof(uint64_t)); ++word_id) { + util::SafeStore(dst + word_id, util::SafeLoad(src + word_id)); + } + }); + break; + } + } else { + uint32_t* offsets = + reinterpret_cast(output->mutable_data(1)) + num_rows_before; + uint32_t sum = num_rows_before == 0 ? 0 : offsets[0]; + RowArrayAccessor::Visit( + rows_, column_id, num_rows_to_append, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { offsets[i] = num_bytes; }); + for (int i = 0; i < num_rows_to_append; ++i) { + uint32_t length = offsets[i]; + offsets[i] = sum; + sum += length; + } + offsets[num_rows_to_append] = sum; + RETURN_NOT_OK(output->ResizeVaryingLengthBuffer()); + RowArrayAccessor::Visit( + rows_, column_id, num_rows_to_append, row_ids, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + uint64_t* dst = reinterpret_cast( + output->mutable_data(2) + + reinterpret_cast( + output->mutable_data(1))[num_rows_before + i]); + const uint64_t* src = reinterpret_cast(ptr); + for (uint32_t word_id = 0; + word_id < bit_util::CeilDiv(num_bytes, sizeof(uint64_t)); ++word_id) { + util::SafeStore(dst + word_id, util::SafeLoad(src + word_id)); + } + }); + } + + // Process nulls + // + RowArrayAccessor::VisitNulls( + rows_, column_id, num_rows_to_append, row_ids, [&](int i, uint8_t value) { + bit_util::SetBitTo(output->mutable_data(0), num_rows_before + i, value == 0); + }); + + return Status::OK(); +} + +void RowArray::DebugPrintToFile(const char* filename, bool print_sorted) const { + FILE* fout; +#if defined(_MSC_VER) && _MSC_VER >= 1400 + fopen_s(&fout, filename, "wt"); +#else + fout = fopen(filename, "wt"); +#endif + if (!fout) { + return; + } + + for (int64_t row_id = 0; row_id < rows_.length(); ++row_id) { + for (uint32_t column_id = 0; column_id < rows_.metadata().num_cols(); ++column_id) { + bool is_null; + uint32_t row_id_cast = static_cast(row_id); + RowArrayAccessor::VisitNulls(rows_, column_id, 1, &row_id_cast, + [&](int i, uint8_t value) { is_null = (value != 0); }); + if (is_null) { + fprintf(fout, "null"); + } else { + RowArrayAccessor::Visit(rows_, column_id, 1, &row_id_cast, + [&](int i, const uint8_t* ptr, uint32_t num_bytes) { + fprintf(fout, "\""); + for (uint32_t ibyte = 0; ibyte < num_bytes; ++ibyte) { + fprintf(fout, "%02x", ptr[ibyte]); + } + fprintf(fout, "\""); + }); + } + fprintf(fout, "\t"); + } + fprintf(fout, "\n"); + } + fclose(fout); + + if (print_sorted) { + struct stat sb; + if (stat(filename, &sb) == -1) { + ARROW_DCHECK(false); + return; + } + std::vector buffer; + buffer.resize(sb.st_size); + std::vector lines; + FILE* fin; +#if defined(_MSC_VER) && _MSC_VER >= 1400 + fopen_s(&fin, filename, "rt"); +#else + fin = fopen(filename, "rt"); +#endif + if (!fin) { + return; + } + while (fgets(buffer.data(), static_cast(buffer.size()), fin)) { + lines.push_back(std::string(buffer.data())); + } + fclose(fin); + std::sort(lines.begin(), lines.end()); + FILE* fout2; +#if defined(_MSC_VER) && _MSC_VER >= 1400 + fopen_s(&fout2, filename, "wt"); +#else + fout2 = fopen(filename, "wt"); +#endif + if (!fout2) { + return; + } + for (size_t i = 0; i < lines.size(); ++i) { + fprintf(fout2, "%s\n", lines[i].c_str()); + } + fclose(fout2); + } +} + +Status RowArrayMerge::PrepareForMerge(RowArray* target, + const std::vector& sources, + std::vector* first_target_row_id, + MemoryPool* pool) { + ARROW_DCHECK(!sources.empty()); + + ARROW_DCHECK(sources[0]->is_initialized_); + const RowTableMetadata& metadata = sources[0]->rows_.metadata(); + ARROW_DCHECK(!target->is_initialized_); + RETURN_NOT_OK(target->InitIfNeeded(pool, metadata)); + + // Sum the number of rows from all input sources and calculate their total + // size. + // + int64_t num_rows = 0; + int64_t num_bytes = 0; + if (first_target_row_id) { + first_target_row_id->resize(sources.size() + 1); + } + for (size_t i = 0; i < sources.size(); ++i) { + // All input sources must be initialized and have the same row format. + // + ARROW_DCHECK(sources[i]->is_initialized_); + ARROW_DCHECK(metadata.is_compatible(sources[i]->rows_.metadata())); + if (first_target_row_id) { + (*first_target_row_id)[i] = num_rows; + } + num_rows += sources[i]->rows_.length(); + if (!metadata.is_fixed_length) { + num_bytes += sources[i]->rows_.offsets()[sources[i]->rows_.length()]; + } + } + if (first_target_row_id) { + (*first_target_row_id)[sources.size()] = num_rows; + } + + // Allocate target memory + // + target->rows_.Clean(); + RETURN_NOT_OK(target->rows_.AppendEmpty(static_cast(num_rows), + static_cast(num_bytes))); + + // In case of varying length rows, + // initialize the first row offset for each range of rows corresponding to a + // single source. + // + if (!metadata.is_fixed_length) { + num_rows = 0; + num_bytes = 0; + for (size_t i = 0; i < sources.size(); ++i) { + target->rows_.mutable_offsets()[num_rows] = static_cast(num_bytes); + num_rows += sources[i]->rows_.length(); + num_bytes += sources[i]->rows_.offsets()[sources[i]->rows_.length()]; + } + target->rows_.mutable_offsets()[num_rows] = static_cast(num_bytes); + } + + return Status::OK(); +} + +void RowArrayMerge::MergeSingle(RowArray* target, const RowArray& source, + int64_t first_target_row_id, + const int64_t* source_rows_permutation) { + // Source and target must: + // - be initialized + // - use the same row format + // - use 64-bit alignment + // + ARROW_DCHECK(source.is_initialized_ && target->is_initialized_); + ARROW_DCHECK(target->rows_.metadata().is_compatible(source.rows_.metadata())); + ARROW_DCHECK(target->rows_.metadata().row_alignment == sizeof(uint64_t)); + + if (target->rows_.metadata().is_fixed_length) { + CopyFixedLength(&target->rows_, source.rows_, first_target_row_id, + source_rows_permutation); + } else { + CopyVaryingLength(&target->rows_, source.rows_, first_target_row_id, + target->rows_.offsets()[first_target_row_id], + source_rows_permutation); + } + CopyNulls(&target->rows_, source.rows_, first_target_row_id, source_rows_permutation); +} + +void RowArrayMerge::CopyFixedLength(RowTableImpl* target, const RowTableImpl& source, + int64_t first_target_row_id, + const int64_t* source_rows_permutation) { + int64_t num_source_rows = source.length(); + + int64_t fixed_length = target->metadata().fixed_length; + + // Permutation of source rows is optional. Without permutation all that is + // needed is memcpy. + // + if (!source_rows_permutation) { + memcpy(target->mutable_data(1) + fixed_length * first_target_row_id, source.data(1), + fixed_length * num_source_rows); + } else { + // Row length must be a multiple of 64-bits due to enforced alignment. + // Loop for each output row copying a fixed number of 64-bit words. + // + ARROW_DCHECK(fixed_length % sizeof(uint64_t) == 0); + + int64_t num_words_per_row = fixed_length / sizeof(uint64_t); + for (int64_t i = 0; i < num_source_rows; ++i) { + int64_t source_row_id = source_rows_permutation[i]; + const uint64_t* source_row_ptr = reinterpret_cast( + source.data(1) + fixed_length * source_row_id); + uint64_t* target_row_ptr = reinterpret_cast( + target->mutable_data(1) + fixed_length * (first_target_row_id + i)); + + for (int64_t word = 0; word < num_words_per_row; ++word) { + target_row_ptr[word] = source_row_ptr[word]; + } + } + } +} + +void RowArrayMerge::CopyVaryingLength(RowTableImpl* target, const RowTableImpl& source, + int64_t first_target_row_id, + int64_t first_target_row_offset, + const int64_t* source_rows_permutation) { + int64_t num_source_rows = source.length(); + uint32_t* target_offsets = target->mutable_offsets(); + const uint32_t* source_offsets = source.offsets(); + + // Permutation of source rows is optional. + // + if (!source_rows_permutation) { + int64_t target_row_offset = first_target_row_offset; + for (int64_t i = 0; i < num_source_rows; ++i) { + target_offsets[first_target_row_id + i] = static_cast(target_row_offset); + target_row_offset += source_offsets[i + 1] - source_offsets[i]; + } + // We purposefully skip outputting of N+1 offset, to allow concurrent + // copies of rows done to adjacent ranges in target array. + // It should have already been initialized during preparation for merge. + // + + // We can simply memcpy bytes of rows if their order has not changed. + // + memcpy(target->mutable_data(2) + target_offsets[first_target_row_id], source.data(2), + source_offsets[num_source_rows] - source_offsets[0]); + } else { + int64_t target_row_offset = first_target_row_offset; + uint64_t* target_row_ptr = + reinterpret_cast(target->mutable_data(2) + target_row_offset); + for (int64_t i = 0; i < num_source_rows; ++i) { + int64_t source_row_id = source_rows_permutation[i]; + const uint64_t* source_row_ptr = reinterpret_cast( + source.data(2) + source_offsets[source_row_id]); + uint32_t length = source_offsets[source_row_id + 1] - source_offsets[source_row_id]; + + // Rows should be 64-bit aligned. + // In that case we can copy them using a sequence of 64-bit read/writes. + // + ARROW_DCHECK(length % sizeof(uint64_t) == 0); + + for (uint32_t word = 0; word < length / sizeof(uint64_t); ++word) { + *target_row_ptr++ = *source_row_ptr++; + } + + target_offsets[first_target_row_id + i] = static_cast(target_row_offset); + target_row_offset += length; + } + } +} + +void RowArrayMerge::CopyNulls(RowTableImpl* target, const RowTableImpl& source, + int64_t first_target_row_id, + const int64_t* source_rows_permutation) { + int64_t num_source_rows = source.length(); + int num_bytes_per_row = target->metadata().null_masks_bytes_per_row; + uint8_t* target_nulls = target->null_masks() + num_bytes_per_row * first_target_row_id; + if (!source_rows_permutation) { + memcpy(target_nulls, source.null_masks(), num_bytes_per_row * num_source_rows); + } else { + for (int64_t i = 0; i < num_source_rows; ++i) { + int64_t source_row_id = source_rows_permutation[i]; + const uint8_t* source_nulls = + source.null_masks() + num_bytes_per_row * source_row_id; + for (int64_t byte = 0; byte < num_bytes_per_row; ++byte) { + *target_nulls++ = *source_nulls++; + } + } + } +} + +Status SwissTableMerge::PrepareForMerge(SwissTable* target, + const std::vector& sources, + std::vector* first_target_group_id, + MemoryPool* pool) { + ARROW_DCHECK(!sources.empty()); + + // Each source should correspond to a range of hashes. + // A row belongs to a source with index determined by K highest bits of hash. + // That means that the number of sources must be a power of 2. + // + int log_num_sources = bit_util::Log2(sources.size()); + ARROW_DCHECK((1 << log_num_sources) == static_cast(sources.size())); + + // Determine the number of blocks in the target table. + // We will use max of numbers of blocks in any of the sources multiplied by + // the number of sources. + // + int log_blocks_max = 1; + for (size_t i = 0; i < sources.size(); ++i) { + log_blocks_max = std::max(log_blocks_max, sources[i]->log_blocks_); + } + int log_blocks = log_num_sources + log_blocks_max; + + // Allocate target blocks and mark all slots as empty + // + // We will skip allocating the array of hash values in target table. + // Target will be used in read-only mode and that array is only needed when + // resizing table which may occur only after new inserts. + // + RETURN_NOT_OK(target->init(sources[0]->hardware_flags_, pool, log_blocks, + /*no_hash_array=*/true)); + + // Calculate and output the first group id index for each source. + // + if (first_target_group_id) { + uint32_t num_groups = 0; + first_target_group_id->resize(sources.size()); + for (size_t i = 0; i < sources.size(); ++i) { + (*first_target_group_id)[i] = num_groups; + num_groups += sources[i]->num_inserted_; + } + target->num_inserted_ = num_groups; + } + + return Status::OK(); +} + +void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* source, + uint32_t partition_id, int num_partition_bits, + uint32_t base_group_id, + std::vector* overflow_group_ids, + std::vector* overflow_hashes) { + // Prepare parameters needed for scanning full slots in source. + // + int source_group_id_bits = + SwissTable::num_groupid_bits_from_log_blocks(source->log_blocks_); + uint64_t source_group_id_mask = ~0ULL >> (64 - source_group_id_bits); + int64_t source_block_bytes = source_group_id_bits + 8; + ARROW_DCHECK(source_block_bytes % sizeof(uint64_t) == 0); + + // Compute index of the last block in target that corresponds to the given + // partition. + // + ARROW_DCHECK(num_partition_bits <= target->log_blocks_); + int64_t target_max_block_id = + ((partition_id + 1) << (target->log_blocks_ - num_partition_bits)) - 1; + + overflow_group_ids->clear(); + overflow_hashes->clear(); + + // For each source block... + int64_t source_blocks = 1LL << source->log_blocks_; + for (int64_t block_id = 0; block_id < source_blocks; ++block_id) { + uint8_t* block_bytes = source->blocks_ + block_id * source_block_bytes; + uint64_t block = *reinterpret_cast(block_bytes); + + // For each non-empty source slot... + constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL; + constexpr int kSlotsPerBlock = 8; + int num_full_slots = + kSlotsPerBlock - static_cast(ARROW_POPCOUNT64(block & kHighBitOfEachByte)); + for (int local_slot_id = 0; local_slot_id < num_full_slots; ++local_slot_id) { + // Read group id and hash for this slot. + // + uint64_t group_id = + source->extract_group_id(block_bytes, local_slot_id, source_group_id_mask); + int64_t global_slot_id = block_id * kSlotsPerBlock + local_slot_id; + uint32_t hash = source->hashes_[global_slot_id]; + // Insert partition id into the highest bits of hash, shifting the + // remaining hash bits right. + // + hash >>= num_partition_bits; + hash |= (partition_id << (SwissTable::bits_hash_ - 1 - num_partition_bits) << 1); + // Add base group id + // + group_id += base_group_id; + + // Insert new entry into target. Store in overflow vectors if not + // successful. + // + bool was_inserted = InsertNewGroup(target, group_id, hash, target_max_block_id); + if (!was_inserted) { + overflow_group_ids->push_back(static_cast(group_id)); + overflow_hashes->push_back(hash); + } + } + } +} + +inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint64_t group_id, + uint32_t hash, int64_t max_block_id) { + // Load the first block to visit for this hash + // + int64_t block_id = hash >> (SwissTable::bits_hash_ - target->log_blocks_); + int64_t block_id_mask = ((1LL << target->log_blocks_) - 1); + int num_group_id_bits = + SwissTable::num_groupid_bits_from_log_blocks(target->log_blocks_); + int64_t num_block_bytes = num_group_id_bits + sizeof(uint64_t); + ARROW_DCHECK(num_block_bytes % sizeof(uint64_t) == 0); + uint8_t* block_bytes = target->blocks_ + block_id * num_block_bytes; + uint64_t block = *reinterpret_cast(block_bytes); + + // Search for the first block with empty slots. + // Stop after reaching max block id. + // + constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL; + while ((block & kHighBitOfEachByte) == 0 && block_id < max_block_id) { + block_id = (block_id + 1) & block_id_mask; + block_bytes = target->blocks_ + block_id * num_block_bytes; + block = *reinterpret_cast(block_bytes); + } + if ((block & kHighBitOfEachByte) == 0) { + return false; + } + constexpr int kSlotsPerBlock = 8; + int local_slot_id = + kSlotsPerBlock - static_cast(ARROW_POPCOUNT64(block & kHighBitOfEachByte)); + int64_t global_slot_id = block_id * kSlotsPerBlock + local_slot_id; + target->insert_into_empty_slot(static_cast(global_slot_id), hash, + static_cast(group_id)); + return true; +} + +void SwissTableMerge::InsertNewGroups(SwissTable* target, + const std::vector& group_ids, + const std::vector& hashes) { + int64_t num_blocks = 1LL << target->log_blocks_; + for (size_t i = 0; i < group_ids.size(); ++i) { + std::ignore = InsertNewGroup(target, group_ids[i], hashes[i], num_blocks); + } +} + +SwissTableWithKeys::Input::Input(const ExecBatch* in_batch, int in_batch_start_row, + int in_batch_end_row, + util::TempVectorStack* in_temp_stack, + std::vector* in_temp_column_arrays) + : batch(in_batch), + batch_start_row(in_batch_start_row), + batch_end_row(in_batch_end_row), + num_selected(0), + selection_maybe_null(nullptr), + temp_stack(in_temp_stack), + temp_column_arrays(in_temp_column_arrays), + temp_group_ids(nullptr) {} + +SwissTableWithKeys::Input::Input(const ExecBatch* in_batch, + util::TempVectorStack* in_temp_stack, + std::vector* in_temp_column_arrays) + : batch(in_batch), + batch_start_row(0), + batch_end_row(static_cast(in_batch->length)), + num_selected(0), + selection_maybe_null(nullptr), + temp_stack(in_temp_stack), + temp_column_arrays(in_temp_column_arrays), + temp_group_ids(nullptr) {} + +SwissTableWithKeys::Input::Input(const ExecBatch* in_batch, int in_num_selected, + const uint16_t* in_selection, + util::TempVectorStack* in_temp_stack, + std::vector* in_temp_column_arrays, + std::vector* in_temp_group_ids) + : batch(in_batch), + batch_start_row(0), + batch_end_row(static_cast(in_batch->length)), + num_selected(in_num_selected), + selection_maybe_null(in_selection), + temp_stack(in_temp_stack), + temp_column_arrays(in_temp_column_arrays), + temp_group_ids(in_temp_group_ids) {} + +SwissTableWithKeys::Input::Input(const Input& base, int num_rows_to_skip, + int num_rows_to_include) + : batch(base.batch), + temp_stack(base.temp_stack), + temp_column_arrays(base.temp_column_arrays), + temp_group_ids(base.temp_group_ids) { + if (base.selection_maybe_null) { + batch_start_row = 0; + batch_end_row = static_cast(batch->length); + ARROW_DCHECK(num_rows_to_skip + num_rows_to_include <= base.num_selected); + num_selected = num_rows_to_include; + selection_maybe_null = base.selection_maybe_null + num_rows_to_skip; + } else { + ARROW_DCHECK(base.batch_start_row + num_rows_to_skip + num_rows_to_include <= + base.batch_end_row); + batch_start_row = base.batch_start_row + num_rows_to_skip; + batch_end_row = base.batch_start_row + num_rows_to_skip + num_rows_to_include; + num_selected = 0; + selection_maybe_null = nullptr; + } +} + +Status SwissTableWithKeys::Init(int64_t hardware_flags, MemoryPool* pool) { + InitCallbacks(); + return swiss_table_.init(hardware_flags, pool); +} + +void SwissTableWithKeys::EqualCallback(int num_keys, const uint16_t* selection_maybe_null, + const uint32_t* group_ids, + uint32_t* out_num_keys_mismatch, + uint16_t* out_selection_mismatch, + void* callback_ctx) { + if (num_keys == 0) { + *out_num_keys_mismatch = 0; + return; + } + + ARROW_DCHECK(num_keys <= swiss_table_.minibatch_size()); + + Input* in = reinterpret_cast(callback_ctx); + + int64_t hardware_flags = swiss_table_.hardware_flags(); + + int batch_start_to_use; + int batch_end_to_use; + const uint16_t* selection_to_use; + const uint32_t* group_ids_to_use; + + if (in->selection_maybe_null) { + auto selection_to_use_buf = + util::TempVectorHolder(in->temp_stack, num_keys); + ARROW_DCHECK(in->temp_group_ids); + in->temp_group_ids->resize(in->batch->length); + + if (selection_maybe_null) { + for (int i = 0; i < num_keys; ++i) { + uint16_t local_row_id = selection_maybe_null[i]; + uint16_t global_row_id = in->selection_maybe_null[local_row_id]; + selection_to_use_buf.mutable_data()[i] = global_row_id; + (*in->temp_group_ids)[global_row_id] = group_ids[local_row_id]; + } + selection_to_use = selection_to_use_buf.mutable_data(); + } else { + for (int i = 0; i < num_keys; ++i) { + uint16_t global_row_id = in->selection_maybe_null[i]; + (*in->temp_group_ids)[global_row_id] = group_ids[i]; + } + selection_to_use = in->selection_maybe_null; + } + batch_start_to_use = 0; + batch_end_to_use = static_cast(in->batch->length); + group_ids_to_use = in->temp_group_ids->data(); + + auto match_bitvector_buf = util::TempVectorHolder(in->temp_stack, num_keys); + uint8_t* match_bitvector = match_bitvector_buf.mutable_data(); + + keys_.Compare(*in->batch, batch_start_to_use, batch_end_to_use, num_keys, + selection_to_use, group_ids_to_use, nullptr, nullptr, hardware_flags, + in->temp_stack, *in->temp_column_arrays, match_bitvector); + + if (selection_maybe_null) { + int num_keys_mismatch = 0; + util::bit_util::bits_filter_indexes(0, hardware_flags, num_keys, match_bitvector, + selection_maybe_null, &num_keys_mismatch, + out_selection_mismatch); + *out_num_keys_mismatch = num_keys_mismatch; + } else { + int num_keys_mismatch = 0; + util::bit_util::bits_to_indexes(0, hardware_flags, num_keys, match_bitvector, + &num_keys_mismatch, out_selection_mismatch); + *out_num_keys_mismatch = num_keys_mismatch; + } + + } else { + batch_start_to_use = in->batch_start_row; + batch_end_to_use = in->batch_end_row; + selection_to_use = selection_maybe_null; + group_ids_to_use = group_ids; + keys_.Compare(*in->batch, batch_start_to_use, batch_end_to_use, num_keys, + selection_to_use, group_ids_to_use, out_num_keys_mismatch, + out_selection_mismatch, hardware_flags, in->temp_stack, + *in->temp_column_arrays); + } +} + +Status SwissTableWithKeys::AppendCallback(int num_keys, const uint16_t* selection, + void* callback_ctx) { + ARROW_DCHECK(num_keys <= swiss_table_.minibatch_size()); + ARROW_DCHECK(selection); + + Input* in = reinterpret_cast(callback_ctx); + + int batch_start_to_use; + int batch_end_to_use; + const uint16_t* selection_to_use; + + if (in->selection_maybe_null) { + auto selection_to_use_buf = + util::TempVectorHolder(in->temp_stack, num_keys); + for (int i = 0; i < num_keys; ++i) { + selection_to_use_buf.mutable_data()[i] = in->selection_maybe_null[selection[i]]; + } + batch_start_to_use = 0; + batch_end_to_use = static_cast(in->batch->length); + selection_to_use = selection_to_use_buf.mutable_data(); + + return keys_.AppendBatchSelection(swiss_table_.pool(), *in->batch, batch_start_to_use, + batch_end_to_use, num_keys, selection_to_use, + *in->temp_column_arrays); + } else { + batch_start_to_use = in->batch_start_row; + batch_end_to_use = in->batch_end_row; + selection_to_use = selection; + + return keys_.AppendBatchSelection(swiss_table_.pool(), *in->batch, batch_start_to_use, + batch_end_to_use, num_keys, selection_to_use, + *in->temp_column_arrays); + } +} + +void SwissTableWithKeys::InitCallbacks() { + equal_impl_ = [&](int num_keys, const uint16_t* selection_maybe_null, + const uint32_t* group_ids, uint32_t* out_num_keys_mismatch, + uint16_t* out_selection_mismatch, void* callback_ctx) { + EqualCallback(num_keys, selection_maybe_null, group_ids, out_num_keys_mismatch, + out_selection_mismatch, callback_ctx); + }; + append_impl_ = [&](int num_keys, const uint16_t* selection, void* callback_ctx) { + return AppendCallback(num_keys, selection, callback_ctx); + }; +} + +void SwissTableWithKeys::Hash(Input* input, uint32_t* hashes, int64_t hardware_flags) { + // Hashing does not support selection of rows + // + ARROW_DCHECK(input->selection_maybe_null == nullptr); + + Status status = + Hashing32::HashBatch(*input->batch, hashes, *input->temp_column_arrays, + hardware_flags, input->temp_stack, input->batch_start_row, + input->batch_end_row - input->batch_start_row); + ARROW_DCHECK(status.ok()); +} + +void SwissTableWithKeys::MapReadOnly(Input* input, const uint32_t* hashes, + uint8_t* match_bitvector, uint32_t* key_ids) { + std::ignore = Map(input, /*insert_missing=*/false, hashes, match_bitvector, key_ids); +} + +Status SwissTableWithKeys::MapWithInserts(Input* input, const uint32_t* hashes, + uint32_t* key_ids) { + return Map(input, /*insert_missing=*/true, hashes, nullptr, key_ids); +} + +Status SwissTableWithKeys::Map(Input* input, bool insert_missing, const uint32_t* hashes, + uint8_t* match_bitvector_maybe_null, uint32_t* key_ids) { + util::TempVectorStack* temp_stack = input->temp_stack; + + // Split into smaller mini-batches + // + int minibatch_size = swiss_table_.minibatch_size(); + int num_rows_to_process = input->selection_maybe_null + ? input->num_selected + : input->batch_end_row - input->batch_start_row; + auto hashes_buf = util::TempVectorHolder(temp_stack, minibatch_size); + auto match_bitvector_buf = util::TempVectorHolder( + temp_stack, + static_cast(bit_util::BytesForBits(minibatch_size)) + sizeof(uint64_t)); + for (int minibatch_start = 0; minibatch_start < num_rows_to_process;) { + int minibatch_size_next = + std::min(minibatch_size, num_rows_to_process - minibatch_start); + + // Prepare updated input buffers that represent the current minibatch. + // + Input minibatch_input(*input, minibatch_start, minibatch_size_next); + uint8_t* minibatch_match_bitvector = + insert_missing ? match_bitvector_buf.mutable_data() + : match_bitvector_maybe_null + minibatch_start / 8; + const uint32_t* minibatch_hashes; + if (input->selection_maybe_null) { + minibatch_hashes = hashes_buf.mutable_data(); + for (int i = 0; i < minibatch_size_next; ++i) { + hashes_buf.mutable_data()[i] = hashes[minibatch_input.selection_maybe_null[i]]; + } + } else { + minibatch_hashes = hashes + minibatch_start; + } + uint32_t* minibatch_key_ids = key_ids + minibatch_start; + + // Lookup existing keys. + { + auto slots = util::TempVectorHolder(temp_stack, minibatch_size_next); + swiss_table_.early_filter(minibatch_size_next, minibatch_hashes, + minibatch_match_bitvector, slots.mutable_data()); + swiss_table_.find(minibatch_size_next, minibatch_hashes, minibatch_match_bitvector, + slots.mutable_data(), minibatch_key_ids, temp_stack, equal_impl_, + &minibatch_input); + } + + // Perform inserts of missing keys if required. + // + if (insert_missing) { + auto ids_buf = util::TempVectorHolder(temp_stack, minibatch_size_next); + int num_ids; + util::bit_util::bits_to_indexes(0, swiss_table_.hardware_flags(), + minibatch_size_next, minibatch_match_bitvector, + &num_ids, ids_buf.mutable_data()); + + RETURN_NOT_OK(swiss_table_.map_new_keys( + num_ids, ids_buf.mutable_data(), minibatch_hashes, minibatch_key_ids, + temp_stack, equal_impl_, append_impl_, &minibatch_input)); + } + + minibatch_start += minibatch_size_next; + } + + return Status::OK(); +} + +uint8_t* SwissTableForJoin::local_has_match(int64_t thread_id) { + int64_t num_rows_hash_table = num_rows(); + if (num_rows_hash_table == 0) { + return nullptr; + } + + ThreadLocalState& local_state = local_states_[thread_id]; + if (local_state.has_match.empty() && num_rows_hash_table > 0) { + local_state.has_match.resize(bit_util::BytesForBits(num_rows_hash_table) + + sizeof(uint64_t)); + memset(local_state.has_match.data(), 0, bit_util::BytesForBits(num_rows_hash_table)); + } + + return local_states_[thread_id].has_match.data(); +} + +void SwissTableForJoin::UpdateHasMatchForKeys(int64_t thread_id, int num_ids, + const uint32_t* key_ids) { + uint8_t* bit_vector = local_has_match(thread_id); + if (num_ids == 0 || !bit_vector) { + return; + } + for (int i = 0; i < num_ids; ++i) { + // Mark row in hash table as having a match + // + bit_util::SetBit(bit_vector, key_ids[i]); + } +} + +void SwissTableForJoin::MergeHasMatch() { + int64_t num_rows_hash_table = num_rows(); + if (num_rows_hash_table == 0) { + return; + } + + has_match_.resize(bit_util::BytesForBits(num_rows_hash_table) + sizeof(uint64_t)); + memset(has_match_.data(), 0, bit_util::BytesForBits(num_rows_hash_table)); + + for (size_t tid = 0; tid < local_states_.size(); ++tid) { + if (!local_states_[tid].has_match.empty()) { + arrow::internal::BitmapOr(has_match_.data(), 0, local_states_[tid].has_match.data(), + 0, num_rows_hash_table, 0, has_match_.data()); + } + } +} + +uint32_t SwissTableForJoin::payload_id_to_key_id(uint32_t payload_id) const { + if (no_duplicate_keys_) { + return payload_id; + } + int64_t num_entries = num_keys(); + const uint32_t* entries = key_to_payload(); + ARROW_DCHECK(entries); + ARROW_DCHECK(entries[num_entries] > payload_id); + const uint32_t* first_greater = + std::upper_bound(entries, entries + num_entries + 1, payload_id); + ARROW_DCHECK(first_greater > entries); + return static_cast(first_greater - entries) - 1; +} + +void SwissTableForJoin::payload_ids_to_key_ids(int num_rows, const uint32_t* payload_ids, + uint32_t* key_ids) const { + if (num_rows == 0) { + return; + } + if (no_duplicate_keys_) { + memcpy(key_ids, payload_ids, num_rows * sizeof(uint32_t)); + return; + } + + const uint32_t* entries = key_to_payload(); + uint32_t key_id = payload_id_to_key_id(payload_ids[0]); + key_ids[0] = key_id; + for (int i = 1; i < num_rows; ++i) { + ARROW_DCHECK(payload_ids[i] > payload_ids[i - 1]); + while (entries[key_id + 1] <= payload_ids[i]) { + ++key_id; + ARROW_DCHECK(key_id < num_keys()); + } + key_ids[i] = key_id; + } +} + +Status SwissTableForJoinBuild::Init(SwissTableForJoin* target, int dop, int64_t num_rows, + bool reject_duplicate_keys, bool no_payload, + const std::vector& key_types, + const std::vector& payload_types, + MemoryPool* pool, int64_t hardware_flags) { + target_ = target; + dop_ = dop; + num_rows_ = num_rows; + + // Make sure that we do not use many partitions if there are not enough rows. + // + constexpr int64_t min_num_rows_per_prtn = 1 << 18; + log_num_prtns_ = + std::min(bit_util::Log2(dop_), + bit_util::Log2(bit_util::CeilDiv(num_rows, min_num_rows_per_prtn))); + num_prtns_ = 1 << log_num_prtns_; + + reject_duplicate_keys_ = reject_duplicate_keys; + no_payload_ = no_payload; + pool_ = pool; + hardware_flags_ = hardware_flags; + + prtn_states_.resize(num_prtns_); + thread_states_.resize(dop_); + prtn_locks_.Init(dop_, num_prtns_); + + RowTableMetadata key_row_metadata; + key_row_metadata.FromColumnMetadataVector(key_types, + /*row_alignment=*/sizeof(uint64_t), + /*string_alignment=*/sizeof(uint64_t)); + RowTableMetadata payload_row_metadata; + payload_row_metadata.FromColumnMetadataVector(payload_types, + /*row_alignment=*/sizeof(uint64_t), + /*string_alignment=*/sizeof(uint64_t)); + + for (int i = 0; i < num_prtns_; ++i) { + PartitionState& prtn_state = prtn_states_[i]; + RETURN_NOT_OK(prtn_state.keys.Init(hardware_flags_, pool_)); + RETURN_NOT_OK(prtn_state.keys.keys()->InitIfNeeded(pool, key_row_metadata)); + RETURN_NOT_OK(prtn_state.payloads.InitIfNeeded(pool, payload_row_metadata)); + } + + target_->dop_ = dop_; + target_->local_states_.resize(dop_); + target_->no_payload_columns_ = no_payload; + target_->no_duplicate_keys_ = reject_duplicate_keys; + target_->map_.InitCallbacks(); + + return Status::OK(); +} + +Status SwissTableForJoinBuild::PushNextBatch(int64_t thread_id, + const ExecBatch& key_batch, + const ExecBatch* payload_batch_maybe_null, + util::TempVectorStack* temp_stack) { + ARROW_DCHECK(thread_id < dop_); + ThreadState& locals = thread_states_[thread_id]; + + // Compute hash + // + locals.batch_hashes.resize(key_batch.length); + RETURN_NOT_OK(Hashing32::HashBatch( + key_batch, locals.batch_hashes.data(), locals.temp_column_arrays, hardware_flags_, + temp_stack, /*start_row=*/0, static_cast(key_batch.length))); + + // Partition on hash + // + locals.batch_prtn_row_ids.resize(locals.batch_hashes.size()); + locals.batch_prtn_ranges.resize(num_prtns_ + 1); + int num_rows = static_cast(locals.batch_hashes.size()); + if (num_prtns_ == 1) { + // We treat single partition case separately to avoid extra checks in row + // partitioning implementation for general case. + // + locals.batch_prtn_ranges[0] = 0; + locals.batch_prtn_ranges[1] = num_rows; + for (int i = 0; i < num_rows; ++i) { + locals.batch_prtn_row_ids[i] = i; + } + } else { + PartitionSort::Eval( + static_cast(locals.batch_hashes.size()), num_prtns_, + locals.batch_prtn_ranges.data(), + [this, &locals](int64_t i) { + // SwissTable uses the highest bits of the hash for block index. + // We want each partition to correspond to a range of block indices, + // so we also partition on the highest bits of the hash. + // + return locals.batch_hashes[i] >> (31 - log_num_prtns_) >> 1; + }, + [&locals](int64_t i, int pos) { + locals.batch_prtn_row_ids[pos] = static_cast(i); + }); + } + + // Update hashes, shifting left to get rid of the bits that were already used + // for partitioning. + // + for (size_t i = 0; i < locals.batch_hashes.size(); ++i) { + locals.batch_hashes[i] <<= log_num_prtns_; + } + + // For each partition: + // - map keys to unique integers using (this partition's) hash table + // - append payloads (if present) to (this partition's) row array + // + locals.temp_prtn_ids.resize(num_prtns_); + + RETURN_NOT_OK(prtn_locks_.ForEachPartition( + thread_id, locals.temp_prtn_ids.data(), + /*is_prtn_empty_fn=*/ + [&](int prtn_id) { + return locals.batch_prtn_ranges[prtn_id + 1] == locals.batch_prtn_ranges[prtn_id]; + }, + /*process_prtn_fn=*/ + [&](int prtn_id) { + return ProcessPartition(thread_id, key_batch, payload_batch_maybe_null, + temp_stack, prtn_id); + })); + + return Status::OK(); +} + +Status SwissTableForJoinBuild::ProcessPartition(int64_t thread_id, + const ExecBatch& key_batch, + const ExecBatch* payload_batch_maybe_null, + util::TempVectorStack* temp_stack, + int prtn_id) { + ARROW_DCHECK(thread_id < dop_); + ThreadState& locals = thread_states_[thread_id]; + + int num_rows_new = + locals.batch_prtn_ranges[prtn_id + 1] - locals.batch_prtn_ranges[prtn_id]; + const uint16_t* row_ids = + locals.batch_prtn_row_ids.data() + locals.batch_prtn_ranges[prtn_id]; + PartitionState& prtn_state = prtn_states_[prtn_id]; + size_t num_rows_before = prtn_state.key_ids.size(); + // Insert new keys into hash table associated with the current partition + // and map existing keys to integer ids. + // + prtn_state.key_ids.resize(num_rows_before + num_rows_new); + SwissTableWithKeys::Input input(&key_batch, num_rows_new, row_ids, temp_stack, + &locals.temp_column_arrays, &locals.temp_group_ids); + RETURN_NOT_OK(prtn_state.keys.MapWithInserts( + &input, locals.batch_hashes.data(), prtn_state.key_ids.data() + num_rows_before)); + // Append input batch rows from current partition to an array of payload + // rows for this partition. + // + // The order of payloads is the same as the order of key ids accumulated + // in a vector (we will use the vector of key ids later on to sort + // payload on key ids before merging into the final row array). + // + if (!no_payload_) { + ARROW_DCHECK(payload_batch_maybe_null); + RETURN_NOT_OK(prtn_state.payloads.AppendBatchSelection( + pool_, *payload_batch_maybe_null, 0, + static_cast(payload_batch_maybe_null->length), num_rows_new, row_ids, + locals.temp_column_arrays)); + } + // We do not need to keep track of key ids if we reject rows with + // duplicate keys. + // + if (reject_duplicate_keys_) { + prtn_state.key_ids.clear(); + } + return Status::OK(); +} + +Status SwissTableForJoinBuild::PreparePrtnMerge() { + // There are 4 data structures that require partition merging: + // 1. array of key rows + // 2. SwissTable + // 3. array of payload rows (only when no_payload_ is false) + // 4. mapping from key id to first payload id (only when + // reject_duplicate_keys_ is false and there are duplicate keys) + // + + // 1. Array of key rows: + // + std::vector partition_keys; + partition_keys.resize(num_prtns_); + for (int i = 0; i < num_prtns_; ++i) { + partition_keys[i] = prtn_states_[i].keys.keys(); + } + RETURN_NOT_OK(RowArrayMerge::PrepareForMerge(target_->map_.keys(), partition_keys, + &partition_keys_first_row_id_, pool_)); + + // 2. SwissTable: + // + std::vector partition_tables; + partition_tables.resize(num_prtns_); + for (int i = 0; i < num_prtns_; ++i) { + partition_tables[i] = prtn_states_[i].keys.swiss_table(); + } + std::vector partition_first_group_id; + RETURN_NOT_OK(SwissTableMerge::PrepareForMerge( + target_->map_.swiss_table(), partition_tables, &partition_first_group_id, pool_)); + + // 3. Array of payload rows: + // + if (!no_payload_) { + std::vector partition_payloads; + partition_payloads.resize(num_prtns_); + for (int i = 0; i < num_prtns_; ++i) { + partition_payloads[i] = &prtn_states_[i].payloads; + } + RETURN_NOT_OK(RowArrayMerge::PrepareForMerge(&target_->payloads_, partition_payloads, + &partition_payloads_first_row_id_, + pool_)); + } + + // Check if we have duplicate keys + // + int64_t num_keys = partition_keys_first_row_id_[num_prtns_]; + int64_t num_rows = 0; + for (int i = 0; i < num_prtns_; ++i) { + num_rows += static_cast(prtn_states_[i].key_ids.size()); + } + bool no_duplicate_keys = reject_duplicate_keys_ || num_keys == num_rows; + + // 4. Mapping from key id to first payload id: + // + target_->no_duplicate_keys_ = no_duplicate_keys; + if (!no_duplicate_keys) { + target_->row_offset_for_key_.resize(num_keys + 1); + int64_t num_rows = 0; + for (int i = 0; i < num_prtns_; ++i) { + int64_t first_key = partition_keys_first_row_id_[i]; + target_->row_offset_for_key_[first_key] = static_cast(num_rows); + num_rows += static_cast(prtn_states_[i].key_ids.size()); + } + target_->row_offset_for_key_[num_keys] = static_cast(num_rows); + } + + return Status::OK(); +} + +void SwissTableForJoinBuild::PrtnMerge(int prtn_id) { + PartitionState& prtn_state = prtn_states_[prtn_id]; + + // There are 4 data structures that require partition merging: + // 1. array of key rows + // 2. SwissTable + // 3. mapping from key id to first payload id (only when + // reject_duplicate_keys_ is false and there are duplicate keys) + // 4. array of payload rows (only when no_payload_ is false) + // + + // 1. Array of key rows: + // + RowArrayMerge::MergeSingle(target_->map_.keys(), *prtn_state.keys.keys(), + partition_keys_first_row_id_[prtn_id], + /*source_rows_permutation=*/nullptr); + + // 2. SwissTable: + // + SwissTableMerge::MergePartition( + target_->map_.swiss_table(), prtn_state.keys.swiss_table(), prtn_id, log_num_prtns_, + static_cast(partition_keys_first_row_id_[prtn_id]), + &prtn_state.overflow_key_ids, &prtn_state.overflow_hashes); + + std::vector source_payload_ids; + + // 3. mapping from key id to first payload id + // + if (!target_->no_duplicate_keys_) { + // Count for each local (within partition) key id how many times it appears + // in input rows. + // + // For convenience, we use an array in merged hash table mapping key ids to + // first payload ids to collect the counters. + // + int64_t first_key = partition_keys_first_row_id_[prtn_id]; + int64_t num_keys = partition_keys_first_row_id_[prtn_id + 1] - first_key; + uint32_t* counters = target_->row_offset_for_key_.data() + first_key; + uint32_t first_payload = counters[0]; + for (int64_t i = 0; i < num_keys; ++i) { + counters[i] = 0; + } + for (size_t i = 0; i < prtn_state.key_ids.size(); ++i) { + uint32_t key_id = prtn_state.key_ids[i]; + ++counters[key_id]; + } + + if (!no_payload_) { + // Count sort payloads on key id + // + // Start by computing inclusive cummulative sum of counters. + // + uint32_t sum = 0; + for (int64_t i = 0; i < num_keys; ++i) { + sum += counters[i]; + counters[i] = sum; + } + // Now use cummulative sum of counters to obtain the target position in + // the sorted order for each row. At the end of this process the counters + // will contain exclusive cummulative sum (instead of inclusive that is + // there at the beginning). + // + source_payload_ids.resize(prtn_state.key_ids.size()); + for (size_t i = 0; i < prtn_state.key_ids.size(); ++i) { + uint32_t key_id = prtn_state.key_ids[i]; + int64_t position = --counters[key_id]; + source_payload_ids[position] = static_cast(i); + } + // Add base payload id to all of the counters. + // + for (int64_t i = 0; i < num_keys; ++i) { + counters[i] += first_payload; + } + } else { + // When there is no payload to process, we just need to compute exclusive + // cummulative sum of counters and add the base payload id to all of them. + // + uint32_t sum = 0; + for (int64_t i = 0; i < num_keys; ++i) { + uint32_t sum_next = sum + counters[i]; + counters[i] = sum + first_payload; + sum = sum_next; + } + } + } + + // 4. Array of payload rows: + // + if (!no_payload_) { + // If there are duplicate keys, then we have already initialized permutation + // of payloads for this partition. + // + if (target_->no_duplicate_keys_) { + source_payload_ids.resize(prtn_state.key_ids.size()); + for (size_t i = 0; i < prtn_state.key_ids.size(); ++i) { + uint32_t key_id = prtn_state.key_ids[i]; + source_payload_ids[key_id] = static_cast(i); + } + } + // Merge partition payloads into target array using the permutation. + // + RowArrayMerge::MergeSingle(&target_->payloads_, prtn_state.payloads, + partition_payloads_first_row_id_[prtn_id], + source_payload_ids.data()); + } +} + +void SwissTableForJoinBuild::FinishPrtnMerge(util::TempVectorStack* temp_stack) { + // Process overflow key ids + // + for (int prtn_id = 0; prtn_id < num_prtns_; ++prtn_id) { + SwissTableMerge::InsertNewGroups(target_->map_.swiss_table(), + prtn_states_[prtn_id].overflow_key_ids, + prtn_states_[prtn_id].overflow_hashes); + } + + // Calculate whether we have nulls in hash table keys + // (it is lazily evaluated but since we will be accessing it from multiple + // threads we need to make sure that the value gets calculated here). + // + LightContext ctx; + ctx.hardware_flags = hardware_flags_; + ctx.stack = temp_stack; + std::ignore = target_->map_.keys()->rows_.has_any_nulls(&ctx); +} + +void JoinResultMaterialize::Init(MemoryPool* pool, + const HashJoinProjectionMaps* probe_schemas, + const HashJoinProjectionMaps* build_schemas) { + pool_ = pool; + probe_schemas_ = probe_schemas; + build_schemas_ = build_schemas; + num_rows_ = 0; + null_ranges_.clear(); + num_produced_batches_ = 0; + + // Initialize mapping of columns from output batch column index to key and + // payload batch column index. + // + probe_output_to_key_and_payload_.resize( + probe_schemas_->num_cols(HashJoinProjection::OUTPUT)); + int num_key_cols = probe_schemas_->num_cols(HashJoinProjection::KEY); + auto to_key = probe_schemas_->map(HashJoinProjection::OUTPUT, HashJoinProjection::KEY); + auto to_payload = + probe_schemas_->map(HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD); + for (int i = 0; static_cast(i) < probe_output_to_key_and_payload_.size(); ++i) { + probe_output_to_key_and_payload_[i] = + to_key.get(i) == SchemaProjectionMap::kMissingField + ? to_payload.get(i) + num_key_cols + : to_key.get(i); + } +} + +void JoinResultMaterialize::SetBuildSide(const RowArray* build_keys, + const RowArray* build_payloads, + bool payload_id_same_as_key_id) { + build_keys_ = build_keys; + build_payloads_ = build_payloads; + payload_id_same_as_key_id_ = payload_id_same_as_key_id; +} + +bool JoinResultMaterialize::HasProbeOutput() const { + return probe_schemas_->num_cols(HashJoinProjection::OUTPUT) > 0; +} + +bool JoinResultMaterialize::HasBuildKeyOutput() const { + auto to_key = build_schemas_->map(HashJoinProjection::OUTPUT, HashJoinProjection::KEY); + for (int i = 0; i < build_schemas_->num_cols(HashJoinProjection::OUTPUT); ++i) { + if (to_key.get(i) != SchemaProjectionMap::kMissingField) { + return true; + } + } + return false; +} + +bool JoinResultMaterialize::HasBuildPayloadOutput() const { + auto to_payload = + build_schemas_->map(HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD); + for (int i = 0; i < build_schemas_->num_cols(HashJoinProjection::OUTPUT); ++i) { + if (to_payload.get(i) != SchemaProjectionMap::kMissingField) { + return true; + } + } + return false; +} + +bool JoinResultMaterialize::NeedsKeyId() const { + return HasBuildKeyOutput() || (HasBuildPayloadOutput() && payload_id_same_as_key_id_); +} + +bool JoinResultMaterialize::NeedsPayloadId() const { + return HasBuildPayloadOutput() && !payload_id_same_as_key_id_; +} + +Status JoinResultMaterialize::AppendProbeOnly(const ExecBatch& key_and_payload, + int num_rows_to_append, + const uint16_t* row_ids, + int* num_rows_appended) { + num_rows_to_append = + std::min(ExecBatchBuilder::num_rows_max() - num_rows_, num_rows_to_append); + if (HasProbeOutput()) { + RETURN_NOT_OK(batch_builder_.AppendSelected( + pool_, key_and_payload, num_rows_to_append, row_ids, + static_cast(probe_output_to_key_and_payload_.size()), + probe_output_to_key_and_payload_.data())); + } + if (!null_ranges_.empty() && + null_ranges_.back().first + null_ranges_.back().second == num_rows_) { + // We can extend the last range of null rows on build side. + // + null_ranges_.back().second += num_rows_to_append; + } else { + null_ranges_.push_back( + std::make_pair(static_cast(num_rows_), num_rows_to_append)); + } + num_rows_ += num_rows_to_append; + *num_rows_appended = num_rows_to_append; + return Status::OK(); +} + +Status JoinResultMaterialize::AppendBuildOnly(int num_rows_to_append, + const uint32_t* key_ids, + const uint32_t* payload_ids, + int* num_rows_appended) { + num_rows_to_append = + std::min(ExecBatchBuilder::num_rows_max() - num_rows_, num_rows_to_append); + if (HasProbeOutput()) { + RETURN_NOT_OK(batch_builder_.AppendNulls( + pool_, probe_schemas_->data_types(HashJoinProjection::OUTPUT), + num_rows_to_append)); + } + if (NeedsKeyId()) { + ARROW_DCHECK(key_ids != nullptr); + key_ids_.resize(num_rows_ + num_rows_to_append); + memcpy(key_ids_.data() + num_rows_, key_ids, num_rows_to_append * sizeof(uint32_t)); + } + if (NeedsPayloadId()) { + ARROW_DCHECK(payload_ids != nullptr); + payload_ids_.resize(num_rows_ + num_rows_to_append); + memcpy(payload_ids_.data() + num_rows_, payload_ids, + num_rows_to_append * sizeof(uint32_t)); + } + num_rows_ += num_rows_to_append; + *num_rows_appended = num_rows_to_append; + return Status::OK(); +} + +Status JoinResultMaterialize::Append(const ExecBatch& key_and_payload, + int num_rows_to_append, const uint16_t* row_ids, + const uint32_t* key_ids, const uint32_t* payload_ids, + int* num_rows_appended) { + num_rows_to_append = + std::min(ExecBatchBuilder::num_rows_max() - num_rows_, num_rows_to_append); + if (HasProbeOutput()) { + RETURN_NOT_OK(batch_builder_.AppendSelected( + pool_, key_and_payload, num_rows_to_append, row_ids, + static_cast(probe_output_to_key_and_payload_.size()), + probe_output_to_key_and_payload_.data())); + } + if (NeedsKeyId()) { + ARROW_DCHECK(key_ids != nullptr); + key_ids_.resize(num_rows_ + num_rows_to_append); + memcpy(key_ids_.data() + num_rows_, key_ids, num_rows_to_append * sizeof(uint32_t)); + } + if (NeedsPayloadId()) { + ARROW_DCHECK(payload_ids != nullptr); + payload_ids_.resize(num_rows_ + num_rows_to_append); + memcpy(payload_ids_.data() + num_rows_, payload_ids, + num_rows_to_append * sizeof(uint32_t)); + } + num_rows_ += num_rows_to_append; + *num_rows_appended = num_rows_to_append; + return Status::OK(); +} + +Result> JoinResultMaterialize::FlushBuildColumn( + const std::shared_ptr& data_type, const RowArray* row_array, int column_id, + uint32_t* row_ids) { + ResizableArrayData output; + output.Init(data_type, pool_, bit_util::Log2(num_rows_)); + + for (size_t i = 0; i <= null_ranges_.size(); ++i) { + int row_id_begin = + i == 0 ? 0 : null_ranges_[i - 1].first + null_ranges_[i - 1].second; + int row_id_end = i == null_ranges_.size() ? num_rows_ : null_ranges_[i].first; + if (row_id_end > row_id_begin) { + RETURN_NOT_OK(row_array->DecodeSelected( + &output, column_id, row_id_end - row_id_begin, row_ids + row_id_begin, pool_)); + } + int num_nulls = i == null_ranges_.size() ? 0 : null_ranges_[i].second; + if (num_nulls > 0) { + RETURN_NOT_OK(ExecBatchBuilder::AppendNulls(data_type, output, num_nulls, pool_)); + } + } + + return output.array_data(); +} + +Status JoinResultMaterialize::Flush(ExecBatch* out) { + if (num_rows_ == 0) { + return Status::OK(); + } + + out->length = num_rows_; + out->values.clear(); + + int num_probe_cols = probe_schemas_->num_cols(HashJoinProjection::OUTPUT); + int num_build_cols = build_schemas_->num_cols(HashJoinProjection::OUTPUT); + out->values.resize(num_probe_cols + num_build_cols); + + if (HasProbeOutput()) { + ExecBatch probe_batch = batch_builder_.Flush(); + ARROW_DCHECK(static_cast(probe_batch.values.size()) == num_probe_cols); + for (size_t i = 0; i < probe_batch.values.size(); ++i) { + out->values[i] = std::move(probe_batch.values[i]); + } + } + auto to_key = build_schemas_->map(HashJoinProjection::OUTPUT, HashJoinProjection::KEY); + auto to_payload = + build_schemas_->map(HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD); + for (int i = 0; i < num_build_cols; ++i) { + if (to_key.get(i) != SchemaProjectionMap::kMissingField) { + std::shared_ptr column; + ARROW_ASSIGN_OR_RAISE( + column, + FlushBuildColumn(build_schemas_->data_type(HashJoinProjection::OUTPUT, i), + build_keys_, to_key.get(i), key_ids_.data())); + out->values[num_probe_cols + i] = std::move(column); + } else if (to_payload.get(i) != SchemaProjectionMap::kMissingField) { + std::shared_ptr column; + ARROW_ASSIGN_OR_RAISE( + column, + FlushBuildColumn( + build_schemas_->data_type(HashJoinProjection::OUTPUT, i), build_payloads_, + to_payload.get(i), + payload_id_same_as_key_id_ ? key_ids_.data() : payload_ids_.data())); + out->values[num_probe_cols + i] = std::move(column); + } else { + ARROW_DCHECK(false); + } + } + + num_rows_ = 0; + key_ids_.clear(); + payload_ids_.clear(); + null_ranges_.clear(); + + ++num_produced_batches_; + + return Status::OK(); +} + +void JoinNullFilter::Filter(const ExecBatch& key_batch, int batch_start_row, + int num_batch_rows, const std::vector& cmp, + bool* all_valid, bool and_with_input, + uint8_t* inout_bit_vector) { + // AND together validity vectors for columns that use equality comparison. + // + bool is_output_initialized = and_with_input; + for (size_t i = 0; i < cmp.size(); ++i) { + // No null filtering if null == null is true + // + if (cmp[i] != JoinKeyCmp::EQ) { + continue; + } + + // No null filtering when there are no nulls + // + const Datum& data = key_batch.values[i]; + ARROW_DCHECK(data.is_array()); + const std::shared_ptr& array_data = data.array(); + if (!array_data->buffers[0]) { + continue; + } + + const uint8_t* non_null_buffer = array_data->buffers[0]->data(); + int64_t offset = array_data->offset + batch_start_row; + + // Filter out nulls for this column + // + if (!is_output_initialized) { + memset(inout_bit_vector, 0xff, bit_util::BytesForBits(num_batch_rows)); + is_output_initialized = true; + } + arrow::internal::BitmapAnd(inout_bit_vector, 0, non_null_buffer, offset, + num_batch_rows, 0, inout_bit_vector); + } + *all_valid = !is_output_initialized; +} + +void JoinMatchIterator::SetLookupResult(int num_batch_rows, int start_batch_row, + const uint8_t* batch_has_match, + const uint32_t* key_ids, bool no_duplicate_keys, + const uint32_t* key_to_payload) { + num_batch_rows_ = num_batch_rows; + start_batch_row_ = start_batch_row; + batch_has_match_ = batch_has_match; + key_ids_ = key_ids; + + no_duplicate_keys_ = no_duplicate_keys; + key_to_payload_ = key_to_payload; + + current_row_ = 0; + current_match_for_row_ = 0; +} + +bool JoinMatchIterator::GetNextBatch(int num_rows_max, int* out_num_rows, + uint16_t* batch_row_ids, uint32_t* key_ids, + uint32_t* payload_ids) { + *out_num_rows = 0; + + if (no_duplicate_keys_) { + // When every input key can have at most one match, + // then we only need to filter according to has match bit vector. + // + // We stop when either we produce a full batch or when we reach the end of + // matches to output. + // + while (current_row_ < num_batch_rows_ && *out_num_rows < num_rows_max) { + batch_row_ids[*out_num_rows] = start_batch_row_ + current_row_; + key_ids[*out_num_rows] = payload_ids[*out_num_rows] = key_ids_[current_row_]; + (*out_num_rows) += bit_util::GetBit(batch_has_match_, current_row_) ? 1 : 0; + ++current_row_; + } + } else { + // When every input key can have zero, one or many matches, + // then we need to filter out ones with no match and + // iterate over all matches for the remaining ones. + // + // We stop when either we produce a full batch or when we reach the end of + // matches to output. + // + while (current_row_ < num_batch_rows_ && *out_num_rows < num_rows_max) { + if (!bit_util::GetBit(batch_has_match_, current_row_)) { + ++current_row_; + current_match_for_row_ = 0; + continue; + } + uint32_t base_payload_id = key_to_payload_[key_ids_[current_row_]]; + + // Total number of matches for the currently selected input row + // + int num_matches_total = + key_to_payload_[key_ids_[current_row_] + 1] - base_payload_id; + + // Number of remaining matches for the currently selected input row + // + int num_matches_left = num_matches_total - current_match_for_row_; + + // Number of matches for the currently selected input row that will fit + // into the next batch + // + int num_matches_next = std::min(num_matches_left, num_rows_max - *out_num_rows); + + for (int imatch = 0; imatch < num_matches_next; ++imatch) { + batch_row_ids[*out_num_rows] = start_batch_row_ + current_row_; + key_ids[*out_num_rows] = key_ids_[current_row_]; + payload_ids[*out_num_rows] = base_payload_id + current_match_for_row_ + imatch; + ++(*out_num_rows); + } + current_match_for_row_ += num_matches_next; + + if (current_match_for_row_ == num_matches_total) { + ++current_row_; + current_match_for_row_ = 0; + } + } + } + + return (*out_num_rows) > 0; +} + +void JoinProbeProcessor::Init(int num_key_columns, JoinType join_type, + SwissTableForJoin* hash_table, + std::vector materialize, + const std::vector* cmp, + OutputBatchFn output_batch_fn) { + num_key_columns_ = num_key_columns; + join_type_ = join_type; + hash_table_ = hash_table; + materialize_.resize(materialize.size()); + for (size_t i = 0; i < materialize.size(); ++i) { + materialize_[i] = materialize[i]; + } + cmp_ = cmp; + output_batch_fn_ = output_batch_fn; +} + +Status JoinProbeProcessor::OnNextBatch(int64_t thread_id, + const ExecBatch& keypayload_batch, + util::TempVectorStack* temp_stack, + std::vector* temp_column_arrays) { + const SwissTable* swiss_table = hash_table_->keys()->swiss_table(); + int64_t hardware_flags = swiss_table->hardware_flags(); + int minibatch_size = swiss_table->minibatch_size(); + int num_rows = static_cast(keypayload_batch.length); + + ExecBatch key_batch({}, keypayload_batch.length); + key_batch.values.resize(num_key_columns_); + for (int i = 0; i < num_key_columns_; ++i) { + key_batch.values[i] = keypayload_batch.values[i]; + } + + // Break into mini-batches + // + // Start by allocating mini-batch buffers + // + auto hashes_buf = util::TempVectorHolder(temp_stack, minibatch_size); + auto match_bitvector_buf = util::TempVectorHolder( + temp_stack, static_cast(bit_util::BytesForBits(minibatch_size))); + auto key_ids_buf = util::TempVectorHolder(temp_stack, minibatch_size); + auto materialize_batch_ids_buf = + util::TempVectorHolder(temp_stack, minibatch_size); + auto materialize_key_ids_buf = + util::TempVectorHolder(temp_stack, minibatch_size); + auto materialize_payload_ids_buf = + util::TempVectorHolder(temp_stack, minibatch_size); + + for (int minibatch_start = 0; minibatch_start < num_rows;) { + uint32_t minibatch_size_next = std::min(minibatch_size, num_rows - minibatch_start); + + SwissTableWithKeys::Input input(&key_batch, minibatch_start, + minibatch_start + minibatch_size_next, temp_stack, + temp_column_arrays); + hash_table_->keys()->Hash(&input, hashes_buf.mutable_data(), hardware_flags); + hash_table_->keys()->MapReadOnly(&input, hashes_buf.mutable_data(), + match_bitvector_buf.mutable_data(), + key_ids_buf.mutable_data()); + + // AND bit vector with null key filter for join + // + bool ignored; + JoinNullFilter::Filter(key_batch, minibatch_start, minibatch_size_next, *cmp_, + &ignored, + /*and_with_input=*/true, match_bitvector_buf.mutable_data()); + // Semi-joins + // + if (join_type_ == JoinType::LEFT_SEMI || join_type_ == JoinType::LEFT_ANTI || + join_type_ == JoinType::RIGHT_SEMI || join_type_ == JoinType::RIGHT_ANTI) { + int num_passing_ids = 0; + util::bit_util::bits_to_indexes( + (join_type_ == JoinType::LEFT_ANTI) ? 0 : 1, hardware_flags, + minibatch_size_next, match_bitvector_buf.mutable_data(), &num_passing_ids, + materialize_batch_ids_buf.mutable_data()); + + // For right-semi, right-anti joins: update has-match flags for the rows + // in hash table. + // + if (join_type_ == JoinType::RIGHT_SEMI || join_type_ == JoinType::RIGHT_ANTI) { + for (int i = 0; i < num_passing_ids; ++i) { + uint16_t id = materialize_batch_ids_buf.mutable_data()[i]; + key_ids_buf.mutable_data()[i] = key_ids_buf.mutable_data()[id]; + } + hash_table_->UpdateHasMatchForKeys(thread_id, num_passing_ids, + key_ids_buf.mutable_data()); + } else { + // For left-semi, left-anti joins: call materialize using match + // bit-vector. + // + + // Add base batch row index. + // + for (int i = 0; i < num_passing_ids; ++i) { + materialize_batch_ids_buf.mutable_data()[i] += + static_cast(minibatch_start); + } + + RETURN_NOT_OK(materialize_[thread_id]->AppendProbeOnly( + keypayload_batch, num_passing_ids, materialize_batch_ids_buf.mutable_data(), + [&](ExecBatch batch) { output_batch_fn_(thread_id, std::move(batch)); })); + } + } else { + // We need to output matching pairs of rows from both sides of the join. + // Since every hash table lookup for an input row might have multiple + // matches we use a helper class that implements enumerating all of them. + // + bool no_duplicate_keys = (hash_table_->key_to_payload() == nullptr); + bool no_payload_columns = (hash_table_->payloads() == nullptr); + JoinMatchIterator match_iterator; + match_iterator.SetLookupResult( + minibatch_size_next, minibatch_start, match_bitvector_buf.mutable_data(), + key_ids_buf.mutable_data(), no_duplicate_keys, hash_table_->key_to_payload()); + int num_matches_next; + while (match_iterator.GetNextBatch(minibatch_size, &num_matches_next, + materialize_batch_ids_buf.mutable_data(), + materialize_key_ids_buf.mutable_data(), + materialize_payload_ids_buf.mutable_data())) { + const uint16_t* materialize_batch_ids = materialize_batch_ids_buf.mutable_data(); + const uint32_t* materialize_key_ids = materialize_key_ids_buf.mutable_data(); + const uint32_t* materialize_payload_ids = + no_duplicate_keys || no_payload_columns + ? materialize_key_ids_buf.mutable_data() + : materialize_payload_ids_buf.mutable_data(); + + // For right-outer, full-outer joins we need to update has-match flags + // for the rows in hash table. + // + if (join_type_ == JoinType::RIGHT_OUTER || join_type_ == JoinType::FULL_OUTER) { + hash_table_->UpdateHasMatchForKeys(thread_id, num_matches_next, + materialize_key_ids); + } + + // Call materialize for resulting id tuples pointing to matching pairs + // of rows. + // + RETURN_NOT_OK(materialize_[thread_id]->Append( + keypayload_batch, num_matches_next, materialize_batch_ids, + materialize_key_ids, materialize_payload_ids, + [&](ExecBatch batch) { output_batch_fn_(thread_id, std::move(batch)); })); + } + + // For left-outer and full-outer joins output non-matches. + // + // Call materialize. Nulls will be output in all columns that come from + // the other side of the join. + // + if (join_type_ == JoinType::LEFT_OUTER || join_type_ == JoinType::FULL_OUTER) { + int num_passing_ids = 0; + util::bit_util::bits_to_indexes( + /*bit_to_search=*/0, hardware_flags, minibatch_size_next, + match_bitvector_buf.mutable_data(), &num_passing_ids, + materialize_batch_ids_buf.mutable_data()); + + // Add base batch row index. + // + for (int i = 0; i < num_passing_ids; ++i) { + materialize_batch_ids_buf.mutable_data()[i] += + static_cast(minibatch_start); + } + + RETURN_NOT_OK(materialize_[thread_id]->AppendProbeOnly( + keypayload_batch, num_passing_ids, materialize_batch_ids_buf.mutable_data(), + [&](ExecBatch batch) { output_batch_fn_(thread_id, std::move(batch)); })); + } + } + + minibatch_start += minibatch_size_next; + } + + return Status::OK(); +} + +Status JoinProbeProcessor::OnFinished() { + // Flush all instances of materialize that have non-zero accumulated output + // rows. + // + for (size_t i = 0; i < materialize_.size(); ++i) { + JoinResultMaterialize& materialize = *materialize_[i]; + RETURN_NOT_OK(materialize.Flush( + [&](ExecBatch batch) { output_batch_fn_(i, std::move(batch)); })); + } + + return Status::OK(); +} + +class SwissJoin : public HashJoinImpl { + public: + Status Init(ExecContext* ctx, JoinType join_type, size_t num_threads, + const HashJoinProjectionMaps* proj_map_left, + const HashJoinProjectionMaps* proj_map_right, + std::vector key_cmp, Expression filter, + OutputBatchCallback output_batch_callback, + FinishedCallback finished_callback, TaskScheduler* scheduler) override { + START_COMPUTE_SPAN(span_, "SwissJoinImpl", + {{"detail", filter.ToString()}, + {"join.kind", arrow::compute::ToString(join_type)}, + {"join.threads", static_cast(num_threads)}}); + + num_threads_ = static_cast(num_threads); + ctx_ = ctx; + hardware_flags_ = ctx->cpu_info()->hardware_flags(); + pool_ = ctx->memory_pool(); + + join_type_ = join_type; + key_cmp_.resize(key_cmp.size()); + for (size_t i = 0; i < key_cmp.size(); ++i) { + key_cmp_[i] = key_cmp[i]; + } + schema_[0] = proj_map_left; + schema_[1] = proj_map_right; + output_batch_callback_ = output_batch_callback; + finished_callback_ = finished_callback; + scheduler_ = scheduler; + hash_table_ready_.store(false); + cancelled_.store(false); + { + std::lock_guard lock(state_mutex_); + left_side_finished_ = false; + right_side_finished_ = false; + error_status_ = Status::OK(); + } + + local_states_.resize(num_threads_); + for (int i = 0; i < num_threads_; ++i) { + local_states_[i].hash_table_ready = false; + local_states_[i].num_output_batches = 0; + RETURN_NOT_OK(CancelIfNotOK(local_states_[i].temp_stack.Init( + pool_, 1024 + 64 * util::MiniBatch::kMiniBatchLength))); + local_states_[i].materialize.Init(pool_, proj_map_left, proj_map_right); + } + + std::vector materialize; + materialize.resize(num_threads_); + for (int i = 0; i < num_threads_; ++i) { + materialize[i] = &local_states_[i].materialize; + } + + probe_processor_.Init(proj_map_left->num_cols(HashJoinProjection::KEY), join_type_, + &hash_table_, materialize, &key_cmp_, output_batch_callback_); + + InitTaskGroups(); + + return Status::OK(); + } + + void InitTaskGroups() { + task_group_build_ = scheduler_->RegisterTaskGroup( + [this](size_t thread_index, int64_t task_id) -> Status { + return BuildTask(thread_index, task_id); + }, + [this](size_t thread_index) -> Status { return BuildFinished(thread_index); }); + task_group_merge_ = scheduler_->RegisterTaskGroup( + [this](size_t thread_index, int64_t task_id) -> Status { + return MergeTask(thread_index, task_id); + }, + [this](size_t thread_index) -> Status { return MergeFinished(thread_index); }); + task_group_scan_ = scheduler_->RegisterTaskGroup( + [this](size_t thread_index, int64_t task_id) -> Status { + return ScanTask(thread_index, task_id); + }, + [this](size_t thread_index) -> Status { return ScanFinished(thread_index); }); + } + + Status ProbeSingleBatch(size_t thread_index, ExecBatch batch) override { + if (IsCancelled()) { + return status(); + } + + if (!local_states_[thread_index].hash_table_ready) { + local_states_[thread_index].hash_table_ready = hash_table_ready_.load(); + } + ARROW_DCHECK(local_states_[thread_index].hash_table_ready); + + ExecBatch keypayload_batch; + ARROW_ASSIGN_OR_RAISE(keypayload_batch, KeyPayloadFromInput(/*side=*/0, &batch)); + + return CancelIfNotOK(probe_processor_.OnNextBatch( + thread_index, keypayload_batch, &local_states_[thread_index].temp_stack, + &local_states_[thread_index].temp_column_arrays)); + } + + Status ProbingFinished(size_t thread_index) override { + if (IsCancelled()) { + return status(); + } + + return CancelIfNotOK(StartScanHashTable(static_cast(thread_index))); + } + + Status BuildHashTable(size_t thread_id, AccumulationQueue batches, + BuildFinishedCallback on_finished) override { + if (IsCancelled()) { + return status(); + } + + build_side_batches_ = std::move(batches); + build_finished_callback_ = on_finished; + + return CancelIfNotOK(StartBuildHashTable(static_cast(thread_id))); + } + + void Abort(TaskScheduler::AbortContinuationImpl pos_abort_callback) override { + EVENT(span_, "Abort"); + END_SPAN(span_); + std::ignore = CancelIfNotOK(Status::Cancelled("Hash Join Cancelled")); + scheduler_->Abort(std::move(pos_abort_callback)); + } + + std::string ToString() const override { return "SwissJoin"; } + + private: + Status StartBuildHashTable(int64_t thread_id) { + // Initialize build class instance + // + const HashJoinProjectionMaps* schema = schema_[1]; + bool reject_duplicate_keys = + join_type_ == JoinType::LEFT_SEMI || join_type_ == JoinType::LEFT_ANTI; + bool no_payload = + reject_duplicate_keys || schema->num_cols(HashJoinProjection::PAYLOAD) == 0; + + std::vector key_types; + for (int i = 0; i < schema->num_cols(HashJoinProjection::KEY); ++i) { + ARROW_ASSIGN_OR_RAISE( + KeyColumnMetadata metadata, + ColumnMetadataFromDataType(schema->data_type(HashJoinProjection::KEY, i))); + key_types.push_back(metadata); + } + std::vector payload_types; + for (int i = 0; i < schema->num_cols(HashJoinProjection::PAYLOAD); ++i) { + ARROW_ASSIGN_OR_RAISE( + KeyColumnMetadata metadata, + ColumnMetadataFromDataType(schema->data_type(HashJoinProjection::PAYLOAD, i))); + payload_types.push_back(metadata); + } + RETURN_NOT_OK(CancelIfNotOK(hash_table_build_.Init( + &hash_table_, num_threads_, build_side_batches_.row_count(), + reject_duplicate_keys, no_payload, key_types, payload_types, pool_, + hardware_flags_))); + + // Process all input batches + // + return CancelIfNotOK(scheduler_->StartTaskGroup(static_cast(thread_id), + task_group_build_, + build_side_batches_.batch_count())); + } + + Status BuildTask(size_t thread_id, int64_t batch_id) { + if (IsCancelled()) { + return Status::OK(); + } + + const HashJoinProjectionMaps* schema = schema_[1]; + bool no_payload = hash_table_build_.no_payload(); + + ExecBatch input_batch; + ARROW_ASSIGN_OR_RAISE( + input_batch, KeyPayloadFromInput(/*side=*/1, &build_side_batches_[batch_id])); + + if (input_batch.length == 0) { + return Status::OK(); + } + + // Split batch into key batch and optional payload batch + // + // Input batch is key-payload batch (key columns followed by payload + // columns). We split it into two separate batches. + // + // TODO: Change SwissTableForJoinBuild interface to use key-payload + // batch instead to avoid this operation, which involves increasing + // shared pointer ref counts. + // + ExecBatch key_batch({}, input_batch.length); + key_batch.values.resize(schema->num_cols(HashJoinProjection::KEY)); + for (size_t icol = 0; icol < key_batch.values.size(); ++icol) { + key_batch.values[icol] = input_batch.values[icol]; + } + ExecBatch payload_batch({}, input_batch.length); + + if (!no_payload) { + payload_batch.values.resize(schema->num_cols(HashJoinProjection::PAYLOAD)); + for (size_t icol = 0; icol < payload_batch.values.size(); ++icol) { + payload_batch.values[icol] = + input_batch.values[schema->num_cols(HashJoinProjection::KEY) + icol]; + } + } + RETURN_NOT_OK(CancelIfNotOK(hash_table_build_.PushNextBatch( + static_cast(thread_id), key_batch, no_payload ? nullptr : &payload_batch, + &local_states_[thread_id].temp_stack))); + + // Release input batch + // + input_batch.values.clear(); + + return Status::OK(); + } + + Status BuildFinished(size_t thread_id) { + RETURN_NOT_OK(status()); + + build_side_batches_.Clear(); + + // On a single thread prepare for merging partitions of the resulting hash + // table. + // + RETURN_NOT_OK(CancelIfNotOK(hash_table_build_.PreparePrtnMerge())); + return CancelIfNotOK(scheduler_->StartTaskGroup(thread_id, task_group_merge_, + hash_table_build_.num_prtns())); + } + + Status MergeTask(size_t /*thread_id*/, int64_t prtn_id) { + if (IsCancelled()) { + return Status::OK(); + } + hash_table_build_.PrtnMerge(static_cast(prtn_id)); + return Status::OK(); + } + + Status MergeFinished(size_t thread_id) { + RETURN_NOT_OK(status()); + hash_table_build_.FinishPrtnMerge(&local_states_[thread_id].temp_stack); + return CancelIfNotOK(OnBuildHashTableFinished(static_cast(thread_id))); + } + + Status OnBuildHashTableFinished(int64_t thread_id) { + if (IsCancelled()) { + return status(); + } + + for (int i = 0; i < num_threads_; ++i) { + local_states_[i].materialize.SetBuildSide(hash_table_.keys()->keys(), + hash_table_.payloads(), + hash_table_.key_to_payload() == nullptr); + } + hash_table_ready_.store(true); + + return build_finished_callback_(thread_id); + } + + Status StartScanHashTable(int64_t thread_id) { + if (IsCancelled()) { + return status(); + } + + bool need_to_scan = + (join_type_ == JoinType::RIGHT_SEMI || join_type_ == JoinType::RIGHT_ANTI || + join_type_ == JoinType::RIGHT_OUTER || join_type_ == JoinType::FULL_OUTER); + + if (need_to_scan) { + hash_table_.MergeHasMatch(); + int64_t num_tasks = bit_util::CeilDiv(hash_table_.num_rows(), kNumRowsPerScanTask); + + return CancelIfNotOK(scheduler_->StartTaskGroup(static_cast(thread_id), + task_group_scan_, num_tasks)); + } else { + return CancelIfNotOK(OnScanHashTableFinished()); + } + } + + Status ScanTask(size_t thread_id, int64_t task_id) { + if (IsCancelled()) { + return Status::OK(); + } + + // Should we output matches or non-matches? + // + bool bit_to_output = (join_type_ == JoinType::RIGHT_SEMI); + + int64_t start_row = task_id * kNumRowsPerScanTask; + int64_t end_row = + std::min((task_id + 1) * kNumRowsPerScanTask, hash_table_.num_rows()); + // Get thread index and related temp vector stack + // + util::TempVectorStack* temp_stack = &local_states_[thread_id].temp_stack; + + // Split into mini-batches + // + auto payload_ids_buf = + util::TempVectorHolder(temp_stack, util::MiniBatch::kMiniBatchLength); + auto key_ids_buf = + util::TempVectorHolder(temp_stack, util::MiniBatch::kMiniBatchLength); + auto selection_buf = + util::TempVectorHolder(temp_stack, util::MiniBatch::kMiniBatchLength); + for (int64_t mini_batch_start = start_row; mini_batch_start < end_row;) { + // Compute the size of the next mini-batch + // + int64_t mini_batch_size_next = + std::min(end_row - mini_batch_start, + static_cast(util::MiniBatch::kMiniBatchLength)); + + // Get the list of key and payload ids from this mini-batch to output. + // + uint32_t first_key_id = + hash_table_.payload_id_to_key_id(static_cast(mini_batch_start)); + uint32_t last_key_id = hash_table_.payload_id_to_key_id( + static_cast(mini_batch_start + mini_batch_size_next - 1)); + int num_output_rows = 0; + for (uint32_t key_id = first_key_id; key_id <= last_key_id; ++key_id) { + if (bit_util::GetBit(hash_table_.has_match(), key_id) == bit_to_output) { + uint32_t first_payload_for_key = + std::max(static_cast(mini_batch_start), + hash_table_.key_to_payload() ? hash_table_.key_to_payload()[key_id] + : key_id); + uint32_t last_payload_for_key = std::min( + static_cast(mini_batch_start + mini_batch_size_next - 1), + hash_table_.key_to_payload() ? hash_table_.key_to_payload()[key_id + 1] - 1 + : key_id); + uint32_t num_payloads_for_key = + last_payload_for_key - first_payload_for_key + 1; + for (uint32_t i = 0; i < num_payloads_for_key; ++i) { + key_ids_buf.mutable_data()[num_output_rows + i] = key_id; + payload_ids_buf.mutable_data()[num_output_rows + i] = + first_payload_for_key + i; + } + num_output_rows += num_payloads_for_key; + } + } + + if (num_output_rows > 0) { + // Materialize (and output whenever buffers get full) hash table + // values according to the generated list of ids. + // + Status status = local_states_[thread_id].materialize.AppendBuildOnly( + num_output_rows, key_ids_buf.mutable_data(), payload_ids_buf.mutable_data(), + [&](ExecBatch batch) { + output_batch_callback_(static_cast(thread_id), std::move(batch)); + }); + RETURN_NOT_OK(CancelIfNotOK(status)); + if (!status.ok()) { + break; + } + } + mini_batch_start += mini_batch_size_next; + } + + return Status::OK(); + } + + Status ScanFinished(size_t thread_id) { + if (IsCancelled()) { + return status(); + } + + return CancelIfNotOK(OnScanHashTableFinished()); + } + + Status OnScanHashTableFinished() { + if (IsCancelled()) { + return status(); + } + END_SPAN(span_); + + // Flush all instances of materialize that have non-zero accumulated output + // rows. + // + RETURN_NOT_OK(CancelIfNotOK(probe_processor_.OnFinished())); + + int64_t num_produced_batches = 0; + for (size_t i = 0; i < local_states_.size(); ++i) { + JoinResultMaterialize& materialize = local_states_[i].materialize; + num_produced_batches += materialize.num_produced_batches(); + } + + finished_callback_(num_produced_batches); + + return Status::OK(); + } + + Result KeyPayloadFromInput(int side, ExecBatch* input) { + ExecBatch projected({}, input->length); + int num_key_cols = schema_[side]->num_cols(HashJoinProjection::KEY); + int num_payload_cols = schema_[side]->num_cols(HashJoinProjection::PAYLOAD); + projected.values.resize(num_key_cols + num_payload_cols); + + auto key_to_input = + schema_[side]->map(HashJoinProjection::KEY, HashJoinProjection::INPUT); + for (int icol = 0; icol < num_key_cols; ++icol) { + const Datum& value_in = input->values[key_to_input.get(icol)]; + if (value_in.is_scalar()) { + ARROW_ASSIGN_OR_RAISE( + projected.values[icol], + MakeArrayFromScalar(*value_in.scalar(), projected.length, pool_)); + } else { + projected.values[icol] = value_in; + } + } + auto payload_to_input = + schema_[side]->map(HashJoinProjection::PAYLOAD, HashJoinProjection::INPUT); + for (int icol = 0; icol < num_payload_cols; ++icol) { + const Datum& value_in = input->values[payload_to_input.get(icol)]; + if (value_in.is_scalar()) { + ARROW_ASSIGN_OR_RAISE( + projected.values[num_key_cols + icol], + MakeArrayFromScalar(*value_in.scalar(), projected.length, pool_)); + } else { + projected.values[num_key_cols + icol] = value_in; + } + } + + return projected; + } + + bool IsCancelled() { return cancelled_.load(); } + + Status status() { + if (IsCancelled()) { + std::lock_guard lock(state_mutex_); + return error_status_; + } + return Status::OK(); + } + + Status CancelIfNotOK(Status status) { + if (!status.ok()) { + { + std::lock_guard lock(state_mutex_); + // Only update the status for the first error encountered. + // + if (error_status_.ok()) { + error_status_ = status; + } + } + cancelled_.store(true); + } + return status; + } + + static constexpr int kNumRowsPerScanTask = 512 * 1024; + + ExecContext* ctx_; + int64_t hardware_flags_; + MemoryPool* pool_; + int num_threads_; + JoinType join_type_; + std::vector key_cmp_; + const HashJoinProjectionMaps* schema_[2]; + + // Task scheduling + TaskScheduler* scheduler_; + int task_group_build_; + int task_group_merge_; + int task_group_scan_; + + // Callbacks + OutputBatchCallback output_batch_callback_; + BuildFinishedCallback build_finished_callback_; + FinishedCallback finished_callback_; + + struct ThreadLocalState { + JoinResultMaterialize materialize; + util::TempVectorStack temp_stack; + std::vector temp_column_arrays; + int64_t num_output_batches; + bool hash_table_ready; + }; + std::vector local_states_; + + SwissTableForJoin hash_table_; + JoinProbeProcessor probe_processor_; + SwissTableForJoinBuild hash_table_build_; + AccumulationQueue build_side_batches_; + + // Atomic state flags. + // These flags are kept outside of mutex, since they can be queried for every + // batch. + // + // The other flags that follow them, protected by mutex, will be queried or + // updated only a fixed number of times during entire join processing. + // + std::atomic hash_table_ready_; + std::atomic cancelled_; + + // Mutex protecting state flags. + // + std::mutex state_mutex_; + + // Mutex protected state flags. + // + bool left_side_finished_; + bool right_side_finished_; + Status error_status_; +}; + +Result> HashJoinImpl::MakeSwiss() { + std::unique_ptr impl{new SwissJoin()}; + return std::move(impl); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/swiss_join.h b/cpp/src/arrow/compute/exec/swiss_join.h new file mode 100644 index 00000000000..bf3273c4e04 --- /dev/null +++ b/cpp/src/arrow/compute/exec/swiss_join.h @@ -0,0 +1,761 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include "arrow/compute/exec/key_map.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/partition_util.h" +#include "arrow/compute/exec/schema_util.h" +#include "arrow/compute/exec/task_util.h" +#include "arrow/compute/kernels/row_encoder.h" +#include "arrow/compute/light_array.h" +#include "arrow/compute/row/encode_internal.h" + +namespace arrow { +namespace compute { + +class RowArrayAccessor { + public: + // Find the index of this varbinary column within the sequence of all + // varbinary columns encoded in rows. + // + static int VarbinaryColumnId(const RowTableMetadata& row_metadata, int column_id); + + // Calculate how many rows to skip from the tail of the + // sequence of selected rows, such that the total size of skipped rows is at + // least equal to the size specified by the caller. Skipping of the tail rows + // is used to allow for faster processing by the caller of remaining rows + // without checking buffer bounds (useful with SIMD or fixed size memory loads + // and stores). + // + static int NumRowsToSkip(const RowTableImpl& rows, int column_id, int num_rows, + const uint32_t* row_ids, int num_tail_bytes_to_skip); + + // The supplied lambda will be called for each row in the given list of rows. + // The arguments given to it will be: + // - index of a row (within the set of selected rows), + // - pointer to the value, + // - byte length of the value. + // + // The information about nulls (validity bitmap) is not used in this call and + // has to be processed separately. + // + template + static void Visit(const RowTableImpl& rows, int column_id, int num_rows, + const uint32_t* row_ids, PROCESS_VALUE_FN process_value_fn); + + // The supplied lambda will be called for each row in the given list of rows. + // The arguments given to it will be: + // - index of a row (within the set of selected rows), + // - byte 0xFF if the null is set for the row or 0x00 otherwise. + // + template + static void VisitNulls(const RowTableImpl& rows, int column_id, int num_rows, + const uint32_t* row_ids, PROCESS_VALUE_FN process_value_fn); + + private: +#if defined(ARROW_HAVE_AVX2) + // This is equivalent to Visit method, but processing 8 rows at a time in a + // loop. + // Returns the number of processed rows, which may be less than requested (up + // to 7 rows at the end may be skipped). + // + template + static int Visit_avx2(const RowTableImpl& rows, int column_id, int num_rows, + const uint32_t* row_ids, PROCESS_8_VALUES_FN process_8_values_fn); + + // This is equivalent to VisitNulls method, but processing 8 rows at a time in + // a loop. Returns the number of processed rows, which may be less than + // requested (up to 7 rows at the end may be skipped). + // + template + static int VisitNulls_avx2(const RowTableImpl& rows, int column_id, int num_rows, + const uint32_t* row_ids, + PROCESS_8_VALUES_FN process_8_values_fn); +#endif +}; + +// Write operations (appending batch rows) must not be called by more than one +// thread at the same time. +// +// Read operations (row comparison, column decoding) +// can be called by multiple threads concurrently. +// +struct RowArray { + RowArray() : is_initialized_(false) {} + + Status InitIfNeeded(MemoryPool* pool, const ExecBatch& batch); + Status InitIfNeeded(MemoryPool* pool, const RowTableMetadata& row_metadata); + + Status AppendBatchSelection(MemoryPool* pool, const ExecBatch& batch, int begin_row_id, + int end_row_id, int num_row_ids, const uint16_t* row_ids, + std::vector& temp_column_arrays); + + // This can only be called for a minibatch. + // + void Compare(const ExecBatch& batch, int begin_row_id, int end_row_id, int num_selected, + const uint16_t* batch_selection_maybe_null, const uint32_t* array_row_ids, + uint32_t* out_num_not_equal, uint16_t* out_not_equal_selection, + int64_t hardware_flags, util::TempVectorStack* temp_stack, + std::vector& temp_column_arrays, + uint8_t* out_match_bitvector_maybe_null = NULLPTR); + + // TODO: add AVX2 version + // + Status DecodeSelected(ResizableArrayData* target, int column_id, int num_rows_to_append, + const uint32_t* row_ids, MemoryPool* pool) const; + + void DebugPrintToFile(const char* filename, bool print_sorted) const; + + int64_t num_rows() const { return is_initialized_ ? rows_.length() : 0; } + + bool is_initialized_; + RowTableEncoder encoder_; + RowTableImpl rows_; + RowTableImpl rows_temp_; +}; + +// Implements concatenating multiple row arrays into a single one, using +// potentially multiple threads, each processing a single input row array. +// +class RowArrayMerge { + public: + // Calculate total number of rows and size in bytes for merged sequence of + // rows and allocate memory for it. + // + // If the rows are of varying length, initialize in the offset array the first + // entry for the write area for each input row array. Leave all other + // offsets and buffers uninitialized. + // + // All input sources must be initialized, but they can contain zero rows. + // + // Output in vector the first target row id for each source (exclusive + // cummulative sum of number of rows in sources). This output is optional, + // caller can pass in nullptr to indicate that it is not needed. + // + static Status PrepareForMerge(RowArray* target, const std::vector& sources, + std::vector* first_target_row_id, + MemoryPool* pool); + + // Copy rows from source array to target array. + // Both arrays must have the same row metadata. + // Target array must already have the memory reserved in all internal buffers + // for the copy of the rows. + // + // Copy of the rows will occupy the same amount of space in the target array + // buffers as in the source array, but in the target array we pick at what row + // position and offset we start writing. + // + // Optionally, the rows may be reordered during copy according to the + // provided permutation, which represents some sorting order of source rows. + // Nth element of the permutation array is the source row index for the Nth + // row written into target array. If permutation is missing (null), then the + // order of source rows will remain unchanged. + // + // In case of varying length rows, we purposefully skip outputting of N+1 (one + // after last) offset, to allow concurrent copies of rows done to adjacent + // ranges in the target array. This offset should already contain the right + // value after calling the method preparing target array for merge (which + // initializes boundary offsets for target row ranges for each source). + // + static void MergeSingle(RowArray* target, const RowArray& source, + int64_t first_target_row_id, + const int64_t* source_rows_permutation); + + private: + // Copy rows from source array to a region of the target array. + // This implementation is for fixed length rows. + // Null information needs to be handled separately. + // + static void CopyFixedLength(RowTableImpl* target, const RowTableImpl& source, + int64_t first_target_row_id, + const int64_t* source_rows_permutation); + + // Copy rows from source array to a region of the target array. + // This implementation is for varying length rows. + // Null information needs to be handled separately. + // + static void CopyVaryingLength(RowTableImpl* target, const RowTableImpl& source, + int64_t first_target_row_id, + int64_t first_target_row_offset, + const int64_t* source_rows_permutation); + + // Copy null information from rows from source array to a region of the target + // array. + // + static void CopyNulls(RowTableImpl* target, const RowTableImpl& source, + int64_t first_target_row_id, + const int64_t* source_rows_permutation); +}; + +// Implements merging of multiple SwissTables into a single one, using +// potentially multiple threads, each processing a single input source. +// +// Each source should correspond to a range of original hashes. +// A row belongs to a source with index determined by K highest bits of +// original hash. That means that the number of sources must be a power of 2. +// +// We assume that the hash values used and stored inside source tables +// have K highest bits removed from the original hash in order to avoid huge +// number of hash collisions that would occur otherwise. +// These bits will be reinserted back (original hashes will be used) when +// merging into target. +// +class SwissTableMerge { + public: + // Calculate total number of blocks for merged table. + // Allocate buffers sized accordingly and initialize empty target table. + // + // All input sources must be initialized, but they can be empty. + // + // Output in a vector the first target group id for each source (exclusive + // cummulative sum of number of groups in sources). This output is optional, + // caller can pass in nullptr to indicate that it is not needed. + // + static Status PrepareForMerge(SwissTable* target, + const std::vector& sources, + std::vector* first_target_group_id, + MemoryPool* pool); + + // Copy all entries from source to a range of blocks (partition) of target. + // + // During copy, adjust group ids from source by adding provided base id. + // + // Skip entries from source that would cross partition boundaries (range of + // blocks) when inserted into target. Save their data in output vector for + // processing later. We postpone inserting these overflow entries in order to + // allow concurrent processing of all partitions. Overflow entries will be + // handled by a single-thread afterwards. + // + static void MergePartition(SwissTable* target, const SwissTable* source, + uint32_t partition_id, int num_partition_bits, + uint32_t base_group_id, + std::vector* overflow_group_ids, + std::vector* overflow_hashes); + + // Single-threaded processing of remaining groups, that could not be + // inserted in partition merge phase + // (due to entries from one partition spilling over due to full blocks into + // the next partition). + // + static void InsertNewGroups(SwissTable* target, const std::vector& group_ids, + const std::vector& hashes); + + private: + // Insert a new group id. + // + // Assumes that there are enough slots in the target + // and there is no need to resize it. + // + // Max block id can be provided, in which case the search for an empty slot to + // insert new entry to will stop after visiting that block. + // + // Max block id value greater or equal to the number of blocks guarantees that + // the search will not be stopped. + // + static inline bool InsertNewGroup(SwissTable* target, uint64_t group_id, uint32_t hash, + int64_t max_block_id); +}; + +struct SwissTableWithKeys { + struct Input { + Input(const ExecBatch* in_batch, int in_batch_start_row, int in_batch_end_row, + util::TempVectorStack* in_temp_stack, + std::vector* in_temp_column_arrays); + + Input(const ExecBatch* in_batch, util::TempVectorStack* in_temp_stack, + std::vector* in_temp_column_arrays); + + Input(const ExecBatch* in_batch, int in_num_selected, const uint16_t* in_selection, + util::TempVectorStack* in_temp_stack, + std::vector* in_temp_column_arrays, + std::vector* in_temp_group_ids); + + Input(const Input& base, int num_rows_to_skip, int num_rows_to_include); + + const ExecBatch* batch; + // Window of the batch to operate on. + // The window information is only used if row selection is null. + // + int batch_start_row; + int batch_end_row; + // Optional selection. + // Used instead of window of the batch if not null. + // + int num_selected; + const uint16_t* selection_maybe_null; + // Thread specific scratch buffers for storing temporary data. + // + util::TempVectorStack* temp_stack; + std::vector* temp_column_arrays; + std::vector* temp_group_ids; + }; + + Status Init(int64_t hardware_flags, MemoryPool* pool); + + void InitCallbacks(); + + static void Hash(Input* input, uint32_t* hashes, int64_t hardware_flags); + + // If input uses selection, then hashes array must have one element for every + // row in the whole (unfiltered and not spliced) input exec batch. Otherwise, + // there must be one element in hashes array for every value in the window of + // the exec batch specified by input. + // + // Output arrays will contain one element for every selected batch row in + // input (selected either by selection vector if provided or input window + // otherwise). + // + void MapReadOnly(Input* input, const uint32_t* hashes, uint8_t* match_bitvector, + uint32_t* key_ids); + Status MapWithInserts(Input* input, const uint32_t* hashes, uint32_t* key_ids); + + SwissTable* swiss_table() { return &swiss_table_; } + const SwissTable* swiss_table() const { return &swiss_table_; } + RowArray* keys() { return &keys_; } + const RowArray* keys() const { return &keys_; } + + private: + void EqualCallback(int num_keys, const uint16_t* selection_maybe_null, + const uint32_t* group_ids, uint32_t* out_num_keys_mismatch, + uint16_t* out_selection_mismatch, void* callback_ctx); + Status AppendCallback(int num_keys, const uint16_t* selection, void* callback_ctx); + Status Map(Input* input, bool insert_missing, const uint32_t* hashes, + uint8_t* match_bitvector_maybe_null, uint32_t* key_ids); + + SwissTable::EqualImpl equal_impl_; + SwissTable::AppendImpl append_impl_; + + SwissTable swiss_table_; + RowArray keys_; +}; + +// Enhances SwissTableWithKeys with the following structures used by hash join: +// - storage of payloads (that unlike keys do not have to be unique) +// - mapping from a key to all inserted payloads corresponding to it (we can +// store multiple rows corresponding to a single key) +// - bit-vectors for keeping track of whether each payload had a match during +// evaluation of join. +// +class SwissTableForJoin { + friend class SwissTableForJoinBuild; + + public: + void UpdateHasMatchForKeys(int64_t thread_id, int num_rows, const uint32_t* key_ids); + void MergeHasMatch(); + + const SwissTableWithKeys* keys() const { return &map_; } + SwissTableWithKeys* keys() { return &map_; } + const RowArray* payloads() const { return no_payload_columns_ ? NULLPTR : &payloads_; } + const uint32_t* key_to_payload() const { + return no_duplicate_keys_ ? NULLPTR : row_offset_for_key_.data(); + } + const uint8_t* has_match() const { + return has_match_.empty() ? NULLPTR : has_match_.data(); + } + int64_t num_keys() const { return map_.keys()->num_rows(); } + int64_t num_rows() const { + return no_duplicate_keys_ ? num_keys() : row_offset_for_key_[num_keys()]; + } + + uint32_t payload_id_to_key_id(uint32_t payload_id) const; + // Input payload ids must form an increasing sequence. + // + void payload_ids_to_key_ids(int num_rows, const uint32_t* payload_ids, + uint32_t* key_ids) const; + + private: + uint8_t* local_has_match(int64_t thread_id); + + // Degree of parallelism (number of threads) + int dop_; + + struct ThreadLocalState { + std::vector has_match; + }; + std::vector local_states_; + std::vector has_match_; + + SwissTableWithKeys map_; + + bool no_duplicate_keys_; + // Not used if no_duplicate_keys_ is true. + std::vector row_offset_for_key_; + + bool no_payload_columns_; + // Not used if no_payload_columns_ is true. + RowArray payloads_; +}; + +// Implements parallel build process for hash table for join from a sequence of +// exec batches with input rows. +// +class SwissTableForJoinBuild { + public: + Status Init(SwissTableForJoin* target, int dop, int64_t num_rows, + bool reject_duplicate_keys, bool no_payload, + const std::vector& key_types, + const std::vector& payload_types, MemoryPool* pool, + int64_t hardware_flags); + + // In the first phase of parallel hash table build, threads pick unprocessed + // exec batches, partition the rows based on hash, and update all of the + // partitions with information related to that batch of rows. + // + Status PushNextBatch(int64_t thread_id, const ExecBatch& key_batch, + const ExecBatch* payload_batch_maybe_null, + util::TempVectorStack* temp_stack); + + // Allocate memory and initialize counters required for parallel merging of + // hash table partitions. + // Single-threaded. + // + Status PreparePrtnMerge(); + + // Second phase of parallel hash table build. + // Each partition can be processed by a different thread. + // Parallel step. + // + void PrtnMerge(int prtn_id); + + // Single-threaded processing of the rows that have been skipped during + // parallel merging phase, due to hash table search resulting in crossing + // partition boundaries. + // + void FinishPrtnMerge(util::TempVectorStack* temp_stack); + + // The number of partitions is the number of parallel tasks to execute during + // the final phase of hash table build process. + // + int num_prtns() const { return num_prtns_; } + + bool no_payload() const { return no_payload_; } + + private: + void InitRowArray(); + Status ProcessPartition(int64_t thread_id, const ExecBatch& key_batch, + const ExecBatch* payload_batch_maybe_null, + util::TempVectorStack* temp_stack, int prtn_id); + + SwissTableForJoin* target_; + // DOP stands for Degree Of Parallelism - the maximum number of participating + // threads. + // + int dop_; + // Partition is a unit of parallel work. + // + // There must be power of 2 partitions (bits of hash will be used to + // identify them). + // + // Pick number of partitions at least equal to the number of threads (degree + // of parallelism). + // + int log_num_prtns_; + int num_prtns_; + int64_t num_rows_; + // Left-semi and left-anti-semi joins do not need more than one copy of the + // same key in the hash table. + // This flag, if set, will result in filtering rows with duplicate keys before + // inserting them into hash table. + // + // Since left-semi and left-anti-semi joins also do not need payload, when + // this flag is set there also will not be any processing of payload. + // + bool reject_duplicate_keys_; + // This flag, when set, will result in skipping any processing of the payload. + // + // The flag for rejecting duplicate keys (which should be set for left-semi + // and left-anti joins), when set, will force this flag to also be set, but + // other join flavors may set it to true as well if no payload columns are + // needed for join output. + // + bool no_payload_; + MemoryPool* pool_; + int64_t hardware_flags_; + + // One per partition. + // + struct PartitionState { + SwissTableWithKeys keys; + RowArray payloads; + std::vector key_ids; + std::vector overflow_key_ids; + std::vector overflow_hashes; + }; + + // One per thread. + // + // Buffers for storing temporary intermediate results when processing input + // batches. + // + struct ThreadState { + std::vector batch_hashes; + std::vector batch_prtn_ranges; + std::vector batch_prtn_row_ids; + std::vector temp_prtn_ids; + std::vector temp_group_ids; + std::vector temp_column_arrays; + }; + + std::vector prtn_states_; + std::vector thread_states_; + PartitionLocks prtn_locks_; + + std::vector partition_keys_first_row_id_; + std::vector partition_payloads_first_row_id_; +}; + +class JoinResultMaterialize { + public: + void Init(MemoryPool* pool, const HashJoinProjectionMaps* probe_schemas, + const HashJoinProjectionMaps* build_schemas); + + void SetBuildSide(const RowArray* build_keys, const RowArray* build_payloads, + bool payload_id_same_as_key_id); + + // Input probe side batches should contain all key columns followed by all + // payload columns. + // + Status AppendProbeOnly(const ExecBatch& key_and_payload, int num_rows_to_append, + const uint16_t* row_ids, int* num_rows_appended); + + Status AppendBuildOnly(int num_rows_to_append, const uint32_t* key_ids, + const uint32_t* payload_ids, int* num_rows_appended); + + Status Append(const ExecBatch& key_and_payload, int num_rows_to_append, + const uint16_t* row_ids, const uint32_t* key_ids, + const uint32_t* payload_ids, int* num_rows_appended); + + // Should only be called if num_rows() returns non-zero. + // + Status Flush(ExecBatch* out); + + int num_rows() const { return num_rows_; } + + template + Status AppendAndOutput(int num_rows_to_append, const APPEND_ROWS_FN& append_rows_fn, + const OUTPUT_BATCH_FN& output_batch_fn) { + int offset = 0; + for (;;) { + int num_rows_appended = 0; + ARROW_RETURN_NOT_OK(append_rows_fn(num_rows_to_append, offset, &num_rows_appended)); + if (num_rows_appended < num_rows_to_append) { + ExecBatch batch; + ARROW_RETURN_NOT_OK(Flush(&batch)); + output_batch_fn(batch); + num_rows_to_append -= num_rows_appended; + offset += num_rows_appended; + } else { + break; + } + } + return Status::OK(); + } + + template + Status AppendProbeOnly(const ExecBatch& key_and_payload, int num_rows_to_append, + const uint16_t* row_ids, OUTPUT_BATCH_FN output_batch_fn) { + return AppendAndOutput( + num_rows_to_append, + [&](int num_rows_to_append_left, int offset, int* num_rows_appended) { + return AppendProbeOnly(key_and_payload, num_rows_to_append_left, + row_ids + offset, num_rows_appended); + }, + output_batch_fn); + } + + template + Status AppendBuildOnly(int num_rows_to_append, const uint32_t* key_ids, + const uint32_t* payload_ids, OUTPUT_BATCH_FN output_batch_fn) { + return AppendAndOutput( + num_rows_to_append, + [&](int num_rows_to_append_left, int offset, int* num_rows_appended) { + return AppendBuildOnly( + num_rows_to_append_left, key_ids ? key_ids + offset : NULLPTR, + payload_ids ? payload_ids + offset : NULLPTR, num_rows_appended); + }, + output_batch_fn); + } + + template + Status Append(const ExecBatch& key_and_payload, int num_rows_to_append, + const uint16_t* row_ids, const uint32_t* key_ids, + const uint32_t* payload_ids, OUTPUT_BATCH_FN output_batch_fn) { + return AppendAndOutput( + num_rows_to_append, + [&](int num_rows_to_append_left, int offset, int* num_rows_appended) { + return Append(key_and_payload, num_rows_to_append_left, + row_ids ? row_ids + offset : NULLPTR, + key_ids ? key_ids + offset : NULLPTR, + payload_ids ? payload_ids + offset : NULLPTR, num_rows_appended); + }, + output_batch_fn); + } + + template + Status Flush(OUTPUT_BATCH_FN output_batch_fn) { + if (num_rows_ > 0) { + ExecBatch batch({}, num_rows_); + ARROW_RETURN_NOT_OK(Flush(&batch)); + output_batch_fn(std::move(batch)); + } + return Status::OK(); + } + + int64_t num_produced_batches() const { return num_produced_batches_; } + + private: + bool HasProbeOutput() const; + bool HasBuildKeyOutput() const; + bool HasBuildPayloadOutput() const; + bool NeedsKeyId() const; + bool NeedsPayloadId() const; + Result> FlushBuildColumn( + const std::shared_ptr& data_type, const RowArray* row_array, + int column_id, uint32_t* row_ids); + + MemoryPool* pool_; + const HashJoinProjectionMaps* probe_schemas_; + const HashJoinProjectionMaps* build_schemas_; + const RowArray* build_keys_; + // Payload array pointer may be left as null, if no payload columns are + // in the output column set. + // + const RowArray* build_payloads_; + // If true, then ignore updating payload ids and use key ids instead when + // reading. + // + bool payload_id_same_as_key_id_; + std::vector probe_output_to_key_and_payload_; + + // Number of accumulated rows (since last flush) + // + int num_rows_; + // Accumulated output columns from probe side batches. + // + ExecBatchBuilder batch_builder_; + // Accumulated build side row references. + // + std::vector key_ids_; + std::vector payload_ids_; + // Information about ranges of rows from build side, + // that in the accumulated materialized results have all fields set to null. + // + // Each pair contains index of the first output row in the range and the + // length of the range. Only rows outside of these ranges have data present in + // the key_ids_ and payload_ids_ arrays. + // + std::vector> null_ranges_; + + int64_t num_produced_batches_; +}; + +// When comparing two join key values to check if they are equal, hash join allows to +// chose (even separately for each field within the join key) whether two null values are +// considered to be equal (IS comparison) or not (EQ comparison). For EQ comparison we +// need to filter rows with nulls in keys outside of hash table lookups, since hash table +// implementation always treats two nulls as equal (like IS comparison). +// +// Implements evaluating filter bit vector eliminating rows that do not have +// join matches due to nulls in key columns. +// +class JoinNullFilter { + public: + // The batch for which the filter bit vector will be computed + // needs to start with all key columns but it may contain more columns + // (payload) following them. + // + static void Filter(const ExecBatch& key_batch, int batch_start_row, int num_batch_rows, + const std::vector& cmp, bool* all_valid, + bool and_with_input, uint8_t* out_bit_vector); +}; + +// A helper class that takes hash table lookup results for a range of rows in +// input batch, that is: +// - bit vector marking whether there was a key match in the hash table +// - key id if there was a match +// - mapping from key id to a range of payload ids associated with that key +// (representing multiple matching rows in a hash table for a single row in an +// input batch), and iterates output batches of limited size containing tuples +// describing all matching pairs of rows: +// - input batch row id (only rows that have matches in the hash table are +// included) +// - key id for a match +// - payload id (different one for each matching row in the hash table) +// +class JoinMatchIterator { + public: + void SetLookupResult(int num_batch_rows, int start_batch_row, + const uint8_t* batch_has_match, const uint32_t* key_ids, + bool no_duplicate_keys, const uint32_t* key_to_payload); + bool GetNextBatch(int num_rows_max, int* out_num_rows, uint16_t* batch_row_ids, + uint32_t* key_ids, uint32_t* payload_ids); + + private: + int num_batch_rows_; + int start_batch_row_; + const uint8_t* batch_has_match_; + const uint32_t* key_ids_; + + bool no_duplicate_keys_; + const uint32_t* key_to_payload_; + + // Index of the first not fully processed input row, or number of rows if all + // have been processed. May be pointing to a row with no matches. + // + int current_row_; + // Index of the first unprocessed match for the input row. May be zero if the + // row has no matches. + // + int current_match_for_row_; +}; + +// Implements entire processing of a probe side exec batch, +// provided the join hash table is already built and available. +// +class JoinProbeProcessor { + public: + using OutputBatchFn = std::function; + + void Init(int num_key_columns, JoinType join_type, SwissTableForJoin* hash_table, + std::vector materialize, + const std::vector* cmp, OutputBatchFn output_batch_fn); + Status OnNextBatch(int64_t thread_id, const ExecBatch& keypayload_batch, + util::TempVectorStack* temp_stack, + std::vector* temp_column_arrays); + + // Must be called by a single-thread having exclusive access to the instance + // of this class. The caller is responsible for ensuring that. + // + Status OnFinished(); + + private: + int num_key_columns_; + JoinType join_type_; + + SwissTableForJoin* hash_table_; + // One element per thread + // + std::vector materialize_; + const std::vector* cmp_; + OutputBatchFn output_batch_fn_; +}; + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/swiss_join_avx2.cc b/cpp/src/arrow/compute/exec/swiss_join_avx2.cc new file mode 100644 index 00000000000..261b458132f --- /dev/null +++ b/cpp/src/arrow/compute/exec/swiss_join_avx2.cc @@ -0,0 +1,197 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/compute/exec/swiss_join.h" +#include "arrow/util/bit_util.h" + +namespace arrow { +namespace compute { + +#if defined(ARROW_HAVE_AVX2) + +template +int RowArrayAccessor::Visit_avx2(const RowTableImpl& rows, int column_id, int num_rows, + const uint32_t* row_ids, + PROCESS_8_VALUES_FN process_8_values_fn) { + // Number of rows processed together in a single iteration of the loop (single + // call to the provided processing lambda). + // + constexpr int unroll = 8; + + bool is_fixed_length_column = + rows.metadata().column_metadatas[column_id].is_fixed_length; + + // There are 4 cases, each requiring different steps: + // 1. Varying length column that is the first varying length column in a row + // 2. Varying length column that is not the first varying length column in a + // row + // 3. Fixed length column in a fixed length row + // 4. Fixed length column in a varying length row + + if (!is_fixed_length_column) { + int varbinary_column_id = VarbinaryColumnId(rows.metadata(), column_id); + const uint8_t* row_ptr_base = rows.data(2); + const uint32_t* row_offsets = rows.offsets(); + + if (varbinary_column_id == 0) { + // Case 1: This is the first varbinary column + // + __m256i field_offset_within_row = _mm256_set1_epi32(rows.metadata().fixed_length); + __m256i varbinary_end_array_offset = + _mm256_set1_epi32(rows.metadata().varbinary_end_array_offset); + for (int i = 0; i < num_rows / unroll; ++i) { + __m256i row_id = + _mm256_loadu_si256(reinterpret_cast(row_ids) + i); + __m256i row_offset = _mm256_i32gather_epi32( + reinterpret_cast(row_offsets), row_id, sizeof(uint32_t)); + __m256i field_length = _mm256_sub_epi32( + _mm256_i32gather_epi32( + reinterpret_cast(row_ptr_base), + _mm256_add_epi32(row_offset, varbinary_end_array_offset), 1), + field_offset_within_row); + process_8_values_fn(i * unroll, row_ptr_base, + _mm256_add_epi32(row_offset, field_offset_within_row), + field_length); + } + } else { + // Case 2: This is second or later varbinary column + // + __m256i varbinary_end_array_offset = + _mm256_set1_epi32(rows.metadata().varbinary_end_array_offset + + sizeof(uint32_t) * (varbinary_column_id - 1)); + auto row_ptr_base_i64 = + reinterpret_cast(row_ptr_base); + for (int i = 0; i < num_rows / unroll; ++i) { + __m256i row_id = + _mm256_loadu_si256(reinterpret_cast(row_ids) + i); + __m256i row_offset = _mm256_i32gather_epi32( + reinterpret_cast(row_offsets), row_id, sizeof(uint32_t)); + __m256i end_array_offset = + _mm256_add_epi32(row_offset, varbinary_end_array_offset); + + __m256i field_offset_within_row_A = _mm256_i32gather_epi64( + row_ptr_base_i64, _mm256_castsi256_si128(end_array_offset), 1); + __m256i field_offset_within_row_B = _mm256_i32gather_epi64( + row_ptr_base_i64, _mm256_extracti128_si256(end_array_offset, 1), 1); + field_offset_within_row_A = _mm256_permutevar8x32_epi32( + field_offset_within_row_A, _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)); + field_offset_within_row_B = _mm256_permutevar8x32_epi32( + field_offset_within_row_B, _mm256_setr_epi32(1, 3, 5, 7, 0, 2, 4, 6)); + + __m256i field_offset_within_row = _mm256_blend_epi32( + field_offset_within_row_A, field_offset_within_row_B, 0xf0); + + __m256i alignment_padding = + _mm256_andnot_si256(field_offset_within_row, _mm256_set1_epi8(0xff)); + alignment_padding = _mm256_add_epi32(alignment_padding, _mm256_set1_epi32(1)); + alignment_padding = _mm256_and_si256( + alignment_padding, _mm256_set1_epi32(rows.metadata().string_alignment - 1)); + + field_offset_within_row = + _mm256_add_epi32(field_offset_within_row, alignment_padding); + + __m256i field_length = _mm256_blend_epi32(field_offset_within_row_A, + field_offset_within_row_B, 0x0f); + field_length = _mm256_permute4x64_epi64(field_length, + 0x4e); // Swapping low and high 128-bits + field_length = _mm256_sub_epi32(field_length, field_offset_within_row); + + process_8_values_fn(i * unroll, row_ptr_base, + _mm256_add_epi32(row_offset, field_offset_within_row), + field_length); + } + } + } + + if (is_fixed_length_column) { + __m256i field_offset_within_row = + _mm256_set1_epi32(rows.metadata().encoded_field_offset( + rows.metadata().pos_after_encoding(column_id))); + __m256i field_length = + _mm256_set1_epi32(rows.metadata().column_metadatas[column_id].fixed_length); + + bool is_fixed_length_row = rows.metadata().is_fixed_length; + if (is_fixed_length_row) { + // Case 3: This is a fixed length column in fixed length row + // + const uint8_t* row_ptr_base = rows.data(1); + for (int i = 0; i < num_rows / unroll; ++i) { + __m256i row_id = + _mm256_loadu_si256(reinterpret_cast(row_ids) + i); + __m256i row_offset = _mm256_mullo_epi32(row_id, field_length); + __m256i field_offset = _mm256_add_epi32(row_offset, field_offset_within_row); + process_8_values_fn(i * unroll, row_ptr_base, field_offset, field_length); + } + } else { + // Case 4: This is a fixed length column in varying length row + // + const uint8_t* row_ptr_base = rows.data(2); + const uint32_t* row_offsets = rows.offsets(); + for (int i = 0; i < num_rows / unroll; ++i) { + __m256i row_id = + _mm256_loadu_si256(reinterpret_cast(row_ids) + i); + __m256i row_offset = _mm256_i32gather_epi32( + reinterpret_cast(row_offsets), row_id, sizeof(uint32_t)); + __m256i field_offset = _mm256_add_epi32(row_offset, field_offset_within_row); + process_8_values_fn(i * unroll, row_ptr_base, field_offset, field_length); + } + } + } + + return num_rows - (num_rows % unroll); +} + +template +int RowArrayAccessor::VisitNulls_avx2(const RowTableImpl& rows, int column_id, + int num_rows, const uint32_t* row_ids, + PROCESS_8_VALUES_FN process_8_values_fn) { + // Number of rows processed together in a single iteration of the loop (single + // call to the provided processing lambda). + // + constexpr int unroll = 8; + + const uint8_t* null_masks = rows.null_masks(); + __m256i null_bits_per_row = + _mm256_set1_epi32(8 * rows.metadata().null_masks_bytes_per_row); + for (int i = 0; i < num_rows / unroll; ++i) { + __m256i row_id = _mm256_loadu_si256(reinterpret_cast(row_ids) + i); + __m256i bit_id = _mm256_mullo_epi32(row_id, null_bits_per_row); + bit_id = _mm256_add_epi32(bit_id, _mm256_set1_epi32(column_id)); + __m256i bytes = _mm256_i32gather_epi32(reinterpret_cast(null_masks), + _mm256_srli_epi32(bit_id, 3), 1); + __m256i bit_in_word = _mm256_sllv_epi32( + _mm256_set1_epi32(1), _mm256_and_si256(bit_id, _mm256_set1_epi32(7))); + __m256i result = + _mm256_cmpeq_epi32(_mm256_and_si256(bytes, bit_in_word), bit_in_word); + uint64_t null_bytes = static_cast( + _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(result)))); + null_bytes |= static_cast(_mm256_movemask_epi8( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(result, 1)))) + << 32; + + process_8_values_fn(i * unroll, null_bytes); + } + + return num_rows - (num_rows % unroll); +} + +#endif + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/util.h b/cpp/src/arrow/compute/exec/util.h index 839a8a7d29c..1a635857f91 100644 --- a/cpp/src/arrow/compute/exec/util.h +++ b/cpp/src/arrow/compute/exec/util.h @@ -77,7 +77,8 @@ using int64_for_gather_t = const long long int; // NOLINT runtime-int // class MiniBatch { public: - static constexpr int kMiniBatchLength = 1024; + static constexpr int kLogMiniBatchLength = 10; + static constexpr int kMiniBatchLength = 1 << kLogMiniBatchLength; }; /// Storage used to allocate temporary vectors of a batch size. @@ -295,5 +296,51 @@ class ARROW_EXPORT ThreadIndexer { std::unordered_map id_to_index_; }; +// Helper class to calculate the modified number of rows to process using SIMD. +// +// Some array elements at the end will be skipped in order to avoid buffer +// overrun, when doing memory loads and stores using larger word size than a +// single array element. +// +class TailSkipForSIMD { + public: + static int64_t FixBitAccess(int num_bytes_accessed_together, int64_t num_rows, + int bit_offset) { + int64_t num_bytes = bit_util::BytesForBits(num_rows + bit_offset); + int64_t num_bytes_safe = + std::max(static_cast(0LL), num_bytes - num_bytes_accessed_together + 1); + int64_t num_rows_safe = + std::max(static_cast(0LL), 8 * num_bytes_safe - bit_offset); + return std::min(num_rows_safe, num_rows); + } + static int64_t FixBinaryAccess(int num_bytes_accessed_together, int64_t num_rows, + int64_t length) { + int64_t num_rows_to_skip = bit_util::CeilDiv(length, num_bytes_accessed_together); + int64_t num_rows_safe = + std::max(static_cast(0LL), num_rows - num_rows_to_skip); + return num_rows_safe; + } + static int64_t FixVarBinaryAccess(int num_bytes_accessed_together, int64_t num_rows, + const uint32_t* offsets) { + // Do not process rows that could read past the end of the buffer using N + // byte loads/stores. + // + int64_t num_rows_safe = num_rows; + while (num_rows_safe > 0 && + offsets[num_rows_safe] + num_bytes_accessed_together > offsets[num_rows]) { + --num_rows_safe; + } + return num_rows_safe; + } + static int FixSelection(int64_t num_rows_safe, int num_selected, + const uint16_t* selection) { + int num_selected_safe = num_selected; + while (num_selected_safe > 0 && selection[num_selected_safe] >= num_rows_safe) { + --num_selected_safe; + } + return num_selected_safe; + } +}; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/util_test.cc b/cpp/src/arrow/compute/exec/util_test.cc index 6d859917351..3861446bb3c 100644 --- a/cpp/src/arrow/compute/exec/util_test.cc +++ b/cpp/src/arrow/compute/exec/util_test.cc @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/compute/exec/hash_join.h" +#include "arrow/compute/exec/hash_join_node.h" #include "arrow/compute/exec/schema_util.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" diff --git a/cpp/src/arrow/compute/row/compare_internal.cc b/cpp/src/arrow/compute/row/compare_internal.cc index e863c9cd05f..750012e60e2 100644 --- a/cpp/src/arrow/compute/row/compare_internal.cc +++ b/cpp/src/arrow/compute/row/compare_internal.cc @@ -35,7 +35,8 @@ void KeyCompare::NullUpdateColumnToRow(uint32_t id_col, uint32_t num_rows_to_com const uint32_t* left_to_right_map, LightContext* ctx, const KeyColumnArray& col, const RowTableImpl& rows, - uint8_t* match_bytevector) { + uint8_t* match_bytevector, + bool are_cols_in_encoding_order) { if (!rows.has_any_nulls(ctx) && !col.data(0)) { return; } @@ -48,6 +49,9 @@ void KeyCompare::NullUpdateColumnToRow(uint32_t id_col, uint32_t num_rows_to_com } #endif + uint32_t null_bit_id = + are_cols_in_encoding_order ? id_col : rows.metadata().pos_after_encoding(id_col); + if (!col.data(0)) { // Remove rows from the result for which the column value is a null const uint8_t* null_masks = rows.null_masks(); @@ -55,11 +59,12 @@ void KeyCompare::NullUpdateColumnToRow(uint32_t id_col, uint32_t num_rows_to_com for (uint32_t i = num_processed; i < num_rows_to_compare; ++i) { uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i; uint32_t irow_right = left_to_right_map[irow_left]; - int64_t bitid = irow_right * null_mask_num_bytes * 8 + id_col; + int64_t bitid = irow_right * null_mask_num_bytes * 8 + null_bit_id; match_bytevector[i] &= (bit_util::GetBit(null_masks, bitid) ? 0 : 0xff); } } else if (!rows.has_any_nulls(ctx)) { - // Remove rows from the result for which the column value on left side is null + // Remove rows from the result for which the column value on left side is + // null const uint8_t* non_nulls = col.data(0); ARROW_DCHECK(non_nulls); for (uint32_t i = num_processed; i < num_rows_to_compare; ++i) { @@ -75,7 +80,7 @@ void KeyCompare::NullUpdateColumnToRow(uint32_t id_col, uint32_t num_rows_to_com for (uint32_t i = num_processed; i < num_rows_to_compare; ++i) { uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i; uint32_t irow_right = left_to_right_map[irow_left]; - int64_t bitid_right = irow_right * null_mask_num_bytes * 8 + id_col; + int64_t bitid_right = irow_right * null_mask_num_bytes * 8 + null_bit_id; int right_null = bit_util::GetBit(null_masks, bitid_right) ? 0xff : 0; int left_null = bit_util::GetBit(non_nulls, irow_left + col.bit_offset(0)) ? 0 : 0xff; @@ -228,27 +233,16 @@ void KeyCompare::CompareBinaryColumnToRow(uint32_t offset_within_row, // Overwrites the match_bytevector instead of updating it template -void KeyCompare::CompareVarBinaryColumnToRow(uint32_t id_varbinary_col, - uint32_t num_rows_to_compare, - const uint16_t* sel_left_maybe_null, - const uint32_t* left_to_right_map, - LightContext* ctx, const KeyColumnArray& col, - const RowTableImpl& rows, - uint8_t* match_bytevector) { -#if defined(ARROW_HAVE_AVX2) - if (ctx->has_avx2()) { - CompareVarBinaryColumnToRow_avx2( - use_selection, is_first_varbinary_col, id_varbinary_col, num_rows_to_compare, - sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector); - return; - } -#endif - +void KeyCompare::CompareVarBinaryColumnToRowHelper( + uint32_t id_varbinary_col, uint32_t first_row_to_compare, + uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null, + const uint32_t* left_to_right_map, LightContext* ctx, const KeyColumnArray& col, + const RowTableImpl& rows, uint8_t* match_bytevector) { const uint32_t* offsets_left = col.offsets(); const uint32_t* offsets_right = rows.offsets(); const uint8_t* rows_left = col.data(2); const uint8_t* rows_right = rows.data(2); - for (uint32_t i = 0; i < num_rows_to_compare; ++i) { + for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) { uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i; uint32_t irow_right = left_to_right_map[irow_left]; uint32_t begin_left = offsets_left[irow_left]; @@ -292,6 +286,29 @@ void KeyCompare::CompareVarBinaryColumnToRow(uint32_t id_varbinary_col, } } +// Overwrites the match_bytevector instead of updating it +template +void KeyCompare::CompareVarBinaryColumnToRow(uint32_t id_varbinary_col, + uint32_t num_rows_to_compare, + const uint16_t* sel_left_maybe_null, + const uint32_t* left_to_right_map, + LightContext* ctx, const KeyColumnArray& col, + const RowTableImpl& rows, + uint8_t* match_bytevector) { + uint32_t num_processed = 0; +#if defined(ARROW_HAVE_AVX2) + if (ctx->has_avx2()) { + num_processed = CompareVarBinaryColumnToRow_avx2( + use_selection, is_first_varbinary_col, id_varbinary_col, num_rows_to_compare, + sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector); + } +#endif + + CompareVarBinaryColumnToRowHelper( + id_varbinary_col, num_processed, num_rows_to_compare, sel_left_maybe_null, + left_to_right_map, ctx, col, rows, match_bytevector); +} + void KeyCompare::AndByteVectors(LightContext* ctx, uint32_t num_elements, uint8_t* bytevector_A, const uint8_t* bytevector_B) { uint32_t num_processed = 0; @@ -308,13 +325,12 @@ void KeyCompare::AndByteVectors(LightContext* ctx, uint32_t num_elements, } } -void KeyCompare::CompareColumnsToRows(uint32_t num_rows_to_compare, - const uint16_t* sel_left_maybe_null, - const uint32_t* left_to_right_map, - LightContext* ctx, uint32_t* out_num_rows, - uint16_t* out_sel_left_maybe_same, - const std::vector& cols, - const RowTableImpl& rows) { +void KeyCompare::CompareColumnsToRows( + uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null, + const uint32_t* left_to_right_map, LightContext* ctx, uint32_t* out_num_rows, + uint16_t* out_sel_left_maybe_same, const std::vector& cols, + const RowTableImpl& rows, bool are_cols_in_encoding_order, + uint8_t* out_match_bitvector_maybe_null) { if (num_rows_to_compare == 0) { *out_num_rows = 0; return; @@ -335,6 +351,7 @@ void KeyCompare::CompareColumnsToRows(uint32_t num_rows_to_compare, bool is_first_column = true; for (size_t icol = 0; icol < cols.size(); ++icol) { const KeyColumnArray& col = cols[icol]; + if (col.metadata().is_null_type) { // If this null type col is the first column, the match_bytevector_A needs to be // initialized with 0xFF. Otherwise, the calculation can be skipped @@ -343,8 +360,11 @@ void KeyCompare::CompareColumnsToRows(uint32_t num_rows_to_compare, } continue; } - uint32_t offset_within_row = - rows.metadata().encoded_field_offset(static_cast(icol)); + + uint32_t offset_within_row = rows.metadata().encoded_field_offset( + are_cols_in_encoding_order + ? static_cast(icol) + : rows.metadata().pos_after_encoding(static_cast(icol))); if (col.metadata().is_fixed_length) { if (sel_left_maybe_null) { CompareBinaryColumnToRow( @@ -354,7 +374,8 @@ void KeyCompare::CompareColumnsToRows(uint32_t num_rows_to_compare, NullUpdateColumnToRow( static_cast(icol), num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, - is_first_column ? match_bytevector_A : match_bytevector_B); + is_first_column ? match_bytevector_A : match_bytevector_B, + are_cols_in_encoding_order); } else { // Version without using selection vector CompareBinaryColumnToRow( @@ -364,7 +385,8 @@ void KeyCompare::CompareColumnsToRows(uint32_t num_rows_to_compare, NullUpdateColumnToRow( static_cast(icol), num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, - is_first_column ? match_bytevector_A : match_bytevector_B); + is_first_column ? match_bytevector_A : match_bytevector_B, + are_cols_in_encoding_order); } if (!is_first_column) { AndByteVectors(ctx, num_rows_to_compare, match_bytevector_A, match_bytevector_B); @@ -391,7 +413,8 @@ void KeyCompare::CompareColumnsToRows(uint32_t num_rows_to_compare, NullUpdateColumnToRow( static_cast(icol), num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, - is_first_column ? match_bytevector_A : match_bytevector_B); + is_first_column ? match_bytevector_A : match_bytevector_B, + are_cols_in_encoding_order); } else { if (ivarbinary == 0) { CompareVarBinaryColumnToRow( @@ -405,7 +428,8 @@ void KeyCompare::CompareColumnsToRows(uint32_t num_rows_to_compare, NullUpdateColumnToRow( static_cast(icol), num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, - is_first_column ? match_bytevector_A : match_bytevector_B); + is_first_column ? match_bytevector_A : match_bytevector_B, + are_cols_in_encoding_order); } if (!is_first_column) { AndByteVectors(ctx, num_rows_to_compare, match_bytevector_A, match_bytevector_B); @@ -417,18 +441,26 @@ void KeyCompare::CompareColumnsToRows(uint32_t num_rows_to_compare, util::bit_util::bytes_to_bits(ctx->hardware_flags, num_rows_to_compare, match_bytevector_A, match_bitvector); - if (sel_left_maybe_null) { - int out_num_rows_int; - util::bit_util::bits_filter_indexes(0, ctx->hardware_flags, num_rows_to_compare, - match_bitvector, sel_left_maybe_null, - &out_num_rows_int, out_sel_left_maybe_same); - *out_num_rows = out_num_rows_int; + + if (out_match_bitvector_maybe_null) { + ARROW_DCHECK(out_num_rows == nullptr); + ARROW_DCHECK(out_sel_left_maybe_same == nullptr); + memcpy(out_match_bitvector_maybe_null, match_bitvector, + bit_util::BytesForBits(num_rows_to_compare)); } else { - int out_num_rows_int; - util::bit_util::bits_to_indexes(0, ctx->hardware_flags, num_rows_to_compare, - match_bitvector, &out_num_rows_int, - out_sel_left_maybe_same); - *out_num_rows = out_num_rows_int; + if (sel_left_maybe_null) { + int out_num_rows_int; + util::bit_util::bits_filter_indexes(0, ctx->hardware_flags, num_rows_to_compare, + match_bitvector, sel_left_maybe_null, + &out_num_rows_int, out_sel_left_maybe_same); + *out_num_rows = out_num_rows_int; + } else { + int out_num_rows_int; + util::bit_util::bits_to_indexes(0, ctx->hardware_flags, num_rows_to_compare, + match_bitvector, &out_num_rows_int, + out_sel_left_maybe_same); + *out_num_rows = out_num_rows_int; + } } } diff --git a/cpp/src/arrow/compute/row/compare_internal.h b/cpp/src/arrow/compute/row/compare_internal.h index e3b9057115e..f9ec1e7f535 100644 --- a/cpp/src/arrow/compute/row/compare_internal.h +++ b/cpp/src/arrow/compute/row/compare_internal.h @@ -35,13 +35,12 @@ class KeyCompare { // Returns a single 16-bit selection vector of rows that failed comparison. // If there is input selection on the left, the resulting selection is a filtered image // of input selection. - static void CompareColumnsToRows(uint32_t num_rows_to_compare, - const uint16_t* sel_left_maybe_null, - const uint32_t* left_to_right_map, LightContext* ctx, - uint32_t* out_num_rows, - uint16_t* out_sel_left_maybe_same, - const std::vector& cols, - const RowTableImpl& rows); + static void CompareColumnsToRows( + uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null, + const uint32_t* left_to_right_map, LightContext* ctx, uint32_t* out_num_rows, + uint16_t* out_sel_left_maybe_same, const std::vector& cols, + const RowTableImpl& rows, bool are_cols_in_encoding_order, + uint8_t* out_match_bitvector_maybe_null = NULLPTR); private: template @@ -49,7 +48,8 @@ class KeyCompare { const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map, LightContext* ctx, const KeyColumnArray& col, const RowTableImpl& rows, - uint8_t* match_bytevector); + uint8_t* match_bytevector, + bool are_cols_in_encoding_order); template static void CompareBinaryColumnToRowHelper( @@ -67,6 +67,13 @@ class KeyCompare { const RowTableImpl& rows, uint8_t* match_bytevector); + template + static void CompareVarBinaryColumnToRowHelper( + uint32_t id_varlen_col, uint32_t first_row_to_compare, uint32_t num_rows_to_compare, + const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map, + LightContext* ctx, const KeyColumnArray& col, const RowTableImpl& rows, + uint8_t* match_bytevector); + template static void CompareVarBinaryColumnToRow(uint32_t id_varlen_col, uint32_t num_rows_to_compare, @@ -125,7 +132,7 @@ class KeyCompare { LightContext* ctx, const KeyColumnArray& col, const RowTableImpl& rows, uint8_t* match_bytevector); - static void CompareVarBinaryColumnToRow_avx2( + static uint32_t CompareVarBinaryColumnToRow_avx2( bool use_selection, bool is_first_varbinary_col, uint32_t id_varlen_col, uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map, LightContext* ctx, const KeyColumnArray& col, diff --git a/cpp/src/arrow/compute/row/compare_internal_avx2.cc b/cpp/src/arrow/compute/row/compare_internal_avx2.cc index 818f4c4fe7f..96dacab6797 100644 --- a/cpp/src/arrow/compute/row/compare_internal_avx2.cc +++ b/cpp/src/arrow/compute/row/compare_internal_avx2.cc @@ -44,6 +44,9 @@ uint32_t KeyCompare::NullUpdateColumnToRowImp_avx2( if (!rows.has_any_nulls(ctx) && !col.data(0)) { return num_rows_to_compare; } + + uint32_t null_bit_id = rows.metadata().pos_after_encoding(id_col); + if (!col.data(0)) { // Remove rows from the result for which the column value is a null const uint8_t* null_masks = rows.null_masks(); @@ -63,7 +66,7 @@ uint32_t KeyCompare::NullUpdateColumnToRowImp_avx2( } __m256i bitid = _mm256_mullo_epi32(irow_right, _mm256_set1_epi32(null_mask_num_bytes * 8)); - bitid = _mm256_add_epi32(bitid, _mm256_set1_epi32(id_col)); + bitid = _mm256_add_epi32(bitid, _mm256_set1_epi32(null_bit_id)); __m256i right = _mm256_i32gather_epi32((const int*)null_masks, _mm256_srli_epi32(bitid, 3), 1); right = _mm256_and_si256( @@ -80,7 +83,8 @@ uint32_t KeyCompare::NullUpdateColumnToRowImp_avx2( num_processed = num_rows_to_compare / unroll * unroll; return num_processed; } else if (!rows.has_any_nulls(ctx)) { - // Remove rows from the result for which the column value on left side is null + // Remove rows from the result for which the column value on left side is + // null const uint8_t* non_nulls = col.data(0); ARROW_DCHECK(non_nulls); uint32_t num_processed = 0; @@ -145,7 +149,7 @@ uint32_t KeyCompare::NullUpdateColumnToRowImp_avx2( } __m256i bitid = _mm256_mullo_epi32(irow_right, _mm256_set1_epi32(null_mask_num_bytes * 8)); - bitid = _mm256_add_epi32(bitid, _mm256_set1_epi32(id_col)); + bitid = _mm256_add_epi32(bitid, _mm256_set1_epi32(null_bit_id)); __m256i right = _mm256_i32gather_epi32((const int*)null_masks, _mm256_srli_epi32(bitid, 3), 1); right = _mm256_and_si256( @@ -252,22 +256,22 @@ inline uint64_t CompareSelected8_avx2(const uint8_t* left_base, const uint8_t* r int bit_offset = 0) { __m256i left; switch (column_width) { - case 0: + case 0: { irow_left = _mm256_add_epi32(irow_left, _mm256_set1_epi32(bit_offset)); left = _mm256_i32gather_epi32((const int*)left_base, - _mm256_srli_epi32(irow_left, 3), 1); - left = _mm256_and_si256( - _mm256_set1_epi32(1), - _mm256_srlv_epi32(left, _mm256_and_si256(irow_left, _mm256_set1_epi32(7)))); - left = _mm256_mullo_epi32(left, _mm256_set1_epi32(0xff)); - break; + _mm256_srli_epi32(irow_left, 5), 4); + __m256i bit_selection = _mm256_sllv_epi32( + _mm256_set1_epi32(1), _mm256_and_si256(irow_left, _mm256_set1_epi32(31))); + left = _mm256_cmpeq_epi32(bit_selection, _mm256_and_si256(left, bit_selection)); + left = _mm256_and_si256(left, _mm256_set1_epi32(0xff)); + } break; case 1: left = _mm256_i32gather_epi32((const int*)left_base, irow_left, 1); left = _mm256_and_si256(left, _mm256_set1_epi32(0xff)); break; case 2: left = _mm256_i32gather_epi32((const int*)left_base, irow_left, 2); - left = _mm256_and_si256(left, _mm256_set1_epi32(0xff)); + left = _mm256_and_si256(left, _mm256_set1_epi32(0xffff)); break; case 4: left = _mm256_i32gather_epi32((const int*)left_base, irow_left, 4); @@ -311,15 +315,15 @@ inline uint64_t Compare8_avx2(const uint8_t* left_base, const uint8_t* right_bas } break; case 1: left = _mm256_cvtepu8_epi32(_mm_set1_epi64x( - reinterpret_cast(left_base)[irow_left_first / 8])); + *reinterpret_cast(left_base + irow_left_first))); break; case 2: left = _mm256_cvtepu16_epi32(_mm_loadu_si128( - reinterpret_cast(left_base) + irow_left_first / 8)); + reinterpret_cast(left_base + 2 * irow_left_first))); break; case 4: - left = _mm256_loadu_si256(reinterpret_cast(left_base) + - irow_left_first / 8); + left = _mm256_loadu_si256( + reinterpret_cast(left_base + 4 * irow_left_first)); break; default: ARROW_DCHECK(false); @@ -347,19 +351,17 @@ inline uint64_t Compare8_64bit_avx2(const uint8_t* left_base, const uint8_t* rig __m256i offset_right) { auto left_base_i64 = reinterpret_cast(left_base); - __m256i left_lo = - _mm256_i32gather_epi64(left_base_i64, _mm256_castsi256_si128(irow_left), 8); - __m256i left_hi = - _mm256_i32gather_epi64(left_base_i64, _mm256_extracti128_si256(irow_left, 1), 8); + __m256i left_lo, left_hi; if (use_selection) { left_lo = _mm256_i32gather_epi64(left_base_i64, _mm256_castsi256_si128(irow_left), 8); left_hi = _mm256_i32gather_epi64(left_base_i64, _mm256_extracti128_si256(irow_left, 1), 8); } else { - left_lo = _mm256_loadu_si256(reinterpret_cast(left_base) + - irow_left_first / 4); - left_hi = _mm256_loadu_si256(reinterpret_cast(left_base) + - irow_left_first / 4 + 1); + left_lo = _mm256_loadu_si256( + reinterpret_cast(left_base + irow_left_first * sizeof(uint64_t))); + left_hi = _mm256_loadu_si256( + reinterpret_cast(left_base + irow_left_first * sizeof(uint64_t)) + + 1); } auto right_base_i64 = reinterpret_cast(right_base); @@ -532,7 +534,7 @@ void KeyCompare::CompareVarBinaryColumnToRowImp_avx2( const __m256i* key_right_ptr = reinterpret_cast(rows_right + begin_right); int32_t j; - // length can be zero + // length is greater than zero for (j = 0; j < (static_cast(length) + 31) / 32 - 1; ++j) { __m256i key_left = _mm256_loadu_si256(key_left_ptr + j); __m256i key_right = _mm256_loadu_si256(key_right_ptr + j); @@ -569,6 +571,15 @@ uint32_t KeyCompare::NullUpdateColumnToRow_avx2( const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map, LightContext* ctx, const KeyColumnArray& col, const RowTableImpl& rows, uint8_t* match_bytevector) { + int64_t num_rows_safe = + TailSkipForSIMD::FixBitAccess(sizeof(uint32_t), col.length(), col.bit_offset(0)); + if (sel_left_maybe_null) { + num_rows_to_compare = static_cast(TailSkipForSIMD::FixSelection( + num_rows_safe, static_cast(num_rows_to_compare), sel_left_maybe_null)); + } else { + num_rows_to_compare = static_cast(num_rows_safe); + } + if (use_selection) { return NullUpdateColumnToRowImp_avx2(id_col, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, @@ -585,6 +596,29 @@ uint32_t KeyCompare::CompareBinaryColumnToRow_avx2( const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map, LightContext* ctx, const KeyColumnArray& col, const RowTableImpl& rows, uint8_t* match_bytevector) { + uint32_t col_width = col.metadata().fixed_length; + int64_t num_rows_safe = col.length(); + if (col_width == 0) { + // In this case we will access left column memory 4B at a time + num_rows_safe = + TailSkipForSIMD::FixBitAccess(sizeof(uint32_t), col.length(), col.bit_offset(1)); + } else if (col_width == 1 || col_width == 2) { + // In this case we will access left column memory 4B at a time + num_rows_safe = + TailSkipForSIMD::FixBinaryAccess(sizeof(uint32_t), col.length(), col_width); + } else if (col_width != 4 && col_width != 8) { + // In this case we will access left column memory 32B at a time + num_rows_safe = + TailSkipForSIMD::FixBinaryAccess(sizeof(__m256i), col.length(), col_width); + } + if (sel_left_maybe_null) { + num_rows_to_compare = static_cast(TailSkipForSIMD::FixSelection( + num_rows_safe, static_cast(num_rows_to_compare), sel_left_maybe_null)); + } else { + num_rows_to_compare = static_cast( + std::min(num_rows_safe, static_cast(num_rows_to_compare))); + } + if (use_selection) { return CompareBinaryColumnToRowImp_avx2(offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, @@ -596,11 +630,20 @@ uint32_t KeyCompare::CompareBinaryColumnToRow_avx2( } } -void KeyCompare::CompareVarBinaryColumnToRow_avx2( +uint32_t KeyCompare::CompareVarBinaryColumnToRow_avx2( bool use_selection, bool is_first_varbinary_col, uint32_t id_varlen_col, uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map, LightContext* ctx, const KeyColumnArray& col, const RowTableImpl& rows, uint8_t* match_bytevector) { + int64_t num_rows_safe = + TailSkipForSIMD::FixVarBinaryAccess(sizeof(__m256i), col.length(), col.offsets()); + if (use_selection) { + num_rows_to_compare = static_cast(TailSkipForSIMD::FixSelection( + num_rows_safe, static_cast(num_rows_to_compare), sel_left_maybe_null)); + } else { + num_rows_to_compare = static_cast(num_rows_safe); + } + if (use_selection) { if (is_first_varbinary_col) { CompareVarBinaryColumnToRowImp_avx2( @@ -622,6 +665,8 @@ void KeyCompare::CompareVarBinaryColumnToRow_avx2( col, rows, match_bytevector); } } + + return num_rows_to_compare; } #endif diff --git a/cpp/src/arrow/compute/row/encode_internal.cc b/cpp/src/arrow/compute/row/encode_internal.cc index cbfd169b448..9d138258d66 100644 --- a/cpp/src/arrow/compute/row/encode_internal.cc +++ b/cpp/src/arrow/compute/row/encode_internal.cc @@ -16,13 +16,14 @@ // under the License. #include "arrow/compute/row/encode_internal.h" +#include "arrow/compute/exec.h" +#include "arrow/util/checked_cast.h" namespace arrow { namespace compute { -void RowTableEncoder::Init(const std::vector& cols, LightContext* ctx, - int row_alignment, int string_alignment) { - ctx_ = ctx; +void RowTableEncoder::Init(const std::vector& cols, int row_alignment, + int string_alignment) { row_metadata_.FromColumnMetadataVector(cols, row_alignment, string_alignment); uint32_t num_cols = row_metadata_.num_cols(); uint32_t num_varbinary_cols = row_metadata_.num_varbinary_cols(); @@ -59,18 +60,24 @@ void RowTableEncoder::PrepareKeyColumnArrays(int64_t start_row, int64_t num_rows void RowTableEncoder::DecodeFixedLengthBuffers(int64_t start_row_input, int64_t start_row_output, int64_t num_rows, const RowTableImpl& rows, - std::vector* cols) { + std::vector* cols, + int64_t hardware_flags, + util::TempVectorStack* temp_stack) { // Prepare column array vectors PrepareKeyColumnArrays(start_row_output, num_rows, *cols); + LightContext ctx; + ctx.hardware_flags = hardware_flags; + ctx.stack = temp_stack; + // Create two temp vectors with 16-bit elements auto temp_buffer_holder_A = - util::TempVectorHolder(ctx_->stack, static_cast(num_rows)); + util::TempVectorHolder(ctx.stack, static_cast(num_rows)); auto temp_buffer_A = KeyColumnArray( KeyColumnMetadata(true, sizeof(uint16_t)), num_rows, nullptr, reinterpret_cast(temp_buffer_holder_A.mutable_data()), nullptr); auto temp_buffer_holder_B = - util::TempVectorHolder(ctx_->stack, static_cast(num_rows)); + util::TempVectorHolder(ctx.stack, static_cast(num_rows)); auto temp_buffer_B = KeyColumnArray( KeyColumnMetadata(true, sizeof(uint16_t)), num_rows, nullptr, reinterpret_cast(temp_buffer_holder_B.mutable_data()), nullptr); @@ -79,7 +86,7 @@ void RowTableEncoder::DecodeFixedLengthBuffers(int64_t start_row_input, if (!is_row_fixed_length) { EncoderOffsets::Decode(static_cast(start_row_input), static_cast(num_rows), rows, &batch_varbinary_cols_, - batch_varbinary_cols_base_offsets_, ctx_); + batch_varbinary_cols_base_offsets_, &ctx); } // Process fixed length columns @@ -98,13 +105,13 @@ void RowTableEncoder::DecodeFixedLengthBuffers(int64_t start_row_input, EncoderBinary::Decode(static_cast(start_row_input), static_cast(num_rows), row_metadata_.column_offsets[i], rows, &batch_all_cols_[i], - ctx_, &temp_buffer_A); + &ctx, &temp_buffer_A); i += 1; } else { EncoderBinaryPair::Decode( static_cast(start_row_input), static_cast(num_rows), row_metadata_.column_offsets[i], rows, &batch_all_cols_[i], - &batch_all_cols_[i + 1], ctx_, &temp_buffer_A, &temp_buffer_B); + &batch_all_cols_[i + 1], &ctx, &temp_buffer_A, &temp_buffer_B); i += 2; } } @@ -114,14 +121,17 @@ void RowTableEncoder::DecodeFixedLengthBuffers(int64_t start_row_input, static_cast(num_rows), rows, &batch_all_cols_); } -void RowTableEncoder::DecodeVaryingLengthBuffers(int64_t start_row_input, - int64_t start_row_output, - int64_t num_rows, - const RowTableImpl& rows, - std::vector* cols) { +void RowTableEncoder::DecodeVaryingLengthBuffers( + int64_t start_row_input, int64_t start_row_output, int64_t num_rows, + const RowTableImpl& rows, std::vector* cols, int64_t hardware_flags, + util::TempVectorStack* temp_stack) { // Prepare column array vectors PrepareKeyColumnArrays(start_row_output, num_rows, *cols); + LightContext ctx; + ctx.hardware_flags = hardware_flags; + ctx.stack = temp_stack; + bool is_row_fixed_length = row_metadata_.is_fixed_length; if (!is_row_fixed_length) { for (size_t i = 0; i < batch_varbinary_cols_.size(); ++i) { @@ -129,7 +139,7 @@ void RowTableEncoder::DecodeVaryingLengthBuffers(int64_t start_row_input, // positions in the output row buffer. EncoderVarBinary::Decode(static_cast(start_row_input), static_cast(num_rows), static_cast(i), - rows, &batch_varbinary_cols_[i], ctx_); + rows, &batch_varbinary_cols_[i], &ctx); } } } diff --git a/cpp/src/arrow/compute/row/encode_internal.h b/cpp/src/arrow/compute/row/encode_internal.h index ce887313466..970537a3067 100644 --- a/cpp/src/arrow/compute/row/encode_internal.h +++ b/cpp/src/arrow/compute/row/encode_internal.h @@ -46,8 +46,8 @@ namespace compute { /// Does not support nested types class RowTableEncoder { public: - void Init(const std::vector& cols, LightContext* ctx, - int row_alignment, int string_alignment); + void Init(const std::vector& cols, int row_alignment, + int string_alignment); const RowTableMetadata& row_metadata() { return row_metadata_; } // GrouperFastImpl right now needs somewhat intrusive visibility into RowTableEncoder @@ -84,7 +84,8 @@ class RowTableEncoder { /// for the call to DecodeVaryingLengthBuffers void DecodeFixedLengthBuffers(int64_t start_row_input, int64_t start_row_output, int64_t num_rows, const RowTableImpl& rows, - std::vector* cols); + std::vector* cols, int64_t hardware_flags, + util::TempVectorStack* temp_stack); /// \brief Decode the varlength columns of a row table into column storage /// \param start_row_input The starting row to decode @@ -94,7 +95,9 @@ class RowTableEncoder { /// \param cols The column arrays to decode into void DecodeVaryingLengthBuffers(int64_t start_row_input, int64_t start_row_output, int64_t num_rows, const RowTableImpl& rows, - std::vector* cols); + std::vector* cols, + int64_t hardware_flags, + util::TempVectorStack* temp_stack); private: /// Prepare column array vectors. @@ -107,8 +110,6 @@ class RowTableEncoder { void PrepareKeyColumnArrays(int64_t start_row, int64_t num_rows, const std::vector& cols_in); - LightContext* ctx_; - // Data initialized once, based on data types of key columns RowTableMetadata row_metadata_; diff --git a/cpp/src/arrow/compute/row/grouper.cc b/cpp/src/arrow/compute/row/grouper.cc index ba76bad0d17..28ebc9f1967 100644 --- a/cpp/src/arrow/compute/row/grouper.cc +++ b/cpp/src/arrow/compute/row/grouper.cc @@ -247,7 +247,7 @@ struct GrouperFastImpl : Grouper { impl->key_types_[icol] = key; } - impl->encoder_.Init(impl->col_metadata_, &impl->encode_ctx_, + impl->encoder_.Init(impl->col_metadata_, /* row_alignment = */ sizeof(uint64_t), /* string_alignment = */ sizeof(uint64_t)); RETURN_NOT_OK(impl->rows_.Init(ctx->memory_pool(), impl->encoder_.row_metadata())); @@ -255,24 +255,23 @@ struct GrouperFastImpl : Grouper { impl->rows_minibatch_.Init(ctx->memory_pool(), impl->encoder_.row_metadata())); impl->minibatch_size_ = impl->minibatch_size_min_; GrouperFastImpl* impl_ptr = impl.get(); - auto equal_func = [impl_ptr]( - int num_keys_to_compare, const uint16_t* selection_may_be_null, - const uint32_t* group_ids, uint32_t* out_num_keys_mismatch, - uint16_t* out_selection_mismatch) { - KeyCompare::CompareColumnsToRows( - num_keys_to_compare, selection_may_be_null, group_ids, &impl_ptr->encode_ctx_, - out_num_keys_mismatch, out_selection_mismatch, - impl_ptr->encoder_.batch_all_cols(), impl_ptr->rows_); - }; - auto append_func = [impl_ptr](int num_keys, const uint16_t* selection) { + impl->map_equal_impl_ = + [impl_ptr](int num_keys_to_compare, const uint16_t* selection_may_be_null, + const uint32_t* group_ids, uint32_t* out_num_keys_mismatch, + uint16_t* out_selection_mismatch, void*) { + KeyCompare::CompareColumnsToRows( + num_keys_to_compare, selection_may_be_null, group_ids, + &impl_ptr->encode_ctx_, out_num_keys_mismatch, out_selection_mismatch, + impl_ptr->encoder_.batch_all_cols(), impl_ptr->rows_, + /* are_cols_in_encoding_order=*/true); + }; + impl->map_append_impl_ = [impl_ptr](int num_keys, const uint16_t* selection, void*) { RETURN_NOT_OK(impl_ptr->encoder_.EncodeSelected(&impl_ptr->rows_minibatch_, num_keys, selection)); return impl_ptr->rows_.AppendSelectionFrom(impl_ptr->rows_minibatch_, num_keys, nullptr); }; - RETURN_NOT_OK(impl->map_.init(impl->encode_ctx_.hardware_flags, ctx->memory_pool(), - impl->encode_ctx_.stack, impl->log_minibatch_max_, - equal_func, append_func)); + RETURN_NOT_OK(impl->map_.init(impl->encode_ctx_.hardware_flags, ctx->memory_pool())); impl->cols_.resize(num_columns); impl->minibatch_hashes_.resize(impl->minibatch_size_max_ + kPaddingForSIMD / sizeof(uint32_t)); @@ -372,7 +371,8 @@ struct GrouperFastImpl : Grouper { match_bitvector.mutable_data(), local_slots.mutable_data()); map_.find(batch_size_next, minibatch_hashes_.data(), match_bitvector.mutable_data(), local_slots.mutable_data(), - reinterpret_cast(group_ids->mutable_data()) + start_row); + reinterpret_cast(group_ids->mutable_data()) + start_row, + &temp_stack_, map_equal_impl_, nullptr); } auto ids = util::TempVectorHolder(&temp_stack_, batch_size_next); int num_ids; @@ -382,7 +382,8 @@ struct GrouperFastImpl : Grouper { RETURN_NOT_OK(map_.map_new_keys( num_ids, ids.mutable_data(), minibatch_hashes_.data(), - reinterpret_cast(group_ids->mutable_data()) + start_row)); + reinterpret_cast(group_ids->mutable_data()) + start_row, + &temp_stack_, map_equal_impl_, map_append_impl_, nullptr)); start_row += batch_size_next; @@ -450,7 +451,7 @@ struct GrouperFastImpl : Grouper { int64_t batch_size_next = std::min(num_groups - start_row, static_cast(minibatch_size_max_)); encoder_.DecodeFixedLengthBuffers(start_row, start_row, batch_size_next, rows_, - &cols_); + &cols_, encode_ctx_.hardware_flags, &temp_stack_); start_row += batch_size_next; } @@ -470,7 +471,8 @@ struct GrouperFastImpl : Grouper { int64_t batch_size_next = std::min(num_groups - start_row, static_cast(minibatch_size_max_)); encoder_.DecodeVaryingLengthBuffers(start_row, start_row, batch_size_next, rows_, - &cols_); + &cols_, encode_ctx_.hardware_flags, + &temp_stack_); start_row += batch_size_next; } } @@ -535,6 +537,8 @@ struct GrouperFastImpl : Grouper { RowTableImpl rows_minibatch_; RowTableEncoder encoder_; SwissTable map_; + SwissTable::EqualImpl map_equal_impl_; + SwissTable::AppendImpl map_append_impl_; }; } // namespace diff --git a/cpp/src/arrow/compute/row/row_internal.cc b/cpp/src/arrow/compute/row/row_internal.cc index e99ff75d64a..11a8a0bc436 100644 --- a/cpp/src/arrow/compute/row/row_internal.cc +++ b/cpp/src/arrow/compute/row/row_internal.cc @@ -110,6 +110,10 @@ void RowTableMetadata::FromColumnMetadataVector( } return left < right; }); + inverse_column_order.resize(num_cols); + for (uint32_t i = 0; i < num_cols; ++i) { + inverse_column_order[column_order[i]] = i; + } row_alignment = in_row_alignment; string_alignment = in_string_alignment; diff --git a/cpp/src/arrow/compute/row/row_internal.h b/cpp/src/arrow/compute/row/row_internal.h index d46ac0e9a9e..c9194267aa3 100644 --- a/cpp/src/arrow/compute/row/row_internal.h +++ b/cpp/src/arrow/compute/row/row_internal.h @@ -73,6 +73,7 @@ struct ARROW_EXPORT RowTableMetadata { /// Order in which fields are encoded. std::vector column_order; + std::vector inverse_column_order; /// Offsets within a row to fields in their encoding order. std::vector column_offsets; @@ -133,6 +134,8 @@ struct ARROW_EXPORT RowTableMetadata { uint32_t encoded_field_order(uint32_t icol) const { return column_order[icol]; } + uint32_t pos_after_encoding(uint32_t icol) const { return inverse_column_order[icol]; } + uint32_t encoded_field_offset(uint32_t icol) const { return column_offsets[icol]; } uint32_t num_cols() const { return static_cast(column_metadatas.size()); } diff --git a/docs/source/python/compute.rst b/docs/source/python/compute.rst index b0867123d53..bcbca9dff36 100644 --- a/docs/source/python/compute.rst +++ b/docs/source/python/compute.rst @@ -237,10 +237,10 @@ In that case the result would be:: n_legs: int64 animal: string ---- - id: [[3,1,2],[4]] - year: [[2019,2020,2022],[null]] - n_legs: [[5,null,null],[100]] - animal: [["Brittle stars",null,null],["Centipede"]] + id: [[3,1,2,4]] + year: [[2019,2020,2022,null]] + n_legs: [[5,null,null,100]] + animal: [["Brittle stars",null,null,"Centipede"]] It's also possible to provide additional join keys, so that the join happens on two keys instead of one. For example we can add diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 17ea6d3558e..17f88aca4eb 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -4631,6 +4631,89 @@ cdef class Table(_PandasConvertible): return table + def group_by(self, keys): + """Declare a grouping over the columns of the table. + + Resulting grouping can then be used to perform aggregations + with a subsequent ``aggregate()`` method. + + Parameters + ---------- + keys : str or list[str] + Name of the columns that should be used as the grouping key. + + Returns + ------- + TableGroupBy + + See Also + -------- + TableGroupBy.aggregate + + Examples + -------- + >>> import pandas as pd + >>> import pyarrow as pa + >>> df = pd.DataFrame({'year': [2020, 2022, 2021, 2022, 2019, 2021], + ... 'n_legs': [2, 2, 4, 4, 5, 100], + ... 'animal': ["Flamingo", "Parrot", "Dog", "Horse", + ... "Brittle stars", "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + >>> table.group_by('year').aggregate([('n_legs', 'sum')]) + pyarrow.Table + n_legs_sum: int64 + year: int64 + ---- + n_legs_sum: [[2,6,104,5]] + year: [[2020,2022,2021,2019]] + """ + return TableGroupBy(self, keys) + + def sort_by(self, sorting): + """ + Sort the table by one or multiple columns. + + Parameters + ---------- + sorting : str or list[tuple(name, order)] + Name of the column to use to sort (ascending), or + a list of multiple sorting conditions where + each entry is a tuple with column name + and sorting order ("ascending" or "descending") + + Returns + ------- + Table + A new table sorted according to the sort keys. + + Examples + -------- + >>> import pandas as pd + >>> import pyarrow as pa + >>> df = pd.DataFrame({'year': [2020, 2022, 2021, 2022, 2019, 2021], + ... 'n_legs': [2, 2, 4, 4, 5, 100], + ... 'animal': ["Flamingo", "Parrot", "Dog", "Horse", + ... "Brittle stars", "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + >>> table.sort_by('animal') + pyarrow.Table + year: int64 + n_legs: int64 + animal: string + ---- + year: [[2019,2021,2021,2020,2022,2022]] + n_legs: [[5,100,4,2,4,2]] + animal: [["Brittle stars","Centipede","Dog","Flamingo","Horse","Parrot"]] + """ + if isinstance(sorting, str): + sorting = [(sorting, "ascending")] + + indices = _pc().sort_indices( + self, + sort_keys=sorting + ) + return self.take(indices) + def join(self, right_table, keys, right_keys=None, join_type="left outer", left_suffix=None, right_suffix=None, coalesce_keys=True, use_threads=True): @@ -4686,7 +4769,7 @@ cdef class Table(_PandasConvertible): Left outer join: - >>> t1.join(t2, 'id') + >>> t1.join(t2, 'id').combine_chunks().sort_by('year') pyarrow.Table id: int64 year: int64 @@ -4700,31 +4783,31 @@ cdef class Table(_PandasConvertible): Full outer join: - >>> t1.join(t2, 'id', join_type="full outer") + >>> t1.join(t2, 'id', join_type="full outer").combine_chunks().sort_by('year') pyarrow.Table id: int64 year: int64 n_legs: int64 animal: string ---- - id: [[3,1,2],[4]] - year: [[2019,2020,2022],[null]] - n_legs: [[5,null,null],[100]] - animal: [["Brittle stars",null,null],["Centipede"]] + id: [[3,1,2,4]] + year: [[2019,2020,2022,null]] + n_legs: [[5,null,null,100]] + animal: [["Brittle stars",null,null,"Centipede"]] Right outer join: - >>> t1.join(t2, 'id', join_type="right outer") + >>> t1.join(t2, 'id', join_type="right outer").combine_chunks().sort_by('year') pyarrow.Table year: int64 id: int64 n_legs: int64 animal: string ---- - year: [[2019],[null]] - id: [[3],[4]] - n_legs: [[5],[100]] - animal: [["Brittle stars"],["Centipede"]] + year: [[2019,null]] + id: [[3,4]] + n_legs: [[5,100]] + animal: [["Brittle stars","Centipede"]] Right anti join @@ -4745,89 +4828,6 @@ cdef class Table(_PandasConvertible): use_threads=use_threads, coalesce_keys=coalesce_keys, output_type=Table) - def group_by(self, keys): - """Declare a grouping over the columns of the table. - - Resulting grouping can then be used to perform aggregations - with a subsequent ``aggregate()`` method. - - Parameters - ---------- - keys : str or list[str] - Name of the columns that should be used as the grouping key. - - Returns - ------- - TableGroupBy - - See Also - -------- - TableGroupBy.aggregate - - Examples - -------- - >>> import pandas as pd - >>> import pyarrow as pa - >>> df = pd.DataFrame({'year': [2020, 2022, 2021, 2022, 2019, 2021], - ... 'n_legs': [2, 2, 4, 4, 5, 100], - ... 'animal': ["Flamingo", "Parrot", "Dog", "Horse", - ... "Brittle stars", "Centipede"]}) - >>> table = pa.Table.from_pandas(df) - >>> table.group_by('year').aggregate([('n_legs', 'sum')]) - pyarrow.Table - n_legs_sum: int64 - year: int64 - ---- - n_legs_sum: [[2,6,104,5]] - year: [[2020,2022,2021,2019]] - """ - return TableGroupBy(self, keys) - - def sort_by(self, sorting): - """ - Sort the table by one or multiple columns. - - Parameters - ---------- - sorting : str or list[tuple(name, order)] - Name of the column to use to sort (ascending), or - a list of multiple sorting conditions where - each entry is a tuple with column name - and sorting order ("ascending" or "descending") - - Returns - ------- - Table - A new table sorted according to the sort keys. - - Examples - -------- - >>> import pandas as pd - >>> import pyarrow as pa - >>> df = pd.DataFrame({'year': [2020, 2022, 2021, 2022, 2019, 2021], - ... 'n_legs': [2, 2, 4, 4, 5, 100], - ... 'animal': ["Flamingo", "Parrot", "Dog", "Horse", - ... "Brittle stars", "Centipede"]}) - >>> table = pa.Table.from_pandas(df) - >>> table.sort_by('animal') - pyarrow.Table - year: int64 - n_legs: int64 - animal: string - ---- - year: [[2019,2021,2021,2020,2022,2022]] - n_legs: [[5,100,4,2,4,2]] - animal: [["Brittle stars","Centipede","Dog","Flamingo","Horse","Parrot"]] - """ - if isinstance(sorting, str): - sorting = [(sorting, "ascending")] - - indices = _pc().sort_indices( - self, - sort_keys=sorting - ) - return self.take(indices) - def _reconstruct_table(arrays, schema): """ diff --git a/python/pyarrow/tests/test_exec_plan.py b/python/pyarrow/tests/test_exec_plan.py index f93aac3f869..209eed9d258 100644 --- a/python/pyarrow/tests/test_exec_plan.py +++ b/python/pyarrow/tests/test_exec_plan.py @@ -134,20 +134,24 @@ def test_table_join_collisions(): result = ep._perform_join( "full outer", t1, ["colA", "colB"], t2, ["colA", "colB"]) - assert result.combine_chunks() == pa.table([ - [1, 2, 6, None], - [10, 20, 60, None], - ["a", "b", "f", None], - [10, 20, None, 99], - ["A", "B", None, "Z"], - [300, 200, None, 100], - [1, 2, None, 99], + result = result.combine_chunks() + result = result.sort_by("colUniq") + assert result == pa.table([ + [None, 2, 1, 6], + [None, 20, 10, 60], + [None, "b", "a", "f"], + [99, 20, 10, None], + ["Z", "B", "A", None], + [100, 200, 300, None], + [99, 2, 1, None], ], names=["colA", "colB", "colVals", "colB", "colVals", "colUniq", "colA"]) result = ep._perform_join("full outer", t1, "colA", t2, "colA", right_suffix="_r", coalesce_keys=False) - assert result.combine_chunks() == pa.table({ + result = result.combine_chunks() + result = result.sort_by("colA") + assert result == pa.table({ "colA": [1, 2, 6, None], "colB": [10, 20, 60, None], "colVals": ["a", "b", "f", None], @@ -160,7 +164,9 @@ def test_table_join_collisions(): result = ep._perform_join("full outer", t1, "colA", t2, "colA", right_suffix="_r", coalesce_keys=True) - assert result.combine_chunks() == pa.table({ + result = result.combine_chunks() + result = result.sort_by("colA") + assert result == pa.table({ "colA": [1, 2, 6, 99], "colB": [10, 20, 60, None], "colVals": ["a", "b", "f", None], @@ -185,7 +191,9 @@ def test_table_join_keys_order(): result = ep._perform_join("full outer", t1, "colA", t2, "colX", left_suffix="_l", right_suffix="_r", coalesce_keys=True) - assert result.combine_chunks() == pa.table({ + result = result.combine_chunks() + result = result.sort_by("colA") + assert result == pa.table({ "colB": [10, 20, 60, None], "colA": [1, 2, 6, 99], "colVals_l": ["a", "b", "f", None],