From f266624efba4db4d6a05399e84b1f5f4b9f89b57 Mon Sep 17 00:00:00 2001 From: Vlad <13818348+walldiss@users.noreply.github.com> Date: Sat, 6 Jul 2024 21:30:34 +0200 Subject: [PATCH] integrate shwap into shrex --- libs/utils/close.go | 13 + nodebuilder/blob/blob.go | 4 +- nodebuilder/share/opts.go | 2 +- share/availability/full/testing.go | 2 +- share/shwap/namespace_data_id.go | 123 +++ share/shwap/namespace_data_id_test.go | 28 + share/shwap/p2p/discovery/backoff.go | 116 +++ share/shwap/p2p/discovery/backoff_test.go | 47 + share/shwap/p2p/discovery/discovery.go | 387 ++++++++ share/shwap/p2p/discovery/discovery_test.go | 210 ++++ share/shwap/p2p/discovery/metrics.go | 172 ++++ share/shwap/p2p/discovery/options.go | 67 ++ share/shwap/p2p/discovery/set.go | 93 ++ share/shwap/p2p/discovery/set_test.go | 87 ++ share/shwap/p2p/shrex/doc.go | 18 + share/shwap/p2p/shrex/errors.go | 17 + share/shwap/p2p/shrex/metrics.go | 73 ++ share/shwap/p2p/shrex/middleware.go | 48 + share/shwap/p2p/shrex/params.go | 69 ++ share/shwap/p2p/shrex/peers/doc.go | 52 + share/shwap/p2p/shrex/peers/manager.go | 526 ++++++++++ share/shwap/p2p/shrex/peers/manager_test.go | 569 +++++++++++ share/shwap/p2p/shrex/peers/metrics.go | 276 ++++++ share/shwap/p2p/shrex/peers/options.go | 84 ++ share/shwap/p2p/shrex/peers/pool.go | 226 +++++ share/shwap/p2p/shrex/peers/pool_test.go | 184 ++++ share/shwap/p2p/shrex/peers/timedqueue.go | 91 ++ .../shwap/p2p/shrex/peers/timedqueue_test.go | 60 ++ share/shwap/p2p/shrex/recovery.go | 21 + share/shwap/p2p/shrex/shrexeds/client.go | 219 +++++ share/shwap/p2p/shrex/shrexeds/doc.go | 51 + .../shwap/p2p/shrex/shrexeds/exchange_test.go | 159 +++ share/shwap/p2p/shrex/shrexeds/params.go | 54 ++ .../shrexeds/pb/extended_data_square.pb.go | 338 +++++++ .../shrexeds/pb/extended_data_square.proto | 14 + share/shwap/p2p/shrex/shrexeds/server.go | 194 ++++ share/shwap/p2p/shrex/shrexnd/client.go | 233 +++++ share/shwap/p2p/shrex/shrexnd/doc.go | 43 + .../shwap/p2p/shrex/shrexnd/exchange_test.go | 125 +++ share/shwap/p2p/shrex/shrexnd/params.go | 38 + .../shrex/shrexnd/pb/row_namespace_data.pb.go | 576 +++++++++++ .../shrex/shrexnd/pb/row_namespace_data.proto | 20 + share/shwap/p2p/shrex/shrexnd/server.go | 252 +++++ share/shwap/p2p/shrex/shrexsub/doc.go | 58 ++ .../p2p/shrex/shrexsub/pb/notification.pb.go | 355 +++++++ .../p2p/shrex/shrexsub/pb/notification.proto | 9 + share/shwap/p2p/shrex/shrexsub/pubsub.go | 146 +++ share/shwap/p2p/shrex/shrexsub/pubsub_test.go | 123 +++ .../shwap/p2p/shrex/shrexsub/subscription.go | 51 + share/shwap/pb/shwap.pb.go | 903 +++++++++++++++++- share/shwap/pb/shwap.proto | 19 + state/core_access.go | 9 +- store/store.go | 11 +- 53 files changed, 7608 insertions(+), 57 deletions(-) create mode 100644 libs/utils/close.go create mode 100644 share/shwap/namespace_data_id.go create mode 100644 share/shwap/namespace_data_id_test.go create mode 100644 share/shwap/p2p/discovery/backoff.go create mode 100644 share/shwap/p2p/discovery/backoff_test.go create mode 100644 share/shwap/p2p/discovery/discovery.go create mode 100644 share/shwap/p2p/discovery/discovery_test.go create mode 100644 share/shwap/p2p/discovery/metrics.go create mode 100644 share/shwap/p2p/discovery/options.go create mode 100644 share/shwap/p2p/discovery/set.go create mode 100644 share/shwap/p2p/discovery/set_test.go create mode 100644 share/shwap/p2p/shrex/doc.go create mode 100644 share/shwap/p2p/shrex/errors.go create mode 100644 share/shwap/p2p/shrex/metrics.go create mode 100644 share/shwap/p2p/shrex/middleware.go create mode 100644 share/shwap/p2p/shrex/params.go create mode 100644 share/shwap/p2p/shrex/peers/doc.go create mode 100644 share/shwap/p2p/shrex/peers/manager.go create mode 100644 share/shwap/p2p/shrex/peers/manager_test.go create mode 100644 share/shwap/p2p/shrex/peers/metrics.go create mode 100644 share/shwap/p2p/shrex/peers/options.go create mode 100644 share/shwap/p2p/shrex/peers/pool.go create mode 100644 share/shwap/p2p/shrex/peers/pool_test.go create mode 100644 share/shwap/p2p/shrex/peers/timedqueue.go create mode 100644 share/shwap/p2p/shrex/peers/timedqueue_test.go create mode 100644 share/shwap/p2p/shrex/recovery.go create mode 100644 share/shwap/p2p/shrex/shrexeds/client.go create mode 100644 share/shwap/p2p/shrex/shrexeds/doc.go create mode 100644 share/shwap/p2p/shrex/shrexeds/exchange_test.go create mode 100644 share/shwap/p2p/shrex/shrexeds/params.go create mode 100644 share/shwap/p2p/shrex/shrexeds/pb/extended_data_square.pb.go create mode 100644 share/shwap/p2p/shrex/shrexeds/pb/extended_data_square.proto create mode 100644 share/shwap/p2p/shrex/shrexeds/server.go create mode 100644 share/shwap/p2p/shrex/shrexnd/client.go create mode 100644 share/shwap/p2p/shrex/shrexnd/doc.go create mode 100644 share/shwap/p2p/shrex/shrexnd/exchange_test.go create mode 100644 share/shwap/p2p/shrex/shrexnd/params.go create mode 100644 share/shwap/p2p/shrex/shrexnd/pb/row_namespace_data.pb.go create mode 100644 share/shwap/p2p/shrex/shrexnd/pb/row_namespace_data.proto create mode 100644 share/shwap/p2p/shrex/shrexnd/server.go create mode 100644 share/shwap/p2p/shrex/shrexsub/doc.go create mode 100644 share/shwap/p2p/shrex/shrexsub/pb/notification.pb.go create mode 100644 share/shwap/p2p/shrex/shrexsub/pb/notification.proto create mode 100644 share/shwap/p2p/shrex/shrexsub/pubsub.go create mode 100644 share/shwap/p2p/shrex/shrexsub/pubsub_test.go create mode 100644 share/shwap/p2p/shrex/shrexsub/subscription.go diff --git a/libs/utils/close.go b/libs/utils/close.go new file mode 100644 index 0000000000..cbcc7b6f67 --- /dev/null +++ b/libs/utils/close.go @@ -0,0 +1,13 @@ +package utils + +import ( + "io" + + logging "github.com/ipfs/go-log/v2" +) + +func CloseAndLog(log logging.StandardLogger, name string, closer io.Closer) { + if err := closer.Close(); err != nil { + log.Warnf("closing %s: %s", name, err) + } +} diff --git a/nodebuilder/blob/blob.go b/nodebuilder/blob/blob.go index c4c0352516..0837aa23d2 100644 --- a/nodebuilder/blob/blob.go +++ b/nodebuilder/blob/blob.go @@ -23,8 +23,8 @@ type Module interface { // If all blobs were found without any errors, the user will receive a list of blobs. // If the BlobService couldn't find any blobs under the requested namespaces, // the user will receive an empty list of blobs along with an empty error. - // If some of the requested namespaces were not found, the user will receive all the found blobs and an empty error. - // If there were internal errors during some of the requests, + // If some of the requested namespaces were not found, the user will receive all the found blobs + // and an empty error. If there were internal errors during some of the requests, // the user will receive all found blobs along with a combined error message. // // All blobs will preserve the order of the namespaces that were requested. diff --git a/nodebuilder/share/opts.go b/nodebuilder/share/opts.go index 9c122b7b0f..60ceabbb2c 100644 --- a/nodebuilder/share/opts.go +++ b/nodebuilder/share/opts.go @@ -5,10 +5,10 @@ import ( "github.com/celestiaorg/celestia-node/share/eds" "github.com/celestiaorg/celestia-node/share/getters" - disc "github.com/celestiaorg/celestia-node/share/p2p/discovery" "github.com/celestiaorg/celestia-node/share/p2p/peers" "github.com/celestiaorg/celestia-node/share/p2p/shrexeds" "github.com/celestiaorg/celestia-node/share/p2p/shrexnd" + disc "github.com/celestiaorg/celestia-node/share/shwap/p2p/discovery" ) // WithPeerManagerMetrics is a utility function to turn on peer manager metrics and that is diff --git a/share/availability/full/testing.go b/share/availability/full/testing.go index 7379c83441..4928250dc1 100644 --- a/share/availability/full/testing.go +++ b/share/availability/full/testing.go @@ -13,7 +13,7 @@ import ( "github.com/celestiaorg/celestia-node/share/eds" "github.com/celestiaorg/celestia-node/share/getters" "github.com/celestiaorg/celestia-node/share/ipld" - "github.com/celestiaorg/celestia-node/share/p2p/discovery" + "github.com/celestiaorg/celestia-node/share/shwap/p2p/discovery" ) // GetterWithRandSquare provides a share.Getter filled with 'n' NMT diff --git a/share/shwap/namespace_data_id.go b/share/shwap/namespace_data_id.go new file mode 100644 index 0000000000..b5d51449cb --- /dev/null +++ b/share/shwap/namespace_data_id.go @@ -0,0 +1,123 @@ +package shwap + +import ( + "encoding/binary" + "fmt" + + "github.com/celestiaorg/celestia-node/share" +) + +// NamespaceDataIDSize defines the total size of a RowNamespaceDataID in bytes, combining the +// size of a RowID and the size of a Namespace. +const NamespaceDataIDSize = EdsIDSize + 4 + share.NamespaceSize + +// RowNamespaceDataID uniquely identifies a piece of namespaced data within a row of an Extended +// Data Square (EDS). +type NamespaceDataID struct { + // Embedding EdsID to include the block height in RowID. + EdsID + // FromRow and ToRow specify the range of rows within the data square. + FromRowIndex, ToRowIndex int + // DataNamespace is a string representation of the namespace to facilitate comparisons. + DataNamespace share.Namespace +} + +// NewNamespaceDataID creates a new RowNamespaceDataID with the specified parameters. It +// validates the RowNamespaceDataID against the provided Root before returning. +func NewNamespaceDataID( + height uint64, + fromRowIndex, toRowIndex int, + namespace share.Namespace, + edsSize int, +) (NamespaceDataID, error) { + ndid := NamespaceDataID{ + EdsID: EdsID{ + Height: height, + }, + FromRowIndex: fromRowIndex, + ToRowIndex: toRowIndex, + DataNamespace: namespace, + } + + if err := ndid.Verify(edsSize); err != nil { + return NamespaceDataID{}, err + } + return ndid, nil +} + +// NamespaceDataIDFromBinary deserializes a RowNamespaceDataID from its binary form. It returns +// an error if the binary data's length does not match the expected size. +func NamespaceDataIDFromBinary(data []byte) (NamespaceDataID, error) { + if len(data) != NamespaceDataIDSize { + return NamespaceDataID{}, + fmt.Errorf("invalid RowNamespaceDataID length: expected %d, got %d", RowNamespaceDataIDSize, len(data)) + } + + edsID, err := EdsIDFromBinary(data[:EdsIDSize]) + if err != nil { + return NamespaceDataID{}, fmt.Errorf("error unmarshaling RowID: %w", err) + } + + fromRowIndex := int(binary.BigEndian.Uint16(data[EdsIDSize:])) + toRowIndex := int(binary.BigEndian.Uint16(data[EdsIDSize+2:])) + ns := share.Namespace(data[EdsIDSize+4:]) + if err := ns.ValidateForData(); err != nil { + return NamespaceDataID{}, fmt.Errorf("error validating DataNamespace: %w", err) + } + + return NamespaceDataID{ + EdsID: edsID, + FromRowIndex: fromRowIndex, + ToRowIndex: toRowIndex, + DataNamespace: ns, + }, nil +} + +// MarshalBinary encodes RowNamespaceDataID into binary form. +// NOTE: Proto is avoided because +// * Its size is not deterministic which is required for IPLD. +// * No support for uint16 +func (ndid NamespaceDataID) MarshalBinary() ([]byte, error) { + data := make([]byte, 0, NamespaceDataIDSize) + return ndid.appendTo(data), nil +} + +// Verify checks the validity of RowNamespaceDataID's fields, including the RowID and the +// namespace. +func (ndid NamespaceDataID) Verify(edsSize int) error { + if ndid.FromRowIndex >= edsSize { + return fmt.Errorf("FromRowIndex: %w: %d >= %d", ErrOutOfBounds, ndid.FromRowIndex, edsSize) + } + if ndid.ToRowIndex >= edsSize { + return fmt.Errorf("ToRowIndex: %w: %d >= %d", ErrOutOfBounds, ndid.ToRowIndex, edsSize) + } + return ndid.Validate() +} + +func (ndid NamespaceDataID) Validate() error { + if err := ndid.EdsID.Validate(); err != nil { + return fmt.Errorf("error validating RowID: %w", err) + } + if ndid.FromRowIndex > ndid.ToRowIndex { + return fmt.Errorf("%w: FromRowIndex %d is greater than ToRowIndex %d", + ErrInvalidShwapID, ndid.FromRowIndex, ndid.ToRowIndex) + } + if ndid.FromRowIndex < 0 { + return fmt.Errorf("%w: FromRowIndex %d", ErrInvalidShwapID, ndid.FromRowIndex) + } + if ndid.ToRowIndex < 0 { + return fmt.Errorf("%w: ToRowIndex %d", ErrInvalidShwapID, ndid.ToRowIndex) + } + if err := ndid.DataNamespace.ValidateForData(); err != nil { + return fmt.Errorf("%w: error validating DataNamespace: %w", ErrInvalidShwapID, err) + } + return nil +} + +// appendTo helps in appending the binary form of DataNamespace to the serialized RowID data. +func (ndid NamespaceDataID) appendTo(data []byte) []byte { + data = ndid.EdsID.appendTo(data) + data = binary.BigEndian.AppendUint16(data, uint16(ndid.FromRowIndex)) + data = binary.BigEndian.AppendUint16(data, uint16(ndid.ToRowIndex)) + return append(data, ndid.DataNamespace...) +} diff --git a/share/shwap/namespace_data_id_test.go b/share/shwap/namespace_data_id_test.go new file mode 100644 index 0000000000..6430cd5be6 --- /dev/null +++ b/share/shwap/namespace_data_id_test.go @@ -0,0 +1,28 @@ +package shwap + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/celestia-node/share/sharetest" +) + +func TestNamespaceDataID(t *testing.T) { + odsSize := 4 + ns := sharetest.RandV0Namespace() + + id, err := NewNamespaceDataID(1, 1, 2, ns, odsSize*2) + require.NoError(t, err) + + data, err := id.MarshalBinary() + require.NoError(t, err) + + sidOut, err := NamespaceDataIDFromBinary(data) + require.NoError(t, err) + assert.EqualValues(t, id, sidOut) + + err = sidOut.Verify(odsSize * 2) + require.NoError(t, err) +} diff --git a/share/shwap/p2p/discovery/backoff.go b/share/shwap/p2p/discovery/backoff.go new file mode 100644 index 0000000000..7294915727 --- /dev/null +++ b/share/shwap/p2p/discovery/backoff.go @@ -0,0 +1,116 @@ +package discovery + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/p2p/discovery/backoff" +) + +const ( + // gcInterval is a default period after which disconnected peers will be removed from cache + gcInterval = time.Minute + // connectTimeout is the timeout used for dialing peers and discovering peer addresses. + connectTimeout = time.Minute * 2 +) + +var ( + defaultBackoffFactory = backoff.NewFixedBackoff(time.Minute * 10) + errBackoffNotEnded = errors.New("share/discovery: backoff period has not ended") +) + +// backoffConnector wraps a libp2p.Host to establish a connection with peers +// with adding a delay for the next connection attempt. +type backoffConnector struct { + h host.Host + backoff backoff.BackoffFactory + + cacheLk sync.Mutex + cacheData map[peer.ID]backoffData +} + +// backoffData stores time when next connection attempt with the remote peer. +type backoffData struct { + nexttry time.Time + backoff backoff.BackoffStrategy +} + +func newBackoffConnector(h host.Host, factory backoff.BackoffFactory) *backoffConnector { + return &backoffConnector{ + h: h, + backoff: factory, + cacheData: make(map[peer.ID]backoffData), + } +} + +// Connect puts peer to the backoffCache and tries to establish a connection with it. +func (b *backoffConnector) Connect(ctx context.Context, p peer.AddrInfo) error { + if b.HasBackoff(p.ID) { + return errBackoffNotEnded + } + + ctx, cancel := context.WithTimeout(ctx, connectTimeout) + defer cancel() + + err := b.h.Connect(ctx, p) + // we don't want to add backoff when the context is canceled. + if !errors.Is(err, context.Canceled) { + b.Backoff(p.ID) + } + return err +} + +// Backoff adds or extends backoff delay for the peer. +func (b *backoffConnector) Backoff(p peer.ID) { + b.cacheLk.Lock() + defer b.cacheLk.Unlock() + + data, ok := b.cacheData[p] + if !ok { + data = backoffData{} + data.backoff = b.backoff() + b.cacheData[p] = data + } + + data.nexttry = time.Now().Add(data.backoff.Delay()) + b.cacheData[p] = data +} + +// HasBackoff checks if peer is in backoff. +func (b *backoffConnector) HasBackoff(p peer.ID) bool { + b.cacheLk.Lock() + cache, ok := b.cacheData[p] + b.cacheLk.Unlock() + return ok && time.Now().Before(cache.nexttry) +} + +// GC is a perpetual GCing loop. +func (b *backoffConnector) GC(ctx context.Context) { + ticker := time.NewTicker(gcInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + b.cacheLk.Lock() + for id, cache := range b.cacheData { + if cache.nexttry.Before(time.Now()) { + delete(b.cacheData, id) + } + } + b.cacheLk.Unlock() + } + } +} + +func (b *backoffConnector) Size() int { + b.cacheLk.Lock() + defer b.cacheLk.Unlock() + return len(b.cacheData) +} diff --git a/share/shwap/p2p/discovery/backoff_test.go b/share/shwap/p2p/discovery/backoff_test.go new file mode 100644 index 0000000000..24814ed199 --- /dev/null +++ b/share/shwap/p2p/discovery/backoff_test.go @@ -0,0 +1,47 @@ +package discovery + +import ( + "context" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/p2p/discovery/backoff" + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" + "github.com/stretchr/testify/require" +) + +func TestBackoff_ConnectPeer(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + t.Cleanup(cancel) + m, err := mocknet.FullMeshLinked(2) + require.NoError(t, err) + b := newBackoffConnector(m.Hosts()[0], backoff.NewFixedBackoff(time.Minute)) + info := host.InfoFromHost(m.Hosts()[1]) + require.NoError(t, b.Connect(ctx, *info)) +} + +func TestBackoff_ConnectPeerFails(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + t.Cleanup(cancel) + m, err := mocknet.FullMeshLinked(2) + require.NoError(t, err) + b := newBackoffConnector(m.Hosts()[0], backoff.NewFixedBackoff(time.Minute)) + info := host.InfoFromHost(m.Hosts()[1]) + require.NoError(t, b.Connect(ctx, *info)) + + require.Error(t, b.Connect(ctx, *info)) +} + +func TestBackoff_ResetBackoffPeriod(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + t.Cleanup(cancel) + m, err := mocknet.FullMeshLinked(2) + require.NoError(t, err) + b := newBackoffConnector(m.Hosts()[0], backoff.NewFixedBackoff(time.Minute)) + info := host.InfoFromHost(m.Hosts()[1]) + require.NoError(t, b.Connect(ctx, *info)) + nexttry := b.cacheData[info.ID].nexttry + b.Backoff(info.ID) + require.True(t, b.cacheData[info.ID].nexttry.After(nexttry)) +} diff --git a/share/shwap/p2p/discovery/discovery.go b/share/shwap/p2p/discovery/discovery.go new file mode 100644 index 0000000000..f2ca04bbbe --- /dev/null +++ b/share/shwap/p2p/discovery/discovery.go @@ -0,0 +1,387 @@ +package discovery + +import ( + "context" + "errors" + "fmt" + "time" + + logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-libp2p/core/discovery" + "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/p2p/host/eventbus" + "golang.org/x/sync/errgroup" +) + +var log = logging.Logger("share/discovery") + +const ( + // eventbusBufSize is the size of the buffered channel to handle + // events in libp2p. We specify a larger buffer size for the channel + // to avoid overflowing and blocking subscription during disconnection bursts. + // (by default it is 16) + eventbusBufSize = 64 + + // findPeersTimeout limits the FindPeers operation in time + findPeersTimeout = time.Minute + + // retryTimeout defines time interval between discovery and advertise attempts. + retryTimeout = time.Second + + // logInterval defines the time interval at which a warning message will be logged + // if the desired number of nodes is not detected. + logInterval = 5 * time.Minute +) + +// discoveryRetryTimeout defines time interval between discovery attempts, needed for tests +var discoveryRetryTimeout = retryTimeout + +// Discovery combines advertise and discover services and allows to store discovered nodes. +// TODO: The code here gets horribly hairy, so we should refactor this at some point +type Discovery struct { + // Tag is used as rondezvous point for discovery service + tag string + set *limitedSet + host host.Host + disc discovery.Discovery + connector *backoffConnector + // onUpdatedPeers will be called on peer set changes + onUpdatedPeers OnUpdatedPeers + + triggerDisc chan struct{} + + metrics *metrics + + cancel context.CancelFunc + + params *Parameters +} + +type OnUpdatedPeers func(peerID peer.ID, isAdded bool) + +func (f OnUpdatedPeers) add(next OnUpdatedPeers) OnUpdatedPeers { + return func(peerID peer.ID, isAdded bool) { + f(peerID, isAdded) + next(peerID, isAdded) + } +} + +// NewDiscovery constructs a new discovery. +func NewDiscovery( + params *Parameters, + h host.Host, + d discovery.Discovery, + tag string, + opts ...Option, +) (*Discovery, error) { + if err := params.Validate(); err != nil { + return nil, err + } + + if tag == "" { + return nil, fmt.Errorf("discovery: tag cannot be empty") + } + o := newOptions(opts...) + return &Discovery{ + tag: tag, + set: newLimitedSet(params.PeersLimit), + host: h, + disc: d, + connector: newBackoffConnector(h, defaultBackoffFactory), + onUpdatedPeers: o.onUpdatedPeers, + params: params, + triggerDisc: make(chan struct{}), + }, nil +} + +func (d *Discovery) Start(context.Context) error { + ctx, cancel := context.WithCancel(context.Background()) + d.cancel = cancel + + sub, err := d.host.EventBus().Subscribe(&event.EvtPeerConnectednessChanged{}, eventbus.BufSize(eventbusBufSize)) + if err != nil { + return fmt.Errorf("subscribing for connection events: %w", err) + } + + go d.discoveryLoop(ctx) + go d.disconnectsLoop(ctx, sub) + go d.connector.GC(ctx) + return nil +} + +func (d *Discovery) Stop(context.Context) error { + d.cancel() + + if err := d.metrics.close(); err != nil { + log.Warnw("failed to close metrics", "err", err) + } + + return nil +} + +// Peers provides a list of discovered peers in the given topic. +// If Discovery hasn't found any peers, it blocks until at least one peer is found. +func (d *Discovery) Peers(ctx context.Context) ([]peer.ID, error) { + return d.set.Peers(ctx) +} + +// Discard removes the peer from the peer set and rediscovers more if soft peer limit is not +// reached. Reports whether peer was removed with bool. +func (d *Discovery) Discard(id peer.ID) bool { + if !d.set.Contains(id) { + return false + } + + d.host.ConnManager().Unprotect(id, d.tag) + d.connector.Backoff(id) + d.set.Remove(id) + d.onUpdatedPeers(id, false) + log.Debugw("removed peer from the peer set", "peer", id.String()) + + if d.set.Size() < d.set.Limit() { + // trigger discovery + select { + case d.triggerDisc <- struct{}{}: + default: + } + } + + return true +} + +// Advertise is a utility function that persistently advertises a service through an Advertiser. +// TODO: Start advertising only after the reachability is confirmed by AutoNAT +func (d *Discovery) Advertise(ctx context.Context) { + timer := time.NewTimer(d.params.AdvertiseInterval) + defer timer.Stop() + for { + _, err := d.disc.Advertise(ctx, d.tag) + d.metrics.observeAdvertise(ctx, err) + if err != nil { + if ctx.Err() != nil { + return + } + log.Warnw("error advertising", "rendezvous", d.tag, "err", err) + + // we don't want retry indefinitely in busy loop + // internal discovery mechanism may need some time before attempts + errTimer := time.NewTimer(retryTimeout) + select { + case <-errTimer.C: + errTimer.Stop() + if !timer.Stop() { + <-timer.C + } + continue + case <-ctx.Done(): + errTimer.Stop() + return + } + } + + log.Debugf("advertised") + if !timer.Stop() { + <-timer.C + } + timer.Reset(d.params.AdvertiseInterval) + select { + case <-timer.C: + case <-ctx.Done(): + return + } + } +} + +// discoveryLoop ensures we always have '~peerLimit' connected peers. +// It initiates peer discovery upon request and restarts the process until the soft limit is +// reached. +func (d *Discovery) discoveryLoop(ctx context.Context) { + t := time.NewTicker(discoveryRetryTimeout) + defer t.Stop() + + warnTicker := time.NewTicker(logInterval) + defer warnTicker.Stop() + + for { + // drain all previous ticks from the channel + drainChannel(t.C) + select { + case <-t.C: + if !d.discover(ctx) { + // rerun discovery if the number of peers hasn't reached the limit + continue + } + case <-warnTicker.C: + if d.set.Size() < d.set.Limit() { + log.Warnf( + "Potentially degraded connectivity, unable to discover the desired amount of %s peers in %v. "+ + "Number of peers discovered: %d. Required: %d.", + d.tag, logInterval, d.set.Size(), d.set.Limit(), + ) + } + // Do not break the loop; just continue + continue + case <-ctx.Done(): + return + } + } +} + +// disconnectsLoop listen for disconnect events and ensures Discovery state +// is updated. +func (d *Discovery) disconnectsLoop(ctx context.Context, sub event.Subscription) { + defer sub.Close() + + for { + select { + case <-ctx.Done(): + return + case e, ok := <-sub.Out(): + if !ok { + log.Error("connection subscription was closed unexpectedly") + return + } + + if evnt := e.(event.EvtPeerConnectednessChanged); evnt.Connectedness == network.NotConnected { + d.Discard(evnt.Peer) + } + } + } +} + +// discover finds new peers and reports whether it succeeded. +func (d *Discovery) discover(ctx context.Context) bool { + size := d.set.Size() + want := d.set.Limit() - size + if want == 0 { + log.Debugw("reached soft peer limit, skipping discovery", "size", size) + return true + } + // TODO @renaynay: eventually, have a mechanism to catch if wanted amount of peers + // has not been discovered in X amount of time so that users are warned of degraded + // FN connectivity. + log.Debugw("discovering peers", "want", want) + + // we use errgroup as it provide limits + var wg errgroup.Group + // limit to minimize chances of overreaching the limit + wg.SetLimit(int(d.set.Limit())) + + findCtx, findCancel := context.WithTimeout(ctx, findPeersTimeout) + defer func() { + // some workers could still be running, wait them to finish before canceling findCtx + wg.Wait() //nolint:errcheck + findCancel() + }() + + peers, err := d.disc.FindPeers(findCtx, d.tag) + if err != nil { + log.Error("unable to start discovery", "err", err) + return false + } + + for { + select { + case p, ok := <-peers: + if !ok { + break + } + + peer := p + wg.Go(func() error { + if findCtx.Err() != nil { + log.Debug("find has been canceled, skip peer") + return nil //nolint:nilerr + } + + // we don't pass findCtx so that we don't cancel in progress connections + // that are likely to be valuable + if !d.handleDiscoveredPeer(ctx, peer) { + return nil + } + + size := d.set.Size() + log.Debugw("found peer", "peer", peer.ID.String(), "found_amount", size) + if size < d.set.Limit() { + return nil + } + + log.Infow("discovered wanted peers", "amount", size) + findCancel() // stop discovery when we are done + return nil + }) + + continue + case <-findCtx.Done(): + } + + isEnoughPeers := d.set.Size() >= d.set.Limit() + d.metrics.observeFindPeers(ctx, isEnoughPeers) + log.Debugw("discovery finished", "discovered_wanted", isEnoughPeers) + return isEnoughPeers + } +} + +// handleDiscoveredPeer adds peer to the internal if can connect or is connected. +// Report whether it succeeded. +func (d *Discovery) handleDiscoveredPeer(ctx context.Context, peer peer.AddrInfo) bool { + logger := log.With("peer", peer.ID.String()) + switch { + case peer.ID == d.host.ID(): + d.metrics.observeHandlePeer(ctx, handlePeerSkipSelf) + logger.Debug("skip handle: self discovery") + return false + case d.set.Size() >= d.set.Limit(): + d.metrics.observeHandlePeer(ctx, handlePeerEnoughPeers) + logger.Debug("skip handle: enough peers found") + return false + } + + switch d.host.Network().Connectedness(peer.ID) { + case network.Connected: + d.connector.Backoff(peer.ID) // we still have to backoff the connected peer + case network.NotConnected: + err := d.connector.Connect(ctx, peer) + if errors.Is(err, errBackoffNotEnded) { + d.metrics.observeHandlePeer(ctx, handlePeerBackoff) + logger.Debug("skip handle: backoff") + return false + } + if err != nil { + d.metrics.observeHandlePeer(ctx, handlePeerConnErr) + logger.Debugw("unable to connect", "err", err) + return false + } + default: + panic("unknown connectedness") + } + + if !d.set.Add(peer.ID) { + d.metrics.observeHandlePeer(ctx, handlePeerInSet) + logger.Debug("peer is already in discovery set") + return false + } + d.onUpdatedPeers(peer.ID, true) + d.metrics.observeHandlePeer(ctx, handlePeerConnected) + logger.Debug("added peer to set") + + // Tag to protect peer from being killed by ConnManager + // NOTE: This is does not protect from remote killing the connection. + // In the future, we should design a protocol that keeps bidirectional agreement on whether + // connection should be kept or not, similar to mesh link in GossipSub. + d.host.ConnManager().Protect(peer.ID, d.tag) + return true +} + +func drainChannel(c <-chan time.Time) { + for { + select { + case <-c: + default: + return + } + } +} diff --git a/share/shwap/p2p/discovery/discovery_test.go b/share/shwap/p2p/discovery/discovery_test.go new file mode 100644 index 0000000000..8214a2bbe0 --- /dev/null +++ b/share/shwap/p2p/discovery/discovery_test.go @@ -0,0 +1,210 @@ +//go:build !race + +package discovery + +import ( + "context" + "testing" + "time" + + dht "github.com/libp2p/go-libp2p-kad-dht" + "github.com/libp2p/go-libp2p/core/discovery" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/p2p/discovery/routing" + basic "github.com/libp2p/go-libp2p/p2p/host/basic" + "github.com/libp2p/go-libp2p/p2p/host/eventbus" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + fullNodesTag = "full" +) + +func TestDiscovery(t *testing.T) { + const nodes = 10 // higher number brings higher coverage + + discoveryRetryTimeout = time.Millisecond * 100 // defined in discovery.go + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) + t.Cleanup(cancel) + + tn := newTestnet(ctx, t) + + type peerUpdate struct { + peerID peer.ID + isAdded bool + } + updateCh := make(chan peerUpdate) + submit := func(peerID peer.ID, isAdded bool) { + updateCh <- peerUpdate{peerID: peerID, isAdded: isAdded} + } + + host, routingDisc := tn.peer() + params := DefaultParameters() + params.PeersLimit = nodes + + // start discovery listener service for peerA + peerA := tn.startNewDiscovery(params, host, routingDisc, fullNodesTag, + WithOnPeersUpdate(submit), + ) + + // start discovery advertisement services for other peers + params.AdvertiseInterval = time.Millisecond * 100 + discs := make([]*Discovery, nodes) + for i := range discs { + host, routingDisc := tn.peer() + disc, err := NewDiscovery(params, host, routingDisc, fullNodesTag) + require.NoError(t, err) + go disc.Advertise(tn.ctx) + discs[i] = tn.startNewDiscovery(params, host, routingDisc, fullNodesTag) + + select { + case res := <-updateCh: + require.Equal(t, discs[i].host.ID(), res.peerID) + require.True(t, res.isAdded) + case <-ctx.Done(): + t.Fatal("did not discover peer in time") + } + } + + assert.EqualValues(t, nodes, peerA.set.Size()) + + // disconnect peerA from all peers and check that notifications are received on updateCh channel + for _, disc := range discs { + peerID := disc.host.ID() + err := peerA.host.Network().ClosePeer(peerID) + require.NoError(t, err) + + select { + case res := <-updateCh: + require.Equal(t, peerID, res.peerID) + require.False(t, res.isAdded) + case <-ctx.Done(): + t.Fatal("did not disconnect from peer in time") + } + } + + assert.EqualValues(t, 0, peerA.set.Size()) +} + +func TestDiscoveryTagged(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + t.Cleanup(cancel) + + tn := newTestnet(ctx, t) + + // launch 2 peers, that advertise with different tags + adv1, routingDisc1 := tn.peer() + adv2, routingDisc2 := tn.peer() + + // sub will discover both peers, but on different tags + sub, routingDisc := tn.peer() + + params := DefaultParameters() + + // create 2 discovery services for sub, each with a different tag + done1 := make(chan struct{}) + tn.startNewDiscovery(params, sub, routingDisc, "tag1", + WithOnPeersUpdate(checkPeer(t, adv1.ID(), done1))) + + done2 := make(chan struct{}) + tn.startNewDiscovery(params, sub, routingDisc, "tag2", + WithOnPeersUpdate(checkPeer(t, adv2.ID(), done2))) + + // run discovery services for advertisers + ds1 := tn.startNewDiscovery(params, adv1, routingDisc1, "tag1") + go ds1.Advertise(tn.ctx) + + ds2 := tn.startNewDiscovery(params, adv2, routingDisc2, "tag2") + go ds2.Advertise(tn.ctx) + + // wait for discovery services to discover each other on different tags + select { + case <-done1: + case <-ctx.Done(): + t.Fatal("did not discover peer in time") + } + + select { + case <-done2: + case <-ctx.Done(): + t.Fatal("did not discover peer in time") + } +} + +type testnet struct { + ctx context.Context + T *testing.T + + bootstrapper peer.AddrInfo +} + +func newTestnet(ctx context.Context, t *testing.T) *testnet { + bus := eventbus.NewBus() + swarm := swarmt.GenSwarm(t, swarmt.OptDisableTCP, swarmt.EventBus(bus)) + hst, err := basic.NewHost(swarm, &basic.HostOpts{EventBus: bus}) + require.NoError(t, err) + hst.Start() + + _, err = dht.New(ctx, hst, + dht.Mode(dht.ModeServer), + dht.BootstrapPeers(), + dht.ProtocolPrefix("/test"), + ) + require.NoError(t, err) + + return &testnet{ctx: ctx, T: t, bootstrapper: *host.InfoFromHost(hst)} +} + +func (t *testnet) startNewDiscovery( + params *Parameters, + hst host.Host, + routingDisc discovery.Discovery, + tag string, + opts ...Option, +) *Discovery { + disc, err := NewDiscovery(params, hst, routingDisc, tag, opts...) + require.NoError(t.T, err) + err = disc.Start(t.ctx) + require.NoError(t.T, err) + t.T.Cleanup(func() { + err := disc.Stop(t.ctx) + require.NoError(t.T, err) + }) + return disc +} + +func (t *testnet) peer() (host.Host, discovery.Discovery) { + bus := eventbus.NewBus() + swarm := swarmt.GenSwarm(t.T, swarmt.OptDisableTCP, swarmt.EventBus(bus)) + hst, err := basic.NewHost(swarm, &basic.HostOpts{EventBus: bus}) + require.NoError(t.T, err) + hst.Start() + + err = hst.Connect(t.ctx, t.bootstrapper) + require.NoError(t.T, err) + + dht, err := dht.New(t.ctx, hst, + dht.Mode(dht.ModeServer), + dht.ProtocolPrefix("/test"), + // needed to reduce connections to peers on DHT level + dht.BucketSize(1), + ) + require.NoError(t.T, err) + + err = dht.Bootstrap(t.ctx) + require.NoError(t.T, err) + + return hst, routing.NewRoutingDiscovery(dht) +} + +func checkPeer(t *testing.T, expected peer.ID, done chan struct{}) func(peerID peer.ID, isAdded bool) { + return func(peerID peer.ID, isAdded bool) { + defer close(done) + require.Equal(t, expected, peerID) + require.True(t, isAdded) + } +} diff --git a/share/shwap/p2p/discovery/metrics.go b/share/shwap/p2p/discovery/metrics.go new file mode 100644 index 0000000000..5847fbdd90 --- /dev/null +++ b/share/shwap/p2p/discovery/metrics.go @@ -0,0 +1,172 @@ +package discovery + +import ( + "context" + "fmt" + + "github.com/libp2p/go-libp2p/core/peer" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + + "github.com/celestiaorg/celestia-node/libs/utils" +) + +const ( + discoveryEnoughPeersKey = "enough_peers" + + handlePeerResultKey = "result" + handlePeerSkipSelf handlePeerResult = "skip_self" + handlePeerEnoughPeers handlePeerResult = "skip_enough_peers" + handlePeerBackoff handlePeerResult = "skip_backoff" + handlePeerConnected handlePeerResult = "connected" + handlePeerConnErr handlePeerResult = "conn_err" + handlePeerInSet handlePeerResult = "in_set" + + advertiseFailedKey = "failed" +) + +var meter = otel.Meter("share_discovery") + +type handlePeerResult string + +type metrics struct { + peersAmount metric.Int64ObservableGauge + discoveryResult metric.Int64Counter // attributes: enough_peers[bool],is_canceled[bool] + handlePeerResult metric.Int64Counter // attributes: result[string] + advertise metric.Int64Counter // attributes: failed[bool] + peerAdded metric.Int64Counter + peerRemoved metric.Int64Counter + + clientReg metric.Registration +} + +// WithMetrics turns on metric collection in discoery. +func (d *Discovery) WithMetrics() error { + metrics, err := initMetrics(d) + if err != nil { + return fmt.Errorf("discovery: init metrics: %w", err) + } + d.metrics = metrics + d.onUpdatedPeers = d.onUpdatedPeers.add(metrics.observeOnPeersUpdate) + return nil +} + +func initMetrics(d *Discovery) (*metrics, error) { + peersAmount, err := meter.Int64ObservableGauge("discovery_amount_of_peers", + metric.WithDescription("amount of peers in discovery set")) + if err != nil { + return nil, err + } + + discoveryResult, err := meter.Int64Counter("discovery_find_peers_result", + metric.WithDescription("result of find peers run")) + if err != nil { + return nil, err + } + + handlePeerResultCounter, err := meter.Int64Counter("discovery_handler_peer_result", + metric.WithDescription("result handling found peer")) + if err != nil { + return nil, err + } + + advertise, err := meter.Int64Counter("discovery_advertise_event", + metric.WithDescription("advertise events counter")) + if err != nil { + return nil, err + } + + peerAdded, err := meter.Int64Counter("discovery_add_peer", + metric.WithDescription("add peer to discovery set counter")) + if err != nil { + return nil, err + } + + peerRemoved, err := meter.Int64Counter("discovery_remove_peer", + metric.WithDescription("remove peer from discovery set counter")) + if err != nil { + return nil, err + } + + backOffSize, err := meter.Int64ObservableGauge("discovery_backoff_amount", + metric.WithDescription("amount of peers in backoff")) + if err != nil { + return nil, err + } + + metrics := &metrics{ + peersAmount: peersAmount, + discoveryResult: discoveryResult, + handlePeerResult: handlePeerResultCounter, + advertise: advertise, + peerAdded: peerAdded, + peerRemoved: peerRemoved, + } + + callback := func(_ context.Context, observer metric.Observer) error { + observer.ObserveInt64(peersAmount, int64(d.set.Size())) + observer.ObserveInt64(backOffSize, int64(d.connector.Size())) + return nil + } + + metrics.clientReg, err = meter.RegisterCallback(callback, peersAmount, backOffSize) + if err != nil { + return nil, fmt.Errorf("registering metrics callback: %w", err) + } + + return metrics, nil +} + +func (m *metrics) close() error { + if m == nil { + return nil + } + return m.clientReg.Unregister() +} + +func (m *metrics) observeFindPeers(ctx context.Context, isEnoughPeers bool) { + if m == nil { + return + } + ctx = utils.ResetContextOnError(ctx) + + m.discoveryResult.Add(ctx, 1, + metric.WithAttributes( + attribute.Bool(discoveryEnoughPeersKey, isEnoughPeers))) +} + +func (m *metrics) observeHandlePeer(ctx context.Context, result handlePeerResult) { + if m == nil { + return + } + ctx = utils.ResetContextOnError(ctx) + + m.handlePeerResult.Add(ctx, 1, + metric.WithAttributes( + attribute.String(handlePeerResultKey, string(result)))) +} + +func (m *metrics) observeAdvertise(ctx context.Context, err error) { + if m == nil { + return + } + ctx = utils.ResetContextOnError(ctx) + + m.advertise.Add(ctx, 1, + metric.WithAttributes( + attribute.Bool(advertiseFailedKey, err != nil))) +} + +func (m *metrics) observeOnPeersUpdate(_ peer.ID, isAdded bool) { + if m == nil { + return + } + ctx := context.Background() + + if isAdded { + m.peerAdded.Add(ctx, 1) + return + } + m.peerRemoved.Add(ctx, 1) +} diff --git a/share/shwap/p2p/discovery/options.go b/share/shwap/p2p/discovery/options.go new file mode 100644 index 0000000000..de4b13a7db --- /dev/null +++ b/share/shwap/p2p/discovery/options.go @@ -0,0 +1,67 @@ +package discovery + +import ( + "fmt" + "time" + + "github.com/libp2p/go-libp2p/core/peer" +) + +// Parameters is the set of Parameters that must be configured for the Discovery module +type Parameters struct { + // PeersLimit defines the soft limit of FNs to connect to via discovery. + // Set 0 to disable. + PeersLimit uint + // AdvertiseInterval is a interval between advertising sessions. + // Set -1 to disable. + // NOTE: only full and bridge can advertise themselves. + AdvertiseInterval time.Duration +} + +// options is the set of options that can be configured for the Discovery module +type options struct { + // onUpdatedPeers will be called on peer set changes + onUpdatedPeers OnUpdatedPeers +} + +// Option is a function that configures Discovery Parameters +type Option func(*options) + +// DefaultParameters returns the default Parameters' configuration values +// for the Discovery module +func DefaultParameters() *Parameters { + return &Parameters{ + PeersLimit: 5, + AdvertiseInterval: time.Hour, + } +} + +// Validate validates the values in Parameters +func (p *Parameters) Validate() error { + if p.PeersLimit <= 0 { + return fmt.Errorf("discovery: peers limit cannot be zero or negative") + } + + if p.AdvertiseInterval <= 0 { + return fmt.Errorf("discovery: advertise interval cannot be zero or negative") + } + return nil +} + +// WithOnPeersUpdate chains OnPeersUpdate callbacks on every update of discovered peers list. +func WithOnPeersUpdate(f OnUpdatedPeers) Option { + return func(p *options) { + p.onUpdatedPeers = p.onUpdatedPeers.add(f) + } +} + +func newOptions(opts ...Option) *options { + defaults := &options{ + onUpdatedPeers: func(peer.ID, bool) {}, + } + + for _, opt := range opts { + opt(defaults) + } + return defaults +} diff --git a/share/shwap/p2p/discovery/set.go b/share/shwap/p2p/discovery/set.go new file mode 100644 index 0000000000..a22e10f06e --- /dev/null +++ b/share/shwap/p2p/discovery/set.go @@ -0,0 +1,93 @@ +package discovery + +import ( + "context" + "sync" + + "github.com/libp2p/go-libp2p/core/peer" +) + +// limitedSet is a thread safe set of peers with given limit. +// Inspired by libp2p peer.Set but extended with Remove method. +type limitedSet struct { + lk sync.RWMutex + ps map[peer.ID]struct{} + + limit uint + waitPeer chan peer.ID +} + +// newLimitedSet constructs a set with the maximum peers amount. +func newLimitedSet(limit uint) *limitedSet { + ps := new(limitedSet) + ps.ps = make(map[peer.ID]struct{}) + ps.limit = limit + ps.waitPeer = make(chan peer.ID) + return ps +} + +func (ps *limitedSet) Contains(p peer.ID) bool { + ps.lk.RLock() + _, ok := ps.ps[p] + ps.lk.RUnlock() + return ok +} + +func (ps *limitedSet) Limit() uint { + return ps.limit +} + +func (ps *limitedSet) Size() uint { + ps.lk.RLock() + defer ps.lk.RUnlock() + return uint(len(ps.ps)) +} + +// Add attempts to add the given peer into the set. +func (ps *limitedSet) Add(p peer.ID) (added bool) { + ps.lk.Lock() + if _, ok := ps.ps[p]; ok { + ps.lk.Unlock() + return false + } + ps.ps[p] = struct{}{} + ps.lk.Unlock() + + for { + // peer will be pushed to the channel only when somebody is reading from it. + // this is done to handle case when Peers() was called on empty set. + select { + case ps.waitPeer <- p: + default: + return true + } + } +} + +func (ps *limitedSet) Remove(id peer.ID) { + ps.lk.Lock() + delete(ps.ps, id) + ps.lk.Unlock() +} + +// Peers returns all discovered peers from the set. +func (ps *limitedSet) Peers(ctx context.Context) ([]peer.ID, error) { + ps.lk.RLock() + if len(ps.ps) > 0 { + out := make([]peer.ID, 0, len(ps.ps)) + for p := range ps.ps { + out = append(out, p) + } + ps.lk.RUnlock() + return out, nil + } + ps.lk.RUnlock() + + // block until a new peer will be discovered + select { + case <-ctx.Done(): + return nil, ctx.Err() + case p := <-ps.waitPeer: + return []peer.ID{p}, nil + } +} diff --git a/share/shwap/p2p/discovery/set_test.go b/share/shwap/p2p/discovery/set_test.go new file mode 100644 index 0000000000..d5113a2291 --- /dev/null +++ b/share/shwap/p2p/discovery/set_test.go @@ -0,0 +1,87 @@ +package discovery + +import ( + "context" + "testing" + "time" + + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" + "github.com/stretchr/testify/require" +) + +func TestSet_TryAdd(t *testing.T) { + m := mocknet.New() + h, err := m.GenPeer() + require.NoError(t, err) + + set := newLimitedSet(1) + set.Add(h.ID()) + require.True(t, set.Contains(h.ID())) +} + +func TestSet_Remove(t *testing.T) { + m := mocknet.New() + h, err := m.GenPeer() + require.NoError(t, err) + + set := newLimitedSet(1) + set.Add(h.ID()) + set.Remove(h.ID()) + require.False(t, set.Contains(h.ID())) +} + +func TestSet_Peers(t *testing.T) { + m := mocknet.New() + h1, err := m.GenPeer() + require.NoError(t, err) + h2, err := m.GenPeer() + require.NoError(t, err) + + set := newLimitedSet(2) + set.Add(h1.ID()) + set.Add(h2.ID()) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*1) + t.Cleanup(cancel) + + peers, err := set.Peers(ctx) + require.NoError(t, err) + require.True(t, len(peers) == 2) +} + +// TestSet_WaitPeers ensures that `Peers` will be unblocked once +// a new peer was discovered. +func TestSet_WaitPeers(t *testing.T) { + m := mocknet.New() + h1, err := m.GenPeer() + require.NoError(t, err) + + set := newLimitedSet(2) + go func() { + time.Sleep(time.Millisecond * 500) + set.Add(h1.ID()) + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + t.Cleanup(cancel) + + // call `Peers` on empty set will block until a new peer will be discovered + peers, err := set.Peers(ctx) + require.NoError(t, err) + require.True(t, len(peers) == 1) +} + +func TestSet_Size(t *testing.T) { + m := mocknet.New() + h1, err := m.GenPeer() + require.NoError(t, err) + h2, err := m.GenPeer() + require.NoError(t, err) + + set := newLimitedSet(2) + set.Add(h1.ID()) + set.Add(h2.ID()) + require.EqualValues(t, 2, set.Size()) + set.Remove(h2.ID()) + require.EqualValues(t, 1, set.Size()) +} diff --git a/share/shwap/p2p/shrex/doc.go b/share/shwap/p2p/shrex/doc.go new file mode 100644 index 0000000000..9654532842 --- /dev/null +++ b/share/shwap/p2p/shrex/doc.go @@ -0,0 +1,18 @@ +// Package shrex provides functionality that powers the share exchange protocols used by celestia-node. +// The available protocols are: +// +// - shrexsub : a floodsub-based pubsub protocol that is used to broadcast/subscribe to the event +// of new EDS in the network to peers. +// +// - shrexnd: a request/response protocol that is used to request shares by namespace or namespace data from peers. +// +// - shrexeds: a request/response protocol that is used to request extended data square shares from peers. +// This protocol exchanges the original data square in between the client and server, and it's up to the +// receiver to compute the extended data square. +// +// This package also defines a peer manager that is used to manage network peers that can be used to exchange +// shares. The peer manager is primarily responsible for providing peers to request shares from, +// and is primarily used by `getters.ShrexGetter` in share/getters/shrex.go. +// +// Find out more about each protocol in their respective sub-packages. +package shrex diff --git a/share/shwap/p2p/shrex/errors.go b/share/shwap/p2p/shrex/errors.go new file mode 100644 index 0000000000..79ff0ed2b2 --- /dev/null +++ b/share/shwap/p2p/shrex/errors.go @@ -0,0 +1,17 @@ +package shrex + +import ( + "errors" +) + +// ErrNotFound is returned when a peer is unable to find the requested data or resource. +// It is used to signal that the peer couldn't serve the data successfully, and it's not +// available at the moment. The request may be retried later, but it's unlikely to succeed. +var ErrNotFound = errors.New("the requested data or resource could not be found") + +var ErrRateLimited = errors.New("server is overloaded and rate limited the request") + +// ErrInvalidResponse is returned when a peer returns an invalid response or caused an internal +// error. It is used to signal that the peer couldn't serve the data successfully, and should not be +// retried. +var ErrInvalidResponse = errors.New("server returned an invalid response or caused an internal error") diff --git a/share/shwap/p2p/shrex/metrics.go b/share/shwap/p2p/shrex/metrics.go new file mode 100644 index 0000000000..9d5c605139 --- /dev/null +++ b/share/shwap/p2p/shrex/metrics.go @@ -0,0 +1,73 @@ +package shrex + +import ( + "context" + "fmt" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + + "github.com/celestiaorg/celestia-node/libs/utils" +) + +var meter = otel.Meter("shrex/eds") + +type status string + +const ( + StatusBadRequest status = "bad_request" + StatusSendRespErr status = "send_resp_err" + StatusSendReqErr status = "send_req_err" + StatusReadRespErr status = "read_resp_err" + StatusInternalErr status = "internal_err" + StatusNotFound status = "not_found" + StatusTimeout status = "timeout" + StatusSuccess status = "success" + StatusRateLimited status = "rate_limited" +) + +type Metrics struct { + totalRequestCounter metric.Int64Counter +} + +// ObserveRequests increments the total number of requests sent with the given status as an +// attribute. +func (m *Metrics) ObserveRequests(ctx context.Context, count int64, status status) { + if m == nil { + return + } + ctx = utils.ResetContextOnError(ctx) + m.totalRequestCounter.Add(ctx, count, + metric.WithAttributes( + attribute.String("status", string(status)), + )) +} + +func InitClientMetrics(protocol string) (*Metrics, error) { + totalRequestCounter, err := meter.Int64Counter( + fmt.Sprintf("shrex_%s_client_total_requests", protocol), + metric.WithDescription(fmt.Sprintf("Total count of sent shrex/%s requests", protocol)), + ) + if err != nil { + return nil, err + } + + return &Metrics{ + totalRequestCounter: totalRequestCounter, + }, nil +} + +func InitServerMetrics(protocol string) (*Metrics, error) { + totalRequestCounter, err := meter.Int64Counter( + fmt.Sprintf("shrex_%s_server_total_responses", protocol), + metric.WithDescription(fmt.Sprintf("Total count of sent shrex/%s responses", protocol)), + ) + if err != nil { + return nil, err + } + + return &Metrics{ + totalRequestCounter: totalRequestCounter, + }, nil +} diff --git a/share/shwap/p2p/shrex/middleware.go b/share/shwap/p2p/shrex/middleware.go new file mode 100644 index 0000000000..c53a996eec --- /dev/null +++ b/share/shwap/p2p/shrex/middleware.go @@ -0,0 +1,48 @@ +package shrex + +import ( + "sync/atomic" + + logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-libp2p/core/network" +) + +var log = logging.Logger("shrex/middleware") + +type Middleware struct { + // concurrencyLimit is the maximum number of requests that can be processed at once. + concurrencyLimit int64 + // parallelRequests is the number of requests currently being processed. + parallelRequests atomic.Int64 + // numRateLimited is the number of requests that were rate limited. + numRateLimited atomic.Int64 +} + +func NewMiddleware(concurrencyLimit int) *Middleware { + return &Middleware{ + concurrencyLimit: int64(concurrencyLimit), + } +} + +// DrainCounter returns the current value of the rate limit counter and resets it to 0. +func (m *Middleware) DrainCounter() int64 { + return m.numRateLimited.Swap(0) +} + +func (m *Middleware) RateLimitHandler(handler network.StreamHandler) network.StreamHandler { + return func(stream network.Stream) { + current := m.parallelRequests.Add(1) + defer m.parallelRequests.Add(-1) + + if current > m.concurrencyLimit { + m.numRateLimited.Add(1) + log.Debug("concurrency limit reached") + err := stream.Close() + if err != nil { + log.Debugw("server: closing stream", "err", err) + } + return + } + handler(stream) + } +} diff --git a/share/shwap/p2p/shrex/params.go b/share/shwap/p2p/shrex/params.go new file mode 100644 index 0000000000..f36221d548 --- /dev/null +++ b/share/shwap/p2p/shrex/params.go @@ -0,0 +1,69 @@ +package shrex + +import ( + "fmt" + "time" + + "github.com/libp2p/go-libp2p/core/protocol" +) + +// Parameters is the set of parameters that must be configured for the shrex/eds protocol. +type Parameters struct { + // ServerReadTimeout sets the timeout for reading messages from the stream. + ServerReadTimeout time.Duration + + // ServerWriteTimeout sets the timeout for writing messages to the stream. + ServerWriteTimeout time.Duration + + // HandleRequestTimeout defines the deadline for handling request. + HandleRequestTimeout time.Duration + + // ConcurrencyLimit is the maximum number of concurrently handled streams + ConcurrencyLimit int + + // networkID is prepended to the protocolID and represents the network the protocol is + // running on. + networkID string +} + +func DefaultParameters() *Parameters { + return &Parameters{ + ServerReadTimeout: 5 * time.Second, + ServerWriteTimeout: time.Minute, // based on max observed sample time for 256 blocks (~50s) + HandleRequestTimeout: time.Minute, + ConcurrencyLimit: 10, + } +} + +const errSuffix = "value should be positive and non-zero" + +func (p *Parameters) Validate() error { + if p.ServerReadTimeout <= 0 { + return fmt.Errorf("invalid stream read timeout: %v, %s", p.ServerReadTimeout, errSuffix) + } + if p.ServerWriteTimeout <= 0 { + return fmt.Errorf("invalid write timeout: %v, %s", p.ServerWriteTimeout, errSuffix) + } + if p.HandleRequestTimeout <= 0 { + return fmt.Errorf("invalid handle request timeout: %v, %s", p.HandleRequestTimeout, errSuffix) + } + if p.ConcurrencyLimit <= 0 { + return fmt.Errorf("invalid concurrency limit: %s", errSuffix) + } + return nil +} + +// WithNetworkID sets the value of networkID in params +func (p *Parameters) WithNetworkID(networkID string) { + p.networkID = networkID +} + +// NetworkID returns the value of networkID stored in params +func (p *Parameters) NetworkID() string { + return p.networkID +} + +// ProtocolID creates a protocol ID string according to common format +func ProtocolID(networkID, protocolString string) protocol.ID { + return protocol.ID(fmt.Sprintf("/%s%s", networkID, protocolString)) +} diff --git a/share/shwap/p2p/shrex/peers/doc.go b/share/shwap/p2p/shrex/peers/doc.go new file mode 100644 index 0000000000..bc1647eb42 --- /dev/null +++ b/share/shwap/p2p/shrex/peers/doc.go @@ -0,0 +1,52 @@ +// Package peers provides a peer manager that handles peer discovery and peer selection for the shrex getter. +// +// The peer manager is responsible for: +// - Discovering peers +// - Selecting peers for data retrieval +// - Validating peers +// - Blacklisting peers +// - Garbage collecting peers +// +// The peer manager is not responsible for: +// - Connecting to peers +// - Disconnecting from peers +// - Sending data to peers +// - Receiving data from peers +// +// The peer manager is a mechanism to store peers from shrexsub, a mechanism that +// handles "peer discovery" and "peer selection" by relying on a shrexsub subscription +// and header subscriptions, such that it listens for new headers and +// new shares and uses this information to pool peers by shares. +// +// This gives the peer manager an ability to block peers that gossip invalid shares, but also access a list of peers +// that are known to have been gossiping valid shares. +// The peers are then returned on request using a round-robin algorithm to return a different peer each time. +// If no peers are found, the peer manager will rely on full nodes retrieved from discovery. +// +// The peer manager is only concerned with recent heights, thus it retrieves peers that +// were active since `initialHeight`. +// The peer manager will also garbage collect peers such that it blacklists peers that +// have been active since `initialHeight` but have been found to be invalid. +// +// The peer manager is passed to the shrex getter and is used at request time to +// select peers for a given data hash for data retrieval. +// +// # Usage +// +// The peer manager is created using [NewManager] constructor: +// +// peerManager := peers.NewManager(headerSub, shrexSub, discovery, host, connGater, opts...) +// +// After creating the peer manager, it should be started to kick off listening and +// validation routines that enable peer selection and retrieval: +// +// err := peerManager.Start(ctx) +// +// The peer manager can be stopped at any time to stop all peer discovery and validation routines: +// +// err := peerManager.Stop(ctx) +// +// The peer manager can be used to select peers for a given datahash for shares retrieval: +// +// peer, err := peerManager.Peer(ctx, hash) +package peers diff --git a/share/shwap/p2p/shrex/peers/manager.go b/share/shwap/p2p/shrex/peers/manager.go new file mode 100644 index 0000000000..ca85a85ea6 --- /dev/null +++ b/share/shwap/p2p/shrex/peers/manager.go @@ -0,0 +1,526 @@ +package peers + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + logging "github.com/ipfs/go-log/v2" + pubsub "github.com/libp2p/go-libp2p-pubsub" + "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/p2p/host/eventbus" + "github.com/libp2p/go-libp2p/p2p/net/conngater" + + libhead "github.com/celestiaorg/go-header" + + "github.com/celestiaorg/celestia-node/header" + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/p2p/shrexsub" +) + +const ( + // ResultNoop indicates operation was successful and no extra action is required + ResultNoop result = "result_noop" + // ResultCooldownPeer will put returned peer on cooldown, meaning it won't be available by Peer + // method for some time + ResultCooldownPeer = "result_cooldown_peer" + // ResultBlacklistPeer will blacklist peer. Blacklisted peers will be disconnected and blocked from + // any p2p communication in future by libp2p Gater + ResultBlacklistPeer = "result_blacklist_peer" + + // eventbusBufSize is the size of the buffered channel to handle + // events in libp2p + eventbusBufSize = 32 + + // storedPoolsAmount is the amount of pools for recent headers that will be stored in the peer + // manager + storedPoolsAmount = 10 +) + +type result string + +var log = logging.Logger("shrex/peer-manager") + +// Manager keeps track of peers coming from shrex.Sub and from discovery +type Manager struct { + lock sync.Mutex + params Parameters + + // header subscription is necessary in order to Validate the inbound eds hash + headerSub libhead.Subscriber[*header.ExtendedHeader] + shrexSub *shrexsub.PubSub + host host.Host + connGater *conngater.BasicConnectionGater + + // pools collecting peers from shrexSub and stores them by datahash + pools map[string]*syncPool + + // initialHeight is the height of the first header received from headersub + initialHeight atomic.Uint64 + // messages from shrex.Sub with height below storeFrom will be ignored, since we don't need to + // track peers for those headers + storeFrom atomic.Uint64 + + // nodes collects nodes' peer.IDs found via discovery + nodes *pool + + // hashes that are not in the chain + blacklistedHashes map[string]bool + + metrics *metrics + + headerSubDone chan struct{} + disconnectedPeersDone chan struct{} + cancel context.CancelFunc +} + +// DoneFunc updates internal state depending on call results. Should be called once per returned +// peer from Peer method +type DoneFunc func(result) + +type syncPool struct { + *pool + + // isValidatedDataHash indicates if datahash was validated by receiving corresponding extended + // header from headerSub + isValidatedDataHash atomic.Bool + // height is the height of the header that corresponds to datahash + height uint64 + // createdAt is the syncPool creation time + createdAt time.Time +} + +func NewManager( + params Parameters, + host host.Host, + connGater *conngater.BasicConnectionGater, + options ...Option, +) (*Manager, error) { + if err := params.Validate(); err != nil { + return nil, err + } + + s := &Manager{ + params: params, + connGater: connGater, + host: host, + pools: make(map[string]*syncPool), + blacklistedHashes: make(map[string]bool), + headerSubDone: make(chan struct{}), + disconnectedPeersDone: make(chan struct{}), + } + + for _, opt := range options { + err := opt(s) + if err != nil { + return nil, err + } + } + + s.nodes = newPool(s.params.PeerCooldown) + return s, nil +} + +func (m *Manager) Start(startCtx context.Context) error { + ctx, cancel := context.WithCancel(context.Background()) + m.cancel = cancel + + // pools will only be populated with senders of shrexsub notifications if the WithShrexSubPools + // option is used. + if m.shrexSub == nil && m.headerSub == nil { + return nil + } + + validatorFn := m.metrics.validationObserver(m.Validate) + err := m.shrexSub.AddValidator(validatorFn) + if err != nil { + return fmt.Errorf("registering validator: %w", err) + } + err = m.shrexSub.Start(startCtx) + if err != nil { + return fmt.Errorf("starting shrexsub: %w", err) + } + + headerSub, err := m.headerSub.Subscribe() + if err != nil { + return fmt.Errorf("subscribing to headersub: %w", err) + } + + sub, err := m.host.EventBus().Subscribe(&event.EvtPeerConnectednessChanged{}, eventbus.BufSize(eventbusBufSize)) + if err != nil { + return fmt.Errorf("subscribing to libp2p events: %w", err) + } + + go m.subscribeHeader(ctx, headerSub) + go m.subscribeDisconnectedPeers(ctx, sub) + go m.GC(ctx) + return nil +} + +func (m *Manager) Stop(ctx context.Context) error { + m.cancel() + + if err := m.metrics.close(); err != nil { + log.Warnw("closing metrics", "err", err) + } + + // we do not need to wait for headersub and disconnected peers to finish + // here, since they were never started + if m.headerSub == nil && m.shrexSub == nil { + return nil + } + + select { + case <-m.headerSubDone: + case <-ctx.Done(): + return ctx.Err() + } + + select { + case <-m.disconnectedPeersDone: + case <-ctx.Done(): + return ctx.Err() + } + + return nil +} + +// Peer returns peer collected from shrex.Sub for given datahash if any available. +// If there is none, it will look for nodes collected from discovery. If there is no discovered +// nodes, it will wait until any peer appear in either source or timeout happen. +// After fetching data using given peer, caller is required to call returned DoneFunc using +// appropriate result value +func (m *Manager) Peer(ctx context.Context, datahash share.DataHash, height uint64, +) (peer.ID, DoneFunc, error) { + p := m.validatedPool(datahash.String(), height) + + // first, check if a peer is available for the given datahash + peerID, ok := p.tryGet() + if ok { + if m.removeIfUnreachable(p, peerID) { + return m.Peer(ctx, datahash, height) + } + return m.newPeer(ctx, datahash, peerID, sourceShrexSub, p.len(), 0) + } + + // if no peer for datahash is currently available, try to use node + // obtained from discovery + peerID, ok = m.nodes.tryGet() + if ok { + return m.newPeer(ctx, datahash, peerID, sourceFullNodes, m.nodes.len(), 0) + } + + // no peers are available right now, wait for the first one + start := time.Now() + select { + case peerID = <-p.next(ctx): + if m.removeIfUnreachable(p, peerID) { + return m.Peer(ctx, datahash, height) + } + return m.newPeer(ctx, datahash, peerID, sourceShrexSub, p.len(), time.Since(start)) + case peerID = <-m.nodes.next(ctx): + return m.newPeer(ctx, datahash, peerID, sourceFullNodes, m.nodes.len(), time.Since(start)) + case <-ctx.Done(): + return "", nil, ctx.Err() + } +} + +// UpdateNodePool is called by discovery when new node is discovered or removed. +func (m *Manager) UpdateNodePool(peerID peer.ID, isAdded bool) { + if isAdded { + if m.isBlacklistedPeer(peerID) { + log.Debugw("got blacklisted peer from discovery", "peer", peerID.String()) + return + } + m.nodes.add(peerID) + log.Debugw("added to discovered nodes pool", "peer", peerID) + return + } + + log.Debugw("removing peer from discovered nodes pool", "peer", peerID.String()) + m.nodes.remove(peerID) +} + +func (m *Manager) newPeer( + ctx context.Context, + datahash share.DataHash, + peerID peer.ID, + source peerSource, + poolSize int, + waitTime time.Duration, +) (peer.ID, DoneFunc, error) { + log.Debugw("got peer", + "hash", datahash.String(), + "peer", peerID.String(), + "source", source, + "pool_size", poolSize, + "wait (s)", waitTime) + m.metrics.observeGetPeer(ctx, source, poolSize, waitTime) + return peerID, m.doneFunc(datahash, peerID, source), nil +} + +func (m *Manager) doneFunc(datahash share.DataHash, peerID peer.ID, source peerSource) DoneFunc { + return func(result result) { + log.Debugw("set peer result", + "hash", datahash.String(), + "peer", peerID.String(), + "source", source, + "result", result) + m.metrics.observeDoneResult(source, result) + switch result { + case ResultNoop: + case ResultCooldownPeer: + if source == sourceFullNodes { + m.nodes.putOnCooldown(peerID) + return + } + m.getPool(datahash.String()).putOnCooldown(peerID) + case ResultBlacklistPeer: + m.blacklistPeers(reasonMisbehave, peerID) + } + } +} + +// subscribeHeader takes datahash from received header and validates corresponding peer pool. +func (m *Manager) subscribeHeader(ctx context.Context, headerSub libhead.Subscription[*header.ExtendedHeader]) { + defer close(m.headerSubDone) + defer headerSub.Cancel() + + for { + h, err := headerSub.NextHeader(ctx) + if err != nil { + if errors.Is(err, context.Canceled) { + return + } + log.Errorw("get next header from sub", "err", err) + continue + } + m.validatedPool(h.DataHash.String(), h.Height()) + + // store first header for validation purposes + if m.initialHeight.CompareAndSwap(0, h.Height()) { + log.Debugw("stored initial height", "height", h.Height()) + } + + // update storeFrom if header height + m.storeFrom.Store(uint64(max(0, int(h.Height())-storedPoolsAmount))) + log.Debugw("updated lowest stored height", "height", h.Height()) + } +} + +// subscribeDisconnectedPeers subscribes to libp2p connectivity events and removes disconnected +// peers from nodes pool +func (m *Manager) subscribeDisconnectedPeers(ctx context.Context, sub event.Subscription) { + defer close(m.disconnectedPeersDone) + defer sub.Close() + for { + select { + case <-ctx.Done(): + return + case e, ok := <-sub.Out(): + if !ok { + log.Fatal("Subscription for connectedness events is closed.") //nolint:gocritic + return + } + // listen to disconnect event to remove peer from nodes pool + connStatus := e.(event.EvtPeerConnectednessChanged) + if connStatus.Connectedness == network.NotConnected { + peer := connStatus.Peer + if m.nodes.has(peer) { + log.Debugw("peer disconnected, removing from discovered nodes pool", + "peer", peer.String()) + m.nodes.remove(peer) + } + } + } + } +} + +// Validate will collect peer.ID into corresponding peer pool +func (m *Manager) Validate(_ context.Context, peerID peer.ID, msg shrexsub.Notification) pubsub.ValidationResult { + logger := log.With("peer", peerID.String(), "hash", msg.DataHash.String()) + + // messages broadcast from self should bypass the validation with Accept + if peerID == m.host.ID() { + logger.Debug("received datahash from self") + return pubsub.ValidationAccept + } + + // punish peer for sending invalid hash if it has misbehaved in the past + if m.isBlacklistedHash(msg.DataHash) { + logger.Debug("received blacklisted hash, reject validation") + return pubsub.ValidationReject + } + + if m.isBlacklistedPeer(peerID) { + logger.Debug("received message from blacklisted peer, reject validation") + return pubsub.ValidationReject + } + + if msg.Height < m.storeFrom.Load() { + logger.Debug("received message for past header") + return pubsub.ValidationIgnore + } + + p := m.getOrCreatePool(msg.DataHash.String(), msg.Height) + logger.Debugw("got hash from shrex-sub") + + p.add(peerID) + if p.isValidatedDataHash.Load() { + // add peer to discovered nodes pool only if datahash has been already validated + m.nodes.add(peerID) + } + return pubsub.ValidationIgnore +} + +func (m *Manager) getPool(datahash string) *syncPool { + m.lock.Lock() + defer m.lock.Unlock() + return m.pools[datahash] +} + +func (m *Manager) getOrCreatePool(datahash string, height uint64) *syncPool { + m.lock.Lock() + defer m.lock.Unlock() + + p, ok := m.pools[datahash] + if !ok { + p = &syncPool{ + height: height, + pool: newPool(m.params.PeerCooldown), + createdAt: time.Now(), + } + m.pools[datahash] = p + } + + return p +} + +func (m *Manager) blacklistPeers(reason blacklistPeerReason, peerIDs ...peer.ID) { + m.metrics.observeBlacklistPeers(reason, len(peerIDs)) + + for _, peerID := range peerIDs { + // blacklisted peers will be logged regardless of EnableBlackListing whether option being is + // enabled, until blacklisting is not properly tested and enabled by default. + log.Debugw("blacklisting peer", "peer", peerID.String(), "reason", reason) + if !m.params.EnableBlackListing { + continue + } + + m.nodes.remove(peerID) + // add peer to the blacklist, so we can't connect to it in the future. + err := m.connGater.BlockPeer(peerID) + if err != nil { + log.Warnw("failed to block peer", "peer", peerID, "err", err) + } + // close connections to peer. + err = m.host.Network().ClosePeer(peerID) + if err != nil { + log.Warnw("failed to close connection with peer", "peer", peerID, "err", err) + } + } +} + +func (m *Manager) isBlacklistedPeer(peerID peer.ID) bool { + return !m.connGater.InterceptPeerDial(peerID) +} + +func (m *Manager) isBlacklistedHash(hash share.DataHash) bool { + m.lock.Lock() + defer m.lock.Unlock() + return m.blacklistedHashes[hash.String()] +} + +func (m *Manager) validatedPool(hashStr string, height uint64) *syncPool { + p := m.getOrCreatePool(hashStr, height) + if p.isValidatedDataHash.CompareAndSwap(false, true) { + log.Debugw("pool marked validated", "datahash", hashStr) + // if pool is proven to be valid, add all collected peers to discovered nodes + m.nodes.add(p.peers()...) + } + return p +} + +// removeIfUnreachable removes peer from some pool if it is blacklisted or disconnected +func (m *Manager) removeIfUnreachable(pool *syncPool, peerID peer.ID) bool { + if m.isBlacklistedPeer(peerID) || !m.nodes.has(peerID) { + log.Debugw("removing outdated peer from pool", "peer", peerID.String()) + pool.remove(peerID) + return true + } + return false +} + +func (m *Manager) GC(ctx context.Context) { + ticker := time.NewTicker(m.params.GcInterval) + defer ticker.Stop() + + var blacklist []peer.ID + for { + select { + case <-ticker.C: + case <-ctx.Done(): + return + } + + blacklist = m.cleanUp() + if len(blacklist) > 0 { + m.blacklistPeers(reasonInvalidHash, blacklist...) + } + } +} + +func (m *Manager) cleanUp() []peer.ID { + if m.initialHeight.Load() == 0 { + // can't blacklist peers until initialHeight is set + return nil + } + + m.lock.Lock() + defer m.lock.Unlock() + + addToBlackList := make(map[peer.ID]struct{}) + for h, p := range m.pools { + if p.isValidatedDataHash.Load() { + // remove pools that are outdated + if p.height < m.storeFrom.Load() { + delete(m.pools, h) + } + continue + } + + // can't validate datahashes below initial height + if p.height < m.initialHeight.Load() { + delete(m.pools, h) + continue + } + + // find pools that are not validated in time + if time.Since(p.createdAt) > m.params.PoolValidationTimeout { + delete(m.pools, h) + + log.Debug("blacklisting datahash with all corresponding peers", + "hash", h, + "peer_list", p.peersList) + // blacklist hash + m.blacklistedHashes[h] = true + + // blacklist peers + for _, peer := range p.peersList { + addToBlackList[peer] = struct{}{} + } + } + } + + blacklist := make([]peer.ID, 0, len(addToBlackList)) + for peerID := range addToBlackList { + blacklist = append(blacklist, peerID) + } + return blacklist +} diff --git a/share/shwap/p2p/shrex/peers/manager_test.go b/share/shwap/p2p/shrex/peers/manager_test.go new file mode 100644 index 0000000000..c7193d8780 --- /dev/null +++ b/share/shwap/p2p/shrex/peers/manager_test.go @@ -0,0 +1,569 @@ +package peers + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/ipfs/go-datastore" + dssync "github.com/ipfs/go-datastore/sync" + dht "github.com/libp2p/go-libp2p-kad-dht" + pubsub "github.com/libp2p/go-libp2p-pubsub" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/peer" + routingdisc "github.com/libp2p/go-libp2p/p2p/discovery/routing" + "github.com/libp2p/go-libp2p/p2p/net/conngater" + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" + "github.com/stretchr/testify/require" + "github.com/tendermint/tendermint/libs/rand" + + libhead "github.com/celestiaorg/go-header" + + "github.com/celestiaorg/celestia-node/header" + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/p2p/shrexsub" + discovery2 "github.com/celestiaorg/celestia-node/share/shwap/p2p/discovery" +) + +func TestManager(t *testing.T) { + t.Run("Validate pool by headerSub", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + // create headerSub mock + h := testHeader() + headerSub := newSubLock(h, nil) + + // start test manager + manager, err := testManager(ctx, headerSub) + require.NoError(t, err) + + // wait until header is requested from header sub + err = headerSub.wait(ctx, 1) + require.NoError(t, err) + + // check validation + require.True(t, manager.pools[h.DataHash.String()].isValidatedDataHash.Load()) + stopManager(t, manager) + }) + + t.Run("Validate pool by shrex.Getter", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + h := testHeader() + headerSub := newSubLock(h, nil) + + // start test manager + manager, err := testManager(ctx, headerSub) + require.NoError(t, err) + + peerID, msg := peer.ID("peer1"), newShrexSubMsg(h) + result := manager.Validate(ctx, peerID, msg) + require.Equal(t, pubsub.ValidationIgnore, result) + + pID, _, err := manager.Peer(ctx, h.DataHash.Bytes(), h.Height()) + require.NoError(t, err) + require.Equal(t, peerID, pID) + + // check pool validation + require.True(t, manager.getPool(h.DataHash.String()).isValidatedDataHash.Load()) + }) + + t.Run("validator", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + // create headerSub mock + h := testHeader() + headerSub := newSubLock(h, nil) + + // start test manager + manager, err := testManager(ctx, headerSub) + require.NoError(t, err) + + // own messages should be accepted + msg := newShrexSubMsg(h) + result := manager.Validate(ctx, manager.host.ID(), msg) + require.Equal(t, pubsub.ValidationAccept, result) + + // normal messages should be ignored + peerID := peer.ID("peer1") + result = manager.Validate(ctx, peerID, msg) + require.Equal(t, pubsub.ValidationIgnore, result) + + // mark peer as misbehaved to blacklist it + pID, done, err := manager.Peer(ctx, h.DataHash.Bytes(), h.Height()) + require.NoError(t, err) + require.Equal(t, peerID, pID) + manager.params.EnableBlackListing = true + done(ResultBlacklistPeer) + + // new messages from misbehaved peer should be Rejected + result = manager.Validate(ctx, pID, msg) + require.Equal(t, pubsub.ValidationReject, result) + + stopManager(t, manager) + }) + + t.Run("cleanup", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + // create headerSub mock + h := testHeader() + headerSub := newSubLock(h) + + // start test manager + manager, err := testManager(ctx, headerSub) + require.NoError(t, err) + require.NoError(t, headerSub.wait(ctx, 1)) + + // set syncTimeout to 0 to allow cleanup to find outdated datahash + manager.params.PoolValidationTimeout = 0 + + // create unvalidated pool + peerID := peer.ID("peer1") + msg := shrexsub.Notification{ + DataHash: share.DataHash("datahash1datahash1datahash1datahash1datahash1"), + Height: 2, + } + manager.Validate(ctx, peerID, msg) + + // create validated pool + validDataHash := share.DataHash("datahash2") + manager.nodes.add("full") // add FN to unblock Peer call + manager.Peer(ctx, validDataHash, h.Height()) //nolint:errcheck + require.Len(t, manager.pools, 3) + + // trigger cleanup + blacklisted := manager.cleanUp() + require.Contains(t, blacklisted, peerID) + require.Len(t, manager.pools, 2) + + // messages with blacklisted hash should be rejected right away + peerID2 := peer.ID("peer2") + result := manager.Validate(ctx, peerID2, msg) + require.Equal(t, pubsub.ValidationReject, result) + + // check blacklisted pools + require.True(t, manager.isBlacklistedHash(msg.DataHash)) + require.False(t, manager.isBlacklistedHash(validDataHash)) + }) + + t.Run("no peers from shrex.Sub, get from discovery", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + // create headerSub mock + h := testHeader() + headerSub := newSubLock(h) + + // start test manager + manager, err := testManager(ctx, headerSub) + require.NoError(t, err) + + // add peers to fullnodes, imitating discovery add + peers := []peer.ID{"peer1", "peer2", "peer3"} + manager.nodes.add(peers...) + + peerID, _, err := manager.Peer(ctx, h.DataHash.Bytes(), h.Height()) + require.NoError(t, err) + require.Contains(t, peers, peerID) + + stopManager(t, manager) + }) + + t.Run("no peers from shrex.Sub and from discovery. Wait", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + // create headerSub mock + h := testHeader() + headerSub := newSubLock(h) + + // start test manager + manager, err := testManager(ctx, headerSub) + require.NoError(t, err) + + // make sure peers are not returned before timeout + timeoutCtx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + t.Cleanup(cancel) + _, _, err = manager.Peer(timeoutCtx, h.DataHash.Bytes(), h.Height()) + require.ErrorIs(t, err, context.DeadlineExceeded) + + peers := []peer.ID{"peer1", "peer2", "peer3"} + + // launch wait routine + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + peerID, _, err := manager.Peer(ctx, h.DataHash.Bytes(), h.Height()) + require.NoError(t, err) + require.Contains(t, peers, peerID) + }() + + // send peers + manager.nodes.add(peers...) + + // wait for peer to be received + select { + case <-doneCh: + case <-ctx.Done(): + require.NoError(t, ctx.Err()) + } + + stopManager(t, manager) + }) + + t.Run("shrexSub sends a message lower than first headerSub header height, headerSub first", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + h := testHeader() + h.RawHeader.Height = 100 + headerSub := newSubLock(h, nil) + + // start test manager + manager, err := testManager(ctx, headerSub) + require.NoError(t, err) + + // unlock headerSub to read first header + require.NoError(t, headerSub.wait(ctx, 1)) + // pool will be created for first headerSub header datahash + require.Len(t, manager.pools, 1) + + // create shrexSub msg with height lower than first header from headerSub + msg := shrexsub.Notification{ + DataHash: share.DataHash("datahash"), + Height: h.Height() - 1, + } + result := manager.Validate(ctx, "peer", msg) + require.Equal(t, pubsub.ValidationIgnore, result) + // pool will be created for first shrexSub message + require.Len(t, manager.pools, 2) + + blacklisted := manager.cleanUp() + require.Empty(t, blacklisted) + // trigger cleanup and outdated pool should be removed + require.Len(t, manager.pools, 1) + }) + + t.Run("shrexSub sends a message lower than first headerSub header height, shrexSub first", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + h := testHeader() + h.RawHeader.Height = 100 + headerSub := newSubLock(h, nil) + + // start test manager + manager, err := testManager(ctx, headerSub) + require.NoError(t, err) + + // create shrexSub msg with height lower than first header from headerSub + msg := shrexsub.Notification{ + DataHash: share.DataHash("datahash"), + Height: h.Height() - 1, + } + result := manager.Validate(ctx, "peer", msg) + require.Equal(t, pubsub.ValidationIgnore, result) + + // pool will be created for first shrexSub message + require.Len(t, manager.pools, 1) + + // unlock headerSub to allow it to send next message + require.NoError(t, headerSub.wait(ctx, 1)) + // second pool should be created + require.Len(t, manager.pools, 2) + + // trigger cleanup and outdated pool should be removed + blacklisted := manager.cleanUp() + require.Len(t, manager.pools, 1) + + // check that no peers or hashes were blacklisted + manager.params.PoolValidationTimeout = 0 + require.Len(t, blacklisted, 0) + require.Len(t, manager.blacklistedHashes, 0) + }) + + t.Run("pools store window", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + h := testHeader() + h.RawHeader.Height = storedPoolsAmount * 2 + headerSub := newSubLock(h, nil) + + // start test manager + manager, err := testManager(ctx, headerSub) + require.NoError(t, err) + + // unlock headerSub to read first header + require.NoError(t, headerSub.wait(ctx, 1)) + // pool will be created for first headerSub header datahash + require.Len(t, manager.pools, 1) + + // create shrexSub msg with height lower than storedPoolsAmount + msg := shrexsub.Notification{ + DataHash: share.DataHash("datahash"), + Height: h.Height() - storedPoolsAmount - 3, + } + result := manager.Validate(ctx, "peer", msg) + require.Equal(t, pubsub.ValidationIgnore, result) + + // shrexSub message should be discarded and amount of pools should not change + require.Len(t, manager.pools, 1) + }) +} + +func TestIntegration(t *testing.T) { + t.Run("get peer from shrexsub", func(t *testing.T) { + nw, err := mocknet.FullMeshLinked(2) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + bnPubSub, err := shrexsub.NewPubSub(ctx, nw.Hosts()[0], "test") + require.NoError(t, err) + + fnPubSub, err := shrexsub.NewPubSub(ctx, nw.Hosts()[1], "test") + require.NoError(t, err) + + require.NoError(t, bnPubSub.Start(ctx)) + require.NoError(t, fnPubSub.Start(ctx)) + + fnPeerManager, err := testManager(ctx, newSubLock()) + require.NoError(t, err) + fnPeerManager.host = nw.Hosts()[1] + + require.NoError(t, fnPubSub.AddValidator(fnPeerManager.Validate)) + _, err = fnPubSub.Subscribe() + require.NoError(t, err) + + time.Sleep(time.Millisecond * 100) + require.NoError(t, nw.ConnectAllButSelf()) + time.Sleep(time.Millisecond * 100) + + // broadcast from BN + randHash := rand.Bytes(32) + require.NoError(t, bnPubSub.Broadcast(ctx, shrexsub.Notification{ + DataHash: randHash, + Height: 1, + })) + + // FN should get message + gotPeer, _, err := fnPeerManager.Peer(ctx, randHash, 13) + require.NoError(t, err) + + // check that gotPeer matched bridge node + require.Equal(t, nw.Hosts()[0].ID(), gotPeer) + }) + + t.Run("get peer from discovery", func(t *testing.T) { + fullNodesTag := "fullNodes" + nw, err := mocknet.FullMeshConnected(3) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + t.Cleanup(cancel) + + // set up bootstrapper + bsHost := nw.Hosts()[0] + bs := host.InfoFromHost(bsHost) + opts := []dht.Option{ + dht.Mode(dht.ModeAuto), + dht.BootstrapPeers(*bs), + dht.RoutingTableRefreshPeriod(time.Second), + } + + bsOpts := opts + bsOpts = append(bsOpts, + dht.Mode(dht.ModeServer), // it must accept incoming connections + dht.BootstrapPeers(), // no bootstrappers for a bootstrapper ¯\_(ツ)_/¯ + ) + bsRouter, err := dht.New(ctx, bsHost, bsOpts...) + require.NoError(t, err) + require.NoError(t, bsRouter.Bootstrap(ctx)) + + // set up broadcaster node + bnHost := nw.Hosts()[1] + bnRouter, err := dht.New(ctx, bnHost, opts...) + require.NoError(t, err) + + params := discovery2.DefaultParameters() + params.AdvertiseInterval = time.Second + + bnDisc, err := discovery2.NewDiscovery( + params, + bnHost, + routingdisc.NewRoutingDiscovery(bnRouter), + fullNodesTag, + ) + require.NoError(t, err) + + // set up full node / receiver node + fnHost := nw.Hosts()[2] + fnRouter, err := dht.New(ctx, fnHost, opts...) + require.NoError(t, err) + + // init peer manager for full node + connGater, err := conngater.NewBasicConnectionGater(dssync.MutexWrap(datastore.NewMapDatastore())) + require.NoError(t, err) + fnPeerManager, err := NewManager( + DefaultParameters(), + nil, + connGater, + ) + require.NoError(t, err) + + waitCh := make(chan struct{}) + checkDiscoveredPeer := func(peerID peer.ID, isAdded bool) { + defer close(waitCh) + // check that obtained peer id is BN + require.Equal(t, bnHost.ID(), peerID) + } + + // set up discovery for full node with hook to peer manager and check discovered peer + params = discovery2.DefaultParameters() + params.AdvertiseInterval = time.Second + params.PeersLimit = 10 + + fnDisc, err := discovery2.NewDiscovery( + params, + fnHost, + routingdisc.NewRoutingDiscovery(fnRouter), + fullNodesTag, + discovery2.WithOnPeersUpdate(fnPeerManager.UpdateNodePool), + discovery2.WithOnPeersUpdate(checkDiscoveredPeer), + ) + require.NoError(t, fnDisc.Start(ctx)) + t.Cleanup(func() { + err = fnDisc.Stop(ctx) + require.NoError(t, err) + }) + + require.NoError(t, bnRouter.Bootstrap(ctx)) + require.NoError(t, fnRouter.Bootstrap(ctx)) + + go bnDisc.Advertise(ctx) + + select { + case <-waitCh: + require.Contains(t, fnPeerManager.nodes.peersList, bnHost.ID()) + case <-ctx.Done(): + require.NoError(t, ctx.Err()) + } + }) +} + +func testManager(ctx context.Context, headerSub libhead.Subscriber[*header.ExtendedHeader]) (*Manager, error) { + host, err := mocknet.New().GenPeer() + if err != nil { + return nil, err + } + shrexSub, err := shrexsub.NewPubSub(ctx, host, "test") + if err != nil { + return nil, err + } + + connGater, err := conngater.NewBasicConnectionGater(dssync.MutexWrap(datastore.NewMapDatastore())) + if err != nil { + return nil, err + } + manager, err := NewManager( + DefaultParameters(), + host, + connGater, + WithShrexSubPools(shrexSub, headerSub), + ) + if err != nil { + return nil, err + } + err = manager.Start(ctx) + return manager, err +} + +func stopManager(t *testing.T, m *Manager) { + closeCtx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + require.NoError(t, m.Stop(closeCtx)) +} + +func testHeader() *header.ExtendedHeader { + return &header.ExtendedHeader{ + RawHeader: header.RawHeader{ + Height: 1, + DataHash: rand.Bytes(32), + }, + } +} + +type subLock struct { + next chan struct{} + wg *sync.WaitGroup + expected []*header.ExtendedHeader +} + +func (s subLock) wait(ctx context.Context, count int) error { + s.wg.Add(count) + for i := 0; i < count; i++ { + err := s.release(ctx) + if err != nil { + return err + } + } + s.wg.Wait() + return nil +} + +func (s subLock) release(ctx context.Context) error { + select { + case s.next <- struct{}{}: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func newSubLock(expected ...*header.ExtendedHeader) *subLock { + wg := &sync.WaitGroup{} + wg.Add(1) + return &subLock{ + next: make(chan struct{}), + expected: expected, + wg: wg, + } +} + +func (s *subLock) Subscribe() (libhead.Subscription[*header.ExtendedHeader], error) { + return s, nil +} + +func (s *subLock) SetVerifier(func(context.Context, *header.ExtendedHeader) error) error { + panic("implement me") +} + +func (s *subLock) NextHeader(ctx context.Context) (*header.ExtendedHeader, error) { + s.wg.Done() + + // wait for call to be unlocked by release + select { + case <-s.next: + h := s.expected[0] + s.expected = s.expected[1:] + return h, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (s *subLock) Cancel() { +} + +func newShrexSubMsg(h *header.ExtendedHeader) shrexsub.Notification { + return shrexsub.Notification{ + DataHash: h.DataHash.Bytes(), + Height: h.Height(), + } +} diff --git a/share/shwap/p2p/shrex/peers/metrics.go b/share/shwap/p2p/shrex/peers/metrics.go new file mode 100644 index 0000000000..da52856425 --- /dev/null +++ b/share/shwap/p2p/shrex/peers/metrics.go @@ -0,0 +1,276 @@ +package peers + +import ( + "context" + "fmt" + "sync" + "time" + + pubsub "github.com/libp2p/go-libp2p-pubsub" + "github.com/libp2p/go-libp2p/core/peer" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + + "github.com/celestiaorg/celestia-node/libs/utils" + "github.com/celestiaorg/celestia-node/share/p2p/shrexsub" +) + +const ( + isInstantKey = "is_instant" + doneResultKey = "done_result" + + sourceKey = "source" + sourceShrexSub peerSource = "shrexsub" + sourceFullNodes peerSource = "full_nodes" + + blacklistPeerReasonKey = "blacklist_reason" + reasonInvalidHash blacklistPeerReason = "invalid_hash" + reasonMisbehave blacklistPeerReason = "misbehave" + + validationResultKey = "validation_result" + validationAccept = "accept" + validationReject = "reject" + validationIgnore = "ignore" + + peerStatusKey = "peer_status" + peerStatusActive peerStatus = "active" + peerStatusCooldown peerStatus = "cooldown" + + poolStatusKey = "pool_status" + poolStatusCreated poolStatus = "created" + poolStatusValidated poolStatus = "validated" + poolStatusBlacklisted poolStatus = "blacklisted" + // Pool status model: + // created(unvalidated) + // / \ + // validated blacklisted +) + +var meter = otel.Meter("shrex_peer_manager") + +type blacklistPeerReason string + +type peerStatus string + +type poolStatus string + +type peerSource string + +type metrics struct { + getPeer metric.Int64Counter // attributes: source, is_instant + getPeerWaitTimeHistogram metric.Int64Histogram // attributes: source + getPeerPoolSizeHistogram metric.Int64Histogram // attributes: source + doneResult metric.Int64Counter // attributes: source, done_result + validationResult metric.Int64Counter // attributes: validation_result + + shrexPools metric.Int64ObservableGauge // attributes: pool_status + fullNodesPool metric.Int64ObservableGauge // attributes: pool_status + blacklistedPeersByReason sync.Map + blacklistedPeers metric.Int64ObservableGauge // attributes: blacklist_reason + + clientReg metric.Registration +} + +func initMetrics(manager *Manager) (*metrics, error) { + getPeer, err := meter.Int64Counter("peer_manager_get_peer_counter", + metric.WithDescription("get peer counter")) + if err != nil { + return nil, err + } + + getPeerWaitTimeHistogram, err := meter.Int64Histogram("peer_manager_get_peer_ms_time_hist", + metric.WithDescription("get peer time histogram(ms), observed only for async get(is_instant = false)")) + if err != nil { + return nil, err + } + + getPeerPoolSizeHistogram, err := meter.Int64Histogram("peer_manager_get_peer_pool_size_hist", + metric.WithDescription("amount of available active peers in pool at time when get was called")) + if err != nil { + return nil, err + } + + doneResult, err := meter.Int64Counter("peer_manager_done_result_counter", + metric.WithDescription("done results counter")) + if err != nil { + return nil, err + } + + validationResult, err := meter.Int64Counter("peer_manager_validation_result_counter", + metric.WithDescription("validation result counter")) + if err != nil { + return nil, err + } + + shrexPools, err := meter.Int64ObservableGauge("peer_manager_pools_gauge", + metric.WithDescription("pools amount")) + if err != nil { + return nil, err + } + + fullNodesPool, err := meter.Int64ObservableGauge("peer_manager_full_nodes_gauge", + metric.WithDescription("full nodes pool peers amount")) + if err != nil { + return nil, err + } + + blacklisted, err := meter.Int64ObservableGauge("peer_manager_blacklisted_peers", + metric.WithDescription("blacklisted peers amount")) + if err != nil { + return nil, err + } + + metrics := &metrics{ + getPeer: getPeer, + getPeerWaitTimeHistogram: getPeerWaitTimeHistogram, + doneResult: doneResult, + validationResult: validationResult, + shrexPools: shrexPools, + fullNodesPool: fullNodesPool, + getPeerPoolSizeHistogram: getPeerPoolSizeHistogram, + blacklistedPeers: blacklisted, + } + + callback := func(_ context.Context, observer metric.Observer) error { + for poolStatus, count := range manager.shrexPools() { + observer.ObserveInt64(shrexPools, count, + metric.WithAttributes( + attribute.String(poolStatusKey, string(poolStatus)))) + } + + observer.ObserveInt64(fullNodesPool, int64(manager.nodes.len()), + metric.WithAttributes( + attribute.String(peerStatusKey, string(peerStatusActive)))) + observer.ObserveInt64(fullNodesPool, int64(manager.nodes.cooldown.len()), + metric.WithAttributes( + attribute.String(peerStatusKey, string(peerStatusCooldown)))) + + metrics.blacklistedPeersByReason.Range(func(key, value any) bool { + reason := key.(blacklistPeerReason) + amount := value.(int) + observer.ObserveInt64(blacklisted, int64(amount), + metric.WithAttributes( + attribute.String(blacklistPeerReasonKey, string(reason)))) + return true + }) + return nil + } + metrics.clientReg, err = meter.RegisterCallback(callback, shrexPools, fullNodesPool, blacklisted) + if err != nil { + return nil, fmt.Errorf("registering metrics callback: %w", err) + } + return metrics, nil +} + +func (m *metrics) close() error { + if m == nil { + return nil + } + return m.clientReg.Unregister() +} + +func (m *metrics) observeGetPeer( + ctx context.Context, + source peerSource, poolSize int, waitTime time.Duration, +) { + if m == nil { + return + } + ctx = utils.ResetContextOnError(ctx) + m.getPeer.Add(ctx, 1, + metric.WithAttributes( + attribute.String(sourceKey, string(source)), + attribute.Bool(isInstantKey, waitTime == 0))) + if source == sourceShrexSub { + m.getPeerPoolSizeHistogram.Record(ctx, int64(poolSize), + metric.WithAttributes( + attribute.String(sourceKey, string(source)))) + } + + // record wait time only for async gets + if waitTime > 0 { + m.getPeerWaitTimeHistogram.Record(ctx, waitTime.Milliseconds(), + metric.WithAttributes( + attribute.String(sourceKey, string(source)))) + } +} + +func (m *metrics) observeDoneResult(source peerSource, result result) { + if m == nil { + return + } + + ctx := context.Background() + m.doneResult.Add(ctx, 1, + metric.WithAttributes( + attribute.String(sourceKey, string(source)), + attribute.String(doneResultKey, string(result)))) +} + +// validationObserver is a middleware that observes validation results as metrics +func (m *metrics) validationObserver(validator shrexsub.ValidatorFn) shrexsub.ValidatorFn { + if m == nil { + return validator + } + return func(ctx context.Context, id peer.ID, n shrexsub.Notification) pubsub.ValidationResult { + res := validator(ctx, id, n) + + var resStr string + switch res { + case pubsub.ValidationAccept: + resStr = validationAccept + case pubsub.ValidationReject: + resStr = validationReject + case pubsub.ValidationIgnore: + resStr = validationIgnore + default: + resStr = "unknown" + } + + ctx = utils.ResetContextOnError(ctx) + + m.validationResult.Add(ctx, 1, + metric.WithAttributes( + attribute.String(validationResultKey, resStr))) + return res + } +} + +// observeBlacklistPeers stores amount of blacklisted peers by reason +func (m *metrics) observeBlacklistPeers(reason blacklistPeerReason, amount int) { + if m == nil { + return + } + for { + prevVal, loaded := m.blacklistedPeersByReason.LoadOrStore(reason, amount) + if !loaded { + return + } + + newVal := prevVal.(int) + amount + if m.blacklistedPeersByReason.CompareAndSwap(reason, prevVal, newVal) { + return + } + } +} + +// shrexPools collects amount of shrex pools by poolStatus +func (m *Manager) shrexPools() map[poolStatus]int64 { + m.lock.Lock() + defer m.lock.Unlock() + + shrexPools := make(map[poolStatus]int64) + for _, p := range m.pools { + if !p.isValidatedDataHash.Load() { + shrexPools[poolStatusCreated]++ + continue + } + + // pool is validated but not synced + shrexPools[poolStatusValidated]++ + } + + shrexPools[poolStatusBlacklisted] = int64(len(m.blacklistedHashes)) + return shrexPools +} diff --git a/share/shwap/p2p/shrex/peers/options.go b/share/shwap/p2p/shrex/peers/options.go new file mode 100644 index 0000000000..2970dd2465 --- /dev/null +++ b/share/shwap/p2p/shrex/peers/options.go @@ -0,0 +1,84 @@ +package peers + +import ( + "fmt" + "time" + + libhead "github.com/celestiaorg/go-header" + + "github.com/celestiaorg/celestia-node/header" + "github.com/celestiaorg/celestia-node/share/p2p/shrexsub" +) + +type Parameters struct { + // PoolValidationTimeout is the timeout used for validating incoming datahashes. Pools that have + // been created for datahashes from shrexsub that do not see this hash from headersub after this + // timeout will be garbage collected. + PoolValidationTimeout time.Duration + + // PeerCooldown is the time a peer is put on cooldown after a ResultCooldownPeer. + PeerCooldown time.Duration + + // GcInterval is the interval at which the manager will garbage collect unvalidated pools. + GcInterval time.Duration + + // EnableBlackListing turns on blacklisting for misbehaved peers + EnableBlackListing bool +} + +type Option func(*Manager) error + +// Validate validates the values in Parameters +func (p *Parameters) Validate() error { + if p.PoolValidationTimeout <= 0 { + return fmt.Errorf("peer-manager: validation timeout must be positive") + } + + if p.PeerCooldown <= 0 { + return fmt.Errorf("peer-manager: peer cooldown must be positive") + } + + if p.GcInterval <= 0 { + return fmt.Errorf("peer-manager: garbage collection interval must be positive") + } + + return nil +} + +// DefaultParameters returns the default configuration values for the peer manager parameters +func DefaultParameters() Parameters { + return Parameters{ + // PoolValidationTimeout's default value is based on the default daser sampling timeout of 1 minute. + // If a received datahash has not tried to be sampled within these two minutes, the pool will be + // removed. + PoolValidationTimeout: 2 * time.Minute, + // PeerCooldown's default value is based on initial network tests that showed a ~3.5 second + // sync time for large blocks. This value gives our (discovery) peers enough time to sync + // the new block before we ask them again. + PeerCooldown: 3 * time.Second, + GcInterval: time.Second * 30, + // blacklisting is off by default //TODO(@walldiss): enable blacklisting once all related issues + // are resolved + EnableBlackListing: false, + } +} + +// WithShrexSubPools passes a shrexsub and headersub instance to be used to populate and validate +// pools from shrexsub notifications. +func WithShrexSubPools(shrexSub *shrexsub.PubSub, headerSub libhead.Subscriber[*header.ExtendedHeader]) Option { + return func(m *Manager) error { + m.shrexSub = shrexSub + m.headerSub = headerSub + return nil + } +} + +// WithMetrics turns on metric collection in peer manager. +func (m *Manager) WithMetrics() error { + metrics, err := initMetrics(m) + if err != nil { + return fmt.Errorf("peer-manager: init metrics: %w", err) + } + m.metrics = metrics + return nil +} diff --git a/share/shwap/p2p/shrex/peers/pool.go b/share/shwap/p2p/shrex/peers/pool.go new file mode 100644 index 0000000000..365ef0306d --- /dev/null +++ b/share/shwap/p2p/shrex/peers/pool.go @@ -0,0 +1,226 @@ +package peers + +import ( + "context" + "sync" + "time" + + "github.com/libp2p/go-libp2p/core/peer" +) + +const defaultCleanupThreshold = 2 + +// pool stores peers and provides methods for simple round-robin access. +type pool struct { + m sync.RWMutex + peersList []peer.ID + statuses map[peer.ID]status + cooldown *timedQueue + activeCount int + nextIdx int + + hasPeer bool + hasPeerCh chan struct{} + + cleanupThreshold int +} + +type status int + +const ( + active status = iota + cooldown + removed +) + +// newPool returns new empty pool. +func newPool(peerCooldownTime time.Duration) *pool { + p := &pool{ + peersList: make([]peer.ID, 0), + statuses: make(map[peer.ID]status), + hasPeerCh: make(chan struct{}), + cleanupThreshold: defaultCleanupThreshold, + } + p.cooldown = newTimedQueue(peerCooldownTime, p.afterCooldown) + return p +} + +// tryGet returns peer along with bool flag indicating success of operation. +func (p *pool) tryGet() (peer.ID, bool) { + p.m.Lock() + defer p.m.Unlock() + + if p.activeCount == 0 { + return "", false + } + + // if pointer is out of range, point to first element + if p.nextIdx > len(p.peersList)-1 { + p.nextIdx = 0 + } + + start := p.nextIdx + for { + peerID := p.peersList[p.nextIdx] + + p.nextIdx++ + if p.nextIdx == len(p.peersList) { + p.nextIdx = 0 + } + + if p.statuses[peerID] == active { + return peerID, true + } + + // full circle passed + if p.nextIdx == start { + return "", false + } + } +} + +// next sends a peer to the returned channel when it becomes available. +func (p *pool) next(ctx context.Context) <-chan peer.ID { + peerCh := make(chan peer.ID, 1) + go func() { + for { + if peerID, ok := p.tryGet(); ok { + peerCh <- peerID + return + } + + p.m.RLock() + hasPeerCh := p.hasPeerCh + p.m.RUnlock() + select { + case <-hasPeerCh: + case <-ctx.Done(): + return + } + } + }() + return peerCh +} + +func (p *pool) add(peers ...peer.ID) { + p.m.Lock() + defer p.m.Unlock() + + for _, peerID := range peers { + status, ok := p.statuses[peerID] + if ok && status != removed { + continue + } + + if !ok { + p.peersList = append(p.peersList, peerID) + } + + p.statuses[peerID] = active + p.activeCount++ + } + p.checkHasPeers() +} + +func (p *pool) remove(peers ...peer.ID) { + p.m.Lock() + defer p.m.Unlock() + + for _, peerID := range peers { + if status, ok := p.statuses[peerID]; ok && status != removed { + p.statuses[peerID] = removed + if status == active { + p.activeCount-- + } + } + } + + // do cleanup if too much garbage + if len(p.peersList) >= p.activeCount+p.cleanupThreshold { + p.cleanup() + } + p.checkHasPeers() +} + +func (p *pool) has(peer peer.ID) bool { + p.m.RLock() + defer p.m.RUnlock() + + status, ok := p.statuses[peer] + return ok && status != removed +} + +func (p *pool) peers() []peer.ID { + p.m.RLock() + defer p.m.RUnlock() + + peers := make([]peer.ID, 0, len(p.peersList)) + for peer, status := range p.statuses { + if status != removed { + peers = append(peers, peer) + } + } + return peers +} + +// cleanup will reduce memory footprint of pool. +func (p *pool) cleanup() { + newList := make([]peer.ID, 0, p.activeCount) + for _, peerID := range p.peersList { + status := p.statuses[peerID] + switch status { + case active, cooldown: + newList = append(newList, peerID) + case removed: + delete(p.statuses, peerID) + } + } + p.peersList = newList +} + +func (p *pool) putOnCooldown(peerID peer.ID) { + p.m.Lock() + defer p.m.Unlock() + + if status, ok := p.statuses[peerID]; ok && status == active { + p.cooldown.push(peerID) + + p.statuses[peerID] = cooldown + p.activeCount-- + p.checkHasPeers() + } +} + +func (p *pool) afterCooldown(peerID peer.ID) { + p.m.Lock() + defer p.m.Unlock() + + // item could have been already removed by the time afterCooldown is called + if status, ok := p.statuses[peerID]; !ok || status != cooldown { + return + } + + p.statuses[peerID] = active + p.activeCount++ + p.checkHasPeers() +} + +// checkHasPeers will check and indicate if there are peers in the pool. +func (p *pool) checkHasPeers() { + if p.activeCount > 0 && !p.hasPeer { + p.hasPeer = true + close(p.hasPeerCh) + return + } + + if p.activeCount == 0 && p.hasPeer { + p.hasPeerCh = make(chan struct{}) + p.hasPeer = false + } +} + +func (p *pool) len() int { + p.m.RLock() + defer p.m.RUnlock() + return p.activeCount +} diff --git a/share/shwap/p2p/shrex/peers/pool_test.go b/share/shwap/p2p/shrex/peers/pool_test.go new file mode 100644 index 0000000000..ac9d38f261 --- /dev/null +++ b/share/shwap/p2p/shrex/peers/pool_test.go @@ -0,0 +1,184 @@ +package peers + +import ( + "context" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/peer" + "github.com/stretchr/testify/require" +) + +func TestPool(t *testing.T) { + t.Run("add / remove peers", func(t *testing.T) { + p := newPool(time.Second) + + peers := []peer.ID{"peer1", "peer1", "peer2", "peer3"} + // adding same peer twice should not produce copies + p.add(peers...) + require.Equal(t, len(peers)-1, p.activeCount) + + p.remove("peer1", "peer2") + require.Equal(t, len(peers)-3, p.activeCount) + + peerID, ok := p.tryGet() + require.True(t, ok) + require.Equal(t, peers[3], peerID) + + p.remove("peer3") + p.remove("peer3") + require.Equal(t, 0, p.activeCount) + _, ok = p.tryGet() + require.False(t, ok) + }) + + t.Run("round robin", func(t *testing.T) { + p := newPool(time.Second) + + peers := []peer.ID{"peer1", "peer1", "peer2", "peer3"} + // adding same peer twice should not produce copies + p.add(peers...) + require.Equal(t, 3, p.activeCount) + + peerID, ok := p.tryGet() + require.True(t, ok) + require.Equal(t, peer.ID("peer1"), peerID) + + peerID, ok = p.tryGet() + require.True(t, ok) + require.Equal(t, peer.ID("peer2"), peerID) + + peerID, ok = p.tryGet() + require.True(t, ok) + require.Equal(t, peer.ID("peer3"), peerID) + + peerID, ok = p.tryGet() + require.True(t, ok) + require.Equal(t, peer.ID("peer1"), peerID) + + p.remove("peer2", "peer3") + require.Equal(t, 1, p.activeCount) + + // pointer should skip removed items until found active one + peerID, ok = p.tryGet() + require.True(t, ok) + require.Equal(t, peer.ID("peer1"), peerID) + }) + + t.Run("wait for peer", func(t *testing.T) { + timeout := time.Second + shortCtx, cancel := context.WithTimeout(context.Background(), timeout/10) + t.Cleanup(cancel) + + longCtx, cancel := context.WithTimeout(context.Background(), timeout) + t.Cleanup(cancel) + + p := newPool(time.Second) + done := make(chan struct{}) + + go func() { + select { + case <-p.next(shortCtx): + case <-shortCtx.Done(): + require.Error(t, shortCtx.Err()) + // unlock longCtx waiter by adding new peer + p.add("peer1") + } + }() + + go func() { + defer close(done) + select { + case peerID := <-p.next(longCtx): + require.Equal(t, peer.ID("peer1"), peerID) + case <-longCtx.Done(): + require.NoError(t, longCtx.Err()) + } + }() + + select { + case <-done: + case <-longCtx.Done(): + require.NoError(t, longCtx.Err()) + } + }) + + t.Run("nextIdx got removed", func(t *testing.T) { + p := newPool(time.Second) + + peers := []peer.ID{"peer1", "peer2", "peer3"} + p.add(peers...) + p.nextIdx = 2 + p.remove(peers[p.nextIdx]) + + // if previous nextIdx was removed, tryGet should iterate until available peer found + peerID, ok := p.tryGet() + require.True(t, ok) + require.Equal(t, peers[0], peerID) + }) + + t.Run("cleanup", func(t *testing.T) { + p := newPool(time.Second) + p.cleanupThreshold = 3 + + peers := []peer.ID{"peer1", "peer2", "peer3", "peer4", "peer5"} + p.add(peers...) + require.Equal(t, len(peers), p.activeCount) + + // point to last element that will be removed, to check how pointer will be updated + p.nextIdx = len(peers) - 1 + + // remove some, but not trigger cleanup yet + p.remove(peers[3:]...) + require.Equal(t, len(peers)-2, p.activeCount) + require.Equal(t, len(peers), len(p.statuses)) + + // trigger cleanup + p.remove(peers[2]) + require.Equal(t, len(peers)-3, p.activeCount) + require.Equal(t, len(peers)-3, len(p.statuses)) + + // nextIdx pointer should be updated after next tryGet + p.tryGet() + require.Equal(t, 1, p.nextIdx) + }) + + t.Run("cooldown blocks get", func(t *testing.T) { + ttl := time.Second / 10 + p := newPool(ttl) + + peerID := peer.ID("peer1") + p.add(peerID) + + _, ok := p.tryGet() + require.True(t, ok) + + p.putOnCooldown(peerID) + // item should be unavailable + _, ok = p.tryGet() + require.False(t, ok) + + ctx, cancel := context.WithTimeout(context.Background(), ttl*5) + defer cancel() + select { + case <-p.next(ctx): + case <-ctx.Done(): + t.Fatal("item should be already available") + } + }) + + t.Run("put on cooldown removed item should be noop", func(t *testing.T) { + p := newPool(time.Second) + p.cleanupThreshold = 3 + + peerID := peer.ID("peer1") + p.add(peerID) + + p.remove(peerID) + p.cleanup() + p.putOnCooldown(peerID) + + _, ok := p.tryGet() + require.False(t, ok) + }) +} diff --git a/share/shwap/p2p/shrex/peers/timedqueue.go b/share/shwap/p2p/shrex/peers/timedqueue.go new file mode 100644 index 0000000000..3ed7e29a2c --- /dev/null +++ b/share/shwap/p2p/shrex/peers/timedqueue.go @@ -0,0 +1,91 @@ +package peers + +import ( + "sync" + "time" + + "github.com/benbjohnson/clock" + "github.com/libp2p/go-libp2p/core/peer" +) + +// timedQueue store items for ttl duration and releases it with calling onPop callback. Each item +// is tracked independently +type timedQueue struct { + sync.Mutex + items []item + + // ttl is the amount of time each item exist in the timedQueue + ttl time.Duration + clock clock.Clock + after *clock.Timer + // onPop will be called on item peer.ID after it is released + onPop func(peer.ID) +} + +type item struct { + peer.ID + createdAt time.Time +} + +func newTimedQueue(ttl time.Duration, onPop func(peer.ID)) *timedQueue { + return &timedQueue{ + items: make([]item, 0), + clock: clock.New(), + ttl: ttl, + onPop: onPop, + } +} + +// releaseExpired will release all expired items +func (q *timedQueue) releaseExpired() { + q.Lock() + defer q.Unlock() + q.releaseUnsafe() +} + +func (q *timedQueue) releaseUnsafe() { + if len(q.items) == 0 { + return + } + + var i int + for _, next := range q.items { + timeIn := q.clock.Since(next.createdAt) + if timeIn < q.ttl { + // item is not expired yet, create a timer that will call releaseExpired + q.after.Stop() + q.after = q.clock.AfterFunc(q.ttl-timeIn, q.releaseExpired) + break + } + + // item is expired + q.onPop(next.ID) + i++ + } + + if i > 0 { + copy(q.items, q.items[i:]) + q.items = q.items[:len(q.items)-i] + } +} + +func (q *timedQueue) push(peerID peer.ID) { + q.Lock() + defer q.Unlock() + + q.items = append(q.items, item{ + ID: peerID, + createdAt: q.clock.Now(), + }) + + // if it is the first item in queue, create a timer to call releaseExpired after its expiration + if len(q.items) == 1 { + q.after = q.clock.AfterFunc(q.ttl, q.releaseExpired) + } +} + +func (q *timedQueue) len() int { + q.Lock() + defer q.Unlock() + return len(q.items) +} diff --git a/share/shwap/p2p/shrex/peers/timedqueue_test.go b/share/shwap/p2p/shrex/peers/timedqueue_test.go new file mode 100644 index 0000000000..9cfae0e6b2 --- /dev/null +++ b/share/shwap/p2p/shrex/peers/timedqueue_test.go @@ -0,0 +1,60 @@ +package peers + +import ( + "testing" + "time" + + "github.com/benbjohnson/clock" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/stretchr/testify/require" +) + +func TestTimedQueue(t *testing.T) { + t.Run("push item", func(t *testing.T) { + peers := []peer.ID{"peer1", "peer2"} + ttl := time.Second + + popCh := make(chan struct{}, 1) + queue := newTimedQueue(ttl, func(id peer.ID) { + go func() { + require.Contains(t, peers, id) + popCh <- struct{}{} + }() + }) + mock := clock.NewMock() + queue.clock = mock + + // push first item | global time : 0 + queue.push(peers[0]) + require.Equal(t, queue.len(), 1) + + // push second item with ttl/2 gap | global time : ttl/2 + mock.Add(ttl / 2) + queue.push(peers[1]) + require.Equal(t, queue.len(), 2) + + // advance clock 1 nano sec before first item should expire | global time : ttl - 1 + mock.Add(ttl/2 - 1) + // check that releaseExpired doesn't remove items + queue.releaseExpired() + require.Equal(t, queue.len(), 2) + // first item should be released after its own timeout | global time : ttl + mock.Add(1) + + select { + case <-popCh: + case <-time.After(ttl): + t.Fatal("first item is not released") + } + require.Equal(t, queue.len(), 1) + + // first item should be released after ttl/2 gap timeout | global time : 3/2*ttl + mock.Add(ttl / 2) + select { + case <-popCh: + case <-time.After(ttl): + t.Fatal("second item is not released") + } + require.Equal(t, queue.len(), 0) + }) +} diff --git a/share/shwap/p2p/shrex/recovery.go b/share/shwap/p2p/shrex/recovery.go new file mode 100644 index 0000000000..67bcb98d73 --- /dev/null +++ b/share/shwap/p2p/shrex/recovery.go @@ -0,0 +1,21 @@ +package shrex + +import ( + "fmt" + + "github.com/libp2p/go-libp2p/core/network" +) + +// RecoveryMiddleware is a middleware that recovers from panics in the handler. +func RecoveryMiddleware(handler network.StreamHandler) network.StreamHandler { + return func(stream network.Stream) { + defer func() { + r := recover() + if r != nil { + err := fmt.Errorf("PANIC while handling request: %s", r) + log.Error(err) + } + }() + handler(stream) + } +} diff --git a/share/shwap/p2p/shrex/shrexeds/client.go b/share/shwap/p2p/shrex/shrexeds/client.go new file mode 100644 index 0000000000..c19bdce480 --- /dev/null +++ b/share/shwap/p2p/shrex/shrexeds/client.go @@ -0,0 +1,219 @@ +package shrexeds + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "time" + + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" + + "github.com/celestiaorg/go-libp2p-messenger/serde" + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" + eds "github.com/celestiaorg/celestia-node/share/new_eds" + "github.com/celestiaorg/celestia-node/share/shwap" + "github.com/celestiaorg/celestia-node/share/shwap/p2p/shrex" + pb "github.com/celestiaorg/celestia-node/share/shwap/p2p/shrex/shrexeds/pb" +) + +// Client is responsible for requesting EDSs for blocksync over the ShrEx/EDS protocol. +type Client struct { + params *Parameters + protocolID protocol.ID + host host.Host + + metrics *shrex.Metrics +} + +// NewClient creates a new ShrEx/EDS client. +func NewClient(params *Parameters, host host.Host) (*Client, error) { + if err := params.Validate(); err != nil { + return nil, fmt.Errorf("shrex-eds: client creation failed: %w", err) + } + + return &Client{ + params: params, + host: host, + protocolID: shrex.ProtocolID(params.NetworkID(), protocolString), + }, nil +} + +// RequestEDS requests the ODS from the given peers and returns the EDS upon success. +func (c *Client) RequestEDS( + ctx context.Context, + root *share.Root, + height uint64, + peer peer.ID, +) (*rsmt2d.ExtendedDataSquare, error) { + eds, err := c.doRequest(ctx, root, height, peer) + if err == nil { + return eds, nil + } + log.Debugw("client: eds request to peer failed", + "height", height, + "peer", peer.String(), + "error", err) + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + c.metrics.ObserveRequests(ctx, 1, shrex.StatusTimeout) + return nil, err + } + // some net.Errors also mean the context deadline was exceeded, but yamux/mocknet do not + // unwrap to a ctx err + var ne net.Error + if errors.As(err, &ne) && ne.Timeout() { + if deadline, _ := ctx.Deadline(); deadline.Before(time.Now()) { + c.metrics.ObserveRequests(ctx, 1, shrex.StatusTimeout) + return nil, context.DeadlineExceeded + } + } + if !errors.Is(err, shrex.ErrNotFound) { + log.Warnw("client: eds request to peer failed", + "peer", peer.String(), + "height", height, + "err", err) + } + + return nil, err +} + +func (c *Client) doRequest( + ctx context.Context, + root *share.Root, + height uint64, + to peer.ID, +) (*rsmt2d.ExtendedDataSquare, error) { + streamOpenCtx, cancel := context.WithTimeout(ctx, c.params.ServerReadTimeout) + defer cancel() + stream, err := c.host.NewStream(streamOpenCtx, to, c.protocolID) + if err != nil { + return nil, fmt.Errorf("open stream: %w", err) + } + defer stream.Close() + + c.setStreamDeadlines(ctx, stream) + + req, err := shwap.NewEdsID(height) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + rb, err := req.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + // request ODS + log.Debugw("client: requesting ods", + "height", height, + "peer", to.String()) + _, err = stream.Write(rb) + if err != nil { + stream.Reset() //nolint:errcheck + return nil, fmt.Errorf("write request to stream: %w", err) + } + err = stream.CloseWrite() + if err != nil { + log.Debugw("client: error closing write", "err", err) + } + + // read and parse status from peer + resp := new(pb.EDSResponse) + err = stream.SetReadDeadline(time.Now().Add(c.params.ServerReadTimeout)) + if err != nil { + log.Debugw("client: failed to set read deadline for reading status", "err", err) + } + _, err = serde.Read(stream, resp) + if err != nil { + // server closes the stream here if we are rate limited + if errors.Is(err, io.EOF) { + c.metrics.ObserveRequests(ctx, 1, shrex.StatusRateLimited) + return nil, shrex.ErrNotFound + } + stream.Reset() //nolint:errcheck + return nil, fmt.Errorf("read status from stream: %w", err) + } + + switch resp.Status { + case pb.Status_OK: + // reset stream deadlines to original values, since read deadline was changed during status read + c.setStreamDeadlines(ctx, stream) + // use header and ODS bytes to construct EDS and verify it against dataHash + eds, err := readEds(ctx, stream, root) + if err != nil { + return nil, fmt.Errorf("read eds from stream: %w", err) + } + c.metrics.ObserveRequests(ctx, 1, shrex.StatusSuccess) + return eds, nil + case pb.Status_NOT_FOUND: + c.metrics.ObserveRequests(ctx, 1, shrex.StatusNotFound) + return nil, shrex.ErrNotFound + case pb.Status_INVALID: + log.Debug("client: invalid request") + fallthrough + case pb.Status_INTERNAL: + fallthrough + default: + c.metrics.ObserveRequests(ctx, 1, shrex.StatusInternalErr) + return nil, shrex.ErrInvalidResponse + } +} + +func readEds(ctx context.Context, stream network.Stream, root *share.Root) (*rsmt2d.ExtendedDataSquare, error) { + odsSize := len(root.RowRoots) / 2 + shares, err := eds.ReadShares(stream, share.Size, odsSize) + if err != nil { + return nil, fmt.Errorf("failed to read eds from ods bytes: %w", err) + } + + // verify that the EDS hash matches the expected hash + rsmt2d, err := eds.Rsmt2DFromShares(shares, odsSize) + if err != nil { + return nil, fmt.Errorf("failed to create rsmt2d from shares: %w", err) + } + datahash, err := rsmt2d.DataHash(ctx) + if err != nil { + return nil, fmt.Errorf("failed to calculate data hash: %w", err) + } + if !bytes.Equal(datahash, root.Hash()) { + return nil, fmt.Errorf( + "content integrity mismatch: imported root %s doesn't match expected root %s", + datahash, + root.Hash(), + ) + } + return rsmt2d.ExtendedDataSquare, nil +} + +func (c *Client) setStreamDeadlines(ctx context.Context, stream network.Stream) { + // set read/write deadline to use context deadline if it exists + if dl, ok := ctx.Deadline(); ok { + err := stream.SetDeadline(dl) + if err == nil { + return + } + log.Debugw("client: setting deadline: %s", "err", err) + } + + // if deadline not set, client read deadline defaults to server write deadline + if c.params.ServerWriteTimeout != 0 { + err := stream.SetReadDeadline(time.Now().Add(c.params.ServerWriteTimeout)) + if err != nil { + log.Debugw("client: setting read deadline", "err", err) + } + } + + // if deadline not set, client write deadline defaults to server read deadline + if c.params.ServerReadTimeout != 0 { + err := stream.SetWriteDeadline(time.Now().Add(c.params.ServerReadTimeout)) + if err != nil { + log.Debugw("client: setting write deadline", "err", err) + } + } +} diff --git a/share/shwap/p2p/shrex/shrexeds/doc.go b/share/shwap/p2p/shrex/shrexeds/doc.go new file mode 100644 index 0000000000..6bad3061f9 --- /dev/null +++ b/share/shwap/p2p/shrex/shrexeds/doc.go @@ -0,0 +1,51 @@ +// This package defines a protocol that is used to request +// extended data squares from peers in the network. +// +// This protocol is a request/response protocol that allows for sending requests for extended data squares by data root +// to the peers in the network and receiving a response containing the original data square(s), which is used +// to recompute the extended data square. +// +// The streams are established using the protocol ID: +// +// - "{networkID}/shrex/eds/v0.0.1" where networkID is the network ID of the network. (e.g. "arabica") +// +// When a peer receives a request for extended data squares, it will read +// the original data square from the EDS store by retrieving the underlying +// CARv1 file containing the full extended data square, but will limit reading +// to the original data square shares only. +// The client on the other hand will take care of computing the extended data squares from +// the original data square on receipt. +// +// # Usage +// +// To use a shrexeds client to request extended data squares from a peer, you must +// first create a new `shrexeds.Client` instance by: +// +// client, err := shrexeds.NewClient(params, host) +// +// where `params` is a `shrexeds.Parameters` instance and `host` is a `libp2p.Host` instance. +// +// To request extended data squares from a peer, you must first create a `Client.RequestEDS` instance by: +// +// eds, err := client.RequestEDS(ctx, dataHash, peer) +// +// where: +// - `ctx` is a `context.Context` instance, +// - `dataHash` is the data root of the extended data square and +// - `peer` is the peer ID of the peer to request the extended data square from. +// +// To use a shrexeds server to respond to requests for extended data squares from peers +// you must first create a new `shrexeds.Server` instance by: +// +// server, err := shrexeds.NewServer(params, host, store) +// +// where `params` is a [Parameters] instance, `host` is a libp2p.Host instance and `store` is a [eds.Store] instance. +// +// To start the server, you must call `Start` on the server: +// +// err := server.Start(ctx) +// +// To stop the server, you must call `Stop` on the server: +// +// err := server.Stop(ctx) +package shrexeds diff --git a/share/shwap/p2p/shrex/shrexeds/exchange_test.go b/share/shwap/p2p/shrex/shrexeds/exchange_test.go new file mode 100644 index 0000000000..d5f62be71f --- /dev/null +++ b/share/shwap/p2p/shrex/shrexeds/exchange_test.go @@ -0,0 +1,159 @@ +package shrexeds + +import ( + "context" + "sync" + "testing" + "time" + + libhost "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/eds/edstest" + "github.com/celestiaorg/celestia-node/share/shwap" + "github.com/celestiaorg/celestia-node/share/shwap/p2p/shrex" + "github.com/celestiaorg/celestia-node/store" +) + +func TestExchange_RequestEDS(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + store, client, server := makeExchange(t) + + err := server.Start(ctx) + require.NoError(t, err) + + // Testcase: EDS is immediately available + t.Run("EDS_Available", func(t *testing.T) { + eds := edstest.RandEDS(t, 4) + dah, err := share.NewRoot(eds) + require.NoError(t, err) + height := uint64(1) + f, err := store.Put(ctx, dah.Hash(), height, eds) + require.NoError(t, err) + require.NoError(t, f.Close()) + + requestedEDS, err := client.RequestEDS(ctx, dah, height, server.host.ID()) + assert.NoError(t, err) + assert.Equal(t, eds.Flattened(), requestedEDS.Flattened()) + }) + + // Testcase: EDS is unavailable initially, but is found after multiple requests + t.Run("EDS_AvailableAfterDelay", func(t *testing.T) { + eds := edstest.RandEDS(t, 4) + dah, err := share.NewRoot(eds) + require.NoError(t, err) + height := uint64(666) + + lock := make(chan struct{}) + go func() { + <-lock + f, err := store.Put(ctx, dah.Hash(), height, eds) + require.NoError(t, err) + require.NoError(t, f.Close()) + lock <- struct{}{} + }() + + requestedEDS, err := client.RequestEDS(ctx, dah, height, server.host.ID()) + assert.ErrorIs(t, err, shrex.ErrNotFound) + assert.Nil(t, requestedEDS) + + // unlock write + lock <- struct{}{} + // wait for write to finish + <-lock + + requestedEDS, err = client.RequestEDS(ctx, dah, height, server.host.ID()) + assert.NoError(t, err) + assert.Equal(t, eds.Flattened(), requestedEDS.Flattened()) + }) + + // Testcase: Invalid request excludes peer from round-robin, stopping request + t.Run("EDS_InvalidRequest", func(t *testing.T) { + emptyRoot := share.EmptyRoot() + height := uint64(0) + requestedEDS, err := client.RequestEDS(ctx, emptyRoot, height, server.host.ID()) + assert.ErrorIs(t, err, shwap.ErrInvalidShwapID) + assert.Nil(t, requestedEDS) + }) + + t.Run("EDS_err_not_found", func(t *testing.T) { + timeoutCtx, cancel := context.WithTimeout(ctx, time.Second) + t.Cleanup(cancel) + eds := edstest.RandEDS(t, 4) + dah, err := share.NewRoot(eds) + require.NoError(t, err) + height := uint64(668) + _, err = client.RequestEDS(timeoutCtx, dah, height, server.host.ID()) + require.ErrorIs(t, err, shrex.ErrNotFound) + }) + + // Testcase: Concurrency limit reached + t.Run("EDS_concurrency_limit", func(t *testing.T) { + _, client, server := makeExchange(t) + + require.NoError(t, server.Start(ctx)) + + ctx, cancel := context.WithTimeout(ctx, time.Second) + t.Cleanup(cancel) + + rateLimit := 2 + wg := sync.WaitGroup{} + wg.Add(rateLimit) + + // mockHandler will block requests on server side until test is over + lock := make(chan struct{}) + defer close(lock) + mockHandler := func(network.Stream) { + wg.Done() + select { + case <-lock: + case <-ctx.Done(): + t.Fatal("timeout") + } + } + middleware := shrex.NewMiddleware(rateLimit) + server.host.SetStreamHandler(server.protocolID, + middleware.RateLimitHandler(mockHandler)) + + // take server concurrency slots with blocked requests + emptyRoot := share.EmptyRoot() + for i := 0; i < rateLimit; i++ { + go func(i int) { + client.RequestEDS(ctx, emptyRoot, 1, server.host.ID()) //nolint:errcheck + }(i) + } + + // wait until all server slots are taken + wg.Wait() + _, err = client.RequestEDS(ctx, emptyRoot, 1, server.host.ID()) + require.ErrorIs(t, err, shrex.ErrNotFound) + }) +} + +func createMocknet(t *testing.T, amount int) []libhost.Host { + t.Helper() + + net, err := mocknet.FullMeshConnected(amount) + require.NoError(t, err) + // get host and peer + return net.Hosts() +} + +func makeExchange(t *testing.T) (*store.Store, *Client, *Server) { + t.Helper() + store, err := store.NewStore(store.DefaultParameters(), t.TempDir()) + require.NoError(t, err) + hosts := createMocknet(t, 2) + + client, err := NewClient(DefaultParameters(), hosts[0]) + require.NoError(t, err) + server, err := NewServer(DefaultParameters(), hosts[1], store) + require.NoError(t, err) + + return store, client, server +} diff --git a/share/shwap/p2p/shrex/shrexeds/params.go b/share/shwap/p2p/shrex/shrexeds/params.go new file mode 100644 index 0000000000..4c3667033d --- /dev/null +++ b/share/shwap/p2p/shrex/shrexeds/params.go @@ -0,0 +1,54 @@ +package shrexeds + +import ( + "fmt" + + logging "github.com/ipfs/go-log/v2" + + "github.com/celestiaorg/celestia-node/share/shwap/p2p/shrex" +) + +const protocolString = "/shrex/eds/v0.0.2" + +var log = logging.Logger("shrex/eds") + +// Parameters is the set of parameters that must be configured for the shrex/eds protocol. +type Parameters struct { + *shrex.Parameters + + // BufferSize defines the size of the buffer used for writing an ODS over the stream. + BufferSize uint64 +} + +func DefaultParameters() *Parameters { + return &Parameters{ + Parameters: shrex.DefaultParameters(), + BufferSize: 32 * 1024, + } +} + +func (p *Parameters) Validate() error { + if p.BufferSize <= 0 { + return fmt.Errorf("invalid buffer size: %v, value should be positive and non-zero", p.BufferSize) + } + + return p.Parameters.Validate() +} + +func (c *Client) WithMetrics() error { + metrics, err := shrex.InitClientMetrics("eds") + if err != nil { + return fmt.Errorf("shrex/eds: init Metrics: %w", err) + } + c.metrics = metrics + return nil +} + +func (s *Server) WithMetrics() error { + metrics, err := shrex.InitServerMetrics("eds") + if err != nil { + return fmt.Errorf("shrex/eds: init Metrics: %w", err) + } + s.metrics = metrics + return nil +} diff --git a/share/shwap/p2p/shrex/shrexeds/pb/extended_data_square.pb.go b/share/shwap/p2p/shrex/shrexeds/pb/extended_data_square.pb.go new file mode 100644 index 0000000000..60d26abdc2 --- /dev/null +++ b/share/shwap/p2p/shrex/shrexeds/pb/extended_data_square.pb.go @@ -0,0 +1,338 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: share/shwap/p2p/shrex/shrexeds/pb/extended_data_square.proto + +package pb + +import ( + fmt "fmt" + proto "github.com/gogo/protobuf/proto" + io "io" + math "math" + math_bits "math/bits" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package + +type Status int32 + +const ( + Status_INVALID Status = 0 + Status_OK Status = 1 + Status_NOT_FOUND Status = 2 + Status_INTERNAL Status = 3 +) + +var Status_name = map[int32]string{ + 0: "INVALID", + 1: "OK", + 2: "NOT_FOUND", + 3: "INTERNAL", +} + +var Status_value = map[string]int32{ + "INVALID": 0, + "OK": 1, + "NOT_FOUND": 2, + "INTERNAL": 3, +} + +func (x Status) String() string { + return proto.EnumName(Status_name, int32(x)) +} + +func (Status) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_5176f06f10cac3fd, []int{0} +} + +type EDSResponse struct { + Status Status `protobuf:"varint,1,opt,name=status,proto3,enum=Status" json:"status,omitempty"` +} + +func (m *EDSResponse) Reset() { *m = EDSResponse{} } +func (m *EDSResponse) String() string { return proto.CompactTextString(m) } +func (*EDSResponse) ProtoMessage() {} +func (*EDSResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_5176f06f10cac3fd, []int{0} +} +func (m *EDSResponse) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *EDSResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_EDSResponse.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *EDSResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_EDSResponse.Merge(m, src) +} +func (m *EDSResponse) XXX_Size() int { + return m.Size() +} +func (m *EDSResponse) XXX_DiscardUnknown() { + xxx_messageInfo_EDSResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_EDSResponse proto.InternalMessageInfo + +func (m *EDSResponse) GetStatus() Status { + if m != nil { + return m.Status + } + return Status_INVALID +} + +func init() { + proto.RegisterEnum("Status", Status_name, Status_value) + proto.RegisterType((*EDSResponse)(nil), "EDSResponse") +} + +func init() { + proto.RegisterFile("share/shwap/p2p/shrex/shrexeds/pb/extended_data_square.proto", fileDescriptor_5176f06f10cac3fd) +} + +var fileDescriptor_5176f06f10cac3fd = []byte{ + // 244 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xb2, 0x29, 0xce, 0x48, 0x2c, + 0x4a, 0xd5, 0x2f, 0xce, 0x28, 0x4f, 0x2c, 0xd0, 0x2f, 0x30, 0x2a, 0xd0, 0x2f, 0xce, 0x28, 0x4a, + 0xad, 0x80, 0x90, 0xa9, 0x29, 0xc5, 0xfa, 0x05, 0x49, 0xfa, 0xa9, 0x15, 0x25, 0xa9, 0x79, 0x29, + 0xa9, 0x29, 0xf1, 0x29, 0x89, 0x25, 0x89, 0xf1, 0xc5, 0x85, 0xa5, 0x89, 0x45, 0xa9, 0x7a, 0x05, + 0x45, 0xf9, 0x25, 0xf9, 0x4a, 0x7a, 0x5c, 0xdc, 0xae, 0x2e, 0xc1, 0x41, 0xa9, 0xc5, 0x05, 0xf9, + 0x79, 0xc5, 0xa9, 0x42, 0xf2, 0x5c, 0x6c, 0xc5, 0x25, 0x89, 0x25, 0xa5, 0xc5, 0x12, 0x8c, 0x0a, + 0x8c, 0x1a, 0x7c, 0x46, 0xec, 0x7a, 0xc1, 0x60, 0x6e, 0x10, 0x54, 0x58, 0xcb, 0x8a, 0x8b, 0x0d, + 0x22, 0x22, 0xc4, 0xcd, 0xc5, 0xee, 0xe9, 0x17, 0xe6, 0xe8, 0xe3, 0xe9, 0x22, 0xc0, 0x20, 0xc4, + 0xc6, 0xc5, 0xe4, 0xef, 0x2d, 0xc0, 0x28, 0xc4, 0xcb, 0xc5, 0xe9, 0xe7, 0x1f, 0x12, 0xef, 0xe6, + 0x1f, 0xea, 0xe7, 0x22, 0xc0, 0x24, 0xc4, 0xc3, 0xc5, 0xe1, 0xe9, 0x17, 0xe2, 0x1a, 0xe4, 0xe7, + 0xe8, 0x23, 0xc0, 0xec, 0x94, 0x70, 0xe2, 0x91, 0x1c, 0xe3, 0x85, 0x47, 0x72, 0x8c, 0x0f, 0x1e, + 0xc9, 0x31, 0x4e, 0x78, 0x2c, 0xc7, 0x70, 0xe1, 0xb1, 0x1c, 0xc3, 0x8d, 0xc7, 0x72, 0x0c, 0x51, + 0x6e, 0xe9, 0x99, 0x25, 0x19, 0xa5, 0x49, 0x7a, 0xc9, 0xf9, 0xb9, 0xfa, 0xc9, 0xa9, 0x39, 0xa9, + 0xc5, 0x25, 0x99, 0x89, 0xf9, 0x45, 0xe9, 0x70, 0xb6, 0x6e, 0x5e, 0x7e, 0x0a, 0xc8, 0x8b, 0x04, + 0x3c, 0x9a, 0xc4, 0x06, 0xf6, 0x94, 0x31, 0x20, 0x00, 0x00, 0xff, 0xff, 0x91, 0x62, 0x7c, 0xa1, + 0x14, 0x01, 0x00, 0x00, +} + +func (m *EDSResponse) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *EDSResponse) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *EDSResponse) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.Status != 0 { + i = encodeVarintExtendedDataSquare(dAtA, i, uint64(m.Status)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func encodeVarintExtendedDataSquare(dAtA []byte, offset int, v uint64) int { + offset -= sovExtendedDataSquare(v) + base := offset + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return base +} +func (m *EDSResponse) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Status != 0 { + n += 1 + sovExtendedDataSquare(uint64(m.Status)) + } + return n +} + +func sovExtendedDataSquare(x uint64) (n int) { + return (math_bits.Len64(x|1) + 6) / 7 +} +func sozExtendedDataSquare(x uint64) (n int) { + return sovExtendedDataSquare(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *EDSResponse) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowExtendedDataSquare + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: EDSResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: EDSResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Status", wireType) + } + m.Status = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowExtendedDataSquare + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Status |= Status(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipExtendedDataSquare(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthExtendedDataSquare + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipExtendedDataSquare(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + depth := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowExtendedDataSquare + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowExtendedDataSquare + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + case 1: + iNdEx += 8 + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowExtendedDataSquare + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLengthExtendedDataSquare + } + iNdEx += length + case 3: + depth++ + case 4: + if depth == 0 { + return 0, ErrUnexpectedEndOfGroupExtendedDataSquare + } + depth-- + case 5: + iNdEx += 4 + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + if iNdEx < 0 { + return 0, ErrInvalidLengthExtendedDataSquare + } + if depth == 0 { + return iNdEx, nil + } + } + return 0, io.ErrUnexpectedEOF +} + +var ( + ErrInvalidLengthExtendedDataSquare = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowExtendedDataSquare = fmt.Errorf("proto: integer overflow") + ErrUnexpectedEndOfGroupExtendedDataSquare = fmt.Errorf("proto: unexpected end of group") +) diff --git a/share/shwap/p2p/shrex/shrexeds/pb/extended_data_square.proto b/share/shwap/p2p/shrex/shrexeds/pb/extended_data_square.proto new file mode 100644 index 0000000000..d8d2cf6a52 --- /dev/null +++ b/share/shwap/p2p/shrex/shrexeds/pb/extended_data_square.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +option go_package = "github.com/celestiaorg/celestia-node/share/shwap/p2p/shrex/shrexeds/pb"; + +enum Status { + INVALID = 0; + OK = 1; // data found + NOT_FOUND = 2; // data not found + INTERNAL = 3; // internal server error +} + +message EDSResponse { + Status status = 1; +} diff --git a/share/shwap/p2p/shrex/shrexeds/server.go b/share/shwap/p2p/shrex/shrexeds/server.go new file mode 100644 index 0000000000..dbcfcadc4a --- /dev/null +++ b/share/shwap/p2p/shrex/shrexeds/server.go @@ -0,0 +1,194 @@ +package shrexeds + +import ( + "context" + "errors" + "fmt" + "io" + "time" + + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/protocol" + "go.uber.org/zap" + + "github.com/celestiaorg/go-libp2p-messenger/serde" + + "github.com/celestiaorg/celestia-node/libs/utils" + eds "github.com/celestiaorg/celestia-node/share/new_eds" + "github.com/celestiaorg/celestia-node/share/shwap" + "github.com/celestiaorg/celestia-node/share/shwap/p2p/shrex" + p2p_pb "github.com/celestiaorg/celestia-node/share/shwap/p2p/shrex/shrexeds/pb" + "github.com/celestiaorg/celestia-node/store" +) + +// Server is responsible for serving ODSs for blocksync over the ShrEx/EDS protocol. +type Server struct { + ctx context.Context + cancel context.CancelFunc + + host host.Host + protocolID protocol.ID + + store *store.Store + + params *Parameters + middleware *shrex.Middleware + metrics *shrex.Metrics +} + +// NewServer creates a new ShrEx/EDS server. +func NewServer(params *Parameters, host host.Host, store *store.Store) (*Server, error) { + if err := params.Validate(); err != nil { + return nil, fmt.Errorf("shrex-eds: server creation failed: %w", err) + } + + return &Server{ + host: host, + store: store, + protocolID: shrex.ProtocolID(params.NetworkID(), protocolString), + params: params, + middleware: shrex.NewMiddleware(params.ConcurrencyLimit), + }, nil +} + +func (s *Server) Start(context.Context) error { + s.ctx, s.cancel = context.WithCancel(context.Background()) + s.host.SetStreamHandler(s.protocolID, s.middleware.RateLimitHandler(s.handleStream)) + return nil +} + +func (s *Server) Stop(context.Context) error { + defer s.cancel() + s.host.RemoveStreamHandler(s.protocolID) + return nil +} + +func (s *Server) observeRateLimitedRequests() { + numRateLimited := s.middleware.DrainCounter() + if numRateLimited > 0 { + s.metrics.ObserveRequests(context.Background(), numRateLimited, shrex.StatusRateLimited) + } +} + +func (s *Server) handleStream(stream network.Stream) { + logger := log.With("peer", stream.Conn().RemotePeer().String()) + logger.Debug("server: handling eds request") + + s.observeRateLimitedRequests() + + // read request from stream to get the dataHash for store lookup + id, err := s.readRequest(logger, stream) + if err != nil { + logger.Warnw("server: reading request from stream", "err", err) + stream.Reset() //nolint:errcheck + return + } + + logger = logger.With("height", id.Height) + + ctx, cancel := context.WithTimeout(s.ctx, s.params.HandleRequestTimeout) + defer cancel() + + // determine whether the EDS is available in our store + // we do not close the reader, so that other requests will not need to re-open the file. + // closing is handled by the LRU cache. + file, err := s.store.GetByHeight(ctx, id.Height) + var status p2p_pb.Status + switch { + case err == nil: + defer utils.CloseAndLog(logger, "file", file) + status = p2p_pb.Status_OK + case errors.Is(err, store.ErrNotFound): + logger.Warnw("server: request height not found") + s.metrics.ObserveRequests(ctx, 1, shrex.StatusNotFound) + status = p2p_pb.Status_NOT_FOUND + case err != nil: + logger.Errorw("server: get file", "err", err) + status = p2p_pb.Status_INTERNAL + } + + // inform the client of our status + err = s.writeStatus(logger, status, stream) + if err != nil { + logger.Warnw("server: writing status to stream", "err", err) + stream.Reset() //nolint:errcheck + return + } + // if we cannot serve the EDS, we are already done + if status != p2p_pb.Status_OK { + err = stream.Close() + if err != nil { + logger.Debugw("server: closing stream", "err", err) + } + return + } + + // start streaming the ODS to the client + err = s.writeODS(logger, file, stream) + if err != nil { + logger.Warnw("server: writing ods to stream", "err", err) + stream.Reset() //nolint:errcheck + return + } + + s.metrics.ObserveRequests(ctx, 1, shrex.StatusSuccess) + err = stream.Close() + if err != nil { + logger.Debugw("server: closing stream", "err", err) + } +} + +func (s *Server) readRequest(logger *zap.SugaredLogger, stream network.Stream) (shwap.EdsID, error) { + err := stream.SetReadDeadline(time.Now().Add(s.params.ServerReadTimeout)) + if err != nil { + logger.Debugw("server: set read deadline", "err", err) + } + + req := make([]byte, shwap.EdsIDSize) + _, err = io.ReadFull(stream, req) + if err != nil { + return shwap.EdsID{}, fmt.Errorf("reading request: %w", err) + } + id, err := shwap.EdsIDFromBinary(req) + if err != nil { + return shwap.EdsID{}, fmt.Errorf("parsing request: %w", err) + } + err = stream.CloseRead() + if err != nil { + logger.Debugw("server: closing read", "err", err) + } + + return id, id.Validate() +} + +func (s *Server) writeStatus(logger *zap.SugaredLogger, status p2p_pb.Status, stream network.Stream) error { + err := stream.SetWriteDeadline(time.Now().Add(s.params.ServerWriteTimeout)) + if err != nil { + logger.Debugw("server: set write deadline", "err", err) + } + + resp := &p2p_pb.EDSResponse{Status: status} + _, err = serde.Write(stream, resp) + return err +} + +func (s *Server) writeODS(logger *zap.SugaredLogger, stramer eds.Streamer, stream network.Stream) error { + reader, err := stramer.Reader() + if err != nil { + return fmt.Errorf("getting ODS reader: %w", err) + } + err = stream.SetWriteDeadline(time.Now().Add(s.params.ServerWriteTimeout)) + if err != nil { + logger.Debugw("server: set read deadline", "err", err) + } + + buf := make([]byte, s.params.BufferSize) + n, err := io.CopyBuffer(stream, reader, buf) + if err != nil { + return fmt.Errorf("written: %v, writing ODS bytes: %w", n, err) + } + + logger.Debugw("server: wrote ODS", "bytes", n) + return nil +} diff --git a/share/shwap/p2p/shrex/shrexnd/client.go b/share/shwap/p2p/shrex/shrexnd/client.go new file mode 100644 index 0000000000..5c8889fd5c --- /dev/null +++ b/share/shwap/p2p/shrex/shrexnd/client.go @@ -0,0 +1,233 @@ +package shrexnd + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "time" + + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" + + "github.com/celestiaorg/go-libp2p-messenger/serde" + "github.com/celestiaorg/nmt" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/shwap" + "github.com/celestiaorg/celestia-node/share/shwap/p2p/shrex" + pb "github.com/celestiaorg/celestia-node/share/shwap/p2p/shrex/shrexnd/pb" +) + +// Client implements client side of shrex/nd protocol to obtain namespaced shares data from remote +// peers. +type Client struct { + params *Parameters + protocolID protocol.ID + + host host.Host + metrics *shrex.Metrics +} + +// NewClient creates a new shrEx/nd client +func NewClient(params *Parameters, host host.Host) (*Client, error) { + if err := params.Validate(); err != nil { + return nil, fmt.Errorf("shrex-nd: client creation failed: %w", err) + } + + return &Client{ + host: host, + protocolID: shrex.ProtocolID(params.NetworkID(), protocolString), + params: params, + }, nil +} + +// RequestND requests namespaced data from the given peer. +// Returns NamespacedData with unverified inclusion proofs against the share.Root. +func (c *Client) RequestND( + ctx context.Context, + root *share.Root, + height uint64, + fromRow, toRow int, + namespace share.Namespace, + peer peer.ID, +) (shwap.NamespacedData, error) { + if err := namespace.ValidateForData(); err != nil { + return nil, err + } + + shares, err := c.doRequest(ctx, height, root, fromRow, toRow, namespace, peer) + if err == nil { + return shares, nil + } + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + c.metrics.ObserveRequests(ctx, 1, shrex.StatusTimeout) + return nil, err + } + // some net.Errors also mean the context deadline was exceeded, but yamux/mocknet do not + // unwrap to a ctx err + var ne net.Error + if errors.As(err, &ne) && ne.Timeout() { + if deadline, _ := ctx.Deadline(); deadline.Before(time.Now()) { + c.metrics.ObserveRequests(ctx, 1, shrex.StatusTimeout) + return nil, context.DeadlineExceeded + } + } + if !errors.Is(err, shrex.ErrNotFound) && errors.Is(err, shrex.ErrRateLimited) { + log.Warnw("client-nd: peer returned err", "err", err) + } + return nil, err +} + +func (c *Client) doRequest( + ctx context.Context, + height uint64, + root *share.Root, + fromRow, toRow int, + namespace share.Namespace, + peerID peer.ID, +) (shwap.NamespacedData, error) { + stream, err := c.host.NewStream(ctx, peerID, c.protocolID) + if err != nil { + return nil, err + } + defer stream.Close() + + c.setStreamDeadlines(ctx, stream) + + req, err := shwap.NewNamespaceDataID(height, fromRow, toRow, namespace, len(root.RowRoots)) + if err != nil { + return nil, fmt.Errorf("client-nd: creating request: %w", err) + } + + br, err := req.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("client-nd: marshaling request: %w", err) + } + + _, err = stream.Write(br) + if err != nil { + c.metrics.ObserveRequests(ctx, 1, shrex.StatusSendReqErr) + stream.Reset() //nolint:errcheck + return nil, fmt.Errorf("client-nd: writing request: %w", err) + } + + err = stream.CloseWrite() + if err != nil { + log.Debugw("client-nd: closing write side of the stream", "err", err) + } + + if err := c.readStatus(ctx, stream); err != nil { + return nil, err + } + return c.readNamespacedShares(ctx, stream) +} + +func (c *Client) readStatus(ctx context.Context, stream network.Stream) error { + var resp pb.GetSharesByNamespaceStatusResponse + _, err := serde.Read(stream, &resp) + if err != nil { + // server is overloaded and closed the stream + if errors.Is(err, io.EOF) { + c.metrics.ObserveRequests(ctx, 1, shrex.StatusRateLimited) + return shrex.ErrRateLimited + } + c.metrics.ObserveRequests(ctx, 1, shrex.StatusReadRespErr) + stream.Reset() //nolint:errcheck + return fmt.Errorf("client-nd: reading status response: %w", err) + } + + return c.convertStatusToErr(ctx, resp.Status) +} + +// readNamespacedShares converts proto Rows to share.NamespacedData +func (c *Client) readNamespacedShares( + ctx context.Context, + stream network.Stream, +) (shwap.NamespacedData, error) { + var shares shwap.NamespacedData + for { + var row pb.NamespaceRowResponse + _, err := serde.Read(stream, &row) + if err != nil { + if errors.Is(err, io.EOF) { + // all data is received and steam is closed by server + return shares, nil + } + c.metrics.ObserveRequests(ctx, 1, shrex.StatusReadRespErr) + return nil, err + } + var proof nmt.Proof + if row.Proof != nil { + if len(row.Shares) != 0 { + proof = nmt.NewInclusionProof( + int(row.Proof.Start), + int(row.Proof.End), + row.Proof.Nodes, + row.Proof.IsMaxNamespaceIgnored, + ) + } else { + proof = nmt.NewAbsenceProof( + int(row.Proof.Start), + int(row.Proof.End), + row.Proof.Nodes, + row.Proof.LeafHash, + row.Proof.IsMaxNamespaceIgnored, + ) + } + } + shares = append(shares, shwap.RowNamespaceData{ + Shares: row.Shares, + Proof: &proof, + }) + } +} + +func (c *Client) setStreamDeadlines(ctx context.Context, stream network.Stream) { + // set read/write deadline to use context deadline if it exists + deadline, ok := ctx.Deadline() + if ok { + err := stream.SetDeadline(deadline) + if err == nil { + return + } + log.Debugw("client-nd: set stream deadline", "err", err) + } + + // if deadline not set, client read deadline defaults to server write deadline + if c.params.ServerWriteTimeout != 0 { + err := stream.SetReadDeadline(time.Now().Add(c.params.ServerWriteTimeout)) + if err != nil { + log.Debugw("client-nd: set read deadline", "err", err) + } + } + + // if deadline not set, client write deadline defaults to server read deadline + if c.params.ServerReadTimeout != 0 { + err := stream.SetWriteDeadline(time.Now().Add(c.params.ServerReadTimeout)) + if err != nil { + log.Debugw("client-nd: set write deadline", "err", err) + } + } +} + +func (c *Client) convertStatusToErr(ctx context.Context, status pb.StatusCode) error { + switch status { + case pb.StatusCode_OK: + c.metrics.ObserveRequests(ctx, 1, shrex.StatusSuccess) + return nil + case pb.StatusCode_NOT_FOUND: + c.metrics.ObserveRequests(ctx, 1, shrex.StatusNotFound) + return shrex.ErrNotFound + case pb.StatusCode_INVALID: + log.Warn("client-nd: invalid request") + fallthrough + case pb.StatusCode_INTERNAL: + fallthrough + default: + return shrex.ErrInvalidResponse + } +} diff --git a/share/shwap/p2p/shrex/shrexnd/doc.go b/share/shwap/p2p/shrex/shrexnd/doc.go new file mode 100644 index 0000000000..74ba7397e8 --- /dev/null +++ b/share/shwap/p2p/shrex/shrexnd/doc.go @@ -0,0 +1,43 @@ +// This package defines a protocol that is used to request namespaced data from peers in the network. +// +// This protocol is a request/response protocol that sends a request for specific data that +// lives in a specific namespace ID and receives a response with the data. +// +// The streams are established using the protocol ID: +// +// - "{networkID}/shrex/nd/0.0.1" where networkID is the network ID of the network. (e.g. "arabica") +// +// The protocol uses protobuf to serialize and deserialize messages. +// +// # Usage +// +// To use a shrexnd client to request data from a peer, you must first create a new `shrexnd.Client` instance by: +// +// 1. Create a new client using `NewClient` and pass in the parameters of the protocol and the host: +// +// client, err := shrexnd.NewClient(params, host) +// +// 2. Request data from a peer by calling [Client.RequestND] on the client and +// pass in the context, the data root, the namespace ID and the peer ID: +// +// data, err := client.RequestND(ctx, dataRoot, peerID, namespaceID) +// +// where data is of type [share.NamespacedShares] +// +// To use a shrexnd server to respond to requests from peers, you must first create a new `shrexnd.Server` instance by: +// +// 1. Create a new server using `NewServer` and pass in the parameters of +// the protocol, the host, the store and store share getter: +// +// server, err := shrexnd.NewServer(params, host, store, storeShareGetter) +// +// where store is of type [share.Store] and storeShareGetter is of type [share.Getter] +// +// 2. Start the server by calling `Start` on the server: +// +// err := server.Start(ctx) +// +// 3. Stop the server by calling `Stop` on the server: +// +// err := server.Stop(ctx) +package shrexnd diff --git a/share/shwap/p2p/shrex/shrexnd/exchange_test.go b/share/shwap/p2p/shrex/shrexnd/exchange_test.go new file mode 100644 index 0000000000..6c23f41c9c --- /dev/null +++ b/share/shwap/p2p/shrex/shrexnd/exchange_test.go @@ -0,0 +1,125 @@ +package shrexnd + +import ( + "context" + "sync" + "testing" + "time" + + libhost "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/eds/edstest" + "github.com/celestiaorg/celestia-node/share/sharetest" + "github.com/celestiaorg/celestia-node/share/shwap/p2p/shrex" + "github.com/celestiaorg/celestia-node/store" +) + +func TestExchange_RequestND_NotFound(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(cancel) + edsStore, client, server := makeExchange(t) + require.NoError(t, server.Start(ctx)) + + t.Run("CAR_not_exist", func(t *testing.T) { + ctx, cancel := context.WithTimeout(ctx, time.Second) + t.Cleanup(cancel) + + root := share.EmptyRoot() + namespace := sharetest.RandV0Namespace() + _, err := client.RequestND(ctx, root, 1, 1, 1, namespace, server.host.ID()) + require.ErrorIs(t, err, shrex.ErrNotFound) + }) + + t.Run("ErrNamespaceNotFound", func(t *testing.T) { + ctx, cancel := context.WithTimeout(ctx, time.Second) + t.Cleanup(cancel) + + eds := edstest.RandEDS(t, 4) + dah, err := share.NewRoot(eds) + require.NoError(t, err) + height := uint64(1) + f, err := edsStore.Put(ctx, dah.Hash(), height, eds) + require.NoError(t, err) + require.NoError(t, f.Close()) + + namespace := sharetest.RandV0Namespace() + emptyShares, err := client.RequestND(ctx, dah, 1, 1, 1, namespace, server.host.ID()) + require.NoError(t, err) + require.Empty(t, emptyShares.Flatten()) + }) +} + +func TestExchange_RequestND(t *testing.T) { + t.Run("ND_concurrency_limit", func(t *testing.T) { + net, err := mocknet.FullMeshConnected(2) + require.NoError(t, err) + + client, err := NewClient(DefaultParameters(), net.Hosts()[0]) + require.NoError(t, err) + server, err := NewServer(DefaultParameters(), net.Hosts()[1], nil) + require.NoError(t, err) + + require.NoError(t, server.Start(context.Background())) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + + rateLimit := 2 + wg := sync.WaitGroup{} + wg.Add(rateLimit) + + // mockHandler will block requests on server side until test is over + lock := make(chan struct{}) + defer close(lock) + mockHandler := func(network.Stream) { + wg.Done() + select { + case <-lock: + case <-ctx.Done(): + t.Fatal("timeout") + } + } + middleware := shrex.NewMiddleware(rateLimit) + server.host.SetStreamHandler(server.protocolID, + middleware.RateLimitHandler(mockHandler)) + + // take server concurrency slots with blocked requests + for i := 0; i < rateLimit; i++ { + go func(i int) { + client.RequestND(ctx, share.EmptyRoot(), 1, 1, 1, sharetest.RandV0Namespace(), server.host.ID()) //nolint:errcheck + }(i) + } + + // wait until all server slots are taken + wg.Wait() + _, err = client.RequestND(ctx, share.EmptyRoot(), 1, 1, 1, sharetest.RandV0Namespace(), server.host.ID()) + require.ErrorIs(t, err, shrex.ErrRateLimited) + }) +} + +func createMocknet(t *testing.T, amount int) []libhost.Host { + t.Helper() + + net, err := mocknet.FullMeshConnected(amount) + require.NoError(t, err) + // get host and peer + return net.Hosts() +} + +func makeExchange(t *testing.T) (*store.Store, *Client, *Server) { + t.Helper() + s, err := store.NewStore(store.DefaultParameters(), t.TempDir()) + require.NoError(t, err) + hosts := createMocknet(t, 2) + + client, err := NewClient(DefaultParameters(), hosts[0]) + require.NoError(t, err) + server, err := NewServer(DefaultParameters(), hosts[1], s) + require.NoError(t, err) + + return s, client, server +} diff --git a/share/shwap/p2p/shrex/shrexnd/params.go b/share/shwap/p2p/shrex/shrexnd/params.go new file mode 100644 index 0000000000..2e1acd6010 --- /dev/null +++ b/share/shwap/p2p/shrex/shrexnd/params.go @@ -0,0 +1,38 @@ +package shrexnd + +import ( + "fmt" + + logging "github.com/ipfs/go-log/v2" + + "github.com/celestiaorg/celestia-node/share/shwap/p2p/shrex" +) + +const protocolString = "/shrex/nd/v0.0.4" + +var log = logging.Logger("shrex/nd") + +// Parameters is the set of parameters that must be configured for the shrex/eds protocol. +type Parameters = shrex.Parameters + +func DefaultParameters() *Parameters { + return shrex.DefaultParameters() +} + +func (c *Client) WithMetrics() error { + metrics, err := shrex.InitClientMetrics("nd") + if err != nil { + return fmt.Errorf("shrex/nd: init Metrics: %w", err) + } + c.metrics = metrics + return nil +} + +func (srv *Server) WithMetrics() error { + metrics, err := shrex.InitServerMetrics("nd") + if err != nil { + return fmt.Errorf("shrex/nd: init Metrics: %w", err) + } + srv.metrics = metrics + return nil +} diff --git a/share/shwap/p2p/shrex/shrexnd/pb/row_namespace_data.pb.go b/share/shwap/p2p/shrex/shrexnd/pb/row_namespace_data.pb.go new file mode 100644 index 0000000000..071046a85a --- /dev/null +++ b/share/shwap/p2p/shrex/shrexnd/pb/row_namespace_data.pb.go @@ -0,0 +1,576 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: share/shwap/p2p/shrex/shrexnd/pb/row_namespace_data.proto + +package share_p2p_shrex_nd + +import ( + fmt "fmt" + pb "github.com/celestiaorg/nmt/pb" + proto "github.com/gogo/protobuf/proto" + io "io" + math "math" + math_bits "math/bits" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package + +type StatusCode int32 + +const ( + StatusCode_INVALID StatusCode = 0 + StatusCode_OK StatusCode = 1 + StatusCode_NOT_FOUND StatusCode = 2 + StatusCode_INTERNAL StatusCode = 3 +) + +var StatusCode_name = map[int32]string{ + 0: "INVALID", + 1: "OK", + 2: "NOT_FOUND", + 3: "INTERNAL", +} + +var StatusCode_value = map[string]int32{ + "INVALID": 0, + "OK": 1, + "NOT_FOUND": 2, + "INTERNAL": 3, +} + +func (x StatusCode) String() string { + return proto.EnumName(StatusCode_name, int32(x)) +} + +func (StatusCode) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_e8097b1aa3ae2e25, []int{0} +} + +type GetSharesByNamespaceStatusResponse struct { + Status StatusCode `protobuf:"varint,1,opt,name=status,proto3,enum=share.p2p.shrex.nd.StatusCode" json:"status,omitempty"` +} + +func (m *GetSharesByNamespaceStatusResponse) Reset() { *m = GetSharesByNamespaceStatusResponse{} } +func (m *GetSharesByNamespaceStatusResponse) String() string { return proto.CompactTextString(m) } +func (*GetSharesByNamespaceStatusResponse) ProtoMessage() {} +func (*GetSharesByNamespaceStatusResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_e8097b1aa3ae2e25, []int{0} +} +func (m *GetSharesByNamespaceStatusResponse) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *GetSharesByNamespaceStatusResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_GetSharesByNamespaceStatusResponse.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *GetSharesByNamespaceStatusResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_GetSharesByNamespaceStatusResponse.Merge(m, src) +} +func (m *GetSharesByNamespaceStatusResponse) XXX_Size() int { + return m.Size() +} +func (m *GetSharesByNamespaceStatusResponse) XXX_DiscardUnknown() { + xxx_messageInfo_GetSharesByNamespaceStatusResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_GetSharesByNamespaceStatusResponse proto.InternalMessageInfo + +func (m *GetSharesByNamespaceStatusResponse) GetStatus() StatusCode { + if m != nil { + return m.Status + } + return StatusCode_INVALID +} + +type NamespaceRowResponse struct { + Shares [][]byte `protobuf:"bytes,1,rep,name=shares,proto3" json:"shares,omitempty"` + Proof *pb.Proof `protobuf:"bytes,2,opt,name=proof,proto3" json:"proof,omitempty"` +} + +func (m *NamespaceRowResponse) Reset() { *m = NamespaceRowResponse{} } +func (m *NamespaceRowResponse) String() string { return proto.CompactTextString(m) } +func (*NamespaceRowResponse) ProtoMessage() {} +func (*NamespaceRowResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_e8097b1aa3ae2e25, []int{1} +} +func (m *NamespaceRowResponse) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *NamespaceRowResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_NamespaceRowResponse.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *NamespaceRowResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_NamespaceRowResponse.Merge(m, src) +} +func (m *NamespaceRowResponse) XXX_Size() int { + return m.Size() +} +func (m *NamespaceRowResponse) XXX_DiscardUnknown() { + xxx_messageInfo_NamespaceRowResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_NamespaceRowResponse proto.InternalMessageInfo + +func (m *NamespaceRowResponse) GetShares() [][]byte { + if m != nil { + return m.Shares + } + return nil +} + +func (m *NamespaceRowResponse) GetProof() *pb.Proof { + if m != nil { + return m.Proof + } + return nil +} + +func init() { + proto.RegisterEnum("share.p2p.shrex.nd.StatusCode", StatusCode_name, StatusCode_value) + proto.RegisterType((*GetSharesByNamespaceStatusResponse)(nil), "share.p2p.shrex.nd.GetSharesByNamespaceStatusResponse") + proto.RegisterType((*NamespaceRowResponse)(nil), "share.p2p.shrex.nd.NamespaceRowResponse") +} + +func init() { + proto.RegisterFile("share/shwap/p2p/shrex/shrexnd/pb/row_namespace_data.proto", fileDescriptor_e8097b1aa3ae2e25) +} + +var fileDescriptor_e8097b1aa3ae2e25 = []byte{ + // 301 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xb2, 0x2c, 0xce, 0x48, 0x2c, + 0x4a, 0xd5, 0x2f, 0xce, 0x28, 0x4f, 0x2c, 0xd0, 0x2f, 0x30, 0x2a, 0xd0, 0x2f, 0xce, 0x28, 0x4a, + 0xad, 0x80, 0x90, 0x79, 0x29, 0xfa, 0x05, 0x49, 0xfa, 0x45, 0xf9, 0xe5, 0xf1, 0x79, 0x89, 0xb9, + 0xa9, 0xc5, 0x05, 0x89, 0xc9, 0xa9, 0xf1, 0x29, 0x89, 0x25, 0x89, 0x7a, 0x05, 0x45, 0xf9, 0x25, + 0xf9, 0x42, 0x42, 0x60, 0xad, 0x7a, 0x05, 0x46, 0x05, 0x7a, 0x60, 0xe5, 0x7a, 0x79, 0x29, 0x52, + 0x7c, 0x05, 0x49, 0xfa, 0x05, 0x45, 0xf9, 0xf9, 0x69, 0x10, 0x35, 0x4a, 0x31, 0x5c, 0x4a, 0xee, + 0xa9, 0x25, 0xc1, 0x20, 0x85, 0xc5, 0x4e, 0x95, 0x7e, 0x30, 0x63, 0x82, 0x4b, 0x12, 0x4b, 0x4a, + 0x8b, 0x83, 0x52, 0x8b, 0x0b, 0xf2, 0xf3, 0x8a, 0x53, 0x85, 0xcc, 0xb8, 0xd8, 0x8a, 0xc1, 0x22, + 0x12, 0x8c, 0x0a, 0x8c, 0x1a, 0x7c, 0x46, 0x72, 0x7a, 0x98, 0x46, 0xeb, 0x41, 0xf4, 0x38, 0xe7, + 0xa7, 0xa4, 0x06, 0x41, 0x55, 0x2b, 0x85, 0x72, 0x89, 0xc0, 0x8d, 0x0c, 0xca, 0x2f, 0x87, 0x9b, + 0x27, 0xc6, 0xc5, 0x06, 0x36, 0x00, 0x64, 0x1e, 0xb3, 0x06, 0x4f, 0x10, 0x94, 0x27, 0xa4, 0xca, + 0xc5, 0x0a, 0x76, 0x9c, 0x04, 0x93, 0x02, 0xa3, 0x06, 0xb7, 0x11, 0xbf, 0x1e, 0xd4, 0xa9, 0x49, + 0x7a, 0x01, 0x20, 0x46, 0x10, 0x44, 0x56, 0xcb, 0x8e, 0x8b, 0x0b, 0x61, 0x99, 0x10, 0x37, 0x17, + 0xbb, 0xa7, 0x5f, 0x98, 0xa3, 0x8f, 0xa7, 0x8b, 0x00, 0x83, 0x10, 0x1b, 0x17, 0x93, 0xbf, 0xb7, + 0x00, 0xa3, 0x10, 0x2f, 0x17, 0xa7, 0x9f, 0x7f, 0x48, 0xbc, 0x9b, 0x7f, 0xa8, 0x9f, 0x8b, 0x00, + 0x93, 0x10, 0x0f, 0x17, 0x87, 0xa7, 0x5f, 0x88, 0x6b, 0x90, 0x9f, 0xa3, 0x8f, 0x00, 0xb3, 0x93, + 0xc4, 0x89, 0x47, 0x72, 0x8c, 0x17, 0x1e, 0xc9, 0x31, 0x3e, 0x78, 0x24, 0xc7, 0x38, 0xe1, 0xb1, + 0x1c, 0xc3, 0x85, 0xc7, 0x72, 0x0c, 0x37, 0x1e, 0xcb, 0x31, 0x24, 0xb1, 0x81, 0x43, 0xc5, 0x18, + 0x10, 0x00, 0x00, 0xff, 0xff, 0x82, 0x3a, 0xfe, 0x72, 0x76, 0x01, 0x00, 0x00, +} + +func (m *GetSharesByNamespaceStatusResponse) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *GetSharesByNamespaceStatusResponse) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *GetSharesByNamespaceStatusResponse) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.Status != 0 { + i = encodeVarintRowNamespaceData(dAtA, i, uint64(m.Status)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func (m *NamespaceRowResponse) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *NamespaceRowResponse) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *NamespaceRowResponse) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.Proof != nil { + { + size, err := m.Proof.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintRowNamespaceData(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x12 + } + if len(m.Shares) > 0 { + for iNdEx := len(m.Shares) - 1; iNdEx >= 0; iNdEx-- { + i -= len(m.Shares[iNdEx]) + copy(dAtA[i:], m.Shares[iNdEx]) + i = encodeVarintRowNamespaceData(dAtA, i, uint64(len(m.Shares[iNdEx]))) + i-- + dAtA[i] = 0xa + } + } + return len(dAtA) - i, nil +} + +func encodeVarintRowNamespaceData(dAtA []byte, offset int, v uint64) int { + offset -= sovRowNamespaceData(v) + base := offset + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return base +} +func (m *GetSharesByNamespaceStatusResponse) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Status != 0 { + n += 1 + sovRowNamespaceData(uint64(m.Status)) + } + return n +} + +func (m *NamespaceRowResponse) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if len(m.Shares) > 0 { + for _, b := range m.Shares { + l = len(b) + n += 1 + l + sovRowNamespaceData(uint64(l)) + } + } + if m.Proof != nil { + l = m.Proof.Size() + n += 1 + l + sovRowNamespaceData(uint64(l)) + } + return n +} + +func sovRowNamespaceData(x uint64) (n int) { + return (math_bits.Len64(x|1) + 6) / 7 +} +func sozRowNamespaceData(x uint64) (n int) { + return sovRowNamespaceData(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *GetSharesByNamespaceStatusResponse) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRowNamespaceData + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: GetSharesByNamespaceStatusResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: GetSharesByNamespaceStatusResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Status", wireType) + } + m.Status = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRowNamespaceData + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Status |= StatusCode(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipRowNamespaceData(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthRowNamespaceData + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *NamespaceRowResponse) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRowNamespaceData + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: NamespaceRowResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: NamespaceRowResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Shares", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRowNamespaceData + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthRowNamespaceData + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthRowNamespaceData + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Shares = append(m.Shares, make([]byte, postIndex-iNdEx)) + copy(m.Shares[len(m.Shares)-1], dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Proof", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRowNamespaceData + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthRowNamespaceData + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthRowNamespaceData + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Proof == nil { + m.Proof = &pb.Proof{} + } + if err := m.Proof.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipRowNamespaceData(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthRowNamespaceData + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipRowNamespaceData(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + depth := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowRowNamespaceData + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowRowNamespaceData + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + case 1: + iNdEx += 8 + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowRowNamespaceData + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLengthRowNamespaceData + } + iNdEx += length + case 3: + depth++ + case 4: + if depth == 0 { + return 0, ErrUnexpectedEndOfGroupRowNamespaceData + } + depth-- + case 5: + iNdEx += 4 + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + if iNdEx < 0 { + return 0, ErrInvalidLengthRowNamespaceData + } + if depth == 0 { + return iNdEx, nil + } + } + return 0, io.ErrUnexpectedEOF +} + +var ( + ErrInvalidLengthRowNamespaceData = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowRowNamespaceData = fmt.Errorf("proto: integer overflow") + ErrUnexpectedEndOfGroupRowNamespaceData = fmt.Errorf("proto: unexpected end of group") +) diff --git a/share/shwap/p2p/shrex/shrexnd/pb/row_namespace_data.proto b/share/shwap/p2p/shrex/shrexnd/pb/row_namespace_data.proto new file mode 100644 index 0000000000..4f484a8c2b --- /dev/null +++ b/share/shwap/p2p/shrex/shrexnd/pb/row_namespace_data.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +package share.p2p.shrex.nd; +import "pb/proof.proto"; + +message GetSharesByNamespaceStatusResponse{ + StatusCode status = 1; +} + +enum StatusCode { + INVALID = 0; + OK = 1; + NOT_FOUND = 2; + INTERNAL = 3; +}; + +message NamespaceRowResponse { + repeated bytes shares = 1; + proof.pb.Proof proof = 2; +} diff --git a/share/shwap/p2p/shrex/shrexnd/server.go b/share/shwap/p2p/shrex/shrexnd/server.go new file mode 100644 index 0000000000..f177daf52c --- /dev/null +++ b/share/shwap/p2p/shrex/shrexnd/server.go @@ -0,0 +1,252 @@ +package shrexnd + +import ( + "context" + "errors" + "fmt" + "io" + "time" + + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/protocol" + "go.uber.org/zap" + + "github.com/celestiaorg/go-libp2p-messenger/serde" + nmt_pb "github.com/celestiaorg/nmt/pb" + + "github.com/celestiaorg/celestia-node/libs/utils" + "github.com/celestiaorg/celestia-node/share/shwap" + "github.com/celestiaorg/celestia-node/share/shwap/p2p/shrex" + pb "github.com/celestiaorg/celestia-node/share/shwap/p2p/shrex/shrexnd/pb" + "github.com/celestiaorg/celestia-node/store" +) + +// Server implements server side of shrex/nd protocol to serve namespaced share to remote +// peers. +type Server struct { + cancel context.CancelFunc + + host host.Host + protocolID protocol.ID + + handler network.StreamHandler + store *store.Store + + params *Parameters + middleware *shrex.Middleware + metrics *shrex.Metrics +} + +// NewServer creates new Server +func NewServer(params *Parameters, host host.Host, store *store.Store) (*Server, error) { + if err := params.Validate(); err != nil { + return nil, fmt.Errorf("shrex-nd: server creation failed: %w", err) + } + + srv := &Server{ + store: store, + host: host, + params: params, + protocolID: shrex.ProtocolID(params.NetworkID(), protocolString), + middleware: shrex.NewMiddleware(params.ConcurrencyLimit), + } + + ctx, cancel := context.WithCancel(context.Background()) + srv.cancel = cancel + + srv.handler = srv.middleware.RateLimitHandler(srv.streamHandler(ctx)) + return srv, nil +} + +// Start starts the server +func (srv *Server) Start(context.Context) error { + srv.host.SetStreamHandler(srv.protocolID, srv.handler) + return nil +} + +// Stop stops the server +func (srv *Server) Stop(context.Context) error { + srv.cancel() + srv.host.RemoveStreamHandler(srv.protocolID) + return nil +} + +func (srv *Server) streamHandler(ctx context.Context) network.StreamHandler { + return func(s network.Stream) { + err := srv.handleNamespacedData(ctx, s) + if err != nil { + s.Reset() //nolint:errcheck + return + } + if err = s.Close(); err != nil { + log.Debugw("server: closing stream", "err", err) + } + } +} + +// SetHandler sets server handler +func (srv *Server) SetHandler(handler network.StreamHandler) { + srv.handler = handler +} + +func (srv *Server) observeRateLimitedRequests() { + numRateLimited := srv.middleware.DrainCounter() + if numRateLimited > 0 { + srv.metrics.ObserveRequests(context.Background(), numRateLimited, shrex.StatusRateLimited) + } +} + +func (srv *Server) handleNamespacedData(ctx context.Context, stream network.Stream) error { + logger := log.With("source", "server", "peer", stream.Conn().RemotePeer().String()) + logger.Debug("handling nd request") + + srv.observeRateLimitedRequests() + ndid, err := srv.readRequest(logger, stream) + if err != nil { + logger.Warnw("read request", "err", err) + srv.metrics.ObserveRequests(ctx, 1, shrex.StatusBadRequest) + return err + } + + logger = logger.With( + "namespace", ndid.DataNamespace.String(), + "height", ndid.Height, + ) + + ctx, cancel := context.WithTimeout(ctx, srv.params.HandleRequestTimeout) + defer cancel() + + shares, status, err := srv.getNamespaceData(ctx, ndid) + if err != nil { + // server should respond with status regardless if there was an error getting data + sendErr := srv.respondStatus(ctx, logger, stream, status) + if sendErr != nil { + logger.Errorw("sending response", "err", sendErr) + srv.metrics.ObserveRequests(ctx, 1, shrex.StatusSendRespErr) + } + logger.Errorw("handling request", "err", err) + return errors.Join(err, sendErr) + } + + err = srv.respondStatus(ctx, logger, stream, status) + if err != nil { + logger.Errorw("sending response", "err", err) + srv.metrics.ObserveRequests(ctx, 1, shrex.StatusSendRespErr) + return err + } + + err = srv.sendNamespacedShares(shares, stream) + if err != nil { + logger.Errorw("send nd data", "err", err) + srv.metrics.ObserveRequests(ctx, 1, shrex.StatusSendRespErr) + return err + } + return nil +} + +func (srv *Server) readRequest( + logger *zap.SugaredLogger, + stream network.Stream, +) (shwap.NamespaceDataID, error) { + err := stream.SetReadDeadline(time.Now().Add(srv.params.ServerReadTimeout)) + if err != nil { + logger.Debugw("setting read deadline", "err", err) + } + + req := make([]byte, shwap.NamespaceDataIDSize) + _, err = io.ReadFull(stream, req) + if err != nil { + return shwap.NamespaceDataID{}, fmt.Errorf("reading request: %w", err) + } + id, err := shwap.NamespaceDataIDFromBinary(req) + if err != nil { + return shwap.NamespaceDataID{}, fmt.Errorf("decoding request: %w", err) + } + + logger.Debugw("new request") + err = stream.CloseRead() + if err != nil { + logger.Debugw("closing read side of the stream", "err", err) + } + + return id, nil +} + +func (srv *Server) getNamespaceData( + ctx context.Context, + id shwap.NamespaceDataID, +) (shwap.NamespacedData, pb.StatusCode, error) { + file, err := srv.store.GetByHeight(ctx, id.Height) + if err != nil { + if errors.Is(err, store.ErrNotFound) { + return nil, pb.StatusCode_NOT_FOUND, nil + } + return nil, pb.StatusCode_INTERNAL, fmt.Errorf("retrieving DAH: %w", err) + } + defer utils.CloseAndLog(log, "file", file) + + namespacedRows := make(shwap.NamespacedData, 0, id.ToRowIndex-id.FromRowIndex+1) + for rowIdx := id.FromRowIndex; rowIdx < id.ToRowIndex; rowIdx++ { + data, err := file.RowNamespaceData(ctx, id.DataNamespace, rowIdx) + if err != nil { + return nil, pb.StatusCode_INTERNAL, fmt.Errorf("retrieving data: %w", err) + } + namespacedRows = append(namespacedRows, data) + } + + return namespacedRows, pb.StatusCode_OK, nil +} + +func (srv *Server) respondStatus( + ctx context.Context, + logger *zap.SugaredLogger, + stream network.Stream, + status pb.StatusCode, +) error { + srv.observeStatus(ctx, status) + + err := stream.SetWriteDeadline(time.Now().Add(srv.params.ServerWriteTimeout)) + if err != nil { + logger.Debugw("setting write deadline", "err", err) + } + + _, err = serde.Write(stream, &pb.GetSharesByNamespaceStatusResponse{Status: status}) + if err != nil { + return fmt.Errorf("writing response: %w", err) + } + + return nil +} + +// sendNamespacedShares encodes shares into proto messages and sends it to client +func (srv *Server) sendNamespacedShares(data shwap.NamespacedData, stream network.Stream) error { + for _, row := range data { + row := &pb.NamespaceRowResponse{ + Shares: row.Shares, + Proof: &nmt_pb.Proof{ + Start: int64(row.Proof.Start()), + End: int64(row.Proof.End()), + Nodes: row.Proof.Nodes(), + LeafHash: row.Proof.LeafHash(), + IsMaxNamespaceIgnored: row.Proof.IsMaxNamespaceIDIgnored(), + }, + } + _, err := serde.Write(stream, row) + if err != nil { + return fmt.Errorf("writing nd data to stream: %w", err) + } + } + return nil +} + +func (srv *Server) observeStatus(ctx context.Context, status pb.StatusCode) { + switch { + case status == pb.StatusCode_OK: + srv.metrics.ObserveRequests(ctx, 1, shrex.StatusSuccess) + case status == pb.StatusCode_NOT_FOUND: + srv.metrics.ObserveRequests(ctx, 1, shrex.StatusNotFound) + case status == pb.StatusCode_INTERNAL: + srv.metrics.ObserveRequests(ctx, 1, shrex.StatusInternalErr) + } +} diff --git a/share/shwap/p2p/shrex/shrexsub/doc.go b/share/shwap/p2p/shrex/shrexsub/doc.go new file mode 100644 index 0000000000..95d08361a2 --- /dev/null +++ b/share/shwap/p2p/shrex/shrexsub/doc.go @@ -0,0 +1,58 @@ +// This package defines a protocol that is used to broadcast shares to peers over a pubsub network. +// +// This protocol runs on a rudimentary floodsub network is primarily a pubsub protocol +// that broadcasts and listens for shares over a pubsub topic. +// +// The pubsub topic used by this protocol is: +// +// "{networkID}/eds-sub/v0.1.0" +// +// where networkID is the network ID of the celestia-node that is running the protocol. (e.g. "arabica") +// +// # Usage +// +// To use this protocol, you must first create a new `shrexsub.PubSub` instance by: +// +// pubsub, err := shrexsub.NewPubSub(ctx, host, networkID) +// +// where host is the libp2p host that is running the protocol, and networkID is the network ID of the celestia-node +// that is running the protocol. +// +// After this, you can start the pubsub protocol by: +// +// err := pubsub.Start(ctx) +// +// Once you have started the `shrexsub.PubSub` instance, you can broadcast a share by: +// +// err := pubsub.Broadcast(ctx, notification) +// +// where `notification` is of type [shrexsub.Notification]. +// +// and `DataHash` is the hash of the share that you want to broadcast, and `Height` is the height of the share. +// +// You can also subscribe to the pubsub topic by: +// +// sub, err := pubsub.Subscribe(ctx) +// +// and then receive notifications by: +// +// for { +// select { +// case <-ctx.Done(): +// sub.Cancel() +// return +// case notification, err := <-sub.Next(): +// // handle notification or err +// } +// } +// +// You can also manipulate the received pubsub messages by using the [PubSub.AddValidator] method: +// +// pubsub.AddValidator(validator ValidatorFn) +// +// where `validator` is of type [shrexsub.ValidatorFn] and `Notification` is the same as above. +// +// You can also stop the pubsub protocol by: +// +// err := pubsub.Stop(ctx) +package shrexsub diff --git a/share/shwap/p2p/shrex/shrexsub/pb/notification.pb.go b/share/shwap/p2p/shrex/shrexsub/pb/notification.pb.go new file mode 100644 index 0000000000..c7cddbba5c --- /dev/null +++ b/share/shwap/p2p/shrex/shrexsub/pb/notification.pb.go @@ -0,0 +1,355 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: share/shwap/p2p/shrex/shrexsub/pb/notification.proto + +package share_p2p_shrex_sub + +import ( + fmt "fmt" + proto "github.com/gogo/protobuf/proto" + io "io" + math "math" + math_bits "math/bits" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package + +type RecentEDSNotification struct { + Height uint64 `protobuf:"varint,1,opt,name=height,proto3" json:"height,omitempty"` + DataHash []byte `protobuf:"bytes,2,opt,name=data_hash,json=dataHash,proto3" json:"data_hash,omitempty"` +} + +func (m *RecentEDSNotification) Reset() { *m = RecentEDSNotification{} } +func (m *RecentEDSNotification) String() string { return proto.CompactTextString(m) } +func (*RecentEDSNotification) ProtoMessage() {} +func (*RecentEDSNotification) Descriptor() ([]byte, []int) { + return fileDescriptor_c16b670e7e556100, []int{0} +} +func (m *RecentEDSNotification) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *RecentEDSNotification) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_RecentEDSNotification.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *RecentEDSNotification) XXX_Merge(src proto.Message) { + xxx_messageInfo_RecentEDSNotification.Merge(m, src) +} +func (m *RecentEDSNotification) XXX_Size() int { + return m.Size() +} +func (m *RecentEDSNotification) XXX_DiscardUnknown() { + xxx_messageInfo_RecentEDSNotification.DiscardUnknown(m) +} + +var xxx_messageInfo_RecentEDSNotification proto.InternalMessageInfo + +func (m *RecentEDSNotification) GetHeight() uint64 { + if m != nil { + return m.Height + } + return 0 +} + +func (m *RecentEDSNotification) GetDataHash() []byte { + if m != nil { + return m.DataHash + } + return nil +} + +func init() { + proto.RegisterType((*RecentEDSNotification)(nil), "share.p2p.shrex.sub.RecentEDSNotification") +} + +func init() { + proto.RegisterFile("share/shwap/p2p/shrex/shrexsub/pb/notification.proto", fileDescriptor_c16b670e7e556100) +} + +var fileDescriptor_c16b670e7e556100 = []byte{ + // 183 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x32, 0x29, 0xce, 0x48, 0x2c, + 0x4a, 0xd5, 0x2f, 0xce, 0x28, 0x4f, 0x2c, 0xd0, 0x2f, 0x30, 0x2a, 0xd0, 0x2f, 0xce, 0x28, 0x4a, + 0xad, 0x80, 0x90, 0xc5, 0xa5, 0x49, 0xfa, 0x05, 0x49, 0xfa, 0x79, 0xf9, 0x25, 0x99, 0x69, 0x99, + 0xc9, 0x89, 0x25, 0x99, 0xf9, 0x79, 0x7a, 0x05, 0x45, 0xf9, 0x25, 0xf9, 0x42, 0xc2, 0x60, 0x5d, + 0x7a, 0x05, 0x46, 0x05, 0x7a, 0x60, 0x95, 0x7a, 0xc5, 0xa5, 0x49, 0x4a, 0x3e, 0x5c, 0xa2, 0x41, + 0xa9, 0xc9, 0xa9, 0x79, 0x25, 0xae, 0x2e, 0xc1, 0x7e, 0x48, 0x7a, 0x84, 0xc4, 0xb8, 0xd8, 0x32, + 0x52, 0x33, 0xd3, 0x33, 0x4a, 0x24, 0x18, 0x15, 0x18, 0x35, 0x58, 0x82, 0xa0, 0x3c, 0x21, 0x69, + 0x2e, 0xce, 0x94, 0xc4, 0x92, 0xc4, 0xf8, 0x8c, 0xc4, 0xe2, 0x0c, 0x09, 0x26, 0x05, 0x46, 0x0d, + 0x9e, 0x20, 0x0e, 0x90, 0x80, 0x47, 0x62, 0x71, 0x86, 0x93, 0xc4, 0x89, 0x47, 0x72, 0x8c, 0x17, + 0x1e, 0xc9, 0x31, 0x3e, 0x78, 0x24, 0xc7, 0x38, 0xe1, 0xb1, 0x1c, 0xc3, 0x85, 0xc7, 0x72, 0x0c, + 0x37, 0x1e, 0xcb, 0x31, 0x24, 0xb1, 0x81, 0xdd, 0x60, 0x0c, 0x08, 0x00, 0x00, 0xff, 0xff, 0xc0, + 0x55, 0x8a, 0x06, 0xbb, 0x00, 0x00, 0x00, +} + +func (m *RecentEDSNotification) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *RecentEDSNotification) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *RecentEDSNotification) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if len(m.DataHash) > 0 { + i -= len(m.DataHash) + copy(dAtA[i:], m.DataHash) + i = encodeVarintNotification(dAtA, i, uint64(len(m.DataHash))) + i-- + dAtA[i] = 0x12 + } + if m.Height != 0 { + i = encodeVarintNotification(dAtA, i, uint64(m.Height)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func encodeVarintNotification(dAtA []byte, offset int, v uint64) int { + offset -= sovNotification(v) + base := offset + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return base +} +func (m *RecentEDSNotification) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Height != 0 { + n += 1 + sovNotification(uint64(m.Height)) + } + l = len(m.DataHash) + if l > 0 { + n += 1 + l + sovNotification(uint64(l)) + } + return n +} + +func sovNotification(x uint64) (n int) { + return (math_bits.Len64(x|1) + 6) / 7 +} +func sozNotification(x uint64) (n int) { + return sovNotification(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *RecentEDSNotification) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNotification + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: RecentEDSNotification: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: RecentEDSNotification: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Height", wireType) + } + m.Height = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNotification + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Height |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field DataHash", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNotification + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthNotification + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthNotification + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.DataHash = append(m.DataHash[:0], dAtA[iNdEx:postIndex]...) + if m.DataHash == nil { + m.DataHash = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipNotification(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthNotification + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipNotification(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + depth := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowNotification + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowNotification + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + case 1: + iNdEx += 8 + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowNotification + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLengthNotification + } + iNdEx += length + case 3: + depth++ + case 4: + if depth == 0 { + return 0, ErrUnexpectedEndOfGroupNotification + } + depth-- + case 5: + iNdEx += 4 + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + if iNdEx < 0 { + return 0, ErrInvalidLengthNotification + } + if depth == 0 { + return iNdEx, nil + } + } + return 0, io.ErrUnexpectedEOF +} + +var ( + ErrInvalidLengthNotification = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowNotification = fmt.Errorf("proto: integer overflow") + ErrUnexpectedEndOfGroupNotification = fmt.Errorf("proto: unexpected end of group") +) diff --git a/share/shwap/p2p/shrex/shrexsub/pb/notification.proto b/share/shwap/p2p/shrex/shrexsub/pb/notification.proto new file mode 100644 index 0000000000..d96cf3369e --- /dev/null +++ b/share/shwap/p2p/shrex/shrexsub/pb/notification.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package share.p2p.shrex.sub; + +message RecentEDSNotification { + uint64 height = 1; + bytes data_hash = 2; +} + diff --git a/share/shwap/p2p/shrex/shrexsub/pubsub.go b/share/shwap/p2p/shrex/shrexsub/pubsub.go new file mode 100644 index 0000000000..88f7efc15b --- /dev/null +++ b/share/shwap/p2p/shrex/shrexsub/pubsub.go @@ -0,0 +1,146 @@ +package shrexsub + +import ( + "context" + "fmt" + + logging "github.com/ipfs/go-log/v2" + pubsub "github.com/libp2p/go-libp2p-pubsub" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/peer" + + "github.com/celestiaorg/celestia-node/share" + pb "github.com/celestiaorg/celestia-node/share/shwap/p2p/shrex/shrexsub/pb" +) + +var log = logging.Logger("shrex-sub") + +// pubsubTopic hardcodes the name of the EDS floodsub topic with the provided networkID. +func pubsubTopicID(networkID string) string { + return fmt.Sprintf("%s/eds-sub/v0.2.0", networkID) +} + +// ValidatorFn is an injectable func and governs EDS notification msg validity. +// It receives the notification and sender peer and expects the validation result. +// ValidatorFn is allowed to be blocking for an indefinite time or until the context is canceled. +type ValidatorFn func(context.Context, peer.ID, Notification) pubsub.ValidationResult + +// BroadcastFn aliases the function that broadcasts the DataHash. +type BroadcastFn func(context.Context, Notification) error + +// Notification is the format of message sent by Broadcaster +type Notification struct { + DataHash share.DataHash + Height uint64 +} + +// PubSub manages receiving and propagating the EDS from/to the network +// over "eds-sub" subscription. +type PubSub struct { + pubSub *pubsub.PubSub + topic *pubsub.Topic + + pubsubTopic string + cancelRelay pubsub.RelayCancelFunc +} + +// NewPubSub creates a libp2p.PubSub wrapper. +func NewPubSub(ctx context.Context, h host.Host, networkID string) (*PubSub, error) { + pubsub, err := pubsub.NewFloodSub(ctx, h) + if err != nil { + return nil, err + } + return &PubSub{ + pubSub: pubsub, + pubsubTopic: pubsubTopicID(networkID), + }, nil +} + +// Start creates an instances of FloodSub and joins specified topic. +func (s *PubSub) Start(context.Context) error { + topic, err := s.pubSub.Join(s.pubsubTopic) + if err != nil { + return err + } + + cancel, err := topic.Relay() + if err != nil { + return err + } + + s.cancelRelay = cancel + s.topic = topic + return nil +} + +// Stop completely stops the PubSub: +// * Unregisters all the added Validators +// * Closes the `ShrEx/Sub` topic +func (s *PubSub) Stop(context.Context) error { + s.cancelRelay() + err := s.pubSub.UnregisterTopicValidator(s.pubsubTopic) + if err != nil { + log.Warnw("unregistering topic", "err", err) + } + return s.topic.Close() +} + +// AddValidator registers given ValidatorFn for EDS notifications. +// Any amount of Validators can be registered. +func (s *PubSub) AddValidator(v ValidatorFn) error { + return s.pubSub.RegisterTopicValidator(s.pubsubTopic, v.validate) +} + +func (v ValidatorFn) validate(ctx context.Context, p peer.ID, msg *pubsub.Message) (res pubsub.ValidationResult) { + defer func() { + r := recover() + if r != nil { + err := fmt.Errorf("PANIC while processing shrexsub msg: %s", r) + log.Error(err) + res = pubsub.ValidationReject + } + }() + + var pbmsg pb.RecentEDSNotification + if err := pbmsg.Unmarshal(msg.Data); err != nil { + log.Debugw("validator: unmarshal error", "err", err) + return pubsub.ValidationReject + } + + n := Notification{ + DataHash: pbmsg.DataHash, + Height: pbmsg.Height, + } + if n.Height == 0 || n.DataHash.IsEmptyRoot() || n.DataHash.Validate() != nil { + // hard reject malicious height (height 0 does not exist) and + // empty/invalid datahashes + return pubsub.ValidationReject + } + return v(ctx, p, n) +} + +// Subscribe provides a new Subscription for EDS notifications. +func (s *PubSub) Subscribe() (*Subscription, error) { + if s.topic == nil { + return nil, fmt.Errorf("shrex-sub: topic is not started") + } + return newSubscription(s.topic) +} + +// Broadcast sends the EDS notification (DataHash) to every connected peer. +func (s *PubSub) Broadcast(ctx context.Context, notification Notification) error { + if notification.DataHash.IsEmptyRoot() { + // no need to broadcast datahash of an empty block EDS + return nil + } + + msg := pb.RecentEDSNotification{ + Height: notification.Height, + DataHash: notification.DataHash, + } + data, err := msg.Marshal() + if err != nil { + return fmt.Errorf("shrex-sub: marshal notification, %w", err) + } + return s.topic.Publish(ctx, data) +} diff --git a/share/shwap/p2p/shrex/shrexsub/pubsub_test.go b/share/shwap/p2p/shrex/shrexsub/pubsub_test.go new file mode 100644 index 0000000000..78c2141852 --- /dev/null +++ b/share/shwap/p2p/shrex/shrexsub/pubsub_test.go @@ -0,0 +1,123 @@ +package shrexsub + +import ( + "context" + "testing" + "time" + + pubsub "github.com/libp2p/go-libp2p-pubsub" + "github.com/libp2p/go-libp2p/core/peer" + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" + "github.com/stretchr/testify/require" + "github.com/tendermint/tendermint/libs/rand" + + pb "github.com/celestiaorg/celestia-node/share/shwap/p2p/shrex/shrexsub/pb" +) + +func TestPubSub(t *testing.T) { + h, err := mocknet.FullMeshConnected(2) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + pSub1, err := NewPubSub(ctx, h.Hosts()[0], "test") + require.NoError(t, err) + + pSub2, err := NewPubSub(ctx, h.Hosts()[1], "test") + require.NoError(t, err) + err = pSub2.AddValidator( + func(ctx context.Context, p peer.ID, n Notification) pubsub.ValidationResult { + // only testing shrexsub validation here + return pubsub.ValidationAccept + }, + ) + require.NoError(t, err) + + require.NoError(t, pSub1.Start(ctx)) + require.NoError(t, pSub2.Start(ctx)) + + subs, err := pSub2.Subscribe() + require.NoError(t, err) + + tests := []struct { + name string + notif Notification + errExpected bool + }{ + { + name: "valid height, valid hash", + notif: Notification{ + Height: 1, + DataHash: rand.Bytes(32), + }, + errExpected: false, + }, + { + name: "valid height, invalid hash (<32 bytes)", + notif: Notification{ + Height: 2, + DataHash: rand.Bytes(20), + }, + errExpected: true, + }, + { + name: "valid height, invalid hash (>32 bytes)", + notif: Notification{ + Height: 2, + DataHash: rand.Bytes(64), + }, + errExpected: true, + }, + { + name: "invalid height, valid hash", + notif: Notification{ + Height: 0, + DataHash: rand.Bytes(32), + }, + errExpected: true, + }, + { + name: "invalid height, nil hash", + notif: Notification{ + Height: 0, + DataHash: nil, + }, + errExpected: true, + }, + { + name: "valid height, nil hash", + notif: Notification{ + Height: 30, + DataHash: nil, + }, + errExpected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msg := pb.RecentEDSNotification{ + Height: tt.notif.Height, + DataHash: tt.notif.DataHash, + } + data, err := msg.Marshal() + require.NoError(t, err) + + err = pSub1.topic.Publish(ctx, data, pubsub.WithReadiness(pubsub.MinTopicSize(1))) + require.NoError(t, err) + + reqCtx, reqCtxCancel := context.WithTimeout(context.Background(), time.Millisecond*100) + defer reqCtxCancel() + + got, err := subs.Next(reqCtx) + if tt.errExpected { + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) + return + } + require.NoError(t, err) + require.NoError(t, err) + require.Equal(t, tt.notif, got) + }) + } +} diff --git a/share/shwap/p2p/shrex/shrexsub/subscription.go b/share/shwap/p2p/shrex/shrexsub/subscription.go new file mode 100644 index 0000000000..5021f090c2 --- /dev/null +++ b/share/shwap/p2p/shrex/shrexsub/subscription.go @@ -0,0 +1,51 @@ +package shrexsub + +import ( + "context" + "fmt" + + pubsub "github.com/libp2p/go-libp2p-pubsub" + + pb "github.com/celestiaorg/celestia-node/share/shwap/p2p/shrex/shrexsub/pb" +) + +// Subscription is a wrapper over pubsub.Subscription that handles +// receiving an EDS DataHash from other peers. +type Subscription struct { + subscription *pubsub.Subscription +} + +func newSubscription(t *pubsub.Topic) (*Subscription, error) { + subs, err := t.Subscribe() + if err != nil { + return nil, err + } + + return &Subscription{subscription: subs}, nil +} + +// Next blocks the caller until any new EDS DataHash notification arrives. +// Returns only notifications which successfully pass validation. +func (subs *Subscription) Next(ctx context.Context) (Notification, error) { + msg, err := subs.subscription.Next(ctx) + if err != nil { + log.Errorw("listening for the next eds hash", "err", err) + return Notification{}, err + } + + log.Debugw("received message", "topic", msg.Message.GetTopic(), "sender", msg.ReceivedFrom) + var pbmsg pb.RecentEDSNotification + if err := pbmsg.Unmarshal(msg.Data); err != nil { + log.Debugw("unmarshal error", "err", err) + return Notification{}, fmt.Errorf("shrex-sub: unmarshal notification, %w", err) + } + return Notification{ + DataHash: pbmsg.DataHash, + Height: pbmsg.Height, + }, nil +} + +// Cancel stops the subscription. +func (subs *Subscription) Cancel() { + subs.subscription.Cancel() +} diff --git a/share/shwap/pb/shwap.pb.go b/share/shwap/pb/shwap.pb.go index 000bf78ca7..a731a18511 100644 --- a/share/shwap/pb/shwap.pb.go +++ b/share/shwap/pb/shwap.pb.go @@ -70,7 +70,207 @@ func (x Row_HalfSide) String() string { } func (Row_HalfSide) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_9431653f3c9f0bcb, []int{0, 0} + return fileDescriptor_9431653f3c9f0bcb, []int{4, 0} +} + +type EDSID struct { + Height uint64 `protobuf:"varint,1,opt,name=height,proto3" json:"height,omitempty"` +} + +func (m *EDSID) Reset() { *m = EDSID{} } +func (m *EDSID) String() string { return proto.CompactTextString(m) } +func (*EDSID) ProtoMessage() {} +func (*EDSID) Descriptor() ([]byte, []int) { + return fileDescriptor_9431653f3c9f0bcb, []int{0} +} +func (m *EDSID) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *EDSID) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_EDSID.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *EDSID) XXX_Merge(src proto.Message) { + xxx_messageInfo_EDSID.Merge(m, src) +} +func (m *EDSID) XXX_Size() int { + return m.Size() +} +func (m *EDSID) XXX_DiscardUnknown() { + xxx_messageInfo_EDSID.DiscardUnknown(m) +} + +var xxx_messageInfo_EDSID proto.InternalMessageInfo + +func (m *EDSID) GetHeight() uint64 { + if m != nil { + return m.Height + } + return 0 +} + +type RowID struct { + EdsId *EDSID `protobuf:"bytes,1,opt,name=eds_id,json=edsId,proto3" json:"eds_id,omitempty"` + RowIndex uint64 `protobuf:"varint,2,opt,name=row_index,json=rowIndex,proto3" json:"row_index,omitempty"` +} + +func (m *RowID) Reset() { *m = RowID{} } +func (m *RowID) String() string { return proto.CompactTextString(m) } +func (*RowID) ProtoMessage() {} +func (*RowID) Descriptor() ([]byte, []int) { + return fileDescriptor_9431653f3c9f0bcb, []int{1} +} +func (m *RowID) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *RowID) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_RowID.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *RowID) XXX_Merge(src proto.Message) { + xxx_messageInfo_RowID.Merge(m, src) +} +func (m *RowID) XXX_Size() int { + return m.Size() +} +func (m *RowID) XXX_DiscardUnknown() { + xxx_messageInfo_RowID.DiscardUnknown(m) +} + +var xxx_messageInfo_RowID proto.InternalMessageInfo + +func (m *RowID) GetEdsId() *EDSID { + if m != nil { + return m.EdsId + } + return nil +} + +func (m *RowID) GetRowIndex() uint64 { + if m != nil { + return m.RowIndex + } + return 0 +} + +type SampleID struct { + RowId *RowID `protobuf:"bytes,1,opt,name=row_id,json=rowId,proto3" json:"row_id,omitempty"` + ColIndex uint64 `protobuf:"varint,2,opt,name=col_index,json=colIndex,proto3" json:"col_index,omitempty"` +} + +func (m *SampleID) Reset() { *m = SampleID{} } +func (m *SampleID) String() string { return proto.CompactTextString(m) } +func (*SampleID) ProtoMessage() {} +func (*SampleID) Descriptor() ([]byte, []int) { + return fileDescriptor_9431653f3c9f0bcb, []int{2} +} +func (m *SampleID) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *SampleID) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_SampleID.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *SampleID) XXX_Merge(src proto.Message) { + xxx_messageInfo_SampleID.Merge(m, src) +} +func (m *SampleID) XXX_Size() int { + return m.Size() +} +func (m *SampleID) XXX_DiscardUnknown() { + xxx_messageInfo_SampleID.DiscardUnknown(m) +} + +var xxx_messageInfo_SampleID proto.InternalMessageInfo + +func (m *SampleID) GetRowId() *RowID { + if m != nil { + return m.RowId + } + return nil +} + +func (m *SampleID) GetColIndex() uint64 { + if m != nil { + return m.ColIndex + } + return 0 +} + +type RowNamespaceDataID struct { + RowId *RowID `protobuf:"bytes,1,opt,name=row_id,json=rowId,proto3" json:"row_id,omitempty"` + Namespace []byte `protobuf:"bytes,2,opt,name=namespace,proto3" json:"namespace,omitempty"` +} + +func (m *RowNamespaceDataID) Reset() { *m = RowNamespaceDataID{} } +func (m *RowNamespaceDataID) String() string { return proto.CompactTextString(m) } +func (*RowNamespaceDataID) ProtoMessage() {} +func (*RowNamespaceDataID) Descriptor() ([]byte, []int) { + return fileDescriptor_9431653f3c9f0bcb, []int{3} +} +func (m *RowNamespaceDataID) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *RowNamespaceDataID) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_RowNamespaceDataID.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *RowNamespaceDataID) XXX_Merge(src proto.Message) { + xxx_messageInfo_RowNamespaceDataID.Merge(m, src) +} +func (m *RowNamespaceDataID) XXX_Size() int { + return m.Size() +} +func (m *RowNamespaceDataID) XXX_DiscardUnknown() { + xxx_messageInfo_RowNamespaceDataID.DiscardUnknown(m) +} + +var xxx_messageInfo_RowNamespaceDataID proto.InternalMessageInfo + +func (m *RowNamespaceDataID) GetRowId() *RowID { + if m != nil { + return m.RowId + } + return nil +} + +func (m *RowNamespaceDataID) GetNamespace() []byte { + if m != nil { + return m.Namespace + } + return nil } type Row struct { @@ -82,7 +282,7 @@ func (m *Row) Reset() { *m = Row{} } func (m *Row) String() string { return proto.CompactTextString(m) } func (*Row) ProtoMessage() {} func (*Row) Descriptor() ([]byte, []int) { - return fileDescriptor_9431653f3c9f0bcb, []int{0} + return fileDescriptor_9431653f3c9f0bcb, []int{4} } func (m *Row) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -135,7 +335,7 @@ func (m *Sample) Reset() { *m = Sample{} } func (m *Sample) String() string { return proto.CompactTextString(m) } func (*Sample) ProtoMessage() {} func (*Sample) Descriptor() ([]byte, []int) { - return fileDescriptor_9431653f3c9f0bcb, []int{1} + return fileDescriptor_9431653f3c9f0bcb, []int{5} } func (m *Sample) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -194,7 +394,7 @@ func (m *RowNamespaceData) Reset() { *m = RowNamespaceData{} } func (m *RowNamespaceData) String() string { return proto.CompactTextString(m) } func (*RowNamespaceData) ProtoMessage() {} func (*RowNamespaceData) Descriptor() ([]byte, []int) { - return fileDescriptor_9431653f3c9f0bcb, []int{2} + return fileDescriptor_9431653f3c9f0bcb, []int{6} } func (m *RowNamespaceData) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -245,7 +445,7 @@ func (m *Share) Reset() { *m = Share{} } func (m *Share) String() string { return proto.CompactTextString(m) } func (*Share) ProtoMessage() {} func (*Share) Descriptor() ([]byte, []int) { - return fileDescriptor_9431653f3c9f0bcb, []int{3} + return fileDescriptor_9431653f3c9f0bcb, []int{7} } func (m *Share) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -284,6 +484,10 @@ func (m *Share) GetData() []byte { func init() { proto.RegisterEnum("shwap.AxisType", AxisType_name, AxisType_value) proto.RegisterEnum("shwap.Row_HalfSide", Row_HalfSide_name, Row_HalfSide_value) + proto.RegisterType((*EDSID)(nil), "shwap.EDSID") + proto.RegisterType((*RowID)(nil), "shwap.RowID") + proto.RegisterType((*SampleID)(nil), "shwap.SampleID") + proto.RegisterType((*RowNamespaceDataID)(nil), "shwap.RowNamespaceDataID") proto.RegisterType((*Row)(nil), "shwap.Row") proto.RegisterType((*Sample)(nil), "shwap.Sample") proto.RegisterType((*RowNamespaceData)(nil), "shwap.RowNamespaceData") @@ -293,31 +497,188 @@ func init() { func init() { proto.RegisterFile("share/shwap/pb/shwap.proto", fileDescriptor_9431653f3c9f0bcb) } var fileDescriptor_9431653f3c9f0bcb = []byte{ - // 381 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0x4f, 0x6b, 0xe2, 0x40, - 0x18, 0xc6, 0x33, 0x1b, 0xe3, 0xc6, 0x57, 0xd1, 0x30, 0x7b, 0x09, 0xee, 0x92, 0x95, 0xb0, 0x0b, - 0xb2, 0x60, 0xb2, 0xe8, 0x27, 0xd8, 0xbf, 0xb5, 0x60, 0x6b, 0x19, 0x85, 0x42, 0x2f, 0x61, 0x62, - 0x46, 0x13, 0x88, 0x9d, 0x21, 0x49, 0x49, 0x3d, 0xf7, 0xd0, 0x6b, 0x3f, 0x56, 0x8f, 0x1e, 0x7b, - 0x2c, 0xfa, 0x45, 0x4a, 0x26, 0xb1, 0x14, 0xda, 0x43, 0x6f, 0xbf, 0xcc, 0xf3, 0xcc, 0xbc, 0xcf, - 0x13, 0x5e, 0xe8, 0xa6, 0x21, 0x4d, 0x98, 0x9b, 0x86, 0x39, 0x15, 0xae, 0xf0, 0x4b, 0x70, 0x44, - 0xc2, 0x33, 0x8e, 0x35, 0xf9, 0xd1, 0x6d, 0x0b, 0xdf, 0x15, 0x09, 0xe7, 0xcb, 0xf2, 0xd8, 0xbe, - 0x45, 0xa0, 0x12, 0x9e, 0xe3, 0x01, 0x34, 0xe5, 0xe5, 0xd4, 0x0b, 0x69, 0xbc, 0x34, 0x51, 0x4f, - 0xed, 0x37, 0x87, 0x2d, 0xa7, 0x7c, 0x61, 0x56, 0x28, 0x04, 0x4a, 0xc3, 0x98, 0xc6, 0x4b, 0xfc, - 0x13, 0x1a, 0x85, 0xcf, 0x4b, 0xa3, 0x80, 0x99, 0x1f, 0x7a, 0xa8, 0xdf, 0x1e, 0x7e, 0xaa, 0xcc, - 0x84, 0xe7, 0x4e, 0xe1, 0x99, 0x45, 0x01, 0x23, 0x7a, 0x58, 0x91, 0xfd, 0x15, 0xf4, 0xc3, 0x29, - 0xd6, 0xa1, 0x36, 0xf9, 0xf7, 0x7f, 0x6e, 0x28, 0xb8, 0x01, 0x1a, 0x39, 0x3e, 0x1a, 0xcf, 0x0d, - 0x64, 0xdf, 0x20, 0xa8, 0xcf, 0xe8, 0x5a, 0xc4, 0x0c, 0xdb, 0xa0, 0xc9, 0x59, 0x26, 0xea, 0xa1, - 0x57, 0x31, 0x4a, 0x09, 0x7f, 0x07, 0x4d, 0xf6, 0x90, 0xd3, 0x9b, 0xc3, 0x8e, 0x53, 0xb5, 0xf2, - 0x9d, 0xb3, 0x02, 0x48, 0xa9, 0x62, 0x07, 0x40, 0x82, 0x97, 0x6d, 0x04, 0x33, 0x55, 0x99, 0xb4, - 0x53, 0xbd, 0xf7, 0xeb, 0x3a, 0x4a, 0xe7, 0x1b, 0xc1, 0x48, 0x43, 0x5a, 0x0a, 0xb4, 0x3d, 0x30, - 0x08, 0xcf, 0x4f, 0xe9, 0x9a, 0xa5, 0x82, 0x2e, 0xd8, 0x5f, 0x9a, 0x51, 0xfc, 0x0d, 0xea, 0x65, - 0xf5, 0x37, 0x7f, 0x4b, 0xa5, 0xbd, 0x33, 0x90, 0xfd, 0x19, 0x34, 0x79, 0x0f, 0x63, 0xa8, 0x05, - 0x34, 0xa3, 0xb2, 0x63, 0x8b, 0x48, 0xfe, 0xf1, 0x05, 0xf4, 0x43, 0x28, 0xfc, 0x11, 0x54, 0x32, - 0x3d, 0x37, 0x94, 0x02, 0xfe, 0x4c, 0x27, 0x06, 0xfa, 0x7d, 0x72, 0xbf, 0xb3, 0xd0, 0x76, 0x67, - 0xa1, 0xc7, 0x9d, 0x85, 0xee, 0xf6, 0x96, 0xb2, 0xdd, 0x5b, 0xca, 0xc3, 0xde, 0x52, 0x2e, 0x46, - 0xab, 0x28, 0x0b, 0xaf, 0x7c, 0x67, 0xc1, 0xd7, 0xee, 0x82, 0xc5, 0x2c, 0xcd, 0x22, 0xca, 0x93, - 0xd5, 0x33, 0x0f, 0x2e, 0x79, 0x50, 0xec, 0xc5, 0xcb, 0xed, 0xf0, 0xeb, 0x72, 0x03, 0x46, 0x4f, - 0x01, 0x00, 0x00, 0xff, 0xff, 0x67, 0xb6, 0xc0, 0x8b, 0x36, 0x02, 0x00, 0x00, + // 485 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x93, 0xc1, 0x6e, 0xd3, 0x40, + 0x10, 0x86, 0xbd, 0x24, 0x36, 0xf6, 0x24, 0x6a, 0xad, 0x45, 0x42, 0x51, 0x5b, 0xb9, 0x95, 0x01, + 0x09, 0x21, 0xd5, 0x41, 0xe9, 0x13, 0x00, 0x29, 0xd4, 0x52, 0xa0, 0x68, 0x13, 0xa9, 0x12, 0x17, + 0x6b, 0x6d, 0x6f, 0x62, 0x4b, 0x4e, 0xd6, 0xb2, 0x8d, 0xdc, 0x9e, 0x39, 0x70, 0xe5, 0xb1, 0x38, + 0xf6, 0xc8, 0x11, 0x25, 0x2f, 0x82, 0x3c, 0xeb, 0xd0, 0x12, 0x71, 0xe8, 0xed, 0xdf, 0x9d, 0x7f, + 0xbe, 0x99, 0x9d, 0xd5, 0xc0, 0x41, 0x99, 0xf0, 0x42, 0x0c, 0xcb, 0xa4, 0xe6, 0xf9, 0x30, 0x0f, + 0x95, 0xf0, 0xf2, 0x42, 0x56, 0x92, 0xea, 0x78, 0x38, 0xd8, 0xcb, 0xc3, 0x61, 0x5e, 0x48, 0x39, + 0x57, 0xd7, 0xee, 0x31, 0xe8, 0xe7, 0xe3, 0xa9, 0x3f, 0xa6, 0x4f, 0xc1, 0x48, 0x44, 0xba, 0x48, + 0xaa, 0x01, 0x39, 0x21, 0x2f, 0xbb, 0xac, 0x3d, 0xb9, 0x3e, 0xe8, 0x4c, 0xd6, 0xfe, 0x98, 0x3e, + 0x03, 0x43, 0xc4, 0x65, 0x90, 0xc6, 0x68, 0xe8, 0x8d, 0xfa, 0x9e, 0xc2, 0x63, 0x3a, 0xd3, 0x45, + 0x5c, 0xfa, 0x31, 0x3d, 0x04, 0xab, 0x90, 0x75, 0x90, 0xae, 0x62, 0x71, 0x3d, 0x78, 0x84, 0x20, + 0xb3, 0x90, 0xb5, 0xdf, 0x9c, 0xdd, 0x09, 0x98, 0x53, 0xbe, 0xcc, 0x33, 0xa1, 0x68, 0x68, 0xdc, + 0xa5, 0x61, 0x2d, 0xa6, 0x37, 0x39, 0x48, 0x8b, 0x64, 0xf6, 0x2f, 0x2d, 0x92, 0x99, 0xa2, 0x5d, + 0x01, 0x65, 0xb2, 0xfe, 0xc4, 0x97, 0xa2, 0xcc, 0x79, 0x24, 0xc6, 0xbc, 0xe2, 0x0f, 0xe5, 0x1e, + 0x81, 0xb5, 0xda, 0xe6, 0x21, 0xb7, 0xcf, 0xee, 0x2e, 0xdc, 0xef, 0x04, 0x3a, 0x4c, 0xd6, 0xf4, + 0x14, 0x7a, 0x38, 0xcf, 0x32, 0x48, 0x78, 0x36, 0x1f, 0x90, 0x93, 0xce, 0x3d, 0xde, 0xb4, 0x89, + 0x30, 0x50, 0x86, 0x0b, 0x9e, 0xcd, 0xe9, 0x6b, 0xb0, 0x1a, 0x5f, 0x50, 0xa6, 0xb1, 0x82, 0xee, + 0x8d, 0x9e, 0xdc, 0x15, 0xf7, 0x1a, 0xcf, 0x34, 0x8d, 0x05, 0x33, 0x93, 0x56, 0xb9, 0xc7, 0x60, + 0x6e, 0x6f, 0xa9, 0x09, 0xdd, 0xc9, 0xf9, 0xfb, 0x99, 0xad, 0x51, 0x0b, 0x74, 0xe6, 0x7f, 0xb8, + 0x98, 0xd9, 0xc4, 0xfd, 0x46, 0xc0, 0x50, 0x13, 0xa3, 0x2e, 0xe8, 0x58, 0x6b, 0xe7, 0x59, 0xaa, + 0x0d, 0x15, 0xa2, 0x2f, 0x40, 0xc7, 0xaf, 0xc5, 0xea, 0xbd, 0xd1, 0xbe, 0xd7, 0x7e, 0x74, 0xe8, + 0x7d, 0x6e, 0x04, 0x53, 0x51, 0xea, 0x01, 0xa0, 0x08, 0xaa, 0x9b, 0x5c, 0x0c, 0x3a, 0xd8, 0xe9, + 0x7e, 0xcb, 0x7b, 0x73, 0x9d, 0x96, 0xb3, 0x9b, 0x5c, 0x30, 0x0b, 0x2d, 0x8d, 0x74, 0x03, 0xb0, + 0x77, 0x07, 0x4d, 0x9f, 0x83, 0xa1, 0x9e, 0xfe, 0xdf, 0xb1, 0xb4, 0xb1, 0x07, 0x36, 0xe4, 0x1e, + 0x82, 0x8e, 0x79, 0x94, 0x42, 0x37, 0xe6, 0x15, 0xc7, 0x37, 0xf6, 0x19, 0xea, 0x57, 0x47, 0x60, + 0x6e, 0x9b, 0xa2, 0x8f, 0xa1, 0xc3, 0x2e, 0xaf, 0x6c, 0xad, 0x11, 0xef, 0x2e, 0x27, 0x36, 0x79, + 0xfb, 0xf1, 0xe7, 0xda, 0x21, 0xb7, 0x6b, 0x87, 0xfc, 0x5e, 0x3b, 0xe4, 0xc7, 0xc6, 0xd1, 0x6e, + 0x37, 0x8e, 0xf6, 0x6b, 0xe3, 0x68, 0x5f, 0xce, 0x16, 0x69, 0x95, 0x7c, 0x0d, 0xbd, 0x48, 0x2e, + 0x87, 0x91, 0xc8, 0x44, 0x59, 0xa5, 0x5c, 0x16, 0x8b, 0xbf, 0xfa, 0x74, 0x25, 0xe3, 0x66, 0x55, + 0xee, 0x2f, 0x4c, 0x68, 0xe0, 0x52, 0x9c, 0xfd, 0x09, 0x00, 0x00, 0xff, 0xff, 0x3b, 0x0a, 0x5d, + 0x31, 0x49, 0x03, 0x00, 0x00, +} + +func (m *EDSID) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *EDSID) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *EDSID) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.Height != 0 { + i = encodeVarintShwap(dAtA, i, uint64(m.Height)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func (m *RowID) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *RowID) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *RowID) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.RowIndex != 0 { + i = encodeVarintShwap(dAtA, i, uint64(m.RowIndex)) + i-- + dAtA[i] = 0x10 + } + if m.EdsId != nil { + { + size, err := m.EdsId.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintShwap(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func (m *SampleID) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *SampleID) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *SampleID) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.ColIndex != 0 { + i = encodeVarintShwap(dAtA, i, uint64(m.ColIndex)) + i-- + dAtA[i] = 0x10 + } + if m.RowId != nil { + { + size, err := m.RowId.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintShwap(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func (m *RowNamespaceDataID) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *RowNamespaceDataID) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *RowNamespaceDataID) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if len(m.Namespace) > 0 { + i -= len(m.Namespace) + copy(dAtA[i:], m.Namespace) + i = encodeVarintShwap(dAtA, i, uint64(len(m.Namespace))) + i-- + dAtA[i] = 0x12 + } + if m.RowId != nil { + { + size, err := m.RowId.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintShwap(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil } func (m *Row) Marshal() (dAtA []byte, err error) { @@ -504,31 +865,92 @@ func encodeVarintShwap(dAtA []byte, offset int, v uint64) int { dAtA[offset] = uint8(v) return base } -func (m *Row) Size() (n int) { +func (m *EDSID) Size() (n int) { if m == nil { return 0 } var l int _ = l - if len(m.SharesHalf) > 0 { - for _, e := range m.SharesHalf { - l = e.Size() - n += 1 + l + sovShwap(uint64(l)) - } - } - if m.HalfSide != 0 { - n += 1 + sovShwap(uint64(m.HalfSide)) + if m.Height != 0 { + n += 1 + sovShwap(uint64(m.Height)) } return n } -func (m *Sample) Size() (n int) { +func (m *RowID) Size() (n int) { if m == nil { return 0 } var l int _ = l - if m.Share != nil { + if m.EdsId != nil { + l = m.EdsId.Size() + n += 1 + l + sovShwap(uint64(l)) + } + if m.RowIndex != 0 { + n += 1 + sovShwap(uint64(m.RowIndex)) + } + return n +} + +func (m *SampleID) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.RowId != nil { + l = m.RowId.Size() + n += 1 + l + sovShwap(uint64(l)) + } + if m.ColIndex != 0 { + n += 1 + sovShwap(uint64(m.ColIndex)) + } + return n +} + +func (m *RowNamespaceDataID) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.RowId != nil { + l = m.RowId.Size() + n += 1 + l + sovShwap(uint64(l)) + } + l = len(m.Namespace) + if l > 0 { + n += 1 + l + sovShwap(uint64(l)) + } + return n +} + +func (m *Row) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if len(m.SharesHalf) > 0 { + for _, e := range m.SharesHalf { + l = e.Size() + n += 1 + l + sovShwap(uint64(l)) + } + } + if m.HalfSide != 0 { + n += 1 + sovShwap(uint64(m.HalfSide)) + } + return n +} + +func (m *Sample) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Share != nil { l = m.Share.Size() n += 1 + l + sovShwap(uint64(l)) } @@ -580,6 +1002,405 @@ func sovShwap(x uint64) (n int) { func sozShwap(x uint64) (n int) { return sovShwap(uint64((x << 1) ^ uint64((int64(x) >> 63)))) } +func (m *EDSID) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: EDSID: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: EDSID: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Height", wireType) + } + m.Height = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Height |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipShwap(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthShwap + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *RowID) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: RowID: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: RowID: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field EdsId", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthShwap + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthShwap + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.EdsId == nil { + m.EdsId = &EDSID{} + } + if err := m.EdsId.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field RowIndex", wireType) + } + m.RowIndex = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.RowIndex |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipShwap(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthShwap + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *SampleID) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: SampleID: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: SampleID: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field RowId", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthShwap + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthShwap + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.RowId == nil { + m.RowId = &RowID{} + } + if err := m.RowId.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field ColIndex", wireType) + } + m.ColIndex = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.ColIndex |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipShwap(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthShwap + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *RowNamespaceDataID) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: RowNamespaceDataID: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: RowNamespaceDataID: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field RowId", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthShwap + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthShwap + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.RowId == nil { + m.RowId = &RowID{} + } + if err := m.RowId.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Namespace", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthShwap + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthShwap + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Namespace = append(m.Namespace[:0], dAtA[iNdEx:postIndex]...) + if m.Namespace == nil { + m.Namespace = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipShwap(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthShwap + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} func (m *Row) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 diff --git a/share/shwap/pb/shwap.proto b/share/shwap/pb/shwap.proto index d7daea568a..1190fb2332 100644 --- a/share/shwap/pb/shwap.proto +++ b/share/shwap/pb/shwap.proto @@ -5,6 +5,25 @@ option go_package = "github.com/celestiaorg/celestia-node/share/shwap/pb"; import "pb/proof.proto"; // celestiaorg/nmt/pb/proof.proto +message EDSID { + uint64 height = 1; +} + +message RowID { + EDSID eds_id = 1; + uint64 row_index = 2; +} + +message SampleID { + RowID row_id = 1; + uint64 col_index = 2; +} + +message RowNamespaceDataID { + RowID row_id = 1; + bytes namespace = 2; +} + message Row { repeated Share shares_half = 1; HalfSide half_side= 2; diff --git a/state/core_access.go b/state/core_access.go index dca866b35c..9e46bcfc08 100644 --- a/state/core_access.go +++ b/state/core_access.go @@ -184,9 +184,9 @@ func (ca *CoreAccessor) cancelCtx() { ca.cancel = nil } -// SubmitPayForBlob builds, signs, and synchronously submits a MsgPayForBlob with additional options defined -// in `TxConfig`. It blocks until the transaction is committed and returns the TxResponse. -// The user can specify additional options that can bee applied to the Tx. +// SubmitPayForBlob builds, signs, and synchronously submits a MsgPayForBlob with additional +// options defined in `TxConfig`. It blocks until the transaction is committed and returns the +// TxResponse. The user can specify additional options that can bee applied to the Tx. func (ca *CoreAccessor) SubmitPayForBlob( ctx context.Context, appblobs []*Blob, @@ -583,7 +583,8 @@ func (ca *CoreAccessor) queryMinimumGasPrice( func (ca *CoreAccessor) setupTxClient(ctx context.Context, keyName string) (*user.TxClient, error) { encCfg := encoding.MakeConfig(app.ModuleEncodingRegisters...) - // explicitly set default address. Otherwise, there could be a mismatch between defaultKey and defaultAddress. + // explicitly set default address. Otherwise, there could be a mismatch between defaultKey and + // defaultAddress. rec, err := ca.keyring.Key(keyName) if err != nil { return nil, err diff --git a/store/store.go b/store/store.go index 17eade46fe..04b3b549d1 100644 --- a/store/store.go +++ b/store/store.go @@ -16,6 +16,7 @@ import ( "github.com/celestiaorg/rsmt2d" + "github.com/celestiaorg/celestia-node/libs/utils" "github.com/celestiaorg/celestia-node/share" eds "github.com/celestiaorg/celestia-node/share/new_eds" "github.com/celestiaorg/celestia-node/store/cache" @@ -471,7 +472,7 @@ func (s *Store) storeEmptyHeights() error { if err != nil { return fmt.Errorf("opening empty heights file: %w", err) } - defer closeAndLog(log, "empty heights file", file) + defer utils.CloseAndLog(log, "empty heights file", file) s.emptyHeightsLock.RLock() defer s.emptyHeightsLock.RUnlock() @@ -489,7 +490,7 @@ func loadEmptyHeights(basepath string) (map[uint64]struct{}, error) { if err != nil { return nil, fmt.Errorf("opening empty heights file: %w", err) } - defer closeAndLog(log, "empty heights file", file) + defer utils.CloseAndLog(log, "empty heights file", file) emptyHeights := make(map[uint64]struct{}) err = gob.NewDecoder(file).Decode(&emptyHeights) @@ -511,9 +512,3 @@ func (s *Store) addEmptyHeight(height uint64) { defer s.emptyHeightsLock.Unlock() s.emptyHeights[height] = struct{}{} } - -func closeAndLog(log logging.StandardLogger, name string, closer io.Closer) { - if err := closer.Close(); err != nil { - log.Warnf("closing %s: %s", name, err) - } -}