Skip to content

Commit

Permalink
Merge pull request #157 from rootulp/rp/errors
Browse files Browse the repository at this point in the history
chore!: add error return params to tree interface
  • Loading branch information
rootulp authored Apr 14, 2023
2 parents 1e85aab + 1956b16 commit 6515446
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 57 deletions.
48 changes: 31 additions & 17 deletions datasquare.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"fmt"
"math"
"sync"

"golang.org/x/sync/errgroup"
)

// ErrUnevenChunks is thrown when non-nil chunks are not all of equal size.
Expand Down Expand Up @@ -187,29 +189,41 @@ func (ds *dataSquare) resetRoots() {
}
}

func (ds *dataSquare) computeRoots() {
var wg sync.WaitGroup
func (ds *dataSquare) computeRoots() error {
var g errgroup.Group

rowRoots := make([][]byte, ds.width)
colRoots := make([][]byte, ds.width)

for i := uint(0); i < ds.width; i++ {
wg.Add(2)

go func(i uint) {
defer wg.Done()
rowRoots[i] = ds.getRowRoot(i)
}(i)
i := i // https://go.dev/doc/faq#closures_and_goroutines
g.Go(func() error {
rowRoot, err := ds.getRowRoot(i)
if err != nil {
return err
}
rowRoots[i] = rowRoot
return nil
})

g.Go(func() error {
colRoot, err := ds.getColRoot(i)
if err != nil {
return err
}
colRoots[i] = colRoot
return nil
})
}

go func(i uint) {
defer wg.Done()
colRoots[i] = ds.getColRoot(i)
}(i)
err := g.Wait()
if err != nil {
return err
}

wg.Wait()
ds.rowRoots = rowRoots
ds.colRoots = colRoots
return nil
}

// getRowRoots returns the Merkle roots of all the rows in the square.
Expand All @@ -223,9 +237,9 @@ func (ds *dataSquare) getRowRoots() [][]byte {

// getRowRoot calculates and returns the root of the selected row. Note: unlike the
// getRowRoots method, getRowRoot does not write to the built-in cache.
func (ds *dataSquare) getRowRoot(x uint) []byte {
func (ds *dataSquare) getRowRoot(x uint) ([]byte, error) {
if ds.rowRoots != nil {
return ds.rowRoots[x]
return ds.rowRoots[x], nil
}

tree := ds.createTreeFn(Row, x)
Expand All @@ -247,9 +261,9 @@ func (ds *dataSquare) getColRoots() [][]byte {

// getColRoot calculates and returns the root of the selected row. Note: unlike the
// getColRoots method, getColRoot does not write to the built-in cache.
func (ds *dataSquare) getColRoot(y uint) []byte {
func (ds *dataSquare) getColRoot(y uint) ([]byte, error) {
if ds.colRoots != nil {
return ds.colRoots[y]
return ds.colRoots[y], nil
}

tree := ds.createTreeFn(Col, y)
Expand Down
77 changes: 57 additions & 20 deletions datasquare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"reflect"
"testing"

"github.com/celestiaorg/merkletree"
"github.com/minio/sha256-simd"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -138,36 +140,60 @@ func TestLazyRootGeneration(t *testing.T) {
var colRoots [][]byte

for i := uint(0); i < square.width; i++ {
rowRoots = append(rowRoots, square.getRowRoot(i))
colRoots = append(rowRoots, square.getColRoot(i))
rowRoot, err := square.getRowRoot(i)
assert.NoError(t, err)
colRoot, err := square.getColRoot(i)
assert.NoError(t, err)
rowRoots = append(rowRoots, rowRoot)
colRoots = append(colRoots, colRoot)
}

square.computeRoots()
err = square.computeRoots()
assert.NoError(t, err)

if !reflect.DeepEqual(square.rowRoots, rowRoots) && !reflect.DeepEqual(square.colRoots, colRoots) {
t.Error("getRowRoot or getColRoot did not produce identical roots to computeRoots")
}
}

func TestComputeRoots(t *testing.T) {
t.Run("default tree computeRoots() returns no error", func(t *testing.T) {
square, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree)
assert.NoError(t, err)
err = square.computeRoots()
assert.NoError(t, err)
})
t.Run("error tree computeRoots() returns an error", func(t *testing.T) {
square, err := newDataSquare([][]byte{{1}}, newErrorTree)
assert.NoError(t, err)
err = square.computeRoots()
assert.Error(t, err)
})
}

func TestRootAPI(t *testing.T) {
square, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree)
if err != nil {
panic(err)
}

for i := uint(0); i < square.width; i++ {
if !reflect.DeepEqual(square.getRowRoots()[i], square.getRowRoot(i)) {
rowRoot, err := square.getRowRoot(i)
assert.NoError(t, err)
if !reflect.DeepEqual(square.getRowRoots()[i], rowRoot) {
t.Errorf(
"Row root API results in different roots, expected %v got %v",
square.getRowRoots()[i],
square.getRowRoot(i),
rowRoot,
)
}
if !reflect.DeepEqual(square.getColRoots()[i], square.getColRoot(i)) {
colRoot, err := square.getColRoot(i)
assert.NoError(t, err)
if !reflect.DeepEqual(square.getColRoots()[i], colRoot) {
t.Errorf(
"Column root API results in different roots, expected %v got %v",
square.getColRoots()[i],
square.getColRoot(i),
colRoot,
)
}
}
Expand Down Expand Up @@ -205,7 +231,8 @@ func BenchmarkEDSRoots(b *testing.B) {
func(b *testing.B) {
for n := 0; n < b.N; n++ {
square.resetRoots()
square.computeRoots()
err := square.computeRoots()
assert.NoError(b, err)
}
},
)
Expand All @@ -224,18 +251,6 @@ func computeRowProof(ds *dataSquare, x uint, y uint) ([]byte, [][]byte, uint, ui
return merkleRoot, proof, uint(proofIndex), uint(numLeaves), nil
}

func computeColProof(ds *dataSquare, x uint, y uint) ([]byte, [][]byte, uint, uint, error) {
tree := ds.createTreeFn(Col, y)
data := ds.col(y)

for i := uint(0); i < ds.width; i++ {
tree.Push(data[i])
}
// TODO(ismail): check for overflow when casting from uint -> int
merkleRoot, proof, proofIndex, numLeaves := treeProve(tree.(*DefaultTree), int(x))
return merkleRoot, proof, uint(proofIndex), uint(numLeaves), nil
}

func treeProve(d *DefaultTree, idx int) (merkleRoot []byte, proofSet [][]byte, proofIndex uint64, numLeaves uint64) {
if err := d.Tree.SetIndex(uint64(idx)); err != nil {
panic(fmt.Sprintf("don't call prove on a already used tree: %v", err))
Expand All @@ -245,3 +260,25 @@ func treeProve(d *DefaultTree, idx int) (merkleRoot []byte, proofSet [][]byte, p
}
return d.Tree.Prove()
}

type errorTree struct {
*merkletree.Tree
leaves [][]byte
}

func newErrorTree(axis Axis, index uint) Tree {
return &errorTree{
Tree: merkletree.New(sha256.New()),
leaves: make([][]byte, 0, 128),
}
}

func (d *errorTree) Push(data []byte) error {
// ignore the idx, as this implementation doesn't need that info
d.leaves = append(d.leaves, data)
return nil
}

func (d *errorTree) Root() ([]byte, error) {
return nil, fmt.Errorf("error")
}
40 changes: 28 additions & 12 deletions extendeddatacrossword.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ func (e *ErrByzantineData) Error() string {
// square (EDS), comparing repaired rows and columns against expected Merkle
// roots.
//
// Input
// # Input
//
// Missing shares must be nil.
//
// Output
// # Output
//
// The EDS is modified in-place. If repairing is successful, the EDS will be
// complete. If repairing is unsuccessful, the EDS will be the most-repaired
Expand Down Expand Up @@ -282,10 +282,14 @@ func (eds *ExtendedDataSquare) verifyAgainstRowRoots(
rebuiltShare []byte,
) error {
var root []byte
var err error
if rebuiltIndex < 0 || rebuiltShare == nil {
root = eds.computeSharesRoot(oldShares, Row, r)
root, err = eds.computeSharesRoot(oldShares, Row, r)
} else {
root = eds.computeSharesRootWithRebuiltShare(oldShares, Row, r, rebuiltIndex, rebuiltShare)
root, err = eds.computeSharesRootWithRebuiltShare(oldShares, Row, r, rebuiltIndex, rebuiltShare)
}
if err != nil {
return err
}

if !bytes.Equal(root, rowRoots[r]) {
Expand All @@ -303,10 +307,14 @@ func (eds *ExtendedDataSquare) verifyAgainstColRoots(
rebuiltShare []byte,
) error {
var root []byte
var err error
if rebuiltIndex < 0 || rebuiltShare == nil {
root = eds.computeSharesRoot(oldShares, Col, c)
root, err = eds.computeSharesRoot(oldShares, Col, c)
} else {
root = eds.computeSharesRootWithRebuiltShare(oldShares, Col, c, rebuiltIndex, rebuiltShare)
root, err = eds.computeSharesRootWithRebuiltShare(oldShares, Col, c, rebuiltIndex, rebuiltShare)
}
if err != nil {
return err
}

if !bytes.Equal(root, colRoots[c]) {
Expand All @@ -331,8 +339,12 @@ func (eds *ExtendedDataSquare) prerepairSanityCheck(
if rowIsComplete {
errs.Go(func() error {
// ensure that the roots are equal
if !bytes.Equal(rowRoots[i], eds.getRowRoot(i)) {
return fmt.Errorf("bad root input: row %d expected %v got %v", i, rowRoots[i], eds.getRowRoot(i))
rowRoot, err := eds.getRowRoot(i)
if err != nil {
return err
}
if !bytes.Equal(rowRoots[i], rowRoot) {
return fmt.Errorf("bad root input: row %d expected %v got %v", i, rowRoots[i], rowRoot)
}
return nil
})
Expand All @@ -342,8 +354,12 @@ func (eds *ExtendedDataSquare) prerepairSanityCheck(
if colIsComplete {
errs.Go(func() error {
// ensure that the roots are equal
if !bytes.Equal(colRoots[i], eds.getColRoot(i)) {
return fmt.Errorf("bad root input: col %d expected %v got %v", i, colRoots[i], eds.getColRoot(i))
colRoot, err := eds.getColRoot(i)
if err != nil {
return err
}
if !bytes.Equal(colRoots[i], colRoot) {
return fmt.Errorf("bad root input: col %d expected %v got %v", i, colRoots[i], colRoot)
}
return nil
})
Expand Down Expand Up @@ -391,15 +407,15 @@ func noMissingData(input [][]byte, rebuiltIndex int) bool {
return true
}

func (eds *ExtendedDataSquare) computeSharesRoot(shares [][]byte, axis Axis, i uint) []byte {
func (eds *ExtendedDataSquare) computeSharesRoot(shares [][]byte, axis Axis, i uint) ([]byte, error) {
tree := eds.createTreeFn(axis, i)
for _, d := range shares {
tree.Push(d)
}
return tree.Root()
}

func (eds *ExtendedDataSquare) computeSharesRootWithRebuiltShare(shares [][]byte, axis Axis, i uint, rebuiltIndex int, rebuiltShare []byte) []byte {
func (eds *ExtendedDataSquare) computeSharesRootWithRebuiltShare(shares [][]byte, axis Axis, i uint, rebuiltIndex int, rebuiltShare []byte) ([]byte, error) {
tree := eds.createTreeFn(axis, i)
for _, d := range shares[:rebuiltIndex] {
tree.Push(d)
Expand Down
7 changes: 5 additions & 2 deletions extendeddatacrossword_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,11 @@ func TestValidFraudProof(t *testing.T) {
if err != nil {
t.Errorf("could not decode fraud proof shares; got: %v", err)
}
root := corrupted.computeSharesRoot(rebuiltShares, byzData.Axis, fraudProof.Index)
if bytes.Equal(root, corrupted.getRowRoot(fraudProof.Index)) {
root, err := corrupted.computeSharesRoot(rebuiltShares, byzData.Axis, fraudProof.Index)
assert.NoError(t, err)
rowRoot, err := corrupted.getRowRoot(fraudProof.Index)
assert.NoError(t, err)
if bytes.Equal(root, rowRoot) {
// If the roots match, then the fraud proof should be for invalid erasure coding.
parityShares, err := codec.Encode(rebuiltShares[0:corrupted.originalDataWidth])
if err != nil {
Expand Down
14 changes: 8 additions & 6 deletions tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import (
"github.com/celestiaorg/merkletree"
)

// TreeConstructorFn creates a fresh Tree instance to be used as the Merkle inside of rsmt2d.
// TreeConstructorFn creates a fresh Tree instance to be used as the Merkle tree
// inside of rsmt2d.
type TreeConstructorFn = func(axis Axis, index uint) Tree

// SquareIndex contains all information needed to identify the cell that is being
Expand All @@ -17,8 +18,8 @@ type SquareIndex struct {

// Tree wraps Merkle tree implementations to work with rsmt2d
type Tree interface {
Push(data []byte)
Root() []byte
Push(data []byte) error
Root() ([]byte, error)
}

var _ Tree = &DefaultTree{}
Expand All @@ -36,17 +37,18 @@ func NewDefaultTree(axis Axis, index uint) Tree {
}
}

func (d *DefaultTree) Push(data []byte) {
func (d *DefaultTree) Push(data []byte) error {
// ignore the idx, as this implementation doesn't need that info
d.leaves = append(d.leaves, data)
return nil
}

func (d *DefaultTree) Root() []byte {
func (d *DefaultTree) Root() ([]byte, error) {
if d.root == nil {
for _, l := range d.leaves {
d.Tree.Push(l)
}
d.root = d.Tree.Root()
}
return d.root
return d.root, nil
}

0 comments on commit 6515446

Please sign in to comment.