Skip to content

Commit

Permalink
fix thread-unsafe snow3g function and modify snow3g UT (#10)
Browse files Browse the repository at this point in the history
* fix thread-unsafe function and modify UT

* golangci-lint check

* modify according to comments

* modify according to comments
  • Loading branch information
chliu-nems authored Nov 5, 2021
1 parent c66ef3d commit 87ccf71
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 118 deletions.
8 changes: 2 additions & 6 deletions security/security.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,10 @@ func NEA1(ck [16]byte, countC, bearer, direction uint32, ibs []byte, length uint
k[i] = binary.BigEndian.Uint32(ck[4*(3-i) : 4*(3-i+1)])
}
iv := [4]uint32{(bearer << 27) | (direction << 26), countC, (bearer << 27) | (direction << 26), countC}
snow3g.InitSnow3g(k, iv)

l := (length + 31) / 32
r := length % 32
ks := make([]uint32, l)
snow3g.GenerateKeystream(int(l), ks)
ks := snow3g.GetKeyStream(k, iv, int(l))
// Clear keystream bits which exceed length
if r != 0 {
ks[l-1] &= ^((1 << (32 - r)) - 1)
Expand Down Expand Up @@ -180,9 +178,7 @@ func NIA1(ik [16]byte, countI uint32, bearer byte, direction uint32, msg []byte,
}
iv := [4]uint32{fresh ^ (direction << 15), countI ^ (direction << 31), fresh, countI}
D := ((length + 63) / 64) + 1
z := make([]uint32, 5)
snow3g.InitSnow3g(k, iv)
snow3g.GenerateKeystream(5, z)
z := snow3g.GetKeyStream(k, iv, 5)

P := (uint64(z[0]) << 32) | uint64(z[1])
Q := (uint64(z[2]) << 32) | uint64(z[3])
Expand Down
111 changes: 58 additions & 53 deletions security/snow3g/snow3g.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,11 @@ var sq = [...]byte{
0x56, 0xe1, 0x77, 0xc9, 0x1e, 0x9e, 0x95, 0xa3, 0x90, 0x19, 0xa8, 0x6c, 0x09, 0xd0, 0xf0, 0x86,
}

type Lfsr struct {
s [16]uint32
type snow3g struct {
lfsr [16]uint32
fsm [3]uint32
}

type Fsm struct {
r [3]uint32
}

var (
lfsr Lfsr
fsm Fsm
)

func mulx(V, c byte) byte {
if V&0x80 != 0 {
return (V << 1) ^ c
Expand Down Expand Up @@ -109,68 +101,81 @@ func divAlpha(c byte) uint32 {
return (r0 << 24) | (r1 << 16) | (r2 << 8) | r3
}

func lfsrInitialisationMode(F uint32) {
v := (lfsr.s[0] << 8) ^ mulAlpha(byte(lfsr.s[0]>>24)&0xff) ^ lfsr.s[2] ^ (lfsr.s[11] >> 8) ^
divAlpha(byte(lfsr.s[11]&0xff)) ^ F
func (s *snow3g) lfsrInitializationMode(F uint32) {
v := (s.lfsr[0] << 8) ^ mulAlpha(byte(s.lfsr[0]>>24)&0xff) ^ s.lfsr[2] ^ (s.lfsr[11] >> 8) ^
divAlpha(byte(s.lfsr[11]&0xff)) ^ F
for i := 0; i < 15; i++ {
lfsr.s[i] = lfsr.s[i+1]
s.lfsr[i] = s.lfsr[i+1]
}
lfsr.s[15] = v
s.lfsr[15] = v
}

func lfsrKeystreamMode() {
v := (lfsr.s[0] << 8) ^ mulAlpha(byte(lfsr.s[0]>>24)&0xff) ^ lfsr.s[2] ^ (lfsr.s[11] >> 8) ^
divAlpha(byte(lfsr.s[11]&0xff))
func (s *snow3g) lfsrKeystreamMode() {
v := (s.lfsr[0] << 8) ^ mulAlpha(byte(s.lfsr[0]>>24)&0xff) ^ s.lfsr[2] ^ (s.lfsr[11] >> 8) ^
divAlpha(byte(s.lfsr[11]&0xff))
for i := 0; i < 15; i++ {
lfsr.s[i] = lfsr.s[i+1]
s.lfsr[i] = s.lfsr[i+1]
}
lfsr.s[15] = v
s.lfsr[15] = v
}

func clockFsm(s15, s5 uint32) uint32 {
F := (s15 + fsm.r[0]) ^ fsm.r[1]
r := fsm.r[1] + (fsm.r[2] ^ s5)
fsm.r[2] = s2(fsm.r[1])
fsm.r[1] = s1(fsm.r[0])
fsm.r[0] = r
func (s *snow3g) clockFsm(s15, s5 uint32) uint32 {
F := (s15 + s.fsm[0]) ^ s.fsm[1]
r := s.fsm[1] + (s.fsm[2] ^ s5)
s.fsm[2] = s2(s.fsm[1])
s.fsm[1] = s1(s.fsm[0])
s.fsm[0] = r
return F
}

func InitSnow3g(k, iv [4]uint32) {
lfsr.s[0] = k[0] ^ 0xffffffff
lfsr.s[1] = k[1] ^ 0xffffffff
lfsr.s[2] = k[2] ^ 0xffffffff
lfsr.s[3] = k[3] ^ 0xffffffff
lfsr.s[4] = k[0]
lfsr.s[5] = k[1]
lfsr.s[6] = k[2]
lfsr.s[7] = k[3]
lfsr.s[8] = k[0] ^ 0xffffffff
lfsr.s[9] = k[1] ^ 0xffffffff ^ iv[3]
lfsr.s[10] = k[2] ^ 0xffffffff ^ iv[2]
lfsr.s[11] = k[3] ^ 0xffffffff
lfsr.s[12] = k[0] ^ iv[1]
lfsr.s[13] = k[1]
lfsr.s[14] = k[2]
lfsr.s[15] = k[3] ^ iv[0]
func newSnow3g(k, iv [4]uint32) *snow3g {
s := &snow3g{}

s.lfsr[0] = k[0] ^ 0xffffffff
s.lfsr[1] = k[1] ^ 0xffffffff
s.lfsr[2] = k[2] ^ 0xffffffff
s.lfsr[3] = k[3] ^ 0xffffffff
s.lfsr[4] = k[0]
s.lfsr[5] = k[1]
s.lfsr[6] = k[2]
s.lfsr[7] = k[3]
s.lfsr[8] = k[0] ^ 0xffffffff
s.lfsr[9] = k[1] ^ 0xffffffff ^ iv[3]
s.lfsr[10] = k[2] ^ 0xffffffff ^ iv[2]
s.lfsr[11] = k[3] ^ 0xffffffff
s.lfsr[12] = k[0] ^ iv[1]
s.lfsr[13] = k[1]
s.lfsr[14] = k[2]
s.lfsr[15] = k[3] ^ iv[0]

for i := 0; i < 3; i++ {
fsm.r[i] = 0
s.fsm[i] = 0
}

for i := 0; i < 32; i++ {
F := clockFsm(lfsr.s[15], lfsr.s[5])
lfsrInitialisationMode(F)
F := s.clockFsm(s.lfsr[15], s.lfsr[5])
s.lfsrInitializationMode(F)
}

return s
}

func GenerateKeystream(n int, ks []uint32) {
clockFsm(lfsr.s[15], lfsr.s[5])
lfsrKeystreamMode()
func (s *snow3g) generateKeystream(n int, ks []uint32) {
s.clockFsm(s.lfsr[15], s.lfsr[5])
s.lfsrKeystreamMode()

for i := 0; i < n; i++ {
F := clockFsm(lfsr.s[15], lfsr.s[5])
ks[i] = F ^ lfsr.s[0]
lfsrKeystreamMode()
F := s.clockFsm(s.lfsr[15], s.lfsr[5])
ks[i] = F ^ s.lfsr[0]
s.lfsrKeystreamMode()
}
}

func GetKeyStream(k, iv [4]uint32, n int) []uint32 {
s := newSnow3g(k, iv)

ks := make([]uint32, n)
s.generateKeystream(n, ks)

return ks
}
104 changes: 45 additions & 59 deletions security/snow3g/snow3g_test.go
Original file line number Diff line number Diff line change
@@ -1,70 +1,56 @@
package snow3g_test
package snow3g

import (
"testing"

"github.com/free5gc/nas/security/snow3g"
"github.com/stretchr/testify/require"
)

func Test1(t *testing.T) {
k := [4]uint32{0x2bd6459f, 0x82c5b300, 0x952c4910, 0x4881ff48}
iv := [4]uint32{0xea024714, 0xad5c4d84, 0xdf1f9b25, 0x1c0bf45f}
z := [2]uint32{0xabee9704, 0x7ac31373}

ks := make([]uint32, 2)
snow3g.InitSnow3g(k, iv)
snow3g.GenerateKeystream(2, ks)
for i, k := range ks {
if k != z[i] {
t.Errorf("%#x != %#x\n", k, z[i])
}
func TestSnow3g(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
k [4]uint32
iv [4]uint32
z []uint32
length int
}{
{
name: "TestCase1",
k: [4]uint32{0x2bd6459f, 0x82c5b300, 0x952c4910, 0x4881ff48},
iv: [4]uint32{0xea024714, 0xad5c4d84, 0xdf1f9b25, 0x1c0bf45f},
z: []uint32{0xabee9704, 0x7ac31373},
length: 2,
},
{
name: "TestCase2",
k: [4]uint32{0x8ce33e2c, 0xc3c0b5fc, 0x1f3de8a6, 0xdc66b1f3},
iv: [4]uint32{0xd3c5d592, 0x327fb11c, 0xde551988, 0xceb2f9b7},
z: []uint32{0xeff8a342, 0xf751480f},
length: 2,
},
{
name: "TestCase3",
k: [4]uint32{0x4035c668, 0x0af8c6d1, 0xa8ff8667, 0xb1714013},
iv: [4]uint32{0x62a54098, 0x1ba6f9b7, 0x4592b0e7, 0x8690f71b},
z: []uint32{0xa8c874a9, 0x7ae7c4f8},
length: 2,
},
{
name: "TestCase4",
k: [4]uint32{0x0ded7263, 0x109cf92e, 0x3352255a, 0x140e0f76},
iv: [4]uint32{0x6b68079a, 0x41a7c4c9, 0x1befd79f, 0x7fdcc233},
z: []uint32{0xd712c05c, 0xa937c2a6, 0xeb7eaae3},
length: 3,
},
}
}

func Test2(t *testing.T) {
k := [4]uint32{0x8ce33e2c, 0xc3c0b5fc, 0x1f3de8a6, 0xdc66b1f3}
iv := [4]uint32{0xd3c5d592, 0x327fb11c, 0xde551988, 0xceb2f9b7}
z := [2]uint32{0xeff8a342, 0xf751480f}

ks := make([]uint32, 2)
snow3g.InitSnow3g(k, iv)
snow3g.GenerateKeystream(2, ks)
for i, k := range ks {
if k != z[i] {
t.Errorf("%#x != %#x\n", k, z[i])
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ks := GetKeyStream(tc.k, tc.iv, tc.length)
require.Equal(t, tc.z, ks)
})
}
}

func Test3(t *testing.T) {
k := [4]uint32{0x4035c668, 0x0af8c6d1, 0xa8ff8667, 0xb1714013}
iv := [4]uint32{0x62a54098, 0x1ba6f9b7, 0x4592b0e7, 0x8690f71b}
z := [2]uint32{0xa8c874a9, 0x7ae7c4f8}

ks := make([]uint32, 2)
snow3g.InitSnow3g(k, iv)
snow3g.GenerateKeystream(2, ks)
for i, k := range ks {
if k != z[i] {
t.Errorf("%#x != %#x\n", k, z[i])
}
}
}

func Test4(t *testing.T) {
k := [4]uint32{0x0ded7263, 0x109cf92e, 0x3352255a, 0x140e0f76}
iv := [4]uint32{0x6b68079a, 0x41a7c4c9, 0x1befd79f, 0x7fdcc233}
z := [3]uint32{0xd712c05c, 0xa937c2a6, 0xeb7eaae3}

ks := make([]uint32, 2500)
snow3g.InitSnow3g(k, iv)
snow3g.GenerateKeystream(2500, ks)
for i := 0; i < 3; i++ {
if ks[i] != z[i] {
t.Errorf("%#x != %#x\n", ks[i], z[i])
}
}
if ks[2499] != 0x9c0db3aa {
t.Errorf("%#x != %#x\n", ks[2499], 0x9c0db3aa)
}
}

0 comments on commit 87ccf71

Please sign in to comment.