Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions zstd/benches/support/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// rand 0.10: SmallRng is available with default features (no `small_rng` flag needed).
// Use Rng::fill() instead of RngCore::fill_bytes(); RngCore removed from rand's public root in 0.10.
use rand::{Rng, SeedableRng, rngs::SmallRng};
// Use RngExt::fill() instead of RngCore::fill_bytes(); RngCore removed from rand's public root in 0.10.
use rand::{RngExt, SeedableRng, rngs::SmallRng};
use std::{collections::HashSet, env, fs, path::Path};
use structured_zstd::encoding::CompressionLevel;

Expand Down
85 changes: 85 additions & 0 deletions zstd/src/bit_io/bit_reader_reverse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,36 @@ impl<'s> BitReaderReversed<'s> {
value
}

/// Ensure at least `n` bits are available for subsequent unchecked reads.
/// After calling this, it is safe to call [`get_bits_unchecked`](Self::get_bits_unchecked)
/// for a combined total of up to `n` bits without individual refill checks.
///
/// `n` must be at most 56.
Comment thread
polaz marked this conversation as resolved.
#[inline(always)]
pub fn ensure_bits(&mut self, n: u8) {
debug_assert!(n <= 56);
if self.bits_consumed + n > 64 {
self.refill();
}
}

/// Read `n` bits from the source **without** checking whether a refill is
/// needed. The caller **must** guarantee enough bits are available (e.g. via
/// a prior [`ensure_bits`](Self::ensure_bits) call).
#[inline(always)]
pub fn get_bits_unchecked(&mut self, n: u8) -> u64 {
debug_assert!(n <= 56);
debug_assert!(
self.bits_consumed + n <= 64,
"get_bits_unchecked: not enough bits (consumed={}, requested={})",
self.bits_consumed,
n
);
let value = self.peek_bits(n);
self.consume(n);
value
Comment thread
polaz marked this conversation as resolved.
}
Comment thread
polaz marked this conversation as resolved.

/// Get the next `n` bits from the source without consuming them.
/// Caller is responsible for making sure that `n` many bits have been refilled.
#[inline(always)]
Expand Down Expand Up @@ -181,4 +211,59 @@ mod test {
assert_eq!(br.get_bits(4), 0b0000);
assert_eq!(br.bits_remaining(), -7);
}

