Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions cpp/src/arrow/compute/exec/key_compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <cstdint>

#include "arrow/compute/exec/util.h"
#include "arrow/util/ubsan.h"

namespace arrow {
namespace compute {
Expand Down Expand Up @@ -170,19 +171,19 @@ void KeyCompare::CompareFixedLengthImp(uint32_t num_rows_already_processed,
//
if (num_64bit_words == 0) {
for (; istripe < num_loops_less_one; ++istripe) {
uint64_t key_left = key_left_ptr[istripe];
uint64_t key_right = key_right_ptr[istripe];
uint64_t key_left = util::SafeLoad(&key_left_ptr[istripe]);
uint64_t key_right = util::SafeLoad(&key_right_ptr[istripe]);
result_or |= (key_left ^ key_right);
}
} else if (num_64bit_words == 2) {
uint64_t key_left = key_left_ptr[istripe];
uint64_t key_right = key_right_ptr[istripe];
uint64_t key_left = util::SafeLoad(&key_left_ptr[istripe]);
uint64_t key_right = util::SafeLoad(&key_right_ptr[istripe]);
result_or |= (key_left ^ key_right);
++istripe;
}

uint64_t key_left = key_left_ptr[istripe];
uint64_t key_right = key_right_ptr[istripe];
uint64_t key_left = util::SafeLoad(&key_left_ptr[istripe]);
uint64_t key_right = util::SafeLoad(&key_right_ptr[istripe]);
result_or |= (tail_mask & (key_left ^ key_right));

int result = (result_or == 0 ? 0xff : 0);
Expand Down Expand Up @@ -246,16 +247,16 @@ void KeyCompare::CompareVaryingLengthImp(
int32_t istripe;
// length can be zero
for (istripe = 0; istripe < (static_cast<int32_t>(length) + 7) / 8 - 1; ++istripe) {
uint64_t key_left = key_left_ptr[istripe];
uint64_t key_right = key_right_ptr[istripe];
uint64_t key_left = util::SafeLoad(&key_left_ptr[istripe]);
uint64_t key_right = util::SafeLoad(&key_right_ptr[istripe]);
result_or |= (key_left ^ key_right);
}

uint32_t length_remaining = length - static_cast<uint32_t>(istripe) * 8;
uint64_t tail_mask = tail_masks[length_remaining];

uint64_t key_left = key_left_ptr[istripe];
uint64_t key_right = key_right_ptr[istripe];
uint64_t key_left = util::SafeLoad(&key_left_ptr[istripe]);
uint64_t key_right = util::SafeLoad(&key_right_ptr[istripe]);
result_or |= (tail_mask & (key_left ^ key_right));

int result = (result_or == 0 ? 0xff : 0);
Expand Down
71 changes: 39 additions & 32 deletions cpp/src/arrow/compute/exec/key_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include "arrow/util/bit_util.h"
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/ubsan.h"

namespace arrow {

Expand Down Expand Up @@ -153,7 +154,7 @@ void SwissTable::lookup_1(const uint16_t* selection, const int num_keys,
for (int i = 0; i < num_keys; ++i) {
int id;
if (use_selection) {
id = selection[i];
id = util::SafeLoad(&selection[i]);
} else {
id = i;
}
Expand All @@ -168,7 +169,7 @@ void SwissTable::lookup_1(const uint16_t* selection, const int num_keys,
uint32_t num_block_bytes = num_groupid_bits + 8;
const uint8_t* blockbase = reinterpret_cast<const uint8_t*>(blocks_) +
static_cast<uint64_t>(iblock) * num_block_bytes;
uint64_t block = *reinterpret_cast<const uint64_t*>(blockbase);
uint64_t block = util::SafeLoadAs<uint64_t>(blockbase);

// Call helper functions to obtain the output triplet:
// - match (of a stamp) found flag
Expand All @@ -182,8 +183,8 @@ void SwissTable::lookup_1(const uint16_t* selection, const int num_keys,
uint64_t islot = next_slot_to_visit(iblock, islot_in_block, match_found);

out_match_bitvector[id / 8] |= match_found << (id & 7);
out_groupids[id] = static_cast<uint32_t>(groupid);
out_slot_ids[id] = static_cast<uint32_t>(islot);
util::SafeStore(&out_groupids[id], static_cast<uint32_t>(groupid));
util::SafeStore(&out_slot_ids[id], static_cast<uint32_t>(islot));
}
}

Expand Down Expand Up @@ -239,7 +240,7 @@ Status SwissTable::lookup_2(const uint32_t* hashes, uint32_t* inout_num_selected
uint16_t* ids[3]{inout_selection, ids_for_comparison_buf.mutable_data(),
ids_inserted_buf.mutable_data()};
auto push_id = [&num_ids, &ids](int category, int id) {
ids[category][num_ids[category]++] = static_cast<uint16_t>(id);
util::SafeStore(&ids[category][num_ids[category]++], static_cast<uint16_t>(id));
};

uint64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
Expand All @@ -256,9 +257,9 @@ Status SwissTable::lookup_2(const uint32_t* hashes, uint32_t* inout_num_selected
num_inserted_ + num_ids[category_inserted] < num_groups_limit;
++num_processed) {
// row id in original batch
int id = inout_selection[num_processed];
int id = util::SafeLoad(&inout_selection[num_processed]);

uint64_t slot_id = wrap_global_slot_id(inout_next_slot_ids[id]);
uint64_t slot_id = wrap_global_slot_id(util::SafeLoad(&inout_next_slot_ids[id]));
uint64_t block_id = slot_id >> 3;
uint32_t hash = hashes[id];
uint8_t* blockbase = blocks_ + num_block_bytes * block_id;
Expand All @@ -278,11 +279,13 @@ Status SwissTable::lookup_2(const uint32_t* hashes, uint32_t* inout_num_selected
// In that case we can insert group id value using aligned 64-bit word access.
ARROW_DCHECK(num_groupid_bits == 8 || num_groupid_bits == 16 ||
num_groupid_bits == 32 || num_groupid_bits == 64);
reinterpret_cast<uint64_t*>(blockbase + 8)[groupid_bit_offset >> 6] |=
(static_cast<uint64_t>(group_id) << (groupid_bit_offset & 63));
uint64_t* ptr =
&reinterpret_cast<uint64_t*>(blockbase + 8)[groupid_bit_offset >> 6];
util::SafeStore(ptr, util::SafeLoad(ptr) | (static_cast<uint64_t>(group_id)
<< (groupid_bit_offset & 63)));

hashes_[slot_id] = hash;
out_group_ids[id] = group_id;
util::SafeStore(&out_group_ids[id], group_id);
push_id(category_inserted, id);
} else {
// We search for a slot with a matching stamp within a single block.
Expand All @@ -298,8 +301,8 @@ Status SwissTable::lookup_2(const uint32_t* hashes, uint32_t* inout_num_selected
ARROW_DCHECK(new_groupid < num_inserted_ + num_ids[category_inserted]);
new_slot =
static_cast<int>(next_slot_to_visit(block_id, new_slot, new_match_found));
inout_next_slot_ids[id] = new_slot;
out_group_ids[id] = new_groupid;
util::SafeStore(&inout_next_slot_ids[id], new_slot);
util::SafeStore(&out_group_ids[id], new_groupid);
push_id(new_match_found, id);
}
}
Expand Down Expand Up @@ -410,7 +413,8 @@ Status SwissTable::map(const int num_keys, const uint32_t* hashes,
//
for (uint32_t i = 0; i < num_ids; ++i) {
// First slot in the new starting block
slot_ids[ids[i]] = (hashes[ids[i]] >> (bits_hash_ - log_blocks_)) * 8;
const int16_t id = util::SafeLoad(&ids[i]);
util::SafeStore(&slot_ids[id], (hashes[id] >> (bits_hash_ - log_blocks_)) * 8);
}
}
} while (num_ids > 0);
Expand Down Expand Up @@ -457,9 +461,8 @@ Status SwissTable::grow_double() {
static_cast<int>(CountLeadingZeros(block & kHighBitOfEachByte) >> 3);
int full_slots_new[2];
full_slots_new[0] = full_slots_new[1] = 0;
*reinterpret_cast<uint64_t*>(double_block_base_new) = kHighBitOfEachByte;
*reinterpret_cast<uint64_t*>(double_block_base_new + block_size_after) =
kHighBitOfEachByte;
util::SafeStore(double_block_base_new, kHighBitOfEachByte);
util::SafeStore(double_block_base_new + block_size_after, kHighBitOfEachByte);

for (int j = 0; j < full_slots; ++j) {
uint64_t slot_id = i * 8 + j;
Expand All @@ -474,18 +477,20 @@ Status SwissTable::grow_double() {
uint8_t stamp_new =
hash >> ((bits_hash_ - log_blocks_after - bits_stamp_)) & stamp_mask;
uint64_t group_id_bit_offs = j * num_group_id_bits_before;
uint64_t group_id = (*reinterpret_cast<const uint64_t*>(block_base + 8 +
(group_id_bit_offs >> 3)) >>
(group_id_bit_offs & 7)) &
group_id_mask_before;
uint64_t group_id =
(util::SafeLoadAs<uint64_t>(block_base + 8 + (group_id_bit_offs >> 3)) >>
(group_id_bit_offs & 7)) &
group_id_mask_before;

uint64_t slot_id_new = i * 16 + ihalf * 8 + full_slots_new[ihalf];
hashes_new[slot_id_new] = hash;
uint8_t* block_base_new = double_block_base_new + ihalf * block_size_after;
block_base_new[7 - full_slots_new[ihalf]] = stamp_new;
int group_id_bit_offs_new = full_slots_new[ihalf] * num_group_id_bits_after;
*reinterpret_cast<uint64_t*>(block_base_new + 8 + (group_id_bit_offs_new >> 3)) |=
(group_id << (group_id_bit_offs_new & 7));
uint64_t* ptr =
reinterpret_cast<uint64_t*>(block_base_new + 8 + (group_id_bit_offs_new >> 3));
util::SafeStore(ptr,
util::SafeLoad(ptr) | (group_id << (group_id_bit_offs_new & 7)));
full_slots_new[ihalf]++;
}
}
Expand All @@ -495,7 +500,7 @@ Status SwissTable::grow_double() {
for (int i = 0; i < (1 << log_blocks_); ++i) {
// How many full slots in this block
uint8_t* block_base = blocks_ + i * block_size_before;
uint64_t block = *reinterpret_cast<const uint64_t*>(block_base);
uint64_t block = util::SafeLoadAs<uint64_t>(block_base);
int full_slots = static_cast<int>(CountLeadingZeros(block & kHighBitOfEachByte) >> 3);

for (int j = 0; j < full_slots; ++j) {
Expand All @@ -508,30 +513,32 @@ Status SwissTable::grow_double() {
}

uint64_t group_id_bit_offs = j * num_group_id_bits_before;
uint64_t group_id = (*reinterpret_cast<const uint64_t*>(block_base + 8 +
(group_id_bit_offs >> 3)) >>
(group_id_bit_offs & 7)) &
group_id_mask_before;
uint64_t group_id =
(util::SafeLoadAs<uint64_t>(block_base + 8 + (group_id_bit_offs >> 3)) >>
(group_id_bit_offs & 7)) &
group_id_mask_before;
uint8_t stamp_new =
hash >> ((bits_hash_ - log_blocks_after - bits_stamp_)) & stamp_mask;

uint8_t* block_base_new = blocks_new + block_id_new * block_size_after;
uint64_t block_new = *reinterpret_cast<const uint64_t*>(block_base_new);
uint64_t block_new = util::SafeLoadAs<uint64_t>(block_base_new);
int full_slots_new =
static_cast<int>(CountLeadingZeros(block_new & kHighBitOfEachByte) >> 3);
while (full_slots_new == 8) {
block_id_new = (block_id_new + 1) & ((1 << log_blocks_after) - 1);
block_base_new = blocks_new + block_id_new * block_size_after;
block_new = *reinterpret_cast<const uint64_t*>(block_base_new);
block_new = util::SafeLoadAs<uint64_t>(block_base_new);
full_slots_new =
static_cast<int>(CountLeadingZeros(block_new & kHighBitOfEachByte) >> 3);
}

hashes_new[block_id_new * 8 + full_slots_new] = hash;
block_base_new[7 - full_slots_new] = stamp_new;
int group_id_bit_offs_new = full_slots_new * num_group_id_bits_after;
*reinterpret_cast<uint64_t*>(block_base_new + 8 + (group_id_bit_offs_new >> 3)) |=
(group_id << (group_id_bit_offs_new & 7));
uint64_t* ptr =
reinterpret_cast<uint64_t*>(block_base_new + 8 + (group_id_bit_offs_new >> 3));
util::SafeStore(ptr,
util::SafeLoad(ptr) | (group_id << (group_id_bit_offs_new & 7)));
}
}

Expand Down Expand Up @@ -567,7 +574,7 @@ Status SwissTable::init(int64_t hardware_flags, MemoryPool* pool,

// Initialize all status bytes to represent an empty slot.
for (uint64_t i = 0; i < (static_cast<uint64_t>(1) << log_blocks_); ++i) {
*reinterpret_cast<uint64_t*>(blocks_ + i * block_bytes) = kHighBitOfEachByte;
util::SafeStore(blocks_ + i * block_bytes, kHighBitOfEachByte);
}

uint64_t num_slots = 1ULL << (log_blocks_ + 3);
Expand Down
16 changes: 9 additions & 7 deletions cpp/src/arrow/compute/exec/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "arrow/util/bit_util.h"
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/ubsan.h"

namespace arrow {

Expand Down Expand Up @@ -66,7 +67,7 @@ void BitUtil::bits_to_indexes_internal(int64_t hardware_flags, const int num_bit
#endif
*num_indexes = 0;
for (int i = 0; i < num_bits / unroll; ++i) {
uint64_t word = reinterpret_cast<const uint64_t*>(bits)[i];
uint64_t word = util::SafeLoad(&reinterpret_cast<const uint64_t*>(bits)[i]);
if (bit_to_search == 0) {
word = ~word;
}
Expand All @@ -81,7 +82,8 @@ void BitUtil::bits_to_indexes_internal(int64_t hardware_flags, const int num_bit
#endif
// Optionally process the last partial word with masking out bits outside range
if (tail) {
uint64_t word = reinterpret_cast<const uint64_t*>(bits)[num_bits / unroll];
uint64_t word =
util::SafeLoad(&reinterpret_cast<const uint64_t*>(bits)[num_bits / unroll]);
if (bit_to_search == 0) {
word = ~word;
}
Expand Down Expand Up @@ -144,7 +146,7 @@ void BitUtil::bits_to_bytes_internal(const int num_bits, const uint8_t* bits,
unpacked |= (bits_next & 1);
unpacked &= 0x0101010101010101ULL;
unpacked *= 255;
reinterpret_cast<uint64_t*>(bytes)[i] = unpacked;
util::SafeStore(&reinterpret_cast<uint64_t*>(bytes)[i], unpacked);
}
}

Expand All @@ -153,7 +155,7 @@ void BitUtil::bytes_to_bits_internal(const int num_bits, const uint8_t* bytes,
constexpr int unroll = 8;
// Process 8 bits at a time
for (int i = 0; i < (num_bits + unroll - 1) / unroll; ++i) {
uint64_t bytes_next = reinterpret_cast<const uint64_t*>(bytes)[i];
uint64_t bytes_next = util::SafeLoad(&reinterpret_cast<const uint64_t*>(bytes)[i]);
bytes_next &= 0x0101010101010101ULL;
bytes_next |= (bytes_next >> 7); // Pairs of adjacent output bits in individual bytes
bytes_next |= (bytes_next >> 14); // 4 adjacent output bits in individual bytes
Expand Down Expand Up @@ -184,7 +186,7 @@ void BitUtil::bits_to_bytes(int64_t hardware_flags, const int num_bits,
unpacked |= (bits_next & 1);
unpacked &= 0x0101010101010101ULL;
unpacked *= 255;
reinterpret_cast<uint64_t*>(bytes)[i] = unpacked;
util::SafeStore(&reinterpret_cast<uint64_t*>(bytes)[i], unpacked);
}
}

Expand All @@ -201,7 +203,7 @@ void BitUtil::bytes_to_bits(int64_t hardware_flags, const int num_bits,
// Process 8 bits at a time
constexpr int unroll = 8;
for (int i = num_processed / unroll; i < (num_bits + unroll - 1) / unroll; ++i) {
uint64_t bytes_next = reinterpret_cast<const uint64_t*>(bytes)[i];
uint64_t bytes_next = util::SafeLoad(&reinterpret_cast<const uint64_t*>(bytes)[i]);
bytes_next &= 0x0101010101010101ULL;
bytes_next |= (bytes_next >> 7); // Pairs of adjacent output bits in individual bytes
bytes_next |= (bytes_next >> 14); // 4 adjacent output bits in individual bytes
Expand All @@ -220,7 +222,7 @@ bool BitUtil::are_all_bytes_zero(int64_t hardware_flags, const uint8_t* bytes,
uint64_t result_or = 0;
uint32_t i;
for (i = 0; i < num_bytes / 8; ++i) {
uint64_t x = reinterpret_cast<const uint64_t*>(bytes)[i];
uint64_t x = util::SafeLoad(&reinterpret_cast<const uint64_t*>(bytes)[i]);
result_or |= x;
}
if (num_bytes % 8 > 0) {
Expand Down