diff --git a/neo4j/session_with_context.go b/neo4j/session_with_context.go index 790104bf..4b560a19 100644 --- a/neo4j/session_with_context.go +++ b/neo4j/session_with_context.go @@ -265,6 +265,18 @@ func (s *sessionWithContext) BeginTransaction(ctx context.Context, configurers . return s.explicitTx, nil } +func (s *sessionWithContext) ExecuteRead(ctx context.Context, + work ManagedTransactionWork, configurers ...func(*TransactionConfig)) (interface{}, error) { + + return s.runRetriable(ctx, db.ReadMode, work, configurers...) +} + +func (s *sessionWithContext) ExecuteWrite(ctx context.Context, + work ManagedTransactionWork, configurers ...func(*TransactionConfig)) (interface{}, error) { + + return s.runRetriable(ctx, db.WriteMode, work, configurers...) +} + func (s *sessionWithContext) runRetriable( ctx context.Context, mode db.AccessMode, @@ -308,58 +320,11 @@ func (s *sessionWithContext) runRetriable( }, } for state.Continue() { - // Establish new connection - conn, err := s.getConnection(ctx, mode) - if err != nil { - state.OnFailure(conn, err, false) + if tryAgain, result := s.executeTransactionFunction(ctx, mode, config, &state, work); tryAgain { continue + } else { + return result, nil } - - // Begin transaction - txHandle, err := conn.TxBegin(ctx, - db.TxConfig{ - Mode: mode, - Bookmarks: s.bookmarks, - Timeout: config.Timeout, - Meta: config.Metadata, - ImpersonatedUser: s.impersonatedUser, - }) - if err != nil { - state.OnFailure(conn, err, false) - s.pool.Return(ctx, conn) - continue - } - - // Construct a transaction like thing for client to execute stuff on - // and invoke the client work function. - tx := managedTransaction{conn: conn, fetchSize: s.fetchSize, txHandle: txHandle} - x, err := work(&tx) - // Evaluate the returned error from all the work for retryable, this means - // that client can mess up the error handling. - if err != nil { - // If the client returns a client specific error that means that - // client wants to rollback. We don't do an explicit rollback here - // but instead rely on the pool invoking reset on the connection, - // that will do an implicit rollback. - state.OnFailure(conn, err, false) - s.pool.Return(ctx, conn) - continue - } - - // Commit transaction - err = conn.TxCommit(ctx, txHandle) - if err != nil { - state.OnFailure(conn, err, true) - s.pool.Return(ctx, conn) - continue - } - - // Collect bookmark and return connection to pool - s.retrieveBookmarks(conn) - s.pool.Return(ctx, conn) - - // All well - return x, nil } // When retries has occurred wrap the error, the last error is always added but @@ -378,16 +343,54 @@ func (s *sessionWithContext) runRetriable( return nil, err } -func (s *sessionWithContext) ExecuteRead(ctx context.Context, - work ManagedTransactionWork, configurers ...func(*TransactionConfig)) (interface{}, error) { +func (s *sessionWithContext) executeTransactionFunction( + ctx context.Context, + mode db.AccessMode, + config TransactionConfig, + state *retry.State, + work ManagedTransactionWork) (bool, any) { - return s.runRetriable(ctx, db.ReadMode, work, configurers...) -} + conn, err := s.getConnection(ctx, mode) + if err != nil { + state.OnFailure(conn, err, false) + return true, nil + } -func (s *sessionWithContext) ExecuteWrite(ctx context.Context, - work ManagedTransactionWork, configurers ...func(*TransactionConfig)) (interface{}, error) { + // handle transaction function panic as well + defer s.pool.Return(ctx, conn) - return s.runRetriable(ctx, db.WriteMode, work, configurers...) + txHandle, err := conn.TxBegin(ctx, + db.TxConfig{ + Mode: mode, + Bookmarks: s.bookmarks, + Timeout: config.Timeout, + Meta: config.Metadata, + ImpersonatedUser: s.impersonatedUser, + }) + if err != nil { + state.OnFailure(conn, err, false) + return true, nil + } + + tx := managedTransaction{conn: conn, fetchSize: s.fetchSize, txHandle: txHandle} + x, err := work(&tx) + if err != nil { + // If the client returns a client specific error that means that + // client wants to rollback. We don't do an explicit rollback here + // but instead rely on the pool invoking reset on the connection, + // that will do an implicit rollback. + state.OnFailure(conn, err, false) + return true, nil + } + + err = conn.TxCommit(ctx, txHandle) + if err != nil { + state.OnFailure(conn, err, true) + return true, nil + } + + s.retrieveBookmarks(conn) + return false, x } func (s *sessionWithContext) getServers(ctx context.Context, mode db.AccessMode) ([]string, error) { diff --git a/neo4j/session_with_context_test.go b/neo4j/session_with_context_test.go index 61b45761..57e0a7e2 100644 --- a/neo4j/session_with_context_test.go +++ b/neo4j/session_with_context_test.go @@ -35,6 +35,9 @@ import ( "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" ) +type transactionFunc func(context.Context, ManagedTransactionWork, ...func(*TransactionConfig)) (any, error) +type transactionFuncApi func(session SessionWithContext) transactionFunc + func TestSession(st *testing.T) { logger := log.Console{Errors: true, Infos: true, Warns: true, Debugs: true} boltLogger := log.ConsoleBoltLogger{} @@ -71,7 +74,7 @@ func TestSession(st *testing.T) { tokenExpiredErr := &db.Neo4jError{Code: "Neo.ClientError.Security.TokenExpired", Msg: "oopsie whoopsie"} - st.Run("Retry mechanism", func(rt *testing.T) { + st.Run("Transaction Functions", func(rt *testing.T) { // Checks that retries occur on database error and that it stops retrying after a certain // amount of time and that connections are returned to pool upon failure. rt.Run("Consistent transient error", func(t *testing.T) { @@ -163,6 +166,33 @@ func TestSession(st *testing.T) { AssertStringEqual(t, mydb, conn.DatabaseName) AssertIntEqual(t, numDefaultDbLookups, 1) }) + + transactionFunctions := map[string]transactionFuncApi{ + "read tx func": func(s SessionWithContext) transactionFunc { return s.ExecuteRead }, + "write tx func": func(s SessionWithContext) transactionFunc { return s.ExecuteWrite }, + } + + for name, txFuncApi := range transactionFunctions { + rt.Run(fmt.Sprintf("Implicitly rolls back when a %s panics without retry", name), func(t *testing.T) { + _, pool, sess := createSessionFromConfig(SessionConfig{}) + pool.BorrowConn = &ConnFake{Alive: true} + poolReturnCalled := 0 + pool.ReturnHook = func() { + poolReturnCalled++ + } + panicBubblesUp := false + func() { + defer func() { + panicBubblesUp = recover() != nil + }() + _, _ = txFuncApi(sess)(context.Background(), func(tx ManagedTransaction) (interface{}, error) { + panic("oopsie") + }) + }() + AssertIntEqual(t, poolReturnCalled, 1) + AssertTrue(t, panicBubblesUp) + }) + } }) st.Run("Bookmarking", func(bt *testing.T) {