diff --git a/zstd/src/decoding/scratch.rs b/zstd/src/decoding/scratch.rs index cf963459..d06b2045 100644 --- a/zstd/src/decoding/scratch.rs +++ b/zstd/src/decoding/scratch.rs @@ -6,6 +6,7 @@ use crate::decoding::dictionary::Dictionary; use crate::fse::FSETable; use crate::huff0::HuffmanTable; use alloc::vec::Vec; +use core::ops::{Deref, DerefMut}; use crate::blocks::sequence_section::{ MAX_LITERAL_LENGTH_CODE, MAX_MATCH_LENGTH_CODE, MAX_OFFSET_CODE, @@ -33,11 +34,11 @@ impl DecoderScratch { table: HuffmanTable::new(), }, fse: FSEScratch { - offsets: FSETable::new(MAX_OFFSET_CODE), + offsets: AlignedFSETable::new(MAX_OFFSET_CODE), of_rle: None, - literal_lengths: FSETable::new(MAX_LITERAL_LENGTH_CODE), + literal_lengths: AlignedFSETable::new(MAX_LITERAL_LENGTH_CODE), ll_rle: None, - match_lengths: FSETable::new(MAX_MATCH_LENGTH_CODE), + match_lengths: AlignedFSETable::new(MAX_MATCH_LENGTH_CODE), ml_rle: None, }, buffer: DecodeBuffer::new(window_size), @@ -97,22 +98,22 @@ impl Default for HuffmanScratch { } pub struct FSEScratch { - pub offsets: FSETable, + pub offsets: AlignedFSETable, pub of_rle: Option, - pub literal_lengths: FSETable, + pub literal_lengths: AlignedFSETable, pub ll_rle: Option, - pub match_lengths: FSETable, + pub match_lengths: AlignedFSETable, pub ml_rle: Option, } impl FSEScratch { pub fn new() -> FSEScratch { FSEScratch { - offsets: FSETable::new(MAX_OFFSET_CODE), + offsets: AlignedFSETable::new(MAX_OFFSET_CODE), of_rle: None, - literal_lengths: FSETable::new(MAX_LITERAL_LENGTH_CODE), + literal_lengths: AlignedFSETable::new(MAX_LITERAL_LENGTH_CODE), ll_rle: None, - match_lengths: FSETable::new(MAX_MATCH_LENGTH_CODE), + match_lengths: AlignedFSETable::new(MAX_MATCH_LENGTH_CODE), ml_rle: None, } } @@ -132,3 +133,30 @@ impl Default for FSEScratch { Self::new() } } + +// Keep LL/ML/OF table *objects* cache-line aligned to avoid cross-table placement +// effects in DecoderScratch when they are accessed in the same decode hot loop. +// Note: this aligns the table containers, not the `Vec` backing allocations. +#[cfg_attr(target_arch = "aarch64", repr(align(128)))] +#[cfg_attr(not(target_arch = "aarch64"), repr(align(64)))] +pub struct AlignedFSETable(FSETable); + +impl AlignedFSETable { + fn new(max_symbol: u8) -> Self { + Self(FSETable::new(max_symbol)) + } +} + +impl Deref for AlignedFSETable { + type Target = FSETable; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for AlignedFSETable { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} diff --git a/zstd/src/decoding/sequence_section_decoder.rs b/zstd/src/decoding/sequence_section_decoder.rs index 74a3baf7..f203e966 100644 --- a/zstd/src/decoding/sequence_section_decoder.rs +++ b/zstd/src/decoding/sequence_section_decoder.rs @@ -475,7 +475,7 @@ fn test_ll_default() { idx, table.decode[idx].symbol, table.decode[idx].num_bits, - table.decode[idx].base_line + table.decode[idx].new_state ); } @@ -484,21 +484,21 @@ fn test_ll_default() { //just test a few values. TODO test all values assert!(table.decode[0].symbol == 0); assert!(table.decode[0].num_bits == 4); - assert!(table.decode[0].base_line == 0); + assert!(table.decode[0].new_state == 0); assert!(table.decode[19].symbol == 27); assert!(table.decode[19].num_bits == 6); - assert!(table.decode[19].base_line == 0); + assert!(table.decode[19].new_state == 0); assert!(table.decode[39].symbol == 25); assert!(table.decode[39].num_bits == 4); - assert!(table.decode[39].base_line == 16); + assert!(table.decode[39].new_state == 16); assert!(table.decode[60].symbol == 35); assert!(table.decode[60].num_bits == 6); - assert!(table.decode[60].base_line == 0); + assert!(table.decode[60].new_state == 0); assert!(table.decode[59].symbol == 24); assert!(table.decode[59].num_bits == 5); - assert!(table.decode[59].base_line == 32); + assert!(table.decode[59].new_state == 32); } diff --git a/zstd/src/fse/fse_decoder.rs b/zstd/src/fse/fse_decoder.rs index d39810d2..07670551 100644 --- a/zstd/src/fse/fse_decoder.rs +++ b/zstd/src/fse/fse_decoder.rs @@ -1,6 +1,7 @@ use crate::bit_io::{BitReader, BitReaderReversed}; use crate::decoding::errors::{FSEDecoderError, FSETableError}; use alloc::vec::Vec; +use core::ptr; pub struct FSEDecoder<'table> { /// An FSE state value represents an index in the FSE table. @@ -14,9 +15,9 @@ impl<'t> FSEDecoder<'t> { pub fn new(table: &'t FSETable) -> FSEDecoder<'t> { FSEDecoder { state: table.decode.first().copied().unwrap_or(Entry { - base_line: 0, - num_bits: 0, + new_state: 0, symbol: 0, + num_bits: 0, }), table, } @@ -43,11 +44,8 @@ impl<'t> FSEDecoder<'t> { pub fn update_state(&mut self, bits: &mut BitReaderReversed<'_>) { let num_bits = self.state.num_bits; let add = bits.get_bits(num_bits); - let base_line = self.state.base_line; - let new_state = base_line + add as u32; - self.state = self.table.decode[new_state as usize]; - - //println!("Update: {}, {} -> {}", base_line, add, self.state); + let next_state = usize::from(self.state.new_state) + add as usize; + self.state = self.table.decode[next_state]; } /// Advance the internal state **without** an individual refill check. @@ -62,8 +60,8 @@ impl<'t> FSEDecoder<'t> { pub fn update_state_fast(&mut self, bits: &mut BitReaderReversed<'_>) { let num_bits = self.state.num_bits; let add = bits.get_bits_unchecked(num_bits); - let new_state = self.state.base_line + add as u32; - self.state = self.table.decode[new_state as usize]; + let next_state = usize::from(self.state.new_state) + add as usize; + self.state = self.table.decode[next_state]; } } @@ -78,6 +76,8 @@ pub struct FSETable { /// The actual table containing the decoded symbol and the compression data /// connected to that symbol. pub decode: Vec, //used to decode symbols, and calculate the next state + /// Reused scratch buffer for symbol spreading to avoid per-build allocations. + symbol_spread_buffer: Vec, /// The size of the table is stored in logarithm base 2 format, /// with the **size of the table** being equal to `(1 << accuracy_log)`. /// This value is used so that the decoder knows how many bits to read from the bitstream. @@ -105,7 +105,8 @@ impl FSETable { max_symbol, symbol_probabilities: Vec::with_capacity(256), //will never be more than 256 symbols because u8 symbol_counter: Vec::with_capacity(256), //will never be more than 256 symbols because u8 - decode: Vec::new(), //depending on acc_log. + symbol_spread_buffer: Vec::new(), + decode: Vec::new(), //depending on acc_log. accuracy_log: 0, } } @@ -116,6 +117,8 @@ impl FSETable { self.symbol_counter.extend_from_slice(&other.symbol_counter); self.symbol_probabilities .extend_from_slice(&other.symbol_probabilities); + self.symbol_spread_buffer + .reserve(other.symbol_spread_buffer.len()); self.decode.extend_from_slice(&other.decode); self.accuracy_log = other.accuracy_log; } @@ -124,6 +127,7 @@ impl FSETable { pub fn reset(&mut self) { self.symbol_counter.clear(); self.symbol_probabilities.clear(); + self.symbol_spread_buffer.clear(); self.decode.clear(); self.accuracy_log = 0; } @@ -142,6 +146,7 @@ impl FSETable { /// returns how many BYTEs (not bits) were read while building the decoder pub fn build_decoder(&mut self, source: &[u8], max_log: u8) -> Result { + let max_log = max_log.min(ENTRY_MAX_ACCURACY_LOG); self.accuracy_log = 0; let bytes_read = self.read_probabilities(source, max_log)?; @@ -159,6 +164,12 @@ impl FSETable { if acc_log == 0 { return Err(FSETableError::AccLogIsZero); } + if acc_log > ENTRY_MAX_ACCURACY_LOG { + return Err(FSETableError::AccLogTooBig { + got: acc_log, + max: ENTRY_MAX_ACCURACY_LOG, + }); + } self.symbol_probabilities = probs.to_vec(); self.accuracy_log = acc_log; self.build_decoding_table() @@ -183,45 +194,54 @@ impl FSETable { self.decode.resize( table_size, Entry { - base_line: 0, - num_bits: 0, + new_state: 0, symbol: 0, + num_bits: 0, }, ); - let mut negative_idx = table_size; //will point to the highest index with is already occupied by a negative-probability-symbol - - //first scan for all -1 probabilities and place them at the top of the table - for symbol in 0..self.symbol_probabilities.len() { - if self.symbol_probabilities[symbol] == -1 { - negative_idx -= 1; - let entry = &mut self.decode[negative_idx]; - entry.symbol = symbol as u8; - entry.base_line = 0; - entry.num_bits = self.accuracy_log; + let mut table_symbols = core::mem::take(&mut self.symbol_spread_buffer); + table_symbols.clear(); + table_symbols.resize(table_size, 0); + let negative_idx = { + let table_symbols = &mut table_symbols; + let mut negative_idx = table_size; //will point to the highest index with is already occupied by a negative-probability-symbol + + //first scan for all -1 probabilities and place them at the top of the table + for symbol in 0..self.symbol_probabilities.len() { + if self.symbol_probabilities[symbol] == -1 { + negative_idx -= 1; + table_symbols[negative_idx] = symbol as u8; + } } - } - //then place in a semi-random order all of the other symbols - let mut position = 0; - for idx in 0..self.symbol_probabilities.len() { - let symbol = idx as u8; - if self.symbol_probabilities[idx] <= 0 { - continue; - } + //then place in a semi-random order all of the other symbols + let mut position = 0; + for idx in 0..self.symbol_probabilities.len() { + let symbol = idx as u8; + if self.symbol_probabilities[idx] <= 0 { + continue; + } - //for each probability point the symbol gets on slot - let prob = self.symbol_probabilities[idx]; - for _ in 0..prob { - let entry = &mut self.decode[position]; - entry.symbol = symbol; + //for each probability point the symbol gets on slot + let prob = self.symbol_probabilities[idx]; + for _ in 0..prob { + table_symbols[position] = symbol; - position = next_position(position, table_size); - while position >= negative_idx { position = next_position(position, table_size); - //everything above negative_idx is already taken + while position >= negative_idx { + position = next_position(position, table_size); + //everything above negative_idx is already taken + } } } + negative_idx + }; + + self.copy_symbols_into_decode(&table_symbols); + self.symbol_spread_buffer = table_symbols; + for idx in negative_idx..table_size { + self.decode[idx].num_bits = self.accuracy_log; } // baselines and num_bits can only be calculated when all symbols have been spread @@ -241,12 +261,51 @@ impl FSETable { assert!(nb <= self.accuracy_log); self.symbol_counter[symbol as usize] += 1; - entry.base_line = bl; + entry.new_state = u16::try_from(bl).map_err(|_| FSETableError::AccLogTooBig { + got: self.accuracy_log, + max: ENTRY_MAX_ACCURACY_LOG, + })?; entry.num_bits = nb; } Ok(()) } + fn copy_symbols_into_decode(&mut self, table_symbols: &[u8]) { + debug_assert_eq!(table_symbols.len(), self.decode.len()); + + #[cfg(target_endian = "little")] + { + debug_assert_eq!(core::mem::size_of::(), 4); + debug_assert_eq!(core::mem::offset_of!(Entry, new_state), 0); + debug_assert_eq!(core::mem::offset_of!(Entry, symbol), 2); + debug_assert_eq!(core::mem::offset_of!(Entry, num_bits), 3); + // Write two packed entries (8 bytes) at once: + // Entry bytes are [new_state_lo, new_state_hi, symbol, num_bits]. + let mut idx = 0usize; + while idx + 1 < table_symbols.len() { + let packed = + ((table_symbols[idx] as u64) << 16) | ((table_symbols[idx + 1] as u64) << 48); + // SAFETY: `idx + 1 < table_symbols.len()` and `table_symbols.len() == self.decode.len()` + // ensure `idx` and `idx + 1` are valid `self.decode` entries (2 x 4 bytes = 8 bytes). + // Unaligned writes are intentional because `Entry` alignment may be < 8. + unsafe { + ptr::write_unaligned(self.decode.as_mut_ptr().add(idx).cast::(), packed); + } + idx += 2; + } + if idx < table_symbols.len() { + self.decode[idx].symbol = table_symbols[idx]; + } + } + + #[cfg(not(target_endian = "little"))] + { + for (entry, symbol) in self.decode.iter_mut().zip(table_symbols.iter().copied()) { + entry.symbol = symbol; + } + } + } + /// Read the accuracy log and the probability table from the source and return the number of bytes /// read. If the size of the table is larger than the provided `max_log`, return an error. fn read_probabilities(&mut self, source: &[u8], max_log: u8) -> Result { @@ -254,6 +313,12 @@ impl FSETable { let mut br = BitReader::new(source); self.accuracy_log = ACC_LOG_OFFSET + (br.get_bits(4)? as u8); + if self.accuracy_log > ENTRY_MAX_ACCURACY_LOG { + return Err(FSETableError::AccLogTooBig { + got: self.accuracy_log, + max: ENTRY_MAX_ACCURACY_LOG, + }); + } if self.accuracy_log > max_log { return Err(FSETableError::AccLogTooBig { got: self.accuracy_log, @@ -336,20 +401,31 @@ impl FSETable { } /// A single entry in an FSE table. +#[repr(C)] #[derive(Copy, Clone, Debug)] pub struct Entry { - /// This value is used as an offset value, and it is added - /// to a value read from the stream to determine the next state value. - pub base_line: u32, - /// How many bits should be read from the stream when decoding this entry. - pub num_bits: u8, + /// Base index for the next state. The low bits read from the bitstream are + /// added to this value to produce the final state index. + pub new_state: u16, /// The byte that should be put in the decode output when encountering this state. pub symbol: u8, + /// How many bits should be read from the stream when decoding this entry. + pub num_bits: u8, } +#[cfg(target_endian = "little")] +const _: [(); 0] = [(); core::mem::offset_of!(Entry, new_state)]; +#[cfg(target_endian = "little")] +const _: [(); 2] = [(); core::mem::offset_of!(Entry, symbol)]; +#[cfg(target_endian = "little")] +const _: [(); 3] = [(); core::mem::offset_of!(Entry, num_bits)]; +#[cfg(target_endian = "little")] +const _: [(); 4] = [(); core::mem::size_of::()]; + /// This value is added to the first 4 bits of the stream to determine the /// `Accuracy_Log` const ACC_LOG_OFFSET: u8 = 5; +const ENTRY_MAX_ACCURACY_LOG: u8 = 16; fn highest_bit_set(x: u32) -> u32 { assert!(x > 0); diff --git a/zstd/src/fse/mod.rs b/zstd/src/fse/mod.rs index f2a8b7b2..5e410322 100644 --- a/zstd/src/fse/mod.rs +++ b/zstd/src/fse/mod.rs @@ -18,6 +18,36 @@ pub use fse_decoder::*; pub mod fse_encoder; +#[test] +fn decoder_entry_is_packed_4_bytes() { + assert_eq!(core::mem::size_of::(), 4); + assert_eq!(core::mem::offset_of!(fse_decoder::Entry, new_state), 0); + assert_eq!(core::mem::offset_of!(fse_decoder::Entry, symbol), 2); + assert_eq!(core::mem::offset_of!(fse_decoder::Entry, num_bits), 3); +} + +#[test] +fn build_from_probabilities_rejects_acc_log_over_entry_limit() { + let mut dec_table = FSETable::new(255); + let err = dec_table + .build_from_probabilities(17, &[1, 1, 1, 1]) + .unwrap_err(); + assert!(matches!( + err, + crate::decoding::errors::FSETableError::AccLogTooBig { got: 17, max: 16 } + )); +} + +#[test] +fn build_decoder_empty_input_reports_bits_error_with_large_max_log() { + let mut dec_table = FSETable::new(255); + let err = dec_table.build_decoder(&[], 17).unwrap_err(); + assert!(matches!( + err, + crate::decoding::errors::FSETableError::GetBitsError(_) + )); +} + #[test] fn tables_equal() { let probs = &[0, 0, -1, 3, 2, 2, (1 << 6) - 8]; @@ -37,7 +67,7 @@ fn check_tables(dec_table: &fse_decoder::FSETable, enc_table: &fse_encoder::FSET .iter() .find(|state| state.index == idx) .unwrap(); - assert_eq!(enc_state.baseline, dec_state.base_line as usize); + assert_eq!(enc_state.baseline, dec_state.new_state as usize); assert_eq!(enc_state.num_bits, dec_state.num_bits); } }