From c4d989b585c5eb871a86a9b8e71ed6ac6fb965d0 Mon Sep 17 00:00:00 2001 From: Arya Tabaie Date: Wed, 10 Apr 2024 20:21:49 -0500 Subject: [PATCH] perf allow for dirty padding of decompression output (#1100) --- std/compress/lzss/snark.go | 36 +++++++++++++++++++++++++-------- std/compress/lzss/snark_test.go | 35 ++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 8 deletions(-) diff --git a/std/compress/lzss/snark.go b/std/compress/lzss/snark.go index 47f38d6f6e..68b9773b36 100644 --- a/std/compress/lzss/snark.go +++ b/std/compress/lzss/snark.go @@ -18,7 +18,12 @@ import ( // it is recommended to pack the dictionary using compress.Pack and take a MiMC checksum of it. // d will consist of bytes // It returns the length of d as a frontend.Variable; if the decompressed stream doesn't fit in d, dLength will be "-1" -func Decompress(api frontend.API, c []frontend.Variable, cLength frontend.Variable, d, dict []frontend.Variable) (dLength frontend.Variable, err error) { +func Decompress(api frontend.API, c []frontend.Variable, cLength frontend.Variable, d, dict []frontend.Variable, options ...DecompressionOption) (dLength frontend.Variable, err error) { + + var aux decompressionAux + for _, opt := range options { + opt(&aux) + } api.AssertIsLessOrEqual(cLength, len(c)) // sanity check @@ -100,8 +105,11 @@ func Decompress(api frontend.API, c []frontend.Variable, cLength frontend.Variab // write to output outVal := api.Select(copying, toCopy, curr) - // TODO previously the last byte of the output kept getting repeated. That can be worked with. If there was a reason to save some 600K constraints in the zkEVM decompressor, take this out again - d[outI] = plonk.EvaluateExpression(api, outVal, eof, 1, 0, -1, 0) // write zeros past eof + if aux.noZeroPaddingOutput { + d[outI] = outVal + } else { + d[outI] = plonk.EvaluateExpression(api, outVal, eof, 1, 0, -1, 0) // write zeros past eof + } // WARNING: curr modified by MulAcc outTable.Insert(d[outI]) @@ -114,11 +122,8 @@ func Decompress(api frontend.API, c []frontend.Variable, cLength frontend.Variab // TODO Try removing this check and requiring the user to pad the input with nonzeros // TODO Change inner to mulacc once https://github.com/Consensys/gnark/pull/859 is merged // inI = inI + inIDelta * (1 - eof) - if eof == 0 { - inI = api.Add(inI, inIDelta) - } else { - inI = api.Add(inI, plonk.EvaluateExpression(api, inIDelta, eof, 1, 0, -1, 0)) // if eof, stay put - } + + inI = api.Add(inI, plonk.EvaluateExpression(api, inIDelta, eof, 1, 0, -1, 0)) // if eof, stay put eofNow := rangeChecker.IsLessThan(8, api.Sub(cLength, inI)) // less than a byte left; meaning we are at the end of the input @@ -179,3 +184,18 @@ func RegisterHints() { hint.RegisterHint(internal.BreakUpBytesIntoHalfHint) hint.RegisterHint(compress.UnpackIntoBytesHint) } + +// options and other auxiliary input +type decompressionAux struct { + noZeroPaddingOutput bool +} + +type DecompressionOption func(*decompressionAux) + +// WithoutZeroPaddingOutput disables the feature where all decompressor output past the end is zeroed out +// It saves one constraint per byte of output but necessitates more assignment work +// If using this option, the output will be padded by the first byte of the input past the end +// If further the input is not padded, the output still will be padded with zeros +func WithoutZeroPaddingOutput(aux *decompressionAux) { + aux.noZeroPaddingOutput = true +} diff --git a/std/compress/lzss/snark_test.go b/std/compress/lzss/snark_test.go index c45db3778d..ab1fbfc232 100644 --- a/std/compress/lzss/snark_test.go +++ b/std/compress/lzss/snark_test.go @@ -366,3 +366,38 @@ func TestBuildDecompress1KBto9KB(t *testing.T) { assert.NoError(t, err) fmt.Println(cs.GetNbConstraints()) } + +func TestNoZeroPaddingOutput(t *testing.T) { + assignment := testNoZeroPaddingOutputCircuit{ + C: []frontend.Variable{0, 1, 0, 2, 3, 0, 0, 0}, + D: []frontend.Variable{2, 3, 3}, + CLen: 4, + DLen: 1, + } + circuit := testNoZeroPaddingOutputCircuit{ + C: make([]frontend.Variable, len(assignment.C)), + D: make([]frontend.Variable, len(assignment.D)), + } + + RegisterHints() + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithBackends(backend.PLONK), test.WithCurves(ecc.BLS12_377)) +} + +type testNoZeroPaddingOutputCircuit struct { + CLen, DLen frontend.Variable + C, D []frontend.Variable +} + +func (c *testNoZeroPaddingOutputCircuit) Define(api frontend.API) error { + dict := []frontend.Variable{254, 255} + d := make([]frontend.Variable, len(c.D)) + dLen, err := Decompress(api, c.C, c.CLen, d, dict, WithoutZeroPaddingOutput) + if err != nil { + return err + } + api.AssertIsEqual(c.DLen, dLen) + for i := range c.D { + api.AssertIsEqual(c.D[i], d[i]) + } + return nil +}