diff --git a/zstd/src/decoding/decode_buffer.rs b/zstd/src/decoding/decode_buffer.rs index 46c4f333..51f7b7be 100644 --- a/zstd/src/decoding/decode_buffer.rs +++ b/zstd/src/decoding/decode_buffer.rs @@ -3,6 +3,7 @@ use alloc::vec::Vec; #[cfg(feature = "hash")] use core::hash::Hasher; +use super::prefetch; use super::ringbuffer::RingBuffer; use crate::decoding::errors::DecodeBufferError; @@ -65,6 +66,14 @@ impl DecodeBuffer { } pub fn repeat(&mut self, offset: usize, match_length: usize) -> Result<(), DecodeBufferError> { + if offset == 0 { + return Err(DecodeBufferError::ZeroOffset); + } + + if match_length == 0 { + return Ok(()); + } + if offset > self.buffer.len() { self.repeat_from_dict(offset, match_length) } else { @@ -73,9 +82,9 @@ impl DecodeBuffer { let end_idx = start_idx + match_length; self.buffer.reserve(match_length); + self.prefetch_match_source(start_idx); if end_idx > buf_len { - // We need to copy in chunks. - self.repeat_in_chunks(offset, match_length, start_idx); + self.repeat_overlapping(offset, match_length, start_idx); } else { // can just copy parts of the existing buffer // SAFETY: Requirements checked: @@ -88,8 +97,13 @@ impl DecodeBuffer { // // 2. explicitly reserved enough memory for the whole match_length unsafe { - self.buffer - .extend_from_within_unchecked(start_idx, match_length) + if offset >= 16 && use_branchless_wildcopy() { + self.buffer + .extend_from_within_unchecked_branchless(start_idx, match_length); + } else { + self.buffer + .extend_from_within_unchecked(start_idx, match_length); + } }; } @@ -98,36 +112,102 @@ impl DecodeBuffer { } } - fn repeat_in_chunks(&mut self, offset: usize, match_length: usize, start_idx: usize) { - // We have at max offset bytes in one chunk, the last one can be smaller + #[inline(always)] + fn repeat_overlapping(&mut self, offset: usize, match_length: usize, start_idx: usize) { + if offset >= 16 { + self.repeat_in_chunks(offset, match_length, start_idx, use_branchless_wildcopy()); + } else if offset >= 8 { + self.repeat_in_chunks(offset, match_length, start_idx, false); + } else { + self.repeat_short_offset(offset, match_length, start_idx); + } + } + + #[inline(always)] + fn repeat_in_chunks( + &mut self, + offset: usize, + match_length: usize, + start_idx: usize, + use_branchless_copy: bool, + ) { let mut start_idx = start_idx; let mut copied_counter_left = match_length; - // TODO this can be optimized further I think. - // Each time we copy a chunk we have a repetiton of length 'offset', so we can copy offset * iteration many bytes from start_idx while copied_counter_left > 0 { let chunksize = usize::min(offset, copied_counter_left); - // SAFETY: Requirements checked: - // 1. start_idx + chunksize must be <= self.buffer.len() - // We know that: - // 1. start_idx starts at buffer.len() - offset - // 2. chunksize <= offset (== offset for each iteration but the last, and match_length modulo offset in the last iteration) - // 3. the buffer grows by offset many bytes each iteration but the last - // 4. start_idx is increased by the same amount as the buffer grows each iteration - // - // Thus follows: start_idx + chunksize == self.buffer.len() in each iteration but the last, where match_length modulo offset == chunksize < offset - // Meaning: start_idx + chunksize <= self.buffer.len() - // - // 2. explicitly reserved enough memory for the whole match_length + // SAFETY: chunksize <= offset keeps each single copy in the currently readable + // source range, and repeat() reserved enough destination capacity. unsafe { - self.buffer - .extend_from_within_unchecked(start_idx, chunksize) + if use_branchless_copy { + self.buffer + .extend_from_within_unchecked_branchless(start_idx, chunksize); + } else { + self.buffer + .extend_from_within_unchecked(start_idx, chunksize); + } }; copied_counter_left -= chunksize; start_idx += chunksize; } } + #[inline(always)] + fn repeat_short_offset(&mut self, offset: usize, match_length: usize, start_idx: usize) { + debug_assert!( + offset > 0, + "offset must be non-zero to avoid modulo by zero in short-offset path" + ); + let mut base = [0u8; 8]; + for (i, slot) in base.iter_mut().take(offset).enumerate() { + *slot = self.byte_at(start_idx + i); + } + + let mut phase_patterns = [[0u8; 8]; 7]; + for phase in 0..offset { + for i in 0..8 { + phase_patterns[phase][i] = base[(phase + i) % offset]; + } + } + + let phase_step = 8 % offset; + let mut phase = 0usize; + let mut copied = 0usize; + while copied + 8 <= match_length { + self.buffer.extend(&phase_patterns[phase]); + copied += 8; + phase = (phase + phase_step) % offset; + } + + if copied < match_length { + let tail = match_length - copied; + self.buffer.extend(&phase_patterns[phase][..tail]); + } + } + + #[inline(always)] + fn byte_at(&self, idx: usize) -> u8 { + let (s1, s2) = self.buffer.as_slices(); + if idx < s1.len() { + s1[idx] + } else { + s2[idx - s1.len()] + } + } + + #[inline(always)] + fn prefetch_match_source(&self, start_idx: usize) { + let (s1, s2) = self.buffer.as_slices(); + if start_idx < s1.len() { + prefetch::prefetch_slice(&s1[start_idx..]); + } else { + let idx = start_idx - s1.len(); + if idx < s2.len() { + prefetch::prefetch_slice(&s2[idx..]); + } + } + } + #[cold] fn repeat_from_dict( &mut self, @@ -147,6 +227,7 @@ impl DecodeBuffer { if bytes_from_dict < match_length { let dict_slice = &self.dict_content[self.dict_content.len() - bytes_from_dict..]; + prefetch::prefetch_slice(dict_slice); self.buffer.extend(dict_slice); self.total_output_counter += bytes_from_dict as u64; @@ -155,7 +236,9 @@ impl DecodeBuffer { let low = self.dict_content.len() - bytes_from_dict; let high = low + match_length; let dict_slice = &self.dict_content[low..high]; + prefetch::prefetch_slice(dict_slice); self.buffer.extend(dict_slice); + self.total_output_counter += match_length as u64; } Ok(()) } else { @@ -315,6 +398,11 @@ fn write_all_bytes(mut sink: impl Write, buf: &[u8]) -> (usize, Result<(), Error (written, Ok(())) } +#[inline(always)] +fn use_branchless_wildcopy() -> bool { + cfg!(any(target_arch = "x86", target_arch = "x86_64")) +} + #[cfg(test)] mod tests { use super::DecodeBuffer; @@ -448,4 +536,110 @@ mod tests { } assert_eq!(short_writer.buf.len(), repeats * 50 + 100); } + + #[test] + fn repeat_overlap_fast_paths_match_reference_behavior() { + let seed = b"0123456789abcdef0123456789abcdef"; + let cases = [ + (16usize, 16usize), // non-overlapping boundary + (16usize, 211usize), + (8usize, 173usize), + (7usize, 149usize), + (3usize, 160usize), + (1usize, 255usize), + ]; + + for (offset, match_len) in cases { + let mut decode_buf = DecodeBuffer::new(4 * 1024); + decode_buf.push(seed); + decode_buf.repeat(offset, match_len).unwrap(); + let got = decode_buf.drain(); + let expected = expected_match_expansion(seed, offset, match_len); + assert_eq!(got, expected, "offset={offset}, match_len={match_len}"); + } + } + + #[test] + fn repeat_zero_offset_returns_error() { + let mut decode_buf = DecodeBuffer::new(1024); + decode_buf.push(b"abcdef"); + let err = decode_buf.repeat(0, 5).unwrap_err(); + assert!(matches!( + err, + crate::decoding::errors::DecodeBufferError::ZeroOffset + )); + } + + #[test] + fn repeat_from_dict_full_copy_updates_total_output_counter() { + let mut decode_buf = DecodeBuffer::new(1); + decode_buf.dict_content = b"0123456789".to_vec(); + + decode_buf.repeat(10, 2).unwrap(); + let err = decode_buf.repeat(10, 1).unwrap_err(); + assert!(matches!( + err, + crate::decoding::errors::DecodeBufferError::OffsetTooBig { .. } + )); + } + + #[test] + fn repeat_overlap_fast_paths_match_reference_behavior_with_wrapped_ringbuffer() { + let window = 32usize; + let seed = b"0123456789abcdef0123456789abcdef"; + let mut decode_buf = DecodeBuffer::new(window); + let mut model = Vec::new(); + + decode_buf.push(seed); + model_push(&mut model, seed); + decode_buf.repeat(16, 16).unwrap(); + model_repeat(&mut model, 16, 16); + + let drained = decode_buf.drain_to_window_size().unwrap(); + let model_drained = model_drain_to_window(&mut model, window); + assert_eq!(drained, model_drained); + + let cases = [(3usize, 97usize), (16usize, 64usize), (7usize, 73usize)]; + for (offset, match_len) in cases { + decode_buf.repeat(offset, match_len).unwrap(); + model_repeat(&mut model, offset, match_len); + + if let Some(got) = decode_buf.drain_to_window_size() { + let expected = model_drain_to_window(&mut model, window); + assert_eq!(got, expected, "offset={offset}, match_len={match_len}"); + } + } + + assert_eq!(decode_buf.drain(), model); + } + + fn expected_match_expansion(seed: &[u8], offset: usize, match_len: usize) -> Vec { + let mut out = seed.to_vec(); + let start = out.len() - offset; + for i in 0..match_len { + let byte = out[start + i]; + out.push(byte); + } + out + } + + fn model_push(model: &mut Vec, bytes: &[u8]) { + model.extend_from_slice(bytes); + } + + fn model_repeat(model: &mut Vec, offset: usize, match_len: usize) { + let start = model.len() - offset; + for i in 0..match_len { + let byte = model[start + i]; + model.push(byte); + } + } + + fn model_drain_to_window(model: &mut Vec, window: usize) -> Vec { + if model.len() <= window { + return Vec::new(); + } + let drain_len = model.len() - window; + model.drain(0..drain_len).collect() + } } diff --git a/zstd/src/decoding/errors.rs b/zstd/src/decoding/errors.rs index 9b1c6bb0..466ffe1a 100644 --- a/zstd/src/decoding/errors.rs +++ b/zstd/src/decoding/errors.rs @@ -396,6 +396,7 @@ impl From for DecodeBlockContentError { pub enum DecodeBufferError { NotEnoughBytesInDictionary { got: usize, need: usize }, OffsetTooBig { offset: usize, buf_len: usize }, + ZeroOffset, } #[cfg(feature = "std")] @@ -413,6 +414,9 @@ impl core::fmt::Display for DecodeBufferError { DecodeBufferError::OffsetTooBig { offset, buf_len } => { write!(f, "offset: {offset} bigger than buffer: {buf_len}",) } + DecodeBufferError::ZeroOffset => { + write!(f, "Illegal offset: 0 found") + } } } } diff --git a/zstd/src/decoding/mod.rs b/zstd/src/decoding/mod.rs index f3e5323f..e0ef937b 100644 --- a/zstd/src/decoding/mod.rs +++ b/zstd/src/decoding/mod.rs @@ -13,6 +13,7 @@ pub(crate) mod decode_buffer; pub(crate) mod dictionary; pub(crate) mod frame; pub(crate) mod literals_section_decoder; +pub(crate) mod prefetch; mod ringbuffer; #[allow(dead_code)] pub(crate) mod scratch; diff --git a/zstd/src/decoding/prefetch.rs b/zstd/src/decoding/prefetch.rs new file mode 100644 index 00000000..13542436 --- /dev/null +++ b/zstd/src/decoding/prefetch.rs @@ -0,0 +1,31 @@ +#[inline(always)] +pub(crate) fn prefetch_slice(slice: &[u8]) { + prefetch_slice_impl(slice); +} + +#[cfg(target_arch = "x86_64")] +#[inline(always)] +fn prefetch_slice_impl(slice: &[u8]) { + use core::arch::x86_64::{_MM_HINT_T0, _mm_prefetch}; + + if !slice.is_empty() { + unsafe { _mm_prefetch(slice.as_ptr().cast(), _MM_HINT_T0) }; + } +} + +#[cfg(all(target_arch = "x86", target_feature = "sse"))] +#[inline(always)] +fn prefetch_slice_impl(slice: &[u8]) { + use core::arch::x86::{_MM_HINT_T0, _mm_prefetch}; + + if !slice.is_empty() { + unsafe { _mm_prefetch(slice.as_ptr().cast(), _MM_HINT_T0) }; + } +} + +#[cfg(not(any( + target_arch = "x86_64", + all(target_arch = "x86", target_feature = "sse"), +)))] +#[inline(always)] +fn prefetch_slice_impl(_slice: &[u8]) {} diff --git a/zstd/src/decoding/sequence_execution.rs b/zstd/src/decoding/sequence_execution.rs index 108ae89e..afaf9b87 100644 --- a/zstd/src/decoding/sequence_execution.rs +++ b/zstd/src/decoding/sequence_execution.rs @@ -1,3 +1,4 @@ +use super::prefetch; use super::scratch::DecoderScratch; use crate::decoding::errors::ExecuteSequencesError; @@ -19,6 +20,7 @@ pub fn execute_sequences(scratch: &mut DecoderScratch) -> Result<(), ExecuteSequ }); } let literals = &scratch.literals_buffer[literals_copy_counter..high]; + prefetch_literals(literals); literals_copy_counter += seq.ll as usize; scratch.buffer.push(literals); @@ -113,3 +115,8 @@ fn do_offset_history(offset_value: u32, lit_len: u32, scratch: &mut [u32; 3]) -> actual_offset } + +#[inline(always)] +fn prefetch_literals(slice: &[u8]) { + prefetch::prefetch_slice(slice); +}