Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add tracing data to prelogin and login7 packets #228

Merged
merged 7 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ environment:
SQLUSER: sa
SQLPASSWORD: Password12!
DATABASE: test
GOVERSION: 117
GOVERSION: 118
COLUMNENCRYPTION:
APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
RACE: -race -cpu 4
Expand Down
2 changes: 1 addition & 1 deletion bulkcopy.go
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,6 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)

func (b *Bulk) dlogf(ctx context.Context, format string, v ...interface{}) {
if b.Debug {
b.cn.sess.logger.Log(ctx, msdsn.LogDebug, fmt.Sprintf(format, v...))
b.cn.sess.LogF(ctx, msdsn.LogDebug, format, v...)
}
}
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
module github.com/microsoft/go-mssqldb

go 1.17
go 1.18

require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.1
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9
github.com/golang-sql/sqlexp v0.1.0
github.com/google/uuid v1.6.0
github.com/jcmturner/gokrb5/v8 v8.4.4
github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.24.0
Expand All @@ -21,7 +22,6 @@ require (
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/golang-jwt/jwt/v5 v5.2.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/hashicorp/go-uuid v1.0.3 // indirect
github.com/jcmturner/aescts/v2 v2.0.0 // indirect
github.com/jcmturner/dnsutils/v2 v2.0.0 // indirect
Expand Down
90 changes: 0 additions & 90 deletions go.sum

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions msdsn/conn_str.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"strings"
"time"
"unicode"

"github.com/google/uuid"
)

