Skip to content

Commit

Permalink
Feat: Add tracing data to prelogin and login7 packets (#228)
Browse files Browse the repository at this point in the history
* add traceid field to prelogin

* add clientid and pid to login7

* add logging of conn id to tdsSession

* fix go mod

* fix test

* update min Go to 118

Fixes #226
  • Loading branch information
shueybubbles authored Dec 5, 2024
1 parent 2521238 commit dad23d2
Show file tree
Hide file tree
Showing 13 changed files with 306 additions and 236 deletions.
2 changes: 1 addition & 1 deletion appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ environment:
SQLUSER: sa
SQLPASSWORD: Password12!
DATABASE: test
GOVERSION: 117
GOVERSION: 118
COLUMNENCRYPTION:
APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
RACE: -race -cpu 4
Expand Down
2 changes: 1 addition & 1 deletion bulkcopy.go
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,6 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)

func (b *Bulk) dlogf(ctx context.Context, format string, v ...interface{}) {
if b.Debug {
b.cn.sess.logger.Log(ctx, msdsn.LogDebug, fmt.Sprintf(format, v...))
b.cn.sess.LogF(ctx, msdsn.LogDebug, format, v...)
}
}
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
module github.com/microsoft/go-mssqldb

go 1.17
go 1.18

require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.1
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9
github.com/golang-sql/sqlexp v0.1.0
github.com/google/uuid v1.6.0
github.com/jcmturner/gokrb5/v8 v8.4.4
github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.24.0
Expand All @@ -21,7 +22,6 @@ require (
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/golang-jwt/jwt/v5 v5.2.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/hashicorp/go-uuid v1.0.3 // indirect
github.com/jcmturner/aescts/v2 v2.0.0 // indirect
github.com/jcmturner/dnsutils/v2 v2.0.0 // indirect
Expand Down
90 changes: 0 additions & 90 deletions go.sum

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions msdsn/conn_str.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"strings"
"time"
"unicode"

"github.com/google/uuid"
)

