diff --git a/compress/compress.go b/compress/compress.go index 72d7ae08..1f7bf402 100644 --- a/compress/compress.go +++ b/compress/compress.go @@ -21,17 +21,6 @@ func (c CompressionAlgorithm) String() string { } } -func NewCompressionAlgorithm(s string) (CompressionAlgorithm, error) { - switch s { - case "zstd": - return CompressionAlgoZstd, nil - case "zstd-cgo": - return CompressionAlgoZstdCgo, nil - default: - return 0, fmt.Errorf("unknown compression algorithm: %s", s) - } -} - // CompressionLevel is the interface that wraps the compression level method. type CompressionLevel int @@ -50,18 +39,34 @@ func (c CompressionLevel) String() string { } } -func NewCompressionLevel(s string) (CompressionLevel, error) { - switch s { - case "fastest": - return CompressionLevelZstdFastest, nil - case "default": - return CompressionLevelZstdDefault, nil - case "better": - return CompressionLevelZstdBetter, nil - case "best": - return CompressionLevelZstdBest, nil +func NewSettings(algo, level string) (CompressionAlgorithm, CompressionLevel, error) { + switch algo { + case "zstd": + switch level { + case "fastest": + return CompressionAlgoZstd, CompressionLevelZstdFastest, nil + case "default": + return CompressionAlgoZstd, CompressionLevelZstdDefault, nil + case "better": + return CompressionAlgoZstd, CompressionLevelZstdBetter, nil + case "best": + return CompressionAlgoZstd, CompressionLevelZstdBest, nil + default: + return 0, 0, fmt.Errorf("unknown compression level for %s: %s", algo, level) + } + case "zstd-cgo": + switch level { + case "fastest": + return CompressionAlgoZstdCgo, CompressionLevelZstdCgoFastest, nil + case "default": + return CompressionAlgoZstdCgo, CompressionLevelZstdCgoDefault, nil + case "best": + return CompressionAlgoZstdCgo, CompressionLevelZstdCgoBest, nil + default: + return 0, 0, fmt.Errorf("unknown compression level for %s: %s", algo, level) + } default: - return 0, fmt.Errorf("unknown compression level: %s", s) + return 0, 0, fmt.Errorf("unknown compression algorithm: %s", algo) } } @@ -79,17 +84,14 @@ var ( ) func New(algo CompressionAlgorithm, level CompressionLevel) (*Compressor, error) { + var err error + algo, level, err = NewSettings(algo.String(), level.String()) + if err != nil { + return nil, err + } + switch algo { case CompressionAlgoZstd: - switch level { - case CompressionLevelZstdFastest, - CompressionLevelZstdDefault, - CompressionLevelZstdBetter, - CompressionLevelZstdBest: - default: - return nil, fmt.Errorf("invalid compression level for %q: %d", algo, level) - } - encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.EncoderLevel(level))) if err != nil { return nil, fmt.Errorf("cannot create zstd encoder: %w", err) @@ -105,20 +107,8 @@ func New(algo CompressionAlgorithm, level CompressionLevel) (*Compressor, error) decoder: decoder, }}, nil case CompressionAlgoZstdCgo: - var cgoLevel int - switch level { - case CompressionLevelZstdCgoFastest: - cgoLevel = zstdcgo.BestSpeed - case CompressionLevelZstdCgoDefault: - cgoLevel = zstdcgo.DefaultCompression - case CompressionLevelZstdCgoBest: - cgoLevel = zstdcgo.BestCompression - default: - return nil, fmt.Errorf("invalid compression level for %q: %d", algo, level) - } - return &Compressor{ - compressorZstdCgo: &compressorZstdCgo{level: cgoLevel}, + compressorZstdCgo: &compressorZstdCgo{level: int(level)}, }, nil default: return nil, fmt.Errorf("unknown compression algorithm: %d", algo) diff --git a/compress/compress_test.go b/compress/compress_test.go index de3706f1..9be81f44 100644 --- a/compress/compress_test.go +++ b/compress/compress_test.go @@ -46,21 +46,39 @@ func TestCompress(t *testing.T) { } func TestSerialization(t *testing.T) { - algo, err := NewCompressionAlgorithm("zstd") - require.NoError(t, err) - require.Equal(t, CompressionAlgoZstd, algo) + type testCase struct { + algo, level string + expectedSerialized string + expectedAlgo CompressionAlgorithm + expectedLevel CompressionLevel + } + testCases := []testCase{ + {"zstd", "fastest", "1:1", CompressionAlgoZstd, CompressionLevelZstdFastest}, + {"zstd", "default", "1:2", CompressionAlgoZstd, CompressionLevelZstdDefault}, + {"zstd", "better", "1:3", CompressionAlgoZstd, CompressionLevelZstdBetter}, + {"zstd", "best", "1:4", CompressionAlgoZstd, CompressionLevelZstdBest}, + + {"zstd-cgo", "fastest", "2:1", CompressionAlgoZstdCgo, CompressionLevelZstdCgoFastest}, + {"zstd-cgo", "default", "2:5", CompressionAlgoZstdCgo, CompressionLevelZstdCgoDefault}, + {"zstd-cgo", "best", "2:20", CompressionAlgoZstdCgo, CompressionLevelZstdCgoBest}, + } - level, err := NewCompressionLevel("best") - require.NoError(t, err) - require.Equal(t, CompressionLevelZstdBest, level) + for _, tc := range testCases { + t.Run(tc.algo+"-"+tc.level, func(t *testing.T) { + algo, level, err := NewSettings(tc.algo, tc.level) + require.NoError(t, err) + require.Equal(t, tc.expectedAlgo, algo) + require.Equal(t, tc.expectedLevel, level) - serialized := SerializeSettings(algo, level) - require.Equal(t, "1:4", serialized) + serialized := SerializeSettings(algo, level) + require.Equal(t, tc.expectedSerialized, serialized) - algo, level, err = DeserializeSettings(serialized) - require.NoError(t, err) - require.Equal(t, CompressionAlgoZstd, algo) - require.Equal(t, CompressionLevelZstdBest, level) + algo, level, err = DeserializeSettings(serialized) + require.NoError(t, err) + require.Equal(t, tc.expectedAlgo, algo) + require.Equal(t, tc.expectedLevel, level) + }) + } } func TestDeserializationError(t *testing.T) {