From b47a8b7b608dad59eed140114cdc39565bc42adf Mon Sep 17 00:00:00 2001 From: michalursa Date: Thu, 29 Jul 2021 16:49:34 -0700 Subject: [PATCH 1/3] Grouper - adding set membership type filtering to hash table interface --- cpp/src/arrow/compute/exec/key_encode.cc | 17 +- cpp/src/arrow/compute/exec/key_encode.h | 6 +- cpp/src/arrow/compute/exec/key_map.cc | 703 ++++++++++++------ cpp/src/arrow/compute/exec/key_map.h | 103 ++- cpp/src/arrow/compute/exec/key_map_avx2.cc | 193 ++--- cpp/src/arrow/compute/exec/util.cc | 40 +- cpp/src/arrow/compute/exec/util.h | 8 +- cpp/src/arrow/compute/exec/util_avx2.cc | 14 +- .../arrow/compute/kernels/hash_aggregate.cc | 22 +- 9 files changed, 719 insertions(+), 387 deletions(-) diff --git a/cpp/src/arrow/compute/exec/key_encode.cc b/cpp/src/arrow/compute/exec/key_encode.cc index de79558f2c2..e71b81b4574 100644 --- a/cpp/src/arrow/compute/exec/key_encode.cc +++ b/cpp/src/arrow/compute/exec/key_encode.cc @@ -732,7 +732,7 @@ void KeyEncoder::EncoderBinary::ColumnMemsetNulls( uint32_t col_width = col.metadata().fixed_length; int dispatch_const = (rows->metadata().is_fixed_length ? 5 : 0) + - (col_width == 1 ? 0 + (col_width <= 1 ? 0 : col_width == 2 ? 1 : col_width == 4 ? 2 : col_width == 8 ? 3 : 4); ColumnMemsetNullsImp_fn[dispatch_const](offset_within_row, rows, col, ctx, temp_vector_16bit, byte_value); @@ -864,6 +864,17 @@ void KeyEncoder::EncoderBinaryPair::Encode(uint32_t offset_within_row, KeyRowArr EncodeImp_fn[dispatch_const](num_processed, offset_within_row, rows, col_prep[0], col_prep[1]); } + + DCHECK(temp1->metadata().is_fixed_length); + DCHECK(temp1->length() * temp1->metadata().fixed_length >= + col1.length() * static_cast(sizeof(uint16_t))); + + KeyColumnArray temp16bit(KeyColumnMetadata(true, sizeof(uint16_t)), col1.length(), + nullptr, temp1->mutable_data(1), nullptr); + + EncoderBinary::ColumnMemsetNulls(offset_within_row, rows, col1, ctx, &temp16bit, 0xae); + EncoderBinary::ColumnMemsetNulls(offset_within_row + col_width1, rows, col2, ctx, + &temp16bit, 0xae); } template @@ -1366,8 +1377,8 @@ void KeyEncoder::KeyRowMetadata::FromColumnMetadataVector( // a) Boolean column, marked with fixed-length 0, is considered to have fixed-length // part of 1 byte. b) Columns with fixed-length part being power of 2 or multiple of row // alignment precede other columns. They are sorted among themselves based on size of - // fixed-length part. c) Fixed-length columns precede varying-length columns when both - // have the same size fixed-length part. + // fixed-length part decreasing. c) Fixed-length columns precede varying-length columns + // when both have the same size fixed-length part. column_order.resize(num_cols); for (uint32_t i = 0; i < num_cols; ++i) { column_order[i] = i; diff --git a/cpp/src/arrow/compute/exec/key_encode.h b/cpp/src/arrow/compute/exec/key_encode.h index e5397b9dfd4..d4dd499a20d 100644 --- a/cpp/src/arrow/compute/exec/key_encode.h +++ b/cpp/src/arrow/compute/exec/key_encode.h @@ -372,6 +372,9 @@ class KeyEncoder { const KeyRowArray& rows, KeyColumnArray* col, KeyEncoderContext* ctx, KeyColumnArray* temp); static bool IsInteger(const KeyColumnMetadata& metadata); + static void ColumnMemsetNulls(uint32_t offset_within_row, KeyRowArray* rows, + const KeyColumnArray& col, KeyEncoderContext* ctx, + KeyColumnArray* temp_vector_16bit, uint8_t byte_value); private: template @@ -403,9 +406,6 @@ class KeyEncoder { uint32_t offset_within_row, const KeyRowArray& rows, KeyColumnArray* col); #endif - static void ColumnMemsetNulls(uint32_t offset_within_row, KeyRowArray* rows, - const KeyColumnArray& col, KeyEncoderContext* ctx, - KeyColumnArray* temp_vector_16bit, uint8_t byte_value); template static void ColumnMemsetNullsImp(uint32_t offset_within_row, KeyRowArray* rows, const KeyColumnArray& col, KeyEncoderContext* ctx, diff --git a/cpp/src/arrow/compute/exec/key_map.cc b/cpp/src/arrow/compute/exec/key_map.cc index ac47c04403c..89a9918af8b 100644 --- a/cpp/src/arrow/compute/exec/key_map.cc +++ b/cpp/src/arrow/compute/exec/key_map.cc @@ -34,22 +34,20 @@ namespace compute { constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL; -// Search status bytes inside a block of 8 slots (64-bit word). -// Try to find a slot that contains a 7-bit stamp matching the one provided. -// There are three possible outcomes: -// 1. A matching slot is found. -// -> Return its index between 0 and 7 and set match found flag. -// 2. A matching slot is not found and there is an empty slot in the block. -// -> Return the index of the first empty slot and clear match found flag. -// 3. A matching slot is not found and there are no empty slots in the block. -// -> Return 8 as the output slot index and clear match found flag. +// Scan bytes in block in reverse and stop as soon +// as a position of interest is found. +// +// Positions of interest: +// a) slot with a matching stamp is encountered, +// b) first empty slot is encountered, +// c) we reach the end of the block. // // Optionally an index of the first slot to start the search from can be specified. // In this case slots before it will be ignored. // template inline void SwissTable::search_block(uint64_t block, int stamp, int start_slot, - int* out_slot, int* out_match_found) { + int* out_slot, int* out_match_found) const { // Filled slot bytes have the highest bit set to 0 and empty slots are equal to 0x80. uint64_t block_high_bits = block & kHighBitOfEachByte; @@ -82,6 +80,11 @@ inline void SwissTable::search_block(uint64_t block, int stamp, int start_slot, matches &= kHighBitOfEachByte; } + // In case when there are no matches in slots and the block is full (no empty slots), + // pretend that there is a match in the last slot. + // + matches |= (~block_high_bits & 0x80); + // We get 0 if there are no matches *out_match_found = (matches == 0 ? 0 : 1); @@ -91,28 +94,16 @@ inline void SwissTable::search_block(uint64_t block, int stamp, int start_slot, *out_slot = static_cast(CountLeadingZeros(matches | block_high_bits) >> 3); } -// This call follows the call to search_block. -// The input slot index is the output returned by it, which is a value from 0 to 8, -// with 8 indicating that both: no match was found and there were no empty slots. -// -// If the slot corresponds to a non-empty slot return a group id associated with it. -// Otherwise return any group id from any of the slots or -// zero, which is the default value stored in empty slots. -// inline uint64_t SwissTable::extract_group_id(const uint8_t* block_ptr, int slot, - uint64_t group_id_mask) { - // Input slot can be equal to 8, in which case we need to output any valid group id - // value, so we take the one from slot 0 in the block. - int clamped_slot = slot & 7; - + uint64_t group_id_mask) const { // Group id values for all 8 slots in the block are bit-packed and follow the status // bytes. We assume here that the number of bits is rounded up to 8, 16, 32 or 64. In // that case we can extract group id using aligned 64-bit word access. - int num_groupid_bits = static_cast(ARROW_POPCOUNT64(group_id_mask)); - ARROW_DCHECK(num_groupid_bits == 8 || num_groupid_bits == 16 || - num_groupid_bits == 32 || num_groupid_bits == 64); + int num_group_id_bits = static_cast(ARROW_POPCOUNT64(group_id_mask)); + ARROW_DCHECK(num_group_id_bits == 8 || num_group_id_bits == 16 || + num_group_id_bits == 32 || num_group_id_bits == 64); - int bit_offset = clamped_slot * num_groupid_bits; + int bit_offset = slot * num_group_id_bits; const uint64_t* group_id_bytes = reinterpret_cast(block_ptr) + 1 + (bit_offset >> 6); uint64_t group_id = (*group_id_bytes >> (bit_offset & 63)) & group_id_mask; @@ -120,48 +111,160 @@ inline uint64_t SwissTable::extract_group_id(const uint8_t* block_ptr, int slot, return group_id; } -// Return global slot id (the index including the information about the block) -// where the search should continue if the first comparison fails. -// This function always follows search_block and receives the slot id returned by it. -// -inline uint64_t SwissTable::next_slot_to_visit(uint64_t block_index, int slot, - int match_found) { - // The result should be taken modulo the number of all slots in all blocks, - // but here we allow it to take a value one above the last slot index. - // Modulo operation is postponed to later. - return block_index * 8 + slot + match_found; +template +void SwissTable::extract_group_ids_imp(const int num_keys, const uint16_t* selection, + const uint32_t* hashes, const uint8_t* local_slots, + uint32_t* out_group_ids, int element_offset, + int element_multiplier) const { + const T* elements = reinterpret_cast(blocks_) + element_offset; + if (log_blocks_ == 0) { + ARROW_DCHECK(sizeof(T) == sizeof(uint8_t)); + for (int i = 0; i < num_keys; ++i) { + uint32_t id = use_selection ? selection[i] : i; + uint32_t group_id = blocks_[8 + local_slots[id]]; + out_group_ids[id] = group_id; + } + } else { + for (int i = 0; i < num_keys; ++i) { + uint32_t id = use_selection ? selection[i] : i; + uint32_t hash = hashes[id]; + int64_t pos = + (hash >> (bits_hash_ - log_blocks_)) * element_multiplier + local_slots[id]; + uint32_t group_id = static_cast(elements[pos]); + ARROW_DCHECK(group_id < num_inserted_ || num_inserted_ == 0); + out_group_ids[id] = group_id; + } + } } -// Implements first (fast-path, optimistic) lookup. -// Searches for a match only within the start block and -// trying only the first slot with a matching stamp. -// -// Comparison callback needed for match verification is done outside of this function. -// Match bit vector filled by it only indicates finding a matching stamp in a slot. +void SwissTable::extract_group_ids(const int num_keys, const uint16_t* optional_selection, + const uint32_t* hashes, const uint8_t* local_slots, + uint32_t* out_group_ids) const { + // Group id values for all 8 slots in the block are bit-packed and follow the status + // bytes. We assume here that the number of bits is rounded up to 8, 16, 32 or 64. In + // that case we can extract group id using aligned 64-bit word access. + int num_group_id_bits = num_groupid_bits_from_log_blocks(log_blocks_); + ARROW_DCHECK(num_group_id_bits == 8 || num_group_id_bits == 16 || + num_group_id_bits == 32); + + // Optimistically use simplified lookup involving only a start block to find + // a single group id candidate for every input. +#if defined(ARROW_HAVE_AVX2) + int num_group_id_bytes = num_group_id_bits / 8; + if ((hardware_flags_ & arrow::internal::CpuInfo::AVX2) && !optional_selection) { + extract_group_ids_avx2(num_keys, hashes, local_slots, out_group_ids, sizeof(uint64_t), + 8 + 8 * num_group_id_bytes, num_group_id_bytes); + } else { +#endif + switch (num_group_id_bits) { + case 8: + if (optional_selection) { + extract_group_ids_imp(num_keys, optional_selection, hashes, + local_slots, out_group_ids, 8, 16); + } else { + extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, + out_group_ids, 8, 16); + } + break; + case 16: + if (optional_selection) { + extract_group_ids_imp(num_keys, optional_selection, hashes, + local_slots, out_group_ids, 4, 12); + } else { + extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, + out_group_ids, 4, 12); + } + break; + case 32: + if (optional_selection) { + extract_group_ids_imp(num_keys, optional_selection, hashes, + local_slots, out_group_ids, 2, 10); + } else { + extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, + out_group_ids, 2, 10); + } + break; + default: + ARROW_DCHECK(false); + } +#if defined(ARROW_HAVE_AVX2) + } +#endif +} + +void SwissTable::init_slot_ids(const int num_keys, const uint16_t* selection, + const uint32_t* hashes, const uint8_t* local_slots, + const uint8_t* match_bitvector, + uint32_t* out_slot_ids) const { + ARROW_DCHECK(selection); + if (log_blocks_ == 0) { + for (int i = 0; i < num_keys; ++i) { + uint16_t id = selection[i]; + uint32_t match = ::arrow::BitUtil::GetBit(match_bitvector, id) ? 1 : 0; + uint32_t slot_id = local_slots[id] + match; + out_slot_ids[id] = slot_id; + } + } else { + for (int i = 0; i < num_keys; ++i) { + uint16_t id = selection[i]; + uint32_t hash = hashes[id]; + uint32_t iblock = (hash >> (bits_hash_ - log_blocks_)); + uint32_t match = ::arrow::BitUtil::GetBit(match_bitvector, id) ? 1 : 0; + uint32_t slot_id = iblock * 8 + local_slots[id] + match; + out_slot_ids[id] = slot_id; + } + } +} + +void SwissTable::init_slot_ids_for_new_keys(uint32_t num_ids, const uint16_t* ids, + const uint32_t* hashes, + uint32_t* slot_ids) const { + int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); + uint32_t num_block_bytes = num_groupid_bits + 8; + if (log_blocks_ == 0) { + uint64_t block = *reinterpret_cast(blocks_); + uint32_t empty_slot = static_cast(8 - ARROW_POPCOUNT64(block)); + for (uint32_t i = 0; i < num_ids; ++i) { + int id = ids[i]; + slot_ids[id] = empty_slot; + } + } else { + for (uint32_t i = 0; i < num_ids; ++i) { + int id = ids[i]; + uint32_t hash = hashes[id]; + uint32_t iblock = hash >> (bits_hash_ - log_blocks_); + uint64_t block; + for (;;) { + block = *reinterpret_cast(blocks_ + num_block_bytes * iblock); + block &= kHighBitOfEachByte; + if (block) { + break; + } + iblock = (iblock + 1) & ((1 << log_blocks_) - 1); + } + uint32_t empty_slot = static_cast(8 - ARROW_POPCOUNT64(block)); + slot_ids[id] = iblock * 8 + empty_slot; + } + } +} + +// Quickly filter out keys that have no matches based only on hash value and the +// corresponding starting 64-bit block of slot status bytes. May return false positives. // -template -void SwissTable::lookup_1(const uint16_t* selection, const int num_keys, - const uint32_t* hashes, uint8_t* out_match_bitvector, - uint32_t* out_groupids, uint32_t* out_slot_ids) { +void SwissTable::early_filter_imp(const int num_keys, const uint32_t* hashes, + uint8_t* out_match_bitvector, + uint8_t* out_local_slots) const { // Clear the output bit vector memset(out_match_bitvector, 0, (num_keys + 7) / 8); // Based on the size of the table, prepare bit number constants. uint32_t stamp_mask = (1 << bits_stamp_) - 1; int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); - uint32_t groupid_mask = (1 << num_groupid_bits) - 1; for (int i = 0; i < num_keys; ++i) { - int id; - if (use_selection) { - id = util::SafeLoad(&selection[i]); - } else { - id = i; - } - // Extract from hash: block index and stamp // - uint32_t hash = hashes[id]; + uint32_t hash = hashes[i]; uint32_t iblock = hash >> (bits_hash_ - bits_stamp_ - log_blocks_); uint32_t stamp = iblock & stamp_mask; iblock >>= bits_stamp_; @@ -169,22 +272,19 @@ void SwissTable::lookup_1(const uint16_t* selection, const int num_keys, uint32_t num_block_bytes = num_groupid_bits + 8; const uint8_t* blockbase = reinterpret_cast(blocks_) + static_cast(iblock) * num_block_bytes; - uint64_t block = util::SafeLoadAs(blockbase); + ARROW_DCHECK(num_block_bytes % sizeof(uint64_t) == 0); + uint64_t block = *reinterpret_cast(blockbase); // Call helper functions to obtain the output triplet: // - match (of a stamp) found flag - // - group id for key comparison - // - slot to resume search from in case of no match or false positive + // - number of slots to skip before resuming further search, in case of no match or + // false positive int match_found; int islot_in_block; search_block(block, stamp, 0, &islot_in_block, &match_found); - uint64_t groupid = extract_group_id(blockbase, islot_in_block, groupid_mask); - ARROW_DCHECK(groupid < num_inserted_ || num_inserted_ == 0); - uint64_t islot = next_slot_to_visit(iblock, islot_in_block, match_found); - out_match_bitvector[id / 8] |= match_found << (id & 7); - util::SafeStore(&out_groupids[id], static_cast(groupid)); - util::SafeStore(&out_slot_ids[id], static_cast(islot)); + out_match_bitvector[i / 8] |= match_found << (i & 7); + out_local_slots[i] = static_cast(islot_in_block); } } @@ -203,18 +303,254 @@ uint64_t SwissTable::num_groups_for_resize() const { } } -uint64_t SwissTable::wrap_global_slot_id(uint64_t global_slot_id) { +uint64_t SwissTable::wrap_global_slot_id(uint64_t global_slot_id) const { uint64_t global_slot_id_mask = (1 << (log_blocks_ + 3)) - 1; return global_slot_id & global_slot_id_mask; } -// Run a single round of slot search - comparison / insert - filter unprocessed. +void SwissTable::early_filter(const int num_keys, const uint32_t* hashes, + uint8_t* out_match_bitvector, + uint8_t* out_local_slots) const { + // Optimistically use simplified lookup involving only a start block to find + // a single group id candidate for every input. +#if defined(ARROW_HAVE_AVX2) + if (hardware_flags_ & arrow::internal::CpuInfo::AVX2) { + if (log_blocks_ <= 4) { + int tail = num_keys % 32; + int delta = num_keys - tail; + early_filter_imp_avx2_x32(num_keys - tail, hashes, out_match_bitvector, + out_local_slots); + early_filter_imp_avx2_x8(tail, hashes + delta, out_match_bitvector + delta / 8, + out_local_slots + delta); + } else { + early_filter_imp_avx2_x8(num_keys, hashes, out_match_bitvector, out_local_slots); + } + } else { +#endif + early_filter_imp(num_keys, hashes, out_match_bitvector, out_local_slots); +#if defined(ARROW_HAVE_AVX2) + } +#endif +} + +// Input selection may be: +// - a range of all ids from 0 to num_keys - 1 +// - a selection vector with list of ids +// - a bit-vector marking ids that are included +// Either selection index vector or selection bit-vector must be provided +// but both cannot be set at the same time (one must be null). +// +// Input and output selection index vectors are allowed to point to the same buffer +// (in-place filtering of ids). +// +// Output selection vector needs to have enough space for num_keys entries. +// +void SwissTable::run_comparisons(const int num_keys, + const uint16_t* optional_selection_ids, + const uint8_t* optional_selection_bitvector, + const uint32_t* groupids, int* out_num_not_equal, + uint16_t* out_not_equal_selection) const { + ARROW_DCHECK(optional_selection_ids || optional_selection_bitvector); + ARROW_DCHECK(!optional_selection_ids || !optional_selection_bitvector); + + if (!optional_selection_ids && optional_selection_bitvector) { + // Count rows with matches (based on stamp comparison) + // and decide based on their percentage whether to call dense or sparse comparison + // function. Dense comparison means evaluating it for all inputs, even if the + // matching stamp was not found. It may be cheaper to evaluate comparison for all + // inputs if the extra cost of filtering is higher than the wasted processing of + // rows with no match. + // + // Dense comparison can only be used if there is at least one inserted key, + // because otherwise there is no key to compare to. + // + int64_t num_matches = arrow::internal::CountSetBits(optional_selection_bitvector, + /*offset=*/0, num_keys); + + if (num_inserted_ > 0 && num_matches > 0 && num_matches > 3 * num_keys / 4) { + uint32_t out_num; + equal_impl_(num_keys, nullptr, groupids, &out_num, out_not_equal_selection); + *out_num_not_equal = static_cast(out_num); + } else { + util::BitUtil::bits_to_indexes(1, hardware_flags_, num_keys, + optional_selection_bitvector, out_num_not_equal, + out_not_equal_selection); + uint32_t out_num; + equal_impl_(*out_num_not_equal, out_not_equal_selection, groupids, &out_num, + out_not_equal_selection); + *out_num_not_equal = static_cast(out_num); + } + } else { + uint32_t out_num; + equal_impl_(num_keys, optional_selection_ids, groupids, &out_num, + out_not_equal_selection); + *out_num_not_equal = static_cast(out_num); + } +} + +// Given starting slot index, search blocks for a matching stamp +// until one is found or an empty slot is reached. +// If the search stopped on a non-empty slot, output corresponding +// group id from that slot. +// +// Return true if a match was found. +// +bool SwissTable::find_next_stamp_match(const uint32_t hash, const uint32_t in_slot_id, + uint32_t* out_slot_id, + uint32_t* out_group_id) const { + const uint64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); + constexpr uint64_t stamp_mask = 0x7f; + const int stamp = + static_cast((hash >> (bits_hash_ - log_blocks_ - bits_stamp_)) & stamp_mask); + uint64_t start_slot_id = wrap_global_slot_id(in_slot_id); + int match_found; + int local_slot; + uint8_t* blockbase; + for (;;) { + const uint64_t num_block_bytes = (8 + num_groupid_bits); + blockbase = blocks_ + num_block_bytes * (start_slot_id >> 3); + uint64_t block = *reinterpret_cast(blockbase); + + search_block(block, stamp, (start_slot_id & 7), &local_slot, &match_found); + + start_slot_id = + wrap_global_slot_id((start_slot_id & ~7ULL) + local_slot + match_found); + + // Match found can be 1 in two cases: + // - match was found + // - match was not found in a full block + // In the second case search needs to continue in the next block. + if (match_found == 0 || blockbase[7 - local_slot] == stamp) { + break; + } + } + + const uint64_t groupid_mask = (1ULL << num_groupid_bits) - 1; + *out_group_id = + static_cast(extract_group_id(blockbase, local_slot, groupid_mask)); + *out_slot_id = static_cast(start_slot_id); + + return match_found; +} + +void SwissTable::insert_into_empty_slot(uint32_t slot_id, uint32_t hash, + uint32_t group_id) { + const uint64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); + + // We assume here that the number of bits is rounded up to 8, 16, 32 or 64. + // In that case we can insert group id value using aligned 64-bit word access. + ARROW_DCHECK(num_groupid_bits == 8 || num_groupid_bits == 16 || + num_groupid_bits == 32 || num_groupid_bits == 64); + + const uint64_t num_block_bytes = (8 + num_groupid_bits); + constexpr uint64_t stamp_mask = 0x7f; + + int start_slot = (slot_id & 7); + int stamp = + static_cast((hash >> (bits_hash_ - log_blocks_ - bits_stamp_)) & stamp_mask); + uint64_t block_id = slot_id >> 3; + uint8_t* blockbase = blocks_ + num_block_bytes * block_id; + + blockbase[7 - start_slot] = static_cast(stamp); + int groupid_bit_offset = static_cast(start_slot * num_groupid_bits); + + // Block status bytes should start at an address aligned to 8 bytes + ARROW_DCHECK((reinterpret_cast(blockbase) & 7) == 0); + uint64_t* ptr = reinterpret_cast(blockbase) + 1 + (groupid_bit_offset >> 6); + *ptr |= (static_cast(group_id) << (groupid_bit_offset & 63)); + + hashes_[slot_id] = hash; +} + +// Find method is the continuation of processing from early_filter. +// Its input consists of hash values and the output of early_filter. +// It updates match bit-vector, clearing it from any false positives +// that might have been left by early_filter. +// It also outputs group ids, which are needed to be able to execute +// key comparisons. The caller may discard group ids if only the +// match flag is of interest. +// +void SwissTable::find(const int num_keys, const uint32_t* hashes, + uint8_t* inout_match_bitvector, const uint8_t* local_slots, + uint32_t* out_group_ids) const { + // Temporary selection vector. + // It will hold ids of keys for which we do not know yet + // if they have a match in hash table or not. + // + // Initially the set of these keys is represented by input + // match bit-vector. Eventually we switch from this bit-vector + // to array of ids. + // + ARROW_DCHECK(num_keys <= (1 << log_minibatch_)); + auto ids_buf = util::TempVectorHolder(temp_stack_, num_keys); + uint16_t* ids = ids_buf.mutable_data(); + int num_ids; + + int64_t num_matches = + arrow::internal::CountSetBits(inout_match_bitvector, /*offset=*/0, num_keys); + + // If there is a high density of selected input rows + // (majority of them are present in the selection), + // we may run some computation on all of the input rows ignoring + // selection and then filter the output of this computation + // (pre-filtering vs post-filtering). + // + bool visit_all = num_matches > 0 && num_matches > 3 * num_keys / 4; + if (visit_all) { + extract_group_ids(num_keys, nullptr, hashes, local_slots, out_group_ids); + run_comparisons(num_keys, nullptr, inout_match_bitvector, out_group_ids, &num_ids, + ids); + } else { + util::BitUtil::bits_to_indexes(1, hardware_flags_, num_keys, inout_match_bitvector, + &num_ids, ids); + extract_group_ids(num_ids, ids, hashes, local_slots, out_group_ids); + run_comparisons(num_ids, ids, nullptr, out_group_ids, &num_ids, ids); + } + + if (num_ids == 0) { + return; + } + + auto slot_ids_buf = util::TempVectorHolder(temp_stack_, num_ids); + uint32_t* slot_ids = slot_ids_buf.mutable_data(); + init_slot_ids(num_ids, ids, hashes, local_slots, inout_match_bitvector, slot_ids); + + while (num_ids > 0) { + int num_ids_last_iteration = num_ids; + num_ids = 0; + for (int i = 0; i < num_ids_last_iteration; ++i) { + int id = ids[i]; + uint32_t next_slot_id; + bool match_found = find_next_stamp_match(hashes[id], slot_ids[id], &next_slot_id, + &(out_group_ids[id])); + slot_ids[id] = next_slot_id; + // If next match was not found then clear match bit in a bit vector + if (!match_found) { + ::arrow::BitUtil::ClearBit(inout_match_bitvector, id); + } else { + ids[num_ids++] = id; + } + } + + run_comparisons(num_ids, ids, nullptr, out_group_ids, &num_ids, ids); + } +} // namespace compute + +// Slow processing of input keys in the most generic case. +// Handles inserting new keys. +// Pre-existing keys will be handled correctly, although the intended use is for this +// call to follow a call to find() method, which would only pass on new keys that were +// not present in the hash table. +// +// Run a single round of slot search - comparison or insert - filter unprocessed. // Update selection vector to reflect which items have been processed. // Ids in selection vector do not have to be sorted. // -Status SwissTable::lookup_2(const uint32_t* hashes, uint32_t* inout_num_selected, - uint16_t* inout_selection, bool* out_need_resize, - uint32_t* out_group_ids, uint32_t* inout_next_slot_ids) { +Status SwissTable::map_new_keys_helper(const uint32_t* hashes, + uint32_t* inout_num_selected, + uint16_t* inout_selection, bool* out_need_resize, + uint32_t* out_group_ids, + uint32_t* inout_next_slot_ids) { auto num_groups_limit = num_groups_for_resize(); ARROW_DCHECK(num_inserted_ < num_groups_limit); @@ -223,198 +559,115 @@ Status SwissTable::lookup_2(const uint32_t* hashes, uint32_t* inout_num_selected // ARROW_DCHECK(*inout_num_selected <= static_cast(1 << log_minibatch_)); - // We will split input row ids into three categories: - // - needing to visit next block [0] - // - needing comparison [1] - // - inserted [2] - // - auto ids_inserted_buf = - util::TempVectorHolder(temp_stack_, *inout_num_selected); - auto ids_for_comparison_buf = - util::TempVectorHolder(temp_stack_, *inout_num_selected); - constexpr int category_nomatch = 0; - constexpr int category_cmp = 1; - constexpr int category_inserted = 2; - int num_ids[3]; - num_ids[0] = num_ids[1] = num_ids[2] = 0; - uint16_t* ids[3]{inout_selection, ids_for_comparison_buf.mutable_data(), - ids_inserted_buf.mutable_data()}; - auto push_id = [&num_ids, &ids](int category, int id) { - util::SafeStore(&ids[category][num_ids[category]++], static_cast(id)); - }; - - uint64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); - uint64_t groupid_mask = (1ULL << num_groupid_bits) - 1; - constexpr uint64_t stamp_mask = 0x7f; - uint64_t num_block_bytes = (8 + num_groupid_bits); + size_t num_bytes_for_bits = (*inout_num_selected + 7) / 8 + sizeof(uint64_t); + auto match_bitvector_buf = util::TempVectorHolder( + temp_stack_, static_cast(num_bytes_for_bits)); + uint8_t* match_bitvector = match_bitvector_buf.mutable_data(); + memset(match_bitvector, 0xff, num_bytes_for_bits); + // Check the alignment of the input selection vector + ARROW_DCHECK((reinterpret_cast(inout_selection) & 1) == 0); + + uint32_t num_inserted_new = 0; uint32_t num_processed; - for (num_processed = 0; - // Second condition in for loop: - // We need to break processing and have the caller of this function - // resize hash table if we reach the limit of the number of groups present. - num_processed < *inout_num_selected && - num_inserted_ + num_ids[category_inserted] < num_groups_limit; - ++num_processed) { + for (num_processed = 0; num_processed < *inout_num_selected; ++num_processed) { // row id in original batch - int id = util::SafeLoad(&inout_selection[num_processed]); - - uint64_t slot_id = wrap_global_slot_id(util::SafeLoad(&inout_next_slot_ids[id])); - uint64_t block_id = slot_id >> 3; - uint32_t hash = hashes[id]; - uint8_t* blockbase = blocks_ + num_block_bytes * block_id; - uint64_t block = *reinterpret_cast(blockbase); - uint64_t stamp = (hash >> (bits_hash_ - log_blocks_ - bits_stamp_)) & stamp_mask; - int start_slot = (slot_id & 7); - - bool isempty = (blockbase[7 - start_slot] == 0x80); - if (isempty) { + int id = inout_selection[num_processed]; + bool match_found = + find_next_stamp_match(hashes[id], inout_next_slot_ids[id], + &inout_next_slot_ids[id], &out_group_ids[id]); + if (!match_found) { // If we reach the empty slot we insert key for new group + // + out_group_ids[id] = num_inserted_ + num_inserted_new; + insert_into_empty_slot(inout_next_slot_ids[id], hashes[id], out_group_ids[id]); + ::arrow::BitUtil::ClearBit(match_bitvector, num_processed); + ++num_inserted_new; - blockbase[7 - start_slot] = static_cast(stamp); - uint32_t group_id = num_inserted_ + num_ids[category_inserted]; - int groupid_bit_offset = static_cast(start_slot * num_groupid_bits); - - // We assume here that the number of bits is rounded up to 8, 16, 32 or 64. - // In that case we can insert group id value using aligned 64-bit word access. - ARROW_DCHECK(num_groupid_bits == 8 || num_groupid_bits == 16 || - num_groupid_bits == 32 || num_groupid_bits == 64); - uint64_t* ptr = - &reinterpret_cast(blockbase + 8)[groupid_bit_offset >> 6]; - util::SafeStore(ptr, util::SafeLoad(ptr) | (static_cast(group_id) - << (groupid_bit_offset & 63))); - - hashes_[slot_id] = hash; - util::SafeStore(&out_group_ids[id], group_id); - push_id(category_inserted, id); - } else { - // We search for a slot with a matching stamp within a single block. - // We append row id to the appropriate sequence of ids based on - // whether the match has been found or not. - - int new_match_found; - int new_slot; - search_block(block, static_cast(stamp), start_slot, &new_slot, - &new_match_found); - auto new_groupid = - static_cast(extract_group_id(blockbase, new_slot, groupid_mask)); - ARROW_DCHECK(new_groupid < num_inserted_ + num_ids[category_inserted]); - new_slot = - static_cast(next_slot_to_visit(block_id, new_slot, new_match_found)); - util::SafeStore(&inout_next_slot_ids[id], new_slot); - util::SafeStore(&out_group_ids[id], new_groupid); - push_id(new_match_found, id); + // We need to break processing and have the caller of this function + // resize hash table if we reach the limit of the number of groups present. + // + if (num_inserted_ + num_inserted_new == num_groups_limit) { + ++num_processed; + break; + } } } + auto temp_ids_buffer = + util::TempVectorHolder(temp_stack_, *inout_num_selected); + uint16_t* temp_ids = temp_ids_buffer.mutable_data(); + int num_temp_ids = 0; + // Copy keys for newly inserted rows using callback - RETURN_NOT_OK(append_impl_(num_ids[category_inserted], ids[category_inserted])); - num_inserted_ += num_ids[category_inserted]; + // + util::BitUtil::bits_filter_indexes(0, hardware_flags_, num_processed, match_bitvector, + inout_selection, &num_temp_ids, temp_ids); + ARROW_DCHECK(static_cast(num_inserted_new) == num_temp_ids); + RETURN_NOT_OK(append_impl_(num_inserted_new, temp_ids)); + num_inserted_ += num_inserted_new; // Evaluate comparisons and append ids of rows that failed it to the non-match set. - uint32_t num_not_equal; - equal_impl_(num_ids[category_cmp], ids[category_cmp], out_group_ids, &num_not_equal, - ids[category_nomatch] + num_ids[category_nomatch]); - num_ids[category_nomatch] += num_not_equal; + util::BitUtil::bits_filter_indexes(1, hardware_flags_, num_processed, match_bitvector, + inout_selection, &num_temp_ids, temp_ids); + run_comparisons(num_temp_ids, temp_ids, nullptr, out_group_ids, &num_temp_ids, + temp_ids); + memcpy(inout_selection, temp_ids, sizeof(uint16_t) * num_temp_ids); // Append ids of any unprocessed entries if we aborted processing due to the need // to resize. if (num_processed < *inout_num_selected) { - memmove(ids[category_nomatch] + num_ids[category_nomatch], - inout_selection + num_processed, + memmove(inout_selection + num_temp_ids, inout_selection + num_processed, sizeof(uint16_t) * (*inout_num_selected - num_processed)); - num_ids[category_nomatch] += (*inout_num_selected - num_processed); } + *inout_num_selected = num_temp_ids + (*inout_num_selected - num_processed); *out_need_resize = (num_inserted_ == num_groups_limit); - *inout_num_selected = num_ids[category_nomatch]; return Status::OK(); } -// Use hashes and callbacks to find group ids for already existing keys and -// to insert and report newly assigned group ids for new keys. +// Do inserts and find group ids for a set of new keys (with possible duplicates within +// this set). // -Status SwissTable::map(const int num_keys, const uint32_t* hashes, - uint32_t* out_groupids) { - // Temporary buffers have limited size. - // Caller is responsible for splitting larger input arrays into smaller chunks. - ARROW_DCHECK(num_keys <= (1 << log_minibatch_)); - - // Allocate temporary buffers with a lifetime of this function - auto match_bitvector_buf = util::TempVectorHolder(temp_stack_, num_keys); - uint8_t* match_bitvector = match_bitvector_buf.mutable_data(); - auto slot_ids_buf = util::TempVectorHolder(temp_stack_, num_keys); - uint32_t* slot_ids = slot_ids_buf.mutable_data(); - auto ids_buf = util::TempVectorHolder(temp_stack_, num_keys); - uint16_t* ids = ids_buf.mutable_data(); - uint32_t num_ids; +Status SwissTable::map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t* hashes, + uint32_t* group_ids) { + if (num_ids == 0) { + return Status::OK(); + } - // First-pass processing. - // Optimistically use simplified lookup involving only a start block to find - // a single group id candidate for every input. -#if defined(ARROW_HAVE_AVX2) - if (hardware_flags_ & arrow::internal::CpuInfo::AVX2) { - if (log_blocks_ <= 4) { - int tail = num_keys % 32; - int delta = num_keys - tail; - lookup_1_avx2_x32(num_keys - tail, hashes, match_bitvector, out_groupids, slot_ids); - lookup_1_avx2_x8(tail, hashes + delta, match_bitvector + delta / 8, - out_groupids + delta, slot_ids + delta); - } else { - lookup_1_avx2_x8(num_keys, hashes, match_bitvector, out_groupids, slot_ids); - } - } else { -#endif - lookup_1(nullptr, num_keys, hashes, match_bitvector, out_groupids, slot_ids); -#if defined(ARROW_HAVE_AVX2) + uint16_t max_id = ids[0]; + for (uint32_t i = 1; i < num_ids; ++i) { + max_id = std::max(max_id, ids[i]); } -#endif - int64_t num_matches = - arrow::internal::CountSetBits(match_bitvector, /*offset=*/0, num_keys); + // Temporary buffers have limited size. + // Caller is responsible for splitting larger input arrays into smaller chunks. + ARROW_DCHECK(static_cast(num_ids) <= (1 << log_minibatch_)); + ARROW_DCHECK(static_cast(max_id + 1) <= (1 << log_minibatch_)); - // After the first-pass processing count rows with matches (based on stamp comparison) - // and decide based on their percentage whether to call dense or sparse comparison - // function. Dense comparison means evaluating it for all inputs, even if the matching - // stamp was not found. It may be cheaper to evaluate comparison for all inputs if the - // extra cost of filtering is higher than the wasted processing of rows with no match. - // - // Dense comparison can only be used if there is at least one inserted key, - // because otherwise there is no key to compare to. - // - if (num_inserted_ > 0 && num_matches > 0 && num_matches > 3 * num_keys / 4) { - // Dense comparisons - equal_impl_(num_keys, nullptr, out_groupids, &num_ids, ids); - } else { - // Sparse comparisons that involve filtering the input set of keys - auto ids_cmp_buf = util::TempVectorHolder(temp_stack_, num_keys); - uint16_t* ids_cmp = ids_cmp_buf.mutable_data(); - int num_ids_result; - util::BitUtil::bits_split_indexes(hardware_flags_, num_keys, match_bitvector, - &num_ids_result, ids, ids_cmp); - num_ids = num_ids_result; - uint32_t num_not_equal; - equal_impl_(num_keys - num_ids, ids_cmp, out_groupids, &num_not_equal, ids + num_ids); - num_ids += num_not_equal; - } + // Allocate temporary buffers for slot ids and intialize them + auto slot_ids_buf = util::TempVectorHolder(temp_stack_, max_id + 1); + uint32_t* slot_ids = slot_ids_buf.mutable_data(); + init_slot_ids_for_new_keys(num_ids, ids, hashes, slot_ids); do { // A single round of slow-pass (robust) lookup or insert. - // A single round ends with either a single comparison verifying the match candidate - // or inserting a new key. A single round of slow-pass may return early if we reach - // the limit of the number of groups due to inserts of new keys. In that case we need - // to resize and recalculating starting global slot ids for new bigger hash table. + // A single round ends with either a single comparison verifying the match + // candidate or inserting a new key. A single round of slow-pass may return early + // if we reach the limit of the number of groups due to inserts of new keys. In + // that case we need to resize and recalculating starting global slot ids for new + // bigger hash table. bool out_of_capacity; - RETURN_NOT_OK( - lookup_2(hashes, &num_ids, ids, &out_of_capacity, out_groupids, slot_ids)); + RETURN_NOT_OK(map_new_keys_helper(hashes, &num_ids, ids, &out_of_capacity, group_ids, + slot_ids)); if (out_of_capacity) { RETURN_NOT_OK(grow_double()); // Reset start slot ids for still unprocessed input keys. // for (uint32_t i = 0; i < num_ids; ++i) { // First slot in the new starting block - const int16_t id = util::SafeLoad(&ids[i]); - util::SafeStore(&slot_ids[id], (hashes[id] >> (bits_hash_ - log_blocks_)) * 8); + const int16_t id = ids[i]; + slot_ids[id] = (hashes[id] >> (bits_hash_ - log_blocks_)) * 8; } } } while (num_ids > 0); diff --git a/cpp/src/arrow/compute/exec/key_map.h b/cpp/src/arrow/compute/exec/key_map.h index 8c472736ec4..286bbed038a 100644 --- a/cpp/src/arrow/compute/exec/key_map.h +++ b/cpp/src/arrow/compute/exec/key_map.h @@ -40,9 +40,17 @@ class SwissTable { Status init(int64_t hardware_flags, MemoryPool* pool, util::TempVectorStack* temp_stack, int log_minibatch, EqualImpl equal_impl, AppendImpl append_impl); + void cleanup(); - Status map(const int ckeys, const uint32_t* hashes, uint32_t* outgroupids); + void early_filter(const int num_keys, const uint32_t* hashes, + uint8_t* out_match_bitvector, uint8_t* out_local_slots) const; + + void find(const int num_keys, const uint32_t* hashes, uint8_t* inout_match_bitvector, + const uint8_t* local_slots, uint32_t* out_group_ids) const; + + Status map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t* hashes, + uint32_t* group_ids); private: // Lookup helpers @@ -55,6 +63,9 @@ class SwissTable { /// b) first empty slot is encountered, /// c) we reach the end of the block. /// + /// Optionally an index of the first slot to start the search from can be specified. + /// In this case slots before it will be ignored. + /// /// \param[in] block 8 byte block of hash table /// \param[in] stamp 7 bits of hash used as a stamp /// \param[in] start_slot Index of the first slot in the block to start search from. We @@ -63,55 +74,79 @@ class SwissTable { /// variant.) /// \param[out] out_slot index corresponding to the discovered position of interest (8 /// represents end of block). - /// \param[out] out_match_found an integer flag (0 or 1) indicating if we found a - /// matching stamp. + /// \param[out] out_match_found an integer flag (0 or 1) indicating if we reached an + /// empty slot (0) or not (1). Therefore 1 can mean that either actual match was found + /// (case a) above) or we reached the end of full block (case b) above). + /// template inline void search_block(uint64_t block, int stamp, int start_slot, int* out_slot, - int* out_match_found); + int* out_match_found) const; /// \brief Extract group id for a given slot in a given block. /// - /// Group ids follow in memory after 64-bit block data. - /// Maximum number of groups inserted is equal to the number - /// of all slots in all blocks, which is 8 * the number of blocks. - /// Group ids are bit packed using that maximum to determine the necessary number of - /// bits. inline uint64_t extract_group_id(const uint8_t* block_ptr, int slot, - uint64_t group_id_mask); + uint64_t group_id_mask) const; + + void extract_group_ids(const int num_keys, const uint16_t* optional_selection, + const uint32_t* hashes, const uint8_t* local_slots, + uint32_t* out_group_ids) const; - inline uint64_t next_slot_to_visit(uint64_t block_index, int slot, int match_found); + template + void extract_group_ids_imp(const int num_keys, const uint16_t* selection, + const uint32_t* hashes, const uint8_t* local_slots, + uint32_t* out_group_ids, int elements_offset, + int element_mutltiplier) const; - inline void insert(uint8_t* block_base, uint64_t slot_id, uint32_t hash, uint8_t stamp, - uint32_t group_id); + inline uint64_t next_slot_to_visit(uint64_t block_index, int slot, + int match_found) const; inline uint64_t num_groups_for_resize() const; - inline uint64_t wrap_global_slot_id(uint64_t global_slot_id); + inline uint64_t wrap_global_slot_id(uint64_t global_slot_id) const; - // First hash table access - // Find first match in the start block if exists. - // Possible cases: - // 1. Stamp match in a block - // 2. No stamp match in a block, no empty buckets in a block - // 3. No stamp match in a block, empty buckets in a block + void init_slot_ids(const int num_keys, const uint16_t* selection, + const uint32_t* hashes, const uint8_t* local_slots, + const uint8_t* match_bitvector, uint32_t* out_slot_ids) const; + + void init_slot_ids_for_new_keys(uint32_t num_ids, const uint16_t* ids, + const uint32_t* hashes, uint32_t* slot_ids) const; + + // Quickly filter out keys that have no matches based only on hash value and the + // corresponding starting 64-bit block of slot status bytes. May return false positives. // - template - void lookup_1(const uint16_t* selection, const int num_keys, const uint32_t* hashes, - uint8_t* out_match_bitvector, uint32_t* out_group_ids, - uint32_t* out_slot_ids); + void early_filter_imp(const int num_keys, const uint32_t* hashes, + uint8_t* out_match_bitvector, uint8_t* out_local_slots) const; #if defined(ARROW_HAVE_AVX2) - void lookup_1_avx2_x8(const int num_hashes, const uint32_t* hashes, - uint8_t* out_match_bitvector, uint32_t* out_group_ids, - uint32_t* out_next_slot_ids); - void lookup_1_avx2_x32(const int num_hashes, const uint32_t* hashes, - uint8_t* out_match_bitvector, uint32_t* out_group_ids, - uint32_t* out_next_slot_ids); + void early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* hashes, + uint8_t* out_match_bitvector, + uint8_t* out_local_slots) const; + void early_filter_imp_avx2_x32(const int num_hashes, const uint32_t* hashes, + uint8_t* out_match_bitvector, + uint8_t* out_local_slots) const; + void extract_group_ids_avx2(const int num_keys, const uint32_t* hashes, + const uint8_t* local_slots, uint32_t* out_group_ids, + int byte_offset, int byte_multiplier, int byte_size) const; #endif - // Completing hash table lookup post first access - Status lookup_2(const uint32_t* hashes, uint32_t* inout_num_selected, - uint16_t* inout_selection, bool* out_need_resize, - uint32_t* out_group_ids, uint32_t* out_next_slot_ids); + void run_comparisons(const int num_keys, const uint16_t* optional_selection_ids, + const uint8_t* optional_selection_bitvector, + const uint32_t* groupids, int* out_num_not_equal, + uint16_t* out_not_equal_selection) const; + + inline bool find_next_stamp_match(const uint32_t hash, const uint32_t in_slot_id, + uint32_t* out_slot_id, uint32_t* out_group_id) const; + + inline void insert_into_empty_slot(uint32_t slot_id, uint32_t hash, uint32_t group_id); + + // Slow processing of input keys in the most generic case. + // Handles inserting new keys. + // Pre-existing keys will be handled correctly, although the intended use is for this + // call to follow a call to find() method, which would only pass on new keys that were + // not present in the hash table. + // + Status map_new_keys_helper(const uint32_t* hashes, uint32_t* inout_num_selected, + uint16_t* inout_selection, bool* out_need_resize, + uint32_t* out_group_ids, uint32_t* out_next_slot_ids); // Resize small hash tables when 50% full (up to 8KB). // Resize large hash tables when 75% full. diff --git a/cpp/src/arrow/compute/exec/key_map_avx2.cc b/cpp/src/arrow/compute/exec/key_map_avx2.cc index a2efb4d1bb9..2fca6bf6c10 100644 --- a/cpp/src/arrow/compute/exec/key_map_avx2.cc +++ b/cpp/src/arrow/compute/exec/key_map_avx2.cc @@ -36,14 +36,13 @@ namespace compute { // This is more or less translation of equivalent scalar code, adjusted for a different // instruction set (e.g. missing leading zero count instruction). // -void SwissTable::lookup_1_avx2_x8(const int num_hashes, const uint32_t* hashes, - uint8_t* out_match_bitvector, uint32_t* out_group_ids, - uint32_t* out_next_slot_ids) { +void SwissTable::early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* hashes, + uint8_t* out_match_bitvector, + uint8_t* out_local_slots) const { // Number of inputs processed together in a loop constexpr int unroll = 8; const int num_group_id_bits = num_groupid_bits_from_log_blocks(log_blocks_); - uint32_t group_id_mask = ~static_cast(0) >> (32 - num_group_id_bits); const __m256i* vhash_ptr = reinterpret_cast(hashes); const __m256i vstamp_mask = _mm256_set1_epi32((1 << bits_stamp_) - 1); @@ -85,6 +84,15 @@ void SwissTable::lookup_1_avx2_x8(const int num_hashes, const uint32_t* hashes, vstamp_B, _mm256_or_si256(vbyte_repeat_pattern, vblock_highbits_B)); __m256i vmatches_A = _mm256_cmpeq_epi8(vblock_A, vstamp_A); __m256i vmatches_B = _mm256_cmpeq_epi8(vblock_B, vstamp_B); + + // In case when there are no matches in slots and the block is full (no empty slots), + // pretend that there is a match in the last slot. + // + vmatches_A = _mm256_or_si256( + vmatches_A, _mm256_andnot_si256(vblock_highbits_A, _mm256_set1_epi64x(0xff))); + vmatches_B = _mm256_or_si256( + vmatches_B, _mm256_andnot_si256(vblock_highbits_B, _mm256_set1_epi64x(0xff))); + __m256i vmatch_found = _mm256_andnot_si256( _mm256_blend_epi32(_mm256_cmpeq_epi64(vmatches_A, _mm256_setzero_si256()), _mm256_cmpeq_epi64(vmatches_B, _mm256_setzero_si256()), @@ -106,46 +114,30 @@ void SwissTable::lookup_1_avx2_x8(const int num_hashes, const uint32_t* hashes, // // Emulating lzcnt in lowest bytes of 32-bit elements __m256i vgt = _mm256_cmpgt_epi32(_mm256_set1_epi32(16), vmatches); - __m256i vnext_slot_id = + __m256i vlocal_slot = _mm256_blendv_epi8(_mm256_srli_epi32(vmatches, 4), _mm256_and_si256(vmatches, _mm256_set1_epi32(0x0f)), vgt); - vnext_slot_id = _mm256_shuffle_epi8( + vlocal_slot = _mm256_shuffle_epi8( _mm256_setr_epi8(4, 3, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0), - vnext_slot_id); - vnext_slot_id = - _mm256_add_epi32(_mm256_and_si256(vnext_slot_id, _mm256_set1_epi32(0xff)), - _mm256_and_si256(vgt, _mm256_set1_epi32(4))); - - // Lookup group ids - // - __m256i vgroupid_bit_offset = - _mm256_mullo_epi32(_mm256_and_si256(vnext_slot_id, _mm256_set1_epi32(7)), - _mm256_set1_epi32(num_group_id_bits)); - - // This only works for up to 25 bits per group id, since it uses 32-bit gather - // TODO: make sure this will never get called when there are more than 2^25 groups. - __m256i vgroupid = - _mm256_add_epi32(_mm256_srli_epi32(vgroupid_bit_offset, 3), - _mm256_add_epi32(vblock_offset, _mm256_set1_epi32(8))); - vgroupid = _mm256_i32gather_epi32(reinterpret_cast(blocks_), vgroupid, 1); - vgroupid = _mm256_srlv_epi32( - vgroupid, _mm256_and_si256(vgroupid_bit_offset, _mm256_set1_epi32(7))); - vgroupid = _mm256_and_si256(vgroupid, _mm256_set1_epi32(group_id_mask)); + vlocal_slot); + vlocal_slot = _mm256_add_epi32(_mm256_and_si256(vlocal_slot, _mm256_set1_epi32(0xff)), + _mm256_and_si256(vgt, _mm256_set1_epi32(4))); // Convert slot id relative to the block to slot id relative to the beginnning of the // table // - vnext_slot_id = _mm256_add_epi32( - _mm256_add_epi32(vnext_slot_id, - _mm256_and_si256(vmatch_found, _mm256_set1_epi32(1))), - _mm256_slli_epi32(vblock_id, 3)); + uint64_t local_slot = _mm256_extract_epi64( + _mm256_permutevar8x32_epi32( + _mm256_shuffle_epi8( + vlocal_slot, _mm256_setr_epi32(0x0c080400, 0, 0, 0, 0x0c080400, 0, 0, 0)), + _mm256_setr_epi32(0, 4, 0, 0, 0, 0, 0, 0)), + 0); + (reinterpret_cast(out_local_slots))[i] = local_slot; // Convert match found vector from 32-bit elements to bit vector out_match_bitvector[i] = _pext_u32(_mm256_movemask_epi8(vmatch_found), 0x11111111); // 0b00010001 repeated 4x - _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + i, vgroupid); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_next_slot_ids) + i, vnext_slot_id); } } @@ -214,9 +206,9 @@ inline void split_bytes_avx2(__m256i word0, __m256i word1, __m256i word2, __m256 // using a different method. // TODO: Explain the idea behind storing arrays in SIMD registers. // Explain why it is faster with SIMD than using memory loads. -void SwissTable::lookup_1_avx2_x32(const int num_hashes, const uint32_t* hashes, - uint8_t* out_match_bitvector, uint32_t* out_group_ids, - uint32_t* out_next_slot_ids) { +void SwissTable::early_filter_imp_avx2_x32(const int num_hashes, const uint32_t* hashes, + uint8_t* out_match_bitvector, + uint8_t* out_local_slots) const { constexpr int unroll = 32; // There is a limit on the number of input blocks, @@ -227,8 +219,6 @@ void SwissTable::lookup_1_avx2_x32(const int num_hashes, const uint32_t* hashes, // table. We put them in the same order. __m256i vblock_byte0, vblock_byte1, vblock_byte2, vblock_byte3, vblock_byte4, vblock_byte5, vblock_byte6, vblock_byte7; - __m256i vgroupid_byte0, vgroupid_byte1, vgroupid_byte2, vgroupid_byte3, vgroupid_byte4, - vgroupid_byte5, vgroupid_byte6, vgroupid_byte7; // What we output if there is no match in the block __m256i vslot_empty_or_end; @@ -236,22 +226,16 @@ void SwissTable::lookup_1_avx2_x32(const int num_hashes, const uint32_t* hashes, constexpr uint32_t k4ByteSequence_1_5_9_13 = 0x0d090501; constexpr uint32_t k4ByteSequence_2_6_10_14 = 0x0e0a0602; constexpr uint32_t k4ByteSequence_3_7_11_15 = 0x0f0b0703; - constexpr uint64_t kEachByteIs1 = 0x0101010101010101ULL; constexpr uint64_t kByteSequence7DownTo0 = 0x0001020304050607ULL; constexpr uint64_t kByteSequence15DownTo8 = 0x08090A0B0C0D0E0FULL; // Bit unpack group ids into 1B. // Assemble the sequence of block bytes. uint64_t block_bytes[16]; - uint64_t groupid_bytes[16]; const int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); - uint64_t bit_unpack_mask = ((1 << num_groupid_bits) - 1) * kEachByteIs1; for (int i = 0; i < (1 << log_blocks_); ++i) { - uint64_t in_groupids = - *reinterpret_cast(blocks_ + (8 + num_groupid_bits) * i + 8); uint64_t in_blockbytes = *reinterpret_cast(blocks_ + (8 + num_groupid_bits) * i); - groupid_bytes[i] = _pdep_u64(in_groupids, bit_unpack_mask); block_bytes[i] = in_blockbytes; } @@ -275,18 +259,11 @@ void SwissTable::lookup_1_avx2_x32(const int num_hashes, const uint32_t* hashes, split_bytes_avx2(vblock_words0, vblock_words1, vblock_words2, vblock_words3, vblock_byte0, vblock_byte1, vblock_byte2, vblock_byte3, vblock_byte4, vblock_byte5, vblock_byte6, vblock_byte7); - split_bytes_avx2( - _mm256_loadu_si256(reinterpret_cast(groupid_bytes) + 0), - _mm256_loadu_si256(reinterpret_cast(groupid_bytes) + 1), - _mm256_loadu_si256(reinterpret_cast(groupid_bytes) + 2), - _mm256_loadu_si256(reinterpret_cast(groupid_bytes) + 3), - vgroupid_byte0, vgroupid_byte1, vgroupid_byte2, vgroupid_byte3, vgroupid_byte4, - vgroupid_byte5, vgroupid_byte6, vgroupid_byte7); // Calculate the slot to output when there is no match in a block. - // It will be the index of the first empty slot or 8 (the number of slots in block) + // It will be the index of the first empty slot or 7 (the number of slots in block) // if there are no empty slots. - vslot_empty_or_end = _mm256_set1_epi8(8); + vslot_empty_or_end = _mm256_set1_epi8(7); { __m256i vis_empty; #define CMP(VBLOCKBYTE, BYTENUM) \ @@ -304,6 +281,9 @@ void SwissTable::lookup_1_avx2_x32(const int num_hashes, const uint32_t* hashes, CMP(vblock_byte0, 0); #undef CMP } + __m256i vblock_is_full = _mm256_andnot_si256( + _mm256_cmpeq_epi8(vblock_byte7, _mm256_set1_epi8(static_cast(0x80))), + _mm256_set1_epi8(static_cast(0xff))); const int block_id_mask = (1 << log_blocks_) - 1; @@ -339,29 +319,28 @@ void SwissTable::lookup_1_avx2_x32(const int num_hashes, const uint32_t* hashes, __m256i vblock_id = _mm256_or_si256(vblock_id_A, _mm256_slli_epi16(vblock_id_B, 8)); // Visit all block bytes in reverse order (overwriting data on multiple matches) - __m256i vmatch_found = _mm256_setzero_si256(); + // + // Always set match found to true for full blocks. + // + __m256i vmatch_found = _mm256_shuffle_epi8(vblock_is_full, vblock_id); __m256i vslot_id = _mm256_shuffle_epi8(vslot_empty_or_end, vblock_id); - __m256i vgroup_id = _mm256_setzero_si256(); -#define CMP(VBLOCK_BYTE, VGROUPID_BYTE, BYTENUM) \ - { \ - __m256i vcmp = \ - _mm256_cmpeq_epi8(_mm256_shuffle_epi8(VBLOCK_BYTE, vblock_id), vstamp); \ - vmatch_found = _mm256_or_si256(vmatch_found, vcmp); \ - vgroup_id = _mm256_blendv_epi8(vgroup_id, \ - _mm256_shuffle_epi8(VGROUPID_BYTE, vblock_id), vcmp); \ - vslot_id = _mm256_blendv_epi8(vslot_id, _mm256_set1_epi8(BYTENUM + 1), vcmp); \ +#define CMP(VBLOCK_BYTE, BYTENUM) \ + { \ + __m256i vcmp = \ + _mm256_cmpeq_epi8(_mm256_shuffle_epi8(VBLOCK_BYTE, vblock_id), vstamp); \ + vmatch_found = _mm256_or_si256(vmatch_found, vcmp); \ + vslot_id = _mm256_blendv_epi8(vslot_id, _mm256_set1_epi8(BYTENUM), vcmp); \ } - CMP(vblock_byte7, vgroupid_byte7, 7); - CMP(vblock_byte6, vgroupid_byte6, 6); - CMP(vblock_byte5, vgroupid_byte5, 5); - CMP(vblock_byte4, vgroupid_byte4, 4); - CMP(vblock_byte3, vgroupid_byte3, 3); - CMP(vblock_byte2, vgroupid_byte2, 2); - CMP(vblock_byte1, vgroupid_byte1, 1); - CMP(vblock_byte0, vgroupid_byte0, 0); + CMP(vblock_byte7, 7); + CMP(vblock_byte6, 6); + CMP(vblock_byte5, 5); + CMP(vblock_byte4, 4); + CMP(vblock_byte3, 3); + CMP(vblock_byte2, 2); + CMP(vblock_byte1, 1); + CMP(vblock_byte0, 0); #undef CMP - vslot_id = _mm256_add_epi8(vslot_id, _mm256_slli_epi32(vblock_id, 3)); // So far the output is in the order: [0, 8, 16, 24, 1, 9, 17, 25, 2, 10, 18, 26, ...] vmatch_found = _mm256_shuffle_epi8( vmatch_found, @@ -374,30 +353,58 @@ void SwissTable::lookup_1_avx2_x32(const int num_hashes, const uint32_t* hashes, vmatch_found = _mm256_permutevar8x32_epi32(vmatch_found, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7)); + // Repeat the same permutation for slot ids + vslot_id = _mm256_shuffle_epi8( + vslot_id, _mm256_setr_epi32(k4ByteSequence_0_4_8_12, k4ByteSequence_1_5_9_13, + k4ByteSequence_2_6_10_14, k4ByteSequence_3_7_11_15, + k4ByteSequence_0_4_8_12, k4ByteSequence_1_5_9_13, + k4ByteSequence_2_6_10_14, k4ByteSequence_3_7_11_15)); + vslot_id = + _mm256_permutevar8x32_epi32(vslot_id, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7)); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_local_slots) + i, vslot_id); + reinterpret_cast(out_match_bitvector)[i] = _mm256_movemask_epi8(vmatch_found); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + 4 * i + 0, - _mm256_and_si256(vgroup_id, _mm256_set1_epi32(0xff))); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(out_group_ids) + 4 * i + 1, - _mm256_and_si256(_mm256_srli_epi32(vgroup_id, 8), _mm256_set1_epi32(0xff))); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(out_group_ids) + 4 * i + 2, - _mm256_and_si256(_mm256_srli_epi32(vgroup_id, 16), _mm256_set1_epi32(0xff))); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(out_group_ids) + 4 * i + 3, - _mm256_and_si256(_mm256_srli_epi32(vgroup_id, 24), _mm256_set1_epi32(0xff))); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_next_slot_ids) + 4 * i + 0, - _mm256_and_si256(vslot_id, _mm256_set1_epi32(0xff))); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(out_next_slot_ids) + 4 * i + 1, - _mm256_and_si256(_mm256_srli_epi32(vslot_id, 8), _mm256_set1_epi32(0xff))); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(out_next_slot_ids) + 4 * i + 2, - _mm256_and_si256(_mm256_srli_epi32(vslot_id, 16), _mm256_set1_epi32(0xff))); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(out_next_slot_ids) + 4 * i + 3, - _mm256_and_si256(_mm256_srli_epi32(vslot_id, 24), _mm256_set1_epi32(0xff))); + } +} + +void SwissTable::extract_group_ids_avx2(const int num_keys, const uint32_t* hashes, + const uint8_t* local_slots, + uint32_t* out_group_ids, int byte_offset, + int byte_multiplier, int byte_size) const { + ARROW_DCHECK(byte_size == 1 || byte_size == 2 || byte_size == 4); + uint32_t mask = byte_size == 1 ? 0xFF : byte_size == 2 ? 0xFFFF : 0xFFFFFFFF; + auto elements = reinterpret_cast(blocks_ + byte_offset); + constexpr int unroll = 8; + if (log_blocks_ == 0) { + ARROW_DCHECK(byte_size == 1 && byte_offset == 8 && byte_multiplier == 16); + __m256i block_group_ids = + _mm256_set1_epi64x(reinterpret_cast(blocks_)[1]); + for (int i = 0; i < (num_keys + unroll - 1) / unroll; ++i) { + __m256i local_slot = + _mm256_set1_epi64x(reinterpret_cast(local_slots)[i]); + __m256i group_id = _mm256_shuffle_epi8(block_group_ids, local_slot); + group_id = _mm256_shuffle_epi8( + group_id, _mm256_setr_epi32(0x80808000, 0x80808001, 0x80808002, 0x80808003, + 0x80808004, 0x80808005, 0x80808006, 0x80808007)); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + i, group_id); + } + } else { + for (int i = 0; i < (num_keys + unroll - 1) / unroll; ++i) { + __m256i hash = _mm256_loadu_si256(reinterpret_cast(hashes) + i); + __m256i local_slot = + _mm256_set1_epi64x(reinterpret_cast(local_slots)[i]); + local_slot = _mm256_shuffle_epi8( + local_slot, _mm256_setr_epi32(0x80808000, 0x80808001, 0x80808002, 0x80808003, + 0x80808004, 0x80808005, 0x80808006, 0x80808007)); + local_slot = _mm256_mullo_epi32(local_slot, _mm256_set1_epi32(byte_size)); + __m256i pos = _mm256_srlv_epi32(hash, _mm256_set1_epi32(bits_hash_ - log_blocks_)); + pos = _mm256_mullo_epi32(pos, _mm256_set1_epi32(byte_multiplier)); + pos = _mm256_add_epi32(pos, local_slot); + __m256i group_id = _mm256_i32gather_epi32(elements, pos, 1); + group_id = _mm256_and_si256(group_id, _mm256_set1_epi32(mask)); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + i, group_id); + } } } diff --git a/cpp/src/arrow/compute/exec/util.cc b/cpp/src/arrow/compute/exec/util.cc index a44676c2f0d..2cab1d68fd8 100644 --- a/cpp/src/arrow/compute/exec/util.cc +++ b/cpp/src/arrow/compute/exec/util.cc @@ -51,7 +51,8 @@ inline void BitUtil::bits_filter_indexes_helper(uint64_t word, template void BitUtil::bits_to_indexes_internal(int64_t hardware_flags, const int num_bits, const uint8_t* bits, const uint16_t* input_indexes, - int* num_indexes, uint16_t* indexes) { + int* num_indexes, uint16_t* indexes, + uint16_t base_index) { // 64 bits at a time constexpr int unroll = 64; int tail = num_bits % unroll; @@ -61,7 +62,8 @@ void BitUtil::bits_to_indexes_internal(int64_t hardware_flags, const int num_bit bits_filter_indexes_avx2(bit_to_search, num_bits - tail, bits, input_indexes, num_indexes, indexes); } else { - bits_to_indexes_avx2(bit_to_search, num_bits - tail, bits, num_indexes, indexes); + bits_to_indexes_avx2(bit_to_search, num_bits - tail, bits, num_indexes, indexes, + base_index); } } else { #endif @@ -74,7 +76,7 @@ void BitUtil::bits_to_indexes_internal(int64_t hardware_flags, const int num_bit if (filter_input_indexes) { bits_filter_indexes_helper(word, input_indexes + i * 64, num_indexes, indexes); } else { - bits_to_indexes_helper(word, i * 64, num_indexes, indexes); + bits_to_indexes_helper(word, i * 64 + base_index, num_indexes, indexes); } } #if defined(ARROW_HAVE_AVX2) @@ -92,41 +94,43 @@ void BitUtil::bits_to_indexes_internal(int64_t hardware_flags, const int num_bit bits_filter_indexes_helper(word, input_indexes + num_bits - tail, num_indexes, indexes); } else { - bits_to_indexes_helper(word, num_bits - tail, num_indexes, indexes); + bits_to_indexes_helper(word, num_bits - tail + base_index, num_indexes, indexes); } } } -void BitUtil::bits_to_indexes(int bit_to_search, int64_t hardware_flags, - const int num_bits, const uint8_t* bits, int* num_indexes, - uint16_t* indexes, int bit_offset) { +void BitUtil::bits_to_indexes(int bit_to_search, int64_t hardware_flags, int num_bits, + const uint8_t* bits, int* num_indexes, uint16_t* indexes, + int bit_offset) { bits += bit_offset / 8; bit_offset %= 8; + *num_indexes = 0; + uint16_t base_index = 0; if (bit_offset != 0) { - int num_indexes_head = 0; uint64_t bits_head = util::SafeLoad(reinterpret_cast(bits)) >> bit_offset; int bits_in_first_byte = std::min(num_bits, 8 - bit_offset); bits_to_indexes(bit_to_search, hardware_flags, bits_in_first_byte, - reinterpret_cast(&bits_head), &num_indexes_head, - indexes); - int num_indexes_tail = 0; - if (num_bits > bits_in_first_byte) { - bits_to_indexes(bit_to_search, hardware_flags, num_bits - bits_in_first_byte, - bits + 1, &num_indexes_tail, indexes + num_indexes_head); + reinterpret_cast(&bits_head), num_indexes, indexes); + if (num_bits <= bits_in_first_byte) { + return; } - *num_indexes = num_indexes_head + num_indexes_tail; - return; + num_bits -= bits_in_first_byte; + indexes += *num_indexes; + bits += 1; + base_index = bits_in_first_byte; } + int num_indexes_new = 0; if (bit_to_search == 0) { bits_to_indexes_internal<0, false>(hardware_flags, num_bits, bits, nullptr, - num_indexes, indexes); + &num_indexes_new, indexes, base_index); } else { ARROW_DCHECK(bit_to_search == 1); bits_to_indexes_internal<1, false>(hardware_flags, num_bits, bits, nullptr, - num_indexes, indexes); + &num_indexes_new, indexes, base_index); } + *num_indexes += num_indexes_new; } void BitUtil::bits_filter_indexes(int bit_to_search, int64_t hardware_flags, diff --git a/cpp/src/arrow/compute/exec/util.h b/cpp/src/arrow/compute/exec/util.h index d8248ceacab..6214c76b517 100644 --- a/cpp/src/arrow/compute/exec/util.h +++ b/cpp/src/arrow/compute/exec/util.h @@ -156,18 +156,20 @@ class BitUtil { template static void bits_to_indexes_internal(int64_t hardware_flags, const int num_bits, const uint8_t* bits, const uint16_t* input_indexes, - int* num_indexes, uint16_t* indexes); + int* num_indexes, uint16_t* indexes, + uint16_t base_index = 0); #if defined(ARROW_HAVE_AVX2) static void bits_to_indexes_avx2(int bit_to_search, const int num_bits, const uint8_t* bits, int* num_indexes, - uint16_t* indexes); + uint16_t* indexes, uint16_t base_index = 0); static void bits_filter_indexes_avx2(int bit_to_search, const int num_bits, const uint8_t* bits, const uint16_t* input_indexes, int* num_indexes, uint16_t* indexes); template static void bits_to_indexes_imp_avx2(const int num_bits, const uint8_t* bits, - int* num_indexes, uint16_t* indexes); + int* num_indexes, uint16_t* indexes, + uint16_t base_index = 0); template static void bits_filter_indexes_imp_avx2(const int num_bits, const uint8_t* bits, const uint16_t* input_indexes, diff --git a/cpp/src/arrow/compute/exec/util_avx2.cc b/cpp/src/arrow/compute/exec/util_avx2.cc index 8cf0104db46..bdc0e41f576 100644 --- a/cpp/src/arrow/compute/exec/util_avx2.cc +++ b/cpp/src/arrow/compute/exec/util_avx2.cc @@ -27,18 +27,19 @@ namespace util { void BitUtil::bits_to_indexes_avx2(int bit_to_search, const int num_bits, const uint8_t* bits, int* num_indexes, - uint16_t* indexes) { + uint16_t* indexes, uint16_t base_index) { if (bit_to_search == 0) { - bits_to_indexes_imp_avx2<0>(num_bits, bits, num_indexes, indexes); + bits_to_indexes_imp_avx2<0>(num_bits, bits, num_indexes, indexes, base_index); } else { ARROW_DCHECK(bit_to_search == 1); - bits_to_indexes_imp_avx2<1>(num_bits, bits, num_indexes, indexes); + bits_to_indexes_imp_avx2<1>(num_bits, bits, num_indexes, indexes, base_index); } } template void BitUtil::bits_to_indexes_imp_avx2(const int num_bits, const uint8_t* bits, - int* num_indexes, uint16_t* indexes) { + int* num_indexes, uint16_t* indexes, + uint16_t base_index) { // 64 bits at a time constexpr int unroll = 64; @@ -74,7 +75,7 @@ void BitUtil::bits_to_indexes_imp_avx2(const int num_bits, const uint8_t* bits, for (int j = 0; j < (num_indexes_loop + 15) / 16; ++j) { __m256i output = _mm256_cvtepi8_epi16( _mm_loadu_si128(reinterpret_cast(byte_indexes) + j)); - output = _mm256_add_epi16(output, _mm256_set1_epi16(i * 64)); + output = _mm256_add_epi16(output, _mm256_set1_epi16(i * 64 + base_index)); _mm256_storeu_si256(((__m256i*)(indexes + *num_indexes)) + j, output); } *num_indexes += num_indexes_loop; @@ -203,6 +204,9 @@ bool BitUtil::are_all_bytes_zero_avx2(const uint8_t* bytes, uint32_t num_bytes) __m256i x = _mm256_loadu_si256(reinterpret_cast(bytes) + i); result_or = _mm256_or_si256(result_or, x); } + result_or = _mm256_cmpeq_epi8(result_or, _mm256_set1_epi8(0)); + result_or = + _mm256_andnot_si256(result_or, _mm256_set1_epi8(static_cast(0xff))); uint32_t result_or32 = _mm256_movemask_epi8(result_or); if (num_bytes % 32 > 0) { uint64_t tail[4] = {0, 0, 0, 0}; diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 472ae956388..9d620242e10 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -606,9 +606,25 @@ struct GrouperFastImpl : Grouper { } // Map - RETURN_NOT_OK( - map_.map(batch_size_next, minibatch_hashes_.data(), - reinterpret_cast(group_ids->mutable_data()) + start_row)); + auto match_bitvector = + util::TempVectorHolder(&temp_stack_, (batch_size_next + 7) / 8); + { + auto local_slots = util::TempVectorHolder(&temp_stack_, batch_size_next); + map_.early_filter(batch_size_next, minibatch_hashes_.data(), + match_bitvector.mutable_data(), local_slots.mutable_data()); + map_.find(batch_size_next, minibatch_hashes_.data(), + match_bitvector.mutable_data(), local_slots.mutable_data(), + reinterpret_cast(group_ids->mutable_data()) + start_row); + } + auto ids = util::TempVectorHolder(&temp_stack_, batch_size_next); + int num_ids; + util::BitUtil::bits_to_indexes(0, encode_ctx_.hardware_flags, batch_size_next, + match_bitvector.mutable_data(), &num_ids, + ids.mutable_data()); + + RETURN_NOT_OK(map_.map_new_keys( + num_ids, ids.mutable_data(), minibatch_hashes_.data(), + reinterpret_cast(group_ids->mutable_data()) + start_row)); start_row += batch_size_next; From e33863c2e936bf7b94b66f93bd660055770f1c33 Mon Sep 17 00:00:00 2001 From: michalursa Date: Tue, 24 Aug 2021 00:06:23 -0700 Subject: [PATCH 2/3] Grouper filtering - code review requested changes --- cpp/src/arrow/compute/exec/key_encode.cc | 20 ++++--- cpp/src/arrow/compute/exec/key_map.cc | 66 ++++++++++++------------ 2 files changed, 46 insertions(+), 40 deletions(-) diff --git a/cpp/src/arrow/compute/exec/key_encode.cc b/cpp/src/arrow/compute/exec/key_encode.cc index e71b81b4574..f517561bfaa 100644 --- a/cpp/src/arrow/compute/exec/key_encode.cc +++ b/cpp/src/arrow/compute/exec/key_encode.cc @@ -866,8 +866,8 @@ void KeyEncoder::EncoderBinaryPair::Encode(uint32_t offset_within_row, KeyRowArr } DCHECK(temp1->metadata().is_fixed_length); - DCHECK(temp1->length() * temp1->metadata().fixed_length >= - col1.length() * static_cast(sizeof(uint16_t))); + DCHECK_GE(temp1->length() * temp1->metadata().fixed_length, + col1.length() * static_cast(sizeof(uint16_t))); KeyColumnArray temp16bit(KeyColumnMetadata(true, sizeof(uint16_t)), col1.length(), nullptr, temp1->mutable_data(1), nullptr); @@ -1370,15 +1370,23 @@ void KeyEncoder::KeyRowMetadata::FromColumnMetadataVector( const auto num_cols = static_cast(cols.size()); // Sort columns. + // // Columns are sorted based on the size in bytes of their fixed-length part. // For the varying-length column, the fixed-length part is the 32-bit field storing // cumulative length of varying-length fields. + // // The rules are: + // // a) Boolean column, marked with fixed-length 0, is considered to have fixed-length - // part of 1 byte. b) Columns with fixed-length part being power of 2 or multiple of row - // alignment precede other columns. They are sorted among themselves based on size of - // fixed-length part decreasing. c) Fixed-length columns precede varying-length columns - // when both have the same size fixed-length part. + // part of 1 byte. + // + // b) Columns with fixed-length part being power of 2 or multiple of row + // alignment precede other columns. They are sorted in decreasing order of the size of + // their fixed-length part. + // + // c) Fixed-length columns precede varying-length columns when + // both have the same size fixed-length part. + // column_order.resize(num_cols); for (uint32_t i = 0; i < num_cols; ++i) { column_order[i] = i; diff --git a/cpp/src/arrow/compute/exec/key_map.cc b/cpp/src/arrow/compute/exec/key_map.cc index 89a9918af8b..43c011c016b 100644 --- a/cpp/src/arrow/compute/exec/key_map.cc +++ b/cpp/src/arrow/compute/exec/key_map.cc @@ -154,42 +154,40 @@ void SwissTable::extract_group_ids(const int num_keys, const uint16_t* optional_ if ((hardware_flags_ & arrow::internal::CpuInfo::AVX2) && !optional_selection) { extract_group_ids_avx2(num_keys, hashes, local_slots, out_group_ids, sizeof(uint64_t), 8 + 8 * num_group_id_bytes, num_group_id_bytes); - } else { -#endif - switch (num_group_id_bits) { - case 8: - if (optional_selection) { - extract_group_ids_imp(num_keys, optional_selection, hashes, - local_slots, out_group_ids, 8, 16); - } else { - extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, - out_group_ids, 8, 16); - } - break; - case 16: - if (optional_selection) { - extract_group_ids_imp(num_keys, optional_selection, hashes, - local_slots, out_group_ids, 4, 12); - } else { - extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, - out_group_ids, 4, 12); - } - break; - case 32: - if (optional_selection) { - extract_group_ids_imp(num_keys, optional_selection, hashes, - local_slots, out_group_ids, 2, 10); - } else { - extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, - out_group_ids, 2, 10); - } - break; - default: - ARROW_DCHECK(false); - } -#if defined(ARROW_HAVE_AVX2) + return; } #endif + switch (num_group_id_bits) { + case 8: + if (optional_selection) { + extract_group_ids_imp(num_keys, optional_selection, hashes, + local_slots, out_group_ids, 8, 16); + } else { + extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, + out_group_ids, 8, 16); + } + break; + case 16: + if (optional_selection) { + extract_group_ids_imp(num_keys, optional_selection, hashes, + local_slots, out_group_ids, 4, 12); + } else { + extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, + out_group_ids, 4, 12); + } + break; + case 32: + if (optional_selection) { + extract_group_ids_imp(num_keys, optional_selection, hashes, + local_slots, out_group_ids, 2, 10); + } else { + extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, + out_group_ids, 2, 10); + } + break; + default: + ARROW_DCHECK(false); + } } void SwissTable::init_slot_ids(const int num_keys, const uint16_t* selection, From ea3ad2472b0237c68214abf4c94ba7c78b28ad64 Mon Sep 17 00:00:00 2001 From: michalursa Date: Tue, 31 Aug 2021 20:09:35 -0700 Subject: [PATCH 3/3] [WIP] Inner and outer hash join --- cpp/src/arrow/CMakeLists.txt | 4 + cpp/src/arrow/compute/exec/join/join.h | 119 ++ cpp/src/arrow/compute/exec/join/join_batch.cc | 1046 +++++++++++++++++ cpp/src/arrow/compute/exec/join/join_batch.h | 283 +++++ .../arrow/compute/exec/join/join_filter.cc | 82 ++ cpp/src/arrow/compute/exec/join/join_filter.h | 118 ++ .../compute/exec/join/join_filter_avx2.cc | 75 ++ .../arrow/compute/exec/join/join_hashtable.cc | 260 ++++ .../arrow/compute/exec/join/join_hashtable.h | 104 ++ cpp/src/arrow/compute/exec/join/join_probe.h | 129 ++ cpp/src/arrow/compute/exec/join/join_schema.h | 156 +++ cpp/src/arrow/compute/exec/join/join_side.h | 116 ++ cpp/src/arrow/compute/exec/join/join_type.h | 82 ++ cpp/src/arrow/compute/exec/key_encode.cc | 32 +- cpp/src/arrow/compute/exec/key_encode.h | 6 + cpp/src/arrow/compute/exec/key_map.cc | 182 ++- cpp/src/arrow/compute/exec/key_map.h | 45 +- cpp/src/arrow/compute/exec/util.h | 18 + .../arrow/compute/kernels/hash_aggregate.cc | 20 +- 19 files changed, 2750 insertions(+), 127 deletions(-) create mode 100644 cpp/src/arrow/compute/exec/join/join.h create mode 100644 cpp/src/arrow/compute/exec/join/join_batch.cc create mode 100644 cpp/src/arrow/compute/exec/join/join_batch.h create mode 100644 cpp/src/arrow/compute/exec/join/join_filter.cc create mode 100644 cpp/src/arrow/compute/exec/join/join_filter.h create mode 100644 cpp/src/arrow/compute/exec/join/join_filter_avx2.cc create mode 100644 cpp/src/arrow/compute/exec/join/join_hashtable.cc create mode 100644 cpp/src/arrow/compute/exec/join/join_hashtable.h create mode 100644 cpp/src/arrow/compute/exec/join/join_probe.h create mode 100644 cpp/src/arrow/compute/exec/join/join_schema.h create mode 100644 cpp/src/arrow/compute/exec/join/join_side.h create mode 100644 cpp/src/arrow/compute/exec/join/join_type.h diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index f13e5b1ef75..7dedeab436f 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -410,6 +410,9 @@ if(ARROW_COMPUTE) compute/exec/key_map.cc compute/exec/key_compare.cc compute/exec/key_encode.cc + compute/exec/join/join_batch.cc + compute/exec/join/join_hashtable.cc + compute/exec/join/join_filter.cc compute/exec/util.cc) append_avx2_src(compute/kernels/aggregate_basic_avx2.cc) @@ -419,6 +422,7 @@ if(ARROW_COMPUTE) append_avx2_src(compute/exec/key_map_avx2.cc) append_avx2_src(compute/exec/key_compare_avx2.cc) append_avx2_src(compute/exec/key_encode_avx2.cc) + append_avx2_src(compute/exec/join/join_filter_avx2.cc) append_avx2_src(compute/exec/util_avx2.cc) list(APPEND ARROW_TESTING_SRCS compute/exec/test_util.cc) diff --git a/cpp/src/arrow/compute/exec/join/join.h b/cpp/src/arrow/compute/exec/join/join.h new file mode 100644 index 00000000000..bbd7016b628 --- /dev/null +++ b/cpp/src/arrow/compute/exec/join/join.h @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +namespace arrow { +namespace compute { + +/* + The basis for the future implementation of main hash join interface - its exec node. + TODO: Implement missing ExecNode class +*/ + +/* +class HashJoin { + enum class BatchSource { INPUT, SAVED, SAVED_FILTERED, HASH_TABLE }; + + void Make(Schema left_keys, Schema right_keys, Schema left_output, Schema right_output, + std::string output_field_name_prefix, Schema left_filter_input, + Schema right_filter_input, + // residual filter callback or template + ); + +// Refer to state diagram +void ProcessInputBatch(int side, BatchSource source, + std::shared_ptr>& batch) { + ARROW_DCHECK(side == 0 || side == 1); + int other_side = 1 - side; + + auto state = input_state_mgr.get(side); + ARROW_DCHECK(state == HAS_MORE_ROWS); + auto other_state = input_state_mgr.get(other_side); + switch (other_state) { + case READING_INPUT: + case BUILDING_EARLY_FILTER: + input_data[side].AppendBatch(batch); + break; + case EARLY_FILTER_READY: + case BUILDING_HASH_TABLE: + auto join_key_batch = + ProjectBatch(batch, input_schema[side], join_key_schema[side]); + auto hash = ComputeHash(join_key_batch); + auto filter_result = EarlyFilter(hash, input_data[other_side].early_filter); + // TODO: Depending on the join type either output or remove rows with no match + input_data[side].AppendFiltered(batch, hash, filter_result); + break; + case HASH_TABLE_READY: + HashTableProbe(batch); + break; + default: + ARROW_DCHECK(false); + } +} + +// List of tasks: +// building early filter +// building hash table +// filtering using early filter +// probing hash table using +void ExecuteInternalTask() { + // Messsage processing for (;;) loop + // returns false when finished + // can be called multiple times after returning false and will consistently keep + // returning false can result in an empty call that does not return false if one of + // the join inputs is still streaming +} + +void SaveBatch() {} + +void FilterAndSaveBatch() {} + +void ProbeHashTable() {} + +void BuildEarlyFilter() {} + +void BuildHashTable() {} + +void ScanHashTable() {} + +void OnFinishedProcessingInputBatch(int side, BatchSource source) {} + +void OnFinishedReadingInput() { + // Check if this was the last batch + // State transition + // Build early filter +} + +void OnSizeLimitReached() {} + +void OnFinishedBuildingEarlyFilter() {} + +void OnFinishedBuildingHashTable() {} + +void OnFinishedHashTableProbing() {} + +void OnFinishedHashTableScan() {} +} +; + +*/ + +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/exec/join/join_batch.cc b/cpp/src/arrow/compute/exec/join/join_batch.cc new file mode 100644 index 00000000000..3c2d8321cd4 --- /dev/null +++ b/cpp/src/arrow/compute/exec/join/join_batch.cc @@ -0,0 +1,1046 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec/join/join_batch.h" + +#include + +#include + +#include "arrow/compute/exec/join/join_hashtable.h" +#include "arrow/compute/exec/key_encode.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/ubsan.h" + +namespace arrow { +namespace compute { + +ShuffleOutputDesc::ShuffleOutputDesc( + std::vector>& in_buffers, int64_t in_length, + bool in_has_nulls) { + offset = in_length; + for (int i = 0; i < 3; ++i) { + DCHECK(static_cast(in_buffers.size()) > i && in_buffers[i].get() != nullptr); + buffer[i] = in_buffers[i].get(); + } + has_nulls = in_has_nulls; +} + +Status ShuffleOutputDesc::ResizeBufferNonNull(int num_new_rows) { + int64_t new_num_rows = offset + num_new_rows; + int64_t old_size = BitUtil::BytesForBits(offset); + int64_t new_size = BitUtil::BytesForBits(new_num_rows); + RETURN_NOT_OK(buffer[0]->Resize(new_size, false)); + uint8_t* data = buffer[0]->mutable_data(); + if (!has_nulls) { + memset(data, 0xff, BitUtil::BytesForBits(offset)); + } + if (offset % 8 > 0) { + data[old_size - 1] |= static_cast(0xff << (offset % 8)); + } + memset(data + old_size, 0xff, new_size - old_size); + return Status::OK(); +} + +Status ShuffleOutputDesc::ResizeBufferFixedLen( + int num_new_rows, const KeyEncoder::KeyColumnMetadata& metadata) { + int64_t new_num_rows = offset + num_new_rows; + int64_t new_size = + metadata.is_fixed_length + ? (metadata.fixed_length == 0 ? BitUtil::BytesForBits(new_num_rows) + : new_num_rows * metadata.fixed_length) + : (new_num_rows + 1) * sizeof(uint32_t); + RETURN_NOT_OK(buffer[1]->Resize(new_size, false)); + if (offset == 0 && offset == 0) { + reinterpret_cast(buffer[1]->mutable_data())[0] = 0; + } + return Status::OK(); +} + +Status ShuffleOutputDesc::ResizeBufferVarLen(int num_new_rows) { + const uint32_t* offsets = reinterpret_cast(buffer[1]->mutable_data()); + int64_t new_num_rows = offset + num_new_rows; + constexpr int64_t extra_padding_for_data_move = sizeof(uint64_t); + RETURN_NOT_OK( + buffer[2]->Resize(offsets[new_num_rows] + extra_padding_for_data_move, false)); + return Status::OK(); +} + +BatchShuffle::ShuffleInputDesc::ShuffleInputDesc( + const uint8_t* non_null_buf, const uint8_t* fixed_len_buf, const uint8_t* var_len_buf, + int in_num_rows, int in_start_row, const uint16_t* in_opt_row_ids, + const KeyEncoder::KeyColumnMetadata& in_metadata) + : num_rows(in_num_rows), + opt_row_ids(in_opt_row_ids), + offset(in_start_row), + metadata(in_metadata.is_fixed_length, in_metadata.fixed_length) { + buffer[0] = non_null_buf; + buffer[1] = fixed_len_buf; + buffer[2] = var_len_buf; +} + +Status BatchShuffle::ShuffleNull(ShuffleOutputDesc& output, const ShuffleInputDesc& input, + Shuffle_ThreadLocal& ctx, bool* out_has_nulls) { + uint8_t* dst = output.buffer[0]->mutable_data(); + const uint8_t* src = input.buffer[0]; + + *out_has_nulls = output.has_nulls; + if (!output.has_nulls && !src) { + return Status::OK(); + } + + auto temp_bytes_buf = + util::TempVectorHolder(ctx.temp_stack, ctx.minibatch_size); + uint8_t* temp_bytes = temp_bytes_buf.mutable_data(); + + bool output_buffer_resized = false; + + for (int start = 0; start < input.num_rows; start += ctx.minibatch_size) { + int mini_batch_size = std::min(input.num_rows - start, ctx.minibatch_size); + uint8_t byte_and = 0xff; + if (input.opt_row_ids) { + for (int i = 0; i < mini_batch_size; ++i) { + uint8_t next_byte = + BitUtil::GetBit(src, input.offset + input.opt_row_ids[start + i]) ? 0xFF + : 0x00; + temp_bytes[i] = next_byte; + byte_and &= next_byte; + } + } else { + util::BitUtil::bits_to_bytes(ctx.hardware_flags, mini_batch_size, src, temp_bytes, + static_cast(input.offset + start)); + } + if (byte_and == 0) { + *out_has_nulls = true; + if (!output_buffer_resized) { + RETURN_NOT_OK(output.ResizeBufferNonNull(input.num_rows)); + output_buffer_resized = true; + } + util::BitUtil::bytes_to_bits(ctx.hardware_flags, mini_batch_size, temp_bytes, dst, + static_cast(output.offset + start)); + } + } + + if (!output_buffer_resized && output.has_nulls) { + RETURN_NOT_OK(output.ResizeBufferNonNull(input.num_rows)); + output_buffer_resized = true; + } + + return Status::OK(); +} + +void BatchShuffle::ShuffleBit(ShuffleOutputDesc& output, const ShuffleInputDesc& input, + Shuffle_ThreadLocal& ctx) { + uint8_t* dst = output.buffer[1]->mutable_data(); + const uint8_t* src = input.buffer[1]; + + auto temp_bytes_buf = + util::TempVectorHolder(ctx.temp_stack, ctx.minibatch_size); + uint8_t* temp_bytes = temp_bytes_buf.mutable_data(); + + for (int start = 0; start < input.num_rows; start += ctx.minibatch_size) { + int mini_batch_size = std::min(input.num_rows - start, ctx.minibatch_size); + if (input.opt_row_ids) { + for (int i = 0; i < mini_batch_size; ++i) { + temp_bytes[i] = BitUtil::GetBit(src, input.offset + input.opt_row_ids[start + i]) + ? 0xFF + : 0x00; + } + } else { + util::BitUtil::bits_to_bytes(ctx.hardware_flags, mini_batch_size, src, temp_bytes, + static_cast(input.offset + start)); + } + util::BitUtil::bytes_to_bits(ctx.hardware_flags, mini_batch_size, temp_bytes, dst, + static_cast(output.offset + start)); + } +} + +template +void BatchShuffle::ShuffleInteger(ShuffleOutputDesc& output, + const ShuffleInputDesc& input) { + T* dst = reinterpret_cast(output.buffer[1]->mutable_data()); + const T* src = reinterpret_cast(input.buffer[1]); + dst += output.offset; + src += input.offset; + if (input.opt_row_ids) { + for (int i = 0; i < input.num_rows; ++i) { + dst[i] = src[input.opt_row_ids[i]]; + } + } else { + memcpy(dst, src, input.num_rows * sizeof(T)); + } +} + +void BatchShuffle::ShuffleBinary(ShuffleOutputDesc& output, + const ShuffleInputDesc& input) { + uint8_t* dst = output.buffer[1]->mutable_data(); + const uint8_t* src = input.buffer[1]; + int binary_width = static_cast(input.metadata.fixed_length); + dst += binary_width * output.offset; + src += binary_width * input.offset; + if (input.opt_row_ids) { + if (binary_width % sizeof(uint64_t) == 0) { + uint64_t* dst64 = reinterpret_cast(dst); + const uint64_t* src64 = reinterpret_cast(src); + int num_words = binary_width / sizeof(uint64_t); + for (int i = 0; i < input.num_rows; ++i) { + for (int word = 0; word < num_words; ++word) { + dst64[i * num_words + word] = src64[input.opt_row_ids[i] * num_words + word]; + } + } + } else { + for (int i = 0; i < input.num_rows; ++i) { + memcpy(dst + i * binary_width, src + input.opt_row_ids[i] * binary_width, + binary_width); + } + } + } else { + memcpy(dst, src, binary_width * input.num_rows); + } +} + +void BatchShuffle::ShuffleOffset(ShuffleOutputDesc& output, + const ShuffleInputDesc& input) { + uint32_t* dst = reinterpret_cast(output.buffer[1]->mutable_data()); + if (output.offset == 0) { + dst[0] = 0; + } + const uint32_t* src = reinterpret_cast(input.buffer[1]); + dst += output.offset; + src += input.offset; + + if (input.opt_row_ids) { + uint32_t last_dst_offset = dst[0]; + for (int i = 0; i < input.num_rows; ++i) { + int src_pos = input.opt_row_ids[i]; + last_dst_offset += src[src_pos + 1] - src[src_pos]; + dst[i + 1] = last_dst_offset; + } + } else { + int delta = dst[0] - src[0]; + for (int i = 0; i < input.num_rows; ++i) { + dst[i + 1] = static_cast(static_cast(src[i + 1]) + delta); + } + } +} + +void BatchShuffle::ShuffleVarBinary(ShuffleOutputDesc& output, + const ShuffleInputDesc& input) { + uint8_t* dst = output.buffer[2]->mutable_data(); + const uint8_t* src = input.buffer[2]; + uint32_t* dst_offsets = reinterpret_cast(output.buffer[1]->mutable_data()); + const uint32_t* src_offsets = reinterpret_cast(input.buffer[1]); + dst_offsets += output.offset; + src_offsets += input.offset; + + if (input.opt_row_ids) { + for (int i = 0; i < input.num_rows; ++i) { + memcpy(dst + dst_offsets[i], src + src_offsets[input.opt_row_ids[i]], + dst_offsets[i + 1] - dst_offsets[i]); + } + } else { + memcpy(dst + dst_offsets[0], src + src_offsets[0], + dst_offsets[input.num_rows] - dst_offsets[0]); + } +} + +Status BatchShuffle::Shuffle(ShuffleOutputDesc& output, const ShuffleInputDesc& input, + Shuffle_ThreadLocal& ctx, bool* out_has_nulls) { + if (input.num_rows == 0) { + return Status::OK(); + } + RETURN_NOT_OK(output.ResizeBufferFixedLen(input.num_rows, input.metadata)); + if (!input.metadata.is_fixed_length) { + ShuffleOffset(output, input); + RETURN_NOT_OK(output.ResizeBufferVarLen(input.num_rows)); + ShuffleVarBinary(output, input); + } else { + switch (input.metadata.fixed_length) { + case 0: + ShuffleBit(output, input, ctx); + break; + case 1: + ShuffleInteger(output, input); + break; + case 2: + ShuffleInteger(output, input); + break; + case 4: + ShuffleInteger(output, input); + break; + case 8: + ShuffleInteger(output, input); + break; + default: + ShuffleBinary(output, input); + break; + } + } + RETURN_NOT_OK(ShuffleNull(output, input, ctx, out_has_nulls)); + return Status::OK(); +} + +KeyRowArrayShuffle::ShuffleInputDesc::ShuffleInputDesc( + const KeyEncoder::KeyRowArray& in_rows, int in_column_id, int in_num_rows, + const key_id_type* in_row_ids) + : rows(&in_rows), + column_id(in_column_id), + num_rows(in_num_rows), + row_ids(in_row_ids) { + const KeyEncoder::KeyRowMetadata& row_metadata = rows->metadata(); + int column_id_after_reordering = -1; + for (uint32_t i = 0; i < row_metadata.num_cols(); ++i) { + if (row_metadata.encoded_field_order(i) == static_cast(column_id)) { + column_id_after_reordering = static_cast(i); + break; + } + } + DCHECK_GE(column_id_after_reordering, 0); + null_bit_id = column_id_after_reordering; + offset_within_row = row_metadata.encoded_field_offset(column_id_after_reordering); + metadata = row_metadata.column_metadatas[column_id]; + if (!metadata.is_fixed_length) { + int delta = static_cast(offset_within_row) - + static_cast(row_metadata.varbinary_end_array_offset); + DCHECK_GE(delta, 0); + DCHECK(delta % sizeof(uint32_t) == 0); + varbinary_id = delta / sizeof(uint32_t); + } else { + varbinary_id = -1; + } +} + +Status KeyRowArrayShuffle::ShuffleNull(ShuffleOutputDesc& output, + const ShuffleInputDesc& input, + Shuffle_ThreadLocal& ctx, bool* out_has_nulls) { + KeyEncoder::KeyEncoderContext encoder_ctx; + encoder_ctx.hardware_flags = ctx.hardware_flags; + encoder_ctx.stack = ctx.temp_stack; + bool input_has_nulls = input.rows->has_any_nulls(&encoder_ctx); + bool output_has_nulls = output.has_nulls; + *out_has_nulls = output.has_nulls; + if (!input_has_nulls && !output_has_nulls) { + return Status::OK(); + } + + // Allocate temporary buffers for mini batch of elements + // + auto temp_bytes_buf = + util::TempVectorHolder(ctx.temp_stack, ctx.minibatch_size); + uint8_t* temp_bytes = temp_bytes_buf.mutable_data(); + + // Prepare metadata + // + const uint8_t* null_masks = input.rows->null_masks(); + int null_masks_bytes_per_row = input.rows->metadata().null_masks_bytes_per_row; + int null_bit_id = input.null_bit_id; + + bool output_buffer_resized = false; + + // Split input into mini batches + // + for (int start = 0; start < input.num_rows; start += ctx.minibatch_size) { + int batch_size = std::min(input.num_rows - start, ctx.minibatch_size); + uint8_t byte_and = 0xff; + for (int i = 0; i < batch_size; ++i) { + int64_t row_id = input.row_ids[start + i]; + uint8_t next_byte = + BitUtil::GetBit(null_masks, row_id * null_masks_bytes_per_row * 8 + null_bit_id) + ? 0 + : 0xff; + temp_bytes[i] = next_byte; + byte_and &= next_byte; + } + if (byte_and == 0) { + *out_has_nulls = true; + if (!output_buffer_resized) { + RETURN_NOT_OK(output.ResizeBufferNonNull(input.num_rows)); + output_buffer_resized = true; + } + util::BitUtil::bytes_to_bits(ctx.hardware_flags, batch_size, temp_bytes, + output.buffer[0]->mutable_data(), + static_cast(output.offset + start)); + } + } + + if (!output_buffer_resized && output.has_nulls) { + RETURN_NOT_OK(output.ResizeBufferNonNull(input.num_rows)); + output_buffer_resized = true; + } + + return Status::OK(); +} + +void KeyRowArrayShuffle::ShuffleBit(ShuffleOutputDesc& output, + const ShuffleInputDesc& input, + Shuffle_ThreadLocal& ctx) { + auto metadata = input.rows->metadata(); + uint32_t offset_within_row = input.offset_within_row; + + auto temp_bytes_buf = + util::TempVectorHolder(ctx.temp_stack, ctx.minibatch_size); + uint8_t* temp_bytes = temp_bytes_buf.mutable_data(); + + // Split input into mini batches + // + for (int start = 0; start < input.num_rows; start += ctx.minibatch_size) { + int batch_size = std::min(input.num_rows - start, ctx.minibatch_size); + if (metadata.is_fixed_length) { + const uint8_t* src = input.rows->data(1) + offset_within_row; + for (int i = 0; i < batch_size; ++i) { + temp_bytes[i] = + *(src + input.row_ids[start + i] * metadata.fixed_length) == 0 ? 0 : 0xff; + } + } else { + const uint32_t* offsets = input.rows->offsets(); + const uint8_t* src = input.rows->data(2) + offset_within_row; + for (int i = 0; i < batch_size; ++i) { + temp_bytes[i] = *(src + offsets[input.row_ids[start + i]]) == 0 ? 0 : 0xff; + } + } + util::BitUtil::bytes_to_bits(ctx.hardware_flags, batch_size, temp_bytes, + output.buffer[1]->mutable_data(), + static_cast(output.offset + start)); + } +} + +template +void KeyRowArrayShuffle::ShuffleInteger(ShuffleOutputDesc& output, + const ShuffleInputDesc& input) { + auto metadata = input.rows->metadata(); + uint32_t offset_within_row = input.offset_within_row; + + T* dst = reinterpret_cast(output.buffer[1]->mutable_data()); + if (metadata.is_fixed_length) { + const uint8_t* src = input.rows->data(1) + offset_within_row; + for (int i = 0; i < input.num_rows; ++i) { + dst[output.offset + i] = + *reinterpret_cast(src + input.row_ids[i] * metadata.fixed_length); + } + } else { + const uint32_t* offsets = input.rows->offsets(); + const uint8_t* src = input.rows->data(2) + offset_within_row; + for (int i = 0; i < input.num_rows; ++i) { + dst[output.offset + i] = + *reinterpret_cast(src + offsets[input.row_ids[i]]); + } + } +} + +void KeyRowArrayShuffle::ShuffleBinary(ShuffleOutputDesc& output, + const ShuffleInputDesc& input) { + auto metadata = input.rows->metadata(); + auto column_metadata = input.metadata; + uint32_t offset_within_row = input.offset_within_row; + + uint8_t* dst = output.buffer[1]->mutable_data(); + const uint8_t* src = input.rows->data(1) + offset_within_row; + + if (column_metadata.fixed_length % sizeof(uint64_t) == 0) { + int num_words = column_metadata.fixed_length / sizeof(uint64_t); + if (metadata.is_fixed_length) { + for (int i = 0; i < input.num_rows; ++i) { + for (int word = 0; word < num_words; ++word) { + reinterpret_cast(dst)[(output.offset + i) * num_words + word] = + reinterpret_cast(src + input.row_ids[i] * + metadata.fixed_length)[word]; + } + } + } else { + const uint32_t* offsets = input.rows->offsets(); + for (int i = 0; i < input.num_rows; ++i) { + for (int word = 0; word < num_words; ++word) { + reinterpret_cast(dst)[(output.offset + i) * num_words + word] = + reinterpret_cast(src + offsets[input.row_ids[i]])[word]; + } + } + } + } else { + if (metadata.is_fixed_length) { + for (int i = 0; i < input.num_rows; ++i) { + memcpy(dst + (output.offset + i) * column_metadata.fixed_length, + src + input.row_ids[i] * metadata.fixed_length, + column_metadata.fixed_length); + } + } else { + const uint32_t* offsets = input.rows->offsets(); + for (int i = 0; i < input.num_rows; ++i) { + memcpy(dst + (output.offset + i) * column_metadata.fixed_length, + src + offsets[input.row_ids[i]], column_metadata.fixed_length); + } + } + } +} + +void KeyRowArrayShuffle::ShuffleOffset(ShuffleOutputDesc& output, + const ShuffleInputDesc& input) { + auto metadata = input.rows->metadata(); + int varbinary_id = input.varbinary_id; + uint32_t offset_within_row = input.offset_within_row; + + uint32_t* dst = reinterpret_cast(output.buffer[1]->mutable_data()); + if (output.offset == 0) { + dst[0] = 0; + } + const uint8_t* src = input.rows->data(1) + offset_within_row; + const uint32_t* offsets = input.rows->offsets(); + + if (varbinary_id == 0) { + int prev_value = dst[output.offset]; + for (int i = 0; i < input.num_rows; ++i) { + prev_value += *reinterpret_cast(src + offsets[input.row_ids[i]]); + dst[output.offset + i + 1] = prev_value; + } + } else { + int prev_value = dst[output.offset]; + for (int i = 0; i < input.num_rows; ++i) { + const uint32_t* varbinary_end = + reinterpret_cast(src + offsets[input.row_ids[i]]); + prev_value += varbinary_end[0] - varbinary_end[-1]; + dst[output.offset + i + 1] = prev_value; + } + } +} + +void KeyRowArrayShuffle::ShuffleVarBinary(ShuffleOutputDesc& output, + const ShuffleInputDesc& input) { + auto metadata = input.rows->metadata(); + int varbinary_id = input.varbinary_id; + + uint8_t* dst = output.buffer[2]->mutable_data(); + const uint32_t* dst_offsets = + reinterpret_cast(output.buffer[1]->data()); + const uint8_t* src = input.rows->data(2); + const uint32_t* src_offsets = input.rows->offsets(); + + if (varbinary_id == 0) { + for (int i = 0; i < input.num_rows; ++i) { + const uint8_t* src_row = src + src_offsets[input.row_ids[i]]; + uint32_t offset_within_row; + uint32_t length; + metadata.first_varbinary_offset_and_length(src_row, &offset_within_row, &length); + src_row += offset_within_row; + int64_t num_words = BitUtil::CeilDiv(length, sizeof(uint64_t)); + uint8_t* dst_row = dst + dst_offsets[output.offset + i]; + for (int64_t word = 0; word < num_words; ++word) { + util::SafeStore(dst_row + word * sizeof(uint64_t), + reinterpret_cast(src_row)[word]); + } + } + } else { + for (int i = 0; i < input.num_rows; ++i) { + const uint8_t* src_row = src + src_offsets[input.row_ids[i]]; + uint32_t offset_within_row; + uint32_t length; + metadata.nth_varbinary_offset_and_length(src_row, varbinary_id, &offset_within_row, + &length); + src_row += offset_within_row; + int64_t num_words = BitUtil::CeilDiv(length, sizeof(uint64_t)); + uint8_t* dst_row = dst + dst_offsets[output.offset + i]; + for (int64_t word = 0; word < num_words; ++word) { + util::SafeStore(dst_row + word * sizeof(uint64_t), + reinterpret_cast(src_row)[word]); + } + } + } +} + +Status KeyRowArrayShuffle::Shuffle(ShuffleOutputDesc& output, + const ShuffleInputDesc& input, + Shuffle_ThreadLocal& ctx, bool* out_has_nulls) { + if (input.num_rows == 0) { + return Status::OK(); + } + RETURN_NOT_OK(output.ResizeBufferFixedLen(input.num_rows, input.metadata)); + if (!input.metadata.is_fixed_length) { + ShuffleOffset(output, input); + RETURN_NOT_OK(output.ResizeBufferVarLen(input.num_rows)); + ShuffleVarBinary(output, input); + } else { + switch (input.metadata.fixed_length) { + case 0: + ShuffleBit(output, input, ctx); + break; + case 1: + ShuffleInteger(output, input); + break; + case 2: + ShuffleInteger(output, input); + break; + case 4: + ShuffleInteger(output, input); + break; + case 8: + ShuffleInteger(output, input); + break; + default: + ShuffleBinary(output, input); + break; + } + } + RETURN_NOT_OK(ShuffleNull(output, input, ctx, out_has_nulls)); + return Status::OK(); +} + +BatchWithJoinData::BatchWithJoinData(int in_join_side, const ExecBatch& in_batch) { + join_side = in_join_side; + batch = in_batch; + hashes = nullptr; +} + +Status BatchWithJoinData::ComputeHashIfMissing(MemoryPool* pool, + JoinColumnMapper* schema_mgr, + JoinHashTable_ThreadLocal* locals) { + if (!hashes) { + ARROW_ASSIGN_OR_RAISE( + hashes, AllocateResizableBuffer(sizeof(hash_type) * batch.length, pool)); + hash_type* hash_values = reinterpret_cast(hashes->mutable_data()); + + // Encode + RETURN_NOT_OK(Encode( + 0, batch.length, locals->key_encoder, locals->keys_minibatch, schema_mgr, + join_side == 0 ? JoinSchemaHandle::FIRST_INPUT : JoinSchemaHandle::SECOND_INPUT, + join_side == 0 ? JoinSchemaHandle::FIRST_KEY : JoinSchemaHandle::SECOND_KEY)); + + if (locals->key_encoder.row_metadata().is_fixed_length) { + ::arrow::compute::Hashing::hash_fixed( + locals->encoder_ctx.hardware_flags, static_cast(batch.length), + locals->key_encoder.row_metadata().fixed_length, locals->keys_minibatch.data(1), + hash_values); + } else { + auto hash_temp_buf = ::arrow::util::TempVectorHolder( + &locals->stack, 4 * locals->minibatch_size); + + for (int64_t start = 0; start < batch.length; start += locals->minibatch_size) { + int next_batch_size = + std::min(static_cast(batch.length - start), locals->minibatch_size); + + ::arrow::compute::Hashing::hash_varlen( + locals->encoder_ctx.hardware_flags, next_batch_size, + locals->keys_minibatch.offsets() + start, locals->keys_minibatch.data(2), + hash_temp_buf.mutable_data(), hash_values + start); + } + } + } + return Status::OK(); +} + +Status BatchWithJoinData::Encode(int64_t start_row, int64_t num_rows, KeyEncoder& encoder, + KeyEncoder::KeyRowArray& rows, + JoinColumnMapper* schema_mgr, + JoinSchemaHandle batch_schema, + JoinSchemaHandle output_schema) const { + int num_output_cols = schema_mgr->num_cols(output_schema); + const int* col_map = schema_mgr->map(output_schema, batch_schema); + std::vector temp_cols(num_output_cols); + + for (int output_col = 0; output_col < num_output_cols; ++output_col) { + int input_col = col_map[output_col]; + KeyEncoder::KeyColumnMetadata col_metadata = + schema_mgr->data_type(output_schema, output_col); + const uint8_t* non_nulls = nullptr; + if (batch[input_col].array()->buffers[0] != NULLPTR) { + non_nulls = batch[input_col].array()->buffers[0]->data(); + } + const uint8_t* fixedlen = batch[input_col].array()->buffers[1]->data(); + const uint8_t* varlen = nullptr; + if (!col_metadata.is_fixed_length) { + varlen = batch[input_col].array()->buffers[2]->data(); + } + int64_t offset = batch[input_col].array()->offset; + auto col_base = arrow::compute::KeyEncoder::KeyColumnArray( + col_metadata, offset + start_row + num_rows, non_nulls, fixedlen, varlen); + temp_cols[output_col] = arrow::compute::KeyEncoder::KeyColumnArray( + col_base, offset + start_row, num_rows); + } + + rows.Clean(); + RETURN_NOT_OK(encoder.PrepareOutputForEncode(0, num_rows, &rows, temp_cols)); + encoder.Encode(0, num_rows, &rows, temp_cols); + + return Status::OK(); +} + +Status BatchAccumulation::Init(JoinSchemaHandle schema, JoinColumnMapper* schema_mgr, + int64_t max_batch_size, MemoryPool* pool) { + pool_ = pool; + schema_ = schema; + max_batch_size_ = max_batch_size; + schema_mgr_ = schema_mgr; + + RETURN_NOT_OK(AllocateEmptyBuffers()); + return Status::OK(); +} + +ShuffleOutputDesc BatchAccumulation::GetColumn(int column_id) { + return ShuffleOutputDesc(output_buffers_[column_id], output_length_, + output_buffer_has_nulls_[column_id]); +} + +Result> BatchAccumulation::MakeBatch() { + std::unique_ptr out = ::arrow::internal::make_unique(); + out->length = output_length_; + int num_columns = schema_mgr_->num_cols(schema_); + out->values.resize(num_columns); + for (int i = 0; i < num_columns; ++i) { + auto field = schema_mgr_->field(schema_, i); + int null_count = 0; + if (output_buffer_has_nulls_[i]) { + auto valid_count = arrow::internal::CountSetBits( + output_buffers_[i][0]->data(), /*offset=*/0, static_cast(output_length_)); + null_count = static_cast(output_length_) - static_cast(valid_count); + } + + if (field->data_type.is_fixed_length) { + if (null_count > 0) { + out->values[i] = ArrayData::Make( + field->full_data_type, output_length_, + {std::move(output_buffers_[i][0]), std::move(output_buffers_[i][1])}, + null_count); + } else { + out->values[i] = ArrayData::Make(field->full_data_type, output_length_, + {nullptr, std::move(output_buffers_[i][1])}, 0); + } + } else { + if (null_count > 0) { + out->values[i] = ArrayData::Make( + field->full_data_type, output_length_, + {std::move(output_buffers_[i][0]), std::move(output_buffers_[i][1]), + std::move(output_buffers_[i][2])}, + null_count); + } else { + out->values[i] = ArrayData::Make( + field->full_data_type, output_length_, + {nullptr, std::move(output_buffers_[i][1]), std::move(output_buffers_[i][2])}, + 0); + } + } + } + + RETURN_NOT_OK(AllocateEmptyBuffers()); + + return out; +} + +Result> BatchAccumulation::MakeBatchWithJoinData() { + std::unique_ptr out = + ::arrow::internal::make_unique(); + std::unique_ptr out_batch; + if (has_hashes_) { + out->hashes = std::move(hashes_); + } else { + out->hashes = nullptr; + } + ARROW_ASSIGN_OR_RAISE(out_batch, MakeBatch()); + out->batch = *(out_batch.release()); + return out; +} + +Status BatchAccumulation::AllocateEmptyBuffers() { + output_length_ = 0; + int num_columns = schema_mgr_->num_cols(schema_); + output_buffers_.resize(num_columns); + output_buffer_has_nulls_.resize(num_columns); + for (int i = 0; i < num_columns; ++i) { + output_buffers_[i].resize(3); + output_buffer_has_nulls_[i] = false; + ARROW_ASSIGN_OR_RAISE(output_buffers_[i][0], AllocateResizableBuffer(0, pool_)); + ARROW_ASSIGN_OR_RAISE(output_buffers_[i][1], AllocateResizableBuffer(0, pool_)); + ARROW_ASSIGN_OR_RAISE(output_buffers_[i][2], AllocateResizableBuffer(0, pool_)); + } + if (!hashes_) { + ARROW_ASSIGN_OR_RAISE(hashes_, AllocateResizableBuffer(0, pool_)); + } + has_hashes_ = false; + return Status::OK(); +} + +Status BatchJoinAssembler::Init(MemoryPool* pool, int64_t max_batch_size, + JoinColumnMapper* schema_mgr) { + schema_mgr_ = schema_mgr; + RETURN_NOT_OK( + output_buffers_.Init(JoinSchemaHandle::OUTPUT, schema_mgr, max_batch_size, pool)); + bound_batch_side_ = 0; + bound_hash_table_side_ = 0; + bound_batch_ = nullptr; + bound_keys_ = nullptr; + bound_payload_ = nullptr; + return Status::OK(); +} + +void BatchJoinAssembler::BindSourceBatch(int side, const ExecBatch* batch) { + bound_batch_side_ = side; + bound_batch_ = batch; +} + +void BatchJoinAssembler::BindSourceHashTable(int side, + const KeyEncoder::KeyRowArray* keys, + const KeyEncoder::KeyRowArray* payload) { + bound_hash_table_side_ = side; + bound_keys_ = keys; + bound_payload_ = payload; +} + +Result> BatchJoinAssembler::Push( + Shuffle_ThreadLocal& ctx, int num_rows, int hash_table_side, bool is_batch_present, + bool is_hash_table_present, int batch_start_row, const uint16_t* opt_batch_row_ids, + const key_id_type* opt_key_ids, const key_id_type* opt_payload_ids, + int* out_num_rows_processed) { + int num_rows_clamped = + std::min(num_rows, static_cast(output_buffers_.space_left())); + *out_num_rows_processed = num_rows_clamped; + + // For each output batch column find the input column which could be either in an + // input batch or in a hash table. + // + int num_output_columns = schema_mgr_->num_cols(JoinSchemaHandle::OUTPUT); + for (int i = 0; i < num_output_columns; ++i) { + // 0 - input batch, 1 - hash table key, 2 - hash table payload + int source, column_id; + column_id = schema_mgr_->map(JoinSchemaHandle::OUTPUT, + hash_table_side == 0 ? JoinSchemaHandle::FIRST_KEY + : JoinSchemaHandle::SECOND_KEY)[i]; + if (column_id != schema_mgr_->kMissingField) { + source = 1; + } else { + column_id = + schema_mgr_->map(JoinSchemaHandle::OUTPUT, + hash_table_side == 0 ? JoinSchemaHandle::FIRST_PAYLOAD + : JoinSchemaHandle::SECOND_PAYLOAD)[i]; + if (column_id != schema_mgr_->kMissingField) { + source = 2; + } else { + column_id = + schema_mgr_->map(JoinSchemaHandle::OUTPUT, + hash_table_side == 0 ? JoinSchemaHandle::SECOND_INPUT + : JoinSchemaHandle::FIRST_INPUT)[i]; + DCHECK_NE(column_id, schema_mgr_->kMissingField); + source = 0; + } + } + // Switch key from hash table to input batch if input batch is present. + // + if (source == 1 && is_batch_present) { + source = 0; + if (hash_table_side == 0) { + column_id = schema_mgr_->map(JoinSchemaHandle::SECOND_KEY, + JoinSchemaHandle::SECOND_INPUT)[column_id]; + } else { + column_id = schema_mgr_->map(JoinSchemaHandle::FIRST_KEY, + JoinSchemaHandle::FIRST_INPUT)[column_id]; + } + } + // Construct output descriptor + // + ShuffleOutputDesc output = output_buffers_.GetColumn(i); + KeyEncoder::KeyColumnMetadata column_metadata = + schema_mgr_->data_type(JoinSchemaHandle::OUTPUT, i); + + // Find whether the input source is missing, which + // means outputting nulls, or present, which means copying values. + // + if ((source == 0 && !is_batch_present) || (source == 1 && !is_hash_table_present) || + (source == 2 && !is_hash_table_present)) { + RETURN_NOT_OK(AppendNulls(output, column_metadata, num_rows)); + } else { + // Construct input descriptor and perform appropriate shuffle + // + if (source == 0) { + DCHECK(is_batch_present); + DCHECK(bound_batch_); + DCHECK(bound_batch_side_ == 1 - hash_table_side); + auto array = bound_batch_->values[column_id].array(); + BatchShuffle::ShuffleInputDesc input( + array->buffers[0]->data(), array->buffers[1]->data(), + array->buffers[2]->data(), num_rows, batch_start_row, opt_batch_row_ids, + column_metadata); + bool out_has_nulls; + RETURN_NOT_OK(BatchShuffle::Shuffle(output, input, ctx, &out_has_nulls)); + if (out_has_nulls) { + output_buffers_.SetHasNulls(i); + } + } else { + DCHECK(is_hash_table_present); + DCHECK((source == 1 && bound_keys_) || (source == 2 && bound_payload_)); + DCHECK(bound_hash_table_side_ == hash_table_side); + const KeyEncoder::KeyRowArray& rows = + source == 1 ? *bound_keys_ : *bound_payload_; + KeyRowArrayShuffle::ShuffleInputDesc input( + rows, column_id, num_rows, source == 1 ? opt_key_ids : opt_payload_ids); + bool out_has_nulls; + RETURN_NOT_OK(KeyRowArrayShuffle::Shuffle(output, input, ctx, &out_has_nulls)); + if (out_has_nulls) { + output_buffers_.SetHasNulls(i); + } + } + } + } + + output_buffers_.IncreaseLength(num_rows_clamped); + + if (output_buffers_.space_left() == 0) { + return output_buffers_.MakeBatch(); + } + return Status::OK(); +} + +Result> BatchJoinAssembler::Flush() { + if (!output_buffers_.is_empty()) { + return output_buffers_.MakeBatch(); + } else { + return std::unique_ptr(); + } +} + +Status BatchJoinAssembler::AppendNulls(ShuffleOutputDesc& output, + const KeyEncoder::KeyColumnMetadata& metadata, + int num_rows) { + if (num_rows == 0) { + return Status::OK(); + } + RETURN_NOT_OK(output.ResizeBufferFixedLen(num_rows, metadata)); + if (!metadata.is_fixed_length) { + uint32_t* offsets = reinterpret_cast(output.buffer[1]->mutable_data()); + offsets += output.offset; + uint32_t value = offsets[0]; + for (int i = 0; i < num_rows; ++i) { + offsets[i + 1] = value; + } + } + RETURN_NOT_OK(output.ResizeBufferNonNull(num_rows)); + uint8_t* non_nulls = output.buffer[0]->mutable_data(); + int64_t old_size = BitUtil::BytesForBits(output.offset); + if (output.offset % 8 > 0) { + non_nulls[old_size - 1] &= static_cast((1 << (output.offset % 8)) - 1); + } + memset(non_nulls + old_size, 0, BitUtil::BytesForBits(output.offset + num_rows)); + return Status::OK(); +} + +void BatchEarlyFilterEval::Init(int join_side, MemoryPool* pool, + JoinColumnMapper* schema_mgr) { + side_ = join_side; + pool_ = pool; + schema_mgr_ = schema_mgr; +} + +void BatchEarlyFilterEval::SetFilter(bool is_other_side_empty, + const std::vector& null_field_means_no_match, + const ApproximateMembershipTest* hash_based_filter) { + is_other_side_empty_ = is_other_side_empty; + null_field_means_no_match_.resize(null_field_means_no_match.size()); + for (size_t i = 0; i < null_field_means_no_match.size(); ++i) { + null_field_means_no_match_[i] = null_field_means_no_match[i]; + } + hash_based_filter_ = hash_based_filter; +} + +Status BatchEarlyFilterEval::EvalFilter(BatchWithJoinData& batch, int64_t start_row, + int64_t num_rows, uint8_t* filter_bit_vector, + JoinHashTable_ThreadLocal* locals) { + if (is_other_side_empty_) { + memset(filter_bit_vector, 0, BitUtil::BytesForBits(num_rows)); + return Status::OK(); + } + DCHECK(hash_based_filter_); + RETURN_NOT_OK(batch.ComputeHashIfMissing(pool_, schema_mgr_, locals)); + + auto byte_vector_buf = + util::TempVectorHolder(&locals->stack, locals->minibatch_size); + auto byte_vector = byte_vector_buf.mutable_data(); + + for (int64_t batch_start = 0; batch_start < num_rows; ++batch_start) { + int next_batch_size = + std::min(static_cast(num_rows - batch_start), locals->minibatch_size); + hash_based_filter_->MayHaveHash( + locals->encoder_ctx.hardware_flags, next_batch_size, + reinterpret_cast(batch.hashes->data()) + start_row + + batch_start, + byte_vector); + util::BitUtil::bytes_to_bits(locals->encoder_ctx.hardware_flags, next_batch_size, + byte_vector, filter_bit_vector, + static_cast(batch_start)); + } + + EvalNullFilter(batch.batch, start_row, num_rows, filter_bit_vector, locals); + + return Status::OK(); +} + +Result> BatchEarlyFilterEval::FilterBatch( + const BatchWithJoinData& batch, int64_t start, int64_t num_rows, + const uint8_t* filter_bit_vector, int* out_num_rows_processed) { + // TODO: Not implemented yet + return Status::OK(); +} + +Result> BatchEarlyFilterEval::Flush() { + if (!output_buffers_.is_empty()) { + return output_buffers_.MakeBatch(); + } else { + return std::unique_ptr(); + } +} + +void BatchEarlyFilterEval::EvalNullFilter(const ExecBatch& batch, int64_t start_row, + int64_t num_rows, uint8_t* bit_vector_to_update, + JoinHashTable_ThreadLocal* locals) { + JoinSchemaHandle batch_schema = + side_ == 0 ? JoinSchemaHandle::FIRST_INPUT : JoinSchemaHandle::SECOND_INPUT; + JoinSchemaHandle key_schema = + side_ == 0 ? JoinSchemaHandle::FIRST_KEY : JoinSchemaHandle::SECOND_KEY; + int num_key_columns = schema_mgr_->num_cols(key_schema); + const int* column_map = schema_mgr_->map(key_schema, batch_schema); + for (int column_id = 0; column_id < num_key_columns; ++column_id) { + if (!null_field_means_no_match_[column_id]) { + continue; + } + int batch_column_id = column_map[column_id]; + const uint8_t* non_nulls = nullptr; + if (batch[batch_column_id].array()->buffers[0] != NULLPTR) { + non_nulls = batch[batch_column_id].array()->buffers[0]->data(); + } + if (!non_nulls) { + continue; + } + + int64_t offset = batch[batch_column_id].array()->offset + start_row; + + auto ids_buf = + util::TempVectorHolder(&locals->stack, locals->minibatch_size); + auto ids = ids_buf.mutable_data(); + int num_ids; + + for (int64_t start_row_minibatch = 0; start_row_minibatch < num_rows; + start_row_minibatch += locals->minibatch_size) { + int next_batch_size = std::min(static_cast(num_rows - start_row_minibatch), + locals->minibatch_size); + util::BitUtil::bits_to_indexes(0, locals->encoder_ctx.hardware_flags, + next_batch_size, non_nulls, &num_ids, ids, + static_cast(offset + start_row_minibatch)); + for (int i = 0; i < num_ids; ++i) { + BitUtil::ClearBit(bit_vector_to_update, start_row_minibatch + ids[i]); + } + } + } +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/join/join_batch.h b/cpp/src/arrow/compute/exec/join/join_batch.h new file mode 100644 index 00000000000..9ea07a5a4b6 --- /dev/null +++ b/cpp/src/arrow/compute/exec/join/join_batch.h @@ -0,0 +1,283 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/compute/exec.h" +#include "arrow/compute/exec/join/join_filter.h" +#include "arrow/compute/exec/join/join_schema.h" +#include "arrow/compute/exec/join/join_type.h" +#include "arrow/compute/exec/key_encode.h" +#include "arrow/compute/exec/key_hash.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/bit_util.h" + +/* + This file implements operations on exec batches related to hash join processing, such +as: +- moving selected rows and columns from an input batch or a hash table to an output batch +- hash join related projections (hash value) +- hash join related filtering (early filter) +- accumulating filter results and join outputs +*/ + +namespace arrow { +namespace compute { + +// Local context that is provided by a thread when executing batch or KeyRowArray +// shuffle operations. +// +struct Shuffle_ThreadLocal { + Shuffle_ThreadLocal(int64_t in_hardware_flags, util::TempVectorStack* in_temp_stack, + int in_minibatch_size) + : hardware_flags(in_hardware_flags), + temp_stack(in_temp_stack), + minibatch_size(in_minibatch_size) {} + int64_t hardware_flags; + // For simple and fast allocation of temporary vectors + util::TempVectorStack* temp_stack; + // Size of a batch to use with temp_stack allocations, related to + // the total size of memory owned by it + int minibatch_size; +}; + +// Description of output buffers to use for a single column in a shuffle operation. +// +struct ShuffleOutputDesc { + ShuffleOutputDesc(std::vector>& in_buffers, + int64_t in_length, bool in_has_nulls); + Status ResizeBufferNonNull(int num_new_rows); + Status ResizeBufferFixedLen(int num_new_rows, + const KeyEncoder::KeyColumnMetadata& metadata); + Status ResizeBufferVarLen(int num_new_rows); + + ResizableBuffer* buffer[3]; + int64_t offset; + bool has_nulls; +}; + +// Write to output buffers a sequence of input batch fields from a single column +// specified by a sequence of row ids. +// The mapping from output rows to input rows can be +// arbitrary, and does not have to be monotonic or injective. +// Includes special handling for a single contiguous range of row ids. +// +class BatchShuffle { + public: + struct ShuffleInputDesc { + ShuffleInputDesc(const uint8_t* non_null_buf, const uint8_t* fixed_len_buf, + const uint8_t* var_len_buf, int in_num_rows, int in_start_row, + const uint16_t* in_opt_row_ids, + const KeyEncoder::KeyColumnMetadata& in_metadata); + const uint8_t* buffer[3]; + int num_rows; + const uint16_t* opt_row_ids; + int64_t offset; + KeyEncoder::KeyColumnMetadata metadata; + }; + static Status Shuffle(ShuffleOutputDesc& output, const ShuffleInputDesc& input, + Shuffle_ThreadLocal& ctx, bool* out_has_nulls); + + private: + static Status ShuffleNull(ShuffleOutputDesc& output, const ShuffleInputDesc& input, + Shuffle_ThreadLocal& ctx, bool* out_has_nulls); + static void ShuffleBit(ShuffleOutputDesc& output, const ShuffleInputDesc& input, + Shuffle_ThreadLocal& ctx); + template + static void ShuffleInteger(ShuffleOutputDesc& output, const ShuffleInputDesc& input); + static void ShuffleBinary(ShuffleOutputDesc& output, const ShuffleInputDesc& input); + static void ShuffleOffset(ShuffleOutputDesc& output, const ShuffleInputDesc& input); + static void ShuffleVarBinary(ShuffleOutputDesc& output, const ShuffleInputDesc& input); +}; + +// Write to output buffers fields from a single column of rows encoded in KeyRowArray +// according to specified sequence of row ids. +// +class KeyRowArrayShuffle { + public: + struct ShuffleInputDesc { + ShuffleInputDesc(const KeyEncoder::KeyRowArray& in_rows, int in_column_id, + int in_num_rows, const key_id_type* in_row_ids); + const KeyEncoder::KeyRowArray* rows; + int column_id; + int num_rows; + const key_id_type* row_ids; + // Precomputed info for accessing this column's data inside encoded rows + // + KeyEncoder::KeyColumnMetadata metadata; + int null_bit_id; + int varbinary_id; + uint32_t offset_within_row; + }; + + static Status Shuffle(ShuffleOutputDesc& output, const ShuffleInputDesc& input, + Shuffle_ThreadLocal& ctx, bool* out_has_nulls); + + private: + static Status ShuffleNull(ShuffleOutputDesc& output, const ShuffleInputDesc& input, + Shuffle_ThreadLocal& ctx, bool* out_has_nulls); + static void ShuffleBit(ShuffleOutputDesc& output, const ShuffleInputDesc& input, + Shuffle_ThreadLocal& ctx); + template + static void ShuffleInteger(ShuffleOutputDesc& output, const ShuffleInputDesc& input); + static void ShuffleBinary(ShuffleOutputDesc& output, const ShuffleInputDesc& input); + static void ShuffleOffset(ShuffleOutputDesc& output, const ShuffleInputDesc& input); + static void ShuffleVarBinary(ShuffleOutputDesc& output, const ShuffleInputDesc& input); +}; + +struct JoinHashTable_ThreadLocal; + +// Wrapper around ExecBatch that carries computed hash with it after it is evaluated on +// demand for the first time. +// +struct BatchWithJoinData { + BatchWithJoinData() = default; + BatchWithJoinData(int in_join_side, const ExecBatch& in_batch); + + Status ComputeHashIfMissing(MemoryPool* pool, JoinColumnMapper* schema_mgr, + JoinHashTable_ThreadLocal* locals); + Status Encode(int64_t start_row, int64_t num_rows, KeyEncoder& encoder, + KeyEncoder::KeyRowArray& rows, JoinColumnMapper* schema_mgr, + JoinSchemaHandle batch_schema, JoinSchemaHandle output_schema) const; + + int join_side; + ExecBatch batch; + std::shared_ptr hashes; +}; + +// Handles accumulation of rows in output buffers up to the provided batch size before +// producing an exec batch. +// +class BatchAccumulation { + public: + Status Init(JoinSchemaHandle schema, JoinColumnMapper* schema_mgr, + int64_t max_batch_size, MemoryPool* pool); + ShuffleOutputDesc GetColumn(int column_id); + ResizableBuffer* GetHashes() { return hashes_.get(); } + Result> MakeBatch(); + Result> MakeBatchWithJoinData(); + + void SetHasNulls(int column_id) { output_buffer_has_nulls_[column_id] = true; } + void SetHasHashes() { has_hashes_ = true; } + bool is_empty() const { return output_length_ == 0; } + int64_t space_left() const { return max_batch_size_ - output_length_; } + int64_t length() const { return output_length_; } + void IncreaseLength(int64_t delta) { output_length_ += delta; } + + private: + Status AllocateEmptyBuffers(); + + MemoryPool* pool_; + JoinSchemaHandle schema_; + int64_t max_batch_size_; + JoinColumnMapper* schema_mgr_; + std::vector>> output_buffers_; + std::vector output_buffer_has_nulls_; + bool has_hashes_; + std::shared_ptr hashes_; + int64_t output_length_; +}; + +// Assembles output exec batches based on the inputs from an exec batch on one side of the +// join and a hash table on the other side of the join. +// Rows of output batch represent a combination of two rows, one from each side of the +// join, specified by given row ids. Missing row ids are used in outer joins and mean that +// nulls will be used in place of a row columns from one of the sides. +// +class BatchJoinAssembler { + public: + Status Init(MemoryPool* pool, int64_t max_batch_size, JoinColumnMapper* schema_mgr); + void BindSourceBatch(int side, const ExecBatch* batch); + void BindSourceHashTable(int side, const KeyEncoder::KeyRowArray* keys, + const KeyEncoder::KeyRowArray* payload); + + // Returns null as output batch if the resulting number of rows in accumulation buffers + // is less than max batch size. + // Missing batch row ids when batch is present mean a + // sequence of consecutive row ids. + Result> Push(Shuffle_ThreadLocal& ctx, int num_rows, + int hash_table_side, bool is_batch_present, + bool is_hash_table_present, int batch_start_row, + const uint16_t* opt_batch_row_ids, + const key_id_type* opt_key_ids, + const key_id_type* opt_payload_ids, + int* out_num_rows_processed); + + Result> Flush(); + + private: + Status AppendNulls(ShuffleOutputDesc& output, + const KeyEncoder::KeyColumnMetadata& metadata, int num_rows); + + int bound_batch_side_; + int bound_hash_table_side_; + const ExecBatch* bound_batch_; + const KeyEncoder::KeyRowArray* bound_keys_; + const KeyEncoder::KeyRowArray* bound_payload_; + JoinColumnMapper* schema_mgr_; + BatchAccumulation output_buffers_; +}; + +// Evaluation of early filter (a cheap hash-based filter that allows for false positives +// but not false negatives). Also takes care of filtering rows that do not have a match in +// case of a) empty hash table, b) when nulls appear in key columns and null is not equal +// to null. +class BatchEarlyFilterEval { + void Init(int join_side, MemoryPool* pool, JoinColumnMapper* schema_mgr); + void SetFilter(bool is_other_side_empty, + const std::vector& null_field_means_no_match, + const ApproximateMembershipTest* hash_based_filter); + Status EvalFilter(BatchWithJoinData& batch, int64_t start_row, int64_t num_rows, + uint8_t* filter_bit_vector, JoinHashTable_ThreadLocal* locals); + Result> FilterBatch(const BatchWithJoinData& batch, + int64_t start, int64_t num_rows, + const uint8_t* filter_bit_vector, + int* out_num_rows_processed); + Result> Flush(); + + private: + void EvalNullFilter(const ExecBatch& batch, int64_t start_row, int64_t num_rows, + uint8_t* bit_vector_to_update, JoinHashTable_ThreadLocal* locals); + int side_; + MemoryPool* pool_; + JoinColumnMapper* schema_mgr_; + + BatchAccumulation output_buffers_; + + // Filter information + // + bool is_other_side_empty_; + std::vector null_field_means_no_match_; + const ApproximateMembershipTest* hash_based_filter_; +}; + +// Instances of classes that due to accumulation of rows are not +// thread-safe and therefore need a copy per thread. +// +class JoinBatch_ThreadLocal { + // Output assembler is shared by both sides of the join + BatchJoinAssembler assembler; + // One filter for each side of the join + BatchEarlyFilterEval early_filter[2]; +}; + +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/exec/join/join_filter.cc b/cpp/src/arrow/compute/exec/join/join_filter.cc new file mode 100644 index 00000000000..052096856bd --- /dev/null +++ b/cpp/src/arrow/compute/exec/join/join_filter.cc @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec/join/join_filter.h" + +#include "arrow/compute/exec/util.h" + +namespace arrow { +namespace compute { + +constexpr int ApproximateMembershipTest::BitMasksGenerator::bit_width_; +constexpr int ApproximateMembershipTest::BitMasksGenerator::min_bits_set_; +constexpr int ApproximateMembershipTest::BitMasksGenerator::max_bits_set_; +constexpr int ApproximateMembershipTest::BitMasksGenerator::log_num_masks_; +constexpr int ApproximateMembershipTest::BitMasksGenerator::num_masks_; +constexpr int ApproximateMembershipTest::BitMasksGenerator::num_masks_less_one_; + +ApproximateMembershipTest::BitMasksGenerator::BitMasksGenerator() { + memset(masks_, 0, (num_masks_ + 7) / 8 + sizeof(uint64_t)); + util::Random64Bit rnd; + int num_bits_set = rnd.from_range(min_bits_set_, max_bits_set_); + for (int i = 0; i < num_bits_set; ++i) { + for (;;) { + int bit_pos = rnd.from_range(0, bit_width_ - 1); + if (!BitUtil::GetBit(masks_, bit_pos)) { + BitUtil::SetBit(masks_, bit_pos); + break; + } + } + } + for (int next_bit = bit_width_; next_bit < num_masks_ + 64; ++next_bit) { + if (BitUtil::GetBit(masks_, next_bit - bit_width_) && num_bits_set == min_bits_set_) { + // Next bit has to be 1 + BitUtil::SetBit(masks_, next_bit); + } else if (!BitUtil::GetBit(masks_, next_bit - bit_width_) && + num_bits_set == max_bits_set_) { + // Next bit has to be 0 + } else { + // Next bit can be random + if ((rnd.next() % 2) == 0) { + BitUtil::SetBit(masks_, next_bit); + ++num_bits_set; + } + if (BitUtil::GetBit(masks_, next_bit - bit_width_)) { + --num_bits_set; + } + } + } +} + +void ApproximateMembershipTest::MayHaveHash(int64_t hardware_flags, int64_t num_rows, + const uint32_t* hashes, + uint8_t* result) const { +#if defined(ARROW_HAVE_AVX2) + if (hardware_flags & arrow::internal::CpuInfo::AVX2) { + MayHaveHash_avx2(num_rows, hashes, result); + return; + } +#endif + for (int64_t i = 0; i < num_rows; ++i) { + result[i] = MayHaveHash(hashes[i]) ? 0xFF : 0; + } +} + +ApproximateMembershipTest::BitMasksGenerator ApproximateMembershipTest::bit_masks_; + +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/exec/join/join_filter.h b/cpp/src/arrow/compute/exec/join/join_filter.h new file mode 100644 index 00000000000..c367e7fe7bc --- /dev/null +++ b/cpp/src/arrow/compute/exec/join/join_filter.h @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/util/bit_util.h" +#include "arrow/util/ubsan.h" + +/* Implementation of Bloom-like approximate membership test for use in hash join. */ + +namespace arrow { +namespace compute { + +// Only supports single-threaded build. TODO: add parallel build implementation. +// +class ApproximateMembershipTest { + public: + void StartBuild(int64_t num_hashes) { + num_bits_ = num_hashes * 8; + ceil_log_num_bits_ = 0; + while (num_bits_ > static_cast(1ULL << ceil_log_num_bits_)) { + ++ceil_log_num_bits_; + } + hash_mask_num_bits_ = (1ULL << ceil_log_num_bits_) - 1; + int64_t num_bytes = (1 << ceil_log_num_bits_) / 8; + bits_.resize(num_bytes + sizeof(uint64_t)); + memset(bits_.data(), 0, num_bytes + sizeof(uint64_t)); + } + + inline void InsertHash(uint64_t hash) { + uint64_t mask; + int64_t byte_offset; + Prepare(hash, &mask, &byte_offset); + util::SafeStore(bits_.data() + byte_offset, mask); + } + + void FinishBuild() { + uint64_t first_word = util::SafeLoadAs(bits_.data()); + uint64_t last_word = util::SafeLoadAs(bits_.data() + num_bits_ / 8); + util::SafeStore(bits_.data(), first_word | last_word); + util::SafeStore(bits_.data() + num_bits_ / 8, first_word | last_word); + } + + inline bool MayHaveHash(uint64_t hash) const { + uint64_t mask; + int64_t byte_offset; + Prepare(hash, &mask, &byte_offset); + return (util::SafeLoadAs(bits_.data() + byte_offset) & mask) == mask; + } + + void MayHaveHash(int64_t hardware_flags, int64_t num_rows, const uint32_t* hashes, + uint8_t* result) const; + + class BitMasksGenerator { + public: + // In each consecutive "bit_width_" bits, there must be between "min_bits_set_" and + // "max_bits_set_" bits set. + BitMasksGenerator(); + + static constexpr int bit_width_ = 57; + static constexpr int min_bits_set_ = 4; + static constexpr int max_bits_set_ = 5; + + static constexpr int log_num_masks_ = 10; + static constexpr int num_masks_ = 1 << log_num_masks_; + static constexpr int num_masks_less_one_ = num_masks_ - 1; + uint8_t masks_[(num_masks_ + 7) / 8 + sizeof(uint64_t)]; + }; + + private: + inline void Prepare(uint64_t hash, uint64_t* mask, int64_t* byte_offset) const { + int64_t bit_offset0 = hash & (BitMasksGenerator::num_masks_ - 1); + constexpr uint64_t mask_mask = (1ULL << BitMasksGenerator::bit_width_) - 1; + *mask = (util::SafeLoadAs(bit_masks_.masks_ + bit_offset0 / 8) >> + (bit_offset0 % 8)) & + mask_mask; + int64_t bit_offset = + (hash >> (BitMasksGenerator::log_num_masks_)) & hash_mask_num_bits_; + // bit_offset = bit_offset * num_bits_ >> ceil_log_num_bits_; + *mask <<= (bit_offset % 8); + *byte_offset = bit_offset / 8; + } + +#if defined(ARROW_HAVE_AVX2) + template + void MayHaveHash_imp_avx2(int64_t num_rows, const hash_type* hashes, + uint8_t* result) const; + void MayHaveHash_avx2(int64_t num_rows, const uint32_t* hashes, uint8_t* result) const; + void MayHaveHash_avx2(int64_t num_rows, const uint64_t* hashes, uint8_t* result) const; +#endif + + static BitMasksGenerator bit_masks_; + int64_t num_bits_; + int64_t ceil_log_num_bits_; + int64_t hash_mask_num_bits_; + std::vector bits_; +}; + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/join/join_filter_avx2.cc b/cpp/src/arrow/compute/exec/join/join_filter_avx2.cc new file mode 100644 index 00000000000..857f6882a2a --- /dev/null +++ b/cpp/src/arrow/compute/exec/join/join_filter_avx2.cc @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/compute/exec/join/join_filter.h" +#include "arrow/util/logging.h" + +namespace arrow { +namespace compute { + +template +void ApproximateMembershipTest::MayHaveHash_imp_avx2(int64_t num_rows, + const hash_type* hashes, + uint8_t* result) const { + constexpr int unroll = 4; + for (int64_t i = 0; i < num_rows / unroll; ++i) { + __m256i hash; + if (sizeof(hash_type) == sizeof(uint64_t)) { + hash = _mm256_loadu_si256(reinterpret_cast(hashes) + i); + } else { + DCHECK(sizeof(hash_type) == sizeof(uint32_t)); + hash = _mm256_cvtepu32_epi64( + _mm_loadu_si128(reinterpret_cast(hashes))); + } + __m256i bit_offset0 = + _mm256_and_si256(hash, _mm256_set1_epi64x(BitMasksGenerator::num_masks_ - 1)); + __m256i mask = + _mm256_i64gather_epi64(reinterpret_cast(bit_masks_.masks_), + _mm256_srli_epi64(bit_offset0, 3), 1); + mask = _mm256_srlv_epi64(mask, _mm256_and_si256(bit_offset0, _mm256_set1_epi64x(7))); + mask = _mm256_and_si256( + mask, _mm256_set1_epi64x((1ULL << BitMasksGenerator::bit_width_) - 1)); + __m256i bit_offset1 = + _mm256_and_si256(_mm256_srli_epi64(hash, BitMasksGenerator::log_num_masks_), + _mm256_set1_epi64x(hash_mask_num_bits_)); + mask = _mm256_sllv_epi64(mask, _mm256_and_si256(bit_offset1, _mm256_set1_epi64x(7))); + __m256i byte_offset = _mm256_srli_epi64(bit_offset1, 3); + __m256i word = _mm256_i64gather_epi64( + reinterpret_cast(bits_.data()), byte_offset, 1); + uint32_t found = + _mm256_movemask_epi8(_mm256_cmpeq_epi64(mask, _mm256_and_si256(word, mask))); + reinterpret_cast(result)[i] = found; + } + for (int64_t i = num_rows / unroll * unroll; i < num_rows; ++i) { + result[i] = MayHaveHash(hashes[i]); + } +} + +void ApproximateMembershipTest::MayHaveHash_avx2(int64_t num_rows, const uint64_t* hashes, + uint8_t* result) const { + MayHaveHash_imp_avx2(num_rows, hashes, result); +} + +void ApproximateMembershipTest::MayHaveHash_avx2(int64_t num_rows, const uint32_t* hashes, + uint8_t* result) const { + MayHaveHash_imp_avx2(num_rows, hashes, result); +} + +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/exec/join/join_hashtable.cc b/cpp/src/arrow/compute/exec/join/join_hashtable.cc new file mode 100644 index 00000000000..05b313e2092 --- /dev/null +++ b/cpp/src/arrow/compute/exec/join/join_hashtable.cc @@ -0,0 +1,260 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec/join/join_hashtable.h" + +#include "arrow/compute/exec/join/join_batch.h" +#include "arrow/compute/exec/key_compare.h" + +namespace arrow { +namespace compute { + +constexpr int JoinHashTable_ThreadLocal::log_minibatch_size; +constexpr int JoinHashTable_ThreadLocal::minibatch_size; + +Status JoinHashTable_ThreadLocal::Init(MemoryPool* in_pool, int64_t in_hardware_flags, + JoinColumnMapper* schema_mgr) { + pool = in_pool; + RETURN_NOT_OK(stack.Init(pool, 64 * minibatch_size)); + encoder_ctx.hardware_flags = in_hardware_flags; + encoder_ctx.stack = &stack; + + // Key columns encoding + // + std::vector col_metadata; + for (int i = 0; i < schema_mgr->num_cols(JoinSchemaHandle::FIRST_KEY); ++i) { + col_metadata.push_back(schema_mgr->data_type(JoinSchemaHandle::FIRST_KEY, i)); + } + key_encoder.Init(col_metadata, &encoder_ctx, + /* row_alignment = */ sizeof(uint64_t), + /* string_alignment = */ sizeof(uint64_t)); + RETURN_NOT_OK(keys_minibatch.Init(pool, key_encoder.row_metadata())); + + return Status::OK(); +} + +Status JoinHashTable::Init(MemoryPool* pool) { + pool_ = pool; + RETURN_NOT_OK(map_.init( + pool, + [this](int num_keys, const uint16_t* selection /* may be null */, + const uint32_t* group_ids, uint32_t* out_num_keys_mismatch, + uint16_t* out_selection_mismatch, void* callback_ctx) { + this->Equal(num_keys, selection, group_ids, out_num_keys_mismatch, + out_selection_mismatch, callback_ctx); + }, + [this](int num_keys, const uint16_t* selection, void* callback_ctx) -> Status { + return Append(num_keys, selection, callback_ctx); + })); + return Status::OK(); +} + +Status JoinHashTable::Build(int hash_table_side, std::vector& batches, + JoinColumnMapper* schema_mgr, + JoinHashTable_ThreadLocal* locals) { + hash_table_side_ = hash_table_side; + RETURN_NOT_OK(keys_.Init(pool_, locals->key_encoder.row_metadata())); + KeyEncoder payload_encoder; + { + std::vector col_metadata; + JoinSchemaHandle handle = hash_table_side == 0 ? JoinSchemaHandle::FIRST_PAYLOAD + : JoinSchemaHandle::SECOND_PAYLOAD; + for (int i = 0; i < schema_mgr->num_cols(handle); ++i) { + col_metadata.push_back(schema_mgr->data_type(handle, i)); + } + payload_encoder.Init(col_metadata, &locals->encoder_ctx, + /* row_alignment = */ sizeof(uint64_t), + /* string_alignment = */ sizeof(uint64_t)); + } + KeyEncoder::KeyRowArray original_payloads, minibatch_payloads; + RETURN_NOT_OK(original_payloads.Init(pool_, payload_encoder.row_metadata())); + RETURN_NOT_OK(minibatch_payloads.Init(pool_, payload_encoder.row_metadata())); + + std::vector key_ids; + + // TODO: handle the case when there are no payload columns + + std::vector sequence(locals->minibatch_size); + for (int i = 0; i < locals->minibatch_size; ++i) { + sequence[i] = i; + } + + // For each batch... + // + for (size_t ibatch = 0; ibatch < batches.size(); ++ibatch) { + RETURN_NOT_OK(batches[ibatch].ComputeHashIfMissing(pool_, schema_mgr, locals)); + // Break batch into minibatches + // + for (int64_t start = 0; start < batches[ibatch].batch.length; + start += locals->minibatch_size) { + int next_minibatch_size = std::min( + static_cast(batches[ibatch].batch.length - start), locals->minibatch_size); + + size_t row_id = key_ids.size(); + key_ids.resize(row_id + next_minibatch_size); + + // Encode keys + // + RETURN_NOT_OK( + batches[ibatch].Encode(start, next_minibatch_size, locals->key_encoder, + locals->keys_minibatch, schema_mgr, + hash_table_side == 0 ? JoinSchemaHandle::FIRST_INPUT + : JoinSchemaHandle::SECOND_INPUT, + hash_table_side == 0 ? JoinSchemaHandle::FIRST_KEY + : JoinSchemaHandle::SECOND_KEY)); + + // Encode payloads + // + // TODO: append to original_payloads instead of making a copy from + // minibatch_payloads to original_payloads + // + RETURN_NOT_OK(batches[ibatch].Encode( + start, next_minibatch_size, payload_encoder, minibatch_payloads, schema_mgr, + hash_table_side == 0 ? JoinSchemaHandle::FIRST_INPUT + : JoinSchemaHandle::SECOND_INPUT, + hash_table_side == 0 ? JoinSchemaHandle::FIRST_PAYLOAD + : JoinSchemaHandle::SECOND_PAYLOAD)); + RETURN_NOT_OK(original_payloads.AppendSelectionFrom( + minibatch_payloads, next_minibatch_size, sequence.data())); + + SwissTable_ThreadLocal map_ctx(locals->encoder_ctx.hardware_flags, &locals->stack, + locals->log_minibatch_size, locals); + + auto match_bitvector = util::TempVectorHolder( + &locals->stack, + static_cast(BitUtil::BytesForBits(next_minibatch_size))); + + const uint32_t* hashes = + reinterpret_cast(batches[ibatch].hashes->data()) + start; + { + auto local_slots = + util::TempVectorHolder(&locals->stack, next_minibatch_size); + map_.early_filter(locals->encoder_ctx.hardware_flags, next_minibatch_size, hashes, + match_bitvector.mutable_data(), local_slots.mutable_data()); + map_.find(next_minibatch_size, hashes, match_bitvector.mutable_data(), + local_slots.mutable_data(), key_ids.data() + row_id, &map_ctx); + } + auto ids = util::TempVectorHolder(&locals->stack, next_minibatch_size); + int num_ids; + util::BitUtil::bits_to_indexes(0, locals->encoder_ctx.hardware_flags, + next_minibatch_size, match_bitvector.mutable_data(), + &num_ids, ids.mutable_data()); + RETURN_NOT_OK(map_.map_new_keys(num_ids, ids.mutable_data(), hashes, + key_ids.data() + row_id, &map_ctx)); + } + } + + // Bucket sort payloads on key_id + // + RETURN_NOT_OK(BucketSort(payloads_, original_payloads, map_.num_groups(), key_ids, + key_to_payload_, locals)); + + // TODO: allocate memory on the first access + has_match_.resize(map_.num_groups()); + memset(has_match_.data(), 0, map_.num_groups()); + + return Status::OK(); +} + +Status JoinHashTable::Find(BatchWithJoinData& batch, int start_row, int num_rows, + JoinColumnMapper* schema_mgr, + JoinHashTable_ThreadLocal* locals, key_id_type* out_key_ids, + uint8_t* out_found_bitvector) { + RETURN_NOT_OK(batch.ComputeHashIfMissing(pool_, schema_mgr, locals)); + + // Encode keys + // + RETURN_NOT_OK(batch.Encode(start_row, num_rows, locals->key_encoder, + locals->keys_minibatch, schema_mgr, + hash_table_side_ == 0 ? JoinSchemaHandle::FIRST_INPUT + : JoinSchemaHandle::SECOND_INPUT, + hash_table_side_ == 0 ? JoinSchemaHandle::FIRST_KEY + : JoinSchemaHandle::SECOND_KEY)); + + { + const uint32_t* hashes = + reinterpret_cast(batch.hashes->data()) + start_row; + auto local_slots = util::TempVectorHolder(&locals->stack, num_rows); + map_.early_filter(locals->encoder_ctx.hardware_flags, num_rows, hashes, + out_found_bitvector, local_slots.mutable_data()); + SwissTable_ThreadLocal map_ctx(locals->encoder_ctx.hardware_flags, &locals->stack, + locals->log_minibatch_size, locals); + map_.find(num_rows, hashes, out_found_bitvector, local_slots.mutable_data(), + out_key_ids, &map_ctx); + } + return Status::OK(); +} + +void JoinHashTable::Equal(int num_keys, const uint16_t* selection /* may be null */, + const uint32_t* group_ids, uint32_t* out_num_keys_mismatch, + uint16_t* out_selection_mismatch, void* callback_ctx) const { + JoinHashTable_ThreadLocal* locals = + reinterpret_cast(callback_ctx); + arrow::compute::KeyCompare::CompareRows( + num_keys, selection, group_ids, &locals->encoder_ctx, out_num_keys_mismatch, + out_selection_mismatch, locals->keys_minibatch, keys_); +} + +Status JoinHashTable::Append(int num_keys, const uint16_t* selection, + void* callback_ctx) { + JoinHashTable_ThreadLocal* locals = + reinterpret_cast(callback_ctx); + return keys_.AppendSelectionFrom(locals->keys_minibatch, num_keys, selection); +} + +Status JoinHashTable::BucketSort(KeyEncoder::KeyRowArray& output, + const KeyEncoder::KeyRowArray& input, int64_t num_keys, + std::vector& key_ids, + std::vector& key_to_payload, + JoinHashTable_ThreadLocal* locals) { + if (num_keys == 0) { + return Status::OK(); + } + // Reorder key_ids while updating key_to_payload + // + std::vector gather_ids; + gather_ids.resize(key_ids.size()); + key_to_payload.clear(); + key_to_payload.resize(num_keys + 1); + for (size_t i = 0; i < key_to_payload.size(); ++i) { + key_to_payload[i] = 0; + } + for (size_t i = 0; i < key_ids.size(); ++i) { + ++key_to_payload[key_ids[i]]; + } + key_id_type sum = 0; + for (int64_t i = 0; i < num_keys; ++i) { + sum += key_to_payload[i]; + } + key_to_payload[num_keys] = key_to_payload[num_keys - 1]; + for (size_t i = 0; i < key_ids.size(); ++i) { + gather_ids[--key_to_payload[key_ids[i]]] = static_cast(i); + } + key_ids.clear(); + + for (size_t start = 0; start < gather_ids.size(); start += locals->minibatch_size) { + int64_t next_minibatch_size = + std::min(gather_ids.size() - start, static_cast(locals->minibatch_size)); + RETURN_NOT_OK(output.AppendSelectionFrom( + input, static_cast(next_minibatch_size), gather_ids.data() + start)); + } + + return Status::OK(); +} + +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/exec/join/join_hashtable.h b/cpp/src/arrow/compute/exec/join/join_hashtable.h new file mode 100644 index 00000000000..7bddd2313d5 --- /dev/null +++ b/cpp/src/arrow/compute/exec/join/join_hashtable.h @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/compute/exec/join/join_schema.h" +#include "arrow/compute/exec/join/join_type.h" +#include "arrow/compute/exec/key_encode.h" +#include "arrow/compute/exec/key_map.h" +#include "arrow/compute/exec/util.h" +#include "arrow/memory_pool.h" +#include "arrow/status.h" + +/* + This file implements hash table access for joins. + Hash table build is single-threaded. TODO: implement parallel version. + Hash table lookups can be done concurrently (hash table is constant after build has + finished). In order to do that, a thread local context for storage of temporary mutable + buffers needs to be provided. +*/ + +namespace arrow { +namespace compute { + +// Thread local context to be used as a scratch space during access to hash +// table. Allows for concurrent read-only lookups. +// +struct JoinHashTable_ThreadLocal { + Status Init(MemoryPool* in_pool, int64_t in_hardware_flags, + JoinColumnMapper* schema_mgr); + + static constexpr int log_minibatch_size = 10; + static constexpr int minibatch_size = 1 << log_minibatch_size; + MemoryPool* pool; + util::TempVectorStack stack; + KeyEncoder::KeyEncoderContext encoder_ctx; + KeyEncoder key_encoder; + KeyEncoder::KeyRowArray keys_minibatch; +}; + +struct BatchWithJoinData; + +// Represents a hash table and related data created for one of the sides of the hash join. +// Here is a list of child structures included in it: +// - SwissTable for hash-based search of matching key candidates +// - KeyRowArray storing only key columns for all inserted rows in a row-oriented way +// - optional KeyRowArray storing payload columns for all inserted rows; +// note: when multiple rows with the same key are inserted into a hash table, +// KeyRowArray for keys will only contain one copy of key shared by all inserted rows, +// while payload KeyRowArray will store one row for each input row +// - cummulative sum of row multiplicities for all keys; +// this is used to enumerate all matching rows in a hash table for a given key; +// the enumerated rows will be stored next to each other in payload KeyRowArray +// - byte vector with one element per inserted key for marking keys with matches; +// used in order to implement right outer and full outer joins. +// +class JoinHashTable { + public: + Status Init(MemoryPool* pool); + Status Build(int hash_table_side, std::vector& batches, + JoinColumnMapper* schema_mgr, JoinHashTable_ThreadLocal* locals); + Status Find(BatchWithJoinData& batch, int start_row, int num_rows, + JoinColumnMapper* schema_mgr, JoinHashTable_ThreadLocal* locals, + key_id_type* out_key_ids, uint8_t* out_found_bitvector); + const KeyEncoder::KeyRowArray& keys() const { return keys_; } + const KeyEncoder::KeyRowArray& payloads() const { return payloads_; } + const key_id_type* key_to_payload() const { return key_to_payload_.data(); } + const uint8_t* has_match() const { return has_match_.data(); } + + private: + void Equal(int num_keys, const uint16_t* selection /* may be null */, + const uint32_t* group_ids, uint32_t* out_num_keys_mismatch, + uint16_t* out_selection_mismatch, void* callback_ctx) const; + Status Append(int num_keys, const uint16_t* selection, void* callback_ctx); + Status BucketSort(KeyEncoder::KeyRowArray& output, const KeyEncoder::KeyRowArray& input, + int64_t num_keys, std::vector& key_ids, + std::vector& key_to_payload, + JoinHashTable_ThreadLocal* locals); + + MemoryPool* pool_; + SwissTable map_; + int hash_table_side_; + KeyEncoder::KeyRowArray keys_; + KeyEncoder::KeyRowArray payloads_; + std::vector key_to_payload_; + std::vector has_match_; +}; + +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/exec/join/join_probe.h b/cpp/src/arrow/compute/exec/join/join_probe.h new file mode 100644 index 00000000000..e855b7cd957 --- /dev/null +++ b/cpp/src/arrow/compute/exec/join/join_probe.h @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/compute/exec/join/join_type.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/bit_util.h" + +/* + This file implements specific row processing strategies related to different types of + join (left semi-join, full outer join, ...). +*/ + +namespace arrow { +namespace compute { +/* +class MatchingPairsBatchIterator { + public: + /// \param num_batch_rows number of rows in input batch to iterate over + /// \param opt_batch_row_ids row ids in input batch. Null means the entire sequence + /// starting from 0. + /// \param opt_key_to_payload_map maps key id to a range of payload ids. Element indexed + /// by key id gives start of the range, element right after gives end of range (one + /// after the last element). Null means 1-1 mapping from key id to payload id, in which + /// case computing payload ids is redundant and will be skipped. + /// \param payload_is_present false means that there is no payload and therefore + /// computing payload ids is useless and will be skipped + MatchingPairsBatchIterator(int num_batch_rows, const uint16_t* opt_batch_row_ids, + const key_id_type* key_ids, + const key_id_type* opt_key_to_payload_map, + bool payload_is_present); + + bool PayloadIdsSameAsKeyIds() const { + return !(payload_is_present_ && opt_key_to_payload_map_); + } + + /// The caller is responsible for pre-allocating output arrays with at least + /// num_rows_max elements. Payload id array may be null if payload ids are always the + /// same as key ids and therefore not populated by this method. + int GetNextBatch(int num_rows_max, uint16_t* out_batch_row_ids, + key_id_type* out_key_ids, key_id_type* out_opt_payload_ids, + bool* out_payload_ids_same_as_key_ids); + + private: + int num_batch_rows_; + const uint16_t* opt_batch_row_ids_; + const key_id_type* key_ids_; + const key_id_type* opt_key_to_payload_map_; + bool payload_is_present_; + int64_t num_processed_batch_rows_; + int64_t num_processed_matches_for_last_batch_row_; +}; + +void MarkHashTableMatch(); + +// Refer to flow diagram +void HashTableProbe_SemiOrAntiSemi(const ExecBatch& batch) {} + +// Refer to flow diagram +// TODO: in a while loop keep returning batches or take a lambda for outputting batch +// (non-template) +void HashTableProbe_InnerOrOuter(HashTable* hash_table, int side, + const ExecBatch& batch) { + auto join_key_batch = ProjectBatch(batch, input_schema[side], join_key_schema[side]); + if (!batch_filtered) { + // Compute hash + // Execute early filters + } + auto has_match_bitvector; + auto match_iterator_begin; + auto match_iterator_end; + hash_table->Find(has_match_bitvector, match_iterator_begin); + if (multiple_matches_possible) { + // Remap match iterator begin + // Load match iterator end + } else { + // Set match iterator end to match iterator begin plus one + } + // For each batch of matching pairs + std::vector matching_pairs; + for (auto batch = matching_pairs_batch_iterator.begin(); + batch != matching_pairs_batch_iterator.end(); ++matching_pairs_batch_iterator) { + if (has_residual_filter) { + // For each column in residual filter definition + // Is the column coming from batch side or hash table side? + for (;;) { + if (auto result = Find(residual_filter_input_schema[column], + input_schema[side]) != kNotFound) { + batch.AppendColumnShuffle(matching_pairs_batch_iterator.first, result); + } else { + auto result = + Find(residual_filter_input_schema[column], join_payload_schema[side]); + ARROW_DCHECK(result != kNotFound); + join_sides[other_side(side)].payload().AppendColumnShuffle( + matching_pairs_batch_iterator.second, result); + } + } + EvaluateResidualFilter(residual_filter_batch); + // Update match bit vector + // Update matching pairs + } + // Update hash table match vector if needed + // Assemble and output next result batch + } +} + +void HashTableScan() {} +*/ +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/exec/join/join_schema.h b/cpp/src/arrow/compute/exec/join/join_schema.h new file mode 100644 index 00000000000..f97fe0e0ebb --- /dev/null +++ b/cpp/src/arrow/compute/exec/join/join_schema.h @@ -0,0 +1,156 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/compute/exec/key_encode.h" + +/* + This file implements helper classes for simple mapping of corresponding columns from one + type of storage to another. + For instance it is used to lookup a column id in input batch + on either side of the join for a given key column, or map a join output column to a + corresponding column stored in a hash table. +*/ + +namespace arrow { +namespace compute { + +// Identifiers of all different row schemas that appear during processing of a join. +// +enum class JoinSchemaHandle { + FIRST_INPUT, + SECOND_INPUT, + FIRST_KEY, + FIRST_PAYLOAD, + SECOND_KEY, + SECOND_PAYLOAD, + OUTPUT +}; + +struct JoinField { + std::string field_namespace; + std::string field_name; + std::shared_ptr full_data_type; + KeyEncoder::KeyColumnMetadata data_type; +}; + +using schema_type = std::vector; + +/* This is a helper class that makes it simple to map between input column id from + * arbitrary input source to output column id for arbitrary output destination. + * It is also thread-safe. + * Materialized mappings are created lazily during the first access request. + */ +template +class ColumnMapper { + public: + // This call is not thread-safe. Registering needs to be executed on a single thread + // before thread-safe read-only queries can be run. + // + void RegisterHandle(SchemaHandleType schema_handle, + std::shared_ptr schema) { + schemas_.push_back(std::make_pair(schema_handle, schema)); + } + + void RegisterEnd() { + size_t size = schemas_.size(); + mapping_pointers_.resize(size * size); + mapping_buffers_.resize(size * size); + } + + int num_cols(SchemaHandleType schema_handle) const { + int id = schema_id(schema_handle); + return static_cast(schemas_[id].second->size()); + } + + const JoinField* field(SchemaHandleType schema_handle, int field_id) const { + int id = schema_id(schema_handle); + const schema_type* schema = schemas_[id].second.get(); + return &((*schema)[field_id]); + } + + KeyEncoder::KeyColumnMetadata data_type(SchemaHandleType schema_handle, + int field_id) const { + return field(schema_handle, field_id)->data_type; + } + + const int* map(SchemaHandleType from, SchemaHandleType to) { + int id_from = schema_id(from); + int id_to = schema_id(to); + int num_schemas = static_cast(schemas_.size()); + int pos = id_from * num_schemas + id_to; + const int* ptr = mapping_pointers_[pos]; + if (!ptr) { + std::lock_guard lock(mutex_); + if (!ptr) { + int num_cols_from = static_cast(schemas_[id_from].second->size()); + int num_cols_to = static_cast(schemas_[id_to].second->size()); + mapping_buffers_[pos].resize(num_cols_from); + const std::vector& fields_from = *(schemas_[id_from].second.get()); + const std::vector& fields_to = *(schemas_[id_to].second.get()); + for (int i = 0; i < num_cols_from; ++i) { + int field_id = kMissingField; + for (int j = 0; j < num_cols_to; ++j) { + if (fields_from[i].field_namespace.compare(fields_to[j].field_namespace) == + 0 && + fields_from[i].field_name.compare(fields_to[j].field_name) == 0) { + DCHECK(fields_from[i].data_type.is_fixed_length == + fields_to[j].data_type.is_fixed_length && + fields_from[i].data_type.fixed_length == + fields_to[j].data_type.fixed_length); + field_id = j; + break; + } + } + mapping_buffers_[pos][i] = field_id; + } + mapping_pointers_[pos] = mapping_buffers_[pos].data(); + } + ptr = mapping_pointers_[pos]; + } + return ptr; + } + + static constexpr int kMissingField = -1; + + private: + int schema_id(SchemaHandleType schema_handle) const { + for (size_t i = 0; i < schemas_.size(); ++i) { + if (schemas_[i].first == schema_handle) { + return static_cast(i); + } + } + // We should never get here + DCHECK(false); + return -1; + } + + std::vector mapping_pointers_; + std::vector> mapping_buffers_; + std::vector>> schemas_; + std::mutex mutex_; +}; + +using JoinColumnMapper = ColumnMapper; + +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/exec/join/join_side.h b/cpp/src/arrow/compute/exec/join/join_side.h new file mode 100644 index 00000000000..2bb69b8bd72 --- /dev/null +++ b/cpp/src/arrow/compute/exec/join/join_side.h @@ -0,0 +1,116 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/compute/exec/join/join_batch.h" +#include "arrow/compute/exec/join/join_filter.h" + +/* + This file contains declarations of data that is stored for each of two sides of the hash + join. Classes here represent the state of processing of each side during execution of + each hash join. +*/ + +namespace arrow { +namespace compute { + +// READING_INPUT is a start state and FINISHED is an end state. +// Transition graph is acyclic. The end state can be reached from any state. +// All other states have only one valid input state (their predecessor on the list below). +enum class JoinSideState : uint8_t { + READING_INPUT, + SAVED_INPUT_READY, + EARLY_FILTER_READY, + HASH_TABLE_READY, + FINISHED, +}; + +class JoinState { + JoinState() { states_[0] = states_[1] = JoinSideState::READING_INPUT; } + + // Return true if transition was successful, false if transition has already been + bool StateTransition(int side, JoinSideState new_state) { + std::lock_guard lock(mutex_); + JoinSideState old_state = states_[side]; + switch (new_state) { + case JoinSideState::READING_INPUT: + return false; + case JoinSideState::SAVED_INPUT_READY: + if (old_state == JoinSideState::READING_INPUT) { + states_[side] = new_state; + return true; + } else { + return false; + } + case JoinSideState::EARLY_FILTER_READY: + if (old_state == JoinSideState::SAVED_INPUT_READY) { + states_[side] = new_state; + return true; + } else { + return false; + } + case JoinSideState::HASH_TABLE_READY: + if (old_state == JoinSideState::EARLY_FILTER_READY) { + states_[side] = new_state; + return true; + } else { + return false; + } + case JoinSideState::FINISHED: + if (old_state != JoinSideState::FINISHED) { + states_[side] = new_state; + return true; + } else { + return false; + } + default: + return false; + } + } + + JoinSideState state(int side) const { return states_[side]; } + + private: + JoinSideState states_[2]; + std::mutex mutex_; +}; + +struct JoinSideData { + void SaveBatch(const ExecBatch& batch, bool is_filtered) { + std::lock_guard lock(mutex); + if (is_filtered) { + saved_filtered.push_back(BatchWithJoinData(batch)); + } else { + saved.push_back(BatchWithJoinData(batch)); + } + } + std::vector saved; + std::vector saved_filtered; + ApproximateMembershipTest early_filter; + std::mutex mutex; +}; + +static inline int other_side(int side) { + ARROW_DCHECK(side == 0 || side == 1); + return 1 - side; +} + +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/exec/join/join_type.h b/cpp/src/arrow/compute/exec/join/join_type.h new file mode 100644 index 00000000000..881c06838ba --- /dev/null +++ b/cpp/src/arrow/compute/exec/join/join_type.h @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/util/logging.h" + +namespace arrow { +namespace compute { + +enum class JoinType { + LEFT_SEMI, + RIGHT_SEMI, + LEFT_ANTI, + RIGHT_ANTI, + INNER, + LEFT_OUTER, + RIGHT_OUTER, + FULL_OUTER +}; + +class ReversibleJoinType { + ReversibleJoinType(JoinType join_type_when_first_child_is_probe); + JoinType get(int probe_side) { + if (probe_side == 0) { + return join_type_when_first_child_is_probe_; + } else { + ARROW_DCHECK(probe_side == 1); + switch (join_type_when_first_child_is_probe_) { + case JoinType::LEFT_SEMI: + return JoinType::RIGHT_SEMI; + break; + case JoinType::RIGHT_SEMI: + return JoinType::LEFT_SEMI; + break; + case JoinType::LEFT_ANTI: + return JoinType::RIGHT_ANTI; + break; + case JoinType::RIGHT_ANTI: + return JoinType::LEFT_ANTI; + break; + case JoinType::INNER: + return JoinType::INNER; + break; + case JoinType::LEFT_OUTER: + return JoinType::RIGHT_OUTER; + break; + case JoinType::RIGHT_OUTER: + return JoinType::LEFT_OUTER; + break; + case JoinType::FULL_OUTER: + return JoinType::FULL_OUTER; + break; + } + } + } + + private: + JoinType join_type_when_first_child_is_probe_; +}; + +using hash_type = uint32_t; +using key_id_type = uint32_t; + +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/exec/key_encode.cc b/cpp/src/arrow/compute/exec/key_encode.cc index f517561bfaa..ac00da92ee9 100644 --- a/cpp/src/arrow/compute/exec/key_encode.cc +++ b/cpp/src/arrow/compute/exec/key_encode.cc @@ -175,9 +175,10 @@ Status KeyEncoder::KeyRowArray::ResizeOptionalVaryingLengthBuffer( return Status::OK(); } -Status KeyEncoder::KeyRowArray::AppendSelectionFrom(const KeyRowArray& from, - uint32_t num_rows_to_append, - const uint16_t* source_row_ids) { +template +Status KeyEncoder::KeyRowArray::AppendSelectionFromImp( + const KeyRowArray& from, uint32_t num_rows_to_append, + const row_id_type* source_row_ids) { DCHECK(metadata_.is_compatible(from.metadata())); RETURN_NOT_OK(ResizeFixedLengthBuffers(num_rows_to_append)); @@ -189,7 +190,7 @@ Status KeyEncoder::KeyRowArray::AppendSelectionFrom(const KeyRowArray& from, uint32_t total_length = to_offsets[num_rows_]; uint32_t total_length_to_append = 0; for (uint32_t i = 0; i < num_rows_to_append; ++i) { - uint16_t row_id = source_row_ids[i]; + auto row_id = source_row_ids[i]; uint32_t length = from_offsets[row_id + 1] - from_offsets[row_id]; total_length_to_append += length; to_offsets[num_rows_ + i + 1] = total_length + total_length_to_append; @@ -200,7 +201,7 @@ Status KeyEncoder::KeyRowArray::AppendSelectionFrom(const KeyRowArray& from, const uint8_t* src = from.rows_->data(); uint8_t* dst = rows_->mutable_data() + total_length; for (uint32_t i = 0; i < num_rows_to_append; ++i) { - uint16_t row_id = source_row_ids[i]; + auto row_id = source_row_ids[i]; uint32_t length = from_offsets[row_id + 1] - from_offsets[row_id]; auto src64 = reinterpret_cast(src + from_offsets[row_id]); auto dst64 = reinterpret_cast(dst); @@ -214,7 +215,7 @@ Status KeyEncoder::KeyRowArray::AppendSelectionFrom(const KeyRowArray& from, const uint8_t* src = from.rows_->data(); uint8_t* dst = rows_->mutable_data() + num_rows_ * metadata_.fixed_length; for (uint32_t i = 0; i < num_rows_to_append; ++i) { - uint16_t row_id = source_row_ids[i]; + auto row_id = source_row_ids[i]; uint32_t length = metadata_.fixed_length; auto src64 = reinterpret_cast(src + length * row_id); auto dst64 = reinterpret_cast(dst); @@ -231,7 +232,7 @@ Status KeyEncoder::KeyRowArray::AppendSelectionFrom(const KeyRowArray& from, const uint8_t* src_base = from.null_masks_->data(); uint8_t* dst_base = null_masks_->mutable_data(); for (uint32_t i = 0; i < num_rows_to_append; ++i) { - uint32_t row_id = source_row_ids[i]; + auto row_id = source_row_ids[i]; int64_t src_byte_offset = row_id * byte_length; const uint8_t* src = src_base + src_byte_offset; uint8_t* dst = dst_base + dst_byte_offset; @@ -246,6 +247,18 @@ Status KeyEncoder::KeyRowArray::AppendSelectionFrom(const KeyRowArray& from, return Status::OK(); } +Status KeyEncoder::KeyRowArray::AppendSelectionFrom(const KeyRowArray& from, + uint32_t num_rows_to_append, + const uint16_t* source_row_ids) { + return AppendSelectionFromImp(from, num_rows_to_append, source_row_ids); +} + +Status KeyEncoder::KeyRowArray::AppendSelectionFrom(const KeyRowArray& from, + uint32_t num_rows_to_append, + const uint32_t* source_row_ids) { + return AppendSelectionFromImp(from, num_rows_to_append, source_row_ids); +} + Status KeyEncoder::KeyRowArray::AppendEmpty(uint32_t num_rows_to_append, uint32_t num_extra_bytes_to_append) { RETURN_NOT_OK(ResizeFixedLengthBuffers(num_rows_to_append)); @@ -866,8 +879,8 @@ void KeyEncoder::EncoderBinaryPair::Encode(uint32_t offset_within_row, KeyRowArr } DCHECK(temp1->metadata().is_fixed_length); - DCHECK_GE(temp1->length() * temp1->metadata().fixed_length, - col1.length() * static_cast(sizeof(uint16_t))); + DCHECK(temp1->length() * temp1->metadata().fixed_length >= + col1.length() * static_cast(sizeof(uint16_t))); KeyColumnArray temp16bit(KeyColumnMetadata(true, sizeof(uint16_t)), col1.length(), nullptr, temp1->mutable_data(1), nullptr); @@ -1386,7 +1399,6 @@ void KeyEncoder::KeyRowMetadata::FromColumnMetadataVector( // // c) Fixed-length columns precede varying-length columns when // both have the same size fixed-length part. - // column_order.resize(num_cols); for (uint32_t i = 0; i < num_cols; ++i) { column_order[i] = i; diff --git a/cpp/src/arrow/compute/exec/key_encode.h b/cpp/src/arrow/compute/exec/key_encode.h index d4dd499a20d..abad345672f 100644 --- a/cpp/src/arrow/compute/exec/key_encode.h +++ b/cpp/src/arrow/compute/exec/key_encode.h @@ -189,6 +189,8 @@ class KeyEncoder { Status AppendEmpty(uint32_t num_rows_to_append, uint32_t num_extra_bytes_to_append); Status AppendSelectionFrom(const KeyRowArray& from, uint32_t num_rows_to_append, const uint16_t* source_row_ids); + Status AppendSelectionFrom(const KeyRowArray& from, uint32_t num_rows_to_append, + const uint32_t* source_row_ids); const KeyRowMetadata& metadata() const { return metadata_; } int64_t length() const { return num_rows_; } const uint8_t* data(int i) const { @@ -207,6 +209,10 @@ class KeyEncoder { bool has_any_nulls(const KeyEncoderContext* ctx) const; private: + template + Status AppendSelectionFromImp(const KeyRowArray& from, uint32_t num_rows_to_append, + const row_id_type* source_row_ids); + Status ResizeFixedLengthBuffers(int64_t num_extra_rows); Status ResizeOptionalVaryingLengthBuffer(int64_t num_extra_bytes); diff --git a/cpp/src/arrow/compute/exec/key_map.cc b/cpp/src/arrow/compute/exec/key_map.cc index 43c011c016b..e9ec0186071 100644 --- a/cpp/src/arrow/compute/exec/key_map.cc +++ b/cpp/src/arrow/compute/exec/key_map.cc @@ -137,57 +137,60 @@ void SwissTable::extract_group_ids_imp(const int num_keys, const uint16_t* selec } } -void SwissTable::extract_group_ids(const int num_keys, const uint16_t* optional_selection, +void SwissTable::extract_group_ids(int64_t hardware_flags, const int num_keys, + const uint16_t* optional_selection, const uint32_t* hashes, const uint8_t* local_slots, uint32_t* out_group_ids) const { // Group id values for all 8 slots in the block are bit-packed and follow the status // bytes. We assume here that the number of bits is rounded up to 8, 16, 32 or 64. In // that case we can extract group id using aligned 64-bit word access. int num_group_id_bits = num_groupid_bits_from_log_blocks(log_blocks_); + int num_group_id_bytes = num_group_id_bits / 8; ARROW_DCHECK(num_group_id_bits == 8 || num_group_id_bits == 16 || num_group_id_bits == 32); // Optimistically use simplified lookup involving only a start block to find // a single group id candidate for every input. #if defined(ARROW_HAVE_AVX2) - int num_group_id_bytes = num_group_id_bits / 8; - if ((hardware_flags_ & arrow::internal::CpuInfo::AVX2) && !optional_selection) { + if ((hardware_flags & arrow::internal::CpuInfo::AVX2) && !optional_selection) { extract_group_ids_avx2(num_keys, hashes, local_slots, out_group_ids, sizeof(uint64_t), 8 + 8 * num_group_id_bytes, num_group_id_bytes); - return; - } + } else { #endif - switch (num_group_id_bits) { - case 8: - if (optional_selection) { - extract_group_ids_imp(num_keys, optional_selection, hashes, - local_slots, out_group_ids, 8, 16); - } else { - extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, - out_group_ids, 8, 16); - } - break; - case 16: - if (optional_selection) { - extract_group_ids_imp(num_keys, optional_selection, hashes, - local_slots, out_group_ids, 4, 12); - } else { - extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, - out_group_ids, 4, 12); - } - break; - case 32: - if (optional_selection) { - extract_group_ids_imp(num_keys, optional_selection, hashes, - local_slots, out_group_ids, 2, 10); - } else { - extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, - out_group_ids, 2, 10); - } - break; - default: - ARROW_DCHECK(false); + switch (num_group_id_bits) { + case 8: + if (optional_selection) { + extract_group_ids_imp(num_keys, optional_selection, hashes, + local_slots, out_group_ids, 8, 16); + } else { + extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, + out_group_ids, 8, 16); + } + break; + case 16: + if (optional_selection) { + extract_group_ids_imp(num_keys, optional_selection, hashes, + local_slots, out_group_ids, 4, 12); + } else { + extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, + out_group_ids, 4, 12); + } + break; + case 32: + if (optional_selection) { + extract_group_ids_imp(num_keys, optional_selection, hashes, + local_slots, out_group_ids, 2, 10); + } else { + extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, + out_group_ids, 2, 10); + } + break; + default: + ARROW_DCHECK(false); + }; +#if defined(ARROW_HAVE_AVX2) } +#endif } void SwissTable::init_slot_ids(const int num_keys, const uint16_t* selection, @@ -195,22 +198,13 @@ void SwissTable::init_slot_ids(const int num_keys, const uint16_t* selection, const uint8_t* match_bitvector, uint32_t* out_slot_ids) const { ARROW_DCHECK(selection); - if (log_blocks_ == 0) { - for (int i = 0; i < num_keys; ++i) { - uint16_t id = selection[i]; - uint32_t match = ::arrow::BitUtil::GetBit(match_bitvector, id) ? 1 : 0; - uint32_t slot_id = local_slots[id] + match; - out_slot_ids[id] = slot_id; - } - } else { - for (int i = 0; i < num_keys; ++i) { - uint16_t id = selection[i]; - uint32_t hash = hashes[id]; - uint32_t iblock = (hash >> (bits_hash_ - log_blocks_)); - uint32_t match = ::arrow::BitUtil::GetBit(match_bitvector, id) ? 1 : 0; - uint32_t slot_id = iblock * 8 + local_slots[id] + match; - out_slot_ids[id] = slot_id; - } + for (int i = 0; i < num_keys; ++i) { + uint16_t id = selection[i]; + uint32_t hash = hashes[id]; + uint32_t iblock = (hash >> (bits_hash_ - log_blocks_)); + uint32_t match = ::arrow::BitUtil::GetBit(match_bitvector, id) ? 1 : 0; + uint32_t slot_id = iblock * 8 + local_slots[id] + match; + out_slot_ids[id] = slot_id; } } @@ -306,13 +300,13 @@ uint64_t SwissTable::wrap_global_slot_id(uint64_t global_slot_id) const { return global_slot_id & global_slot_id_mask; } -void SwissTable::early_filter(const int num_keys, const uint32_t* hashes, - uint8_t* out_match_bitvector, +void SwissTable::early_filter(int64_t hardware_flags, const int num_keys, + const uint32_t* hashes, uint8_t* out_match_bitvector, uint8_t* out_local_slots) const { // Optimistically use simplified lookup involving only a start block to find // a single group id candidate for every input. #if defined(ARROW_HAVE_AVX2) - if (hardware_flags_ & arrow::internal::CpuInfo::AVX2) { + if (hardware_flags & arrow::internal::CpuInfo::AVX2) { if (log_blocks_ <= 4) { int tail = num_keys % 32; int delta = num_keys - tail; @@ -347,7 +341,8 @@ void SwissTable::run_comparisons(const int num_keys, const uint16_t* optional_selection_ids, const uint8_t* optional_selection_bitvector, const uint32_t* groupids, int* out_num_not_equal, - uint16_t* out_not_equal_selection) const { + uint16_t* out_not_equal_selection, + SwissTable_ThreadLocal* local) const { ARROW_DCHECK(optional_selection_ids || optional_selection_bitvector); ARROW_DCHECK(!optional_selection_ids || !optional_selection_bitvector); @@ -367,21 +362,22 @@ void SwissTable::run_comparisons(const int num_keys, if (num_inserted_ > 0 && num_matches > 0 && num_matches > 3 * num_keys / 4) { uint32_t out_num; - equal_impl_(num_keys, nullptr, groupids, &out_num, out_not_equal_selection); + equal_impl_(num_keys, nullptr, groupids, &out_num, out_not_equal_selection, + local->callback_ctx); *out_num_not_equal = static_cast(out_num); } else { - util::BitUtil::bits_to_indexes(1, hardware_flags_, num_keys, + util::BitUtil::bits_to_indexes(1, local->hardware_flags, num_keys, optional_selection_bitvector, out_num_not_equal, out_not_equal_selection); uint32_t out_num; equal_impl_(*out_num_not_equal, out_not_equal_selection, groupids, &out_num, - out_not_equal_selection); + out_not_equal_selection, local->callback_ctx); *out_num_not_equal = static_cast(out_num); } } else { uint32_t out_num; equal_impl_(num_keys, optional_selection_ids, groupids, &out_num, - out_not_equal_selection); + out_not_equal_selection, local->callback_ctx); *out_num_not_equal = static_cast(out_num); } } @@ -470,7 +466,7 @@ void SwissTable::insert_into_empty_slot(uint32_t slot_id, uint32_t hash, // void SwissTable::find(const int num_keys, const uint32_t* hashes, uint8_t* inout_match_bitvector, const uint8_t* local_slots, - uint32_t* out_group_ids) const { + uint32_t* out_group_ids, SwissTable_ThreadLocal* local) const { // Temporary selection vector. // It will hold ids of keys for which we do not know yet // if they have a match in hash table or not. @@ -479,8 +475,8 @@ void SwissTable::find(const int num_keys, const uint32_t* hashes, // match bit-vector. Eventually we switch from this bit-vector // to array of ids. // - ARROW_DCHECK(num_keys <= (1 << log_minibatch_)); - auto ids_buf = util::TempVectorHolder(temp_stack_, num_keys); + ARROW_DCHECK(num_keys <= (1 << local->log_minibatch)); + auto ids_buf = util::TempVectorHolder(local->temp_stack, num_keys); uint16_t* ids = ids_buf.mutable_data(); int num_ids; @@ -495,21 +491,23 @@ void SwissTable::find(const int num_keys, const uint32_t* hashes, // bool visit_all = num_matches > 0 && num_matches > 3 * num_keys / 4; if (visit_all) { - extract_group_ids(num_keys, nullptr, hashes, local_slots, out_group_ids); + extract_group_ids(local->hardware_flags, num_keys, nullptr, hashes, local_slots, + out_group_ids); run_comparisons(num_keys, nullptr, inout_match_bitvector, out_group_ids, &num_ids, - ids); + ids, local); } else { - util::BitUtil::bits_to_indexes(1, hardware_flags_, num_keys, inout_match_bitvector, - &num_ids, ids); - extract_group_ids(num_ids, ids, hashes, local_slots, out_group_ids); - run_comparisons(num_ids, ids, nullptr, out_group_ids, &num_ids, ids); + util::BitUtil::bits_to_indexes(1, local->hardware_flags, num_keys, + inout_match_bitvector, &num_ids, ids); + extract_group_ids(local->hardware_flags, num_ids, ids, hashes, local_slots, + out_group_ids); + run_comparisons(num_ids, ids, nullptr, out_group_ids, &num_ids, ids, local); } if (num_ids == 0) { return; } - auto slot_ids_buf = util::TempVectorHolder(temp_stack_, num_ids); + auto slot_ids_buf = util::TempVectorHolder(local->temp_stack, num_ids); uint32_t* slot_ids = slot_ids_buf.mutable_data(); init_slot_ids(num_ids, ids, hashes, local_slots, inout_match_bitvector, slot_ids); @@ -530,9 +528,9 @@ void SwissTable::find(const int num_keys, const uint32_t* hashes, } } - run_comparisons(num_ids, ids, nullptr, out_group_ids, &num_ids, ids); + run_comparisons(num_ids, ids, nullptr, out_group_ids, &num_ids, ids, local); } -} // namespace compute +} // Slow processing of input keys in the most generic case. // Handles inserting new keys. @@ -548,18 +546,19 @@ Status SwissTable::map_new_keys_helper(const uint32_t* hashes, uint32_t* inout_num_selected, uint16_t* inout_selection, bool* out_need_resize, uint32_t* out_group_ids, - uint32_t* inout_next_slot_ids) { + uint32_t* inout_next_slot_ids, + SwissTable_ThreadLocal* local) { auto num_groups_limit = num_groups_for_resize(); ARROW_DCHECK(num_inserted_ < num_groups_limit); // Temporary arrays are of limited size. // The input needs to be split into smaller portions if it exceeds that limit. // - ARROW_DCHECK(*inout_num_selected <= static_cast(1 << log_minibatch_)); + ARROW_DCHECK(*inout_num_selected <= static_cast(1 << local->log_minibatch)); size_t num_bytes_for_bits = (*inout_num_selected + 7) / 8 + sizeof(uint64_t); auto match_bitvector_buf = util::TempVectorHolder( - temp_stack_, static_cast(num_bytes_for_bits)); + local->temp_stack, static_cast(num_bytes_for_bits)); uint8_t* match_bitvector = match_bitvector_buf.mutable_data(); memset(match_bitvector, 0xff, num_bytes_for_bits); @@ -593,23 +592,25 @@ Status SwissTable::map_new_keys_helper(const uint32_t* hashes, } auto temp_ids_buffer = - util::TempVectorHolder(temp_stack_, *inout_num_selected); + util::TempVectorHolder(local->temp_stack, *inout_num_selected); uint16_t* temp_ids = temp_ids_buffer.mutable_data(); int num_temp_ids = 0; // Copy keys for newly inserted rows using callback // - util::BitUtil::bits_filter_indexes(0, hardware_flags_, num_processed, match_bitvector, - inout_selection, &num_temp_ids, temp_ids); + util::BitUtil::bits_filter_indexes(0, local->hardware_flags, num_processed, + match_bitvector, inout_selection, &num_temp_ids, + temp_ids); ARROW_DCHECK(static_cast(num_inserted_new) == num_temp_ids); - RETURN_NOT_OK(append_impl_(num_inserted_new, temp_ids)); + RETURN_NOT_OK(append_impl_(num_inserted_new, temp_ids, local->callback_ctx)); num_inserted_ += num_inserted_new; // Evaluate comparisons and append ids of rows that failed it to the non-match set. - util::BitUtil::bits_filter_indexes(1, hardware_flags_, num_processed, match_bitvector, - inout_selection, &num_temp_ids, temp_ids); - run_comparisons(num_temp_ids, temp_ids, nullptr, out_group_ids, &num_temp_ids, - temp_ids); + util::BitUtil::bits_filter_indexes(1, local->hardware_flags, num_processed, + match_bitvector, inout_selection, &num_temp_ids, + temp_ids); + run_comparisons(num_temp_ids, temp_ids, nullptr, out_group_ids, &num_temp_ids, temp_ids, + local); memcpy(inout_selection, temp_ids, sizeof(uint16_t) * num_temp_ids); // Append ids of any unprocessed entries if we aborted processing due to the need @@ -628,7 +629,7 @@ Status SwissTable::map_new_keys_helper(const uint32_t* hashes, // this set). // Status SwissTable::map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t* hashes, - uint32_t* group_ids) { + uint32_t* group_ids, SwissTable_ThreadLocal* local) { if (num_ids == 0) { return Status::OK(); } @@ -640,11 +641,11 @@ Status SwissTable::map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t* // Temporary buffers have limited size. // Caller is responsible for splitting larger input arrays into smaller chunks. - ARROW_DCHECK(static_cast(num_ids) <= (1 << log_minibatch_)); - ARROW_DCHECK(static_cast(max_id + 1) <= (1 << log_minibatch_)); + ARROW_DCHECK(static_cast(num_ids) <= (1 << local->log_minibatch)); + ARROW_DCHECK(static_cast(max_id + 1) <= (1 << local->log_minibatch)); // Allocate temporary buffers for slot ids and intialize them - auto slot_ids_buf = util::TempVectorHolder(temp_stack_, max_id + 1); + auto slot_ids_buf = util::TempVectorHolder(local->temp_stack, max_id + 1); uint32_t* slot_ids = slot_ids_buf.mutable_data(); init_slot_ids_for_new_keys(num_ids, ids, hashes, slot_ids); @@ -657,7 +658,7 @@ Status SwissTable::map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t* // bigger hash table. bool out_of_capacity; RETURN_NOT_OK(map_new_keys_helper(hashes, &num_ids, ids, &out_of_capacity, group_ids, - slot_ids)); + slot_ids, local)); if (out_of_capacity) { RETURN_NOT_OK(grow_double()); // Reset start slot ids for still unprocessed input keys. @@ -802,13 +803,8 @@ Status SwissTable::grow_double() { return Status::OK(); } -Status SwissTable::init(int64_t hardware_flags, MemoryPool* pool, - util::TempVectorStack* temp_stack, int log_minibatch, - EqualImpl equal_impl, AppendImpl append_impl) { - hardware_flags_ = hardware_flags; +Status SwissTable::init(MemoryPool* pool, EqualImpl equal_impl, AppendImpl append_impl) { pool_ = pool; - temp_stack_ = temp_stack; - log_minibatch_ = log_minibatch; equal_impl_ = equal_impl; append_impl_ = append_impl; diff --git a/cpp/src/arrow/compute/exec/key_map.h b/cpp/src/arrow/compute/exec/key_map.h index 286bbed038a..adde4b600cf 100644 --- a/cpp/src/arrow/compute/exec/key_map.h +++ b/cpp/src/arrow/compute/exec/key_map.h @@ -27,6 +27,19 @@ namespace arrow { namespace compute { +struct SwissTable_ThreadLocal { + SwissTable_ThreadLocal(int64_t in_hardware_flags, util::TempVectorStack* in_temp_stack, + int in_log_minibatch, void* in_callback_ctx) + : hardware_flags(in_hardware_flags), + temp_stack(in_temp_stack), + log_minibatch(in_log_minibatch), + callback_ctx(in_callback_ctx) {} + int64_t hardware_flags; + util::TempVectorStack* temp_stack; + int log_minibatch; + void* callback_ctx; +}; + class SwissTable { public: SwissTable() = default; @@ -35,22 +48,25 @@ class SwissTable { using EqualImpl = std::function; - using AppendImpl = std::function; + uint16_t* out_selection_mismatch, void* local_ctx)>; + using AppendImpl = + std::function; - Status init(int64_t hardware_flags, MemoryPool* pool, util::TempVectorStack* temp_stack, - int log_minibatch, EqualImpl equal_impl, AppendImpl append_impl); + Status init(MemoryPool* pool, EqualImpl equal_impl, AppendImpl append_impl); void cleanup(); - void early_filter(const int num_keys, const uint32_t* hashes, + void early_filter(int64_t hardware_flags, const int num_keys, const uint32_t* hashes, uint8_t* out_match_bitvector, uint8_t* out_local_slots) const; void find(const int num_keys, const uint32_t* hashes, uint8_t* inout_match_bitvector, - const uint8_t* local_slots, uint32_t* out_group_ids) const; + const uint8_t* local_slots, uint32_t* out_group_ids, + SwissTable_ThreadLocal* locals) const; Status map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t* hashes, - uint32_t* group_ids); + uint32_t* group_ids, SwissTable_ThreadLocal* locals); + + int64_t num_groups() const { return num_inserted_; } private: // Lookup helpers @@ -87,9 +103,9 @@ class SwissTable { inline uint64_t extract_group_id(const uint8_t* block_ptr, int slot, uint64_t group_id_mask) const; - void extract_group_ids(const int num_keys, const uint16_t* optional_selection, - const uint32_t* hashes, const uint8_t* local_slots, - uint32_t* out_group_ids) const; + void extract_group_ids(int64_t hardware_flags, const int num_keys, + const uint16_t* optional_selection, const uint32_t* hashes, + const uint8_t* local_slots, uint32_t* out_group_ids) const; template void extract_group_ids_imp(const int num_keys, const uint16_t* selection, @@ -131,7 +147,8 @@ class SwissTable { void run_comparisons(const int num_keys, const uint16_t* optional_selection_ids, const uint8_t* optional_selection_bitvector, const uint32_t* groupids, int* out_num_not_equal, - uint16_t* out_not_equal_selection) const; + uint16_t* out_not_equal_selection, + SwissTable_ThreadLocal* local) const; inline bool find_next_stamp_match(const uint32_t hash, const uint32_t in_slot_id, uint32_t* out_slot_id, uint32_t* out_group_id) const; @@ -146,7 +163,8 @@ class SwissTable { // Status map_new_keys_helper(const uint32_t* hashes, uint32_t* inout_num_selected, uint16_t* inout_selection, bool* out_need_resize, - uint32_t* out_group_ids, uint32_t* out_next_slot_ids); + uint32_t* out_group_ids, uint32_t* out_next_slot_ids, + SwissTable_ThreadLocal* local); // Resize small hash tables when 50% full (up to 8KB). // Resize large hash tables when 75% full. @@ -169,7 +187,6 @@ class SwissTable { // Padding bytes added at the end of buffers for ease of SIMD access static constexpr int padding_ = 64; - int log_minibatch_; // Base 2 log of the number of blocks int log_blocks_ = 0; // Number of keys inserted into hash table @@ -195,9 +212,7 @@ class SwissTable { // There is 64B padding at the end. uint32_t* hashes_; - int64_t hardware_flags_; MemoryPool* pool_; - util::TempVectorStack* temp_stack_; EqualImpl equal_impl_; AppendImpl append_impl_; diff --git a/cpp/src/arrow/compute/exec/util.h b/cpp/src/arrow/compute/exec/util.h index 6214c76b517..fbb7e1c8985 100644 --- a/cpp/src/arrow/compute/exec/util.h +++ b/cpp/src/arrow/compute/exec/util.h @@ -18,6 +18,7 @@ #pragma once #include +#include #include #include "arrow/buffer.h" @@ -180,5 +181,22 @@ class BitUtil { #endif }; +class Random64Bit { + public: + Random64Bit() : rs{0, 0, 0, 0, 0, 0, 0, 0}, re(rs) {} + uint64_t next() { return rdist(re); } + template + inline T from_range(const T& min_val, const T& max_val) { + return static_cast(min_val + (next() % (max_val - min_val + 1))); + } + std::mt19937& get_engine() { return re; } + + private: + std::random_device rd; + std::seed_seq rs; + std::mt19937 re; + std::uniform_int_distribution rdist; +}; + } // namespace util } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 9d620242e10..1f5e1a5d43b 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -513,19 +513,17 @@ struct GrouperFastImpl : Grouper { auto equal_func = [impl_ptr]( int num_keys_to_compare, const uint16_t* selection_may_be_null, const uint32_t* group_ids, uint32_t* out_num_keys_mismatch, - uint16_t* out_selection_mismatch) { + uint16_t* out_selection_mismatch, void*) { arrow::compute::KeyCompare::CompareRows( num_keys_to_compare, selection_may_be_null, group_ids, &impl_ptr->encode_ctx_, out_num_keys_mismatch, out_selection_mismatch, impl_ptr->rows_minibatch_, impl_ptr->rows_); }; - auto append_func = [impl_ptr](int num_keys, const uint16_t* selection) { + auto append_func = [impl_ptr](int num_keys, const uint16_t* selection, void*) { return impl_ptr->rows_.AppendSelectionFrom(impl_ptr->rows_minibatch_, num_keys, selection); }; - RETURN_NOT_OK(impl->map_.init(impl->encode_ctx_.hardware_flags, ctx->memory_pool(), - impl->encode_ctx_.stack, impl->log_minibatch_max_, - equal_func, append_func)); + RETURN_NOT_OK(impl->map_.init(ctx->memory_pool(), equal_func, append_func)); impl->cols_.resize(num_columns); impl->minibatch_hashes_.resize(impl->minibatch_size_max_ + kPaddingForSIMD / sizeof(uint32_t)); @@ -608,13 +606,17 @@ struct GrouperFastImpl : Grouper { // Map auto match_bitvector = util::TempVectorHolder(&temp_stack_, (batch_size_next + 7) / 8); + SwissTable_ThreadLocal map_ctx(encode_ctx_.hardware_flags, &temp_stack_, + log_minibatch_max_, nullptr); { auto local_slots = util::TempVectorHolder(&temp_stack_, batch_size_next); - map_.early_filter(batch_size_next, minibatch_hashes_.data(), - match_bitvector.mutable_data(), local_slots.mutable_data()); + map_.early_filter(encode_ctx_.hardware_flags, batch_size_next, + minibatch_hashes_.data(), match_bitvector.mutable_data(), + local_slots.mutable_data()); map_.find(batch_size_next, minibatch_hashes_.data(), match_bitvector.mutable_data(), local_slots.mutable_data(), - reinterpret_cast(group_ids->mutable_data()) + start_row); + reinterpret_cast(group_ids->mutable_data()) + start_row, + &map_ctx); } auto ids = util::TempVectorHolder(&temp_stack_, batch_size_next); int num_ids; @@ -624,7 +626,7 @@ struct GrouperFastImpl : Grouper { RETURN_NOT_OK(map_.map_new_keys( num_ids, ids.mutable_data(), minibatch_hashes_.data(), - reinterpret_cast(group_ids->mutable_data()) + start_row)); + reinterpret_cast(group_ids->mutable_data()) + start_row, &map_ctx)); start_row += batch_size_next;