diff --git a/zstd/benches/support/mod.rs b/zstd/benches/support/mod.rs index 08875d11..6305b6ab 100644 --- a/zstd/benches/support/mod.rs +++ b/zstd/benches/support/mod.rs @@ -1,6 +1,6 @@ // rand 0.10: SmallRng is available with default features (no `small_rng` flag needed). -// Use Rng::fill() instead of RngCore::fill_bytes(); RngCore removed from rand's public root in 0.10. -use rand::{Rng, SeedableRng, rngs::SmallRng}; +// Use RngExt::fill() instead of RngCore::fill_bytes(); RngCore removed from rand's public root in 0.10. +use rand::{RngExt, SeedableRng, rngs::SmallRng}; use std::{collections::HashSet, env, fs, path::Path}; use structured_zstd::encoding::CompressionLevel; diff --git a/zstd/src/bit_io/bit_reader_reverse.rs b/zstd/src/bit_io/bit_reader_reverse.rs index b6a1de5c..99f5a1df 100644 --- a/zstd/src/bit_io/bit_reader_reverse.rs +++ b/zstd/src/bit_io/bit_reader_reverse.rs @@ -99,6 +99,36 @@ impl<'s> BitReaderReversed<'s> { value } + /// Ensure at least `n` bits are available for subsequent unchecked reads. + /// After calling this, it is safe to call [`get_bits_unchecked`](Self::get_bits_unchecked) + /// for a combined total of up to `n` bits without individual refill checks. + /// + /// `n` must be at most 56. + #[inline(always)] + pub fn ensure_bits(&mut self, n: u8) { + debug_assert!(n <= 56); + if self.bits_consumed + n > 64 { + self.refill(); + } + } + + /// Read `n` bits from the source **without** checking whether a refill is + /// needed. The caller **must** guarantee enough bits are available (e.g. via + /// a prior [`ensure_bits`](Self::ensure_bits) call). + #[inline(always)] + pub fn get_bits_unchecked(&mut self, n: u8) -> u64 { + debug_assert!(n <= 56); + debug_assert!( + self.bits_consumed + n <= 64, + "get_bits_unchecked: not enough bits (consumed={}, requested={})", + self.bits_consumed, + n + ); + let value = self.peek_bits(n); + self.consume(n); + value + } + /// Get the next `n` bits from the source without consuming them. /// Caller is responsible for making sure that `n` many bits have been refilled. #[inline(always)] @@ -181,4 +211,59 @@ mod test { assert_eq!(br.get_bits(4), 0b0000); assert_eq!(br.bits_remaining(), -7); } + + /// Verify that `ensure_bits(n)` + `get_bits_unchecked(..)` returns the same + /// values as plain `get_bits(..)`, including across refill boundaries and + /// for edge cases like n=0. + #[test] + fn ensure_and_unchecked_match_get_bits() { + // 10 bytes = 80 bits — enough to force multiple refills + let data: [u8; 10] = [0xDE, 0xAD, 0xBE, 0xEF, 0x42, 0x13, 0x37, 0xCA, 0xFE, 0x01]; + + // Reference: read with get_bits + let mut ref_br = super::BitReaderReversed::new(&data); + let r1 = ref_br.get_bits(0); + let r2 = ref_br.get_bits(7); + let r3 = ref_br.get_bits(13); + let r4 = ref_br.get_bits(9); + let r5 = ref_br.get_bits(8); + let r5b = ref_br.get_bits(2); + // After 39 bits consumed, ensure_bits(26) triggers a real refill + // because 39 + 26 = 65 > 64. + let r6 = ref_br.get_bits(9); + let r7 = ref_br.get_bits(9); + let r8 = ref_br.get_bits(8); + + // Unchecked path: same reads via ensure_bits + get_bits_unchecked + let mut fast_br = super::BitReaderReversed::new(&data); + + // n=0 edge case + fast_br.ensure_bits(0); + assert_eq!(fast_br.get_bits_unchecked(0), r1); + + // Single reads + fast_br.ensure_bits(7); + assert_eq!(fast_br.get_bits_unchecked(7), r2); + + fast_br.ensure_bits(13); + assert_eq!(fast_br.get_bits_unchecked(13), r3); + + fast_br.ensure_bits(9); + assert_eq!(fast_br.get_bits_unchecked(9), r4); + + fast_br.ensure_bits(8); + assert_eq!(fast_br.get_bits_unchecked(8), r5); + + fast_br.ensure_bits(2); + assert_eq!(fast_br.get_bits_unchecked(2), r5b); + + // Batched: one ensure covering 9+9+8 = 26 bits. + // At 39 bits consumed, this forces a real refill (39+26=65 > 64). + fast_br.ensure_bits(26); + assert_eq!(fast_br.get_bits_unchecked(9), r6); + assert_eq!(fast_br.get_bits_unchecked(9), r7); + assert_eq!(fast_br.get_bits_unchecked(8), r8); + + assert_eq!(ref_br.bits_remaining(), fast_br.bits_remaining()); + } } diff --git a/zstd/src/decoding/sequence_section_decoder.rs b/zstd/src/decoding/sequence_section_decoder.rs index 2753fe0c..74a3baf7 100644 --- a/zstd/src/decoding/sequence_section_decoder.rs +++ b/zstd/src/decoding/sequence_section_decoder.rs @@ -69,6 +69,25 @@ fn decode_sequences_with_rle( target.clear(); target.reserve(section.num_sequences as usize); + // Only non-RLE decoders need state updates; compute their combined worst-case. + let max_update_bits = if scratch.ll_rle.is_none() { + scratch.literal_lengths.accuracy_log + } else { + 0 + } + if scratch.ml_rle.is_none() { + scratch.match_lengths.accuracy_log + } else { + 0 + } + if scratch.of_rle.is_none() { + scratch.offsets.accuracy_log + } else { + 0 + }; + debug_assert!( + max_update_bits <= 56, + "sequence section update bits exceed 56-bit budget" + ); + for _seq_idx in 0..section.num_sequences { //get the codes from either the RLE byte or from the decoder let ll_code = if let Some(ll_rle) = scratch.ll_rle { @@ -90,17 +109,6 @@ fn decode_sequences_with_rle( let (ll_value, ll_num_bits) = lookup_ll_code(ll_code); let (ml_value, ml_num_bits) = lookup_ml_code(ml_code); - //println!("Sequence: {}", i); - //println!("of stat: {}", of_dec.state); - //println!("of Code: {}", of_code); - //println!("ll stat: {}", ll_dec.state); - //println!("ll bits: {}", ll_num_bits); - //println!("ll Code: {}", ll_value); - //println!("ml stat: {}", ml_dec.state); - //println!("ml bits: {}", ml_num_bits); - //println!("ml Code: {}", ml_value); - //println!(""); - if of_code > MAX_OFFSET_CODE { return Err(DecodeSequenceError::UnsupportedOffset { offset_code: of_code, @@ -121,19 +129,18 @@ fn decode_sequences_with_rle( }); if target.len() < section.num_sequences as usize { - //println!( - // "Bits left: {} ({} bytes)", - // br.bits_remaining(), - // br.bits_remaining() / 8, - //); + // One refill check for all non-RLE state updates (batched fast path). + if max_update_bits > 0 { + br.ensure_bits(max_update_bits); + } if scratch.ll_rle.is_none() { - ll_dec.update_state(br); + ll_dec.update_state_fast(br); } if scratch.ml_rle.is_none() { - ml_dec.update_state(br); + ml_dec.update_state_fast(br); } if scratch.of_rle.is_none() { - of_dec.update_state(br); + of_dec.update_state_fast(br); } } @@ -168,6 +175,19 @@ fn decode_sequences_without_rle( target.clear(); target.reserve(section.num_sequences as usize); + // Maximum bits consumed by the three state updates combined. + // LL and ML accuracy logs are at most 9, OF at most 8, so the ceiling is 26. + // A single ensure_bits call (which guarantees ≥56 bits after refill) replaces + // three individual per-update refill checks, eliminating two branches per + // iteration on the hot decode path. + let max_update_bits = scratch.literal_lengths.accuracy_log + + scratch.match_lengths.accuracy_log + + scratch.offsets.accuracy_log; + debug_assert!( + max_update_bits <= 56, + "sequence section update bits exceed 56-bit budget" + ); + for _seq_idx in 0..section.num_sequences { let ll_code = ll_dec.decode_symbol(); let ml_code = ml_dec.decode_symbol(); @@ -196,14 +216,11 @@ fn decode_sequences_without_rle( }); if target.len() < section.num_sequences as usize { - //println!( - // "Bits left: {} ({} bytes)", - // br.bits_remaining(), - // br.bits_remaining() / 8, - //); - ll_dec.update_state(br); - ml_dec.update_state(br); - of_dec.update_state(br); + // One refill check for all three state updates (batched fast path). + br.ensure_bits(max_update_bits); + ll_dec.update_state_fast(br); + ml_dec.update_state_fast(br); + of_dec.update_state_fast(br); } if br.bits_remaining() < 0 { diff --git a/zstd/src/fse/fse_decoder.rs b/zstd/src/fse/fse_decoder.rs index 8d05e142..d39810d2 100644 --- a/zstd/src/fse/fse_decoder.rs +++ b/zstd/src/fse/fse_decoder.rs @@ -49,6 +49,22 @@ impl<'t> FSEDecoder<'t> { //println!("Update: {}, {} -> {}", base_line, add, self.state); } + + /// Advance the internal state **without** an individual refill check. + /// + /// The caller **must** guarantee that enough bits are available in the bit + /// reader (e.g. via [`BitReaderReversed::ensure_bits`] with a budget that + /// covers this and any other unchecked reads in the same batch). + /// + /// This is the "fast path" used in the interleaved sequence decode loop + /// where a single refill check covers all three FSE state updates. + #[inline(always)] + pub fn update_state_fast(&mut self, bits: &mut BitReaderReversed<'_>) { + let num_bits = self.state.num_bits; + let add = bits.get_bits_unchecked(num_bits); + let new_state = self.state.base_line + add as u32; + self.state = self.table.decode[new_state as usize]; + } } /// FSE decoding involves a decoding table that describes the probabilities of