From c9e413163323e4138f4e8f0e545220ccc3aedfb9 Mon Sep 17 00:00:00 2001 From: Liz Fong-Jones Date: Thu, 30 Sep 2021 16:37:43 -0700 Subject: [PATCH 1/2] pass level param through --- compress.go | 2 +- zstd.go | 25 ++++++++++++++++++++++--- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/compress.go b/compress.go index 12cd7c3d5..8436dc20e 100644 --- a/compress.go +++ b/compress.go @@ -187,7 +187,7 @@ func compress(cc CompressionCodec, level int, data []byte) ([]byte, error) { } return buf.Bytes(), nil case CompressionZSTD: - return zstdCompress(nil, data) + return zstdCompress(level, nil, data) default: return nil, PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", cc)} } diff --git a/zstd.go b/zstd.go index e23bfc477..336564ac0 100644 --- a/zstd.go +++ b/zstd.go @@ -1,18 +1,37 @@ package sarama import ( + "sync" + "github.com/klauspost/compress/zstd" ) +var zstdEncMap sync.Map + var ( zstdDec, _ = zstd.NewReader(nil) - zstdEnc, _ = zstd.NewWriter(nil, zstd.WithZeroFrames(true)) ) +func getEncoder(level int) *zstd.Encoder { + if ret, ok := zstdEncMap.Load(level); ok { + return ret.(*zstd.Encoder) + } + // It's possible to race and create multiple new writers. + // Only one will survive GC after use. + encoderLevel := zstd.SpeedDefault + if level != CompressionLevelDefault { + encoderLevel = zstd.EncoderLevelFromZstd(level) + } + zstdEnc, _ := zstd.NewWriter(nil, zstd.WithZeroFrames(true), + zstd.WithEncoderLevel(encoderLevel)) + zstdEncMap.Store(level, zstdEnc) + return zstdEnc +} + func zstdDecompress(dst, src []byte) ([]byte, error) { return zstdDec.DecodeAll(src, dst) } -func zstdCompress(dst, src []byte) ([]byte, error) { - return zstdEnc.EncodeAll(src, dst), nil +func zstdCompress(level int, dst, src []byte) ([]byte, error) { + return getEncoder(level).EncodeAll(src, dst), nil } From 60e10ba6f76c89375231a9780392c0feb3660b23 Mon Sep 17 00:00:00 2001 From: Liz Fong-Jones Date: Thu, 14 Oct 2021 20:36:51 -0700 Subject: [PATCH 2/2] respond to review feedback --- compress.go | 2 +- decompress.go | 2 +- zstd.go | 39 ++++++++++++++++++++++++++------------- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/compress.go b/compress.go index 8436dc20e..3439247ab 100644 --- a/compress.go +++ b/compress.go @@ -187,7 +187,7 @@ func compress(cc CompressionCodec, level int, data []byte) ([]byte, error) { } return buf.Bytes(), nil case CompressionZSTD: - return zstdCompress(level, nil, data) + return zstdCompress(ZstdEncoderParams{level}, nil, data) default: return nil, PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", cc)} } diff --git a/decompress.go b/decompress.go index 5565e36cf..953147d01 100644 --- a/decompress.go +++ b/decompress.go @@ -54,7 +54,7 @@ func decompress(cc CompressionCodec, data []byte) ([]byte, error) { return io.ReadAll(reader) case CompressionZSTD: - return zstdDecompress(nil, data) + return zstdDecompress(ZstdDecoderParams{}, nil, data) default: return nil, PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", cc)} } diff --git a/zstd.go b/zstd.go index 336564ac0..80507e14e 100644 --- a/zstd.go +++ b/zstd.go @@ -6,32 +6,45 @@ import ( "github.com/klauspost/compress/zstd" ) -var zstdEncMap sync.Map +type ZstdEncoderParams struct { + Level int +} +type ZstdDecoderParams struct { +} -var ( - zstdDec, _ = zstd.NewReader(nil) -) +var zstdEncMap, zstdDecMap sync.Map -func getEncoder(level int) *zstd.Encoder { - if ret, ok := zstdEncMap.Load(level); ok { +func getEncoder(params ZstdEncoderParams) *zstd.Encoder { + if ret, ok := zstdEncMap.Load(params); ok { return ret.(*zstd.Encoder) } // It's possible to race and create multiple new writers. // Only one will survive GC after use. encoderLevel := zstd.SpeedDefault - if level != CompressionLevelDefault { - encoderLevel = zstd.EncoderLevelFromZstd(level) + if params.Level != CompressionLevelDefault { + encoderLevel = zstd.EncoderLevelFromZstd(params.Level) } zstdEnc, _ := zstd.NewWriter(nil, zstd.WithZeroFrames(true), zstd.WithEncoderLevel(encoderLevel)) - zstdEncMap.Store(level, zstdEnc) + zstdEncMap.Store(params, zstdEnc) return zstdEnc } -func zstdDecompress(dst, src []byte) ([]byte, error) { - return zstdDec.DecodeAll(src, dst) +func getDecoder(params ZstdDecoderParams) *zstd.Decoder { + if ret, ok := zstdDecMap.Load(params); ok { + return ret.(*zstd.Decoder) + } + // It's possible to race and create multiple new readers. + // Only one will survive GC after use. + zstdDec, _ := zstd.NewReader(nil) + zstdDecMap.Store(params, zstdDec) + return zstdDec +} + +func zstdDecompress(params ZstdDecoderParams, dst, src []byte) ([]byte, error) { + return getDecoder(params).DecodeAll(src, dst) } -func zstdCompress(level int, dst, src []byte) ([]byte, error) { - return getEncoder(level).EncodeAll(src, dst), nil +func zstdCompress(params ZstdEncoderParams, dst, src []byte) ([]byte, error) { + return getEncoder(params).EncodeAll(src, dst), nil }