Skip to content

Commit

Permalink
Add ComputeShuffledIndex algorithm. (#6267)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelneuder authored Dec 10, 2022
1 parent b5a7faa commit f512c88
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
52 changes: 52 additions & 0 deletions cmd/erigon-cl/core/transition/beacon_state_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package transition

import (
"crypto/sha256"
"encoding/binary"
"fmt"
)

const SHUFFLE_ROUND_COUNT = uint8(90)

func ComputeShuffledIndex(ind, ind_count uint64, seed [32]byte) (uint64, error) {
if ind >= ind_count {
return 0, fmt.Errorf("index=%d must be less than the index count=%d", ind, ind_count)
}

for i := uint8(0); i < SHUFFLE_ROUND_COUNT; i++ {
// Construct first hash input.
input := append(seed[:], i)
hash := sha256.New()
hash.Write(input)

// Read hash value.
hashValue := binary.LittleEndian.Uint64(hash.Sum(nil)[:8])

// Caclulate pivot and flip.
pivot := hashValue % ind_count
flip := (pivot + ind_count - ind) % ind_count

// No uint64 max function in go standard library.
position := ind
if flip > ind {
position = flip
}

// Construct the second hash input.
positionByteArray := make([]byte, 4)
binary.LittleEndian.PutUint32(positionByteArray, uint32(position>>8))
input2 := append(seed[:], i)
input2 = append(input2, positionByteArray...)

hash.Reset()
hash.Write(input2)
// Read hash value.
source := hash.Sum(nil)
byteVal := source[(position%256)/8]
bitVal := (byteVal >> (position % 8)) % 2
if bitVal == 1 {
ind = flip
}
}
return ind, nil
}
36 changes: 36 additions & 0 deletions cmd/erigon-cl/core/transition/beacon_state_utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package transition

import (
"testing"
)

func TestComputeShuffledIndex(t *testing.T) {
testCases := []struct {
description string
startInds []uint64
expectedInds []uint64
seed [32]byte
}{
{
description: "success",
startInds: []uint64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
expectedInds: []uint64{0, 9, 8, 4, 6, 7, 3, 1, 2, 5},
seed: [32]byte{1, 128, 12},
},
}

for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
for i, val := range tc.startInds {
got, err := ComputeShuffledIndex(val, uint64(len(tc.startInds)), tc.seed)
// Non-failure case.
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if got != tc.expectedInds[i] {
t.Errorf("unexpected result: got %d, want %d", got, tc.expectedInds[i])
}
}
})
}
}

0 comments on commit f512c88

Please sign in to comment.