diff --git a/zstd_stream.go b/zstd_stream.go index 0e2a76d..28dfd3e 100644 --- a/zstd_stream.go +++ b/zstd_stream.go @@ -8,7 +8,6 @@ package zstd */ import "C" import ( - "bytes" "fmt" "io" "unsafe" @@ -147,15 +146,14 @@ func (w *Writer) Close() error { type reader struct { ctx *C.ZBUFF_DCtx compressionBuffer []byte + compressionLeft int decompressionBuffer []byte + decompOff int + decompSize int dict []byte firstError error - // Reuse previous bytes from source that were not consumed - // Hopefully because we use recommended size, we will never need to use that - srcBuffer bytes.Buffer - dstBuffer bytes.Buffer - recommendedSrcSize int - underlyingReader io.Reader + recommendedSrcSize int + underlyingReader io.Reader } // NewReader creates a new io.ReadCloser. Reads from the returned ReadCloser @@ -209,24 +207,28 @@ func (r *reader) Close() error { func (r *reader) Read(p []byte) (int, error) { // If we already have enough bytes, return - if r.dstBuffer.Len() >= len(p) { - return r.dstBuffer.Read(p) + if r.decompSize-r.decompOff >= len(p) { + copy(p, r.decompressionBuffer[r.decompOff:]) + r.decompOff += len(p) + return len(p), nil } - for r.dstBuffer.Len() < len(p) { + copy(p, r.decompressionBuffer[r.decompOff:r.decompSize]) + got := r.decompSize - r.decompOff + r.decompSize = 0 + r.decompOff = 0 + + for got < len(p) { // Populate src src := r.compressionBuffer reader := r.underlyingReader - if r.srcBuffer.Len() != 0 { - reader = io.MultiReader(&r.srcBuffer, r.underlyingReader) - } - n, err := io.ReadFull(reader, src) - if err == io.EOF { - break - } else if err != nil && err != io.ErrUnexpectedEOF { + n, err := io.ReadFull(reader, src[r.compressionLeft:]) + if err == io.EOF && r.compressionLeft == 0 { + return got, io.EOF + } else if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { return 0, fmt.Errorf("failed to read from underlying reader: %s", err) } - src = src[:n] + src = src[:r.compressionLeft+n] // C code cSrcSize := C.size_t(len(src)) @@ -243,25 +245,25 @@ func (r *reader) Read(p []byte) (int, error) { } // Put everything in buffer - if int(cSrcSize) < len(src) { // We did not read everything, put in buffer - toSave := src[int(cSrcSize):] - _, err = r.srcBuffer.Write(toSave) - if err != nil { - return 0, fmt.Errorf("failed to store temporary src buffer: %s", err) - } - } - _, err = r.dstBuffer.Write(r.decompressionBuffer[:int(cDstSize)]) - if err != nil { - return 0, fmt.Errorf("failed to store temporary result: %s", err) + if int(cSrcSize) < len(src) { + left := src[int(cSrcSize):] + copy(r.compressionBuffer, left) } + r.compressionLeft = len(src) - int(cSrcSize) + r.decompSize = int(cDstSize) + r.decompOff = copy(p[got:], r.decompressionBuffer[:r.decompSize]) + got += r.decompOff // Resize buffers - if retCode > 0 { // Hint for next src buffer size - r.compressionBuffer = resize(r.compressionBuffer, retCode) - } else { // Reset to recommended size - r.compressionBuffer = resize(r.compressionBuffer, r.recommendedSrcSize) + nsize := retCode // Hint for next src buffer size + if nsize <= 0 { + // Reset to recommended size + nsize = r.recommendedSrcSize + } + if nsize < r.compressionLeft { + nsize = r.compressionLeft } + r.compressionBuffer = resize(r.compressionBuffer, nsize) } - // Write to dst - return r.dstBuffer.Read(p) + return got, nil }