Skip to content

Commit

Permalink
Add TDS8 support
Browse files Browse the repository at this point in the history
This commit fixes #136.  Adds support for TDS8. TDS8 connection
can now be used by specifying encrypt=strict. TrustServerCertificate=true
will not come into effect when encrypt is set to 'strict'.
  • Loading branch information
apoorvdeshmukh authored Aug 31, 2023
2 parents 77786ad + 8ee6a31 commit 2c72afb
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 59 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
* Added `ActiveDirectoryAzCli` and `ActiveDirectoryDeviceCode` authentication types to `azuread` package
* Always Encrypted encryption and decryption with 2 hour key cache (#116)
* 'pfx', 'MSSQL_CERTIFICATE_STORE', and 'AZURE_KEY_VAULT' encryption key providers
* TDS8 can now be used for connections by setting encrypt="strict"

## 1.5.0

Expand Down
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ Other supported formats are listed below.
* `connection timeout` - in seconds (default is 0 for no timeout), set to 0 for no timeout. Recommended to set to 0 and use context to manage query and connection timeouts.
* `dial timeout` - in seconds (default is 15 times the number of registered protocols), set to 0 for no timeout.
* `encrypt`
* `strict` - Data sent between client and server is encrypted E2E using [TDS8](https://learn.microsoft.com/en-us/sql/relational-databases/security/networking/tds-8?view=sql-server-ver16).
* `disable` - Data send between client and server is not encrypted.
* `false` - Data sent between client and server is not encrypted beyond the login packet. (Default)
* `true` - Data sent between client and server is encrypted.
* `false`/`optional`/`no`/`0`/`f` - Data sent between client and server is not encrypted beyond the login packet. (Default)
* `true`/`mandatory`/`yes`/`1`/`t` - Data sent between client and server is encrypted.
* `app name` - The application name (default is go-mssqldb)
* `authenticator` - Can be used to specify use of a registered authentication provider. (e.g. ntlm, winsspi (on windows) or krb5 (on linux))

Expand Down Expand Up @@ -56,7 +57,7 @@ Other supported formats are listed below.
* `TrustServerCertificate`
* false - Server certificate is checked. Default is false if encrypt is specified.
* true - Server certificate is not checked. Default is true if encrypt is not specified. If trust server certificate is true, driver accepts any certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing.
* `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates.
* `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates. Currently, certificates of PEM type are supported.
* `hostNameInCertificate` - Specifies the Common Name (CN) in the server certificate. Default value is the server host.
* `tlsmin` - Specifies the minimum TLS version for negotiating encryption with the server. Recognized values are `1.0`, `1.1`, `1.2`, `1.3`. If not set to a recognized value the default value for the `tls` package will be used. The default is currently `1.2`.
* `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port.
Expand Down Expand Up @@ -470,6 +471,7 @@ Constrain the provider to an allowed list of key vaults by appending vault host
* Always Encrypted
- `MSSQL_CERTIFICATE_STORE` provider on Windows
- `pfx` provider on Linux and Windows

## Tests

`go test` is used for testing. A running instance of MSSQL server is required.
Expand Down
30 changes: 27 additions & 3 deletions azuread/azuread_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@ package azuread
import (
"bufio"
"database/sql"
"encoding/hex"
"io"
"os"
"testing"

mssql "github.com/microsoft/go-mssqldb"
"github.com/stretchr/testify/assert"
)

func TestAzureSqlAuth(t *testing.T) {
mssqlConfig := testConnParams(t)
mssqlConfig := testConnParams(t, "")

conn, err := newConnectorConfig(mssqlConfig)
if err != nil {
Expand All @@ -35,9 +37,31 @@ func TestAzureSqlAuth(t *testing.T) {

}

func TestTDS8ConnWithAzureSqlAuth(t *testing.T) {
mssqlConfig := testConnParams(t, ";encrypt=strict;TrustServerCertificate=false;tlsmin=1.2")
conn, err := newConnectorConfig(mssqlConfig)
if err != nil {
t.Fatalf("Unable to get a connector: %v", err)
}
db := sql.OpenDB(conn)
row := db.QueryRow("SELECT protocol_type, CONVERT(varbinary(9),protocol_version),client_net_address from sys.dm_exec_connections where session_id=@@SPID")
if err != nil {
t.Fatal("Prepare failed:", err.Error())
}
var protocolName string
var tdsver []byte
var clientAddress string
err = row.Scan(&protocolName, &tdsver, &clientAddress)
if err != nil {
t.Fatal("Scan failed:", err.Error())
}
assert.Equal(t, "TSQL", protocolName, "Protocol name does not match")
assert.Equal(t, "08000000", hex.EncodeToString(tdsver))
}

// returns parsed connection parameters derived from
// environment variables
func testConnParams(t testing.TB) *azureFedAuthConfig {
func testConnParams(t testing.TB, dsnParams string) *azureFedAuthConfig {
dsn := os.Getenv("AZURESERVER_DSN")
const logFlags = 127
if dsn == "" {
Expand All @@ -54,7 +78,7 @@ func testConnParams(t testing.TB) *azureFedAuthConfig {
if dsn == "" {
t.Skip("no azure database connection string. set AZURESERVER_DSN environment variable or create .azureconnstr file")
}
config, err := parse(dsn)
config, err := parse(dsn + dsnParams)
if err != nil {
t.Skip("error parsing connection string ")
}
Expand Down
6 changes: 5 additions & 1 deletion internal/akvkeys/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,13 @@ func CreateRSAKey(client *azkeys.Client) (name string, err error) {
Kty: &kt,
KeySize: &ks,
}
i, _ := rand.Int(rand.Reader, big.NewInt(1000))

i, _ := rand.Int(rand.Reader, big.NewInt(1000000))
name = fmt.Sprintf("go-mssqlkey%d", i)
_, err = client.CreateKey(context.TODO(), name, rsaKeyParams, nil)
if err != nil {
_, err = client.RecoverDeletedKey(context.TODO(), name, &azkeys.RecoverDeletedKeyOptions{})
}
return
}

Expand Down
73 changes: 54 additions & 19 deletions msdsn/conn_str.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,17 @@ type (
BrowserMsg byte
)

const (
DsnTypeURL = 1
DsnTypeOdbc = 2
DsnTypeAdo = 3
)

const (
EncryptionOff = 0
EncryptionRequired = 1
EncryptionDisabled = 3
EncryptionStrict = 4
)

const (
Expand Down Expand Up @@ -162,17 +169,19 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e
var encryption Encryption = EncryptionOff
encrypt, ok := params[Encrypt]
if ok {
if strings.EqualFold(encrypt, "DISABLE") {
encrypt = strings.ToLower(encrypt)
switch encrypt {
case "mandatory", "yes", "1", "t", "true":
encryption = EncryptionRequired
case "disable":
encryption = EncryptionDisabled
} else {
e, err := strconv.ParseBool(encrypt)
if err != nil {
f := "invalid encrypt '%s': %s"
return encryption, nil, fmt.Errorf(f, encrypt, err.Error())
}
if e {
encryption = EncryptionRequired
}
case "strict":
encryption = EncryptionStrict
case "optional", "no", "0", "f", "false":
encryption = EncryptionOff
default:
f := "invalid encrypt '%s'"
return encryption, nil, fmt.Errorf(f, encrypt)
}
} else {
trustServerCert = true
Expand All @@ -189,6 +198,9 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e
certificate := params[Certificate]
if encryption != EncryptionDisabled {
tlsMin := params[TLSMin]
if encrypt == "strict" {
trustServerCert = false
}
tlsConfig, err := SetupTLS(certificate, trustServerCert, host, tlsMin)
if err != nil {
return encryption, nil, fmt.Errorf("failed to setup TLS: %w", err)
Expand All @@ -200,28 +212,51 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e

var skipSetup = errors.New("skip setting up TLS")

func Parse(dsn string) (Config, error) {
p := Config{
ProtocolParameters: map[string]interface{}{},
Protocols: []string{},
func getDsnType(dsn string) int {
if strings.HasPrefix(dsn, "sqlserver://") {
return DsnTypeURL
}
if strings.HasPrefix(dsn, "odbc:") {
return DsnTypeOdbc
}
return DsnTypeAdo
}

func getDsnParams(dsn string) (map[string]string, error) {

var params map[string]string
var err error
if strings.HasPrefix(dsn, "odbc:") {

switch getDsnType(dsn) {
case DsnTypeOdbc:
params, err = splitConnectionStringOdbc(dsn[len("odbc:"):])
if err != nil {
return p, err
return params, err
}
} else if strings.HasPrefix(dsn, "sqlserver://") {
case DsnTypeURL:
params, err = splitConnectionStringURL(dsn)
if err != nil {
return p, err
return params, err
}
} else {
default:
params = splitConnectionString(dsn)
}
return params, nil
}

func Parse(dsn string) (Config, error) {
p := Config{
ProtocolParameters: map[string]interface{}{},
Protocols: []string{},
}

var params map[string]string
var err error

params, err = getDsnParams(dsn)
if err != nil {
return p, err
}
p.Parameters = params

strlog, ok := params[LogParam]
Expand Down
6 changes: 6 additions & 0 deletions msdsn/conn_str_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func TestValidConnectionString(t *testing.T) {
{"encrypt=disable", func(p Config) bool { return p.Encryption == EncryptionDisabled }},
{"encrypt=disable;tlsmin=1.1", func(p Config) bool { return p.Encryption == EncryptionDisabled && p.TLSConfig == nil }},
{"encrypt=true", func(p Config) bool { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == 0 }},
{"encrypt=mandatory", func(p Config) bool { return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == 0 }},
{"encrypt=true;tlsmin=1.0", func(p Config) bool {
return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS10
}},
Expand All @@ -74,10 +75,15 @@ func TestValidConnectionString(t *testing.T) {
{"encrypt=true;tlsmin=1.2", func(p Config) bool {
return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS12
}},
{"encrypt=true;tlsmin=1.3", func(p Config) bool {
return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS13
}},
{"encrypt=true;tlsmin=1.4", func(p Config) bool {
return p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == 0
}},
{"encrypt=false", func(p Config) bool { return p.Encryption == EncryptionOff }},
{"encrypt=optional", func(p Config) bool { return p.Encryption == EncryptionOff }},
{"encrypt=strict", func(p Config) bool { return p.Encryption == EncryptionStrict }},
{"connection timeout=3;dial timeout=4;keepalive=5", func(p Config) bool {
return p.ConnTimeout == 3*time.Second && p.DialTimeout == 4*time.Second && p.KeepAlive == 5*time.Second
}},
Expand Down
Loading

0 comments on commit 2c72afb

Please sign in to comment.