/// Verify that `ensure_bits(n)` + `get_bits_unchecked(..)` returns the same
/// values as plain `get_bits(..)`, including across refill boundaries and
/// for edge cases like n=0.
#[test]
fn ensure_and_unchecked_match_get_bits() {
// 10 bytes = 80 bits — enough to force multiple refills
let data: [u8; 10] = [0xDE, 0xAD, 0xBE, 0xEF, 0x42, 0x13, 0x37, 0xCA, 0xFE, 0x01];

// Reference: read with get_bits
let mut ref_br = super::BitReaderReversed::new(&data);
let r1 = ref_br.get_bits(0);
let r2 = ref_br.get_bits(7);
let r3 = ref_br.get_bits(13);
let r4 = ref_br.get_bits(9);
let r5 = ref_br.get_bits(8);
let r5b = ref_br.get_bits(2);
// After 39 bits consumed, ensure_bits(26) triggers a real refill
// because 39 + 26 = 65 > 64.
let r6 = ref_br.get_bits(9);
let r7 = ref_br.get_bits(9);
let r8 = ref_br.get_bits(8);

// Unchecked path: same reads via ensure_bits + get_bits_unchecked
let mut fast_br = super::BitReaderReversed::new(&data);

// n=0 edge case
fast_br.ensure_bits(0);
assert_eq!(fast_br.get_bits_unchecked(0), r1);

// Single reads
fast_br.ensure_bits(7);
assert_eq!(fast_br.get_bits_unchecked(7), r2);

fast_br.ensure_bits(13);
assert_eq!(fast_br.get_bits_unchecked(13), r3);

fast_br.ensure_bits(9);
assert_eq!(fast_br.get_bits_unchecked(9), r4);

fast_br.ensure_bits(8);
assert_eq!(fast_br.get_bits_unchecked(8), r5);

fast_br.ensure_bits(2);
assert_eq!(fast_br.get_bits_unchecked(2), r5b);

// Batched: one ensure covering 9+9+8 = 26 bits.
// At 39 bits consumed, this forces a real refill (39+26=65 > 64).
fast_br.ensure_bits(26);
assert_eq!(fast_br.get_bits_unchecked(9), r6);
assert_eq!(fast_br.get_bits_unchecked(9), r7);
assert_eq!(fast_br.get_bits_unchecked(8), r8);

assert_eq!(ref_br.bits_remaining(), fast_br.bits_remaining());
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
}
71 changes: 44 additions & 27 deletions zstd/src/decoding/sequence_section_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,25 @@ fn decode_sequences_with_rle(
target.clear();
target.reserve(section.num_sequences as usize);

// Only non-RLE decoders need state updates; compute their combined worst-case.
let max_update_bits = if scratch.ll_rle.is_none() {
scratch.literal_lengths.accuracy_log
} else {
0
} + if scratch.ml_rle.is_none() {
scratch.match_lengths.accuracy_log
} else {
0
} + if scratch.of_rle.is_none() {
scratch.offsets.accuracy_log
} else {
0
};
Comment thread
coderabbitai[bot] marked this conversation as resolved.
debug_assert!(
max_update_bits <= 56,
"sequence section update bits exceed 56-bit budget"
);

for _seq_idx in 0..section.num_sequences {
//get the codes from either the RLE byte or from the decoder
let ll_code = if let Some(ll_rle) = scratch.ll_rle {
Expand All @@ -90,17 +109,6 @@ fn decode_sequences_with_rle(
let (ll_value, ll_num_bits) = lookup_ll_code(ll_code);
let (ml_value, ml_num_bits) = lookup_ml_code(ml_code);

//println!("Sequence: {}", i);
//println!("of stat: {}", of_dec.state);
//println!("of Code: {}", of_code);
//println!("ll stat: {}", ll_dec.state);
//println!("ll bits: {}", ll_num_bits);
//println!("ll Code: {}", ll_value);
//println!("ml stat: {}", ml_dec.state);
//println!("ml bits: {}", ml_num_bits);
//println!("ml Code: {}", ml_value);
//println!("");

if of_code > MAX_OFFSET_CODE {
return Err(DecodeSequenceError::UnsupportedOffset {
offset_code: of_code,
Expand All @@ -121,19 +129,18 @@ fn decode_sequences_with_rle(
});

if target.len() < section.num_sequences as usize {
//println!(
// "Bits left: {} ({} bytes)",
// br.bits_remaining(),
// br.bits_remaining() / 8,
//);
// One refill check for all non-RLE state updates (batched fast path).
if max_update_bits > 0 {
br.ensure_bits(max_update_bits);
}
if scratch.ll_rle.is_none() {
ll_dec.update_state(br);
ll_dec.update_state_fast(br);
}
if scratch.ml_rle.is_none() {
ml_dec.update_state(br);
ml_dec.update_state_fast(br);
}
if scratch.of_rle.is_none() {
of_dec.update_state(br);
of_dec.update_state_fast(br);
}
}

Expand Down Expand Up @@ -168,6 +175,19 @@ fn decode_sequences_without_rle(
target.clear();
target.reserve(section.num_sequences as usize);

// Maximum bits consumed by the three state updates combined.
// LL and ML accuracy logs are at most 9, OF at most 8, so the ceiling is 26.
// A single ensure_bits call (which guarantees ≥56 bits after refill) replaces
// three individual per-update refill checks, eliminating two branches per
// iteration on the hot decode path.
let max_update_bits = scratch.literal_lengths.accuracy_log
+ scratch.match_lengths.accuracy_log
+ scratch.offsets.accuracy_log;
debug_assert!(
max_update_bits <= 56,
"sequence section update bits exceed 56-bit budget"
);

for _seq_idx in 0..section.num_sequences {
let ll_code = ll_dec.decode_symbol();
let ml_code = ml_dec.decode_symbol();
Expand Down Expand Up @@ -196,14 +216,11 @@ fn decode_sequences_without_rle(
});

if target.len() < section.num_sequences as usize {
//println!(
// "Bits left: {} ({} bytes)",
// br.bits_remaining(),
// br.bits_remaining() / 8,
//);
ll_dec.update_state(br);
ml_dec.update_state(br);
of_dec.update_state(br);
// One refill check for all three state updates (batched fast path).
br.ensure_bits(max_update_bits);
ll_dec.update_state_fast(br);
ml_dec.update_state_fast(br);
of_dec.update_state_fast(br);
}

if br.bits_remaining() < 0 {
Expand Down
16 changes: 16 additions & 0 deletions zstd/src/fse/fse_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@ impl<'t> FSEDecoder<'t> {

//println!("Update: {}, {} -> {}", base_line, add, self.state);
}

/// Advance the internal state **without** an individual refill check.
///
/// The caller **must** guarantee that enough bits are available in the bit
/// reader (e.g. via [`BitReaderReversed::ensure_bits`] with a budget that
/// covers this and any other unchecked reads in the same batch).
///
/// This is the "fast path" used in the interleaved sequence decode loop
/// where a single refill check covers all three FSE state updates.
#[inline(always)]
pub fn update_state_fast(&mut self, bits: &mut BitReaderReversed<'_>) {
let num_bits = self.state.num_bits;
let add = bits.get_bits_unchecked(num_bits);
let new_state = self.state.base_line + add as u32;
self.state = self.table.decode[new_state as usize];
}
}

/// FSE decoding involves a decoding table that describes the probabilities of
Expand Down
Loading