From 45db0d7035b6393e93d6c73a321b099323d99d0b Mon Sep 17 00:00:00 2001 From: Dmitry Prudnikov Date: Mon, 6 Apr 2026 18:54:05 +0300 Subject: [PATCH 1/8] perf(fse): pack decoder entry to 4-byte layout - replace Entry.base_line(u32) with Entry.new_state(u16) - keep decode transition semantics (new_state + low bits) - update FSE/sequence tests and add size assertion for packed entry --- zstd/src/decoding/sequence_section_decoder.rs | 12 ++++----- zstd/src/fse/fse_decoder.rs | 27 ++++++++++--------- zstd/src/fse/mod.rs | 7 ++++- 3 files changed, 26 insertions(+), 20 deletions(-) diff --git a/zstd/src/decoding/sequence_section_decoder.rs b/zstd/src/decoding/sequence_section_decoder.rs index 74a3baf7..f203e966 100644 --- a/zstd/src/decoding/sequence_section_decoder.rs +++ b/zstd/src/decoding/sequence_section_decoder.rs @@ -475,7 +475,7 @@ fn test_ll_default() { idx, table.decode[idx].symbol, table.decode[idx].num_bits, - table.decode[idx].base_line + table.decode[idx].new_state ); } @@ -484,21 +484,21 @@ fn test_ll_default() { //just test a few values. TODO test all values assert!(table.decode[0].symbol == 0); assert!(table.decode[0].num_bits == 4); - assert!(table.decode[0].base_line == 0); + assert!(table.decode[0].new_state == 0); assert!(table.decode[19].symbol == 27); assert!(table.decode[19].num_bits == 6); - assert!(table.decode[19].base_line == 0); + assert!(table.decode[19].new_state == 0); assert!(table.decode[39].symbol == 25); assert!(table.decode[39].num_bits == 4); - assert!(table.decode[39].base_line == 16); + assert!(table.decode[39].new_state == 16); assert!(table.decode[60].symbol == 35); assert!(table.decode[60].num_bits == 6); - assert!(table.decode[60].base_line == 0); + assert!(table.decode[60].new_state == 0); assert!(table.decode[59].symbol == 24); assert!(table.decode[59].num_bits == 5); - assert!(table.decode[59].base_line == 32); + assert!(table.decode[59].new_state == 32); } diff --git a/zstd/src/fse/fse_decoder.rs b/zstd/src/fse/fse_decoder.rs index d39810d2..0a31402f 100644 --- a/zstd/src/fse/fse_decoder.rs +++ b/zstd/src/fse/fse_decoder.rs @@ -14,7 +14,7 @@ impl<'t> FSEDecoder<'t> { pub fn new(table: &'t FSETable) -> FSEDecoder<'t> { FSEDecoder { state: table.decode.first().copied().unwrap_or(Entry { - base_line: 0, + new_state: 0, num_bits: 0, symbol: 0, }), @@ -43,11 +43,10 @@ impl<'t> FSEDecoder<'t> { pub fn update_state(&mut self, bits: &mut BitReaderReversed<'_>) { let num_bits = self.state.num_bits; let add = bits.get_bits(num_bits); - let base_line = self.state.base_line; - let new_state = base_line + add as u32; - self.state = self.table.decode[new_state as usize]; + let next_state = self.state.new_state + add as u16; + self.state = self.table.decode[next_state as usize]; - //println!("Update: {}, {} -> {}", base_line, add, self.state); + //println!("Update: {}, {} -> {}", self.state.new_state, add, self.state); } /// Advance the internal state **without** an individual refill check. @@ -62,8 +61,8 @@ impl<'t> FSEDecoder<'t> { pub fn update_state_fast(&mut self, bits: &mut BitReaderReversed<'_>) { let num_bits = self.state.num_bits; let add = bits.get_bits_unchecked(num_bits); - let new_state = self.state.base_line + add as u32; - self.state = self.table.decode[new_state as usize]; + let next_state = self.state.new_state + add as u16; + self.state = self.table.decode[next_state as usize]; } } @@ -183,7 +182,7 @@ impl FSETable { self.decode.resize( table_size, Entry { - base_line: 0, + new_state: 0, num_bits: 0, symbol: 0, }, @@ -197,7 +196,7 @@ impl FSETable { negative_idx -= 1; let entry = &mut self.decode[negative_idx]; entry.symbol = symbol as u8; - entry.base_line = 0; + entry.new_state = 0; entry.num_bits = self.accuracy_log; } } @@ -241,7 +240,8 @@ impl FSETable { assert!(nb <= self.accuracy_log); self.symbol_counter[symbol as usize] += 1; - entry.base_line = bl; + assert!(u16::try_from(bl).is_ok(), "next_state must fit in u16"); + entry.new_state = bl as u16; entry.num_bits = nb; } Ok(()) @@ -336,11 +336,12 @@ impl FSETable { } /// A single entry in an FSE table. +#[repr(C)] #[derive(Copy, Clone, Debug)] pub struct Entry { - /// This value is used as an offset value, and it is added - /// to a value read from the stream to determine the next state value. - pub base_line: u32, + /// This value is used as an offset value, and it is added to the bits read + /// from the stream to determine the next state value. + pub new_state: u16, /// How many bits should be read from the stream when decoding this entry. pub num_bits: u8, /// The byte that should be put in the decode output when encountering this state. diff --git a/zstd/src/fse/mod.rs b/zstd/src/fse/mod.rs index f2a8b7b2..10022387 100644 --- a/zstd/src/fse/mod.rs +++ b/zstd/src/fse/mod.rs @@ -18,6 +18,11 @@ pub use fse_decoder::*; pub mod fse_encoder; +#[test] +fn decoder_entry_is_packed_4_bytes() { + assert_eq!(core::mem::size_of::(), 4); +} + #[test] fn tables_equal() { let probs = &[0, 0, -1, 3, 2, 2, (1 << 6) - 8]; @@ -37,7 +42,7 @@ fn check_tables(dec_table: &fse_decoder::FSETable, enc_table: &fse_encoder::FSET .iter() .find(|state| state.index == idx) .unwrap(); - assert_eq!(enc_state.baseline, dec_state.base_line as usize); + assert_eq!(enc_state.baseline, dec_state.new_state as usize); assert_eq!(enc_state.num_bits, dec_state.num_bits); } } From 1a69aab492ce5c6e549b7eb76fa14f2d062e9ed7 Mon Sep 17 00:00:00 2001 From: Dmitry Prudnikov Date: Mon, 6 Apr 2026 18:58:33 +0300 Subject: [PATCH 2/8] perf(fse): align decode tables and bulk spread symbols --- zstd/src/decoding/scratch.rs | 45 ++++++++++++++++++++++------ zstd/src/fse/fse_decoder.rs | 58 ++++++++++++++++++++++++++++-------- zstd/src/fse/mod.rs | 3 ++ 3 files changed, 85 insertions(+), 21 deletions(-) diff --git a/zstd/src/decoding/scratch.rs b/zstd/src/decoding/scratch.rs index cf963459..0d503fbb 100644 --- a/zstd/src/decoding/scratch.rs +++ b/zstd/src/decoding/scratch.rs @@ -6,6 +6,7 @@ use crate::decoding::dictionary::Dictionary; use crate::fse::FSETable; use crate::huff0::HuffmanTable; use alloc::vec::Vec; +use core::ops::{Deref, DerefMut}; use crate::blocks::sequence_section::{ MAX_LITERAL_LENGTH_CODE, MAX_MATCH_LENGTH_CODE, MAX_OFFSET_CODE, @@ -33,11 +34,11 @@ impl DecoderScratch { table: HuffmanTable::new(), }, fse: FSEScratch { - offsets: FSETable::new(MAX_OFFSET_CODE), + offsets: AlignedFSETable::new(MAX_OFFSET_CODE), of_rle: None, - literal_lengths: FSETable::new(MAX_LITERAL_LENGTH_CODE), + literal_lengths: AlignedFSETable::new(MAX_LITERAL_LENGTH_CODE), ll_rle: None, - match_lengths: FSETable::new(MAX_MATCH_LENGTH_CODE), + match_lengths: AlignedFSETable::new(MAX_MATCH_LENGTH_CODE), ml_rle: None, }, buffer: DecodeBuffer::new(window_size), @@ -97,22 +98,22 @@ impl Default for HuffmanScratch { } pub struct FSEScratch { - pub offsets: FSETable, + pub offsets: AlignedFSETable, pub of_rle: Option, - pub literal_lengths: FSETable, + pub literal_lengths: AlignedFSETable, pub ll_rle: Option, - pub match_lengths: FSETable, + pub match_lengths: AlignedFSETable, pub ml_rle: Option, } impl FSEScratch { pub fn new() -> FSEScratch { FSEScratch { - offsets: FSETable::new(MAX_OFFSET_CODE), + offsets: AlignedFSETable::new(MAX_OFFSET_CODE), of_rle: None, - literal_lengths: FSETable::new(MAX_LITERAL_LENGTH_CODE), + literal_lengths: AlignedFSETable::new(MAX_LITERAL_LENGTH_CODE), ll_rle: None, - match_lengths: FSETable::new(MAX_MATCH_LENGTH_CODE), + match_lengths: AlignedFSETable::new(MAX_MATCH_LENGTH_CODE), ml_rle: None, } } @@ -132,3 +133,29 @@ impl Default for FSEScratch { Self::new() } } + +// Keep LL/ML/OF table objects cache-line aligned to avoid cross-table placement +// effects when they are accessed in the same decode hot loop. +#[cfg_attr(target_arch = "aarch64", repr(align(128)))] +#[cfg_attr(not(target_arch = "aarch64"), repr(align(64)))] +pub struct AlignedFSETable(FSETable); + +impl AlignedFSETable { + fn new(max_symbol: u8) -> Self { + Self(FSETable::new(max_symbol)) + } +} + +impl Deref for AlignedFSETable { + type Target = FSETable; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for AlignedFSETable { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} diff --git a/zstd/src/fse/fse_decoder.rs b/zstd/src/fse/fse_decoder.rs index 0a31402f..5cf6e11b 100644 --- a/zstd/src/fse/fse_decoder.rs +++ b/zstd/src/fse/fse_decoder.rs @@ -1,6 +1,8 @@ use crate::bit_io::{BitReader, BitReaderReversed}; use crate::decoding::errors::{FSEDecoderError, FSETableError}; +use alloc::vec; use alloc::vec::Vec; +use core::ptr; pub struct FSEDecoder<'table> { /// An FSE state value represents an index in the FSE table. @@ -15,8 +17,8 @@ impl<'t> FSEDecoder<'t> { FSEDecoder { state: table.decode.first().copied().unwrap_or(Entry { new_state: 0, - num_bits: 0, symbol: 0, + num_bits: 0, }), table, } @@ -183,21 +185,19 @@ impl FSETable { table_size, Entry { new_state: 0, - num_bits: 0, symbol: 0, + num_bits: 0, }, ); + let mut table_symbols = vec![0u8; table_size]; let mut negative_idx = table_size; //will point to the highest index with is already occupied by a negative-probability-symbol //first scan for all -1 probabilities and place them at the top of the table for symbol in 0..self.symbol_probabilities.len() { if self.symbol_probabilities[symbol] == -1 { negative_idx -= 1; - let entry = &mut self.decode[negative_idx]; - entry.symbol = symbol as u8; - entry.new_state = 0; - entry.num_bits = self.accuracy_log; + table_symbols[negative_idx] = symbol as u8; } } @@ -212,8 +212,7 @@ impl FSETable { //for each probability point the symbol gets on slot let prob = self.symbol_probabilities[idx]; for _ in 0..prob { - let entry = &mut self.decode[position]; - entry.symbol = symbol; + table_symbols[position] = symbol; position = next_position(position, table_size); while position >= negative_idx { @@ -223,6 +222,11 @@ impl FSETable { } } + self.copy_symbols_into_decode(&table_symbols); + for idx in negative_idx..table_size { + self.decode[idx].num_bits = self.accuracy_log; + } + // baselines and num_bits can only be calculated when all symbols have been spread self.symbol_counter.clear(); self.symbol_counter @@ -247,6 +251,36 @@ impl FSETable { Ok(()) } + fn copy_symbols_into_decode(&mut self, table_symbols: &[u8]) { + debug_assert_eq!(table_symbols.len(), self.decode.len()); + + #[cfg(target_endian = "little")] + { + // Write two packed entries (8 bytes) at once: + // Entry bytes are [new_state_lo, new_state_hi, symbol, num_bits]. + let mut idx = 0usize; + while idx + 1 < table_symbols.len() { + let packed = + ((table_symbols[idx] as u64) << 16) | ((table_symbols[idx + 1] as u64) << 48); + // SAFETY: `idx + 1 < len`, so at least 8 bytes remain. Unaligned writes are intentional. + unsafe { + ptr::write_unaligned(self.decode.as_mut_ptr().add(idx).cast::(), packed); + } + idx += 2; + } + if idx < table_symbols.len() { + self.decode[idx].symbol = table_symbols[idx]; + } + } + + #[cfg(not(target_endian = "little"))] + { + for (entry, symbol) in self.decode.iter_mut().zip(table_symbols.iter().copied()) { + entry.symbol = symbol; + } + } + } + /// Read the accuracy log and the probability table from the source and return the number of bytes /// read. If the size of the table is larger than the provided `max_log`, return an error. fn read_probabilities(&mut self, source: &[u8], max_log: u8) -> Result { @@ -339,13 +373,13 @@ impl FSETable { #[repr(C)] #[derive(Copy, Clone, Debug)] pub struct Entry { - /// This value is used as an offset value, and it is added to the bits read - /// from the stream to determine the next state value. + /// Base index for the next state. The low bits read from the bitstream are + /// added to this value to produce the final state index. pub new_state: u16, - /// How many bits should be read from the stream when decoding this entry. - pub num_bits: u8, /// The byte that should be put in the decode output when encountering this state. pub symbol: u8, + /// How many bits should be read from the stream when decoding this entry. + pub num_bits: u8, } /// This value is added to the first 4 bits of the stream to determine the diff --git a/zstd/src/fse/mod.rs b/zstd/src/fse/mod.rs index 10022387..4ddf3818 100644 --- a/zstd/src/fse/mod.rs +++ b/zstd/src/fse/mod.rs @@ -21,6 +21,9 @@ pub mod fse_encoder; #[test] fn decoder_entry_is_packed_4_bytes() { assert_eq!(core::mem::size_of::(), 4); + assert_eq!(core::mem::offset_of!(fse_decoder::Entry, new_state), 0); + assert_eq!(core::mem::offset_of!(fse_decoder::Entry, symbol), 2); + assert_eq!(core::mem::offset_of!(fse_decoder::Entry, num_bits), 3); } #[test] From e7489003d6d7643318a310249bdde4ce6f005505 Mon Sep 17 00:00:00 2001 From: Dmitry Prudnikov Date: Mon, 6 Apr 2026 23:39:49 +0300 Subject: [PATCH 3/8] fix(fse): validate acc log bounds and reuse spread buffer --- zstd/src/decoding/scratch.rs | 5 +- zstd/src/fse/fse_decoder.rs | 95 ++++++++++++++++++++++++------------ zstd/src/fse/mod.rs | 22 +++++++++ 3 files changed, 88 insertions(+), 34 deletions(-) diff --git a/zstd/src/decoding/scratch.rs b/zstd/src/decoding/scratch.rs index 0d503fbb..d06b2045 100644 --- a/zstd/src/decoding/scratch.rs +++ b/zstd/src/decoding/scratch.rs @@ -134,8 +134,9 @@ impl Default for FSEScratch { } } -// Keep LL/ML/OF table objects cache-line aligned to avoid cross-table placement -// effects when they are accessed in the same decode hot loop. +// Keep LL/ML/OF table *objects* cache-line aligned to avoid cross-table placement +// effects in DecoderScratch when they are accessed in the same decode hot loop. +// Note: this aligns the table containers, not the `Vec` backing allocations. #[cfg_attr(target_arch = "aarch64", repr(align(128)))] #[cfg_attr(not(target_arch = "aarch64"), repr(align(64)))] pub struct AlignedFSETable(FSETable); diff --git a/zstd/src/fse/fse_decoder.rs b/zstd/src/fse/fse_decoder.rs index 5cf6e11b..5ef837fe 100644 --- a/zstd/src/fse/fse_decoder.rs +++ b/zstd/src/fse/fse_decoder.rs @@ -1,6 +1,5 @@ use crate::bit_io::{BitReader, BitReaderReversed}; use crate::decoding::errors::{FSEDecoderError, FSETableError}; -use alloc::vec; use alloc::vec::Vec; use core::ptr; @@ -45,8 +44,8 @@ impl<'t> FSEDecoder<'t> { pub fn update_state(&mut self, bits: &mut BitReaderReversed<'_>) { let num_bits = self.state.num_bits; let add = bits.get_bits(num_bits); - let next_state = self.state.new_state + add as u16; - self.state = self.table.decode[next_state as usize]; + let next_state = usize::from(self.state.new_state) + add as usize; + self.state = self.table.decode[next_state]; //println!("Update: {}, {} -> {}", self.state.new_state, add, self.state); } @@ -63,8 +62,8 @@ impl<'t> FSEDecoder<'t> { pub fn update_state_fast(&mut self, bits: &mut BitReaderReversed<'_>) { let num_bits = self.state.num_bits; let add = bits.get_bits_unchecked(num_bits); - let next_state = self.state.new_state + add as u16; - self.state = self.table.decode[next_state as usize]; + let next_state = usize::from(self.state.new_state) + add as usize; + self.state = self.table.decode[next_state]; } } @@ -79,6 +78,8 @@ pub struct FSETable { /// The actual table containing the decoded symbol and the compression data /// connected to that symbol. pub decode: Vec, //used to decode symbols, and calculate the next state + /// Reused scratch buffer for symbol spreading to avoid per-build allocations. + symbol_spread_buffer: Vec, /// The size of the table is stored in logarithm base 2 format, /// with the **size of the table** being equal to `(1 << accuracy_log)`. /// This value is used so that the decoder knows how many bits to read from the bitstream. @@ -106,7 +107,8 @@ impl FSETable { max_symbol, symbol_probabilities: Vec::with_capacity(256), //will never be more than 256 symbols because u8 symbol_counter: Vec::with_capacity(256), //will never be more than 256 symbols because u8 - decode: Vec::new(), //depending on acc_log. + symbol_spread_buffer: Vec::new(), + decode: Vec::new(), //depending on acc_log. accuracy_log: 0, } } @@ -125,6 +127,7 @@ impl FSETable { pub fn reset(&mut self) { self.symbol_counter.clear(); self.symbol_probabilities.clear(); + self.symbol_spread_buffer.clear(); self.decode.clear(); self.accuracy_log = 0; } @@ -143,6 +146,12 @@ impl FSETable { /// returns how many BYTEs (not bits) were read while building the decoder pub fn build_decoder(&mut self, source: &[u8], max_log: u8) -> Result { + if max_log > ENTRY_MAX_ACCURACY_LOG { + return Err(FSETableError::AccLogTooBig { + got: max_log, + max: ENTRY_MAX_ACCURACY_LOG, + }); + } self.accuracy_log = 0; let bytes_read = self.read_probabilities(source, max_log)?; @@ -160,6 +169,12 @@ impl FSETable { if acc_log == 0 { return Err(FSETableError::AccLogIsZero); } + if acc_log > ENTRY_MAX_ACCURACY_LOG { + return Err(FSETableError::AccLogTooBig { + got: acc_log, + max: ENTRY_MAX_ACCURACY_LOG, + }); + } self.symbol_probabilities = probs.to_vec(); self.accuracy_log = acc_log; self.build_decoding_table() @@ -190,39 +205,46 @@ impl FSETable { }, ); - let mut table_symbols = vec![0u8; table_size]; - let mut negative_idx = table_size; //will point to the highest index with is already occupied by a negative-probability-symbol - - //first scan for all -1 probabilities and place them at the top of the table - for symbol in 0..self.symbol_probabilities.len() { - if self.symbol_probabilities[symbol] == -1 { - negative_idx -= 1; - table_symbols[negative_idx] = symbol as u8; + let mut table_symbols = core::mem::take(&mut self.symbol_spread_buffer); + table_symbols.clear(); + table_symbols.resize(table_size, 0); + let negative_idx = { + let table_symbols = &mut table_symbols; + let mut negative_idx = table_size; //will point to the highest index with is already occupied by a negative-probability-symbol + + //first scan for all -1 probabilities and place them at the top of the table + for symbol in 0..self.symbol_probabilities.len() { + if self.symbol_probabilities[symbol] == -1 { + negative_idx -= 1; + table_symbols[negative_idx] = symbol as u8; + } } - } - //then place in a semi-random order all of the other symbols - let mut position = 0; - for idx in 0..self.symbol_probabilities.len() { - let symbol = idx as u8; - if self.symbol_probabilities[idx] <= 0 { - continue; - } + //then place in a semi-random order all of the other symbols + let mut position = 0; + for idx in 0..self.symbol_probabilities.len() { + let symbol = idx as u8; + if self.symbol_probabilities[idx] <= 0 { + continue; + } - //for each probability point the symbol gets on slot - let prob = self.symbol_probabilities[idx]; - for _ in 0..prob { - table_symbols[position] = symbol; + //for each probability point the symbol gets on slot + let prob = self.symbol_probabilities[idx]; + for _ in 0..prob { + table_symbols[position] = symbol; - position = next_position(position, table_size); - while position >= negative_idx { position = next_position(position, table_size); - //everything above negative_idx is already taken + while position >= negative_idx { + position = next_position(position, table_size); + //everything above negative_idx is already taken + } } } - } + negative_idx + }; self.copy_symbols_into_decode(&table_symbols); + self.symbol_spread_buffer = table_symbols; for idx in negative_idx..table_size { self.decode[idx].num_bits = self.accuracy_log; } @@ -244,8 +266,10 @@ impl FSETable { assert!(nb <= self.accuracy_log); self.symbol_counter[symbol as usize] += 1; - assert!(u16::try_from(bl).is_ok(), "next_state must fit in u16"); - entry.new_state = bl as u16; + entry.new_state = u16::try_from(bl).map_err(|_| FSETableError::AccLogTooBig { + got: self.accuracy_log, + max: ENTRY_MAX_ACCURACY_LOG, + })?; entry.num_bits = nb; } Ok(()) @@ -288,6 +312,12 @@ impl FSETable { let mut br = BitReader::new(source); self.accuracy_log = ACC_LOG_OFFSET + (br.get_bits(4)? as u8); + if self.accuracy_log > ENTRY_MAX_ACCURACY_LOG { + return Err(FSETableError::AccLogTooBig { + got: self.accuracy_log, + max: ENTRY_MAX_ACCURACY_LOG, + }); + } if self.accuracy_log > max_log { return Err(FSETableError::AccLogTooBig { got: self.accuracy_log, @@ -385,6 +415,7 @@ pub struct Entry { /// This value is added to the first 4 bits of the stream to determine the /// `Accuracy_Log` const ACC_LOG_OFFSET: u8 = 5; +const ENTRY_MAX_ACCURACY_LOG: u8 = 16; fn highest_bit_set(x: u32) -> u32 { assert!(x > 0); diff --git a/zstd/src/fse/mod.rs b/zstd/src/fse/mod.rs index 4ddf3818..7830458f 100644 --- a/zstd/src/fse/mod.rs +++ b/zstd/src/fse/mod.rs @@ -26,6 +26,28 @@ fn decoder_entry_is_packed_4_bytes() { assert_eq!(core::mem::offset_of!(fse_decoder::Entry, num_bits), 3); } +#[test] +fn build_from_probabilities_rejects_acc_log_over_entry_limit() { + let mut dec_table = FSETable::new(255); + let err = dec_table + .build_from_probabilities(17, &[1, 1, 1, 1]) + .unwrap_err(); + assert!(matches!( + err, + crate::decoding::errors::FSETableError::AccLogTooBig { got: 17, max: 16 } + )); +} + +#[test] +fn build_decoder_rejects_max_log_over_entry_limit() { + let mut dec_table = FSETable::new(255); + let err = dec_table.build_decoder(&[], 17).unwrap_err(); + assert!(matches!( + err, + crate::decoding::errors::FSETableError::AccLogTooBig { got: 17, max: 16 } + )); +} + #[test] fn tables_equal() { let probs = &[0, 0, -1, 3, 2, 2, (1 << 6) - 8]; From cb6f0c7dcba9aa0429c1c976ce14363b5edb6717 Mon Sep 17 00:00:00 2001 From: Dmitry Prudnikov Date: Tue, 7 Apr 2026 00:08:43 +0300 Subject: [PATCH 4/8] fix(fse): address review feedback for packed decode path - Clamp build_decoder max_log to entry layout limit instead of early reject - Add explicit layout assertions and tighten unsafe write safety invariants - Update regression test to validate decoder path behavior --- zstd/src/fse/fse_decoder.rs | 15 ++++++++------- zstd/src/fse/mod.rs | 6 +++--- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/zstd/src/fse/fse_decoder.rs b/zstd/src/fse/fse_decoder.rs index 5ef837fe..73e42ff8 100644 --- a/zstd/src/fse/fse_decoder.rs +++ b/zstd/src/fse/fse_decoder.rs @@ -146,12 +146,7 @@ impl FSETable { /// returns how many BYTEs (not bits) were read while building the decoder pub fn build_decoder(&mut self, source: &[u8], max_log: u8) -> Result { - if max_log > ENTRY_MAX_ACCURACY_LOG { - return Err(FSETableError::AccLogTooBig { - got: max_log, - max: ENTRY_MAX_ACCURACY_LOG, - }); - } + let max_log = max_log.min(ENTRY_MAX_ACCURACY_LOG); self.accuracy_log = 0; let bytes_read = self.read_probabilities(source, max_log)?; @@ -280,13 +275,19 @@ impl FSETable { #[cfg(target_endian = "little")] { + debug_assert_eq!(core::mem::size_of::(), 4); + debug_assert_eq!(core::mem::offset_of!(Entry, new_state), 0); + debug_assert_eq!(core::mem::offset_of!(Entry, symbol), 2); + debug_assert_eq!(core::mem::offset_of!(Entry, num_bits), 3); // Write two packed entries (8 bytes) at once: // Entry bytes are [new_state_lo, new_state_hi, symbol, num_bits]. let mut idx = 0usize; while idx + 1 < table_symbols.len() { let packed = ((table_symbols[idx] as u64) << 16) | ((table_symbols[idx + 1] as u64) << 48); - // SAFETY: `idx + 1 < len`, so at least 8 bytes remain. Unaligned writes are intentional. + // SAFETY: `idx + 1 < table_symbols.len()` and `table_symbols.len() == self.decode.len()` + // ensure `idx` and `idx + 1` are valid `self.decode` entries (2 x 4 bytes = 8 bytes). + // Unaligned writes are intentional because `Entry` alignment may be < 8. unsafe { ptr::write_unaligned(self.decode.as_mut_ptr().add(idx).cast::(), packed); } diff --git a/zstd/src/fse/mod.rs b/zstd/src/fse/mod.rs index 7830458f..ca1d08d1 100644 --- a/zstd/src/fse/mod.rs +++ b/zstd/src/fse/mod.rs @@ -39,12 +39,12 @@ fn build_from_probabilities_rejects_acc_log_over_entry_limit() { } #[test] -fn build_decoder_rejects_max_log_over_entry_limit() { +fn build_decoder_clamps_max_log_over_entry_limit() { let mut dec_table = FSETable::new(255); - let err = dec_table.build_decoder(&[], 17).unwrap_err(); + let err = dec_table.build_decoder(&[], 16).unwrap_err(); assert!(matches!( err, - crate::decoding::errors::FSETableError::AccLogTooBig { got: 17, max: 16 } + crate::decoding::errors::FSETableError::GetBitsError(_) )); } From 61e4ac1b98cb1706b092e761d8066a1808257bf4 Mon Sep 17 00:00:00 2001 From: Dmitry Prudnikov Date: Tue, 7 Apr 2026 00:41:10 +0300 Subject: [PATCH 5/8] test(fse): cover clamp path and enforce entry layout - exercise build_decoder clamp branch with max_log > 16 - add compile-time size and field-offset assertions for Entry on little-endian --- zstd/src/fse/fse_decoder.rs | 9 +++++++++ zstd/src/fse/mod.rs | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/zstd/src/fse/fse_decoder.rs b/zstd/src/fse/fse_decoder.rs index 73e42ff8..b66916f0 100644 --- a/zstd/src/fse/fse_decoder.rs +++ b/zstd/src/fse/fse_decoder.rs @@ -413,6 +413,15 @@ pub struct Entry { pub num_bits: u8, } +#[cfg(target_endian = "little")] +const _: [(); 0] = [(); core::mem::offset_of!(Entry, new_state)]; +#[cfg(target_endian = "little")] +const _: [(); 2] = [(); core::mem::offset_of!(Entry, symbol)]; +#[cfg(target_endian = "little")] +const _: [(); 3] = [(); core::mem::offset_of!(Entry, num_bits)]; +#[cfg(target_endian = "little")] +const _: [(); 4] = [(); core::mem::size_of::()]; + /// This value is added to the first 4 bits of the stream to determine the /// `Accuracy_Log` const ACC_LOG_OFFSET: u8 = 5; diff --git a/zstd/src/fse/mod.rs b/zstd/src/fse/mod.rs index ca1d08d1..b24d6b26 100644 --- a/zstd/src/fse/mod.rs +++ b/zstd/src/fse/mod.rs @@ -41,7 +41,7 @@ fn build_from_probabilities_rejects_acc_log_over_entry_limit() { #[test] fn build_decoder_clamps_max_log_over_entry_limit() { let mut dec_table = FSETable::new(255); - let err = dec_table.build_decoder(&[], 16).unwrap_err(); + let err = dec_table.build_decoder(&[], 17).unwrap_err(); assert!(matches!( err, crate::decoding::errors::FSETableError::GetBitsError(_) From ee56e95aa340e04bc0044cc8819d86f59783e512 Mon Sep 17 00:00:00 2001 From: Dmitry Prudnikov Date: Tue, 7 Apr 2026 02:16:28 +0300 Subject: [PATCH 6/8] perf(fse): preserve spread-buffer reuse on table reinit - copy symbol_spread_buffer in reinit_from to retain allocated capacity --- zstd/src/fse/fse_decoder.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/zstd/src/fse/fse_decoder.rs b/zstd/src/fse/fse_decoder.rs index b66916f0..94c961fd 100644 --- a/zstd/src/fse/fse_decoder.rs +++ b/zstd/src/fse/fse_decoder.rs @@ -119,6 +119,8 @@ impl FSETable { self.symbol_counter.extend_from_slice(&other.symbol_counter); self.symbol_probabilities .extend_from_slice(&other.symbol_probabilities); + self.symbol_spread_buffer + .extend_from_slice(&other.symbol_spread_buffer); self.decode.extend_from_slice(&other.decode); self.accuracy_log = other.accuracy_log; } From 0e022ecec10cb9a5530009b9611038adfbcbf0f6 Mon Sep 17 00:00:00 2001 From: Dmitry Prudnikov Date: Tue, 7 Apr 2026 02:31:58 +0300 Subject: [PATCH 7/8] perf(fse): avoid spread-buffer copy on table reinit - preserve only symbol_spread_buffer capacity via reserve - rename empty-input test to match asserted behavior --- zstd/src/fse/fse_decoder.rs | 2 +- zstd/src/fse/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/zstd/src/fse/fse_decoder.rs b/zstd/src/fse/fse_decoder.rs index 94c961fd..ef80b706 100644 --- a/zstd/src/fse/fse_decoder.rs +++ b/zstd/src/fse/fse_decoder.rs @@ -120,7 +120,7 @@ impl FSETable { self.symbol_probabilities .extend_from_slice(&other.symbol_probabilities); self.symbol_spread_buffer - .extend_from_slice(&other.symbol_spread_buffer); + .reserve(other.symbol_spread_buffer.len()); self.decode.extend_from_slice(&other.decode); self.accuracy_log = other.accuracy_log; } diff --git a/zstd/src/fse/mod.rs b/zstd/src/fse/mod.rs index b24d6b26..5e410322 100644 --- a/zstd/src/fse/mod.rs +++ b/zstd/src/fse/mod.rs @@ -39,7 +39,7 @@ fn build_from_probabilities_rejects_acc_log_over_entry_limit() { } #[test] -fn build_decoder_clamps_max_log_over_entry_limit() { +fn build_decoder_empty_input_reports_bits_error_with_large_max_log() { let mut dec_table = FSETable::new(255); let err = dec_table.build_decoder(&[], 17).unwrap_err(); assert!(matches!( From 25be69c3a22680136500707ec6b065446fe1683b Mon Sep 17 00:00:00 2001 From: Dmitry Prudnikov Date: Tue, 7 Apr 2026 02:43:24 +0300 Subject: [PATCH 8/8] refactor(fse): remove stale debug print in update_state - drop commented println that logged post-update state and could mislead debugging --- zstd/src/fse/fse_decoder.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/zstd/src/fse/fse_decoder.rs b/zstd/src/fse/fse_decoder.rs index ef80b706..07670551 100644 --- a/zstd/src/fse/fse_decoder.rs +++ b/zstd/src/fse/fse_decoder.rs @@ -46,8 +46,6 @@ impl<'t> FSEDecoder<'t> { let add = bits.get_bits(num_bits); let next_state = usize::from(self.state.new_state) + add as usize; self.state = self.table.decode[next_state]; - - //println!("Update: {}, {} -> {}", self.state.new_state, add, self.state); } /// Advance the internal state **without** an individual refill check.