diff --git a/zstd/src/encoding/streaming_encoder.rs b/zstd/src/encoding/streaming_encoder.rs index c140bc38..a60a5b6b 100644 --- a/zstd/src/encoding/streaming_encoder.rs +++ b/zstd/src/encoding/streaming_encoder.rs @@ -26,6 +26,7 @@ pub struct StreamingEncoder { compression_level: CompressionLevel, state: CompressState, pending: Vec, + encoded_scratch: Vec, errored: bool, last_error_kind: Option, last_error_message: Option, @@ -66,6 +67,7 @@ impl StreamingEncoder { offset_hist: [1, 4, 8], }, pending: Vec::new(), + encoded_scratch: Vec::new(), errored: false, last_error_kind: None, last_error_message: None, @@ -338,8 +340,13 @@ impl StreamingEncoder { last_block: bool, ) -> Result<(), (Error, Vec)> { let mut raw_block = Some(uncompressed_data); - // TODO: reuse scratch buffer across blocks to reduce allocation churn (#47) - let mut encoded = Vec::with_capacity(self.block_capacity() + 3); + let mut encoded = Vec::new(); + mem::swap(&mut encoded, &mut self.encoded_scratch); + encoded.clear(); + let needed_capacity = self.block_capacity() + 3; + if encoded.capacity() < needed_capacity { + encoded.reserve(needed_capacity.saturating_sub(encoded.len())); + } let mut moved_into_matcher = false; if raw_block.as_ref().is_some_and(|block| block.is_empty()) { let header = BlockHeader { @@ -374,6 +381,8 @@ impl StreamingEncoder { } if let Err(err) = self.drain_mut().and_then(|drain| drain.write_all(&encoded)) { + encoded.clear(); + mem::swap(&mut encoded, &mut self.encoded_scratch); let restored = if moved_into_matcher { self.state.matcher.get_last_space().to_vec() } else { @@ -390,6 +399,8 @@ impl StreamingEncoder { } else { self.hash_block(raw_block.as_deref().unwrap_or(&[])); } + encoded.clear(); + mem::swap(&mut encoded, &mut self.encoded_scratch); Ok(()) } @@ -1009,6 +1020,37 @@ mod tests { assert_eq!(err.kind(), ErrorKind::InvalidInput); } + #[test] + fn encoded_scratch_capacity_is_reused_across_blocks() { + let payload = vec![0xAB; 64 * 3]; + let mut encoder = StreamingEncoder::new_with_matcher( + TinyMatcher::new(64), + Vec::new(), + CompressionLevel::Uncompressed, + ); + + encoder.write_all(&payload[..64]).unwrap(); + let first_capacity = encoder.encoded_scratch.capacity(); + assert!( + first_capacity >= 67, + "expected encoded scratch to keep block header + payload capacity", + ); + + encoder.write_all(&payload[64..128]).unwrap(); + let second_capacity = encoder.encoded_scratch.capacity(); + assert!( + second_capacity >= first_capacity, + "encoded scratch capacity should be reused across block emits", + ); + + encoder.write_all(&payload[128..]).unwrap(); + let compressed = encoder.finish().unwrap(); + let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap(); + let mut decoded = Vec::new(); + decoder.read_to_end(&mut decoded).unwrap(); + assert_eq!(decoded, payload); + } + #[test] fn pledged_content_size_after_write_returns_error() { let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);