Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(ipld): use Set/GetCell API from rstm2d #1173

Merged
merged 1 commit into from
Nov 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 43 additions & 55 deletions share/eds/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,23 +104,18 @@ func (r *Retriever) Retrieve(ctx context.Context, dah *da.DataAvailabilityHeader
// quadrant request retries. Also, provides an API
// to reconstruct the block once enough shares are fetched.
type retrievalSession struct {
dah *da.DataAvailabilityHeader
bget blockservice.BlockGetter
adder *ipld.NmtNodeAdder

treeFn rsmt2d.TreeConstructorFn
codec rsmt2d.Codec

dah *da.DataAvailabilityHeader
squareImported *rsmt2d.ExtendedDataSquare

quadrants []*quadrant
sharesLks []sync.Mutex
sharesCount uint32

squareLk sync.RWMutex
square [][]byte
squareSig chan struct{}
squareDn chan struct{}
// TODO(@Wondertan): Extract into a separate data structure https://github.com/celestiaorg/rsmt2d/issues/135
squareQuadrants []*quadrant
squareCellsLks [][]sync.Mutex
squareCellsCount uint32
squareSig chan struct{}
squareDn chan struct{}
squareLk sync.RWMutex
square *rsmt2d.ExtendedDataSquare

span trace.Span
}
Expand All @@ -133,29 +128,31 @@ func (r *Retriever) newSession(ctx context.Context, dah *da.DataAvailabilityHead
r.bServ,
ipld.MaxSizeBatchOption(size),
)
ses := &retrievalSession{
bget: blockservice.NewSession(ctx, r.bServ),
adder: adder,
treeFn: func(_ rsmt2d.Axis, index uint) rsmt2d.Tree {
tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(size)/2, index, nmt.NodeVisitor(adder.Visit))
return &tree
},
codec: share.DefaultRSMT2DCodec(),
dah: dah,
quadrants: newQuadrants(dah),
sharesLks: make([]sync.Mutex, size*size),
square: make([][]byte, size*size),
squareSig: make(chan struct{}, 1),
squareDn: make(chan struct{}),
span: trace.SpanFromContext(ctx),

treeFn := func(_ rsmt2d.Axis, index uint) rsmt2d.Tree {
tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(size)/2, index, nmt.NodeVisitor(adder.Visit))
return &tree
}

square, err := rsmt2d.ImportExtendedDataSquare(ses.square, ses.codec, ses.treeFn)
square, err := rsmt2d.ImportExtendedDataSquare(make([][]byte, size*size), share.DefaultRSMT2DCodec(), treeFn)
if err != nil {
return nil, err
}

