This repository has been archived by the owner on May 15, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
shuffle_test.go
129 lines (114 loc) · 3.51 KB
/
shuffle_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package eth2_shuffle
import (
"encoding/csv"
"encoding/hex"
"fmt"
"os"
"strconv"
"strings"
"testing"
)
func readEncodedListInput(input string, requiredLen int64, lineIndex int) ([]uint64, error) {
var itemStrs []string
if input != "" {
itemStrs = strings.Split(input, ":")
} else {
itemStrs = make([]string, 0)
}
if int64(len(itemStrs)) != requiredLen {
return nil, fmt.Errorf("expected outputs length does not match list size on line %d\n", lineIndex)
}
items := make([]uint64, len(itemStrs), len(itemStrs))
for i, itemStr := range itemStrs {
item, err := strconv.ParseInt(itemStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("expected list item on line %d, item %d cannot be parsed\n", lineIndex, i)
}
items[i] = uint64(item)
}
return items, nil
}
func TestAgainstSpec(t *testing.T) {
// Open CSV file
f, err := os.Open("spec/tests.csv")
if err != nil {
panic(err)
}
defer f.Close()
// Read File into a Variable
lines, err := csv.NewReader(f).ReadAll()
if err != nil {
panic(err)
}
// constant in spec
rounds := uint8(90)
// Loop through lines & turn into object
for lineIndex, line := range lines {
parsedSeed, err := hex.DecodeString(line[0])
if err != nil {
t.Fatalf("seed on line %d cannot be parsed\n", lineIndex)
}
listSize, err := strconv.ParseInt(line[1], 10, 32)
if err != nil {
t.Fatalf("list size on line %d cannot be parsed\n", lineIndex)
}
inputItems, err := readEncodedListInput(line[2], listSize, lineIndex)
expectedItems, err := readEncodedListInput(line[3], listSize, lineIndex)
t.Run("", func(listSize uint64, shuffleIn []uint64, shuffleOut []uint64) func(st *testing.T) {
return func(st *testing.T) {
seed := [32]byte{}
copy(seed[:], parsedSeed)
// run every test case in parallel. Input data is copied, for loop won't mess it up.
st.Parallel()
hashFn := getStandardHashFn()
st.Run("PermuteIndex", func (it *testing.T) {
for i := uint64(0); i < listSize; i++ {
// calculate the permuted index. (i.e. shuffle single index)
permuted := PermuteIndex(hashFn, rounds, i, listSize, seed)
// compare with expectation
if shuffleIn[i] != shuffleOut[permuted] {
it.FailNow()
}
}
})
st.Run("UnpermuteIndex", func (it *testing.T) {
// for each index, test un-permuting
for i := uint64(0); i < listSize; i++ {
// calculate the un-permuted index. (i.e. un-shuffle single index)
unpermuted := UnpermuteIndex(hashFn, rounds, i, listSize, seed)
// compare with expectation
if shuffleOut[i] != shuffleIn[unpermuted] {
it.FailNow()
}
}
})
st.Run("ShuffleList", func (it *testing.T) {
// create input, this slice will be shuffled.
testInput := make([]uint64, listSize, listSize)
copy(testInput, shuffleIn)
// shuffle!
ShuffleList(hashFn, testInput, rounds, seed)
// compare shuffled list to expected output
for i := uint64(0); i < listSize; i++ {
if testInput[i] != shuffleOut[i] {
it.FailNow()
}
}
})
st.Run("UnshuffleList", func (it *testing.T) {
// create input, this slice will be un-shuffled.
testInput := make([]uint64, listSize, listSize)
copy(testInput, shuffleOut)
// un-shuffle!
UnshuffleList(hashFn, testInput, rounds, seed)
// compare shuffled list to original input
for i := uint64(0); i < listSize; i++ {
if testInput[i] != shuffleIn[i] {
it.FailNow()
}
}
})
}
}(uint64(listSize), inputItems, expectedItems))
}
}