Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add encryption support #627

Merged
merged 12 commits into from
Sep 11, 2024
66 changes: 66 additions & 0 deletions encrypt/aes_gcm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package encrypt

import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"fmt"
"io"
)

type encryptionAESGCM struct {
level int
}

func (e *encryptionAESGCM) Encrypt(src []byte, key string) ([]byte, error) {
if len(key) != e.level/8 {
return nil, fmt.Errorf("key length must be %d bytes", e.level/8)
}

block, err := aes.NewCipher([]byte(key))
if err != nil {
return nil, err
}

aesGCM, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}

nonce := make([]byte, aesGCM.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}

ciphertext := aesGCM.Seal(nonce, nonce, src, nil)
return ciphertext, nil
}

func (e *encryptionAESGCM) Decrypt(src []byte, key string) ([]byte, error) {
if len(key) != e.level/8 {
return nil, fmt.Errorf("key length must be %d bytes", e.level/8)
}

block, err := aes.NewCipher([]byte(key))
if err != nil {
return nil, err
}

aesGCM, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}

nonceSize := aesGCM.NonceSize()
if len(src) < nonceSize {
return nil, fmt.Errorf("ciphertext too short")
}

nonce, ciphertext := src[:nonceSize], src[nonceSize:]
plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}

return plaintext, nil
}
60 changes: 60 additions & 0 deletions encrypt/benchmark_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package encrypt

import (
"testing"

"github.com/stretchr/testify/require"
)

/*
BenchmarkEncryptDecrypt
BenchmarkEncryptDecrypt/SMALL_AESGCM_AES128
BenchmarkEncryptDecrypt/SMALL_AESGCM_AES128-12 803842 1444 ns/op 1616 B/op 13 allocs/op
BenchmarkEncryptDecrypt/SMALL_AESGCM_AES192
BenchmarkEncryptDecrypt/SMALL_AESGCM_AES192-12 805350 1443 ns/op 1744 B/op 13 allocs/op
BenchmarkEncryptDecrypt/SMALL_AESGCM_AES256
BenchmarkEncryptDecrypt/SMALL_AESGCM_AES256-12 744871 1516 ns/op 1872 B/op 13 allocs/op
BenchmarkEncryptDecrypt/LARGE_AESGCM_AES128
BenchmarkEncryptDecrypt/LARGE_AESGCM_AES128-12 1900 614516 ns/op 4204053 B/op 13 allocs/op
BenchmarkEncryptDecrypt/LARGE_AESGCM_AES192
BenchmarkEncryptDecrypt/LARGE_AESGCM_AES192-12 1755 672776 ns/op 4204180 B/op 13 allocs/op
BenchmarkEncryptDecrypt/LARGE_AESGCM_AES256
BenchmarkEncryptDecrypt/LARGE_AESGCM_AES256-12 1624 723403 ns/op 4204308 B/op 13 allocs/op
*/
func BenchmarkEncryptDecrypt(b *testing.B) {
tests := []struct {
payload []byte
name string
algo EncryptionAlgorithm
level EncryptionLevel
}{
{[]byte("small payload"), "SMALL_AESGCM_AES128", EncryptionAlgoAESGCM, EncryptionLevelAES128},
{[]byte("small payload"), "SMALL_AESGCM_AES192", EncryptionAlgoAESGCM, EncryptionLevelAES192},
{[]byte("small payload"), "SMALL_AESGCM_AES256", EncryptionAlgoAESGCM, EncryptionLevelAES256},
{make([]byte, 2*1024*1024), "LARGE_AESGCM_AES128", EncryptionAlgoAESGCM, EncryptionLevelAES128},
{make([]byte, 2*1024*1024), "LARGE_AESGCM_AES192", EncryptionAlgoAESGCM, EncryptionLevelAES192},
{make([]byte, 2*1024*1024), "LARGE_AESGCM_AES256", EncryptionAlgoAESGCM, EncryptionLevelAES256},
}

for _, tt := range tests {
b.Run(tt.name, func(b *testing.B) {
b.ReportAllocs()
encrypter, err := New(tt.algo, tt.level)
require.NoError(b, err)

key, err := generateRandomString(int(tt.level / 8))
require.NoError(b, err)

plaintext := tt.payload

b.ResetTimer()
for i := 0; i < b.N; i++ {
ciphertext, err := encrypter.Encrypt(plaintext, key)
require.NoError(b, err)

_, err = encrypter.Decrypt(ciphertext, key)
require.NoError(b, err)
}
})
}
}
102 changes: 102 additions & 0 deletions encrypt/encrypt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package encrypt

