diff --git a/neo4j/session_with_context.go b/neo4j/session_with_context.go index c262950c..b77c685b 100644 --- a/neo4j/session_with_context.go +++ b/neo4j/session_with_context.go @@ -308,7 +308,6 @@ func (s *sessionWithContext) BeginTransaction(ctx context.Context, configurers . return nil, errorutil.WrapError(err) } - // Begin transaction beginBookmarks, err := s.getBookmarks(ctx) if err != nil { _ = s.pool.Return(ctx, conn) @@ -337,10 +336,14 @@ func (s *sessionWithContext) BeginTransaction(ctx context.Context, configurers . fetchSize: s.fetchSize, txHandle: txHandle, onClosed: func(tx *explicitTransaction) { - // On transaction closed (rolled back or committed) - bookmarkErr := s.retrieveBookmarks(ctx, conn, beginBookmarks) - poolErr := s.pool.Return(ctx, conn) + if tx.conn == nil { + return + } + // On run failure, transaction closed (rolled back or committed) + bookmarkErr := s.retrieveBookmarks(ctx, tx.conn, beginBookmarks) + poolErr := s.pool.Return(ctx, tx.conn) tx.err = errorutil.CombineAllErrors(tx.err, bookmarkErr, poolErr) + tx.conn = nil s.explicitTx = nil }, } diff --git a/neo4j/session_with_context_test.go b/neo4j/session_with_context_test.go index d47de14d..50e05cad 100644 --- a/neo4j/session_with_context_test.go +++ b/neo4j/session_with_context_test.go @@ -655,6 +655,7 @@ func TestSession(outer *testing.T) { sess.Close(context.Background()) wg.Wait() }) + ct.Run("Cleans up router async", func(t *testing.T) { router, _, sess := createSession() wg := sync.WaitGroup{} @@ -665,7 +666,81 @@ func TestSession(outer *testing.T) { sess.Close(context.Background()) wg.Wait() }) + + ct.Run("Does not put back connection twice to the pool", func(inner *testing.T) { + type testCase struct { + name string + completeTx func(context.Context, SessionWithContext, ExplicitTransaction) error + } + cases := []testCase{ + { + name: "session close", + completeTx: func(ctx context.Context, session SessionWithContext, _ ExplicitTransaction) error { + return session.Close(ctx) + }, + }, + { + name: "tx commit", + completeTx: func(ctx context.Context, _ SessionWithContext, transaction ExplicitTransaction) error { + return transaction.Commit(ctx) + }, + }, + { + name: "tx rollback", + completeTx: func(ctx context.Context, _ SessionWithContext, transaction ExplicitTransaction) error { + return transaction.Rollback(ctx) + }, + }, + { + name: "tx close", + completeTx: func(ctx context.Context, _ SessionWithContext, transaction ExplicitTransaction) error { + return transaction.Close(ctx) + }, + }, + } + + for _, test := range cases { + inner.Run(fmt.Sprintf("after %s", test.name), func(t *testing.T) { + _, pool, session := createSession() + conn := &ConnFake{Alive: true, RunTxErr: errors.New("invalid transaction handle")} + poolReturnsCalls := 0 + pool.BorrowConn = conn + pool.ReturnHook = func() { + poolReturnsCalls++ + } + tx, err := session.BeginTransaction(ctx) + + AssertNoError(t, err) + AssertNoError(t, test.completeTx(ctx, session, tx)) + AssertIntEqual(t, poolReturnsCalls, 1) + _, err = tx.Run(ctx, "RETURN 42", nil) + AssertErrorMessageContains(t, err, "cannot use this transaction") + AssertIntEqual(t, poolReturnsCalls, 1) // pool.Return must not be called again + }) + } + }) + + ct.Run("Does not put back connection twice to the pool after second failed run", func(t *testing.T) { + _, pool, session := createSession() + runTxErr := errors.New("oopsie") + conn := &ConnFake{Alive: true, RunTxErr: runTxErr} + poolReturnsCalls := 0 + pool.BorrowConn = conn + pool.ReturnHook = func() { + poolReturnsCalls++ + } + tx, err := session.BeginTransaction(ctx) + + AssertNoError(t, err) + _, err = tx.Run(ctx, "RETURN 42", nil) + AssertDeepEquals(t, err, runTxErr) + AssertIntEqual(t, poolReturnsCalls, 1) + _, err = tx.Run(ctx, "RETURN 42", nil) + AssertErrorMessageContains(t, err, "cannot use this transaction") + AssertIntEqual(t, poolReturnsCalls, 1) // pool.Return must not be called again + }) }) + } func assertTokenExpiredError(t *testing.T, err error) { diff --git a/neo4j/transaction_with_context.go b/neo4j/transaction_with_context.go index bdf93d0d..49371944 100644 --- a/neo4j/transaction_with_context.go +++ b/neo4j/transaction_with_context.go @@ -59,14 +59,15 @@ type explicitTransaction struct { conn db.Connection fetchSize int txHandle db.TxHandle - done bool runFailed bool err error onClosed func(*explicitTransaction) } -func (tx *explicitTransaction) Run(ctx context.Context, cypher string, - params map[string]any) (ResultWithContext, error) { +func (tx *explicitTransaction) Run(ctx context.Context, cypher string, params map[string]any) (ResultWithContext, error) { + if tx.conn == nil { + return nil, transactionAlreadyCompletedError() + } stream, err := tx.conn.RunTx(ctx, tx.txHandle, db.Command{Cypher: cypher, Params: params, FetchSize: tx.fetchSize}) if err != nil { tx.err = err @@ -80,20 +81,19 @@ func (tx *explicitTransaction) Run(ctx context.Context, cypher string, func (tx *explicitTransaction) Commit(ctx context.Context) error { if tx.runFailed { - tx.runFailed, tx.done = false, true + tx.runFailed = false return tx.err } - if tx.done { + if tx.conn == nil { return transactionAlreadyCompletedError() } tx.err = tx.conn.TxCommit(ctx, tx.txHandle) - tx.done = true tx.onClosed(tx) return errorutil.WrapError(tx.err) } func (tx *explicitTransaction) Close(ctx context.Context) error { - if tx.done { + if tx.conn == nil { // repeated calls to Close => NOOP return nil } @@ -102,10 +102,10 @@ func (tx *explicitTransaction) Close(ctx context.Context) error { func (tx *explicitTransaction) Rollback(ctx context.Context) error { if tx.runFailed { - tx.done, tx.runFailed = true, false + tx.runFailed = false return nil } - if tx.done { + if tx.conn == nil { return transactionAlreadyCompletedError() } if !tx.conn.IsAlive() || tx.conn.HasFailed() { @@ -114,7 +114,6 @@ func (tx *explicitTransaction) Rollback(ctx context.Context) error { } else { tx.err = tx.conn.TxRollback(ctx, tx.txHandle) } - tx.done = true tx.onClosed(tx) return errorutil.WrapError(tx.err) } @@ -189,5 +188,5 @@ func (tx *autocommitTransaction) discard(ctx context.Context) { } func transactionAlreadyCompletedError() *UsageError { - return &UsageError{Message: "commit or rollback already called once on this transaction"} + return &UsageError{Message: "cannot use this transaction, because it has been committed or rolled back either because of an error or explicit termination"} }