Skip to content

Commit

Permalink
Handle panic in transaction functions
Browse files Browse the repository at this point in the history
Handle the connection return to the pool in a deferred call
so that panicking transaction function calls are covered as
well.
  • Loading branch information
fbiville authored Apr 12, 2022
1 parent b602bd6 commit 7b2399d
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 58 deletions.
117 changes: 60 additions & 57 deletions neo4j/session_with_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,18 @@ func (s *sessionWithContext) BeginTransaction(ctx context.Context, configurers .
return s.explicitTx, nil
}

func (s *sessionWithContext) ExecuteRead(ctx context.Context,
work ManagedTransactionWork, configurers ...func(*TransactionConfig)) (interface{}, error) {

return s.runRetriable(ctx, db.ReadMode, work, configurers...)
}

func (s *sessionWithContext) ExecuteWrite(ctx context.Context,
work ManagedTransactionWork, configurers ...func(*TransactionConfig)) (interface{}, error) {

return s.runRetriable(ctx, db.WriteMode, work, configurers...)
}

func (s *sessionWithContext) runRetriable(
ctx context.Context,
mode db.AccessMode,
Expand Down Expand Up @@ -308,58 +320,11 @@ func (s *sessionWithContext) runRetriable(
},
}
for state.Continue() {
// Establish new connection
conn, err := s.getConnection(ctx, mode)
if err != nil {
state.OnFailure(conn, err, false)
if tryAgain, result := s.executeTransactionFunction(ctx, mode, config, &state, work); tryAgain {
continue
} else {
return result, nil
}

// Begin transaction
txHandle, err := conn.TxBegin(ctx,
db.TxConfig{
Mode: mode,
Bookmarks: s.bookmarks,
Timeout: config.Timeout,
Meta: config.Metadata,
ImpersonatedUser: s.impersonatedUser,
})
if err != nil {
state.OnFailure(conn, err, false)
s.pool.Return(ctx, conn)
continue
}

// Construct a transaction like thing for client to execute stuff on
// and invoke the client work function.
tx := managedTransaction{conn: conn, fetchSize: s.fetchSize, txHandle: txHandle}
x, err := work(&tx)
// Evaluate the returned error from all the work for retryable, this means
// that client can mess up the error handling.
if err != nil {
// If the client returns a client specific error that means that
// client wants to rollback. We don't do an explicit rollback here
// but instead rely on the pool invoking reset on the connection,
// that will do an implicit rollback.
state.OnFailure(conn, err, false)
s.pool.Return(ctx, conn)
continue
}

// Commit transaction
err = conn.TxCommit(ctx, txHandle)
if err != nil {
state.OnFailure(conn, err, true)
s.pool.Return(ctx, conn)
continue
}

// Collect bookmark and return connection to pool
s.retrieveBookmarks(conn)
s.pool.Return(ctx, conn)

// All well
return x, nil
}

// When retries has occurred wrap the error, the last error is always added but
Expand All @@ -378,16 +343,54 @@ func (s *sessionWithContext) runRetriable(
return nil, err
}

func (s *sessionWithContext) ExecuteRead(ctx context.Context,
work ManagedTransactionWork, configurers ...func(*TransactionConfig)) (interface{}, error) {
func (s *sessionWithContext) executeTransactionFunction(
ctx context.Context,
mode db.AccessMode,
config TransactionConfig,
state *retry.State,
work ManagedTransactionWork) (bool, any) {

return s.runRetriable(ctx, db.ReadMode, work, configurers...)
}
conn, err := s.getConnection(ctx, mode)
if err != nil {
state.OnFailure(conn, err, false)
return true, nil
}

func (s *sessionWithContext) ExecuteWrite(ctx context.Context,
work ManagedTransactionWork, configurers ...func(*TransactionConfig)) (interface{}, error) {
// handle transaction function panic as well
defer s.pool.Return(ctx, conn)

return s.runRetriable(ctx, db.WriteMode, work, configurers...)
txHandle, err := conn.TxBegin(ctx,
db.TxConfig{
Mode: mode,
Bookmarks: s.bookmarks,
Timeout: config.Timeout,
Meta: config.Metadata,
ImpersonatedUser: s.impersonatedUser,
})
if err != nil {
state.OnFailure(conn, err, false)
return true, nil
}

tx := managedTransaction{conn: conn, fetchSize: s.fetchSize, txHandle: txHandle}
x, err := work(&tx)
if err != nil {
// If the client returns a client specific error that means that
// client wants to rollback. We don't do an explicit rollback here
// but instead rely on the pool invoking reset on the connection,
// that will do an implicit rollback.
state.OnFailure(conn, err, false)
return true, nil
}

err = conn.TxCommit(ctx, txHandle)
if err != nil {
state.OnFailure(conn, err, true)
return true, nil
}

s.retrieveBookmarks(conn)
return false, x
}

func (s *sessionWithContext) getServers(ctx context.Context, mode db.AccessMode) ([]string, error) {
Expand Down
32 changes: 31 additions & 1 deletion neo4j/session_with_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ import (
"github.com/neo4j/neo4j-go-driver/v5/neo4j/log"
)

type transactionFunc func(context.Context, ManagedTransactionWork, ...func(*TransactionConfig)) (any, error)
type transactionFuncApi func(session SessionWithContext) transactionFunc

func TestSession(st *testing.T) {
logger := log.Console{Errors: true, Infos: true, Warns: true, Debugs: true}
boltLogger := log.ConsoleBoltLogger{}
Expand Down Expand Up @@ -71,7 +74,7 @@ func TestSession(st *testing.T) {

tokenExpiredErr := &db.Neo4jError{Code: "Neo.ClientError.Security.TokenExpired", Msg: "oopsie whoopsie"}

st.Run("Retry mechanism", func(rt *testing.T) {
st.Run("Transaction Functions", func(rt *testing.T) {
// Checks that retries occur on database error and that it stops retrying after a certain
// amount of time and that connections are returned to pool upon failure.
rt.Run("Consistent transient error", func(t *testing.T) {
Expand Down Expand Up @@ -163,6 +166,33 @@ func TestSession(st *testing.T) {
AssertStringEqual(t, mydb, conn.DatabaseName)
AssertIntEqual(t, numDefaultDbLookups, 1)
})

transactionFunctions := map[string]transactionFuncApi{
"read tx func": func(s SessionWithContext) transactionFunc { return s.ExecuteRead },
"write tx func": func(s SessionWithContext) transactionFunc { return s.ExecuteWrite },
}

for name, txFuncApi := range transactionFunctions {
rt.Run(fmt.Sprintf("Implicitly rolls back when a %s panics without retry", name), func(t *testing.T) {
_, pool, sess := createSessionFromConfig(SessionConfig{})
pool.BorrowConn = &ConnFake{Alive: true}
poolReturnCalled := 0
pool.ReturnHook = func() {
poolReturnCalled++
}
panicBubblesUp := false
func() {
defer func() {
panicBubblesUp = recover() != nil
}()
_, _ = txFuncApi(sess)(context.Background(), func(tx ManagedTransaction) (interface{}, error) {
panic("oopsie")
})
}()
AssertIntEqual(t, poolReturnCalled, 1)
AssertTrue(t, panicBubblesUp)
})
}
})

st.Run("Bookmarking", func(bt *testing.T) {
Expand Down

0 comments on commit 7b2399d

Please sign in to comment.