From 5a2db5e72aba83e2d15474ab8a8484049d6b82ae Mon Sep 17 00:00:00 2001 From: Dmitriy Tverdiakov <11927660+injectives@users.noreply.github.com> Date: Tue, 9 Nov 2021 11:08:45 +0000 Subject: [PATCH 1/2] Make UnmanagedTransaction return ongoing tx completion stage (#1057) This update ensures that `UnmanagedTransaction` returns existing on-going tx completion stage when a similar request is made. For instance, if it was requested to be rolled back and then requested to be closed, both invocations should get the same on-going stage. In addition, it should not accept conflicting actions, like committing and rolling back at the same time. In addition, it makes sure that cancellation on reactive transaction function results in rollback. --- .../internal/async/UnmanagedTransaction.java | 300 ++++++++++-------- .../internal/reactive/InternalRxSession.java | 2 +- .../neo4j/driver/internal/util/Futures.java | 15 + .../async/UnmanagedTransactionTest.java | 165 ++++++++++ 4 files changed, 341 insertions(+), 141 deletions(-) diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java b/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java index a46b4d628a..03fdb0e7ca 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java @@ -20,9 +20,13 @@ import java.util.Arrays; import java.util.EnumSet; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; import java.util.function.BiFunction; +import java.util.function.Function; import org.neo4j.driver.Bookmark; import org.neo4j.driver.Query; @@ -37,91 +41,57 @@ import org.neo4j.driver.internal.cursor.RxResultCursor; import org.neo4j.driver.internal.messaging.BoltProtocol; import org.neo4j.driver.internal.spi.Connection; -import org.neo4j.driver.internal.util.Futures; +import static org.neo4j.driver.internal.util.Futures.asCompletionException; +import static org.neo4j.driver.internal.util.Futures.combineErrors; import static org.neo4j.driver.internal.util.Futures.completedWithNull; import static org.neo4j.driver.internal.util.Futures.failedFuture; +import static org.neo4j.driver.internal.util.Futures.futureCompletingConsumer; +import static org.neo4j.driver.internal.util.LockUtil.executeWithLock; public class UnmanagedTransaction { private enum State { - /** The transaction is running with no explicit success or failure marked */ + /** + * The transaction is running with no explicit success or failure marked + */ ACTIVE, /** - * This transaction has been terminated either because of explicit {@link Session#reset()} or because of a - * fatal connection error. + * This transaction has been terminated either because of explicit {@link Session#reset()} or because of a fatal connection error. */ TERMINATED, - /** This transaction has successfully committed */ - COMMITTED, - - /** This transaction has been rolled back */ - ROLLED_BACK - } - - /** - * This is a holder so that we can have ony the state volatile in the tx without having to synchronize the whole block. - */ - private static final class StateHolder - { - private static final EnumSet OPEN_STATES = EnumSet.of( State.ACTIVE, State.TERMINATED ); - private static final StateHolder ACTIVE_HOLDER = new StateHolder( State.ACTIVE, null ); - private static final StateHolder COMMITTED_HOLDER = new StateHolder( State.COMMITTED, null ); - private static final StateHolder ROLLED_BACK_HOLDER = new StateHolder( State.ROLLED_BACK, null ); - /** - * The actual state. + * This transaction has successfully committed */ - final State value; + COMMITTED, /** - * If this holder contains a state of {@link State#TERMINATED}, this represents the cause if any. + * This transaction has been rolled back */ - final Throwable causeOfTermination; - - static StateHolder of( State value ) - { - switch ( value ) - { - case ACTIVE: - return ACTIVE_HOLDER; - case COMMITTED: - return COMMITTED_HOLDER; - case ROLLED_BACK: - return ROLLED_BACK_HOLDER; - case TERMINATED: - default: - throw new IllegalArgumentException( "Cannot provide a default state holder for state " + value ); - } - } - - static StateHolder terminatedWith( Throwable cause ) - { - return new StateHolder( State.TERMINATED, cause ); - } - - private StateHolder( State value, Throwable causeOfTermination ) - { - this.value = value; - this.causeOfTermination = causeOfTermination; - } - - boolean isOpen() - { - return OPEN_STATES.contains( this.value ); - } + ROLLED_BACK } + protected static final String CANT_COMMIT_COMMITTED_MSG = "Can't commit, transaction has been committed"; + protected static final String CANT_ROLLBACK_COMMITTED_MSG = "Can't rollback, transaction has been committed"; + protected static final String CANT_COMMIT_ROLLED_BACK_MSG = "Can't commit, transaction has been rolled back"; + protected static final String CANT_ROLLBACK_ROLLED_BACK_MSG = "Can't rollback, transaction has been rolled back"; + protected static final String CANT_COMMIT_ROLLING_BACK_MSG = "Can't commit, transaction has been requested to be rolled back"; + protected static final String CANT_ROLLBACK_COMMITTING_MSG = "Can't rollback, transaction has been requested to be committed"; + private static final EnumSet OPEN_STATES = EnumSet.of( State.ACTIVE, State.TERMINATED ); + private final Connection connection; private final BoltProtocol protocol; private final BookmarkHolder bookmarkHolder; private final ResultCursorsHolder resultCursors; private final long fetchSize; - - private volatile StateHolder state = StateHolder.of( State.ACTIVE ); + private final Lock lock = new ReentrantLock(); + private State state = State.ACTIVE; + private CompletableFuture commitFuture; + private CompletableFuture rollbackFuture; + private Throwable causeOfTermination; public UnmanagedTransaction( Connection connection, BookmarkHolder bookmarkHolder, long fetchSize ) { @@ -156,7 +126,7 @@ else if ( beginError instanceof ConnectionReadTimeoutException ) { connection.release(); } - throw Futures.asCompletionException( beginError ); + throw asCompletionException( beginError ); } return this; } ); @@ -164,50 +134,17 @@ else if ( beginError instanceof ConnectionReadTimeoutException ) public CompletionStage closeAsync() { - if ( isOpen() ) - { - return rollbackAsync(); - } - else - { - return completedWithNull(); - } + return closeAsync( false, true ); } public CompletionStage commitAsync() { - if ( state.value == State.COMMITTED ) - { - return failedFuture( new ClientException( "Can't commit, transaction has been committed" ) ); - } - else if ( state.value == State.ROLLED_BACK ) - { - return failedFuture( new ClientException( "Can't commit, transaction has been rolled back" ) ); - } - else - { - return resultCursors.retrieveNotConsumedError() - .thenCompose( error -> doCommitAsync( error ).handle( handleCommitOrRollback( error ) ) ) - .whenComplete( ( ignore, error ) -> handleTransactionCompletion( true, error ) ); - } + return closeAsync( true, false ); } public CompletionStage rollbackAsync() { - if ( state.value == State.COMMITTED ) - { - return failedFuture( new ClientException( "Can't rollback, transaction has been committed" ) ); - } - else if ( state.value == State.ROLLED_BACK ) - { - return failedFuture( new ClientException( "Can't rollback, transaction has been rolled back" ) ); - } - else - { - return resultCursors.retrieveNotConsumedError() - .thenCompose( error -> doRollbackAsync().handle( handleCommitOrRollback( error ) ) ) - .whenComplete( ( ignore, error ) -> handleTransactionCompletion( false, error ) ); - } + return closeAsync( false, false ); } public CompletionStage runAsync( Query query ) @@ -219,7 +156,7 @@ public CompletionStage runAsync( Query query ) return cursorStage.thenCompose( AsyncResultCursor::mapSuccessfulRunCompletionAsync ).thenApply( cursor -> cursor ); } - public CompletionStage runRx(Query query) + public CompletionStage runRx( Query query ) { ensureCanRunQueries(); CompletionStage cursorStage = @@ -230,22 +167,26 @@ public CompletionStage runRx(Query query) public boolean isOpen() { - return state.isOpen(); + return OPEN_STATES.contains( executeWithLock( lock, () -> state ) ); } public void markTerminated( Throwable cause ) { - if ( state.value == State.TERMINATED ) + executeWithLock( lock, () -> { - if ( state.causeOfTermination != null ) + if ( state == State.TERMINATED ) { - addSuppressedWhenNotCaptured( state.causeOfTermination, cause ); + if ( causeOfTermination != null ) + { + addSuppressedWhenNotCaptured( causeOfTermination, cause ); + } } - } - else - { - state = StateHolder.terminatedWith( cause ); - } + else + { + state = State.TERMINATED; + causeOfTermination = cause; + } + } ); } private void addSuppressedWhenNotCaptured( Throwable currentCause, Throwable newCause ) @@ -267,46 +208,46 @@ public Connection connection() private void ensureCanRunQueries() { - if ( state.value == State.COMMITTED ) - { - throw new ClientException( "Cannot run more queries in this transaction, it has been committed" ); - } - else if ( state.value == State.ROLLED_BACK ) + executeWithLock( lock, () -> { - throw new ClientException( "Cannot run more queries in this transaction, it has been rolled back" ); - } - else if ( state.value == State.TERMINATED ) - { - throw new ClientException( "Cannot run more queries in this transaction, " + - "it has either experienced an fatal error or was explicitly terminated", state.causeOfTermination ); - } + if ( state == State.COMMITTED ) + { + throw new ClientException( "Cannot run more queries in this transaction, it has been committed" ); + } + else if ( state == State.ROLLED_BACK ) + { + throw new ClientException( "Cannot run more queries in this transaction, it has been rolled back" ); + } + else if ( state == State.TERMINATED ) + { + throw new ClientException( "Cannot run more queries in this transaction, " + + "it has either experienced an fatal error or was explicitly terminated", causeOfTermination ); + } + } ); } private CompletionStage doCommitAsync( Throwable cursorFailure ) { - if ( state.value == State.TERMINATED ) - { - return failedFuture( new ClientException( "Transaction can't be committed. " + - "It has been rolled back either because of an error or explicit termination", - cursorFailure != state.causeOfTermination ? state.causeOfTermination : null ) ); - } - return protocol.commitTransaction( connection ).thenAccept( bookmarkHolder::setBookmark ); + ClientException exception = executeWithLock( + lock, () -> state == State.TERMINATED + ? new ClientException( "Transaction can't be committed. " + + "It has been rolled back either because of an error or explicit termination", + cursorFailure != causeOfTermination ? causeOfTermination : null ) + : null + ); + return exception != null ? failedFuture( exception ) : protocol.commitTransaction( connection ).thenAccept( bookmarkHolder::setBookmark ); } private CompletionStage doRollbackAsync() { - if ( state.value == State.TERMINATED ) - { - return completedWithNull(); - } - return protocol.rollbackTransaction( connection ); + return executeWithLock( lock, () -> state ) == State.TERMINATED ? completedWithNull() : protocol.rollbackTransaction( connection ); } private static BiFunction handleCommitOrRollback( Throwable cursorFailure ) { return ( ignore, commitOrRollbackError ) -> { - CompletionException combinedError = Futures.combineErrors( cursorFailure, commitOrRollbackError ); + CompletionException combinedError = combineErrors( cursorFailure, commitOrRollbackError ); if ( combinedError != null ) { throw combinedError; @@ -315,17 +256,19 @@ private static BiFunction handleCommitOrRollback( Throwable }; } - private void handleTransactionCompletion( boolean commitOnSuccess, Throwable throwable ) + private void handleTransactionCompletion( boolean commitAttempt, Throwable throwable ) { - if ( commitOnSuccess && throwable == null ) - { - state = StateHolder.of( State.COMMITTED ); - } - else + executeWithLock( lock, () -> { - state = StateHolder.of( State.ROLLED_BACK ); - } - + if ( commitAttempt && throwable == null ) + { + state = State.COMMITTED; + } + else + { + state = State.ROLLED_BACK; + } + } ); if ( throwable instanceof AuthorizationExpiredException ) { connection.terminateAndRelease( AuthorizationExpiredException.DESCRIPTION ); @@ -339,4 +282,81 @@ else if ( throwable instanceof ConnectionReadTimeoutException ) connection.release(); // release in background } } + + private CompletionStage closeAsync( boolean commit, boolean completeWithNullIfNotOpen ) + { + CompletionStage stage = executeWithLock( lock, () -> + { + CompletionStage resultStage = null; + if ( completeWithNullIfNotOpen && !isOpen() ) + { + resultStage = completedWithNull(); + } + else if ( state == State.COMMITTED ) + { + resultStage = failedFuture( new ClientException( commit ? CANT_COMMIT_COMMITTED_MSG : CANT_ROLLBACK_COMMITTED_MSG ) ); + } + else if ( state == State.ROLLED_BACK ) + { + resultStage = failedFuture( new ClientException( commit ? CANT_COMMIT_ROLLED_BACK_MSG : CANT_ROLLBACK_ROLLED_BACK_MSG ) ); + } + else + { + if ( commit ) + { + if ( rollbackFuture != null ) + { + resultStage = failedFuture( new ClientException( CANT_COMMIT_ROLLING_BACK_MSG ) ); + } + else if ( commitFuture != null ) + { + resultStage = commitFuture; + } + else + { + commitFuture = new CompletableFuture<>(); + } + } + else + { + if ( commitFuture != null ) + { + resultStage = failedFuture( new ClientException( CANT_ROLLBACK_COMMITTING_MSG ) ); + } + else if ( rollbackFuture != null ) + { + resultStage = rollbackFuture; + } + else + { + rollbackFuture = new CompletableFuture<>(); + } + } + } + return resultStage; + } ); + + if ( stage == null ) + { + CompletableFuture targetFuture; + Function> targetAction; + if ( commit ) + { + targetFuture = commitFuture; + targetAction = throwable -> doCommitAsync( throwable ).handle( handleCommitOrRollback( throwable ) ); + } + else + { + targetFuture = rollbackFuture; + targetAction = throwable -> doRollbackAsync().handle( handleCommitOrRollback( throwable ) ); + } + resultCursors.retrieveNotConsumedError() + .thenCompose( targetAction ) + .whenComplete( ( ignored, throwable ) -> handleTransactionCompletion( commit, throwable ) ) + .whenComplete( futureCompletingConsumer( targetFuture ) ); + stage = targetFuture; + } + + return stage; + } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxSession.java b/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxSession.java index 110892c026..41f70d0a1d 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxSession.java +++ b/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxSession.java @@ -130,7 +130,7 @@ public Publisher writeTransaction( RxTransactionWork Publisher runTransaction( AccessMode mode, RxTransactionWork> work, TransactionConfig config ) { Flux repeatableWork = Flux.usingWhen( beginTransaction( mode, config ), work::execute, - InternalRxTransaction::commitIfOpen, ( tx, error ) -> tx.close(), null ); + InternalRxTransaction::commitIfOpen, ( tx, error ) -> tx.close(), InternalRxTransaction::close ); return session.retryLogic().retryRx( repeatableWork ); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/util/Futures.java b/driver/src/main/java/org/neo4j/driver/internal/util/Futures.java index 24ed13c879..56b714df1c 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/util/Futures.java +++ b/driver/src/main/java/org/neo4j/driver/internal/util/Futures.java @@ -256,6 +256,21 @@ public static CompletableFuture onErrorContinue( CompletableFuture fut } ); } + public static BiConsumer futureCompletingConsumer( CompletableFuture future ) + { + return ( value, throwable ) -> + { + if ( throwable != null ) + { + future.completeExceptionally( throwable ); + } + else + { + future.complete( value ); + } + }; + } + private static class CompletionResult { T value; diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java index b8909e176f..6a29c338bf 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java @@ -19,9 +19,17 @@ package org.neo4j.driver.internal.async; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.mockito.InOrder; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.ExecutionException; import java.util.function.Consumer; +import java.util.function.Supplier; +import java.util.stream.Stream; import org.neo4j.driver.Bookmark; import org.neo4j.driver.Query; @@ -32,6 +40,7 @@ import org.neo4j.driver.internal.DefaultBookmarkHolder; import org.neo4j.driver.internal.FailableCursor; import org.neo4j.driver.internal.InternalBookmark; +import org.neo4j.driver.internal.messaging.BoltProtocol; import org.neo4j.driver.internal.messaging.v4.BoltProtocolV4; import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.spi.ResponseHandler; @@ -40,16 +49,21 @@ import static java.util.concurrent.CompletableFuture.completedFuture; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.then; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.neo4j.driver.internal.handlers.pulln.FetchSizeUtil.UNLIMITED_FETCH_SIZE; import static org.neo4j.driver.util.TestUtil.assertNoCircularReferences; @@ -311,6 +325,127 @@ void shouldReleaseConnectionOnConnectionReadTimeoutExceptionFailure() verify( connection, never() ).release(); } + private static Stream similarTransactionCompletingActionArgs() + { + return Stream.of( + Arguments.of( true, "commit", "commit" ), + + Arguments.of( false, "rollback", "rollback" ), + Arguments.of( false, "rollback", "close" ), + + Arguments.of( false, "close", "rollback" ), + Arguments.of( false, "close", "close" ) + ); + } + + @ParameterizedTest + @MethodSource( "similarTransactionCompletingActionArgs" ) + void shouldReturnExistingStageOnSimilarCompletingAction( boolean protocolCommit, String initialAction, String similarAction ) + { + Connection connection = mock( Connection.class ); + BoltProtocol protocol = mock( BoltProtocol.class ); + given( connection.protocol() ).willReturn( protocol ); + given( protocolCommit ? protocol.commitTransaction( connection ) : protocol.rollbackTransaction( connection ) ).willReturn( new CompletableFuture<>() ); + UnmanagedTransaction tx = new UnmanagedTransaction( connection, new DefaultBookmarkHolder(), UNLIMITED_FETCH_SIZE ); + + CompletionStage initialStage = mapTransactionAction( initialAction, tx ).get(); + CompletionStage similarStage = mapTransactionAction( similarAction, tx ).get(); + + assertSame( initialStage, similarStage ); + if ( protocolCommit ) + { + then( protocol ).should( times( 1 ) ).commitTransaction( connection ); + } + else + { + then( protocol ).should( times( 1 ) ).rollbackTransaction( connection ); + } + } + + private static Stream conflictingTransactionCompletingActionArgs() + { + return Stream.of( + Arguments.of( true, true, "commit", "commit", UnmanagedTransaction.CANT_COMMIT_COMMITTED_MSG ), + Arguments.of( true, true, "commit", "rollback", UnmanagedTransaction.CANT_ROLLBACK_COMMITTED_MSG ), + Arguments.of( true, false, "commit", "rollback", UnmanagedTransaction.CANT_ROLLBACK_COMMITTING_MSG ), + Arguments.of( true, false, "commit", "close", UnmanagedTransaction.CANT_ROLLBACK_COMMITTING_MSG ), + + Arguments.of( false, true, "rollback", "rollback", UnmanagedTransaction.CANT_ROLLBACK_ROLLED_BACK_MSG ), + Arguments.of( false, true, "rollback", "commit", UnmanagedTransaction.CANT_COMMIT_ROLLED_BACK_MSG ), + Arguments.of( false, false, "rollback", "commit", UnmanagedTransaction.CANT_COMMIT_ROLLING_BACK_MSG ), + + Arguments.of( false, true, "close", "commit", UnmanagedTransaction.CANT_COMMIT_ROLLED_BACK_MSG ), + Arguments.of( false, true, "close", "rollback", UnmanagedTransaction.CANT_ROLLBACK_ROLLED_BACK_MSG ), + Arguments.of( false, false, "close", "commit", UnmanagedTransaction.CANT_COMMIT_ROLLING_BACK_MSG ) + ); + } + + @ParameterizedTest + @MethodSource( "conflictingTransactionCompletingActionArgs" ) + void shouldReturnFailingStageOnConflictingCompletingAction( boolean protocolCommit, boolean protocolActionCompleted, String initialAction, + String conflictingAction, String expectedErrorMsg ) + { + Connection connection = mock( Connection.class ); + BoltProtocol protocol = mock( BoltProtocol.class ); + given( connection.protocol() ).willReturn( protocol ); + given( protocolCommit ? protocol.commitTransaction( connection ) : protocol.rollbackTransaction( connection ) ) + .willReturn( protocolActionCompleted ? completedFuture( null ) : new CompletableFuture<>() ); + UnmanagedTransaction tx = new UnmanagedTransaction( connection, new DefaultBookmarkHolder(), UNLIMITED_FETCH_SIZE ); + + CompletionStage originalActionStage = mapTransactionAction( initialAction, tx ).get(); + CompletionStage conflictingActionStage = mapTransactionAction( conflictingAction, tx ).get(); + + assertNotNull( originalActionStage ); + if ( protocolCommit ) + { + then( protocol ).should( times( 1 ) ).commitTransaction( connection ); + } + else + { + then( protocol ).should( times( 1 ) ).rollbackTransaction( connection ); + } + assertTrue( conflictingActionStage.toCompletableFuture().isCompletedExceptionally() ); + Throwable throwable = assertThrows( ExecutionException.class, () -> conflictingActionStage.toCompletableFuture().get() ).getCause(); + assertTrue( throwable instanceof ClientException ); + assertEquals( expectedErrorMsg, throwable.getMessage() ); + } + + private static Stream closingNotActionTransactionArgs() + { + return Stream.of( + Arguments.of( true, 1, "commit" ), + Arguments.of( false, 1, "rollback" ), + Arguments.of( false, 0, "terminate" ) + ); + } + + @ParameterizedTest + @MethodSource( "closingNotActionTransactionArgs" ) + void shouldReturnCompletedWithNullStageOnClosingNotActiveTransaction( boolean protocolCommit, int expectedProtocolInvocations, String originalAction ) + { + Connection connection = mock( Connection.class ); + BoltProtocol protocol = mock( BoltProtocol.class ); + given( connection.protocol() ).willReturn( protocol ); + given( protocolCommit ? protocol.commitTransaction( connection ) : protocol.rollbackTransaction( connection ) ) + .willReturn( completedFuture( null ) ); + UnmanagedTransaction tx = new UnmanagedTransaction( connection, new DefaultBookmarkHolder(), UNLIMITED_FETCH_SIZE ); + + CompletionStage originalActionStage = mapTransactionAction( originalAction, tx ).get(); + CompletionStage closeStage = tx.closeAsync(); + + assertTrue( originalActionStage.toCompletableFuture().isDone() ); + assertFalse( originalActionStage.toCompletableFuture().isCompletedExceptionally() ); + if ( protocolCommit ) + { + then( protocol ).should( times( expectedProtocolInvocations ) ).commitTransaction( connection ); + } + else + { + then( protocol ).should( times( expectedProtocolInvocations ) ).rollbackTransaction( connection ); + } + assertNull( closeStage.toCompletableFuture().join() ); + } + private static UnmanagedTransaction beginTx( Connection connection ) { return beginTx( connection, InternalBookmark.empty() ); @@ -346,4 +481,34 @@ private ResultCursorsHolder mockResultCursorWith( ClientException clientExceptio resultCursorsHolder.add( completedFuture( cursor ) ); return resultCursorsHolder; } + + private Supplier> mapTransactionAction( String actionName, UnmanagedTransaction tx ) + { + Supplier> action; + if ( "commit".equals( actionName ) ) + { + action = tx::commitAsync; + } + else if ( "rollback".equals( actionName ) ) + { + action = tx::rollbackAsync; + } + else if ( "terminate".equals( actionName ) ) + { + action = () -> + { + tx.markTerminated( mock( Throwable.class ) ); + return completedFuture( null ); + }; + } + else if ( "close".equals( actionName ) ) + { + action = tx::closeAsync; + } + else + { + throw new RuntimeException( String.format( "Unknown completing action type '%s'", actionName ) ); + } + return action; + } } From f6b2e907566745db11889995cc5613dbe05c01b0 Mon Sep 17 00:00:00 2001 From: Dmitriy Tverdiakov <11927660+injectives@users.noreply.github.com> Date: Wed, 10 Nov 2021 14:16:15 +0000 Subject: [PATCH 2/2] Call close with the appropriate flag to commit or rollback on UnmanagedTransaction where possible to avoid double state acquisition (#1065) * Call close with the appropriate flag to commit or rollback on UnmanagedTransaction where possible to avoid double state acquisition Calling `close` instead of separate `isOpen` and `commitAsync` requires less lock acquisitions and is safer. * Update tests --- .../internal/async/InternalAsyncSession.java | 54 ++++++++---------- .../internal/async/UnmanagedTransaction.java | 7 ++- .../internal/reactive/InternalRxSession.java | 2 +- .../reactive/InternalRxTransaction.java | 9 ++- .../async/UnmanagedTransactionTest.java | 16 ++++-- .../reactive/InternalRxSessionTest.java | 56 +++++++++---------- .../reactive/InternalRxTransactionTest.java | 28 ++-------- 7 files changed, 77 insertions(+), 95 deletions(-) diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/InternalAsyncSession.java b/driver/src/main/java/org/neo4j/driver/internal/async/InternalAsyncSession.java index 83f0f7a862..efc291933b 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/InternalAsyncSession.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/InternalAsyncSession.java @@ -146,7 +146,7 @@ private void executeWork(CompletableFuture resultFuture, UnmanagedTransac Throwable error = Futures.completionExceptionCause( completionError ); if ( error != null ) { - rollbackTxAfterFailedTransactionWork( tx, resultFuture, error ); + closeTxAfterFailedTransactionWork( tx, resultFuture, error ); } else { @@ -174,43 +174,33 @@ private CompletionStage safeExecuteWork(UnmanagedTransaction tx, AsyncTra } } - private void rollbackTxAfterFailedTransactionWork(UnmanagedTransaction tx, CompletableFuture resultFuture, Throwable error ) + private void closeTxAfterFailedTransactionWork( UnmanagedTransaction tx, CompletableFuture resultFuture, Throwable error ) { - if ( tx.isOpen() ) - { - tx.rollbackAsync().whenComplete( ( ignore, rollbackError ) -> { - if ( rollbackError != null ) + tx.closeAsync().whenComplete( + ( ignored, rollbackError ) -> { - error.addSuppressed( rollbackError ); - } - resultFuture.completeExceptionally( error ); - } ); - } - else - { - resultFuture.completeExceptionally( error ); - } + if ( rollbackError != null ) + { + error.addSuppressed( rollbackError ); + } + resultFuture.completeExceptionally( error ); + } ); } private void closeTxAfterSucceededTransactionWork(UnmanagedTransaction tx, CompletableFuture resultFuture, T result ) { - if ( tx.isOpen() ) - { - tx.commitAsync().whenComplete( ( ignore, completionError ) -> { - Throwable commitError = Futures.completionExceptionCause( completionError ); - if ( commitError != null ) + tx.closeAsync( true ).whenComplete( + ( ignored, completionError ) -> { - resultFuture.completeExceptionally( commitError ); - } - else - { - resultFuture.complete( result ); - } - } ); - } - else - { - resultFuture.complete( result ); - } + Throwable commitError = Futures.completionExceptionCause( completionError ); + if ( commitError != null ) + { + resultFuture.completeExceptionally( commitError ); + } + else + { + resultFuture.complete( result ); + } + } ); } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java b/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java index 03fdb0e7ca..fbd7a985c7 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java @@ -134,7 +134,12 @@ else if ( beginError instanceof ConnectionReadTimeoutException ) public CompletionStage closeAsync() { - return closeAsync( false, true ); + return closeAsync( false ); + } + + public CompletionStage closeAsync( boolean commit ) + { + return closeAsync( commit, true ); } public CompletionStage commitAsync() diff --git a/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxSession.java b/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxSession.java index 41f70d0a1d..222b64562d 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxSession.java +++ b/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxSession.java @@ -130,7 +130,7 @@ public Publisher writeTransaction( RxTransactionWork Publisher runTransaction( AccessMode mode, RxTransactionWork> work, TransactionConfig config ) { Flux repeatableWork = Flux.usingWhen( beginTransaction( mode, config ), work::execute, - InternalRxTransaction::commitIfOpen, ( tx, error ) -> tx.close(), InternalRxTransaction::close ); + tx -> tx.close( true ), ( tx, error ) -> tx.close(), InternalRxTransaction::close ); return session.retryLogic().retryRx( repeatableWork ); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxTransaction.java b/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxTransaction.java index c1a9267336..b4212ae963 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxTransaction.java +++ b/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxTransaction.java @@ -30,7 +30,6 @@ import org.neo4j.driver.reactive.RxTransaction; import static org.neo4j.driver.internal.reactive.RxUtils.createEmptyPublisher; -import static org.neo4j.driver.internal.util.Futures.completedWithNull; public class InternalRxTransaction extends AbstractRxQueryRunner implements RxTransaction { @@ -77,13 +76,13 @@ public Publisher rollback() return createEmptyPublisher( tx::rollbackAsync ); } - Publisher commitIfOpen() + Publisher close() { - return createEmptyPublisher( () -> tx.isOpen() ? tx.commitAsync() : completedWithNull() ); + return close( false ); } - Publisher close() + Publisher close( boolean commit ) { - return createEmptyPublisher( tx::closeAsync ); + return createEmptyPublisher( () -> tx.closeAsync( commit ) ); } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java index 6a29c338bf..f639565a66 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java @@ -413,15 +413,21 @@ void shouldReturnFailingStageOnConflictingCompletingAction( boolean protocolComm private static Stream closingNotActionTransactionArgs() { return Stream.of( - Arguments.of( true, 1, "commit" ), - Arguments.of( false, 1, "rollback" ), - Arguments.of( false, 0, "terminate" ) + Arguments.of( true, 1, "commit", null ), + Arguments.of( false, 1, "rollback", null ), + Arguments.of( false, 0, "terminate", null ), + Arguments.of( true, 1, "commit", true ), + Arguments.of( false, 1, "rollback", true ), + Arguments.of( true, 1, "commit", false ), + Arguments.of( false, 1, "rollback", false ), + Arguments.of( false, 0, "terminate", false ) ); } @ParameterizedTest @MethodSource( "closingNotActionTransactionArgs" ) - void shouldReturnCompletedWithNullStageOnClosingNotActiveTransaction( boolean protocolCommit, int expectedProtocolInvocations, String originalAction ) + void shouldReturnCompletedWithNullStageOnClosingInactiveTransactionExceptCommittingAborted( + boolean protocolCommit, int expectedProtocolInvocations, String originalAction, Boolean commitOnClose ) { Connection connection = mock( Connection.class ); BoltProtocol protocol = mock( BoltProtocol.class ); @@ -431,7 +437,7 @@ void shouldReturnCompletedWithNullStageOnClosingNotActiveTransaction( boolean pr UnmanagedTransaction tx = new UnmanagedTransaction( connection, new DefaultBookmarkHolder(), UNLIMITED_FETCH_SIZE ); CompletionStage originalActionStage = mapTransactionAction( originalAction, tx ).get(); - CompletionStage closeStage = tx.closeAsync(); + CompletionStage closeStage = commitOnClose != null ? tx.closeAsync( commitOnClose ) : tx.closeAsync(); assertTrue( originalActionStage.toCompletableFuture().isDone() ); assertFalse( originalActionStage.toCompletableFuture().isCompletedExceptionally() ); diff --git a/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxSessionTest.java b/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxSessionTest.java index 3f320231ab..2ca1ea21fe 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxSessionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxSessionTest.java @@ -199,9 +199,7 @@ void shouldDelegateRunTx( Function> runTx ) throws T // Given NetworkSession session = mock( NetworkSession.class ); UnmanagedTransaction tx = mock( UnmanagedTransaction.class ); - when( tx.isOpen() ).thenReturn( true ); - when( tx.commitAsync() ).thenReturn( completedWithNull() ); - when( tx.rollbackAsync() ).thenReturn( completedWithNull() ); + when( tx.closeAsync( true ) ).thenReturn( completedWithNull() ); when( session.beginTransactionAsync( any( AccessMode.class ), any( TransactionConfig.class ) ) ).thenReturn( completedFuture( tx ) ); when( session.retryLogic() ).thenReturn( new FixedRetryLogic( 1 ) ); @@ -213,7 +211,7 @@ void shouldDelegateRunTx( Function> runTx ) throws T // Then verify( session ).beginTransactionAsync( any( AccessMode.class ), any( TransactionConfig.class ) ); - verify( tx ).commitAsync(); + verify( tx ).closeAsync( true ); } @Test @@ -223,25 +221,24 @@ void shouldRetryOnError() throws Throwable int retryCount = 2; NetworkSession session = mock( NetworkSession.class ); UnmanagedTransaction tx = mock( UnmanagedTransaction.class ); - when( tx.isOpen() ).thenReturn( true ); - when( tx.commitAsync() ).thenReturn( completedWithNull() ); - when( tx.rollbackAsync() ).thenReturn( completedWithNull() ); + when( tx.closeAsync( false ) ).thenReturn( completedWithNull() ); when( session.beginTransactionAsync( any( AccessMode.class ), any( TransactionConfig.class ) ) ).thenReturn( completedFuture( tx ) ); when( session.retryLogic() ).thenReturn( new FixedRetryLogic( retryCount ) ); InternalRxSession rxSession = new InternalRxSession( session ); // When - Publisher strings = rxSession.readTransaction( t -> - Flux.just( "a" ).then( Mono.error( new RuntimeException( "Errored" ) ) ) ); + Publisher strings = rxSession.readTransaction( + t -> + Flux.just( "a" ).then( Mono.error( new RuntimeException( "Errored" ) ) ) ); StepVerifier.create( Flux.from( strings ) ) - // we lost the "a"s too as the user only see the last failure - .expectError( RuntimeException.class ) - .verify(); + // we lost the "a"s too as the user only see the last failure + .expectError( RuntimeException.class ) + .verify(); // Then verify( session, times( retryCount + 1 ) ).beginTransactionAsync( any( AccessMode.class ), any( TransactionConfig.class ) ); - verify( tx, times( retryCount + 1 ) ).closeAsync(); + verify( tx, times( retryCount + 1 ) ).closeAsync( false ); } @Test @@ -251,9 +248,8 @@ void shouldObtainResultIfRetrySucceed() throws Throwable int retryCount = 2; NetworkSession session = mock( NetworkSession.class ); UnmanagedTransaction tx = mock( UnmanagedTransaction.class ); - when( tx.isOpen() ).thenReturn( true ); - when( tx.commitAsync() ).thenReturn( completedWithNull() ); - when( tx.rollbackAsync() ).thenReturn( completedWithNull() ); + when( tx.closeAsync( false ) ).thenReturn( completedWithNull() ); + when( tx.closeAsync( true ) ).thenReturn( completedWithNull() ); when( session.beginTransactionAsync( any( AccessMode.class ), any( TransactionConfig.class ) ) ).thenReturn( completedFuture( tx ) ); when( session.retryLogic() ).thenReturn( new FixedRetryLogic( retryCount ) ); @@ -261,23 +257,25 @@ void shouldObtainResultIfRetrySucceed() throws Throwable // When AtomicInteger count = new AtomicInteger(); - Publisher strings = rxSession.readTransaction( t -> { - // we fail for the first few retries, and then success on the last run. - if ( count.getAndIncrement() == retryCount ) - { - return Flux.just( "a" ); - } - else - { - return Flux.just( "a" ).then( Mono.error( new RuntimeException( "Errored" ) ) ); - } - } ); + Publisher strings = rxSession.readTransaction( + t -> + { + // we fail for the first few retries, and then success on the last run. + if ( count.getAndIncrement() == retryCount ) + { + return Flux.just( "a" ); + } + else + { + return Flux.just( "a" ).then( Mono.error( new RuntimeException( "Errored" ) ) ); + } + } ); StepVerifier.create( Flux.from( strings ) ).expectNext( "a" ).verifyComplete(); // Then verify( session, times( retryCount + 1 ) ).beginTransactionAsync( any( AccessMode.class ), any( TransactionConfig.class ) ); - verify( tx, times( retryCount ) ).closeAsync(); - verify( tx ).commitAsync(); + verify( tx, times( retryCount ) ).closeAsync( false ); + verify( tx ).closeAsync( true ); } @Test diff --git a/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxTransactionTest.java b/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxTransactionTest.java index 1accde96db..5a9f0bb4b6 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxTransactionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxTransactionTest.java @@ -48,7 +48,6 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.neo4j.driver.Values.parameters; @@ -140,43 +139,28 @@ void shouldMarkTxIfFailedToRun( Function runReturnOne ) } @Test - void shouldCommitWhenOpen() + void shouldDelegateConditionalClose() { UnmanagedTransaction tx = mock( UnmanagedTransaction.class ); - when( tx.isOpen() ).thenReturn( true ); - when( tx.commitAsync() ).thenReturn( Futures.completedWithNull() ); - - InternalRxTransaction rxTx = new InternalRxTransaction( tx ); - Publisher publisher = rxTx.commitIfOpen(); - StepVerifier.create( publisher ).verifyComplete(); - - verify( tx ).commitAsync(); - } - - @Test - void shouldNotCommitWhenNotOpen() - { - UnmanagedTransaction tx = mock( UnmanagedTransaction.class ); - when( tx.isOpen() ).thenReturn( false ); - when( tx.commitAsync() ).thenReturn( Futures.completedWithNull() ); + when( tx.closeAsync( true ) ).thenReturn( Futures.completedWithNull() ); InternalRxTransaction rxTx = new InternalRxTransaction( tx ); - Publisher publisher = rxTx.commitIfOpen(); + Publisher publisher = rxTx.close( true ); StepVerifier.create( publisher ).verifyComplete(); - verify( tx, never() ).commitAsync(); + verify( tx ).closeAsync( true ); } @Test void shouldDelegateClose() { UnmanagedTransaction tx = mock( UnmanagedTransaction.class ); - when( tx.closeAsync() ).thenReturn( Futures.completedWithNull() ); + when( tx.closeAsync( false ) ).thenReturn( Futures.completedWithNull() ); InternalRxTransaction rxTx = new InternalRxTransaction( tx ); Publisher publisher = rxTx.close(); StepVerifier.create( publisher ).verifyComplete(); - verify( tx ).closeAsync(); + verify( tx ).closeAsync( false ); } }