Skip to content

Commit

Permalink
ATX handler rejects invalid ATXs on pubsub lvl (#6142)
Browse files Browse the repository at this point in the history
In order to drop peers sending invalid ATXs, the handler must return `pubsub.ErrValidationReject`
  • Loading branch information
poszu authored and fasmat committed Jul 22, 2024
1 parent 69618c8 commit 90c8a9b
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 14 deletions.
2 changes: 1 addition & 1 deletion activation/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ func (h *Handler) handleAtx(

opaqueAtx, err := h.decodeATX(msg)
if err != nil {
return nil, fmt.Errorf("decoding ATX: %w", err)
return nil, fmt.Errorf("%w: decoding ATX: %w", pubsub.ErrValidationReject, err)
}
id := opaqueAtx.ID()

Expand Down
4 changes: 4 additions & 0 deletions activation/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ func TestHandler_HandleSyncedAtx(t *testing.T) {
err := atxHdlr.HandleSyncedAtx(context.Background(), atx.ID().Hash32(), p2p.NoPeer, buf)
require.ErrorIs(t, err, errMalformedData)
require.ErrorContains(t, err, "invalid atx signature")
require.ErrorIs(t, err, pubsub.ErrValidationReject)
})
t.Run("atx V2", func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -857,12 +858,14 @@ func TestHandler_DecodeATX(t *testing.T) {
atxHdlr := newTestHandler(t, types.RandomATXID())
_, err := atxHdlr.decodeATX(nil)
require.ErrorIs(t, err, errMalformedData)
require.ErrorIs(t, err, pubsub.ErrValidationReject)
})
t.Run("malformed atx", func(t *testing.T) {
t.Parallel()
atxHdlr := newTestHandler(t, types.RandomATXID())
_, err := atxHdlr.decodeATX([]byte("malformed"))
require.ErrorIs(t, err, errMalformedData)
require.ErrorIs(t, err, pubsub.ErrValidationReject)
})
t.Run("v1", func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -893,5 +896,6 @@ func TestHandler_DecodeATX(t *testing.T) {
atx.PublishEpoch = 9
_, err := atxHdlr.decodeATX(codec.MustEncode(atx))
require.ErrorIs(t, err, errMalformedData)
require.ErrorIs(t, err, pubsub.ErrValidationReject)
})
}
9 changes: 5 additions & 4 deletions activation/handler_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/spacemeshos/go-spacemesh/log"
mwire "github.com/spacemeshos/go-spacemesh/malfeasance/wire"
"github.com/spacemeshos/go-spacemesh/p2p"
"github.com/spacemeshos/go-spacemesh/p2p/pubsub"
"github.com/spacemeshos/go-spacemesh/signing"
"github.com/spacemeshos/go-spacemesh/sql"
"github.com/spacemeshos/go-spacemesh/sql/atxs"
Expand Down Expand Up @@ -622,7 +623,7 @@ func (h *HandlerV1) processATX(
received time.Time,
) (*mwire.MalfeasanceProof, error) {
if !h.edVerifier.Verify(signing.ATX, watx.SmesherID, watx.SignedBytes(), watx.Signature) {
return nil, fmt.Errorf("invalid atx signature: %w", errMalformedData)
return nil, fmt.Errorf("%w: invalid atx signature: %w", pubsub.ErrValidationReject, errMalformedData)
}

existing, _ := h.cdb.GetAtx(watx.ID())
Expand All @@ -638,7 +639,7 @@ func (h *HandlerV1) processATX(
)

if err := h.syntacticallyValidate(ctx, watx); err != nil {
return nil, fmt.Errorf("atx %s syntactically invalid: %w", watx.ID(), err)
return nil, fmt.Errorf("%w: validating atx %s: %w", pubsub.ErrValidationReject, watx.ID(), err)
}

poetRef, atxIDs := collectAtxDeps(h.goldenATXID, watx)
Expand All @@ -649,7 +650,7 @@ func (h *HandlerV1) processATX(

leaves, effectiveNumUnits, proof, err := h.syntacticallyValidateDeps(ctx, watx)
if err != nil {
return nil, fmt.Errorf("atx %s syntactically invalid based on deps: %w", watx.ID(), err)
return nil, fmt.Errorf("%w: validating atx %s (deps): %w", pubsub.ErrValidationReject, watx.ID(), err)
}
if proof != nil {
return proof, nil
Expand Down Expand Up @@ -712,7 +713,7 @@ func (h *HandlerV1) registerHashes(peer p2p.Peer, poetRef types.Hash32, atxIDs [
// fetchReferences makes sure that the referenced poet proof and ATXs are available.
func (h *HandlerV1) fetchReferences(ctx context.Context, poetRef types.Hash32, atxIDs []types.ATXID) error {
if err := h.fetcher.GetPoetProof(ctx, poetRef); err != nil {
return fmt.Errorf("missing poet proof (%s): %w", poetRef.ShortString(), err)
return fmt.Errorf("fetching poet proof (%s): %w", poetRef.ShortString(), err)
}

if len(atxIDs) == 0 {
Expand Down
16 changes: 16 additions & 0 deletions activation/handler_v1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ import (
"github.com/spacemeshos/go-spacemesh/codec"
"github.com/spacemeshos/go-spacemesh/common/types"
"github.com/spacemeshos/go-spacemesh/datastore"
"github.com/spacemeshos/go-spacemesh/fetch"
mwire "github.com/spacemeshos/go-spacemesh/malfeasance/wire"
"github.com/spacemeshos/go-spacemesh/p2p"
"github.com/spacemeshos/go-spacemesh/p2p/pubsub"
"github.com/spacemeshos/go-spacemesh/signing"
"github.com/spacemeshos/go-spacemesh/sql"
"github.com/spacemeshos/go-spacemesh/sql/atxs"
Expand Down Expand Up @@ -857,4 +859,18 @@ func TestHandlerV1_FetchesReferences(t *testing.T) {
atxHdlr.mockFetch.EXPECT().GetAtxs(gomock.Any(), atxs, gomock.Any()).Return(errors.New("oh"))
require.Error(t, atxHdlr.fetchReferences(context.Background(), poet, atxs))
})
t.Run("reject ATX when dependency ATX is rejected", func(t *testing.T) {
t.Parallel()
atxHdlr := newV1TestHandler(t, goldenATXID)

poet := types.RandomHash()
atxs := []types.ATXID{types.RandomATXID(), types.RandomATXID()}
var batchErr fetch.BatchError
batchErr.Add(atxs[0].Hash32(), pubsub.ErrValidationReject)

atxHdlr.mockFetch.EXPECT().GetPoetProof(gomock.Any(), poet)
atxHdlr.mockFetch.EXPECT().GetAtxs(gomock.Any(), atxs, gomock.Any()).Return(&batchErr)

require.ErrorIs(t, atxHdlr.fetchReferences(context.Background(), poet, atxs), pubsub.ErrValidationReject)
})
}
9 changes: 5 additions & 4 deletions activation/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/spacemeshos/go-spacemesh/log"
mwire "github.com/spacemeshos/go-spacemesh/malfeasance/wire"
"github.com/spacemeshos/go-spacemesh/p2p"
"github.com/spacemeshos/go-spacemesh/p2p/pubsub"
"github.com/spacemeshos/go-spacemesh/signing"
"github.com/spacemeshos/go-spacemesh/sql"
"github.com/spacemeshos/go-spacemesh/sql/atxs"
Expand Down Expand Up @@ -90,7 +91,7 @@ func (h *HandlerV2) processATX(
)

if err := h.syntacticallyValidate(ctx, watx); err != nil {
return nil, fmt.Errorf("atx %s syntactically invalid: %w", watx.ID(), err)
return nil, fmt.Errorf("%w: validating atx %s: %w", pubsub.ErrValidationReject, watx.ID(), err)
}

poetRef, atxIDs := h.collectAtxDeps(watx)
Expand All @@ -101,17 +102,17 @@ func (h *HandlerV2) processATX(

baseTickHeight, err := h.validatePositioningAtx(watx.PublishEpoch, h.goldenATXID, watx.PositioningATX)
if err != nil {
return nil, fmt.Errorf("validating positioning atx: %w", err)
return nil, fmt.Errorf("%w: validating positioning atx: %w", pubsub.ErrValidationReject, err)
}

marrying, err := h.validateMarriages(watx)
if err != nil {
return nil, fmt.Errorf("validating marriages: %w", err)
return nil, fmt.Errorf("%w: validating marriages: %w", pubsub.ErrValidationReject, err)
}

parts, proof, err := h.syntacticallyValidateDeps(ctx, watx)
if err != nil {
return nil, fmt.Errorf("atx %s syntactically invalid based on deps: %w", watx.ID(), err)
return nil, fmt.Errorf("%w: validating atx %s (deps): %w", pubsub.ErrValidationReject, watx.ID(), err)
}

if proof != nil {
Expand Down
33 changes: 33 additions & 0 deletions activation/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ import (
"github.com/spacemeshos/go-spacemesh/codec"
"github.com/spacemeshos/go-spacemesh/common/types"
"github.com/spacemeshos/go-spacemesh/datastore"
"github.com/spacemeshos/go-spacemesh/fetch"
mwire "github.com/spacemeshos/go-spacemesh/malfeasance/wire"
"github.com/spacemeshos/go-spacemesh/p2p/pubsub"
"github.com/spacemeshos/go-spacemesh/signing"
"github.com/spacemeshos/go-spacemesh/sql"
"github.com/spacemeshos/go-spacemesh/sql/atxs"
Expand Down Expand Up @@ -539,6 +541,7 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) {

_, err = atxHandler.processATX(context.Background(), peer, atx, codec.MustEncode(atx), time.Now())
require.ErrorContains(t, err, "vrf nonce is not valid")
require.ErrorIs(t, err, pubsub.ErrValidationReject)

_, err = atxs.Get(atxHandler.cdb, atx.ID())
require.ErrorIs(t, err, sql.ErrNotFound)
Expand Down Expand Up @@ -581,6 +584,7 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) {
atxHandler.expectFetchDeps(atx)
_, err := atxHandler.processATX(context.Background(), peer, atx, codec.MustEncode(atx), time.Now())
require.ErrorContains(t, err, "validating positioning atx")
require.ErrorIs(t, err, pubsub.ErrValidationReject)

_, err = atxs.Get(atxHandler.cdb, atx.ID())
require.ErrorIs(t, err, sql.ErrNotFound)
Expand Down Expand Up @@ -706,6 +710,20 @@ func TestHandlerV2_FetchesReferences(t *testing.T) {
atxHdlr.mockFetch.EXPECT().GetPoetProof(gomock.Any(), poets[1]).Return(errors.New("pooh"))
require.Error(t, atxHdlr.fetchReferences(context.Background(), poets, nil))
})
t.Run("reject ATX when dependency poet proof is rejected", func(t *testing.T) {
t.Parallel()
atxHdlr := newV2TestHandler(t, golden)

poets := []types.Hash32{types.RandomHash()}
atxs := []types.ATXID{types.RandomATXID()}
var batchErr fetch.BatchError
batchErr.Add(atxs[0].Hash32(), pubsub.ErrValidationReject)

atxHdlr.mockFetch.EXPECT().GetPoetProof(gomock.Any(), poets[0]).Return(&batchErr)
atxHdlr.mockFetch.EXPECT().GetAtxs(gomock.Any(), atxs, gomock.Any())

require.ErrorIs(t, atxHdlr.fetchReferences(context.Background(), poets, atxs), pubsub.ErrValidationReject)
})

t.Run("failed to fetch atxs", func(t *testing.T) {
t.Parallel()
Expand All @@ -719,6 +737,20 @@ func TestHandlerV2_FetchesReferences(t *testing.T) {
atxHdlr.mockFetch.EXPECT().GetAtxs(gomock.Any(), atxs, gomock.Any()).Return(errors.New("oh"))
require.Error(t, atxHdlr.fetchReferences(context.Background(), poets, atxs))
})
t.Run("reject ATX when dependency ATX is rejected", func(t *testing.T) {
t.Parallel()
atxHdlr := newV2TestHandler(t, golden)

poets := []types.Hash32{types.RandomHash()}
atxs := []types.ATXID{types.RandomATXID(), types.RandomATXID()}
var batchErr fetch.BatchError
batchErr.Add(atxs[0].Hash32(), pubsub.ErrValidationReject)

atxHdlr.mockFetch.EXPECT().GetPoetProof(gomock.Any(), poets[0])
atxHdlr.mockFetch.EXPECT().GetAtxs(gomock.Any(), atxs, gomock.Any()).Return(&batchErr)

require.ErrorIs(t, atxHdlr.fetchReferences(context.Background(), poets, atxs), pubsub.ErrValidationReject)
})
t.Run("no atxs to fetch", func(t *testing.T) {
t.Parallel()
atxHdlr := newV2TestHandler(t, golden)
Expand Down Expand Up @@ -1215,6 +1247,7 @@ func Test_Marriages(t *testing.T) {
atxHandler.mclock.EXPECT().CurrentLayer().AnyTimes()
_, err = atxHandler.processATX(context.Background(), "", atx, codec.MustEncode(atx), time.Now())
require.ErrorContains(t, err, "signer must marry itself")
require.ErrorIs(t, err, pubsub.ErrValidationReject)
})
}

Expand Down
3 changes: 2 additions & 1 deletion activation/poet.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/spacemeshos/go-spacemesh/activation/metrics"
"github.com/spacemeshos/go-spacemesh/common/types"
"github.com/spacemeshos/go-spacemesh/log"
"github.com/spacemeshos/go-spacemesh/sql"
"github.com/spacemeshos/go-spacemesh/sql/localsql/certifier"
)

Expand Down Expand Up @@ -505,7 +506,7 @@ func (c *poetService) Proof(ctx context.Context, roundID string) (*types.PoetPro
return nil, nil, fmt.Errorf("getting proof: %w", err)
}

if err := c.db.ValidateAndStore(ctx, proof); err != nil && !errors.Is(err, ErrObjectExists) {
if err := c.db.ValidateAndStore(ctx, proof); err != nil && !errors.Is(err, sql.ErrObjectExists) {
c.logger.Warn("failed to validate and store proof", zap.Error(err), zap.Object("proof", proof))
return nil, nil, fmt.Errorf("validating and storing proof: %w", err)
}
Expand Down
2 changes: 0 additions & 2 deletions activation/poetdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ import (
"github.com/spacemeshos/go-spacemesh/sql/poets"
)

var ErrObjectExists = sql.ErrObjectExists

// PoetDb is a database for PoET proofs.
type PoetDb struct {
sqlDB *sql.Database
Expand Down
13 changes: 11 additions & 2 deletions fetch/mesh_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ import (
"github.com/spacemeshos/go-scale"
"golang.org/x/sync/errgroup"

"github.com/spacemeshos/go-spacemesh/activation"
"github.com/spacemeshos/go-spacemesh/codec"
"github.com/spacemeshos/go-spacemesh/common/types"
"github.com/spacemeshos/go-spacemesh/datastore"
"github.com/spacemeshos/go-spacemesh/log"
"github.com/spacemeshos/go-spacemesh/p2p"
"github.com/spacemeshos/go-spacemesh/p2p/pubsub"
"github.com/spacemeshos/go-spacemesh/p2p/server"
"github.com/spacemeshos/go-spacemesh/sql"
"github.com/spacemeshos/go-spacemesh/system"
)

Expand Down Expand Up @@ -211,7 +211,7 @@ func (f *Fetch) GetPoetProof(ctx context.Context, id types.Hash32) error {
switch {
case pm.err == nil:
return nil
case errors.Is(pm.err, activation.ErrObjectExists):
case errors.Is(pm.err, sql.ErrObjectExists):
// PoET proofs are concurrently stored in DB in two places:
// fetcher and nipost builder. Hence, it might happen that
// a proof had been inserted into the DB while the fetcher
Expand Down Expand Up @@ -400,6 +400,15 @@ func (b *BatchError) Empty() bool {
return len(b.Errors) == 0
}

func (b *BatchError) Is(target error) bool {
for _, err := range b.Errors {
if errors.Is(err, target) {
return true
}
}
return false
}

func (b *BatchError) Add(id types.Hash32, err error) {
if b.Errors == nil {
b.Errors = map[types.Hash32]error{}
Expand Down

0 comments on commit 90c8a9b

Please sign in to comment.