import (
"fmt"
"strings"
)

// EncryptionAlgorithm is the interface that wraps the encryption algorithm method.
type EncryptionAlgorithm int

func (e EncryptionAlgorithm) String() string {
switch e {
case EncryptionAlgoAESGCM:
return "aes-gcm"
default:
return ""
}
}

// EncryptionLevel is the interface that wraps the encryption level method.
type EncryptionLevel int

func (e EncryptionLevel) String() string {
switch e {
case EncryptionLevelAES128, EncryptionLevelAES192, EncryptionLevelAES256:
return fmt.Sprintf("%d", e)
default:
return ""
}
}

func NewSettings(algo, level string) (EncryptionAlgorithm, EncryptionLevel, error) {
switch algo {
case "aes-gcm":
switch level {
case "128":
return EncryptionAlgoAESGCM, EncryptionLevelAES128, nil
case "192":
return EncryptionAlgoAESGCM, EncryptionLevelAES192, nil
case "256":
return EncryptionAlgoAESGCM, EncryptionLevelAES256, nil
default:
return 0, 0, fmt.Errorf("unknown encryption level for %s: %s", algo, level)
}
default:
return 0, 0, fmt.Errorf("unknown encryption algorithm: %s", algo)
}
}

var (
EncryptionAlgoAESGCM = EncryptionAlgorithm(1)
EncryptionLevelAES128 = EncryptionLevel(128)
EncryptionLevelAES192 = EncryptionLevel(192)
EncryptionLevelAES256 = EncryptionLevel(256)
)

func New(algo EncryptionAlgorithm, level EncryptionLevel) (*Encryptor, error) {
var err error
algo, level, err = NewSettings(algo.String(), level.String())
if err != nil {
return nil, err
}

switch algo {
case EncryptionAlgoAESGCM:
return &Encryptor{encryptionAESGCM: &encryptionAESGCM{level: int(level)}}, nil
default:
return nil, fmt.Errorf("unknown encryption algorithm: %d", algo)
}
}

type Encryptor struct {
*encryptionAESGCM
}

func (e *Encryptor) Encrypt(src []byte, key string) ([]byte, error) {
if e.encryptionAESGCM != nil {
return e.encryptionAESGCM.Encrypt(src, key)
}
return nil, fmt.Errorf("no encryption method available")
}

func (e *Encryptor) Decrypt(src []byte, key string) ([]byte, error) {
if e.encryptionAESGCM != nil {
return e.encryptionAESGCM.Decrypt(src, key)
}
return nil, fmt.Errorf("no decryption method available")
}

// SerializeSettings converts the EncryptionAlgorithm and EncryptionLevel to a string.
func SerializeSettings(algo EncryptionAlgorithm, level EncryptionLevel) string {
return fmt.Sprintf("%s:%s", algo.String(), level.String())
}

// DeserializeSettings converts a string to EncryptionAlgorithm and EncryptionLevel.
func DeserializeSettings(settings string) (EncryptionAlgorithm, EncryptionLevel, error) {
parts := strings.Split(settings, ":")
if len(parts) != 2 {
return 0, 0, fmt.Errorf("invalid settings format")
}
return NewSettings(parts[0], parts[1])
}
154 changes: 154 additions & 0 deletions encrypt/encrypt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package encrypt

import (
"bytes"
"crypto/rand"
"testing"

"github.com/stretchr/testify/require"
)

var loremIpsumDolor = []byte(`Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.`)

func generateRandomString(n int) (string, error) {
b := make([]byte, n)
if _, err := rand.Read(b); err != nil {
return "", err
}
return string(b), nil
}

