Skip to content

Commit

Permalink
add serialisation methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mihir20 committed Sep 5, 2024
1 parent 3973b4b commit 9d6f4b6
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 12 deletions.
15 changes: 15 additions & 0 deletions encrypt/encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package encrypt

import (
"fmt"
"strings"
)

// EncryptionAlgorithm is the interface that wraps the encryption algorithm method.
Expand Down Expand Up @@ -108,3 +109,17 @@ func (e *Encrypter) Decrypt(src []byte, key string) ([]byte, error) {
}
return nil, fmt.Errorf("no decryption method available")
}

// SerializeSettings converts the EncryptionAlgorithm and EncryptionLevel to a string.
func SerializeSettings(algo EncryptionAlgorithm, level EncryptionLevel) string {
return fmt.Sprintf("%s:%s", algo.String(), level.String())
}

// DeserializeSettings converts a string to EncryptionAlgorithm and EncryptionLevel.
func DeserializeSettings(settings string) (EncryptionAlgorithm, EncryptionLevel, error) {
parts := strings.Split(settings, ":")
if len(parts) != 2 {
return 0, 0, fmt.Errorf("invalid settings format")
}
return NewSettings(parts[0], parts[1])
}
70 changes: 58 additions & 12 deletions encrypt/encrypt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.`)

func generateRandomString(n int) (string, error) {
bytes := make([]byte, n)
if _, err := rand.Read(bytes); err != nil {
b := make([]byte, n)
if _, err := rand.Read(b); err != nil {
return "", err
}
return string(bytes), nil
return string(b), nil
}

func Test_EncryptDecrypt(t *testing.T) {
Expand All @@ -38,27 +38,73 @@ func Test_EncryptDecrypt(t *testing.T) {
for _, tt := range tests {
t.Run(tt.algo.String()+"_"+tt.level.String(), func(t *testing.T) {
encrypter, err := New(tt.algo, tt.level)
if err != nil {
t.Fatalf("New() error = %v", err)
}
require.NoError(t, err)

key, err := generateRandomString(int(tt.level / 8))
require.NoError(t, err)

plaintext := loremIpsumDolor
ciphertext, err := encrypter.Encrypt(plaintext, key)
if err != nil {
t.Fatalf("Encrypt() error = %v", err)
}
require.NoError(t, err)

decrypted, err := encrypter.Decrypt(ciphertext, key)
if err != nil {
t.Fatalf("Decrypt() error = %v", err)
}
require.NoError(t, err)

if !bytes.Equal(decrypted, plaintext) {
t.Errorf("Decrypted data = %v, want %v", decrypted, plaintext)
}
})
}
}

func Test_SerializeSettings(t *testing.T) {
tests := []struct {
algo EncryptionAlgorithm
level EncryptionLevel
expect string
}{
{EncryptionAlgoAESCFB, EncryptionLevelAES128, "cfb:aes-128"},
{EncryptionAlgoAESCFB, EncryptionLevelAES192, "cfb:aes-192"},
{EncryptionAlgoAESCFB, EncryptionLevelAES256, "cfb:aes-256"},
{EncryptionAlgoAESGCM, EncryptionLevelAES128, "gcm:aes-128"},
{EncryptionAlgoAESGCM, EncryptionLevelAES192, "gcm:aes-192"},
{EncryptionAlgoAESGCM, EncryptionLevelAES256, "gcm:aes-256"},
}

for _, tt := range tests {
t.Run(tt.expect, func(t *testing.T) {
result := SerializeSettings(tt.algo, tt.level)
require.Equal(t, tt.expect, result)
})
}
}

func Test_DeserializeSettings(t *testing.T) {
tests := []struct {
input string
algo EncryptionAlgorithm
level EncryptionLevel
hasErr bool
}{
{"cfb:aes-128", EncryptionAlgoAESCFB, EncryptionLevelAES128, false},
{"cfb:aes-192", EncryptionAlgoAESCFB, EncryptionLevelAES192, false},
{"cfb:aes-256", EncryptionAlgoAESCFB, EncryptionLevelAES256, false},
{"gcm:aes-128", EncryptionAlgoAESGCM, EncryptionLevelAES128, false},
{"gcm:aes-192", EncryptionAlgoAESGCM, EncryptionLevelAES192, false},
{"gcm:aes-256", EncryptionAlgoAESGCM, EncryptionLevelAES256, false},
{"invalid:settings", 0, 0, true},
}

for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
algo, level, err := DeserializeSettings(tt.input)
if tt.hasErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tt.algo, algo)
require.Equal(t, tt.level, level)
}
})
}
}

0 comments on commit 9d6f4b6

Please sign in to comment.