Skip to content
146 changes: 134 additions & 12 deletions zstd/src/encoding/match_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@
use alloc::collections::VecDeque;
use alloc::vec::Vec;
#[cfg(all(target_arch = "aarch64", target_endian = "little"))]
use core::arch::aarch64::{uint8x16_t, vceqq_u8, vgetq_lane_u64, vld1q_u8, vreinterpretq_u64_u8};
use core::arch::aarch64::{
__crc32d, uint8x16_t, vceqq_u8, vgetq_lane_u64, vld1q_u8, vreinterpretq_u64_u8,
};
#[cfg(target_arch = "x86")]
use core::arch::x86::{
__m128i, __m256i, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm256_cmpeq_epi8,
_mm256_loadu_si256, _mm256_movemask_epi8,
};
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::{
__m128i, __m256i, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm256_cmpeq_epi8,
_mm256_loadu_si256, _mm256_movemask_epi8,
__m128i, __m256i, _mm_cmpeq_epi8, _mm_crc32_u64, _mm_loadu_si128, _mm_movemask_epi8,
_mm256_cmpeq_epi8, _mm256_loadu_si256, _mm256_movemask_epi8,
};
use core::convert::TryInto;
use core::num::NonZeroUsize;
Expand Down Expand Up @@ -59,6 +61,7 @@ const ROW_TARGET_LEN: usize = 48;
const ROW_TAG_BITS: usize = 8;
const ROW_EMPTY_SLOT: usize = usize::MAX;
const ROW_HASH_KEY_LEN: usize = 4;
const HASH_MIX_PRIME: u64 = 0x9E37_79B1_85EB_CA87;

Comment thread
polaz marked this conversation as resolved.
const HC_HASH_LOG: usize = 20;
const HC_CHAIN_LOG: usize = 19;
Expand All @@ -73,6 +76,79 @@ const HC_EMPTY: u32 = 0;
// fixed-length candidate array returned by chain_candidates().
const MAX_HC_SEARCH_DEPTH: usize = 32;

#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[repr(u8)]
enum HashMixKernel {
Scalar = 0,
#[cfg(target_arch = "x86_64")]
X86Sse42 = 1,
#[cfg(all(target_arch = "aarch64", target_endian = "little"))]
Aarch64Crc = 2,
}

#[inline(always)]
fn hash_mix_u64_with_kernel(value: u64, kernel: HashMixKernel) -> u64 {
match kernel {
HashMixKernel::Scalar => value.wrapping_mul(HASH_MIX_PRIME),
#[cfg(target_arch = "x86_64")]
HashMixKernel::X86Sse42 => {
// SAFETY: runtime/static detection selected this kernel.
unsafe { hash_mix_u64_sse42(value) }
}
#[cfg(all(target_arch = "aarch64", target_endian = "little"))]
HashMixKernel::Aarch64Crc => {
// SAFETY: runtime/static detection selected this kernel.
unsafe { hash_mix_u64_crc(value) }
}
}
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

#[inline(always)]
fn detect_hash_mix_kernel() -> HashMixKernel {
#[cfg(all(feature = "std", target_arch = "x86_64"))]
if is_x86_feature_detected!("sse4.2") {
return HashMixKernel::X86Sse42;
}

#[cfg(all(feature = "std", target_arch = "aarch64", target_endian = "little"))]
if is_aarch64_feature_detected!("crc") {
return HashMixKernel::Aarch64Crc;
}

#[cfg(all(not(feature = "std"), target_arch = "x86_64"))]
if cfg!(target_feature = "sse4.2") {
return HashMixKernel::X86Sse42;
}

#[cfg(all(
not(feature = "std"),
target_arch = "aarch64",
target_endian = "little"
))]
if cfg!(target_feature = "crc") {
return HashMixKernel::Aarch64Crc;
}

HashMixKernel::Scalar
}

#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.2")]
unsafe fn hash_mix_u64_sse42(value: u64) -> u64 {
let crc = _mm_crc32_u64(0, value);
((crc << 32) ^ value.rotate_left(13)).wrapping_mul(HASH_MIX_PRIME)
}

#[cfg(all(target_arch = "aarch64", target_endian = "little"))]
#[target_feature(enable = "crc")]
unsafe fn hash_mix_u64_crc(value: u64) -> u64 {
// Feed the full 64-bit lane through ARM CRC32 and then mix back with a
// rotated copy of the source to keep dispersion in the upper bits used by
// hash table indexing.
let crc = __crc32d(0, value) as u64;
((crc << 32) ^ value.rotate_left(17)).wrapping_mul(HASH_MIX_PRIME)
}

#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum PrefixKernel {
Scalar,
Expand Down Expand Up @@ -1469,6 +1545,7 @@ struct DfastMatchGenerator {
short_hash: Vec<[usize; DFAST_SEARCH_DEPTH]>,
long_hash: Vec<[usize; DFAST_SEARCH_DEPTH]>,
hash_bits: usize,
hash_mix_kernel: HashMixKernel,
use_fast_loop: bool,
// Lazy match lookahead depth (internal tuning parameter).
lazy_depth: u8,
Expand Down Expand Up @@ -1640,6 +1717,7 @@ impl DfastMatchGenerator {
short_hash: Vec::new(),
long_hash: Vec::new(),
hash_bits: DFAST_HASH_BITS,
hash_mix_kernel: detect_hash_mix_kernel(),
use_fast_loop: false,
lazy_depth: 1,
}
Expand Down Expand Up @@ -2191,8 +2269,7 @@ impl DfastMatchGenerator {
}

fn hash_index(&self, value: u64) -> usize {
const PRIME: u64 = 0x9E37_79B1_85EB_CA87;
((value.wrapping_mul(PRIME)) >> (64 - self.hash_bits)) as usize
(hash_mix_u64_with_kernel(value, self.hash_mix_kernel) >> (64 - self.hash_bits)) as usize
}
}

