From 5b18f56e9dbe7135ca51f44cedf26a96ecb39d5b Mon Sep 17 00:00:00 2001 From: davidshi Date: Tue, 19 Sep 2023 14:10:58 -0500 Subject: [PATCH 1/6] Add context parameter to key provider interface --- aecmk/akv/keyprovider.go | 165 ++++++++++++++++++++-------------- aecmk/akv/keyprovider_test.go | 15 ++-- aecmk/error.go | 39 ++++++++ aecmk/keyprovider.go | 17 ++-- 4 files changed, 159 insertions(+), 77 deletions(-) create mode 100644 aecmk/error.go diff --git a/aecmk/akv/keyprovider.go b/aecmk/akv/keyprovider.go index cecd40c0..b90565e0 100644 --- a/aecmk/akv/keyprovider.go +++ b/aecmk/akv/keyprovider.go @@ -63,101 +63,120 @@ func init() { // DecryptColumnEncryptionKey decrypts the specified encrypted value of a column encryption key. // The encrypted value is expected to be encrypted using the column master key with the specified key path and using the specified algorithm. -func (p *Provider) DecryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) (decryptedKey []byte) { +func (p *Provider) DecryptColumnEncryptionKey(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) (decryptedKey []byte, err error) { decryptedKey = nil - keyData := p.getKeyData(masterKeyPath) - if keyData == nil { + keyData, err := p.getKeyData(ctx, masterKeyPath, aecmk.Decryption) + if err != nil { return } keySize := keyData.publicKey.Size() cekv := ae.LoadCEKV(encryptedCek) if cekv.Version != 1 { - panic(fmt.Errorf("Invalid version byte in encrypted key")) + return nil, aecmk.NewError(aecmk.Decryption, "Invalid version byte in encrypted key", nil) } if keySize != len(cekv.Ciphertext) { - panic(fmt.Errorf("Encrypted key has wrong ciphertext length")) + return nil, aecmk.NewError(aecmk.Decryption, "Encrypted key has wrong ciphertext length", nil) } if keySize != len(cekv.SignedHash) { - panic(fmt.Errorf("Encrypted key signature length mismatch")) + return nil, aecmk.NewError(aecmk.Decryption, "Encrypted key signature length mismatch", nil) } if !cekv.VerifySignature(keyData.publicKey) { - panic(fmt.Errorf("Invalid signature hash")) + return nil, aecmk.NewError(aecmk.Decryption, "Invalid signature hash", nil) } - client := p.getAKVClient(keyData.endpoint) - algorithm := getAlgorithm(encryptionAlgorithm) + client, err := p.getAKVClient(aecmk.Decryption, keyData.endpoint) + if err != nil { + return + } + algorithm, err := getAlgorithm(aecmk.Decryption, encryptionAlgorithm) + if err != nil { + return + } parameters := azkeys.KeyOperationParameters{ Algorithm: &algorithm, Value: cekv.Ciphertext, } - r, err := client.UnwrapKey(context.Background(), keyData.name, keyData.version, parameters, nil) - if err != nil { - panic(fmt.Errorf("Unable to decrypt key %s: %w", masterKeyPath, err)) + r, e := client.UnwrapKey(ctx, keyData.name, keyData.version, parameters, nil) + if e != nil { + err = aecmk.NewError(aecmk.Decryption, fmt.Sprintf("Unable to decrypt key %s", masterKeyPath), e) + } else { + decryptedKey = r.Result } - decryptedKey = r.Result return } // EncryptColumnEncryptionKey encrypts a column encryption key using the column master key with the specified key path and using the specified algorithm. -func (p *Provider) EncryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, cek []byte) []byte { - keyData := p.getKeyData(masterKeyPath) - // just validate the algorith - _ = getAlgorithm(encryptionAlgorithm) +func (p *Provider) EncryptColumnEncryptionKey(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, cek []byte) (buf []byte, err error) { + keyData, err := p.getKeyData(ctx, masterKeyPath, aecmk.Encryption) + if err != nil { + return + } + _, err = getAlgorithm(aecmk.Encryption, encryptionAlgorithm) + if err != nil { + return + } keySize := keyData.publicKey.Size() enc := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewEncoder() // Start with version byte == 1 - buf := []byte{byte(1)} + tmp := []byte{byte(1)} // EncryptedColumnEncryptionKey = version + keyPathLength + ciphertextLength + keyPath + ciphertext + signature // version keyPathBytes, err := enc.Bytes([]byte(strings.ToLower(masterKeyPath))) if err != nil { - panic(fmt.Errorf("Unable to serialize key path %w", err)) + err = aecmk.NewError(aecmk.Encryption, "Unable to serialize key path", err) + return } k := uint16(len(keyPathBytes)) // keyPathLength - buf = append(buf, byte(k), byte(k>>8)) + tmp = append(tmp, byte(k), byte(k>>8)) cipherText, err := rsa.EncryptOAEP(sha1.New(), rand.Reader, keyData.publicKey, cek, []byte{}) if err != nil { - panic(fmt.Errorf("Unable to encrypt data %w", err)) + err = aecmk.NewError(aecmk.Encryption, "Unable to encrypt data", err) + return } l := uint16(len(cipherText)) // ciphertextLength - buf = append(buf, byte(l), byte(l>>8)) + tmp = append(tmp, byte(l), byte(l>>8)) // keypath - buf = append(buf, keyPathBytes...) + tmp = append(tmp, keyPathBytes...) // ciphertext - buf = append(buf, cipherText...) - hash := sha256.Sum256(buf) - client := p.getAKVClient(keyData.endpoint) + tmp = append(tmp, cipherText...) + hash := sha256.Sum256(tmp) + client, err := p.getAKVClient(aecmk.Encryption, keyData.endpoint) + if err != nil { + return + } signAlgorithm := azkeys.SignatureAlgorithmRS256 parameters := azkeys.SignParameters{ Algorithm: &signAlgorithm, Value: hash[:], } - r, err := client.Sign(context.Background(), keyData.name, keyData.version, parameters, nil) + r, err := client.Sign(ctx, keyData.name, keyData.version, parameters, nil) if err != nil { - panic(err) + err = aecmk.NewError(aecmk.Encryption, "AKV failed to sign data", err) + return } if len(r.Result) != keySize { - panic("Signature length doesn't match certificate key size") + err = aecmk.NewError(aecmk.Encryption, "Signature length doesn't match certificate key size", nil) + } else { + // signature + buf = append(tmp, r.Result...) } - // signature - buf = append(buf, r.Result...) - return buf + return } // SignColumnMasterKeyMetadata digitally signs the column master key metadata with the column master key // referenced by the masterKeyPath parameter. The input values used to generate the signature should be the // specified values of the masterKeyPath and allowEnclaveComputations parameters. May return an empty slice if not supported. -func (p *Provider) SignColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) []byte { - return nil +func (p *Provider) SignColumnMasterKeyMetadata(ctx context.Context, masterKeyPath string, allowEnclaveComputations bool) ([]byte, error) { + return nil, nil } // VerifyColumnMasterKeyMetadata verifies the specified signature is valid for the column master key // with the specified key path and the specified enclave behavior. Return nil if not supported. -func (p *Provider) VerifyColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) *bool { - return nil +func (p *Provider) VerifyColumnMasterKeyMetadata(ctx context.Context, masterKeyPath string, allowEnclaveComputations bool) (*bool, error) { + return nil, nil } // KeyLifetime is an optional Duration. Keys fetched by this provider will be discarded after their lifetime expires. @@ -167,51 +186,60 @@ func (p *Provider) KeyLifetime() *time.Duration { return nil } -func getAlgorithm(encryptionAlgorithm string) (algorithm azkeys.EncryptionAlgorithm) { +func getAlgorithm(op aecmk.Operation, encryptionAlgorithm string) (algorithm azkeys.EncryptionAlgorithm, err error) { // support both RSA_OAEP and RSA-OAEP if strings.EqualFold(encryptionAlgorithm, aecmk.KeyEncryptionAlgorithm) { encryptionAlgorithm = string(azkeys.EncryptionAlgorithmRSAOAEP) } if !strings.EqualFold(encryptionAlgorithm, string(azkeys.EncryptionAlgorithmRSAOAEP)) { - panic(fmt.Errorf("Unsupported encryption algorithm %s", encryptionAlgorithm)) + err = aecmk.NewError(op, fmt.Sprintf("Unsupported encryption algorithm %s", encryptionAlgorithm), nil) + } else { + algorithm = azkeys.EncryptionAlgorithmRSAOAEP } - return azkeys.EncryptionAlgorithmRSAOAEP + return } // masterKeyPath is a full URL. The AKV client requires it broken down into endpoint, name, and version // The URL has format '{endpoint}/{host}/keys/{name}/[{version}/]' -func (p *Provider) getKeyData(masterKeyPath string) *keyData { +func (p *Provider) getKeyData(ctx context.Context, masterKeyPath string, op aecmk.Operation) (k *keyData, err error) { endpoint, keypath, allowed := p.allowedPathAndEndpoint(masterKeyPath) if !(allowed) { - return nil + err = aecmk.KeyPathNotAllowed(masterKeyPath, op) + return } - k := &keyData{ + k = &keyData{ endpoint: endpoint, name: keypath[0], } if len(keypath) > 1 { k.version = keypath[1] } - client := p.getAKVClient(endpoint) - r, err := client.GetKey(context.Background(), k.name, k.version, nil) + client, err := p.getAKVClient(op, endpoint) + if err != nil { + return + } + r, err := client.GetKey(ctx, k.name, k.version, nil) if err != nil { - panic(fmt.Errorf("Unable to get key from AKV %w", err)) + err = aecmk.NewError(op, "Unable to get key from AKV. Name:"+masterKeyPath, err) } if r.Key.Kty == nil || (*r.Key.Kty != azkeys.KeyTypeRSA && *r.Key.Kty != azkeys.KeyTypeRSAHSM) { - panic(fmt.Errorf("Key type not supported for Always Encrypted")) + err = aecmk.NewError(op, "Key type not supported for Always Encrypted", nil) } - k.publicKey = &rsa.PublicKey{ - N: new(big.Int).SetBytes(r.Key.N), - E: int(new(big.Int).SetBytes(r.Key.E).Int64()), + if err == nil { + k.publicKey = &rsa.PublicKey{ + N: new(big.Int).SetBytes(r.Key.N), + E: int(new(big.Int).SetBytes(r.Key.E).Int64()), + } } - return k + return } func (p *Provider) allowedPathAndEndpoint(masterKeyPath string) (endpoint string, keypath []string, allowed bool) { allowed = len(p.AllowedLocations) == 0 url, err := url.Parse(masterKeyPath) if err != nil { - panic(fmt.Errorf("Invalid URL for master key path %s: %w", masterKeyPath, err)) + allowed = false + return } if !allowed { @@ -226,7 +254,8 @@ func (p *Provider) allowedPathAndEndpoint(masterKeyPath string) (endpoint string if allowed { pathParts := strings.Split(strings.TrimLeft(url.Path, "/"), "/") if len(pathParts) < 2 || len(pathParts) > 3 || pathParts[0] != "keys" { - panic(fmt.Errorf("Invalid URL for master key path %s", masterKeyPath)) + allowed = false + return } keypath = pathParts[1:] url.Path = "" @@ -237,28 +266,34 @@ func (p *Provider) allowedPathAndEndpoint(masterKeyPath string) (endpoint string return } -func (p *Provider) getAKVClient(endpoint string) (client *azkeys.Client) { - client, err := azkeys.NewClient(endpoint, p.getCredential(endpoint), nil) +func (p *Provider) getAKVClient(op aecmk.Operation, endpoint string) (client *azkeys.Client, err error) { + credential, err := p.getCredential(op, endpoint) + if err == nil { + client, err = azkeys.NewClient(endpoint, credential, nil) + } if err != nil { - panic(fmt.Errorf("Unable to create AKV client %w", err)) + err = aecmk.NewError(op, "Unable to create AKV client", err) } return } -func (p *Provider) getCredential(endpoint string) azcore.TokenCredential { +func (p *Provider) getCredential(op aecmk.Operation, endpoint string) (credential azcore.TokenCredential, err error) { if len(p.credentials) == 0 { - credential, err := azidentity.NewDefaultAzureCredential(nil) + credential, err = azidentity.NewDefaultAzureCredential(nil) if err != nil { - panic(fmt.Errorf("Unable to create a default credential: %w", err)) + err = aecmk.NewError(op, "Unable to create a default credential", err) + } else { + p.credentials[wildcard] = credential } - p.credentials[wildcard] = credential - return credential + return } - if credential, ok := p.credentials[endpoint]; ok { - return credential + var ok bool + if credential, ok = p.credentials[endpoint]; ok { + return } - if credential, ok := p.credentials[wildcard]; ok { - return credential + if credential, ok = p.credentials[wildcard]; ok { + return } - panic(fmt.Errorf("No credential available for AKV path %s", endpoint)) + err = aecmk.NewError(op, fmt.Sprintf("No credential available for AKV path %s", endpoint), nil) + return } diff --git a/aecmk/akv/keyprovider_test.go b/aecmk/akv/keyprovider_test.go index f16f826a..0abcc021 100644 --- a/aecmk/akv/keyprovider_test.go +++ b/aecmk/akv/keyprovider_test.go @@ -4,6 +4,7 @@ package akv import ( + "context" "crypto/rand" "net/url" "testing" @@ -26,9 +27,13 @@ func TestEncryptDecryptRoundTrip(t *testing.T) { plainKey := make([]byte, 32) _, _ = rand.Read(plainKey) t.Log("Plainkey:", plainKey) - encryptedKey := p.EncryptColumnEncryptionKey(keyPath, aecmk.KeyEncryptionAlgorithm, plainKey) - t.Log("Encryptedkey:", encryptedKey) - assert.NotEqualValues(t, plainKey, encryptedKey, "encryptedKey is the same as plainKey") - decryptedKey := p.DecryptColumnEncryptionKey(keyPath, aecmk.KeyEncryptionAlgorithm, encryptedKey) - assert.Equalf(t, plainKey, decryptedKey, "decryptedkey doesn't match plainKey. %v : %v", decryptedKey, plainKey) + encryptedKey, err := p.EncryptColumnEncryptionKey(context.Background(), keyPath, aecmk.KeyEncryptionAlgorithm, plainKey) + if assert.NoError(t, err, "EncryptColumnEncryptionKey") { + t.Log("Encryptedkey:", encryptedKey) + assert.NotEqualValues(t, plainKey, encryptedKey, "encryptedKey is the same as plainKey") + decryptedKey, err := p.DecryptColumnEncryptionKey(context.Background(), keyPath, aecmk.KeyEncryptionAlgorithm, encryptedKey) + if assert.NoError(t, err, "DecryptColumnEncryptionKey") { + assert.Equalf(t, plainKey, decryptedKey, "decryptedkey doesn't match plainKey. %v : %v", decryptedKey, plainKey) + } + } } diff --git a/aecmk/error.go b/aecmk/error.go new file mode 100644 index 00000000..a623b0e8 --- /dev/null +++ b/aecmk/error.go @@ -0,0 +1,39 @@ +package aecmk + +import "fmt" + +// Operation specifies the action that returned an error +type Operation int + +const ( + Decryption Operation = iota + Encryption + Validation +) + +// Error is the type of all errors returned by key encryption providers +type Error struct { + Operation Operation + err error + msg string +} + +func (e *Error) Error() string { + return e.msg +} + +func (e *Error) Unwrap() error { + return e.err +} + +func NewError(operation Operation, msg string, err error) error { + return &Error{ + Operation: operation, + msg: msg, + err: err, + } +} + +func KeyPathNotAllowed(path string, operation Operation) error { + return NewError(operation, fmt.Sprintf("Key path not allowed: %s", path), nil) +} diff --git a/aecmk/keyprovider.go b/aecmk/keyprovider.go index 7cdcb82c..6572f0da 100644 --- a/aecmk/keyprovider.go +++ b/aecmk/keyprovider.go @@ -1,6 +1,7 @@ package aecmk import ( + "context" "fmt" "sync" "time" @@ -36,7 +37,7 @@ func NewCekProvider(provider ColumnEncryptionKeyProvider) *CekProvider { return &CekProvider{Provider: provider, decryptedKeys: make(cekCache), mutex: sync.Mutex{}} } -func (cp *CekProvider) GetDecryptedKey(keyPath string, encryptedBytes []byte) (decryptedKey []byte, err error) { +func (cp *CekProvider) GetDecryptedKey(ctx context.Context, keyPath string, encryptedBytes []byte) (decryptedKey []byte, err error) { cp.mutex.Lock() ev, cachedKey := cp.decryptedKeys[keyPath] if cachedKey { @@ -53,9 +54,9 @@ func (cp *CekProvider) GetDecryptedKey(keyPath string, encryptedBytes []byte) (d // but there'd be high value in having a queue of waiters for decrypting a key stored in the cloud. cp.mutex.Unlock() if !cachedKey { - decryptedKey = cp.Provider.DecryptColumnEncryptionKey(keyPath, KeyEncryptionAlgorithm, encryptedBytes) + decryptedKey, err = cp.Provider.DecryptColumnEncryptionKey(ctx, keyPath, KeyEncryptionAlgorithm, encryptedBytes) } - if !cachedKey { + if err == nil && !cachedKey { duration := cp.Provider.KeyLifetime() if duration == nil { duration = &ColumnEncryptionKeyLifetime @@ -78,22 +79,23 @@ var globalCekProviderFactoryMap = ColumnEncryptionKeyProviderMap{} type ColumnEncryptionKeyProvider interface { // DecryptColumnEncryptionKey decrypts the specified encrypted value of a column encryption key. // The encrypted value is expected to be encrypted using the column master key with the specified key path and using the specified algorithm. - DecryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) []byte + DecryptColumnEncryptionKey(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) ([]byte, error) // EncryptColumnEncryptionKey encrypts a column encryption key using the column master key with the specified key path and using the specified algorithm. - EncryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, cek []byte) []byte + EncryptColumnEncryptionKey(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, cek []byte) ([]byte, error) // SignColumnMasterKeyMetadata digitally signs the column master key metadata with the column master key // referenced by the masterKeyPath parameter. The input values used to generate the signature should be the // specified values of the masterKeyPath and allowEnclaveComputations parameters. May return an empty slice if not supported. - SignColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) []byte + SignColumnMasterKeyMetadata(ctx context.Context, masterKeyPath string, allowEnclaveComputations bool) ([]byte, error) // VerifyColumnMasterKeyMetadata verifies the specified signature is valid for the column master key // with the specified key path and the specified enclave behavior. Return nil if not supported. - VerifyColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) *bool + VerifyColumnMasterKeyMetadata(ctx context.Context, masterKeyPath string, allowEnclaveComputations bool) (*bool, error) // KeyLifetime is an optional Duration. Keys fetched by this provider will be discarded after their lifetime expires. // If it returns nil, the keys will expire based on the value of ColumnEncryptionKeyLifetime. // If it returns zero, the keys will not be cached. KeyLifetime() *time.Duration } +// RegisterCekProvider adds the named provider to the global provider list func RegisterCekProvider(name string, provider ColumnEncryptionKeyProvider) error { _, ok := globalCekProviderFactoryMap[name] if ok { @@ -103,6 +105,7 @@ func RegisterCekProvider(name string, provider ColumnEncryptionKeyProvider) erro return nil } +// GetGlobalCekProviders enumerates all globally registered providers func GetGlobalCekProviders() (providers ColumnEncryptionKeyProviderMap) { providers = make(ColumnEncryptionKeyProviderMap) for i, p := range globalCekProviderFactoryMap { From cf6659125802f2742545abd899038ec313ab679c Mon Sep 17 00:00:00 2001 From: davidshi Date: Mon, 25 Sep 2023 12:12:45 -0500 Subject: [PATCH 2/6] update error handling for AE key providers --- CHANGELOG.md | 13 ++ aecmk/localcert/keyprovider.go | 117 +++++++++++------- .../keyprovider_go117_windows_test.go | 11 +- aecmk/localcert/keyprovider_windows.go | 13 +- alwaysencrypted_test.go | 81 +++++++++++- encrypt.go | 8 +- internal/certs/certs_windows.go | 15 +-- tds_go110_test.go | 10 +- token.go | 51 +++++--- version.go | 2 +- 10 files changed, 230 insertions(+), 91 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ba3dceae..44a14eb7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,17 @@ # Changelog +## 1.7.0 + +### Changed + +* Changed always encrypted key provider error handling not to panic on failure + +### Features + +* Support DER certificates for server authentication (#152) + +### Bug fixes + +* Improved speed of CharsetToUTF8 (#154) ## 1.6.0 diff --git a/aecmk/localcert/keyprovider.go b/aecmk/localcert/keyprovider.go index 24c1bcae..5704197a 100644 --- a/aecmk/localcert/keyprovider.go +++ b/aecmk/localcert/keyprovider.go @@ -4,6 +4,7 @@ package localcert import ( + "context" "crypto" "crypto/rand" "crypto/rsa" @@ -11,7 +12,7 @@ import ( "crypto/sha256" "crypto/x509" "fmt" - "io/ioutil" + "io" "os" "strconv" "strings" @@ -61,52 +62,65 @@ func init() { // DecryptColumnEncryptionKey decrypts the specified encrypted value of a column encryption key. // The encrypted value is expected to be encrypted using the column master key with the specified key path and using the specified algorithm. -func (p *Provider) DecryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) (decryptedKey []byte) { +func (p *Provider) DecryptColumnEncryptionKey(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) (decryptedKey []byte, err error) { decryptedKey = nil - pk, cert, allowed := p.tryLoadCertificate(masterKeyPath) - if !allowed { + err = validateEncryptionAlgorithm(aecmk.Encryption, encryptionAlgorithm) + if err != nil { + return + } + err = validateKeyPathLength(aecmk.Encryption, masterKeyPath) + if err != nil { + return + } + pk, cert, err := p.tryLoadCertificate(aecmk.Decryption, masterKeyPath) + if err != nil { return } cekv := ae.LoadCEKV(encryptedCek) if !cekv.Verify(cert) { - panic(fmt.Errorf("Invalid certificate provided for decryption. Key Store Path: %s. <%s>-<%v>", masterKeyPath, cekv.KeyPath, fmt.Sprintf("%02x", sha1.Sum(cert.Raw)))) + err = aecmk.NewError(aecmk.Decryption, fmt.Sprintf("Invalid certificate provided for decryption. Key Store Path: %s. <%s>-<%v>", masterKeyPath, cekv.KeyPath, fmt.Sprintf("%02x", sha1.Sum(cert.Raw))), nil) } - decryptedKey, err := cekv.Decrypt(pk.(*rsa.PrivateKey)) + decryptedKey, err = cekv.Decrypt(pk.(*rsa.PrivateKey)) if err != nil { - panic(err) + err = aecmk.NewError(aecmk.Decryption, fmt.Sprintf("Decryption failed using %s", masterKeyPath), err) } return } -func (p *Provider) tryLoadCertificate(masterKeyPath string) (privateKey interface{}, cert *x509.Certificate, allowed bool) { - allowed = len(p.AllowedLocations) == 0 +func (p *Provider) tryLoadCertificate(op aecmk.Operation, masterKeyPath string) (privateKey interface{}, cert *x509.Certificate, err error) { + allowed := len(p.AllowedLocations) == 0 if !allowed { loop: for _, l := range p.AllowedLocations { - if l == masterKeyPath { + if strings.HasPrefix(masterKeyPath, l) { allowed = true break loop } } } if !allowed { + err = aecmk.KeyPathNotAllowed(masterKeyPath, op) return } switch p.name { case PfxKeyProviderName: - privateKey, cert = p.loadLocalCertificate(masterKeyPath) + privateKey, cert, err = p.loadLocalCertificate(masterKeyPath) case aecmk.CertificateStoreKeyProvider: - privateKey, cert = p.loadWindowsCertStoreCertificate(masterKeyPath) + privateKey, cert, err = p.loadWindowsCertStoreCertificate(masterKeyPath) + } + if err != nil { + err = aecmk.NewError(op, "Unable to load certificate", err) } return } -func (p *Provider) loadLocalCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { - if f, err := os.Open(path); err == nil { - pfxBytes, err := ioutil.ReadAll(f) - if err != nil { - panic(invalidCertificatePath(path, err)) +func (p *Provider) loadLocalCertificate(path string) (privateKey interface{}, cert *x509.Certificate, err error) { + if f, e := os.Open(path); e == nil { + pfxBytes, er := io.ReadAll(f) + if er != nil { + err = invalidCertificatePath(path, er) + return } pwd, ok := p.passwords[path] if !ok { @@ -116,75 +130,82 @@ func (p *Provider) loadLocalCertificate(path string) (privateKey interface{}, ce } } privateKey, cert, err = pkcs.Decode(pfxBytes, pwd) - if err != nil { - panic(err) - } } else { - panic(invalidCertificatePath(path, err)) + err = invalidCertificatePath(path, err) } return } // EncryptColumnEncryptionKey encrypts a column encryption key using the column master key with the specified key path and using the specified algorithm. -func (p *Provider) EncryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, cek []byte) []byte { +func (p *Provider) EncryptColumnEncryptionKey(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, cek []byte) (buf []byte, err error) { - validateEncryptionAlgorithm(encryptionAlgorithm) - validateKeyPathLength(masterKeyPath) - pk, cert, allowed := p.tryLoadCertificate(masterKeyPath) - if !allowed { - panic(fmt.Errorf("Key path not allowed for use in column key encryption")) + err = validateEncryptionAlgorithm(aecmk.Encryption, encryptionAlgorithm) + if err != nil { + return + } + err = validateKeyPathLength(aecmk.Encryption, masterKeyPath) + if err != nil { + return + } + pk, cert, err := p.tryLoadCertificate(aecmk.Encryption, masterKeyPath) + if err != nil { + return nil, err } publicKey := cert.PublicKey.(*rsa.PublicKey) keySizeInBytes := publicKey.Size() enc := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewEncoder() // Start with version byte == 1 - buf := []byte{byte(1)} + tmp := []byte{byte(1)} // EncryptedColumnEncryptionKey = version + keyPathLength + ciphertextLength + keyPath + ciphertext + signature // version keyPathBytes, err := enc.Bytes([]byte(strings.ToLower(masterKeyPath))) if err != nil { - panic(fmt.Errorf("Unable to serialize key path %w", err)) + err = aecmk.NewError(aecmk.Encryption, fmt.Sprintf("Unable to serialize key path"), err) + return } k := uint16(len(keyPathBytes)) // keyPathLength - buf = append(buf, byte(k), byte(k>>8)) + tmp = append(tmp, byte(k), byte(k>>8)) cipherText, err := rsa.EncryptOAEP(sha1.New(), rand.Reader, publicKey, cek, []byte{}) if err != nil { - panic(fmt.Errorf("Unable to encrypt data %w", err)) + err = aecmk.NewError(aecmk.Encryption, "Unable to encrypt data", err) + return } l := uint16(len(cipherText)) // ciphertextLength - buf = append(buf, byte(l), byte(l>>8)) + tmp = append(tmp, byte(l), byte(l>>8)) // keypath - buf = append(buf, keyPathBytes...) + tmp = append(tmp, keyPathBytes...) // ciphertext - buf = append(buf, cipherText...) - hash := sha256.Sum256(buf) + tmp = append(tmp, cipherText...) + hash := sha256.Sum256(tmp) // signature is the signed hash of the current buf sig, err := rsa.SignPKCS1v15(rand.Reader, pk.(*rsa.PrivateKey), crypto.SHA256, hash[:]) if err != nil { - panic(err) + err = aecmk.NewError(aecmk.Encryption, "Unable to sign encrypted data", err) + return } if len(sig) != keySizeInBytes { - panic("Signature length doesn't match certificate key size") + err = aecmk.NewError(aecmk.Encryption, "Signature length doesn't match certificate key size", nil) + } else { + buf = append(tmp, sig...) } - buf = append(buf, sig...) - return buf + return } // SignColumnMasterKeyMetadata digitally signs the column master key metadata with the column master key // referenced by the masterKeyPath parameter. The input values used to generate the signature should be the // specified values of the masterKeyPath and allowEnclaveComputations parameters. May return an empty slice if not supported. -func (p *Provider) SignColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) []byte { - return nil +func (p *Provider) SignColumnMasterKeyMetadata(ctx context.Context, masterKeyPath string, allowEnclaveComputations bool) ([]byte, error) { + return nil, nil } // VerifyColumnMasterKeyMetadata verifies the specified signature is valid for the column master key // with the specified key path and the specified enclave behavior. Return nil if not supported. -func (p *Provider) VerifyColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) *bool { - return nil +func (p *Provider) VerifyColumnMasterKeyMetadata(ctx context.Context, masterKeyPath string, allowEnclaveComputations bool) (*bool, error) { + return nil, nil } // KeyLifetime is an optional Duration. Keys fetched by this provider will be discarded after their lifetime expires. @@ -194,16 +215,18 @@ func (p *Provider) KeyLifetime() *time.Duration { return nil } -func validateEncryptionAlgorithm(encryptionAlgorithm string) { +func validateEncryptionAlgorithm(op aecmk.Operation, encryptionAlgorithm string) error { if !strings.EqualFold(encryptionAlgorithm, "RSA_OAEP") { - panic(fmt.Errorf("Unsupported encryption algorithm %s", encryptionAlgorithm)) + return aecmk.NewError(op, fmt.Sprintf("Unsupported encryption algorithm %s", encryptionAlgorithm), nil) } + return nil } -func validateKeyPathLength(keyPath string) { +func validateKeyPathLength(op aecmk.Operation, keyPath string) error { if len(keyPath) > 32767 { - panic(fmt.Errorf("Key path is too long. %d", len(keyPath))) + return aecmk.NewError(op, fmt.Sprintf("Key path is too long. %d", len(keyPath)), nil) } + return nil } // InvalidCertificatePathError indicates the provided path could not be used to load a certificate diff --git a/aecmk/localcert/keyprovider_go117_windows_test.go b/aecmk/localcert/keyprovider_go117_windows_test.go index 9212d4b8..0da6b0eb 100644 --- a/aecmk/localcert/keyprovider_go117_windows_test.go +++ b/aecmk/localcert/keyprovider_go117_windows_test.go @@ -4,12 +4,14 @@ package localcert import ( + "context" "crypto/rsa" "strings" "testing" "github.com/microsoft/go-mssqldb/aecmk" "github.com/microsoft/go-mssqldb/internal/certs" + "github.com/stretchr/testify/assert" ) func TestLoadWindowsCertStoreCertificate(t *testing.T) { @@ -19,7 +21,8 @@ func TestLoadWindowsCertStoreCertificate(t *testing.T) { } defer certs.DeleteMasterKeyCert(thumbprint) provider := aecmk.GetGlobalCekProviders()[aecmk.CertificateStoreKeyProvider].Provider.(*Provider) - pk, cert := provider.loadWindowsCertStoreCertificate("CurrentUser/My/" + thumbprint) + pk, cert, err := provider.loadWindowsCertStoreCertificate("CurrentUser/My/" + thumbprint) + assert.NoError(t, err, "loadWindowsCertStoreCertificate") switch z := pk.(type) { case *rsa.PrivateKey: @@ -41,8 +44,10 @@ func TestEncryptDecryptEncryptionKeyRoundTrip(t *testing.T) { bytesToEncrypt := []byte{1, 2, 3} keyPath := "CurrentUser/My/" + thumbprint provider := aecmk.GetGlobalCekProviders()[aecmk.CertificateStoreKeyProvider].Provider.(*Provider) - encryptedBytes := provider.EncryptColumnEncryptionKey(keyPath, "RSA_OAEP", bytesToEncrypt) - decryptedBytes := provider.DecryptColumnEncryptionKey(keyPath, "RSA_OAEP", encryptedBytes) + encryptedBytes, err := provider.EncryptColumnEncryptionKey(context.Background(), keyPath, "RSA_OAEP", bytesToEncrypt) + assert.NoError(t, err, "Encrypt") + decryptedBytes, err := provider.DecryptColumnEncryptionKey(context.Background(), keyPath, "RSA_OAEP", encryptedBytes) + assert.NoError(t, err, "Decrypt") if len(decryptedBytes) != 3 || decryptedBytes[0] != 1 || decryptedBytes[1] != 2 || decryptedBytes[2] != 3 { t.Fatalf("Encrypt/Decrypt did not roundtrip. encryptedBytes:%v, decryptedBytes: %v", encryptedBytes, decryptedBytes) } diff --git a/aecmk/localcert/keyprovider_windows.go b/aecmk/localcert/keyprovider_windows.go index 25c7fa20..ad52a125 100644 --- a/aecmk/localcert/keyprovider_windows.go +++ b/aecmk/localcert/keyprovider_windows.go @@ -23,12 +23,13 @@ func init() { } } -func (p *Provider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { +func (p *Provider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate, err error) { privateKey = nil cert = nil pathParts := strings.Split(path, `/`) if len(pathParts) != 3 { - panic(invalidCertificatePath(path, fmt.Errorf("key store path requires 3 segments"))) + err = invalidCertificatePath(path, fmt.Errorf("key store path requires 3 segments")) + return } var storeId uint32 @@ -38,18 +39,20 @@ func (p *Provider) loadWindowsCertStoreCertificate(path string) (privateKey inte case "currentuser": storeId = windows.CERT_SYSTEM_STORE_CURRENT_USER default: - panic(invalidCertificatePath(path, fmt.Errorf("Unknown certificate store"))) + err = invalidCertificatePath(path, fmt.Errorf("Unknown certificate store")) + return } system, err := windows.UTF16PtrFromString(pathParts[1]) if err != nil { - panic(err) + err = invalidCertificatePath(path, err) + return } h, err := windows.CertOpenStore(windows.CERT_STORE_PROV_SYSTEM, windows.PKCS_7_ASN_ENCODING|windows.X509_ASN_ENCODING, 0, storeId, uintptr(unsafe.Pointer(system))) if err != nil { - panic(err) + return } defer windows.CertCloseStore(h, 0) signature := thumbprintToByteArray(pathParts[2]) diff --git a/alwaysencrypted_test.go b/alwaysencrypted_test.go index 05260c43..76d8a5bf 100644 --- a/alwaysencrypted_test.go +++ b/alwaysencrypted_test.go @@ -1,6 +1,7 @@ package mssql import ( + "context" "crypto/rand" "database/sql" "fmt" @@ -68,6 +69,8 @@ func TestAlwaysEncryptedE2E(t *testing.T) { // {"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}}, } for _, test := range providerTests { + // turn off key caching + aecmk.ColumnEncryptionKeyLifetime = 0 t.Run(test.Name(), func(t *testing.T) { conn, _ := open(t) defer conn.Close() @@ -86,9 +89,10 @@ func TestAlwaysEncryptedE2E(t *testing.T) { tableName := fmt.Sprintf("mssqlAe%d", r.Int64()) keyBytes := make([]byte, 32) _, _ = rand.Read(keyBytes) - encryptedCek := test.GetProvider(t).EncryptColumnEncryptionKey(certPath, KeyEncryptionAlgorithm, keyBytes) + encryptedCek, err := test.GetProvider(t).EncryptColumnEncryptionKey(context.Background(), certPath, KeyEncryptionAlgorithm, keyBytes) + assert.NoError(t, err, "Encrypt") createCek := fmt.Sprintf(createColumnEncryptionKey, cekName, certPath, encryptedCek) - _, err := conn.Exec(createCek) + _, err = conn.Exec(createCek) assert.NoError(t, err, "Unable to create CEK") defer func() { _, err := conn.Exec(fmt.Sprintf(dropColumnEncryptionKey, cekName)) @@ -178,10 +182,41 @@ func TestAlwaysEncryptedE2E(t *testing.T) { _ = rows.Next() err = rows.Err() assert.NoError(t, err, "rows.Err() has non-nil values") + testProviderErrorHandling(t, test.Name(), test.GetProvider(t), sel.String(), insert.String(), insertArgs) }) } } +func testProviderErrorHandling(t *testing.T, name string, provider aecmk.ColumnEncryptionKeyProvider, sel string, insert string, insertArgs []interface{}) { + t.Helper() + testProvider := &testKeyProvider{fallback: provider} + connector, _ := getTestConnector(t) + connector.RegisterCekProvider(name, testProvider) + conn := sql.OpenDB(connector) + defer conn.Close() + testProvider.decrypt = func(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, cek []byte) ([]byte, error) { + return nil, context.DeadlineExceeded + } + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(1*time.Hour)) + defer cancel() + rows, err := conn.QueryContext(ctx, sel) + if assert.NoError(t, err, "Exec should return no error") { + if rows.Next() { + assert.Fail(t, "rows.Next should have failed") + } + assert.ErrorIs(t, rows.Err(), context.DeadlineExceeded) + } + + var notAllowed error + testProvider.decrypt = func(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, cek []byte) ([]byte, error) { + notAllowed = aecmk.KeyPathNotAllowed(masterKeyPath, aecmk.Decryption) + return nil, notAllowed + } + _, err = conn.Exec(insert, insertArgs...) + assert.ErrorIs(t, err, notAllowed, "Insert should fail with key path not allowed") + +} + func comparisonValueFromObject(object interface{}) string { switch v := object.(type) { case []byte: @@ -221,3 +256,45 @@ const ( COLUMN_ENCRYPTION_KEY = [%s]) )` ) + +// Parameterized implementation of a key provider +type testKeyProvider struct { + encrypt func(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, cek []byte) ([]byte, error) + decrypt func(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) ([]byte, error) + lifetime *time.Duration + fallback aecmk.ColumnEncryptionKeyProvider +} + +func (p *testKeyProvider) DecryptColumnEncryptionKey(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) (decryptedKey []byte, err error) { + if p.decrypt != nil { + return p.decrypt(ctx, masterKeyPath, encryptionAlgorithm, encryptedCek) + } + return p.fallback.DecryptColumnEncryptionKey(ctx, masterKeyPath, encryptionAlgorithm, encryptedCek) +} + +func (p *testKeyProvider) EncryptColumnEncryptionKey(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, cek []byte) ([]byte, error) { + if p.encrypt != nil { + return p.encrypt(ctx, masterKeyPath, encryptionAlgorithm, cek) + } + return p.fallback.EncryptColumnEncryptionKey(ctx, masterKeyPath, encryptionAlgorithm, cek) +} + +func (p *testKeyProvider) SignColumnMasterKeyMetadata(ctx context.Context, masterKeyPath string, allowEnclaveComputations bool) ([]byte, error) { + return nil, nil +} + +// VerifyColumnMasterKeyMetadata verifies the specified signature is valid for the column master key +// with the specified key path and the specified enclave behavior. Return nil if not supported. +func (p *testKeyProvider) VerifyColumnMasterKeyMetadata(ctx context.Context, masterKeyPath string, allowEnclaveComputations bool) (*bool, error) { + return nil, nil +} + +// KeyLifetime is an optional Duration. Keys fetched by this provider will be discarded after their lifetime expires. +// If it returns nil, the keys will expire based on the value of ColumnEncryptionKeyLifetime. +// If it returns zero, the keys will not be cached. +func (p *testKeyProvider) KeyLifetime() *time.Duration { + if p.lifetime != nil { + return p.lifetime + } + return p.fallback.KeyLifetime() +} diff --git a/encrypt.go b/encrypt.go index 0e04837a..a395d88a 100644 --- a/encrypt.go +++ b/encrypt.go @@ -80,7 +80,7 @@ func (s *Stmt) encryptArgs(ctx context.Context, args []namedValue) (encryptedArg if len(cekInfo) == 0 { return args, nil } - err = s.decryptCek(cekInfo) + err = s.decryptCek(ctx, cekInfo) if err != nil { return } @@ -141,13 +141,13 @@ func (s *Stmt) buildParametersForColumnEncryption(args []namedValue) (parameters return } -func (s *Stmt) decryptCek(cekInfo []*cekData) error { +func (s *Stmt) decryptCek(ctx context.Context, cekInfo []*cekData) error { for _, info := range cekInfo { kp, ok := s.c.sess.aeSettings.keyProviders[info.cmkStoreName] if !ok { return fmt.Errorf("No provider found for key store %s", info.cmkStoreName) } - dk, err := kp.GetDecryptedKey(info.cmkPath, info.encryptedValue) + dk, err := kp.GetDecryptedKey(ctx, info.cmkPath, info.encryptedValue) if err != nil { return err } @@ -285,7 +285,7 @@ func processDescribeParameterEncryption(rows driver.Rows) (cekInfo []*cekData, p if qerr != io.EOF { err = qerr } else { - err = fmt.Errorf("No parameter encryption rows were returned from sp_describe_parameter_encryption") + badStreamPanic(fmt.Errorf("No parameter encryption rows were returned from sp_describe_parameter_encryption")) } } return diff --git a/internal/certs/certs_windows.go b/internal/certs/certs_windows.go index 5577dd1a..37a3e042 100644 --- a/internal/certs/certs_windows.go +++ b/internal/certs/certs_windows.go @@ -15,9 +15,8 @@ import ( "golang.org/x/sys/windows" ) -func FindCertBySignatureHash(storeHandle windows.Handle, hash []byte) (interface{}, *x509.Certificate) { +func FindCertBySignatureHash(storeHandle windows.Handle, hash []byte) (pk interface{}, cert *x509.Certificate, err error) { var certContext *windows.CertContext - var err error cryptoAPIBlob := windows.CryptHashBlob{ Size: uint32(len(hash)), Data: &hash[0], @@ -32,15 +31,11 @@ func FindCertBySignatureHash(storeHandle windows.Handle, hash []byte) (interface nil) if err != nil { - - panic(fmt.Errorf("Unable to find certificate by signature hash. %s", err.Error())) - } - pk, cert, err := certContextToX509(certContext) - if err != nil { - panic(err) + err = fmt.Errorf("Unable to find certificate by signature hash. %w", err) + return } - - return pk, cert + pk, cert, err = certContextToX509(certContext) + return } func certContextToX509(ctx *windows.CertContext) (pk interface{}, cert *x509.Certificate, err error) { diff --git a/tds_go110_test.go b/tds_go110_test.go index c36f29a9..3b6bb716 100644 --- a/tds_go110_test.go +++ b/tds_go110_test.go @@ -1,3 +1,4 @@ +//go:build go1.10 // +build go1.10 package mssql @@ -8,6 +9,12 @@ import ( ) func open(t testing.TB) (*sql.DB, *testLogger) { + connector, logger := getTestConnector(t) + conn := sql.OpenDB(connector) + return conn, logger +} + +func getTestConnector(t testing.TB) (*Connector, *testLogger) { tl := testLogger{t: t} SetLogger(&tl) connector, err := NewConnector(makeConnStr(t).String()) @@ -15,6 +22,5 @@ func open(t testing.TB) (*sql.DB, *testLogger) { t.Error("Open connection failed:", err.Error()) return nil, &tl } - conn := sql.OpenDB(connector) - return conn, &tl + return connector, nil } diff --git a/token.go b/token.go index 323ddbd7..bf8fffc3 100644 --- a/token.go +++ b/token.go @@ -11,6 +11,7 @@ import ( "strconv" "github.com/golang-sql/sqlexp" + "github.com/microsoft/go-mssqldb/aecmk" "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms" "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption" "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys" @@ -694,7 +695,7 @@ func parseCryptoMetadata(r *tdsBuffer, cekTable *cekTable) cryptoMetadata { if cekTable != nil { if int(ordinal) > len(cekTable.entries)-1 { - panic(fmt.Errorf("invalid ordinal, cekTable only has %d entries", len(cekTable.entries))) + badStreamPanicf("invalid ordinal, cekTable only has %d entries", len(cekTable.entries)) } entry = &cekTable.entries[ordinal] } @@ -732,7 +733,7 @@ func readCekTableEntry(r *tdsBuffer) cekTableEntry { var cekMdVersion = make([]byte, 8) _, err := r.Read(cekMdVersion) if err != nil { - panic("unable to read cekMdVersion") + badStreamPanicf("unable to read cekMdVersion") } cekValueCount := uint(r.byte()) @@ -784,7 +785,7 @@ func readCekTableEntry(r *tdsBuffer) cekTableEntry { } // http://msdn.microsoft.com/en-us/library/dd357254.aspx -func parseRow(r *tdsBuffer, s *tdsSession, columns []columnStruct, row []interface{}) { +func parseRow(ctx context.Context, r *tdsBuffer, s *tdsSession, columns []columnStruct, row []interface{}) error { for i, column := range columns { columnContent := column.ti.Reader(&column.ti, r, nil) if columnContent == nil { @@ -793,13 +794,17 @@ func parseRow(r *tdsBuffer, s *tdsSession, columns []columnStruct, row []interfa } if column.isEncrypted() { - buffer := decryptColumn(column, s, columnContent) + buffer, err := decryptColumn(ctx, column, s, columnContent) + if err != nil { + return err + } // Decrypt - row[i] = column.cryptoMeta.typeInfo.Reader(&column.cryptoMeta.typeInfo, &buffer, column.cryptoMeta) + row[i] = column.cryptoMeta.typeInfo.Reader(&column.cryptoMeta.typeInfo, buffer, column.cryptoMeta) } else { row[i] = columnContent } } + return nil } type RWCBuffer struct { @@ -818,7 +823,7 @@ func (R RWCBuffer) Close() error { return nil } -func decryptColumn(column columnStruct, s *tdsSession, columnContent interface{}) tdsBuffer { +func decryptColumn(ctx context.Context, column columnStruct, s *tdsSession, columnContent interface{}) (*tdsBuffer, error) { encType := encryption.From(column.cryptoMeta.encType) cekValue := column.cryptoMeta.entry.cekValues[column.cryptoMeta.ordinal] if (s.logFlags & uint64(msdsn.LogDebug)) == uint64(msdsn.LogDebug) { @@ -827,17 +832,18 @@ func decryptColumn(column columnStruct, s *tdsSession, columnContent interface{} cekProvider, ok := s.aeSettings.keyProviders[cekValue.keyStoreName] if !ok { - panic(fmt.Errorf("Unable to find provider %s to decrypt CEK", cekValue.keyStoreName)) + // The app hasn't installed the key provider it needs + panic(aecmk.NewError(aecmk.Decryption, fmt.Sprintf("Unable to find provider %s to decrypt CEK", cekValue.keyStoreName), nil)) } - cek, err := cekProvider.GetDecryptedKey(cekValue.keyPath, column.cryptoMeta.entry.cekValues[0].encryptedKey) + cek, err := cekProvider.GetDecryptedKey(ctx, cekValue.keyPath, column.cryptoMeta.entry.cekValues[0].encryptedKey) if err != nil { - panic(err) + return nil, err } k := keys.NewAeadAes256CbcHmac256(cek) alg := algorithms.NewAeadAes256CbcHmac256Algorithm(k, encType, byte(cekValue.cekVersion)) d, err := alg.Decrypt(columnContent.([]byte)) if err != nil { - panic(err) + return nil, aecmk.NewError(aecmk.Decryption, "Unable to decrypt key using AES256", err) } // Decrypt returns a minimum of 8 bytes so truncate to the actual data size @@ -853,11 +859,11 @@ func decryptColumn(column columnStruct, s *tdsSession, columnContent interface{} column.cryptoMeta.typeInfo.Buffer = d buffer := tdsBuffer{rpos: 0, rsize: len(newBuff), rbuf: newBuff, transport: rwc} - return buffer + return &buffer, nil } // http://msdn.microsoft.com/en-us/library/dd304783.aspx -func parseNbcRow(r *tdsBuffer, s *tdsSession, columns []columnStruct, row []interface{}) { +func parseNbcRow(ctx context.Context, r *tdsBuffer, s *tdsSession, columns []columnStruct, row []interface{}) error { bitlen := (len(columns) + 7) / 8 pres := make([]byte, bitlen) r.ReadFull(pres) @@ -868,14 +874,17 @@ func parseNbcRow(r *tdsBuffer, s *tdsSession, columns []columnStruct, row []inte } columnContent := col.ti.Reader(&col.ti, r, nil) if col.isEncrypted() { - buffer := decryptColumn(col, s, columnContent) + buffer, err := decryptColumn(ctx, col, s, columnContent) + if err != nil { + return err + } // Decrypt - row[i] = col.cryptoMeta.typeInfo.Reader(&col.cryptoMeta.typeInfo, &buffer, col.cryptoMeta) + row[i] = col.cryptoMeta.typeInfo.Reader(&col.cryptoMeta.typeInfo, buffer, col.cryptoMeta) } else { row[i] = columnContent } - } + return nil } // http://msdn.microsoft.com/en-us/library/dd304156.aspx @@ -1079,11 +1088,19 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS case tokenRow: row := make([]interface{}, len(columns)) - parseRow(sess.buf, sess, columns, row) + err = parseRow(ctx, sess.buf, sess, columns, row) + if err != nil { + ch <- err + return + } ch <- row case tokenNbcRow: row := make([]interface{}, len(columns)) - parseNbcRow(sess.buf, sess, columns, row) + err = parseNbcRow(ctx, sess.buf, sess, columns, row) + if err != nil { + ch <- err + return + } ch <- row case tokenEnvChange: processEnvChg(ctx, sess) diff --git a/version.go b/version.go index 256e9b4e..161b4dde 100644 --- a/version.go +++ b/version.go @@ -4,7 +4,7 @@ import "fmt" // Update this variable with the release tag before pushing the tag // This value is written to the prelogin and login7 packets during a new connection -const driverVersion = "v1.6.0" +const driverVersion = "v1.7.0" func getDriverVersion(ver string) uint32 { var majorVersion uint32 From f1db5022c0536d40f1be7750527671218e23cadc Mon Sep 17 00:00:00 2001 From: davidshi Date: Tue, 26 Sep 2023 17:15:24 -0500 Subject: [PATCH 3/6] fix build --- aecmk/localcert/keyprovider_darwin.go | 4 ++-- aecmk/localcert/keyprovider_linux.go | 4 ++-- alwaysencrypted_test.go | 1 + 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/aecmk/localcert/keyprovider_darwin.go b/aecmk/localcert/keyprovider_darwin.go index c3a7564a..2b7c2f0a 100644 --- a/aecmk/localcert/keyprovider_darwin.go +++ b/aecmk/localcert/keyprovider_darwin.go @@ -8,7 +8,7 @@ import ( "fmt" ) -func (p *Provider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { - panic(fmt.Errorf("Windows cert store not supported on this OS")) +func (p *Provider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate, err error) { + err = fmt.Errorf("Windows cert store not supported on this OS") return } diff --git a/aecmk/localcert/keyprovider_linux.go b/aecmk/localcert/keyprovider_linux.go index c3a7564a..2b7c2f0a 100644 --- a/aecmk/localcert/keyprovider_linux.go +++ b/aecmk/localcert/keyprovider_linux.go @@ -8,7 +8,7 @@ import ( "fmt" ) -func (p *Provider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { - panic(fmt.Errorf("Windows cert store not supported on this OS")) +func (p *Provider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate, err error) { + err = fmt.Errorf("Windows cert store not supported on this OS") return } diff --git a/alwaysencrypted_test.go b/alwaysencrypted_test.go index 76d8a5bf..43378f67 100644 --- a/alwaysencrypted_test.go +++ b/alwaysencrypted_test.go @@ -200,6 +200,7 @@ func testProviderErrorHandling(t *testing.T, name string, provider aecmk.ColumnE ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(1*time.Hour)) defer cancel() rows, err := conn.QueryContext(ctx, sel) + defer rows.Close() if assert.NoError(t, err, "Exec should return no error") { if rows.Next() { assert.Fail(t, "rows.Next should have failed") From b119f0e23e7db209a8932b069c14df0394dcf59f Mon Sep 17 00:00:00 2001 From: davidshi Date: Tue, 26 Sep 2023 17:20:45 -0500 Subject: [PATCH 4/6] fix error handling --- aecmk/localcert/keyprovider.go | 3 ++- alwaysencrypted_test.go | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/aecmk/localcert/keyprovider.go b/aecmk/localcert/keyprovider.go index 5704197a..f95aa9ad 100644 --- a/aecmk/localcert/keyprovider.go +++ b/aecmk/localcert/keyprovider.go @@ -79,6 +79,7 @@ func (p *Provider) DecryptColumnEncryptionKey(ctx context.Context, masterKeyPath cekv := ae.LoadCEKV(encryptedCek) if !cekv.Verify(cert) { err = aecmk.NewError(aecmk.Decryption, fmt.Sprintf("Invalid certificate provided for decryption. Key Store Path: %s. <%s>-<%v>", masterKeyPath, cekv.KeyPath, fmt.Sprintf("%02x", sha1.Sum(cert.Raw))), nil) + return } decryptedKey, err = cekv.Decrypt(pk.(*rsa.PrivateKey)) @@ -161,7 +162,7 @@ func (p *Provider) EncryptColumnEncryptionKey(ctx context.Context, masterKeyPath // version keyPathBytes, err := enc.Bytes([]byte(strings.ToLower(masterKeyPath))) if err != nil { - err = aecmk.NewError(aecmk.Encryption, fmt.Sprintf("Unable to serialize key path"), err) + err = aecmk.NewError(aecmk.Encryption, "Unable to serialize key path", err) return } k := uint16(len(keyPathBytes)) diff --git a/alwaysencrypted_test.go b/alwaysencrypted_test.go index 43378f67..eadd7330 100644 --- a/alwaysencrypted_test.go +++ b/alwaysencrypted_test.go @@ -200,7 +200,9 @@ func testProviderErrorHandling(t *testing.T, name string, provider aecmk.ColumnE ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(1*time.Hour)) defer cancel() rows, err := conn.QueryContext(ctx, sel) - defer rows.Close() + if err != nil { + defer rows.Close() + } if assert.NoError(t, err, "Exec should return no error") { if rows.Next() { assert.Fail(t, "rows.Next should have failed") From 6e3d0a91d2a53c3ad6cae8e06c973b8641c5ca00 Mon Sep 17 00:00:00 2001 From: davidshi Date: Tue, 10 Oct 2023 16:06:19 -0500 Subject: [PATCH 5/6] fix test logger function --- tds_go110_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tds_go110_test.go b/tds_go110_test.go index 3b6bb716..76ecfc66 100644 --- a/tds_go110_test.go +++ b/tds_go110_test.go @@ -22,5 +22,5 @@ func getTestConnector(t testing.TB) (*Connector, *testLogger) { t.Error("Open connection failed:", err.Error()) return nil, &tl } - return connector, nil + return connector, &tl } From 3309cf9f468cb7ad8c74b9c9061a7c1ef1ecc664 Mon Sep 17 00:00:00 2001 From: davidshi Date: Tue, 10 Oct 2023 17:05:10 -0500 Subject: [PATCH 6/6] close Rows --- alwaysencrypted_test.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/alwaysencrypted_test.go b/alwaysencrypted_test.go index eadd7330..c78d638f 100644 --- a/alwaysencrypted_test.go +++ b/alwaysencrypted_test.go @@ -200,9 +200,8 @@ func testProviderErrorHandling(t *testing.T, name string, provider aecmk.ColumnE ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(1*time.Hour)) defer cancel() rows, err := conn.QueryContext(ctx, sel) - if err != nil { - defer rows.Close() - } + defer rows.Close() + if assert.NoError(t, err, "Exec should return no error") { if rows.Next() { assert.Fail(t, "rows.Next should have failed")