From d568b17e509441c72e88dd006731a85df068132c Mon Sep 17 00:00:00 2001 From: Bao Pham <145053932+bao1029p@users.noreply.github.com> Date: Wed, 1 Nov 2023 02:57:22 +0700 Subject: [PATCH] feat: add roots helper function (#270) ## Overview close [#249](https://github.com/celestiaorg/rsmt2d/issues/249) ## Changes Add helper function and unit test for new function --------- Co-authored-by: Rootul Patel --- datasquare_test.go | 4 ++-- extendeddatasquare.go | 18 ++++++++++++++++ extendeddatasquare_test.go | 43 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 2 deletions(-) diff --git a/datasquare_test.go b/datasquare_test.go index 42514f7..78331db 100644 --- a/datasquare_test.go +++ b/datasquare_test.go @@ -188,9 +188,9 @@ func TestInvalidSquareExtension(t *testing.T) { } } -// TestRoots verifies that the row roots and column roots are equal for a 1x1 +// Test_getRoots verifies that the row roots and column roots are equal for a 1x1 // square. -func TestRoots(t *testing.T) { +func Test_getRoots(t *testing.T) { result, err := newDataSquare([][]byte{{1, 2}}, NewDefaultTree, 2) assert.NoError(t, err) diff --git a/extendeddatasquare.go b/extendeddatasquare.go index 9193d87..f684020 100644 --- a/extendeddatasquare.go +++ b/extendeddatasquare.go @@ -322,6 +322,24 @@ func (eds *ExtendedDataSquare) Equals(other *ExtendedDataSquare) bool { return true } +// Roots returns a byte slice with this eds's RowRoots and ColRoots +// concatenated. +func (eds *ExtendedDataSquare) Roots() (roots [][]byte, err error) { + rowRoots, err := eds.RowRoots() + if err != nil { + return nil, err + } + colRoots, err := eds.ColRoots() + if err != nil { + return nil, err + } + + roots = make([][]byte, 0, len(rowRoots)+len(colRoots)) + roots = append(roots, rowRoots...) + roots = append(roots, colRoots...) + return roots, nil +} + // validateEdsWidth returns an error if edsWidth is not a valid width for an // extended data square. func validateEdsWidth(edsWidth uint) error { diff --git a/extendeddatasquare_test.go b/extendeddatasquare_test.go index 0815418..2cad1b2 100644 --- a/extendeddatasquare_test.go +++ b/extendeddatasquare_test.go @@ -426,6 +426,49 @@ func TestEquals(t *testing.T) { }) } +func TestRoots(t *testing.T) { + t.Run("returns roots for a 4x4 EDS", func(t *testing.T) { + eds, err := ComputeExtendedDataSquare([][]byte{ + ones, twos, + threes, fours, + }, NewLeoRSCodec(), NewDefaultTree) + require.NoError(t, err) + + roots, err := eds.Roots() + require.NoError(t, err) + assert.Len(t, roots, 8) + + rowRoots, err := eds.RowRoots() + require.NoError(t, err) + + colRoots, err := eds.ColRoots() + require.NoError(t, err) + + assert.Equal(t, roots[0], rowRoots[0]) + assert.Equal(t, roots[1], rowRoots[1]) + assert.Equal(t, roots[2], rowRoots[2]) + assert.Equal(t, roots[3], rowRoots[3]) + assert.Equal(t, roots[4], colRoots[0]) + assert.Equal(t, roots[5], colRoots[1]) + assert.Equal(t, roots[6], colRoots[2]) + assert.Equal(t, roots[7], colRoots[3]) + }) + + t.Run("returns an error for an incomplete EDS", func(t *testing.T) { + eds, err := ComputeExtendedDataSquare([][]byte{ + ones, twos, + threes, fours, + }, NewLeoRSCodec(), NewDefaultTree) + require.NoError(t, err) + + // set a cell to nil to make the EDS incomplete + eds.setCell(0, 0, nil) + + _, err = eds.Roots() + assert.Error(t, err) + }) +} + func createExampleEds(t *testing.T, chunkSize int) (eds *ExtendedDataSquare) { ones := bytes.Repeat([]byte{1}, chunkSize) twos := bytes.Repeat([]byte{2}, chunkSize)