Expand All @@ -2209,6 +2286,7 @@ struct RowMatchGenerator {
search_depth: usize,
target_len: usize,
lazy_depth: u8,
hash_mix_kernel: HashMixKernel,
row_heads: Vec<u8>,
row_positions: Vec<usize>,
row_tags: Vec<u8>,
Expand All @@ -2229,6 +2307,7 @@ impl RowMatchGenerator {
search_depth: ROW_SEARCH_DEPTH,
target_len: ROW_TARGET_LEN,
lazy_depth: 1,
hash_mix_kernel: detect_hash_mix_kernel(),
row_heads: Vec::new(),
row_positions: Vec::new(),
row_tags: Vec::new(),
Expand Down Expand Up @@ -2421,8 +2500,7 @@ impl RowMatchGenerator {
}
let value =
u32::from_le_bytes(concat[idx..idx + ROW_HASH_KEY_LEN].try_into().unwrap()) as u64;
const PRIME: u64 = 0x9E37_79B1_85EB_CA87;
let hash = value.wrapping_mul(PRIME);
let hash = hash_mix_u64_with_kernel(value, self.hash_mix_kernel);
let total_bits = self.row_hash_log + ROW_TAG_BITS;
let combined = hash >> (u64::BITS as usize - total_bits);
let row_mask = (1usize << self.row_hash_log) - 1;
Expand Down Expand Up @@ -2794,8 +2872,7 @@ impl HcMatchGenerator {

fn hash_position(&self, data: &[u8]) -> usize {
let value = u32::from_le_bytes(data[..4].try_into().unwrap()) as u64;
const PRIME: u64 = 0x9E37_79B1_85EB_CA87;
((value.wrapping_mul(PRIME)) >> (64 - self.hash_log)) as usize
((value.wrapping_mul(HASH_MIX_PRIME)) >> (64 - self.hash_log)) as usize
}

fn relative_position(&self, abs_pos: usize) -> Option<u32> {
Expand Down Expand Up @@ -4170,7 +4247,7 @@ fn row_pick_lazy_depth2_keeps_best_when_next2_is_only_one_byte_better() {
assert_eq!(chosen.match_len, best.match_len);
}

/// Verifies row/tag extraction uses the high bits of the multiplicative hash.
/// Verifies row/tag extraction uses the shared hash mix bit-splitting contract.
#[test]
fn row_hash_and_row_extracts_high_bits() {
let mut matcher = RowMatchGenerator::new(1 << 22);
Expand All @@ -4192,8 +4269,7 @@ fn row_hash_and_row_extracts_high_bits() {
let idx = pos - matcher.history_abs_start;
let concat = matcher.live_history();
let value = u32::from_le_bytes(concat[idx..idx + ROW_HASH_KEY_LEN].try_into().unwrap()) as u64;
const PRIME: u64 = 0x9E37_79B1_85EB_CA87;
let hash = value.wrapping_mul(PRIME);
let hash = hash_mix_u64_with_kernel(value, matcher.hash_mix_kernel);
let total_bits = matcher.row_hash_log + ROW_TAG_BITS;
let combined = hash >> (u64::BITS as usize - total_bits);
let expected_row =
Expand Down Expand Up @@ -4228,6 +4304,52 @@ fn row_repcode_returns_none_when_position_too_close_to_history_end() {
assert!(matcher.repcode_candidate(4, 1).is_none());
}

#[cfg(all(feature = "std", target_arch = "x86_64"))]
#[test]
fn hash_mix_sse42_path_is_available_and_matches_accelerated_impl_when_supported() {
if !is_x86_feature_detected!("sse4.2") {
return;
}

let kernel = detect_hash_mix_kernel();
assert_eq!(kernel, HashMixKernel::X86Sse42);
let v = 0x0123_4567_89AB_CDEFu64;
let accelerated = unsafe { hash_mix_u64_sse42(v) };
assert_eq!(hash_mix_u64_with_kernel(v, kernel), accelerated);
}

#[cfg(all(feature = "std", target_arch = "x86_64"))]
#[test]
fn hash_mix_scalar_path_can_be_forced_for_coverage_and_matches_formula() {
let v = 0x0123_4567_89AB_CDEFu64;
let expected = v.wrapping_mul(HASH_MIX_PRIME);
let mixed = hash_mix_u64_with_kernel(v, HashMixKernel::Scalar);
assert_eq!(mixed, expected);
}

#[cfg(all(feature = "std", target_arch = "aarch64", target_endian = "little"))]
#[test]
fn hash_mix_crc_path_is_available_and_matches_accelerated_impl_when_supported() {
if !is_aarch64_feature_detected!("crc") {
return;
}

let kernel = detect_hash_mix_kernel();
assert_eq!(kernel, HashMixKernel::Aarch64Crc);
let v = 0x0123_4567_89AB_CDEFu64;
let accelerated = unsafe { hash_mix_u64_crc(v) };
assert_eq!(hash_mix_u64_with_kernel(v, kernel), accelerated);
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

#[cfg(all(feature = "std", target_arch = "aarch64", target_endian = "little"))]
#[test]
fn hash_mix_scalar_path_can_be_forced_on_aarch64_and_matches_formula() {
let v = 0x0123_4567_89AB_CDEFu64;
let expected = v.wrapping_mul(HASH_MIX_PRIME);
let mixed = hash_mix_u64_with_kernel(v, HashMixKernel::Scalar);
assert_eq!(mixed, expected);
}

#[test]
fn row_candidate_returns_none_when_abs_pos_near_end_of_history() {
let mut matcher = RowMatchGenerator::new(1 << 22);
Expand Down
Loading