Skip to content

Commit

Permalink
Fix concurrent auth token map modification (#558)
Browse files Browse the repository at this point in the history
Co-authored-by: Stephen Cathcart <stephen.cathcart@neotechnology.com>
  • Loading branch information
robsdedude and StephenCathcart authored Dec 15, 2023
1 parent fdbce9f commit 29fad8a
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 14 deletions.
166 changes: 166 additions & 0 deletions neo4j/internal/bolt/bolt5_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/neo4j/neo4j-go-driver/v5/neo4j/notifications"
"io"
"reflect"
"strings"
"sync"
"testing"
"time"
Expand All @@ -33,6 +34,22 @@ import (
. "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/testutil"
)

type recordingBoltLogger struct {
clientMessages []string
serverMessages []string
}

func (r *recordingBoltLogger) LogClientMessage(context string, msg string, args ...any) {
fmtString := fmt.Sprintf("[%s]", context) + msg
r.clientMessages = append(r.clientMessages, fmt.Sprintf(fmtString, args...))

}

func (r *recordingBoltLogger) LogServerMessage(context string, msg string, args ...any) {
fmtString := fmt.Sprintf("[%s]", context) + msg
r.serverMessages = append(r.serverMessages, fmt.Sprintf(fmtString, args...))
}

// bolt5.Connect is tested through Connect, no need to test it here
func TestBolt5(outer *testing.T) {
// Test streams
Expand Down Expand Up @@ -1706,6 +1723,155 @@ func TestBolt5(outer *testing.T) {
AssertIntEqual(t, int(summary2.TFirst), 20)
})

outer.Run("redacts credentials 5.0", func(t *testing.T) {
runs := 100
ctx := context.Background()
authToken := auth.Manager.(iauth.Token)
expectedPrincipal := authToken.Tokens["principal"].(string)
expectedCredentials := authToken.Tokens["credentials"].(string)

var wg sync.WaitGroup
wg.Add(runs)
for i := 0; i < runs; i++ {
go func() {
tcpConn, srv, cleanup := setupBolt5Pipe(t)
defer cleanup()
go func() {
srv.waitForHandshake()
srv.acceptVersion(5, 0)
hello := srv.waitForHello()
principal, exists := hello["principal"]
if !exists {
t.Error("Missing principal in hello")
}
if principal != expectedPrincipal {
t.Errorf("Expected principal %s but got %s", expectedPrincipal, principal)
}
credentials, exists := hello["credentials"]
if !exists {
t.Error("Missing credentials in hello")
}
if credentials != expectedCredentials {
t.Errorf("Expected credentials %s but got %s", expectedCredentials, credentials)
}

srv.acceptHello()
}()

boltLogger := recordingBoltLogger{}

c, err := Connect(
context.Background(),
"serverName",
tcpConn,
auth,
"007",
nil,
noopErrorListener{},
logger,
&boltLogger,
idb.NotificationConfig{},
)
if err != nil {
t.Error(err)
}
defer c.Close(ctx)

bolt := c.(*bolt5)
assertBoltState(t, bolt5Ready, bolt)

AssertAny(t, boltLogger.clientMessages, func(logMsg string) bool {
if strings.Contains(logMsg, "HELLO") {
AssertStringContain(t, logMsg, "credentials")
AssertStringNotContain(t, logMsg, expectedCredentials)
return true
}
return false
})

wg.Done()
}()
}

wg.Wait()
})

outer.Run("redacts credentials 5.1", func(t *testing.T) {
runs := 100
ctx := context.Background()
authToken := auth.Manager.(iauth.Token)
expectedPrincipal := authToken.Tokens["principal"].(string)
expectedCredentials := authToken.Tokens["credentials"].(string)

var wg sync.WaitGroup
wg.Add(runs)
for i := 0; i < runs; i++ {
go func() {
tcpConn, srv, cleanup := setupBolt5Pipe(t)
defer cleanup()
go func() {
srv.waitForHandshake()
srv.acceptVersion(5, 1)
srv.waitForHelloWithoutAuthToken()
srv.acceptHello()
logon := srv.waitForLogon()
srv.acceptLogon()
principal, exists := logon["principal"]
if !exists {
t.Error("Missing principal in logon")
}
if principal != expectedPrincipal {
t.Errorf("Expected principal %s but got %s", expectedPrincipal, principal)
}
credentials, exists := logon["credentials"]
if !exists {
t.Error("Missing credentials in logon")
}
if credentials != expectedCredentials {
t.Errorf("Expected credentials %s but got %s", expectedCredentials, credentials)
}

srv.acceptLogon()
}()

boltLogger := recordingBoltLogger{}

c, err := Connect(
context.Background(),
"serverName",
tcpConn,
auth,
"007",
nil,
noopErrorListener{},
logger,
&boltLogger,
idb.NotificationConfig{},
)
if err != nil {
t.Error(err)
}
defer c.Close(ctx)

bolt := c.(*bolt5)
assertBoltState(t, bolt5Ready, bolt)

AssertAny(t, boltLogger.clientMessages, func(logMsg string) bool {
if strings.Contains(logMsg, "LOGON") {
AssertStringContain(t, logMsg, "credentials")
AssertStringNotContain(t, logMsg, expectedCredentials)
return true
}
return false
})

wg.Done()
}()
}

wg.Wait()
})

type txTimeoutTestCase struct {
description string
input time.Duration
Expand Down
29 changes: 15 additions & 14 deletions neo4j/internal/bolt/bolt_logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,27 @@ import (

type loggableDictionary map[string]any

func (d loggableDictionary) String() string {
if credentials, ok := d["credentials"]; ok {
d["credentials"] = "<redacted>"
defer func() {
d["credentials"] = credentials
}()
func copyAndSanitizeDictionary[T any | string](in map[string]T) map[string]T {
out := make(map[string]T, len(in))
for k, v := range in {
if k == "credentials" {
var redacted any = "<redacted>"
out[k] = redacted.(T)
} else {
out[k] = v
}
}
return serializeTrace(d)
return out
}

func (d loggableDictionary) String() string {
return serializeTrace(copyAndSanitizeDictionary(d))
}

type loggableStringDictionary map[string]string

func (sd loggableStringDictionary) String() string {
if credentials, ok := sd["credentials"]; ok {
sd["credentials"] = "<redacted>"
defer func() {
sd["credentials"] = credentials
}()
}
return serializeTrace(sd)
return serializeTrace(copyAndSanitizeDictionary(sd))
}

type loggableList []any
Expand Down
26 changes: 26 additions & 0 deletions neo4j/internal/testutil/asserts.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,32 @@ func AssertStringContain(t *testing.T, s, sub string) {
}
}

func AssertStringNotContain(t *testing.T, s, sub string) {
t.Helper()
if strings.Contains(s, sub) {
t.Errorf("Expected %s to not contain %s", s, sub)
}
}

func AssertAny[T any](t *testing.T, slice []T, predicate func(T) bool) {
t.Helper()
for _, e := range slice {
if predicate(e) {
return
}
}
t.Errorf("Expected slice to contain element matching predicate")
}

func AssertAll[T any](t *testing.T, slice []T, predicate func(T) bool) {
t.Helper()
for _, e := range slice {
if !predicate(e) {
t.Errorf("Expected slice to contain only elements matching predicate")
}
}
}

func AssertMapHasKey[K comparable](t *testing.T, m map[K]any, key K) {
t.Helper()
if _, ok := m[key]; !ok {
Expand Down

0 comments on commit 29fad8a

Please sign in to comment.