diff --git a/beacon-chain/core/helpers/BUILD.bazel b/beacon-chain/core/helpers/BUILD.bazel index c33980c260f3..5a30cfac9c0a 100644 --- a/beacon-chain/core/helpers/BUILD.bazel +++ b/beacon-chain/core/helpers/BUILD.bazel @@ -8,6 +8,7 @@ go_library( "block.go", "genesis.go", "metrics.go", + "payload_attestation.go", "randao.go", "rewards_penalties.go", "shuffle.go", @@ -53,6 +54,8 @@ go_test( "attestation_test.go", "beacon_committee_test.go", "block_test.go", + "exports_test.go", + "payload_attestation_test.go", "private_access_fuzz_noop_test.go", # keep "private_access_test.go", "randao_test.go", @@ -83,6 +86,7 @@ go_test( "//testing/assert:go_default_library", "//testing/require:go_default_library", "//testing/util:go_default_library", + "//testing/util/random:go_default_library", "//time:go_default_library", "//time/slots:go_default_library", "@com_github_prysmaticlabs_go_bitfield//:go_default_library", diff --git a/beacon-chain/core/helpers/exports_test.go b/beacon-chain/core/helpers/exports_test.go new file mode 100644 index 000000000000..0ec103bbd6a0 --- /dev/null +++ b/beacon-chain/core/helpers/exports_test.go @@ -0,0 +1,12 @@ +package helpers + +var ( + ErrNilMessage = errNilMessage + ErrNilData = errNilData + ErrNilBeaconBlockRoot = errNilBeaconBlockRoot + ErrNilPayloadAttestation = errNilPayloadAttestation + ErrNilSignature = errNilSignature + ErrNilAggregationBits = errNilAggregationBits + ErrPreEPBSState = errPreEPBSState + ErrCommitteeOverflow = errCommitteeOverflow +) diff --git a/beacon-chain/core/helpers/payload_attestation.go b/beacon-chain/core/helpers/payload_attestation.go new file mode 100644 index 000000000000..f91003cf4af9 --- /dev/null +++ b/beacon-chain/core/helpers/payload_attestation.go @@ -0,0 +1,91 @@ +package helpers + +import ( + "context" + + "github.com/pkg/errors" + "github.com/prysmaticlabs/prysm/v5/beacon-chain/state" + fieldparams "github.com/prysmaticlabs/prysm/v5/config/fieldparams" + "github.com/prysmaticlabs/prysm/v5/consensus-types/primitives" + "github.com/prysmaticlabs/prysm/v5/math" + eth "github.com/prysmaticlabs/prysm/v5/proto/prysm/v1alpha1" + "github.com/prysmaticlabs/prysm/v5/runtime/version" + "github.com/prysmaticlabs/prysm/v5/time/slots" +) + +var ( + errNilMessage = errors.New("nil PayloadAttestationMessage") + errNilData = errors.New("nil PayloadAttestationData") + errNilBeaconBlockRoot = errors.New("nil BeaconBlockRoot") + errNilPayloadAttestation = errors.New("nil PayloadAttestation") + errNilSignature = errors.New("nil Signature") + errNilAggregationBits = errors.New("nil AggregationBits") + errPreEPBSState = errors.New("beacon state pre ePBS fork") + errCommitteeOverflow = errors.New("beacon committee of insufficient size") +) + +// ValidateNilPayloadAttestationData checks if any composite field of the +// payload attestation data is nil +func ValidateNilPayloadAttestationData(data *eth.PayloadAttestationData) error { + if data == nil { + return errNilData + } + if data.BeaconBlockRoot == nil { + return errNilBeaconBlockRoot + } + return nil +} + +// ValidateNilPayloadAttestationMessage checks if any composite field of the +// payload attestation message is nil +func ValidateNilPayloadAttestationMessage(att *eth.PayloadAttestationMessage) error { + if att == nil { + return errNilMessage + } + if att.Signature == nil { + return errNilSignature + } + return ValidateNilPayloadAttestationData(att.Data) +} + +// ValidateNilPayloadAttestation checks if any composite field of the +// payload attestation is nil +func ValidateNilPayloadAttestation(att *eth.PayloadAttestation) error { + if att == nil { + return errNilPayloadAttestation + } + if att.AggregationBits == nil { + return errNilAggregationBits + } + if att.Signature == nil { + return errNilSignature + } + return ValidateNilPayloadAttestationData(att.Data) +} + +// GetPayloadTimelinessCommittee returns the PTC for the given slot, computed from the passed state as in the +// spec function `get_ptc`. +func GetPayloadTimelinessCommittee(ctx context.Context, state state.ReadOnlyBeaconState, slot primitives.Slot) (indices []primitives.ValidatorIndex, err error) { + if state.Version() < version.EPBS { + return nil, errPreEPBSState + } + epoch := slots.ToEpoch(slot) + activeCount, err := ActiveValidatorCount(ctx, state, epoch) + if err != nil { + return nil, errors.Wrap(err, "could not compute active validator count") + } + committeesPerSlot := math.LargestPowerOfTwo(math.Min(SlotCommitteeCount(activeCount), fieldparams.PTCSize)) + membersPerCommittee := fieldparams.PTCSize / committeesPerSlot + for i := uint64(0); i <= committeesPerSlot; i++ { + committee, err := BeaconCommitteeFromState(ctx, state, slot, primitives.CommitteeIndex(i)) + if err != nil { + return nil, err + } + if uint64(len(committee)) < membersPerCommittee { + return nil, errCommitteeOverflow + } + start := uint64(len(committee)) - membersPerCommittee + indices = append(indices, committee[start:]...) + } + return +} diff --git a/beacon-chain/core/helpers/payload_attestation_test.go b/beacon-chain/core/helpers/payload_attestation_test.go new file mode 100644 index 000000000000..5b9abda5817b --- /dev/null +++ b/beacon-chain/core/helpers/payload_attestation_test.go @@ -0,0 +1,92 @@ +package helpers_test + +import ( + "context" + "strconv" + "testing" + + "github.com/prysmaticlabs/go-bitfield" + "github.com/prysmaticlabs/prysm/v5/beacon-chain/core/helpers" + state_native "github.com/prysmaticlabs/prysm/v5/beacon-chain/state/state-native" + fieldparams "github.com/prysmaticlabs/prysm/v5/config/fieldparams" + "github.com/prysmaticlabs/prysm/v5/config/params" + "github.com/prysmaticlabs/prysm/v5/math" + eth "github.com/prysmaticlabs/prysm/v5/proto/prysm/v1alpha1" + ethpb "github.com/prysmaticlabs/prysm/v5/proto/prysm/v1alpha1" + "github.com/prysmaticlabs/prysm/v5/testing/require" + "github.com/prysmaticlabs/prysm/v5/testing/util/random" + "github.com/prysmaticlabs/prysm/v5/time/slots" +) + +func TestValidateNilPayloadAttestation(t *testing.T) { + require.ErrorIs(t, helpers.ErrNilData, helpers.ValidateNilPayloadAttestationData(nil)) + data := ð.PayloadAttestationData{} + require.ErrorIs(t, helpers.ErrNilBeaconBlockRoot, helpers.ValidateNilPayloadAttestationData(data)) + data.BeaconBlockRoot = make([]byte, 32) + require.NoError(t, helpers.ValidateNilPayloadAttestationData(data)) + + require.ErrorIs(t, helpers.ErrNilMessage, helpers.ValidateNilPayloadAttestationMessage(nil)) + message := ð.PayloadAttestationMessage{} + require.ErrorIs(t, helpers.ErrNilSignature, helpers.ValidateNilPayloadAttestationMessage(message)) + message.Signature = make([]byte, 96) + require.ErrorIs(t, helpers.ErrNilData, helpers.ValidateNilPayloadAttestationMessage(message)) + message.Data = data + require.NoError(t, helpers.ValidateNilPayloadAttestationMessage(message)) + + require.ErrorIs(t, helpers.ErrNilPayloadAttestation, helpers.ValidateNilPayloadAttestation(nil)) + att := ð.PayloadAttestation{} + require.ErrorIs(t, helpers.ErrNilAggregationBits, helpers.ValidateNilPayloadAttestation(att)) + att.AggregationBits = bitfield.NewBitvector512() + require.ErrorIs(t, helpers.ErrNilSignature, helpers.ValidateNilPayloadAttestation(att)) + att.Signature = message.Signature + require.ErrorIs(t, helpers.ErrNilData, helpers.ValidateNilPayloadAttestation(att)) + att.Data = data + require.NoError(t, helpers.ValidateNilPayloadAttestation(att)) +} + +func TestGetPayloadTimelinessCommittee(t *testing.T) { + helpers.ClearCache() + + // Create 10 committees + committeeCount := uint64(10) + validatorCount := committeeCount * params.BeaconConfig().TargetCommitteeSize * uint64(params.BeaconConfig().SlotsPerEpoch) + validators := make([]*ethpb.Validator, validatorCount) + + for i := 0; i < len(validators); i++ { + k := make([]byte, 48) + copy(k, strconv.Itoa(i)) + validators[i] = ðpb.Validator{ + PublicKey: k, + WithdrawalCredentials: make([]byte, 32), + ExitEpoch: params.BeaconConfig().FarFutureEpoch, + } + } + + state, err := state_native.InitializeFromProtoEpbs(random.BeaconState(t)) + require.NoError(t, err) + require.NoError(t, state.SetValidators(validators)) + require.NoError(t, state.SetSlot(200)) + + ctx := context.Background() + indices, err := helpers.BeaconCommitteeFromState(ctx, state, state.Slot(), 1) + require.NoError(t, err) + require.Equal(t, 128, len(indices)) + + epoch := slots.ToEpoch(state.Slot()) + activeCount, err := helpers.ActiveValidatorCount(ctx, state, epoch) + require.NoError(t, err) + require.Equal(t, uint64(40960), activeCount) + + computedCommitteeCount := helpers.SlotCommitteeCount(activeCount) + require.Equal(t, committeeCount, computedCommitteeCount) + committeesPerSlot := math.LargestPowerOfTwo(math.Min(committeeCount, fieldparams.PTCSize)) + require.Equal(t, uint64(8), committeesPerSlot) + + ptc, err := helpers.GetPayloadTimelinessCommittee(ctx, state, state.Slot()) + require.NoError(t, err) + + committee1, err := helpers.BeaconCommitteeFromState(ctx, state, state.Slot(), 0) + require.NoError(t, err) + + require.DeepEqual(t, committee1[len(committee1)-64:], ptc[:64]) +} diff --git a/math/math_helper.go b/math/math_helper.go index e4a59e21d6f0..9e9555159cb9 100644 --- a/math/math_helper.go +++ b/math/math_helper.go @@ -116,6 +116,21 @@ func PowerOf2(n uint64) uint64 { return 1 << n } +// LargestPowerOfTwo returns the largest power of 2 that is lower or equal than +// the parameter +func LargestPowerOfTwo(n uint64) uint64 { + if n == 0 { + return 0 + } + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + return n - (n >> 1) +} + // Max returns the larger integer of the two // given ones.This is used over the Max function // in the standard math library because that max function diff --git a/math/math_helper_test.go b/math/math_helper_test.go index 8ee06bbd2ed2..efe2b0f40fd8 100644 --- a/math/math_helper_test.go +++ b/math/math_helper_test.go @@ -549,3 +549,27 @@ func TestAddInt(t *testing.T) { }) } } + +func TestLargestPowerOfTwo(t *testing.T) { + testCases := []struct { + name string + input uint64 + expected uint64 + }{ + {"Zero", 0, 0}, + {"One", 1, 1}, + {"Just below power of two", 14, 8}, + {"Power of two", 16, 16}, + {"Large number", 123456789, 67108864}, + {"Max uint64", 18446744073709551615, 9223372036854775808}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := math.LargestPowerOfTwo(tc.input) + if result != tc.expected { + t.Errorf("For input %d, expected %d but got %d", tc.input, tc.expected, result) + } + }) + } +}