ses.squareImported = square
ses := &retrievalSession{
dah: dah,
bget: blockservice.NewSession(ctx, r.bServ),
adder: adder,
squareQuadrants: newQuadrants(dah),
squareCellsLks: make([][]sync.Mutex, size),
squareSig: make(chan struct{}, 1),
squareDn: make(chan struct{}),
square: square,
span: trace.SpanFromContext(ctx),
}
for i := range ses.squareCellsLks {
ses.squareCellsLks[i] = make([]sync.Mutex, size)
}
go ses.request(ctx)
return ses, nil
}
Expand All @@ -170,36 +167,24 @@ func (rs *retrievalSession) Done() <-chan struct{} {
// Reconstruct tries to reconstruct the data square and returns it on success.
func (rs *retrievalSession) Reconstruct(ctx context.Context) (*rsmt2d.ExtendedDataSquare, error) {
if rs.isReconstructed() {
return rs.squareImported, nil
return rs.square, nil
}
// prevent further writes to the square
rs.squareLk.Lock()
defer rs.squareLk.Unlock()

// TODO(@Wondertan): This is bad!
// * We should not reimport the square multiple times
// * We should set shares into imported square via
// SetShare(https://github.com/celestiaorg/rsmt2d/issues/83) to accomplish the above point.
{
squareImported, err := rsmt2d.ImportExtendedDataSquare(rs.square, rs.codec, rs.treeFn)
if err != nil {
return nil, err
}
rs.squareImported = squareImported
}

_, span := tracer.Start(ctx, "reconstruct-square")
defer span.End()

// and try to repair with what we have
err := rs.squareImported.Repair(rs.dah.RowsRoots, rs.dah.ColumnRoots)
err := rs.square.Repair(rs.dah.RowsRoots, rs.dah.ColumnRoots)
if err != nil {
span.RecordError(err)
return nil, err
}
log.Infow("data square reconstructed", "data_hash", hex.EncodeToString(rs.dah.Hash()), "size", len(rs.dah.RowsRoots))
close(rs.squareDn)
return rs.squareImported, nil
return rs.square, nil
}

// isReconstructed report true whether the square attached to the session
Expand Down Expand Up @@ -232,16 +217,16 @@ func (rs *retrievalSession) Close() error {
func (rs *retrievalSession) request(ctx context.Context) {
t := time.NewTicker(RetrieveQuadrantTimeout)
defer t.Stop()
for retry := 0; retry < len(rs.quadrants); retry++ {
q := rs.quadrants[retry]
for retry := 0; retry < len(rs.squareQuadrants); retry++ {
q := rs.squareQuadrants[retry]
log.Debugw("requesting quadrant",
"axis", q.source,
"x", q.x,
"y", q.y,
"size", len(q.roots),
)
rs.span.AddEvent("requesting quadrant", trace.WithAttributes(
attribute.Int("axis", q.source),
attribute.Int("axis", int(q.source)),
attribute.Int("x", q.x),
attribute.Int("y", q.y),
attribute.Int("size", len(q.roots)),
Expand All @@ -260,7 +245,7 @@ func (rs *retrievalSession) request(ctx context.Context) {
"size", len(q.roots),
)
rs.span.AddEvent("quadrant request timeout", trace.WithAttributes(
attribute.Int("axis", q.source),
attribute.Int("axis", int(q.source)),
attribute.Int("x", q.x),
attribute.Int("y", q.y),
attribute.Int("size", len(q.roots)),
Expand Down Expand Up @@ -292,10 +277,10 @@ func (rs *retrievalSession) doRequest(ctx context.Context, q *quadrant) {
// in the square.
// NOTE-2: We never actually fetch shares from the network *twice*.
// Once a share is downloaded from the network it is cached on the IPLD(blockservice) level.
// calc index of the share
idx := q.index(i, j)
// calc position of the share
x, y := q.pos(i, j)
// try to lock the share
ok := rs.sharesLks[idx].TryLock()
ok := rs.squareCellsLks[x][y].TryLock()
if !ok {
// if already locked and written - do nothing
return
Expand All @@ -312,14 +297,17 @@ func (rs *retrievalSession) doRequest(ctx context.Context, q *quadrant) {
if rs.isReconstructed() {
return
}
rs.square[idx] = share
if rs.square.GetCell(uint(x), uint(y)) != nil {
return
}
rs.square.SetCell(uint(x), uint(y), share)
// if we have >= 1/4 of the square we can start trying to Reconstruct
// TODO(@Wondertan): This is not an ideal way to know when to start
// reconstruction and can cause idle reconstruction tries in some cases,
// but it is totally fine for the happy case and for now.
// The earlier we correctly know that we have the full square - the earlier
// we cancel ongoing requests - the less data is being wastedly transferred.
if atomic.AddUint32(&rs.sharesCount, 1) >= uint32(size*size) {
if atomic.AddUint32(&rs.squareCellsCount, 1) >= uint32(size*size) {
select {
case rs.squareSig <- struct{}{}:
default:
Expand Down
55 changes: 17 additions & 38 deletions share/eds/retriever_quadrant.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package eds

import (
"math"
"math/rand"
"time"

"github.com/ipfs/go-cid"

"github.com/celestiaorg/celestia-app/pkg/da"
"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share/ipld"
)
Expand Down Expand Up @@ -42,10 +42,8 @@ type quadrant struct {
// |(0;1)| |(1;1)|
// ------ -------
x, y int
// source defines the axis for quadrant
// it can be either 1 or 0 similar to x and y
// where 0 is Row source and 1 is Col respectively
source int
// source defines the axis(Row or Col) to fetch the quadrant from
source rsmt2d.Axis
}

// newQuadrants constructs a slice of quadrants from DAHeader.
Expand All @@ -70,17 +68,13 @@ func newQuadrants(dah *da.DataAvailabilityHeader) []*quadrant {
}

for i := range quadrants {
// convert quadrant index into coordinates
// convert quadrant 1D into into 2D coordinates
x, y := i%2, i/2
if source == 1 { // swap coordinates for column
x, y = y, x
}

quadrants[i] = &quadrant{
roots: roots[qsize*y : qsize*(y+1)],
x: x,
y: y,
source: source,
source: rsmt2d.Axis(source),
}
}
}
Expand All @@ -93,31 +87,16 @@ func newQuadrants(dah *da.DataAvailabilityHeader) []*quadrant {
return quadrants
}

// index calculates index for a share in a data square slice flattened by rows.
//
// NOTE: The complexity of the formula below comes from:
// - Goal to avoid share copying
// - Goal to make formula generic for both rows and cols
// - While data square is flattened by rows only
//
// TODO(@Wondertan): This can be simplified by making rsmt2d working over 3D byte slice(not
// flattened)
func (q *quadrant) index(rootIdx, cellIdx int) int {
size := len(q.roots)
// half square offsets, e.g. share is from Q3,
// so we add to index Q1+Q2
halfSquareOffsetCol := pow(size*2, q.source)
halfSquareOffsetRow := pow(size*2, q.source^1)
// offsets for the axis, e.g. share is from Q4.
// so we add to index Q3
offsetX := q.x * halfSquareOffsetCol * size
offsetY := q.y * halfSquareOffsetRow * size

rootIdx *= halfSquareOffsetRow
cellIdx *= halfSquareOffsetCol
return rootIdx + cellIdx + offsetX + offsetY
}

func pow(x, y int) int {
return int(math.Pow(float64(x), float64(y)))
// pos calculates position of a share in a data square.
func (q *quadrant) pos(rootIdx, cellIdx int) (int, int) {
cellIdx += len(q.roots) * q.x
rootIdx += len(q.roots) * q.y
switch q.source {
case rsmt2d.Row:
return rootIdx, cellIdx
case rsmt2d.Col:
return cellIdx, rootIdx
default:
panic("unknown axis")
}
distractedm1nd marked this conversation as resolved.
Show resolved Hide resolved
}