Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add encryption support #627

Merged
merged 12 commits into from
Sep 11, 2024
66 changes: 66 additions & 0 deletions encrypt/aes_gcm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package encrypt

import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"fmt"
"io"
)

type encryptionAESGCM struct {
level int
}

func (e *encryptionAESGCM) Encrypt(src []byte, key string) ([]byte, error) {
if len(key) != e.level/8 {
return nil, fmt.Errorf("key length must be %d bytes", e.level/8)
}

block, err := aes.NewCipher([]byte(key))
if err != nil {
return nil, err
}

aesGCM, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}

nonce := make([]byte, aesGCM.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}

ciphertext := aesGCM.Seal(nonce, nonce, src, nil)
return ciphertext, nil
}

func (e *encryptionAESGCM) Decrypt(src []byte, key string) ([]byte, error) {
if len(key) != e.level/8 {
return nil, fmt.Errorf("key length must be %d bytes", e.level/8)
}

block, err := aes.NewCipher([]byte(key))
if err != nil {
return nil, err
}

aesGCM, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}

nonceSize := aesGCM.NonceSize()
if len(src) < nonceSize {
return nil, fmt.Errorf("ciphertext too short")
}

nonce, ciphertext := src[:nonceSize], src[nonceSize:]
plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}

return plaintext, nil
}
60 changes: 60 additions & 0 deletions encrypt/benchmark_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package encrypt

import (
"testing"

"github.com/stretchr/testify/require"
)

/*
BenchmarkEncryptDecrypt
BenchmarkEncryptDecrypt/SMALL_AESGCM_AES128
BenchmarkEncryptDecrypt/SMALL_AESGCM_AES128-12 803842 1444 ns/op 1616 B/op 13 allocs/op
BenchmarkEncryptDecrypt/SMALL_AESGCM_AES192
BenchmarkEncryptDecrypt/SMALL_AESGCM_AES192-12 805350 1443 ns/op 1744 B/op 13 allocs/op
BenchmarkEncryptDecrypt/SMALL_AESGCM_AES256
BenchmarkEncryptDecrypt/SMALL_AESGCM_AES256-12 744871 1516 ns/op 1872 B/op 13 allocs/op
BenchmarkEncryptDecrypt/LARGE_AESGCM_AES128
BenchmarkEncryptDecrypt/LARGE_AESGCM_AES128-12 1900 614516 ns/op 4204053 B/op 13 allocs/op
BenchmarkEncryptDecrypt/LARGE_AESGCM_AES192
BenchmarkEncryptDecrypt/LARGE_AESGCM_AES192-12 1755 672776 ns/op 4204180 B/op 13 allocs/op
BenchmarkEncryptDecrypt/LARGE_AESGCM_AES256
BenchmarkEncryptDecrypt/LARGE_AESGCM_AES256-12 1624 723403 ns/op 4204308 B/op 13 allocs/op
*/
func BenchmarkEncryptDecrypt(b *testing.B) {
tests := []struct {
payload []byte
name string
algo EncryptionAlgorithm
level EncryptionLevel
}{
{[]byte("small payload"), "SMALL_AESGCM_AES128", EncryptionAlgoAESGCM, EncryptionLevelAES128},
{[]byte("small payload"), "SMALL_AESGCM_AES192", EncryptionAlgoAESGCM, EncryptionLevelAES192},
{[]byte("small payload"), "SMALL_AESGCM_AES256", EncryptionAlgoAESGCM, EncryptionLevelAES256},
{make([]byte, 2*1024*1024), "LARGE_AESGCM_AES128", EncryptionAlgoAESGCM, EncryptionLevelAES128},
{make([]byte, 2*1024*1024), "LARGE_AESGCM_AES192", EncryptionAlgoAESGCM, EncryptionLevelAES192},
{make([]byte, 2*1024*1024), "LARGE_AESGCM_AES256", EncryptionAlgoAESGCM, EncryptionLevelAES256},
}

for _, tt := range tests {
b.Run(tt.name, func(b *testing.B) {
b.ReportAllocs()
encrypter, err := New(tt.algo, tt.level)
require.NoError(b, err)

key, err := generateRandomString(int(tt.level / 8))
require.NoError(b, err)

plaintext := tt.payload

b.ResetTimer()
for i := 0; i < b.N; i++ {
ciphertext, err := encrypter.Encrypt(plaintext, key)
require.NoError(b, err)

_, err = encrypter.Decrypt(ciphertext, key)
require.NoError(b, err)
}
})
}
}
102 changes: 102 additions & 0 deletions encrypt/encrypt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package encrypt