type (
Expand Down Expand Up @@ -44,6 +46,8 @@ const (
LogTransaction Log = 32
LogDebug Log = 64
LogRetries Log = 128
// LogSessionIDs tells the session logger to include activity id and connection id
LogSessionIDs Log = 0x8000
)

const (
Expand Down Expand Up @@ -79,6 +83,7 @@ const (
DialTimeout = "dial timeout"
Pipe = "pipe"
MultiSubnetFailover = "multisubnetfailover"
NoTraceID = "notraceid"
)

type Config struct {
Expand Down Expand Up @@ -131,6 +136,11 @@ type Config struct {
ColumnEncryption bool
// Attempt to connect to all IPs in parallel when MultiSubnetFailover is true
MultiSubnetFailover bool
// guid to set as Activity Id in the prelogin packet. Defaults to a new value for each Config.
ActivityID []byte
// When true, no connection id or trace id value is sent in the prelogin packet.
// Some cloud servers may block connections that lack such values.
NoTraceID bool
}

func readDERFile(filename string) ([]byte, error) {
Expand Down Expand Up @@ -285,6 +295,10 @@ func Parse(dsn string) (Config, error) {
Protocols: []string{},
}

activityid, uerr := uuid.NewRandom()
if uerr == nil {
p.ActivityID = activityid[:]
}
var params map[string]string
var err error

Expand Down Expand Up @@ -504,6 +518,13 @@ func Parse(dsn string) (Config, error) {
// Defaulting to true to prevent breaking change although other client libraries default to false
p.MultiSubnetFailover = true
}
nti, ok := params[NoTraceID]
if ok {
notraceid, err := strconv.ParseBool(nti)
if err == nil {
p.NoTraceID = notraceid
}
}
return p, nil
}

Expand Down
6 changes: 4 additions & 2 deletions msdsn/conn_str_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,8 @@ func TestValidConnectionString(t *testing.T) {
{"disableretry=1", func(p Config) bool { return p.DisableRetry }},
{"disableretry=0", func(p Config) bool { return !p.DisableRetry }},
{"", func(p Config) bool { return p.DisableRetry == disableRetryDefault }},
{"MultiSubnetFailover=true", func(p Config) bool { return p.MultiSubnetFailover }},
{"MultiSubnetFailover=true;NoTraceID=true", func(p Config) bool { return p.MultiSubnetFailover && p.NoTraceID }},
{"MultiSubnetFailover=false", func(p Config) bool { return !p.MultiSubnetFailover }},

// those are supported currently, but maybe should not be
{"someparam", func(p Config) bool { return true }},
{";;=;", func(p Config) bool { return true }},
Expand Down Expand Up @@ -226,6 +225,9 @@ func TestConnParseRoundTripFixed(t *testing.T) {
if err != nil {
t.Fatal("Params after roundtrip are not valid", err)
}
t.Log("params.URL " + params.URL().String())
params.ActivityID = nil
rtParams.ActivityID = nil
if !reflect.DeepEqual(params, rtParams) {
t.Fatal("Parameters do not match after roundtrip", params, rtParams)
}
Expand Down
44 changes: 12 additions & 32 deletions mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,7 @@ func (c *Conn) checkBadConn(ctx context.Context, err error, mayRetry bool) error
}

if !c.connectionGood && mayRetry && !c.connector.params.DisableRetry {
if c.sess.logFlags&logRetries != 0 {
c.sess.logger.Log(ctx, msdsn.LogRetries, err.Error())
}
c.sess.Log(ctx, msdsn.LogRetries, err.Error)
return newRetryableError(err)
}

Expand Down Expand Up @@ -324,9 +322,7 @@ func (c *Conn) sendCommitRequest() error {
reset := c.resetSession
c.resetSession = false
if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
if c.sess.logFlags&logErrors != 0 {
c.sess.logger.Log(c.transactionCtx, msdsn.LogErrors, fmt.Sprintf("Failed to send CommitXact with %v", err))
}
c.sess.LogF(c.transactionCtx, msdsn.LogErrors, "Failed to send CommitXact with %v", err)
c.connectionGood = false
return fmt.Errorf("faild to send CommitXact: %v", err)
}
Expand All @@ -351,9 +347,7 @@ func (c *Conn) sendRollbackRequest() error {
reset := c.resetSession
c.resetSession = false
if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
if c.sess.logFlags&logErrors != 0 {
c.sess.logger.Log(c.transactionCtx, msdsn.LogErrors, fmt.Sprintf("Failed to send RollbackXact with %v", err))
}
c.sess.LogF(c.transactionCtx, msdsn.LogErrors, "Failed to send RollbackXact with %v", err)
c.connectionGood = false
return fmt.Errorf("failed to send RollbackXact: %v", err)
}
Expand Down Expand Up @@ -388,9 +382,7 @@ func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) erro
reset := c.resetSession
c.resetSession = false
if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, "", reset); err != nil {
if c.sess.logFlags&logErrors != 0 {
c.sess.logger.Log(ctx, msdsn.LogErrors, fmt.Sprintf("Failed to send BeginXact with %v", err))
}
c.sess.LogF(ctx, msdsn.LogErrors, "Failed to send BeginXact with %v", err)
c.connectionGood = false
return fmt.Errorf("failed to send BeginXact: %v", err)
}
Expand Down Expand Up @@ -524,15 +516,13 @@ func (s *Stmt) sendQuery(ctx context.Context, args []namedValue) (err error) {
conn := s.c

// no need to check number of parameters here, it is checked by database/sql
if conn.sess.logFlags&logSQL != 0 {
conn.sess.logger.Log(ctx, msdsn.LogSQL, s.query)
}
conn.sess.LogS(ctx, msdsn.LogSQL, s.query)
if conn.sess.logFlags&logParams != 0 && len(args) > 0 {
for i := 0; i < len(args); i++ {
if len(args[i].Name) > 0 {
s.c.sess.logger.Log(ctx, msdsn.LogParams, fmt.Sprintf("\t@%s\t%v", args[i].Name, args[i].Value))
s.c.sess.LogF(ctx, msdsn.LogParams, "\t@%s\t%v", args[i].Name, args[i].Value)
} else {
s.c.sess.logger.Log(ctx, msdsn.LogParams, fmt.Sprintf("\t@p%d\t%v", i+1, args[i].Value))
s.c.sess.LogF(ctx, msdsn.LogParams, "\t@p%d\t%v", i+1, args[i].Value)
}
}
}
Expand All @@ -542,9 +532,7 @@ func (s *Stmt) sendQuery(ctx context.Context, args []namedValue) (err error) {
isProc := isProc(s.query)
if len(args) == 0 && !isProc {
if err = sendSqlBatch72(conn.sess.buf, s.query, headers, reset); err != nil {
if conn.sess.logFlags&logErrors != 0 {
conn.sess.logger.Log(ctx, msdsn.LogErrors, fmt.Sprintf("Failed to send SqlBatch with %v", err))
}
conn.sess.LogF(ctx, msdsn.LogErrors, "Failed to send SqlBatch with %v", err)
conn.connectionGood = false
return fmt.Errorf("failed to send SQL Batch: %v", err)
}
Expand All @@ -567,9 +555,7 @@ func (s *Stmt) sendQuery(ctx context.Context, args []namedValue) (err error) {
params[1] = makeStrParam(strings.Join(decls, ","))
}
if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset); err != nil {
if conn.sess.logFlags&logErrors != 0 {
conn.sess.logger.Log(ctx, msdsn.LogErrors, fmt.Sprintf("Failed to send Rpc with %v", err))
}
conn.sess.LogF(ctx, msdsn.LogErrors, "Failed to send Rpc with %v", err)
conn.connectionGood = false
return fmt.Errorf("failed to send RPC: %v", err)
}
Expand Down Expand Up @@ -1298,9 +1284,7 @@ func (rc *Rowsq) Columns() (res []string) {
for {
tok, err := rc.reader.nextToken()
if err == nil {
if rc.reader.sess.logFlags&logDebug != 0 {
rc.reader.sess.logger.Log(rc.reader.ctx, msdsn.LogDebug, fmt.Sprintf("Columns() token type:%v", reflect.TypeOf(tok)))
}
rc.reader.sess.LogF(rc.reader.ctx, msdsn.LogDebug, "Columns() token type:%v", reflect.TypeOf(tok))
if tok == nil {
return []string{}
} else {
Expand All @@ -1327,9 +1311,7 @@ func (rc *Rowsq) Next(dest []driver.Value) error {
}
for {
tok, err := rc.reader.nextToken()
if rc.reader.sess.logFlags&logDebug != 0 {
rc.reader.sess.logger.Log(rc.reader.ctx, msdsn.LogDebug, fmt.Sprintf("Next() token type:%v", reflect.TypeOf(tok)))
}
rc.reader.sess.LogF(rc.reader.ctx, msdsn.LogDebug, "Next() token type:%v", reflect.TypeOf(tok))
if err == nil {
if tok == nil {
return io.EOF
Expand Down Expand Up @@ -1391,9 +1373,7 @@ func (rc *Rowsq) NextResultSet() error {
scan:
for {
tok, err := rc.reader.nextToken()
if rc.reader.sess.logFlags&logDebug != 0 {
rc.reader.sess.logger.Log(rc.reader.ctx, msdsn.LogDebug, fmt.Sprintf("NextResultSet() token type:%v", reflect.TypeOf(tok)))
}
rc.reader.sess.LogF(rc.reader.ctx, msdsn.LogDebug, "NextResultSet() token type:%v", reflect.TypeOf(tok))

if err != nil {
return err
Expand Down
100 changes: 100 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package mssql

import (
"context"
"fmt"

"github.com/google/uuid"
"github.com/microsoft/go-mssqldb/aecmk"
"github.com/microsoft/go-mssqldb/msdsn"
)

func newSession(outbuf *tdsBuffer, logger ContextLogger, p msdsn.Config) *tdsSession {
sess := &tdsSession{
buf: outbuf,
logger: logger,
logFlags: uint64(p.LogFlags),
aeSettings: &alwaysEncryptedSettings{keyProviders: aecmk.GetGlobalCekProviders()},
}
_ = sess.activityid.Scan(p.ActivityID)
// generating a guid has a small chance of failure. Make a best effort
connid, cerr := uuid.NewRandom()
if cerr == nil {
_ = sess.connid.Scan(connid[:])
}

return sess
}

func (s *tdsSession) preparePreloginFields(ctx context.Context, p msdsn.Config, fe *featureExtFedAuth) map[uint8][]byte {
instance_buf := []byte(p.Instance)
instance_buf = append(instance_buf, 0) // zero terminate instance name

var encrypt byte
switch p.Encryption {
default:
panic(fmt.Errorf("Unsupported Encryption Config %v", p.Encryption))
case msdsn.EncryptionDisabled:
encrypt = encryptNotSup
case msdsn.EncryptionRequired:
encrypt = encryptOn
case msdsn.EncryptionOff:
encrypt = encryptOff
case msdsn.EncryptionStrict:
encrypt = encryptStrict
}
v := getDriverVersion(driverVersion)
fields := map[uint8][]byte{
// 4 bytes for version and 2 bytes for minor version
preloginVERSION: {byte(v), byte(v >> 8), byte(v >> 16), byte(v >> 24), 0, 0},
preloginENCRYPTION: {encrypt},
preloginINSTOPT: instance_buf,
preloginTHREADID: {0, 0, 0, 0},
preloginMARS: {0}, // MARS disabled
}

if !p.NoTraceID {
traceID := make([]byte, 36) // 16 byte connection id + 16 byte activity id + 4 byte sequence number
connid, _ := s.connid.Value()
activityid, _ := s.activityid.Value()
_ = copy(traceID[:16], connid.([]byte))
_ = copy(traceID[16:32], activityid.([]byte))
fields[preloginTRACEID] = traceID
if (s.logFlags)&logDebug != 0 {
msg := fmt.Sprintf("Creating prelogin packet with connection id '%s' and activity id '%s'", s.connid, s.activityid)
s.logger.Log(ctx, msdsn.LogDebug, msg)
}
}
if fe.FedAuthLibrary != FedAuthLibraryReserved {
fields[preloginFEDAUTHREQUIRED] = []byte{1}
}

return fields
}

type logFunc func() string

func (s *tdsSession) logPrefix() string {
if s.logFlags&uint64(msdsn.LogSessionIDs) != 0 {
return fmt.Sprintf("aid:%v cid:%v - ", s.activityid, s.connid)
}
return ""
}

func (s *tdsSession) LogS(ctx context.Context, category msdsn.Log, msg string) {
s.Log(ctx, category, func() string { return msg })
}

// Log checks that the session logFlags includes the category before evaluating the logFunc and emitting the trace
func (s *tdsSession) Log(ctx context.Context, category msdsn.Log, logFunc logFunc) {
if s.logFlags&uint64(category) != 0 {
s.logger.Log(ctx, category, s.logPrefix()+logFunc())
}
}

// LogF checks that the session logFlags includes the category before calling fmt.Sprintf and emitting the trace
func (s *tdsSession) LogF(ctx context.Context, category msdsn.Log, format string, a ...any) {
if s.logFlags&uint64(category) != 0 {
s.logger.Log(ctx, category, s.logPrefix()+fmt.Sprintf(format, a...))
}
}
Loading
Loading