func Test_EncryptDecrypt(t *testing.T) {
tests := []struct {
algo EncryptionAlgorithm
level EncryptionLevel
}{
{EncryptionAlgoAESGCM, EncryptionLevelAES128},
{EncryptionAlgoAESGCM, EncryptionLevelAES192},
{EncryptionAlgoAESGCM, EncryptionLevelAES256},
}

for _, tt := range tests {
t.Run(tt.algo.String()+"_"+tt.level.String(), func(t *testing.T) {
encrypter, err := New(tt.algo, tt.level)
require.NoError(t, err)

key, err := generateRandomString(int(tt.level / 8))
require.NoError(t, err)

plaintext := loremIpsumDolor
_, err = encrypter.Encrypt(plaintext, key[:len(key)-1])
require.Error(t, err)

ciphertext, err := encrypter.Encrypt(plaintext, key)
require.NoError(t, err)

decrypted, err := encrypter.Decrypt(ciphertext, key)
require.NoError(t, err)

if !bytes.Equal(decrypted, plaintext) {
t.Errorf("Decrypted data = %v, want %v", decrypted, plaintext)
}
})
t.Run("integrity check", func(t *testing.T) {
for _, tt := range tests {
t.Run(tt.algo.String()+"_"+tt.level.String(), func(t *testing.T) {
encrypter, err := New(tt.algo, tt.level)
require.NoError(t, err)

key, err := generateRandomString(int(tt.level / 8))
require.NoError(t, err)

plaintext := loremIpsumDolor
_, err = encrypter.Encrypt(plaintext, key[:len(key)-1])
require.Error(t, err)

ciphertext, err := encrypter.Encrypt(plaintext, key)
require.NoError(t, err)

t.Log("to test integrity check properties manipulate the ciphertext")
ciphertext[0] = ciphertext[0] + 1
decrypted, err := encrypter.Decrypt(ciphertext, key)

require.Error(t, err, "decryption should fail, instead got %q", decrypted)
require.Empty(t, decrypted)
})
}
})
t.Run("invalid key decryption", func(t *testing.T) {
for _, tt := range tests {
t.Run(tt.algo.String()+"_"+tt.level.String(), func(t *testing.T) {
encrypter, err := New(tt.algo, tt.level)
require.NoError(t, err)

key, err := generateRandomString(int(tt.level / 8))
require.NoError(t, err)

plaintext := loremIpsumDolor
_, err = encrypter.Encrypt(plaintext, key[:len(key)-1])
require.Error(t, err)

ciphertext, err := encrypter.Encrypt(plaintext, key)
require.NoError(t, err)

anotherKey, err := generateRandomString(int(tt.level / 8))
require.NoError(t, err)

t.Log("use a different key to decrypt the ciphertext")
decrypted, err := encrypter.Decrypt(ciphertext, anotherKey)
require.Error(t, err, "decryption should fail, instead got %q", decrypted)
require.Empty(t, decrypted)
})
}
})
}
}

func Test_SerializeSettings(t *testing.T) {
tests := []struct {
algo EncryptionAlgorithm
level EncryptionLevel
expect string
}{
{EncryptionAlgoAESGCM, EncryptionLevelAES128, "aes-gcm:128"},
{EncryptionAlgoAESGCM, EncryptionLevelAES192, "aes-gcm:192"},
{EncryptionAlgoAESGCM, EncryptionLevelAES256, "aes-gcm:256"},
}

for _, tt := range tests {
t.Run(tt.expect, func(t *testing.T) {
result := SerializeSettings(tt.algo, tt.level)
require.Equal(t, tt.expect, result)
})
}
}

func Test_DeserializeSettings(t *testing.T) {
tests := []struct {
input string
algo EncryptionAlgorithm
level EncryptionLevel
hasErr bool
}{
{"aes-gcm:128", EncryptionAlgoAESGCM, EncryptionLevelAES128, false},
{"aes-gcm:192", EncryptionAlgoAESGCM, EncryptionLevelAES192, false},
{"aes-gcm:256", EncryptionAlgoAESGCM, EncryptionLevelAES256, false},
{"invalid:settings", 0, 0, true},
}

for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
algo, level, err := DeserializeSettings(tt.input)
if tt.hasErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tt.algo, algo)
require.Equal(t, tt.level, level)
}
})
}
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/rudderlabs/rudder-go-kit

go 1.22.5
go 1.22.7

replace github.com/gocql/gocql => github.com/scylladb/gocql v1.14.2

Expand Down
Loading