diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 50bfe13e..8c25fb26 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -10,6 +10,6 @@ Pure Rust zstd implementation — managed fork of [ruzstd (KillingSpark/zstd-rs) ## Rust Code Standards -- **Clippy:** Must pass `cargo clippy --all-features -- -D warnings` +- **Clippy:** Must pass `cargo clippy -p structured-zstd --features hash,std,dict_builder -- -D warnings` (`rustc-dep-of-std` is excluded — it's an internal feature for Rust stdlib builds only; `fuzz_exports` is excluded — fuzzing-specific entry points are validated separately from the regular lint gate) - This is a fork — avoid suggesting architectural changes that diverge too far from upstream - Performance-critical code: benchmark before/after any changes diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a54082ab..cb77893b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -59,9 +59,9 @@ jobs: steps: - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable - with: - targets: i686-unknown-linux-gnu - uses: taiki-e/install-action@nextest + - name: Install i686 target + run: rustup target add i686-unknown-linux-gnu - name: Install 32-bit libs run: sudo apt-get update && sudo apt-get install -y gcc-multilib - uses: Swatinem/rust-cache@v2 @@ -79,7 +79,7 @@ jobs: - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable with: - toolchain: "1.92" + toolchain: "1.92.0" - uses: taiki-e/install-action@nextest - uses: Swatinem/rust-cache@v2 with: diff --git a/cli/Cargo.toml b/cli/Cargo.toml index ecfbad3e..b269e83a 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -4,7 +4,7 @@ name = "structured-zstd-cli" version = "0.8.2" rust-version = "1.92" authors = ["Moritz Borcherding "] -edition = "2018" +edition = "2024" license = "Apache-2.0" homepage = "https://github.com/structured-world/structured-zstd" repository = "https://github.com/structured-world/structured-zstd" diff --git a/cli/src/progress.rs b/cli/src/progress.rs index c0971f76..7830b27d 100644 --- a/cli/src/progress.rs +++ b/cli/src/progress.rs @@ -144,7 +144,7 @@ mod tests { assert_eq!(&fmt_duration(Duration::from_secs(5 * 60)), "5m"); assert_eq!(&fmt_duration(Duration::from_secs(3 * 60 * 60)), "3h"); assert_eq!( - &fmt_duration(Duration::from_secs(1 * 60 * 60 + 20 * 60 + 30)), + &fmt_duration(Duration::from_secs(60 * 60 + 20 * 60 + 30)), "1h 20m 30s" ); } diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 00000000..1b6261ee --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,5 @@ +[toolchain] +# Follow the latest stable toolchain by default. +# MSRV remains 1.92.0 and is verified separately via `rust-version` plus the CI msrv job. +channel = "stable" +components = ["clippy", "rustfmt"] diff --git a/zstd/Cargo.toml b/zstd/Cargo.toml index 538eae27..30779b0f 100644 --- a/zstd/Cargo.toml +++ b/zstd/Cargo.toml @@ -6,7 +6,7 @@ authors = [ "Moritz Borcherding ", "Structured World Foundation ", ] -edition = "2018" +edition = "2024" license = "Apache-2.0" homepage = "https://github.com/structured-world/structured-zstd" repository = "https://github.com/structured-world/structured-zstd" diff --git a/zstd/benches/compare_ffi.rs b/zstd/benches/compare_ffi.rs index 5f5bef7c..2660702f 100644 --- a/zstd/benches/compare_ffi.rs +++ b/zstd/benches/compare_ffi.rs @@ -3,7 +3,7 @@ //! Five variations: decompress (pure Rust/C FFI), compress (pure Rust/C FFI L1/L3). //! Both decompress benchmarks allocate per-iteration for symmetric comparison. -use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; /// Compressed corpus for decompression benchmarks. const COMPRESSED_CORPUS: &[u8] = include_bytes!("../decodecorpus_files/z000033.zst"); diff --git a/zstd/benches/decode_all.rs b/zstd/benches/decode_all.rs index a17d2351..63a17485 100644 --- a/zstd/benches/decode_all.rs +++ b/zstd/benches/decode_all.rs @@ -1,4 +1,4 @@ -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use structured_zstd::decoding::FrameDecoder; fn criterion_benchmark(c: &mut Criterion) { diff --git a/zstd/src/bit_io/bit_writer.rs b/zstd/src/bit_io/bit_writer.rs index 7ce228a5..612eafbc 100644 --- a/zstd/src/bit_io/bit_writer.rs +++ b/zstd/src/bit_io/bit_writer.rs @@ -194,7 +194,10 @@ impl>> BitWriter { /// dumping pub fn dump(mut self) -> V { if self.misaligned() != 0 { - panic!("`dump` was called on a bit writer but an even number of bytes weren't written into the buffer. Was: {}", self.index()) + panic!( + "`dump` was called on a bit writer but an even number of bytes weren't written into the buffer. Was: {}", + self.index() + ) } self.flush(); debug_assert_eq!(self.partial, 0); @@ -248,7 +251,11 @@ mod tests { bw.write_bits(0b1111u8, 4); bw.write_bits(0b0000u8, 4); let output = bw.dump(); - assert!(output.len() == 1, "Single byte written into writer returned a vec that wasn't one byte, vec was {} elements long", output.len()); + assert!( + output.len() == 1, + "Single byte written into writer returned a vec that wasn't one byte, vec was {} elements long", + output.len() + ); assert_eq!( 0b0000_1111, output[0], "4 bits and 4 bits written into buffer" @@ -262,7 +269,11 @@ mod tests { bw.write_bits(0b111u8, 3); bw.write_bits(0b0_0000u8, 5); let output = bw.dump(); - assert!(output.len() == 1, "Single byte written into writer return a vec that wasn't one byte, vec was {} elements long", output.len()); + assert!( + output.len() == 1, + "Single byte written into writer return a vec that wasn't one byte, vec was {} elements long", + output.len() + ); assert_eq!(0b0000_0111, output[0], "3 and 5 bits written into buffer"); } @@ -273,7 +284,11 @@ mod tests { bw.write_bits(0b1u8, 1); bw.write_bits(0u8, 7); let output = bw.dump(); - assert!(output.len() == 1, "Single byte written into writer return a vec that wasn't one byte, vec was {} elements long", output.len()); + assert!( + output.len() == 1, + "Single byte written into writer return a vec that wasn't one byte, vec was {} elements long", + output.len() + ); assert_eq!(0b0000_0001, output[0], "1 and 7 bits written into buffer"); } @@ -283,7 +298,11 @@ mod tests { let mut bw = BitWriter::new(); bw.write_bits(1u8, 8); let output = bw.dump(); - assert!(output.len() == 1, "Single byte written into writer return a vec that wasn't one byte, vec was {} elements long", output.len()); + assert!( + output.len() == 1, + "Single byte written into writer return a vec that wasn't one byte, vec was {} elements long", + output.len() + ); assert_eq!(1, output[0], "1 and 7 bits written into buffer"); } diff --git a/zstd/src/decoding/block_decoder.rs b/zstd/src/decoding/block_decoder.rs index 08345f14..023ad964 100644 --- a/zstd/src/decoding/block_decoder.rs +++ b/zstd/src/decoding/block_decoder.rs @@ -46,7 +46,7 @@ impl BlockDecoder { DecoderState::ReadyToDecodeNextBody => { /* Happy :) */ } DecoderState::Failed => return Err(DecodeBlockContentError::DecoderStateIsFailed), DecoderState::ReadyToDecodeNextHeader => { - return Err(DecodeBlockContentError::ExpectedHeaderOfPreviousBlock) + return Err(DecodeBlockContentError::ExpectedHeaderOfPreviousBlock); } } @@ -108,7 +108,9 @@ impl BlockDecoder { } BlockType::Reserved => { - panic!("How did you even get this. The decoder should error out if it detects a reserved-type block"); + panic!( + "How did you even get this. The decoder should error out if it detects a reserved-type block" + ); } BlockType::Compressed => { diff --git a/zstd/src/decoding/errors.rs b/zstd/src/decoding/errors.rs index 59dc2026..9b1c6bb0 100644 --- a/zstd/src/decoding/errors.rs +++ b/zstd/src/decoding/errors.rs @@ -219,8 +219,9 @@ impl core::fmt::Display for BlockTypeError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { BlockTypeError::InvalidBlocktypeNumber { num } => { - write!(f, - "Invalid Blocktype number. Is: {num} Should be one of: 0, 1, 2, 3 (3 is reserved though", + write!( + f, + "Invalid Blocktype number. Is: {num}. Should be one of: 0, 1, 2, 3 (3 is reserved).", ) } } @@ -291,7 +292,8 @@ impl core::fmt::Display for DecompressBlockError { expected_len, remaining_bytes, } => { - write!(f, + write!( + f, "Malformed section header. Says literals would be this long: {expected_len} but there are only {remaining_bytes} bytes left", ) } @@ -370,9 +372,10 @@ impl core::fmt::Display for DecodeBlockContentError { ) } DecodeBlockContentError::ExpectedHeaderOfPreviousBlock => { - write!(f, - "Can't decode next block body, while expecting to decode the header of the previous block. Results will be nonsense", - ) + write!( + f, + "Can't decode next block body, while expecting to decode the header of the previous block. Results will be nonsense", + ) } DecodeBlockContentError::ReadError { step, source } => { write!(f, "Error while reading bytes for {step}: {source}",) @@ -545,10 +548,16 @@ impl core::fmt::Display for FrameDecoderError { ) } FrameDecoderError::TargetTooSmall => { - write!(f, "Target must have at least as many bytes as the contentsize of the frame reports") + write!( + f, + "Target must have at least as many bytes as the content size reported by the frame" + ) } FrameDecoderError::DictNotProvided { dict_id } => { - write!(f, "Frame header specified dictionary id 0x{dict_id:X} that wasnt provided by add_dict() or reset_with_dict()") + write!( + f, + "Frame header specified dictionary id 0x{dict_id:X} that wasn't provided via add_dict() or reset_with_dict()" + ) } } } @@ -609,12 +618,14 @@ impl core::fmt::Display for DecompressLiteralsError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { DecompressLiteralsError::MissingCompressedSize => { - write!(f, + write!( + f, "compressed size was none even though it must be set to something for compressed literals", ) } DecompressLiteralsError::MissingNumStreams => { - write!(f, + write!( + f, "num_streams was none even though it must be set to something (1 or 4) for compressed literals", ) } @@ -637,7 +648,8 @@ impl core::fmt::Display for DecompressLiteralsError { ) } DecompressLiteralsError::ExtraPadding { skipped_bits } => { - write!(f, + write!( + f, "Padding at the end of the sequence_section was more than a byte long: {skipped_bits} bits. Probably caused by data corruption", ) } @@ -754,7 +766,8 @@ impl core::fmt::Display for DecodeSequenceError { DecodeSequenceError::FSEDecoderError(e) => write!(f, "{e:?}"), DecodeSequenceError::FSETableError(e) => write!(f, "{e:?}"), DecodeSequenceError::ExtraPadding { skipped_bits } => { - write!(f, + write!( + f, "Padding at the end of the sequence_section was more than a byte long: {skipped_bits} bits. Probably caused by data corruption", ) } @@ -929,7 +942,8 @@ impl core::fmt::Display for FSETableError { expected_sum, symbol_probabilities, } => { - write!(f, + write!( + f, "The counter ({got}) exceeded the expected sum: {expected_sum}. This means an error or corrupted data \n {symbol_probabilities:?}", ) } @@ -1047,10 +1061,14 @@ impl core::fmt::Display for HuffmanTableError { got_bytes, expected_bytes, } => { - write!(f, "Header says there should be {expected_bytes} bytes for the weights but there are only {got_bytes} bytes in the stream") + write!( + f, + "Header says there should be {expected_bytes} bytes for the weights but there are only {got_bytes} bytes in the stream" + ) } HuffmanTableError::ExtraPadding { skipped_bits } => { - write!(f, + write!( + f, "Padding at the end of the sequence_section was more than a byte long: {skipped_bits} bits. Probably caused by data corruption", ) } @@ -1076,7 +1094,8 @@ impl core::fmt::Display for HuffmanTableError { used, available_bytes, } => { - write!(f, + write!( + f, "FSE table used more bytes: {used} than were meant to be used for the whole stream of huffman weights ({available_bytes})", ) } @@ -1149,3 +1168,98 @@ impl From for HuffmanDecoderError { Self::GetBitsError(val) } } + +#[cfg(test)] +mod tests { + use alloc::{string::ToString, vec}; + + use super::{ + BlockTypeError, DecodeBlockContentError, DecodeSequenceError, DecompressBlockError, + DecompressLiteralsError, FSETableError, FrameDecoderError, HuffmanTableError, + }; + + #[test] + fn block_and_sequence_display_messages_are_specific() { + assert_eq!( + BlockTypeError::InvalidBlocktypeNumber { num: 7 }.to_string(), + "Invalid Blocktype number. Is: 7. Should be one of: 0, 1, 2, 3 (3 is reserved)." + ); + assert_eq!( + DecompressBlockError::MalformedSectionHeader { + expected_len: 12, + remaining_bytes: 3, + } + .to_string(), + "Malformed section header. Says literals would be this long: 12 but there are only 3 bytes left" + ); + assert_eq!( + DecodeBlockContentError::ExpectedHeaderOfPreviousBlock.to_string(), + "Can't decode next block body, while expecting to decode the header of the previous block. Results will be nonsense" + ); + assert_eq!( + DecodeSequenceError::ExtraPadding { skipped_bits: 11 }.to_string(), + "Padding at the end of the sequence_section was more than a byte long: 11 bits. Probably caused by data corruption" + ); + } + + #[test] + fn frame_decoder_display_messages_are_specific() { + assert_eq!( + FrameDecoderError::TargetTooSmall.to_string(), + "Target must have at least as many bytes as the content size reported by the frame" + ); + assert_eq!( + FrameDecoderError::DictNotProvided { dict_id: 0xABCD }.to_string(), + "Frame header specified dictionary id 0xABCD that wasn't provided via add_dict() or reset_with_dict()" + ); + } + + #[test] + fn literal_display_messages_are_specific() { + assert_eq!( + DecompressLiteralsError::MissingCompressedSize.to_string(), + "compressed size was none even though it must be set to something for compressed literals" + ); + assert_eq!( + DecompressLiteralsError::MissingNumStreams.to_string(), + "num_streams was none even though it must be set to something (1 or 4) for compressed literals" + ); + assert_eq!( + DecompressLiteralsError::ExtraPadding { skipped_bits: 9 }.to_string(), + "Padding at the end of the sequence_section was more than a byte long: 9 bits. Probably caused by data corruption" + ); + } + + #[test] + fn fse_and_huffman_display_messages_are_specific() { + assert_eq!( + FSETableError::ProbabilityCounterMismatch { + got: 4, + expected_sum: 3, + symbol_probabilities: vec![1, -1], + } + .to_string(), + "The counter (4) exceeded the expected sum: 3. This means an error or corrupted data \n [1, -1]" + ); + assert_eq!( + HuffmanTableError::NotEnoughBytesForWeights { + got_bytes: 2, + expected_bytes: 5, + } + .to_string(), + "Header says there should be 5 bytes for the weights but there are only 2 bytes in the stream" + ); + assert_eq!( + HuffmanTableError::ExtraPadding { skipped_bits: 13 }.to_string(), + "Padding at the end of the sequence_section was more than a byte long: 13 bits. Probably caused by data corruption" + ); + assert_eq!( + HuffmanTableError::FSETableUsedTooManyBytes { + used: 7, + available_bytes: 6, + } + .to_string(), + "FSE table used more bytes: 7 than were meant to be used for the whole stream of huffman weights (6)" + ); + } +} diff --git a/zstd/src/decoding/ringbuffer.rs b/zstd/src/decoding/ringbuffer.rs index cf25bc6f..408a1088 100644 --- a/zstd/src/decoding/ringbuffer.rs +++ b/zstd/src/decoding/ringbuffer.rs @@ -470,56 +470,60 @@ impl RingBuffer { /// Needs start + len <= self.len() /// And more then len reserved space pub unsafe fn extend_from_within_unchecked_branchless(&mut self, start: usize, len: usize) { - // data slices in raw parts - let ((s1_ptr, s1_len), (s2_ptr, s2_len)) = self.data_slice_parts(); + // SAFETY: caller guarantees the source range is valid and enough free + // space exists; the raw-pointer arithmetic and copy stay within those bounds. + unsafe { + // data slices in raw parts + let ((s1_ptr, s1_len), (s2_ptr, s2_len)) = self.data_slice_parts(); - debug_assert!(len <= s1_len + s2_len, "{} > {} + {}", len, s1_len, s2_len); + debug_assert!(len <= s1_len + s2_len, "{} > {} + {}", len, s1_len, s2_len); - // calc the actually wanted slices in raw parts - let start_in_s1 = usize::min(s1_len, start); - let end_in_s1 = usize::min(s1_len, start + len); - let m1_ptr = s1_ptr.add(start_in_s1); - let m1_len = end_in_s1 - start_in_s1; + // calc the actually wanted slices in raw parts + let start_in_s1 = usize::min(s1_len, start); + let end_in_s1 = usize::min(s1_len, start + len); + let m1_ptr = s1_ptr.add(start_in_s1); + let m1_len = end_in_s1 - start_in_s1; - debug_assert!(end_in_s1 <= s1_len); - debug_assert!(start_in_s1 <= s1_len); + debug_assert!(end_in_s1 <= s1_len); + debug_assert!(start_in_s1 <= s1_len); - let start_in_s2 = start.saturating_sub(s1_len); - let end_in_s2 = start_in_s2 + (len - m1_len); - let m2_ptr = s2_ptr.add(start_in_s2); - let m2_len = end_in_s2 - start_in_s2; + let start_in_s2 = start.saturating_sub(s1_len); + let end_in_s2 = start_in_s2 + (len - m1_len); + let m2_ptr = s2_ptr.add(start_in_s2); + let m2_len = end_in_s2 - start_in_s2; - debug_assert!(start_in_s2 <= s2_len); - debug_assert!(end_in_s2 <= s2_len); + debug_assert!(start_in_s2 <= s2_len); + debug_assert!(end_in_s2 <= s2_len); - debug_assert_eq!(len, m1_len + m2_len); + debug_assert_eq!(len, m1_len + m2_len); - // the free slices, must hold: f1_len + f2_len >= m1_len + m2_len - let ((f1_ptr, f1_len), (f2_ptr, f2_len)) = self.free_slice_parts(); + // the free slices, must hold: f1_len + f2_len >= m1_len + m2_len + let ((f1_ptr, f1_len), (f2_ptr, f2_len)) = self.free_slice_parts(); - debug_assert!(f1_len + f2_len >= m1_len + m2_len); + debug_assert!(f1_len + f2_len >= m1_len + m2_len); - // calc how many from where bytes go where - let m1_in_f1 = usize::min(m1_len, f1_len); - let m1_in_f2 = m1_len - m1_in_f1; - let m2_in_f1 = usize::min(f1_len - m1_in_f1, m2_len); - let m2_in_f2 = m2_len - m2_in_f1; + // calc how many from where bytes go where + let m1_in_f1 = usize::min(m1_len, f1_len); + let m1_in_f2 = m1_len - m1_in_f1; + let m2_in_f1 = usize::min(f1_len - m1_in_f1, m2_len); + let m2_in_f2 = m2_len - m2_in_f1; - debug_assert_eq!(m1_len, m1_in_f1 + m1_in_f2); - debug_assert_eq!(m2_len, m2_in_f1 + m2_in_f2); - debug_assert!(f1_len >= m1_in_f1 + m2_in_f1); - debug_assert!(f2_len >= m1_in_f2 + m2_in_f2); - debug_assert_eq!(len, m1_in_f1 + m2_in_f1 + m1_in_f2 + m2_in_f2); + debug_assert_eq!(m1_len, m1_in_f1 + m1_in_f2); + debug_assert_eq!(m2_len, m2_in_f1 + m2_in_f2); + debug_assert!(f1_len >= m1_in_f1 + m2_in_f1); + debug_assert!(f2_len >= m1_in_f2 + m2_in_f2); + debug_assert_eq!(len, m1_in_f1 + m2_in_f1 + m1_in_f2 + m2_in_f2); - debug_assert!(self.buf.as_ptr().add(self.cap) > f1_ptr.add(m1_in_f1 + m2_in_f1)); - debug_assert!(self.buf.as_ptr().add(self.cap) > f2_ptr.add(m1_in_f2 + m2_in_f2)); + debug_assert!(self.buf.as_ptr().add(self.cap) >= f1_ptr.add(m1_in_f1 + m2_in_f1)); + debug_assert!(self.buf.as_ptr().add(self.cap) >= f2_ptr.add(m1_in_f2 + m2_in_f2)); - debug_assert!((m1_in_f2 > 0) ^ (m2_in_f1 > 0) || (m1_in_f2 == 0 && m2_in_f1 == 0)); + debug_assert!((m1_in_f2 > 0) ^ (m2_in_f1 > 0) || (m1_in_f2 == 0 && m2_in_f1 == 0)); - copy_with_checks( - m1_ptr, m2_ptr, f1_ptr, f2_ptr, m1_in_f1, m2_in_f1, m1_in_f2, m2_in_f2, - ); - self.tail = (self.tail + len) % self.cap; + copy_with_nobranch_check( + m1_ptr, m2_ptr, f1_ptr, f2_ptr, m1_in_f1, m2_in_f1, m1_in_f2, m2_in_f2, + ); + self.tail = (self.tail + len) % self.cap; + } } } @@ -572,33 +576,35 @@ unsafe fn copy_bytes_overshooting( let min_buffer_size = usize::min(src.1, dst.1); // Can copy in just one read+write, very common case - if min_buffer_size >= COPY_AT_ONCE_SIZE && copy_at_least <= COPY_AT_ONCE_SIZE { - dst.0 - .cast::() - .write_unaligned(src.0.cast::().read_unaligned()) - } else { - let copy_multiple = copy_at_least.next_multiple_of(COPY_AT_ONCE_SIZE); - // Can copy in multiple simple instructions - if min_buffer_size >= copy_multiple { - let mut src_ptr = src.0.cast::(); - let src_ptr_end = src.0.add(copy_multiple).cast::(); - let mut dst_ptr = dst.0.cast::(); - - while src_ptr < src_ptr_end { - dst_ptr.write_unaligned(src_ptr.read_unaligned()); - src_ptr = src_ptr.add(1); - dst_ptr = dst_ptr.add(1); - } + unsafe { + if min_buffer_size >= COPY_AT_ONCE_SIZE && copy_at_least <= COPY_AT_ONCE_SIZE { + dst.0 + .cast::() + .write_unaligned(src.0.cast::().read_unaligned()) } else { - // Fall back to standard memcopy - dst.0.copy_from_nonoverlapping(src.0, copy_at_least); + let copy_multiple = copy_at_least.next_multiple_of(COPY_AT_ONCE_SIZE); + // Can copy in multiple simple instructions + if min_buffer_size >= copy_multiple { + let mut src_ptr = src.0.cast::(); + let src_ptr_end = src.0.add(copy_multiple).cast::(); + let mut dst_ptr = dst.0.cast::(); + + while src_ptr < src_ptr_end { + dst_ptr.write_unaligned(src_ptr.read_unaligned()); + src_ptr = src_ptr.add(1); + dst_ptr = dst_ptr.add(1); + } + } else { + // Fall back to standard memcopy + dst.0.copy_from_nonoverlapping(src.0, copy_at_least); + } } - } - debug_assert_eq!( - slice::from_raw_parts(src.0, copy_at_least), - slice::from_raw_parts(dst.0, copy_at_least) - ); + debug_assert_eq!( + slice::from_raw_parts(src.0, copy_at_least), + slice::from_raw_parts(dst.0, copy_at_least) + ); + } } #[allow(dead_code)] @@ -614,43 +620,13 @@ unsafe fn copy_without_checks( m1_in_f2: usize, m2_in_f2: usize, ) { - f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1); - f1_ptr - .add(m1_in_f1) - .copy_from_nonoverlapping(m2_ptr, m2_in_f1); - - f2_ptr.copy_from_nonoverlapping(m1_ptr.add(m1_in_f1), m1_in_f2); - f2_ptr - .add(m1_in_f2) - .copy_from_nonoverlapping(m2_ptr.add(m2_in_f1), m2_in_f2); -} - -#[allow(dead_code)] -#[inline(always)] -#[allow(clippy::too_many_arguments)] -unsafe fn copy_with_checks( - m1_ptr: *const u8, - m2_ptr: *const u8, - f1_ptr: *mut u8, - f2_ptr: *mut u8, - m1_in_f1: usize, - m2_in_f1: usize, - m1_in_f2: usize, - m2_in_f2: usize, -) { - if m1_in_f1 != 0 { + unsafe { f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1); - } - if m2_in_f1 != 0 { f1_ptr .add(m1_in_f1) .copy_from_nonoverlapping(m2_ptr, m2_in_f1); - } - if m1_in_f2 != 0 { f2_ptr.copy_from_nonoverlapping(m1_ptr.add(m1_in_f1), m1_in_f2); - } - if m2_in_f2 != 0 { f2_ptr .add(m1_in_f2) .copy_from_nonoverlapping(m2_ptr.add(m2_in_f1), m2_in_f2); @@ -660,7 +636,7 @@ unsafe fn copy_with_checks( #[allow(dead_code)] #[inline(always)] #[allow(clippy::too_many_arguments)] -unsafe fn copy_with_nobranch_check( +unsafe fn copy_with_checks( m1_ptr: *const u8, m2_ptr: *const u8, f1_ptr: *mut u8, @@ -670,80 +646,147 @@ unsafe fn copy_with_nobranch_check( m1_in_f2: usize, m2_in_f2: usize, ) { - let case = (m1_in_f1 > 0) as usize - | (((m2_in_f1 > 0) as usize) << 1) - | (((m1_in_f2 > 0) as usize) << 2) - | (((m2_in_f2 > 0) as usize) << 3); - - match case { - 0 => {} - - // one bit set - 1 => { + unsafe { + if m1_in_f1 != 0 { f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1); } - 2 => { - f1_ptr.copy_from_nonoverlapping(m2_ptr, m2_in_f1); - } - 4 => { - f2_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f2); - } - 8 => { - f2_ptr.copy_from_nonoverlapping(m2_ptr, m2_in_f2); - } - - // two bit set - 3 => { - f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1); + if m2_in_f1 != 0 { f1_ptr .add(m1_in_f1) .copy_from_nonoverlapping(m2_ptr, m2_in_f1); } - 5 => { - f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1); + + if m1_in_f2 != 0 { f2_ptr.copy_from_nonoverlapping(m1_ptr.add(m1_in_f1), m1_in_f2); } - 6 => core::hint::unreachable_unchecked(), - 7 => core::hint::unreachable_unchecked(), - 9 => { - f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1); - f2_ptr.copy_from_nonoverlapping(m2_ptr, m2_in_f2); - } - 10 => { - f1_ptr.copy_from_nonoverlapping(m2_ptr, m2_in_f1); - f2_ptr.copy_from_nonoverlapping(m2_ptr.add(m2_in_f1), m2_in_f2); - } - 12 => { - f2_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f2); + if m2_in_f2 != 0 { f2_ptr .add(m1_in_f2) - .copy_from_nonoverlapping(m2_ptr, m2_in_f2); + .copy_from_nonoverlapping(m2_ptr.add(m2_in_f1), m2_in_f2); } + } +} - // three bit set - 11 => { - f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1); - f1_ptr - .add(m1_in_f1) - .copy_from_nonoverlapping(m2_ptr, m2_in_f1); - f2_ptr.copy_from_nonoverlapping(m2_ptr.add(m2_in_f1), m2_in_f2); - } - 13 => { - f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1); - f2_ptr.copy_from_nonoverlapping(m1_ptr.add(m1_in_f1), m1_in_f2); - f2_ptr - .add(m1_in_f2) - .copy_from_nonoverlapping(m2_ptr, m2_in_f2); +#[allow(dead_code)] +#[inline(always)] +#[allow(clippy::too_many_arguments)] +unsafe fn copy_with_nobranch_check( + m1_ptr: *const u8, + m2_ptr: *const u8, + f1_ptr: *mut u8, + f2_ptr: *mut u8, + m1_in_f1: usize, + m2_in_f1: usize, + m1_in_f2: usize, + m2_in_f2: usize, +) { + unsafe { + let case = (m1_in_f1 > 0) as usize + | (((m2_in_f1 > 0) as usize) << 1) + | (((m1_in_f2 > 0) as usize) << 2) + | (((m2_in_f2 > 0) as usize) << 3); + + match case { + 0 => {} + + // one bit set + 1 => { + f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1); + } + 2 => { + f1_ptr.copy_from_nonoverlapping(m2_ptr, m2_in_f1); + } + 4 => { + f2_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f2); + } + 8 => { + f2_ptr.copy_from_nonoverlapping(m2_ptr, m2_in_f2); + } + + // two bit set + 3 => { + f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1); + f1_ptr + .add(m1_in_f1) + .copy_from_nonoverlapping(m2_ptr, m2_in_f1); + } + 5 => { + f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1); + f2_ptr.copy_from_nonoverlapping(m1_ptr.add(m1_in_f1), m1_in_f2); + } + 6 => core::hint::unreachable_unchecked(), + 7 => core::hint::unreachable_unchecked(), + 9 => { + f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1); + f2_ptr.copy_from_nonoverlapping(m2_ptr, m2_in_f2); + } + 10 => { + f1_ptr.copy_from_nonoverlapping(m2_ptr, m2_in_f1); + f2_ptr.copy_from_nonoverlapping(m2_ptr.add(m2_in_f1), m2_in_f2); + } + 12 => { + f2_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f2); + f2_ptr + .add(m1_in_f2) + .copy_from_nonoverlapping(m2_ptr, m2_in_f2); + } + + // three bit set + 11 => { + f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1); + f1_ptr + .add(m1_in_f1) + .copy_from_nonoverlapping(m2_ptr, m2_in_f1); + f2_ptr.copy_from_nonoverlapping(m2_ptr.add(m2_in_f1), m2_in_f2); + } + 13 => { + f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1); + f2_ptr.copy_from_nonoverlapping(m1_ptr.add(m1_in_f1), m1_in_f2); + f2_ptr + .add(m1_in_f2) + .copy_from_nonoverlapping(m2_ptr, m2_in_f2); + } + 14 => core::hint::unreachable_unchecked(), + 15 => core::hint::unreachable_unchecked(), + _ => core::hint::unreachable_unchecked(), } - 14 => core::hint::unreachable_unchecked(), - 15 => core::hint::unreachable_unchecked(), - _ => core::hint::unreachable_unchecked(), } } #[cfg(test)] mod tests { - use super::RingBuffer; + use super::{RingBuffer, copy_bytes_overshooting, copy_with_checks, copy_with_nobranch_check}; + use core::mem::size_of; + + fn assert_buffers_equal(expected: &RingBuffer, actual: &RingBuffer) { + assert_eq!(expected.len(), actual.len()); + assert_eq!(expected.as_slices(), actual.as_slices()); + assert_eq!(expected.head, actual.head); + assert_eq!(expected.tail, actual.tail); + assert_eq!(expected.cap, actual.cap); + } + + fn assert_branchless_matches_checked( + mut checked: RingBuffer, + mut branchless: RingBuffer, + start: usize, + len: usize, + ) { + assert!(checked.free() >= len); + assert!(branchless.free() >= len); + + unsafe { + checked.extend_from_within_unchecked(start, len); + branchless.extend_from_within_unchecked_branchless(start, len); + } + + assert_buffers_equal(&checked, &branchless); + } + + #[cfg(all(not(target_feature = "sse2"), not(target_feature = "neon")))] + const COPY_CHUNK_SIZE: usize = size_of::(); + #[cfg(any(target_feature = "sse2", target_feature = "neon"))] + const COPY_CHUNK_SIZE: usize = size_of::(); #[test] fn smoke() { @@ -884,4 +927,130 @@ mod tests { assert_eq!(b"11", rb.as_slices().0); assert_eq!(b"111111", rb.as_slices().1); } + + #[test] + fn extend_from_within_branchless_matches_checked_across_layouts() { + let contiguous = || { + let mut rb = RingBuffer::new(); + rb.reserve(16); + rb.extend(b"0123456789"); + rb + }; + assert_branchless_matches_checked(contiguous(), contiguous(), 2, 5); + + let wrapped_write = || { + let mut rb = RingBuffer::new(); + rb.reserve(16); + rb.extend(b"0123456789ABC"); + rb.drop_first_n(2); + rb + }; + assert_branchless_matches_checked(wrapped_write(), wrapped_write(), 1, 5); + + let wrapped_data = || { + let mut rb = RingBuffer::new(); + rb.reserve(32); + rb.extend(b"0123456789abcdefghijklmn"); + rb.drop_first_n(18); + rb.extend(b"wxyz012345"); + rb + }; + assert_branchless_matches_checked(wrapped_data(), wrapped_data(), 8, 2); + assert_branchless_matches_checked(wrapped_data(), wrapped_data(), 2, 8); + } + + #[test] + fn copy_with_nobranch_check_matches_checked_for_all_valid_case_masks() { + let cases = [ + (0, 0, 0, 0), + (1, 0, 0, 0), + (0, 1, 0, 0), + (0, 0, 1, 0), + (0, 0, 0, 1), + (1, 1, 0, 0), + (1, 0, 1, 0), + (1, 0, 0, 1), + (0, 1, 0, 1), + (0, 0, 1, 1), + (1, 1, 0, 1), + (1, 0, 1, 1), + ]; + + for (m1_in_f1, m2_in_f1, m1_in_f2, m2_in_f2) in cases { + let m1 = [11_u8, 12, 13, 14]; + let m2 = [21_u8, 22, 23, 24]; + let mut expected = [0_u8; 8]; + let mut actual = [0_u8; 8]; + + unsafe { + copy_with_checks( + m1.as_ptr(), + m2.as_ptr(), + expected.as_mut_ptr(), + expected.as_mut_ptr().add(4), + m1_in_f1, + m2_in_f1, + m1_in_f2, + m2_in_f2, + ); + copy_with_nobranch_check( + m1.as_ptr(), + m2.as_ptr(), + actual.as_mut_ptr(), + actual.as_mut_ptr().add(4), + m1_in_f1, + m2_in_f1, + m1_in_f2, + m2_in_f2, + ); + } + + assert_eq!( + expected, actual, + "case=({}, {}, {}, {})", + m1_in_f1, m2_in_f1, m1_in_f2, m2_in_f2 + ); + } + } + + #[test] + fn copy_bytes_overshooting_covers_all_copy_strategies() { + let src_single = [1_u8; 64]; + let mut dst_single = [0_u8; 64]; + unsafe { + copy_bytes_overshooting( + (src_single.as_ptr(), COPY_CHUNK_SIZE), + (dst_single.as_mut_ptr(), COPY_CHUNK_SIZE), + COPY_CHUNK_SIZE, + ); + } + assert_eq!( + &dst_single[..COPY_CHUNK_SIZE], + &src_single[..COPY_CHUNK_SIZE] + ); + + let multi_len = COPY_CHUNK_SIZE * 2; + let src_multi = [2_u8; 64]; + let mut dst_multi = [0_u8; 64]; + unsafe { + copy_bytes_overshooting( + (src_multi.as_ptr(), multi_len), + (dst_multi.as_mut_ptr(), multi_len), + multi_len, + ); + } + assert_eq!(&dst_multi[..multi_len], &src_multi[..multi_len]); + + let fallback_len = COPY_CHUNK_SIZE + 1; + let src_fallback = [3_u8; 64]; + let mut dst_fallback = [0_u8; 64]; + unsafe { + copy_bytes_overshooting( + (src_fallback.as_ptr(), fallback_len), + (dst_fallback.as_mut_ptr(), fallback_len), + fallback_len, + ); + } + assert_eq!(&dst_fallback[..fallback_len], &src_fallback[..fallback_len]); + } } diff --git a/zstd/src/dictionary/mod.rs b/zstd/src/dictionary/mod.rs index 3dfd287e..117bfe7a 100644 --- a/zstd/src/dictionary/mod.rs +++ b/zstd/src/dictionary/mod.rs @@ -33,7 +33,6 @@ use cover::*; use std::{ boxed::Box, collections::{BinaryHeap, HashMap}, - dbg, fs::{self, File}, io::{self, BufReader, Read}, path::{Path, PathBuf}, @@ -160,9 +159,9 @@ pub fn create_raw_dict_from_source( }; // Score each segment in the epoch and select the highest scoring segment // for the pool - while dbg!(buffered_source + while buffered_source .read(&mut current_epoch) - .expect("can read input")) + .expect("can read input") != 0 { epoch_counter += 1; diff --git a/zstd/src/encoding/blocks/compressed.rs b/zstd/src/encoding/blocks/compressed.rs index 1d0fbac2..54fae524 100644 --- a/zstd/src/encoding/blocks/compressed.rs +++ b/zstd/src/encoding/blocks/compressed.rs @@ -1,10 +1,10 @@ -use alloc::vec::Vec; +use alloc::{boxed::Box, vec::Vec}; use crate::{ bit_io::BitWriter, - encoding::frame_compressor::CompressState, + encoding::frame_compressor::{CompressState, FseTables, PreviousFseTable}, encoding::{Matcher, Sequence}, - fse::fse_encoder::{build_table_from_data, FSETable, State}, + fse::fse_encoder::{FSETable, State, build_table_from_symbol_counts}, huff0::huff0_encoder, }; @@ -15,21 +15,23 @@ const _: () = assert!(crate::common::MAX_BLOCK_SIZE <= 262_143); pub fn compress_block(state: &mut CompressState, output: &mut Vec) { let mut literals_vec = Vec::new(); let mut sequences = Vec::new(); - state.matcher.start_matching(|seq| { - match seq { - Sequence::Literals { literals } => literals_vec.extend_from_slice(literals), - Sequence::Triple { - literals, - offset, - match_len, - } => { - literals_vec.extend_from_slice(literals); - sequences.push(crate::blocks::sequence_section::Sequence { - ll: literals.len() as u32, - ml: match_len as u32, - of: (offset + 3) as u32, // TODO make use of the offset history - }); - } + let offset_hist = &mut state.offset_hist; + state.matcher.start_matching(|seq| match seq { + Sequence::Literals { literals } => literals_vec.extend_from_slice(literals), + Sequence::Triple { + literals, + offset, + match_len, + } => { + let ll = literals.len() as u32; + literals_vec.extend_from_slice(literals); + let actual_offset = offset as u32; + let of = encode_offset_with_history(actual_offset, ll, offset_hist); + sequences.push(crate::blocks::sequence_section::Sequence { + ll, + ml: match_len as u32, + of, + }); } }); @@ -54,21 +56,29 @@ pub fn compress_block(state: &mut CompressState, output: &mut Vec encode_seqnum(sequences.len(), &mut writer); // Choose the tables - // TODO store previously used tables let ll_mode = choose_table( - state.fse_tables.ll_previous.as_ref(), + previous_table( + state.fse_tables.ll_previous.as_ref(), + &state.fse_tables.ll_default, + ), &state.fse_tables.ll_default, sequences.iter().map(|seq| encode_literal_length(seq.ll).0), 9, ); let ml_mode = choose_table( - state.fse_tables.ml_previous.as_ref(), + previous_table( + state.fse_tables.ml_previous.as_ref(), + &state.fse_tables.ml_default, + ), &state.fse_tables.ml_default, sequences.iter().map(|seq| encode_match_len(seq.ml).0), 9, ); let of_mode = choose_table( - state.fse_tables.of_previous.as_ref(), + previous_table( + state.fse_tables.of_previous.as_ref(), + &state.fse_tables.of_default, + ), &state.fse_tables.of_default, sequences.iter().map(|seq| encode_offset(seq.of).0), 8, @@ -88,15 +98,10 @@ pub fn compress_block(state: &mut CompressState, output: &mut Vec of_mode.as_ref(), ); - if let FseTableMode::Encoded(table) = ll_mode { - state.fse_tables.ll_previous = Some(table) - } - if let FseTableMode::Encoded(table) = ml_mode { - state.fse_tables.ml_previous = Some(table) - } - if let FseTableMode::Encoded(table) = of_mode { - state.fse_tables.of_previous = Some(table) - } + let ll_last = into_last_used_table(ll_mode); + let ml_last = into_last_used_table(ml_mode); + let of_last = into_last_used_table(of_mode); + remember_last_used_tables(&mut state.fse_tables, ll_last, ml_last, of_last); } writer.flush(); } @@ -106,41 +111,145 @@ pub fn compress_block(state: &mut CompressState, output: &mut Vec enum FseTableMode<'a> { Predefined(&'a FSETable), Encoded(FSETable), - RepeateLast(&'a FSETable), + RepeatLast(&'a FSETable), } impl FseTableMode<'_> { pub fn as_ref(&self) -> &FSETable { match self { Self::Predefined(t) => t, - Self::RepeateLast(t) => t, + Self::RepeatLast(t) => t, Self::Encoded(t) => t, } } } +/// Estimate the encoding cost (in bits) of the given symbol distribution using a table. +/// Returns `None` if the table cannot encode all symbols present in the data. +fn estimate_encoding_cost(counts: &[usize; 256], total: usize, table: &FSETable) -> Option { + if total == 0 { + return Some(0); + } + let table_size = table.table_size as f64; + let mut cost_bits = 0.0f64; + for (symbol, &count) in counts.iter().enumerate() { + if count == 0 { + continue; + } + let prob = table.symbol_probability(symbol as u8); + if prob == 0 { + // Table cannot encode this symbol + return None; + } + let effective_prob = if prob == -1 { 1 } else { prob as usize }; + // Keep the same entropy-style estimate for every candidate table. A + // cheaper integer proxy perturbs close Encoded vs Repeat/Predefined + // decisions, which is harder to recover from than this small setup cost. + // Bits per symbol ≈ log2(table_size / probability) + let bits_per_symbol = (table_size / effective_prob as f64).log2(); + cost_bits += count as f64 * bits_per_symbol; + } + Some(cost_bits as usize) +} + fn choose_table<'a>( previous: Option<&'a FSETable>, default_table: &'a FSETable, data: impl Iterator, max_log: u8, ) -> FseTableMode<'a> { - // TODO check if the new table is better than the predefined and previous table - let use_new_table = true; - let use_previous_table = false; - if use_previous_table { - FseTableMode::RepeateLast(previous.unwrap()) - } else if use_new_table { - FseTableMode::Encoded(build_table_from_data(data, max_log, true)) - } else { - FseTableMode::Predefined(default_table) + // Collect symbol distribution + let mut counts = [0usize; 256]; + let mut total = 0usize; + for symbol in data { + counts[symbol as usize] += 1; + total += 1; + } + + if total == 0 { + return FseTableMode::Predefined(default_table); + } + + // Build a new table from the actual data distribution + let max_symbol = counts + .iter() + .rposition(|&count| count > 0) + .unwrap_or_default(); + let distinct_symbols = counts.iter().filter(|&&count| count > 0).take(2).count(); + // Sequence-section RLE mode is not implemented yet, so one-symbol streams + // stay on the Predefined/Repeat paths unless those tables cannot encode the + // stream at all. For non-degenerate inputs we still build the dynamic candidate + // here instead of adding a heuristic short-circuit: exact cost comparison is + // what lets Repeat, Predefined, and Encoded compete without hard-coded ratio + // regressions. + let new_table = (distinct_symbols > 1) + .then(|| build_table_from_symbol_counts(&counts[..=max_symbol], max_log, true)); + + // Estimate costs: encoding cost + table header cost. `None` means the + // candidate table cannot encode the observed distribution. + let new_total_cost = new_table.as_ref().and_then(|table| { + estimate_encoding_cost(&counts, total, table) + .map(|cost| cost.saturating_add(table.table_header_bits())) + }); + + // Predefined table: zero header cost + let predefined_cost = estimate_encoding_cost(&counts, total, default_table); + + // Previous table: zero header cost (repeat mode) + let previous_cost = previous.and_then(|table| estimate_encoding_cost(&counts, total, table)); + + enum Choice { + Previous, + Predefined, + New, + } + + let mut best: Option<(usize, Choice)> = None; + + if let Some(cost) = previous_cost { + best = Some((cost, Choice::Previous)); + } + + if let Some(cost) = predefined_cost { + match best { + Some((best_cost, _)) if best_cost <= cost => {} + _ => best = Some((cost, Choice::Predefined)), + } + } + + if let Some(cost) = new_total_cost { + match best { + Some((best_cost, _)) if best_cost <= cost => {} + _ => best = Some((cost, Choice::New)), + } + } + + match best.map(|(_, choice)| choice) { + Some(Choice::Previous) => previous + .map(FseTableMode::RepeatLast) + .unwrap_or(FseTableMode::Predefined(default_table)), + Some(Choice::Predefined) => FseTableMode::Predefined(default_table), + Some(Choice::New) => new_table + .map(FseTableMode::Encoded) + .unwrap_or(FseTableMode::Predefined(default_table)), + None => { + let fallback_counts = [counts[0], 0]; + let fallback = if max_symbol == 0 { + // `build_table_from_symbol_counts` needs at least two entries, so + // single-symbol streams use a phantom zero-count second slot here. + build_table_from_symbol_counts(&fallback_counts, max_log, true) + } else { + build_table_from_symbol_counts(&counts[..=max_symbol], max_log, true) + }; + FseTableMode::Encoded(fallback) + } } } fn encode_table(mode: &FseTableMode<'_>, writer: &mut BitWriter<&mut Vec>) { match mode { FseTableMode::Predefined(_) => {} - FseTableMode::RepeateLast(_) => {} + FseTableMode::RepeatLast(_) => {} FseTableMode::Encoded(table) => table.write_table(writer), } } @@ -154,12 +263,44 @@ fn encode_fse_table_modes( match mode { FseTableMode::Predefined(_) => 0, FseTableMode::Encoded(_) => 2, - FseTableMode::RepeateLast(_) => 3, + FseTableMode::RepeatLast(_) => 3, } } mode_to_bits(ll_mode) << 6 | mode_to_bits(of_mode) << 4 | mode_to_bits(ml_mode) << 2 } +fn remember_last_used_tables( + fse_tables: &mut FseTables, + ll_last: Option, + ml_last: Option, + of_last: Option, +) { + remember_last_used_table(&mut fse_tables.ll_previous, ll_last); + remember_last_used_table(&mut fse_tables.ml_previous, ml_last); + remember_last_used_table(&mut fse_tables.of_previous, of_last); +} + +fn previous_table<'a>( + previous: Option<&'a PreviousFseTable>, + default: &'a FSETable, +) -> Option<&'a FSETable> { + previous.map(|previous| previous.as_table(default)) +} + +fn remember_last_used_table(slot: &mut Option, next: Option) { + if let Some(next) = next { + *slot = Some(next); + } +} + +fn into_last_used_table(mode: FseTableMode<'_>) -> Option { + match mode { + FseTableMode::Encoded(table) => Some(PreviousFseTable::Custom(Box::new(table))), + FseTableMode::Predefined(_) => Some(PreviousFseTable::Default), + FseTableMode::RepeatLast(_) => None, + } +} + fn encode_sequences( sequences: &[crate::blocks::sequence_section::Sequence], writer: &mut BitWriter<&mut Vec>, @@ -301,6 +442,71 @@ fn encode_match_len(len: u32) -> (u8, u32, usize) { } } +/// Convert an actual byte offset into the encoded offset code, using repeat offset +/// history per RFC 8878 §3.1.2.5. Updates `offset_hist` in place. +/// +/// Encoded offset codes: 1/2/3 = repeat offsets, N+3 = new absolute offset N. +fn encode_offset_with_history(actual_offset: u32, lit_len: u32, offset_hist: &mut [u32; 3]) -> u32 { + let encoded = if lit_len > 0 { + if actual_offset == offset_hist[0] { + 1 + } else if actual_offset == offset_hist[1] { + 2 + } else if actual_offset == offset_hist[2] { + 3 + } else { + actual_offset + 3 + } + } else { + // When lit_len == 0, repeat offset mapping shifts per RFC 8878: + // code 1 → rep[1], code 2 → rep[2], code 3 → rep[0]-1 + if actual_offset == offset_hist[1] { + 1 + } else if actual_offset == offset_hist[2] { + 2 + } else if actual_offset == offset_hist[0].wrapping_sub(1) && offset_hist[0] > 1 { + 3 + } else { + actual_offset + 3 + } + }; + + // Update history (same rules as decoder) + if lit_len > 0 { + match encoded { + 1 => { /* rep[0] stays the same */ } + 2 => { + offset_hist[1] = offset_hist[0]; + offset_hist[0] = actual_offset; + } + _ => { + offset_hist[2] = offset_hist[1]; + offset_hist[1] = offset_hist[0]; + offset_hist[0] = actual_offset; + } + } + } else { + match encoded { + 1 => { + offset_hist[1] = offset_hist[0]; + offset_hist[0] = actual_offset; + } + 2 => { + offset_hist[2] = offset_hist[1]; + offset_hist[1] = offset_hist[0]; + offset_hist[0] = actual_offset; + } + _ => { + offset_hist[2] = offset_hist[1]; + offset_hist[1] = offset_hist[0]; + offset_hist[0] = actual_offset; + } + } + } + + encoded +} + fn encode_offset(len: u32) -> (u8, u32, usize) { let log = len.ilog2(); let lower = len & ((1 << log) - 1); @@ -391,3 +597,147 @@ fn compress_literals( None } } + +#[cfg(test)] +mod tests { + use alloc::boxed::Box; + + use super::{ + FseTableMode, choose_table, encode_offset_with_history, previous_table, + remember_last_used_tables, + }; + use crate::encoding::frame_compressor::{FseTables, PreviousFseTable}; + use crate::fse::fse_encoder::build_table_from_symbol_counts; + + fn tables_match( + lhs: &crate::fse::fse_encoder::FSETable, + rhs: &crate::fse::fse_encoder::FSETable, + ) -> bool { + lhs.table_size == rhs.table_size + && (0..=255u8) + .all(|symbol| lhs.symbol_probability(symbol) == rhs.symbol_probability(symbol)) + } + + #[test] + fn repeat_offset_codes_follow_rfc_mapping() { + let mut hist = [10, 20, 30]; + assert_eq!(encode_offset_with_history(10, 5, &mut hist), 1); + assert_eq!(hist, [10, 20, 30]); + + let mut hist = [10, 20, 30]; + assert_eq!(encode_offset_with_history(20, 5, &mut hist), 2); + assert_eq!(hist, [20, 10, 30]); + + let mut hist = [10, 20, 30]; + assert_eq!(encode_offset_with_history(30, 5, &mut hist), 3); + assert_eq!(hist, [30, 10, 20]); + + let mut hist = [10, 20, 30]; + assert_eq!(encode_offset_with_history(20, 0, &mut hist), 1); + assert_eq!(hist, [20, 10, 30]); + + let mut hist = [10, 20, 30]; + assert_eq!(encode_offset_with_history(30, 0, &mut hist), 2); + assert_eq!(hist, [30, 10, 20]); + + let mut hist = [10, 20, 30]; + assert_eq!(encode_offset_with_history(9, 0, &mut hist), 3); + assert_eq!(hist, [9, 10, 20]); + } + + #[test] + fn remember_last_used_tables_keeps_predefined_and_repeat_modes() { + let mut fse_tables = FseTables::new(); + + remember_last_used_tables( + &mut fse_tables, + Some(PreviousFseTable::Default), + Some(PreviousFseTable::Default), + Some(PreviousFseTable::Default), + ); + + assert!(tables_match( + previous_table(fse_tables.ll_previous.as_ref(), &fse_tables.ll_default).unwrap(), + &fse_tables.ll_default + )); + assert!(tables_match( + previous_table(fse_tables.ml_previous.as_ref(), &fse_tables.ml_default).unwrap(), + &fse_tables.ml_default + )); + assert!(tables_match( + previous_table(fse_tables.of_previous.as_ref(), &fse_tables.of_default).unwrap(), + &fse_tables.of_default + )); + + let sample_codes = [0u8, 1u8]; + let ll_repeat = choose_table( + previous_table(fse_tables.ll_previous.as_ref(), &fse_tables.ll_default), + &fse_tables.ll_default, + sample_codes.iter().copied(), + 9, + ); + let ml_repeat = choose_table( + previous_table(fse_tables.ml_previous.as_ref(), &fse_tables.ml_default), + &fse_tables.ml_default, + sample_codes.iter().copied(), + 9, + ); + let of_repeat = choose_table( + previous_table(fse_tables.of_previous.as_ref(), &fse_tables.of_default), + &fse_tables.of_default, + sample_codes.iter().copied(), + 8, + ); + + assert!(matches!(ll_repeat, FseTableMode::RepeatLast(_))); + assert!(matches!(ml_repeat, FseTableMode::RepeatLast(_))); + assert!(matches!(of_repeat, FseTableMode::RepeatLast(_))); + } + + #[test] + fn remember_last_used_tables_reuses_existing_custom_slot_for_repeat() { + let mut fse_tables = FseTables::new(); + let custom = build_table_from_symbol_counts(&[1, 1], 5, false); + fse_tables.ll_previous = Some(PreviousFseTable::Custom(Box::new(custom))); + + let before = core::ptr::from_ref( + previous_table(fse_tables.ll_previous.as_ref(), &fse_tables.ll_default).unwrap(), + ); + + remember_last_used_tables( + &mut fse_tables, + None, + Some(PreviousFseTable::Default), + Some(PreviousFseTable::Default), + ); + + let after = core::ptr::from_ref( + previous_table(fse_tables.ll_previous.as_ref(), &fse_tables.ll_default).unwrap(), + ); + + assert_eq!(before, after); + assert!(matches!( + fse_tables.ll_previous.as_ref(), + Some(PreviousFseTable::Custom(_)) + )); + } + + #[test] + fn choose_table_handles_single_symbol_distribution() { + let fse_tables = FseTables::new(); + let mode = choose_table( + None, + &fse_tables.ll_default, + core::iter::repeat_n(0u8, 32), + 9, + ); + assert!(matches!(mode, FseTableMode::Predefined(_))); + } + + #[test] + fn choose_table_without_previous_does_not_unwrap_none() { + let only_zero_table = build_table_from_symbol_counts(&[1], 5, false); + let mode = choose_table(None, &only_zero_table, core::iter::repeat_n(1u8, 32), 5); + assert!(matches!(mode, FseTableMode::Encoded(_))); + } +} diff --git a/zstd/src/encoding/frame_compressor.rs b/zstd/src/encoding/frame_compressor.rs index 231885c3..a50d392b 100644 --- a/zstd/src/encoding/frame_compressor.rs +++ b/zstd/src/encoding/frame_compressor.rs @@ -1,6 +1,6 @@ //! Utilities and interfaces for encoding an entire frame. Allows reusing resources -use alloc::vec::Vec; +use alloc::{boxed::Box, vec::Vec}; use core::convert::TryInto; #[cfg(feature = "hash")] use twox_hash::XxHash64; @@ -9,10 +9,10 @@ use twox_hash::XxHash64; use core::hash::Hasher; use super::{ - block_header::BlockHeader, frame_header::FrameHeader, levels::*, - match_generator::MatchGeneratorDriver, CompressionLevel, Matcher, + CompressionLevel, Matcher, block_header::BlockHeader, frame_header::FrameHeader, levels::*, + match_generator::MatchGeneratorDriver, }; -use crate::fse::fse_encoder::{default_ll_table, default_ml_table, default_of_table, FSETable}; +use crate::fse::fse_encoder::{FSETable, default_ll_table, default_ml_table, default_of_table}; use crate::io::{Read, Write}; @@ -44,13 +44,30 @@ pub struct FrameCompressor { hasher: XxHash64, } +#[derive(Clone)] +pub(crate) enum PreviousFseTable { + // Default tables are immutable and already stored alongside the state, so + // repeating them only needs a lightweight marker instead of cloning FSETable. + Default, + Custom(Box), +} + +impl PreviousFseTable { + pub(crate) fn as_table<'a>(&'a self, default: &'a FSETable) -> &'a FSETable { + match self { + Self::Default => default, + Self::Custom(table) => table, + } + } +} + pub(crate) struct FseTables { pub(crate) ll_default: FSETable, - pub(crate) ll_previous: Option, + pub(crate) ll_previous: Option, pub(crate) ml_default: FSETable, - pub(crate) ml_previous: Option, + pub(crate) ml_previous: Option, pub(crate) of_default: FSETable, - pub(crate) of_previous: Option, + pub(crate) of_previous: Option, } impl FseTables { @@ -70,6 +87,9 @@ pub(crate) struct CompressState { pub(crate) matcher: M, pub(crate) last_huff_table: Option, pub(crate) fse_tables: FseTables, + /// Offset history for repeat offset encoding: [rep0, rep1, rep2]. + /// Initialized to [1, 4, 8] per RFC 8878 §3.1.2.5. + pub(crate) offset_hist: [u32; 3], } impl FrameCompressor { @@ -83,6 +103,7 @@ impl FrameCompressor { matcher: MatchGeneratorDriver::new(1024 * 128, 1), last_huff_table: None, fse_tables: FseTables::new(), + offset_hist: [1, 4, 8], }, #[cfg(feature = "hash")] hasher: XxHash64::with_seed(0), @@ -100,6 +121,7 @@ impl FrameCompressor { matcher, last_huff_table: None, fse_tables: FseTables::new(), + offset_hist: [1, 4, 8], }, compression_level, #[cfg(feature = "hash")] @@ -132,6 +154,10 @@ impl FrameCompressor { // Clearing buffers to allow re-using of the compressor self.state.matcher.reset(self.compression_level); self.state.last_huff_table = None; + self.state.fse_tables.ll_previous = None; + self.state.fse_tables.ml_previous = None; + self.state.fse_tables.of_previous = None; + self.state.offset_hist = [1, 4, 8]; #[cfg(feature = "hash")] { self.hasher = XxHash64::with_seed(0); diff --git a/zstd/src/encoding/frame_header.rs b/zstd/src/encoding/frame_header.rs index 9e957eee..186f3504 100644 --- a/zstd/src/encoding/frame_header.rs +++ b/zstd/src/encoding/frame_header.rs @@ -42,12 +42,12 @@ impl FrameHeader { // `Window_Descriptor // TODO: https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor - if !self.single_segment { - if let Some(window_size) = self.window_size { - let log = window_size.next_power_of_two().ilog2(); - let exponent = if log > 10 { log - 10 } else { 1 } as u8; - output.push(exponent << 3); - } + if !self.single_segment + && let Some(window_size) = self.window_size + { + let log = window_size.next_power_of_two().ilog2(); + let exponent = if log > 10 { log - 10 } else { 1 } as u8; + output.push(exponent << 3); } if let Some(id) = self.dictionary_id { @@ -116,7 +116,10 @@ impl FrameHeader { // and the `Frame_Content_Size` field must be present in the header. // If this flag is not set, the `Window_Descriptor` field must be present in the frame header. if self.single_segment { - assert!(self.frame_content_size.is_some(), "if the `single_segment` flag is set to true, then a frame content size must be provided"); + assert!( + self.frame_content_size.is_some(), + "if the `single_segment` flag is set to true, then a frame content size must be provided" + ); bw.write_bits(1u8, 1); } else { assert!( @@ -163,7 +166,7 @@ fn minify_val_fcs(val: u64) -> Vec { #[cfg(test)] mod tests { use super::FrameHeader; - use crate::decoding::frame::{read_frame_header, FrameDescriptor}; + use crate::decoding::frame::{FrameDescriptor, read_frame_header}; use alloc::vec::Vec; #[test] diff --git a/zstd/src/encoding/levels/fastest.rs b/zstd/src/encoding/levels/fastest.rs index 4ec87572..dc3e4e8b 100644 --- a/zstd/src/encoding/levels/fastest.rs +++ b/zstd/src/encoding/levels/fastest.rs @@ -1,7 +1,7 @@ use crate::{ common::MAX_BLOCK_SIZE, encoding::{ - block_header::BlockHeader, blocks::compress_block, frame_compressor::CompressState, Matcher, + Matcher, block_header::BlockHeader, blocks::compress_block, frame_compressor::CompressState, }, }; use alloc::vec::Vec; diff --git a/zstd/src/encoding/match_generator.rs b/zstd/src/encoding/match_generator.rs index 5d765e6f..456867be 100644 --- a/zstd/src/encoding/match_generator.rs +++ b/zstd/src/encoding/match_generator.rs @@ -442,7 +442,9 @@ fn matches() { assert!(!matcher.next_sequence(|_| {})); matcher.add_data( - alloc::vec![1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0,], + alloc::vec![ + 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, + ], SuffixStore::with_capacity(100), |_, _| {}, ); diff --git a/zstd/src/fse/fse_encoder.rs b/zstd/src/fse/fse_encoder.rs index 974eecf5..958e04ca 100644 --- a/zstd/src/fse/fse_encoder.rs +++ b/zstd/src/fse/fse_encoder.rs @@ -144,7 +144,69 @@ impl FSETable { self.table_size.ilog2() as u8 } + /// Get the probability assigned to a symbol (0 means absent, -1 means less-than-1). + pub(crate) fn symbol_probability(&self, symbol: u8) -> i32 { + self.states[symbol as usize].probability + } + + /// Compute the exact serialized size (in bits) of the FSE table header, + /// including the byte-alignment padding at the end. + /// Mirrors `write_table` but counts bits instead of writing them. + /// + /// The result assumes the header starts at a byte boundary, which matches + /// all current encoder call sites. + pub(crate) fn table_header_bits(&self) -> usize { + let mut bits = 4; // acc_log - 5 + let mut probability_counter = 0usize; + let probability_sum = 1 << self.acc_log(); + + let mut prob_idx = 0; + while probability_counter < probability_sum { + let max_remaining_value = probability_sum - probability_counter + 1; + let bits_to_write = max_remaining_value.ilog2() + 1; + let low_threshold = ((1 << bits_to_write) - 1) - max_remaining_value; + + let prob = self.states[prob_idx].probability; + prob_idx += 1; + let value = (prob + 1) as u32; + if value < low_threshold as u32 { + bits += bits_to_write as usize - 1; + } else { + bits += bits_to_write as usize; + } + + if prob == -1 { + probability_counter += 1; + } else if prob > 0 { + probability_counter += prob as usize; + } else { + let mut zeros = 0u8; + while prob_idx < self.states.len() && self.states[prob_idx].probability == 0 { + zeros += 1; + prob_idx += 1; + if zeros == 3 { + bits += 2; + zeros = 0; + } + } + bits += 2; + } + } + // Byte-alignment padding + let misaligned = bits % 8; + if misaligned != 0 { + bits += 8 - misaligned; + } + bits + } + pub(crate) fn write_table>>(&self, writer: &mut BitWriter) { + assert!( + writer.index().is_multiple_of(8), + "FSE table headers must start on a byte boundary" + ); + #[cfg(debug_assertions)] + let start_idx = writer.index(); writer.write_bits(self.acc_log() - 5, 4); let mut probability_counter = 0usize; let probability_sum = 1 << self.acc_log(); @@ -173,7 +235,7 @@ impl FSETable { probability_counter += prob as usize; } else { let mut zeros = 0u8; - while self.states[prob_idx].probability == 0 { + while prob_idx < self.states.len() && self.states[prob_idx].probability == 0 { zeros += 1; prob_idx += 1; if zeros == 3 { @@ -185,6 +247,15 @@ impl FSETable { } } writer.write_bits(0u8, writer.misaligned()); + #[cfg(debug_assertions)] + { + let written_bits = writer.index() - start_idx; + let computed = self.table_header_bits(); + debug_assert_eq!( + written_bits, computed, + "table_header_bits() mismatch: written={written_bits}, computed={computed}" + ); + } } } @@ -241,6 +312,14 @@ pub fn build_table_from_data( build_table_from_counts(&counts[..=max_symbol], max_log, avoid_0_numbit) } +pub(crate) fn build_table_from_symbol_counts( + counts: &[usize], + max_log: u8, + avoid_0_numbit: bool, +) -> FSETable { + build_table_from_counts(counts, max_log, avoid_0_numbit) +} + fn build_table_from_counts(counts: &[usize], max_log: u8, avoid_0_numbit: bool) -> FSETable { let mut probs = [0; 256]; let probs = &mut probs[..counts.len()]; diff --git a/zstd/src/fse/mod.rs b/zstd/src/fse/mod.rs index 569007ad..f2a8b7b2 100644 --- a/zstd/src/fse/mod.rs +++ b/zstd/src/fse/mod.rs @@ -42,6 +42,54 @@ fn check_tables(dec_table: &fse_decoder::FSETable, enc_table: &fse_encoder::FSET } } +/// Verify `table_header_bits()` matches the actual byte count written by `write_table()`. +#[test] +fn table_header_bits_exact() { + use crate::bit_io::BitWriter; + use fse_encoder::{ + build_table_from_data, build_table_from_probabilities, default_ll_table, default_ml_table, + default_of_table, + }; + + let check = |table: &fse_encoder::FSETable| { + let mut buf = alloc::vec::Vec::new(); + let mut writer = BitWriter::from(&mut buf); + table.write_table(&mut writer); + writer.flush(); + let written_bits = buf.len() * 8; // flush pads to byte boundary + let computed_bits = table.table_header_bits(); + assert_eq!( + computed_bits, written_bits, + "table_header_bits() mismatch: computed={computed_bits}, written={written_bits}" + ); + }; + + // Predefined tables + check(&default_ll_table()); + check(&default_ml_table()); + check(&default_of_table()); + + // Tables built from synthetic data + let data: alloc::vec::Vec = (0u8..32).cycle().take(1000).collect(); + check(&build_table_from_data(data.iter().copied(), 9, true)); + + let data2: alloc::vec::Vec = alloc::vec![0, 1, 2, 3] + .into_iter() + .cycle() + .take(500) + .collect(); + check(&build_table_from_data(data2.iter().copied(), 8, true)); + + // Uniform distribution: 32 symbols × prob=2 = 64 = 1<<6 (acc_log=6 requires sum=64) + check(&build_table_from_probabilities( + &[ + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, + ], + 6, + )); +} + #[test] fn roundtrip() { round_trip(&(0..64).collect::>()); diff --git a/zstd/src/huff0/huff0_encoder.rs b/zstd/src/huff0/huff0_encoder.rs index 28b45ce4..7d35fc32 100644 --- a/zstd/src/huff0/huff0_encoder.rs +++ b/zstd/src/huff0/huff0_encoder.rs @@ -105,15 +105,12 @@ impl>> HuffmanEncoder<'_, '_, V> { pub(super) fn weights(&self) -> Vec { let max = self.table.codes.iter().map(|(_, nb)| nb).max().unwrap(); - let weights = self - .table + self.table .codes .iter() .copied() .map(|(_, nb)| if nb == 0 { 0 } else { max - nb + 1 }) - .collect::>(); - - weights + .collect::>() } fn write_table(&mut self) { diff --git a/zstd/src/tests/mod.rs b/zstd/src/tests/mod.rs index 6cacdaa5..e1b75c7f 100644 --- a/zstd/src/tests/mod.rs +++ b/zstd/src/tests/mod.rs @@ -489,8 +489,8 @@ fn test_streaming_no_std() { #[test] fn test_decode_all() { - use crate::decoding::errors::FrameDecoderError; use crate::decoding::FrameDecoder; + use crate::decoding::errors::FrameDecoderError; let skip_frame = |input: &mut Vec, length: usize| { input.extend_from_slice(&0x184D2A50u32.to_le_bytes()); @@ -587,5 +587,6 @@ pub mod roundtrip_integrity; #[test] fn verbose_disabled() { use crate::VERBOSE; - assert_eq!(VERBOSE, false); + use core::hint::black_box; + assert!(!black_box(VERBOSE)); } diff --git a/zstd/src/tests/roundtrip_integrity.rs b/zstd/src/tests/roundtrip_integrity.rs index 6aeb7350..7912fa10 100644 --- a/zstd/src/tests/roundtrip_integrity.rs +++ b/zstd/src/tests/roundtrip_integrity.rs @@ -10,7 +10,7 @@ use alloc::vec; use alloc::vec::Vec; use crate::decoding::StreamingDecoder; -use crate::encoding::{compress_to_vec, CompressionLevel, FrameCompressor}; +use crate::encoding::{CompressionLevel, FrameCompressor, compress_to_vec}; use crate::io::Read; /// Generate deterministic pseudo-random data using a simple LCG. @@ -46,14 +46,18 @@ fn roundtrip_simple(data: &[u8]) -> Vec { result } -/// Roundtrip using FrameCompressor (streaming API). -fn roundtrip_streaming(data: &[u8]) -> Vec { +fn compress_streaming(data: &[u8]) -> Vec { let mut compressed = Vec::new(); let mut compressor = FrameCompressor::new(CompressionLevel::Fastest); compressor.set_source(data); compressor.set_drain(&mut compressed); compressor.compress(); + compressed +} +/// Roundtrip using FrameCompressor (streaming API). +fn roundtrip_streaming(data: &[u8]) -> Vec { + let compressed = compress_streaming(data); let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap(); let mut result = Vec::new(); decoder.read_to_end(&mut result).unwrap(); @@ -75,6 +79,15 @@ fn generate_huffman_friendly(seed: u64, len: usize, alphabet_size: u8) -> Vec Vec { + let mut data = Vec::with_capacity(chunks * (pattern.len() + 2)); + for i in 0..chunks { + data.extend_from_slice(pattern); + data.extend_from_slice(&(i as u16).to_le_bytes()); + } + data +} + // Cross-validation tests (pure Rust ↔ C FFI) are in tests/cross_validation.rs // because dev-dependencies (zstd) aren't available in library test modules. @@ -124,7 +137,7 @@ fn roundtrip_streaming_api_1000_iterations() { #[test] fn roundtrip_edge_cases() { // Empty data - assert_eq!(roundtrip_simple(&[]), vec![]); + assert_eq!(roundtrip_simple(&[]), Vec::::new()); // Single byte assert_eq!(roundtrip_simple(&[0x42]), vec![0x42]); @@ -187,3 +200,140 @@ fn roundtrip_multi_block_large_literals() { assert_eq!(roundtrip_simple(&data), data); assert_eq!(roundtrip_streaming(&data), data); } + +/// Repeat offset encoding: data with many repeated match offsets should compress +/// better than data where every offset is unique, and must roundtrip correctly. +#[test] +fn roundtrip_repeat_offsets() { + // Break each repeated chunk with a changing 2-byte sentinel so the matcher + // has to re-emit the same offset instead of collapsing everything into one + // maximal match. + let data = repeat_offset_fixture(b"ABCDE12345", 10_000); + let result = roundtrip_simple(&data); + assert_eq!(data, result, "Repeat offset roundtrip failed"); + + // Also verify with streaming API + let result = roundtrip_streaming(&data); + assert_eq!(data, result, "Repeat offset streaming roundtrip failed"); +} + +/// Verify that highly repetitive data compresses significantly better than random data. +#[test] +fn repetitive_data_compresses_better_than_random() { + // Repetitive data: fixed-offset matches separated by a changing sentinel. + let repetitive = repeat_offset_fixture(b"ABCDE12345", 5_000); + let compressed_repetitive = compress_to_vec(&repetitive[..], CompressionLevel::Fastest); + + // Random data of same size (incompressible) + let random = generate_data(999, repetitive.len()); + let compressed_random = compress_to_vec(&random[..], CompressionLevel::Fastest); + + // Repetitive data should still beat random data, without pinning an exact + // ratio that may drift as encoder heuristics evolve. + assert!( + compressed_repetitive.len() < compressed_random.len(), + "Repetitive input should compress better than random input. \ + repetitive={} bytes, random={} bytes", + compressed_repetitive.len(), + compressed_random.len() + ); +} + +/// Multi-block data exercises FSE table reuse across blocks and offset history +/// persistence across block boundaries. +#[test] +fn roundtrip_multi_block_repeat_offsets() { + // 512KB of data with fixed-offset repeats broken by a changing sentinel — + // spans multiple 128KB blocks, so offset history and FSE tables must + // persist correctly across block boundaries. + let mut data = repeat_offset_fixture(b"HelloWorld", (512 * 1024) / 12 + 1); + data.truncate(512 * 1024); + + let result = roundtrip_simple(&data); + assert_eq!(data, result, "Multi-block repeat offset roundtrip failed"); + + let result = roundtrip_streaming(&data); + assert_eq!( + data, result, + "Multi-block repeat offset streaming roundtrip failed" + ); + + let whole_frame = compress_streaming(&data); + let frame_overhead = compress_to_vec(&[][..], CompressionLevel::Fastest).len(); + let independent_chunks: usize = data + .chunks(128 * 1024) + .map(|chunk| { + compress_to_vec(chunk, CompressionLevel::Fastest) + .len() + .saturating_sub(frame_overhead) + }) + .sum::() + .saturating_add(frame_overhead); + assert!( + whole_frame.len() < independent_chunks, + "Cross-block reuse should beat per-block resets. whole={} bytes, split={} bytes", + whole_frame.len(), + independent_chunks + ); +} + +/// Zero literal length sequences (back-to-back matches with no literals between them) +/// exercise the shifted repeat-offset remap path instead of only generic new offsets. +#[test] +fn roundtrip_zero_literal_length_sequences() { + // Alternate a base prefix with a one-byte-shifted version so the encoder + // sees back-to-back zero-literal matches that must use a shifted repeat + // remap path instead of only generic new offsets. + let mut data = Vec::with_capacity(10_000); + // Initial unique segment + for i in 0..100u8 { + data.push(i); + } + // Repeat the first 50 bytes, then alternate with a shifted 50-byte window. + let prefix = data[..50].to_vec(); + let shifted_prefix = data[1..51].to_vec(); + data.extend_from_slice(&prefix); + for _ in 0..100 { + data.extend_from_slice(&shifted_prefix); + data.extend_from_slice(&prefix); + } + + let result = roundtrip_simple(&data); + assert_eq!(data, result, "Zero ll sequence roundtrip failed"); +} + +/// Reusing the same `FrameCompressor` across frames must reset per-frame FSE repeat tables. +#[test] +fn roundtrip_reused_frame_compressor_across_frames() { + let first = generate_huffman_friendly(700, 512 * 1024, 48); + let second = generate_huffman_friendly(701, 512 * 1024, 48); + + let mut first_compressed = Vec::new(); + let mut second_compressed = Vec::new(); + { + let mut compressor = FrameCompressor::new(CompressionLevel::Fastest); + compressor.set_source(first.as_slice()); + compressor.set_drain(&mut first_compressed); + compressor.compress(); + + compressor.set_source(second.as_slice()); + compressor.set_drain(&mut second_compressed); + compressor.compress(); + } + + let mut decoder = StreamingDecoder::new(first_compressed.as_slice()).unwrap(); + let mut first_roundtrip = Vec::new(); + decoder.read_to_end(&mut first_roundtrip).unwrap(); + assert_eq!( + first, first_roundtrip, + "First reused-frame roundtrip failed" + ); + + let mut decoder = StreamingDecoder::new(second_compressed.as_slice()).unwrap(); + let mut second_roundtrip = Vec::new(); + decoder.read_to_end(&mut second_roundtrip).unwrap(); + assert_eq!( + second, second_roundtrip, + "Second reused-frame roundtrip failed" + ); +} diff --git a/zstd/tests/cross_validation.rs b/zstd/tests/cross_validation.rs index 83b6e783..e60db61e 100644 --- a/zstd/tests/cross_validation.rs +++ b/zstd/tests/cross_validation.rs @@ -5,7 +5,7 @@ //! - C FFI compress → Pure Rust decompress use structured_zstd::decoding::StreamingDecoder; -use structured_zstd::encoding::{compress_to_vec, CompressionLevel}; +use structured_zstd::encoding::{CompressionLevel, compress_to_vec}; use structured_zstd::io::Read; /// Generate deterministic pseudo-random data using a simple LCG. @@ -115,3 +115,84 @@ fn cross_ffi_compress_rust_decompress_large_blocks() { decoder.read_to_end(&mut result).unwrap(); assert_eq!(data, result, "ffi→rust multi-block roundtrip failed"); } + +/// Cross-validate Rust compress (seed=100, 512KB) → C FFI decompress for the +/// same Huffman-heavy multi-block input used in roundtrip_multi_block_large_literals. +#[test] +fn cross_rust_compress_ffi_decompress_huffman_seed100() { + let data = generate_huffman_friendly(100, 512 * 1024, 48); + let compressed = compress_to_vec(&data[..], CompressionLevel::Fastest); + let result = zstd::decode_all(compressed.as_slice()).unwrap(); + assert_eq!(data, result, "rust→ffi seed=100 512KB roundtrip failed"); +} + +/// Cross-validate the same Huffman-heavy 512KB input in the opposite direction: +/// C FFI compress (seed=100) → Rust decompress. +#[test] +fn cross_ffi_compress_rust_decompress_huffman_seed100() { + let data = generate_huffman_friendly(100, 512 * 1024, 48); + let compressed = zstd::encode_all(&data[..], 1).unwrap(); + let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap(); + let mut result = Vec::new(); + decoder.read_to_end(&mut result).unwrap(); + assert_eq!(data, result, "ffi→rust seed=100 512KB roundtrip failed"); +} + +/// Cross-validate repeat offset encoding: Rust compress → C FFI decompress. +/// Exercises repeat offset codes (1/2/3) and offset history across blocks. +#[test] +fn cross_rust_compress_ffi_decompress_repeat_offsets() { + // Single-block: repeating pattern at fixed offset + let pattern = b"ABCDE12345"; + let mut data = Vec::with_capacity(50_000); + for _ in 0..5_000 { + data.extend_from_slice(pattern); + } + let compressed = compress_to_vec(&data[..], CompressionLevel::Fastest); + let result = zstd::decode_all(compressed.as_slice()).unwrap(); + assert_eq!(data, result, "rust→ffi repeat offset roundtrip failed"); + + // Multi-block: 512KB with repeating patterns spanning block boundaries + let mut multi_block = Vec::with_capacity(512 * 1024); + while multi_block.len() < 512 * 1024 { + multi_block.extend_from_slice(pattern); + } + multi_block.truncate(512 * 1024); + let compressed = compress_to_vec(&multi_block[..], CompressionLevel::Fastest); + let result = zstd::decode_all(compressed.as_slice()).unwrap(); + assert_eq!( + multi_block, result, + "rust→ffi multi-block repeat offset roundtrip failed" + ); +} + +/// Cross-validate repeat-offset-friendly inputs in the opposite direction: +/// C FFI compress → Rust decompress. +#[test] +fn cross_ffi_compress_rust_decompress_repeat_offsets() { + let pattern = b"ABCDE12345"; + + let mut data = Vec::with_capacity(50_000); + for _ in 0..5_000 { + data.extend_from_slice(pattern); + } + let compressed = zstd::encode_all(&data[..], 1).unwrap(); + let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap(); + let mut result = Vec::new(); + decoder.read_to_end(&mut result).unwrap(); + assert_eq!(data, result, "ffi→rust repeat offset roundtrip failed"); + + let mut multi_block = Vec::with_capacity(512 * 1024); + while multi_block.len() < 512 * 1024 { + multi_block.extend_from_slice(pattern); + } + multi_block.truncate(512 * 1024); + let compressed = zstd::encode_all(&multi_block[..], 1).unwrap(); + let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap(); + let mut result = Vec::new(); + decoder.read_to_end(&mut result).unwrap(); + assert_eq!( + multi_block, result, + "ffi→rust multi-block repeat offset roundtrip failed" + ); +}