Skip to content

Commit

Permalink
EIP-7549: Attestation packing (#14238)
Browse files Browse the repository at this point in the history
* EIP-7549: Attestation packing

* new files

* change var name

* test fixes

* enhance comment

* unit test for Deneb state
  • Loading branch information
rkapka authored Jul 19, 2024
1 parent 57ffc12 commit 49055ac
Show file tree
Hide file tree
Showing 6 changed files with 313 additions and 9 deletions.
6 changes: 3 additions & 3 deletions beacon-chain/operations/attestations/prepare_forkchoice.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (s *Service) batchForkChoiceAtts(ctx context.Context) error {
atts := append(s.cfg.Pool.AggregatedAttestations(), s.cfg.Pool.BlockAttestations()...)
atts = append(atts, s.cfg.Pool.ForkchoiceAttestations()...)

attsByVerAndDataRoot := make(map[attestation.Id][]ethpb.Att, len(atts))
attsById := make(map[attestation.Id][]ethpb.Att, len(atts))

// Consolidate attestations by aggregating them by similar data root.
for _, att := range atts {
Expand All @@ -83,10 +83,10 @@ func (s *Service) batchForkChoiceAtts(ctx context.Context) error {
if err != nil {
return errors.Wrap(err, "could not create attestation ID")
}
attsByVerAndDataRoot[id] = append(attsByVerAndDataRoot[id], att)
attsById[id] = append(attsById[id], att)
}

for _, atts := range attsByVerAndDataRoot {
for _, atts := range attsById {
if err := s.aggregateAndSaveForkChoiceAtts(atts); err != nil {
return err
}
Expand Down
3 changes: 3 additions & 0 deletions beacon-chain/rpc/prysm/v1alpha1/validator/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ go_library(
"proposer.go",
"proposer_altair.go",
"proposer_attestations.go",
"proposer_attestations_electra.go",
"proposer_bellatrix.go",
"proposer_builder.go",
"proposer_capella.go",
Expand Down Expand Up @@ -147,6 +148,7 @@ common_deps = [
"//consensus-types/primitives:go_default_library",
"//container/trie:go_default_library",
"//crypto/bls:go_default_library",
"//crypto/bls/blst:go_default_library",
"//encoding/bytesutil:go_default_library",
"//encoding/ssz:go_default_library",
"//proto/engine/v1:go_default_library",
Expand Down Expand Up @@ -186,6 +188,7 @@ go_test(
"duties_test.go",
"exit_test.go",
"proposer_altair_test.go",
"proposer_attestations_electra_test.go",
"proposer_attestations_test.go",
"proposer_bellatrix_test.go",
"proposer_builder_test.go",
Expand Down
32 changes: 27 additions & 5 deletions beacon-chain/rpc/prysm/v1alpha1/validator/proposer_attestations.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ func (vs *Server) packAttestations(ctx context.Context, latestState state.Beacon
}
atts = append(atts, uAtts...)

// Checking the state's version here will give the wrong result if the last slot of Deneb is missed.
// The head state will still be in Deneb while we are trying to build an Electra block.
postElectra := slots.ToEpoch(blkSlot) >= params.BeaconConfig().ElectraForkEpoch

versionAtts := make([]ethpb.Att, 0, len(atts))
Expand All @@ -66,23 +68,43 @@ func (vs *Server) packAttestations(ctx context.Context, latestState state.Beacon
return nil, err
}

attsByDataRoot := make(map[attestation.Id][]ethpb.Att, len(versionAtts))
attsById := make(map[attestation.Id][]ethpb.Att, len(versionAtts))
for _, att := range versionAtts {
id, err := attestation.NewId(att, attestation.Data)
if err != nil {
return nil, errors.Wrap(err, "could not create attestation ID")
}
attsByDataRoot[id] = append(attsByDataRoot[id], att)
attsById[id] = append(attsById[id], att)
}

attsForInclusion := proposerAtts(make([]ethpb.Att, 0))
for _, as := range attsByDataRoot {
for id, as := range attsById {
as, err := attaggregation.Aggregate(as)
if err != nil {
return nil, err
}
attsForInclusion = append(attsForInclusion, as...)
attsById[id] = as
}

var attsForInclusion proposerAtts
if postElectra {
// TODO: hack for Electra devnet-1, take only one aggregate per ID
// (which essentially means one aggregate for an attestation_data+committee combination
topAggregates := make([]ethpb.Att, 0)
for _, v := range attsById {
topAggregates = append(topAggregates, v[0])
}

attsForInclusion, err = computeOnChainAggregate(topAggregates)
if err != nil {
return nil, err
}
} else {
attsForInclusion = make([]ethpb.Att, 0)
for _, as := range attsById {
attsForInclusion = append(attsForInclusion, as...)
}
}

deduped, err := attsForInclusion.dedup()
if err != nil {
return nil, err
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package validator

import (
"slices"

"github.com/prysmaticlabs/go-bitfield"
"github.com/prysmaticlabs/prysm/v5/beacon-chain/core/helpers"
"github.com/prysmaticlabs/prysm/v5/consensus-types/primitives"
"github.com/prysmaticlabs/prysm/v5/crypto/bls"
ethpb "github.com/prysmaticlabs/prysm/v5/proto/prysm/v1alpha1"
)

// computeOnChainAggregate constructs a final aggregate form a list of network aggregates with equal attestation data.
// It assumes that each network aggregate has exactly one committee bit set.
//
// Spec definition:
//
// def compute_on_chain_aggregate(network_aggregates: Sequence[Attestation]) -> Attestation:
// aggregates = sorted(network_aggregates, key=lambda a: get_committee_indices(a.committee_bits)[0])
//
// data = aggregates[0].data
// aggregation_bits = Bitlist[MAX_VALIDATORS_PER_COMMITTEE * MAX_COMMITTEES_PER_SLOT]()
// for a in aggregates:
// for b in a.aggregation_bits:
// aggregation_bits.append(b)
//
// signature = bls.Aggregate([a.signature for a in aggregates])
//
// committee_indices = [get_committee_indices(a.committee_bits)[0] for a in aggregates]
// committee_flags = [(index in committee_indices) for index in range(0, MAX_COMMITTEES_PER_SLOT)]
// committee_bits = Bitvector[MAX_COMMITTEES_PER_SLOT](committee_flags)
//
// return Attestation(
// aggregation_bits=aggregation_bits,
// data=data,
// committee_bits=committee_bits,
// signature=signature,
// )
func computeOnChainAggregate(aggregates []ethpb.Att) ([]ethpb.Att, error) {
aggsByDataRoot := make(map[[32]byte][]ethpb.Att)
for _, agg := range aggregates {
key, err := agg.GetData().HashTreeRoot()
if err != nil {
return nil, err
}
existing, ok := aggsByDataRoot[key]
if ok {
aggsByDataRoot[key] = append(existing, agg)
} else {
aggsByDataRoot[key] = []ethpb.Att{agg}
}
}

result := make([]ethpb.Att, 0)

for _, aggs := range aggsByDataRoot {
slices.SortFunc(aggs, func(a, b ethpb.Att) int {
return a.CommitteeBitsVal().BitIndices()[0] - b.CommitteeBitsVal().BitIndices()[0]
})

sigs := make([]bls.Signature, len(aggs))
committeeIndices := make([]primitives.CommitteeIndex, len(aggs))
aggBitsIndices := make([]uint64, 0)
aggBitsOffset := uint64(0)
var err error
for i, a := range aggs {
for _, bi := range a.GetAggregationBits().BitIndices() {
aggBitsIndices = append(aggBitsIndices, uint64(bi)+aggBitsOffset)
}
sigs[i], err = bls.SignatureFromBytes(a.GetSignature())
if err != nil {
return nil, err
}
committeeIndices[i] = helpers.CommitteeIndices(a.CommitteeBitsVal())[0]

aggBitsOffset += a.GetAggregationBits().Len()
}

aggregationBits := bitfield.NewBitlist(aggBitsOffset)
for _, bi := range aggBitsIndices {
aggregationBits.SetBitAt(bi, true)
}

cb := primitives.NewAttestationCommitteeBits()
att := &ethpb.AttestationElectra{
AggregationBits: aggregationBits,
Data: aggs[0].GetData(),
CommitteeBits: cb,
Signature: bls.AggregateSignatures(sigs).Marshal(),
}
for _, ci := range committeeIndices {
att.CommitteeBits.SetBitAt(uint64(ci), true)
}
result = append(result, att)
}

return result, nil
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
package validator

import (
"reflect"
"testing"

"github.com/prysmaticlabs/go-bitfield"
"github.com/prysmaticlabs/prysm/v5/config/params"
"github.com/prysmaticlabs/prysm/v5/consensus-types/primitives"
"github.com/prysmaticlabs/prysm/v5/crypto/bls/blst"
"github.com/prysmaticlabs/prysm/v5/encoding/bytesutil"
ethpb "github.com/prysmaticlabs/prysm/v5/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/v5/testing/assert"
"github.com/prysmaticlabs/prysm/v5/testing/require"
)

func Test_computeOnChainAggregate(t *testing.T) {
params.SetupTestConfigCleanup(t)
cfg := params.MainnetConfig().Copy()
cfg.MaxCommitteesPerSlot = 64
params.OverrideBeaconConfig(cfg)

key, err := blst.RandKey()
require.NoError(t, err)
sig := key.Sign([]byte{'X'})

data1 := &ethpb.AttestationData{
Slot: 123,
CommitteeIndex: 123,
BeaconBlockRoot: bytesutil.PadTo([]byte("root"), 32),
Source: &ethpb.Checkpoint{
Epoch: 123,
Root: bytesutil.PadTo([]byte("root"), 32),
},
Target: &ethpb.Checkpoint{
Epoch: 123,
Root: bytesutil.PadTo([]byte("root"), 32),
},
}
data2 := &ethpb.AttestationData{
Slot: 456,
CommitteeIndex: 456,
BeaconBlockRoot: bytesutil.PadTo([]byte("root"), 32),
Source: &ethpb.Checkpoint{
Epoch: 456,
Root: bytesutil.PadTo([]byte("root"), 32),
},
Target: &ethpb.Checkpoint{
Epoch: 456,
Root: bytesutil.PadTo([]byte("root"), 32),
},
}

t.Run("single aggregate", func(t *testing.T) {
cb := primitives.NewAttestationCommitteeBits()
cb.SetBitAt(0, true)
att := &ethpb.AttestationElectra{
AggregationBits: bitfield.Bitlist{0b00011111},
Data: data1,
CommitteeBits: cb,
Signature: sig.Marshal(),
}
result, err := computeOnChainAggregate([]ethpb.Att{att})
require.NoError(t, err)
require.Equal(t, 1, len(result))
assert.DeepEqual(t, att.AggregationBits, result[0].GetAggregationBits())
assert.DeepEqual(t, att.Data, result[0].GetData())
assert.DeepEqual(t, att.CommitteeBits, result[0].CommitteeBitsVal())
})
t.Run("all aggregates for one root", func(t *testing.T) {
cb := primitives.NewAttestationCommitteeBits()
cb.SetBitAt(0, true)
att1 := &ethpb.AttestationElectra{
AggregationBits: bitfield.Bitlist{0b00010011}, // aggregation bits 0,1
Data: data1,
CommitteeBits: cb,
Signature: sig.Marshal(),
}
cb = primitives.NewAttestationCommitteeBits()
cb.SetBitAt(1, true)
att2 := &ethpb.AttestationElectra{
AggregationBits: bitfield.Bitlist{0b00010011}, // aggregation bits 0,1
Data: data1,
CommitteeBits: cb,
Signature: sig.Marshal(),
}
result, err := computeOnChainAggregate([]ethpb.Att{att1, att2})
require.NoError(t, err)
require.Equal(t, 1, len(result))
assert.DeepEqual(t, bitfield.Bitlist{0b00110011, 0b00000001}, result[0].GetAggregationBits())
assert.DeepEqual(t, data1, result[0].GetData())
cb = primitives.NewAttestationCommitteeBits()
cb.SetBitAt(0, true)
cb.SetBitAt(1, true)
assert.DeepEqual(t, cb, result[0].CommitteeBitsVal())
})
t.Run("aggregates for multiple roots", func(t *testing.T) {
cb := primitives.NewAttestationCommitteeBits()
cb.SetBitAt(0, true)
att1 := &ethpb.AttestationElectra{
AggregationBits: bitfield.Bitlist{0b00010011}, // aggregation bits 0,1
Data: data1,
CommitteeBits: cb,
Signature: sig.Marshal(),
}
cb = primitives.NewAttestationCommitteeBits()
cb.SetBitAt(1, true)
att2 := &ethpb.AttestationElectra{
AggregationBits: bitfield.Bitlist{0b00010011}, // aggregation bits 0,1
Data: data1,
CommitteeBits: cb,
Signature: sig.Marshal(),
}
cb = primitives.NewAttestationCommitteeBits()
cb.SetBitAt(0, true)
att3 := &ethpb.AttestationElectra{
AggregationBits: bitfield.Bitlist{0b00011001}, // aggregation bits 0,3
Data: data2,
CommitteeBits: cb,
Signature: sig.Marshal(),
}
cb = primitives.NewAttestationCommitteeBits()
cb.SetBitAt(1, true)
att4 := &ethpb.AttestationElectra{
AggregationBits: bitfield.Bitlist{0b00010010}, // aggregation bits 1
Data: data2,
CommitteeBits: cb,
Signature: sig.Marshal(),
}
result, err := computeOnChainAggregate([]ethpb.Att{att1, att2, att3, att4})
require.NoError(t, err)
require.Equal(t, 2, len(result))
cb = primitives.NewAttestationCommitteeBits()
cb.SetBitAt(0, true)
cb.SetBitAt(1, true)

expectedAggBits := bitfield.Bitlist{0b00110011, 0b00000001}
expectedData := data1
found := false
for _, a := range result {
if reflect.DeepEqual(expectedAggBits, a.GetAggregationBits()) && reflect.DeepEqual(expectedData, a.GetData()) && reflect.DeepEqual(cb, a.CommitteeBitsVal()) {
found = true
break
}
}
if !found {
t.Error("Expected aggregate not found")
}

expectedAggBits = bitfield.Bitlist{0b00101001, 0b00000001}
expectedData = data2
found = false
for _, a := range result {
if reflect.DeepEqual(expectedAggBits, a.GetAggregationBits()) && reflect.DeepEqual(expectedData, a.GetData()) && reflect.DeepEqual(cb, a.CommitteeBitsVal()) {
found = true
break
}
}
if !found {
t.Error("Expected aggregate not found")
}
})
}
Loading

0 comments on commit 49055ac

Please sign in to comment.