diff --git a/cpp/src/arrow/compute/exec/key_encode.cc b/cpp/src/arrow/compute/exec/key_encode.cc index 1a563867e90..8ab76cd27b3 100644 --- a/cpp/src/arrow/compute/exec/key_encode.cc +++ b/cpp/src/arrow/compute/exec/key_encode.cc @@ -828,15 +828,23 @@ void KeyEncoder::KeyRowMetadata::FromColumnMetadataVector( const auto num_cols = static_cast(cols.size()); // Sort columns. + // // Columns are sorted based on the size in bytes of their fixed-length part. // For the varying-length column, the fixed-length part is the 32-bit field storing // cumulative length of varying-length fields. + // // The rules are: + // // a) Boolean column, marked with fixed-length 0, is considered to have fixed-length - // part of 1 byte. b) Columns with fixed-length part being power of 2 or multiple of row - // alignment precede other columns. They are sorted among themselves based on size of - // fixed-length part. c) Fixed-length columns precede varying-length columns when both - // have the same size fixed-length part. + // part of 1 byte. + // + // b) Columns with fixed-length part being power of 2 or multiple of row + // alignment precede other columns. They are sorted in decreasing order of the size of + // their fixed-length part. + // + // c) Fixed-length columns precede varying-length columns when + // both have the same size fixed-length part. + // column_order.resize(num_cols); for (uint32_t i = 0; i < num_cols; ++i) { column_order[i] = i; diff --git a/cpp/src/arrow/compute/exec/key_map.cc b/cpp/src/arrow/compute/exec/key_map.cc index eaa2ae3e39f..f06a090cf85 100644 --- a/cpp/src/arrow/compute/exec/key_map.cc +++ b/cpp/src/arrow/compute/exec/key_map.cc @@ -34,22 +34,20 @@ namespace compute { constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL; -// Search status bytes inside a block of 8 slots (64-bit word). -// Try to find a slot that contains a 7-bit stamp matching the one provided. -// There are three possible outcomes: -// 1. A matching slot is found. -// -> Return its index between 0 and 7 and set match found flag. -// 2. A matching slot is not found and there is an empty slot in the block. -// -> Return the index of the first empty slot and clear match found flag. -// 3. A matching slot is not found and there are no empty slots in the block. -// -> Return 8 as the output slot index and clear match found flag. +// Scan bytes in block in reverse and stop as soon +// as a position of interest is found. +// +// Positions of interest: +// a) slot with a matching stamp is encountered, +// b) first empty slot is encountered, +// c) we reach the end of the block. // // Optionally an index of the first slot to start the search from can be specified. // In this case slots before it will be ignored. // template inline void SwissTable::search_block(uint64_t block, int stamp, int start_slot, - int* out_slot, int* out_match_found) { + int* out_slot, int* out_match_found) const { // Filled slot bytes have the highest bit set to 0 and empty slots are equal to 0x80. uint64_t block_high_bits = block & kHighBitOfEachByte; @@ -82,6 +80,11 @@ inline void SwissTable::search_block(uint64_t block, int stamp, int start_slot, matches &= kHighBitOfEachByte; } + // In case when there are no matches in slots and the block is full (no empty slots), + // pretend that there is a match in the last slot. + // + matches |= (~block_high_bits & 0x80); + // We get 0 if there are no matches *out_match_found = (matches == 0 ? 0 : 1); @@ -91,28 +94,16 @@ inline void SwissTable::search_block(uint64_t block, int stamp, int start_slot, *out_slot = static_cast(CountLeadingZeros(matches | block_high_bits) >> 3); } -// This call follows the call to search_block. -// The input slot index is the output returned by it, which is a value from 0 to 8, -// with 8 indicating that both: no match was found and there were no empty slots. -// -// If the slot corresponds to a non-empty slot return a group id associated with it. -// Otherwise return any group id from any of the slots or -// zero, which is the default value stored in empty slots. -// inline uint64_t SwissTable::extract_group_id(const uint8_t* block_ptr, int slot, uint64_t group_id_mask) const { - // Input slot can be equal to 8, in which case we need to output any valid group id - // value, so we take the one from slot 0 in the block. - int clamped_slot = slot & 7; - // Group id values for all 8 slots in the block are bit-packed and follow the status // bytes. We assume here that the number of bits is rounded up to 8, 16, 32 or 64. In // that case we can extract group id using aligned 64-bit word access. - int num_groupid_bits = static_cast(ARROW_POPCOUNT64(group_id_mask)); - ARROW_DCHECK(num_groupid_bits == 8 || num_groupid_bits == 16 || - num_groupid_bits == 32 || num_groupid_bits == 64); + int num_group_id_bits = static_cast(ARROW_POPCOUNT64(group_id_mask)); + ARROW_DCHECK(num_group_id_bits == 8 || num_group_id_bits == 16 || + num_group_id_bits == 32 || num_group_id_bits == 64); - int bit_offset = clamped_slot * num_groupid_bits; + int bit_offset = slot * num_group_id_bits; const uint64_t* group_id_bytes = reinterpret_cast(block_ptr) + 1 + (bit_offset >> 6); uint64_t group_id = (*group_id_bytes >> (bit_offset & 63)) & group_id_mask; @@ -120,48 +111,159 @@ inline uint64_t SwissTable::extract_group_id(const uint8_t* block_ptr, int slot, return group_id; } -// Return global slot id (the index including the information about the block) -// where the search should continue if the first comparison fails. -// This function always follows search_block and receives the slot id returned by it. -// -inline uint64_t SwissTable::next_slot_to_visit(uint64_t block_index, int slot, - int match_found) { - // The result should be taken modulo the number of all slots in all blocks, - // but here we allow it to take a value one above the last slot index. - // Modulo operation is postponed to later. - return block_index * 8 + slot + match_found; +template +void SwissTable::extract_group_ids_imp(const int num_keys, const uint16_t* selection, + const uint32_t* hashes, const uint8_t* local_slots, + uint32_t* out_group_ids, int element_offset, + int element_multiplier) const { + const T* elements = reinterpret_cast(blocks_) + element_offset; + if (log_blocks_ == 0) { + ARROW_DCHECK(sizeof(T) == sizeof(uint8_t)); + for (int i = 0; i < num_keys; ++i) { + uint32_t id = use_selection ? selection[i] : i; + uint32_t group_id = blocks_[8 + local_slots[id]]; + out_group_ids[id] = group_id; + } + } else { + for (int i = 0; i < num_keys; ++i) { + uint32_t id = use_selection ? selection[i] : i; + uint32_t hash = hashes[id]; + int64_t pos = + (hash >> (bits_hash_ - log_blocks_)) * element_multiplier + local_slots[id]; + uint32_t group_id = static_cast(elements[pos]); + ARROW_DCHECK(group_id < num_inserted_ || num_inserted_ == 0); + out_group_ids[id] = group_id; + } + } } -// Implements first (fast-path, optimistic) lookup. -// Searches for a match only within the start block and -// trying only the first slot with a matching stamp. -// -// Comparison callback needed for match verification is done outside of this function. -// Match bit vector filled by it only indicates finding a matching stamp in a slot. +void SwissTable::extract_group_ids(const int num_keys, const uint16_t* optional_selection, + const uint32_t* hashes, const uint8_t* local_slots, + uint32_t* out_group_ids) const { + // Group id values for all 8 slots in the block are bit-packed and follow the status + // bytes. We assume here that the number of bits is rounded up to 8, 16, 32 or 64. In + // that case we can extract group id using aligned 64-bit word access. + int num_group_id_bits = num_groupid_bits_from_log_blocks(log_blocks_); + ARROW_DCHECK(num_group_id_bits == 8 || num_group_id_bits == 16 || + num_group_id_bits == 32); + + // Optimistically use simplified lookup involving only a start block to find + // a single group id candidate for every input. +#if defined(ARROW_HAVE_AVX2) + int num_group_id_bytes = num_group_id_bits / 8; + if ((hardware_flags_ & arrow::internal::CpuInfo::AVX2) && !optional_selection) { + extract_group_ids_avx2(num_keys, hashes, local_slots, out_group_ids, sizeof(uint64_t), + 8 + 8 * num_group_id_bytes, num_group_id_bytes); + return; + } +#endif + switch (num_group_id_bits) { + case 8: + if (optional_selection) { + extract_group_ids_imp(num_keys, optional_selection, hashes, + local_slots, out_group_ids, 8, 16); + } else { + extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, + out_group_ids, 8, 16); + } + break; + case 16: + if (optional_selection) { + extract_group_ids_imp(num_keys, optional_selection, hashes, + local_slots, out_group_ids, 4, 12); + } else { + extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, + out_group_ids, 4, 12); + } + break; + case 32: + if (optional_selection) { + extract_group_ids_imp(num_keys, optional_selection, hashes, + local_slots, out_group_ids, 2, 10); + } else { + extract_group_ids_imp(num_keys, nullptr, hashes, local_slots, + out_group_ids, 2, 10); + } + break; + default: + ARROW_DCHECK(false); + } +} + +void SwissTable::init_slot_ids(const int num_keys, const uint16_t* selection, + const uint32_t* hashes, const uint8_t* local_slots, + const uint8_t* match_bitvector, + uint32_t* out_slot_ids) const { + ARROW_DCHECK(selection); + if (log_blocks_ == 0) { + for (int i = 0; i < num_keys; ++i) { + uint16_t id = selection[i]; + uint32_t match = ::arrow::BitUtil::GetBit(match_bitvector, id) ? 1 : 0; + uint32_t slot_id = local_slots[id] + match; + out_slot_ids[id] = slot_id; + } + } else { + for (int i = 0; i < num_keys; ++i) { + uint16_t id = selection[i]; + uint32_t hash = hashes[id]; + uint32_t iblock = (hash >> (bits_hash_ - log_blocks_)); + uint32_t match = ::arrow::BitUtil::GetBit(match_bitvector, id) ? 1 : 0; + uint32_t slot_id = iblock * 8 + local_slots[id] + match; + out_slot_ids[id] = slot_id; + } + } +} + +void SwissTable::init_slot_ids_for_new_keys(uint32_t num_ids, const uint16_t* ids, + const uint32_t* hashes, + uint32_t* slot_ids) const { + int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); + uint32_t num_block_bytes = num_groupid_bits + 8; + if (log_blocks_ == 0) { + uint64_t block = *reinterpret_cast(blocks_); + uint32_t empty_slot = + static_cast(8 - ARROW_POPCOUNT64(block & kHighBitOfEachByte)); + for (uint32_t i = 0; i < num_ids; ++i) { + int id = ids[i]; + slot_ids[id] = empty_slot; + } + } else { + for (uint32_t i = 0; i < num_ids; ++i) { + int id = ids[i]; + uint32_t hash = hashes[id]; + uint32_t iblock = hash >> (bits_hash_ - log_blocks_); + uint64_t block; + for (;;) { + block = *reinterpret_cast(blocks_ + num_block_bytes * iblock); + block &= kHighBitOfEachByte; + if (block) { + break; + } + iblock = (iblock + 1) & ((1 << log_blocks_) - 1); + } + uint32_t empty_slot = static_cast(8 - ARROW_POPCOUNT64(block)); + slot_ids[id] = iblock * 8 + empty_slot; + } + } +} + +// Quickly filter out keys that have no matches based only on hash value and the +// corresponding starting 64-bit block of slot status bytes. May return false positives. // -template -void SwissTable::lookup_1(const uint16_t* selection, const int num_keys, - const uint32_t* hashes, uint8_t* out_match_bitvector, - uint32_t* out_groupids, uint32_t* out_slot_ids) { +void SwissTable::early_filter_imp(const int num_keys, const uint32_t* hashes, + uint8_t* out_match_bitvector, + uint8_t* out_local_slots) const { // Clear the output bit vector memset(out_match_bitvector, 0, (num_keys + 7) / 8); // Based on the size of the table, prepare bit number constants. uint32_t stamp_mask = (1 << bits_stamp_) - 1; int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); - uint32_t groupid_mask = (1 << num_groupid_bits) - 1; for (int i = 0; i < num_keys; ++i) { - int id; - if (use_selection) { - id = util::SafeLoad(&selection[i]); - } else { - id = i; - } - // Extract from hash: block index and stamp // - uint32_t hash = hashes[id]; + uint32_t hash = hashes[i]; uint32_t iblock = hash >> (bits_hash_ - bits_stamp_ - log_blocks_); uint32_t stamp = iblock & stamp_mask; iblock >>= bits_stamp_; @@ -169,22 +271,19 @@ void SwissTable::lookup_1(const uint16_t* selection, const int num_keys, uint32_t num_block_bytes = num_groupid_bits + 8; const uint8_t* blockbase = reinterpret_cast(blocks_) + static_cast(iblock) * num_block_bytes; - uint64_t block = util::SafeLoadAs(blockbase); + ARROW_DCHECK(num_block_bytes % sizeof(uint64_t) == 0); + uint64_t block = *reinterpret_cast(blockbase); // Call helper functions to obtain the output triplet: // - match (of a stamp) found flag - // - group id for key comparison - // - slot to resume search from in case of no match or false positive + // - number of slots to skip before resuming further search, in case of no match or + // false positive int match_found; int islot_in_block; search_block(block, stamp, 0, &islot_in_block, &match_found); - uint64_t groupid = extract_group_id(blockbase, islot_in_block, groupid_mask); - ARROW_DCHECK(groupid < num_inserted_ || num_inserted_ == 0); - uint64_t islot = next_slot_to_visit(iblock, islot_in_block, match_found); - out_match_bitvector[id / 8] |= match_found << (id & 7); - util::SafeStore(&out_groupids[id], static_cast(groupid)); - util::SafeStore(&out_slot_ids[id], static_cast(islot)); + out_match_bitvector[i / 8] |= match_found << (i & 7); + out_local_slots[i] = static_cast(islot_in_block); } } @@ -203,18 +302,254 @@ uint64_t SwissTable::num_groups_for_resize() const { } } -uint64_t SwissTable::wrap_global_slot_id(uint64_t global_slot_id) { +uint64_t SwissTable::wrap_global_slot_id(uint64_t global_slot_id) const { uint64_t global_slot_id_mask = (1 << (log_blocks_ + 3)) - 1; return global_slot_id & global_slot_id_mask; } -// Run a single round of slot search - comparison / insert - filter unprocessed. +void SwissTable::early_filter(const int num_keys, const uint32_t* hashes, + uint8_t* out_match_bitvector, + uint8_t* out_local_slots) const { + // Optimistically use simplified lookup involving only a start block to find + // a single group id candidate for every input. +#if defined(ARROW_HAVE_AVX2) + if (hardware_flags_ & arrow::internal::CpuInfo::AVX2) { + if (log_blocks_ <= 4) { + int tail = num_keys % 32; + int delta = num_keys - tail; + early_filter_imp_avx2_x32(num_keys - tail, hashes, out_match_bitvector, + out_local_slots); + early_filter_imp_avx2_x8(tail, hashes + delta, out_match_bitvector + delta / 8, + out_local_slots + delta); + } else { + early_filter_imp_avx2_x8(num_keys, hashes, out_match_bitvector, out_local_slots); + } + } else { +#endif + early_filter_imp(num_keys, hashes, out_match_bitvector, out_local_slots); +#if defined(ARROW_HAVE_AVX2) + } +#endif +} + +// Input selection may be: +// - a range of all ids from 0 to num_keys - 1 +// - a selection vector with list of ids +// - a bit-vector marking ids that are included +// Either selection index vector or selection bit-vector must be provided +// but both cannot be set at the same time (one must be null). +// +// Input and output selection index vectors are allowed to point to the same buffer +// (in-place filtering of ids). +// +// Output selection vector needs to have enough space for num_keys entries. +// +void SwissTable::run_comparisons(const int num_keys, + const uint16_t* optional_selection_ids, + const uint8_t* optional_selection_bitvector, + const uint32_t* groupids, int* out_num_not_equal, + uint16_t* out_not_equal_selection) const { + ARROW_DCHECK(optional_selection_ids || optional_selection_bitvector); + ARROW_DCHECK(!optional_selection_ids || !optional_selection_bitvector); + + if (!optional_selection_ids && optional_selection_bitvector) { + // Count rows with matches (based on stamp comparison) + // and decide based on their percentage whether to call dense or sparse comparison + // function. Dense comparison means evaluating it for all inputs, even if the + // matching stamp was not found. It may be cheaper to evaluate comparison for all + // inputs if the extra cost of filtering is higher than the wasted processing of + // rows with no match. + // + // Dense comparison can only be used if there is at least one inserted key, + // because otherwise there is no key to compare to. + // + int64_t num_matches = arrow::internal::CountSetBits(optional_selection_bitvector, + /*offset=*/0, num_keys); + + if (num_inserted_ > 0 && num_matches > 0 && num_matches > 3 * num_keys / 4) { + uint32_t out_num; + equal_impl_(num_keys, nullptr, groupids, &out_num, out_not_equal_selection); + *out_num_not_equal = static_cast(out_num); + } else { + util::BitUtil::bits_to_indexes(1, hardware_flags_, num_keys, + optional_selection_bitvector, out_num_not_equal, + out_not_equal_selection); + uint32_t out_num; + equal_impl_(*out_num_not_equal, out_not_equal_selection, groupids, &out_num, + out_not_equal_selection); + *out_num_not_equal = static_cast(out_num); + } + } else { + uint32_t out_num; + equal_impl_(num_keys, optional_selection_ids, groupids, &out_num, + out_not_equal_selection); + *out_num_not_equal = static_cast(out_num); + } +} + +// Given starting slot index, search blocks for a matching stamp +// until one is found or an empty slot is reached. +// If the search stopped on a non-empty slot, output corresponding +// group id from that slot. +// +// Return true if a match was found. +// +bool SwissTable::find_next_stamp_match(const uint32_t hash, const uint32_t in_slot_id, + uint32_t* out_slot_id, + uint32_t* out_group_id) const { + const uint64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); + constexpr uint64_t stamp_mask = 0x7f; + const int stamp = + static_cast((hash >> (bits_hash_ - log_blocks_ - bits_stamp_)) & stamp_mask); + uint64_t start_slot_id = wrap_global_slot_id(in_slot_id); + int match_found; + int local_slot; + uint8_t* blockbase; + for (;;) { + const uint64_t num_block_bytes = (8 + num_groupid_bits); + blockbase = blocks_ + num_block_bytes * (start_slot_id >> 3); + uint64_t block = *reinterpret_cast(blockbase); + + search_block(block, stamp, (start_slot_id & 7), &local_slot, &match_found); + + start_slot_id = + wrap_global_slot_id((start_slot_id & ~7ULL) + local_slot + match_found); + + // Match found can be 1 in two cases: + // - match was found + // - match was not found in a full block + // In the second case search needs to continue in the next block. + if (match_found == 0 || blockbase[7 - local_slot] == stamp) { + break; + } + } + + const uint64_t groupid_mask = (1ULL << num_groupid_bits) - 1; + *out_group_id = + static_cast(extract_group_id(blockbase, local_slot, groupid_mask)); + *out_slot_id = static_cast(start_slot_id); + + return match_found; +} + +void SwissTable::insert_into_empty_slot(uint32_t slot_id, uint32_t hash, + uint32_t group_id) { + const uint64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); + + // We assume here that the number of bits is rounded up to 8, 16, 32 or 64. + // In that case we can insert group id value using aligned 64-bit word access. + ARROW_DCHECK(num_groupid_bits == 8 || num_groupid_bits == 16 || + num_groupid_bits == 32 || num_groupid_bits == 64); + + const uint64_t num_block_bytes = (8 + num_groupid_bits); + constexpr uint64_t stamp_mask = 0x7f; + + int start_slot = (slot_id & 7); + int stamp = + static_cast((hash >> (bits_hash_ - log_blocks_ - bits_stamp_)) & stamp_mask); + uint64_t block_id = slot_id >> 3; + uint8_t* blockbase = blocks_ + num_block_bytes * block_id; + + blockbase[7 - start_slot] = static_cast(stamp); + int groupid_bit_offset = static_cast(start_slot * num_groupid_bits); + + // Block status bytes should start at an address aligned to 8 bytes + ARROW_DCHECK((reinterpret_cast(blockbase) & 7) == 0); + uint64_t* ptr = reinterpret_cast(blockbase) + 1 + (groupid_bit_offset >> 6); + *ptr |= (static_cast(group_id) << (groupid_bit_offset & 63)); + + hashes_[slot_id] = hash; +} + +// Find method is the continuation of processing from early_filter. +// Its input consists of hash values and the output of early_filter. +// It updates match bit-vector, clearing it from any false positives +// that might have been left by early_filter. +// It also outputs group ids, which are needed to be able to execute +// key comparisons. The caller may discard group ids if only the +// match flag is of interest. +// +void SwissTable::find(const int num_keys, const uint32_t* hashes, + uint8_t* inout_match_bitvector, const uint8_t* local_slots, + uint32_t* out_group_ids) const { + // Temporary selection vector. + // It will hold ids of keys for which we do not know yet + // if they have a match in hash table or not. + // + // Initially the set of these keys is represented by input + // match bit-vector. Eventually we switch from this bit-vector + // to array of ids. + // + ARROW_DCHECK(num_keys <= (1 << log_minibatch_)); + auto ids_buf = util::TempVectorHolder(temp_stack_, num_keys); + uint16_t* ids = ids_buf.mutable_data(); + int num_ids; + + int64_t num_matches = + arrow::internal::CountSetBits(inout_match_bitvector, /*offset=*/0, num_keys); + + // If there is a high density of selected input rows + // (majority of them are present in the selection), + // we may run some computation on all of the input rows ignoring + // selection and then filter the output of this computation + // (pre-filtering vs post-filtering). + // + bool visit_all = num_matches > 0 && num_matches > 3 * num_keys / 4; + if (visit_all) { + extract_group_ids(num_keys, nullptr, hashes, local_slots, out_group_ids); + run_comparisons(num_keys, nullptr, inout_match_bitvector, out_group_ids, &num_ids, + ids); + } else { + util::BitUtil::bits_to_indexes(1, hardware_flags_, num_keys, inout_match_bitvector, + &num_ids, ids); + extract_group_ids(num_ids, ids, hashes, local_slots, out_group_ids); + run_comparisons(num_ids, ids, nullptr, out_group_ids, &num_ids, ids); + } + + if (num_ids == 0) { + return; + } + + auto slot_ids_buf = util::TempVectorHolder(temp_stack_, num_ids); + uint32_t* slot_ids = slot_ids_buf.mutable_data(); + init_slot_ids(num_ids, ids, hashes, local_slots, inout_match_bitvector, slot_ids); + + while (num_ids > 0) { + int num_ids_last_iteration = num_ids; + num_ids = 0; + for (int i = 0; i < num_ids_last_iteration; ++i) { + int id = ids[i]; + uint32_t next_slot_id; + bool match_found = find_next_stamp_match(hashes[id], slot_ids[id], &next_slot_id, + &(out_group_ids[id])); + slot_ids[id] = next_slot_id; + // If next match was not found then clear match bit in a bit vector + if (!match_found) { + ::arrow::BitUtil::ClearBit(inout_match_bitvector, id); + } else { + ids[num_ids++] = id; + } + } + + run_comparisons(num_ids, ids, nullptr, out_group_ids, &num_ids, ids); + } +} // namespace compute + +// Slow processing of input keys in the most generic case. +// Handles inserting new keys. +// Pre-existing keys will be handled correctly, although the intended use is for this +// call to follow a call to find() method, which would only pass on new keys that were +// not present in the hash table. +// +// Run a single round of slot search - comparison or insert - filter unprocessed. // Update selection vector to reflect which items have been processed. // Ids in selection vector do not have to be sorted. // -Status SwissTable::lookup_2(const uint32_t* hashes, uint32_t* inout_num_selected, - uint16_t* inout_selection, bool* out_need_resize, - uint32_t* out_group_ids, uint32_t* inout_next_slot_ids) { +Status SwissTable::map_new_keys_helper(const uint32_t* hashes, + uint32_t* inout_num_selected, + uint16_t* inout_selection, bool* out_need_resize, + uint32_t* out_group_ids, + uint32_t* inout_next_slot_ids) { auto num_groups_limit = num_groups_for_resize(); ARROW_DCHECK(num_inserted_ < num_groups_limit); @@ -223,198 +558,115 @@ Status SwissTable::lookup_2(const uint32_t* hashes, uint32_t* inout_num_selected // ARROW_DCHECK(*inout_num_selected <= static_cast(1 << log_minibatch_)); - // We will split input row ids into three categories: - // - needing to visit next block [0] - // - needing comparison [1] - // - inserted [2] - // - auto ids_inserted_buf = - util::TempVectorHolder(temp_stack_, *inout_num_selected); - auto ids_for_comparison_buf = - util::TempVectorHolder(temp_stack_, *inout_num_selected); - constexpr int category_nomatch = 0; - constexpr int category_cmp = 1; - constexpr int category_inserted = 2; - int num_ids[3]; - num_ids[0] = num_ids[1] = num_ids[2] = 0; - uint16_t* ids[3]{inout_selection, ids_for_comparison_buf.mutable_data(), - ids_inserted_buf.mutable_data()}; - auto push_id = [&num_ids, &ids](int category, int id) { - util::SafeStore(&ids[category][num_ids[category]++], static_cast(id)); - }; - - uint64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); - uint64_t groupid_mask = (1ULL << num_groupid_bits) - 1; - constexpr uint64_t stamp_mask = 0x7f; - uint64_t num_block_bytes = (8 + num_groupid_bits); + size_t num_bytes_for_bits = (*inout_num_selected + 7) / 8 + sizeof(uint64_t); + auto match_bitvector_buf = util::TempVectorHolder( + temp_stack_, static_cast(num_bytes_for_bits)); + uint8_t* match_bitvector = match_bitvector_buf.mutable_data(); + memset(match_bitvector, 0xff, num_bytes_for_bits); + // Check the alignment of the input selection vector + ARROW_DCHECK((reinterpret_cast(inout_selection) & 1) == 0); + + uint32_t num_inserted_new = 0; uint32_t num_processed; - for (num_processed = 0; - // Second condition in for loop: - // We need to break processing and have the caller of this function - // resize hash table if we reach the limit of the number of groups present. - num_processed < *inout_num_selected && - num_inserted_ + num_ids[category_inserted] < num_groups_limit; - ++num_processed) { + for (num_processed = 0; num_processed < *inout_num_selected; ++num_processed) { // row id in original batch - int id = util::SafeLoad(&inout_selection[num_processed]); - - uint64_t slot_id = wrap_global_slot_id(util::SafeLoad(&inout_next_slot_ids[id])); - uint64_t block_id = slot_id >> 3; - uint32_t hash = hashes[id]; - uint8_t* blockbase = blocks_ + num_block_bytes * block_id; - uint64_t block = *reinterpret_cast(blockbase); - uint64_t stamp = (hash >> (bits_hash_ - log_blocks_ - bits_stamp_)) & stamp_mask; - int start_slot = (slot_id & 7); - - bool isempty = (blockbase[7 - start_slot] == 0x80); - if (isempty) { + int id = inout_selection[num_processed]; + bool match_found = + find_next_stamp_match(hashes[id], inout_next_slot_ids[id], + &inout_next_slot_ids[id], &out_group_ids[id]); + if (!match_found) { // If we reach the empty slot we insert key for new group + // + out_group_ids[id] = num_inserted_ + num_inserted_new; + insert_into_empty_slot(inout_next_slot_ids[id], hashes[id], out_group_ids[id]); + ::arrow::BitUtil::ClearBit(match_bitvector, num_processed); + ++num_inserted_new; - blockbase[7 - start_slot] = static_cast(stamp); - uint32_t group_id = num_inserted_ + num_ids[category_inserted]; - int groupid_bit_offset = static_cast(start_slot * num_groupid_bits); - - // We assume here that the number of bits is rounded up to 8, 16, 32 or 64. - // In that case we can insert group id value using aligned 64-bit word access. - ARROW_DCHECK(num_groupid_bits == 8 || num_groupid_bits == 16 || - num_groupid_bits == 32 || num_groupid_bits == 64); - uint64_t* ptr = - &reinterpret_cast(blockbase + 8)[groupid_bit_offset >> 6]; - util::SafeStore(ptr, util::SafeLoad(ptr) | (static_cast(group_id) - << (groupid_bit_offset & 63))); - - hashes_[slot_id] = hash; - util::SafeStore(&out_group_ids[id], group_id); - push_id(category_inserted, id); - } else { - // We search for a slot with a matching stamp within a single block. - // We append row id to the appropriate sequence of ids based on - // whether the match has been found or not. - - int new_match_found; - int new_slot; - search_block(block, static_cast(stamp), start_slot, &new_slot, - &new_match_found); - auto new_groupid = - static_cast(extract_group_id(blockbase, new_slot, groupid_mask)); - ARROW_DCHECK(new_groupid < num_inserted_ + num_ids[category_inserted]); - new_slot = - static_cast(next_slot_to_visit(block_id, new_slot, new_match_found)); - util::SafeStore(&inout_next_slot_ids[id], new_slot); - util::SafeStore(&out_group_ids[id], new_groupid); - push_id(new_match_found, id); + // We need to break processing and have the caller of this function + // resize hash table if we reach the limit of the number of groups present. + // + if (num_inserted_ + num_inserted_new == num_groups_limit) { + ++num_processed; + break; + } } } + auto temp_ids_buffer = + util::TempVectorHolder(temp_stack_, *inout_num_selected); + uint16_t* temp_ids = temp_ids_buffer.mutable_data(); + int num_temp_ids = 0; + // Copy keys for newly inserted rows using callback - RETURN_NOT_OK(append_impl_(num_ids[category_inserted], ids[category_inserted])); - num_inserted_ += num_ids[category_inserted]; + // + util::BitUtil::bits_filter_indexes(0, hardware_flags_, num_processed, match_bitvector, + inout_selection, &num_temp_ids, temp_ids); + ARROW_DCHECK(static_cast(num_inserted_new) == num_temp_ids); + RETURN_NOT_OK(append_impl_(num_inserted_new, temp_ids)); + num_inserted_ += num_inserted_new; // Evaluate comparisons and append ids of rows that failed it to the non-match set. - uint32_t num_not_equal; - equal_impl_(num_ids[category_cmp], ids[category_cmp], out_group_ids, &num_not_equal, - ids[category_nomatch] + num_ids[category_nomatch]); - num_ids[category_nomatch] += num_not_equal; + util::BitUtil::bits_filter_indexes(1, hardware_flags_, num_processed, match_bitvector, + inout_selection, &num_temp_ids, temp_ids); + run_comparisons(num_temp_ids, temp_ids, nullptr, out_group_ids, &num_temp_ids, + temp_ids); + memcpy(inout_selection, temp_ids, sizeof(uint16_t) * num_temp_ids); // Append ids of any unprocessed entries if we aborted processing due to the need // to resize. if (num_processed < *inout_num_selected) { - memmove(ids[category_nomatch] + num_ids[category_nomatch], - inout_selection + num_processed, + memmove(inout_selection + num_temp_ids, inout_selection + num_processed, sizeof(uint16_t) * (*inout_num_selected - num_processed)); - num_ids[category_nomatch] += (*inout_num_selected - num_processed); } + *inout_num_selected = num_temp_ids + (*inout_num_selected - num_processed); *out_need_resize = (num_inserted_ == num_groups_limit); - *inout_num_selected = num_ids[category_nomatch]; return Status::OK(); } -// Use hashes and callbacks to find group ids for already existing keys and -// to insert and report newly assigned group ids for new keys. +// Do inserts and find group ids for a set of new keys (with possible duplicates within +// this set). // -Status SwissTable::map(const int num_keys, const uint32_t* hashes, - uint32_t* out_groupids) { - // Temporary buffers have limited size. - // Caller is responsible for splitting larger input arrays into smaller chunks. - ARROW_DCHECK(num_keys <= (1 << log_minibatch_)); - - // Allocate temporary buffers with a lifetime of this function - auto match_bitvector_buf = util::TempVectorHolder(temp_stack_, num_keys); - uint8_t* match_bitvector = match_bitvector_buf.mutable_data(); - auto slot_ids_buf = util::TempVectorHolder(temp_stack_, num_keys); - uint32_t* slot_ids = slot_ids_buf.mutable_data(); - auto ids_buf = util::TempVectorHolder(temp_stack_, num_keys); - uint16_t* ids = ids_buf.mutable_data(); - uint32_t num_ids; +Status SwissTable::map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t* hashes, + uint32_t* group_ids) { + if (num_ids == 0) { + return Status::OK(); + } - // First-pass processing. - // Optimistically use simplified lookup involving only a start block to find - // a single group id candidate for every input. -#if defined(ARROW_HAVE_AVX2) - if (hardware_flags_ & arrow::internal::CpuInfo::AVX2) { - if (log_blocks_ <= 4) { - int tail = num_keys % 32; - int delta = num_keys - tail; - lookup_1_avx2_x32(num_keys - tail, hashes, match_bitvector, out_groupids, slot_ids); - lookup_1_avx2_x8(tail, hashes + delta, match_bitvector + delta / 8, - out_groupids + delta, slot_ids + delta); - } else { - lookup_1_avx2_x8(num_keys, hashes, match_bitvector, out_groupids, slot_ids); - } - } else { -#endif - lookup_1(nullptr, num_keys, hashes, match_bitvector, out_groupids, slot_ids); -#if defined(ARROW_HAVE_AVX2) + uint16_t max_id = ids[0]; + for (uint32_t i = 1; i < num_ids; ++i) { + max_id = std::max(max_id, ids[i]); } -#endif - int64_t num_matches = - arrow::internal::CountSetBits(match_bitvector, /*offset=*/0, num_keys); + // Temporary buffers have limited size. + // Caller is responsible for splitting larger input arrays into smaller chunks. + ARROW_DCHECK(static_cast(num_ids) <= (1 << log_minibatch_)); + ARROW_DCHECK(static_cast(max_id + 1) <= (1 << log_minibatch_)); - // After the first-pass processing count rows with matches (based on stamp comparison) - // and decide based on their percentage whether to call dense or sparse comparison - // function. Dense comparison means evaluating it for all inputs, even if the matching - // stamp was not found. It may be cheaper to evaluate comparison for all inputs if the - // extra cost of filtering is higher than the wasted processing of rows with no match. - // - // Dense comparison can only be used if there is at least one inserted key, - // because otherwise there is no key to compare to. - // - if (num_inserted_ > 0 && num_matches > 0 && num_matches > 3 * num_keys / 4) { - // Dense comparisons - equal_impl_(num_keys, nullptr, out_groupids, &num_ids, ids); - } else { - // Sparse comparisons that involve filtering the input set of keys - auto ids_cmp_buf = util::TempVectorHolder(temp_stack_, num_keys); - uint16_t* ids_cmp = ids_cmp_buf.mutable_data(); - int num_ids_result; - util::BitUtil::bits_split_indexes(hardware_flags_, num_keys, match_bitvector, - &num_ids_result, ids, ids_cmp); - num_ids = num_ids_result; - uint32_t num_not_equal; - equal_impl_(num_keys - num_ids, ids_cmp, out_groupids, &num_not_equal, ids + num_ids); - num_ids += num_not_equal; - } + // Allocate temporary buffers for slot ids and intialize them + auto slot_ids_buf = util::TempVectorHolder(temp_stack_, max_id + 1); + uint32_t* slot_ids = slot_ids_buf.mutable_data(); + init_slot_ids_for_new_keys(num_ids, ids, hashes, slot_ids); do { // A single round of slow-pass (robust) lookup or insert. - // A single round ends with either a single comparison verifying the match candidate - // or inserting a new key. A single round of slow-pass may return early if we reach - // the limit of the number of groups due to inserts of new keys. In that case we need - // to resize and recalculating starting global slot ids for new bigger hash table. + // A single round ends with either a single comparison verifying the match + // candidate or inserting a new key. A single round of slow-pass may return early + // if we reach the limit of the number of groups due to inserts of new keys. In + // that case we need to resize and recalculating starting global slot ids for new + // bigger hash table. bool out_of_capacity; - RETURN_NOT_OK( - lookup_2(hashes, &num_ids, ids, &out_of_capacity, out_groupids, slot_ids)); + RETURN_NOT_OK(map_new_keys_helper(hashes, &num_ids, ids, &out_of_capacity, group_ids, + slot_ids)); if (out_of_capacity) { RETURN_NOT_OK(grow_double()); // Reset start slot ids for still unprocessed input keys. // for (uint32_t i = 0; i < num_ids; ++i) { // First slot in the new starting block - const int16_t id = util::SafeLoad(&ids[i]); - util::SafeStore(&slot_ids[id], (hashes[id] >> (bits_hash_ - log_blocks_)) * 8); + const int16_t id = ids[i]; + slot_ids[id] = (hashes[id] >> (bits_hash_ - log_blocks_)) * 8; } } } while (num_ids > 0); diff --git a/cpp/src/arrow/compute/exec/key_map.h b/cpp/src/arrow/compute/exec/key_map.h index 7ee28b82898..cf539f4a99b 100644 --- a/cpp/src/arrow/compute/exec/key_map.h +++ b/cpp/src/arrow/compute/exec/key_map.h @@ -40,9 +40,17 @@ class SwissTable { Status init(int64_t hardware_flags, MemoryPool* pool, util::TempVectorStack* temp_stack, int log_minibatch, EqualImpl equal_impl, AppendImpl append_impl); + void cleanup(); - Status map(const int ckeys, const uint32_t* hashes, uint32_t* outgroupids); + void early_filter(const int num_keys, const uint32_t* hashes, + uint8_t* out_match_bitvector, uint8_t* out_local_slots) const; + + void find(const int num_keys, const uint32_t* hashes, uint8_t* inout_match_bitvector, + const uint8_t* local_slots, uint32_t* out_group_ids) const; + + Status map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t* hashes, + uint32_t* group_ids); private: // Lookup helpers @@ -55,6 +63,9 @@ class SwissTable { /// b) first empty slot is encountered, /// c) we reach the end of the block. /// + /// Optionally an index of the first slot to start the search from can be specified. + /// In this case slots before it will be ignored. + /// /// \param[in] block 8 byte block of hash table /// \param[in] stamp 7 bits of hash used as a stamp /// \param[in] start_slot Index of the first slot in the block to start search from. We @@ -63,55 +74,78 @@ class SwissTable { /// variant.) /// \param[out] out_slot index corresponding to the discovered position of interest (8 /// represents end of block). - /// \param[out] out_match_found an integer flag (0 or 1) indicating if we found a - /// matching stamp. + /// \param[out] out_match_found an integer flag (0 or 1) indicating if we reached an + /// empty slot (0) or not (1). Therefore 1 can mean that either actual match was found + /// (case a) above) or we reached the end of full block (case b) above). + /// template inline void search_block(uint64_t block, int stamp, int start_slot, int* out_slot, - int* out_match_found); + int* out_match_found) const; /// \brief Extract group id for a given slot in a given block. /// - /// Group ids follow in memory after 64-bit block data. - /// Maximum number of groups inserted is equal to the number - /// of all slots in all blocks, which is 8 * the number of blocks. - /// Group ids are bit packed using that maximum to determine the necessary number of - /// bits. inline uint64_t extract_group_id(const uint8_t* block_ptr, int slot, uint64_t group_id_mask) const; + void extract_group_ids(const int num_keys, const uint16_t* optional_selection, + const uint32_t* hashes, const uint8_t* local_slots, + uint32_t* out_group_ids) const; - inline uint64_t next_slot_to_visit(uint64_t block_index, int slot, int match_found); + template + void extract_group_ids_imp(const int num_keys, const uint16_t* selection, + const uint32_t* hashes, const uint8_t* local_slots, + uint32_t* out_group_ids, int elements_offset, + int element_mutltiplier) const; - inline void insert(uint8_t* block_base, uint64_t slot_id, uint32_t hash, uint8_t stamp, - uint32_t group_id); + inline uint64_t next_slot_to_visit(uint64_t block_index, int slot, + int match_found) const; inline uint64_t num_groups_for_resize() const; - inline uint64_t wrap_global_slot_id(uint64_t global_slot_id); + inline uint64_t wrap_global_slot_id(uint64_t global_slot_id) const; + + void init_slot_ids(const int num_keys, const uint16_t* selection, + const uint32_t* hashes, const uint8_t* local_slots, + const uint8_t* match_bitvector, uint32_t* out_slot_ids) const; - // First hash table access - // Find first match in the start block if exists. - // Possible cases: - // 1. Stamp match in a block - // 2. No stamp match in a block, no empty buckets in a block - // 3. No stamp match in a block, empty buckets in a block + void init_slot_ids_for_new_keys(uint32_t num_ids, const uint16_t* ids, + const uint32_t* hashes, uint32_t* slot_ids) const; + + // Quickly filter out keys that have no matches based only on hash value and the + // corresponding starting 64-bit block of slot status bytes. May return false positives. // - template - void lookup_1(const uint16_t* selection, const int num_keys, const uint32_t* hashes, - uint8_t* out_match_bitvector, uint32_t* out_group_ids, - uint32_t* out_slot_ids); + void early_filter_imp(const int num_keys, const uint32_t* hashes, + uint8_t* out_match_bitvector, uint8_t* out_local_slots) const; #if defined(ARROW_HAVE_AVX2) - void lookup_1_avx2_x8(const int num_hashes, const uint32_t* hashes, - uint8_t* out_match_bitvector, uint32_t* out_group_ids, - uint32_t* out_next_slot_ids); - void lookup_1_avx2_x32(const int num_hashes, const uint32_t* hashes, - uint8_t* out_match_bitvector, uint32_t* out_group_ids, - uint32_t* out_next_slot_ids); + void early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* hashes, + uint8_t* out_match_bitvector, + uint8_t* out_local_slots) const; + void early_filter_imp_avx2_x32(const int num_hashes, const uint32_t* hashes, + uint8_t* out_match_bitvector, + uint8_t* out_local_slots) const; + void extract_group_ids_avx2(const int num_keys, const uint32_t* hashes, + const uint8_t* local_slots, uint32_t* out_group_ids, + int byte_offset, int byte_multiplier, int byte_size) const; #endif - // Completing hash table lookup post first access - Status lookup_2(const uint32_t* hashes, uint32_t* inout_num_selected, - uint16_t* inout_selection, bool* out_need_resize, - uint32_t* out_group_ids, uint32_t* out_next_slot_ids); + void run_comparisons(const int num_keys, const uint16_t* optional_selection_ids, + const uint8_t* optional_selection_bitvector, + const uint32_t* groupids, int* out_num_not_equal, + uint16_t* out_not_equal_selection) const; + + inline bool find_next_stamp_match(const uint32_t hash, const uint32_t in_slot_id, + uint32_t* out_slot_id, uint32_t* out_group_id) const; + + inline void insert_into_empty_slot(uint32_t slot_id, uint32_t hash, uint32_t group_id); + + // Slow processing of input keys in the most generic case. + // Handles inserting new keys. + // Pre-existing keys will be handled correctly, although the intended use is for this + // call to follow a call to find() method, which would only pass on new keys that were + // not present in the hash table. + // + Status map_new_keys_helper(const uint32_t* hashes, uint32_t* inout_num_selected, + uint16_t* inout_selection, bool* out_need_resize, + uint32_t* out_group_ids, uint32_t* out_next_slot_ids); // Resize small hash tables when 50% full (up to 8KB). // Resize large hash tables when 75% full. diff --git a/cpp/src/arrow/compute/exec/key_map_avx2.cc b/cpp/src/arrow/compute/exec/key_map_avx2.cc index a2efb4d1bb9..2fca6bf6c10 100644 --- a/cpp/src/arrow/compute/exec/key_map_avx2.cc +++ b/cpp/src/arrow/compute/exec/key_map_avx2.cc @@ -36,14 +36,13 @@ namespace compute { // This is more or less translation of equivalent scalar code, adjusted for a different // instruction set (e.g. missing leading zero count instruction). // -void SwissTable::lookup_1_avx2_x8(const int num_hashes, const uint32_t* hashes, - uint8_t* out_match_bitvector, uint32_t* out_group_ids, - uint32_t* out_next_slot_ids) { +void SwissTable::early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* hashes, + uint8_t* out_match_bitvector, + uint8_t* out_local_slots) const { // Number of inputs processed together in a loop constexpr int unroll = 8; const int num_group_id_bits = num_groupid_bits_from_log_blocks(log_blocks_); - uint32_t group_id_mask = ~static_cast(0) >> (32 - num_group_id_bits); const __m256i* vhash_ptr = reinterpret_cast(hashes); const __m256i vstamp_mask = _mm256_set1_epi32((1 << bits_stamp_) - 1); @@ -85,6 +84,15 @@ void SwissTable::lookup_1_avx2_x8(const int num_hashes, const uint32_t* hashes, vstamp_B, _mm256_or_si256(vbyte_repeat_pattern, vblock_highbits_B)); __m256i vmatches_A = _mm256_cmpeq_epi8(vblock_A, vstamp_A); __m256i vmatches_B = _mm256_cmpeq_epi8(vblock_B, vstamp_B); + + // In case when there are no matches in slots and the block is full (no empty slots), + // pretend that there is a match in the last slot. + // + vmatches_A = _mm256_or_si256( + vmatches_A, _mm256_andnot_si256(vblock_highbits_A, _mm256_set1_epi64x(0xff))); + vmatches_B = _mm256_or_si256( + vmatches_B, _mm256_andnot_si256(vblock_highbits_B, _mm256_set1_epi64x(0xff))); + __m256i vmatch_found = _mm256_andnot_si256( _mm256_blend_epi32(_mm256_cmpeq_epi64(vmatches_A, _mm256_setzero_si256()), _mm256_cmpeq_epi64(vmatches_B, _mm256_setzero_si256()), @@ -106,46 +114,30 @@ void SwissTable::lookup_1_avx2_x8(const int num_hashes, const uint32_t* hashes, // // Emulating lzcnt in lowest bytes of 32-bit elements __m256i vgt = _mm256_cmpgt_epi32(_mm256_set1_epi32(16), vmatches); - __m256i vnext_slot_id = + __m256i vlocal_slot = _mm256_blendv_epi8(_mm256_srli_epi32(vmatches, 4), _mm256_and_si256(vmatches, _mm256_set1_epi32(0x0f)), vgt); - vnext_slot_id = _mm256_shuffle_epi8( + vlocal_slot = _mm256_shuffle_epi8( _mm256_setr_epi8(4, 3, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0), - vnext_slot_id); - vnext_slot_id = - _mm256_add_epi32(_mm256_and_si256(vnext_slot_id, _mm256_set1_epi32(0xff)), - _mm256_and_si256(vgt, _mm256_set1_epi32(4))); - - // Lookup group ids - // - __m256i vgroupid_bit_offset = - _mm256_mullo_epi32(_mm256_and_si256(vnext_slot_id, _mm256_set1_epi32(7)), - _mm256_set1_epi32(num_group_id_bits)); - - // This only works for up to 25 bits per group id, since it uses 32-bit gather - // TODO: make sure this will never get called when there are more than 2^25 groups. - __m256i vgroupid = - _mm256_add_epi32(_mm256_srli_epi32(vgroupid_bit_offset, 3), - _mm256_add_epi32(vblock_offset, _mm256_set1_epi32(8))); - vgroupid = _mm256_i32gather_epi32(reinterpret_cast(blocks_), vgroupid, 1); - vgroupid = _mm256_srlv_epi32( - vgroupid, _mm256_and_si256(vgroupid_bit_offset, _mm256_set1_epi32(7))); - vgroupid = _mm256_and_si256(vgroupid, _mm256_set1_epi32(group_id_mask)); + vlocal_slot); + vlocal_slot = _mm256_add_epi32(_mm256_and_si256(vlocal_slot, _mm256_set1_epi32(0xff)), + _mm256_and_si256(vgt, _mm256_set1_epi32(4))); // Convert slot id relative to the block to slot id relative to the beginnning of the // table // - vnext_slot_id = _mm256_add_epi32( - _mm256_add_epi32(vnext_slot_id, - _mm256_and_si256(vmatch_found, _mm256_set1_epi32(1))), - _mm256_slli_epi32(vblock_id, 3)); + uint64_t local_slot = _mm256_extract_epi64( + _mm256_permutevar8x32_epi32( + _mm256_shuffle_epi8( + vlocal_slot, _mm256_setr_epi32(0x0c080400, 0, 0, 0, 0x0c080400, 0, 0, 0)), + _mm256_setr_epi32(0, 4, 0, 0, 0, 0, 0, 0)), + 0); + (reinterpret_cast(out_local_slots))[i] = local_slot; // Convert match found vector from 32-bit elements to bit vector out_match_bitvector[i] = _pext_u32(_mm256_movemask_epi8(vmatch_found), 0x11111111); // 0b00010001 repeated 4x - _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + i, vgroupid); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_next_slot_ids) + i, vnext_slot_id); } } @@ -214,9 +206,9 @@ inline void split_bytes_avx2(__m256i word0, __m256i word1, __m256i word2, __m256 // using a different method. // TODO: Explain the idea behind storing arrays in SIMD registers. // Explain why it is faster with SIMD than using memory loads. -void SwissTable::lookup_1_avx2_x32(const int num_hashes, const uint32_t* hashes, - uint8_t* out_match_bitvector, uint32_t* out_group_ids, - uint32_t* out_next_slot_ids) { +void SwissTable::early_filter_imp_avx2_x32(const int num_hashes, const uint32_t* hashes, + uint8_t* out_match_bitvector, + uint8_t* out_local_slots) const { constexpr int unroll = 32; // There is a limit on the number of input blocks, @@ -227,8 +219,6 @@ void SwissTable::lookup_1_avx2_x32(const int num_hashes, const uint32_t* hashes, // table. We put them in the same order. __m256i vblock_byte0, vblock_byte1, vblock_byte2, vblock_byte3, vblock_byte4, vblock_byte5, vblock_byte6, vblock_byte7; - __m256i vgroupid_byte0, vgroupid_byte1, vgroupid_byte2, vgroupid_byte3, vgroupid_byte4, - vgroupid_byte5, vgroupid_byte6, vgroupid_byte7; // What we output if there is no match in the block __m256i vslot_empty_or_end; @@ -236,22 +226,16 @@ void SwissTable::lookup_1_avx2_x32(const int num_hashes, const uint32_t* hashes, constexpr uint32_t k4ByteSequence_1_5_9_13 = 0x0d090501; constexpr uint32_t k4ByteSequence_2_6_10_14 = 0x0e0a0602; constexpr uint32_t k4ByteSequence_3_7_11_15 = 0x0f0b0703; - constexpr uint64_t kEachByteIs1 = 0x0101010101010101ULL; constexpr uint64_t kByteSequence7DownTo0 = 0x0001020304050607ULL; constexpr uint64_t kByteSequence15DownTo8 = 0x08090A0B0C0D0E0FULL; // Bit unpack group ids into 1B. // Assemble the sequence of block bytes. uint64_t block_bytes[16]; - uint64_t groupid_bytes[16]; const int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); - uint64_t bit_unpack_mask = ((1 << num_groupid_bits) - 1) * kEachByteIs1; for (int i = 0; i < (1 << log_blocks_); ++i) { - uint64_t in_groupids = - *reinterpret_cast(blocks_ + (8 + num_groupid_bits) * i + 8); uint64_t in_blockbytes = *reinterpret_cast(blocks_ + (8 + num_groupid_bits) * i); - groupid_bytes[i] = _pdep_u64(in_groupids, bit_unpack_mask); block_bytes[i] = in_blockbytes; } @@ -275,18 +259,11 @@ void SwissTable::lookup_1_avx2_x32(const int num_hashes, const uint32_t* hashes, split_bytes_avx2(vblock_words0, vblock_words1, vblock_words2, vblock_words3, vblock_byte0, vblock_byte1, vblock_byte2, vblock_byte3, vblock_byte4, vblock_byte5, vblock_byte6, vblock_byte7); - split_bytes_avx2( - _mm256_loadu_si256(reinterpret_cast(groupid_bytes) + 0), - _mm256_loadu_si256(reinterpret_cast(groupid_bytes) + 1), - _mm256_loadu_si256(reinterpret_cast(groupid_bytes) + 2), - _mm256_loadu_si256(reinterpret_cast(groupid_bytes) + 3), - vgroupid_byte0, vgroupid_byte1, vgroupid_byte2, vgroupid_byte3, vgroupid_byte4, - vgroupid_byte5, vgroupid_byte6, vgroupid_byte7); // Calculate the slot to output when there is no match in a block. - // It will be the index of the first empty slot or 8 (the number of slots in block) + // It will be the index of the first empty slot or 7 (the number of slots in block) // if there are no empty slots. - vslot_empty_or_end = _mm256_set1_epi8(8); + vslot_empty_or_end = _mm256_set1_epi8(7); { __m256i vis_empty; #define CMP(VBLOCKBYTE, BYTENUM) \ @@ -304,6 +281,9 @@ void SwissTable::lookup_1_avx2_x32(const int num_hashes, const uint32_t* hashes, CMP(vblock_byte0, 0); #undef CMP } + __m256i vblock_is_full = _mm256_andnot_si256( + _mm256_cmpeq_epi8(vblock_byte7, _mm256_set1_epi8(static_cast(0x80))), + _mm256_set1_epi8(static_cast(0xff))); const int block_id_mask = (1 << log_blocks_) - 1; @@ -339,29 +319,28 @@ void SwissTable::lookup_1_avx2_x32(const int num_hashes, const uint32_t* hashes, __m256i vblock_id = _mm256_or_si256(vblock_id_A, _mm256_slli_epi16(vblock_id_B, 8)); // Visit all block bytes in reverse order (overwriting data on multiple matches) - __m256i vmatch_found = _mm256_setzero_si256(); + // + // Always set match found to true for full blocks. + // + __m256i vmatch_found = _mm256_shuffle_epi8(vblock_is_full, vblock_id); __m256i vslot_id = _mm256_shuffle_epi8(vslot_empty_or_end, vblock_id); - __m256i vgroup_id = _mm256_setzero_si256(); -#define CMP(VBLOCK_BYTE, VGROUPID_BYTE, BYTENUM) \ - { \ - __m256i vcmp = \ - _mm256_cmpeq_epi8(_mm256_shuffle_epi8(VBLOCK_BYTE, vblock_id), vstamp); \ - vmatch_found = _mm256_or_si256(vmatch_found, vcmp); \ - vgroup_id = _mm256_blendv_epi8(vgroup_id, \ - _mm256_shuffle_epi8(VGROUPID_BYTE, vblock_id), vcmp); \ - vslot_id = _mm256_blendv_epi8(vslot_id, _mm256_set1_epi8(BYTENUM + 1), vcmp); \ +#define CMP(VBLOCK_BYTE, BYTENUM) \ + { \ + __m256i vcmp = \ + _mm256_cmpeq_epi8(_mm256_shuffle_epi8(VBLOCK_BYTE, vblock_id), vstamp); \ + vmatch_found = _mm256_or_si256(vmatch_found, vcmp); \ + vslot_id = _mm256_blendv_epi8(vslot_id, _mm256_set1_epi8(BYTENUM), vcmp); \ } - CMP(vblock_byte7, vgroupid_byte7, 7); - CMP(vblock_byte6, vgroupid_byte6, 6); - CMP(vblock_byte5, vgroupid_byte5, 5); - CMP(vblock_byte4, vgroupid_byte4, 4); - CMP(vblock_byte3, vgroupid_byte3, 3); - CMP(vblock_byte2, vgroupid_byte2, 2); - CMP(vblock_byte1, vgroupid_byte1, 1); - CMP(vblock_byte0, vgroupid_byte0, 0); + CMP(vblock_byte7, 7); + CMP(vblock_byte6, 6); + CMP(vblock_byte5, 5); + CMP(vblock_byte4, 4); + CMP(vblock_byte3, 3); + CMP(vblock_byte2, 2); + CMP(vblock_byte1, 1); + CMP(vblock_byte0, 0); #undef CMP - vslot_id = _mm256_add_epi8(vslot_id, _mm256_slli_epi32(vblock_id, 3)); // So far the output is in the order: [0, 8, 16, 24, 1, 9, 17, 25, 2, 10, 18, 26, ...] vmatch_found = _mm256_shuffle_epi8( vmatch_found, @@ -374,30 +353,58 @@ void SwissTable::lookup_1_avx2_x32(const int num_hashes, const uint32_t* hashes, vmatch_found = _mm256_permutevar8x32_epi32(vmatch_found, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7)); + // Repeat the same permutation for slot ids + vslot_id = _mm256_shuffle_epi8( + vslot_id, _mm256_setr_epi32(k4ByteSequence_0_4_8_12, k4ByteSequence_1_5_9_13, + k4ByteSequence_2_6_10_14, k4ByteSequence_3_7_11_15, + k4ByteSequence_0_4_8_12, k4ByteSequence_1_5_9_13, + k4ByteSequence_2_6_10_14, k4ByteSequence_3_7_11_15)); + vslot_id = + _mm256_permutevar8x32_epi32(vslot_id, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7)); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_local_slots) + i, vslot_id); + reinterpret_cast(out_match_bitvector)[i] = _mm256_movemask_epi8(vmatch_found); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + 4 * i + 0, - _mm256_and_si256(vgroup_id, _mm256_set1_epi32(0xff))); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(out_group_ids) + 4 * i + 1, - _mm256_and_si256(_mm256_srli_epi32(vgroup_id, 8), _mm256_set1_epi32(0xff))); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(out_group_ids) + 4 * i + 2, - _mm256_and_si256(_mm256_srli_epi32(vgroup_id, 16), _mm256_set1_epi32(0xff))); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(out_group_ids) + 4 * i + 3, - _mm256_and_si256(_mm256_srli_epi32(vgroup_id, 24), _mm256_set1_epi32(0xff))); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_next_slot_ids) + 4 * i + 0, - _mm256_and_si256(vslot_id, _mm256_set1_epi32(0xff))); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(out_next_slot_ids) + 4 * i + 1, - _mm256_and_si256(_mm256_srli_epi32(vslot_id, 8), _mm256_set1_epi32(0xff))); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(out_next_slot_ids) + 4 * i + 2, - _mm256_and_si256(_mm256_srli_epi32(vslot_id, 16), _mm256_set1_epi32(0xff))); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(out_next_slot_ids) + 4 * i + 3, - _mm256_and_si256(_mm256_srli_epi32(vslot_id, 24), _mm256_set1_epi32(0xff))); + } +} + +void SwissTable::extract_group_ids_avx2(const int num_keys, const uint32_t* hashes, + const uint8_t* local_slots, + uint32_t* out_group_ids, int byte_offset, + int byte_multiplier, int byte_size) const { + ARROW_DCHECK(byte_size == 1 || byte_size == 2 || byte_size == 4); + uint32_t mask = byte_size == 1 ? 0xFF : byte_size == 2 ? 0xFFFF : 0xFFFFFFFF; + auto elements = reinterpret_cast(blocks_ + byte_offset); + constexpr int unroll = 8; + if (log_blocks_ == 0) { + ARROW_DCHECK(byte_size == 1 && byte_offset == 8 && byte_multiplier == 16); + __m256i block_group_ids = + _mm256_set1_epi64x(reinterpret_cast(blocks_)[1]); + for (int i = 0; i < (num_keys + unroll - 1) / unroll; ++i) { + __m256i local_slot = + _mm256_set1_epi64x(reinterpret_cast(local_slots)[i]); + __m256i group_id = _mm256_shuffle_epi8(block_group_ids, local_slot); + group_id = _mm256_shuffle_epi8( + group_id, _mm256_setr_epi32(0x80808000, 0x80808001, 0x80808002, 0x80808003, + 0x80808004, 0x80808005, 0x80808006, 0x80808007)); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + i, group_id); + } + } else { + for (int i = 0; i < (num_keys + unroll - 1) / unroll; ++i) { + __m256i hash = _mm256_loadu_si256(reinterpret_cast(hashes) + i); + __m256i local_slot = + _mm256_set1_epi64x(reinterpret_cast(local_slots)[i]); + local_slot = _mm256_shuffle_epi8( + local_slot, _mm256_setr_epi32(0x80808000, 0x80808001, 0x80808002, 0x80808003, + 0x80808004, 0x80808005, 0x80808006, 0x80808007)); + local_slot = _mm256_mullo_epi32(local_slot, _mm256_set1_epi32(byte_size)); + __m256i pos = _mm256_srlv_epi32(hash, _mm256_set1_epi32(bits_hash_ - log_blocks_)); + pos = _mm256_mullo_epi32(pos, _mm256_set1_epi32(byte_multiplier)); + pos = _mm256_add_epi32(pos, local_slot); + __m256i group_id = _mm256_i32gather_epi32(elements, pos, 1); + group_id = _mm256_and_si256(group_id, _mm256_set1_epi32(mask)); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + i, group_id); + } } } diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 843d62911a7..3c9e3f43a48 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -600,9 +600,25 @@ struct GrouperFastImpl : Grouper { minibatch_hashes_.data()); // Map - RETURN_NOT_OK( - map_.map(batch_size_next, minibatch_hashes_.data(), - reinterpret_cast(group_ids->mutable_data()) + start_row)); + auto match_bitvector = + util::TempVectorHolder(&temp_stack_, (batch_size_next + 7) / 8); + { + auto local_slots = util::TempVectorHolder(&temp_stack_, batch_size_next); + map_.early_filter(batch_size_next, minibatch_hashes_.data(), + match_bitvector.mutable_data(), local_slots.mutable_data()); + map_.find(batch_size_next, minibatch_hashes_.data(), + match_bitvector.mutable_data(), local_slots.mutable_data(), + reinterpret_cast(group_ids->mutable_data()) + start_row); + } + auto ids = util::TempVectorHolder(&temp_stack_, batch_size_next); + int num_ids; + util::BitUtil::bits_to_indexes(0, encode_ctx_.hardware_flags, batch_size_next, + match_bitvector.mutable_data(), &num_ids, + ids.mutable_data()); + + RETURN_NOT_OK(map_.map_new_keys( + num_ids, ids.mutable_data(), minibatch_hashes_.data(), + reinterpret_cast(group_ids->mutable_data()) + start_row)); start_row += batch_size_next; @@ -2188,7 +2204,8 @@ struct GroupedCountDistinctImpl : public GroupedAggregator { } Status Consume(const ExecBatch& batch) override { - return grouper_->Consume(batch).status(); + ARROW_ASSIGN_OR_RAISE(std::ignore, grouper_->Consume(batch)); + return Status::OK(); } Status Merge(GroupedAggregator&& raw_other, diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc index f90e71bf670..369d1d8066f 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc @@ -57,6 +57,7 @@ #include "arrow/util/thread_pool.h" #include "arrow/util/vector.h" +using testing::Eq; using testing::HasSubstr; namespace arrow { @@ -319,6 +320,14 @@ struct TestGrouper { AssertEquivalentIds(expected, ids); } + void ExpectUniques(const ExecBatch& uniques) { + EXPECT_THAT(grouper_->GetUniques(), ResultWith(Eq(uniques))); + } + + void ExpectUniques(const std::string& uniques_json) { + ExpectUniques(ExecBatchFromJSON(descrs_, uniques_json)); + } + void AssertEquivalentIds(const Datum& expected, const Datum& actual) { auto left = expected.make_array(); auto right = actual.make_array(); @@ -437,13 +446,17 @@ TEST(Grouper, NumericKey) { TestGrouper g({ty}); g.ExpectConsume("[[3], [3]]", "[0, 0]"); + g.ExpectUniques("[[3]]"); g.ExpectConsume("[[3], [3]]", "[0, 0]"); + g.ExpectUniques("[[3]]"); - g.ExpectConsume("[[27], [81]]", "[1, 2]"); + g.ExpectConsume("[[27], [81], [81]]", "[1, 2, 2]"); + g.ExpectUniques("[[3], [27], [81]]"); g.ExpectConsume("[[3], [27], [3], [27], [null], [81], [27], [81]]", "[0, 1, 0, 1, 3, 2, 1, 2]"); + g.ExpectUniques("[[3], [27], [81], [null]]"); } } @@ -1900,7 +1913,7 @@ TEST(GroupBy, Distinct) { CountOptions all(CountOptions::ALL); CountOptions only_valid(CountOptions::ONLY_VALID); CountOptions only_null(CountOptions::ONLY_NULL); - for (bool use_threads : {true, false}) { + for (bool use_threads : {false}) { SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); auto table =