From 29fad8a680b47c5335a9ca436e5f1893c0e563de Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 15 Dec 2023 10:27:41 +0100 Subject: [PATCH] Fix concurrent auth token map modification (#558) Co-authored-by: Stephen Cathcart --- neo4j/internal/bolt/bolt5_test.go | 166 ++++++++++++++++++++++++++++ neo4j/internal/bolt/bolt_logging.go | 29 ++--- neo4j/internal/testutil/asserts.go | 26 +++++ 3 files changed, 207 insertions(+), 14 deletions(-) diff --git a/neo4j/internal/bolt/bolt5_test.go b/neo4j/internal/bolt/bolt5_test.go index bc1c9e3f..9d3a5bcf 100644 --- a/neo4j/internal/bolt/bolt5_test.go +++ b/neo4j/internal/bolt/bolt5_test.go @@ -25,6 +25,7 @@ import ( "github.com/neo4j/neo4j-go-driver/v5/neo4j/notifications" "io" "reflect" + "strings" "sync" "testing" "time" @@ -33,6 +34,22 @@ import ( . "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/testutil" ) +type recordingBoltLogger struct { + clientMessages []string + serverMessages []string +} + +func (r *recordingBoltLogger) LogClientMessage(context string, msg string, args ...any) { + fmtString := fmt.Sprintf("[%s]", context) + msg + r.clientMessages = append(r.clientMessages, fmt.Sprintf(fmtString, args...)) + +} + +func (r *recordingBoltLogger) LogServerMessage(context string, msg string, args ...any) { + fmtString := fmt.Sprintf("[%s]", context) + msg + r.serverMessages = append(r.serverMessages, fmt.Sprintf(fmtString, args...)) +} + // bolt5.Connect is tested through Connect, no need to test it here func TestBolt5(outer *testing.T) { // Test streams @@ -1706,6 +1723,155 @@ func TestBolt5(outer *testing.T) { AssertIntEqual(t, int(summary2.TFirst), 20) }) + outer.Run("redacts credentials 5.0", func(t *testing.T) { + runs := 100 + ctx := context.Background() + authToken := auth.Manager.(iauth.Token) + expectedPrincipal := authToken.Tokens["principal"].(string) + expectedCredentials := authToken.Tokens["credentials"].(string) + + var wg sync.WaitGroup + wg.Add(runs) + for i := 0; i < runs; i++ { + go func() { + tcpConn, srv, cleanup := setupBolt5Pipe(t) + defer cleanup() + go func() { + srv.waitForHandshake() + srv.acceptVersion(5, 0) + hello := srv.waitForHello() + principal, exists := hello["principal"] + if !exists { + t.Error("Missing principal in hello") + } + if principal != expectedPrincipal { + t.Errorf("Expected principal %s but got %s", expectedPrincipal, principal) + } + credentials, exists := hello["credentials"] + if !exists { + t.Error("Missing credentials in hello") + } + if credentials != expectedCredentials { + t.Errorf("Expected credentials %s but got %s", expectedCredentials, credentials) + } + + srv.acceptHello() + }() + + boltLogger := recordingBoltLogger{} + + c, err := Connect( + context.Background(), + "serverName", + tcpConn, + auth, + "007", + nil, + noopErrorListener{}, + logger, + &boltLogger, + idb.NotificationConfig{}, + ) + if err != nil { + t.Error(err) + } + defer c.Close(ctx) + + bolt := c.(*bolt5) + assertBoltState(t, bolt5Ready, bolt) + + AssertAny(t, boltLogger.clientMessages, func(logMsg string) bool { + if strings.Contains(logMsg, "HELLO") { + AssertStringContain(t, logMsg, "credentials") + AssertStringNotContain(t, logMsg, expectedCredentials) + return true + } + return false + }) + + wg.Done() + }() + } + + wg.Wait() + }) + + outer.Run("redacts credentials 5.1", func(t *testing.T) { + runs := 100 + ctx := context.Background() + authToken := auth.Manager.(iauth.Token) + expectedPrincipal := authToken.Tokens["principal"].(string) + expectedCredentials := authToken.Tokens["credentials"].(string) + + var wg sync.WaitGroup + wg.Add(runs) + for i := 0; i < runs; i++ { + go func() { + tcpConn, srv, cleanup := setupBolt5Pipe(t) + defer cleanup() + go func() { + srv.waitForHandshake() + srv.acceptVersion(5, 1) + srv.waitForHelloWithoutAuthToken() + srv.acceptHello() + logon := srv.waitForLogon() + srv.acceptLogon() + principal, exists := logon["principal"] + if !exists { + t.Error("Missing principal in logon") + } + if principal != expectedPrincipal { + t.Errorf("Expected principal %s but got %s", expectedPrincipal, principal) + } + credentials, exists := logon["credentials"] + if !exists { + t.Error("Missing credentials in logon") + } + if credentials != expectedCredentials { + t.Errorf("Expected credentials %s but got %s", expectedCredentials, credentials) + } + + srv.acceptLogon() + }() + + boltLogger := recordingBoltLogger{} + + c, err := Connect( + context.Background(), + "serverName", + tcpConn, + auth, + "007", + nil, + noopErrorListener{}, + logger, + &boltLogger, + idb.NotificationConfig{}, + ) + if err != nil { + t.Error(err) + } + defer c.Close(ctx) + + bolt := c.(*bolt5) + assertBoltState(t, bolt5Ready, bolt) + + AssertAny(t, boltLogger.clientMessages, func(logMsg string) bool { + if strings.Contains(logMsg, "LOGON") { + AssertStringContain(t, logMsg, "credentials") + AssertStringNotContain(t, logMsg, expectedCredentials) + return true + } + return false + }) + + wg.Done() + }() + } + + wg.Wait() + }) + type txTimeoutTestCase struct { description string input time.Duration diff --git a/neo4j/internal/bolt/bolt_logging.go b/neo4j/internal/bolt/bolt_logging.go index aceb7873..701dd220 100644 --- a/neo4j/internal/bolt/bolt_logging.go +++ b/neo4j/internal/bolt/bolt_logging.go @@ -25,26 +25,27 @@ import ( type loggableDictionary map[string]any -func (d loggableDictionary) String() string { - if credentials, ok := d["credentials"]; ok { - d["credentials"] = "" - defer func() { - d["credentials"] = credentials - }() +func copyAndSanitizeDictionary[T any | string](in map[string]T) map[string]T { + out := make(map[string]T, len(in)) + for k, v := range in { + if k == "credentials" { + var redacted any = "" + out[k] = redacted.(T) + } else { + out[k] = v + } } - return serializeTrace(d) + return out +} + +func (d loggableDictionary) String() string { + return serializeTrace(copyAndSanitizeDictionary(d)) } type loggableStringDictionary map[string]string func (sd loggableStringDictionary) String() string { - if credentials, ok := sd["credentials"]; ok { - sd["credentials"] = "" - defer func() { - sd["credentials"] = credentials - }() - } - return serializeTrace(sd) + return serializeTrace(copyAndSanitizeDictionary(sd)) } type loggableList []any diff --git a/neo4j/internal/testutil/asserts.go b/neo4j/internal/testutil/asserts.go index 2b4e0e2d..8dc597f4 100644 --- a/neo4j/internal/testutil/asserts.go +++ b/neo4j/internal/testutil/asserts.go @@ -177,6 +177,32 @@ func AssertStringContain(t *testing.T, s, sub string) { } } +func AssertStringNotContain(t *testing.T, s, sub string) { + t.Helper() + if strings.Contains(s, sub) { + t.Errorf("Expected %s to not contain %s", s, sub) + } +} + +func AssertAny[T any](t *testing.T, slice []T, predicate func(T) bool) { + t.Helper() + for _, e := range slice { + if predicate(e) { + return + } + } + t.Errorf("Expected slice to contain element matching predicate") +} + +func AssertAll[T any](t *testing.T, slice []T, predicate func(T) bool) { + t.Helper() + for _, e := range slice { + if !predicate(e) { + t.Errorf("Expected slice to contain only elements matching predicate") + } + } +} + func AssertMapHasKey[K comparable](t *testing.T, m map[K]any, key K) { t.Helper() if _, ok := m[key]; !ok {