From 514012a5b52921d2173dd2d5af4ba72abf1eb1d8 Mon Sep 17 00:00:00 2001 From: David Shiflet Date: Thu, 10 Aug 2023 10:36:58 -0500 Subject: [PATCH] Start Always Encrypted feature branch (#116) * add core CEK parameters and types * add column encryption featureext * Add parsing of always encrypted tokens * implement local cert key provider * use key providers for decrypt * implement EncryptColumnEncryptionKey for local cert * add cipher data to parameters * copy swisscom code locally * implement Encrypt * don't claim to support enclaves * update readme * fix Scan to use correct data types * make cert store provider go1.17+ * rename files for clarity * update dependencies and min Go version * update reviewdog * remove old SQL versions from PR build --- .github/workflows/pr-validation.yml | 4 +- .github/workflows/reviewdog.yml | 4 +- .pipelines/TestSql2017.yml | 2 +- README.md | 56 ++- aecmk/keyprovider.go | 112 ++++++ aecmk/localcert/keyprovider.go | 240 +++++++++++++ aecmk/localcert/keyprovider_darwin.go | 14 + .../keyprovider_go117_windows_test.go | 49 +++ aecmk/localcert/keyprovider_linux.go | 14 + aecmk/localcert/keyprovider_test.go | 18 + aecmk/localcert/keyprovider_windows.go | 57 +++ alwaysencrypted_windows_test.go | 220 ++++++++++++ appveyor.yml | 46 +-- bulkcopy.go | 4 + columnencryptionkey.go | 40 +++ encrypt.go | 292 +++++++++++++++ encrypt_test.go | 119 ++++++ go.mod | 6 +- go.sum | 2 - internal/certs/certs.go | 55 +++ internal/certs/certs_windows.go | 176 +++++++++ .../mssql-always-encrypted/LICENSE.txt | 20 ++ .../swisscom/mssql-always-encrypted/README.md | 5 + .../aead_aes_256_cbc_hmac_sha256.go | 120 +++++++ .../aead_aes_256_cbc_hmac_sha256_test.go | 37 ++ .../pkg/algorithms/algorithm.go | 6 + .../pkg/alwaysencrypted.go | 79 ++++ .../pkg/alwaysencrypted_test.go | 106 ++++++ .../pkg/crypto/aes_cbc_pkcs5.go | 69 ++++ .../pkg/crypto/utils.go | 12 + .../pkg/encryption/type.go | 37 ++ .../pkg/keys/aead_aes_256_cbc_hmac_256.go | 51 +++ .../mssql-always-encrypted/pkg/keys/key.go | 5 + .../mssql-always-encrypted/pkg/utils/utf16.go | 18 + .../test/always-encrypted.pem | 28 ++ .../test/always-encrypted_pub.pem | 19 + .../mssql-always-encrypted/test/cekv.key | Bin 0 -> 627 bytes .../test/column_value.enc | 2 + .../test/decrypted_key.key | Bin 0 -> 627 bytes msdsn/conn_str.go | 36 +- msdsn/conn_str_test.go | 12 +- mssql.go | 117 ++++-- mssql_go19.go | 22 +- quoter.go | 40 +++ rpc.go | 20 +- tds.go | 99 ++++- tds_test.go | 110 +++++- token.go | 339 ++++++++++++++++-- types.go | 96 +++-- 49 files changed, 2843 insertions(+), 192 deletions(-) create mode 100644 aecmk/keyprovider.go create mode 100644 aecmk/localcert/keyprovider.go create mode 100644 aecmk/localcert/keyprovider_darwin.go create mode 100644 aecmk/localcert/keyprovider_go117_windows_test.go create mode 100644 aecmk/localcert/keyprovider_linux.go create mode 100644 aecmk/localcert/keyprovider_test.go create mode 100644 aecmk/localcert/keyprovider_windows.go create mode 100644 alwaysencrypted_windows_test.go create mode 100644 columnencryptionkey.go create mode 100644 encrypt.go create mode 100644 encrypt_test.go create mode 100644 internal/certs/certs.go create mode 100644 internal/certs/certs_windows.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/LICENSE.txt create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/README.md create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256_test.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/algorithm.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/aes_cbc_pkcs5.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/utils.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption/type.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/aead_aes_256_cbc_hmac_256.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/key.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/pkg/utils/utf16.go create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/test/always-encrypted.pem create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/test/always-encrypted_pub.pem create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/test/cekv.key create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/test/column_value.enc create mode 100644 internal/github.com/swisscom/mssql-always-encrypted/test/decrypted_key.key create mode 100644 quoter.go diff --git a/.github/workflows/pr-validation.yml b/.github/workflows/pr-validation.yml index 8c3a92cb..7963a660 100644 --- a/.github/workflows/pr-validation.yml +++ b/.github/workflows/pr-validation.yml @@ -10,8 +10,8 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go: ['1.16','1.17', '1.18'] - sqlImage: ['2017-latest','2019-latest'] + go: ['1.19','1.20'] + sqlImage: ['2019-latest','2022-latest'] steps: - uses: actions/checkout@v2 - name: Setup go diff --git a/.github/workflows/reviewdog.yml b/.github/workflows/reviewdog.yml index 97b897a4..143ab8b1 100644 --- a/.github/workflows/reviewdog.yml +++ b/.github/workflows/reviewdog.yml @@ -6,9 +6,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out code into the Go module directory - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: golangci-lint - uses: reviewdog/action-golangci-lint@v1 + uses: reviewdog/action-golangci-lint@v2 with: level: warning reporter: github-pr-review diff --git a/.pipelines/TestSql2017.yml b/.pipelines/TestSql2017.yml index 046e3e98..9bdb5303 100644 --- a/.pipelines/TestSql2017.yml +++ b/.pipelines/TestSql2017.yml @@ -8,7 +8,7 @@ variables: steps: - task: GoTool@0 inputs: - version: '1.16.5' + version: '1.19' - task: Go@0 displayName: 'Go: get sources' inputs: diff --git a/README.md b/README.md index f2405ea4..b424dd87 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ ## Install -Requires Go 1.10 or above. +Requires Go 1.16 or above. Install with `go install github.com/microsoft/go-mssqldb@latest`. @@ -63,6 +63,7 @@ Other supported formats are listed below. * `Workstation ID` - The workstation name (default is the host name) * `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`. * `protocol` - forces use of a protocol. Make sure the corresponding package is imported. +* `columnencryption` or `column encryption setting` - a boolean value indicating whether Always Encrypted should be enabled on the connection. ### Connection parameters for namedpipe package * `pipe` - If set, no Browser query is made and named pipe used will be `\\\pipe\` @@ -377,8 +378,56 @@ db.QueryContext(ctx, `select * from t2 where user_name = @p1;`, mssql.VarChar(na // Note: Mismatched data types on table and parameter may cause long running queries ``` +## Using Always Encrypted + +The protocol and cryptography details for AE are [detailed elsewhere](https://learn.microsoft.com/sql/relational-databases/security/encryption/always-encrypted-database-engine?view=sql-server-ver16). + +### Enablement + +To enable AE on a connection, set the `ColumnEncryption` value to true on a config or pass `columnencryption=true` in the connection string. + +Decryption and encryption won't succeed, however, without also including a decryption key provider. To avoid code size impacts on non-AE applications, key providers are not included by default. + +Include the local certificate providers: + +```go + import ( + "github.com/microsoft/go-mssqldb/aecmk/localcert" + ) + ``` + +You can also instantiate a key provider directly in code and hand it to a `Connector` instance. + +```go +c := mssql.NewConnectorConfig(myconfig) +c.RegisterCekProvider(providerName, MyProviderType{}) +``` + +### Decryption + +If the correct key provider is included in your application, decryption of encrypted cells happens automatically with no extra server round trips. + +### Encryption + +Encryption of parameters passed to `Exec` and `Query` variants requires an extra round trip per query to fetch the encryption metadata. If the error returned by a query attempt indicates a type mismatch between the parameter and the destination table, most likely your input type is not a strict match for the SQL Server data type of the destination. You may be using a Go `string` when you need to use one of the driver-specific aliases like `VarChar` or `NVarCharMax`. + +*** NOTE *** - Currently `char` and `varchar` types do not include a collation parameter component so can't be used for inserting encrypted values. Also, using a nullable sql package type like `sql.NullableInt32` to pass a `NULL` value for an encrypted column will not work unless the encrypted column type is `nvarchar`. +https://github.com/microsoft/go-mssqldb/issues/129 +https://github.com/microsoft/go-mssqldb/issues/130 + + +### Local certificate AE key provider + +Key provider configuration is managed separately without any properties in the connection string. +The `pfx` provider exposes its instance as the variable `PfxKeyProvider`. You can give it passwords for certificates using `SetCertificatePassword(pathToCertificate, path)`. Use an empty string or `"*"` as the path to use the same password for all certificates. + +The `MSSQL_CERTIFICATE_STORE` provider exposes its instance as the variable `WindowsCertificateStoreKeyProvider`. + +Both providers can be constrained to an allowed list of encryption key paths by appending paths to `provider.AllowedLocations`. + ## Important Notes + * [LastInsertId](https://golang.org/pkg/database/sql/#Result.LastInsertId) should not be used with this driver (or SQL Server) due to how the TDS protocol works. Please use the [OUTPUT Clause](https://docs.microsoft.com/en-us/sql/t-sql/queries/output-clause-transact-sql) @@ -409,7 +458,9 @@ db.QueryContext(ctx, `select * from t2 where user_name = @p1;`, mssql.VarChar(na * A `namedpipe` package to support connections using named pipes (np:) on Windows * A `sharedmemory` package to support connections using shared memory (lpc:) on Windows * Dedicated Administrator Connection (DAC) is supported using `admin` protocol - +* Always Encrypted + - `MSSQL_CERTIFICATE_STORE` provider on Windows + - `pfx` provider on Linux and Windows ## Tests `go test` is used for testing. A running instance of MSSQL server is required. @@ -449,6 +500,7 @@ To fix SQL Server 2008 R2 issue, install SQL Server 2008 R2 Service Pack 2. To fix SQL Server 2008 issue, install Microsoft SQL Server 2008 Service Pack 3 and Cumulative update package 3 for SQL Server 2008 SP3. More information: +* Bulk copy does not yet support encrypting column values using Always Encrypted. Tracked in [#127](https://github.com/microsoft/go-mssqldb/issues/127) # Contributing This project is a fork of [https://github.com/denisenkom/go-mssqldb](https://github.com/denisenkom/go-mssqldb) and welcomes new and previous contributors. For more informaton on contributing to this project, please see [Contributing](./CONTRIBUTING.md). diff --git a/aecmk/keyprovider.go b/aecmk/keyprovider.go new file mode 100644 index 00000000..7cdcb82c --- /dev/null +++ b/aecmk/keyprovider.go @@ -0,0 +1,112 @@ +package aecmk + +import ( + "fmt" + "sync" + "time" +) + +const ( + CertificateStoreKeyProvider = "MSSQL_CERTIFICATE_STORE" + CspKeyProvider = "MSSQL_CSP_PROVIDER" + CngKeyProvider = "MSSQL_CNG_STORE" + AzureKeyVaultKeyProvider = "AZURE_KEY_VAULT" + JavaKeyProvider = "MSSQL_JAVA_KEYSTORE" + KeyEncryptionAlgorithm = "RSA_OAEP" +) + +// ColumnEncryptionKeyLifetime is the default lifetime of decrypted Column Encryption Keys in the global cache. +// The default is 2 hours +var ColumnEncryptionKeyLifetime time.Duration = 2 * time.Hour + +type cekCacheEntry struct { + Expiry time.Time + Key []byte +} + +type cekCache map[string]cekCacheEntry + +type CekProvider struct { + Provider ColumnEncryptionKeyProvider + decryptedKeys cekCache + mutex sync.Mutex +} + +func NewCekProvider(provider ColumnEncryptionKeyProvider) *CekProvider { + return &CekProvider{Provider: provider, decryptedKeys: make(cekCache), mutex: sync.Mutex{}} +} + +func (cp *CekProvider) GetDecryptedKey(keyPath string, encryptedBytes []byte) (decryptedKey []byte, err error) { + cp.mutex.Lock() + ev, cachedKey := cp.decryptedKeys[keyPath] + if cachedKey { + if ev.Expiry.Before(time.Now()) { + delete(cp.decryptedKeys, keyPath) + cachedKey = false + } else { + decryptedKey = ev.Key + } + } + // decrypting a key can take a while, so let multiple callers race + // Key providers can choose to optimize their own concurrency. + // For example - there's probably minimal value in serializing access to a local certificate, + // but there'd be high value in having a queue of waiters for decrypting a key stored in the cloud. + cp.mutex.Unlock() + if !cachedKey { + decryptedKey = cp.Provider.DecryptColumnEncryptionKey(keyPath, KeyEncryptionAlgorithm, encryptedBytes) + } + if !cachedKey { + duration := cp.Provider.KeyLifetime() + if duration == nil { + duration = &ColumnEncryptionKeyLifetime + } + expiry := time.Now().Add(*duration) + cp.mutex.Lock() + cp.decryptedKeys[keyPath] = cekCacheEntry{Expiry: expiry, Key: decryptedKey} + cp.mutex.Unlock() + } + return +} + +// no synchronization on this map. Providers register during init. +type ColumnEncryptionKeyProviderMap map[string]*CekProvider + +var globalCekProviderFactoryMap = ColumnEncryptionKeyProviderMap{} + +// ColumnEncryptionKeyProvider is the interface for decrypting and encrypting column encryption keys. +// It is similar to .Net https://learn.microsoft.com/dotnet/api/microsoft.data.sqlclient.sqlcolumnencryptionkeystoreprovider. +type ColumnEncryptionKeyProvider interface { + // DecryptColumnEncryptionKey decrypts the specified encrypted value of a column encryption key. + // The encrypted value is expected to be encrypted using the column master key with the specified key path and using the specified algorithm. + DecryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) []byte + // EncryptColumnEncryptionKey encrypts a column encryption key using the column master key with the specified key path and using the specified algorithm. + EncryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, cek []byte) []byte + // SignColumnMasterKeyMetadata digitally signs the column master key metadata with the column master key + // referenced by the masterKeyPath parameter. The input values used to generate the signature should be the + // specified values of the masterKeyPath and allowEnclaveComputations parameters. May return an empty slice if not supported. + SignColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) []byte + // VerifyColumnMasterKeyMetadata verifies the specified signature is valid for the column master key + // with the specified key path and the specified enclave behavior. Return nil if not supported. + VerifyColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) *bool + // KeyLifetime is an optional Duration. Keys fetched by this provider will be discarded after their lifetime expires. + // If it returns nil, the keys will expire based on the value of ColumnEncryptionKeyLifetime. + // If it returns zero, the keys will not be cached. + KeyLifetime() *time.Duration +} + +func RegisterCekProvider(name string, provider ColumnEncryptionKeyProvider) error { + _, ok := globalCekProviderFactoryMap[name] + if ok { + return fmt.Errorf("CEK provider %s is already registered", name) + } + globalCekProviderFactoryMap[name] = &CekProvider{Provider: provider, decryptedKeys: cekCache{}, mutex: sync.Mutex{}} + return nil +} + +func GetGlobalCekProviders() (providers ColumnEncryptionKeyProviderMap) { + providers = make(ColumnEncryptionKeyProviderMap) + for i, p := range globalCekProviderFactoryMap { + providers[i] = p + } + return +} diff --git a/aecmk/localcert/keyprovider.go b/aecmk/localcert/keyprovider.go new file mode 100644 index 00000000..24c1bcae --- /dev/null +++ b/aecmk/localcert/keyprovider.go @@ -0,0 +1,240 @@ +//go:build go1.17 +// +build go1.17 + +package localcert + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "crypto/x509" + "fmt" + "io/ioutil" + "os" + "strconv" + "strings" + "time" + + "github.com/microsoft/go-mssqldb/aecmk" + ae "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg" + pkcs "golang.org/x/crypto/pkcs12" + "golang.org/x/text/encoding/unicode" +) + +const ( + PfxKeyProviderName = "pfx" + wildcard = "*" +) + +// Provider uses local certificates to decrypt CEKs +// It supports both 'MSSQL_CERTIFICATE_STORE' and 'pfx' key stores. +// MSSQL_CERTIFICATE_STORE key paths are of the form `storename/storepath/thumbprint` and only supported on Windows clients. +// pfx key paths are absolute file system paths that are operating system dependent. +type Provider struct { + // Name identifies which key store the provider supports. + name string + // AllowedLocations constrains which locations the provider will use to find certificates. If empty, all locations are allowed. + // When presented with a key store path not in the allowed list, the data will be returned still encrypted. + AllowedLocations []string + passwords map[string]string +} + +// SetCertificatePassword stores the password associated with the certificate at the given location. +// If location is empty the given password applies to all certificates that have not been explicitly assigned a value. +func (p Provider) SetCertificatePassword(location string, password string) { + if location == "" { + location = wildcard + } + p.passwords[location] = password +} + +var PfxKeyProvider = Provider{name: PfxKeyProviderName, passwords: make(map[string]string), AllowedLocations: make([]string, 0)} + +func init() { + err := aecmk.RegisterCekProvider("pfx", &PfxKeyProvider) + if err != nil { + panic(err) + } +} + +// DecryptColumnEncryptionKey decrypts the specified encrypted value of a column encryption key. +// The encrypted value is expected to be encrypted using the column master key with the specified key path and using the specified algorithm. +func (p *Provider) DecryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) (decryptedKey []byte) { + decryptedKey = nil + pk, cert, allowed := p.tryLoadCertificate(masterKeyPath) + if !allowed { + return + } + cekv := ae.LoadCEKV(encryptedCek) + if !cekv.Verify(cert) { + panic(fmt.Errorf("Invalid certificate provided for decryption. Key Store Path: %s. <%s>-<%v>", masterKeyPath, cekv.KeyPath, fmt.Sprintf("%02x", sha1.Sum(cert.Raw)))) + } + + decryptedKey, err := cekv.Decrypt(pk.(*rsa.PrivateKey)) + if err != nil { + panic(err) + } + return +} + +func (p *Provider) tryLoadCertificate(masterKeyPath string) (privateKey interface{}, cert *x509.Certificate, allowed bool) { + allowed = len(p.AllowedLocations) == 0 + if !allowed { + loop: + for _, l := range p.AllowedLocations { + if l == masterKeyPath { + allowed = true + break loop + } + } + } + if !allowed { + return + } + switch p.name { + case PfxKeyProviderName: + privateKey, cert = p.loadLocalCertificate(masterKeyPath) + case aecmk.CertificateStoreKeyProvider: + privateKey, cert = p.loadWindowsCertStoreCertificate(masterKeyPath) + } + return +} + +func (p *Provider) loadLocalCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { + if f, err := os.Open(path); err == nil { + pfxBytes, err := ioutil.ReadAll(f) + if err != nil { + panic(invalidCertificatePath(path, err)) + } + pwd, ok := p.passwords[path] + if !ok { + pwd, ok = p.passwords[wildcard] + if !ok { + pwd = "" + } + } + privateKey, cert, err = pkcs.Decode(pfxBytes, pwd) + if err != nil { + panic(err) + } + } else { + panic(invalidCertificatePath(path, err)) + } + return +} + +// EncryptColumnEncryptionKey encrypts a column encryption key using the column master key with the specified key path and using the specified algorithm. +func (p *Provider) EncryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, cek []byte) []byte { + + validateEncryptionAlgorithm(encryptionAlgorithm) + validateKeyPathLength(masterKeyPath) + pk, cert, allowed := p.tryLoadCertificate(masterKeyPath) + if !allowed { + panic(fmt.Errorf("Key path not allowed for use in column key encryption")) + } + publicKey := cert.PublicKey.(*rsa.PublicKey) + keySizeInBytes := publicKey.Size() + + enc := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewEncoder() + // Start with version byte == 1 + buf := []byte{byte(1)} + // EncryptedColumnEncryptionKey = version + keyPathLength + ciphertextLength + keyPath + ciphertext + signature + // version + keyPathBytes, err := enc.Bytes([]byte(strings.ToLower(masterKeyPath))) + if err != nil { + panic(fmt.Errorf("Unable to serialize key path %w", err)) + } + k := uint16(len(keyPathBytes)) + // keyPathLength + buf = append(buf, byte(k), byte(k>>8)) + + cipherText, err := rsa.EncryptOAEP(sha1.New(), rand.Reader, publicKey, cek, []byte{}) + if err != nil { + panic(fmt.Errorf("Unable to encrypt data %w", err)) + } + l := uint16(len(cipherText)) + // ciphertextLength + buf = append(buf, byte(l), byte(l>>8)) + // keypath + buf = append(buf, keyPathBytes...) + // ciphertext + buf = append(buf, cipherText...) + hash := sha256.Sum256(buf) + // signature is the signed hash of the current buf + sig, err := rsa.SignPKCS1v15(rand.Reader, pk.(*rsa.PrivateKey), crypto.SHA256, hash[:]) + if err != nil { + panic(err) + } + if len(sig) != keySizeInBytes { + panic("Signature length doesn't match certificate key size") + } + buf = append(buf, sig...) + return buf +} + +// SignColumnMasterKeyMetadata digitally signs the column master key metadata with the column master key +// referenced by the masterKeyPath parameter. The input values used to generate the signature should be the +// specified values of the masterKeyPath and allowEnclaveComputations parameters. May return an empty slice if not supported. +func (p *Provider) SignColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) []byte { + return nil +} + +// VerifyColumnMasterKeyMetadata verifies the specified signature is valid for the column master key +// with the specified key path and the specified enclave behavior. Return nil if not supported. +func (p *Provider) VerifyColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) *bool { + return nil +} + +// KeyLifetime is an optional Duration. Keys fetched by this provider will be discarded after their lifetime expires. +// If it returns nil, the keys will expire based on the value of ColumnEncryptionKeyLifetime. +// If it returns zero, the keys will not be cached. +func (p *Provider) KeyLifetime() *time.Duration { + return nil +} + +func validateEncryptionAlgorithm(encryptionAlgorithm string) { + if !strings.EqualFold(encryptionAlgorithm, "RSA_OAEP") { + panic(fmt.Errorf("Unsupported encryption algorithm %s", encryptionAlgorithm)) + } +} + +func validateKeyPathLength(keyPath string) { + if len(keyPath) > 32767 { + panic(fmt.Errorf("Key path is too long. %d", len(keyPath))) + } +} + +// InvalidCertificatePathError indicates the provided path could not be used to load a certificate +type InvalidCertificatePathError struct { + path string + innerErr error +} + +func (i *InvalidCertificatePathError) Error() string { + return fmt.Sprintf("Invalid certificate path: %s", i.path) +} + +func (i *InvalidCertificatePathError) Unwrap() error { + return i.innerErr +} + +func invalidCertificatePath(path string, err error) error { + return &InvalidCertificatePathError{path: path, innerErr: err} +} + +func thumbprintToByteArray(thumbprint string) []byte { + if len(thumbprint)%2 != 0 { + panic(fmt.Errorf("Thumbprint must have even length %s", thumbprint)) + } + bytes := make([]byte, len(thumbprint)/2) + for i := range bytes { + b, err := strconv.ParseInt(thumbprint[i*2:(i*2)+2], 16, 32) + if err != nil { + panic(err) + } + bytes[i] = byte(b) + } + return bytes +} diff --git a/aecmk/localcert/keyprovider_darwin.go b/aecmk/localcert/keyprovider_darwin.go new file mode 100644 index 00000000..c3a7564a --- /dev/null +++ b/aecmk/localcert/keyprovider_darwin.go @@ -0,0 +1,14 @@ +//go:build go1.17 +// +build go1.17 + +package localcert + +import ( + "crypto/x509" + "fmt" +) + +func (p *Provider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { + panic(fmt.Errorf("Windows cert store not supported on this OS")) + return +} diff --git a/aecmk/localcert/keyprovider_go117_windows_test.go b/aecmk/localcert/keyprovider_go117_windows_test.go new file mode 100644 index 00000000..9212d4b8 --- /dev/null +++ b/aecmk/localcert/keyprovider_go117_windows_test.go @@ -0,0 +1,49 @@ +//go:build go1.17 +// +build go1.17 + +package localcert + +import ( + "crypto/rsa" + "strings" + "testing" + + "github.com/microsoft/go-mssqldb/aecmk" + "github.com/microsoft/go-mssqldb/internal/certs" +) + +func TestLoadWindowsCertStoreCertificate(t *testing.T) { + thumbprint, err := certs.ProvisionMasterKeyInCertStore() + if err != nil { + t.Fatal(err) + } + defer certs.DeleteMasterKeyCert(thumbprint) + provider := aecmk.GetGlobalCekProviders()[aecmk.CertificateStoreKeyProvider].Provider.(*Provider) + pk, cert := provider.loadWindowsCertStoreCertificate("CurrentUser/My/" + thumbprint) + switch z := pk.(type) { + case *rsa.PrivateKey: + + t.Logf("Got an rsa.PrivateKey with size %d", z.Size()) + default: + t.Fatalf("Unexpected private key type: %v", z) + } + if !strings.HasPrefix(cert.Subject.String(), `CN=gomssqltest-`) { + t.Fatalf("Wrong cert loaded: %s", cert.Subject.String()) + } +} + +func TestEncryptDecryptEncryptionKeyRoundTrip(t *testing.T) { + thumbprint, err := certs.ProvisionMasterKeyInCertStore() + if err != nil { + t.Fatal(err) + } + defer certs.DeleteMasterKeyCert(thumbprint) + bytesToEncrypt := []byte{1, 2, 3} + keyPath := "CurrentUser/My/" + thumbprint + provider := aecmk.GetGlobalCekProviders()[aecmk.CertificateStoreKeyProvider].Provider.(*Provider) + encryptedBytes := provider.EncryptColumnEncryptionKey(keyPath, "RSA_OAEP", bytesToEncrypt) + decryptedBytes := provider.DecryptColumnEncryptionKey(keyPath, "RSA_OAEP", encryptedBytes) + if len(decryptedBytes) != 3 || decryptedBytes[0] != 1 || decryptedBytes[1] != 2 || decryptedBytes[2] != 3 { + t.Fatalf("Encrypt/Decrypt did not roundtrip. encryptedBytes:%v, decryptedBytes: %v", encryptedBytes, decryptedBytes) + } +} diff --git a/aecmk/localcert/keyprovider_linux.go b/aecmk/localcert/keyprovider_linux.go new file mode 100644 index 00000000..c3a7564a --- /dev/null +++ b/aecmk/localcert/keyprovider_linux.go @@ -0,0 +1,14 @@ +//go:build go1.17 +// +build go1.17 + +package localcert + +import ( + "crypto/x509" + "fmt" +) + +func (p *Provider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { + panic(fmt.Errorf("Windows cert store not supported on this OS")) + return +} diff --git a/aecmk/localcert/keyprovider_test.go b/aecmk/localcert/keyprovider_test.go new file mode 100644 index 00000000..8b4237cf --- /dev/null +++ b/aecmk/localcert/keyprovider_test.go @@ -0,0 +1,18 @@ +//go:build go1.17 +// +build go1.17 + +package localcert + +import ( + "bytes" + "encoding/hex" + "testing" +) + +func TestThumbPrintToSignature(t *testing.T) { + thumbprint := "5e89a107f0ade0aed5f753ecc60378b1bbae3598" + signature := thumbprintToByteArray(thumbprint) + if !bytes.Equal(signature, []byte{0x5e, 0x89, 0xa1, 0x07, 0xf0, 0xad, 0xe0, 0xae, 0xd5, 0xf7, 0x53, 0xec, 0xc6, 0x03, 0x78, 0xb1, 0xbb, 0xae, 0x35, 0x98}) { + t.Fatalf("Incorrect signature bytes for %s. Got: %s", thumbprint, hex.Dump(signature)) + } +} diff --git a/aecmk/localcert/keyprovider_windows.go b/aecmk/localcert/keyprovider_windows.go new file mode 100644 index 00000000..25c7fa20 --- /dev/null +++ b/aecmk/localcert/keyprovider_windows.go @@ -0,0 +1,57 @@ +//go:build go1.17 +// +build go1.17 + +package localcert + +import ( + "crypto/x509" + "fmt" + "strings" + "unsafe" + + "github.com/microsoft/go-mssqldb/aecmk" + "github.com/microsoft/go-mssqldb/internal/certs" + "golang.org/x/sys/windows" +) + +var WindowsCertificateStoreKeyProvider = Provider{name: aecmk.CertificateStoreKeyProvider, passwords: make(map[string]string)} + +func init() { + err := aecmk.RegisterCekProvider(aecmk.CertificateStoreKeyProvider, &WindowsCertificateStoreKeyProvider) + if err != nil { + panic(err) + } +} + +func (p *Provider) loadWindowsCertStoreCertificate(path string) (privateKey interface{}, cert *x509.Certificate) { + privateKey = nil + cert = nil + pathParts := strings.Split(path, `/`) + if len(pathParts) != 3 { + panic(invalidCertificatePath(path, fmt.Errorf("key store path requires 3 segments"))) + } + + var storeId uint32 + switch strings.ToLower(pathParts[0]) { + case "localmachine": + storeId = windows.CERT_SYSTEM_STORE_LOCAL_MACHINE + case "currentuser": + storeId = windows.CERT_SYSTEM_STORE_CURRENT_USER + default: + panic(invalidCertificatePath(path, fmt.Errorf("Unknown certificate store"))) + } + system, err := windows.UTF16PtrFromString(pathParts[1]) + if err != nil { + panic(err) + } + h, err := windows.CertOpenStore(windows.CERT_STORE_PROV_SYSTEM, + windows.PKCS_7_ASN_ENCODING|windows.X509_ASN_ENCODING, + 0, + storeId, uintptr(unsafe.Pointer(system))) + if err != nil { + panic(err) + } + defer windows.CertCloseStore(h, 0) + signature := thumbprintToByteArray(pathParts[2]) + return certs.FindCertBySignatureHash(h, signature) +} diff --git a/alwaysencrypted_windows_test.go b/alwaysencrypted_windows_test.go new file mode 100644 index 00000000..20f95c73 --- /dev/null +++ b/alwaysencrypted_windows_test.go @@ -0,0 +1,220 @@ +//go:build go1.17 +// +build go1.17 + +package mssql + +import ( + "crypto/rand" + "database/sql" + "fmt" + "math/big" + "strings" + "testing" + "time" + + "github.com/golang-sql/civil" + "github.com/microsoft/go-mssqldb/aecmk/localcert" + "github.com/microsoft/go-mssqldb/internal/certs" +) + +// Define phrases for create table for each enryptable data type along with sample data for insertion and validation +type aeColumnInfo struct { + queryPhrase string + sqlDataType string + encType ColumnEncryptionType + sampleValue interface{} +} + +var encryptableColumns = []aeColumnInfo{ + {"int", "INT", ColumnEncryptionDeterministic, int32(1)}, + {"nchar(10) COLLATE Latin1_General_BIN2", "NCHAR", ColumnEncryptionDeterministic, NChar("ncharval")}, + {"tinyint", "TINYINT", ColumnEncryptionRandomized, byte(2)}, + {"smallint", "SMALLINT", ColumnEncryptionDeterministic, int16(-3)}, + {"bigint", "BIGINT", ColumnEncryptionRandomized, int64(4)}, + // We can't use fractional float/real values due to rounding errors in the round trip + {"real", "REAL", ColumnEncryptionDeterministic, float32(5)}, + {"float", "FLOAT", ColumnEncryptionRandomized, float64(6)}, + {"varbinary(10)", "VARBINARY", ColumnEncryptionDeterministic, []byte{1, 2, 3, 4}}, + // TODO: Varchar support requires proper selection of a collation and conversion + // {"varchar(10) COLLATE Latin1_General_BIN2", "VARCHAR", ColumnEncryptionRandomized, VarChar("varcharval")}, + {"nvarchar(30)", "NVARCHAR", ColumnEncryptionRandomized, "nvarcharval"}, + {"bit", "BIT", ColumnEncryptionDeterministic, true}, + {"datetimeoffset(7)", "DATETIMEOFFSET", ColumnEncryptionRandomized, time.Now()}, + {"datetime2(7)", "DATETIME2", ColumnEncryptionDeterministic, civil.DateTimeOf(time.Now())}, + {"nvarchar(max)", "NVARCHAR", ColumnEncryptionRandomized, NVarCharMax("nvarcharmaxval")}, + // TODO: The driver throws away type information about Valuer implementations and sends nil as nvarchar(1). Fix that. + // {"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}}, +} + +func TestAlwaysEncryptedE2E(t *testing.T) { + params := testConnParams(t) + if !params.ColumnEncryption { + t.Skip("Test is not running with column encryption enabled") + } + conn, _ := open(t) + defer conn.Close() + thumbprint, err := certs.ProvisionMasterKeyInCertStore() + if err != nil { + t.Fatal(err) + } + defer certs.DeleteMasterKeyCert(thumbprint) + certPath := fmt.Sprintf(`CurrentUser/My/%s`, thumbprint) + s := fmt.Sprintf(createColumnMasterKey, certPath, certPath) + if _, err := conn.Exec(s); err != nil { + t.Fatalf("Unable to create CMK: %s", err.Error()) + } + defer conn.Exec(fmt.Sprintf(dropColumnMasterKey, certPath)) + r, _ := rand.Int(rand.Reader, big.NewInt(1000)) + cekName := fmt.Sprintf("mssqlCek%d", r.Int64()) + tableName := fmt.Sprintf("mssqlAe%d", r.Int64()) + keyBytes := make([]byte, 32) + _, _ = rand.Read(keyBytes) + encryptedCek := localcert.WindowsCertificateStoreKeyProvider.EncryptColumnEncryptionKey(certPath, KeyEncryptionAlgorithm, keyBytes) + createCek := fmt.Sprintf(createColumnEncryptionKey, cekName, certPath, encryptedCek) + _, err = conn.Exec(createCek) + if err != nil { + t.Fatalf("Unable to create CEK: %s", err.Error()) + } + defer conn.Exec(fmt.Sprintf(dropColumnEncryptionKey, cekName)) + _, _ = conn.Exec("DROP TABLE IF EXISTS " + tableName) + query := new(strings.Builder) + insert := new(strings.Builder) + sel := new(strings.Builder) + _, _ = query.WriteString(fmt.Sprintf("CREATE TABLE [%s] (", tableName)) + _, _ = insert.WriteString(fmt.Sprintf("INSERT INTO [%s] VALUES (", tableName)) + _, _ = sel.WriteString("select top(1) ") + insertArgs := make([]interface{}, len(encryptableColumns)+1) + for i, ec := range encryptableColumns { + encType := "RANDOMIZED" + null := "" + _, ok := ec.sampleValue.(sql.NullInt32) + if ok { + null = "NULL" + } + if ec.encType == ColumnEncryptionDeterministic { + encType = "DETERMINISTIC" + } + _, _ = query.WriteString(fmt.Sprintf(`col%d %s ENCRYPTED WITH (ENCRYPTION_TYPE = %s, + ALGORITHM = 'AEAD_AES_256_CBC_HMAC_SHA_256', + COLUMN_ENCRYPTION_KEY = [%s]) %s, + `, i, ec.queryPhrase, encType, cekName, null)) + + insertArgs[i] = ec.sampleValue + insert.WriteString(fmt.Sprintf("@p%d,", i+1)) + sel.WriteString(fmt.Sprintf("col%d,", i)) + } + _, _ = query.WriteString("unencryptedcolumn nvarchar(100)") + _, _ = query.WriteString(")") + insertArgs[len(encryptableColumns)] = "unencryptedvalue" + insert.WriteString(fmt.Sprintf("@p%d)", len(encryptableColumns)+1)) + sel.WriteString(fmt.Sprintf("unencryptedcolumn from [%s]", tableName)) + _, err = conn.Exec(query.String()) + if err != nil { + t.Fatalf("Failed to create encrypted table %s", err.Error()) + } + defer conn.Exec("DROP TABLE IF EXISTS " + tableName) + _, err = conn.Exec(insert.String(), insertArgs...) + if err != nil { + t.Fatalf("Failed to insert row in encrypted table %s", err.Error()) + } + rows, err := conn.Query(sel.String()) + if err != nil { + t.Fatalf("Unable to query encrypted columns: %v", err.(Error).All) + } + if !rows.Next() { + rows.Close() + t.Fatalf("rows.Next returned false") + } + cols, err := rows.ColumnTypes() + if err != nil { + t.Fatalf("rows.ColumnTypes failed %s", err.Error()) + } + for i := range encryptableColumns { + + if cols[i].DatabaseTypeName() != encryptableColumns[i].sqlDataType { + t.Fatalf("Got wrong type name for col%d. Expected: %s, Got:%s", i, encryptableColumns[i].sqlDataType, cols[i].DatabaseTypeName()) + } + } + + var unencryptedColumnValue string + scanValues := make([]interface{}, len(encryptableColumns)+1) + for v := range scanValues { + if v < len(encryptableColumns) { + scanValues[v] = new(interface{}) + } + } + scanValues[len(encryptableColumns)] = &unencryptedColumnValue + err = rows.Scan(scanValues...) + if err != nil { + rows.Close() + t.Fatalf("rows.Scan failed: %s", err.Error()) + } + + for i := range encryptableColumns { + var strVal string + var expectedStrVal string + if encryptableColumns[i].sampleValue == nil { + expectedStrVal = "NULL" + } else { + expectedStrVal = comparisonValueFromObject(encryptableColumns[i].sampleValue) + } + rawVal := scanValues[i].(*interface{}) + + if rawVal == nil { + strVal = "NULL" + } else { + strVal = comparisonValueFromObject(*rawVal) + } + if expectedStrVal != strVal { + t.Fatalf("Incorrect value for col%d. Expected:%s, Got:%s", i, expectedStrVal, strVal) + } + } + if unencryptedColumnValue != "unencryptedvalue" { + t.Fatalf("Got wrong value for unencrypted column: %s", unencryptedColumnValue) + } + rows.Close() + err = rows.Err() + if err != nil { + t.Fatalf("rows.Err() has non-nil value: %s", err.Error()) + } +} + +func comparisonValueFromObject(object interface{}) string { + switch v := object.(type) { + case []byte: + { + return string(v) + } + case string: + return v + case time.Time: + return civil.DateTimeOf(v).String() + //return v.Format(time.RFC3339) + case fmt.Stringer: + return v.String() + case bool: + if v == true { + return "1" + } + return "0" + default: + return fmt.Sprintf("%v", v) + } +} + +const ( + createColumnMasterKey = `CREATE COLUMN MASTER KEY [%s] WITH (KEY_STORE_PROVIDER_NAME= 'MSSQL_CERTIFICATE_STORE', KEY_PATH='%s')` + dropColumnMasterKey = `DROP COLUMN MASTER KEY [%s]` + createColumnEncryptionKey = `CREATE COLUMN ENCRYPTION KEY [%s] WITH VALUES (COLUMN_MASTER_KEY = [%s], ALGORITHM = 'RSA_OAEP', ENCRYPTED_VALUE = 0x%x )` + dropColumnEncryptionKey = `DROP COLUMN ENCRYPTION KEY [%s]` + createEncryptedTable = `CREATE TABLE %s + (col1 int + ENCRYPTED WITH (ENCRYPTION_TYPE = DETERMINISTIC, + ALGORITHM = 'AEAD_AES_256_CBC_HMAC_SHA_256', + COLUMN_ENCRYPTION_KEY = [%s]), + col2 nchar(10) COLLATE Latin1_General_BIN2 + ENCRYPTED WITH (ENCRYPTION_TYPE = DETERMINISTIC, + ALGORITHM = 'AEAD_AES_256_CBC_HMAC_SHA_256', + COLUMN_ENCRYPTION_KEY = [%s]) + )` +) diff --git a/appveyor.yml b/appveyor.yml index dafa9729..fdeeedf3 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -11,52 +11,31 @@ environment: SQLUSER: sa SQLPASSWORD: Password12! DATABASE: test - GOVERSION: 113 + GOVERSION: 116 + COLUMNENCRYPTION: + APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 RACE: -race -cpu 4 TAGS: matrix: - - GOVERSION: 110 - SQLINSTANCE: SQL2017 - - GOVERSION: 111 - SQLINSTANCE: SQL2017 - - GOVERSION: 112 - SQLINSTANCE: SQL2017 - SQLINSTANCE: SQL2017 - - SQLINSTANCE: SQL2016 - - SQLINSTANCE: SQL2014 - - SQLINSTANCE: SQL2012SP1 - - SQLINSTANCE: SQL2008R2SP2 - - # Go 1.14+ and SQL2019 are available on the Visual Studio 2019 image only - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 - GOVERSION: 114 - SQLINSTANCE: SQL2019 - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 - GOVERSION: 115 - SQLINSTANCE: SQL2019 - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 - GOVERSION: 115 + - GOVERSION: 117 SQLINSTANCE: SQL2017 - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 - GOVERSION: 116 - SQLINSTANCE: SQL2017 - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 - GOVERSION: 117 - SQLINSTANCE: SQL2017 - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 - GOVERSION: 118 + - GOVERSION: 118 SQLINSTANCE: SQL2017 + - GOVERSION: 120 + RACE: + SQLINSTANCE: SQL2019 + COLUMNENCRYPTION: 1 # Cover 32bit and named pipes protocol - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 - GOVERSION: 118-x86 + - GOVERSION: 119-x86 SQLINSTANCE: SQL2017 GOARCH: 386 RACE: PROTOCOL: np TAGS: -tags np # Cover SSPI and lpc protocol - - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 - GOVERSION: 118 + - GOVERSION: 120 + RACE: SQLINSTANCE: SQL2019 PROTOCOL: lpc TAGS: -tags sm @@ -70,6 +49,7 @@ install: - go get -u github.com/golang-sql/civil - go get -u github.com/golang-sql/sqlexp - go get -u golang.org/x/crypto/md4 + - go get -u golang.org/x/text/encoding/unicode build_script: - go build diff --git a/bulkcopy.go b/bulkcopy.go index 97edb7be..15512a9e 100644 --- a/bulkcopy.go +++ b/bulkcopy.go @@ -250,6 +250,10 @@ func (b *Bulk) createColMetadata() []byte { buf.WriteByte(byte(tokenColMetadata)) // token binary.Write(buf, binary.LittleEndian, uint16(len(b.bulkColumns))) // column count + // TODO: Write a valid CEK table if any parameters have cekTableEntry values + if b.cn.sess.alwaysEncrypted { + binary.Write(buf, binary.LittleEndian, uint16(0)) + } for i, col := range b.bulkColumns { if b.cn.sess.loginAck.TDSVersion >= verTDS72 { diff --git a/columnencryptionkey.go b/columnencryptionkey.go new file mode 100644 index 00000000..1dd51068 --- /dev/null +++ b/columnencryptionkey.go @@ -0,0 +1,40 @@ +package mssql + +const ( + CertificateStoreKeyProvider = "MSSQL_CERTIFICATE_STORE" + CspKeyProvider = "MSSQL_CSP_PROVIDER" + CngKeyProvider = "MSSQL_CNG_STORE" + AzureKeyVaultKeyProvider = "AZURE_KEY_VAULT" + JavaKeyProvider = "MSSQL_JAVA_KEYSTORE" + KeyEncryptionAlgorithm = "RSA_OAEP" +) + +// cek ==> Column Encryption Key +// Every row of an encrypted table has an associated list of keys used to decrypt its columns +type cekTable struct { + entries []cekTableEntry +} + +type encryptionKeyInfo struct { + encryptedKey []byte + databaseID int + cekID int + cekVersion int + cekMdVersion []byte + keyPath string + keyStoreName string + algorithmName string +} + +type cekTableEntry struct { + databaseID int + keyId int + keyVersion int + mdVersion []byte + valueCount int + cekValues []encryptionKeyInfo +} + +func newCekTable(size uint16) cekTable { + return cekTable{entries: make([]cekTableEntry, size)} +} diff --git a/encrypt.go b/encrypt.go new file mode 100644 index 00000000..0e04837a --- /dev/null +++ b/encrypt.go @@ -0,0 +1,292 @@ +package mssql + +import ( + "context" + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "strings" + + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys" +) + +type ColumnEncryptionType int + +var ( + ColumnEncryptionPlainText ColumnEncryptionType = 0 + ColumnEncryptionDeterministic ColumnEncryptionType = 1 + ColumnEncryptionRandomized ColumnEncryptionType = 2 +) + +type cekData struct { + ordinal int + database_id int + id int + version int + metadataVersion []byte + encryptedValue []byte + cmkStoreName string + cmkPath string + algorithm string + //byEnclave bool + //cmkSignature string + decryptedValue []byte +} + +type parameterEncData struct { + ordinal int + name string + algorithm int + encType ColumnEncryptionType + cekOrdinal int + ruleVersion int +} + +type paramMapEntry struct { + cek *cekData + p *parameterEncData +} + +// when Always Encrypted is turned on, we have to ask the server for metadata about how to encrypt input parameters. +// This function stores the relevant encryption parameters in a copy of the args so they can be +// encrypted just before being sent to the server +func (s *Stmt) encryptArgs(ctx context.Context, args []namedValue) (encryptedArgs []namedValue, err error) { + q := Stmt{c: s.c, + paramCount: s.paramCount, + query: "sp_describe_parameter_encryption", + skipEncryption: true, + } + oldouts := s.c.outs + s.c.clearOuts() + newArgs, err := s.prepareEncryptionQuery(isProc(s.query), s.query, args) + if err != nil { + return + } + // TODO: Consider not using recursion. + rows, err := q.queryContext(ctx, newArgs) + if err != nil { + s.c.outs = oldouts + return + } + cekInfo, paramsInfo, err := processDescribeParameterEncryption(rows) + rows.Close() + s.c.outs = oldouts + if err != nil { + return + } + if len(cekInfo) == 0 { + return args, nil + } + err = s.decryptCek(cekInfo) + if err != nil { + return + } + paramMap := make(map[string]paramMapEntry) + for _, p := range paramsInfo { + if p.encType == ColumnEncryptionPlainText { + paramMap[p.name] = paramMapEntry{nil, p} + } else { + paramMap[p.name] = paramMapEntry{cekInfo[p.cekOrdinal-1], p} + } + } + encryptedArgs = make([]namedValue, len(args)) + for i, a := range args { + encryptedArgs[i] = a + name := "" + if len(a.Name) > 0 { + name = "@" + a.Name + } else { + name = fmt.Sprintf("@p%d", a.Ordinal) + } + info := paramMap[name] + + if info.p.encType == ColumnEncryptionPlainText || a.Value == nil { + continue + } + + encryptedArgs[i].encrypt = getEncryptor(info) + } + return encryptedArgs, nil +} + +// returns the arguments to sp_describe_parameter_encryption +// sp_describe_parameter_encryption +// [ @tsql = ] N'Transact-SQL_batch' , +// [ @params = ] N'parameters' +// [ ;] +func (s *Stmt) prepareEncryptionQuery(isProc bool, q string, args []namedValue) (newArgs []namedValue, err error) { + newArgs = make([]namedValue, 2) + if isProc { + newArgs[0] = namedValue{Name: "tsql", Ordinal: 0, Value: buildStoredProcedureStatementForColumnEncryption(q, args)} + } else { + newArgs[0] = namedValue{Name: "tsql", Ordinal: 0, Value: q} + } + params, err := s.buildParametersForColumnEncryption(args) + if err != nil { + return + } + newArgs[1] = namedValue{Name: "params", Ordinal: 1, Value: params} + return +} + +func (s *Stmt) buildParametersForColumnEncryption(args []namedValue) (parameters string, err error) { + _, decls, err := s.makeRPCParams(args, false) + if err != nil { + return + } + parameters = strings.Join(decls, ", ") + return +} + +func (s *Stmt) decryptCek(cekInfo []*cekData) error { + for _, info := range cekInfo { + kp, ok := s.c.sess.aeSettings.keyProviders[info.cmkStoreName] + if !ok { + return fmt.Errorf("No provider found for key store %s", info.cmkStoreName) + } + dk, err := kp.GetDecryptedKey(info.cmkPath, info.encryptedValue) + if err != nil { + return err + } + info.decryptedValue = dk + } + return nil +} + +func getEncryptor(info paramMapEntry) valueEncryptor { + k := keys.NewAeadAes256CbcHmac256(info.cek.decryptedValue) + alg := algorithms.NewAeadAes256CbcHmac256Algorithm(k, encryption.From(byte(info.p.encType)), byte(info.cek.version)) + // Metadata to append to an encrypted parameter. Doesn't include original typeinfo + // https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/619c43b6-9495-4a58-9e49-a4950db245b3 + // ParamCipherInfo = TYPE_INFO + // EncryptionAlgo (byte) + // [AlgoName] (b_varchar) unused, no custom algorithm + // EncryptionType (byte) + // DatabaseId (ulong) + // CekId (ulong) + // CekVersion (ulong) + // CekMDVersion (ulonglong) - really a byte array + // NormVersion (byte) + // algo+ enctype+ dbid+ keyid+ keyver+ normversion + metadataLen := 1 + 1 + 4 + 4 + 4 + 1 + metadataLen += len(info.cek.metadataVersion) + metadata := make([]byte, metadataLen) + offset := 0 + // AEAD_AES_256_CBC_HMAC_SHA256 + metadata[offset] = byte(info.p.algorithm) + offset++ + metadata[offset] = byte(info.p.encType) + offset++ + binary.LittleEndian.PutUint32(metadata[offset:], uint32(info.cek.database_id)) + offset += 4 + binary.LittleEndian.PutUint32(metadata[offset:], uint32(info.cek.id)) + offset += 4 + binary.LittleEndian.PutUint32(metadata[offset:], uint32(info.cek.version)) + offset += 4 + copy(metadata[offset:], info.cek.metadataVersion) + offset += len(info.cek.metadataVersion) + metadata[offset] = byte(info.p.ruleVersion) + return func(b []byte) ([]byte, []byte, error) { + encryptedData, err := alg.Encrypt(b) + if err != nil { + return nil, nil, err + } + return encryptedData, metadata, nil + } +} + +// Based on the .Net implementation at https://github.com/dotnet/SqlClient/blob/2b31810ce69b88d707450e2059ee8fbde63f774f/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs#L6040 +func buildStoredProcedureStatementForColumnEncryption(sproc string, args []namedValue) string { + b := new(strings.Builder) + _, _ = b.WriteString("EXEC ") + q := TSQLQuoter{} + sproc = q.ID(sproc) + + b.WriteString(sproc) + + // Unlike ADO.Net, go-mssqldb doesn't support ReturnValue named parameters + first := true + for _, a := range args { + if !first { + b.WriteRune(',') + } + first = false + b.WriteRune(' ') + name := a.Name + if len(name) == 0 { + name = fmt.Sprintf("@p%d", a.Ordinal) + } + appendPrefixedParameterName(b, name) + if len(a.Name) > 0 { + b.WriteRune('=') + appendPrefixedParameterName(b, a.Name) + } + if isOutputValue(a.Value) { + b.WriteString(" OUTPUT") + } + } + return b.String() +} + +func appendPrefixedParameterName(b *strings.Builder, p string) { + if len(p) > 0 { + if p[0] != '@' { + b.WriteRune('@') + } + b.WriteString(p) + } +} + +func processDescribeParameterEncryption(rows driver.Rows) (cekInfo []*cekData, paramInfo []*parameterEncData, err error) { + cekInfo = make([]*cekData, 0) + values := make([]driver.Value, 9) + qerr := rows.Next(values) + for qerr == nil { + cekInfo = append(cekInfo, &cekData{ordinal: int(values[0].(int64)), + database_id: int(values[1].(int64)), + id: int(values[2].(int64)), + version: int(values[3].(int64)), + metadataVersion: values[4].([]byte), + encryptedValue: values[5].([]byte), + cmkStoreName: values[6].(string), + cmkPath: values[7].(string), + algorithm: values[8].(string), + }) + qerr = rows.Next(values) + } + if len(cekInfo) == 0 || qerr != io.EOF { + if qerr != io.EOF { + err = qerr + } + // No encryption needed + return + } + r := rows.(driver.RowsNextResultSet) + err = r.NextResultSet() + if err != nil { + return + } + paramInfo = make([]*parameterEncData, 0) + qerr = rows.Next(values[:6]) + for qerr == nil { + paramInfo = append(paramInfo, ¶meterEncData{ordinal: int(values[0].(int64)), + name: values[1].(string), + algorithm: int(values[2].(int64)), + encType: ColumnEncryptionType(values[3].(int64)), + cekOrdinal: int(values[4].(int64)), + ruleVersion: int(values[5].(int64)), + }) + qerr = rows.Next(values[:6]) + } + if len(paramInfo) == 0 || qerr != io.EOF { + if qerr != io.EOF { + err = qerr + } else { + err = fmt.Errorf("No parameter encryption rows were returned from sp_describe_parameter_encryption") + } + } + return +} diff --git a/encrypt_test.go b/encrypt_test.go new file mode 100644 index 00000000..6e54a0a9 --- /dev/null +++ b/encrypt_test.go @@ -0,0 +1,119 @@ +package mssql + +import ( + "database/sql" + "strings" + "testing" +) + +func TestBuildQueryParametersForCE(t *testing.T) { + type test struct { + name string + args []namedValue + expectedParams string + expectedError string + } + var outparam string + var tests = []test{ + { + "Single string", + []namedValue{ + {Name: "c1", Value: "somestring"}, + }, + `@c1 nvarchar(10)`, + "", + }, + { + "Input and Output params", + []namedValue{ + {Name: "", Ordinal: 0, Value: VarChar("somestring")}, + {Name: "c1", Value: int64(5)}, + {Name: "pout", Value: sql.Out{Dest: outparam}}, + }, + `@p0 varchar(10), @c1 bigint, @pout nvarchar(max) output`, + "", + }, + } + s := &Stmt{} + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + actual, err := s.buildParametersForColumnEncryption(tc.args) + if len(tc.expectedError) > 0 { + if err == nil || strings.Compare(err.Error(), tc.expectedError) != 0 { + t.Fatalf("buildParametersForColumnEncryption should have failed with %s. Got: %v", tc.expectedError, err) + } + } else if err != nil { + t.Fatalf("buildParametersForColumnEncryption failed with %s", err.Error()) + } + if strings.Compare(tc.expectedParams, actual) != 0 { + t.Fatalf("Incorrect parameters. Expected: %s. Got: %s ", tc.expectedParams, actual) + } + }) + } +} +func TestSprocQueryForCE(t *testing.T) { + type test struct { + name string + proc string + args []namedValue + expected string + } + var out int + tests := []test{ + { + "Empty args", + "m]yproc", + make([]namedValue, 0), + "EXEC [m]]yproc]", + }, + { + "No OUT args", + "myproc", + []namedValue{ + { + "p1", + 0, + 5, + nil, + }, + { + "@p2", + 0, + "val", + nil, + }, + }, + "EXEC [myproc] @p1=@p1, @p2=@p2", + }, + { + "OUT args", + "myproc", + []namedValue{ + { + "pout", + 0, + sql.Out{ + Dest: &out, + In: false, + }, + nil, + }, + { + "pin", + 1, + "in", + nil, + }, + }, + "EXEC [myproc] @pout=@pout OUTPUT, @pin=@pin", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + q := buildStoredProcedureStatementForColumnEncryption(tc.proc, tc.args) + if q != tc.expected { + t.Fatalf("Incorrect query for %s: %s", tc.name, q) + } + }) + } +} diff --git a/go.mod b/go.mod index fa06dc45..4c3ea17a 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/microsoft/go-mssqldb -go 1.13 +go 1.16 require ( github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 @@ -8,6 +8,8 @@ require ( github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 github.com/golang-sql/sqlexp v0.1.0 github.com/jcmturner/gokrb5/v8 v8.4.4 + github.com/stretchr/testify v1.8.1 golang.org/x/crypto v0.9.0 + golang.org/x/sys v0.8.0 + golang.org/x/text v0.9.0 ) - diff --git a/go.sum b/go.sum index f37d8343..69dce41d 100644 --- a/go.sum +++ b/go.sum @@ -107,8 +107,6 @@ golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce h1:+JknDZhAj8YMt7GC73Ei8pv4MzjDUNPHgQWJdtMAaDU= -gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce/go.mod h1:5AcXVHNjg+BDxry382+8OKon8SEWiKktQR07RKPsv1c= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/internal/certs/certs.go b/internal/certs/certs.go new file mode 100644 index 00000000..8646266a --- /dev/null +++ b/internal/certs/certs.go @@ -0,0 +1,55 @@ +package certs + +import ( + "bytes" + "fmt" + "math/big" + "os/exec" + "strings" + + "crypto/rand" +) + +// TODO: Create a Linux equivalent. +const ( + createUserCertScript = `New-SelfSignedCertificate -Subject "%s" -CertStoreLocation Cert:CurrentUser\My -KeyExportPolicy Exportable -Type DocumentEncryptionCert -KeyUsage KeyEncipherment -Keyspec KeyExchange -KeyLength 2048 -HashAlgorithm 'SHA256' | select {$_.Thumbprint}` + deleteUserCertScript = `Get-ChildItem Cert:\CurrentUser\My\%s | Remove-Item -DeleteKey` +) + +func ProvisionMasterKeyInCertStore() (thumbprint string, err error) { + x, _ := rand.Int(rand.Reader, big.NewInt(50000)) + subject := fmt.Sprintf(`gomssqltest-%d`, x) + + cmd := exec.Command(`C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe`, `/ExecutionPolicy`, `Unrestricted`, fmt.Sprintf(createUserCertScript, subject)) + buf := &memoryBuffer{buf: new(bytes.Buffer)} + cmd.Stdout = buf + if err = cmd.Run(); err != nil { + err = fmt.Errorf("Unable to create cert for encryption: %v", err.Error()) + return + } + out := buf.buf.String() + thumbprint = strings.Trim(out[strings.LastIndex(out, "-")+1:], " \r\n") + return +} + +func DeleteMasterKeyCert(thumbprint string) error { + cmd := exec.Command(`C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe`, `/ExecutionPolicy`, `Unrestricted`, fmt.Sprintf(deleteUserCertScript, thumbprint)) + if err := cmd.Run(); err != nil { + return fmt.Errorf("Unable to delete user cert %s. %s", thumbprint, err.Error()) + } + return nil +} + +type memoryBuffer struct { + buf *bytes.Buffer +} + +func (b *memoryBuffer) Write(p []byte) (n int, err error) { + return b.buf.Write(p) +} + +func (b *memoryBuffer) Close() error { + return nil +} + +// C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe /ExecutionPolicy Unrestricted New-SelfSignedCertificate -Subject "%s" -CertStoreLocation Cert:CurrentUser\My -KeyExportPolicy Exportable -Type DocumentEncryptionCert -KeyUsage KeyEncipherment -Keyspec KeyExchange -KeyLength 2048 | select {$_.Thumbprint} diff --git a/internal/certs/certs_windows.go b/internal/certs/certs_windows.go new file mode 100644 index 00000000..5577dd1a --- /dev/null +++ b/internal/certs/certs_windows.go @@ -0,0 +1,176 @@ +//go:build go1.17 +// +build go1.17 + +package certs + +import ( + "crypto/rsa" + "crypto/x509" + "errors" + "fmt" + "math/big" + + "unsafe" + + "golang.org/x/sys/windows" +) + +func FindCertBySignatureHash(storeHandle windows.Handle, hash []byte) (interface{}, *x509.Certificate) { + var certContext *windows.CertContext + var err error + cryptoAPIBlob := windows.CryptHashBlob{ + Size: uint32(len(hash)), + Data: &hash[0], + } + + certContext, err = windows.CertFindCertificateInStore( + storeHandle, + windows.X509_ASN_ENCODING|windows.PKCS_7_ASN_ENCODING, + 0, + windows.CERT_FIND_HASH, + unsafe.Pointer(&cryptoAPIBlob), + nil) + + if err != nil { + + panic(fmt.Errorf("Unable to find certificate by signature hash. %s", err.Error())) + } + pk, cert, err := certContextToX509(certContext) + if err != nil { + panic(err) + } + + return pk, cert +} + +func certContextToX509(ctx *windows.CertContext) (pk interface{}, cert *x509.Certificate, err error) { + // To ensure we don't mess with the cert context's memory, use a copy of it. + src := (*[1 << 20]byte)(unsafe.Pointer(ctx.EncodedCert))[:ctx.Length:ctx.Length] + der := make([]byte, int(ctx.Length)) + copy(der, src) + + cert, err = x509.ParseCertificate(der) + if err != nil { + return + } + var kh windows.Handle + var keySpec uint32 + var freeProvOrKey bool + err = windows.CryptAcquireCertificatePrivateKey(ctx, windows.CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG, nil, &kh, &keySpec, &freeProvOrKey) + if err != nil { + return + } + + pkBytes, err := nCryptExportKey(kh, "RSAFULLPRIVATEBLOB") + if freeProvOrKey { + _, _, _ = procNCryptFreeObject.Call(uintptr(kh)) + } + if err != nil { + return + } + + pk, err = unmarshalRSA(pkBytes) + return +} + +var ( + nCrypt = windows.MustLoadDLL("ncrypt.dll") + procNCryptExportKey = nCrypt.MustFindProc("NCryptExportKey") + procNCryptFreeObject = nCrypt.MustFindProc("NCryptFreeObject") +) + +// wide returns a pointer to a uint16 representing the equivalent +// to a Windows LPCWSTR. +func wide(s string) *uint16 { + w, _ := windows.UTF16PtrFromString(s) + return w +} + +func nCryptExportKey(kh windows.Handle, blobType string) ([]byte, error) { + var size uint32 + // When obtaining the size of a public key, most parameters are not required + r, _, err := procNCryptExportKey.Call( + uintptr(kh), + 0, + uintptr(unsafe.Pointer(wide(blobType))), + 0, + 0, + 0, + uintptr(unsafe.Pointer(&size)), + 0) + if !errors.Is(err, windows.Errno(0)) { + return nil, fmt.Errorf("nCryptExportKey returned %w", err) + } + if r != 0 { + return nil, fmt.Errorf("NCryptExportKey returned 0x%X during size check", uint32(r)) + } + + // Place the exported key in buf now that we know the size required + buf := make([]byte, size) + r, _, err = procNCryptExportKey.Call( + uintptr(kh), + 0, + uintptr(unsafe.Pointer(wide(blobType))), + 0, + uintptr(unsafe.Pointer(&buf[0])), + uintptr(size), + uintptr(unsafe.Pointer(&size)), + 0) + if !errors.Is(err, windows.Errno(0)) { + return nil, fmt.Errorf("nCryptExportKey returned %w", err) + } + if r != 0 { + return nil, fmt.Errorf("NCryptExportKey returned 0x%X during export", uint32(r)) + } + return buf, nil +} + +type RSA_HEADER struct { + Magic uint32 + BitLength uint32 + PublicExpSize uint32 + ModulusSize uint32 + Prime1Size uint32 + Prime2Size uint32 +} + +func unmarshalRSA(buf []byte) (*rsa.PrivateKey, error) { + // BCRYPT_RSA_BLOB -- https://learn.microsoft.com/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_rsakey_blob + cbHeader := uint32(unsafe.Sizeof(RSA_HEADER{})) + header := (*(*RSA_HEADER)(unsafe.Pointer(&buf[0]))) + buf = buf[cbHeader:] + if header.Magic != 0x33415352 { // "RSA3" BCRYPT_RSAFULLPRIVATE_MAGIC + return nil, fmt.Errorf("invalid header magic %x", header.Magic) + } + + if header.PublicExpSize > 8 { + return nil, fmt.Errorf("unsupported public exponent size (%d bits)", header.PublicExpSize*8) + } + + consumeBigInt := func(size uint32) *big.Int { + b := buf[:size] + buf = buf[size:] + return new(big.Int).SetBytes(b) + } + E := consumeBigInt(header.PublicExpSize) + N := consumeBigInt(header.ModulusSize) + P := consumeBigInt(header.Prime1Size) + Q := consumeBigInt(header.Prime2Size) + Dp := consumeBigInt(header.Prime1Size) + Dq := consumeBigInt(header.Prime2Size) + Qinv := consumeBigInt(header.Prime1Size) + D := consumeBigInt(header.ModulusSize) + + pk := &rsa.PrivateKey{ + PublicKey: rsa.PublicKey{ + N: N, + E: int(E.Int64()), + }, + D: D, + Primes: []*big.Int{P, Q}, + Precomputed: rsa.PrecomputedValues{Dp: Dp, + Dq: Dq, Qinv: Qinv, + }, + } + return pk, nil +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/LICENSE.txt b/internal/github.com/swisscom/mssql-always-encrypted/LICENSE.txt new file mode 100644 index 00000000..3ece719c --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/LICENSE.txt @@ -0,0 +1,20 @@ +Copyright (c) 2021 Swisscom (Switzerland) Ltd + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + diff --git a/internal/github.com/swisscom/mssql-always-encrypted/README.md b/internal/github.com/swisscom/mssql-always-encrypted/README.md new file mode 100644 index 00000000..c40de310 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/README.md @@ -0,0 +1,5 @@ +# mssql-always-encrypted + +A library to interact with MSSQL's Always Encrypted features. +This library mostly handles the crpyto part to facilitate +the integration with [go-mssql](https://github.com/denisenkom/go-mssqldb). \ No newline at end of file diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256.go new file mode 100644 index 00000000..d4267ab3 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256.go @@ -0,0 +1,120 @@ +package algorithms + +import ( + "crypto/rand" + "crypto/subtle" + "fmt" + + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys" +) + +// https://tools.ietf.org/html/draft-mcgrew-aead-aes-cbc-hmac-sha2-05 +// https://winprotocoldoc.blob.core.windows.net/productionwindowsarchives/MS-TDS/%5bMS-TDS%5d.pdf + +var _ Algorithm = &AeadAes256CbcHmac256Algorithm{} + +type AeadAes256CbcHmac256Algorithm struct { + algorithmVersion byte + deterministic bool + blockSizeBytes int + keySizeBytes int + minimumCipherTextLengthBytesNoAuthTag int + minimumCipherTextLengthBytesWithAuthTag int + cek keys.AeadAes256CbcHmac256 + version []byte + versionSize []byte +} + +func NewAeadAes256CbcHmac256Algorithm(key keys.AeadAes256CbcHmac256, encType encryption.Type, algorithmVersion byte) AeadAes256CbcHmac256Algorithm { + const keySizeBytes = 256 / 8 + const blockSizeBytes = 16 + const minimumCipherTextLengthBytesNoAuthTag = 1 + 2*blockSizeBytes + const minimumCipherTextLengthBytesWithAuthTag = minimumCipherTextLengthBytesNoAuthTag + keySizeBytes + + a := AeadAes256CbcHmac256Algorithm{ + algorithmVersion: algorithmVersion, + deterministic: encType.Deterministic, + blockSizeBytes: blockSizeBytes, + keySizeBytes: keySizeBytes, + cek: key, + minimumCipherTextLengthBytesNoAuthTag: minimumCipherTextLengthBytesNoAuthTag, + minimumCipherTextLengthBytesWithAuthTag: minimumCipherTextLengthBytesWithAuthTag, + version: []byte{0x01}, + versionSize: []byte{1}, + } + + a.version[0] = algorithmVersion + return a +} + +func (a *AeadAes256CbcHmac256Algorithm) Encrypt(cleartext []byte) ([]byte, error) { + buf := make([]byte, 0) + var iv []byte + if a.deterministic { + iv = crypto.Sha256Hmac(cleartext, a.cek.IvKey()) + if len(iv) > a.blockSizeBytes { + iv = iv[:a.blockSizeBytes] + } + } else { + iv = make([]byte, a.blockSizeBytes) + _, err := rand.Read(iv) + if err != nil { + panic(err) + } + } + buf = append(buf, a.algorithmVersion) + aescdbc := crypto.NewAESCbcPKCS5(a.cek.EncryptionKey(), iv) + ciphertext := aescdbc.Encrypt(cleartext) + authTag := a.prepareAuthTag(iv, ciphertext) + buf = append(buf, authTag...) + buf = append(buf, iv...) + buf = append(buf, ciphertext...) + return buf, nil +} + +func (a *AeadAes256CbcHmac256Algorithm) Decrypt(ciphertext []byte) ([]byte, error) { + // This algorithm always has the auth tag! + minimumCiphertextLength := a.minimumCipherTextLengthBytesWithAuthTag + + if len(ciphertext) < minimumCiphertextLength { + return nil, fmt.Errorf("invalid ciphertext length: at least %v bytes expected", minimumCiphertextLength) + } + + idx := 0 + if ciphertext[idx] != a.algorithmVersion { + return nil, fmt.Errorf("invalid algorithm version used: %v found but %v expected", ciphertext[idx], + a.algorithmVersion) + } + + idx++ + authTag := ciphertext[idx : idx+a.keySizeBytes] + idx += a.keySizeBytes + + iv := ciphertext[idx : idx+a.blockSizeBytes] + idx += len(iv) + + realCiphertext := ciphertext[idx:] + ourAuthTag := a.prepareAuthTag(iv, realCiphertext) + + // bytes.Compare is subject to timing attacks + if subtle.ConstantTimeCompare(ourAuthTag, authTag) != 1 { + return nil, fmt.Errorf("invalid auth tag") + } + + // decrypt + aescdbc := crypto.NewAESCbcPKCS5(a.cek.EncryptionKey(), iv) + cleartext := aescdbc.Decrypt(realCiphertext) + + return cleartext, nil +} + +func (a *AeadAes256CbcHmac256Algorithm) prepareAuthTag(iv []byte, ciphertext []byte) []byte { + var input = make([]byte, 0) + input = append(input, a.algorithmVersion) + input = append(input, iv...) + input = append(input, ciphertext...) + input = append(input, a.versionSize...) + return crypto.Sha256Hmac(input, a.cek.MacKey()) +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256_test.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256_test.go new file mode 100644 index 00000000..ad59b292 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/aead_aes_256_cbc_hmac_sha256_test.go @@ -0,0 +1,37 @@ +package algorithms_test + +import ( + "encoding/hex" + "testing" + + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys" + "github.com/stretchr/testify/assert" +) + +func TestAeadAes256CbcHmac256Algorithm_Decrypt(t *testing.T) { + expectedResult, err := hex.DecodeString("3100320033003400350020002000200020002000") + if err != nil { + t.Fatal(err) + } + + cipherText, err := hex.DecodeString("0181c4b77e1c50583c5e83a20afd4c98ce5acb39a636f00247b3a4d78a8be319c840e6970541a66723583def227eb774b4234cff209443b0209b75309532b527bdf9b2dfb326b4428840532a20460d06d4") + if err != nil { + t.Fatal(err) + } + + rootKey, err := hex.DecodeString("0ff9e45335df3dec7be0649f741e6ea870e9d49d16fe4be7437ce22489f48ead") + if err != nil { + t.Fatal(err) + } + + key := keys.NewAeadAes256CbcHmac256(rootKey) + alg := algorithms.NewAeadAes256CbcHmac256Algorithm(key, encryption.Deterministic, 1) + + result, err := alg.Decrypt(cipherText) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, expectedResult, result) +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/algorithm.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/algorithm.go new file mode 100644 index 00000000..ea1ca6b8 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms/algorithm.go @@ -0,0 +1,6 @@ +package algorithms + +type Algorithm interface { + Encrypt([]byte) ([]byte, error) + Decrypt([]byte) ([]byte, error) +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted.go new file mode 100644 index 00000000..64ca57f6 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted.go @@ -0,0 +1,79 @@ +package alwaysencrypted + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "crypto/x509" + "encoding/binary" + "unicode/utf16" +) + +type CEKV struct { + Version int + KeyPath string + Ciphertext []byte + SignedHash []byte + DataToSign []byte + + Key []byte +} + +func (c *CEKV) Verify(cert *x509.Certificate) bool { + sha256Sum := sha256.Sum256(c.DataToSign) + err := rsa.VerifyPKCS1v15(cert.PublicKey.(*rsa.PublicKey), crypto.SHA256, sha256Sum[:], c.SignedHash) + + return err == nil +} + +func (c *CEKV) Decrypt(private *rsa.PrivateKey) ([]byte, error) { + decryptedData, decryptErr := rsa.DecryptOAEP(sha1.New(), rand.Reader, private, c.Ciphertext, nil) + if decryptErr != nil { + return nil, decryptErr + } + + return decryptedData, nil +} + +func LoadCEKV(bytes []byte) CEKV { + idx := 0 + version := int(bytes[idx]) + idx++ + + keyPathLengthBytes := bytes[idx : idx+2] + keyPathLength := binary.LittleEndian.Uint16(keyPathLengthBytes) + idx += 2 + + cipherTextLengthBytes := bytes[idx : idx+2] + cipherTextLength := binary.LittleEndian.Uint16(cipherTextLengthBytes) + idx += 2 + + keyPathBytes := bytes[idx : idx+int(keyPathLength)] + idx += int(keyPathLength) + + var uint16Bytes []uint16 + for i := range keyPathBytes { + if i%2 == 0 { + continue + } + uint16Value := binary.LittleEndian.Uint16([]byte{keyPathBytes[i-1], keyPathBytes[i]}) + uint16Bytes = append(uint16Bytes, uint16Value) + } + keyPath := string(utf16.Decode(uint16Bytes)) + + cipherText := bytes[idx : idx+int(cipherTextLength)] + idx += int(cipherTextLength) + + dataToSign := bytes[0:idx] + signedHash := bytes[idx:] + + return CEKV{ + Version: version, + KeyPath: keyPath, + DataToSign: dataToSign, + Ciphertext: cipherText, + SignedHash: signedHash, + } +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go new file mode 100644 index 00000000..11e4c237 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/alwaysencrypted_test.go @@ -0,0 +1,106 @@ +package alwaysencrypted + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/x509" + "encoding/pem" + "fmt" + "io/ioutil" + "os" + "testing" + + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys" + "github.com/stretchr/testify/assert" + "golang.org/x/text/encoding/unicode" +) + +func TestLoadCEKV(t *testing.T) { + certFile, err := os.Open("../test/always-encrypted_pub.pem") + assert.NoError(t, err) + + certBytes, err := ioutil.ReadAll(certFile) + assert.NoError(t, err) + pemB, _ := pem.Decode(certBytes) + cert, err := x509.ParseCertificate(pemB.Bytes) + assert.NoError(t, err) + + cekvFile, err := os.Open("../test/cekv.key") + assert.NoError(t, err) + cekvBytes, err := ioutil.ReadAll(cekvFile) + assert.NoError(t, err) + cekv := LoadCEKV(cekvBytes) + assert.Equal(t, 1, cekv.Version) + assert.True(t, cekv.Verify(cert)) +} +func TestDecrypt(t *testing.T) { + certFile, err := os.Open("../test/always-encrypted.pem") + assert.NoError(t, err) + + certBytes, err := ioutil.ReadAll(certFile) + assert.NoError(t, err) + pemB, _ := pem.Decode(certBytes) + privKey, err := x509.ParsePKCS8PrivateKey(pemB.Bytes) + assert.NoError(t, err) + + rsaPrivKey := privKey.(*rsa.PrivateKey) + + cekvFile, err := os.Open("../test/cekv.key") + assert.NoError(t, err) + cekvBytes, err := ioutil.ReadAll(cekvFile) + assert.NoError(t, err) + cekv := LoadCEKV(cekvBytes) + rootKey, err := cekv.Decrypt(rsaPrivKey) + assert.NoError(t, err) + assert.Equal(t, "0ff9e45335df3dec7be0649f741e6ea870e9d49d16fe4be7437ce22489f48ead", fmt.Sprintf("%02x", rootKey)) + assert.Equal(t, 1, cekv.Version) + assert.NotNil(t, rootKey) + + columnBytesFile, err := os.Open("../test/column_value.enc") + assert.NoError(t, err) + + columnBytes, err := ioutil.ReadAll(columnBytesFile) + assert.NoError(t, err) + + key := keys.NewAeadAes256CbcHmac256(rootKey) + alg := algorithms.NewAeadAes256CbcHmac256Algorithm(key, encryption.Deterministic, 1) + cleartext, err := alg.Decrypt(columnBytes) + assert.NoErrorf(t, err, "Decrypt failed! %v", err) + + enc := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM) + decoder := enc.NewDecoder() + cleartextUtf8, err := decoder.Bytes(cleartext) + assert.NoError(t, err) + t.Logf("column value: \"%02X\"", cleartextUtf8) + assert.Equal(t, "12345 ", string(cleartextUtf8)) +} +func TestDecryptCEK(t *testing.T) { + certFile, err := os.Open("../test/always-encrypted.pem") + assert.NoError(t, err) + + certFileBytes, err := ioutil.ReadAll(certFile) + assert.NoError(t, err) + + pemBlock, _ := pem.Decode(certFileBytes) + cert, err := x509.ParsePKCS8PrivateKey(pemBlock.Bytes) + assert.NoError(t, err) + + cekvFile, err := os.Open("../test/cekv.key") + assert.NoError(t, err) + + cekvBytes, err := ioutil.ReadAll(cekvFile) + assert.NoError(t, err) + + cekv := LoadCEKV(cekvBytes) + t.Logf("Cert: %v\n", cert) + + rsaKey := cert.(*rsa.PrivateKey) + + // RSA/ECB/OAEPWithSHA-1AndMGF1Padding + bytes, err := rsa.DecryptOAEP(sha1.New(), rand.Reader, rsaKey, cekv.Ciphertext, nil) + assert.NoError(t, err) + t.Logf("Key: %02x\n", bytes) +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/aes_cbc_pkcs5.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/aes_cbc_pkcs5.go new file mode 100644 index 00000000..4ea2e5be --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/aes_cbc_pkcs5.go @@ -0,0 +1,69 @@ +package crypto + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "fmt" +) + +// Inspired by: https://gist.github.com/hothero/7d085573f5cb7cdb5801d7adcf66dcf3 + +type AESCbcPKCS5 struct { + key []byte + iv []byte + block cipher.Block +} + +func NewAESCbcPKCS5(key []byte, iv []byte) AESCbcPKCS5 { + a := AESCbcPKCS5{ + key: key, + iv: iv, + block: nil, + } + a.initCipher() + return a +} + +func (a AESCbcPKCS5) Encrypt(cleartext []byte) (cipherText []byte) { + if a.block == nil { + a.initCipher() + } + + blockMode := cipher.NewCBCEncrypter(a.block, a.iv) + paddedCleartext := PKCS5Padding(cleartext, blockMode.BlockSize()) + cipherText = make([]byte, len(paddedCleartext)) + blockMode.CryptBlocks(cipherText, paddedCleartext) + return +} + +func (a AESCbcPKCS5) Decrypt(ciphertext []byte) []byte { + if a.block == nil { + a.initCipher() + } + + blockMode := cipher.NewCBCDecrypter(a.block, a.iv) + var cleartext = make([]byte, len(ciphertext)) + blockMode.CryptBlocks(cleartext, ciphertext) + return PKCS5Trim(cleartext) +} + +func PKCS5Padding(inArr []byte, blockSize int) []byte { + padding := blockSize - len(inArr)%blockSize + padText := bytes.Repeat([]byte{byte(padding)}, padding) + return append(inArr, padText...) +} + +func PKCS5Trim(inArr []byte) []byte { + padding := inArr[len(inArr)-1] + return inArr[:len(inArr)-int(padding)] +} + +func (a *AESCbcPKCS5) initCipher() { + block, err := aes.NewCipher(a.key) + if err != nil { + panic(fmt.Errorf("unable to create cipher: %v", err)) + } + + a.block = block +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/utils.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/utils.go new file mode 100644 index 00000000..b8f9319f --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto/utils.go @@ -0,0 +1,12 @@ +package crypto + +import ( + "crypto/hmac" + "crypto/sha256" +) + +func Sha256Hmac(input []byte, key []byte) []byte { + sha256Hmac := hmac.New(sha256.New, key) + sha256Hmac.Write(input) + return sha256Hmac.Sum(nil) +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption/type.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption/type.go new file mode 100644 index 00000000..a46dc3d7 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption/type.go @@ -0,0 +1,37 @@ +package encryption + +type Type struct { + Deterministic bool + Name string + Value byte +} + +var Plaintext = Type{ + Deterministic: false, + Name: "Plaintext", + Value: 0, +} + +var Deterministic = Type{ + Deterministic: true, + Name: "Deterministic", + Value: 1, +} + +var Randomized = Type{ + Deterministic: false, + Name: "Randomized", + Value: 2, +} + +func From(encType byte) Type { + switch encType { + case 0: + return Plaintext + case 1: + return Deterministic + case 2: + return Randomized + } + return Plaintext +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/aead_aes_256_cbc_hmac_256.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/aead_aes_256_cbc_hmac_256.go new file mode 100644 index 00000000..4c1dba15 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/aead_aes_256_cbc_hmac_256.go @@ -0,0 +1,51 @@ +package keys + +import ( + "fmt" + + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/utils" +) + +var _ Key = &AeadAes256CbcHmac256{} + +type AeadAes256CbcHmac256 struct { + rootKey []byte + encryptionKey []byte + macKey []byte + ivKey []byte +} + +func NewAeadAes256CbcHmac256(rootKey []byte) AeadAes256CbcHmac256 { + const keySize = 256 + const encryptionKeySaltFormat = "Microsoft SQL Server cell encryption key with encryption algorithm:%v and key length:%v" + const macKeySaltFormat = "Microsoft SQL Server cell MAC key with encryption algorithm:%v and key length:%v" + const ivKeySaltFormat = "Microsoft SQL Server cell IV key with encryption algorithm:%v and key length:%v" + const algorithmName = "AEAD_AES_256_CBC_HMAC_SHA256" + + encryptionKeySalt := utils.ProcessUTF16LE(fmt.Sprintf(encryptionKeySaltFormat, algorithmName, keySize)) + macKeySalt := utils.ProcessUTF16LE(fmt.Sprintf(macKeySaltFormat, algorithmName, keySize)) + ivKeySalt := utils.ProcessUTF16LE(fmt.Sprintf(ivKeySaltFormat, algorithmName, keySize)) + + return AeadAes256CbcHmac256{ + rootKey: rootKey, + encryptionKey: crypto.Sha256Hmac(encryptionKeySalt, rootKey), + macKey: crypto.Sha256Hmac(macKeySalt, rootKey), + ivKey: crypto.Sha256Hmac(ivKeySalt, rootKey)} +} + +func (a AeadAes256CbcHmac256) IvKey() []byte { + return a.ivKey +} + +func (a AeadAes256CbcHmac256) MacKey() []byte { + return a.macKey +} + +func (a AeadAes256CbcHmac256) EncryptionKey() []byte { + return a.encryptionKey +} + +func (a AeadAes256CbcHmac256) RootKey() []byte { + return a.rootKey +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/key.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/key.go new file mode 100644 index 00000000..f778e902 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys/key.go @@ -0,0 +1,5 @@ +package keys + +type Key interface { + RootKey() []byte +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/pkg/utils/utf16.go b/internal/github.com/swisscom/mssql-always-encrypted/pkg/utils/utf16.go new file mode 100644 index 00000000..4eb13390 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/pkg/utils/utf16.go @@ -0,0 +1,18 @@ +package utils + +import ( + "encoding/binary" + "unicode/utf16" +) + +func ConvertUTF16ToLittleEndianBytes(u []uint16) []byte { + b := make([]byte, 2*len(u)) + for index, value := range u { + binary.LittleEndian.PutUint16(b[index*2:], value) + } + return b +} + +func ProcessUTF16LE(inputString string) []byte { + return ConvertUTF16ToLittleEndianBytes(utf16.Encode([]rune(inputString))) +} diff --git a/internal/github.com/swisscom/mssql-always-encrypted/test/always-encrypted.pem b/internal/github.com/swisscom/mssql-always-encrypted/test/always-encrypted.pem new file mode 100644 index 00000000..382ab002 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/test/always-encrypted.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDFkCQfKsTM0UMw +pRA4t/kWYWTFUMJyEu5i4Zhw4yL0emFUcDdLqmwqrgR2eC1RFb9CU4UnLowVoq6H +vXlYuHnlCQio1QwLio80WoHBezdU/TQBBbpm5D2FxOLhzen2Puby+PQd+ZIBVuQu +I920g5O0wke86RmWmNdu0jaftwNqoIqc/TAqjYNKB2/CwnPnHwsCHJjIhCoSGlCa +WsQZSptSeqLQ87eaVfJypJpxG5FJ+bOXjFdgpXY3XOQoeR+xsXs2AKZ+eKOaSmb+ +Hg+pvMGBCXuSwBIAwPUxlCQSe2dfcXTkF+stadfH6EvVyIvK0G8RZ9N0Ow5vyRaU +95Bxc+2BAgMBAAECggEBAKJmz9qy/J3lc5ccSQ5m5SJpoz20GnNNbproGbjKbiSM +KVARAtN3X31iGRcNySq7dsJeB7niwJLUbSX2MjclRkZpO64Vm9Ys63U85ScYU67Q +iZxBii4kdxJse5jk/OtIX+7hiULOsh/Zvq7TGt/VvWi8v93hvAAY2hcmRHLcLbnK +li9DLnN3dIJoFh3y2OHlFfvFcX04wNmyfv04/FZKliGwrONkTN1YvEclU3XSjdrH +JM2977u+rB216Y1jiIObFceKj573hBAwS+gU2kx7g9Fpq9SvwszxmHMWtJQvJxg+ +7ClBeB8aSu1wSydm/0hfmwFNBH9c4BDVo3P1+K37PQUCgYEA8Lnceo9S4NOog5ri +taSVUqoHjruRU2tqFFi1wni+dw0m99kd5h8p9K0aXwvvjP8cmpK/ultSVZb9NzEz +zA5ZXXxT83QZOmq4FJCl31tjhcA/oidD139dCpe3RQ08ToClJgOuG8obS0hgy9Xt +sa16HgYP4aDerEgXR2fg3TWW1icCgYEA0hkt2FXFTh8L9z3nb/a8TNGBgVlafxcV +d4m1HhDoJ+GF8yscvUq7kn4xG2BHA5GNnUn0hIfrci/A0CXNGVOeUufgOUBKw39V +5Wq26ryElDcQ7CyJ36yH8/zQ4jgUOVo+R+jSO0+L4H1T/vP9F1ARtORb0/Ga5JFq +pxh6Q5VB0BcCgYEAh/2Hd1lGSapolUhHcLP0g0l4kYKWu5h/ydS/gYgymRC+BeAK +yvip/AZaUn1sq6tm3k+urjluztlIXQiXqVwl0fEtf+gDZIPrT/rTKdX36BROHm2u +HqxdxGEm8IRkoDh+k3YawqovNx1BSYWmDOzigtmL2TvG726ecAFX/7+JYZsCgYAf +kHTYyZoI8JUlogFBSvpjOB6Sxk/YRCmPefrh93xJcZJkRBffQHkJuze5ey9wE9AI +z3GS77CpyQ7YtrUnlu50Wi3PrB8PW/QVsYClp4jrk5JRSSe1mQAb4eGn+vDe5PXy +a8IZ8wt6wJl79kAR3o+qc5xwLR4uNMKnNAA6YxQuJQKBgQCIjo++s0i1pxf60CaL +2Mph/sDztdv0nZMPZzN0j2HGGJ21tKi3O+V+VoHHIs2YYjTsFu5Iwc7LONiGN+SF +38ojT7uWyY4Jz+9Sr4uYTJvWLc9G4BCkco3RNowLK8tb6TfewajWXeAzlz/Eafmj +nlUFODdXG+URQ5tpDjdCd6zbpQ== +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/internal/github.com/swisscom/mssql-always-encrypted/test/always-encrypted_pub.pem b/internal/github.com/swisscom/mssql-always-encrypted/test/always-encrypted_pub.pem new file mode 100644 index 00000000..b0b4a9e5 --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/test/always-encrypted_pub.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDKjCCAhKgAwIBAgIQRlupjX13FaVC/c36tbVQxzANBgkqhkiG9w0BAQsFADAn +MSUwIwYDVQQDDBxBbHdheXMgRW5jcnlwdGVkIENlcnRpZmljYXRlMB4XDTIxMDEy +NjE1MDgyMloXDTIyMDEyNjE1MDgyMlowJzElMCMGA1UEAwwcQWx3YXlzIEVuY3J5 +cHRlZCBDZXJ0aWZpY2F0ZTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB +AMWQJB8qxMzRQzClEDi3+RZhZMVQwnIS7mLhmHDjIvR6YVRwN0uqbCquBHZ4LVEV +v0JThScujBWiroe9eVi4eeUJCKjVDAuKjzRagcF7N1T9NAEFumbkPYXE4uHN6fY+ +5vL49B35kgFW5C4j3bSDk7TCR7zpGZaY127SNp+3A2qgipz9MCqNg0oHb8LCc+cf +CwIcmMiEKhIaUJpaxBlKm1J6otDzt5pV8nKkmnEbkUn5s5eMV2Cldjdc5Ch5H7Gx +ezYApn54o5pKZv4eD6m8wYEJe5LAEgDA9TGUJBJ7Z19xdOQX6y1p18foS9XIi8rQ +bxFn03Q7Dm/JFpT3kHFz7YECAwEAAaNSMFAwHwYDVR0lBBgwFgYIKwYBBQUIAgIG +CisGAQQBgjcKAwswHQYDVR0OBBYEFNQfS2liOJPsJuonIc0KPF4+CtFIMA4GA1Ud +DwEB/wQEAwIFIDANBgkqhkiG9w0BAQsFAAOCAQEAKMzuAfIv6uGxgx+SGgjDqk2O +oVdRul5xB/QlChdhzTrMwpIdul0+eLo46gqPdj/5kxWhQGNMuns+5/QrSfbaqAUz +ZWFsNAm+bhTBsgy9VSor3QUGedfQV3fP/8aZ/nvgLUe7PegmFBIiSALyjvCdayb5 +UZIxcBGQTmmpqGmL0hnRQwE2JvneOGEAiIIOTObCzgWyKhKuF2DWxinBtzyRlXfD +TV15+7v5kAdrjLevk57NOEshr0IDirD9auI61bqoxJZFyDqkdLZWED69pbCF8Ly5 +zbC8uUnDh3enxgmnUPXU/JZM1dbiPHZBxkUjVOoMYxycr0YgROJk7w5cfjrMYQ== +-----END CERTIFICATE----- diff --git a/internal/github.com/swisscom/mssql-always-encrypted/test/cekv.key b/internal/github.com/swisscom/mssql-always-encrypted/test/cekv.key new file mode 100644 index 0000000000000000000000000000000000000000..d26e9f9eedbdf4a3744f65c4a6b3939ec3ee01ed GIT binary patch literal 627 zcmZR~V_;xRW+-JS0>V^=Jcbe=yBNqSV$f&EWvB#_1`J6+z9oY>g9VV42qX=G{8WY% z1~VYb00<2kOc~OEe2~03Lo!fJDv)mq7BvK`G69Pk1Le&aJkq~zPu_lh`)s*?tmnV( z`Lc)C=4`A(#4&mQ!=In^O@4H$Xr|NkeHLM?&E?Y!g(?F+9lXzH1^G^rhY|n{(pvH2<*LH(I-wN9^{SckSwTSS7yWB+v8MhN$sT|Z}#xRJko zYQ`Q*(e7QF136p{mLD&ebI!bN+r-C@4SasC5pJ9nDsXUd>BBjbq<_8PTbktM#xs9z z(2bW_dWG_~1+l!^=7pR*`>r(!seJdb6LjJCjI>o@T9LMX7vsjADam)1w9mS~=!3!A z^JPL83pj)sy(NWxo}c(@){s1FHPc${3o9iqUfIx?{;bA%&AGm{ZejT!Y-a8@y!*4r zr~HNE&ThsBub5pDPVm$_N~gY0mbOTKU~6;bP~x*Arb&rs!?rEZ=sa6u7yV_mE4!=O zxuuQ!uUvAtW2AmXe`=F^(1a7hAJX_V&+<>PS!nLM`piMbYp(*BCLUCtf2i(%WnlMx JzsBF&WdU&e5_tds literal 0 HcmV?d00001 diff --git a/internal/github.com/swisscom/mssql-always-encrypted/test/column_value.enc b/internal/github.com/swisscom/mssql-always-encrypted/test/column_value.enc new file mode 100644 index 00000000..b3243a4c --- /dev/null +++ b/internal/github.com/swisscom/mssql-always-encrypted/test/column_value.enc @@ -0,0 +1,2 @@ +伳穨PX<^儮 +齃樜Z󈖬G长讑嬨菮鏃A#X="~穞#L 擟 泆02'靳策&碆園S* F  \ No newline at end of file diff --git a/internal/github.com/swisscom/mssql-always-encrypted/test/decrypted_key.key b/internal/github.com/swisscom/mssql-always-encrypted/test/decrypted_key.key new file mode 100644 index 0000000000000000000000000000000000000000..d26e9f9eedbdf4a3744f65c4a6b3939ec3ee01ed GIT binary patch literal 627 zcmZR~V_;xRW+-JS0>V^=Jcbe=yBNqSV$f&EWvB#_1`J6+z9oY>g9VV42qX=G{8WY% z1~VYb00<2kOc~OEe2~03Lo!fJDv)mq7BvK`G69Pk1Le&aJkq~zPu_lh`)s*?tmnV( z`Lc)C=4`A(#4&mQ!=In^O@4H$Xr|NkeHLM?&E?Y!g(?F+9lXzH1^G^rhY|n{(pvH2<*LH(I-wN9^{SckSwTSS7yWB+v8MhN$sT|Z}#xRJko zYQ`Q*(e7QF136p{mLD&ebI!bN+r-C@4SasC5pJ9nDsXUd>BBjbq<_8PTbktM#xs9z z(2bW_dWG_~1+l!^=7pR*`>r(!seJdb6LjJCjI>o@T9LMX7vsjADam)1w9mS~=!3!A z^JPL83pj)sy(NWxo}c(@){s1FHPc${3o9iqUfIx?{;bA%&AGm{ZejT!Y-a8@y!*4r zr~HNE&ThsBub5pDPVm$_N~gY0mbOTKU~6;bP~x*Arb&rs!?rEZ=sa6u7yV_mE4!=O zxuuQ!uUvAtW2AmXe`=F^(1a7hAJX_V&+<>PS!nLM`piMbYp(*BCLUCtf2i(%WnlMx JzsBF&WdU&e5_tds literal 0 HcmV?d00001 diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index 4f71453d..ee186e0d 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -88,6 +88,8 @@ type Config struct { ProtocolParameters map[string]interface{} // BrowserMsg is the message identifier to fetch instance data from SQL browser BrowserMessage BrowserMsg + //ColumnEncryption is true if the application needs to decrypt or encrypt Always Encrypted values + ColumnEncryption bool } // Build a tls.Config object from the supplied certificate. @@ -371,6 +373,19 @@ func Parse(dsn string) (Config, error) { return p, err } + if c, ok := params["columnencryption"]; ok { + columnEncryption, err := strconv.ParseBool(c) + if err != nil { + if strings.EqualFold(c, "Enabled") { + columnEncryption = true + } else if strings.EqualFold(c, "Disabled") { + columnEncryption = false + } else { + return p, fmt.Errorf("invalid columnencryption '%v' : %v", columnEncryption, err.Error()) + } + } + p.ColumnEncryption = columnEncryption + } return p, nil } @@ -421,6 +436,9 @@ func (p Config) URL() *url.URL { res.Path = p.Instance } q.Add("dial timeout", strconv.FormatFloat(float64(p.DialTimeout.Seconds()), 'f', 0, 64)) + if p.ColumnEncryption { + q.Add("columnencryption", "true") + } if len(q) > 0 { res.RawQuery = q.Encode() } @@ -428,15 +446,17 @@ func (p Config) URL() *url.URL { return &res } +// ADO connection string keywords at https://github.com/dotnet/SqlClient/blob/main/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/DbConnectionStringCommon.cs var adoSynonyms = map[string]string{ - "application name": "app name", - "data source": "server", - "address": "server", - "network address": "server", - "addr": "server", - "user": "user id", - "uid": "user id", - "initial catalog": "database", + "application name": "app name", + "data source": "server", + "address": "server", + "network address": "server", + "addr": "server", + "user": "user id", + "uid": "user id", + "initial catalog": "database", + "column encryption setting": "columnencryption", } func splitConnectionString(dsn string) (res map[string]string) { diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index 5fa1a0ed..1e001385 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -54,7 +54,7 @@ func TestValidConnectionString(t *testing.T) { {"server=server\\instance;database=testdb;user id=tester;password=pwd", func(p Config) bool { return p.Host == "server" && p.Instance == "instance" && p.User == "tester" && p.Password == "pwd" }}, - {"server=.", func(p Config) bool { return p.Host == "localhost" }}, + {"server=.", func(p Config) bool { return p.Host == "localhost" && !p.ColumnEncryption }}, {"server=(local)", func(p Config) bool { return p.Host == "localhost" }}, {"ServerSPN=serverspn;Workstation ID=workstid", func(p Config) bool { return p.ServerSPN == "serverspn" && p.Workstation == "workstid" }}, {"failoverpartner=fopartner;failoverport=2000", func(p Config) bool { return p.FailOverPartner == "fopartner" && p.FailOverPort == 2000 }}, @@ -68,8 +68,8 @@ func TestValidConnectionString(t *testing.T) { {"encrypt=false;tlsmin=1.0", func(p Config) bool { return p.Encryption == EncryptionOff && p.TLSConfig.MinVersion == tls.VersionTLS10 }}, - {"encrypt=true;tlsmin=1.1", func(p Config) bool { - return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 + {"encrypt=true;tlsmin=1.1;column encryption setting=enabled", func(p Config) bool { + return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption }}, {"encrypt=true;tlsmin=1.2", func(p Config) bool { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS12 @@ -174,10 +174,10 @@ func TestValidConnectionString(t *testing.T) { return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.DisableRetry }}, {"sqlserver://someuser@somehost?connection+timeout=30&disableretry=1", func(p Config) bool { - return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second && p.DisableRetry + return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second && p.DisableRetry && !p.ColumnEncryption }}, - {"sqlserver://somehost?encrypt=true&tlsmin=1.1", func(p Config) bool { - return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 + {"sqlserver://somehost?encrypt=true&tlsmin=1.1&columnencryption=1", func(p Config) bool { + return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption }}, } for _, ts := range connStrings { diff --git a/mssql.go b/mssql.go index 23a271d2..8870410d 100644 --- a/mssql.go +++ b/mssql.go @@ -17,6 +17,7 @@ import ( "unicode" "github.com/golang-sql/sqlexp" + "github.com/microsoft/go-mssqldb/aecmk" "github.com/microsoft/go-mssqldb/internal/querytext" "github.com/microsoft/go-mssqldb/msdsn" ) @@ -69,10 +70,7 @@ func (d *Driver) OpenConnector(dsn string) (*Connector, error) { return nil, err } - return &Connector{ - params: params, - driver: d, - }, nil + return newConnector(params, d), nil } func (d *Driver) Open(dsn string) (driver.Conn, error) { @@ -122,10 +120,8 @@ func NewConnector(dsn string) (*Connector, error) { if err != nil { return nil, err } - c := &Connector{ - params: params, - driver: driverInstanceNoProcess, - } + c := newConnector(params, driverInstanceNoProcess) + return c, nil } @@ -146,9 +142,14 @@ func NewConnectorWithAccessTokenProvider(dsn string, tokenProvider func(ctx cont // NewConnectorConfig creates a new Connector for a DSN Config struct. // The returned connector may be used with sql.OpenDB. func NewConnectorConfig(config msdsn.Config) *Connector { + return newConnector(config, driverInstanceNoProcess) +} + +func newConnector(config msdsn.Config, driver *Driver) *Connector { return &Connector{ - params: config, - driver: driverInstanceNoProcess, + params: config, + driver: driver, + keyProviders: make(aecmk.ColumnEncryptionKeyProviderMap), } } @@ -199,6 +200,8 @@ type Connector struct { // // If Dialer is not set, normal net dialers are used. Dialer Dialer + + keyProviders aecmk.ColumnEncryptionKeyProviderMap } type Dialer interface { @@ -219,6 +222,11 @@ func (c *Connector) getDialer(p *msdsn.Config) Dialer { return createDialer(p) } +// RegisterCekProvider associates the given provider with the named key store. If an entry of the given name already exists, that entry is overwritten +func (c *Connector) RegisterCekProvider(name string, provider aecmk.ColumnEncryptionKeyProvider) { + c.keyProviders[name] = aecmk.NewCekProvider(provider) +} + type Conn struct { connector *Connector sess *tdsSession @@ -403,7 +411,7 @@ func (d *Driver) open(ctx context.Context, dsn string) (*Conn, error) { if err != nil { return nil, err } - c := &Connector{params: params} + c := newConnector(params, nil) return d.connect(ctx, c, params) } @@ -445,10 +453,11 @@ func (c *Conn) Close() error { } type Stmt struct { - c *Conn - query string - paramCount int - notifSub *queryNotifSub + c *Conn + query string + paramCount int + notifSub *queryNotifSub + skipEncryption bool } type queryNotifSub struct { @@ -472,7 +481,7 @@ func (c *Conn) prepareContext(ctx context.Context, query string) (*Stmt, error) if c.processQueryText { query, paramCount = querytext.ParseParams(query) } - return &Stmt{c, query, paramCount, nil}, nil + return &Stmt{c, query, paramCount, nil, false}, nil } func (s *Stmt) Close() error { @@ -654,16 +663,38 @@ func (s *Stmt) makeRPCParams(args []namedValue, isProc bool) ([]param, []string, if isOutputValue(val.Value) { output = outputSuffix } - decls[i] = fmt.Sprintf("%s %s%s", name, makeDecl(params[i+offset].ti), output) + tiDecl := params[i+offset].ti + if val.encrypt != nil { + // Encrypted parameters have a few requirements: + // 1. Copy original typeinfo to a block after the data + // 2. Set the parameter type to varbinary(max) + // 3. Append the crypto metadata bytes + params[i+offset].tiOriginal = params[i+offset].ti + params[i+offset].Flags |= fEncrypted + encryptedBytes, metadata, err := val.encrypt(params[i+offset].buffer) + if err != nil { + return nil, nil, err + } + params[i+offset].cipherInfo = metadata + params[i+offset].ti.TypeId = typeBigVarBin + params[i+offset].buffer = encryptedBytes + params[i+offset].ti.Size = 0 + } + + decls[i] = fmt.Sprintf("%s %s%s", name, makeDecl(tiDecl), output) } return params, decls, nil } +// Encrypts the input bytes. Returns the encrypted bytes followed by the encryption metadata to append to the packet. +type valueEncryptor func(bytes []byte) ([]byte, []byte, error) + type namedValue struct { Name string Ordinal int Value driver.Value + encrypt valueEncryptor } func convertOldArgs(args []driver.Value) []namedValue { @@ -677,6 +708,10 @@ func convertOldArgs(args []driver.Value) []namedValue { return list } +func (s *Stmt) doEncryption() bool { + return !s.skipEncryption && s.c.sess.alwaysEncrypted +} + func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) { defer s.c.clearOuts() @@ -687,6 +722,12 @@ func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver if !s.c.connectionGood { return nil, driver.ErrBadConn } + if s.doEncryption() && len(args) > 0 { + args, err = s.encryptArgs(ctx, args) + } + if err != nil { + return nil, err + } if err = s.sendQuery(ctx, args); err != nil { return nil, s.c.checkBadConn(ctx, err, true) } @@ -754,6 +795,12 @@ func (s *Stmt) exec(ctx context.Context, args []namedValue) (res driver.Result, if !s.c.connectionGood { return nil, driver.ErrBadConn } + if s.doEncryption() && len(args) > 0 { + args, err = s.encryptArgs(ctx, args) + } + if err != nil { + return nil, err + } if err = s.sendQuery(ctx, args); err != nil { return nil, s.c.checkBadConn(ctx, err, true) } @@ -872,7 +919,7 @@ func (rc *Rows) NextResultSet() error { // the value type that can be used to scan types into. For example, the database // column type "bigint" this should return "reflect.TypeOf(int64(0))". func (r *Rows) ColumnTypeScanType(index int) reflect.Type { - return makeGoLangScanType(r.cols[index].ti) + return makeGoLangScanType(r.cols[index].originalTypeInfo()) } // RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the @@ -881,7 +928,7 @@ func (r *Rows) ColumnTypeScanType(index int) reflect.Type { // "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML", // "TIMESTAMP". func (r *Rows) ColumnTypeDatabaseTypeName(index int) string { - return makeGoLangTypeName(r.cols[index].ti) + return makeGoLangTypeName(r.cols[index].originalTypeInfo()) } // RowsColumnTypeLength may be implemented by Rows. It should return the length @@ -897,7 +944,7 @@ func (r *Rows) ColumnTypeDatabaseTypeName(index int) string { // int (0, false) // bytea(30) (30, true) func (r *Rows) ColumnTypeLength(index int) (int64, bool) { - return makeGoLangTypeLength(r.cols[index].ti) + return makeGoLangTypeLength(r.cols[index].originalTypeInfo()) } // It should return @@ -908,7 +955,7 @@ func (r *Rows) ColumnTypeLength(index int) (int64, bool) { // int (0, 0, false) // decimal (math.MaxInt64, math.MaxInt64, true) func (r *Rows) ColumnTypePrecisionScale(index int) (int64, int64, bool) { - return makeGoLangTypePrecisionScale(r.cols[index].ti) + return makeGoLangTypePrecisionScale(r.cols[index].originalTypeInfo()) } // The nullable value should @@ -974,12 +1021,20 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) { res.ti.TypeId = typeIntN res.ti.Size = 8 res.buffer = []byte{} - + case byte: + res.ti.TypeId = typeIntN + res.buffer = []byte{val} + res.ti.Size = 1 case float64: res.ti.TypeId = typeFltN res.ti.Size = 8 res.buffer = make([]byte, 8) binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(val)) + case float32: + res.ti.TypeId = typeFltN + res.ti.Size = 4 + res.buffer = make([]byte, 4) + binary.LittleEndian.PutUint32(res.buffer, math.Float32bits(val)) case sql.NullFloat64: // only null values should be getting here res.ti.TypeId = typeFltN @@ -1043,7 +1098,7 @@ func (c *Conn) Ping(ctx context.Context) error { if !c.connectionGood { return driver.ErrBadConn } - stmt := &Stmt{c, `select 1;`, 0, nil} + stmt := &Stmt{c, `select 1;`, 0, nil, true} _, err := stmt.ExecContext(ctx, nil) return err } @@ -1108,7 +1163,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv } list := make([]namedValue, len(args)) for i, nv := range args { - list[i] = namedValue(nv) + list[i] = namedValueFromDriverNamedValue(nv) } return s.queryContext(ctx, list) } @@ -1121,11 +1176,15 @@ func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive } list := make([]namedValue, len(args)) for i, nv := range args { - list[i] = namedValue(nv) + list[i] = namedValueFromDriverNamedValue(nv) } return s.exec(ctx, list) } +func namedValueFromDriverNamedValue(v driver.NamedValue) namedValue { + return namedValue{Name: v.Name, Ordinal: v.Ordinal, Value: v.Value, encrypt: nil} +} + // Rowsq implements the sqlexp messages model for Query and QueryContext // Theory: We could also implement the non-experimental model this way type Rowsq struct { @@ -1316,7 +1375,7 @@ scan: // the value type that can be used to scan types into. For example, the database // column type "bigint" this should return "reflect.TypeOf(int64(0))". func (r *Rowsq) ColumnTypeScanType(index int) reflect.Type { - return makeGoLangScanType(r.cols[index].ti) + return makeGoLangScanType(r.cols[index].originalTypeInfo()) } // RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the @@ -1325,7 +1384,7 @@ func (r *Rowsq) ColumnTypeScanType(index int) reflect.Type { // "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML", // "TIMESTAMP". func (r *Rowsq) ColumnTypeDatabaseTypeName(index int) string { - return makeGoLangTypeName(r.cols[index].ti) + return makeGoLangTypeName(r.cols[index].originalTypeInfo()) } // RowsColumnTypeLength may be implemented by Rows. It should return the length @@ -1341,7 +1400,7 @@ func (r *Rowsq) ColumnTypeDatabaseTypeName(index int) string { // int (0, false) // bytea(30) (30, true) func (r *Rowsq) ColumnTypeLength(index int) (int64, bool) { - return makeGoLangTypeLength(r.cols[index].ti) + return makeGoLangTypeLength(r.cols[index].originalTypeInfo()) } // It should return @@ -1352,7 +1411,7 @@ func (r *Rowsq) ColumnTypeLength(index int) (int64, bool) { // int (0, 0, false) // decimal (math.MaxInt64, math.MaxInt64, true) func (r *Rowsq) ColumnTypePrecisionScale(index int) (int64, int64, bool) { - return makeGoLangTypePrecisionScale(r.cols[index].ti) + return makeGoLangTypePrecisionScale(r.cols[index].originalTypeInfo()) } // The nullable value should diff --git a/mssql_go19.go b/mssql_go19.go index 688b5c5d..b0285eef 100644 --- a/mssql_go19.go +++ b/mssql_go19.go @@ -29,12 +29,18 @@ type MssqlStmt = Stmt // Deprecated: users should transition to th var _ driver.NamedValueChecker = &Conn{} -// VarChar parameter types. +// VarChar is used to encode a string parameter as VarChar instead of a sized NVarChar type VarChar string +// NVarCharMax is used to encode a string parameter as NVarChar(max) instead of a sized NVarChar type NVarCharMax string + +// VarCharMax is used to encode a string parameter as VarChar(max) instead of a sized NVarChar type VarCharMax string +// NChar is used to encode a string parameter as NChar instead of a sized NVarChar +type NChar string + // DateTime1 encodes parameters to original DateTime SQL types. type DateTime1 time.Time @@ -45,12 +51,16 @@ func convertInputParameter(val interface{}) (interface{}, error) { switch v := val.(type) { case int, int16, int32, int64, int8: return val, nil + case byte: + return val, nil case VarChar: return val, nil case NVarCharMax: return val, nil case VarCharMax: return val, nil + case NChar: + return val, nil case DateTime1: return val, nil case DateTimeOffset: @@ -61,8 +71,10 @@ func convertInputParameter(val interface{}) (interface{}, error) { return val, nil case civil.Time: return val, nil - // case *apd.Decimal: - // return nil + // case *apd.Decimal: + // return nil + case float32: + return val, nil default: return driver.DefaultParameterConverter.ConvertValue(v) } @@ -144,6 +156,10 @@ func (s *Stmt) makeParamExtra(val driver.Value) (res param, err error) { res.ti.TypeId = typeNVarChar res.buffer = str2ucs2(string(val)) res.ti.Size = 0 // currently zero forces nvarchar(max) + case NChar: + res.ti.TypeId = typeNChar + res.buffer = str2ucs2(string(val)) + res.ti.Size = len(res.buffer) case DateTime1: t := time.Time(val) res.ti.TypeId = typeDateTimeN diff --git a/quoter.go b/quoter.go new file mode 100644 index 00000000..1f8f4f38 --- /dev/null +++ b/quoter.go @@ -0,0 +1,40 @@ +package mssql + +import ( + "strings" +) + +// TSQLQuoter implements sqlexp.Quoter +type TSQLQuoter struct { +} + +// ID quotes identifiers such as schema, table, or column names. +// This implementation handles multi-part names. +func (TSQLQuoter) ID(name string) string { + return "[" + strings.Replace(name, "]", "]]", -1) + "]" +} + +// Value quotes database values such as string or []byte types as strings +// that are suitable and safe to embed in SQL text. The returned value +// of a string will include all surrounding quotes. +// +// If a value type is not supported it must panic. +func (TSQLQuoter) Value(v interface{}) string { + switch v := v.(type) { + default: + panic("unsupported value") + + case string: + return sqlString(v) + case VarChar: + return sqlString(string(v)) + case VarCharMax: + return sqlString(string(v)) + case NVarCharMax: + return sqlString(string(v)) + } +} + +func sqlString(v string) string { + return "'" + strings.Replace(string(v), "'", "''", -1) + "'" +} diff --git a/rpc.go b/rpc.go index f7d4c00e..8f1ef2b4 100644 --- a/rpc.go +++ b/rpc.go @@ -13,13 +13,16 @@ type procId struct { const ( fByRevValue = 1 fDefaultValue = 2 + fEncrypted = 8 ) type param struct { - Name string - Flags uint8 - ti typeInfo - buffer []byte + Name string + Flags uint8 + ti typeInfo + buffer []byte + tiOriginal typeInfo + cipherInfo []byte } var ( @@ -78,6 +81,15 @@ func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16, if err != nil { return } + if (param.Flags & fEncrypted) == fEncrypted { + err = writeTypeInfo(buf, ¶m.tiOriginal) + if err != nil { + return + } + if _, err = buf.Write(param.cipherInfo); err != nil { + return + } + } } return buf.FinishPacket() } diff --git a/tds.go b/tds.go index d10f9c6b..891630c6 100644 --- a/tds.go +++ b/tds.go @@ -15,6 +15,7 @@ import ( "unicode/utf16" "unicode/utf8" + "github.com/microsoft/go-mssqldb/aecmk" "github.com/microsoft/go-mssqldb/integratedauth" "github.com/microsoft/go-mssqldb/msdsn" ) @@ -157,16 +158,23 @@ const ( ) type tdsSession struct { - buf *tdsBuffer - loginAck loginAckStruct - database string - partner string - columns []columnStruct - tranid uint64 - logFlags uint64 - logger ContextLogger - routedServer string - routedPort uint16 + buf *tdsBuffer + loginAck loginAckStruct + database string + partner string + columns []columnStruct + tranid uint64 + logFlags uint64 + logger ContextLogger + routedServer string + routedPort uint16 + alwaysEncrypted bool + aeSettings *alwaysEncryptedSettings +} + +type alwaysEncryptedSettings struct { + enclaveType string + keyProviders aecmk.ColumnEncryptionKeyProviderMap } const ( @@ -178,10 +186,26 @@ const ( ) type columnStruct struct { - UserType uint32 - Flags uint16 - ColName string - ti typeInfo + UserType uint32 + Flags uint16 + ColName string + ti typeInfo + cryptoMeta *cryptoMetadata +} + +func (c columnStruct) isEncrypted() bool { + return isEncryptedFlag(c.Flags) +} + +func isEncryptedFlag(flags uint16) bool { + return colFlagEncrypted == (flags & colFlagEncrypted) +} + +func (c columnStruct) originalTypeInfo() typeInfo { + if c.isEncrypted() { + return c.cryptoMeta.typeInfo + } + return c.ti } type keySlice []uint8 @@ -1047,6 +1071,9 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont CtlIntName: "go-mssqldb", ClientProgVer: getDriverVersion(driverVersion), } + if p.ColumnEncryption { + _ = l.FeatureExt.Add(&featureExtColumnEncryption{}) + } switch { case fe.FedAuthLibrary == FedAuthLibrarySecurityToken: if uint64(p.LogFlags)&logDebug != 0 { @@ -1061,14 +1088,14 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont return nil, err } - l.FeatureExt.Add(fe) + _ = l.FeatureExt.Add(fe) case fe.FedAuthLibrary == FedAuthLibraryADAL: if uint64(p.LogFlags)&logDebug != 0 { logger.Log(ctx, msdsn.LogDebug, "Starting federated authentication using ADAL") } - l.FeatureExt.Add(fe) + _ = l.FeatureExt.Add(fe) case auth != nil: if uint64(p.LogFlags)&logDebug != 0 { @@ -1136,11 +1163,15 @@ initiate_connection: outbuf := newTdsBuffer(packetSize, toconn) sess := tdsSession{ - buf: outbuf, - logger: logger, - logFlags: uint64(p.LogFlags), + buf: outbuf, + logger: logger, + logFlags: uint64(p.LogFlags), + aeSettings: &alwaysEncryptedSettings{keyProviders: aecmk.GetGlobalCekProviders()}, } + for i, p := range c.keyProviders { + sess.aeSettings.keyProviders[i] = p + } fedAuth := &featureExtFedAuth{ FedAuthLibrary: FedAuthLibraryReserved, } @@ -1288,6 +1319,18 @@ initiate_connection: case loginAckStruct: sess.loginAck = token loginAck = true + case featureExtAck: + for _, v := range token { + switch v := v.(type) { + case colAckStruct: + if v.Version <= 2 && v.Version > 0 { + sess.alwaysEncrypted = true + if len(v.EnclaveType) > 0 { + sess.aeSettings.enclaveType = string(v.EnclaveType) + } + } + } + } case doneStruct: if token.isError() { tokenErr := token.getError() @@ -1317,3 +1360,21 @@ initiate_connection: } return &sess, nil } + +type featureExtColumnEncryption struct { +} + +func (f *featureExtColumnEncryption) featureID() byte { + return featExtCOLUMNENCRYPTION +} + +func (f *featureExtColumnEncryption) toBytes() []byte { + /* + 1 = The client supports column encryption without enclave computations. + 2 = The client SHOULD<25> support column encryption when encrypted data require enclave computations. + 3 = The client SHOULD<26> support column encryption when encrypted data require enclave computations + with the additional ability to cache column encryption keys that are to be sent to the enclave + and the ability to retry queries when the keys sent by the client do not match what is needed for the query to run. + */ + return []byte{0x01} +} diff --git a/tds_test.go b/tds_test.go index 448491b8..daabd714 100644 --- a/tds_test.go +++ b/tds_test.go @@ -35,7 +35,7 @@ func TestConstantsDefined(t *testing.T) { // This test is just here to avoid complaints about unused code. // These constants are part of the spec but not yet used. for _, b := range []byte{ - featExtSESSIONRECOVERY, featExtCOLUMNENCRYPTION, featExtGLOBALTRANSACTIONS, + featExtSESSIONRECOVERY, featExtGLOBALTRANSACTIONS, featExtAZURESQLSUPPORT, featExtDATACLASSIFICATION, featExtUTF8SUPPORT, } { if b == 0 { @@ -122,16 +122,18 @@ func TestSendLoginWithFeatureExt(t *testing.T) { Database: "database", ClientLCID: 0x204, } - login.FeatureExt.Add(&featureExtFedAuth{ + _ = login.FeatureExt.Add(&featureExtFedAuth{ FedAuthLibrary: FedAuthLibrarySecurityToken, FedAuthToken: "fedauthtoken", }) + _ = login.FeatureExt.Add(&featureExtColumnEncryption{}) err := sendLogin(buf, &login) if err != nil { t.Error("sendLogin should succeed") } - ref := []byte{ - 16, 1, 0, 223, 0, 0, 1, 0, 215, 0, 0, 0, 4, 0, 0, 116, + // featureext ordering is non-deterministic + ref1 := []byte{ + 16, 1, 0, 0xe5, 0, 0, 1, 0, 0xdd, 0, 0, 0, 4, 0, 0, 116, 0, 16, 0, 0, 0, 1, 6, 1, 100, 0, 0, 0, 0, 0, 0, 0, 224, 0, 0, 24, 16, 255, 255, 255, 4, 2, 0, 0, 94, 0, 7, 0, 108, 0, 0, 0, 108, 0, 0, 0, 108, 0, 7, 0, 122, 0, 10, 0, @@ -144,11 +146,30 @@ func TestSendLoginWithFeatureExt(t *testing.T) { 114, 0, 121, 0, 101, 0, 110, 0, 100, 0, 97, 0, 116, 0, 97, 0, 98, 0, 97, 0, 115, 0, 101, 0, 180, 0, 0, 0, 2, 29, 0, 0, 0, 2, 24, 0, 0, 0, 102, 0, 101, 0, 100, 0, 97, 0, 117, 0, + 116, 0, 104, 0, 116, 0, 111, 0, 107, 0, 101, 0, 110, 0, 4, 1, + 0, 0, 0, 1, 255} + ref2 := []byte{ + 16, 1, 0, 0xe5, 0, 0, 1, 0, 0xdd, 0, 0, 0, 4, 0, 0, 116, + 0, 16, 0, 0, 0, 1, 6, 1, 100, 0, 0, 0, 0, 0, 0, 0, + 224, 0, 0, 24, 16, 255, 255, 255, 4, 2, 0, 0, 94, 0, 7, 0, + 108, 0, 0, 0, 108, 0, 0, 0, 108, 0, 7, 0, 122, 0, 10, 0, + 176, 0, 4, 0, 142, 0, 7, 0, 156, 0, 2, 0, 160, 0, 8, 0, + 18, 52, 86, 120, 144, 171, 176, 0, 0, 0, 176, 0, 0, 0, 176, 0, + 0, 0, 0, 0, 0, 0, 115, 0, 117, 0, 98, 0, 100, 0, 101, 0, + 118, 0, 49, 0, 97, 0, 112, 0, 112, 0, 110, 0, 97, 0, 109, 0, + 101, 0, 115, 0, 101, 0, 114, 0, 118, 0, 101, 0, 114, 0, 110, 0, + 97, 0, 109, 0, 101, 0, 108, 0, 105, 0, 98, 0, 114, 0, 97, 0, + 114, 0, 121, 0, 101, 0, 110, 0, 100, 0, 97, 0, 116, 0, 97, 0, + 98, 0, 97, 0, 115, 0, 101, 0, 180, 0, 0, 0, 4, 1, + 0, 0, 0, 1, 2, 29, 0, 0, + 0, 2, 24, 0, 0, 0, 102, 0, 101, 0, 100, 0, 97, 0, 117, 0, 116, 0, 104, 0, 116, 0, 111, 0, 107, 0, 101, 0, 110, 0, 255} out := memBuf.Bytes() - if !bytes.Equal(ref, out) { + if !bytes.Equal(ref1, out) && !bytes.Equal(ref2, out) { t.Log("Expected:") - t.Log(hex.Dump(ref)) + t.Log(hex.Dump(ref1)) + t.Log("Or:") + t.Log(hex.Dump(ref2)) t.Log("Returned:") t.Log(hex.Dump(out)) t.Fatal("input output don't match") @@ -202,6 +223,59 @@ func TestSendSqlBatch(t *testing.T) { } } +func TestLoginWithColumnEncryption(t *testing.T) { + checkConnStr(t) + p, err := msdsn.Parse(makeConnStr(t).String()) + if err != nil { + t.Error("parseConnectParams failed:", err.Error()) + return + } + p.ColumnEncryption = true + tl := testLogger{t: t} + defer tl.StopLogging() + conn, err := connect(context.Background(), &Connector{params: p}, optionalLogger{loggerAdapter{&tl}}, p) + if err != nil { + t.Error("Open connection failed:", err.Error()) + return + } + defer conn.buf.transport.Close() + + headers := []headerStruct{ + {hdrtype: dataStmHdrTransDescr, + data: transDescrHdr{0, 1}.pack()}, + } + err = sendSqlBatch72(conn.buf, "select (@@microsoftversion / 0x1000000) & 0xff AS [VersionMajor]", headers, true) + if err != nil { + t.Error("Sending sql batch failed", err.Error()) + return + } + + reader := startReading(conn, context.Background(), outputs{}) + + err = reader.iterateResponse() + if err != nil { + t.Fatal(err) + } + + if len(reader.lastRow) == 0 { + t.Fatal("expected row but no row set") + } + + switch value := reader.lastRow[0].(type) { + case int64: + if value > 12 { + if !conn.alwaysEncrypted { + t.Fatalf("SQL Version %d should have alwaysEncrypted == true", value) + } + } else if conn.alwaysEncrypted { + t.Fatalf("SQL Version %d should have alwaysEncrypted == false", value) + } + + default: + t.Fatalf("Expected int64 return but got %v", value) + } +} + // returns parsed connection parameters derived from // environment variables func testConnParams(t testing.TB) msdsn.Config { @@ -252,6 +326,9 @@ func GetConnParams() (*msdsn.Config, error) { if os.Getenv("PIPE") != "" { c.Parameters["pipe"] = os.Getenv("PIPE") } + if os.Getenv("COLUMNENCRYPTION") != "" { + c.ColumnEncryption = true + } return c, nil } // try loading connection string from file @@ -912,19 +989,22 @@ func BenchmarkPacketSize(b *testing.B) { b.Run(bm.name, func(b *testing.B) { for i := 0; i < b.N; i++ { p.PacketSize = bm.packetSize - runBatch(b, p) + runBatch(b, "", p) } }) } } -func runBatch(t testing.TB, p msdsn.Config) { +func runBatch(t testing.TB, batch string, p msdsn.Config) int32 { + if len(batch) == 0 { + batch = "select 1" + } tl := testLogger{t: t} defer tl.StopLogging() conn, err := connect(context.Background(), &Connector{params: p}, optionalLogger{loggerAdapter{&tl}}, p) if err != nil { t.Error("Open connection failed:", err.Error()) - return + return 0 } defer conn.buf.transport.Close() @@ -932,10 +1012,10 @@ func runBatch(t testing.TB, p msdsn.Config) { {hdrtype: dataStmHdrTransDescr, data: transDescrHdr{0, 1}.pack()}, } - err = sendSqlBatch72(conn.buf, "select 1", headers, true) + err = sendSqlBatch72(conn.buf, batch, headers, true) if err != nil { t.Error("Sending sql batch failed", err.Error()) - return + return 0 } reader := startReading(conn, context.Background(), outputs{}) @@ -951,11 +1031,11 @@ func runBatch(t testing.TB, p msdsn.Config) { switch value := reader.lastRow[0].(type) { case int32: - if value != 1 { - t.Error("Invalid value returned, should be 1", value) - return - } + return value + default: + t.Fatalf("expected an int32 return but got %v", value) } + return 0 } func TestGetDriverVersion(t *testing.T) { diff --git a/token.go b/token.go index 76d4e025..323ddbd7 100644 --- a/token.go +++ b/token.go @@ -1,6 +1,7 @@ package mssql import ( + "bytes" "context" "encoding/binary" "fmt" @@ -10,7 +11,11 @@ import ( "strconv" "github.com/golang-sql/sqlexp" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption" + "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys" "github.com/microsoft/go-mssqldb/msdsn" + "golang.org/x/text/encoding/unicode" ) //go:generate go run golang.org/x/tools/cmd/stringer -type token @@ -92,10 +97,15 @@ const ( fedAuthInfoSPN = 0x02 ) +const ( + cipherAlgCustom = 0x00 +) + // COLMETADATA flags // https://msdn.microsoft.com/en-us/library/dd357363.aspx const ( - colFlagNullable = 1 + colFlagNullable = 1 + colFlagEncrypted = 0x0800 // TODO implement more flags ) @@ -533,7 +543,14 @@ type fedAuthAckStruct struct { Signature []byte } -func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} { +type colAckStruct struct { + Version int + EnclaveType string +} + +type featureExtAck map[byte]interface{} + +func parseFeatureExtAck(r *tdsBuffer) featureExtAck { ack := map[byte]interface{}{} for feature := r.byte(); feature != featExtTERMINATOR; feature = r.byte() { @@ -555,7 +572,21 @@ func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} { length -= 32 } ack[feature] = fedAuthAck - + case featExtCOLUMNENCRYPTION: + colAck := colAckStruct{Version: int(r.byte())} + length-- + if length > 0 { + // enclave type is sent as utf16 le + enclaveLength := r.byte() * 2 + length-- + enclaveBytes := make([]byte, enclaveLength) + r.ReadFull(enclaveBytes) + // if the enclave type is malformed we'll just ignore it + colAck.EnclaveType, _ = ucs22str(enclaveBytes) + length -= uint32(enclaveLength) + + } + ack[feature] = colAck } // Skip unprocessed bytes @@ -568,34 +599,265 @@ func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} { } // http://msdn.microsoft.com/en-us/library/dd357363.aspx -func parseColMetadata72(r *tdsBuffer) (columns []columnStruct) { +func parseColMetadata72(r *tdsBuffer, s *tdsSession) (columns []columnStruct) { count := r.uint16() if count == 0xffff { // no metadata is sent return nil } columns = make([]columnStruct, count) + var cekTable *cekTable + if s.alwaysEncrypted { + // column encryption key list + cekTable = readCekTable(r) + } + for i := range columns { column := &columns[i] - column.UserType = r.uint32() - column.Flags = r.uint16() + baseTi := getBaseTypeInfo(r, true) + typeInfo := readTypeInfo(r, baseTi.TypeId, column.cryptoMeta) + typeInfo.UserType = baseTi.UserType + typeInfo.Flags = baseTi.Flags + typeInfo.TypeId = baseTi.TypeId + + column.Flags = baseTi.Flags + column.UserType = baseTi.UserType + column.ti = typeInfo + + if column.isEncrypted() && s.alwaysEncrypted { + // Read Crypto Metadata + cryptoMeta := parseCryptoMetadata(r, cekTable) + cryptoMeta.typeInfo.Flags = baseTi.Flags + column.cryptoMeta = &cryptoMeta + } else { + column.cryptoMeta = nil + } - // parsing TYPE_INFO structure - column.ti = readTypeInfo(r) column.ColName = r.BVarChar() } return columns } +func getBaseTypeInfo(r *tdsBuffer, parseFlags bool) typeInfo { + userType := r.uint32() + flags := uint16(0) + if parseFlags { + flags = r.uint16() + } + tId := r.byte() + + return typeInfo{ + UserType: userType, + Flags: flags, + TypeId: tId} +} + +type cryptoMetadata struct { + entry *cekTableEntry + ordinal uint16 + algorithmId byte + algorithmName *string + encType byte + normRuleVer byte + typeInfo typeInfo +} + +func parseCryptoMetadata(r *tdsBuffer, cekTable *cekTable) cryptoMetadata { + ordinal := uint16(0) + if cekTable != nil { + ordinal = r.uint16() + } + + typeInfo := getBaseTypeInfo(r, false) + ti := readTypeInfo(r, typeInfo.TypeId, nil) + ti.UserType = typeInfo.UserType + ti.Flags = typeInfo.Flags + ti.TypeId = typeInfo.TypeId + + algorithmId := r.byte() + var algName *string = nil + + if algorithmId == cipherAlgCustom { + // Read the name when a custom algorithm is used + nameLen := int(r.byte()) + var algNameUtf16 = make([]byte, nameLen*2) + r.ReadFull(algNameUtf16) + algNameBytes, _ := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewDecoder().Bytes(algNameUtf16) + mAlgName := string(algNameBytes) + algName = &mAlgName + } + + encType := r.byte() + normRuleVer := r.byte() + + var entry *cekTableEntry = nil + + if cekTable != nil { + if int(ordinal) > len(cekTable.entries)-1 { + panic(fmt.Errorf("invalid ordinal, cekTable only has %d entries", len(cekTable.entries))) + } + entry = &cekTable.entries[ordinal] + } + + return cryptoMetadata{ + entry: entry, + ordinal: ordinal, + algorithmId: algorithmId, + algorithmName: algName, + encType: encType, + normRuleVer: normRuleVer, + typeInfo: ti, + } +} + +func readCekTable(r *tdsBuffer) *cekTable { + tableSize := r.uint16() + var cekTable *cekTable = nil + + if tableSize != 0 { + mCekTable := newCekTable(tableSize) + for i := uint16(0); i < tableSize; i++ { + mCekTable.entries[i] = readCekTableEntry(r) + } + cekTable = &mCekTable + } + + return cekTable +} + +func readCekTableEntry(r *tdsBuffer) cekTableEntry { + databaseId := r.int32() + cekID := r.int32() + cekVersion := r.int32() + var cekMdVersion = make([]byte, 8) + _, err := r.Read(cekMdVersion) + if err != nil { + panic("unable to read cekMdVersion") + } + + cekValueCount := uint(r.byte()) + // not using ucs22str because we already know the data is utf16 + enc := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM) + utf16dec := enc.NewDecoder() + cekValues := make([]encryptionKeyInfo, cekValueCount) + + for i := uint(0); i < cekValueCount; i++ { + encryptedCekLength := r.uint16() + encryptedCek := make([]byte, encryptedCekLength) + r.ReadFull(encryptedCek) + + keyStoreLength := r.byte() + keyStoreNameUtf16 := make([]byte, keyStoreLength*2) + r.ReadFull(keyStoreNameUtf16) + keyStoreName, _ := utf16dec.Bytes(keyStoreNameUtf16) + + keyPathLength := r.uint16() + keyPathUtf16 := make([]byte, keyPathLength*2) + r.ReadFull(keyPathUtf16) + keyPath, _ := utf16dec.Bytes(keyPathUtf16) + + algLength := r.byte() + algNameUtf16 := make([]byte, algLength*2) + r.ReadFull(algNameUtf16) + algName, _ := utf16dec.Bytes(algNameUtf16) + + cekValues[i] = encryptionKeyInfo{ + encryptedKey: encryptedCek, + databaseID: int(databaseId), + cekID: int(cekID), + cekVersion: int(cekVersion), + cekMdVersion: cekMdVersion, + keyPath: string(keyPath), + keyStoreName: string(keyStoreName), + algorithmName: string(algName), + } + } + + return cekTableEntry{ + databaseID: int(databaseId), + keyId: int(cekID), + keyVersion: int(cekVersion), + mdVersion: cekMdVersion, + valueCount: int(cekValueCount), + cekValues: cekValues, + } +} + // http://msdn.microsoft.com/en-us/library/dd357254.aspx -func parseRow(r *tdsBuffer, columns []columnStruct, row []interface{}) { +func parseRow(r *tdsBuffer, s *tdsSession, columns []columnStruct, row []interface{}) { for i, column := range columns { - row[i] = column.ti.Reader(&column.ti, r) + columnContent := column.ti.Reader(&column.ti, r, nil) + if columnContent == nil { + row[i] = columnContent + continue + } + + if column.isEncrypted() { + buffer := decryptColumn(column, s, columnContent) + // Decrypt + row[i] = column.cryptoMeta.typeInfo.Reader(&column.cryptoMeta.typeInfo, &buffer, column.cryptoMeta) + } else { + row[i] = columnContent + } } } +type RWCBuffer struct { + buffer *bytes.Reader +} + +func (R RWCBuffer) Read(p []byte) (n int, err error) { + return R.buffer.Read(p) +} + +func (R RWCBuffer) Write(p []byte) (n int, err error) { + return 0, nil +} + +func (R RWCBuffer) Close() error { + return nil +} + +func decryptColumn(column columnStruct, s *tdsSession, columnContent interface{}) tdsBuffer { + encType := encryption.From(column.cryptoMeta.encType) + cekValue := column.cryptoMeta.entry.cekValues[column.cryptoMeta.ordinal] + if (s.logFlags & uint64(msdsn.LogDebug)) == uint64(msdsn.LogDebug) { + s.logger.Log(context.Background(), msdsn.LogDebug, fmt.Sprintf("Decrypting column %s. Key path: %s, Key store:%s, Algo: %s", column.ColName, cekValue.keyPath, cekValue.keyStoreName, cekValue.algorithmName)) + } + + cekProvider, ok := s.aeSettings.keyProviders[cekValue.keyStoreName] + if !ok { + panic(fmt.Errorf("Unable to find provider %s to decrypt CEK", cekValue.keyStoreName)) + } + cek, err := cekProvider.GetDecryptedKey(cekValue.keyPath, column.cryptoMeta.entry.cekValues[0].encryptedKey) + if err != nil { + panic(err) + } + k := keys.NewAeadAes256CbcHmac256(cek) + alg := algorithms.NewAeadAes256CbcHmac256Algorithm(k, encType, byte(cekValue.cekVersion)) + d, err := alg.Decrypt(columnContent.([]byte)) + if err != nil { + panic(err) + } + + // Decrypt returns a minimum of 8 bytes so truncate to the actual data size + if column.cryptoMeta.typeInfo.Size > 0 && column.cryptoMeta.typeInfo.Size < len(d) { + d = d[:column.cryptoMeta.typeInfo.Size] + } + var newBuff []byte + newBuff = append(newBuff, d...) + + rwc := RWCBuffer{ + buffer: bytes.NewReader(newBuff), + } + + column.cryptoMeta.typeInfo.Buffer = d + buffer := tdsBuffer{rpos: 0, rsize: len(newBuff), rbuf: newBuff, transport: rwc} + return buffer +} + // http://msdn.microsoft.com/en-us/library/dd304783.aspx -func parseNbcRow(r *tdsBuffer, columns []columnStruct, row []interface{}) { +func parseNbcRow(r *tdsBuffer, s *tdsSession, columns []columnStruct, row []interface{}) { bitlen := (len(columns) + 7) / 8 pres := make([]byte, bitlen) r.ReadFull(pres) @@ -604,7 +866,15 @@ func parseNbcRow(r *tdsBuffer, columns []columnStruct, row []interface{}) { row[i] = nil continue } - row[i] = col.ti.Reader(&col.ti, r) + columnContent := col.ti.Reader(&col.ti, r, nil) + if col.isEncrypted() { + buffer := decryptColumn(col, s, columnContent) + // Decrypt + row[i] = col.cryptoMeta.typeInfo.Reader(&col.cryptoMeta.typeInfo, &buffer, col.cryptoMeta) + } else { + row[i] = columnContent + } + } } @@ -637,7 +907,7 @@ func parseInfo(r *tdsBuffer) (res Error) { } // https://msdn.microsoft.com/en-us/library/dd303881.aspx -func parseReturnValue(r *tdsBuffer) (nv namedValue) { +func parseReturnValue(r *tdsBuffer, s *tdsSession) (nv namedValue) { /* ParamOrdinal ParamName @@ -648,13 +918,21 @@ func parseReturnValue(r *tdsBuffer) (nv namedValue) { CryptoMetadata Value */ - r.uint16() - nv.Name = r.BVarChar() - r.byte() - r.uint32() // UserType (uint16 prior to 7.2) - r.uint16() - ti := readTypeInfo(r) - nv.Value = ti.Reader(&ti, r) + _ = r.uint16() // ParamOrdinal + nv.Name = r.BVarChar() // ParamName + _ = r.byte() // Status + + ti := getBaseTypeInfo(r, true) // UserType + Flags + TypeInfo + + var cryptoMetadata *cryptoMetadata = nil + if s.alwaysEncrypted && (ti.Flags&fEncrypted) == fEncrypted { + cm := parseCryptoMetadata(r, nil) // CryptoMetadata + cryptoMetadata = &cm + } + + ti2 := readTypeInfo(r, ti.TypeId, cryptoMetadata) + nv.Value = ti2.Reader(&ti2, r, cryptoMetadata) + return } @@ -664,6 +942,17 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS if sess.logFlags&logErrors != 0 { sess.logger.Log(ctx, msdsn.LogErrors, fmt.Sprintf("Intercepted panic %v", err)) } + if outs.msgq != nil { + var derr error + switch e := err.(type) { + case error: + derr = e + default: + derr = fmt.Errorf("Unhandled session error %v", e) + } + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgError{Error: derr}) + + } ch <- err } close(ch) @@ -760,7 +1049,7 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS ch <- done if done.Status&doneCount != 0 { if sess.logFlags&logRows != 0 { - sess.logger.Log(ctx, msdsn.LogRows, fmt.Sprintf("(%d row(s) affected)", done.RowCount)) + sess.logger.Log(ctx, msdsn.LogRows, fmt.Sprintf("(Rows affected: %d)", done.RowCount)) } if (colsReceived || done.CurCmd != cmdSelect) && outs.msgq != nil { @@ -781,7 +1070,7 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS return } case tokenColMetadata: - columns = parseColMetadata72(sess.buf) + columns = parseColMetadata72(sess.buf, sess) ch <- columns colsReceived = true if outs.msgq != nil { @@ -790,11 +1079,11 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS case tokenRow: row := make([]interface{}, len(columns)) - parseRow(sess.buf, columns, row) + parseRow(sess.buf, sess, columns, row) ch <- row case tokenNbcRow: row := make([]interface{}, len(columns)) - parseNbcRow(sess.buf, columns, row) + parseNbcRow(sess.buf, sess, columns, row) ch <- row case tokenEnvChange: processEnvChg(ctx, sess) @@ -822,7 +1111,7 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNotice{Message: info}) } case tokenReturnValue: - nv := parseReturnValue(sess.buf) + nv := parseReturnValue(sess.buf, sess) if len(nv.Name) > 0 { name := nv.Name[1:] // Remove the leading "@". if ov, has := outs.params[name]; has { diff --git a/types.go b/types.go index 3b4760e3..24cc4077 100644 --- a/types.go +++ b/types.go @@ -89,6 +89,8 @@ const ( // http://msdn.microsoft.com/en-us/library/dd358284.aspx type typeInfo struct { TypeId uint8 + UserType uint32 + Flags uint16 Size int Scale uint8 Prec uint8 @@ -96,7 +98,7 @@ type typeInfo struct { Collation cp.Collation UdtInfo udtInfo XmlInfo xmlInfo - Reader func(ti *typeInfo, r *tdsBuffer) (res interface{}) + Reader func(ti *typeInfo, r *tdsBuffer, cryptoMeta *cryptoMetadata) (res interface{}) Writer func(w io.Writer, ti typeInfo, buf []byte) (err error) } @@ -119,9 +121,9 @@ type xmlInfo struct { XmlSchemaCollection string } -func readTypeInfo(r *tdsBuffer) (res typeInfo) { - res.TypeId = r.byte() - switch res.TypeId { +func readTypeInfo(r *tdsBuffer, typeId byte, c *cryptoMetadata) (res typeInfo) { + res.TypeId = typeId + switch typeId { case typeNull, typeInt1, typeBit, typeInt2, typeInt4, typeDateTim4, typeFlt4, typeMoney, typeDateTime, typeFlt8, typeMoney4, typeInt8: // those are fixed length types @@ -140,7 +142,7 @@ func readTypeInfo(r *tdsBuffer) (res typeInfo) { res.Reader = readFixedType res.Buffer = make([]byte, res.Size) default: // all others are VARLENTYPE - readVarLen(&res, r) + readVarLen(&res, r, c) } return } @@ -315,7 +317,7 @@ func decodeDateTime(buf []byte) time.Time { 0, 0, secs, ns, time.UTC) } -func readFixedType(ti *typeInfo, r *tdsBuffer) interface{} { +func readFixedType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { r.ReadFull(ti.Buffer) buf := ti.Buffer switch ti.TypeId { @@ -349,8 +351,13 @@ func readFixedType(ti *typeInfo, r *tdsBuffer) interface{} { panic("shoulnd't get here") } -func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { - size := r.byte() +func readByteLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { + var size byte + if c != nil { + size = byte(r.rsize) + } else { + size = r.byte() + } if size == 0 { return nil } @@ -433,7 +440,7 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { default: badStreamPanicf("Invalid typeid") } - panic("shoulnd't get here") + panic("shouldn't get here") } func writeByteLenType(w io.Writer, ti typeInfo, buf []byte) (err error) { @@ -448,8 +455,13 @@ func writeByteLenType(w io.Writer, ti typeInfo, buf []byte) (err error) { return } -func readShortLenType(ti *typeInfo, r *tdsBuffer) interface{} { - size := r.uint16() +func readShortLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { + var size uint16 + if c != nil { + size = uint16(r.rsize) + } else { + size = r.uint16() + } if size == 0xffff { return nil } @@ -491,7 +503,7 @@ func writeShortLenType(w io.Writer, ti typeInfo, buf []byte) (err error) { return } -func readLongLenType(ti *typeInfo, r *tdsBuffer) interface{} { +func readLongLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { // information about this format can be found here: // http://msdn.microsoft.com/en-us/library/dd304783.aspx // and here: @@ -566,7 +578,7 @@ func writeCollation(w io.Writer, col cp.Collation) (err error) { // reads variant value // http://msdn.microsoft.com/en-us/library/dd303302.aspx -func readVariantType(ti *typeInfo, r *tdsBuffer) interface{} { +func readVariantType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { size := r.int32() if size == 0 { return nil @@ -658,41 +670,47 @@ func readVariantType(ti *typeInfo, r *tdsBuffer) interface{} { // partially length prefixed stream // http://msdn.microsoft.com/en-us/library/dd340469.aspx -func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} { - size := r.uint64() - var buf *bytes.Buffer - switch size { - case _PLP_NULL: - // null - return nil - case _UNKNOWN_PLP_LEN: - // size unknown - buf = bytes.NewBuffer(make([]byte, 0, 1000)) - default: - buf = bytes.NewBuffer(make([]byte, 0, size)) - } - for { - chunksize := r.uint32() - if chunksize == 0 { - break +func readPLPType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} { + var bytesToDecode []byte + if c == nil { + size := r.uint64() + var buf *bytes.Buffer + switch size { + case _PLP_NULL: + // null + return nil + case _UNKNOWN_PLP_LEN: + // size unknown + buf = bytes.NewBuffer(make([]byte, 0, 1000)) + default: + buf = bytes.NewBuffer(make([]byte, 0, size)) } - if _, err := io.CopyN(buf, r, int64(chunksize)); err != nil { - badStreamPanicf("Reading PLP type failed: %s", err.Error()) + for { + chunksize := r.uint32() + if chunksize == 0 { + break + } + if _, err := io.CopyN(buf, r, int64(chunksize)); err != nil { + badStreamPanicf("Reading PLP type failed: %s", err.Error()) + } } + bytesToDecode = buf.Bytes() + } else { + bytesToDecode = r.rbuf } switch ti.TypeId { case typeXml: - return decodeXml(*ti, buf.Bytes()) + return decodeXml(*ti, bytesToDecode) case typeBigVarChar, typeBigChar, typeText: - return decodeChar(ti.Collation, buf.Bytes()) + return decodeChar(ti.Collation, bytesToDecode) case typeBigVarBin, typeBigBinary, typeImage: - return buf.Bytes() + return bytesToDecode case typeNVarChar, typeNChar, typeNText: - return decodeNChar(buf.Bytes()) + return decodeNChar(bytesToDecode) case typeUdt: - return decodeUdt(*ti, buf.Bytes()) + return decodeUdt(*ti, bytesToDecode) } - panic("shoulnd't get here") + panic("shouldn't get here") } func writePLPType(w io.Writer, ti typeInfo, buf []byte) (err error) { @@ -719,7 +737,7 @@ func writePLPType(w io.Writer, ti typeInfo, buf []byte) (err error) { } } -func readVarLen(ti *typeInfo, r *tdsBuffer) { +func readVarLen(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) { switch ti.TypeId { case typeDateN: ti.Size = 3