Skip to content

Commit

Permalink
perf allow for dirty padding of decompression output (#1100)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tabaie authored Apr 11, 2024
1 parent acd3529 commit c4d989b
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 8 deletions.
36 changes: 28 additions & 8 deletions std/compress/lzss/snark.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@ import (
// it is recommended to pack the dictionary using compress.Pack and take a MiMC checksum of it.
// d will consist of bytes
// It returns the length of d as a frontend.Variable; if the decompressed stream doesn't fit in d, dLength will be "-1"
func Decompress(api frontend.API, c []frontend.Variable, cLength frontend.Variable, d, dict []frontend.Variable) (dLength frontend.Variable, err error) {
func Decompress(api frontend.API, c []frontend.Variable, cLength frontend.Variable, d, dict []frontend.Variable, options ...DecompressionOption) (dLength frontend.Variable, err error) {

var aux decompressionAux
for _, opt := range options {
opt(&aux)
}

api.AssertIsLessOrEqual(cLength, len(c)) // sanity check

Expand Down Expand Up @@ -100,8 +105,11 @@ func Decompress(api frontend.API, c []frontend.Variable, cLength frontend.Variab

// write to output
outVal := api.Select(copying, toCopy, curr)
// TODO previously the last byte of the output kept getting repeated. That can be worked with. If there was a reason to save some 600K constraints in the zkEVM decompressor, take this out again
d[outI] = plonk.EvaluateExpression(api, outVal, eof, 1, 0, -1, 0) // write zeros past eof
if aux.noZeroPaddingOutput {
d[outI] = outVal
} else {
d[outI] = plonk.EvaluateExpression(api, outVal, eof, 1, 0, -1, 0) // write zeros past eof
}
// WARNING: curr modified by MulAcc
outTable.Insert(d[outI])

Expand All @@ -114,11 +122,8 @@ func Decompress(api frontend.API, c []frontend.Variable, cLength frontend.Variab
// TODO Try removing this check and requiring the user to pad the input with nonzeros
// TODO Change inner to mulacc once https://github.com/Consensys/gnark/pull/859 is merged
// inI = inI + inIDelta * (1 - eof)
if eof == 0 {
inI = api.Add(inI, inIDelta)
} else {
inI = api.Add(inI, plonk.EvaluateExpression(api, inIDelta, eof, 1, 0, -1, 0)) // if eof, stay put
}

inI = api.Add(inI, plonk.EvaluateExpression(api, inIDelta, eof, 1, 0, -1, 0)) // if eof, stay put

eofNow := rangeChecker.IsLessThan(8, api.Sub(cLength, inI)) // less than a byte left; meaning we are at the end of the input

Expand Down Expand Up @@ -179,3 +184,18 @@ func RegisterHints() {
hint.RegisterHint(internal.BreakUpBytesIntoHalfHint)
hint.RegisterHint(compress.UnpackIntoBytesHint)
}

// options and other auxiliary input
type decompressionAux struct {
noZeroPaddingOutput bool
}

type DecompressionOption func(*decompressionAux)

// WithoutZeroPaddingOutput disables the feature where all decompressor output past the end is zeroed out
// It saves one constraint per byte of output but necessitates more assignment work
// If using this option, the output will be padded by the first byte of the input past the end
// If further the input is not padded, the output still will be padded with zeros
func WithoutZeroPaddingOutput(aux *decompressionAux) {
aux.noZeroPaddingOutput = true
}
35 changes: 35 additions & 0 deletions std/compress/lzss/snark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,38 @@ func TestBuildDecompress1KBto9KB(t *testing.T) {
assert.NoError(t, err)
fmt.Println(cs.GetNbConstraints())
}

func TestNoZeroPaddingOutput(t *testing.T) {
assignment := testNoZeroPaddingOutputCircuit{
C: []frontend.Variable{0, 1, 0, 2, 3, 0, 0, 0},
D: []frontend.Variable{2, 3, 3},
CLen: 4,
DLen: 1,
}
circuit := testNoZeroPaddingOutputCircuit{
C: make([]frontend.Variable, len(assignment.C)),
D: make([]frontend.Variable, len(assignment.D)),
}

RegisterHints()
test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithBackends(backend.PLONK), test.WithCurves(ecc.BLS12_377))
}

type testNoZeroPaddingOutputCircuit struct {
CLen, DLen frontend.Variable
C, D []frontend.Variable
}

func (c *testNoZeroPaddingOutputCircuit) Define(api frontend.API) error {
dict := []frontend.Variable{254, 255}
d := make([]frontend.Variable, len(c.D))
dLen, err := Decompress(api, c.C, c.CLen, d, dict, WithoutZeroPaddingOutput)
if err != nil {
return err
}
api.AssertIsEqual(c.DLen, dLen)
for i := range c.D {
api.AssertIsEqual(c.D[i], d[i])
}
return nil
}

0 comments on commit c4d989b

Please sign in to comment.