Skip to content

Commit

Permalink
fix(dot/state): store raw authority keys and decode when verifying bl…
Browse files Browse the repository at this point in the history
…ock signature (#3627)

Co-authored-by: Timothy Wu <tim.wu@chainsafe.io>
  • Loading branch information
EclesioMeloJunior and timwu20 authored Dec 7, 2023
1 parent 1126ad4 commit 58f741d
Show file tree
Hide file tree
Showing 17 changed files with 177 additions and 241 deletions.
5 changes: 2 additions & 3 deletions dot/digest/digest_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ func TestHandler_HandleNextEpochData(t *testing.T) {

handler.handleBlockFinalisation(ctx)

stored, err := handler.epochState.(*state.EpochState).GetEpochData(targetEpoch, nil)
stored, err := handler.epochState.(*state.EpochState).GetEpochDataRaw(targetEpoch, nil)
require.NoError(t, err)

digestValue, err := digest.Value()
Expand All @@ -326,8 +326,7 @@ func TestHandler_HandleNextEpochData(t *testing.T) {
t.Fatal()
}

res, err := act.ToEpochData()
require.NoError(t, err)
res := act.ToEpochDataRaw()
require.Equal(t, res, stored)
}

Expand Down
61 changes: 23 additions & 38 deletions dot/state/epoch.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,12 @@ func NewEpochStateFromGenesis(db database.Database, blockState *BlockState,
nextConfigData: make(nextEpochMap[types.NextConfigDataV1]),
}

auths, err := types.BABEAuthorityRawToAuthority(genesisConfig.GenesisAuthorities)
if err != nil {
return nil, err
epochDataRaw := &types.EpochDataRaw{
Authorities: genesisConfig.GenesisAuthorities,
Randomness: genesisConfig.Randomness,
}

err = s.SetEpochData(0, &types.EpochData{
Authorities: auths,
Randomness: genesisConfig.Randomness,
})
err = s.SetEpochDataRaw(0, epochDataRaw)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -235,10 +232,8 @@ func (s *EpochState) GetEpochForBlock(header *types.Header) (uint64, error) {
return 0, errNoPreRuntimeDigest
}

// SetEpochData sets the epoch data for a given epoch
func (s *EpochState) SetEpochData(epoch uint64, info *types.EpochData) error {
raw := info.ToEpochDataRaw()

// SetEpochDataRaw sets the epoch data raw for a given epoch
func (s *EpochState) SetEpochDataRaw(epoch uint64, raw *types.EpochDataRaw) error {
enc, err := scale.Marshal(*raw)
if err != nil {
return err
Expand All @@ -247,17 +242,17 @@ func (s *EpochState) SetEpochData(epoch uint64, info *types.EpochData) error {
return s.db.Put(epochDataKey(epoch), enc)
}

// GetEpochData returns the epoch data for a given epoch persisted in database
// GetEpochDataRaw returns the raw epoch data for a given epoch persisted in database
// otherwise will try to get the data from the in-memory map using the header
// if the header params is nil then it will search only in database
func (s *EpochState) GetEpochData(epoch uint64, header *types.Header) (*types.EpochData, error) {
epochData, err := s.getEpochDataFromDatabase(epoch)
func (s *EpochState) GetEpochDataRaw(epoch uint64, header *types.Header) (*types.EpochDataRaw, error) {
epochDataRaw, err := s.getEpochDataRawFromDatabase(epoch)
if err != nil && !errors.Is(err, database.ErrNotFound) {
return nil, fmt.Errorf("failed to retrieve epoch data from database: %w", err)
}

if epochData != nil {
return epochData, nil
if epochDataRaw != nil {
return epochDataRaw, nil
}

if header == nil {
Expand All @@ -272,38 +267,33 @@ func (s *EpochState) GetEpochData(epoch uint64, header *types.Header) (*types.Ep
return nil, fmt.Errorf("failed to get epoch data from memory: %w", err)
}

epochData, err = inMemoryEpochData.ToEpochData()
if err != nil {
return nil, fmt.Errorf("cannot transform into epoch data: %w", err)
}

return epochData, nil
return inMemoryEpochData.ToEpochDataRaw(), nil
}

// getEpochDataFromDatabase returns the epoch data for a given epoch persisted in database
func (s *EpochState) getEpochDataFromDatabase(epoch uint64) (*types.EpochData, error) {
// getEpochDataRawFromDatabase returns the epoch data for a given epoch persisted in database
func (s *EpochState) getEpochDataRawFromDatabase(epoch uint64) (*types.EpochDataRaw, error) {
enc, err := s.db.Get(epochDataKey(epoch))
if err != nil {
return nil, err
}

raw := &types.EpochDataRaw{}
raw := new(types.EpochDataRaw)
err = scale.Unmarshal(enc, raw)
if err != nil {
return nil, err
return nil, fmt.Errorf("unmarshaling into epoch data raw: %w", err)
}

return raw.ToEpochData()
return raw, nil
}

// GetLatestEpochData returns the EpochData for the current epoch
func (s *EpochState) GetLatestEpochData() (*types.EpochData, error) {
// GetLatestEpochDataRaw returns the EpochData for the current epoch
func (s *EpochState) GetLatestEpochDataRaw() (*types.EpochDataRaw, error) {
curr, err := s.GetCurrentEpoch()
if err != nil {
return nil, err
}

return s.GetEpochData(curr, nil)
return s.GetEpochDataRaw(curr, nil)
}

// SetConfigData sets the BABE config data for a given epoch
Expand Down Expand Up @@ -586,7 +576,7 @@ func (s *EpochState) FinalizeBABENextEpochData(finalizedHeader *types.Header) er
nextEpoch = finalizedBlockEpoch + 1
}

epochInDatabase, err := s.getEpochDataFromDatabase(nextEpoch)
epochRawInDatabase, err := s.getEpochDataRawFromDatabase(nextEpoch)

// if an error occurs and the error is database.ErrNotFound we ignore
// since this error is what we will handle in the next lines
Expand All @@ -595,7 +585,7 @@ func (s *EpochState) FinalizeBABENextEpochData(finalizedHeader *types.Header) er
}

// epoch data already defined we don't need to lookup in the map
if epochInDatabase != nil {
if epochRawInDatabase != nil {
return nil
}

Expand All @@ -604,12 +594,7 @@ func (s *EpochState) FinalizeBABENextEpochData(finalizedHeader *types.Header) er
return fmt.Errorf("cannot find next epoch data: %w", err)
}

ed, err := finalizedNextEpochData.ToEpochData()
if err != nil {
return fmt.Errorf("cannot transform epoch data: %w", err)
}

err = s.SetEpochData(nextEpoch, ed)
err = s.SetEpochDataRaw(nextEpoch, finalizedNextEpochData.ToEpochDataRaw())
if err != nil {
return fmt.Errorf("cannot set epoch data: %w", err)
}
Expand Down
32 changes: 13 additions & 19 deletions dot/state/epoch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,46 +58,42 @@ func TestEpochState_EpochData(t *testing.T) {
keyring, err := keystore.NewSr25519Keyring()
require.NoError(t, err)

auth := types.Authority{
Key: keyring.Alice().Public().(*sr25519.PublicKey),
auth := types.AuthorityRaw{
Key: keyring.Alice().Public().(*sr25519.PublicKey).AsBytes(),
Weight: 1,
}

info := &types.EpochData{
Authorities: []types.Authority{auth},
info := &types.EpochDataRaw{
Authorities: []types.AuthorityRaw{auth},
Randomness: [32]byte{77},
}

err = s.SetEpochData(1, info)
err = s.SetEpochDataRaw(1, info)
require.NoError(t, err)
res, err := s.GetEpochData(1, nil)
res, err := s.GetEpochDataRaw(1, nil)
require.NoError(t, err)
require.Equal(t, info.Randomness, res.Randomness)

for i, auth := range res.Authorities {
expected, err := info.Authorities[i].Encode()
require.NoError(t, err)
res, err := auth.Encode()
require.NoError(t, err)
require.Equal(t, expected, res)
require.Equal(t, info.Authorities[i], auth)
}
}

func TestEpochState_GetStartSlotForEpoch(t *testing.T) {
s := newEpochStateFromGenesis(t)

info := &types.EpochData{
info := &types.EpochDataRaw{
Randomness: [32]byte{77},
}

err := s.SetEpochData(2, info)
err := s.SetEpochDataRaw(2, info)
require.NoError(t, err)

info = &types.EpochData{
info = &types.EpochDataRaw{
Randomness: [32]byte{77},
}

err = s.SetEpochData(3, info)
err = s.SetEpochDataRaw(3, info)
require.NoError(t, err)

start, err := s.GetStartSlotForEpoch(0)
Expand Down Expand Up @@ -405,10 +401,8 @@ func TestStoreAndFinalizeBabeNextEpochData(t *testing.T) {
} else {
require.NoError(t, err)

expected, err := expectedNextEpochData.ToEpochData()
require.NoError(t, err)

gotNextEpochData, err := epochState.GetEpochData(tt.finalizeEpoch, nil)
expected := expectedNextEpochData.ToEpochDataRaw()
gotNextEpochData, err := epochState.GetEpochDataRaw(tt.finalizeEpoch, nil)
require.NoError(t, err)

require.Equal(t, expected, gotNextEpochData)
Expand Down
13 changes: 4 additions & 9 deletions dot/types/consensus_digest.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,11 @@ func (d NextEpochData) String() string { //skipcq: GO-W1029
}

// ToEpochData returns the NextEpochData as EpochData
func (d *NextEpochData) ToEpochData() (*EpochData, error) { //skipcq: GO-W1029
auths, err := BABEAuthorityRawToAuthority(d.Authorities)
if err != nil {
return nil, err
}

return &EpochData{
Authorities: auths,
func (d *NextEpochData) ToEpochDataRaw() *EpochDataRaw {
return &EpochDataRaw{
Authorities: d.Authorities,
Randomness: d.Randomness,
}, nil
}
}

// BABEOnDisabled represents a GRANDPA authority being disabled
Expand Down
12 changes: 4 additions & 8 deletions lib/babe/babe.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,28 +258,24 @@ func (b *Service) Stop() error {
}

// Authorities returns the current BABE authorities
func (b *Service) Authorities() []types.Authority {
auths := make([]types.Authority, len(b.epochHandler.epochData.authorities))
for i, auth := range b.epochHandler.epochData.authorities {
auths[i] = *auth.DeepCopy()
}
return auths
func (b *Service) AuthoritiesRaw() []types.AuthorityRaw {
return b.epochHandler.epochData.authorities
}

// IsStopped returns true if the service is stopped (ie not producing blocks)
func (b *Service) IsStopped() bool {
return b.ctx.Err() != nil
}

func (b *Service) getAuthorityIndex(Authorities []types.Authority) (uint32, error) {
func (b *Service) getAuthorityIndex(Authorities []types.AuthorityRaw) (uint32, error) {
if !b.authority {
return 0, ErrNotAuthority
}

pub := b.keypair.Public()

for i, auth := range Authorities {
if bytes.Equal(pub.Encode(), auth.Key.Encode()) {
if bytes.Equal(pub.Encode(), auth.Key[:]) {
return uint32(i), nil
}
}
Expand Down
6 changes: 3 additions & 3 deletions lib/babe/babe_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ func TestService_GetAuthorityIndex(t *testing.T) {
pubA := kpA.Public().(*sr25519.PublicKey)
pubB := kpB.Public().(*sr25519.PublicKey)

authData := []types.Authority{
{Key: pubA, Weight: 1},
{Key: pubB, Weight: 1},
authData := []types.AuthorityRaw{
{Key: pubA.AsBytes(), Weight: 1},
{Key: pubB.AsBytes(), Weight: 1},
}

bs := &Service{
Expand Down
6 changes: 3 additions & 3 deletions lib/babe/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func checkPrimaryThreshold(randomness Randomness,

func claimSecondarySlotVRF(randomness Randomness,
slot, epoch uint64,
authorities []types.Authority,
authorities []types.AuthorityRaw,
keypair *sr25519.Keypair,
authorityIndex uint32,
) (*VrfOutputAndProof, error) {
Expand Down Expand Up @@ -123,8 +123,8 @@ func claimSecondarySlotVRF(randomness Randomness,
}, nil
}

func claimSecondarySlotPlain(randomness Randomness, slot uint64, authorities []types.Authority, authorityIndex uint32,
) error {
func claimSecondarySlotPlain(randomness Randomness, slot uint64,
authorities []types.AuthorityRaw, authorityIndex uint32) error {
secondarySlotAuthor, err := getSecondarySlotAuthor(slot, len(authorities), randomness)
if err != nil {
return fmt.Errorf("cannot get secondary slot author: %w", err)
Expand Down
8 changes: 4 additions & 4 deletions lib/babe/epoch.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (b *Service) getEpochData(epoch uint64, bestBlock *types.Header) (*epochDat
return epochData, nil
}

currEpochData, err := b.epochState.GetEpochData(epoch, bestBlock)
currEpochData, err := b.epochState.GetEpochDataRaw(epoch, bestBlock)
if err != nil {
return nil, fmt.Errorf("cannot get epoch data for epoch %d: %w", epoch, err)
}
Expand Down Expand Up @@ -127,13 +127,13 @@ func (b *Service) getEpochData(epoch uint64, bestBlock *types.Header) (*epochDat
func (b *Service) getLatestEpochData() (resEpochData *epochData, error error) {
resEpochData = &epochData{}

epochData, err := b.epochState.GetLatestEpochData()
epochDataRaw, err := b.epochState.GetLatestEpochDataRaw()
if err != nil {
return nil, fmt.Errorf("cannot get latest epoch data: %w", err)
}

resEpochData.randomness = epochData.Randomness
resEpochData.authorities = epochData.Authorities
resEpochData.randomness = epochDataRaw.Randomness
resEpochData.authorities = epochDataRaw.Authorities

configData, err := b.epochState.GetLatestConfigData()
if err != nil {
Expand Down
14 changes: 10 additions & 4 deletions lib/babe/epoch_handler_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ func TestEpochHandler_run_shouldReturnAfterContextCancel(t *testing.T) {
epochData := &epochData{
threshold: scale.MaxUint128,
authorityIndex: authorityIndex,
authorities: []types.Authority{
*types.NewAuthority(aliceKeyPair.Public(), 1),
authorities: []types.AuthorityRaw{
{
Key: [32]byte(aliceKeyPair.Public().Encode()),
Weight: 1,
},
},
}

Expand Down Expand Up @@ -66,8 +69,11 @@ func TestEpochHandler_run(t *testing.T) {
epochData := &epochData{
threshold: scale.MaxUint128,
authorityIndex: authorityIndex,
authorities: []types.Authority{
*types.NewAuthority(aliceKeyPair.Public(), 1),
authorities: []types.AuthorityRaw{
{
Key: [32]byte(aliceKeyPair.Public().Encode()),
Weight: 1,
},
},
}

Expand Down
Loading

0 comments on commit 58f741d

Please sign in to comment.