type (
Expand Down Expand Up @@ -44,6 +46,8 @@ const (
LogTransaction Log = 32
LogDebug Log = 64
LogRetries Log = 128
// LogSessionIDs tells the session logger to include activity id and connection id
LogSessionIDs Log = 0x8000
)

const (
Expand Down Expand Up @@ -79,6 +83,7 @@ const (
DialTimeout = "dial timeout"
Pipe = "pipe"
MultiSubnetFailover = "multisubnetfailover"
NoTraceID = "notraceid"
)

type Config struct {
Expand Down Expand Up @@ -131,6 +136,11 @@ type Config struct {
ColumnEncryption bool
// Attempt to connect to all IPs in parallel when MultiSubnetFailover is true
MultiSubnetFailover bool
// guid to set as Activity Id in the prelogin packet. Defaults to a new value for each Config.
ActivityID []byte
// When true, no connection id or trace id value is sent in the prelogin packet.
// Some cloud servers may block connections that lack such values.
NoTraceID bool
}

func readDERFile(filename string) ([]byte, error) {
Expand Down Expand Up @@ -285,6 +295,10 @@ func Parse(dsn string) (Config, error) {
Protocols: []string{},
}

activityid, uerr := uuid.NewRandom()
if uerr == nil {
p.ActivityID = activityid[:]
}
var params map[string]string
var err error

Expand Down Expand Up @@ -504,6 +518,13 @@ func Parse(dsn string) (Config, error) {
// Defaulting to true to prevent breaking change although other client libraries default to false
p.MultiSubnetFailover = true
}
nti, ok := params[NoTraceID]
if ok {
notraceid, err := strconv.ParseBool(nti)
if err == nil {
p.NoTraceID = notraceid
}
}
return p, nil
}

Expand Down
6 changes: 4 additions & 2 deletions msdsn/conn_str_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,8 @@ func TestValidConnectionString(t *testing.T) {
{"disableretry=1", func(p Config) bool { return p.DisableRetry }},
{"disableretry=0", func(p Config) bool { return !p.DisableRetry }},
{"", func(p Config) bool { return p.DisableRetry == disableRetryDefault }},
{"MultiSubnetFailover=true", func(p Config) bool { return p.MultiSubnetFailover }},
{"MultiSubnetFailover=true;NoTraceID=true", func(p Config) bool { return p.MultiSubnetFailover && p.NoTraceID }},
{"MultiSubnetFailover=false", func(p Config) bool { return !p.MultiSubnetFailover }},

// those are supported currently, but maybe should not be
{"someparam", func(p Config) bool { return true }},
{";;=;", func(p Config) bool { return true }},
Expand Down Expand Up @@ -226,6 +225,9 @@ func TestConnParseRoundTripFixed(t *testing.T) {
if err != nil {
t.Fatal("Params after roundtrip are not valid", err)
}
t.Log("params.URL " + params.URL().String())
params.ActivityID = nil
rtParams.ActivityID = nil
if !reflect.DeepEqual(params, rtParams) {
t.Fatal("Parameters do not match after roundtrip", params, rtParams)
}
Expand Down
44 changes: 12 additions & 32 deletions mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,7 @@ func (c *Conn) checkBadConn(ctx context.Context, err error, mayRetry bool) error
}

if !c.connectionGood && mayRetry && !c.connector.params.DisableRetry {
if c.sess.logFlags&logRetries != 0 {
c.sess.logger.Log(ctx, msdsn.LogRetries, err.Error())
}
c.sess.Log(ctx, msdsn.LogRetries, err.Error)
return newRetryableError(err)
}

Expand Down Expand Up @@ -324,9 +322,7 @@ func (c *Conn) sendCommitRequest() error {
reset := c.resetSession
c.resetSession = false
if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
if c.sess.logFlags&logErrors != 0 {
c.sess.logger.Log(c.transactionCtx, msdsn.LogErrors, fmt.Sprintf("Failed to send CommitXact with %v", err))
}
c.sess.LogF(c.transactionCtx, msdsn.LogErrors, "Failed to send CommitXact with %v", err)
c.connectionGood = false
return fmt.Errorf("faild to send CommitXact: %v", err)
}
Expand All @@ -351,9 +347,7 @@ func (c *Conn) sendRollbackRequest() error {
reset := c.resetSession
c.resetSession = false
if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
if c.sess.logFlags&logErrors != 0 {
c.sess.logger.Log(c.transactionCtx, msdsn.LogErrors, fmt.Sprintf("Failed to send RollbackXact with %v", err))
}
c.sess.LogF(c.transactionCtx, msdsn.LogErrors, "Failed to send RollbackXact with %v", err)
c.connectionGood = false
return fmt.Errorf("failed to send RollbackXact: %v", err)
}
Expand Down Expand Up @@ -388,9 +382,7 @@ func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) erro
reset := c.resetSession
c.resetSession = false
if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, "", reset); err != nil {
if c.sess.logFlags&logErrors != 0 {
c.sess.logger.Log(ctx, msdsn.LogErrors, fmt.Sprintf("Failed to send BeginXact with %v", err))
}
c.sess.LogF(ctx, msdsn.LogErrors, "Failed to send BeginXact with %v", err)
c.connectionGood = false
return fmt.Errorf("failed to send BeginXact: %v", err)
}
Expand Down Expand Up @@ -524,15 +516,13 @@ func (s *Stmt) sendQuery(ctx context.Context, args []namedValue) (err error) {
conn := s.c

// no need to check number of parameters here, it is checked by database/sql
if conn.sess.logFlags&logSQL != 0 {
conn.sess.logger.Log(ctx, msdsn.LogSQL, s.query)
}
conn.sess.LogS(ctx, msdsn.LogSQL, s.query)
if conn.sess.logFlags&logParams != 0 && len(args) > 0 {
for i := 0; i < len(args); i++ {
if len(args[i].Name) > 0 {
s.c.sess.logger.Log(ctx, msdsn.LogParams, fmt.Sprintf("\t@%s\t%v", args[i].Name, args[i].Value))
s.c.sess.LogF(ctx, msdsn.LogParams, "\t@%s\t%v", args[i].Name, args[i].Value)
} else {
s.c.sess.logger.Log(ctx, msdsn.LogParams, fmt.Sprintf("\t@p%d\t%v", i+1, args[i].Value))
s.c.sess.LogF(ctx, msdsn.LogParams, "\t@p%d\t%v", i+1, args[i].Value)
}
}
}
Expand All @@ -542,9 +532,7 @@ func (s *Stmt) sendQuery(ctx context.Context, args []namedValue) (err error) {
isProc := isProc(s.query)
if len(args) == 0 && !isProc {
if err = sendSqlBatch72(conn.sess.buf, s.query, headers, reset); err != nil {
if conn.sess.logFlags&logErrors != 0 {
conn.sess.logger.Log(ctx, msdsn.LogErrors, fmt.Sprintf("Failed to send SqlBatch with %v", err))
}
conn.sess.LogF(ctx, msdsn.LogErrors, "Failed to send SqlBatch with %v", err)
conn.connectionGood = false
return fmt.Errorf("failed to send SQL Batch: %v", err)
}
Expand All @@ -567,9 +555,7 @@ func (s *Stmt) sendQuery(ctx context.Context, args []namedValue) (err error) {
params[1] = makeStrParam(strings.Join(decls, ","))
}
if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset); err != nil {
if conn.sess.logFlags&logErrors != 0 {
conn.sess.logger.Log(ctx, msdsn.LogErrors, fmt.Sprintf("Failed to send Rpc with %v", err))
}
conn.sess.LogF(ctx, msdsn.LogErrors, "Failed to send Rpc with %v", err)
conn.connectionGood = false
return fmt.Errorf("failed to send RPC: %v", err)
}
Expand Down Expand Up @@ -1298,9 +1284,7 @@ func (rc *Rowsq) Columns() (res []string) {
for {
tok, err := rc.reader.nextToken()
if err == nil {
if rc.reader.sess.logFlags&logDebug != 0 {
rc.reader.sess.logger.Log(rc.reader.ctx, msdsn.LogDebug, fmt.Sprintf("Columns() token type:%v", reflect.TypeOf(tok)))
}
rc.reader.sess.LogF(rc.reader.ctx, msdsn.LogDebug, "Columns() token type:%v", reflect.TypeOf(tok))
if tok == nil {
return []string{}
} else {
Expand All @@ -1327,9 +1311,7 @@ func (rc *Rowsq) Next(dest []driver.Value) error {
}
for {
tok, err := rc.reader.nextToken()
if rc.reader.sess.logFlags&logDebug != 0 {
rc.reader.sess.logger.Log(rc.reader.ctx, msdsn.LogDebug, fmt.Sprintf("Next() token type:%v", reflect.TypeOf(tok)))
}
rc.reader.sess.LogF(rc.reader.ctx, msdsn.LogDebug, "Next() token type:%v", reflect.TypeOf(tok))
if err == nil {
if tok == nil {
return io.EOF
Expand Down Expand Up @@ -1391,9 +1373,7 @@ func (rc *Rowsq) NextResultSet() error {
scan:
for {
tok, err := rc.reader.nextToken()
if rc.reader.sess.logFlags&logDebug != 0 {
rc.reader.sess.logger.Log(rc.reader.ctx, msdsn.LogDebug, fmt.Sprintf("NextResultSet() token type:%v", reflect.TypeOf(tok)))
}
rc.reader.sess.LogF(rc.reader.ctx, msdsn.LogDebug, "NextResultSet() token type:%v", reflect.TypeOf(tok))

if err != nil {
return err
Expand Down
100 changes: 100 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package mssql

import (
"context"
"fmt"

"github.com/google/uuid"
"github.com/microsoft/go-mssqldb/aecmk"
"github.com/microsoft/go-mssqldb/msdsn"
)

func newSession(outbuf *tdsBuffer, logger ContextLogger, p msdsn.Config) *tdsSession {
sess := &tdsSession{
buf: outbuf,
logger: logger,
logFlags: uint64(p.LogFlags),
aeSettings: &alwaysEncryptedSettings{keyProviders: aecmk.GetGlobalCekProviders()},
}
_ = sess.activityid.Scan(p.ActivityID)
// generating a guid has a small chance of failure. Make a best effort
connid, cerr := uuid.NewRandom()
if cerr == nil {
_ = sess.connid.Scan(connid[:])
}

return sess
}

func (s *tdsSession) preparePreloginFields(ctx context.Context, p msdsn.Config, fe *featureExtFedAuth) map[uint8][]byte {
instance_buf := []byte(p.Instance)
instance_buf = append(instance_buf, 0) // zero terminate instance name

var encrypt byte
switch p.Encryption {
default:
panic(fmt.Errorf("Unsupported Encryption Config %v", p.Encryption))
case msdsn.EncryptionDisabled:
encrypt = encryptNotSup
case msdsn.EncryptionRequired:
encrypt = encryptOn
case msdsn.EncryptionOff:
encrypt = encryptOff
case msdsn.EncryptionStrict:
encrypt = encryptStrict
}
v := getDriverVersion(driverVersion)
fields := map[uint8][]byte{
// 4 bytes for version and 2 bytes for minor version
preloginVERSION: {byte(v), byte(v >> 8), byte(v >> 16), byte(v >> 24), 0, 0},
preloginENCRYPTION: {encrypt},
preloginINSTOPT: instance_buf,
preloginTHREADID: {0, 0, 0, 0},
preloginMARS: {0}, // MARS disabled
}

if !p.NoTraceID {
traceID := make([]byte, 36) // 16 byte connection id + 16 byte activity id + 4 byte sequence number
connid, _ := s.connid.Value()
activityid, _ := s.activityid.Value()
_ = copy(traceID[:16], connid.([]byte))
_ = copy(traceID[16:32], activityid.([]byte))
fields[preloginTRACEID] = traceID
if (s.logFlags)&logDebug != 0 {
msg := fmt.Sprintf("Creating prelogin packet with connection id '%s' and activity id '%s'", s.connid, s.activityid)
s.logger.Log(ctx, msdsn.LogDebug, msg)
}
}
if fe.FedAuthLibrary != FedAuthLibraryReserved {
fields[preloginFEDAUTHREQUIRED] = []byte{1}
}

return fields
}

type logFunc func() string

func (s *tdsSession) logPrefix() string {
if s.logFlags&uint64(msdsn.LogSessionIDs) != 0 {
return fmt.Sprintf("aid:%v cid:%v - ", s.activityid, s.connid)
}
return ""
}

func (s *tdsSession) LogS(ctx context.Context, category msdsn.Log, msg string) {
s.Log(ctx, category, func() string { return msg })
}

// Log checks that the session logFlags includes the category before evaluating the logFunc and emitting the trace
func (s *tdsSession) Log(ctx context.Context, category msdsn.Log, logFunc logFunc) {
if s.logFlags&uint64(category) != 0 {
s.logger.Log(ctx, category, s.logPrefix()+logFunc())
}
}

// LogF checks that the session logFlags includes the category before calling fmt.Sprintf and emitting the trace
func (s *tdsSession) LogF(ctx context.Context, category msdsn.Log, format string, a ...any) {
if s.logFlags&uint64(category) != 0 {
s.logger.Log(ctx, category, s.logPrefix()+fmt.Sprintf(format, a...))
}
}
Loading

0 comments on commit dad23d2

Please sign in to comment.