diff --git a/compress.go b/compress.go index 12cd7c3d5..3439247ab 100644 --- a/compress.go +++ b/compress.go @@ -187,7 +187,7 @@ func compress(cc CompressionCodec, level int, data []byte) ([]byte, error) { } return buf.Bytes(), nil case CompressionZSTD: - return zstdCompress(nil, data) + return zstdCompress(ZstdEncoderParams{level}, nil, data) default: return nil, PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", cc)} } diff --git a/decompress.go b/decompress.go index 5565e36cf..953147d01 100644 --- a/decompress.go +++ b/decompress.go @@ -54,7 +54,7 @@ func decompress(cc CompressionCodec, data []byte) ([]byte, error) { return io.ReadAll(reader) case CompressionZSTD: - return zstdDecompress(nil, data) + return zstdDecompress(ZstdDecoderParams{}, nil, data) default: return nil, PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", cc)} } diff --git a/zstd.go b/zstd.go index e23bfc477..80507e14e 100644 --- a/zstd.go +++ b/zstd.go @@ -1,18 +1,50 @@ package sarama import ( + "sync" + "github.com/klauspost/compress/zstd" ) -var ( - zstdDec, _ = zstd.NewReader(nil) - zstdEnc, _ = zstd.NewWriter(nil, zstd.WithZeroFrames(true)) -) +type ZstdEncoderParams struct { + Level int +} +type ZstdDecoderParams struct { +} + +var zstdEncMap, zstdDecMap sync.Map + +func getEncoder(params ZstdEncoderParams) *zstd.Encoder { + if ret, ok := zstdEncMap.Load(params); ok { + return ret.(*zstd.Encoder) + } + // It's possible to race and create multiple new writers. + // Only one will survive GC after use. + encoderLevel := zstd.SpeedDefault + if params.Level != CompressionLevelDefault { + encoderLevel = zstd.EncoderLevelFromZstd(params.Level) + } + zstdEnc, _ := zstd.NewWriter(nil, zstd.WithZeroFrames(true), + zstd.WithEncoderLevel(encoderLevel)) + zstdEncMap.Store(params, zstdEnc) + return zstdEnc +} + +func getDecoder(params ZstdDecoderParams) *zstd.Decoder { + if ret, ok := zstdDecMap.Load(params); ok { + return ret.(*zstd.Decoder) + } + // It's possible to race and create multiple new readers. + // Only one will survive GC after use. + zstdDec, _ := zstd.NewReader(nil) + zstdDecMap.Store(params, zstdDec) + return zstdDec +} -func zstdDecompress(dst, src []byte) ([]byte, error) { - return zstdDec.DecodeAll(src, dst) +func zstdDecompress(params ZstdDecoderParams, dst, src []byte) ([]byte, error) { + return getDecoder(params).DecodeAll(src, dst) } -func zstdCompress(dst, src []byte) ([]byte, error) { - return zstdEnc.EncodeAll(src, dst), nil +func zstdCompress(params ZstdEncoderParams, dst, src []byte) ([]byte, error) { + return getEncoder(params).EncodeAll(src, dst), nil }