diff --git a/zstd/blockenc.go b/zstd/blockenc.go index 2cfe925ade..32a7f401d5 100644 --- a/zstd/blockenc.go +++ b/zstd/blockenc.go @@ -427,6 +427,16 @@ func (b *blockEnc) encodeLits(lits []byte, raw bool) error { return nil } +// encodeRLE will encode an RLE block. +func (b *blockEnc) encodeRLE(val byte, length uint32) { + var bh blockHeader + bh.setLast(b.last) + bh.setSize(length) + bh.setType(blockTypeRLE) + b.output = bh.appendTo(b.output) + b.output = append(b.output, val) +} + // fuzzFseEncoder can be used to fuzz the FSE encoder. func fuzzFseEncoder(data []byte) int { if len(data) > maxSequences || len(data) < 2 { @@ -479,6 +489,16 @@ func (b *blockEnc) encode(org []byte, raw, rawAllLits bool) error { if len(b.sequences) == 0 { return b.encodeLits(b.literals, rawAllLits) } + if len(b.sequences) == 1 && len(org) > 0 && len(b.literals) <= 1 { + // Check common RLE cases. + seq := b.sequences[0] + if seq.litLen == uint32(len(b.literals)) && seq.offset-3 == 1 { + // Offset == 1 and 0 or 1 literals. + b.encodeRLE(org[0], b.sequences[0].matchLen+zstdMinMatch+seq.litLen) + return nil + } + } + // We want some difference to at least account for the headers. saved := b.size - len(b.literals) - (b.size >> 6) if saved < 16 { diff --git a/zstd/enc_best.go b/zstd/enc_best.go index 87f42879a8..4613724e9d 100644 --- a/zstd/enc_best.go +++ b/zstd/enc_best.go @@ -135,8 +135,20 @@ func (e *bestFastEncoder) Encode(blk *blockEnc, src []byte) { break } + // Add block to history s := e.addBlock(src) blk.size = len(src) + + // Check RLE first + if len(src) > zstdMinMatch { + ml := matchLen(src[1:], src) + if ml == len(src)-1 { + blk.literals = append(blk.literals, src[0]) + blk.sequences = append(blk.sequences, seq{litLen: 1, matchLen: uint32(len(src)-1) - zstdMinMatch, offset: 1 + 3}) + return + } + } + if len(src) < minNonLiteralBlockSize { blk.extraLits = len(src) blk.literals = blk.literals[:len(src)] diff --git a/zstd/enc_better.go b/zstd/enc_better.go index 20d25b0e05..a4f5bf91fc 100644 --- a/zstd/enc_better.go +++ b/zstd/enc_better.go @@ -102,9 +102,20 @@ func (e *betterFastEncoder) Encode(blk *blockEnc, src []byte) { e.cur = e.maxMatchOff break } - + // Add block to history s := e.addBlock(src) blk.size = len(src) + + // Check RLE first + if len(src) > zstdMinMatch { + ml := matchLen(src[1:], src) + if ml == len(src)-1 { + blk.literals = append(blk.literals, src[0]) + blk.sequences = append(blk.sequences, seq{litLen: 1, matchLen: uint32(len(src)-1) - zstdMinMatch, offset: 1 + 3}) + return + } + } + if len(src) < minNonLiteralBlockSize { blk.extraLits = len(src) blk.literals = blk.literals[:len(src)] diff --git a/zstd/encoder_test.go b/zstd/encoder_test.go index 0f13aac5ba..1b2569f1db 100644 --- a/zstd/encoder_test.go +++ b/zstd/encoder_test.go @@ -342,6 +342,52 @@ func TestEncoder_EncodeAllTwain(t *testing.T) { } } +func TestEncoder_EncodeRLE(t *testing.T) { + in := make([]byte, 1<<20) + testWindowSizes := testWindowSizes + if testing.Short() { + testWindowSizes = []int{1 << 20} + } + + dec, err := NewReader(nil) + if err != nil { + t.Fatal(err) + } + defer dec.Close() + + for level := speedNotSet + 1; level < speedLast; level++ { + t.Run(level.String(), func(t *testing.T) { + if isRaceTest && level >= SpeedBestCompression { + t.SkipNow() + } + for _, windowSize := range testWindowSizes { + t.Run(fmt.Sprintf("window:%d", windowSize), func(t *testing.T) { + e, err := NewWriter(nil, WithEncoderLevel(level), WithWindowSize(windowSize)) + if err != nil { + t.Fatal(err) + } + defer e.Close() + start := time.Now() + dst := e.EncodeAll(in, nil) + t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst)) + mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second))) + t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec) + + decoded, err := dec.DecodeAll(dst, nil) + if err != nil { + t.Error(err, len(decoded)) + } + if !bytes.Equal(decoded, in) { + os.WriteFile("testdata/"+t.Name()+"-RLE.got", decoded, os.ModePerm) + t.Fatal("Decoded does not match") + } + t.Log("Encoded content matched") + }) + } + }) + } +} + func TestEncoder_EncodeAllPi(t *testing.T) { in, err := os.ReadFile("../testdata/pi.txt") if err != nil {