diff --git a/zstd/Cargo.toml b/zstd/Cargo.toml index 0d435fc7..ee02ca61 100644 --- a/zstd/Cargo.toml +++ b/zstd/Cargo.toml @@ -11,7 +11,7 @@ license = "Apache-2.0" homepage = "https://github.com/structured-world/structured-zstd" repository = "https://github.com/structured-world/structured-zstd" description = "Pure Rust zstd implementation — managed fork of ruzstd. Dictionary decompression, no FFI." -exclude = ["dict_tests/*", "fuzz_decodecorpus/*", "decodecorpus_files/*"] +exclude = ["fuzz_decodecorpus/*", "decodecorpus_files/*", "dict_tests/files/**"] # Package metadata points at a crate-local symlink so the packaged crate and repo root README stay in sync. readme = "README.md" keywords = ["zstd", "zstandard", "decompression", "compression", "pure-rust"] diff --git a/zstd/src/decoding/dictionary.rs b/zstd/src/decoding/dictionary.rs index f0f7b7ad..4d3030de 100644 --- a/zstd/src/decoding/dictionary.rs +++ b/zstd/src/decoding/dictionary.rs @@ -40,15 +40,50 @@ pub struct Dictionary { pub const MAGIC_NUM: [u8; 4] = [0x37, 0xA4, 0x30, 0xEC]; impl Dictionary { - /// Parses the dictionary from `raw` and set the tables - /// it returns the dict_id for checking with the frame's `dict_id`` + /// Build a dictionary from raw content bytes (without entropy table sections). + /// + /// This is primarily intended for dictionaries produced by the `dict_builder` + /// module, which currently emits raw-content dictionaries. + pub fn from_raw_content( + id: u32, + dict_content: Vec, + ) -> Result { + if id == 0 { + return Err(DictionaryDecodeError::ZeroDictionaryId); + } + if dict_content.is_empty() { + return Err(DictionaryDecodeError::DictionaryTooSmall { got: 0, need: 1 }); + } + + Ok(Dictionary { + id, + fse: FSEScratch::new(), + huf: HuffmanScratch::new(), + dict_content, + offset_hist: [1, 4, 8], + }) + } + + /// Parses the dictionary from `raw`, initializes its tables, + /// and returns a fully constructed [`Dictionary`] whose `id` can be + /// checked against the frame's `dict_id`. pub fn decode_dict(raw: &[u8]) -> Result { + const MIN_MAGIC_AND_ID_LEN: usize = 8; + const OFFSET_HISTORY_LEN: usize = 12; + + if raw.len() < MIN_MAGIC_AND_ID_LEN { + return Err(DictionaryDecodeError::DictionaryTooSmall { + got: raw.len(), + need: MIN_MAGIC_AND_ID_LEN, + }); + } + let mut new_dict = Dictionary { id: 0, fse: FSEScratch::new(), huf: HuffmanScratch::new(), dict_content: Vec::new(), - offset_hist: [2, 4, 8], + offset_hist: [1, 4, 8], }; let magic_num: [u8; 4] = raw[..4].try_into().expect("optimized away"); @@ -58,6 +93,9 @@ impl Dictionary { let dict_id = raw[4..8].try_into().expect("optimized away"); let dict_id = u32::from_le_bytes(dict_id); + if dict_id == 0 { + return Err(DictionaryDecodeError::ZeroDictionaryId); + } new_dict.id = dict_id; let raw_tables = &raw[8..]; @@ -83,6 +121,13 @@ impl Dictionary { )?; let raw_tables = &raw_tables[ll_size..]; + if raw_tables.len() < OFFSET_HISTORY_LEN { + return Err(DictionaryDecodeError::DictionaryTooSmall { + got: raw_tables.len(), + need: OFFSET_HISTORY_LEN, + }); + } + let offset1 = raw_tables[0..4].try_into().expect("optimized away"); let offset1 = u32::from_le_bytes(offset1); @@ -92,6 +137,16 @@ impl Dictionary { let offset3 = raw_tables[8..12].try_into().expect("optimized away"); let offset3 = u32::from_le_bytes(offset3); + if offset1 == 0 { + return Err(DictionaryDecodeError::ZeroRepeatOffsetInDictionary { index: 0 }); + } + if offset2 == 0 { + return Err(DictionaryDecodeError::ZeroRepeatOffsetInDictionary { index: 1 }); + } + if offset3 == 0 { + return Err(DictionaryDecodeError::ZeroRepeatOffsetInDictionary { index: 2 }); + } + new_dict.offset_hist[0] = offset1; new_dict.offset_hist[1] = offset2; new_dict.offset_hist[2] = offset3; @@ -102,3 +157,102 @@ impl Dictionary { Ok(new_dict) } } + +#[cfg(test)] +mod tests { + use super::*; + + fn offset_history_start(raw: &[u8]) -> usize { + let mut huf = crate::decoding::scratch::HuffmanScratch::new(); + let mut fse = crate::decoding::scratch::FSEScratch::new(); + let mut cursor = 8usize; + + let huf_size = huf + .table + .build_decoder(&raw[cursor..]) + .expect("reference dictionary huffman table should decode"); + cursor += huf_size as usize; + + let of_size = fse + .offsets + .build_decoder( + &raw[cursor..], + crate::decoding::sequence_section_decoder::OF_MAX_LOG, + ) + .expect("reference dictionary OF table should decode"); + cursor += of_size; + + let ml_size = fse + .match_lengths + .build_decoder( + &raw[cursor..], + crate::decoding::sequence_section_decoder::ML_MAX_LOG, + ) + .expect("reference dictionary ML table should decode"); + cursor += ml_size; + + let ll_size = fse + .literal_lengths + .build_decoder( + &raw[cursor..], + crate::decoding::sequence_section_decoder::LL_MAX_LOG, + ) + .expect("reference dictionary LL table should decode"); + cursor += ll_size; + + cursor + } + + #[test] + fn decode_dict_rejects_short_buffer_before_magic_and_id() { + let err = match Dictionary::decode_dict(&[]) { + Ok(_) => panic!("expected short dictionary to fail"), + Err(err) => err, + }; + assert!(matches!( + err, + DictionaryDecodeError::DictionaryTooSmall { got: 0, need: 8 } + )); + } + + #[test] + fn decode_dict_malformed_input_returns_error_instead_of_panicking() { + let mut raw = Vec::new(); + raw.extend_from_slice(&MAGIC_NUM); + raw.extend_from_slice(&1u32.to_le_bytes()); + raw.extend_from_slice(&[0u8; 7]); + + let result = std::panic::catch_unwind(|| Dictionary::decode_dict(&raw)); + assert!( + result.is_ok(), + "decode_dict must not panic on malformed input" + ); + assert!( + result.unwrap().is_err(), + "malformed dictionary must return error" + ); + } + + #[test] + fn decode_dict_rejects_zero_repeat_offsets() { + let mut raw = include_bytes!("../../dict_tests/dictionary").to_vec(); + let offset_start = offset_history_start(&raw); + + // Corrupt rep0 to zero. + raw[offset_start..offset_start + 4].copy_from_slice(&0u32.to_le_bytes()); + let decoded = Dictionary::decode_dict(&raw); + assert!(matches!( + decoded, + Err(DictionaryDecodeError::ZeroRepeatOffsetInDictionary { index: 0 }) + )); + } + + #[test] + fn from_raw_content_rejects_empty_dictionary_content() { + let result = Dictionary::from_raw_content(1, Vec::new()); + assert!(matches!( + result, + Err(DictionaryDecodeError::DictionaryTooSmall { got: 0, need: 1 }) + )); + } +} diff --git a/zstd/src/decoding/errors.rs b/zstd/src/decoding/errors.rs index 466ffe1a..06a0085b 100644 --- a/zstd/src/decoding/errors.rs +++ b/zstd/src/decoding/errors.rs @@ -425,6 +425,9 @@ impl core::fmt::Display for DecodeBufferError { #[non_exhaustive] pub enum DictionaryDecodeError { BadMagicNum { got: [u8; 4] }, + DictionaryTooSmall { got: usize, need: usize }, + ZeroDictionaryId, + ZeroRepeatOffsetInDictionary { index: u8 }, FSETableError(FSETableError), HuffmanTableError(HuffmanTableError), } @@ -451,6 +454,18 @@ impl core::fmt::Display for DictionaryDecodeError { crate::decoding::dictionary::MAGIC_NUM, ) } + DictionaryDecodeError::DictionaryTooSmall { got, need } => { + write!( + f, + "Dictionary is too small: got {got} bytes, need at least {need} bytes", + ) + } + DictionaryDecodeError::ZeroDictionaryId => { + write!(f, "Dictionary id must be non-zero") + } + DictionaryDecodeError::ZeroRepeatOffsetInDictionary { index } => { + write!(f, "Dictionary repeat offset rep{index} must be non-zero") + } DictionaryDecodeError::FSETableError(e) => write!(f, "{e:?}"), DictionaryDecodeError::HuffmanTableError(e) => write!(f, "{e:?}"), } diff --git a/zstd/src/encoding/frame_compressor.rs b/zstd/src/encoding/frame_compressor.rs index c87806b7..ed4c2973 100644 --- a/zstd/src/encoding/frame_compressor.rs +++ b/zstd/src/encoding/frame_compressor.rs @@ -39,11 +39,21 @@ pub struct FrameCompressor { uncompressed_data: Option, compressed_data: Option, compression_level: CompressionLevel, + dictionary: Option, + dictionary_entropy_cache: Option, state: CompressState, #[cfg(feature = "hash")] hasher: XxHash64, } +#[derive(Clone, Default)] +struct CachedDictionaryEntropy { + huff: Option, + ll_previous: Option, + ml_previous: Option, + of_previous: Option, +} + #[derive(Clone)] pub(crate) enum PreviousFseTable { // Default tables are immutable and already stored alongside the state, so @@ -99,6 +109,8 @@ impl FrameCompressor { uncompressed_data: None, compressed_data: None, compression_level, + dictionary: None, + dictionary_entropy_cache: None, state: CompressState { matcher: MatchGeneratorDriver::new(1024 * 128, 1), last_huff_table: None, @@ -117,6 +129,8 @@ impl FrameCompressor { Self { uncompressed_data: None, compressed_data: None, + dictionary: None, + dictionary_entropy_cache: None, state: CompressState { matcher, last_huff_table: None, @@ -153,11 +167,48 @@ impl FrameCompressor { pub fn compress(&mut self) { // Clearing buffers to allow re-using of the compressor self.state.matcher.reset(self.compression_level); - self.state.last_huff_table = None; - self.state.fse_tables.ll_previous = None; - self.state.fse_tables.ml_previous = None; - self.state.fse_tables.of_previous = None; self.state.offset_hist = [1, 4, 8]; + let use_dictionary_state = + !matches!(self.compression_level, CompressionLevel::Uncompressed) + && self.state.matcher.supports_dictionary_priming(); + let cached_entropy = if use_dictionary_state { + self.dictionary_entropy_cache.as_ref() + } else { + None + }; + if use_dictionary_state && let Some(dict) = self.dictionary.as_ref() { + // This state drives sequence encoding, while matcher priming below updates + // the match generator's internal repeat-offset history for match finding. + self.state.offset_hist = dict.offset_hist; + self.state + .matcher + .prime_with_dictionary(dict.dict_content.as_slice(), dict.offset_hist); + } + if let Some(cache) = cached_entropy { + self.state.last_huff_table.clone_from(&cache.huff); + } else { + self.state.last_huff_table = None; + } + // `clone_from` keeps frame-to-frame seeding cheap for reused compressors by + // reusing existing allocations where possible instead of reallocating every frame. + if let Some(cache) = cached_entropy { + self.state + .fse_tables + .ll_previous + .clone_from(&cache.ll_previous); + self.state + .fse_tables + .ml_previous + .clone_from(&cache.ml_previous); + self.state + .fse_tables + .of_previous + .clone_from(&cache.of_previous); + } else { + self.state.fse_tables.ll_previous = None; + self.state.fse_tables.ml_previous = None; + self.state.fse_tables.of_previous = None; + } #[cfg(feature = "hash")] { self.hasher = XxHash64::with_seed(0); @@ -171,7 +222,11 @@ impl FrameCompressor { frame_content_size: None, single_segment: false, content_checksum: cfg!(feature = "hash"), - dictionary_id: None, + dictionary_id: if use_dictionary_state { + self.dictionary.as_ref().map(|dict| dict.id as u64) + } else { + None + }, window_size: Some(self.state.matcher.window_size()), }; header.serialize(output); @@ -301,17 +356,122 @@ impl FrameCompressor { pub fn compression_level(&self) -> CompressionLevel { self.compression_level } + + /// Attach a pre-parsed dictionary to be used for subsequent compressions. + /// + /// In compressed modes, the dictionary id is written only when the active + /// matcher supports dictionary priming. + /// Uncompressed mode and non-priming matchers ignore the attached dictionary + /// at encode time. + pub fn set_dictionary( + &mut self, + dictionary: crate::decoding::Dictionary, + ) -> Result, crate::decoding::errors::DictionaryDecodeError> + { + if dictionary.id == 0 { + return Err(crate::decoding::errors::DictionaryDecodeError::ZeroDictionaryId); + } + if let Some(index) = dictionary.offset_hist.iter().position(|&rep| rep == 0) { + return Err( + crate::decoding::errors::DictionaryDecodeError::ZeroRepeatOffsetInDictionary { + index: index as u8, + }, + ); + } + self.dictionary_entropy_cache = Some(CachedDictionaryEntropy { + huff: dictionary.huf.table.to_encoder_table(), + ll_previous: dictionary + .fse + .literal_lengths + .to_encoder_table() + .map(|table| PreviousFseTable::Custom(Box::new(table))), + ml_previous: dictionary + .fse + .match_lengths + .to_encoder_table() + .map(|table| PreviousFseTable::Custom(Box::new(table))), + of_previous: dictionary + .fse + .offsets + .to_encoder_table() + .map(|table| PreviousFseTable::Custom(Box::new(table))), + }); + Ok(self.dictionary.replace(dictionary)) + } + + /// Parse and attach a serialized dictionary blob. + pub fn set_dictionary_from_bytes( + &mut self, + raw_dictionary: &[u8], + ) -> Result, crate::decoding::errors::DictionaryDecodeError> + { + let dictionary = crate::decoding::Dictionary::decode_dict(raw_dictionary)?; + self.set_dictionary(dictionary) + } + + /// Remove the attached dictionary. + pub fn clear_dictionary(&mut self) -> Option { + self.dictionary_entropy_cache = None; + self.dictionary.take() + } } #[cfg(test)] mod tests { + #[cfg(all(feature = "dict_builder", feature = "std"))] + use alloc::format; use alloc::vec; use super::FrameCompressor; use crate::common::MAGIC_NUM; use crate::decoding::FrameDecoder; + use crate::encoding::{Matcher, Sequence}; use alloc::vec::Vec; + struct NoDictionaryMatcher { + last_space: Vec, + window_size: u64, + } + + impl NoDictionaryMatcher { + fn new(window_size: u64) -> Self { + Self { + last_space: Vec::new(), + window_size, + } + } + } + + impl Matcher for NoDictionaryMatcher { + fn get_next_space(&mut self) -> Vec { + vec![0; self.window_size as usize] + } + + fn get_last_space(&mut self) -> &[u8] { + self.last_space.as_slice() + } + + fn commit_space(&mut self, space: Vec) { + self.last_space = space; + } + + fn skip_matching(&mut self) {} + + fn start_matching(&mut self, mut handle_sequence: impl for<'a> FnMut(Sequence<'a>)) { + handle_sequence(Sequence::Literals { + literals: self.last_space.as_slice(), + }); + } + + fn reset(&mut self, _level: super::CompressionLevel) { + self.last_space.clear(); + } + + fn window_size(&self) -> u64 { + self.window_size + } + } + #[test] fn frame_starts_with_magic_num() { let mock_data = [1_u8, 2, 3].as_slice(); @@ -395,6 +555,342 @@ mod tests { assert_eq!(mock_data, decoded); } + #[test] + fn dictionary_compression_sets_required_dict_id_and_roundtrips() { + let dict_raw = include_bytes!("../../dict_tests/dictionary"); + let dict_for_encoder = crate::decoding::Dictionary::decode_dict(dict_raw).unwrap(); + let dict_for_decoder = crate::decoding::Dictionary::decode_dict(dict_raw).unwrap(); + + let mut data = Vec::new(); + for _ in 0..8 { + data.extend_from_slice(&dict_for_decoder.dict_content[..2048]); + } + + let mut with_dict = Vec::new(); + let mut compressor = FrameCompressor::new(super::CompressionLevel::Fastest); + let previous = compressor + .set_dictionary_from_bytes(dict_raw) + .expect("dictionary bytes should parse"); + assert!( + previous.is_none(), + "first dictionary insert should return None" + ); + assert_eq!( + compressor + .set_dictionary(dict_for_encoder) + .expect("valid dictionary should attach") + .expect("set_dictionary_from_bytes inserted previous dictionary") + .id, + dict_for_decoder.id + ); + compressor.set_source(data.as_slice()); + compressor.set_drain(&mut with_dict); + compressor.compress(); + + let (frame_header, _) = crate::decoding::frame::read_frame_header(with_dict.as_slice()) + .expect("encoded stream should have a frame header"); + assert_eq!(frame_header.dictionary_id(), Some(dict_for_decoder.id)); + + let mut decoder = FrameDecoder::new(); + let mut missing_dict_target = Vec::with_capacity(data.len()); + let err = decoder + .decode_all_to_vec(&with_dict, &mut missing_dict_target) + .unwrap_err(); + assert!( + matches!( + &err, + crate::decoding::errors::FrameDecoderError::DictNotProvided { .. } + ), + "dict-compressed stream should require dictionary id, got: {err:?}" + ); + + let mut decoder = FrameDecoder::new(); + decoder.add_dict(dict_for_decoder).unwrap(); + let mut decoded = Vec::with_capacity(data.len()); + decoder.decode_all_to_vec(&with_dict, &mut decoded).unwrap(); + assert_eq!(decoded, data); + + let mut ffi_decoder = zstd::bulk::Decompressor::with_dictionary(dict_raw).unwrap(); + let mut ffi_decoded = Vec::with_capacity(data.len()); + let ffi_written = ffi_decoder + .decompress_to_buffer(with_dict.as_slice(), &mut ffi_decoded) + .unwrap(); + assert_eq!(ffi_written, data.len()); + assert_eq!(ffi_decoded, data); + } + + #[cfg(all(feature = "dict_builder", feature = "std"))] + #[test] + fn dictionary_compression_roundtrips_with_dict_builder_dictionary() { + use std::io::Cursor; + + let mut training = Vec::new(); + for idx in 0..256u32 { + training.extend_from_slice( + format!("tenant=demo table=orders key={idx} region=eu\n").as_bytes(), + ); + } + let mut raw_dict = Vec::new(); + crate::dictionary::create_raw_dict_from_source( + Cursor::new(training.as_slice()), + training.len(), + &mut raw_dict, + 4096, + ); + assert!( + !raw_dict.is_empty(), + "dict_builder produced an empty dictionary" + ); + + let dict_id = 0xD1C7_0008; + let encoder_dict = + crate::decoding::Dictionary::from_raw_content(dict_id, raw_dict.clone()).unwrap(); + let decoder_dict = + crate::decoding::Dictionary::from_raw_content(dict_id, raw_dict.clone()).unwrap(); + + let mut payload = Vec::new(); + for idx in 0..96u32 { + payload.extend_from_slice( + format!( + "tenant=demo table=orders op=put key={idx} value=aaaaabbbbbcccccdddddeeeee\n" + ) + .as_bytes(), + ); + } + + let mut without_dict = Vec::new(); + let mut baseline = FrameCompressor::new(super::CompressionLevel::Fastest); + baseline.set_source(payload.as_slice()); + baseline.set_drain(&mut without_dict); + baseline.compress(); + + let mut with_dict = Vec::new(); + let mut compressor = FrameCompressor::new(super::CompressionLevel::Fastest); + compressor + .set_dictionary(encoder_dict) + .expect("valid dict_builder dictionary should attach"); + compressor.set_source(payload.as_slice()); + compressor.set_drain(&mut with_dict); + compressor.compress(); + + let (frame_header, _) = crate::decoding::frame::read_frame_header(with_dict.as_slice()) + .expect("encoded stream should have a frame header"); + assert_eq!(frame_header.dictionary_id(), Some(dict_id)); + let mut decoder = FrameDecoder::new(); + decoder.add_dict(decoder_dict).unwrap(); + let mut decoded = Vec::with_capacity(payload.len()); + decoder.decode_all_to_vec(&with_dict, &mut decoded).unwrap(); + assert_eq!(decoded, payload); + assert!( + with_dict.len() < without_dict.len(), + "trained dictionary should improve compression for this small payload" + ); + } + + #[test] + fn set_dictionary_from_bytes_seeds_entropy_tables_for_first_block() { + let dict_raw = include_bytes!("../../dict_tests/dictionary"); + let mut output = Vec::new(); + let input = b""; + + let mut compressor = FrameCompressor::new(super::CompressionLevel::Fastest); + let previous = compressor + .set_dictionary_from_bytes(dict_raw) + .expect("dictionary bytes should parse"); + assert!(previous.is_none()); + + compressor.set_source(input.as_slice()); + compressor.set_drain(&mut output); + compressor.compress(); + + assert!( + compressor.state.last_huff_table.is_some(), + "dictionary entropy should seed previous huffman table before first block" + ); + assert!( + compressor.state.fse_tables.ll_previous.is_some(), + "dictionary entropy should seed previous ll table before first block" + ); + assert!( + compressor.state.fse_tables.ml_previous.is_some(), + "dictionary entropy should seed previous ml table before first block" + ); + assert!( + compressor.state.fse_tables.of_previous.is_some(), + "dictionary entropy should seed previous of table before first block" + ); + } + + #[test] + fn set_dictionary_rejects_zero_dictionary_id() { + let invalid = crate::decoding::Dictionary { + id: 0, + fse: crate::decoding::scratch::FSEScratch::new(), + huf: crate::decoding::scratch::HuffmanScratch::new(), + dict_content: vec![1, 2, 3], + offset_hist: [1, 4, 8], + }; + + let mut compressor: FrameCompressor< + &[u8], + Vec, + crate::encoding::match_generator::MatchGeneratorDriver, + > = FrameCompressor::new(super::CompressionLevel::Fastest); + let result = compressor.set_dictionary(invalid); + assert!(matches!( + result, + Err(crate::decoding::errors::DictionaryDecodeError::ZeroDictionaryId) + )); + } + + #[test] + fn set_dictionary_rejects_zero_repeat_offsets() { + let invalid = crate::decoding::Dictionary { + id: 1, + fse: crate::decoding::scratch::FSEScratch::new(), + huf: crate::decoding::scratch::HuffmanScratch::new(), + dict_content: vec![1, 2, 3], + offset_hist: [0, 4, 8], + }; + + let mut compressor: FrameCompressor< + &[u8], + Vec, + crate::encoding::match_generator::MatchGeneratorDriver, + > = FrameCompressor::new(super::CompressionLevel::Fastest); + let result = compressor.set_dictionary(invalid); + assert!(matches!( + result, + Err( + crate::decoding::errors::DictionaryDecodeError::ZeroRepeatOffsetInDictionary { + index: 0 + } + ) + )); + } + + #[test] + fn uncompressed_mode_does_not_require_dictionary() { + let dict_id = 0xABCD_0001; + let dict = + crate::decoding::Dictionary::from_raw_content(dict_id, b"shared-history".to_vec()) + .expect("raw dictionary should be valid"); + + let payload = b"plain-bytes-that-should-stay-raw"; + let mut output = Vec::new(); + let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed); + compressor + .set_dictionary(dict) + .expect("dictionary should attach in uncompressed mode"); + compressor.set_source(payload.as_slice()); + compressor.set_drain(&mut output); + compressor.compress(); + + let (frame_header, _) = crate::decoding::frame::read_frame_header(output.as_slice()) + .expect("encoded frame should have a header"); + assert_eq!( + frame_header.dictionary_id(), + None, + "raw/uncompressed frames must not advertise dictionary dependency" + ); + + let mut decoder = FrameDecoder::new(); + let mut decoded = Vec::with_capacity(payload.len()); + decoder.decode_all_to_vec(&output, &mut decoded).unwrap(); + assert_eq!(decoded, payload); + } + + #[test] + fn dictionary_roundtrip_stays_valid_after_output_exceeds_window() { + use crate::encoding::match_generator::MatchGeneratorDriver; + + let dict_id = 0xABCD_0002; + let dict = crate::decoding::Dictionary::from_raw_content(dict_id, b"abcdefgh".to_vec()) + .expect("raw dictionary should be valid"); + let dict_for_decoder = + crate::decoding::Dictionary::from_raw_content(dict_id, b"abcdefgh".to_vec()) + .expect("raw dictionary should be valid"); + + let payload = b"abcdefgh".repeat(512); + let matcher = MatchGeneratorDriver::new(8, 1); + + let mut no_dict_output = Vec::new(); + let mut no_dict_compressor = + FrameCompressor::new_with_matcher(matcher, super::CompressionLevel::Fastest); + no_dict_compressor.set_source(payload.as_slice()); + no_dict_compressor.set_drain(&mut no_dict_output); + no_dict_compressor.compress(); + let (no_dict_frame_header, _) = + crate::decoding::frame::read_frame_header(no_dict_output.as_slice()) + .expect("baseline frame should have a header"); + let no_dict_window = no_dict_frame_header + .window_size() + .expect("window size should be present"); + + let mut output = Vec::new(); + let matcher = MatchGeneratorDriver::new(8, 1); + let mut compressor = + FrameCompressor::new_with_matcher(matcher, super::CompressionLevel::Fastest); + compressor + .set_dictionary(dict) + .expect("dictionary should attach"); + compressor.set_source(payload.as_slice()); + compressor.set_drain(&mut output); + compressor.compress(); + + let (frame_header, _) = crate::decoding::frame::read_frame_header(output.as_slice()) + .expect("encoded frame should have a header"); + let advertised_window = frame_header + .window_size() + .expect("window size should be present"); + assert_eq!( + advertised_window, no_dict_window, + "dictionary priming must not inflate advertised window size" + ); + assert!( + payload.len() > advertised_window as usize, + "test must cross the advertised window boundary" + ); + + let mut decoder = FrameDecoder::new(); + decoder.add_dict(dict_for_decoder).unwrap(); + let mut decoded = Vec::with_capacity(payload.len()); + decoder.decode_all_to_vec(&output, &mut decoded).unwrap(); + assert_eq!(decoded, payload); + } + + #[test] + fn custom_matcher_without_dictionary_priming_does_not_advertise_dict_id() { + let dict_id = 0xABCD_0003; + let dict = crate::decoding::Dictionary::from_raw_content(dict_id, b"abcdefgh".to_vec()) + .expect("raw dictionary should be valid"); + let payload = b"abcdefghabcdefgh"; + + let mut output = Vec::new(); + let matcher = NoDictionaryMatcher::new(64); + let mut compressor = + FrameCompressor::new_with_matcher(matcher, super::CompressionLevel::Fastest); + compressor + .set_dictionary(dict) + .expect("dictionary should attach"); + compressor.set_source(payload.as_slice()); + compressor.set_drain(&mut output); + compressor.compress(); + + let (frame_header, _) = crate::decoding::frame::read_frame_header(output.as_slice()) + .expect("encoded frame should have a header"); + assert_eq!( + frame_header.dictionary_id(), + None, + "matchers that do not support dictionary priming must not advertise dictionary dependency" + ); + + let mut decoder = FrameDecoder::new(); + let mut decoded = Vec::with_capacity(payload.len()); + decoder.decode_all_to_vec(&output, &mut decoded).unwrap(); + assert_eq!(decoded, payload); + } + #[cfg(feature = "hash")] #[test] fn checksum_two_frames_reused_compressor() { diff --git a/zstd/src/encoding/match_generator.rs b/zstd/src/encoding/match_generator.rs index d0c91245..e15905e3 100644 --- a/zstd/src/encoding/match_generator.rs +++ b/zstd/src/encoding/match_generator.rs @@ -42,6 +42,12 @@ pub struct MatchGeneratorDriver { slice_size: usize, base_slice_size: usize, base_window_size: usize, + // Frame header window size must stay at the configured live-window budget. + // Dictionary retention expands internal matcher capacity only. + reported_window_size: usize, + // Tracks currently retained bytes that originated from primed dictionary + // history and have not been evicted yet. + dictionary_retained_budget: usize, } impl MatchGeneratorDriver { @@ -58,6 +64,8 @@ impl MatchGeneratorDriver { slice_size, base_slice_size: slice_size, base_window_size: max_window_size, + reported_window_size: max_window_size, + dictionary_retained_budget: 0, } } @@ -107,11 +115,71 @@ impl MatchGeneratorDriver { .as_mut() .expect("dfast backend must be initialized by reset() before use") } + + fn retire_dictionary_budget(&mut self, evicted_bytes: usize) { + let reclaimed = evicted_bytes.min(self.dictionary_retained_budget); + if reclaimed == 0 { + return; + } + self.dictionary_retained_budget -= reclaimed; + match self.active_backend { + MatcherBackend::Simple => { + self.match_generator.max_window_size = self + .match_generator + .max_window_size + .saturating_sub(reclaimed); + } + MatcherBackend::Dfast => { + let matcher = self.dfast_matcher_mut(); + matcher.max_window_size = matcher.max_window_size.saturating_sub(reclaimed); + } + } + } + + fn trim_after_budget_retire(&mut self) { + loop { + let mut evicted_bytes = 0usize; + match self.active_backend { + MatcherBackend::Simple => { + let vec_pool = &mut self.vec_pool; + let suffix_pool = &mut self.suffix_pool; + self.match_generator.reserve(0, |mut data, mut suffixes| { + evicted_bytes += data.len(); + data.resize(data.capacity(), 0); + vec_pool.push(data); + suffixes.slots.clear(); + suffixes.slots.resize(suffixes.slots.capacity(), None); + suffix_pool.push(suffixes); + }); + } + MatcherBackend::Dfast => { + let mut retired = Vec::new(); + self.dfast_matcher_mut().trim_to_window(|data| { + evicted_bytes += data.len(); + retired.push(data); + }); + for mut data in retired { + data.resize(data.capacity(), 0); + self.vec_pool.push(data); + } + } + } + if evicted_bytes == 0 { + break; + } + self.retire_dictionary_budget(evicted_bytes); + } + } } impl Matcher for MatchGeneratorDriver { + fn supports_dictionary_priming(&self) -> bool { + true + } + fn reset(&mut self, level: CompressionLevel) { let (backend, slice_size, max_window_size, hash_fill_step) = self.level_config(level); + self.dictionary_retained_budget = 0; if self.active_backend != backend { match self.active_backend { MatcherBackend::Simple => { @@ -139,6 +207,7 @@ impl Matcher for MatchGeneratorDriver { self.active_backend = backend; self.slice_size = slice_size; + self.reported_window_size = max_window_size; match self.active_backend { MatcherBackend::Simple => { let vec_pool = &mut self.vec_pool; @@ -167,11 +236,79 @@ impl Matcher for MatchGeneratorDriver { } } - fn window_size(&self) -> u64 { + fn prime_with_dictionary(&mut self, dict_content: &[u8], offset_hist: [u32; 3]) { + match self.active_backend { + MatcherBackend::Simple => self.match_generator.offset_hist = offset_hist, + MatcherBackend::Dfast => self.dfast_matcher_mut().offset_hist = offset_hist, + } + + if dict_content.is_empty() { + return; + } + + // Dictionary bytes should stay addressable until produced frame output + // itself exceeds the live window size. + let retained_dict_budget = dict_content.len(); match self.active_backend { - MatcherBackend::Simple => self.match_generator.max_window_size as u64, - MatcherBackend::Dfast => self.dfast_matcher().max_window_size as u64, + MatcherBackend::Simple => { + self.match_generator.max_window_size = self + .match_generator + .max_window_size + .saturating_add(retained_dict_budget); + } + MatcherBackend::Dfast => { + let matcher = self.dfast_matcher_mut(); + matcher.max_window_size = + matcher.max_window_size.saturating_add(retained_dict_budget); + } + } + + let mut start = 0usize; + let mut committed_dict_budget = 0usize; + let min_primed_tail = match self.active_backend { + MatcherBackend::Simple => MIN_MATCH_LEN, + MatcherBackend::Dfast => 4, + }; + while start < dict_content.len() { + let end = (start + self.slice_size).min(dict_content.len()); + if end - start < min_primed_tail { + break; + } + let mut space = self.get_next_space(); + space.clear(); + space.extend_from_slice(&dict_content[start..end]); + self.commit_space(space); + self.skip_matching(); + committed_dict_budget += end - start; + start = end; } + + let uncommitted_tail_budget = retained_dict_budget.saturating_sub(committed_dict_budget); + if uncommitted_tail_budget > 0 { + match self.active_backend { + MatcherBackend::Simple => { + self.match_generator.max_window_size = self + .match_generator + .max_window_size + .saturating_sub(uncommitted_tail_budget); + } + MatcherBackend::Dfast => { + let matcher = self.dfast_matcher_mut(); + matcher.max_window_size = matcher + .max_window_size + .saturating_sub(uncommitted_tail_budget); + } + } + } + if committed_dict_budget > 0 { + self.dictionary_retained_budget = self + .dictionary_retained_budget + .saturating_add(committed_dict_budget); + } + } + + fn window_size(&self) -> u64 { + self.reported_window_size as u64 } fn get_next_space(&mut self) -> Vec { @@ -193,6 +330,7 @@ impl Matcher for MatchGeneratorDriver { match self.active_backend { MatcherBackend::Simple => { let vec_pool = &mut self.vec_pool; + let mut evicted_bytes = 0usize; let suffixes = self .suffix_pool .pop() @@ -200,22 +338,29 @@ impl Matcher for MatchGeneratorDriver { let suffix_pool = &mut self.suffix_pool; self.match_generator .add_data(space, suffixes, |mut data, mut suffixes| { + evicted_bytes += data.len(); data.resize(data.capacity(), 0); vec_pool.push(data); suffixes.slots.clear(); suffixes.slots.resize(suffixes.slots.capacity(), None); suffix_pool.push(suffixes); }); + self.retire_dictionary_budget(evicted_bytes); + self.trim_after_budget_retire(); } MatcherBackend::Dfast => { let vec_pool = &mut self.vec_pool; + let mut evicted_bytes = 0usize; self.dfast_match_generator .as_mut() .expect("dfast backend must be initialized by reset() before use") .add_data(space, |mut data| { + evicted_bytes += data.len(); data.resize(data.capacity(), 0); vec_pool.push(data); }); + self.retire_dictionary_budget(evicted_bytes); + self.trim_after_budget_retire(); } } } @@ -276,6 +421,11 @@ impl SuffixStore { #[inline(always)] fn key(&self, suffix: &[u8]) -> usize { + // Capacity=1 yields len_log=0; shifting by 64 would panic. + if self.len_log == 0 { + return 0; + } + let s0 = suffix[0] as u64; let s1 = suffix[1] as u64; let s2 = suffix[2] as u64; @@ -767,11 +917,10 @@ impl DfastMatchGenerator { fn add_data(&mut self, data: Vec, mut reuse_space: impl FnMut(Vec)) { assert!(data.len() <= self.max_window_size); while self.window_size + data.len() > self.max_window_size { - let mut removed = self.window.pop_front().unwrap(); + let removed = self.window.pop_front().unwrap(); self.window_size -= removed.len(); self.history_start += removed.len(); self.history_abs_start += removed.len(); - removed.resize(removed.capacity(), 0); reuse_space(removed); } self.compact_history(); @@ -780,6 +929,16 @@ impl DfastMatchGenerator { self.window.push_back(data); } + fn trim_to_window(&mut self, mut reuse_space: impl FnMut(Vec)) { + while self.window_size > self.max_window_size { + let removed = self.window.pop_front().unwrap(); + self.window_size -= removed.len(); + self.history_start += removed.len(); + self.history_abs_start += removed.len(); + reuse_space(removed); + } + } + fn skip_matching(&mut self) { self.ensure_hash_tables(); let current_len = self.window.back().unwrap().len(); @@ -1292,6 +1451,249 @@ fn driver_switches_backends_and_initializes_dfast_via_reset() { assert_eq!(driver.window_size(), 64); } +#[test] +fn prime_with_dictionary_preserves_history_for_first_full_block() { + let mut driver = MatchGeneratorDriver::new(8, 1); + driver.reset(CompressionLevel::Fastest); + + driver.prime_with_dictionary(b"abcdefgh", [1, 4, 8]); + + let mut space = driver.get_next_space(); + space.clear(); + space.extend_from_slice(b"abcdefgh"); + driver.commit_space(space); + + let mut saw_match = false; + driver.start_matching(|seq| { + if let Sequence::Triple { + literals, + offset, + match_len, + } = seq + && literals.is_empty() + && offset == 8 + && match_len >= MIN_MATCH_LEN + { + saw_match = true; + } + }); + + assert!( + saw_match, + "first full block should still match dictionary-primed history" + ); +} + +#[test] +fn prime_with_large_dictionary_preserves_early_history_until_first_block() { + let mut driver = MatchGeneratorDriver::new(8, 1); + driver.reset(CompressionLevel::Fastest); + + driver.prime_with_dictionary(b"abcdefghABCDEFGHijklmnop", [1, 4, 8]); + + let mut space = driver.get_next_space(); + space.clear(); + space.extend_from_slice(b"abcdefgh"); + driver.commit_space(space); + + let mut saw_match = false; + driver.start_matching(|seq| { + if let Sequence::Triple { + literals, + offset, + match_len, + } = seq + && literals.is_empty() + && offset == 24 + && match_len >= MIN_MATCH_LEN + { + saw_match = true; + } + }); + + assert!( + saw_match, + "dictionary bytes should remain addressable until frame output exceeds the live window" + ); +} + +#[test] +fn prime_with_dictionary_applies_offset_history_even_when_content_is_empty() { + let mut driver = MatchGeneratorDriver::new(8, 1); + driver.reset(CompressionLevel::Fastest); + + driver.prime_with_dictionary(&[], [11, 7, 3]); + + assert_eq!(driver.match_generator.offset_hist, [11, 7, 3]); +} + +#[test] +fn dfast_prime_with_dictionary_preserves_history_for_first_full_block() { + let mut driver = MatchGeneratorDriver::new(8, 1); + driver.reset(CompressionLevel::Default); + + driver.prime_with_dictionary(b"abcdefgh", [1, 4, 8]); + + let mut space = driver.get_next_space(); + space.clear(); + space.extend_from_slice(b"abcdefgh"); + driver.commit_space(space); + + let mut saw_match = false; + driver.start_matching(|seq| { + if let Sequence::Triple { + literals, + offset, + match_len, + } = seq + && literals.is_empty() + && offset == 8 + && match_len >= DFAST_MIN_MATCH_LEN + { + saw_match = true; + } + }); + + assert!( + saw_match, + "dfast backend should match dictionary-primed history in first full block" + ); +} + +#[test] +fn prime_with_dictionary_does_not_inflate_reported_window_size() { + let mut driver = MatchGeneratorDriver::new(8, 1); + driver.reset(CompressionLevel::Fastest); + + let before = driver.window_size(); + driver.prime_with_dictionary(b"abcdefghABCDEFGHijklmnop", [1, 4, 8]); + let after = driver.window_size(); + + assert_eq!( + after, before, + "dictionary retention budget must not change reported frame window size" + ); +} + +#[test] +fn prime_with_dictionary_does_not_reuse_tiny_suffix_store() { + let mut driver = MatchGeneratorDriver::new(8, 2); + driver.reset(CompressionLevel::Fastest); + + // This dictionary leaves a 1-byte tail chunk (capacity=1 suffix table), + // which should never be committed to the matcher window. + driver.prime_with_dictionary(b"abcdefghi", [1, 4, 8]); + + assert!( + driver + .match_generator + .window + .iter() + .all(|entry| entry.data.len() >= MIN_MATCH_LEN), + "dictionary priming must not commit tails shorter than MIN_MATCH_LEN" + ); +} + +#[test] +fn prime_with_dictionary_counts_only_committed_tail_budget() { + let mut driver = MatchGeneratorDriver::new(8, 1); + driver.reset(CompressionLevel::Fastest); + + let before = driver.match_generator.max_window_size; + // One full slice plus a 1-byte tail that cannot be committed. + driver.prime_with_dictionary(b"abcdefghi", [1, 4, 8]); + + assert_eq!( + driver.match_generator.max_window_size, + before + 8, + "retention budget must account only for dictionary bytes actually committed to history" + ); +} + +#[test] +fn dfast_prime_with_dictionary_counts_four_byte_tail_budget() { + let mut driver = MatchGeneratorDriver::new(8, 1); + driver.reset(CompressionLevel::Default); + + let before = driver.dfast_matcher().max_window_size; + // One full slice plus a 4-byte tail. Dfast can still use this tail through + // short-hash overlap into the next block, so it should stay retained. + driver.prime_with_dictionary(b"abcdefghijkl", [1, 4, 8]); + + assert_eq!( + driver.dfast_matcher().max_window_size, + before + 12, + "dfast retention budget should include 4-byte dictionary tails" + ); +} + +#[test] +fn prime_with_dictionary_budget_shrinks_after_simple_eviction() { + let mut driver = MatchGeneratorDriver::new(8, 1); + driver.reset(CompressionLevel::Fastest); + + let base_window = driver.match_generator.max_window_size; + driver.prime_with_dictionary(b"abcdefghABCDEFGHijklmnop", [1, 4, 8]); + assert_eq!(driver.match_generator.max_window_size, base_window + 24); + + for block in [b"AAAAAAAA", b"BBBBBBBB"] { + let mut space = driver.get_next_space(); + space.clear(); + space.extend_from_slice(block); + driver.commit_space(space); + driver.skip_matching(); + } + + assert_eq!( + driver.dictionary_retained_budget, 0, + "dictionary budget should be fully retired once primed dict slices are evicted" + ); + assert_eq!( + driver.match_generator.max_window_size, base_window, + "retired dictionary budget must not remain reusable for live history" + ); +} + +#[test] +fn prime_with_dictionary_budget_shrinks_after_dfast_eviction() { + let mut driver = MatchGeneratorDriver::new(8, 1); + driver.reset(CompressionLevel::Default); + // Use a small live window in this regression so dictionary-primed slices are + // evicted quickly and budget retirement can be asserted deterministically. + driver.dfast_matcher_mut().max_window_size = 8; + driver.reported_window_size = 8; + + let base_window = driver.dfast_matcher().max_window_size; + driver.prime_with_dictionary(b"abcdefghABCDEFGHijklmnop", [1, 4, 8]); + assert_eq!(driver.dfast_matcher().max_window_size, base_window + 24); + + for block in [b"AAAAAAAA", b"BBBBBBBB"] { + let mut space = driver.get_next_space(); + space.clear(); + space.extend_from_slice(block); + driver.commit_space(space); + driver.skip_matching(); + } + + assert_eq!( + driver.dictionary_retained_budget, 0, + "dictionary budget should be fully retired once primed dict slices are evicted" + ); + assert_eq!( + driver.dfast_matcher().max_window_size, + base_window, + "retired dictionary budget must not remain reusable for live history" + ); +} + +#[test] +fn suffix_store_with_single_slot_does_not_panic_on_keying() { + let mut suffixes = SuffixStore::with_capacity(1); + suffixes.insert(b"abcde", 0); + assert!(suffixes.contains_key(b"abcde")); + assert_eq!(suffixes.get(b"abcde"), Some(0)); +} + #[test] fn fastest_reset_uses_interleaved_hash_fill_step() { let mut driver = MatchGeneratorDriver::new(32, 2); @@ -1531,6 +1933,55 @@ fn dfast_skip_matching_handles_window_eviction() { assert_eq!(reconstructed, [7, 8, 9, 10, 11, 12, 7, 8, 9, 10, 11, 12]); } +#[test] +fn dfast_add_data_callback_reports_evicted_len_not_capacity() { + let mut matcher = DfastMatchGenerator::new(8); + + let mut first = Vec::with_capacity(64); + first.extend_from_slice(b"abcdefgh"); + matcher.add_data(first, |_| {}); + + let mut second = Vec::with_capacity(64); + second.extend_from_slice(b"ijklmnop"); + + let mut observed_evicted_len = None; + matcher.add_data(second, |data| { + observed_evicted_len = Some(data.len()); + }); + + assert_eq!( + observed_evicted_len, + Some(8), + "eviction callback must report evicted byte length, not backing capacity" + ); +} + +#[test] +fn dfast_trim_to_window_callback_reports_evicted_len_not_capacity() { + let mut matcher = DfastMatchGenerator::new(16); + + let mut first = Vec::with_capacity(64); + first.extend_from_slice(b"abcdefgh"); + matcher.add_data(first, |_| {}); + + let mut second = Vec::with_capacity(64); + second.extend_from_slice(b"ijklmnop"); + matcher.add_data(second, |_| {}); + + matcher.max_window_size = 8; + + let mut observed_evicted_len = None; + matcher.trim_to_window(|data| { + observed_evicted_len = Some(data.len()); + }); + + assert_eq!( + observed_evicted_len, + Some(8), + "trim callback must report evicted byte length, not backing capacity" + ); +} + #[test] fn dfast_inserts_tail_positions_for_next_block_matching() { let mut matcher = DfastMatchGenerator::new(DFAST_DEFAULT_WINDOW_SIZE); diff --git a/zstd/src/encoding/mod.rs b/zstd/src/encoding/mod.rs index aa640f66..49c6e36a 100644 --- a/zstd/src/encoding/mod.rs +++ b/zstd/src/encoding/mod.rs @@ -95,6 +95,14 @@ pub trait Matcher { fn start_matching(&mut self, handle_sequence: impl for<'a> FnMut(Sequence<'a>)); /// Reset this matcher so it can be used for the next new frame fn reset(&mut self, level: CompressionLevel); + /// Prime matcher state with dictionary history before compressing the next frame. + /// Default implementation is a no-op for custom matchers that do not support this. + fn prime_with_dictionary(&mut self, _dict_content: &[u8], _offset_hist: [u32; 3]) {} + /// Returns whether this matcher can consume dictionary priming state and produce + /// dictionary-dependent sequences. Defaults to `false` for custom matchers. + fn supports_dictionary_priming(&self) -> bool { + false + } /// The size of the window the decoder will need to execute all sequences produced by this matcher /// /// May change after a call to reset with a different compression level diff --git a/zstd/src/fse/fse_decoder.rs b/zstd/src/fse/fse_decoder.rs index 7cd59dc6..8d05e142 100644 --- a/zstd/src/fse/fse_decoder.rs +++ b/zstd/src/fse/fse_decoder.rs @@ -112,6 +112,18 @@ impl FSETable { self.accuracy_log = 0; } + /// Build the equivalent encoder-side table from a parsed decoder table. + pub(crate) fn to_encoder_table(&self) -> Option { + if self.accuracy_log == 0 || self.symbol_probabilities.is_empty() { + return None; + } + + Some(crate::fse::fse_encoder::build_table_from_probabilities( + &self.symbol_probabilities, + self.accuracy_log, + )) + } + /// returns how many BYTEs (not bits) were read while building the decoder pub fn build_decoder(&mut self, source: &[u8], max_log: u8) -> Result { self.accuracy_log = 0; diff --git a/zstd/src/huff0/huff0_decoder.rs b/zstd/src/huff0/huff0_decoder.rs index 1952aea3..b220cdc5 100644 --- a/zstd/src/huff0/huff0_decoder.rs +++ b/zstd/src/huff0/huff0_decoder.rs @@ -111,6 +111,28 @@ impl HuffmanTable { self.fse_table.reset(); } + /// Build the equivalent encoder-side Huffman table from parsed weights. + pub(crate) fn to_encoder_table(&self) -> Option { + if self.bits.is_empty() || self.max_num_bits == 0 { + return None; + } + + let max_bits = usize::from(self.max_num_bits); + let weights = self + .bits + .iter() + .copied() + .map(|num_bits| { + if num_bits == 0 { + 0 + } else { + max_bits - usize::from(num_bits) + 1 + } + }) + .collect::>(); + Some(crate::huff0::huff0_encoder::HuffmanTable::build_from_weights(&weights)) + } + /// Read from `source` and decode the input, populating the huffman decoding table. /// /// Returns the number of bytes read. diff --git a/zstd/src/huff0/huff0_encoder.rs b/zstd/src/huff0/huff0_encoder.rs index 7d35fc32..828c056f 100644 --- a/zstd/src/huff0/huff0_encoder.rs +++ b/zstd/src/huff0/huff0_encoder.rs @@ -150,6 +150,7 @@ impl>> HuffmanEncoder<'_, '_, V> { } } +#[derive(Clone)] pub struct HuffmanTable { /// Index is the symbol, values are the bitstring in the lower bits of the u32 and the amount of bits in the u8 codes: Vec<(u32, u8)>,