Skip to content

Commit

Permalink
Fix transaction termination handling (#536)
Browse files Browse the repository at this point in the history
Transaction termination handling fix

---------

Co-authored-by: Robsdedude <dev@rouvenbauer.de>
Co-authored-by: Rouven Bauer <rouven.bauer@neo4j.com>
  • Loading branch information
3 people authored Nov 3, 2023
1 parent 88f48ad commit 4d8aa8f
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 68 deletions.
4 changes: 4 additions & 0 deletions neo4j/driver_with_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,10 @@ func (f *fakeResult) legacy() Result {
panic("implement me")
}

func (f *fakeResult) errorHandler(error) {
panic("implement me")
}

type fakeSummary struct {
resultAvailableAfter time.Duration
}
Expand Down
23 changes: 22 additions & 1 deletion neo4j/result_with_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,13 @@ type ResultWithContext interface {
IsOpen() bool
buffer(ctx context.Context)
legacy() Result
errorHandler(err error)
}

const consumedResultError = "result cursor is not available anymore"

const resultFailedError = "result failed due to invalid transaction"

type resultWithContext struct {
conn idb.Connection
streamHandle idb.StreamHandle
Expand All @@ -72,15 +75,24 @@ type resultWithContext struct {
peekedRecord *Record
peekedSummary *db.Summary
peeked bool
txState *transactionState
afterConsumptionHook func()
}

func newResultWithContext(connection idb.Connection, stream idb.StreamHandle, cypher string, params map[string]any, afterConsumptionHook func()) ResultWithContext {
func newResultWithContext(
connection idb.Connection,
stream idb.StreamHandle,
cypher string,
params map[string]any,
txState *transactionState,
afterConsumptionHook func(),
) ResultWithContext {
return &resultWithContext{
conn: connection,
streamHandle: stream,
cypher: cypher,
params: params,
txState: txState,
afterConsumptionHook: afterConsumptionHook,
}
}
Expand Down Expand Up @@ -234,6 +246,9 @@ func (r *resultWithContext) advance(ctx context.Context) {
r.peeked = false
} else {
r.record, r.summary, r.err = r.conn.Next(ctx, r.streamHandle)
if r.err != nil {
r.txState.onError(r.err)
}
}
}

Expand Down Expand Up @@ -262,3 +277,9 @@ func (r *resultWithContext) callAfterConsumptionHook() {
r.afterConsumptionHook()
r.afterConsumptionHook = nil
}

func (r *resultWithContext) errorHandler(error) {
if r.err == nil {
r.err = &UsageError{Message: resultFailedError}
}
}
37 changes: 19 additions & 18 deletions neo4j/result_with_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func TestResult(outer *testing.T) {
// Initialization
outer.Run("Initialization", func(t *testing.T) {
conn := &ConnFake{}
res := newResultWithContext(conn, streamHandle, cypher, params, nil)
res := newResultWithContext(conn, streamHandle, cypher, params, &transactionState{}, nil)
rec := res.Record()
if rec != nil {
t.Errorf("Should be no record")
Expand Down Expand Up @@ -119,7 +119,7 @@ func TestResult(outer *testing.T) {
for _, c := range iterCases {
outer.Run(fmt.Sprintf("Next %s", c.name), func(t *testing.T) {
conn := &ConnFake{Nexts: c.stream}
res := newResultWithContext(conn, streamHandle, cypher, params, nil)
res := newResultWithContext(conn, streamHandle, cypher, params, &transactionState{}, nil)
for i, call := range c.rounds {
gotNext := res.Next(context.Background())
if gotNext != call.expectNext {
Expand Down Expand Up @@ -154,7 +154,7 @@ func TestResult(outer *testing.T) {
var nextSecond *Record
conn := &ConnFake{Nexts: []Next{{Record: recs[0]}}}

result := newResultWithContext(conn, streamHandle, cypher, params, nil)
result := newResultWithContext(conn, streamHandle, cypher, params, &transactionState{}, nil)

AssertTrue(t, result.PeekRecord(ctx, &peekedFirst))
AssertTrue(t, result.PeekRecord(ctx, &peekedSecond))
Expand All @@ -172,7 +172,7 @@ func TestResult(outer *testing.T) {
inner.Run("peeks single record", func(t *testing.T) {
conn := &ConnFake{Nexts: []Next{{Record: record1}}}

result := newResultWithContext(conn, streamHandle, cypher, params, nil)
result := newResultWithContext(conn, streamHandle, cypher, params, &transactionState{}, nil)

AssertTrue(t, result.Peek(ctx))
AssertDeepEquals(t, record1, result.Record())
Expand All @@ -187,7 +187,7 @@ func TestResult(outer *testing.T) {
inner.Run("peeks once and fetches subsequent records", func(t *testing.T) {
conn := &ConnFake{Nexts: []Next{{Record: record1}, {Record: record2}}}

result := newResultWithContext(conn, streamHandle, cypher, params, nil)
result := newResultWithContext(conn, streamHandle, cypher, params, &transactionState{}, nil)

AssertTrue(t, result.Peek(ctx))
AssertDeepEquals(t, record1, result.Record())
Expand All @@ -208,7 +208,7 @@ func TestResult(outer *testing.T) {
ConsumeErr: nil,
Nexts: []Next{{Record: recs[0]}},
}
res := newResultWithContext(conn, streamHandle, cypher, params, nil)
res := newResultWithContext(conn, streamHandle, cypher, params, &transactionState{}, nil)
// Get one record to make sure that Record() is cleared
res.Next(ctx)
AssertNotNil(t, res.Record())
Expand All @@ -226,7 +226,7 @@ func TestResult(outer *testing.T) {
ConsumeErr: errs[0],
Nexts: []Next{{Record: recs[0]}},
}
res := newResultWithContext(conn, streamHandle, cypher, params, nil)
res := newResultWithContext(conn, streamHandle, cypher, params, &transactionState{}, nil)
// Get one record to make sure that Record() is cleared
res.Next(ctx)
AssertNotNil(t, res.Record())
Expand All @@ -243,7 +243,7 @@ func TestResult(outer *testing.T) {
conn := &ConnFake{
Nexts: []Next{{Record: recs[0]}, {Summary: sums[0]}},
}
res := newResultWithContext(conn, streamHandle, cypher, params, nil)
res := newResultWithContext(conn, streamHandle, cypher, params, &transactionState{}, nil)
rec, err := res.Single(ctx)
AssertNotNil(t, rec)
AssertNoError(t, err)
Expand All @@ -256,7 +256,7 @@ func TestResult(outer *testing.T) {
conn := &ConnFake{
Nexts: []Next{{Summary: sums[0]}},
}
res := newResultWithContext(conn, streamHandle, cypher, params, nil)
res := newResultWithContext(conn, streamHandle, cypher, params, &transactionState{}, nil)
rec, err := res.Single(ctx)
AssertNil(t, rec)
assertUsageError(t, err)
Expand All @@ -275,7 +275,7 @@ func TestResult(outer *testing.T) {
},
ConsumeSum: sums[0],
}
res := newResultWithContext(conn, streamHandle, cypher, params, nil)
res := newResultWithContext(conn, streamHandle, cypher, params, &transactionState{}, nil)
rec, err := res.Single(ctx)
AssertNil(t, rec)
assertUsageError(t, err)
Expand All @@ -296,7 +296,7 @@ func TestResult(outer *testing.T) {
conn := &ConnFake{
Nexts: []Next{{Err: errs[0]}},
}
res := newResultWithContext(conn, streamHandle, cypher, params, nil)
res := newResultWithContext(conn, streamHandle, cypher, params, &transactionState{}, nil)
rec, err := res.Single(ctx)
AssertNil(t, rec)
AssertError(t, err)
Expand All @@ -310,7 +310,7 @@ func TestResult(outer *testing.T) {
conn := &ConnFake{
Nexts: []Next{{Record: recs[0]}, {Record: recs[1]}, {Summary: sums[0]}},
}
res := newResultWithContext(conn, streamHandle, cypher, params, nil)
res := newResultWithContext(conn, streamHandle, cypher, params, &transactionState{}, nil)
coll, err := res.Collect(ctx)
AssertNoError(t, err)
AssertLen(t, coll, 2)
Expand All @@ -325,7 +325,7 @@ func TestResult(outer *testing.T) {
conn := &ConnFake{
Nexts: []Next{{Record: recs[0]}, {Record: recs[1]}, {Record: recs[2]}, {Summary: sums[0]}},
}
res := newResultWithContext(conn, streamHandle, cypher, params, nil)
res := newResultWithContext(conn, streamHandle, cypher, params, &transactionState{}, nil)
res.Next(ctx)
AssertNotNil(t, res.Record())
coll, err := res.Collect(ctx)
Expand All @@ -342,7 +342,7 @@ func TestResult(outer *testing.T) {
conn := &ConnFake{
Nexts: []Next{{Summary: sums[0]}},
}
res := newResultWithContext(conn, streamHandle, cypher, params, nil)
res := newResultWithContext(conn, streamHandle, cypher, params, &transactionState{}, nil)
coll, err := res.Collect(ctx)
AssertNoError(t, err)
AssertLen(t, coll, 0)
Expand All @@ -354,7 +354,7 @@ func TestResult(outer *testing.T) {
conn := &ConnFake{
Nexts: []Next{{Summary: sums[0]}},
}
res := newResultWithContext(conn, streamHandle, cypher, params, nil)
res := newResultWithContext(conn, streamHandle, cypher, params, &transactionState{}, nil)
res.Next(ctx)
AssertNil(t, res.Record())
coll, err := res.Collect(ctx)
Expand All @@ -368,7 +368,7 @@ func TestResult(outer *testing.T) {
conn := &ConnFake{
Nexts: []Next{{Err: errs[0]}},
}
res := newResultWithContext(conn, streamHandle, cypher, params, nil)
res := newResultWithContext(conn, streamHandle, cypher, params, &transactionState{}, nil)
coll, err := res.Collect(ctx)
AssertError(t, err)
AssertLen(t, coll, 0)
Expand All @@ -380,7 +380,7 @@ func TestResult(outer *testing.T) {
conn := &ConnFake{
Nexts: []Next{{Record: recs[0]}, {Err: errs[0]}},
}
res := newResultWithContext(conn, streamHandle, cypher, params, nil)
res := newResultWithContext(conn, streamHandle, cypher, params, &transactionState{}, nil)
coll, err := res.Collect(ctx)
AssertError(t, err)
AssertLen(t, coll, 0)
Expand Down Expand Up @@ -469,7 +469,7 @@ func TestResult(outer *testing.T) {

for _, testCase := range testCases {
inner.Run(testCase.scenario, func(t *testing.T) {
result := &resultWithContext{summary: &db.Summary{}}
result := &resultWithContext{summary: &db.Summary{}, txState: &transactionState{}}

err := testCase.callback(t, result)

Expand Down Expand Up @@ -534,6 +534,7 @@ func TestResult(outer *testing.T) {
conn: &ConnFake{
Nexts: []Next{{Record: record1}, {Summary: sums[0]}},
},
txState: &transactionState{},
afterConsumptionHook: func() {
count++
}}
Expand Down
35 changes: 21 additions & 14 deletions neo4j/session_with_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,22 +336,29 @@ func (s *sessionWithContext) BeginTransaction(ctx context.Context, configurers .
}

// Create transaction wrapper
s.explicitTx = &explicitTransaction{
txState := &transactionState{}
tx := &explicitTransaction{
conn: conn,
fetchSize: s.fetchSize,
txHandle: txHandle,
onClosed: func(tx *explicitTransaction) {
if tx.conn == nil {
return
}
// On run failure, transaction closed (rolled back or committed)
bookmarkErr := s.retrieveBookmarks(ctx, tx.conn, beginBookmarks)
s.pool.Return(ctx, tx.conn)
tx.err = errorutil.CombineAllErrors(tx.err, bookmarkErr)
tx.conn = nil
s.explicitTx = nil
},
txState: txState,
}

onClose := func() {
if tx.conn == nil {
return
}
// On run failure, transaction closed (rolled back or committed)
bookmarkErr := s.retrieveBookmarks(ctx, tx.conn, beginBookmarks)
s.pool.Return(ctx, tx.conn)
tx.txState.err = errorutil.CombineAllErrors(tx.txState.err, bookmarkErr)
tx.conn = nil
s.explicitTx = nil
}
tx.onClosed = onClose
txState.resultErrorHandlers = append(txState.resultErrorHandlers, func(error) { onClose() })

s.explicitTx = tx

return s.explicitTx, nil
}
Expand Down Expand Up @@ -477,7 +484,7 @@ func (s *sessionWithContext) executeTransactionFunction(
return false, nil
}

tx := managedTransaction{conn: conn, fetchSize: s.fetchSize, txHandle: txHandle}
tx := managedTransaction{conn: conn, fetchSize: s.fetchSize, txHandle: txHandle, txState: &transactionState{}}
x, err := work(&tx)
if err != nil {
// If the client returns a client specific error that means that
Expand Down Expand Up @@ -641,7 +648,7 @@ func (s *sessionWithContext) Run(ctx context.Context,

s.autocommitTx = &autocommitTransaction{
conn: conn,
res: newResultWithContext(conn, stream, cypher, params, func() {
res: newResultWithContext(conn, stream, cypher, params, &transactionState{}, func() {
if err := s.retrieveBookmarks(ctx, conn, runBookmarks); err != nil {
s.log.Warnf(log.Session, s.logId, "could not retrieve bookmarks after result consumption: %s\n"+
"the result of the initiating auto-commit transaction may not be visible to subsequent operations", err.Error())
Expand Down
Loading

0 comments on commit 4d8aa8f

Please sign in to comment.