From 4467dd0f1ee5330e376eabe567d4f7beb84bcaa9 Mon Sep 17 00:00:00 2001 From: Avi Deitcher Date: Thu, 21 Nov 2024 10:28:46 +0200 Subject: [PATCH] support for encrypted config Signed-off-by: Avi Deitcher --- go.mod | 2 +- go.sum | 6 ++ pkg/config/process.go | 146 ++++++++++++++++++++++++++++++++++- pkg/config/process_test.go | 154 +++++++++++++++++++++++++++++++++++++ 4 files changed, 306 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index f4d14de..30a2f54 100644 --- a/go.mod +++ b/go.mod @@ -31,7 +31,7 @@ require ( ) require ( - github.com/databacker/api/go/api v0.0.0-20241128084006-ed33dc044eaa + github.com/databacker/api/go/api v0.0.0-20241202154620-01b0380f21cb github.com/google/go-cmp v0.6.0 go.opentelemetry.io/otel v1.31.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 diff --git a/go.sum b/go.sum index bb66f66..f2f6b49 100644 --- a/go.sum +++ b/go.sum @@ -72,6 +72,12 @@ github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfc github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/databacker/api/go/api v0.0.0-20241128084006-ed33dc044eaa h1:cDI48+AG1mPMdvgGWz/SLpNKhzDiGZwfSSww9VvvWuI= github.com/databacker/api/go/api v0.0.0-20241128084006-ed33dc044eaa/go.mod h1:bQhbl71Lk1ATni0H+u249hjoQ8ShAdVNcNjnw6z+SbE= +github.com/databacker/api/go/api v0.0.0-20241201124314-f86f0bf46c54 h1:NirzpOdczBCCwBlmfdYLcroaFYIdR2bJkeDjwJKaxQE= +github.com/databacker/api/go/api v0.0.0-20241201124314-f86f0bf46c54/go.mod h1:bQhbl71Lk1ATni0H+u249hjoQ8ShAdVNcNjnw6z+SbE= +github.com/databacker/api/go/api v0.0.0-20241201140600-cb4443d89ac3 h1:RhB+NKRnj6T+2mbmvy1me2zglATss01byA9Nar+0pbk= +github.com/databacker/api/go/api v0.0.0-20241201140600-cb4443d89ac3/go.mod h1:bQhbl71Lk1ATni0H+u249hjoQ8ShAdVNcNjnw6z+SbE= +github.com/databacker/api/go/api v0.0.0-20241202154620-01b0380f21cb h1:9PthuA+o1wBZuTkNc2LLXQfI5+Myy+ok8nD3bQzd7DA= +github.com/databacker/api/go/api v0.0.0-20241202154620-01b0380f21cb/go.mod h1:bQhbl71Lk1ATni0H+u249hjoQ8ShAdVNcNjnw6z+SbE= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/pkg/config/process.go b/pkg/config/process.go index e5048ba..7afb750 100644 --- a/pkg/config/process.go +++ b/pkg/config/process.go @@ -1,11 +1,20 @@ package config import ( + "crypto/aes" + "crypto/cipher" + "crypto/ecdh" + "crypto/ed25519" + "crypto/sha256" + "encoding/base64" "errors" "fmt" "io" "github.com/databacker/api/go/api" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/hkdf" + "golang.org/x/crypto/nacl/box" "gopkg.in/yaml.v3" "github.com/databacker/mysql-backup/pkg/remote" @@ -15,7 +24,10 @@ import ( // If the configuration is of type remote, it will retrieve the remote configuration. // Continues to process remotes until it gets a final valid ConfigSpec or fails. func ProcessConfig(r io.Reader) (actualConfig *api.ConfigSpec, err error) { - var conf api.Config + var ( + conf api.Config + credentials []string + ) decoder := yaml.NewDecoder(r) if err := decoder.Decode(&conf); err != nil { return nil, fmt.Errorf("fatal error reading config file: %w", err) @@ -60,6 +72,20 @@ func ProcessConfig(r io.Reader) (actualConfig *api.ConfigSpec, err error) { return nil, fmt.Errorf("error parsing remote config: %w", err) } conf = remoteConfig + // save encryption key for later + if spec.Credentials != nil { + credentials = append(credentials, *spec.Credentials) + } + case api.Encrypted: + var spec api.EncryptedSpec + if err := yaml.Unmarshal(specBytes, &spec); err != nil { + return nil, fmt.Errorf("parsed yaml had kind encrypted, but spec invalid") + } + // now try to decrypt it + conf, err = decryptConfig(spec, credentials) + if err != nil { + return nil, fmt.Errorf("error decrypting config: %w", err) + } default: return nil, fmt.Errorf("unknown config type: %s", conf.Kind) } @@ -91,3 +117,121 @@ func getRemoteConfig(spec api.RemoteSpec) (conf api.Config, err error) { return baseConf, nil } + +// decryptConfig decrypt an EncryptedSpec given an EncryptedSpec and a list of credentials. +// Returns the decrypted Config struct. +func decryptConfig(spec api.EncryptedSpec, credentials []string) (api.Config, error) { + var plainConfig api.Config + if spec.Algorithm == nil { + return plainConfig, errors.New("empty algorithm") + } + if spec.RecipientPublicKey == nil { + return plainConfig, errors.New("empty recipient public key") + } + if spec.SenderPublicKey == nil { + return plainConfig, errors.New("empty sender public key") + } + if spec.Data == nil { + return plainConfig, errors.New("empty data") + } + // make sure we have the key matching the public key + var ( + privateKey *ecdh.PrivateKey + curve = ecdh.X25519() + ) + + for _, cred := range credentials { + // get our curve25519 private key + keyBytes, err := base64.StdEncoding.DecodeString(cred) + if err != nil { + return plainConfig, fmt.Errorf("error decoding credentials: %w", err) + } + if len(keyBytes) != ed25519.SeedSize { + return plainConfig, fmt.Errorf("invalid key size %d, must be %d", len(keyBytes), ed25519.SeedSize) + } + candidatePrivateKey, err := curve.NewPrivateKey(keyBytes) + if err != nil { + return plainConfig, fmt.Errorf("error creating private key: %w", err) + } + // get the public key from the private key + candidatePublicKey := candidatePrivateKey.PublicKey() + // check if the public key matches the one we have, if so, break + pubKeyBase64 := base64.StdEncoding.EncodeToString(candidatePublicKey.Bytes()) + if pubKeyBase64 == *spec.RecipientPublicKey { + privateKey = candidatePrivateKey + break + } + } + // if we didn't find a matching key, return an error + if privateKey == nil { + return plainConfig, fmt.Errorf("no private key found that matches public key %s", *spec.RecipientPublicKey) + } + senderPublicKeyBytes, err := base64.StdEncoding.DecodeString(*spec.SenderPublicKey) + if err != nil { + return plainConfig, fmt.Errorf("failed to decode sender public key: %w", err) + } + + // Derive the shared secret using the sender's public key and receiver's private key + var senderPublicKey, receiverPrivateKey, sharedSecret [32]byte + copy(senderPublicKey[:], senderPublicKeyBytes) + copy(receiverPrivateKey[:], privateKey.Bytes()) // Use the seed to get the private scalar + box.Precompute(&sharedSecret, &senderPublicKey, &receiverPrivateKey) + + // Derive a symmetric key using HKDF with the shared secret + hkdfReader := hkdf.New(sha256.New, sharedSecret[:], nil, []byte(api.SymmetricKey)) + var symmetricKeySize int + switch *spec.Algorithm { + case api.AesGcm256: + symmetricKeySize = 32 + case api.Chacha20Poly1305: + symmetricKeySize = 32 + default: + return plainConfig, fmt.Errorf("unsupported algorithm: %s", *spec.Algorithm) + } + symmetricKey := make([]byte, symmetricKeySize) + if _, err := hkdfReader.Read(symmetricKey); err != nil { + return plainConfig, fmt.Errorf("failed to derive symmetric key: %w", err) + } + + var ( + plaintext []byte + aead cipher.AEAD + ) + encryptedData, err := base64.StdEncoding.DecodeString(*spec.Data) + if err != nil { + return plainConfig, fmt.Errorf("failed to decode encrypted data: %w", err) + } + switch *spec.Algorithm { + case api.AesGcm256: + // Decrypt with AES-GCM + block, err := aes.NewCipher(symmetricKey) + if err != nil { + return plainConfig, fmt.Errorf("failed to initialize AES cipher: %w", err) + } + aead, err = cipher.NewGCM(block) + if err != nil { + return plainConfig, fmt.Errorf("failed to initialize AES-GCM: %w", err) + } + case api.Chacha20Poly1305: + // Decrypt with ChaCha20Poly1305 + aead, err = chacha20poly1305.New(symmetricKey) + if err != nil { + return plainConfig, fmt.Errorf("failed to initialize ChaCha20Poly1305: %w", err) + } + default: + return plainConfig, fmt.Errorf("unsupported algorithm: %s", *spec.Algorithm) + } + if len(encryptedData) < aead.NonceSize() { + return plainConfig, errors.New("invalid encrypted data length") + } + dataNonce := encryptedData[:aead.NonceSize()] + ciphertext := encryptedData[aead.NonceSize():] + plaintext, err = aead.Open(nil, dataNonce, ciphertext, nil) + if err != nil { + return plainConfig, fmt.Errorf("failed to decrypt data: %w", err) + } + if err := yaml.Unmarshal(plaintext, &plainConfig); err != nil { + return plainConfig, fmt.Errorf("parsed yaml had kind remote, but spec invalid") + } + return plainConfig, nil +} diff --git a/pkg/config/process_test.go b/pkg/config/process_test.go index 4d20c53..60f5188 100644 --- a/pkg/config/process_test.go +++ b/pkg/config/process_test.go @@ -2,13 +2,23 @@ package config import ( "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/ecdh" + "crypto/ed25519" + cryptorand "crypto/rand" + "crypto/sha256" "encoding/base64" + "errors" + "io" "net/http" "os" "strings" "testing" utiltest "github.com/databacker/mysql-backup/pkg/internal/test" + "golang.org/x/crypto/hkdf" + "golang.org/x/crypto/nacl/box" "gopkg.in/yaml.v3" "github.com/databacker/api/go/api" @@ -88,3 +98,147 @@ func TestGetRemoteConfig(t *testing.T) { } } + +func TestDecryptConfig(t *testing.T) { + configFile := "./testdata/config.yml" + content, err := os.ReadFile(configFile) + if err != nil { + t.Fatalf("failed to read config file: %v", err) + } + var validConfig api.Config + if err := yaml.Unmarshal(content, &validConfig); err != nil { + t.Fatalf("failed to unmarshal config: %v", err) + } + + senderCurve := ecdh.X25519() + senderPrivateKey, err := senderCurve.GenerateKey(cryptorand.Reader) + if err != nil { + t.Fatalf("failed to generate sender random seed: %v", err) + } + senderPublicKey := senderPrivateKey.PublicKey() + senderPublicKeyBytes := senderPublicKey.Bytes() + + recipientCurve := ecdh.X25519() + recipientPrivateKey, err := recipientCurve.GenerateKey(cryptorand.Reader) + if err != nil { + t.Fatalf("failed to generate recipient random seed: %v", err) + } + recipientPublicKey := recipientPrivateKey.PublicKey() + recipientPublicKeyBytes := recipientPublicKey.Bytes() + + var recipientPublicKeyArray, senderPrivateKeyArray [32]byte + copy(recipientPublicKeyArray[:], recipientPublicKeyBytes) + copy(senderPrivateKeyArray[:], senderPrivateKey.Bytes()) + + senderPublicKeyB64 := base64.StdEncoding.EncodeToString(senderPublicKeyBytes) + + recipientPublicKeyB64 := base64.StdEncoding.EncodeToString(recipientPublicKeyBytes) + + // compute the shared secret using the sender's private key and the recipient's public key + var sharedSecret [32]byte + box.Precompute(&sharedSecret, &recipientPublicKeyArray, &senderPrivateKeyArray) + + // Derive the symmetric key using HKDF with the shared secret + hkdfReader := hkdf.New(sha256.New, sharedSecret[:], nil, []byte(api.SymmetricKey)) + symmetricKey := make([]byte, 32) // AES-GCM requires 32 bytes + if _, err := hkdfReader.Read(symmetricKey); err != nil { + t.Fatalf("failed to derive symmetric key: %v", err) + } + + // Create AES cipher block + block, err := aes.NewCipher(symmetricKey) + if err != nil { + t.Fatalf("failed to create AES cipher") + } + // Create GCM instance + aesGCM, err := cipher.NewGCM(block) + if err != nil { + t.Fatalf("failed to create AES-GCM") + } + + // Generate a random nonce + nonce := make([]byte, aesGCM.NonceSize()) + _, err = cryptorand.Read(nonce) + if err != nil { + t.Fatalf("failed to generate nonce") + } + + // Encrypt the plaintext + ciphertext := aesGCM.Seal(nil, nonce, content, nil) + + // Embed the nonce in the ciphertext + fullCiphertext := append(nonce, ciphertext...) + + algo := api.AesGcm256 + data := base64.StdEncoding.EncodeToString(fullCiphertext) + + // this is a valid spec, we want to be able to change fields + // without modifying the original, so we have a utility function after + validSpec := api.EncryptedSpec{ + Algorithm: &algo, + Data: &data, + RecipientPublicKey: &recipientPublicKeyB64, + SenderPublicKey: &senderPublicKeyB64, + } + + // copy a spec, changing specific fields + copyModifySpec := func(opts ...func(*api.EncryptedSpec)) api.EncryptedSpec { + copy := validSpec + for _, opt := range opts { + opt(©) + } + return copy + } + + unusedSeed := make([]byte, ed25519.SeedSize) + if _, err := io.ReadFull(cryptorand.Reader, unusedSeed); err != nil { + t.Fatalf("failed to generate sender random seed: %v", err) + } + + // recipient private key credentials + recipientCreds := []string{base64.StdEncoding.EncodeToString(recipientPrivateKey.Bytes())} + unusedCreds := []string{base64.StdEncoding.EncodeToString(unusedSeed)} + + tests := []struct { + name string + inSpec api.EncryptedSpec + credentials []string + config api.Config + err error + }{ + {"no algorithm", copyModifySpec(func(s *api.EncryptedSpec) { s.Algorithm = nil }), recipientCreds, api.Config{}, errors.New("empty algorithm")}, + {"no data", copyModifySpec(func(s *api.EncryptedSpec) { s.Data = nil }), recipientCreds, api.Config{}, errors.New("empty data")}, + {"bad base64 data", copyModifySpec(func(s *api.EncryptedSpec) { data := "abcdef"; s.Data = &data }), recipientCreds, api.Config{}, errors.New("failed to decode encrypted data: illegal base64 data")}, + {"short encrypted data", copyModifySpec(func(s *api.EncryptedSpec) { + data := base64.StdEncoding.EncodeToString([]byte("abcdef")) + s.Data = &data + }), recipientCreds, api.Config{}, errors.New("invalid encrypted data length")}, + {"invalid encrypted data", copyModifySpec(func(s *api.EncryptedSpec) { + bad := nonce + bad = append(bad, 1, 2, 3, 4) + data := base64.StdEncoding.EncodeToString(bad) + s.Data = &data + }), recipientCreds, api.Config{}, errors.New("failed to decrypt data: cipher: message authentication failed")}, + {"empty credentials", validSpec, nil, api.Config{}, errors.New("no private key found that matches public key")}, + {"unmatched credentials", validSpec, unusedCreds, api.Config{}, errors.New("no private key found that matches public key")}, + {"success with just one credential", validSpec, recipientCreds, validConfig, nil}, + {"success with multiple credentials", validSpec, append(recipientCreds, unusedCreds...), validConfig, nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conf, err := decryptConfig(tt.inSpec, tt.credentials) + switch { + case err == nil && tt.err != nil: + t.Fatalf("expected error: %v", tt.err) + case err != nil && tt.err == nil: + t.Fatalf("unexpected error: %v", err) + case err != nil && tt.err != nil && !strings.HasPrefix(err.Error(), tt.err.Error()): + t.Fatalf("mismatched error: %v", err) + } + diff := cmp.Diff(tt.config, conf) + if diff != "" { + t.Fatalf("mismatched config: %s", diff) + } + }) + } +}