From f8e5807286d501caef6843b07eeda67075675a00 Mon Sep 17 00:00:00 2001 From: michalursa Date: Tue, 5 Oct 2021 01:17:14 -0700 Subject: [PATCH 1/3] Support for dictionaries in hash join --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/compute/exec/hash_join.cc | 103 ++- cpp/src/arrow/compute/exec/hash_join.h | 4 - cpp/src/arrow/compute/exec/hash_join_dict.cc | 662 ++++++++++++++++++ cpp/src/arrow/compute/exec/hash_join_dict.h | 321 +++++++++ cpp/src/arrow/compute/exec/hash_join_node.cc | 28 +- .../arrow/compute/exec/hash_join_node_test.cc | 534 ++++++++++++++ cpp/src/arrow/compute/exec/schema_util.h | 6 +- cpp/src/arrow/compute/exec/source_node.cc | 14 +- cpp/src/arrow/compute/kernels/row_encoder.cc | 4 +- cpp/src/arrow/compute/kernels/row_encoder.h | 6 +- 11 files changed, 1628 insertions(+), 55 deletions(-) create mode 100644 cpp/src/arrow/compute/exec/hash_join_dict.cc create mode 100644 cpp/src/arrow/compute/exec/hash_join_dict.h diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index d7e433f4844..231000ac76e 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -422,6 +422,7 @@ if(ARROW_COMPUTE) compute/exec/key_compare.cc compute/exec/key_encode.cc compute/exec/util.cc + compute/exec/hash_join_dict.cc compute/exec/hash_join.cc compute/exec/hash_join_node.cc compute/exec/task_util.cc) diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index 8bbd8182451..a89e23796d4 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -24,6 +24,7 @@ #include #include +#include "arrow/compute/exec/hash_join_dict.h" #include "arrow/compute/exec/task_util.h" #include "arrow/compute/kernels/row_encoder.h" @@ -96,6 +97,7 @@ class HashJoinBasicImpl : public HashJoinImpl { local_states_[i].is_initialized = false; local_states_[i].is_has_match_initialized = false; } + dict_probe_.Init(num_threads); has_hash_table_ = false; num_batches_produced_.store(0); @@ -144,12 +146,13 @@ class HashJoinBasicImpl : public HashJoinImpl { if (has_payload) { InitEncoder(0, HashJoinProjection::PAYLOAD, &local_state.exec_batch_payloads); } + local_state.is_initialized = true; } } Status EncodeBatch(int side, HashJoinProjection projection_handle, RowEncoder* encoder, - const ExecBatch& batch) { + const ExecBatch& batch, ExecBatch* opt_projected_batch = nullptr) { ExecBatch projected({}, batch.length); int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle); projected.values.resize(num_cols); @@ -160,6 +163,10 @@ class HashJoinBasicImpl : public HashJoinImpl { projected.values[icol] = batch.values[to_input.get(icol)]; } + if (opt_projected_batch) { + *opt_projected_batch = projected; + } + return encoder->EncodeAndAppend(projected); } @@ -170,6 +177,8 @@ class HashJoinBasicImpl : public HashJoinImpl { std::vector* output_no_match, std::vector* output_match_left, std::vector* output_match_right) { + InitHasMatchIfNeeded(local_state); + ARROW_DCHECK(has_hash_table_); InitHasMatchIfNeeded(local_state); @@ -311,6 +320,8 @@ class HashJoinBasicImpl : public HashJoinImpl { ARROW_DCHECK(opt_right_ids); ARROW_ASSIGN_OR_RAISE(right_key, hash_table_keys_.Decode(batch_size_next, opt_right_ids)); + // Post process build side keys that use dictionary + RETURN_NOT_OK(dict_build_.PostDecode(schema_mgr_->proj_maps[1], &right_key, ctx_)); } if (has_right_payload) { ARROW_ASSIGN_OR_RAISE(right_payload, @@ -368,13 +379,48 @@ class HashJoinBasicImpl : public HashJoinImpl { return Status::OK(); } + void NullInfoFromBatch(const ExecBatch& batch, + std::vector* nn_bit_vectors, + std::vector* nn_offsets, + std::vector* nn_bit_vector_all_nulls) { + int num_cols = static_cast(batch.values.size()); + nn_bit_vectors->resize(num_cols); + nn_offsets->resize(num_cols); + nn_bit_vector_all_nulls->clear(); + for (int64_t i = 0; i < num_cols; ++i) { + const uint8_t* nn = nullptr; + int64_t offset = 0; + if (batch[i].is_array()) { + if (batch[i].array()->buffers[0] != NULLPTR) { + nn = batch[i].array()->buffers[0]->data(); + offset = batch[i].array()->offset; + } + } else { + ARROW_DCHECK(batch[i].is_scalar()); + if (!batch[i].scalar_as().is_valid) { + if (nn_bit_vector_all_nulls->empty()) { + nn_bit_vector_all_nulls->resize(BitUtil::BytesForBits(batch.length)); + memset(nn_bit_vector_all_nulls->data(), 0, + BitUtil::BytesForBits(batch.length)); + } + nn = nn_bit_vector_all_nulls->data(); + } + } + (*nn_bit_vectors)[i] = nn; + (*nn_offsets)[i] = offset; + } + } + Status ProbeBatch(size_t thread_index, const ExecBatch& batch) { ThreadLocalState& local_state = local_states_[thread_index]; InitLocalStateIfNeeded(thread_index); local_state.exec_batch_keys.Clear(); - RETURN_NOT_OK( - EncodeBatch(0, HashJoinProjection::KEY, &local_state.exec_batch_keys, batch)); + + ExecBatch batch_key_for_lookups; + + RETURN_NOT_OK(EncodeBatch(0, HashJoinProjection::KEY, &local_state.exec_batch_keys, + batch, &batch_key_for_lookups)); bool has_left_payload = (schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0); if (has_left_payload) { @@ -388,26 +434,24 @@ class HashJoinBasicImpl : public HashJoinImpl { local_state.match_left.clear(); local_state.match_right.clear(); + bool use_key_batch_for_dicts = dict_probe_.BatchRemapNeeded( + thread_index, schema_mgr_->proj_maps[0], schema_mgr_->proj_maps[1], ctx_); + RowEncoder* row_encoder_for_lookups = &local_state.exec_batch_keys; + if (use_key_batch_for_dicts) { + RETURN_NOT_OK(dict_probe_.EncodeBatch( + thread_index, schema_mgr_->proj_maps[0], schema_mgr_->proj_maps[1], dict_build_, + batch, &row_encoder_for_lookups, &batch_key_for_lookups, ctx_)); + } + + // Collect information about all nulls in key columns. + // std::vector non_null_bit_vectors; std::vector non_null_bit_vector_offsets; - int num_key_cols = schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::KEY); - non_null_bit_vectors.resize(num_key_cols); - non_null_bit_vector_offsets.resize(num_key_cols); - auto from_batch = - schema_mgr_->proj_maps[0].map(HashJoinProjection::KEY, HashJoinProjection::INPUT); - for (int i = 0; i < num_key_cols; ++i) { - int input_col_id = from_batch.get(i); - const uint8_t* non_nulls = nullptr; - int64_t offset = 0; - if (batch[input_col_id].array()->buffers[0] != NULLPTR) { - non_nulls = batch[input_col_id].array()->buffers[0]->data(); - offset = batch[input_col_id].array()->offset; - } - non_null_bit_vectors[i] = non_nulls; - non_null_bit_vector_offsets[i] = offset; - } + std::vector all_nulls; + NullInfoFromBatch(batch_key_for_lookups, &non_null_bit_vectors, + &non_null_bit_vector_offsets, &all_nulls); - ProbeBatch_Lookup(&local_state, local_state.exec_batch_keys, non_null_bit_vectors, + ProbeBatch_Lookup(&local_state, *row_encoder_for_lookups, non_null_bit_vectors, non_null_bit_vector_offsets, &local_state.match, &local_state.no_match, &local_state.match_left, &local_state.match_right); @@ -427,7 +471,7 @@ class HashJoinBasicImpl : public HashJoinImpl { if (batches.empty()) { hash_table_empty_ = true; } else { - InitEncoder(1, HashJoinProjection::KEY, &hash_table_keys_); + dict_build_.InitEncoder(schema_mgr_->proj_maps[1], &hash_table_keys_, ctx_); bool has_payload = (schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0); if (has_payload) { @@ -441,11 +485,14 @@ class HashJoinBasicImpl : public HashJoinImpl { const ExecBatch& batch = batches[ibatch]; if (batch.length == 0) { continue; - } else { + } else if (hash_table_empty_) { hash_table_empty_ = false; + + RETURN_NOT_OK(dict_build_.Init(schema_mgr_->proj_maps[1], &batch, ctx_)); } int32_t num_rows_before = hash_table_keys_.num_rows(); - RETURN_NOT_OK(EncodeBatch(1, HashJoinProjection::KEY, &hash_table_keys_, batch)); + RETURN_NOT_OK(dict_build_.EncodeBatch(thread_index, schema_mgr_->proj_maps[1], + batch, &hash_table_keys_, ctx_)); if (has_payload) { RETURN_NOT_OK( EncodeBatch(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_, batch)); @@ -456,6 +503,11 @@ class HashJoinBasicImpl : public HashJoinImpl { } } } + + if (hash_table_empty_) { + RETURN_NOT_OK(dict_build_.Init(schema_mgr_->proj_maps[1], nullptr, ctx_)); + } + return Status::OK(); } @@ -713,6 +765,11 @@ class HashJoinBasicImpl : public HashJoinImpl { std::vector has_match_; bool hash_table_empty_; + // Dictionary handling + // + HashJoinDictBuildMulti dict_build_; + HashJoinDictProbeMulti dict_probe_; + std::vector left_batches_; bool has_hash_table_; std::mutex left_batches_mutex_; diff --git a/cpp/src/arrow/compute/exec/hash_join.h b/cpp/src/arrow/compute/exec/hash_join.h index a2312e09653..11b36d9af27 100644 --- a/cpp/src/arrow/compute/exec/hash_join.h +++ b/cpp/src/arrow/compute/exec/hash_join.h @@ -31,10 +31,6 @@ namespace arrow { namespace compute { -// Identifiers for all different row schemas that are used in a join -// -enum class HashJoinProjection : int { INPUT = 0, KEY = 1, PAYLOAD = 2, OUTPUT = 3 }; - class ARROW_EXPORT HashJoinSchema { public: Status Init(JoinType join_type, const Schema& left_schema, diff --git a/cpp/src/arrow/compute/exec/hash_join_dict.cc b/cpp/src/arrow/compute/exec/hash_join_dict.cc new file mode 100644 index 00000000000..f3b0812ca7c --- /dev/null +++ b/cpp/src/arrow/compute/exec/hash_join_dict.cc @@ -0,0 +1,662 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec/hash_join_dict.h" + +#include +#include +#include +#include +#include + +#include "arrow/buffer.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/checked_cast.h" + +namespace arrow { +namespace compute { + +bool HashJoinDictUtil::KeyDataTypesValid( + const std::shared_ptr& probe_data_type, + const std::shared_ptr& build_data_type) { + bool l_is_dict = (probe_data_type->id() == Type::DICTIONARY); + bool r_is_dict = (build_data_type->id() == Type::DICTIONARY); + DataType* l_type; + if (l_is_dict) { + const auto& dict_type = checked_cast(*probe_data_type); + l_type = dict_type.value_type().get(); + } else { + l_type = probe_data_type.get(); + } + DataType* r_type; + if (r_is_dict) { + const auto& dict_type = checked_cast(*build_data_type); + r_type = dict_type.value_type().get(); + } else { + r_type = build_data_type.get(); + } + return l_type->Equals(*r_type); +} + +Result> HashJoinDictUtil::IndexRemapUsingLUT( + ExecContext* ctx, const Datum& indices, int64_t batch_length, + const std::shared_ptr& map_array, + const std::shared_ptr& data_type) { + ARROW_DCHECK(indices.is_array() || indices.is_scalar()); + + const uint8_t* map_non_nulls = map_array->buffers[0]->data(); + const int32_t* map = reinterpret_cast(map_array->buffers[1]->data()); + + ARROW_DCHECK(data_type->id() == Type::DICTIONARY); + const auto& dict_type = checked_cast(*data_type); + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr result, + CvtToInt32(dict_type.index_type(), indices, batch_length, ctx)); + + uint8_t* nns = result->buffers[0]->mutable_data(); + int32_t* ids = reinterpret_cast(result->buffers[1]->mutable_data()); + for (int64_t i = 0; i < batch_length; ++i) { + bool is_null = !BitUtil::GetBit(nns, i); + if (is_null) { + ids[i] = kNullId; + } else { + ARROW_DCHECK(ids[i] >= 0 && ids[i] < map_array->length); + if (!BitUtil::GetBit(map_non_nulls, ids[i])) { + BitUtil::ClearBit(nns, i); + ids[i] = kNullId; + } else { + ids[i] = map[ids[i]]; + } + } + } + + return result; +} + +Result> HashJoinDictUtil::CvtToInt32( + const std::shared_ptr& from_type, const Datum& input, int64_t batch_length, + ExecContext* ctx) { + switch (from_type->id()) { + case Type::UINT8: + return CvtImp(int32(), input, batch_length, ctx); + case Type::INT8: + return CvtImp(int32(), input, batch_length, ctx); + case Type::UINT16: + return CvtImp(int32(), input, batch_length, ctx); + case Type::INT16: + return CvtImp(int32(), input, batch_length, ctx); + case Type::UINT32: + return CvtImp(int32(), input, batch_length, ctx); + case Type::INT32: + return CvtImp(int32(), input, batch_length, ctx); + case Type::UINT64: + return CvtImp(int32(), input, batch_length, ctx); + case Type::INT64: + return CvtImp(int32(), input, batch_length, ctx); + default: + ARROW_DCHECK(false); + return nullptr; + } +} + +template +Result> HashJoinDictUtil::CvtImp( + const std::shared_ptr& to_type, const Datum& input, int64_t batch_length, + ExecContext* ctx) { + ARROW_DCHECK(input.is_array() || input.is_scalar()); + bool is_scalar = input.is_scalar(); + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr to_buf, + AllocateBuffer(batch_length * sizeof(TO), ctx->memory_pool())); + TO* to = reinterpret_cast(to_buf->mutable_data()); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr to_nn_buf, + AllocateBitmap(batch_length, ctx->memory_pool())); + uint8_t* to_nn = to_nn_buf->mutable_data(); + memset(to_nn, 0xff, BitUtil::BytesForBits(batch_length)); + + if (!is_scalar) { + const ArrayData& arr = *input.array(); + const FROM* from = arr.GetValues(1); + DCHECK_EQ(arr.length, batch_length); + + for (int64_t i = 0; i < arr.length; ++i) { + to[i] = static_cast(from[i]); + // Make sure we did not lose information during cast + ARROW_DCHECK(static_cast(to[i]) == from[i]); + + bool is_null = (arr.buffers[0] != NULLPTR) && + !BitUtil::GetBit(arr.buffers[0]->data(), arr.offset + i); + if (is_null) { + BitUtil::ClearBit(to_nn, i); + } + } + + // Pass null buffer unchanged + return ArrayData::Make(to_type, arr.length, + {std::move(to_nn_buf), std::move(to_buf)}); + } else { + const auto& scalar = input.scalar_as(); + if (scalar.is_valid) { + const util::string_view data = scalar.view(); + DCHECK_EQ(data.size(), sizeof(FROM)); + const FROM from = *reinterpret_cast(data.data()); + const TO to_value = static_cast(from); + // Make sure we did not lose information during cast + ARROW_DCHECK(static_cast(to_value) == from); + + for (int64_t i = 0; i < batch_length; ++i) { + to[i] = to_value; + } + + memset(to_nn, 0xff, BitUtil::BytesForBits(batch_length)); + return ArrayData::Make(to_type, batch_length, + {std::move(to_nn_buf), std::move(to_buf)}); + } else { + memset(to_nn, 0, BitUtil::BytesForBits(batch_length)); + return ArrayData::Make(to_type, batch_length, + {std::move(to_nn_buf), std::move(to_buf)}); + } + } +} + +Result> HashJoinDictUtil::CvtFromInt32( + const std::shared_ptr& to_type, const Datum& input, int64_t batch_length, + ExecContext* ctx) { + switch (to_type->id()) { + case Type::UINT8: + return CvtImp(to_type, input, batch_length, ctx); + case Type::INT8: + return CvtImp(to_type, input, batch_length, ctx); + case Type::UINT16: + return CvtImp(to_type, input, batch_length, ctx); + case Type::INT16: + return CvtImp(to_type, input, batch_length, ctx); + case Type::UINT32: + return CvtImp(to_type, input, batch_length, ctx); + case Type::INT32: + return CvtImp(to_type, input, batch_length, ctx); + case Type::UINT64: + return CvtImp(to_type, input, batch_length, ctx); + case Type::INT64: + return CvtImp(to_type, input, batch_length, ctx); + default: + ARROW_DCHECK(false); + return nullptr; + } +} + +std::shared_ptr HashJoinDictUtil::ExtractDictionary(const Datum& data) { + return data.is_array() ? MakeArray(data.array()->dictionary) + : data.scalar_as().value.dictionary; +} + +Status HashJoinDictBuild::Init(ExecContext* ctx, std::shared_ptr dictionary, + std::shared_ptr index_type, + std::shared_ptr value_type) { + index_type_ = std::move(index_type); + value_type_ = std::move(value_type); + hash_table_.clear(); + + if (!dictionary) { + ARROW_ASSIGN_OR_RAISE(auto dict, MakeArrayOfNull(value_type_, 0)); + unified_dictionary_ = dict->data(); + return Status::OK(); + } + + dictionary_ = dictionary; + + // Initialize encoder + internal::RowEncoder encoder; + std::vector encoder_types; + encoder_types.emplace_back(value_type_, ValueDescr::ARRAY); + encoder.Init(encoder_types, ctx); + + // Encode all dictionary values + int64_t length = dictionary->data()->length; + if (length >= std::numeric_limits::max()) { + return Status::Invalid( + "Dictionary length in hash join must fit into signed 32-bit integer."); + } + ExecBatch batch({dictionary->data()}, length); + RETURN_NOT_OK(encoder.EncodeAndAppend(batch)); + + std::vector entries_to_take; + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr non_nulls_buf, + AllocateBitmap(length, ctx->memory_pool())); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr ids_buf, + AllocateBuffer(length * sizeof(int32_t), ctx->memory_pool())); + uint8_t* non_nulls = non_nulls_buf->mutable_data(); + int32_t* ids = reinterpret_cast(ids_buf->mutable_data()); + memset(non_nulls, 0xff, BitUtil::BytesForBits(length)); + + int32_t num_entries = 0; + for (int64_t i = 0; i < length; ++i) { + std::string str = encoder.encoded_row(static_cast(i)); + + // Do not insert null values into resulting dictionary. + // Null values will always be represented as null not an id pointing to a + // dictionary entry for null. + // + if (internal::KeyEncoder::IsNull(reinterpret_cast(str.data()))) { + ids[i] = HashJoinDictUtil::kNullId; + BitUtil::ClearBit(non_nulls, i); + continue; + } + + auto iter = hash_table_.find(str); + if (iter == hash_table_.end()) { + hash_table_.insert(std::make_pair(str, num_entries)); + ids[i] = num_entries; + entries_to_take.push_back(static_cast(i)); + ++num_entries; + } else { + ids[i] = iter->second; + } + } + + ARROW_ASSIGN_OR_RAISE(auto out, encoder.Decode(num_entries, entries_to_take.data())); + + unified_dictionary_ = out[0].array(); + remapped_ids_ = ArrayData::Make(DataTypeAfterRemapping(), length, + {std::move(non_nulls_buf), std::move(ids_buf)}); + + return Status::OK(); +} + +Result> HashJoinDictBuild::RemapInputValues( + ExecContext* ctx, const Datum& values, int64_t batch_length) const { + // Initialize encoder + // + internal::RowEncoder encoder; + std::vector encoder_types; + encoder_types.emplace_back(value_type_, ValueDescr::ARRAY); + encoder.Init(encoder_types, ctx); + + // Encode all + // + ARROW_DCHECK(values.is_array() || values.is_scalar()); + bool is_scalar = values.is_scalar(); + int64_t encoded_length = is_scalar ? 1 : batch_length; + ExecBatch batch({values}, encoded_length); + RETURN_NOT_OK(encoder.EncodeAndAppend(batch)); + + // Allocate output buffers + // + ARROW_ASSIGN_OR_RAISE(std::shared_ptr non_nulls_buf, + AllocateBitmap(batch_length, ctx->memory_pool())); + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr ids_buf, + AllocateBuffer(batch_length * sizeof(int32_t), ctx->memory_pool())); + uint8_t* non_nulls = non_nulls_buf->mutable_data(); + int32_t* ids = reinterpret_cast(ids_buf->mutable_data()); + memset(non_nulls, 0xff, BitUtil::BytesForBits(batch_length)); + + // Populate output buffers (for scalar only the first entry is populated) + // + for (int64_t i = 0; i < encoded_length; ++i) { + std::string str = encoder.encoded_row(static_cast(i)); + if (internal::KeyEncoder::IsNull(reinterpret_cast(str.data()))) { + // Map nulls to nulls + BitUtil::ClearBit(non_nulls, i); + ids[i] = HashJoinDictUtil::kNullId; + } else { + auto iter = hash_table_.find(str); + if (iter == hash_table_.end()) { + ids[i] = HashJoinDictUtil::kMissingValueId; + } else { + ids[i] = iter->second; + } + } + } + + // Generate array of repeated values for scalar input + // + if (is_scalar) { + if (!BitUtil::GetBit(non_nulls, 0)) { + memset(non_nulls, 0, BitUtil::BytesForBits(batch_length)); + } + for (int64_t i = 1; i < batch_length; ++i) { + ids[i] = ids[0]; + } + } + + return ArrayData::Make(DataTypeAfterRemapping(), batch_length, + {std::move(non_nulls_buf), std::move(ids_buf)}); +} + +Result> HashJoinDictBuild::RemapInput( + ExecContext* ctx, const Datum& indices, int64_t batch_length, + const std::shared_ptr& data_type) const { + auto dict = HashJoinDictUtil::ExtractDictionary(indices); + + if (!dictionary_->Equals(dict)) { + return Status::NotImplemented("Unifying differing dictionaries"); + } + + return HashJoinDictUtil::IndexRemapUsingLUT(ctx, indices, batch_length, remapped_ids_, + data_type); +} + +Result> HashJoinDictBuild::RemapOutput( + const ArrayData& indices32Bit, ExecContext* ctx) const { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr indices, + HashJoinDictUtil::CvtFromInt32(index_type_, Datum(indices32Bit), + indices32Bit.length, ctx)); + + auto type = std::make_shared(index_type_, value_type_); + return ArrayData::Make(type, indices->length, indices->buffers, {}, + unified_dictionary_); +} + +void HashJoinDictBuild::CleanUp() { + index_type_.reset(); + value_type_.reset(); + hash_table_.clear(); + remapped_ids_.reset(); + unified_dictionary_.reset(); +} + +bool HashJoinDictProbe::KeyNeedsProcessing( + const std::shared_ptr& probe_data_type, + const std::shared_ptr& build_data_type) { + bool l_is_dict = (probe_data_type->id() == Type::DICTIONARY); + bool r_is_dict = (build_data_type->id() == Type::DICTIONARY); + return l_is_dict || r_is_dict; +} + +std::shared_ptr HashJoinDictProbe::DataTypeAfterRemapping( + const std::shared_ptr& build_data_type) { + bool r_is_dict = (build_data_type->id() == Type::DICTIONARY); + if (r_is_dict) { + return HashJoinDictBuild::DataTypeAfterRemapping(); + } else { + return build_data_type; + } +} + +Result> HashJoinDictProbe::RemapInput( + const HashJoinDictBuild* opt_build_side, const Datum& data, int64_t batch_length, + const std::shared_ptr& probe_data_type, + const std::shared_ptr& build_data_type, ExecContext* ctx) { + // Cases: + // 1. Dictionary(probe)-Dictionary(build) + // 2. Dictionary(probe)-Value(build) + // 3. Value(probe)-Dictionary(build) + // + bool l_is_dict = (probe_data_type->id() == Type::DICTIONARY); + bool r_is_dict = (build_data_type->id() == Type::DICTIONARY); + if (l_is_dict) { + auto dict = HashJoinDictUtil::ExtractDictionary(data); + const auto& dict_type = checked_cast(*probe_data_type); + + // Verify that the dictionary is always the same. + if (dictionary_) { + if (!dictionary_->Equals(dict)) { + return Status::NotImplemented( + "Unifying differing dictionaries for probe key of hash join"); + } + } else { + dictionary_ = dict; + + // Precompute helper data for the given dictionary if this is the first call. + if (r_is_dict) { + ARROW_DCHECK(opt_build_side); + ARROW_ASSIGN_OR_RAISE( + remapped_ids_, + opt_build_side->RemapInputValues(ctx, Datum(dict->data()), dict->length())); + } else { + std::vector encoder_types; + encoder_types.emplace_back(dict_type.value_type(), ValueDescr::ARRAY); + encoder_.Init(encoder_types, ctx); + ExecBatch batch({dict->data()}, dict->length()); + RETURN_NOT_OK(encoder_.EncodeAndAppend(batch)); + } + } + + if (r_is_dict) { + // CASE 1: + // Remap dictionary ids + return HashJoinDictUtil::IndexRemapUsingLUT(ctx, data, batch_length, remapped_ids_, + probe_data_type); + } else { + // CASE 2: + // Decode selected rows from encoder. + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr row_ids_arr, + HashJoinDictUtil::CvtToInt32(dict_type.index_type(), data, batch_length, ctx)); + // Change nulls to internal::RowEncoder::kRowIdForNulls() in index. + int32_t* row_ids = + reinterpret_cast(row_ids_arr->buffers[1]->mutable_data()); + const uint8_t* non_nulls = row_ids_arr->buffers[0]->data(); + for (int64_t i = 0; i < batch_length; ++i) { + if (!BitUtil::GetBit(non_nulls, i)) { + row_ids[i] = internal::RowEncoder::kRowIdForNulls(); + } + } + + ARROW_ASSIGN_OR_RAISE(ExecBatch batch, encoder_.Decode(batch_length, row_ids)); + return batch.values[0].array(); + } + } else { + // CASE 3: + // Map values to dictionary ids from build side. + // Values missing in the dictionary will get assigned a special constant + // HashJoinDictUtil::kMissingValueId (different than any valid id). + // + ARROW_DCHECK(r_is_dict); + ARROW_DCHECK(opt_build_side); + return opt_build_side->RemapInputValues(ctx, data, batch_length); + } +} + +void HashJoinDictProbe::CleanUp() { + dictionary_.reset(); + remapped_ids_.reset(); + encoder_.Clear(); +} + +Status HashJoinDictBuildMulti::Init( + const SchemaProjectionMaps& proj_map, + const ExecBatch* opt_non_empty_batch, ExecContext* ctx) { + int num_keys = proj_map.num_cols(HashJoinProjection::KEY); + needs_remap_.resize(num_keys); + remap_imp_.resize(num_keys); + for (int i = 0; i < num_keys; ++i) { + needs_remap_[i] = HashJoinDictBuild::KeyNeedsProcessing( + proj_map.data_type(HashJoinProjection::KEY, i)); + } + + bool build_side_empty = (opt_non_empty_batch == nullptr); + + if (!build_side_empty) { + auto key_to_input = proj_map.map(HashJoinProjection::KEY, HashJoinProjection::INPUT); + for (int i = 0; i < num_keys; ++i) { + const std::shared_ptr& data_type = + proj_map.data_type(HashJoinProjection::KEY, i); + if (data_type->id() == Type::DICTIONARY) { + const auto& dict_type = checked_cast(*data_type); + const auto& dict = HashJoinDictUtil::ExtractDictionary( + opt_non_empty_batch->values[key_to_input.get(i)]); + RETURN_NOT_OK(remap_imp_[i].Init(ctx, dict, dict_type.index_type(), + dict_type.value_type())); + } + } + } else { + for (int i = 0; i < num_keys; ++i) { + const std::shared_ptr& data_type = + proj_map.data_type(HashJoinProjection::KEY, i); + if (data_type->id() == Type::DICTIONARY) { + const auto& dict_type = checked_cast(*data_type); + RETURN_NOT_OK(remap_imp_[i].Init(ctx, nullptr, dict_type.index_type(), + dict_type.value_type())); + } + } + } + return Status::OK(); +} + +void HashJoinDictBuildMulti::InitEncoder( + const SchemaProjectionMaps& proj_map, RowEncoder* encoder, + ExecContext* ctx) { + int num_cols = proj_map.num_cols(HashJoinProjection::KEY); + std::vector data_types(num_cols); + for (int icol = 0; icol < num_cols; ++icol) { + std::shared_ptr data_type = + proj_map.data_type(HashJoinProjection::KEY, icol); + if (HashJoinDictBuild::KeyNeedsProcessing(data_type)) { + data_type = HashJoinDictBuild::DataTypeAfterRemapping(); + } + data_types[icol] = ValueDescr(data_type, ValueDescr::ARRAY); + } + encoder->Init(data_types, ctx); +} + +Status HashJoinDictBuildMulti::EncodeBatch( + size_t thread_index, const SchemaProjectionMaps& proj_map, + const ExecBatch& batch, RowEncoder* encoder, ExecContext* ctx) const { + ExecBatch projected({}, batch.length); + int num_cols = proj_map.num_cols(HashJoinProjection::KEY); + projected.values.resize(num_cols); + + auto to_input = proj_map.map(HashJoinProjection::KEY, HashJoinProjection::INPUT); + for (int icol = 0; icol < num_cols; ++icol) { + projected.values[icol] = batch.values[to_input.get(icol)]; + + if (needs_remap_[icol]) { + ARROW_ASSIGN_OR_RAISE( + projected.values[icol], + remap_imp_[icol].RemapInput(ctx, projected.values[icol], batch.length, + proj_map.data_type(HashJoinProjection::KEY, icol))); + } + } + return encoder->EncodeAndAppend(projected); +} + +Status HashJoinDictBuildMulti::PostDecode( + const SchemaProjectionMaps& proj_map, + ExecBatch* decoded_key_batch, ExecContext* ctx) { + // Post process build side keys that use dictionary + int num_keys = proj_map.num_cols(HashJoinProjection::KEY); + for (int i = 0; i < num_keys; ++i) { + if (needs_remap_[i]) { + ARROW_ASSIGN_OR_RAISE( + decoded_key_batch->values[i], + remap_imp_[i].RemapOutput(*decoded_key_batch->values[i].array(), ctx)); + } + } + return Status::OK(); +} + +void HashJoinDictProbeMulti::Init(size_t num_threads) { + local_states_.resize(num_threads); + for (size_t i = 0; i < local_states_.size(); ++i) { + local_states_[i].is_initialized = false; + } +} + +bool HashJoinDictProbeMulti::BatchRemapNeeded( + size_t thread_index, const SchemaProjectionMaps& proj_map_probe, + const SchemaProjectionMaps& proj_map_build, ExecContext* ctx) { + InitLocalStateIfNeeded(thread_index, proj_map_probe, proj_map_build, ctx); + return local_states_[thread_index].any_needs_remap; +} + +void HashJoinDictProbeMulti::InitLocalStateIfNeeded( + size_t thread_index, const SchemaProjectionMaps& proj_map_probe, + const SchemaProjectionMaps& proj_map_build, ExecContext* ctx) { + ThreadLocalState& local_state = local_states_[thread_index]; + + // Check if we need to remap any of the input keys because of dictionary encoding + // on either side of the join + // + int num_cols = proj_map_probe.num_cols(HashJoinProjection::KEY); + local_state.any_needs_remap = false; + local_state.needs_remap.resize(num_cols); + local_state.remap_imp.resize(num_cols); + for (int i = 0; i < num_cols; ++i) { + local_state.needs_remap[i] = HashJoinDictProbe::KeyNeedsProcessing( + proj_map_probe.data_type(HashJoinProjection::KEY, i), + proj_map_build.data_type(HashJoinProjection::KEY, i)); + if (local_state.needs_remap[i]) { + local_state.any_needs_remap = true; + } + } + + if (local_state.any_needs_remap) { + InitEncoder(proj_map_probe, proj_map_build, &local_state.post_remap_encoder, ctx); + } +} + +void HashJoinDictProbeMulti::InitEncoder( + const SchemaProjectionMaps& proj_map_probe, + const SchemaProjectionMaps& proj_map_build, RowEncoder* encoder, + ExecContext* ctx) { + int num_cols = proj_map_probe.num_cols(HashJoinProjection::KEY); + std::vector data_types(num_cols); + for (int icol = 0; icol < num_cols; ++icol) { + std::shared_ptr data_type = + proj_map_probe.data_type(HashJoinProjection::KEY, icol); + std::shared_ptr build_data_type = + proj_map_build.data_type(HashJoinProjection::KEY, icol); + if (HashJoinDictProbe::KeyNeedsProcessing(data_type, build_data_type)) { + data_type = HashJoinDictProbe::DataTypeAfterRemapping(build_data_type); + } + data_types[icol] = ValueDescr(data_type, ValueDescr::ARRAY); + } + encoder->Init(data_types, ctx); +} + +Status HashJoinDictProbeMulti::EncodeBatch( + size_t thread_index, const SchemaProjectionMaps& proj_map_probe, + const SchemaProjectionMaps& proj_map_build, + const HashJoinDictBuildMulti& dict_build, const ExecBatch& batch, + RowEncoder** out_encoder, ExecBatch* opt_out_key_batch, ExecContext* ctx) { + ThreadLocalState& local_state = local_states_[thread_index]; + InitLocalStateIfNeeded(thread_index, proj_map_probe, proj_map_build, ctx); + + ExecBatch projected({}, batch.length); + int num_cols = proj_map_probe.num_cols(HashJoinProjection::KEY); + projected.values.resize(num_cols); + + auto to_input = proj_map_probe.map(HashJoinProjection::KEY, HashJoinProjection::INPUT); + for (int icol = 0; icol < num_cols; ++icol) { + projected.values[icol] = batch.values[to_input.get(icol)]; + + if (local_state.needs_remap[icol]) { + ARROW_ASSIGN_OR_RAISE( + projected.values[icol], + local_state.remap_imp[icol].RemapInput( + &(dict_build.get_dict_build(icol)), projected.values[icol], batch.length, + proj_map_probe.data_type(HashJoinProjection::KEY, icol), + proj_map_build.data_type(HashJoinProjection::KEY, icol), ctx)); + } + } + + if (opt_out_key_batch) { + *opt_out_key_batch = projected; + } + + local_state.post_remap_encoder.Clear(); + RETURN_NOT_OK(local_state.post_remap_encoder.EncodeAndAppend(projected)); + *out_encoder = &local_state.post_remap_encoder; + + return Status::OK(); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/hash_join_dict.h b/cpp/src/arrow/compute/exec/hash_join_dict.h new file mode 100644 index 00000000000..9b746c2d0cf --- /dev/null +++ b/cpp/src/arrow/compute/exec/hash_join_dict.h @@ -0,0 +1,321 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "arrow/compute/exec.h" +#include "arrow/compute/exec/schema_util.h" +#include "arrow/compute/kernels/row_encoder.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" + +// This file contains hash join logic related to handling of dictionary encoded key +// columns. +// +// A key column from probe side of the join can be matched against a key column from build +// side of the join, as long as the underlying value types are equal. That means that: +// - both scalars and arrays can be used and even mixed in the same column +// - dictionary column can be matched against non-dictionary column if underlying value +// types are equal +// - dictionary column can be matched against dictionary column with a different index +// type, and potentially using a different dictionary, if underlying value types are equal +// +// We currently require in hash that for all dictionary encoded columns, the same +// dictionary is used in all input exec batches. +// +// In order to allow matching columns with different dictionaries, different dictionary +// index types, and dictionary key against non-dictionary key, internally comparisons will +// be evaluated after remapping values on both sides of the join to a common +// representation (which will be called "unified representation"). This common +// representation is a column of int32() type (not a dictionary column). It represents an +// index in the unified dictionary computed for the (only) dictionary present on build +// side (an empty dictionary is still created for an empty build side). Null value is +// always represented in this common representation as null int32 value, unified +// dictionary will never contain a null value (so there is no ambiguity of representing +// nulls as either index to a null entry in the dictionary or null index). +// +// Unified dictionary represents values present on build side. There may be values on +// probe side that are not present in it. All such values, that are not null, are mapped +// in the common representation to a special constant kMissingValueId. +// + +namespace arrow { +namespace compute { + +using internal::RowEncoder; + +/// Helper class with operations that are stateless and common to processing of dictionary +/// keys on both build and probe side. +class HashJoinDictUtil { + public: + // Null values in unified representation are always represented as null that has + // corresponding integer set to this constant + static constexpr int32_t kNullId = 0; + // Constant representing a value, that is not null, missing on the build side, in + // unified representation. + static constexpr int32_t kMissingValueId = -1; + + // Check if data types of corresponding pair of key column on build and probe side are + // compatible + static bool KeyDataTypesValid(const std::shared_ptr& probe_data_type, + const std::shared_ptr& build_data_type); + + // Input must be dictionary array or dictionary scalar. + // A precomputed and provided here lookup table in the form of int32() array will be + // used to remap input indices to unified representation. + // + static Result> IndexRemapUsingLUT( + ExecContext* ctx, const Datum& indices, int64_t batch_length, + const std::shared_ptr& map_array, + const std::shared_ptr& data_type); + + // Return int32() array that contains indices of input dictionary array or scalar after + // type casting. + static Result> CvtToInt32( + const std::shared_ptr& from_type, const Datum& input, + int64_t batch_length, ExecContext* ctx); + + // Return an array that contains elements of input int32() array after casting to a + // given integer type. This is used for mapping unified representation stored in the + // hash table on build side back to original input data type of hash join, when + // outputting hash join results to parent exec node. + // + static Result> CvtFromInt32( + const std::shared_ptr& to_type, const Datum& input, int64_t batch_length, + ExecContext* ctx); + + // Return dictionary referenced in either dictionary array or dictionary scalar + static std::shared_ptr ExtractDictionary(const Datum& data); + + private: + template + static Result> CvtImp( + const std::shared_ptr& to_type, const Datum& input, int64_t batch_length, + ExecContext* ctx); +}; + +/// Implements processing of dictionary arrays/scalars in key columns on the build side of +/// a hash join. +/// Each instance of this class corresponds to a single column and stores and +/// processes only the information related to that column. +/// Const methods are thread-safe, non-const methods are not (the caller must make sure +/// that only one thread at any time will access them). +/// +class HashJoinDictBuild { + public: + // Returns true if the key column (described in input by its data type) requires any + // pre- or post-processing related to handling dictionaries. + // + static bool KeyNeedsProcessing(const std::shared_ptr& build_data_type) { + return (build_data_type->id() == Type::DICTIONARY); + } + + // Data type of unified representation + static std::shared_ptr DataTypeAfterRemapping() { return int32(); } + + // Should be called only once in hash join, before processing any build or probe + // batches. + // + // Takes a pointer to the dictionary for a corresponding key column on the build side as + // an input. If the build side is empty, it still needs to be called, but with + // dictionary pointer set to null. + // + // Currently it is required that all input batches on build side share the same + // dictionary. For each input batch during its pre-processing, dictionary will be + // checked and error will be returned if it is different then the one provided in the + // call to this method. + // + // Unifies the dictionary. The order of the values is still preserved. + // Null and duplicate entries are removed. If the dictionary is already unified, its + // copy will be produced and stored within this class. + // + // Prepares the mapping from ids within original dictionary to the ids in the resulting + // dictionary. This is used later on to pre-process (map to unified representation) key + // column on build side. + // + // Prepares the reverse mapping (in the form of hash table) from values to the ids in + // the resulting dictionary. This will be used later on to pre-process (map to unified + // representation) key column on probe side. Values on probe side that are not present + // in the original dictionary will be mapped to a special constant kMissingValueId. The + // exception is made for nulls, which get always mapped to nulls (both when null is + // represented as a dictionary id pointing to a null and a null dictionary id). + // + Status Init(ExecContext* ctx, std::shared_ptr dictionary, + std::shared_ptr index_type, std::shared_ptr value_type); + + // Remap array or scalar values into unified representation (array of int32()). + // Outputs kMissingValueId if input value is not found in the unified dictionary. + // Outputs null for null input value (with corresponding data set to kNullId). + // + Result> RemapInputValues(ExecContext* ctx, + const Datum& values, + int64_t batch_length) const; + + // Remap dictionary array or dictionary scalar on build side to unified representation. + // Dictionary referenced in the input must match the dictionary that was + // given during initialization. + // The output is a dictionary array that references unified dictionary. + // + Result> RemapInput( + ExecContext* ctx, const Datum& indices, int64_t batch_length, + const std::shared_ptr& data_type) const; + + // Outputs dictionary array referencing unified dictionary, given an array with 32-bit + // ids. + // Used to post-process values looked up in a hash table on build side of the hash join + // before outputting to the parent exec node. + // + Result> RemapOutput(const ArrayData& indices32Bit, + ExecContext* ctx) const; + + // Release shared pointers and memory + void CleanUp(); + + private: + // Data type of dictionary ids for the input dictionary on build side + std::shared_ptr index_type_; + // Data type of values for the input dictionary on build side + std::shared_ptr value_type_; + // Mapping from (encoded as string) values to the ids in unified dictionary + std::unordered_map hash_table_; + // Mapping from input dictionary ids to unified dictionary ids + std::shared_ptr remapped_ids_; + // Input dictionary + std::shared_ptr dictionary_; + // Unified dictionary + std::shared_ptr unified_dictionary_; +}; + +/// Implements processing of dictionary arrays/scalars in key columns on the probe side of +/// a hash join. +/// Each instance of this class corresponds to a single column and stores and +/// processes only the information related to that column. +/// It is not thread-safe - every participating thread should use its own instance of +/// this class. +/// +class HashJoinDictProbe { + public: + static bool KeyNeedsProcessing(const std::shared_ptr& probe_data_type, + const std::shared_ptr& build_data_type); + + // Data type of the result of remapping input key column. + // + // The result of remapping is what is used in hash join for matching keys on build and + // probe side. The exact data types may be different, as described below, and therefore + // a common representation is needed for simplifying comparisons of pairs of keys on + // both sides. + // + // We support matching key that is of non-dictionary type with key that is of dictionary + // type, as long as the underlying value types are equal. We support matching when both + // keys are of dictionary type, regardless whether underlying dictionary index types are + // the same or not. + // + static std::shared_ptr DataTypeAfterRemapping( + const std::shared_ptr& build_data_type); + + // Should only be called if KeyNeedsProcessing method returns true for a pair of + // corresponding key columns from build and probe side. + // Converts values in order to match the common representation for + // both build and probe side used in hash table comparison. + // Supports arrays and scalars as input. + // Argument opt_build_side should be null if dictionary key on probe side is matched + // with non-dictionary key on build side. + // + Result> RemapInput( + const HashJoinDictBuild* opt_build_side, const Datum& data, int64_t batch_length, + const std::shared_ptr& probe_data_type, + const std::shared_ptr& build_data_type, ExecContext* ctx); + + void CleanUp(); + + private: + // May be null if probe side key is non-dictionary. Otherwise it is used to verify that + // only a single dictionary is referenced in exec batch on probe side of hash join. + std::shared_ptr dictionary_; + // Mapping from dictionary on probe side of hash join (if it is used) to unified + // representation. + std::shared_ptr remapped_ids_; + // Encoder of key columns that uses unified representation instead of original data type + // for key columns that need to use it (have dictionaries on either side of the join). + internal::RowEncoder encoder_; +}; + +// Encapsulates dictionary handling logic for build side of hash join. +// +class HashJoinDictBuildMulti { + public: + Status Init(const SchemaProjectionMaps& proj_map, + const ExecBatch* opt_non_empty_batch, ExecContext* ctx); + static void InitEncoder(const SchemaProjectionMaps& proj_map, + RowEncoder* encoder, ExecContext* ctx); + Status EncodeBatch(size_t thread_index, + const SchemaProjectionMaps& proj_map, + const ExecBatch& batch, RowEncoder* encoder, ExecContext* ctx) const; + Status PostDecode(const SchemaProjectionMaps& proj_map, + ExecBatch* decoded_key_batch, ExecContext* ctx); + const HashJoinDictBuild& get_dict_build(int icol) const { return remap_imp_[icol]; } + + private: + std::vector needs_remap_; + std::vector remap_imp_; +}; + +// Encapsulates dictionary handling logic for probe side of hash join +// +class HashJoinDictProbeMulti { + public: + void Init(size_t num_threads); + bool BatchRemapNeeded(size_t thread_index, + const SchemaProjectionMaps& proj_map_probe, + const SchemaProjectionMaps& proj_map_build, + ExecContext* ctx); + Status EncodeBatch(size_t thread_index, + const SchemaProjectionMaps& proj_map_probe, + const SchemaProjectionMaps& proj_map_build, + const HashJoinDictBuildMulti& dict_build, const ExecBatch& batch, + RowEncoder** out_encoder, ExecBatch* opt_out_key_batch, + ExecContext* ctx); + + private: + void InitLocalStateIfNeeded( + size_t thread_index, const SchemaProjectionMaps& proj_map_probe, + const SchemaProjectionMaps& proj_map_build, ExecContext* ctx); + static void InitEncoder(const SchemaProjectionMaps& proj_map_probe, + const SchemaProjectionMaps& proj_map_build, + RowEncoder* encoder, ExecContext* ctx); + struct ThreadLocalState { + bool is_initialized; + // Whether any key column needs remapping (because of dictionaries used) before doing + // join hash table lookups + bool any_needs_remap; + // Whether each key column needs remapping before doing join hash table lookups + std::vector needs_remap; + std::vector remap_imp; + // Encoder of key columns that uses unified representation instead of original data + // type for key columns that need to use it (have dictionaries on either side of the + // join). + RowEncoder post_remap_encoder; + }; + std::vector local_states_; +}; + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc index 3e02054fbed..583ac9a1468 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node.cc @@ -19,6 +19,7 @@ #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/hash_join.h" +#include "arrow/compute/exec/hash_join_dict.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/schema_util.h" #include "arrow/compute/exec/util.h" @@ -163,13 +164,6 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_sc const FieldPath& match = result.ValueUnsafe(); const std::shared_ptr& type = (left_side ? left_schema.fields() : right_schema.fields())[match[0]]->type(); - if (type->id() == Type::DICTIONARY) { - return Status::Invalid( - "Dictionary type support for join key is not yet implemented, key field " - "reference: ", - field_ref.ToString(), left_side ? " on left " : " on right ", - "side of the join"); - } if ((type->id() != Type::BOOL && !is_fixed_width(type->id()) && !is_binary_like(type->id())) || is_large_binary_like(type->id())) { @@ -184,11 +178,11 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_sc int right_id = right_ref.FindOne(right_schema).ValueUnsafe()[0]; const std::shared_ptr& left_type = left_schema.fields()[left_id]->type(); const std::shared_ptr& right_type = right_schema.fields()[right_id]->type(); - if (!left_type->Equals(right_type)) { - return Status::Invalid("Mismatched data types for corresponding join field keys: ", - left_ref.ToString(), " of type ", left_type->ToString(), - " and ", right_ref.ToString(), " of type ", - right_type->ToString()); + if (!HashJoinDictUtil::KeyDataTypesValid(left_type, right_type)) { + return Status::Invalid( + "Incompatible data types for corresponding join field keys: ", + left_ref.ToString(), " of type ", left_type->ToString(), " and ", + right_ref.ToString(), " of type ", right_type->ToString()); } } @@ -228,16 +222,6 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_sc field_ref.ToString(), left_side ? " on left " : " on right ", "side of the join"); } - const FieldPath& match = result.ValueUnsafe(); - const std::shared_ptr& type = - (left_side ? left_schema.fields() : right_schema.fields())[match[0]]->type(); - if (type->id() == Type::DICTIONARY) { - return Status::Invalid( - "Dictionary type support for join output field is not yet implemented, output " - "field reference: ", - field_ref.ToString(), left_side ? " on left " : " on right ", - "side of the join"); - } } return Status::OK(); } diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc index a5410b0d37a..3cd84e07b1e 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc @@ -1113,5 +1113,539 @@ TEST(HashJoin, Random) { } } +void DecodeScalarsAndDictionariesInBatch(ExecBatch* batch, MemoryPool* pool) { + for (size_t i = 0; i < batch->values.size(); ++i) { + if (batch->values[i].is_scalar()) { + ASSERT_OK_AND_ASSIGN( + std::shared_ptr col, + MakeArrayFromScalar(*(batch->values[i].scalar()), batch->length, pool)); + batch->values[i] = Datum(col); + } + if (batch->values[i].type()->id() == Type::DICTIONARY) { + const auto& dict_type = + checked_cast(*batch->values[i].type()); + std::shared_ptr indices = + ArrayData::Make(dict_type.index_type(), batch->values[i].array()->length, + batch->values[i].array()->buffers); + const std::shared_ptr& dictionary = batch->values[i].array()->dictionary; + ASSERT_OK_AND_ASSIGN(Datum col, Take(*dictionary, *indices)); + batch->values[i] = col; + } + } +} + +std::shared_ptr UpdateSchemaAfterDecodingDictionaries( + const std::shared_ptr& schema) { + std::vector> output_fields(schema->num_fields()); + for (int i = 0; i < schema->num_fields(); ++i) { + const std::shared_ptr& field = schema->field(i); + if (field->type()->id() == Type::DICTIONARY) { + const auto& dict_type = checked_cast(*field->type()); + output_fields[i] = std::make_shared(field->name(), dict_type.value_type(), + true /* nullable */); + } else { + output_fields[i] = field->Copy(); + } + } + return std::make_shared(std::move(output_fields)); +} + +void TestHashJoinDictionaryHelper( + JoinType join_type, JoinKeyCmp cmp, + // Whether to run parallel hash join. + // This requires generating multiple copies of each input batch on one side of the + // join. Expected results will be automatically adjusted to reflect the multiplication + // of input batches. + bool parallel, Datum l_key, Datum l_payload, Datum r_key, Datum r_payload, + Datum l_out_key, Datum l_out_payload, Datum r_out_key, Datum r_out_payload, + // Number of rows at the end of expected output that represent rows from the right + // side that do not have a match on the left side. This number is needed to + // automatically adjust expected result when multiplying input batches on the left + // side. + int expected_num_r_no_match, + // Whether to swap two inputs to the hash join + bool swap_sides) { + int64_t l_length = l_key.is_array() + ? l_key.array()->length + : l_payload.is_array() ? l_payload.array()->length : -1; + int64_t r_length = r_key.is_array() + ? r_key.array()->length + : r_payload.is_array() ? r_payload.array()->length : -1; + ARROW_DCHECK(l_length >= 0 && r_length >= 0); + + constexpr int batch_multiplicity_for_parallel = 2; + + // Split both sides into exactly two batches + int64_t l_first_length = l_length / 2; + int64_t r_first_length = r_length / 2; + BatchesWithSchema l_batches, r_batches; + l_batches.batches.resize(2); + r_batches.batches.resize(2); + ASSERT_OK_AND_ASSIGN( + l_batches.batches[0], + ExecBatch::Make({l_key.is_array() ? l_key.array()->Slice(0, l_first_length) : l_key, + l_payload.is_array() ? l_payload.array()->Slice(0, l_first_length) + : l_payload})); + ASSERT_OK_AND_ASSIGN( + l_batches.batches[1], + ExecBatch::Make( + {l_key.is_array() + ? l_key.array()->Slice(l_first_length, l_length - l_first_length) + : l_key, + l_payload.is_array() + ? l_payload.array()->Slice(l_first_length, l_length - l_first_length) + : l_payload})); + ASSERT_OK_AND_ASSIGN( + r_batches.batches[0], + ExecBatch::Make({r_key.is_array() ? r_key.array()->Slice(0, r_first_length) : r_key, + r_payload.is_array() ? r_payload.array()->Slice(0, r_first_length) + : r_payload})); + ASSERT_OK_AND_ASSIGN( + r_batches.batches[1], + ExecBatch::Make( + {r_key.is_array() + ? r_key.array()->Slice(r_first_length, r_length - r_first_length) + : r_key, + r_payload.is_array() + ? r_payload.array()->Slice(r_first_length, r_length - r_first_length) + : r_payload})); + l_batches.schema = + schema({field("l_key", l_key.type()), field("l_payload", l_payload.type())}); + r_batches.schema = + schema({field("r_key", r_key.type()), field("r_payload", r_payload.type())}); + + // Add copies of input batches on originally left side of the hash join + if (parallel) { + for (int i = 0; i < batch_multiplicity_for_parallel - 1; ++i) { + l_batches.batches.push_back(l_batches.batches[0]); + l_batches.batches.push_back(l_batches.batches[1]); + } + } + + auto exec_ctx = arrow::internal::make_unique( + default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); + ASSERT_OK_AND_ASSIGN( + ExecNode * l_source, + MakeExecNode("source", plan.get(), {}, + SourceNodeOptions{l_batches.schema, l_batches.gen(parallel, + /*slow=*/false)})); + ASSERT_OK_AND_ASSIGN( + ExecNode * r_source, + MakeExecNode("source", plan.get(), {}, + SourceNodeOptions{r_batches.schema, r_batches.gen(parallel, + /*slow=*/false)})); + HashJoinNodeOptions join_options{join_type, + {FieldRef(swap_sides ? "r_key" : "l_key")}, + {FieldRef(swap_sides ? "l_key" : "r_key")}, + {FieldRef(swap_sides ? "r_key" : "l_key"), + FieldRef(swap_sides ? "r_payload" : "l_payload")}, + {FieldRef(swap_sides ? "l_key" : "r_key"), + FieldRef(swap_sides ? "l_payload" : "r_payload")}, + {cmp}}; + ASSERT_OK_AND_ASSIGN(ExecNode * join, MakeExecNode("hashjoin", plan.get(), + {(swap_sides ? r_source : l_source), + (swap_sides ? l_source : r_source)}, + join_options)); + AsyncGenerator> sink_gen; + ASSERT_OK_AND_ASSIGN( + std::ignore, MakeExecNode("sink", plan.get(), {join}, SinkNodeOptions{&sink_gen})); + ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); + + for (auto& batch : res) { + DecodeScalarsAndDictionariesInBatch(&batch, exec_ctx->memory_pool()); + } + std::shared_ptr output_schema = + UpdateSchemaAfterDecodingDictionaries(join->output_schema()); + + ASSERT_OK_AND_ASSIGN(std::shared_ptr output, + TableFromExecBatches(output_schema, res)); + + ExecBatch expected_batch; + if (swap_sides) { + ASSERT_OK_AND_ASSIGN(expected_batch, ExecBatch::Make({r_out_key, r_out_payload, + l_out_key, l_out_payload})); + } else { + ASSERT_OK_AND_ASSIGN(expected_batch, ExecBatch::Make({l_out_key, l_out_payload, + r_out_key, r_out_payload})); + } + + DecodeScalarsAndDictionariesInBatch(&expected_batch, exec_ctx->memory_pool()); + + // Slice expected batch into two to separate rows on right side with no matches from + // everything else. + // + std::vector expected_batches; + ASSERT_OK_AND_ASSIGN( + auto prefix_batch, + ExecBatch::Make({expected_batch.values[0].array()->Slice( + 0, expected_batch.length - expected_num_r_no_match), + expected_batch.values[1].array()->Slice( + 0, expected_batch.length - expected_num_r_no_match), + expected_batch.values[2].array()->Slice( + 0, expected_batch.length - expected_num_r_no_match), + expected_batch.values[3].array()->Slice( + 0, expected_batch.length - expected_num_r_no_match)})); + for (int i = 0; i < (parallel ? batch_multiplicity_for_parallel : 1); ++i) { + expected_batches.push_back(prefix_batch); + } + if (expected_num_r_no_match > 0) { + ASSERT_OK_AND_ASSIGN( + auto suffix_batch, + ExecBatch::Make({expected_batch.values[0].array()->Slice( + expected_batch.length - expected_num_r_no_match, + expected_num_r_no_match), + expected_batch.values[1].array()->Slice( + expected_batch.length - expected_num_r_no_match, + expected_num_r_no_match), + expected_batch.values[2].array()->Slice( + expected_batch.length - expected_num_r_no_match, + expected_num_r_no_match), + expected_batch.values[3].array()->Slice( + expected_batch.length - expected_num_r_no_match, + expected_num_r_no_match)})); + expected_batches.push_back(suffix_batch); + } + + ASSERT_OK_AND_ASSIGN(std::shared_ptr
expected, + TableFromExecBatches(output_schema, expected_batches)); + + // Compare results + AssertTablesEqual(expected, output); + + // TODO: This was added for debugging. Remove in the final version. + // std::cout << output->ToString(); +} + +TEST(HashJoin, Dictionary) { + auto int8_utf8 = std::make_shared(int8(), utf8()); + auto uint8_utf8 = std::make_shared(uint8(), utf8()); + auto int16_utf8 = std::make_shared(int16(), utf8()); + auto uint16_utf8 = std::make_shared(uint16(), utf8()); + auto int32_utf8 = std::make_shared(int32(), utf8()); + auto uint32_utf8 = std::make_shared(uint32(), utf8()); + auto int64_utf8 = std::make_shared(int64(), utf8()); + auto uint64_utf8 = std::make_shared(uint64(), utf8()); + std::shared_ptr dict_types[] = {int8_utf8, uint8_utf8, int16_utf8, + uint16_utf8, int32_utf8, uint32_utf8, + int64_utf8, uint64_utf8}; + + Random64Bit rng(43); + + // Dictionaries in payload columns + for (auto parallel : {false, true}) + for (auto swap_sides : {false, true}) { + TestHashJoinDictionaryHelper( + JoinType::FULL_OUTER, JoinKeyCmp::EQ, parallel, + // Input + ArrayFromJSON(utf8(), R"(["a", "c", "c", "d"])"), + DictArrayFromJSON(int8_utf8, R"([4, 2, 3, 0])", + R"(["p", "q", "r", null, "r"])"), + ArrayFromJSON(utf8(), R"(["a", "a", "b", "c"])"), + DictArrayFromJSON(int16_utf8, R"([0, 1, 0, 2])", R"(["r", null, "r", "q"])"), + // Expected output + ArrayFromJSON(utf8(), R"(["a", "a", "c", "c", "d", null])"), + DictArrayFromJSON(int8_utf8, R"([4, 4, 2, 3, 0, null])", + R"(["p", "q", "r", null, "r"])"), + ArrayFromJSON(utf8(), R"(["a", "a", "c", "c", null, "b"])"), + DictArrayFromJSON(int16_utf8, R"([0, 1, 2, 2, null, 0])", + R"(["r", null, "r", "q"])"), + 1, swap_sides); + } + + // Dictionaries in key columns + for (auto parallel : {false, true}) + for (auto swap_sides : {false, true}) + for (auto l_key_dict : {true, false}) + for (auto r_key_dict : {true, false}) { + auto l_key_dict_type = dict_types[rng.from_range(0, 7)]; + auto r_key_dict_type = dict_types[rng.from_range(0, 7)]; + + auto l_key = l_key_dict ? DictArrayFromJSON(l_key_dict_type, R"([2, 2, 0, 1])", + R"(["b", null, "a"])") + : ArrayFromJSON(utf8(), R"(["a", "a", "b", null])"); + auto l_payload = ArrayFromJSON(utf8(), R"(["x", "y", "z", "y"])"); + auto r_key = r_key_dict + ? DictArrayFromJSON(int16_utf8, R"([1, 0, null, 1, 2])", + R"([null, "b", "c"])") + : ArrayFromJSON(utf8(), R"(["b", null, null, "b", "c"])"); + auto r_payload = ArrayFromJSON(utf8(), R"(["p", "r", "p", "q", "s"])"); + + // IS comparison function (null is equal to null when matching keys) + TestHashJoinDictionaryHelper( + JoinType::FULL_OUTER, JoinKeyCmp::IS, parallel, + // Input + l_key, l_payload, r_key, r_payload, + // Expected + l_key_dict ? DictArrayFromJSON(l_key_dict_type, R"([2, 2, 0, 0, 1, 1, + null])", + R"(["b", null, "a"])") + : ArrayFromJSON(utf8(), R"(["a", "a", "b", "b", null, null, + null])"), + ArrayFromJSON(utf8(), R"(["x", "y", "z", "z", "y", "y", null])"), + r_key_dict + ? DictArrayFromJSON(r_key_dict_type, R"([null, null, 0, 0, null, null, + 1])", + R"(["b", "c"])") + : ArrayFromJSON(utf8(), R"([null, null, "b", "b", null, null, "c"])"), + ArrayFromJSON(utf8(), R"([null, null, "p", "q", "r", "p", "s"])"), 1, + swap_sides); + + // EQ comparison function (null is not matching null) + TestHashJoinDictionaryHelper( + JoinType::FULL_OUTER, JoinKeyCmp::EQ, parallel, + // Input + l_key, l_payload, r_key, r_payload, + // Expected + l_key_dict ? DictArrayFromJSON(l_key_dict_type, + R"([2, 2, 0, 0, 1, null, null, null])", + R"(["b", null, "a"])") + : ArrayFromJSON( + utf8(), R"(["a", "a", "b", "b", null, null, null, null])"), + ArrayFromJSON(utf8(), R"(["x", "y", "z", "z", "y", null, null, null])"), + r_key_dict + ? DictArrayFromJSON(r_key_dict_type, + R"([null, null, 0, 0, null, null, null, 1])", + R"(["b", "c"])") + : ArrayFromJSON(utf8(), + R"([null, null, "b", "b", null, null, null, "c"])"), + ArrayFromJSON(utf8(), R"([null, null, "p", "q", null, "r", "p", "s"])"), 3, + swap_sides); + } + + // Empty build side + { + auto l_key_dict_type = dict_types[rng.from_range(0, 7)]; + auto l_payload_dict_type = dict_types[rng.from_range(0, 7)]; + auto r_key_dict_type = dict_types[rng.from_range(0, 7)]; + auto r_payload_dict_type = dict_types[rng.from_range(0, 7)]; + + for (auto parallel : {false, true}) + for (auto swap_sides : {false, true}) + for (auto cmp : {JoinKeyCmp::IS, JoinKeyCmp::EQ}) { + TestHashJoinDictionaryHelper( + JoinType::FULL_OUTER, cmp, parallel, + // Input + DictArrayFromJSON(l_key_dict_type, R"([2, 0, 1])", R"(["b", null, "a"])"), + DictArrayFromJSON(l_payload_dict_type, R"([2, 2, 0])", + R"(["x", "y", "z"])"), + DictArrayFromJSON(r_key_dict_type, R"([])", R"([null, "b", "c"])"), + DictArrayFromJSON(r_payload_dict_type, R"([])", R"(["p", "r", "s"])"), + // Expected + DictArrayFromJSON(l_key_dict_type, R"([2, 0, 1])", R"(["b", null, "a"])"), + DictArrayFromJSON(l_payload_dict_type, R"([2, 2, 0])", + R"(["x", "y", "z"])"), + DictArrayFromJSON(r_key_dict_type, R"([null, null, null])", + R"(["b", "c"])"), + DictArrayFromJSON(r_payload_dict_type, R"([null, null, null])", + R"(["p", "r", "s"])"), + 0, swap_sides); + } + } + + // Empty probe side + { + auto l_key_dict_type = dict_types[rng.from_range(0, 7)]; + auto l_payload_dict_type = dict_types[rng.from_range(0, 7)]; + auto r_key_dict_type = dict_types[rng.from_range(0, 7)]; + auto r_payload_dict_type = dict_types[rng.from_range(0, 7)]; + + for (auto parallel : {false, true}) + for (auto swap_sides : {false, true}) + for (auto cmp : {JoinKeyCmp::IS, JoinKeyCmp::EQ}) { + TestHashJoinDictionaryHelper( + JoinType::FULL_OUTER, cmp, parallel, + // Input + DictArrayFromJSON(l_key_dict_type, R"([])", R"(["b", null, "a"])"), + DictArrayFromJSON(l_payload_dict_type, R"([])", R"(["x", "y", "z"])"), + DictArrayFromJSON(r_key_dict_type, R"([2, 0, 1, null])", + R"([null, "b", "c"])"), + DictArrayFromJSON(r_payload_dict_type, R"([1, 1, null, 0])", + R"(["p", "r", "s"])"), + // Expected + DictArrayFromJSON(l_key_dict_type, R"([null, null, null, null])", + R"(["b", null, "a"])"), + DictArrayFromJSON(l_payload_dict_type, R"([null, null, null, null])", + R"(["x", "y", "z"])"), + DictArrayFromJSON(r_key_dict_type, R"([1, null, 0, null])", + R"(["b", "c"])"), + DictArrayFromJSON(r_payload_dict_type, R"([1, 1, null, 0])", + R"(["p", "r", "s"])"), + 4, swap_sides); + } + } +} + +TEST(HashJoin, Scalars) { + auto int8_utf8 = std::make_shared(int8(), utf8()); + auto int16_utf8 = std::make_shared(int16(), utf8()); + auto int32_utf8 = std::make_shared(int32(), utf8()); + + // Scalars in payload columns + for (auto use_scalar_dict : {false, true}) { + TestHashJoinDictionaryHelper( + JoinType::FULL_OUTER, JoinKeyCmp::EQ, false /*parallel*/, + // Input + ArrayFromJSON(utf8(), R"(["a", "c", "c", "d"])"), + use_scalar_dict ? DictScalarFromJSON(int16_utf8, "1", R"(["z", "x", "y"])") + : ScalarFromJSON(utf8(), "\"x\""), + ArrayFromJSON(utf8(), R"(["a", "a", "b", "c"])"), + use_scalar_dict ? DictScalarFromJSON(int32_utf8, "0", R"(["z", "x", "y"])") + : ScalarFromJSON(utf8(), "\"z\""), + // Expected output + ArrayFromJSON(utf8(), R"(["a", "a", "c", "c", "d", null])"), + ArrayFromJSON(utf8(), R"(["x", "x", "x", "x", "x", null])"), + ArrayFromJSON(utf8(), R"(["a", "a", "c", "c", null, "b"])"), + ArrayFromJSON(utf8(), R"(["z", "z", "z", "z", null, "z"])"), 1, + false /*swap sides*/); + } + + // Scalars in key columns + for (auto use_scalar_dict : {false, true}) + for (auto swap_sides : {false, true}) { + TestHashJoinDictionaryHelper( + JoinType::FULL_OUTER, JoinKeyCmp::EQ, false /*parallel*/, + // Input + use_scalar_dict ? DictScalarFromJSON(int8_utf8, "1", R"(["b", "a", "c"])") + : ScalarFromJSON(utf8(), "\"a\""), + ArrayFromJSON(utf8(), R"(["x", "y"])"), + ArrayFromJSON(utf8(), R"(["a", null, "b"])"), + ArrayFromJSON(utf8(), R"(["p", "q", "r"])"), + // Expected output + ArrayFromJSON(utf8(), R"(["a", "a", null, null])"), + ArrayFromJSON(utf8(), R"(["x", "y", null, null])"), + ArrayFromJSON(utf8(), R"(["a", "a", null, "b"])"), + ArrayFromJSON(utf8(), R"(["p", "p", "q", "r"])"), 2, swap_sides); + } + + // Null scalars in key columns + for (auto use_scalar_dict : {false, true}) + for (auto swap_sides : {false, true}) { + TestHashJoinDictionaryHelper( + JoinType::FULL_OUTER, JoinKeyCmp::EQ, false /*parallel*/, + // Input + use_scalar_dict ? DictScalarFromJSON(int16_utf8, "2", R"(["a", "b", null])") + : ScalarFromJSON(utf8(), "null"), + ArrayFromJSON(utf8(), R"(["x", "y"])"), + ArrayFromJSON(utf8(), R"(["a", null, "b"])"), + ArrayFromJSON(utf8(), R"(["p", "q", "r"])"), + // Expected output + ArrayFromJSON(utf8(), R"([null, null, null, null, null])"), + ArrayFromJSON(utf8(), R"(["x", "y", null, null, null])"), + ArrayFromJSON(utf8(), R"([null, null, "a", null, "b"])"), + ArrayFromJSON(utf8(), R"([null, null, "p", "q", "r"])"), 3, swap_sides); + TestHashJoinDictionaryHelper( + JoinType::FULL_OUTER, JoinKeyCmp::IS, false /*parallel*/, + // Input + use_scalar_dict ? DictScalarFromJSON(int16_utf8, "null", R"(["a", "b", null])") + : ScalarFromJSON(utf8(), "null"), + ArrayFromJSON(utf8(), R"(["x", "y"])"), + ArrayFromJSON(utf8(), R"(["a", null, "b"])"), + ArrayFromJSON(utf8(), R"(["p", "q", "r"])"), + // Expected output + ArrayFromJSON(utf8(), R"([null, null, null, null])"), + ArrayFromJSON(utf8(), R"(["x", "y", null, null])"), + ArrayFromJSON(utf8(), R"([null, null, "a", "b"])"), + ArrayFromJSON(utf8(), R"(["q", "q", "p", "r"])"), 2, swap_sides); + } + + // Scalars with the empty build/probe side + for (auto use_scalar_dict : {false, true}) + for (auto swap_sides : {false, true}) { + TestHashJoinDictionaryHelper( + JoinType::FULL_OUTER, JoinKeyCmp::EQ, false /*parallel*/, + // Input + use_scalar_dict ? DictScalarFromJSON(int8_utf8, "1", R"(["b", "a", "c"])") + : ScalarFromJSON(utf8(), "\"a\""), + ArrayFromJSON(utf8(), R"(["x", "y"])"), ArrayFromJSON(utf8(), R"([])"), + ArrayFromJSON(utf8(), R"([])"), + // Expected output + ArrayFromJSON(utf8(), R"(["a", "a"])"), ArrayFromJSON(utf8(), R"(["x", "y"])"), + ArrayFromJSON(utf8(), R"([null, null])"), + ArrayFromJSON(utf8(), R"([null, null])"), 0, swap_sides); + } + + // Scalars vs dictionaries in key columns + for (auto use_scalar_dict : {false, true}) + for (auto swap_sides : {false, true}) { + TestHashJoinDictionaryHelper( + JoinType::FULL_OUTER, JoinKeyCmp::EQ, false /*parallel*/, + // Input + use_scalar_dict ? DictScalarFromJSON(int32_utf8, "1", R"(["b", "a", "c"])") + : ScalarFromJSON(utf8(), "\"a\""), + ArrayFromJSON(utf8(), R"(["x", "y"])"), + DictArrayFromJSON(int32_utf8, R"([2, 2, 1])", R"(["b", null, "a"])"), + ArrayFromJSON(utf8(), R"(["p", "q", "r"])"), + // Expected output + ArrayFromJSON(utf8(), R"(["a", "a", "a", "a", null])"), + ArrayFromJSON(utf8(), R"(["x", "x", "y", "y", null])"), + ArrayFromJSON(utf8(), R"(["a", "a", "a", "a", null])"), + ArrayFromJSON(utf8(), R"(["p", "q", "p", "q", "r"])"), 1, swap_sides); + } +} + +TEST(HashJoin, DictNegative) { + // For dictionary keys, all batches must share a single dictionary. + // Eventually, differing dictionaries will be unified and indices transposed + // during encoding to relieve this restriction. + const auto dictA = ArrayFromJSON(utf8(), R"(["ex", "why", "zee", null])"); + const auto dictB = ArrayFromJSON(utf8(), R"(["different", "dictionary"])"); + + Datum datumFirst = Datum( + *DictionaryArray::FromArrays(ArrayFromJSON(int32(), R"([0, 1, 2, 3])"), dictA)); + Datum datumSecondA = Datum( + *DictionaryArray::FromArrays(ArrayFromJSON(int32(), R"([3, 2, 2, 3])"), dictA)); + Datum datumSecondB = Datum( + *DictionaryArray::FromArrays(ArrayFromJSON(int32(), R"([0, 1, 1, 0])"), dictB)); + + for (int i = 0; i < 4; ++i) { + BatchesWithSchema l, r; + l.schema = schema({field("l_key", dictionary(int32(), utf8())), + field("l_payload", dictionary(int32(), utf8()))}); + r.schema = schema({field("r_key", dictionary(int32(), utf8())), + field("r_payload", dictionary(int32(), utf8()))}); + l.batches.resize(2); + r.batches.resize(2); + ASSERT_OK_AND_ASSIGN(l.batches[0], ExecBatch::Make({datumFirst, datumFirst})); + ASSERT_OK_AND_ASSIGN(r.batches[0], ExecBatch::Make({datumFirst, datumFirst})); + ASSERT_OK_AND_ASSIGN(l.batches[1], + ExecBatch::Make({i == 0 ? datumSecondB : datumSecondA, + i == 1 ? datumSecondB : datumSecondA})); + ASSERT_OK_AND_ASSIGN(r.batches[1], + ExecBatch::Make({i == 2 ? datumSecondB : datumSecondA, + i == 3 ? datumSecondB : datumSecondA})); + + auto exec_ctx = + arrow::internal::make_unique(default_memory_pool(), nullptr); + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); + ASSERT_OK_AND_ASSIGN( + ExecNode * l_source, + MakeExecNode("source", plan.get(), {}, + SourceNodeOptions{l.schema, l.gen(/*parallel=*/false, + /*slow=*/false)})); + ASSERT_OK_AND_ASSIGN( + ExecNode * r_source, + MakeExecNode("source", plan.get(), {}, + SourceNodeOptions{r.schema, r.gen(/*parallel=*/false, + /*slow=*/false)})); + HashJoinNodeOptions join_options{JoinType::INNER, + {FieldRef("l_key")}, + {FieldRef("r_key")}, + {FieldRef("l_key"), FieldRef("l_payload")}, + {FieldRef("r_key"), FieldRef("r_payload")}, + {JoinKeyCmp::EQ}}; + ASSERT_OK_AND_ASSIGN( + ExecNode * join, + MakeExecNode("hashjoin", plan.get(), {l_source, r_source}, join_options)); + AsyncGenerator> sink_gen; + ASSERT_OK_AND_ASSIGN(std::ignore, MakeExecNode("sink", plan.get(), {join}, + SinkNodeOptions{&sink_gen})); + + EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT( + NotImplemented, ::testing::HasSubstr("Unifying differing dictionaries"), + StartAndCollect(plan.get(), sink_gen)); + } +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/schema_util.h b/cpp/src/arrow/compute/exec/schema_util.h index ba14d577dc9..33f42701ff5 100644 --- a/cpp/src/arrow/compute/exec/schema_util.h +++ b/cpp/src/arrow/compute/exec/schema_util.h @@ -32,6 +32,10 @@ using internal::checked_cast; namespace compute { +// Identifiers for all different row schemas that are used in a join +// +enum class HashJoinProjection : int { INPUT = 0, KEY = 1, PAYLOAD = 2, OUTPUT = 3 }; + struct SchemaProjectionMap { static constexpr int kMissingField = -1; int num_cols; @@ -86,7 +90,7 @@ class SchemaProjectionMaps { return field(schema_handle, field_id).data_type; } - SchemaProjectionMap map(ProjectionIdEnum from, ProjectionIdEnum to) { + SchemaProjectionMap map(ProjectionIdEnum from, ProjectionIdEnum to) const { int id_from = schema_id(from); int id_to = schema_id(to); SchemaProjectionMap result; diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index 127a1b4f9b3..46bba5609d4 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -15,11 +15,10 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/compute/exec/exec_plan.h" - #include #include "arrow/compute/exec.h" +#include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/expression.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/util.h" @@ -67,7 +66,16 @@ struct SourceNode : ExecNode { [[noreturn]] void InputFinished(ExecNode*, int) override { NoInputs(); } Status StartProducing() override { - DCHECK(!stop_requested_) << "Restarted SourceNode"; + { + // If another exec node encountered an error during its StartProducing call + // it might have already called StopProducing on all of its inputs (including this + // node). + // + std::unique_lock lock(mutex_); + if (stop_requested_) { + return Status::OK(); + } + } CallbackOptions options; auto executor = plan()->exec_context()->executor(); diff --git a/cpp/src/arrow/compute/kernels/row_encoder.cc b/cpp/src/arrow/compute/kernels/row_encoder.cc index 63bff8c2688..840e4634fc8 100644 --- a/cpp/src/arrow/compute/kernels/row_encoder.cc +++ b/cpp/src/arrow/compute/kernels/row_encoder.cc @@ -238,7 +238,9 @@ Result> DictionaryKeyEncoder::Decode(uint8_t** encode if (dictionary_) { data->dictionary = dictionary_->data(); } else { - ARROW_ASSIGN_OR_RAISE(auto dict, MakeArrayOfNull(type_, 0)); + ARROW_DCHECK(type_->id() == Type::DICTIONARY); + const auto& dict_type = checked_cast(*type_); + ARROW_ASSIGN_OR_RAISE(auto dict, MakeArrayOfNull(dict_type.value_type(), 0)); data->dictionary = dict->data(); } diff --git a/cpp/src/arrow/compute/kernels/row_encoder.h b/cpp/src/arrow/compute/kernels/row_encoder.h index 49356c5e9fc..40509f2df7b 100644 --- a/cpp/src/arrow/compute/kernels/row_encoder.h +++ b/cpp/src/arrow/compute/kernels/row_encoder.h @@ -53,6 +53,10 @@ struct KeyEncoder { // extract the null bitmap from the leading nullity bytes of encoded keys static Status DecodeNulls(MemoryPool* pool, int32_t length, uint8_t** encoded_bytes, std::shared_ptr* null_bitmap, int32_t* null_count); + + static bool IsNull(const uint8_t* encoded_bytes) { + return encoded_bytes[0] == kNullByte; + } }; struct BooleanKeyEncoder : KeyEncoder { @@ -156,8 +160,8 @@ struct VarLengthKeyEncoder : KeyEncoder { }); } else { const auto& scalar = data.scalar_as(); - const auto& bytes = *scalar.value; if (scalar.is_valid) { + const auto& bytes = *scalar.value; for (int64_t i = 0; i < batch_length; i++) { auto& encoded_ptr = *encoded_bytes++; *encoded_ptr++ = kValidByte; From 2f426181b2b5675e6d8e4de3ad16970ded7b1131 Mon Sep 17 00:00:00 2001 From: michalursa Date: Tue, 2 Nov 2021 17:11:47 -0700 Subject: [PATCH 2/3] Dictionary support for hash join - addressing code review comments --- cpp/src/arrow/compute/exec/hash_join_dict.cc | 89 ++++++++++--------- cpp/src/arrow/compute/exec/hash_join_dict.h | 12 +-- .../arrow/compute/exec/hash_join_node_test.cc | 61 +++++++------ 3 files changed, 84 insertions(+), 78 deletions(-) diff --git a/cpp/src/arrow/compute/exec/hash_join_dict.cc b/cpp/src/arrow/compute/exec/hash_join_dict.cc index f3b0812ca7c..195331a5976 100644 --- a/cpp/src/arrow/compute/exec/hash_join_dict.cc +++ b/cpp/src/arrow/compute/exec/hash_join_dict.cc @@ -64,8 +64,9 @@ Result> HashJoinDictUtil::IndexRemapUsingLUT( ARROW_DCHECK(data_type->id() == Type::DICTIONARY); const auto& dict_type = checked_cast(*data_type); - ARROW_ASSIGN_OR_RAISE(std::shared_ptr result, - CvtToInt32(dict_type.index_type(), indices, batch_length, ctx)); + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr result, + ConvertToInt32(dict_type.index_type(), indices, batch_length, ctx)); uint8_t* nns = result->buffers[0]->mutable_data(); int32_t* ids = reinterpret_cast(result->buffers[1]->mutable_data()); @@ -87,34 +88,9 @@ Result> HashJoinDictUtil::IndexRemapUsingLUT( return result; } -Result> HashJoinDictUtil::CvtToInt32( - const std::shared_ptr& from_type, const Datum& input, int64_t batch_length, - ExecContext* ctx) { - switch (from_type->id()) { - case Type::UINT8: - return CvtImp(int32(), input, batch_length, ctx); - case Type::INT8: - return CvtImp(int32(), input, batch_length, ctx); - case Type::UINT16: - return CvtImp(int32(), input, batch_length, ctx); - case Type::INT16: - return CvtImp(int32(), input, batch_length, ctx); - case Type::UINT32: - return CvtImp(int32(), input, batch_length, ctx); - case Type::INT32: - return CvtImp(int32(), input, batch_length, ctx); - case Type::UINT64: - return CvtImp(int32(), input, batch_length, ctx); - case Type::INT64: - return CvtImp(int32(), input, batch_length, ctx); - default: - ARROW_DCHECK(false); - return nullptr; - } -} - +namespace { template -Result> HashJoinDictUtil::CvtImp( +static Result> ConvertImp( const std::shared_ptr& to_type, const Datum& input, int64_t batch_length, ExecContext* ctx) { ARROW_DCHECK(input.is_array() || input.is_scalar()); @@ -172,27 +148,54 @@ Result> HashJoinDictUtil::CvtImp( } } } +} // namespace -Result> HashJoinDictUtil::CvtFromInt32( +Result> HashJoinDictUtil::ConvertToInt32( + const std::shared_ptr& from_type, const Datum& input, int64_t batch_length, + ExecContext* ctx) { + switch (from_type->id()) { + case Type::UINT8: + return ConvertImp(int32(), input, batch_length, ctx); + case Type::INT8: + return ConvertImp(int32(), input, batch_length, ctx); + case Type::UINT16: + return ConvertImp(int32(), input, batch_length, ctx); + case Type::INT16: + return ConvertImp(int32(), input, batch_length, ctx); + case Type::UINT32: + return ConvertImp(int32(), input, batch_length, ctx); + case Type::INT32: + return ConvertImp(int32(), input, batch_length, ctx); + case Type::UINT64: + return ConvertImp(int32(), input, batch_length, ctx); + case Type::INT64: + return ConvertImp(int32(), input, batch_length, ctx); + default: + ARROW_DCHECK(false); + return nullptr; + } +} + +Result> HashJoinDictUtil::ConvertFromInt32( const std::shared_ptr& to_type, const Datum& input, int64_t batch_length, ExecContext* ctx) { switch (to_type->id()) { case Type::UINT8: - return CvtImp(to_type, input, batch_length, ctx); + return ConvertImp(to_type, input, batch_length, ctx); case Type::INT8: - return CvtImp(to_type, input, batch_length, ctx); + return ConvertImp(to_type, input, batch_length, ctx); case Type::UINT16: - return CvtImp(to_type, input, batch_length, ctx); + return ConvertImp(to_type, input, batch_length, ctx); case Type::INT16: - return CvtImp(to_type, input, batch_length, ctx); + return ConvertImp(to_type, input, batch_length, ctx); case Type::UINT32: - return CvtImp(to_type, input, batch_length, ctx); + return ConvertImp(to_type, input, batch_length, ctx); case Type::INT32: - return CvtImp(to_type, input, batch_length, ctx); + return ConvertImp(to_type, input, batch_length, ctx); case Type::UINT64: - return CvtImp(to_type, input, batch_length, ctx); + return ConvertImp(to_type, input, batch_length, ctx); case Type::INT64: - return CvtImp(to_type, input, batch_length, ctx); + return ConvertImp(to_type, input, batch_length, ctx); default: ARROW_DCHECK(false); return nullptr; @@ -355,8 +358,8 @@ Result> HashJoinDictBuild::RemapInput( Result> HashJoinDictBuild::RemapOutput( const ArrayData& indices32Bit, ExecContext* ctx) const { ARROW_ASSIGN_OR_RAISE(std::shared_ptr indices, - HashJoinDictUtil::CvtFromInt32(index_type_, Datum(indices32Bit), - indices32Bit.length, ctx)); + HashJoinDictUtil::ConvertFromInt32( + index_type_, Datum(indices32Bit), indices32Bit.length, ctx)); auto type = std::make_shared(index_type_, value_type_); return ArrayData::Make(type, indices->length, indices->buffers, {}, @@ -436,9 +439,9 @@ Result> HashJoinDictProbe::RemapInput( } else { // CASE 2: // Decode selected rows from encoder. - ARROW_ASSIGN_OR_RAISE( - std::shared_ptr row_ids_arr, - HashJoinDictUtil::CvtToInt32(dict_type.index_type(), data, batch_length, ctx)); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr row_ids_arr, + HashJoinDictUtil::ConvertToInt32(dict_type.index_type(), data, + batch_length, ctx)); // Change nulls to internal::RowEncoder::kRowIdForNulls() in index. int32_t* row_ids = reinterpret_cast(row_ids_arr->buffers[1]->mutable_data()); diff --git a/cpp/src/arrow/compute/exec/hash_join_dict.h b/cpp/src/arrow/compute/exec/hash_join_dict.h index 9b746c2d0cf..26605cc449a 100644 --- a/cpp/src/arrow/compute/exec/hash_join_dict.h +++ b/cpp/src/arrow/compute/exec/hash_join_dict.h @@ -38,7 +38,7 @@ // - dictionary column can be matched against dictionary column with a different index // type, and potentially using a different dictionary, if underlying value types are equal // -// We currently require in hash that for all dictionary encoded columns, the same +// We currently require in hash join that for all dictionary encoded columns, the same // dictionary is used in all input exec batches. // // In order to allow matching columns with different dictionaries, different dictionary @@ -89,7 +89,7 @@ class HashJoinDictUtil { // Return int32() array that contains indices of input dictionary array or scalar after // type casting. - static Result> CvtToInt32( + static Result> ConvertToInt32( const std::shared_ptr& from_type, const Datum& input, int64_t batch_length, ExecContext* ctx); @@ -98,18 +98,12 @@ class HashJoinDictUtil { // hash table on build side back to original input data type of hash join, when // outputting hash join results to parent exec node. // - static Result> CvtFromInt32( + static Result> ConvertFromInt32( const std::shared_ptr& to_type, const Datum& input, int64_t batch_length, ExecContext* ctx); // Return dictionary referenced in either dictionary array or dictionary scalar static std::shared_ptr ExtractDictionary(const Datum& data); - - private: - template - static Result> CvtImp( - const std::shared_ptr& to_type, const Datum& input, int64_t batch_length, - ExecContext* ctx); }; /// Implements processing of dictionary arrays/scalars in key columns on the build side of diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc index 3cd84e07b1e..d20b456fec5 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc @@ -1312,28 +1312,25 @@ void TestHashJoinDictionaryHelper( // Compare results AssertTablesEqual(expected, output); - - // TODO: This was added for debugging. Remove in the final version. - // std::cout << output->ToString(); } TEST(HashJoin, Dictionary) { - auto int8_utf8 = std::make_shared(int8(), utf8()); - auto uint8_utf8 = std::make_shared(uint8(), utf8()); - auto int16_utf8 = std::make_shared(int16(), utf8()); - auto uint16_utf8 = std::make_shared(uint16(), utf8()); - auto int32_utf8 = std::make_shared(int32(), utf8()); - auto uint32_utf8 = std::make_shared(uint32(), utf8()); - auto int64_utf8 = std::make_shared(int64(), utf8()); - auto uint64_utf8 = std::make_shared(uint64(), utf8()); - std::shared_ptr dict_types[] = {int8_utf8, uint8_utf8, int16_utf8, - uint16_utf8, int32_utf8, uint32_utf8, - int64_utf8, uint64_utf8}; + auto int8_utf8 = dictionary(int8(), utf8()); + auto uint8_utf8 = arrow::dictionary(uint8(), utf8()); + auto int16_utf8 = arrow::dictionary(int16(), utf8()); + auto uint16_utf8 = arrow::dictionary(uint16(), utf8()); + auto int32_utf8 = arrow::dictionary(int32(), utf8()); + auto uint32_utf8 = arrow::dictionary(uint32(), utf8()); + auto int64_utf8 = arrow::dictionary(int64(), utf8()); + auto uint64_utf8 = arrow::dictionary(uint64(), utf8()); + std::shared_ptr dict_types[] = {int8_utf8, uint8_utf8, int16_utf8, + uint16_utf8, int32_utf8, uint32_utf8, + int64_utf8, uint64_utf8}; Random64Bit rng(43); // Dictionaries in payload columns - for (auto parallel : {false, true}) + for (auto parallel : {false, true}) { for (auto swap_sides : {false, true}) { TestHashJoinDictionaryHelper( JoinType::FULL_OUTER, JoinKeyCmp::EQ, parallel, @@ -1352,11 +1349,12 @@ TEST(HashJoin, Dictionary) { R"(["r", null, "r", "q"])"), 1, swap_sides); } + } // Dictionaries in key columns - for (auto parallel : {false, true}) - for (auto swap_sides : {false, true}) - for (auto l_key_dict : {true, false}) + for (auto parallel : {false, true}) { + for (auto swap_sides : {false, true}) { + for (auto l_key_dict : {true, false}) { for (auto r_key_dict : {true, false}) { auto l_key_dict_type = dict_types[rng.from_range(0, 7)]; auto r_key_dict_type = dict_types[rng.from_range(0, 7)]; @@ -1412,6 +1410,9 @@ TEST(HashJoin, Dictionary) { ArrayFromJSON(utf8(), R"([null, null, "p", "q", null, "r", "p", "s"])"), 3, swap_sides); } + } + } + } // Empty build side { @@ -1420,8 +1421,8 @@ TEST(HashJoin, Dictionary) { auto r_key_dict_type = dict_types[rng.from_range(0, 7)]; auto r_payload_dict_type = dict_types[rng.from_range(0, 7)]; - for (auto parallel : {false, true}) - for (auto swap_sides : {false, true}) + for (auto parallel : {false, true}) { + for (auto swap_sides : {false, true}) { for (auto cmp : {JoinKeyCmp::IS, JoinKeyCmp::EQ}) { TestHashJoinDictionaryHelper( JoinType::FULL_OUTER, cmp, parallel, @@ -1441,6 +1442,8 @@ TEST(HashJoin, Dictionary) { R"(["p", "r", "s"])"), 0, swap_sides); } + } + } } // Empty probe side @@ -1450,8 +1453,8 @@ TEST(HashJoin, Dictionary) { auto r_key_dict_type = dict_types[rng.from_range(0, 7)]; auto r_payload_dict_type = dict_types[rng.from_range(0, 7)]; - for (auto parallel : {false, true}) - for (auto swap_sides : {false, true}) + for (auto parallel : {false, true}) { + for (auto swap_sides : {false, true}) { for (auto cmp : {JoinKeyCmp::IS, JoinKeyCmp::EQ}) { TestHashJoinDictionaryHelper( JoinType::FULL_OUTER, cmp, parallel, @@ -1473,6 +1476,8 @@ TEST(HashJoin, Dictionary) { R"(["p", "r", "s"])"), 4, swap_sides); } + } + } } } @@ -1501,7 +1506,7 @@ TEST(HashJoin, Scalars) { } // Scalars in key columns - for (auto use_scalar_dict : {false, true}) + for (auto use_scalar_dict : {false, true}) { for (auto swap_sides : {false, true}) { TestHashJoinDictionaryHelper( JoinType::FULL_OUTER, JoinKeyCmp::EQ, false /*parallel*/, @@ -1517,9 +1522,10 @@ TEST(HashJoin, Scalars) { ArrayFromJSON(utf8(), R"(["a", "a", null, "b"])"), ArrayFromJSON(utf8(), R"(["p", "p", "q", "r"])"), 2, swap_sides); } + } // Null scalars in key columns - for (auto use_scalar_dict : {false, true}) + for (auto use_scalar_dict : {false, true}) { for (auto swap_sides : {false, true}) { TestHashJoinDictionaryHelper( JoinType::FULL_OUTER, JoinKeyCmp::EQ, false /*parallel*/, @@ -1548,9 +1554,10 @@ TEST(HashJoin, Scalars) { ArrayFromJSON(utf8(), R"([null, null, "a", "b"])"), ArrayFromJSON(utf8(), R"(["q", "q", "p", "r"])"), 2, swap_sides); } + } // Scalars with the empty build/probe side - for (auto use_scalar_dict : {false, true}) + for (auto use_scalar_dict : {false, true}) { for (auto swap_sides : {false, true}) { TestHashJoinDictionaryHelper( JoinType::FULL_OUTER, JoinKeyCmp::EQ, false /*parallel*/, @@ -1564,9 +1571,10 @@ TEST(HashJoin, Scalars) { ArrayFromJSON(utf8(), R"([null, null])"), ArrayFromJSON(utf8(), R"([null, null])"), 0, swap_sides); } + } // Scalars vs dictionaries in key columns - for (auto use_scalar_dict : {false, true}) + for (auto use_scalar_dict : {false, true}) { for (auto swap_sides : {false, true}) { TestHashJoinDictionaryHelper( JoinType::FULL_OUTER, JoinKeyCmp::EQ, false /*parallel*/, @@ -1582,6 +1590,7 @@ TEST(HashJoin, Scalars) { ArrayFromJSON(utf8(), R"(["a", "a", "a", "a", null])"), ArrayFromJSON(utf8(), R"(["p", "q", "p", "q", "r"])"), 1, swap_sides); } + } } TEST(HashJoin, DictNegative) { From 8e688d6d4f146b0d76bf74002b1023447a351c77 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Fri, 5 Nov 2021 14:35:19 -0400 Subject: [PATCH 3/3] Test joins with dictionary columns in R --- r/tests/testthat/test-dplyr-join.R | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/r/tests/testthat/test-dplyr-join.R b/r/tests/testthat/test-dplyr-join.R index 3ff9ad8ff1a..d8239f81085 100644 --- a/r/tests/testthat/test-dplyr-join.R +++ b/r/tests/testthat/test-dplyr-join.R @@ -20,11 +20,6 @@ skip_if_not_available("dataset") library(dplyr, warn.conflicts = FALSE) left <- example_data -# Error: Invalid: Dictionary type support for join output field -# is not yet implemented, output field reference: FieldRef.Name(fct) -# on left side of the join -# (select(-fct) also solves this but remove once) -left$fct <- NULL left$some_grouping <- rep(c(1, 2), 5) left_tab <- Table$create(left) @@ -37,7 +32,6 @@ to_join <- tibble::tibble( to_join_tab <- Table$create(to_join) - test_that("left_join", { expect_message( compare_dplyr_binding( @@ -68,8 +62,6 @@ test_that("left_join `by` args", { left ) - # TODO: allow renaming columns on the right side as well - skip("ARROW-14184") compare_dplyr_binding( .input %>% rename(the_grouping = some_grouping) %>% @@ -82,7 +74,6 @@ test_that("left_join `by` args", { ) }) - test_that("join two tables", { expect_identical( left_tab %>% @@ -146,6 +137,9 @@ test_that("semi_join", { test_that("anti_join", { compare_dplyr_binding( .input %>% + # Factor levels when there are no rows in the data don't match + # TODO: use better anti_join test data + select(-fct) %>% anti_join(to_join, by = "some_grouping") %>% collect(), left