Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 74 additions & 68 deletions cpp/src/arrow/compute/key_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,23 +105,23 @@ inline void Hashing32::StripeMask(int i, uint32_t* mask1, uint32_t* mask2,
}

template <bool T_COMBINE_HASHES>
void Hashing32::HashFixedLenImp(uint32_t num_rows, uint64_t length, const uint8_t* keys,
uint32_t* hashes) {
void Hashing32::HashFixedLenImp(uint32_t num_rows, uint64_t key_length,
const uint8_t* keys, uint32_t* hashes) {
// Calculate the number of rows that skip the last 16 bytes
//
uint32_t num_rows_safe = num_rows;
while (num_rows_safe > 0 && (num_rows - num_rows_safe) * length < kStripeSize) {
while (num_rows_safe > 0 && (num_rows - num_rows_safe) * key_length < kStripeSize) {
--num_rows_safe;
}

// Compute masks for the last 16 byte stripe
//
uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize);
uint64_t num_stripes = bit_util::CeilDiv(key_length, kStripeSize);
uint32_t mask1, mask2, mask3, mask4;
StripeMask(((length - 1) & (kStripeSize - 1)) + 1, &mask1, &mask2, &mask3, &mask4);
StripeMask(((key_length - 1) & (kStripeSize - 1)) + 1, &mask1, &mask2, &mask3, &mask4);

for (uint32_t i = 0; i < num_rows_safe; ++i) {
const uint8_t* key = keys + static_cast<uint64_t>(i) * length;
const uint8_t* key = keys + static_cast<uint64_t>(i) * key_length;
uint32_t acc1, acc2, acc3, acc4;
ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4);
ProcessLastStripe(mask1, mask2, mask3, mask4, key + (num_stripes - 1) * kStripeSize,
Expand All @@ -138,11 +138,11 @@ void Hashing32::HashFixedLenImp(uint32_t num_rows, uint64_t length, const uint8_

uint32_t last_stripe_copy[4];
for (uint32_t i = num_rows_safe; i < num_rows; ++i) {
const uint8_t* key = keys + static_cast<uint64_t>(i) * length;
const uint8_t* key = keys + static_cast<uint64_t>(i) * key_length;
uint32_t acc1, acc2, acc3, acc4;
ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4);
memcpy(last_stripe_copy, key + (num_stripes - 1) * kStripeSize,
length - (num_stripes - 1) * kStripeSize);
key_length - (num_stripes - 1) * kStripeSize);
ProcessLastStripe(mask1, mask2, mask3, mask4,
reinterpret_cast<const uint8_t*>(last_stripe_copy), &acc1, &acc2,
&acc3, &acc4);
Expand All @@ -168,15 +168,16 @@ void Hashing32::HashVarLenImp(uint32_t num_rows, const T* offsets,
}

for (uint32_t i = 0; i < num_rows_safe; ++i) {
uint64_t length = offsets[i + 1] - offsets[i];
uint64_t key_length = offsets[i + 1] - offsets[i];

// Compute masks for the last 16 byte stripe.
// For an empty string set number of stripes to 1 but mask to all zeroes.
//
int is_non_empty = length == 0 ? 0 : 1;
uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize) + (1 - is_non_empty);
int is_non_empty = key_length == 0 ? 0 : 1;
uint64_t num_stripes =
bit_util::CeilDiv(key_length, kStripeSize) + (1 - is_non_empty);
uint32_t mask1, mask2, mask3, mask4;
StripeMask(((length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
StripeMask(((key_length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
&mask2, &mask3, &mask4);

const uint8_t* key = concatenated_keys + offsets[i];
Expand All @@ -198,23 +199,24 @@ void Hashing32::HashVarLenImp(uint32_t num_rows, const T* offsets,

uint32_t last_stripe_copy[4];
for (uint32_t i = num_rows_safe; i < num_rows; ++i) {
uint64_t length = offsets[i + 1] - offsets[i];
uint64_t key_length = offsets[i + 1] - offsets[i];

// Compute masks for the last 16 byte stripe.
// For an empty string set number of stripes to 1 but mask to all zeroes.
//
int is_non_empty = length == 0 ? 0 : 1;
uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize) + (1 - is_non_empty);
int is_non_empty = key_length == 0 ? 0 : 1;
uint64_t num_stripes =
bit_util::CeilDiv(key_length, kStripeSize) + (1 - is_non_empty);
uint32_t mask1, mask2, mask3, mask4;
StripeMask(((length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
StripeMask(((key_length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
&mask2, &mask3, &mask4);

const uint8_t* key = concatenated_keys + offsets[i];
uint32_t acc1, acc2, acc3, acc4;
ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4);
if (length > 0) {
if (key_length > 0) {
memcpy(last_stripe_copy, key + (num_stripes - 1) * kStripeSize,
length - (num_stripes - 1) * kStripeSize);
key_length - (num_stripes - 1) * kStripeSize);
}
if (num_stripes > 0) {
ProcessLastStripe(mask1, mask2, mask3, mask4,
Expand Down Expand Up @@ -309,9 +311,9 @@ void Hashing32::HashIntImp(uint32_t num_keys, const T* keys, uint32_t* hashes) {
}
}

void Hashing32::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t length_key,
void Hashing32::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t key_length,
const uint8_t* keys, uint32_t* hashes) {
switch (length_key) {
switch (key_length) {
case sizeof(uint8_t):
if (combine_hashes) {
HashIntImp<true, uint8_t>(num_keys, keys, hashes);
Expand Down Expand Up @@ -352,27 +354,27 @@ void Hashing32::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t length_
}
}

void Hashing32::HashFixed(int64_t hardware_flags, bool combine_hashes, uint32_t num_rows,
uint64_t length, const uint8_t* keys, uint32_t* hashes,
uint32_t* hashes_temp_for_combine) {
if (ARROW_POPCOUNT64(length) == 1 && length <= sizeof(uint64_t)) {
HashInt(combine_hashes, num_rows, length, keys, hashes);
void Hashing32::HashFixed(int64_t hardware_flags, bool combine_hashes, uint32_t num_keys,
uint64_t key_length, const uint8_t* keys, uint32_t* hashes,
uint32_t* temp_hashes_for_combine) {
if (ARROW_POPCOUNT64(key_length) == 1 && key_length <= sizeof(uint64_t)) {
HashInt(combine_hashes, num_keys, key_length, keys, hashes);
return;
}

uint32_t num_processed = 0;
#if defined(ARROW_HAVE_RUNTIME_AVX2)
if (hardware_flags & arrow::internal::CpuInfo::AVX2) {
num_processed = HashFixedLen_avx2(combine_hashes, num_rows, length, keys, hashes,
hashes_temp_for_combine);
num_processed = HashFixedLen_avx2(combine_hashes, num_keys, key_length, keys, hashes,
temp_hashes_for_combine);
}
#endif
if (combine_hashes) {
HashFixedLenImp<true>(num_rows - num_processed, length, keys + length * num_processed,
hashes + num_processed);
HashFixedLenImp<true>(num_keys - num_processed, key_length,
keys + key_length * num_processed, hashes + num_processed);
} else {
HashFixedLenImp<false>(num_rows - num_processed, length,
keys + length * num_processed, hashes + num_processed);
HashFixedLenImp<false>(num_keys - num_processed, key_length,
keys + key_length * num_processed, hashes + num_processed);
}
}

Expand Down Expand Up @@ -423,13 +425,13 @@ void Hashing32::HashMultiColumn(const std::vector<KeyColumnArray>& cols,
}

if (cols[icol].metadata().is_fixed_length) {
uint32_t col_width = cols[icol].metadata().fixed_length;
if (col_width == 0) {
uint32_t key_length = cols[icol].metadata().fixed_length;
if (key_length == 0) {
HashBit(icol > 0, cols[icol].bit_offset(1), batch_size_next,
cols[icol].data(1) + first_row / 8, hashes + first_row);
} else {
HashFixed(ctx->hardware_flags, icol > 0, batch_size_next, col_width,
cols[icol].data(1) + first_row * col_width, hashes + first_row,
HashFixed(ctx->hardware_flags, icol > 0, batch_size_next, key_length,
cols[icol].data(1) + first_row * key_length, hashes + first_row,
hash_temp);
}
} else if (cols[icol].metadata().fixed_length == sizeof(uint32_t)) {
Expand Down Expand Up @@ -463,8 +465,9 @@ void Hashing32::HashMultiColumn(const std::vector<KeyColumnArray>& cols,
Status Hashing32::HashBatch(const ExecBatch& key_batch, uint32_t* hashes,
std::vector<KeyColumnArray>& column_arrays,
int64_t hardware_flags, util::TempVectorStack* temp_stack,
int64_t offset, int64_t length) {
RETURN_NOT_OK(ColumnArraysFromExecBatch(key_batch, offset, length, &column_arrays));
int64_t start_rows, int64_t num_rows) {
RETURN_NOT_OK(
ColumnArraysFromExecBatch(key_batch, start_rows, num_rows, &column_arrays));

LightContext ctx;
ctx.hardware_flags = hardware_flags;
Expand Down Expand Up @@ -574,23 +577,23 @@ inline void Hashing64::StripeMask(int i, uint64_t* mask1, uint64_t* mask2,
}

template <bool T_COMBINE_HASHES>
void Hashing64::HashFixedLenImp(uint32_t num_rows, uint64_t length, const uint8_t* keys,
uint64_t* hashes) {
void Hashing64::HashFixedLenImp(uint32_t num_rows, uint64_t key_length,
const uint8_t* keys, uint64_t* hashes) {
// Calculate the number of rows that skip the last 32 bytes
//
uint32_t num_rows_safe = num_rows;
while (num_rows_safe > 0 && (num_rows - num_rows_safe) * length < kStripeSize) {
while (num_rows_safe > 0 && (num_rows - num_rows_safe) * key_length < kStripeSize) {
--num_rows_safe;
}

// Compute masks for the last 32 byte stripe
//
uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize);
uint64_t num_stripes = bit_util::CeilDiv(key_length, kStripeSize);
uint64_t mask1, mask2, mask3, mask4;
StripeMask(((length - 1) & (kStripeSize - 1)) + 1, &mask1, &mask2, &mask3, &mask4);
StripeMask(((key_length - 1) & (kStripeSize - 1)) + 1, &mask1, &mask2, &mask3, &mask4);

for (uint32_t i = 0; i < num_rows_safe; ++i) {
const uint8_t* key = keys + static_cast<uint64_t>(i) * length;
const uint8_t* key = keys + static_cast<uint64_t>(i) * key_length;
uint64_t acc1, acc2, acc3, acc4;
ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4);
ProcessLastStripe(mask1, mask2, mask3, mask4, key + (num_stripes - 1) * kStripeSize,
Expand All @@ -607,11 +610,11 @@ void Hashing64::HashFixedLenImp(uint32_t num_rows, uint64_t length, const uint8_

uint64_t last_stripe_copy[4];
for (uint32_t i = num_rows_safe; i < num_rows; ++i) {
const uint8_t* key = keys + static_cast<uint64_t>(i) * length;
const uint8_t* key = keys + static_cast<uint64_t>(i) * key_length;
uint64_t acc1, acc2, acc3, acc4;
ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4);
memcpy(last_stripe_copy, key + (num_stripes - 1) * kStripeSize,
length - (num_stripes - 1) * kStripeSize);
key_length - (num_stripes - 1) * kStripeSize);
ProcessLastStripe(mask1, mask2, mask3, mask4,
reinterpret_cast<const uint8_t*>(last_stripe_copy), &acc1, &acc2,
&acc3, &acc4);
Expand All @@ -637,15 +640,16 @@ void Hashing64::HashVarLenImp(uint32_t num_rows, const T* offsets,
}

for (uint32_t i = 0; i < num_rows_safe; ++i) {
uint64_t length = offsets[i + 1] - offsets[i];
uint64_t key_length = offsets[i + 1] - offsets[i];

// Compute masks for the last 32 byte stripe.
// For an empty string set number of stripes to 1 but mask to all zeroes.
//
int is_non_empty = length == 0 ? 0 : 1;
uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize) + (1 - is_non_empty);
int is_non_empty = key_length == 0 ? 0 : 1;
uint64_t num_stripes =
bit_util::CeilDiv(key_length, kStripeSize) + (1 - is_non_empty);
uint64_t mask1, mask2, mask3, mask4;
StripeMask(((length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
StripeMask(((key_length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
&mask2, &mask3, &mask4);

const uint8_t* key = concatenated_keys + offsets[i];
Expand All @@ -667,22 +671,23 @@ void Hashing64::HashVarLenImp(uint32_t num_rows, const T* offsets,

uint64_t last_stripe_copy[4];
for (uint32_t i = num_rows_safe; i < num_rows; ++i) {
uint64_t length = offsets[i + 1] - offsets[i];
uint64_t key_length = offsets[i + 1] - offsets[i];

// Compute masks for the last 32 byte stripe
//
int is_non_empty = length == 0 ? 0 : 1;
uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize) + (1 - is_non_empty);
int is_non_empty = key_length == 0 ? 0 : 1;
uint64_t num_stripes =
bit_util::CeilDiv(key_length, kStripeSize) + (1 - is_non_empty);
uint64_t mask1, mask2, mask3, mask4;
StripeMask(((length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
StripeMask(((key_length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
&mask2, &mask3, &mask4);

const uint8_t* key = concatenated_keys + offsets[i];
uint64_t acc1, acc2, acc3, acc4;
ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4);
if (length > 0) {
if (key_length > 0) {
memcpy(last_stripe_copy, key + (num_stripes - 1) * kStripeSize,
length - (num_stripes - 1) * kStripeSize);
key_length - (num_stripes - 1) * kStripeSize);
}
if (num_stripes > 0) {
ProcessLastStripe(mask1, mask2, mask3, mask4,
Expand Down Expand Up @@ -759,9 +764,9 @@ void Hashing64::HashIntImp(uint32_t num_keys, const T* keys, uint64_t* hashes) {
}
}

void Hashing64::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t length_key,
void Hashing64::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t key_length,
const uint8_t* keys, uint64_t* hashes) {
switch (length_key) {
switch (key_length) {
case sizeof(uint8_t):
if (combine_hashes) {
HashIntImp<true, uint8_t>(num_keys, keys, hashes);
Expand Down Expand Up @@ -802,17 +807,17 @@ void Hashing64::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t length_
}
}

void Hashing64::HashFixed(bool combine_hashes, uint32_t num_rows, uint64_t length,
void Hashing64::HashFixed(bool combine_hashes, uint32_t num_keys, uint64_t key_length,
const uint8_t* keys, uint64_t* hashes) {
if (ARROW_POPCOUNT64(length) == 1 && length <= sizeof(uint64_t)) {
HashInt(combine_hashes, num_rows, length, keys, hashes);
if (ARROW_POPCOUNT64(key_length) == 1 && key_length <= sizeof(uint64_t)) {
HashInt(combine_hashes, num_keys, key_length, keys, hashes);
return;
}

if (combine_hashes) {
HashFixedLenImp<true>(num_rows, length, keys, hashes);
HashFixedLenImp<true>(num_keys, key_length, keys, hashes);
} else {
HashFixedLenImp<false>(num_rows, length, keys, hashes);
HashFixedLenImp<false>(num_keys, key_length, keys, hashes);
}
}

Expand Down Expand Up @@ -860,13 +865,13 @@ void Hashing64::HashMultiColumn(const std::vector<KeyColumnArray>& cols,
}

if (cols[icol].metadata().is_fixed_length) {
uint64_t col_width = cols[icol].metadata().fixed_length;
if (col_width == 0) {
uint64_t key_length = cols[icol].metadata().fixed_length;
if (key_length == 0) {
HashBit(icol > 0, cols[icol].bit_offset(1), batch_size_next,
cols[icol].data(1) + first_row / 8, hashes + first_row);
} else {
HashFixed(icol > 0, batch_size_next, col_width,
cols[icol].data(1) + first_row * col_width, hashes + first_row);
HashFixed(icol > 0, batch_size_next, key_length,
cols[icol].data(1) + first_row * key_length, hashes + first_row);
}
} else if (cols[icol].metadata().fixed_length == sizeof(uint32_t)) {
HashVarLen(icol > 0, batch_size_next, cols[icol].offsets() + first_row,
Expand Down Expand Up @@ -897,8 +902,9 @@ void Hashing64::HashMultiColumn(const std::vector<KeyColumnArray>& cols,
Status Hashing64::HashBatch(const ExecBatch& key_batch, uint64_t* hashes,
std::vector<KeyColumnArray>& column_arrays,
int64_t hardware_flags, util::TempVectorStack* temp_stack,
int64_t offset, int64_t length) {
RETURN_NOT_OK(ColumnArraysFromExecBatch(key_batch, offset, length, &column_arrays));
int64_t start_row, int64_t num_rows) {
RETURN_NOT_OK(
ColumnArraysFromExecBatch(key_batch, start_row, num_rows, &column_arrays));

LightContext ctx;
ctx.hardware_flags = hardware_flags;
Expand Down
Loading