import (
"fmt"
"strings"
)

// EncryptionAlgorithm is the interface that wraps the encryption algorithm method.
type EncryptionAlgorithm int

func (e EncryptionAlgorithm) String() string {
switch e {
case EncryptionAlgoAESGCM:
return "aes-gcm"
default:
return ""
}
}

// EncryptionLevel is the interface that wraps the encryption level method.
type EncryptionLevel int

func (e EncryptionLevel) String() string {
switch e {
case EncryptionLevelAES128, EncryptionLevelAES192, EncryptionLevelAES256:
return fmt.Sprintf("%d", e)
default:
return ""
}
}

func NewSettings(algo, level string) (EncryptionAlgorithm, EncryptionLevel, error) {
switch algo {
case "aes-gcm":
switch level {
case "128":
return EncryptionAlgoAESGCM, EncryptionLevelAES128, nil
case "192":
return EncryptionAlgoAESGCM, EncryptionLevelAES192, nil
case "256":
return EncryptionAlgoAESGCM, EncryptionLevelAES256, nil
default:
return 0, 0, fmt.Errorf("unknown encryption level for %s: %s", algo, level)
}
default:
return 0, 0, fmt.Errorf("unknown encryption algorithm: %s", algo)
}
}

var (
EncryptionAlgoAESGCM = EncryptionAlgorithm(1)
EncryptionLevelAES128 = EncryptionLevel(128)
EncryptionLevelAES192 = EncryptionLevel(192)
EncryptionLevelAES256 = EncryptionLevel(256)
)

func New(algo EncryptionAlgorithm, level EncryptionLevel) (*Encryptor, error) {
var err error
algo, level, err = NewSettings(algo.String(), level.String())
if err != nil {
return nil, err
}

switch algo {
case EncryptionAlgoAESGCM:
return &Encryptor{encryptionAESGCM: &encryptionAESGCM{level: int(level)}}, nil
default:
return nil, fmt.Errorf("unknown encryption algorithm: %d", algo)
}
}

type Encryptor struct {
*encryptionAESGCM
}

func (e *Encryptor) Encrypt(src []byte, key string) ([]byte, error) {
if e.encryptionAESGCM != nil {
return e.encryptionAESGCM.Encrypt(src, key)
}
return nil, fmt.Errorf("no encryption method available")
}

func (e *Encryptor) Decrypt(src []byte, key string) ([]byte, error) {
if e.encryptionAESGCM != nil {
return e.encryptionAESGCM.Decrypt(src, key)
}
return nil, fmt.Errorf("no decryption method available")
}

// SerializeSettings converts the EncryptionAlgorithm and EncryptionLevel to a string.
func SerializeSettings(algo EncryptionAlgorithm, level EncryptionLevel) string {
return fmt.Sprintf("%s:%s", algo.String(), level.String())
}

// DeserializeSettings converts a string to EncryptionAlgorithm and EncryptionLevel.
func DeserializeSettings(settings string) (EncryptionAlgorithm, EncryptionLevel, error) {
parts := strings.Split(settings, ":")
if len(parts) != 2 {
return 0, 0, fmt.Errorf("invalid settings format")
}
return NewSettings(parts[0], parts[1])
}
Loading
Loading