diff --git a/activation/activation.go b/activation/activation.go index 7923f4ef46..dd71dcef88 100644 --- a/activation/activation.go +++ b/activation/activation.go @@ -874,7 +874,7 @@ func (b *Builder) createAtx( PositioningATX: challenge.PositioningATX, Coinbase: b.Coinbase(), VRFNonce: (uint64)(nipostState.VRFNonce), - NiPosts: []wire.NiPostsV2{ + NIPosts: []wire.NIPostV2{ { Membership: wire.MerkleProofV2{ Nodes: nipostState.Membership.Nodes, diff --git a/activation/certifier.go b/activation/certifier.go index 636a14ca78..cd106ab9b7 100644 --- a/activation/certifier.go +++ b/activation/certifier.go @@ -344,7 +344,7 @@ func loadPost(ctx context.Context, db sql.Executor, id types.ATXID) (*types.Post if err := codec.Decode(blob.Bytes, &atx); err != nil { return nil, nil, fmt.Errorf("decoding ATX blob: %w", err) } - return wire.PostFromWireV1(&atx.NiPosts[0].Posts[0].Post), atx.NiPosts[0].Challenge[:], nil + return wire.PostFromWireV1(&atx.NIPosts[0].Posts[0].Post), atx.NIPosts[0].Challenge[:], nil } panic("unsupported ATX version") } diff --git a/activation/e2e/atx_merge_test.go b/activation/e2e/atx_merge_test.go index dca9e733b6..a97c44b74e 100644 --- a/activation/e2e/atx_merge_test.go +++ b/activation/e2e/atx_merge_test.go @@ -97,7 +97,7 @@ func createInitialAtx( Post: *wire.PostToWireV1(initial), }, VRFNonce: uint64(nipost.VRFNonce), - NiPosts: []wire.NiPostsV2{ + NIPosts: []wire.NIPostV2{ { Membership: wire.MerkleProofV2{ Nodes: nipost.Membership.Nodes, @@ -121,7 +121,7 @@ func createSoloAtx(publish types.EpochID, prev, pos types.ATXID, nipost *nipost. PreviousATXs: []types.ATXID{prev}, PositioningATX: pos, VRFNonce: uint64(nipost.VRFNonce), - NiPosts: []wire.NiPostsV2{ + NIPosts: []wire.NIPostV2{ { Membership: wire.MerkleProofV2{ Nodes: nipost.Membership.Nodes, @@ -152,7 +152,7 @@ func createMerged( PreviousATXs: previous, MarriageATX: &marriage, PositioningATX: positioning, - NiPosts: []wire.NiPostsV2{ + NIPosts: []wire.NIPostV2{ { Membership: membership, Challenge: types.Hash32(niposts[0].PostMetadata.Challenge), @@ -163,7 +163,7 @@ func createMerged( for i, nipost := range niposts { idx := slices.IndexFunc(previous, func(a types.ATXID) bool { return a == nipost.previous }) require.NotEqual(tb, -1, idx) - atx.NiPosts[0].Posts = append(atx.NiPosts[0].Posts, wire.SubPostV2{ + atx.NIPosts[0].Posts = append(atx.NIPosts[0].Posts, wire.SubPostV2{ MarriageIndex: uint32(i), PrevATXIndex: uint32(idx), MembershipLeafIndex: nipost.Membership.LeafIndex, diff --git a/activation/handler_test.go b/activation/handler_test.go index 013501e1ad..c080012133 100644 --- a/activation/handler_test.go +++ b/activation/handler_test.go @@ -119,6 +119,7 @@ func toAtx(tb testing.TB, watx *wire.ActivationTxV1) *types.ActivationTx { } type handlerMocks struct { + ctrl *gomock.Controller goldenATXID types.ATXID mclock *MocklayerClock @@ -184,6 +185,8 @@ func (h *handlerMocks) expectAtxV1(atx *wire.ActivationTxV1, nodeId types.NodeID func newTestHandlerMocks(tb testing.TB, golden types.ATXID) handlerMocks { ctrl := gomock.NewController(tb) return handlerMocks{ + ctrl: ctrl, + goldenATXID: golden, mclock: NewMocklayerClock(ctrl), mpub: pubsubmocks.NewMockPublisher(ctrl), diff --git a/activation/handler_v2.go b/activation/handler_v2.go index 3f3e6606a2..3ec7abcca0 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -182,13 +182,13 @@ func (h *HandlerV2) syntacticallyValidate(ctx context.Context, atx *wire.Activat } if atx.MarriageATX == nil { - if len(atx.NiPosts) != 1 { + if len(atx.NIPosts) != 1 { return errors.New("solo atx must have one nipost") } - if len(atx.NiPosts[0].Posts) != 1 { + if len(atx.NIPosts[0].Posts) != 1 { return errors.New("solo atx must have one post") } - if atx.NiPosts[0].Posts[0].PrevATXIndex != 0 { + if atx.NIPosts[0].Posts[0].PrevATXIndex != 0 { return errors.New("solo atx post must have prevATXIndex 0") } } @@ -204,7 +204,7 @@ func (h *HandlerV2) syntacticallyValidate(ctx context.Context, atx *wire.Activat return errors.New("initial atx must not have previous atxs") } - numUnits := atx.NiPosts[0].Posts[0].NumUnits + numUnits := atx.NIPosts[0].Posts[0].NumUnits if err := h.nipostValidator.VRFNonceV2( atx.SmesherID, atx.Initial.CommitmentATX, atx.VRFNonce, numUnits, ); err != nil { @@ -309,7 +309,7 @@ func (h *HandlerV2) collectAtxDeps(atx *wire.ActivationTxV2) ([]types.Hash32, [] } poetRefs := make(map[types.Hash32]struct{}) - for _, nipost := range atx.NiPosts { + for _, nipost := range atx.NIPosts { poetRefs[nipost.Challenge] = struct{}{} } @@ -495,7 +495,7 @@ func (n nipostSizes) sumUp() (units uint32, weight uint64, err error) { func (h *HandlerV2) verifyIncludedIDsUniqueness(atx *wire.ActivationTxV2) error { seen := make(map[uint32]struct{}) - for _, niposts := range atx.NiPosts { + for _, niposts := range atx.NIPosts { for _, post := range niposts.Posts { if _, ok := seen[post.MarriageIndex]; ok { return fmt.Errorf("ID present twice (duplicated marriage index): %d", post.MarriageIndex) @@ -540,8 +540,8 @@ func (h *HandlerV2) syntacticallyValidateDeps( } // validate previous ATXs - nipostSizes := make(nipostSizes, len(atx.NiPosts)) - for i, niposts := range atx.NiPosts { + nipostSizes := make(nipostSizes, len(atx.NIPosts)) + for i, niposts := range atx.NIPosts { nipostSizes[i] = new(nipostSize) for _, post := range niposts.Posts { if post.MarriageIndex >= uint32(len(equivocationSet)) { @@ -563,7 +563,7 @@ func (h *HandlerV2) syntacticallyValidateDeps( } // validate poet membership proofs - for i, niposts := range atx.NiPosts { + for i, niposts := range atx.NIPosts { // verify PoET memberships in a single go indexedChallenges := make(map[uint64][]byte) @@ -608,7 +608,7 @@ func (h *HandlerV2) syntacticallyValidateDeps( // validate all niposts var smesherCommitment *types.ATXID - for _, niposts := range atx.NiPosts { + for idx, niposts := range atx.NIPosts { for _, post := range niposts.Posts { id := equivocationSet[post.MarriageIndex] var commitment types.ATXID @@ -632,24 +632,18 @@ func (h *HandlerV2) syntacticallyValidateDeps( id, commitment, wire.PostFromWireV1(&post.Post), - niposts.Challenge[:], + niposts.Challenge.Bytes(), post.NumUnits, PostSubset([]byte(h.local)), ) invalidIdx := &verifying.ErrInvalidIndex{} - if errors.As(err, invalidIdx) { - h.logger.Debug( - "ATX with invalid post index", - zap.Stringer("id", atx.ID()), - zap.Int("index", invalidIdx.Index), - ) - // TODO(mafa): finish proof - var proof wire.Proof - if err := h.malPublisher.Publish(ctx, id, proof); err != nil { - return nil, fmt.Errorf("publishing malfeasance proof for invalid post: %w", err) + switch { + case errors.As(err, invalidIdx): + if err := h.publishInvalidPostProof(ctx, atx, id, idx, uint32(invalidIdx.Index)); err != nil { + return nil, fmt.Errorf("publishing invalid post proof: %w", err) } - } - if err != nil { + return nil, fmt.Errorf("invalid post for ID %s: %w", id.ShortString(), err) + case err != nil: return nil, fmt.Errorf("validating post for ID %s: %w", id.ShortString(), err) } result.ids[id] = idData{ @@ -674,6 +668,45 @@ func (h *HandlerV2) syntacticallyValidateDeps( return &result, nil } +func (h *HandlerV2) publishInvalidPostProof( + ctx context.Context, + atx *wire.ActivationTxV2, + nodeID types.NodeID, + nipostIndex int, + invalidPostIndex uint32, +) error { + initialAtx := atx + if initialAtx.Initial == nil { + initialID, err := atxs.GetFirstIDByNodeID(h.cdb, nodeID) + if err != nil { + return fmt.Errorf("fetch initial ATX for ID %s: %w", nodeID.ShortString(), err) + } + + // TODO(mafa): implement for v1 initial ATXs: https://github.com/spacemeshos/go-spacemesh/issues/6433 + initialAtx, err = h.fetchWireAtx(ctx, h.cdb, initialID) + if err != nil { + return fmt.Errorf("fetch initial ATX blob for ID %s: %w", nodeID.ShortString(), err) + } + } + + // TODO(mafa): checkpoints need to include all initial ATXs in full to be able to create this malfeasance proof: + // + // see https://github.com/spacemeshos/go-spacemesh/issues/6436 + // + // TODO(mafa): checkpoints need to include all marriage ATXs in full to be able to create malfeasance proofs + // like this one (but also others) + // + // see https://github.com/spacemeshos/go-spacemesh/issues/6435 + proof, err := wire.NewInvalidPostProof(h.cdb, atx, initialAtx, nodeID, nipostIndex, invalidPostIndex) + if err != nil { + return fmt.Errorf("creating invalid post proof: %w", err) + } + if err := h.malPublisher.Publish(ctx, nodeID, proof); err != nil { + return fmt.Errorf("publishing malfeasance proof for invalid post: %w", err) + } + return nil +} + func (h *HandlerV2) checkMalicious(ctx context.Context, tx sql.Transaction, atx *activationTx) (bool, error) { malicious, err := malfeasance.IsMalicious(tx, atx.SmesherID) if err != nil { @@ -717,7 +750,7 @@ func (h *HandlerV2) checkMalicious(ctx context.Context, tx sql.Transaction, atx func (h *HandlerV2) fetchWireAtx( ctx context.Context, - tx sql.Transaction, + tx sql.Executor, id types.ATXID, ) (*wire.ActivationTxV2, error) { var blob sql.Blob @@ -808,6 +841,7 @@ func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx sql.Transaction, at zap.Stringer("smesher_id", atx.SmesherID), ) + // TODO(mafa): finish proof var proof wire.Proof return true, h.malPublisher.Publish(ctx, atx.SmesherID, proof) } diff --git a/activation/handler_v2_test.go b/activation/handler_v2_test.go index 81c6d40fc9..a03058ab3d 100644 --- a/activation/handler_v2_test.go +++ b/activation/handler_v2_test.go @@ -75,7 +75,7 @@ func newV2TestHandler(tb testing.TB, golden types.ATXID) *v2TestHandler { func (h *handlerMocks) expectFetchDeps(atx *wire.ActivationTxV2) { h.mockFetch.EXPECT().RegisterPeerHashes(gomock.Any(), gomock.Any()) - h.mockFetch.EXPECT().GetPoetProof(gomock.Any(), atx.NiPosts[0].Challenge) + h.mockFetch.EXPECT().GetPoetProof(gomock.Any(), atx.NIPosts[0].Challenge) _, atxDeps := (&HandlerV2{goldenATXID: h.goldenATXID}).collectAtxDeps(atx) if len(atxDeps) != 0 { h.mockFetch.EXPECT().GetAtxs(gomock.Any(), gomock.InAnyOrder(atxDeps), gomock.Any()) @@ -87,15 +87,15 @@ func (h *handlerMocks) expectVerifyNIPoST(atx *wire.ActivationTxV2) { gomock.Any(), atx.SmesherID, gomock.Any(), - wire.PostFromWireV1(&atx.NiPosts[0].Posts[0].Post), - atx.NiPosts[0].Challenge.Bytes(), - atx.NiPosts[0].Posts[0].NumUnits, + wire.PostFromWireV1(&atx.NIPosts[0].Posts[0].Post), + atx.NIPosts[0].Challenge.Bytes(), + atx.NIPosts[0].Posts[0].NumUnits, gomock.Any(), ) h.mValidator.EXPECT().PoetMembership( gomock.Any(), gomock.Any(), - atx.NiPosts[0].Challenge, + atx.NIPosts[0].Challenge, gomock.Any(), ).Return(poetLeaves, nil) } @@ -105,7 +105,7 @@ func (h *handlerMocks) expectVerifyNIPoSTs( equivocationSet []types.NodeID, poetLeaves []uint64, ) { - for i, nipost := range atx.NiPosts { + for i, nipost := range atx.NIPosts { for _, post := range nipost.Posts { h.mValidator.EXPECT().PostV2( gomock.Any(), @@ -140,7 +140,7 @@ func (h *handlerMocks) expectInitialAtxV2(atx *wire.ActivationTxV2) { atx.SmesherID, atx.Initial.CommitmentATX, atx.VRFNonce, - atx.NiPosts[0].Posts[0].NumUnits, + atx.NIPosts[0].Posts[0].NumUnits, ) h.mValidator.EXPECT().PostV2( gomock.Any(), @@ -148,7 +148,7 @@ func (h *handlerMocks) expectInitialAtxV2(atx *wire.ActivationTxV2) { atx.Initial.CommitmentATX, wire.PostFromWireV1(&atx.Initial.Post), shared.ZeroChallenge, - atx.NiPosts[0].Posts[0].NumUnits, + atx.NIPosts[0].Posts[0].NumUnits, gomock.Any(), ) @@ -163,7 +163,7 @@ func (h *handlerMocks) expectAtxV2(atx *wire.ActivationTxV2) { atx.SmesherID, gomock.Any(), atx.VRFNonce, - atx.NiPosts[0].Posts[0].NumUnits, + atx.NIPosts[0].Posts[0].NumUnits, ) h.expectFetchDeps(atx) h.expectVerifyNIPoST(atx) @@ -282,7 +282,7 @@ func TestHandlerV2_SyntacticallyValidate_InitialAtx(t *testing.T) { sig.NodeID(), atx.Initial.CommitmentATX, atx.VRFNonce, - atx.NiPosts[0].Posts[0].NumUnits, + atx.NIPosts[0].Posts[0].NumUnits, ) atxHandler.mValidator.EXPECT().PostV2( context.Background(), @@ -290,7 +290,7 @@ func TestHandlerV2_SyntacticallyValidate_InitialAtx(t *testing.T) { atx.Initial.CommitmentATX, wire.PostFromWireV1(&atx.Initial.Post), shared.ZeroChallenge, - atx.NiPosts[0].Posts[0].NumUnits, + atx.NIPosts[0].Posts[0].NumUnits, ) require.NoError(t, atxHandler.syntacticallyValidate(context.Background(), atx)) }) @@ -346,7 +346,7 @@ func TestHandlerV2_SyntacticallyValidate_InitialAtx(t *testing.T) { sig.NodeID(), atx.Initial.CommitmentATX, atx.VRFNonce, - atx.NiPosts[0].Posts[0].NumUnits, + atx.NIPosts[0].Posts[0].NumUnits, ). Return(errors.New("invalid nonce")) @@ -363,7 +363,7 @@ func TestHandlerV2_SyntacticallyValidate_InitialAtx(t *testing.T) { sig.NodeID(), atx.Initial.CommitmentATX, atx.VRFNonce, - atx.NiPosts[0].Posts[0].NumUnits, + atx.NIPosts[0].Posts[0].NumUnits, ) atxHandler.mValidator.EXPECT(). PostV2( @@ -372,7 +372,7 @@ func TestHandlerV2_SyntacticallyValidate_InitialAtx(t *testing.T) { atx.Initial.CommitmentATX, wire.PostFromWireV1(&atx.Initial.Post), shared.ZeroChallenge, - atx.NiPosts[0].Posts[0].NumUnits, + atx.NIPosts[0].Posts[0].NumUnits, ). Return(errors.New("invalid post")) require.ErrorContains(t, atxHandler.syntacticallyValidate(context.Background(), atx), "invalid post") @@ -406,7 +406,7 @@ func TestHandlerV2_SyntacticallyValidate_SoloAtx(t *testing.T) { t.Run("rejects when len(NIPoSTs) != 1", func(t *testing.T) { t.Parallel() atx := newInitialATXv2(t, golden) - atx.NiPosts = append(atx.NiPosts, wire.NiPostsV2{}) + atx.NIPosts = append(atx.NIPosts, wire.NIPostV2{}) atx.Sign(sig) atxHandler.mclock.EXPECT().CurrentLayer() @@ -416,7 +416,7 @@ func TestHandlerV2_SyntacticallyValidate_SoloAtx(t *testing.T) { t.Run("rejects when contains more than 1 ID", func(t *testing.T) { t.Parallel() atx := newInitialATXv2(t, golden) - atx.NiPosts[0].Posts = append(atx.NiPosts[0].Posts, wire.SubPostV2{}) + atx.NIPosts[0].Posts = append(atx.NIPosts[0].Posts, wire.SubPostV2{}) atx.Sign(sig) atxHandler.mclock.EXPECT().CurrentLayer() @@ -426,7 +426,7 @@ func TestHandlerV2_SyntacticallyValidate_SoloAtx(t *testing.T) { t.Run("rejects when PrevATXIndex != 0", func(t *testing.T) { t.Parallel() atx := newInitialATXv2(t, golden) - atx.NiPosts[0].Posts[0].PrevATXIndex = 1 + atx.NIPosts[0].Posts[0].PrevATXIndex = 1 atx.Sign(sig) atxHandler.mclock.EXPECT().CurrentLayer() @@ -484,8 +484,8 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { require.Equal(t, atx.Coinbase, atxFromDb.Coinbase) require.EqualValues(t, poetLeaves/tickSize, atxFromDb.TickCount) require.EqualValues(t, 0+atxFromDb.TickCount, atxFromDb.TickHeight()) // positioning is golden - require.Equal(t, atx.NiPosts[0].Posts[0].NumUnits, atxFromDb.NumUnits) - require.EqualValues(t, atx.NiPosts[0].Posts[0].NumUnits*poetLeaves/tickSize, atxFromDb.Weight) + require.Equal(t, atx.NIPosts[0].Posts[0].NumUnits, atxFromDb.NumUnits) + require.EqualValues(t, atx.NIPosts[0].Posts[0].NumUnits*poetLeaves/tickSize, atxFromDb.Weight) // processing ATX for the second time should skip checks err = atxHandler.processATX(context.Background(), peer, atx, time.Now()) @@ -511,8 +511,8 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { require.EqualValues(t, poetLeaves/tickSize, atxFromDb.TickCount) require.EqualValues(t, prevAtx.TickHeight(), atxFromDb.BaseTickHeight) require.EqualValues(t, prevAtx.TickHeight()+atxFromDb.TickCount, atxFromDb.TickHeight()) - require.Equal(t, atx.NiPosts[0].Posts[0].NumUnits, atxFromDb.NumUnits) - require.EqualValues(t, atx.NiPosts[0].Posts[0].NumUnits*poetLeaves/tickSize, atxFromDb.Weight) + require.Equal(t, atx.NIPosts[0].Posts[0].NumUnits, atxFromDb.NumUnits) + require.EqualValues(t, atx.NIPosts[0].Posts[0].NumUnits*poetLeaves/tickSize, atxFromDb.Weight) }) t.Run("second ATX, previous checkpointed", func(t *testing.T) { t.Parallel() @@ -543,7 +543,7 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { prev := atxHandler.createAndProcessInitial(sig) atx := newSoloATXv2(t, prev.PublishEpoch+1, prev.ID(), golden) - atx.NiPosts[0].Posts[0].NumUnits = prev.TotalNumUnits() * 10 + atx.NIPosts[0].Posts[0].NumUnits = prev.TotalNumUnits() * 10 atx.VRFNonce = 7779989 atx.Sign(sig) atxHandler.expectAtxV2(atx) @@ -562,7 +562,7 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { prev := atxHandler.createAndProcessInitial(sig) atx := newSoloATXv2(t, prev.PublishEpoch+1, prev.ID(), golden) - atx.NiPosts[0].Posts[0].NumUnits = prev.TotalNumUnits() * 10 + atx.NIPosts[0].Posts[0].NumUnits = prev.TotalNumUnits() * 10 atx.VRFNonce = 7779989 atx.Sign(sig) atxHandler.mclock.EXPECT().CurrentLayer().Return(postGenesisEpoch.FirstLayer()) @@ -589,7 +589,7 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { atx := newSoloATXv2(t, prev.PublishEpoch+1, prev.ID(), golden) atx.VRFNonce = uint64(123) - atx.NiPosts[0].Posts[0].NumUnits = prev.TotalNumUnits() - 1 + atx.NIPosts[0].Posts[0].NumUnits = prev.TotalNumUnits() - 1 atx.Sign(sig) atxHandler.expectAtxV2(atx) @@ -675,7 +675,7 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { // Process a merged ATX merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) - totalNumUnits := merged.NiPosts[0].Posts[0].NumUnits + totalNumUnits := merged.NIPosts[0].Posts[0].NumUnits for i, atx := range otherATXs { post := wire.SubPostV2{ MarriageIndex: uint32(i + 1), @@ -683,7 +683,7 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { PrevATXIndex: uint32(i + 1), } totalNumUnits += post.NumUnits - merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) } mATXID := mATX.ID() merged.MarriageATX = &mATXID @@ -720,7 +720,7 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { PositioningATX: mATX.ID(), Coinbase: types.GenerateAddress([]byte("aaaa")), VRFNonce: uint64(999), - NiPosts: make([]wire.NiPostsV2, 4), + NIPosts: make([]wire.NIPostV2, 4), } atxsPerPoet := [][]*wire.ActivationTxV2{ append([]*wire.ActivationTxV2{mATX}, otherATXs[0]), @@ -740,7 +740,7 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { } unitsPerPoet[nipostId] += post.NumUnits totalNumUnits += post.NumUnits - merged.NiPosts[nipostId].Posts = append(merged.NiPosts[nipostId].Posts, post) + merged.NIPosts[nipostId].Posts = append(merged.NIPosts[nipostId].Posts, post) idx++ } } @@ -786,14 +786,14 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { // Process a merged ATX merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) - merged.NiPosts[0].Posts = []wire.SubPostV2{} // remove signer's PoST + merged.NIPosts[0].Posts = []wire.SubPostV2{} // remove signer's PoST for i, atx := range otherATXs { post := wire.SubPostV2{ MarriageIndex: uint32(i + 1), NumUnits: atx.TotalNumUnits(), PrevATXIndex: uint32(i), } - merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) } mATXID := mATX.ID() merged.MarriageATX = &mATXID @@ -828,7 +828,7 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { PrevATXIndex: 1, NumUnits: otherATXs[0].TotalNumUnits(), } - merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) } mATXID := mATX.ID() merged.MarriageATX = &mATXID @@ -858,7 +858,7 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { PrevATXIndex: 0, // use wrong previous ATX NumUnits: otherATXs[0].TotalNumUnits(), } - merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) mATXID := mATX.ID() merged.MarriageATX = &mATXID @@ -892,13 +892,13 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { // Process a merged ATX merged := newSoloATXv2(t, prev.Epoch+1, prev.ID, golden) - merged.NiPosts[0].Posts = []wire.SubPostV2{} + merged.NIPosts[0].Posts = []wire.SubPostV2{} for marriageIdx := range equivocationSet { post := wire.SubPostV2{ MarriageIndex: uint32(marriageIdx), NumUnits: 7, } - merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) } mATXID := mATX.ID() @@ -919,13 +919,13 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { require.NoError(t, atxs.AddCheckpointed(atxHandler.cdb, &prev)) merged = newSoloATXv2(t, prev.Epoch+1, prev.ID, golden) - merged.NiPosts[0].Posts = []wire.SubPostV2{} + merged.NIPosts[0].Posts = []wire.SubPostV2{} for marriageIdx := range equivocationSet { post := wire.SubPostV2{ MarriageIndex: uint32(marriageIdx), NumUnits: 7, } - merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) } merged.MarriageATX = &mATXID merged.Sign(sig) @@ -947,14 +947,14 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { // Process a merged ATX for 2 IDs merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) - merged.NiPosts[0].Posts = []wire.SubPostV2{} + merged.NIPosts[0].Posts = []wire.SubPostV2{} for i := range equivocationSet[:2] { post := wire.SubPostV2{ MarriageIndex: uint32(i), PrevATXIndex: uint32(i), NumUnits: 4, } - merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) } mATXID := mATX.ID() @@ -969,14 +969,14 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { // Process a second merged ATX for the same equivocation set, but different IDs merged = newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) - merged.NiPosts[0].Posts = []wire.SubPostV2{} + merged.NIPosts[0].Posts = []wire.SubPostV2{} for i := range equivocationSet[:2] { post := wire.SubPostV2{ MarriageIndex: uint32(i + 2), PrevATXIndex: uint32(i), NumUnits: 4, } - merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) } mATXID = mATX.ID() @@ -1006,14 +1006,14 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { // create and process another merged ATX merged := newSoloATXv2(t, checkpointedATX.Epoch, mATX.ID(), golden) - merged.NiPosts[0].Posts = []wire.SubPostV2{} + merged.NIPosts[0].Posts = []wire.SubPostV2{} for i := range equivocationSet[2:] { post := wire.SubPostV2{ MarriageIndex: uint32(i + 2), PrevATXIndex: uint32(i), NumUnits: 4, } - merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) } merged.MarriageATX = &mATXID @@ -1049,7 +1049,7 @@ func TestCollectDeps_AtxV2(t *testing.T) { PositioningATX: positioning, Initial: &wire.InitialAtxPartsV2{CommitmentATX: commitment}, MarriageATX: &marriage, - NiPosts: []wire.NiPostsV2{ + NIPosts: []wire.NIPostV2{ {Challenge: poetA}, {Challenge: poetB}, }, @@ -1071,7 +1071,7 @@ func TestCollectDeps_AtxV2(t *testing.T) { PositioningATX: atxA, Initial: &wire.InitialAtxPartsV2{CommitmentATX: atxA}, MarriageATX: &atxA, - NiPosts: []wire.NiPostsV2{ + NIPosts: []wire.NIPostV2{ {Challenge: poetA}, {Challenge: poetA}, }, @@ -1086,7 +1086,7 @@ func TestCollectDeps_AtxV2(t *testing.T) { PreviousATXs: []types.ATXID{prev0, prev1}, PositioningATX: positioning, MarriageATX: &marriage, - NiPosts: []wire.NiPostsV2{ + NIPosts: []wire.NIPostV2{ {Challenge: poetA}, {Challenge: poetB}, }, @@ -1513,8 +1513,8 @@ func TestHandlerV2_SyntacticallyValidateDeps(t *testing.T) { gomock.Any(), sig.NodeID(), golden, - wire.PostFromWireV1(&atx.NiPosts[0].Posts[0].Post), - atx.NiPosts[0].Challenge.Bytes(), + wire.PostFromWireV1(&atx.NIPosts[0].Posts[0].Post), + atx.NIPosts[0].Challenge.Bytes(), atx.TotalNumUnits(), gomock.Any(), ). @@ -1522,26 +1522,193 @@ func TestHandlerV2_SyntacticallyValidateDeps(t *testing.T) { _, err := atxHandler.syntacticallyValidateDeps(context.Background(), atx) require.ErrorContains(t, err, "post failure") }) - t.Run("invalid PoST index - generates a malfeasance proof", func(t *testing.T) { + t.Run("invalid PoST index initial ATX - generates a malfeasance proof", func(t *testing.T) { atxHandler := newV2TestHandler(t, golden) atx := newInitialATXv2(t, golden) atx.Sign(sig) atxHandler.mValidator.EXPECT().PoetMembership(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - atxHandler.mValidator.EXPECT(). - PostV2( + atxHandler.mValidator.EXPECT().PostV2( + context.Background(), + atx.SmesherID, + atx.Initial.CommitmentATX, + wire.PostFromWireV1(&atx.NIPosts[0].Posts[0].Post), + atx.NIPosts[0].Challenge.Bytes(), + atx.TotalNumUnits(), + gomock.Any(), + ).Return(verifying.ErrInvalidIndex{Index: 7}) + + verifier := wire.NewMockMalfeasanceValidator(atxHandler.ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return atxHandler.edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + verifier.EXPECT().PostIndex( + context.Background(), + atx.SmesherID, + atx.Initial.CommitmentATX, + wire.PostFromWireV1(&atx.NIPosts[0].Posts[0].Post), + atx.NIPosts[0].Challenge.Bytes(), + atx.TotalNumUnits(), + 7, + ).Return(errors.New("invalid post index")) + + atxHandler.mMalPublish.EXPECT().Publish( + gomock.Any(), + sig.NodeID(), + gomock.Cond(func(data wire.Proof) bool { + _, ok := data.(*wire.ProofInvalidPost) + return ok + }), + ).DoAndReturn(func(ctx context.Context, _ types.NodeID, proof wire.Proof) error { + malProof := proof.(*wire.ProofInvalidPost) + nId, err := malProof.Valid(ctx, verifier) + require.NoError(t, err) + require.Equal(t, sig.NodeID(), nId) + return nil + }) + _, err := atxHandler.syntacticallyValidateDeps(context.Background(), atx) + vErr := &verifying.ErrInvalidIndex{} + require.ErrorAs(t, err, vErr) + require.Equal(t, 7, vErr.Index) + }) + t.Run("invalid PoST index solo ATX - generates a malfeasance proof", func(t *testing.T) { + atxHandler := newV2TestHandler(t, golden) + + initialAtx := atxHandler.createAndProcessInitial(sig) + + atx := newSoloATXv2(t, initialAtx.PublishEpoch+1, initialAtx.ID(), initialAtx.ID()) + atx.Sign(sig) + + atxHandler.mValidator.EXPECT().PoetMembership(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + atxHandler.mValidator.EXPECT().PostV2( + context.Background(), + atx.SmesherID, + initialAtx.Initial.CommitmentATX, + wire.PostFromWireV1(&atx.NIPosts[0].Posts[0].Post), + atx.NIPosts[0].Challenge.Bytes(), + atx.TotalNumUnits(), + gomock.Any(), + ).Return(verifying.ErrInvalidIndex{Index: 7}) + + verifier := wire.NewMockMalfeasanceValidator(atxHandler.ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return atxHandler.edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + verifier.EXPECT().PostIndex( + context.Background(), + atx.SmesherID, + initialAtx.Initial.CommitmentATX, + wire.PostFromWireV1(&atx.NIPosts[0].Posts[0].Post), + atx.NIPosts[0].Challenge.Bytes(), + atx.TotalNumUnits(), + 7, + ).Return(errors.New("invalid post index")) + + atxHandler.mMalPublish.EXPECT().Publish( + gomock.Any(), + sig.NodeID(), + gomock.Cond(func(data wire.Proof) bool { + _, ok := data.(*wire.ProofInvalidPost) + return ok + }), + ).DoAndReturn(func(ctx context.Context, _ types.NodeID, proof wire.Proof) error { + malProof := proof.(*wire.ProofInvalidPost) + nId, err := malProof.Valid(ctx, verifier) + require.NoError(t, err) + require.Equal(t, sig.NodeID(), nId) + return nil + }) + _, err := atxHandler.syntacticallyValidateDeps(context.Background(), atx) + vErr := &verifying.ErrInvalidIndex{} + require.ErrorAs(t, err, vErr) + require.Equal(t, 7, vErr.Index) + }) + t.Run("invalid PoST index merged ATX - generates a malfeasance proof", func(t *testing.T) { + atxHandler := newV2TestHandler(t, golden) + + marrySig, err := signing.NewEdSigner() + require.NoError(t, err) + pubSig, err := signing.NewEdSigner() + require.NoError(t, err) + + // Marry IDs + mATX, otherATXs := marryIDs(t, atxHandler, []*signing.EdSigner{marrySig, sig, pubSig}, golden) + previousATXs := []types.ATXID{mATX.ID()} + for _, atx := range otherATXs { + previousATXs = append(previousATXs, atx.ID()) + } + + // Process a merged ATX + merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) + for i, atx := range otherATXs { + post := wire.SubPostV2{ + MarriageIndex: uint32(i + 1), + NumUnits: atx.TotalNumUnits(), + PrevATXIndex: uint32(i + 1), + } + merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) + } + mATXID := mATX.ID() + merged.MarriageATX = &mATXID + + merged.PreviousATXs = previousATXs + merged.Sign(sig) + + equivocationSet := []types.NodeID{marrySig.NodeID(), sig.NodeID(), pubSig.NodeID()} + atxHandler.mValidator.EXPECT().PoetMembership(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + for _, post := range merged.NIPosts[0].Posts { + call := atxHandler.mValidator.EXPECT().PostV2( + context.Background(), + equivocationSet[post.MarriageIndex], gomock.Any(), - sig.NodeID(), - golden, - wire.PostFromWireV1(&atx.NiPosts[0].Posts[0].Post), - atx.NiPosts[0].Challenge.Bytes(), - atx.TotalNumUnits(), + wire.PostFromWireV1(&post.Post), + merged.NIPosts[0].Challenge.Bytes(), + post.NumUnits, gomock.Any(), - ). - Return(verifying.ErrInvalidIndex{Index: 7}) - atxHandler.mMalPublish.EXPECT().Publish(gomock.Any(), sig.NodeID(), gomock.Any()) - _, err := atxHandler.syntacticallyValidateDeps(context.Background(), atx) + ) + if equivocationSet[post.MarriageIndex] == sig.NodeID() { + call.Return(verifying.ErrInvalidIndex{Index: 7}) + } else { + call.AnyTimes() + } + } + + verifier := wire.NewMockMalfeasanceValidator(atxHandler.ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return atxHandler.edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + verifier.EXPECT().PostIndex( + context.Background(), + sig.NodeID(), + gomock.Any(), + gomock.Any(), + merged.NIPosts[0].Challenge.Bytes(), + gomock.Any(), + 7, + ).Return(errors.New("invalid post index")) + + atxHandler.mMalPublish.EXPECT().Publish( + gomock.Any(), + sig.NodeID(), + gomock.Cond(func(data wire.Proof) bool { + _, ok := data.(*wire.ProofInvalidPost) + return ok + }), + ).DoAndReturn(func(ctx context.Context, _ types.NodeID, proof wire.Proof) error { + malProof := proof.(*wire.ProofInvalidPost) + nId, err := malProof.Valid(ctx, verifier) + require.NoError(t, err) + require.Equal(t, sig.NodeID(), nId) + return nil + }) + _, err = atxHandler.syntacticallyValidateDeps(context.Background(), merged) vErr := &verifying.ErrInvalidIndex{} require.ErrorAs(t, err, vErr) require.Equal(t, 7, vErr.Index) @@ -1553,7 +1720,7 @@ func TestHandlerV2_SyntacticallyValidateDeps(t *testing.T) { atx.Sign(sig) atxHandler.mValidator.EXPECT(). - PoetMembership(gomock.Any(), gomock.Any(), atx.NiPosts[0].Challenge, gomock.Any()). + PoetMembership(gomock.Any(), gomock.Any(), atx.NIPosts[0].Challenge, gomock.Any()). Return(0, errors.New("poet failure")) _, err := atxHandler.syntacticallyValidateDeps(context.Background(), atx) require.ErrorContains(t, err, "poet failure") @@ -1655,6 +1822,13 @@ func Test_Marriages(t *testing.T) { } atx2.Sign(sig) atxHandler.expectAtxV2(atx2) + + verifier := wire.NewMockMalfeasanceValidator(atxHandler.ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return atxHandler.edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + atxHandler.mMalPublish.EXPECT().Publish( gomock.Any(), sig.NodeID(), @@ -1662,9 +1836,9 @@ func Test_Marriages(t *testing.T) { _, ok := data.(*wire.ProofDoubleMarry) return ok }), - ).DoAndReturn(func(_ context.Context, _ types.NodeID, proof wire.Proof) error { + ).DoAndReturn(func(ctx context.Context, _ types.NodeID, proof wire.Proof) error { malProof := proof.(*wire.ProofDoubleMarry) - nId, err := malProof.Valid(atxHandler.edVerifier) + nId, err := malProof.Valid(ctx, verifier) require.NoError(t, err) require.Equal(t, sig.NodeID(), nId) return nil @@ -1828,7 +2002,7 @@ func TestContextualValidation_DoublePost(t *testing.T) { NumUnits: othersAtx.TotalNumUnits(), PrevATXIndex: 1, } - merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) mATXID := mATX.ID() merged.MarriageATX = &mATXID @@ -1906,7 +2080,7 @@ func TestContextual_PreviousATX(t *testing.T) { PrevATXIndex: 1, NumUnits: soloAtx.TotalNumUnits(), } - merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) // Pass a wrong previous ATX for signer 1. It's already been used for soloATX // (which should be used for the previous ATX for signer 1). merged.PreviousATXs = append(merged.PreviousATXs, otherAtxs[0].ID()) @@ -1950,8 +2124,11 @@ func newInitialATXv2(tb testing.TB, golden types.ATXID) *wire.ActivationTxV2 { atx := &wire.ActivationTxV2{ PositioningATX: golden, Initial: &wire.InitialAtxPartsV2{CommitmentATX: golden}, - NiPosts: []wire.NiPostsV2{ + NIPosts: []wire.NIPostV2{ { + Membership: wire.MerkleProofV2{ + Nodes: make([]types.Hash32, 32), + }, Challenge: types.RandomHash(), Posts: []wire.SubPostV2{ { @@ -1974,7 +2151,7 @@ func newSoloATXv2(tb testing.TB, publish types.EpochID, prev, pos types.ATXID) * PublishEpoch: publish, PreviousATXs: []types.ATXID{prev}, PositioningATX: pos, - NiPosts: []wire.NiPostsV2{ + NIPosts: []wire.NIPostV2{ { Challenge: types.RandomHash(), Posts: []wire.SubPostV2{ diff --git a/activation/malfeasance.go b/activation/malfeasance.go index 81d855ce58..042da18e03 100644 --- a/activation/malfeasance.go +++ b/activation/malfeasance.go @@ -152,8 +152,8 @@ func (mh *InvalidPostIndexHandler) Validate(ctx context.Context, data wire.Proof } post := (*shared.Proof)(atx.NIPost.Post) meta := &shared.ProofMetadata{ - NodeId: atx.SmesherID[:], - CommitmentAtxId: commitmentAtx[:], + NodeId: atx.SmesherID.Bytes(), + CommitmentAtxId: commitmentAtx.Bytes(), NumUnits: atx.NumUnits, Challenge: atx.NIPost.PostMetadata.Challenge, LabelsPerUnit: atx.NIPost.PostMetadata.LabelsPerUnit, diff --git a/activation/validation.go b/activation/validation.go index d6d070d895..cb9a9a8885 100644 --- a/activation/validation.go +++ b/activation/validation.go @@ -41,6 +41,7 @@ func (e *ErrAtxNotFound) Is(target error) bool { } type validatorOptions struct { + postIdx *int postSubsetSeed []byte prioritized bool } @@ -53,6 +54,14 @@ func PostSubset(seed []byte) validatorOption { } } +// PostIndex configures the validator to validate only the POST index at the given `idx`. +func PostIndex(idx int) validatorOption { + return func(o *validatorOptions) { + o.postIdx = new(int) + *o.postIdx = idx + } +} + func PrioritizeCall() validatorOption { return func(o *validatorOptions) { o.prioritized = true @@ -204,6 +213,9 @@ func (v *Validator) Post( } verifyOpts := []verifying.OptionFunc{verifying.WithLabelScryptParams(v.scrypt)} + if options.postIdx != nil { + verifyOpts = append(verifyOpts, verifying.SelectedIndex(*options.postIdx)) + } if options.postSubsetSeed != nil { verifyOpts = append(verifyOpts, verifying.Subset(v.cfg.K3, options.postSubsetSeed)) } @@ -486,7 +498,7 @@ func (v *Validator) getAtxDeps(ctx context.Context, id types.ATXID) (*atxDeps, e previous: atx.PreviousATXs, commitment: commitment, } - for _, nipost := range atx.NiPosts { + for _, nipost := range atx.NIPosts { for _, post := range nipost.Posts { deps.niposts = append(deps.niposts, types.NIPost{ Post: wire.PostFromWireV1(&post.Post), diff --git a/activation/validation_test.go b/activation/validation_test.go index 59278548dd..4ba72911f3 100644 --- a/activation/validation_test.go +++ b/activation/validation_test.go @@ -588,7 +588,7 @@ func TestVerifyChainDeps(t *testing.T) { require.NoError(t, atxs.Add(db, atx, watx.Blob())) v := NewMockPostVerifier(gomock.NewController(t)) - expectedPost := (*shared.Proof)(wire.PostFromWireV1(&watx.NiPosts[0].Posts[0].Post)) + expectedPost := (*shared.Proof)(wire.PostFromWireV1(&watx.NIPosts[0].Posts[0].Post)) v.EXPECT().Verify(ctx, expectedPost, gomock.Any(), gomock.Any()) validator := NewValidator(db, nil, DefaultPostConfig(), config.ScryptParams{}, v) err = validator.VerifyChain(ctx, watx.ID(), goldenATXID) @@ -609,7 +609,7 @@ func TestVerifyChainDeps(t *testing.T) { require.NoError(t, atxs.Add(db, atx, watx.Blob())) v := NewMockPostVerifier(gomock.NewController(t)) - expectedPost := (*shared.Proof)(wire.PostFromWireV1(&watx.NiPosts[0].Posts[0].Post)) + expectedPost := (*shared.Proof)(wire.PostFromWireV1(&watx.NIPosts[0].Posts[0].Post)) v.EXPECT().Verify(ctx, (*shared.Proof)(initialAtx.NIPost.Post), gomock.Any(), gomock.Any()) v.EXPECT().Verify(ctx, expectedPost, gomock.Any(), gomock.Any()) validator := NewValidator(db, nil, DefaultPostConfig(), config.ScryptParams{}, v) @@ -629,7 +629,7 @@ func TestVerifyChainDeps(t *testing.T) { require.NoError(t, atxs.Add(db, toAtx(t, initialAtx2), initialAtx2.Blob())) watx := newSoloATXv2(t, initialAtx.PublishEpoch+1, initialAtx.ID(), initialAtx.ID()) - watx.NiPosts[0].Posts = append(watx.NiPosts[0].Posts, wire.SubPostV2{ + watx.NIPosts[0].Posts = append(watx.NIPosts[0].Posts, wire.SubPostV2{ MarriageIndex: 1, PrevATXIndex: 1, Post: wire.PostV1{ @@ -649,8 +649,8 @@ func TestVerifyChainDeps(t *testing.T) { require.NoError(t, atxs.Add(db, atx, watx.Blob())) v := NewMockPostVerifier(gomock.NewController(t)) - expectedPost := (*shared.Proof)(wire.PostFromWireV1(&watx.NiPosts[0].Posts[0].Post)) - expectedPost2 := (*shared.Proof)(wire.PostFromWireV1(&watx.NiPosts[0].Posts[1].Post)) + expectedPost := (*shared.Proof)(wire.PostFromWireV1(&watx.NIPosts[0].Posts[0].Post)) + expectedPost2 := (*shared.Proof)(wire.PostFromWireV1(&watx.NIPosts[0].Posts[1].Post)) v.EXPECT().Verify(ctx, (*shared.Proof)(initialAtx.NIPost.Post), gomock.Any(), gomock.Any()) v.EXPECT().Verify(ctx, (*shared.Proof)(initialAtx2.NIPost.Post), gomock.Any(), gomock.Any()) v.EXPECT().Verify(ctx, expectedPost, gomock.Any(), gomock.Any()) diff --git a/activation/wire/interface.go b/activation/wire/interface.go new file mode 100644 index 0000000000..ba5006e3cd --- /dev/null +++ b/activation/wire/interface.go @@ -0,0 +1,27 @@ +package wire + +import ( + "context" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/signing" +) + +//go:generate mockgen -typed -package=wire -destination=./mocks.go -source=./interface.go + +type MalfeasanceValidator interface { + // PostIndex validates the given post against for the provided index. + // It returns an error if the post is invalid. + PostIndex( + ctx context.Context, + smesherID types.NodeID, + commitment types.ATXID, + post *types.Post, + challenge []byte, + numUnits uint32, + idx int, + ) error + + // Signature validates the given signature against the given message and public key. + Signature(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool +} diff --git a/activation/wire/malfeasance.go b/activation/wire/malfeasance.go index c857dd075b..7e2aa97b98 100644 --- a/activation/wire/malfeasance.go +++ b/activation/wire/malfeasance.go @@ -1,22 +1,23 @@ package wire import ( + "context" + "github.com/spacemeshos/go-scale" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/signing" ) //go:generate scalegen // MerkleTreeIndex is the index of the leaf containing the given field in the merkle tree. -type MerkleTreeIndex uint16 +type MerkleTreeIndex uint64 const ( PublishEpochIndex MerkleTreeIndex = iota PositioningATXIndex CoinbaseIndex - InitialPostIndex + InitialPostRootIndex PreviousATXsRootIndex NIPostsRootIndex VRFNonceIndex @@ -24,6 +25,38 @@ const ( MarriageATXIndex ) +type InitialPostTreeIndex uint64 + +const ( + CommitmentATXIndex InitialPostTreeIndex = iota + InitialPostIndex +) + +type NIPostTreeIndex uint64 + +const ( + MembershipIndex NIPostTreeIndex = iota + ChallengeIndex + PostsRootIndex +) + +type MarriageCertificateIndex uint64 + +const ( + ReferenceATXIndex MarriageCertificateIndex = iota + SignatureIndex +) + +type SubPostTreeIndex uint64 + +const ( + MarriageIndex SubPostTreeIndex = iota + PrevATXIndex + MembershipLeafIndex + PostIndex + NumUnitsIndex +) + // ProofType is an identifier for the type of proof that is encoded in the ATXProof. type ProofType byte @@ -33,11 +66,10 @@ const ( LegacyInvalidPost ProofType = 0x01 LegacyInvalidPrevATX ProofType = 0x02 - DoublePublish ProofType = 0x10 - DoubleMarry ProofType = 0x11 - DoubleMerge ProofType = 0x12 - InvalidPost ProofType = 0x13 - InvalidPrevious ProofType = 0x14 + DoubleMarry ProofType = 0x10 + DoubleMerge ProofType = 0x11 + InvalidPost ProofType = 0x12 + InvalidPrevious ProofType = 0x13 ) // ProofVersion is an identifier for the version of the proof that is encoded in the ATXProof. @@ -57,5 +89,5 @@ type ATXProof struct { type Proof interface { scale.Encodable - Valid(edVerifier *signing.EdVerifier) (types.NodeID, error) + Valid(ctx context.Context, malHandler MalfeasanceValidator) (types.NodeID, error) } diff --git a/activation/wire/malfeasance_double_marry.go b/activation/wire/malfeasance_double_marry.go index ea946b29ea..ac7a760c1f 100644 --- a/activation/wire/malfeasance_double_marry.go +++ b/activation/wire/malfeasance_double_marry.go @@ -1,16 +1,13 @@ package wire import ( + "context" "errors" "fmt" - "slices" - - "github.com/spacemeshos/merkle-tree" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" - "github.com/spacemeshos/go-spacemesh/sql/atxs" ) //go:generate scalegen @@ -28,7 +25,23 @@ type ProofDoubleMarry struct { // NodeID is the node ID that married twice. NodeID types.NodeID - Proofs [2]MarryProof + // ATX1 is the ID of the ATX being proven to have the marriage certificate of interest. + ATX1 types.ATXID + // SmesherID1 is the ID of the smesher that published ATX1. + SmesherID1 types.NodeID + // Signature1 is the signature of the ATXID by the smesher. + Signature1 types.EdSignature + // Proof1 is the proof that the marriage certificate is contained in the ATX1. + Proof1 MarryProof + + // ATX2 is the ID of the ATX being proven to have the marriage certificate of interest. + ATX2 types.ATXID + // SmesherID2 is the ID of the smesher that published ATX2. + SmesherID2 types.NodeID + // Signature2 is the signature of the ATXID by the smesher. + Signature2 types.EdSignature + // Proof2 is the proof that the marriage certificate is contained in the ATX2. + Proof2 MarryProof } var _ Proof = &ProofDoubleMarry{} @@ -42,180 +55,41 @@ func NewDoubleMarryProof(db sql.Executor, atx1, atx2 *ActivationTxV2, nodeID typ if err != nil { return nil, fmt.Errorf("proof for atx1: %w", err) } - proof2, err := createMarryProof(db, atx2, nodeID) if err != nil { return nil, fmt.Errorf("proof for atx2: %w", err) } - proof := &ProofDoubleMarry{ + return &ProofDoubleMarry{ NodeID: nodeID, - Proofs: [2]MarryProof{proof1, proof2}, - } - return proof, nil -} - -func createMarryProof(db sql.Executor, atx *ActivationTxV2, nodeID types.NodeID) (MarryProof, error) { - marriageProof, err := marriageProof(atx) - if err != nil { - return MarryProof{}, fmt.Errorf("failed to create proof for ATX 1: %w", err) - } - - marriageIndex := slices.IndexFunc(atx.Marriages, func(cert MarriageCertificate) bool { - if cert.ReferenceAtx == types.EmptyATXID && atx.SmesherID == nodeID { - // special case of the self signed certificate of the ATX publisher - return true - } - refATX, err := atxs.Get(db, cert.ReferenceAtx) - if err != nil { - return false - } - return refATX.SmesherID == nodeID - }) - if marriageIndex == -1 { - return MarryProof{}, fmt.Errorf("does not contain a marriage certificate signed by %s", nodeID.ShortString()) - } - certProof, err := certificateProof(atx.Marriages, uint64(marriageIndex)) - if err != nil { - return MarryProof{}, fmt.Errorf("failed to create certificate proof for ATX 1: %w", err) - } - - proof := MarryProof{ - ATXID: atx.ID(), - - MarriageRoot: types.Hash32(atx.Marriages.Root()), - MarriageProof: marriageProof, - CertificateReference: atx.Marriages[marriageIndex].ReferenceAtx, - CertificateSignature: atx.Marriages[marriageIndex].Signature, - CertificateIndex: uint64(marriageIndex), - CertificateProof: certProof, + ATX1: atx1.ID(), + SmesherID1: atx1.SmesherID, + Signature1: atx1.Signature, + Proof1: proof1, - SmesherID: atx.SmesherID, - Signature: atx.Signature, - } - return proof, nil + ATX2: atx2.ID(), + SmesherID2: atx2.SmesherID, + Signature2: atx2.Signature, + Proof2: proof2, + }, nil } -func marriageProof(atx *ActivationTxV2) ([]types.Hash32, error) { - tree, err := merkle.NewTreeBuilder(). - WithLeavesToProve(map[uint64]bool{uint64(MarriagesRootIndex): true}). - WithHashFunc(atxTreeHash). - Build() - if err != nil { - return nil, err - } - atx.merkleTree(tree) - proof := tree.Proof() - - proofHashes := make([]types.Hash32, len(proof)) - for i, p := range proof { - proofHashes[i] = types.Hash32(p) +func (p ProofDoubleMarry) Valid(_ context.Context, malValidator MalfeasanceValidator) (types.NodeID, error) { + if p.ATX1 == p.ATX2 { + return types.EmptyNodeID, errors.New("proofs have the same ATX ID") } - return proofHashes, nil -} - -func certificateProof(certs MarriageCertificates, index uint64) ([]types.Hash32, error) { - tree, err := merkle.NewTreeBuilder(). - WithLeavesToProve(map[uint64]bool{index: true}). - WithHashFunc(atxTreeHash). - Build() - if err != nil { - return nil, err + if !malValidator.Signature(signing.ATX, p.SmesherID1, p.ATX1.Bytes(), p.Signature1) { + return types.EmptyNodeID, errors.New("invalid signature for ATX1") } - certs.merkleTree(tree) - proof := tree.Proof() - - proofHashes := make([]types.Hash32, len(proof)) - for i, p := range proof { - proofHashes[i] = types.Hash32(p) + if !malValidator.Signature(signing.ATX, p.SmesherID2, p.ATX2.Bytes(), p.Signature2) { + return types.EmptyNodeID, errors.New("invalid signature for ATX2") } - return proofHashes, nil -} - -func (p ProofDoubleMarry) Valid(edVerifier *signing.EdVerifier) (types.NodeID, error) { - if p.Proofs[0].ATXID == p.Proofs[1].ATXID { - return types.EmptyNodeID, errors.New("proofs have the same ATX ID") - } - - if err := p.Proofs[0].Valid(edVerifier, p.NodeID); err != nil { + if err := p.Proof1.Valid(malValidator, p.ATX1, p.SmesherID1, p.NodeID); err != nil { return types.EmptyNodeID, fmt.Errorf("proof 1 is invalid: %w", err) } - if err := p.Proofs[1].Valid(edVerifier, p.NodeID); err != nil { + if err := p.Proof2.Valid(malValidator, p.ATX2, p.SmesherID2, p.NodeID); err != nil { return types.EmptyNodeID, fmt.Errorf("proof 2 is invalid: %w", err) } return p.NodeID, nil } - -type MarryProof struct { - // ATXID is the ID of the ATX being proven. - ATXID types.ATXID - - // MarriageRoot and its proof that it is contained in the ATX. - MarriageRoot types.Hash32 - MarriageProof []types.Hash32 `scale:"max=32"` - - // The signature of the certificate and the proof that the certificate is contained in the MarriageRoot at - // the given index. - CertificateReference types.ATXID - CertificateSignature types.EdSignature - CertificateIndex uint64 - CertificateProof []types.Hash32 `scale:"max=32"` - - // SmesherID is the ID of the smesher that published the ATX. - SmesherID types.NodeID - // Signature is the signature of the ATXID by the smesher. - Signature types.EdSignature -} - -func (p MarryProof) Valid(edVerifier *signing.EdVerifier, nodeID types.NodeID) error { - if !edVerifier.Verify(signing.ATX, p.SmesherID, p.ATXID.Bytes(), p.Signature) { - return errors.New("invalid ATX signature") - } - - if !edVerifier.Verify(signing.MARRIAGE, nodeID, p.SmesherID.Bytes(), p.CertificateSignature) { - return errors.New("invalid certificate signature") - } - - proof := make([][]byte, len(p.MarriageProof)) - for i, h := range p.MarriageProof { - proof[i] = h.Bytes() - } - ok, err := merkle.ValidatePartialTree( - []uint64{uint64(MarriagesRootIndex)}, - [][]byte{p.MarriageRoot.Bytes()}, - proof, - p.ATXID.Bytes(), - atxTreeHash, - ) - if err != nil { - return fmt.Errorf("validate marriage proof: %w", err) - } - if !ok { - return errors.New("invalid marriage proof") - } - - mc := MarriageCertificate{ - ReferenceAtx: p.CertificateReference, - Signature: p.CertificateSignature, - } - - certProof := make([][]byte, len(p.CertificateProof)) - for i, h := range p.CertificateProof { - certProof[i] = h.Bytes() - } - ok, err = merkle.ValidatePartialTree( - []uint64{p.CertificateIndex}, - [][]byte{mc.Root()}, - certProof, - p.MarriageRoot.Bytes(), - atxTreeHash, - ) - if err != nil { - return fmt.Errorf("validate certificate proof: %w", err) - } - if !ok { - return errors.New("invalid certificate proof") - } - return nil -} diff --git a/activation/wire/malfeasance_double_marry_scale.go b/activation/wire/malfeasance_double_marry_scale.go index 03f70c95fa..d7f3855020 100644 --- a/activation/wire/malfeasance_double_marry_scale.go +++ b/activation/wire/malfeasance_double_marry_scale.go @@ -5,7 +5,6 @@ package wire import ( "github.com/spacemeshos/go-scale" - "github.com/spacemeshos/go-spacemesh/common/types" ) func (t *ProofDoubleMarry) EncodeScale(enc *scale.Encoder) (total int, err error) { @@ -17,92 +16,56 @@ func (t *ProofDoubleMarry) EncodeScale(enc *scale.Encoder) (total int, err error total += n } { - n, err := scale.EncodeStructArray(enc, t.Proofs[:]) - if err != nil { - return total, err - } - total += n - } - return total, nil -} - -func (t *ProofDoubleMarry) DecodeScale(dec *scale.Decoder) (total int, err error) { - { - n, err := scale.DecodeByteArray(dec, t.NodeID[:]) - if err != nil { - return total, err - } - total += n - } - { - n, err := scale.DecodeStructArray(dec, t.Proofs[:]) - if err != nil { - return total, err - } - total += n - } - return total, nil -} - -func (t *MarryProof) EncodeScale(enc *scale.Encoder) (total int, err error) { - { - n, err := scale.EncodeByteArray(enc, t.ATXID[:]) + n, err := scale.EncodeByteArray(enc, t.ATX1[:]) if err != nil { return total, err } total += n } { - n, err := scale.EncodeByteArray(enc, t.MarriageRoot[:]) + n, err := scale.EncodeByteArray(enc, t.SmesherID1[:]) if err != nil { return total, err } total += n } { - n, err := scale.EncodeStructSliceWithLimit(enc, t.MarriageProof, 32) + n, err := scale.EncodeByteArray(enc, t.Signature1[:]) if err != nil { return total, err } total += n } { - n, err := scale.EncodeByteArray(enc, t.CertificateReference[:]) + n, err := t.Proof1.EncodeScale(enc) if err != nil { return total, err } total += n } { - n, err := scale.EncodeByteArray(enc, t.CertificateSignature[:]) + n, err := scale.EncodeByteArray(enc, t.ATX2[:]) if err != nil { return total, err } total += n } { - n, err := scale.EncodeCompact64(enc, uint64(t.CertificateIndex)) + n, err := scale.EncodeByteArray(enc, t.SmesherID2[:]) if err != nil { return total, err } total += n } { - n, err := scale.EncodeStructSliceWithLimit(enc, t.CertificateProof, 32) + n, err := scale.EncodeByteArray(enc, t.Signature2[:]) if err != nil { return total, err } total += n } { - n, err := scale.EncodeByteArray(enc, t.SmesherID[:]) - if err != nil { - return total, err - } - total += n - } - { - n, err := scale.EncodeByteArray(enc, t.Signature[:]) + n, err := t.Proof2.EncodeScale(enc) if err != nil { return total, err } @@ -111,68 +74,65 @@ func (t *MarryProof) EncodeScale(enc *scale.Encoder) (total int, err error) { return total, nil } -func (t *MarryProof) DecodeScale(dec *scale.Decoder) (total int, err error) { +func (t *ProofDoubleMarry) DecodeScale(dec *scale.Decoder) (total int, err error) { { - n, err := scale.DecodeByteArray(dec, t.ATXID[:]) + n, err := scale.DecodeByteArray(dec, t.NodeID[:]) if err != nil { return total, err } total += n } { - n, err := scale.DecodeByteArray(dec, t.MarriageRoot[:]) + n, err := scale.DecodeByteArray(dec, t.ATX1[:]) if err != nil { return total, err } total += n } { - field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + n, err := scale.DecodeByteArray(dec, t.SmesherID1[:]) if err != nil { return total, err } total += n - t.MarriageProof = field } { - n, err := scale.DecodeByteArray(dec, t.CertificateReference[:]) + n, err := scale.DecodeByteArray(dec, t.Signature1[:]) if err != nil { return total, err } total += n } { - n, err := scale.DecodeByteArray(dec, t.CertificateSignature[:]) + n, err := t.Proof1.DecodeScale(dec) if err != nil { return total, err } total += n } { - field, n, err := scale.DecodeCompact64(dec) + n, err := scale.DecodeByteArray(dec, t.ATX2[:]) if err != nil { return total, err } total += n - t.CertificateIndex = uint64(field) } { - field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + n, err := scale.DecodeByteArray(dec, t.SmesherID2[:]) if err != nil { return total, err } total += n - t.CertificateProof = field } { - n, err := scale.DecodeByteArray(dec, t.SmesherID[:]) + n, err := scale.DecodeByteArray(dec, t.Signature2[:]) if err != nil { return total, err } total += n } { - n, err := scale.DecodeByteArray(dec, t.Signature[:]) + n, err := t.Proof2.DecodeScale(dec) if err != nil { return total, err } diff --git a/activation/wire/malfeasance_double_marry_test.go b/activation/wire/malfeasance_double_marry_test.go index 33bc5ef400..f9f686503a 100644 --- a/activation/wire/malfeasance_double_marry_test.go +++ b/activation/wire/malfeasance_double_marry_test.go @@ -1,11 +1,13 @@ package wire import ( + "context" "fmt" "slices" "testing" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" @@ -20,6 +22,8 @@ func Test_DoubleMarryProof(t *testing.T) { otherSig, err := signing.NewEdSigner() require.NoError(t, err) + edVerifier := signing.NewEdVerifier() + t.Run("valid", func(t *testing.T) { db := statesql.InMemoryTest(t) otherAtx := &types.ActivationTx{} @@ -43,8 +47,14 @@ func Test_DoubleMarryProof(t *testing.T) { require.NoError(t, err) require.NotNil(t, proof) - verifier := signing.NewEdVerifier() - id, err := proof.Valid(verifier) + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + id, err := proof.Valid(context.Background(), verifier) require.NoError(t, err) require.Equal(t, otherSig.NodeID(), id) }) @@ -86,18 +96,14 @@ func Test_DoubleMarryProof(t *testing.T) { // manually construct an invalid proof proof = &ProofDoubleMarry{ - Proofs: [2]MarryProof{ - { - ATXID: atx1.ID(), - }, - { - ATXID: atx1.ID(), - }, - }, + ATX1: atx1.ID(), + ATX2: atx1.ID(), } - verifier := signing.NewEdVerifier() - id, err := proof.Valid(verifier) + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + + id, err := proof.Valid(context.Background(), verifier) require.ErrorContains(t, err, "same ATX ID") require.Equal(t, types.EmptyNodeID, id) }) @@ -129,22 +135,35 @@ func Test_DoubleMarryProof(t *testing.T) { proof := &ProofDoubleMarry{ NodeID: otherSig.NodeID(), - Proofs: [2]MarryProof{ - proof1, proof2, - }, + + ATX1: atx1.ID(), + SmesherID1: atx1.SmesherID, + Signature1: atx1.Signature, + Proof1: proof1, + + ATX2: atx2.ID(), + SmesherID2: atx2.SmesherID, + Signature2: atx2.Signature, + Proof2: proof2, } - verifier := signing.NewEdVerifier() - proof.Proofs[0].MarriageProof = slices.Clone(proof1.MarriageProof) - proof.Proofs[0].MarriageProof[0] = types.RandomHash() - id, err := proof.Valid(verifier) + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + proof.Proof1.MarriageCertificatesProof = slices.Clone(proof1.MarriageCertificatesProof) + proof.Proof1.MarriageCertificatesProof[0] = types.RandomHash() + id, err := proof.Valid(context.Background(), verifier) require.ErrorContains(t, err, "proof 1 is invalid: invalid marriage proof") require.Equal(t, types.EmptyNodeID, id) - proof.Proofs[0].MarriageProof[0] = proof1.MarriageProof[0] - proof.Proofs[1].MarriageProof = slices.Clone(proof2.MarriageProof) - proof.Proofs[1].MarriageProof[0] = types.RandomHash() - id, err = proof.Valid(verifier) + proof.Proof1.MarriageCertificatesProof[0] = proof1.MarriageCertificatesProof[0] + proof.Proof2.MarriageCertificatesProof = slices.Clone(proof2.MarriageCertificatesProof) + proof.Proof2.MarriageCertificatesProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) require.ErrorContains(t, err, "proof 2 is invalid: invalid marriage proof") require.Equal(t, types.EmptyNodeID, id) }) @@ -176,22 +195,35 @@ func Test_DoubleMarryProof(t *testing.T) { proof := &ProofDoubleMarry{ NodeID: otherSig.NodeID(), - Proofs: [2]MarryProof{ - proof1, proof2, - }, + + ATX1: atx1.ID(), + SmesherID1: atx1.SmesherID, + Signature1: atx1.Signature, + Proof1: proof1, + + ATX2: atx2.ID(), + SmesherID2: atx2.SmesherID, + Signature2: atx2.Signature, + Proof2: proof2, } - verifier := signing.NewEdVerifier() - proof.Proofs[0].CertificateProof = slices.Clone(proof1.CertificateProof) - proof.Proofs[0].CertificateProof[0] = types.RandomHash() - id, err := proof.Valid(verifier) + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + proof.Proof1.CertificateProof = slices.Clone(proof1.CertificateProof) + proof.Proof1.CertificateProof[0] = types.RandomHash() + id, err := proof.Valid(context.Background(), verifier) require.ErrorContains(t, err, "proof 1 is invalid: invalid certificate proof") require.Equal(t, types.EmptyNodeID, id) - proof.Proofs[0].CertificateProof[0] = proof1.CertificateProof[0] - proof.Proofs[1].CertificateProof = slices.Clone(proof2.CertificateProof) - proof.Proofs[1].CertificateProof[0] = types.RandomHash() - id, err = proof.Valid(verifier) + proof.Proof1.CertificateProof[0] = proof1.CertificateProof[0] + proof.Proof2.CertificateProof = slices.Clone(proof2.CertificateProof) + proof.Proof2.CertificateProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) require.ErrorContains(t, err, "proof 2 is invalid: invalid certificate proof") require.Equal(t, types.EmptyNodeID, id) }) @@ -218,17 +250,22 @@ func Test_DoubleMarryProof(t *testing.T) { proof, err := NewDoubleMarryProof(db, atx1, atx2, otherSig.NodeID()) require.NoError(t, err) - verifier := signing.NewEdVerifier() + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() - proof.Proofs[0].Signature = types.RandomEdSignature() - id, err := proof.Valid(verifier) - require.ErrorContains(t, err, "proof 1 is invalid: invalid ATX signature") + proof.Signature1 = types.RandomEdSignature() + id, err := proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid signature for ATX1") require.Equal(t, types.EmptyNodeID, id) - proof.Proofs[0].Signature = atx1.Signature - proof.Proofs[1].Signature = types.RandomEdSignature() - id, err = proof.Valid(verifier) - require.ErrorContains(t, err, "proof 2 is invalid: invalid ATX signature") + proof.Signature1 = atx1.Signature + proof.Signature2 = types.RandomEdSignature() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid signature for ATX2") require.Equal(t, types.EmptyNodeID, id) }) @@ -254,16 +291,21 @@ func Test_DoubleMarryProof(t *testing.T) { proof, err := NewDoubleMarryProof(db, atx1, atx2, otherSig.NodeID()) require.NoError(t, err) - verifier := signing.NewEdVerifier() + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() - proof.Proofs[0].CertificateSignature = types.RandomEdSignature() - id, err := proof.Valid(verifier) + proof.Proof1.Certificate.Signature = types.RandomEdSignature() + id, err := proof.Valid(context.Background(), verifier) require.ErrorContains(t, err, "proof 1 is invalid: invalid certificate signature") require.Equal(t, types.EmptyNodeID, id) - proof.Proofs[0].CertificateSignature = atx1.Marriages[1].Signature - proof.Proofs[1].CertificateSignature = types.RandomEdSignature() - id, err = proof.Valid(verifier) + proof.Proof1.Certificate.Signature = atx1.Marriages[1].Signature + proof.Proof2.Certificate.Signature = types.RandomEdSignature() + id, err = proof.Valid(context.Background(), verifier) require.ErrorContains(t, err, "proof 2 is invalid: invalid certificate signature") require.Equal(t, types.EmptyNodeID, id) }) diff --git a/activation/wire/malfeasance_invalid_post.go b/activation/wire/malfeasance_invalid_post.go new file mode 100644 index 0000000000..076b678bcb --- /dev/null +++ b/activation/wire/malfeasance_invalid_post.go @@ -0,0 +1,324 @@ +package wire + +import ( + "context" + "errors" + "fmt" + "slices" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql" +) + +//go:generate scalegen + +// ProofInvalidPost is a proof that a merged ATX with an invalid Post was published by a smesher. +// +// We are proofing the following: +// 1. The ATX has a valid signature. +// 2. If NodeID is different from SmesherID, we prove that NodeID and SmesherID are married. +// 3. The commitment ATX of NodeID used for the invalid PoST based on their initial ATX. +// 4. The provided Post is invalid for the given NodeID. +type ProofInvalidPost struct { + // ATXID is the ID of the ATX containing the invalid PoST. + ATXID types.ATXID + // SmesherID is the ID of the smesher that published the ATX. + SmesherID types.NodeID + // Signature is the signature of the ATXID by the smesher. + Signature types.EdSignature + + // NodeID is the node ID that created the invalid PoST. + NodeID types.NodeID + + // MarriageProof is the proof that NodeID and SmesherID are married. It is nil if NodeID == SmesherID. + MarriageProof *MarriageProof + // CommitmentProof is the proof for the commitment ATX of the smesher. Generated from the initial ATX of NodeID. + CommitmentProof CommitmentProof + // InvalidPostProof is the proof for the invalid PoST of the ATX. It contains the PoST and the merkle proofs to + // verify the PoST. + InvalidPostProof InvalidPostProof +} + +var _ Proof = &ProofInvalidPost{} + +func NewInvalidPostProof( + db sql.Executor, + atx, initialATX *ActivationTxV2, + nodeID types.NodeID, + nipostIndex int, + invalidPostIndex uint32, +) (*ProofInvalidPost, error) { + if atx.SmesherID != nodeID && atx.MarriageATX == nil { + return nil, errors.New("ATX is not a merged ATX, but NodeID is different from SmesherID") + } + + if nipostIndex < 0 || nipostIndex >= len(atx.NIPosts) { + return nil, errors.New("invalid NIPoST index") + } + + postIndex := 0 + var marriageProof *MarriageProof + if atx.SmesherID != nodeID { + proof, err := createMarriageProof(db, atx, nodeID) + if err != nil { + return nil, fmt.Errorf("marriage proof: %w", err) + } + marriageProof = &proof + postIndex = slices.IndexFunc(atx.NIPosts[nipostIndex].Posts, func(post SubPostV2) bool { + return post.MarriageIndex == proof.NodeIDMarryProof.CertificateIndex + }) + if postIndex == -1 { + return nil, fmt.Errorf("no PoST from %s in ATX", nodeID.ShortString()) + } + } + + commitmentProof, err := createCommitmentProof(initialATX, nodeID) + if err != nil { + return nil, fmt.Errorf("commitment proof: %w", err) + } + invalidPostProof, err := createInvalidPostProof(atx, nipostIndex, postIndex, invalidPostIndex) + if err != nil { + return nil, fmt.Errorf("invalid post proof: %w", err) + } + + return &ProofInvalidPost{ + ATXID: atx.ID(), + SmesherID: atx.SmesherID, + Signature: atx.Signature, + + NodeID: nodeID, + + MarriageProof: marriageProof, + + CommitmentProof: commitmentProof, + InvalidPostProof: invalidPostProof, + }, nil +} + +func (p ProofInvalidPost) Valid(ctx context.Context, malValidator MalfeasanceValidator) (types.NodeID, error) { + if !malValidator.Signature(signing.ATX, p.SmesherID, p.ATXID.Bytes(), p.Signature) { + return types.EmptyNodeID, errors.New("invalid signature") + } + + if p.NodeID != p.SmesherID && p.MarriageProof == nil { + return types.EmptyNodeID, errors.New("missing marriage proof") + } + + var marriageIndex *uint32 + if p.MarriageProof != nil { + if err := p.MarriageProof.Valid(malValidator, p.ATXID, p.NodeID, p.SmesherID); err != nil { + return types.EmptyNodeID, fmt.Errorf("invalid marriage proof: %w", err) + } + marriageIndex = &p.MarriageProof.NodeIDMarryProof.CertificateIndex + } + + if err := p.CommitmentProof.Valid(malValidator, p.NodeID); err != nil { + return types.EmptyNodeID, fmt.Errorf("invalid commitment proof: %w", err) + } + + if err := p.InvalidPostProof.Valid( + ctx, + malValidator, + p.ATXID, + p.NodeID, + p.CommitmentProof.CommitmentATX, + marriageIndex, + ); err != nil { + return types.EmptyNodeID, fmt.Errorf("invalid invalid post proof: %w", err) + } + + return p.NodeID, nil +} + +// CommitmentProof is a proof for the commitment ATX of a smesher. It is generated from the initial ATX. +type CommitmentProof struct { + // InitialATXID is the ID of the initial ATX of the smesher. + InitialATXID types.ATXID + + // InitialPostRoot and its proof that it is contained in the InitialATX. + InitialPostRoot InitialPostRoot + InitialPostProof InitialPostRootProof `scale:"max=32"` + + // CommitmentATX and its proof that it is contained in the InitialPostRoot. + CommitmentATX types.ATXID + CommitmentATXProof CommitmentATXProof `scale:"max=32"` + + // Signature is the signature of the ATXID by the smesher. + Signature types.EdSignature +} + +func createCommitmentProof(initialAtx *ActivationTxV2, nodeID types.NodeID) (CommitmentProof, error) { + if initialAtx.SmesherID != nodeID { + return CommitmentProof{}, errors.New("node ID does not match smesher ID of initial ATX") + } + if initialAtx.Initial == nil { + return CommitmentProof{}, errors.New("initial ATX does not contain initial PoST") + } + + return CommitmentProof{ + InitialATXID: initialAtx.ID(), + + InitialPostRoot: initialAtx.Initial.Root(), + InitialPostProof: initialAtx.InitialPostRootProof(), + + CommitmentATX: initialAtx.Initial.CommitmentATX, + CommitmentATXProof: initialAtx.Initial.CommitmentATXProof(), + + Signature: initialAtx.Signature, + }, nil +} + +func (p CommitmentProof) Valid(malValidator MalfeasanceValidator, nodeID types.NodeID) error { + if !malValidator.Signature(signing.ATX, nodeID, p.InitialATXID.Bytes(), p.Signature) { + return errors.New("invalid signature") + } + + if types.Hash32(p.InitialPostRoot) == types.EmptyHash32 { + return errors.New("invalid empty initial PoST root") // initial PoST root is empty for non-initial ATXs + } + + if !p.InitialPostProof.Valid(p.InitialATXID, p.InitialPostRoot) { + return errors.New("invalid initial PoST proof") + } + if !p.CommitmentATXProof.Valid(p.InitialPostRoot, p.CommitmentATX) { + return errors.New("invalid commitment ATX proof") + } + + return nil +} + +// InvalidPostProof is a proof for an invalid PoST in an ATX. It contains the PoST and the merkle proofs to verify the +// PoST. +type InvalidPostProof struct { + // NIPostsRoot and its proof that it is contained in the ATX. + NIPostsRoot NIPostsRoot + NIPostsRootProof NIPostsRootProof `scale:"max=32"` + + // NIPostRoot and its proof that it is contained at the given index in the NIPostsRoot. + NIPostRoot NIPostRoot + NIPostRootProof NIPostRootProof `scale:"max=32"` + NIPostIndex uint16 + + // Challenge and its proof that it is contained in the NIPostRoot. + Challenge types.Hash32 + ChallengeProof ChallengeProof `scale:"max=32"` + + // SubPostsRoot and its proof that it is contained in the NIPostRoot. + SubPostsRoot SubPostsRoot + SubPostsRootProof SubPostsRootProof `scale:"max=32"` + + // SubPostRoot and its proof that is contained at the given index in the SubPostsRoot. + SubPostRoot SubPostRoot + SubPostRootProof SubPostRootProof `scale:"max=32"` + SubPostRootIndex uint16 + + // MarriageIndexProof is the proof that the MarriageIndex (CertificateIndex from MarryProof) is contained in the + // SubPostRoot. + MarriageIndexProof MarriageIndexProof `scale:"max=32"` + + // Post is the invalid PoST and its proof that it is contained in the SubPostRoot. + Post PostV1 + PostProof PostRootProof `scale:"max=32"` + + // NumUnits and its proof that it is contained in the SubPostRoot. + NumUnits uint32 + NumUnitsProof NumUnitsProof `scale:"max=32"` + + // InvalidPostIndex is the index of the leaf that was identified to be invalid. + InvalidPostIndex uint32 +} + +func createInvalidPostProof( + atx *ActivationTxV2, + nipostIndex, + postIndex int, + invalidPostIndex uint32, +) (InvalidPostProof, error) { + if nipostIndex < 0 || nipostIndex >= len(atx.NIPosts) { + return InvalidPostProof{}, errors.New("invalid NIPoST index") + } + if postIndex < 0 || postIndex >= len(atx.NIPosts[nipostIndex].Posts) { + return InvalidPostProof{}, errors.New("invalid PoST index") + } + + return InvalidPostProof{ + NIPostsRoot: atx.NIPosts.Root(atx.PreviousATXs), + NIPostsRootProof: atx.NIPostsRootProof(), + + NIPostRoot: atx.NIPosts[nipostIndex].Root(atx.PreviousATXs), + NIPostRootProof: atx.NIPosts.Proof(int(nipostIndex), atx.PreviousATXs), + NIPostIndex: uint16(nipostIndex), + + Challenge: atx.NIPosts[nipostIndex].Challenge, + ChallengeProof: atx.NIPosts[nipostIndex].ChallengeProof(atx.PreviousATXs), + + SubPostsRoot: atx.NIPosts[nipostIndex].Posts.Root(atx.PreviousATXs), + SubPostsRootProof: atx.NIPosts[nipostIndex].PostsRootProof(atx.PreviousATXs), + + SubPostRoot: atx.NIPosts[nipostIndex].Posts[postIndex].Root(atx.PreviousATXs), + SubPostRootProof: atx.NIPosts[nipostIndex].Posts.Proof(postIndex, atx.PreviousATXs), + SubPostRootIndex: uint16(postIndex), + + MarriageIndexProof: atx.NIPosts[nipostIndex].Posts[postIndex].MarriageIndexProof(atx.PreviousATXs), + + Post: atx.NIPosts[nipostIndex].Posts[postIndex].Post, + PostProof: atx.NIPosts[nipostIndex].Posts[postIndex].PostProof(atx.PreviousATXs), + + NumUnits: atx.NIPosts[nipostIndex].Posts[postIndex].NumUnits, + NumUnitsProof: atx.NIPosts[nipostIndex].Posts[postIndex].NumUnitsProof(atx.PreviousATXs), + + InvalidPostIndex: invalidPostIndex, + }, nil +} + +// Valid returns no error if the proof is valid. It verifies that the signature is valid, that the merkle proofs are +// and that the provided post is invalid. +func (p InvalidPostProof) Valid( + ctx context.Context, + malValidator MalfeasanceValidator, + atxID types.ATXID, + nodeID types.NodeID, + commitmentATX types.ATXID, + marriageIndex *uint32, +) error { + if !p.NIPostsRootProof.Valid(atxID, p.NIPostsRoot) { + return errors.New("invalid NIPosts root proof") + } + if !p.NIPostRootProof.Valid(p.NIPostsRoot, int(p.NIPostIndex), p.NIPostRoot) { + return errors.New("invalid NIPoST root proof") + } + if !p.ChallengeProof.Valid(p.NIPostRoot, p.Challenge) { + return errors.New("invalid challenge proof") + } + if !p.SubPostsRootProof.Valid(p.NIPostRoot, p.SubPostsRoot) { + return errors.New("invalid sub PoSTs root proof") + } + if !p.SubPostRootProof.Valid(p.SubPostsRoot, int(p.SubPostRootIndex), p.SubPostRoot) { + return errors.New("invalid sub PoST root proof") + } + if marriageIndex != nil { + if !p.MarriageIndexProof.Valid(p.SubPostRoot, *marriageIndex) { + return errors.New("invalid marriage index proof") + } + } + if !p.PostProof.Valid(p.SubPostRoot, p.Post.Root()) { + return errors.New("invalid PoST proof") + } + if !p.NumUnitsProof.Valid(p.SubPostRoot, p.NumUnits) { + return errors.New("invalid num units proof") + } + + if err := malValidator.PostIndex( + ctx, + nodeID, + commitmentATX, + PostFromWireV1(&p.Post), + p.Challenge.Bytes(), + p.NumUnits, + int(p.InvalidPostIndex), + ); err != nil { + return nil + } + return errors.New("PoST is valid") +} diff --git a/activation/wire/malfeasance_invalid_post_scale.go b/activation/wire/malfeasance_invalid_post_scale.go new file mode 100644 index 0000000000..64f2f630cc --- /dev/null +++ b/activation/wire/malfeasance_invalid_post_scale.go @@ -0,0 +1,482 @@ +// Code generated by github.com/spacemeshos/go-scale/scalegen. DO NOT EDIT. + +// nolint +package wire + +import ( + "github.com/spacemeshos/go-scale" + "github.com/spacemeshos/go-spacemesh/common/types" +) + +func (t *ProofInvalidPost) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := scale.EncodeByteArray(enc, t.ATXID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.SmesherID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.Signature[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.NodeID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeOption(enc, t.MarriageProof) + if err != nil { + return total, err + } + total += n + } + { + n, err := t.CommitmentProof.EncodeScale(enc) + if err != nil { + return total, err + } + total += n + } + { + n, err := t.InvalidPostProof.EncodeScale(enc) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *ProofInvalidPost) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + n, err := scale.DecodeByteArray(dec, t.ATXID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.SmesherID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.Signature[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.NodeID[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeOption[MarriageProof](dec) + if err != nil { + return total, err + } + total += n + t.MarriageProof = field + } + { + n, err := t.CommitmentProof.DecodeScale(dec) + if err != nil { + return total, err + } + total += n + } + { + n, err := t.InvalidPostProof.DecodeScale(dec) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *CommitmentProof) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := scale.EncodeByteArray(enc, t.InitialATXID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.InitialPostRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.InitialPostProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.CommitmentATX[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.CommitmentATXProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.Signature[:]) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *CommitmentProof) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + n, err := scale.DecodeByteArray(dec, t.InitialATXID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.InitialPostRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.InitialPostProof = field + } + { + n, err := scale.DecodeByteArray(dec, t.CommitmentATX[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.CommitmentATXProof = field + } + { + n, err := scale.DecodeByteArray(dec, t.Signature[:]) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *InvalidPostProof) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := scale.EncodeByteArray(enc, t.NIPostsRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.NIPostsRootProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.NIPostRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.NIPostRootProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeCompact16(enc, uint16(t.NIPostIndex)) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.Challenge[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.ChallengeProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.SubPostsRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.SubPostsRootProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.SubPostRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.SubPostRootProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeCompact16(enc, uint16(t.SubPostRootIndex)) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.MarriageIndexProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := t.Post.EncodeScale(enc) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.PostProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeCompact32(enc, uint32(t.NumUnits)) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.NumUnitsProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeCompact32(enc, uint32(t.InvalidPostIndex)) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *InvalidPostProof) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + n, err := scale.DecodeByteArray(dec, t.NIPostsRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.NIPostsRootProof = field + } + { + n, err := scale.DecodeByteArray(dec, t.NIPostRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.NIPostRootProof = field + } + { + field, n, err := scale.DecodeCompact16(dec) + if err != nil { + return total, err + } + total += n + t.NIPostIndex = uint16(field) + } + { + n, err := scale.DecodeByteArray(dec, t.Challenge[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.ChallengeProof = field + } + { + n, err := scale.DecodeByteArray(dec, t.SubPostsRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.SubPostsRootProof = field + } + { + n, err := scale.DecodeByteArray(dec, t.SubPostRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.SubPostRootProof = field + } + { + field, n, err := scale.DecodeCompact16(dec) + if err != nil { + return total, err + } + total += n + t.SubPostRootIndex = uint16(field) + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.MarriageIndexProof = field + } + { + n, err := t.Post.DecodeScale(dec) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.PostProof = field + } + { + field, n, err := scale.DecodeCompact32(dec) + if err != nil { + return total, err + } + total += n + t.NumUnits = uint32(field) + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.NumUnitsProof = field + } + { + field, n, err := scale.DecodeCompact32(dec) + if err != nil { + return total, err + } + total += n + t.InvalidPostIndex = uint32(field) + } + return total, nil +} diff --git a/activation/wire/malfeasance_invalid_post_test.go b/activation/wire/malfeasance_invalid_post_test.go new file mode 100644 index 0000000000..9144420297 --- /dev/null +++ b/activation/wire/malfeasance_invalid_post_test.go @@ -0,0 +1,632 @@ +package wire + +import ( + "context" + "errors" + "fmt" + "math/rand/v2" + "slices" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" +) + +func Test_InvalidPostProof(t *testing.T) { + // sig is the identity that creates the invalid PoST + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + // pubSig is the identity that publishes the merged ATX with the invalid PoST + pubSig, err := signing.NewEdSigner() + require.NoError(t, err) + + // marrySig is the identity that publishes the marriage ATX + marrySig, err := signing.NewEdSigner() + require.NoError(t, err) + + edVerifier := signing.NewEdVerifier() + + newSoloATXv2 := func( + db sql.Executor, + nipostChallenge types.Hash32, + post PostV1, + numUnits uint32, + ) (*ActivationTxV2, *ActivationTxV2) { + wInitialAtx := newActivationTxV2( + withInitial(types.RandomATXID(), PostV1{}), + ) + wInitialAtx.Sign(sig) + initialAtx := &types.ActivationTx{ + CommitmentATX: &wInitialAtx.Initial.CommitmentATX, + } + initialAtx.SetID(wInitialAtx.ID()) + initialAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, initialAtx, wInitialAtx.Blob())) + + atx := newActivationTxV2( + withPreviousATXs(wInitialAtx.ID()), + withNIPost( + withNIPostChallenge(nipostChallenge), + withNIPostSubPost(SubPostV2{ + Post: post, + NumUnits: numUnits, + }), + ), + ) + atx.Sign(sig) + return atx, wInitialAtx + } + + newMergedATXv2 := func( + db sql.Executor, + nipostChallenge types.Hash32, + post PostV1, + numUnits uint32, + ) (*ActivationTxV2, *ActivationTxV2) { + wInitialAtx := newActivationTxV2( + withInitial(types.RandomATXID(), PostV1{}), + ) + wInitialAtx.Sign(sig) + initialAtx := &types.ActivationTx{ + CommitmentATX: &wInitialAtx.Initial.CommitmentATX, + } + initialAtx.SetID(wInitialAtx.ID()) + initialAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, initialAtx, wInitialAtx.Blob())) + + wPubInitialAtx := newActivationTxV2( + withInitial(types.RandomATXID(), PostV1{}), + ) + wPubInitialAtx.Sign(pubSig) + pubInitialAtx := &types.ActivationTx{} + pubInitialAtx.SetID(wPubInitialAtx.ID()) + pubInitialAtx.SmesherID = pubSig.NodeID() + require.NoError(t, atxs.Add(db, pubInitialAtx, wPubInitialAtx.Blob())) + + marryInitialAtx := types.RandomATXID() + + wMarriageAtx := newActivationTxV2( + withMarriageCertificate(marrySig, types.EmptyATXID, marrySig.NodeID()), + withMarriageCertificate(sig, wInitialAtx.ID(), marrySig.NodeID()), + withMarriageCertificate(pubSig, wPubInitialAtx.ID(), marrySig.NodeID()), + ) + wMarriageAtx.Sign(marrySig) + + marriageAtx := &types.ActivationTx{} + marriageAtx.SetID(wMarriageAtx.ID()) + marriageAtx.SmesherID = marrySig.NodeID() + require.NoError(t, atxs.Add(db, marriageAtx, wMarriageAtx.Blob())) + + atx := newActivationTxV2( + withPreviousATXs(marryInitialAtx, wInitialAtx.ID(), wPubInitialAtx.ID()), + withMarriageATX(wMarriageAtx.ID()), + withNIPost( + withNIPostChallenge(nipostChallenge), + withNIPostMembershipProof(MerkleProofV2{}), + withNIPostSubPost(SubPostV2{ + MarriageIndex: 0, + PrevATXIndex: 0, + Post: PostV1{}, + }), + withNIPostSubPost(SubPostV2{ + MarriageIndex: 1, + PrevATXIndex: 1, + Post: post, + NumUnits: numUnits, + }), + withNIPostSubPost(SubPostV2{ + MarriageIndex: 2, + PrevATXIndex: 2, + Post: PostV1{}, + }), + ), + ) + atx.Sign(pubSig) + return atx, wInitialAtx + } + + t.Run("valid", func(t *testing.T) { + db := statesql.InMemoryTest(t) + + nipostChallenge := types.RandomHash() + const numUnits = uint32(11) + post := PostV1{ + Nonce: rand.Uint32(), + Indices: types.RandomBytes(11), + Pow: rand.Uint64(), + } + atx, initialAtx := newSoloATXv2(db, nipostChallenge, post, numUnits) + + const invalidPostIndex = 7 + proof, err := NewInvalidPostProof(db, atx, initialAtx, sig.NodeID(), 0, invalidPostIndex) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + verifier.EXPECT().PostIndex( + context.Background(), + sig.NodeID(), + initialAtx.Initial.CommitmentATX, + PostFromWireV1(&post), + nipostChallenge.Bytes(), + numUnits, + invalidPostIndex, + ).Return(errors.New("invalid post")) + + id, err := proof.Valid(context.Background(), verifier) + require.NoError(t, err) + require.Equal(t, sig.NodeID(), id) + }) + + t.Run("valid merged atx", func(t *testing.T) { + db := statesql.InMemoryTest(t) + + nipostChallenge := types.RandomHash() + const numUnits = uint32(11) + post := PostV1{ + Nonce: rand.Uint32(), + Indices: types.RandomBytes(11), + Pow: rand.Uint64(), + } + atx, initialAtx := newMergedATXv2(db, nipostChallenge, post, numUnits) + + const invalidPostIndex = 7 + proof, err := NewInvalidPostProof(db, atx, initialAtx, sig.NodeID(), 0, invalidPostIndex) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + verifier.EXPECT().PostIndex( + context.Background(), + sig.NodeID(), + initialAtx.Initial.CommitmentATX, + PostFromWireV1(&post), + nipostChallenge.Bytes(), + numUnits, + invalidPostIndex, + ).Return(errors.New("invalid post")) + + id, err := proof.Valid(context.Background(), verifier) + require.NoError(t, err) + require.Equal(t, sig.NodeID(), id) + }) + + t.Run("post is valid", func(t *testing.T) { + db := statesql.InMemoryTest(t) + + nipostChallenge := types.RandomHash() + const numUnits = uint32(11) + post := PostV1{ + Nonce: rand.Uint32(), + Indices: types.RandomBytes(11), + Pow: rand.Uint64(), + } + atx, initialAtx := newSoloATXv2(db, nipostChallenge, post, numUnits) + + const invalidPostIndex = 7 + proof, err := NewInvalidPostProof(db, atx, initialAtx, sig.NodeID(), 0, invalidPostIndex) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + verifier.EXPECT().PostIndex( + context.Background(), + sig.NodeID(), + initialAtx.Initial.CommitmentATX, + PostFromWireV1(&post), + nipostChallenge.Bytes(), + numUnits, + invalidPostIndex, + ).Return(nil) + + id, err := proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "invalid invalid post proof: PoST is valid") + require.Equal(t, types.EmptyNodeID, id) + }) + + t.Run("differing node ID without marriage ATX", func(t *testing.T) { + db := statesql.InMemoryTest(t) + + nipostChallenge := types.RandomHash() + const numUnits = uint32(11) + post := PostV1{ + Nonce: rand.Uint32(), + Indices: types.RandomBytes(11), + Pow: rand.Uint64(), + } + atx, initialAtx := newSoloATXv2(db, nipostChallenge, post, numUnits) + + const invalidPostIndex = 7 + proof, err := NewInvalidPostProof(db, atx, initialAtx, types.RandomNodeID(), 0, invalidPostIndex) + require.EqualError(t, err, "ATX is not a merged ATX, but NodeID is different from SmesherID") + require.Nil(t, proof) + + proof, err = NewInvalidPostProof(db, atx, initialAtx, sig.NodeID(), 0, invalidPostIndex) + require.NoError(t, err) + require.NotNil(t, proof) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + proof.NodeID = types.RandomNodeID() // invalid node ID + + id, err := proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "missing marriage proof") + require.Equal(t, types.EmptyNodeID, id) + }) + + t.Run("node ID not in marriage ATX", func(t *testing.T) { + db := statesql.InMemoryTest(t) + + nipostChallenge := types.RandomHash() + const numUnits = uint32(11) + post := PostV1{ + Nonce: rand.Uint32(), + Indices: types.RandomBytes(11), + Pow: rand.Uint64(), + } + atx, initialAtx := newMergedATXv2(db, nipostChallenge, post, numUnits) + + const invalidPostIndex = 7 + nodeID := types.RandomNodeID() + proof, err := NewInvalidPostProof(db, atx, initialAtx, nodeID, 0, invalidPostIndex) + require.ErrorContains(t, err, + fmt.Sprintf("does not contain a marriage certificate signed by %s", nodeID.ShortString()), + ) + require.Nil(t, proof) + }) + + t.Run("invalid marriage proof", func(t *testing.T) { + db := statesql.InMemoryTest(t) + + nipostChallenge := types.RandomHash() + const numUnits = uint32(11) + post := PostV1{ + Nonce: rand.Uint32(), + Indices: types.RandomBytes(11), + Pow: rand.Uint64(), + } + atx, _ := newMergedATXv2(db, nipostChallenge, post, numUnits) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + // manually construct an invalid proof + proof, err := createMarriageProof(db, atx, sig.NodeID()) + require.NoError(t, err) + + marriageATX := proof.MarriageATX + proof.MarriageATX = types.RandomATXID() // invalid ATX + err = proof.Valid(verifier, atx.ID(), sig.NodeID(), pubSig.NodeID()) + require.ErrorContains(t, err, "invalid marriage ATX proof") + + proof.MarriageATX = marriageATX + proof.MarriageATXProof[0] = types.RandomHash() // invalid proof + err = proof.Valid(verifier, atx.ID(), sig.NodeID(), pubSig.NodeID()) + require.ErrorContains(t, err, "invalid marriage ATX proof") + }) + + t.Run("node ID did not include post in merged ATX", func(t *testing.T) { + db := statesql.InMemoryTest(t) + + nipostChallenge := types.RandomHash() + const numUnits = uint32(11) + post := PostV1{ + Nonce: rand.Uint32(), + Indices: types.RandomBytes(11), + Pow: rand.Uint64(), + } + atx, initialAtx := newMergedATXv2(db, nipostChallenge, post, numUnits) + atx.NIPosts[0].Posts = slices.DeleteFunc(atx.NIPosts[0].Posts, func(subPost SubPostV2) bool { + return cmp.Equal(subPost.Post, post) + }) + + const invalidPostIndex = 7 + proof, err := NewInvalidPostProof(db, atx, initialAtx, sig.NodeID(), 0, invalidPostIndex) + require.EqualError(t, err, fmt.Sprintf("no PoST from %s in ATX", sig)) + require.Nil(t, proof) + }) + + t.Run("initial ATX is invalid", func(t *testing.T) { + db := statesql.InMemoryTest(t) + + nipostChallenge := types.RandomHash() + const numUnits = uint32(11) + post := PostV1{ + Nonce: rand.Uint32(), + Indices: types.RandomBytes(11), + Pow: rand.Uint64(), + } + atx, initialAtx := newMergedATXv2(db, nipostChallenge, post, numUnits) + initialAtx.SmesherID = types.RandomNodeID() // initial ATX published by different identity + + const invalidPostIndex = 7 + proof, err := NewInvalidPostProof(db, atx, initialAtx, sig.NodeID(), 0, invalidPostIndex) + require.ErrorContains(t, err, "node ID does not match smesher ID of initial ATX") + require.Nil(t, proof) + + atx, initialAtx = newMergedATXv2(db, nipostChallenge, post, numUnits) + initialAtx.Initial = nil // not an initial ATX + + proof, err = NewInvalidPostProof(db, atx, initialAtx, sig.NodeID(), 0, invalidPostIndex) + require.ErrorContains(t, err, "initial ATX does not contain initial PoST") + require.Nil(t, proof) + }) + + t.Run("invalid nipost index", func(t *testing.T) { + db := statesql.InMemoryTest(t) + + nipostChallenge := types.RandomHash() + const numUnits = uint32(11) + post := PostV1{ + Nonce: rand.Uint32(), + Indices: types.RandomBytes(11), + Pow: rand.Uint64(), + } + atx, initialAtx := newSoloATXv2(db, nipostChallenge, post, numUnits) + + const invalidPostIndex = 7 + proof, err := NewInvalidPostProof(db, atx, initialAtx, sig.NodeID(), 1, invalidPostIndex) // 1 is invalid + require.EqualError(t, err, "invalid NIPoST index") + require.Nil(t, proof) + }) + + t.Run("invalid ATX signature", func(t *testing.T) { + db := statesql.InMemoryTest(t) + + nipostChallenge := types.RandomHash() + const numUnits = uint32(11) + post := PostV1{ + Nonce: rand.Uint32(), + Indices: types.RandomBytes(11), + Pow: rand.Uint64(), + } + atx, initialAtx := newSoloATXv2(db, nipostChallenge, post, numUnits) + + const invalidPostIndex = 7 + proof, err := NewInvalidPostProof(db, atx, initialAtx, sig.NodeID(), 0, invalidPostIndex) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + proof.Signature = types.RandomEdSignature() // invalid signature + + id, err := proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "invalid signature") + require.Equal(t, types.EmptyNodeID, id) + }) + + t.Run("commitment proof is invalid", func(t *testing.T) { + db := statesql.InMemoryTest(t) + + nipostChallenge := types.RandomHash() + const numUnits = uint32(11) + post := PostV1{ + Nonce: rand.Uint32(), + Indices: types.RandomBytes(11), + Pow: rand.Uint64(), + } + _, initialAtx := newMergedATXv2(db, nipostChallenge, post, numUnits) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + // manually construct an invalid proof + proof, err := createCommitmentProof(initialAtx, sig.NodeID()) + require.NoError(t, err) + + signature := proof.Signature + proof.Signature = types.RandomEdSignature() // invalid signature + err = proof.Valid(verifier, sig.NodeID()) + require.ErrorContains(t, err, "invalid signature") + proof.Signature = signature + + proof.InitialATXID = types.RandomATXID() // invalid ATX + err = proof.Valid(verifier, sig.NodeID()) + require.ErrorContains(t, err, "invalid signature") + proof.InitialATXID = initialAtx.ID() + + proofHash := proof.InitialPostProof[0] + proof.InitialPostProof[0] = types.RandomHash() // invalid proof + err = proof.Valid(verifier, sig.NodeID()) + require.ErrorContains(t, err, "invalid initial PoST proof") + proof.InitialPostProof[0] = proofHash + + initialPostRoot := proof.InitialPostRoot + proof.InitialPostRoot = InitialPostRoot(types.EmptyHash32) // invalid initial post root + err = proof.Valid(verifier, sig.NodeID()) + require.ErrorContains(t, err, "invalid empty initial PoST root") + proof.InitialPostRoot = initialPostRoot + + commitmentATX := proof.CommitmentATX + proof.CommitmentATX = types.RandomATXID() // invalid ATX + err = proof.Valid(verifier, sig.NodeID()) + require.ErrorContains(t, err, "invalid commitment ATX proof") + proof.CommitmentATX = commitmentATX + + proofHash = proof.CommitmentATXProof[0] + proof.CommitmentATXProof[0] = types.RandomHash() // invalid proof + err = proof.Valid(verifier, sig.NodeID()) + require.ErrorContains(t, err, "invalid commitment ATX proof") + proof.CommitmentATXProof[0] = proofHash + }) + + t.Run("solo invalid post proof is not valid", func(t *testing.T) { + db := statesql.InMemoryTest(t) + + nipostChallenge := types.RandomHash() + const numUnits = uint32(11) + post := PostV1{ + Nonce: rand.Uint32(), + Indices: types.RandomBytes(11), + Pow: rand.Uint64(), + } + atx, initialAtx := newSoloATXv2(db, nipostChallenge, post, numUnits) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + // manually construct an invalid proof + const invalidPostIndex = 7 + proof, err := createInvalidPostProof(atx, 0, 0, invalidPostIndex) + require.NoError(t, err) + require.NotNil(t, proof) + + nipostsRoot := proof.NIPostsRoot + proof.NIPostsRoot = NIPostsRoot(types.RandomHash()) // invalid root + err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) + require.EqualError(t, err, "invalid NIPosts root proof") + proof.NIPostsRoot = nipostsRoot + + proofHash := proof.NIPostsRootProof[0] + proof.NIPostsRootProof[0] = types.RandomHash() // invalid proof + err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) + require.EqualError(t, err, "invalid NIPosts root proof") + proof.NIPostsRootProof[0] = proofHash + + proof.NIPostIndex = 1 // invalid index + err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) + require.EqualError(t, err, "invalid NIPoST root proof") + proof.NIPostIndex = 0 + + nipostRoot := proof.NIPostRoot + proof.NIPostRoot = NIPostRoot(types.RandomHash()) // invalid root + err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) + require.EqualError(t, err, "invalid NIPoST root proof") + proof.NIPostRoot = nipostRoot + + proofHash = proof.NIPostRootProof[0] + proof.NIPostRootProof[0] = types.RandomHash() // invalid proof + err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) + require.EqualError(t, err, "invalid NIPoST root proof") + proof.NIPostRootProof[0] = proofHash + + challenge := proof.Challenge + proof.Challenge = types.RandomHash() // invalid challenge + err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) + require.EqualError(t, err, "invalid challenge proof") + proof.Challenge = challenge + + proofHash = proof.ChallengeProof[0] + proof.ChallengeProof[0] = types.RandomHash() // invalid proof + err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) + require.EqualError(t, err, "invalid challenge proof") + proof.ChallengeProof[0] = proofHash + + subPostsRoot := proof.SubPostsRoot + proof.SubPostsRoot = SubPostsRoot(types.RandomHash()) // invalid root + err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) + require.EqualError(t, err, "invalid sub PoSTs root proof") + proof.SubPostsRoot = subPostsRoot + + proofHash = proof.SubPostsRootProof[0] + proof.SubPostsRootProof[0] = types.RandomHash() // invalid proof + err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) + require.EqualError(t, err, "invalid sub PoSTs root proof") + proof.SubPostsRootProof[0] = proofHash + + proof.SubPostRootIndex = 1 // invalid index + err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) + require.EqualError(t, err, "invalid sub PoST root proof") + proof.SubPostRootIndex = 0 + + subPost := proof.SubPostRoot + proof.SubPostRoot = SubPostRoot(types.RandomHash()) // invalid root + err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) + require.EqualError(t, err, "invalid sub PoST root proof") + proof.SubPostRoot = subPost + + proofHash = proof.SubPostRootProof[0] + proof.SubPostRootProof[0] = types.RandomHash() // invalid proof + err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) + require.EqualError(t, err, "invalid sub PoST root proof") + proof.SubPostRootProof[0] = proofHash + + proof.Post = PostV1{} // invalid post + err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) + require.EqualError(t, err, "invalid PoST proof") + proof.Post = post + + proof.NumUnits++ // invalid number of units + err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) + require.EqualError(t, err, "invalid num units proof") + proof.NumUnits-- + }) + + t.Run("merged invalid post proof is not valid", func(t *testing.T) { + db := statesql.InMemoryTest(t) + + nipostChallenge := types.RandomHash() + const numUnits = uint32(11) + post := PostV1{ + Nonce: rand.Uint32(), + Indices: types.RandomBytes(11), + Pow: rand.Uint64(), + } + atx, initialAtx := newMergedATXv2(db, nipostChallenge, post, numUnits) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + // manually construct an invalid proof + marriageIndex := uint32(1) + commitmentAtx := initialAtx.Initial.CommitmentATX + const invalidPostIndex = 7 + proof, err := createInvalidPostProof(atx, 0, 1, invalidPostIndex) + require.NoError(t, err) + require.NotNil(t, proof) + + invalidMarriageIndex := marriageIndex + 1 + + err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), commitmentAtx, &invalidMarriageIndex) + require.EqualError(t, err, "invalid marriage index proof") + }) +} diff --git a/activation/wire/malfeasance_shared.go b/activation/wire/malfeasance_shared.go new file mode 100644 index 0000000000..7486d2060e --- /dev/null +++ b/activation/wire/malfeasance_shared.go @@ -0,0 +1,146 @@ +package wire + +import ( + "context" + "errors" + "fmt" + "slices" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/atxs" +) + +//go:generate scalegen + +// MarryProof is a proof that a NodeID is married to another NodeID. +type MarryProof struct { + // MarriageCertificatesRoot and its proof that it is contained in the ATX. + MarriageCertificatesRoot MarriageCertificatesRoot + MarriageCertificatesProof MarriageCertificatesRootProof `scale:"max=32"` + + // The signature of the certificate and the proof that the certificate is contained in the MarriageRoot at + // the given index. + Certificate MarriageCertificate + CertificateProof MarriageCertificateProof `scale:"max=32"` + CertificateIndex uint32 +} + +func createMarryProof(db sql.Executor, atx *ActivationTxV2, nodeID types.NodeID) (MarryProof, error) { + marriageIndex := slices.IndexFunc(atx.Marriages, func(cert MarriageCertificate) bool { + if cert.ReferenceAtx == types.EmptyATXID && atx.SmesherID == nodeID { + // special case of the self signed certificate of the ATX publisher + return true + } + refATX, err := atxs.Get(db, cert.ReferenceAtx) + if err != nil { + return false + } + return refATX.SmesherID == nodeID + }) + if marriageIndex == -1 { + return MarryProof{}, fmt.Errorf("does not contain a marriage certificate signed by %s", nodeID.ShortString()) + } + + proof := MarryProof{ + MarriageCertificatesRoot: atx.Marriages.Root(), + MarriageCertificatesProof: atx.MarriagesRootProof(), + + Certificate: atx.Marriages[marriageIndex], + CertificateProof: atx.Marriages.Proof(marriageIndex), + CertificateIndex: uint32(marriageIndex), + } + return proof, nil +} + +// Valid returns an error if the proof is invalid. It checks that `nodeID` signed a certificate to marry `smesherID` +// and it was included in the ATX with the given `atxID`. +func (p MarryProof) Valid( + malValidator MalfeasanceValidator, + atxID types.ATXID, + smesherID types.NodeID, + nodeID types.NodeID, +) error { + if !malValidator.Signature(signing.MARRIAGE, nodeID, smesherID.Bytes(), p.Certificate.Signature) { + return errors.New("invalid certificate signature") + } + if !p.MarriageCertificatesProof.Valid(atxID, p.MarriageCertificatesRoot) { + return errors.New("invalid marriage proof") + } + if !p.CertificateProof.Valid(p.MarriageCertificatesRoot, int(p.CertificateIndex), p.Certificate) { + return errors.New("invalid certificate proof") + } + return nil +} + +// MarriageProof is a proof for two identities to be married via a marriage ATX. +type MarriageProof struct { + // MarriageATX and its proof that it is contained in the ATX. + MarriageATX types.ATXID + MarriageATXProof MarriageATXProof `scale:"max=32"` + // MarriageATXSmesherID is the ID of the smesher that published the marriage ATX. + MarriageATXSmesherID types.NodeID + + // NodeIDMarryProof is the proof that NodeID married in MarriageATX. + NodeIDMarryProof MarryProof + // SmesherIDMarryProof is the proof that SmesherID married in MarriageATX. + SmesherIDMarryProof MarryProof +} + +func createMarriageProof(db sql.Executor, atx *ActivationTxV2, nodeID types.NodeID) (MarriageProof, error) { + if nodeID == atx.SmesherID { + // we don't need a marriage proof if the node ID is the same as the smesher ID + return MarriageProof{}, errors.New("node ID is the same as smesher ID") + } + + var blob sql.Blob + v, err := atxs.LoadBlob(context.Background(), db, atx.MarriageATX.Bytes(), &blob) + if err != nil { + return MarriageProof{}, fmt.Errorf("get marriage ATX: %w", err) + } + if v != types.AtxV2 { + return MarriageProof{}, errors.New("invalid ATX version for marriage ATX") + } + marriageATX, err := DecodeAtxV2(blob.Bytes) + if err != nil { + return MarriageProof{}, fmt.Errorf("decode marriage ATX: %w", err) + } + + nodeIDmarriageProof, err := createMarryProof(db, marriageATX, nodeID) + if err != nil { + return MarriageProof{}, fmt.Errorf("NodeID marriage proof: %w", err) + } + smesherIDmarriageProof, err := createMarryProof(db, marriageATX, atx.SmesherID) + if err != nil { + return MarriageProof{}, fmt.Errorf("SmesherID marriage proof: %w", err) + } + + proof := MarriageProof{ + MarriageATX: marriageATX.ID(), + MarriageATXProof: atx.MarriageATXProof(), + + MarriageATXSmesherID: marriageATX.SmesherID, + + NodeIDMarryProof: nodeIDmarriageProof, + SmesherIDMarryProof: smesherIDmarriageProof, + } + return proof, nil +} + +func (p MarriageProof) Valid( + malValidator MalfeasanceValidator, + atxID types.ATXID, + nodeID, smesherID types.NodeID, +) error { + if !p.MarriageATXProof.Valid(atxID, p.MarriageATX) { + return errors.New("invalid marriage ATX proof") + } + if err := p.NodeIDMarryProof.Valid(malValidator, p.MarriageATX, p.MarriageATXSmesherID, nodeID); err != nil { + return fmt.Errorf("invalid marriage proof for NodeID: %w", err) + } + if err := p.SmesherIDMarryProof.Valid(malValidator, p.MarriageATX, p.MarriageATXSmesherID, smesherID); err != nil { + return fmt.Errorf("invalid marriage proof for SmesherID: %w", err) + } + return nil +} diff --git a/activation/wire/malfeasance_shared_scale.go b/activation/wire/malfeasance_shared_scale.go new file mode 100644 index 0000000000..026bb4c8a3 --- /dev/null +++ b/activation/wire/malfeasance_shared_scale.go @@ -0,0 +1,169 @@ +// Code generated by github.com/spacemeshos/go-scale/scalegen. DO NOT EDIT. + +// nolint +package wire + +import ( + "github.com/spacemeshos/go-scale" + "github.com/spacemeshos/go-spacemesh/common/types" +) + +func (t *MarryProof) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := scale.EncodeByteArray(enc, t.MarriageCertificatesRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.MarriageCertificatesProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := t.Certificate.EncodeScale(enc) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.CertificateProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeCompact32(enc, uint32(t.CertificateIndex)) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *MarryProof) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + n, err := scale.DecodeByteArray(dec, t.MarriageCertificatesRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.MarriageCertificatesProof = field + } + { + n, err := t.Certificate.DecodeScale(dec) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.CertificateProof = field + } + { + field, n, err := scale.DecodeCompact32(dec) + if err != nil { + return total, err + } + total += n + t.CertificateIndex = uint32(field) + } + return total, nil +} + +func (t *MarriageProof) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := scale.EncodeByteArray(enc, t.MarriageATX[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.MarriageATXProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.MarriageATXSmesherID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := t.NodeIDMarryProof.EncodeScale(enc) + if err != nil { + return total, err + } + total += n + } + { + n, err := t.SmesherIDMarryProof.EncodeScale(enc) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *MarriageProof) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + n, err := scale.DecodeByteArray(dec, t.MarriageATX[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.MarriageATXProof = field + } + { + n, err := scale.DecodeByteArray(dec, t.MarriageATXSmesherID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := t.NodeIDMarryProof.DecodeScale(dec) + if err != nil { + return total, err + } + total += n + } + { + n, err := t.SmesherIDMarryProof.DecodeScale(dec) + if err != nil { + return total, err + } + total += n + } + return total, nil +} diff --git a/activation/wire/mocks.go b/activation/wire/mocks.go new file mode 100644 index 0000000000..ae0fd1be61 --- /dev/null +++ b/activation/wire/mocks.go @@ -0,0 +1,119 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./interface.go +// +// Generated by this command: +// +// mockgen -typed -package=wire -destination=./mocks.go -source=./interface.go +// + +// Package wire is a generated GoMock package. +package wire + +import ( + context "context" + reflect "reflect" + + types "github.com/spacemeshos/go-spacemesh/common/types" + signing "github.com/spacemeshos/go-spacemesh/signing" + gomock "go.uber.org/mock/gomock" +) + +// MockMalfeasanceValidator is a mock of MalfeasanceValidator interface. +type MockMalfeasanceValidator struct { + ctrl *gomock.Controller + recorder *MockMalfeasanceValidatorMockRecorder + isgomock struct{} +} + +// MockMalfeasanceValidatorMockRecorder is the mock recorder for MockMalfeasanceValidator. +type MockMalfeasanceValidatorMockRecorder struct { + mock *MockMalfeasanceValidator +} + +// NewMockMalfeasanceValidator creates a new mock instance. +func NewMockMalfeasanceValidator(ctrl *gomock.Controller) *MockMalfeasanceValidator { + mock := &MockMalfeasanceValidator{ctrl: ctrl} + mock.recorder = &MockMalfeasanceValidatorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMalfeasanceValidator) EXPECT() *MockMalfeasanceValidatorMockRecorder { + return m.recorder +} + +// PostIndex mocks base method. +func (m *MockMalfeasanceValidator) PostIndex(ctx context.Context, smesherID types.NodeID, commitment types.ATXID, post *types.Post, challenge []byte, numUnits uint32, idx int) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PostIndex", ctx, smesherID, commitment, post, challenge, numUnits, idx) + ret0, _ := ret[0].(error) + return ret0 +} + +// PostIndex indicates an expected call of PostIndex. +func (mr *MockMalfeasanceValidatorMockRecorder) PostIndex(ctx, smesherID, commitment, post, challenge, numUnits, idx any) *MockMalfeasanceValidatorPostIndexCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PostIndex", reflect.TypeOf((*MockMalfeasanceValidator)(nil).PostIndex), ctx, smesherID, commitment, post, challenge, numUnits, idx) + return &MockMalfeasanceValidatorPostIndexCall{Call: call} +} + +// MockMalfeasanceValidatorPostIndexCall wrap *gomock.Call +type MockMalfeasanceValidatorPostIndexCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockMalfeasanceValidatorPostIndexCall) Return(arg0 error) *MockMalfeasanceValidatorPostIndexCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockMalfeasanceValidatorPostIndexCall) Do(f func(context.Context, types.NodeID, types.ATXID, *types.Post, []byte, uint32, int) error) *MockMalfeasanceValidatorPostIndexCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockMalfeasanceValidatorPostIndexCall) DoAndReturn(f func(context.Context, types.NodeID, types.ATXID, *types.Post, []byte, uint32, int) error) *MockMalfeasanceValidatorPostIndexCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Signature mocks base method. +func (m_2 *MockMalfeasanceValidator) Signature(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + m_2.ctrl.T.Helper() + ret := m_2.ctrl.Call(m_2, "Signature", d, nodeID, m, sig) + ret0, _ := ret[0].(bool) + return ret0 +} + +// Signature indicates an expected call of Signature. +func (mr *MockMalfeasanceValidatorMockRecorder) Signature(d, nodeID, m, sig any) *MockMalfeasanceValidatorSignatureCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signature", reflect.TypeOf((*MockMalfeasanceValidator)(nil).Signature), d, nodeID, m, sig) + return &MockMalfeasanceValidatorSignatureCall{Call: call} +} + +// MockMalfeasanceValidatorSignatureCall wrap *gomock.Call +type MockMalfeasanceValidatorSignatureCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockMalfeasanceValidatorSignatureCall) Return(arg0 bool) *MockMalfeasanceValidatorSignatureCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockMalfeasanceValidatorSignatureCall) Do(f func(signing.Domain, types.NodeID, []byte, types.EdSignature) bool) *MockMalfeasanceValidatorSignatureCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockMalfeasanceValidatorSignatureCall) DoAndReturn(f func(signing.Domain, types.NodeID, []byte, types.EdSignature) bool) *MockMalfeasanceValidatorSignatureCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/activation/wire/wire_v1.go b/activation/wire/wire_v1.go index 059180c9dc..200ae83fa1 100644 --- a/activation/wire/wire_v1.go +++ b/activation/wire/wire_v1.go @@ -44,23 +44,24 @@ type PostV1 struct { Pow uint64 } -func (p *PostV1) Root() []byte { - tree, err := merkle.NewTreeBuilder(). - WithHashFunc(atxTreeHash). - Build() - if err != nil { - panic(err) - } - nonce := make([]byte, 4) - binary.LittleEndian.PutUint32(nonce, p.Nonce) - tree.AddLeaf(nonce) +func (p *PostV1) merkleTree(tree *merkle.Tree) { + var nonce types.Hash32 + binary.LittleEndian.PutUint32(nonce[:], p.Nonce) + tree.AddLeaf(nonce.Bytes()) + + hasher := hash.GetHasher() + defer hash.PutHasher(hasher) + tree.AddLeaf(hasher.Sum(p.Indices)) + + var pow types.Hash32 + binary.LittleEndian.PutUint64(pow[:], p.Pow) + tree.AddLeaf(pow.Bytes()) +} - tree.AddLeaf(p.Indices) +type PostRoot types.Hash32 - pow := make([]byte, 8) - binary.LittleEndian.PutUint64(pow, p.Pow) - tree.AddLeaf(pow) - return tree.Root() +func (p *PostV1) Root() PostRoot { + return PostRoot(createRoot(p.merkleTree)) } type MerkleProofV1 struct { diff --git a/activation/wire/wire_v2.go b/activation/wire/wire_v2.go index 5990589d30..08a1390f5d 100644 --- a/activation/wire/wire_v2.go +++ b/activation/wire/wire_v2.go @@ -4,11 +4,11 @@ import ( "encoding/binary" "github.com/spacemeshos/merkle-tree" - "github.com/zeebo/blake3" "go.uber.org/zap/zapcore" "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/hash" "github.com/spacemeshos/go-spacemesh/signing" ) @@ -21,8 +21,8 @@ type ActivationTxV2 struct { // only present in initial ATX Initial *InitialAtxPartsV2 - PreviousATXs []types.ATXID `scale:"max=256"` - NiPosts []NiPostsV2 `scale:"max=4"` + PreviousATXs PrevATXs `scale:"max=256"` + NIPosts NIPosts `scale:"max=4"` // The VRF nonce must be valid for the collected space of all included IDs. VRFNonce uint64 @@ -66,54 +66,77 @@ func DecodeAtxV2(blob []byte) (*ActivationTxV2, error) { return atx, nil } +func (atx *ActivationTxV2) Sign(signer *signing.EdSigner) { + atx.SmesherID = signer.NodeID() + atx.Signature = signer.Sign(signing.ATX, atx.ID().Bytes()) +} + +func (atx *ActivationTxV2) TotalNumUnits() uint32 { + var total uint32 + for _, post := range atx.NIPosts { + for _, subPost := range post.Posts { + total += subPost.NumUnits + } + } + return total +} + +func (atx *ActivationTxV2) MarshalLogObject(encoder zapcore.ObjectEncoder) error { + if atx == nil { + return nil + } + encoder.AddString("ID", atx.ID().String()) + encoder.AddString("Smesher", atx.SmesherID.String()) + encoder.AddUint32("PublishEpoch", atx.PublishEpoch.Uint32()) + encoder.AddString("PositioningATX", atx.PositioningATX.String()) + encoder.AddString("Coinbase", atx.Coinbase.String()) + encoder.AddObject("Initial", atx.Initial) + encoder.AddArray("PreviousATXs", types.ATXIDs(atx.PreviousATXs)) + encoder.AddArray("NiPosts", zapcore.ArrayMarshalerFunc(func(encoder zapcore.ArrayEncoder) error { + for _, nipost := range atx.NIPosts { + encoder.AppendObject(&nipost) + } + return nil + })) + encoder.AddUint64("VRFNonce", atx.VRFNonce) + + encoder.AddArray("Marriages", zapcore.ArrayMarshalerFunc(func(encoder zapcore.ArrayEncoder) error { + for _, marriage := range atx.Marriages { + encoder.AppendObject(&marriage) + } + return nil + })) + if atx.MarriageATX != nil { + encoder.AddString("MarriageATX", atx.MarriageATX.String()) + } + encoder.AddString("Signature", atx.Signature.String()) + return nil +} + func (atx *ActivationTxV2) merkleTree(tree *merkle.Tree) { var publishEpoch types.Hash32 binary.LittleEndian.PutUint32(publishEpoch[:], atx.PublishEpoch.Uint32()) tree.AddLeaf(publishEpoch.Bytes()) tree.AddLeaf(atx.PositioningATX.Bytes()) - tree.AddLeaf(atx.Coinbase.Bytes()) + + var coinbase types.Hash32 + copy(coinbase[:], atx.Coinbase.Bytes()) + tree.AddLeaf(coinbase.Bytes()) if atx.Initial != nil { - tree.AddLeaf(atx.Initial.Root()) + tree.AddLeaf(types.Hash32(atx.Initial.Root()).Bytes()) } else { tree.AddLeaf(types.EmptyHash32.Bytes()) } - prevATXTree, err := merkle.NewTreeBuilder(). - WithHashFunc(atxTreeHash). - Build() - if err != nil { - panic(err) - } - for _, prevATX := range atx.PreviousATXs { - prevATXTree.AddLeaf(prevATX.Bytes()) - } - for i := len(atx.PreviousATXs); i < 256; i++ { - prevATXTree.AddLeaf(types.EmptyATXID.Bytes()) - } - tree.AddLeaf(prevATXTree.Root()) - - niPostTree, err := merkle.NewTreeBuilder(). - WithHashFunc(atxTreeHash). - Build() - if err != nil { - panic(err) - } - for _, niPost := range atx.NiPosts { - niPostTree.AddLeaf(niPost.Root(atx.PreviousATXs)) - } - // Add empty niposts up to the max scale limit. - // This must be updated when the max scale limit is changed. - for i := len(atx.NiPosts); i < 4; i++ { - niPostTree.AddLeaf(types.EmptyHash32.Bytes()) - } - tree.AddLeaf(niPostTree.Root()) + tree.AddLeaf(atx.PreviousATXs.Root().Bytes()) + tree.AddLeaf(types.Hash32(atx.NIPosts.Root(atx.PreviousATXs)).Bytes()) var vrfNonce types.Hash32 binary.LittleEndian.PutUint64(vrfNonce[:], atx.VRFNonce) tree.AddLeaf(vrfNonce.Bytes()) - tree.AddLeaf(atx.Marriages.Root()) + tree.AddLeaf(types.Hash32(atx.Marriages.Root()).Bytes()) if atx.MarriageATX != nil { tree.AddLeaf(atx.MarriageATX.Bytes()) @@ -122,100 +145,238 @@ func (atx *ActivationTxV2) merkleTree(tree *merkle.Tree) { } } +func (atx *ActivationTxV2) merkleProof(leafIndex MerkleTreeIndex) []types.Hash32 { + return createProof(uint64(leafIndex), atx.merkleTree) +} + +// ID returns the ATX ID. It is the root of the ATX merkle tree. func (atx *ActivationTxV2) ID() types.ATXID { if atx.id != types.EmptyATXID { return atx.id } - - tree, err := merkle.NewTreeBuilder(). - WithHashFunc(atxTreeHash). - Build() - if err != nil { - panic(err) - } - atx.merkleTree(tree) - atx.id = types.ATXID(tree.Root()) + atx.id = types.ATXID(createRoot(atx.merkleTree)) return atx.id } -func (atx *ActivationTxV2) Sign(signer *signing.EdSigner) { - atx.SmesherID = signer.NodeID() - atx.Signature = signer.Sign(signing.ATX, atx.ID().Bytes()) +func (atx *ActivationTxV2) PublishEpochProof() []types.Hash32 { + return atx.merkleProof(PublishEpochIndex) } -func (atx *ActivationTxV2) TotalNumUnits() uint32 { - var total uint32 - for _, post := range atx.NiPosts { - for _, subPost := range post.Posts { - total += subPost.NumUnits - } +func (atx *ActivationTxV2) PositioningATXProof() []types.Hash32 { + return atx.merkleProof(PositioningATXIndex) +} + +func (atx *ActivationTxV2) CoinbaseProof() []types.Hash32 { + return atx.merkleProof(CoinbaseIndex) +} + +func (atx *ActivationTxV2) InitialPostRootProof() InitialPostRootProof { + return atx.merkleProof(InitialPostRootIndex) +} + +type InitialPostRootProof []types.Hash32 + +func (p InitialPostRootProof) Valid(atxID types.ATXID, initialPostRoot InitialPostRoot) bool { + return validateProof(types.Hash32(atxID), types.Hash32(initialPostRoot), p, uint64(InitialPostRootIndex)) +} + +func (atx *ActivationTxV2) PreviousATXsRootProof() []types.Hash32 { + return atx.merkleProof(PreviousATXsRootIndex) +} + +func (atx *ActivationTxV2) NIPostsRootProof() NIPostsRootProof { + return atx.merkleProof(NIPostsRootIndex) +} + +type NIPostsRootProof []types.Hash32 + +func (p NIPostsRootProof) Valid(atxID types.ATXID, niPostsRoot NIPostsRoot) bool { + return validateProof(types.Hash32(atxID), types.Hash32(niPostsRoot), p, uint64(NIPostsRootIndex)) +} + +func (atx *ActivationTxV2) VRFNonceProof() []types.Hash32 { + return atx.merkleProof(VRFNonceIndex) +} + +func (atx *ActivationTxV2) MarriagesRootProof() MarriageCertificatesRootProof { + return atx.merkleProof(MarriagesRootIndex) +} + +type MarriageCertificatesRootProof []types.Hash32 + +func (p MarriageCertificatesRootProof) Valid(atxID types.ATXID, marriagesRoot MarriageCertificatesRoot) bool { + return validateProof(types.Hash32(atxID), types.Hash32(marriagesRoot), p, uint64(MarriagesRootIndex)) +} + +func (atx *ActivationTxV2) MarriageATXProof() MarriageATXProof { + return atx.merkleProof(MarriageATXIndex) +} + +type MarriageATXProof []types.Hash32 + +func (p MarriageATXProof) Valid(atxID, marriageATX types.ATXID) bool { + return validateProof(types.Hash32(atxID), types.Hash32(marriageATX), p, uint64(MarriageATXIndex)) +} + +type InitialAtxPartsV2 struct { + CommitmentATX types.ATXID + Post PostV1 +} + +func (parts *InitialAtxPartsV2) MarshalLogObject(encoder zapcore.ObjectEncoder) error { + if parts == nil { + return nil } - return total + encoder.AddString("CommitmentATX", parts.CommitmentATX.String()) + encoder.AddObject("Post", &parts.Post) + return nil } -type MarriageCertificates []MarriageCertificate +func (parts *InitialAtxPartsV2) merkleTree(tree *merkle.Tree) { + tree.AddLeaf(parts.CommitmentATX.Bytes()) + tree.AddLeaf(types.Hash32(parts.Post.Root()).Bytes()) +} -func (mcs MarriageCertificates) Root() []byte { - marriagesTree, err := merkle.NewTreeBuilder(). - WithHashFunc(atxTreeHash). - Build() - if err != nil { - panic(err) +func (parts *InitialAtxPartsV2) merkleProof(leafIndex InitialPostTreeIndex) []types.Hash32 { + return createProof(uint64(leafIndex), parts.merkleTree) +} + +type InitialPostRoot types.Hash32 + +func (parts *InitialAtxPartsV2) Root() InitialPostRoot { + return InitialPostRoot(createRoot(parts.merkleTree)) +} + +func (parts *InitialAtxPartsV2) CommitmentATXProof() CommitmentATXProof { + return parts.merkleProof(CommitmentATXIndex) +} + +type CommitmentATXProof []types.Hash32 + +func (p CommitmentATXProof) Valid(initialPostRoot InitialPostRoot, commitmentATX types.ATXID) bool { + return validateProof(types.Hash32(initialPostRoot), types.Hash32(commitmentATX), p, uint64(CommitmentATXIndex)) +} + +func (parts *InitialAtxPartsV2) PostProof() []types.Hash32 { + return parts.merkleProof(InitialPostIndex) +} + +type PrevATXs []types.ATXID + +func (prevATXs PrevATXs) merkleTree(tree *merkle.Tree) { + for _, prevATX := range prevATXs { + tree.AddLeaf(prevATX.Bytes()) + } + for i := len(prevATXs); i < 256; i++ { + tree.AddLeaf(types.EmptyATXID.Bytes()) } - mcs.merkleTree(marriagesTree) - return marriagesTree.Root() } -func (mcs MarriageCertificates) merkleTree(tree *merkle.Tree) { - for _, marriage := range mcs { - tree.AddLeaf(marriage.Root()) +func (prevATXs PrevATXs) Root() types.Hash32 { + return createRoot(prevATXs.merkleTree) +} + +type NIPosts []NIPostV2 + +func (nps NIPosts) merkleTree(tree *merkle.Tree, prevATXs []types.ATXID) { + for _, niPost := range nps { + tree.AddLeaf(types.Hash32(niPost.Root(prevATXs)).Bytes()) } - for i := len(mcs); i < 256; i++ { + // Add empty NiPoSTs up to the max scale limit. + // This must be updated when the max scale limit is changed. + for i := len(nps); i < 4; i++ { tree.AddLeaf(types.EmptyHash32.Bytes()) } } -type InitialAtxPartsV2 struct { - CommitmentATX types.ATXID - Post PostV1 +type NIPostsRoot types.Hash32 + +func (nps NIPosts) Root(prevATXs []types.ATXID) NIPostsRoot { + return NIPostsRoot(createRoot(func(tree *merkle.Tree) { + nps.merkleTree(tree, prevATXs) + })) } -func (i *InitialAtxPartsV2) Root() []byte { - tree, err := merkle.NewTreeBuilder(). - WithHashFunc(atxTreeHash). - Build() - if err != nil { - panic(err) +func (nps NIPosts) Proof(index int, prevATXs []types.ATXID) NIPostRootProof { + if index < 0 || index >= len(nps) { + panic("index out of range") } - tree.AddLeaf(i.CommitmentATX.Bytes()) - tree.AddLeaf(i.Post.Root()) - return tree.Root() + return createProof(uint64(index), func(tree *merkle.Tree) { + nps.merkleTree(tree, prevATXs) + }) } -// MarriageCertificate proves the will of ID to be married with the ID that includes this certificate. -// A marriage allows for publishing a merged ATX, which can contain PoST for all married IDs. -// Any ID from the marriage can publish a merged ATX on behalf of all married IDs. -type MarriageCertificate struct { - // An ATX of the ID that marries. It proves that the ID exists. - // Note: the reference ATX does not need to be from the previous epoch. - // It only needs to prove the existence of the ID. - ReferenceAtx types.ATXID - // Signature over the other ID that this ID marries with - // If Alice marries Bob, then Alice signs Bob's ID - // and Bob includes this certificate in his ATX. - Signature types.EdSignature +type NIPostRootProof []types.Hash32 + +func (p NIPostRootProof) Valid(niPostsRoot NIPostsRoot, index int, nipostRoot NIPostRoot) bool { + return validateProof(types.Hash32(niPostsRoot), types.Hash32(nipostRoot), p, uint64(index)) } -func (mc *MarriageCertificate) Root() []byte { - tree, err := merkle.NewTreeBuilder(). - WithHashFunc(atxTreeHash). - Build() - if err != nil { - panic(err) +type NIPostV2 struct { + // Single membership proof for all IDs in `Posts`. + Membership MerkleProofV2 + // The root of the PoET proof, that serves as the challenge for PoSTs. + Challenge types.Hash32 + Posts SubPostsV2 `scale:"max=256"` // support merging up to 256 IDs +} + +func (np *NIPostV2) MarshalLogObject(encoder zapcore.ObjectEncoder) error { + if np == nil { + return nil } - tree.AddLeaf(mc.ReferenceAtx.Bytes()) - tree.AddLeaf(mc.Signature.Bytes()) - return tree.Root() + // skip membership proof + encoder.AddString("Challenge", np.Challenge.String()) + encoder.AddArray("Posts", zapcore.ArrayMarshalerFunc(func(ae zapcore.ArrayEncoder) error { + for _, post := range np.Posts { + ae.AppendObject(&post) + } + return nil + })) + return nil +} + +func (np *NIPostV2) merkleTree(tree *merkle.Tree, prevATXs []types.ATXID) { + tree.AddLeaf(np.Membership.Root().Bytes()) + tree.AddLeaf(np.Challenge.Bytes()) + tree.AddLeaf(types.Hash32(np.Posts.Root(prevATXs)).Bytes()) +} + +func (np *NIPostV2) merkleProof(leafIndex NIPostTreeIndex, prevATXs []types.ATXID) []types.Hash32 { + return createProof(uint64(leafIndex), func(tree *merkle.Tree) { + np.merkleTree(tree, prevATXs) + }) +} + +type NIPostRoot types.Hash32 + +func (np *NIPostV2) Root(prevATXs []types.ATXID) NIPostRoot { + return NIPostRoot(createRoot(func(tree *merkle.Tree) { + np.merkleTree(tree, prevATXs) + })) +} + +func (np *NIPostV2) MembershipProof(prevATXs []types.ATXID) []types.Hash32 { + return np.merkleProof(MembershipIndex, prevATXs) +} + +func (np *NIPostV2) ChallengeProof(prevATXs []types.ATXID) []types.Hash32 { + return np.merkleProof(ChallengeIndex, prevATXs) +} + +type ChallengeProof []types.Hash32 + +func (p ChallengeProof) Valid(nipostRoot NIPostRoot, challenge types.Hash32) bool { + return validateProof(types.Hash32(nipostRoot), challenge, p, uint64(ChallengeIndex)) +} + +func (np *NIPostV2) PostsRootProof(prevATXs []types.ATXID) SubPostsRootProof { + return np.merkleProof(PostsRootIndex, prevATXs) +} + +type SubPostsRootProof []types.Hash32 + +func (p SubPostsRootProof) Valid(nipostRoot NIPostRoot, postsRoot SubPostsRoot) bool { + return validateProof(types.Hash32(nipostRoot), types.Hash32(postsRoot), p, uint64(PostsRootIndex)) } // MerkleProofV2 proves membership of multiple challenges in a PoET membership merkle tree. @@ -224,6 +385,54 @@ type MerkleProofV2 struct { Nodes []types.Hash32 `scale:"max=32"` } +func (mp MerkleProofV2) Root() types.Hash32 { + hasher := hash.GetHasher() + defer hash.PutHasher(hasher) + hasher.Write([]byte{0x01}) + for _, node := range mp.Nodes { + hasher.Write(node.Bytes()) + } + return types.Hash32(hasher.Sum(nil)) +} + +type SubPostsV2 []SubPostV2 + +func (sp SubPostsV2) merkleTree(tree *merkle.Tree, prevATXs []types.ATXID) { + for _, subPost := range sp { + // if root is nil it will be handled like 0x00...00 + // this will still generate a valid ID for the ATX, + // but syntactical validation will catch the invalid subPost and + // consider the ATX invalid + tree.AddLeaf(types.Hash32(subPost.Root(prevATXs)).Bytes()) + } + for i := len(sp); i < 256; i++ { + tree.AddLeaf(types.EmptyHash32.Bytes()) + } +} + +type SubPostsRoot types.Hash32 + +func (sp SubPostsV2) Root(prevATXs []types.ATXID) SubPostsRoot { + return SubPostsRoot(createRoot(func(tree *merkle.Tree) { + sp.merkleTree(tree, prevATXs) + })) +} + +func (sp SubPostsV2) Proof(index int, prevATXs []types.ATXID) SubPostRootProof { + if index < 0 || index >= len(sp) { + panic("index out of range") + } + return createProof(uint64(index), func(tree *merkle.Tree) { + sp.merkleTree(tree, prevATXs) + }) +} + +type SubPostRootProof []types.Hash32 + +func (p SubPostRootProof) Valid(subPostsRoot SubPostsRoot, index int, subPostRoot SubPostRoot) bool { + return validateProof(types.Hash32(subPostsRoot), types.Hash32(subPostRoot), p, uint64(index)) +} + type SubPostV2 struct { // Index of marriage certificate for this ID in the 'Marriages' slice. Only valid for merged ATXs. // Can be used to extract the nodeID and verify if it is married with the smesher of the ATX. @@ -238,153 +447,226 @@ type SubPostV2 struct { NumUnits uint32 } -func (sp *SubPostV2) Root(prevATXs []types.ATXID) []byte { - tree, err := merkle.NewTreeBuilder(). - WithHashFunc(atxTreeHash). - Build() - if err != nil { - panic(err) +func (post *SubPostV2) MarshalLogObject(encoder zapcore.ObjectEncoder) error { + if post == nil { + return nil } - marriageIndex := make([]byte, 4) - binary.LittleEndian.PutUint32(marriageIndex, sp.MarriageIndex) - tree.AddLeaf(marriageIndex) + encoder.AddUint32("MarriageIndex", post.MarriageIndex) + encoder.AddUint32("PrevATXIndex", post.PrevATXIndex) + encoder.AddUint64("MembershipLeafIndex", post.MembershipLeafIndex) + encoder.AddObject("Post", &post.Post) + encoder.AddUint32("NumUnits", post.NumUnits) + return nil +} - if int(sp.PrevATXIndex) >= len(prevATXs) { - return nil // invalid index, root cannot be generated +func (sp *SubPostV2) merkleTree(tree *merkle.Tree, prevATXs []types.ATXID) { + var marriageIndex types.Hash32 + binary.LittleEndian.PutUint32(marriageIndex[:], sp.MarriageIndex) + tree.AddLeaf(marriageIndex.Bytes()) + + switch { + case len(prevATXs) == 0: // special case for initial ATX: prevATXs is empty + tree.AddLeaf(types.EmptyATXID.Bytes()) + case int(sp.PrevATXIndex) < len(prevATXs): + tree.AddLeaf(prevATXs[sp.PrevATXIndex].Bytes()) + default: + // prevATXIndex is out of range, don't fail ATXID generation + // will be detected by syntactical validation + tree.AddLeaf(types.EmptyATXID.Bytes()) } - tree.AddLeaf(prevATXs[sp.PrevATXIndex].Bytes()) var leafIndex types.Hash32 binary.LittleEndian.PutUint64(leafIndex[:], sp.MembershipLeafIndex) tree.AddLeaf(leafIndex[:]) - tree.AddLeaf(sp.Post.Root()) + tree.AddLeaf(types.Hash32(sp.Post.Root()).Bytes()) - numUnits := make([]byte, 4) - binary.LittleEndian.PutUint32(numUnits, sp.NumUnits) - tree.AddLeaf(numUnits) - return tree.Root() + var numUnits types.Hash32 + binary.LittleEndian.PutUint32(numUnits[:], sp.NumUnits) + tree.AddLeaf(numUnits.Bytes()) } -type NiPostsV2 struct { - // Single membership proof for all IDs in `Posts`. - Membership MerkleProofV2 - // The root of the PoET proof, that serves as the challenge for PoSTs. - Challenge types.Hash32 - Posts []SubPostV2 `scale:"max=256"` // support merging up to 256 IDs +func (sp *SubPostV2) merkleProof(leafIndex SubPostTreeIndex, prevATXs []types.ATXID) []types.Hash32 { + return createProof(uint64(leafIndex), func(tree *merkle.Tree) { + sp.merkleTree(tree, prevATXs) + }) } -func (np *NiPostsV2) Root(prevATXs []types.ATXID) []byte { - tree, err := merkle.NewTreeBuilder(). - WithHashFunc(atxTreeHash). - Build() - if err != nil { - panic(err) - } - tree.AddLeaf(codec.MustEncode(&np.Membership)) - tree.AddLeaf(np.Challenge.Bytes()) +type SubPostRoot types.Hash32 - postsTree, err := merkle.NewTreeBuilder(). - WithHashFunc(atxTreeHash). - Build() - if err != nil { - panic(err) - } - for _, subPost := range np.Posts { - // if root is nil it will be handled like 0x00...00 - // this will still generate a valid ID for the ATX, - // but syntactical validation will catch the invalid subPost and - // consider the ATX invalid - postsTree.AddLeaf(subPost.Root(prevATXs)) +func (sp *SubPostV2) Root(prevATXs []types.ATXID) SubPostRoot { + return SubPostRoot(createRoot(func(tree *merkle.Tree) { + sp.merkleTree(tree, prevATXs) + })) +} + +func (sp *SubPostV2) MarriageIndexProof(prevATXs []types.ATXID) MarriageIndexProof { + return sp.merkleProof(MarriageIndex, prevATXs) +} + +type MarriageIndexProof []types.Hash32 + +func (p MarriageIndexProof) Valid(subPostRoot SubPostRoot, marriageIndex uint32) bool { + var marriageIndexBytes types.Hash32 + binary.LittleEndian.PutUint32(marriageIndexBytes[:], marriageIndex) + return validateProof(types.Hash32(subPostRoot), marriageIndexBytes, p, uint64(MarriageIndex)) +} + +func (sp *SubPostV2) PrevATXIndexProof(prevATXs []types.ATXID) []types.Hash32 { + return sp.merkleProof(PrevATXIndex, prevATXs) +} + +func (sp *SubPostV2) MembershipLeafIndexProof(prevATXs []types.ATXID) []types.Hash32 { + return sp.merkleProof(MembershipLeafIndex, prevATXs) +} + +func (sp *SubPostV2) PostProof(prevATXs []types.ATXID) PostRootProof { + return sp.merkleProof(PostIndex, prevATXs) +} + +type PostRootProof []types.Hash32 + +func (p PostRootProof) Valid(subPostRoot SubPostRoot, postRoot PostRoot) bool { + return validateProof(types.Hash32(subPostRoot), types.Hash32(postRoot), p, uint64(PostIndex)) +} + +func (sp *SubPostV2) NumUnitsProof(prevATXs []types.ATXID) NumUnitsProof { + return sp.merkleProof(NumUnitsIndex, prevATXs) +} + +type NumUnitsProof []types.Hash32 + +func (p NumUnitsProof) Valid(subPostRoot SubPostRoot, numUnits uint32) bool { + var numUnitsBytes types.Hash32 + binary.LittleEndian.PutUint32(numUnitsBytes[:], numUnits) + return validateProof(types.Hash32(subPostRoot), numUnitsBytes, p, uint64(NumUnitsIndex)) +} + +type MarriageCertificates []MarriageCertificate + +func (mcs MarriageCertificates) merkleTree(tree *merkle.Tree) { + for _, marriage := range mcs { + tree.AddLeaf(marriage.Root().Bytes()) } - for i := len(np.Posts); i < 256; i++ { - postsTree.AddLeaf(types.EmptyHash32.Bytes()) + for i := len(mcs); i < 256; i++ { + tree.AddLeaf(types.EmptyHash32.Bytes()) } - tree.AddLeaf(postsTree.Root()) - return tree.Root() } -func atxTreeHash(buf, lChild, rChild []byte) []byte { - hash := blake3.New() - hash.Write([]byte{0x01}) - hash.Write(lChild) - hash.Write(rChild) - return hash.Sum(buf) +type MarriageCertificatesRoot types.Hash32 + +func (mcs MarriageCertificates) Root() MarriageCertificatesRoot { + return MarriageCertificatesRoot(createRoot(mcs.merkleTree)) } -func (atx *ActivationTxV2) MarshalLogObject(encoder zapcore.ObjectEncoder) error { - if atx == nil { - return nil +func (mcs MarriageCertificates) Proof(index int) MarriageCertificateProof { + if index < 0 || index >= len(mcs) { + panic("index out of range") } - encoder.AddString("ID", atx.ID().String()) - encoder.AddString("Smesher", atx.SmesherID.String()) - encoder.AddUint32("PublishEpoch", atx.PublishEpoch.Uint32()) - encoder.AddString("PositioningATX", atx.PositioningATX.String()) - encoder.AddString("Coinbase", atx.Coinbase.String()) - encoder.AddObject("Initial", atx.Initial) - encoder.AddArray("PreviousATXs", types.ATXIDs(atx.PreviousATXs)) - encoder.AddArray("NiPosts", zapcore.ArrayMarshalerFunc(func(encoder zapcore.ArrayEncoder) error { - for _, nipost := range atx.NiPosts { - encoder.AppendObject(&nipost) - } - return nil - })) - encoder.AddUint64("VRFNonce", atx.VRFNonce) + return createProof(uint64(index), mcs.merkleTree) +} - encoder.AddArray("Marriages", zapcore.ArrayMarshalerFunc(func(encoder zapcore.ArrayEncoder) error { - for _, marriage := range atx.Marriages { - encoder.AppendObject(&marriage) - } - return nil - })) - if atx.MarriageATX != nil { - encoder.AddString("MarriageATX", atx.MarriageATX.String()) - } - encoder.AddString("Signature", atx.Signature.String()) - return nil +type MarriageCertificateProof []types.Hash32 + +func (p MarriageCertificateProof) Valid(marriageRoot MarriageCertificatesRoot, index int, mc MarriageCertificate) bool { + return validateProof(types.Hash32(marriageRoot), types.Hash32(mc.Root()), p, uint64(index)) +} + +// MarriageCertificate proves the will of ID to be married with the ID that includes this certificate. +// A marriage allows for publishing a merged ATX, which can contain PoST for all married IDs. +// Any ID from the marriage can publish a merged ATX on behalf of all married IDs. +type MarriageCertificate struct { + // An ATX of the NodeID that marries. It proves that the NodeID exists. + // Note: the reference ATX does not need to be from the previous epoch. + // It only needs to prove the existence of the Identity. + ReferenceAtx types.ATXID + // Signature over the other ID that this ID marries with + // If Alice marries Bob, then Alice signs Bob's ID + // and Bob includes this certificate in his ATX. + Signature types.EdSignature } -func (marriage *MarriageCertificate) MarshalLogObject(encoder zapcore.ObjectEncoder) error { - if marriage == nil { +func (mc *MarriageCertificate) MarshalLogObject(encoder zapcore.ObjectEncoder) error { + if mc == nil { return nil } - encoder.AddString("ReferenceATX", marriage.ReferenceAtx.String()) - encoder.AddString("Signature", marriage.Signature.String()) + encoder.AddString("ReferenceATX", mc.ReferenceAtx.String()) + encoder.AddString("Signature", mc.Signature.String()) return nil } -func (parts *InitialAtxPartsV2) MarshalLogObject(encoder zapcore.ObjectEncoder) error { - if parts == nil { - return nil +func (mc *MarriageCertificate) merkleTree(tree *merkle.Tree) { + tree.AddLeaf(mc.ReferenceAtx.Bytes()) + tree.AddLeaf(mc.Signature.Bytes()) +} + +func (mc *MarriageCertificate) merkleProof(leafIndex MarriageCertificateIndex) []types.Hash32 { + return createProof(uint64(leafIndex), mc.merkleTree) +} + +func (mc *MarriageCertificate) Root() types.Hash32 { + return createRoot(mc.merkleTree) +} + +func (mc *MarriageCertificate) ReferenceATXProof() []types.Hash32 { + return mc.merkleProof(ReferenceATXIndex) +} + +func (mc *MarriageCertificate) SignatureProof() []types.Hash32 { + return mc.merkleProof(SignatureIndex) +} + +func atxTreeHash(buf, lChild, rChild []byte) []byte { + hasher := hash.GetHasher() + defer hash.PutHasher(hasher) + hasher.Write([]byte{0x01}) + hasher.Write(lChild) + hasher.Write(rChild) + return hasher.Sum(buf) +} + +func createRoot(addLeaves func(tree *merkle.Tree)) types.Hash32 { + tree, err := merkle.NewTreeBuilder(). + WithHashFunc(atxTreeHash). + Build() + if err != nil { + panic(err) } - encoder.AddString("CommitmentATX", parts.CommitmentATX.String()) - encoder.AddObject("Post", &parts.Post) - return nil + addLeaves(tree) + return types.Hash32(tree.Root()) } -func (post *SubPostV2) MarshalLogObject(encoder zapcore.ObjectEncoder) error { - if post == nil { - return nil +func createProof(leafIndex uint64, addLeaves func(tree *merkle.Tree)) []types.Hash32 { + tree, err := merkle.NewTreeBuilder(). + WithLeavesToProve(map[uint64]bool{uint64(leafIndex): true}). + WithHashFunc(atxTreeHash). + Build() + if err != nil { + panic(err) } - encoder.AddUint32("MarriageIndex", post.MarriageIndex) - encoder.AddUint32("PrevATXIndex", post.PrevATXIndex) - encoder.AddUint64("MembershipLeafIndex", post.MembershipLeafIndex) - encoder.AddObject("Post", &post.Post) - encoder.AddUint32("NumUnits", post.NumUnits) - return nil + addLeaves(tree) + proof := tree.Proof() + proofHashes := make([]types.Hash32, len(proof)) + for i, p := range proof { + proofHashes[i] = types.Hash32(p) + } + return proofHashes } -func (posts *NiPostsV2) MarshalLogObject(encoder zapcore.ObjectEncoder) error { - if posts == nil { - return nil +func validateProof(root, leaf types.Hash32, proof []types.Hash32, leafIndex uint64) bool { + proofBytes := make([][]byte, len(proof)) + for i, h := range proof { + proofBytes[i] = h.Bytes() + } + ok, err := merkle.ValidatePartialTree( + []uint64{leafIndex}, + [][]byte{leaf.Bytes()}, + proofBytes, + root.Bytes(), + atxTreeHash, + ) + if err != nil { + panic(err) } - // skip membership proof - encoder.AddString("Challenge", posts.Challenge.String()) - encoder.AddArray("Posts", zapcore.ArrayMarshalerFunc(func(ae zapcore.ArrayEncoder) error { - for _, post := range posts.Posts { - ae.AppendObject(&post) - } - return nil - })) - return nil + return ok } diff --git a/activation/wire/wire_v2_scale.go b/activation/wire/wire_v2_scale.go index e2140c75ed..46d4acb63f 100644 --- a/activation/wire/wire_v2_scale.go +++ b/activation/wire/wire_v2_scale.go @@ -45,7 +45,7 @@ func (t *ActivationTxV2) EncodeScale(enc *scale.Encoder) (total int, err error) total += n } { - n, err := scale.EncodeStructSliceWithLimit(enc, t.NiPosts, 4) + n, err := scale.EncodeStructSliceWithLimit(enc, t.NIPosts, 4) if err != nil { return total, err } @@ -129,12 +129,12 @@ func (t *ActivationTxV2) DecodeScale(dec *scale.Decoder) (total int, err error) t.PreviousATXs = field } { - field, n, err := scale.DecodeStructSliceWithLimit[NiPostsV2](dec, 4) + field, n, err := scale.DecodeStructSliceWithLimit[NIPostV2](dec, 4) if err != nil { return total, err } total += n - t.NiPosts = field + t.NIPosts = field } { field, n, err := scale.DecodeCompact64(dec) @@ -213,16 +213,23 @@ func (t *InitialAtxPartsV2) DecodeScale(dec *scale.Decoder) (total int, err erro return total, nil } -func (t *MarriageCertificate) EncodeScale(enc *scale.Encoder) (total int, err error) { +func (t *NIPostV2) EncodeScale(enc *scale.Encoder) (total int, err error) { { - n, err := scale.EncodeByteArray(enc, t.ReferenceAtx[:]) + n, err := t.Membership.EncodeScale(enc) if err != nil { return total, err } total += n } { - n, err := scale.EncodeByteArray(enc, t.Signature[:]) + n, err := scale.EncodeByteArray(enc, t.Challenge[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.Posts, 256) if err != nil { return total, err } @@ -231,20 +238,28 @@ func (t *MarriageCertificate) EncodeScale(enc *scale.Encoder) (total int, err er return total, nil } -func (t *MarriageCertificate) DecodeScale(dec *scale.Decoder) (total int, err error) { +func (t *NIPostV2) DecodeScale(dec *scale.Decoder) (total int, err error) { { - n, err := scale.DecodeByteArray(dec, t.ReferenceAtx[:]) + n, err := t.Membership.DecodeScale(dec) if err != nil { return total, err } total += n } { - n, err := scale.DecodeByteArray(dec, t.Signature[:]) + n, err := scale.DecodeByteArray(dec, t.Challenge[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[SubPostV2](dec, 256) if err != nil { return total, err } total += n + t.Posts = field } return total, nil } @@ -354,23 +369,16 @@ func (t *SubPostV2) DecodeScale(dec *scale.Decoder) (total int, err error) { return total, nil } -func (t *NiPostsV2) EncodeScale(enc *scale.Encoder) (total int, err error) { - { - n, err := t.Membership.EncodeScale(enc) - if err != nil { - return total, err - } - total += n - } +func (t *MarriageCertificate) EncodeScale(enc *scale.Encoder) (total int, err error) { { - n, err := scale.EncodeByteArray(enc, t.Challenge[:]) + n, err := scale.EncodeByteArray(enc, t.ReferenceAtx[:]) if err != nil { return total, err } total += n } { - n, err := scale.EncodeStructSliceWithLimit(enc, t.Posts, 256) + n, err := scale.EncodeByteArray(enc, t.Signature[:]) if err != nil { return total, err } @@ -379,28 +387,20 @@ func (t *NiPostsV2) EncodeScale(enc *scale.Encoder) (total int, err error) { return total, nil } -func (t *NiPostsV2) DecodeScale(dec *scale.Decoder) (total int, err error) { - { - n, err := t.Membership.DecodeScale(dec) - if err != nil { - return total, err - } - total += n - } +func (t *MarriageCertificate) DecodeScale(dec *scale.Decoder) (total int, err error) { { - n, err := scale.DecodeByteArray(dec, t.Challenge[:]) + n, err := scale.DecodeByteArray(dec, t.ReferenceAtx[:]) if err != nil { return total, err } total += n } { - field, n, err := scale.DecodeStructSliceWithLimit[SubPostV2](dec, 256) + n, err := scale.DecodeByteArray(dec, t.Signature[:]) if err != nil { return total, err } total += n - t.Posts = field } return total, nil } diff --git a/activation/wire/wire_v2_test.go b/activation/wire/wire_v2_test.go index 40aa02f223..188e19a130 100644 --- a/activation/wire/wire_v2_test.go +++ b/activation/wire/wire_v2_test.go @@ -25,12 +25,70 @@ func withMarriageCertificate(sig *signing.EdSigner, refAtx types.ATXID, atxPubli } } +func withMarriageATX(id types.ATXID) testAtxV2Opt { + return func(atx *ActivationTxV2) { + atx.MarriageATX = &id + } +} + +func withInitial(commitAtx types.ATXID, post PostV1) testAtxV2Opt { + return func(atx *ActivationTxV2) { + atx.Initial = &InitialAtxPartsV2{ + CommitmentATX: commitAtx, + Post: post, + } + } +} + +func withPreviousATXs(atxs ...types.ATXID) testAtxV2Opt { + return func(atx *ActivationTxV2) { + atx.PreviousATXs = atxs + } +} + +func withNIPost(opts ...testNIPostV2Opt) testAtxV2Opt { + return func(atx *ActivationTxV2) { + nipost := &NIPostV2{} + for _, opt := range opts { + opt(nipost) + } + atx.NIPosts = append(atx.NIPosts, *nipost) + } +} + +type testNIPostV2Opt func(*NIPostV2) + +func withNIPostChallenge(challenge types.Hash32) testNIPostV2Opt { + return func(nipost *NIPostV2) { + nipost.Challenge = challenge + } +} + +func withNIPostMembershipProof(proof MerkleProofV2) testNIPostV2Opt { + return func(nipost *NIPostV2) { + nipost.Membership = proof + } +} + +func withNIPostSubPost(subPost SubPostV2) testNIPostV2Opt { + return func(nipost *NIPostV2) { + nipost.Posts = append(nipost.Posts, subPost) + } +} + func newActivationTxV2(opts ...testAtxV2Opt) *ActivationTxV2 { atx := &ActivationTxV2{ PublishEpoch: rand.N(types.EpochID(255)), PositioningATX: types.RandomATXID(), - PreviousATXs: make([]types.ATXID, 1+rand.IntN(255)), - NiPosts: []NiPostsV2{ + } + for _, opt := range opts { + opt(atx) + } + if atx.PreviousATXs == nil { + atx.PreviousATXs = make([]types.ATXID, 1+rand.IntN(255)) + } + if atx.NIPosts == nil { + atx.NIPosts = []NIPostV2{ { Membership: MerkleProofV2{ Nodes: make([]types.Hash32, 32), @@ -48,10 +106,7 @@ func newActivationTxV2(opts ...testAtxV2Opt) *ActivationTxV2 { }, }, }, - }, - } - for _, opt := range opts { - opt(atx) + } } return atx } @@ -78,7 +133,7 @@ func Benchmark_ATXv2ID_WorstScenario(b *testing.B) { PublishEpoch: 0, PositioningATX: types.RandomATXID(), PreviousATXs: make([]types.ATXID, 256), - NiPosts: []NiPostsV2{ + NIPosts: []NIPostV2{ { Membership: MerkleProofV2{ Nodes: make([]types.Hash32, 32), @@ -95,15 +150,15 @@ func Benchmark_ATXv2ID_WorstScenario(b *testing.B) { }, }, } - for i := range atx.NiPosts[0].Posts { - atx.NiPosts[0].Posts[i].Post = PostV1{ + for i := range atx.NIPosts[0].Posts { + atx.NIPosts[0].Posts[i].Post = PostV1{ Nonce: 0, Indices: make([]byte, 800), Pow: 0, } } - for i := range atx.NiPosts[1].Posts { - atx.NiPosts[1].Posts[i].Post = PostV1{ + for i := range atx.NIPosts[1].Posts { + atx.NIPosts[1].Posts[i].Post = PostV1{ Nonce: 0, Indices: make([]byte, 800), Pow: 0, @@ -134,13 +189,13 @@ func Test_ATXv2_SupportUpTo4Niposts(t *testing.T) { f.Fuzz(atx) for i := range 4 { t.Run(fmt.Sprintf("supports %d poet", i), func(t *testing.T) { - atx.NiPosts = make([]NiPostsV2, i) + atx.NIPosts = make([]NIPostV2, i) _, err := codec.Encode(atx) require.NoError(t, err) }) } t.Run("doesn't support > 5 niposts", func(t *testing.T) { - atx.NiPosts = make([]NiPostsV2, 5) + atx.NIPosts = make([]NIPostV2, 5) _, err := codec.Encode(atx) require.Error(t, err) }) diff --git a/api/grpcserver/v2alpha1/transaction.go b/api/grpcserver/v2alpha1/transaction.go index 54f0ccb811..7ceed2c543 100644 --- a/api/grpcserver/v2alpha1/transaction.go +++ b/api/grpcserver/v2alpha1/transaction.go @@ -331,7 +331,7 @@ func toTx(tx *types.MeshTransaction, result *types.TransactionResult, Message: result.Message, GasConsumed: result.Gas, Fee: result.Fee, - Block: result.Block[:], + Block: result.Block.Bytes(), Layer: result.Layer.Uint32(), } if len(result.Addresses) > 0 { diff --git a/checkpoint/util.go b/checkpoint/util.go index f6f5332d7d..df61601b0b 100644 --- a/checkpoint/util.go +++ b/checkpoint/util.go @@ -186,8 +186,8 @@ func poetProofRefs(ctx context.Context, db sql.Executor, id types.ATXID) ([]type if err := codec.Decode(blob.Bytes, &atx); err != nil { return nil, fmt.Errorf("decoding ATX blob: %w", err) } - refs := make([]types.PoetProofRef, len(atx.NiPosts)) - for i, post := range atx.NiPosts { + refs := make([]types.PoetProofRef, len(atx.NIPosts)) + for i, post := range atx.NIPosts { refs[i] = types.PoetProofRef(post.Challenge) } return refs, nil diff --git a/genvm/core/context.go b/genvm/core/context.go index 35cedb6db7..e5f46a1e41 100644 --- a/genvm/core/context.go +++ b/genvm/core/context.go @@ -10,7 +10,7 @@ import ( ) // Context serves 2 purposes: -// - maintains changes to the system state, that will be applied only after succeful execution +// - maintains changes to the system state, that will be applied only after successful execution // - accumulates set of reusable objects and data. type Context struct { Registry HandlerRegistry @@ -36,7 +36,7 @@ type Context struct { consumed uint64 // fee is in coins units fee uint64 - // an amount transfrered to other accounts + // an amount transferred to other accounts transferred uint64 touched []Address diff --git a/go.mod b/go.mod index 93c8c90831..03135e2dbf 100644 --- a/go.mod +++ b/go.mod @@ -44,7 +44,7 @@ require ( github.com/spacemeshos/economics v0.1.4 github.com/spacemeshos/fixed v0.1.2 github.com/spacemeshos/go-scale v1.2.1 - github.com/spacemeshos/merkle-tree v0.2.4 + github.com/spacemeshos/merkle-tree v0.2.5 github.com/spacemeshos/poet v0.10.4 github.com/spacemeshos/post v0.12.10 github.com/spf13/afero v1.11.0 diff --git a/go.sum b/go.sum index 4eca5bbcfc..21153d9a53 100644 --- a/go.sum +++ b/go.sum @@ -637,8 +637,8 @@ github.com/spacemeshos/fixed v0.1.2 h1:pENQ8pXFAqin3f15ZLoOVVeSgcmcFJ0IFdFm4+9u4 github.com/spacemeshos/fixed v0.1.2/go.mod h1:OekUZD7FA9Ji8H/WEf5VuGYxPB+mWfXjbUI7I3qcT48= github.com/spacemeshos/go-scale v1.2.1 h1:+IJ6KmFl9tF1Om8B1NvEwilGpBG1ebr4Se8A0Fe4puE= github.com/spacemeshos/go-scale v1.2.1/go.mod h1:fpO6tCoKdUmvF6o9zkUtq2erSOH5t4ik02Zwdm31qOs= -github.com/spacemeshos/merkle-tree v0.2.4 h1:kA7uRGadeyULOFlBxsWRwN0v1U2B4PD4OluR1l3d4nE= -github.com/spacemeshos/merkle-tree v0.2.4/go.mod h1:yTJd262m3pjnw8mH0j5vt3w0J+2mpM6g0xyc5Pr2zIw= +github.com/spacemeshos/merkle-tree v0.2.5 h1:4iWiW4SvDEBGYRUvFUjArHeTHjvOa52JQ/iLW6wBzUs= +github.com/spacemeshos/merkle-tree v0.2.5/go.mod h1:lxMuC/C2qhN6wdH6iSXW0HM8FS6fnKnyLWjCAKsCtr8= github.com/spacemeshos/poet v0.10.4 h1:MHGG1dhMVwy5DdlsmwdRLDQTTqlPA21lSQB2PVd8MSk= github.com/spacemeshos/poet v0.10.4/go.mod h1:hz21GMyHb9h29CqNhVeKxCD5dxZdQK27nAqLpT46gjE= github.com/spacemeshos/post v0.12.10 h1:S4THKvy/uGdNzoZkTI5qqIo2H8/W4xktKtYzxKsYNVU= diff --git a/signing/signer.go b/signing/signer.go index c07f8589a1..ccd02fcc90 100644 --- a/signing/signer.go +++ b/signing/signer.go @@ -43,6 +43,8 @@ func (d Domain) String() string { return "HARE" case POET: return "POET" + case MARRIAGE: + return "MARRIAGE" case BEACON_FIRST_MSG: return "BEACON_FIRST_MSG" case BEACON_FOLLOWUP_MSG: diff --git a/systest/Makefile b/systest/Makefile index 79c8e808c8..3add184f17 100644 --- a/systest/Makefile +++ b/systest/Makefile @@ -10,7 +10,7 @@ poet_image ?= $(org)/poet:v0.10.3 post_service_image ?= $(org)/post-service:v0.7.13 post_init_image ?= $(org)/postcli:v0.12.5 smesher_image ?= $(org)/go-spacemesh-dev:$(version_info) -old_smesher_image ?= $(org)/go-spacemesh-dev:v1.7.4 +old_smesher_image ?= $(org)/go-spacemesh-dev:e46c154 # Update this when new version is released bs_image ?= $(org)/go-spacemesh-dev-bs:$(version_info) test_id ?= systest-$(version_info)