Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 51 additions & 13 deletions zstd_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"fmt"
"io"
"runtime"
"sync"
"unsafe"
)

Expand Down Expand Up @@ -146,6 +147,48 @@ func (w *Writer) Close() error {
return nil
}

// cSize is the recommended size of reader.compressionBuffer. This func and
// invocation allow for a one-time check for validity.
var cSize = func() int {
v := int(C.ZBUFF_recommendedDInSize())
if v <= 0 {
panic(fmt.Errorf("ZBUFF_recommendedDInSize() returned invalid size: %v", v))
}
return v
}()

// dSize is the recommended size of reader.decompressionBuffer. This func and
// invocation allow for a one-time check for validity.
var dSize = func() int {
v := int(C.ZBUFF_recommendedDOutSize())
if v <= 0 {
panic(fmt.Errorf("ZBUFF_recommendedDOutSize() returned invalid size: %v", v))
}
return v
}()

// cPool is a pool of buffers for use in reader.compressionBuffer. Buffers are
// taken from the pool in NewReaderDict, returned in reader.Close(). Returns a
// pointer to a slice to avoid the extra allocation of returning the slice as a
// value.
var cPool = sync.Pool{
New: func() interface{} {
buff := make([]byte, cSize)
return &buff
},
}

// dPool is a pool of buffers for use in reader.decompressionBuffer. Buffers are
// taken from the pool in NewReaderDict, returned in reader.Close(). Returns a
// pointer to a slice to avoid the extra allocation of returning the slice as a
// value.
var dPool = sync.Pool{
New: func() interface{} {
buff := make([]byte, dSize)
return &buff
},
}

// reader is an io.ReadCloser that decompresses when read from.
type reader struct {
ctx *C.ZBUFF_DCtx
Expand Down Expand Up @@ -181,22 +224,13 @@ func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {
unsafe.Pointer(&dict[0]),
C.size_t(len(dict)))))
}
cSize := int(C.ZBUFF_recommendedDInSize())
dSize := int(C.ZBUFF_recommendedDOutSize())
if cSize <= 0 {
panic(fmt.Errorf("ZBUFF_recommendedDInSize() returned invalid size: %v", cSize))
}
if dSize <= 0 {
panic(fmt.Errorf("ZBUFF_recommendedDOutSize() returned invalid size: %v", dSize))
}

compressionBuffer := make([]byte, cSize)
decompressionBuffer := make([]byte, dSize)
compressionBufferP := cPool.Get().(*[]byte)
decompressionBufferP := dPool.Get().(*[]byte)
return &reader{
ctx: ctx,
dict: dict,
compressionBuffer: compressionBuffer,
decompressionBuffer: decompressionBuffer,
compressionBuffer: *compressionBufferP,
decompressionBuffer: *decompressionBufferP,
firstError: err,
recommendedSrcSize: cSize,
underlyingReader: r,
Expand All @@ -205,6 +239,10 @@ func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {

// Close frees the allocated C objects
func (r *reader) Close() error {
cb := r.compressionBuffer
db := r.decompressionBuffer
cPool.Put(&cb)
dPool.Put(&db)
return getError(int(C.ZBUFF_freeDCtx(r.ctx)))
}

Expand Down
10 changes: 10 additions & 0 deletions zstd_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ func BenchmarkStreamDecompression(b *testing.B) {
if err != nil {
b.Fatalf("Failed to decompress: %s", err)
}
r.Close()
}
}

Expand All @@ -208,3 +209,12 @@ func TestUnexpectedEOFHandling(t *testing.T) {
t.Error("Underlying error was handled silently")
}
}

func TestStreamCompressionDecompressionParallel(t *testing.T) {
for i := 0; i < 200; i++ {
t.Run("", func(t2 *testing.T) {
t2.Parallel()
TestStreamCompressionDecompression(t2)
})
}
}