diff --git a/share/eds/retriever.go b/share/eds/retriever.go index a7b9ab16f3..2e8713c72f 100644 --- a/share/eds/retriever.go +++ b/share/eds/retriever.go @@ -104,23 +104,18 @@ func (r *Retriever) Retrieve(ctx context.Context, dah *da.DataAvailabilityHeader // quadrant request retries. Also, provides an API // to reconstruct the block once enough shares are fetched. type retrievalSession struct { + dah *da.DataAvailabilityHeader bget blockservice.BlockGetter adder *ipld.NmtNodeAdder - treeFn rsmt2d.TreeConstructorFn - codec rsmt2d.Codec - - dah *da.DataAvailabilityHeader - squareImported *rsmt2d.ExtendedDataSquare - - quadrants []*quadrant - sharesLks []sync.Mutex - sharesCount uint32 - - squareLk sync.RWMutex - square [][]byte - squareSig chan struct{} - squareDn chan struct{} + // TODO(@Wondertan): Extract into a separate data structure https://github.com/celestiaorg/rsmt2d/issues/135 + squareQuadrants []*quadrant + squareCellsLks [][]sync.Mutex + squareCellsCount uint32 + squareSig chan struct{} + squareDn chan struct{} + squareLk sync.RWMutex + square *rsmt2d.ExtendedDataSquare span trace.Span } @@ -133,29 +128,31 @@ func (r *Retriever) newSession(ctx context.Context, dah *da.DataAvailabilityHead r.bServ, ipld.MaxSizeBatchOption(size), ) - ses := &retrievalSession{ - bget: blockservice.NewSession(ctx, r.bServ), - adder: adder, - treeFn: func(_ rsmt2d.Axis, index uint) rsmt2d.Tree { - tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(size)/2, index, nmt.NodeVisitor(adder.Visit)) - return &tree - }, - codec: share.DefaultRSMT2DCodec(), - dah: dah, - quadrants: newQuadrants(dah), - sharesLks: make([]sync.Mutex, size*size), - square: make([][]byte, size*size), - squareSig: make(chan struct{}, 1), - squareDn: make(chan struct{}), - span: trace.SpanFromContext(ctx), + + treeFn := func(_ rsmt2d.Axis, index uint) rsmt2d.Tree { + tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(size)/2, index, nmt.NodeVisitor(adder.Visit)) + return &tree } - square, err := rsmt2d.ImportExtendedDataSquare(ses.square, ses.codec, ses.treeFn) + square, err := rsmt2d.ImportExtendedDataSquare(make([][]byte, size*size), share.DefaultRSMT2DCodec(), treeFn) if err != nil { return nil, err } - ses.squareImported = square + ses := &retrievalSession{ + dah: dah, + bget: blockservice.NewSession(ctx, r.bServ), + adder: adder, + squareQuadrants: newQuadrants(dah), + squareCellsLks: make([][]sync.Mutex, size), + squareSig: make(chan struct{}, 1), + squareDn: make(chan struct{}), + square: square, + span: trace.SpanFromContext(ctx), + } + for i := range ses.squareCellsLks { + ses.squareCellsLks[i] = make([]sync.Mutex, size) + } go ses.request(ctx) return ses, nil } @@ -170,36 +167,24 @@ func (rs *retrievalSession) Done() <-chan struct{} { // Reconstruct tries to reconstruct the data square and returns it on success. func (rs *retrievalSession) Reconstruct(ctx context.Context) (*rsmt2d.ExtendedDataSquare, error) { if rs.isReconstructed() { - return rs.squareImported, nil + return rs.square, nil } // prevent further writes to the square rs.squareLk.Lock() defer rs.squareLk.Unlock() - // TODO(@Wondertan): This is bad! - // * We should not reimport the square multiple times - // * We should set shares into imported square via - // SetShare(https://github.com/celestiaorg/rsmt2d/issues/83) to accomplish the above point. - { - squareImported, err := rsmt2d.ImportExtendedDataSquare(rs.square, rs.codec, rs.treeFn) - if err != nil { - return nil, err - } - rs.squareImported = squareImported - } - _, span := tracer.Start(ctx, "reconstruct-square") defer span.End() // and try to repair with what we have - err := rs.squareImported.Repair(rs.dah.RowsRoots, rs.dah.ColumnRoots) + err := rs.square.Repair(rs.dah.RowsRoots, rs.dah.ColumnRoots) if err != nil { span.RecordError(err) return nil, err } log.Infow("data square reconstructed", "data_hash", hex.EncodeToString(rs.dah.Hash()), "size", len(rs.dah.RowsRoots)) close(rs.squareDn) - return rs.squareImported, nil + return rs.square, nil } // isReconstructed report true whether the square attached to the session @@ -232,8 +217,8 @@ func (rs *retrievalSession) Close() error { func (rs *retrievalSession) request(ctx context.Context) { t := time.NewTicker(RetrieveQuadrantTimeout) defer t.Stop() - for retry := 0; retry < len(rs.quadrants); retry++ { - q := rs.quadrants[retry] + for retry := 0; retry < len(rs.squareQuadrants); retry++ { + q := rs.squareQuadrants[retry] log.Debugw("requesting quadrant", "axis", q.source, "x", q.x, @@ -241,7 +226,7 @@ func (rs *retrievalSession) request(ctx context.Context) { "size", len(q.roots), ) rs.span.AddEvent("requesting quadrant", trace.WithAttributes( - attribute.Int("axis", q.source), + attribute.Int("axis", int(q.source)), attribute.Int("x", q.x), attribute.Int("y", q.y), attribute.Int("size", len(q.roots)), @@ -260,7 +245,7 @@ func (rs *retrievalSession) request(ctx context.Context) { "size", len(q.roots), ) rs.span.AddEvent("quadrant request timeout", trace.WithAttributes( - attribute.Int("axis", q.source), + attribute.Int("axis", int(q.source)), attribute.Int("x", q.x), attribute.Int("y", q.y), attribute.Int("size", len(q.roots)), @@ -292,10 +277,10 @@ func (rs *retrievalSession) doRequest(ctx context.Context, q *quadrant) { // in the square. // NOTE-2: We never actually fetch shares from the network *twice*. // Once a share is downloaded from the network it is cached on the IPLD(blockservice) level. - // calc index of the share - idx := q.index(i, j) + // calc position of the share + x, y := q.pos(i, j) // try to lock the share - ok := rs.sharesLks[idx].TryLock() + ok := rs.squareCellsLks[x][y].TryLock() if !ok { // if already locked and written - do nothing return @@ -312,14 +297,17 @@ func (rs *retrievalSession) doRequest(ctx context.Context, q *quadrant) { if rs.isReconstructed() { return } - rs.square[idx] = share + if rs.square.GetCell(uint(x), uint(y)) != nil { + return + } + rs.square.SetCell(uint(x), uint(y), share) // if we have >= 1/4 of the square we can start trying to Reconstruct // TODO(@Wondertan): This is not an ideal way to know when to start // reconstruction and can cause idle reconstruction tries in some cases, // but it is totally fine for the happy case and for now. // The earlier we correctly know that we have the full square - the earlier // we cancel ongoing requests - the less data is being wastedly transferred. - if atomic.AddUint32(&rs.sharesCount, 1) >= uint32(size*size) { + if atomic.AddUint32(&rs.squareCellsCount, 1) >= uint32(size*size) { select { case rs.squareSig <- struct{}{}: default: diff --git a/share/eds/retriever_quadrant.go b/share/eds/retriever_quadrant.go index 9a637a27be..8b8037ce85 100644 --- a/share/eds/retriever_quadrant.go +++ b/share/eds/retriever_quadrant.go @@ -1,13 +1,13 @@ package eds import ( - "math" "math/rand" "time" "github.com/ipfs/go-cid" "github.com/celestiaorg/celestia-app/pkg/da" + "github.com/celestiaorg/rsmt2d" "github.com/celestiaorg/celestia-node/share/ipld" ) @@ -42,10 +42,8 @@ type quadrant struct { // |(0;1)| |(1;1)| // ------ ------- x, y int - // source defines the axis for quadrant - // it can be either 1 or 0 similar to x and y - // where 0 is Row source and 1 is Col respectively - source int + // source defines the axis(Row or Col) to fetch the quadrant from + source rsmt2d.Axis } // newQuadrants constructs a slice of quadrants from DAHeader. @@ -70,17 +68,13 @@ func newQuadrants(dah *da.DataAvailabilityHeader) []*quadrant { } for i := range quadrants { - // convert quadrant index into coordinates + // convert quadrant 1D into into 2D coordinates x, y := i%2, i/2 - if source == 1 { // swap coordinates for column - x, y = y, x - } - quadrants[i] = &quadrant{ roots: roots[qsize*y : qsize*(y+1)], x: x, y: y, - source: source, + source: rsmt2d.Axis(source), } } } @@ -93,31 +87,16 @@ func newQuadrants(dah *da.DataAvailabilityHeader) []*quadrant { return quadrants } -// index calculates index for a share in a data square slice flattened by rows. -// -// NOTE: The complexity of the formula below comes from: -// - Goal to avoid share copying -// - Goal to make formula generic for both rows and cols -// - While data square is flattened by rows only -// -// TODO(@Wondertan): This can be simplified by making rsmt2d working over 3D byte slice(not -// flattened) -func (q *quadrant) index(rootIdx, cellIdx int) int { - size := len(q.roots) - // half square offsets, e.g. share is from Q3, - // so we add to index Q1+Q2 - halfSquareOffsetCol := pow(size*2, q.source) - halfSquareOffsetRow := pow(size*2, q.source^1) - // offsets for the axis, e.g. share is from Q4. - // so we add to index Q3 - offsetX := q.x * halfSquareOffsetCol * size - offsetY := q.y * halfSquareOffsetRow * size - - rootIdx *= halfSquareOffsetRow - cellIdx *= halfSquareOffsetCol - return rootIdx + cellIdx + offsetX + offsetY -} - -func pow(x, y int) int { - return int(math.Pow(float64(x), float64(y))) +// pos calculates position of a share in a data square. +func (q *quadrant) pos(rootIdx, cellIdx int) (int, int) { + cellIdx += len(q.roots) * q.x + rootIdx += len(q.roots) * q.y + switch q.source { + case rsmt2d.Row: + return rootIdx, cellIdx + case rsmt2d.Col: + return cellIdx, rootIdx + default: + panic("unknown axis") + } }