Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 216 additions & 22 deletions zstd/src/decoding/decode_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use alloc::vec::Vec;
#[cfg(feature = "hash")]
use core::hash::Hasher;

use super::prefetch;
use super::ringbuffer::RingBuffer;
use crate::decoding::errors::DecodeBufferError;

Expand Down Expand Up @@ -65,6 +66,14 @@ impl DecodeBuffer {
}

pub fn repeat(&mut self, offset: usize, match_length: usize) -> Result<(), DecodeBufferError> {
if offset == 0 {
return Err(DecodeBufferError::ZeroOffset);
}

if match_length == 0 {
return Ok(());
}
Comment thread
polaz marked this conversation as resolved.

if offset > self.buffer.len() {
self.repeat_from_dict(offset, match_length)
Comment thread
polaz marked this conversation as resolved.
} else {
Expand All @@ -73,9 +82,9 @@ impl DecodeBuffer {
let end_idx = start_idx + match_length;

self.buffer.reserve(match_length);
self.prefetch_match_source(start_idx);
if end_idx > buf_len {
// We need to copy in chunks.
self.repeat_in_chunks(offset, match_length, start_idx);
self.repeat_overlapping(offset, match_length, start_idx);
} else {
// can just copy parts of the existing buffer
// SAFETY: Requirements checked:
Expand All @@ -88,8 +97,13 @@ impl DecodeBuffer {
//
// 2. explicitly reserved enough memory for the whole match_length
unsafe {
self.buffer
.extend_from_within_unchecked(start_idx, match_length)
if offset >= 16 && use_branchless_wildcopy() {
self.buffer
.extend_from_within_unchecked_branchless(start_idx, match_length);
} else {
self.buffer
.extend_from_within_unchecked(start_idx, match_length);
}
};
}

Expand All @@ -98,36 +112,102 @@ impl DecodeBuffer {
}
}

fn repeat_in_chunks(&mut self, offset: usize, match_length: usize, start_idx: usize) {
// We have at max offset bytes in one chunk, the last one can be smaller
#[inline(always)]
fn repeat_overlapping(&mut self, offset: usize, match_length: usize, start_idx: usize) {
if offset >= 16 {
self.repeat_in_chunks(offset, match_length, start_idx, use_branchless_wildcopy());
} else if offset >= 8 {
self.repeat_in_chunks(offset, match_length, start_idx, false);
} else {
self.repeat_short_offset(offset, match_length, start_idx);
}
}

#[inline(always)]
fn repeat_in_chunks(
&mut self,
offset: usize,
match_length: usize,
start_idx: usize,
use_branchless_copy: bool,
) {
let mut start_idx = start_idx;
let mut copied_counter_left = match_length;
// TODO this can be optimized further I think.
// Each time we copy a chunk we have a repetiton of length 'offset', so we can copy offset * iteration many bytes from start_idx
while copied_counter_left > 0 {
let chunksize = usize::min(offset, copied_counter_left);

// SAFETY: Requirements checked:
// 1. start_idx + chunksize must be <= self.buffer.len()
// We know that:
// 1. start_idx starts at buffer.len() - offset
// 2. chunksize <= offset (== offset for each iteration but the last, and match_length modulo offset in the last iteration)
// 3. the buffer grows by offset many bytes each iteration but the last
// 4. start_idx is increased by the same amount as the buffer grows each iteration
//
// Thus follows: start_idx + chunksize == self.buffer.len() in each iteration but the last, where match_length modulo offset == chunksize < offset
// Meaning: start_idx + chunksize <= self.buffer.len()
//
// 2. explicitly reserved enough memory for the whole match_length
// SAFETY: chunksize <= offset keeps each single copy in the currently readable
// source range, and repeat() reserved enough destination capacity.
unsafe {
self.buffer
.extend_from_within_unchecked(start_idx, chunksize)
if use_branchless_copy {
self.buffer
.extend_from_within_unchecked_branchless(start_idx, chunksize);
} else {
self.buffer
.extend_from_within_unchecked(start_idx, chunksize);
}
};
copied_counter_left -= chunksize;
start_idx += chunksize;
}
}

#[inline(always)]
fn repeat_short_offset(&mut self, offset: usize, match_length: usize, start_idx: usize) {
debug_assert!(
offset > 0,
"offset must be non-zero to avoid modulo by zero in short-offset path"
);
let mut base = [0u8; 8];
for (i, slot) in base.iter_mut().take(offset).enumerate() {
*slot = self.byte_at(start_idx + i);
}

let mut phase_patterns = [[0u8; 8]; 7];
for phase in 0..offset {
for i in 0..8 {
phase_patterns[phase][i] = base[(phase + i) % offset];
}
}

let phase_step = 8 % offset;
let mut phase = 0usize;
let mut copied = 0usize;
while copied + 8 <= match_length {
self.buffer.extend(&phase_patterns[phase]);
copied += 8;
phase = (phase + phase_step) % offset;
}

if copied < match_length {
let tail = match_length - copied;
self.buffer.extend(&phase_patterns[phase][..tail]);
}
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Comment thread
coderabbitai[bot] marked this conversation as resolved.

#[inline(always)]
fn byte_at(&self, idx: usize) -> u8 {
let (s1, s2) = self.buffer.as_slices();
if idx < s1.len() {
s1[idx]
} else {
s2[idx - s1.len()]
}
}

#[inline(always)]
fn prefetch_match_source(&self, start_idx: usize) {
let (s1, s2) = self.buffer.as_slices();
if start_idx < s1.len() {
prefetch::prefetch_slice(&s1[start_idx..]);
} else {
let idx = start_idx - s1.len();
if idx < s2.len() {
prefetch::prefetch_slice(&s2[idx..]);
}
}
}

#[cold]
fn repeat_from_dict(
&mut self,
Expand All @@ -147,6 +227,7 @@ impl DecodeBuffer {

if bytes_from_dict < match_length {
let dict_slice = &self.dict_content[self.dict_content.len() - bytes_from_dict..];
prefetch::prefetch_slice(dict_slice);
self.buffer.extend(dict_slice);

self.total_output_counter += bytes_from_dict as u64;
Expand All @@ -155,7 +236,9 @@ impl DecodeBuffer {
let low = self.dict_content.len() - bytes_from_dict;
let high = low + match_length;
let dict_slice = &self.dict_content[low..high];
prefetch::prefetch_slice(dict_slice);
self.buffer.extend(dict_slice);
self.total_output_counter += match_length as u64;
}
Ok(())
} else {
Expand Down Expand Up @@ -315,6 +398,11 @@ fn write_all_bytes(mut sink: impl Write, buf: &[u8]) -> (usize, Result<(), Error
(written, Ok(()))
}

#[inline(always)]
fn use_branchless_wildcopy() -> bool {
cfg!(any(target_arch = "x86", target_arch = "x86_64"))
}

#[cfg(test)]
mod tests {
use super::DecodeBuffer;
Expand Down Expand Up @@ -448,4 +536,110 @@ mod tests {
}
assert_eq!(short_writer.buf.len(), repeats * 50 + 100);
}

#[test]
fn repeat_overlap_fast_paths_match_reference_behavior() {
let seed = b"0123456789abcdef0123456789abcdef";
let cases = [
(16usize, 16usize), // non-overlapping boundary
(16usize, 211usize),
(8usize, 173usize),
(7usize, 149usize),
(3usize, 160usize),
(1usize, 255usize),
];

for (offset, match_len) in cases {
let mut decode_buf = DecodeBuffer::new(4 * 1024);
decode_buf.push(seed);
decode_buf.repeat(offset, match_len).unwrap();
let got = decode_buf.drain();
let expected = expected_match_expansion(seed, offset, match_len);
assert_eq!(got, expected, "offset={offset}, match_len={match_len}");
}
}
Comment thread
polaz marked this conversation as resolved.

#[test]
fn repeat_zero_offset_returns_error() {
let mut decode_buf = DecodeBuffer::new(1024);
decode_buf.push(b"abcdef");
let err = decode_buf.repeat(0, 5).unwrap_err();
assert!(matches!(
err,
crate::decoding::errors::DecodeBufferError::ZeroOffset
));
}

#[test]
fn repeat_from_dict_full_copy_updates_total_output_counter() {
let mut decode_buf = DecodeBuffer::new(1);
decode_buf.dict_content = b"0123456789".to_vec();

decode_buf.repeat(10, 2).unwrap();
let err = decode_buf.repeat(10, 1).unwrap_err();
assert!(matches!(
err,
crate::decoding::errors::DecodeBufferError::OffsetTooBig { .. }
));
}

#[test]
fn repeat_overlap_fast_paths_match_reference_behavior_with_wrapped_ringbuffer() {
let window = 32usize;
let seed = b"0123456789abcdef0123456789abcdef";
let mut decode_buf = DecodeBuffer::new(window);
let mut model = Vec::new();

decode_buf.push(seed);
model_push(&mut model, seed);
decode_buf.repeat(16, 16).unwrap();
model_repeat(&mut model, 16, 16);

let drained = decode_buf.drain_to_window_size().unwrap();
let model_drained = model_drain_to_window(&mut model, window);
assert_eq!(drained, model_drained);

let cases = [(3usize, 97usize), (16usize, 64usize), (7usize, 73usize)];
for (offset, match_len) in cases {
decode_buf.repeat(offset, match_len).unwrap();
model_repeat(&mut model, offset, match_len);

if let Some(got) = decode_buf.drain_to_window_size() {
let expected = model_drain_to_window(&mut model, window);
assert_eq!(got, expected, "offset={offset}, match_len={match_len}");
}
}

assert_eq!(decode_buf.drain(), model);
}

fn expected_match_expansion(seed: &[u8], offset: usize, match_len: usize) -> Vec<u8> {
let mut out = seed.to_vec();
let start = out.len() - offset;
for i in 0..match_len {
let byte = out[start + i];
out.push(byte);
}
out
}

fn model_push(model: &mut Vec<u8>, bytes: &[u8]) {
model.extend_from_slice(bytes);
}

fn model_repeat(model: &mut Vec<u8>, offset: usize, match_len: usize) {
let start = model.len() - offset;
for i in 0..match_len {
let byte = model[start + i];
model.push(byte);
}
}

fn model_drain_to_window(model: &mut Vec<u8>, window: usize) -> Vec<u8> {
if model.len() <= window {
return Vec::new();
}
let drain_len = model.len() - window;
model.drain(0..drain_len).collect()
}
}
4 changes: 4 additions & 0 deletions zstd/src/decoding/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ impl From<DecompressBlockError> for DecodeBlockContentError {
pub enum DecodeBufferError {
NotEnoughBytesInDictionary { got: usize, need: usize },
OffsetTooBig { offset: usize, buf_len: usize },
ZeroOffset,
}

#[cfg(feature = "std")]
Expand All @@ -413,6 +414,9 @@ impl core::fmt::Display for DecodeBufferError {
DecodeBufferError::OffsetTooBig { offset, buf_len } => {
write!(f, "offset: {offset} bigger than buffer: {buf_len}",)
}
DecodeBufferError::ZeroOffset => {
write!(f, "Illegal offset: 0 found")
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions zstd/src/decoding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub(crate) mod decode_buffer;
pub(crate) mod dictionary;
pub(crate) mod frame;
pub(crate) mod literals_section_decoder;
pub(crate) mod prefetch;
mod ringbuffer;
#[allow(dead_code)]
pub(crate) mod scratch;
Expand Down
31 changes: 31 additions & 0 deletions zstd/src/decoding/prefetch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#[inline(always)]
pub(crate) fn prefetch_slice(slice: &[u8]) {
prefetch_slice_impl(slice);
}

#[cfg(target_arch = "x86_64")]
#[inline(always)]
fn prefetch_slice_impl(slice: &[u8]) {
use core::arch::x86_64::{_MM_HINT_T0, _mm_prefetch};

if !slice.is_empty() {
unsafe { _mm_prefetch(slice.as_ptr().cast(), _MM_HINT_T0) };
}
}

#[cfg(all(target_arch = "x86", target_feature = "sse"))]
#[inline(always)]
fn prefetch_slice_impl(slice: &[u8]) {
use core::arch::x86::{_MM_HINT_T0, _mm_prefetch};

if !slice.is_empty() {
unsafe { _mm_prefetch(slice.as_ptr().cast(), _MM_HINT_T0) };
}
}

#[cfg(not(any(
target_arch = "x86_64",
all(target_arch = "x86", target_feature = "sse"),
)))]
#[inline(always)]
fn prefetch_slice_impl(_slice: &[u8]) {}
7 changes: 7 additions & 0 deletions zstd/src/decoding/sequence_execution.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::prefetch;
use super::scratch::DecoderScratch;
use crate::decoding::errors::ExecuteSequencesError;

Expand All @@ -19,6 +20,7 @@ pub fn execute_sequences(scratch: &mut DecoderScratch) -> Result<(), ExecuteSequ
});
}
let literals = &scratch.literals_buffer[literals_copy_counter..high];
prefetch_literals(literals);
literals_copy_counter += seq.ll as usize;

scratch.buffer.push(literals);
Expand Down Expand Up @@ -113,3 +115,8 @@ fn do_offset_history(offset_value: u32, lit_len: u32, scratch: &mut [u32; 3]) ->

actual_offset
}

#[inline(always)]
fn prefetch_literals(slice: &[u8]) {
prefetch::prefetch_slice(slice);
}
Loading