diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/BoltConnectionWithAuthTokenManager.java b/driver/src/main/java/org/neo4j/driver/internal/async/BoltConnectionWithAuthTokenManager.java index 3cf7cd254..0e02bfacc 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/BoltConnectionWithAuthTokenManager.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/BoltConnectionWithAuthTokenManager.java @@ -16,271 +16,26 @@ */ package org.neo4j.driver.internal.async; -import java.time.Duration; -import java.util.Map; import java.util.Objects; -import java.util.Set; import java.util.concurrent.CompletionStage; import org.neo4j.driver.AuthTokenManager; -import org.neo4j.driver.Value; import org.neo4j.driver.exceptions.SecurityException; import org.neo4j.driver.exceptions.SecurityRetryableException; -import org.neo4j.driver.internal.bolt.api.AccessMode; -import org.neo4j.driver.internal.bolt.api.AuthData; import org.neo4j.driver.internal.bolt.api.BoltConnection; -import org.neo4j.driver.internal.bolt.api.BoltConnectionState; -import org.neo4j.driver.internal.bolt.api.BoltProtocolVersion; -import org.neo4j.driver.internal.bolt.api.BoltServerAddress; -import org.neo4j.driver.internal.bolt.api.DatabaseName; -import org.neo4j.driver.internal.bolt.api.NotificationConfig; import org.neo4j.driver.internal.bolt.api.ResponseHandler; -import org.neo4j.driver.internal.bolt.api.TelemetryApi; -import org.neo4j.driver.internal.bolt.api.TransactionType; -import org.neo4j.driver.internal.bolt.api.summary.BeginSummary; -import org.neo4j.driver.internal.bolt.api.summary.CommitSummary; -import org.neo4j.driver.internal.bolt.api.summary.DiscardSummary; -import org.neo4j.driver.internal.bolt.api.summary.LogoffSummary; -import org.neo4j.driver.internal.bolt.api.summary.LogonSummary; -import org.neo4j.driver.internal.bolt.api.summary.PullSummary; -import org.neo4j.driver.internal.bolt.api.summary.ResetSummary; -import org.neo4j.driver.internal.bolt.api.summary.RollbackSummary; -import org.neo4j.driver.internal.bolt.api.summary.RouteSummary; -import org.neo4j.driver.internal.bolt.api.summary.RunSummary; -import org.neo4j.driver.internal.bolt.api.summary.TelemetrySummary; import org.neo4j.driver.internal.security.InternalAuthToken; -public class BoltConnectionWithAuthTokenManager implements BoltConnection { - private final BoltConnection delegate; +final class BoltConnectionWithAuthTokenManager extends DelegatingBoltConnection { private final AuthTokenManager authTokenManager; public BoltConnectionWithAuthTokenManager(BoltConnection delegate, AuthTokenManager authTokenManager) { - this.delegate = Objects.requireNonNull(delegate); + super(delegate); this.authTokenManager = Objects.requireNonNull(authTokenManager); } - @Override - public CompletionStage route( - DatabaseName databaseName, String impersonatedUser, Set bookmarks) { - return delegate.route(databaseName, impersonatedUser, bookmarks).thenApply(ignored -> this); - } - - @Override - public CompletionStage beginTransaction( - DatabaseName databaseName, - AccessMode accessMode, - String impersonatedUser, - Set bookmarks, - TransactionType transactionType, - Duration txTimeout, - Map txMetadata, - String txType, - NotificationConfig notificationConfig) { - return delegate.beginTransaction( - databaseName, - accessMode, - impersonatedUser, - bookmarks, - transactionType, - txTimeout, - txMetadata, - txType, - notificationConfig) - .thenApply(ignored -> this); - } - - @Override - public CompletionStage runInAutoCommitTransaction( - DatabaseName databaseName, - AccessMode accessMode, - String impersonatedUser, - Set bookmarks, - String query, - Map parameters, - Duration txTimeout, - Map txMetadata, - NotificationConfig notificationConfig) { - return delegate.runInAutoCommitTransaction( - databaseName, - accessMode, - impersonatedUser, - bookmarks, - query, - parameters, - txTimeout, - txMetadata, - notificationConfig) - .thenApply(ignored -> this); - } - - @Override - public CompletionStage run(String query, Map parameters) { - return delegate.run(query, parameters).thenApply(ignored -> this); - } - - @Override - public CompletionStage pull(long qid, long request) { - return delegate.pull(qid, request).thenApply(ignored -> this); - } - - @Override - public CompletionStage discard(long qid, long number) { - return delegate.discard(qid, number).thenApply(ignored -> this); - } - - @Override - public CompletionStage commit() { - return delegate.commit().thenApply(ignored -> this); - } - - @Override - public CompletionStage rollback() { - return delegate.rollback().thenApply(ignored -> this); - } - - @Override - public CompletionStage reset() { - return delegate.reset().thenApply(ignored -> this); - } - - @Override - public CompletionStage logoff() { - return delegate.logoff().thenApply(ignored -> this); - } - - @Override - public CompletionStage logon(Map authMap) { - return delegate.logon(authMap).thenApply(ignored -> this); - } - - @Override - public CompletionStage telemetry(TelemetryApi telemetryApi) { - return delegate.telemetry(telemetryApi).thenApply(ignored -> this); - } - - @Override - public CompletionStage clear() { - return delegate.clear(); - } - @Override public CompletionStage flush(ResponseHandler handler) { - return delegate.flush(new ResponseHandler() { - - @Override - public void onError(Throwable throwable) { - handler.onError(mapSecurityError(throwable)); - } - - @Override - public void onBeginSummary(BeginSummary summary) { - handler.onBeginSummary(summary); - } - - @Override - public void onRunSummary(RunSummary summary) { - handler.onRunSummary(summary); - } - - @Override - public void onRecord(Value[] fields) { - handler.onRecord(fields); - } - - @Override - public void onPullSummary(PullSummary summary) { - handler.onPullSummary(summary); - } - - @Override - public void onDiscardSummary(DiscardSummary summary) { - handler.onDiscardSummary(summary); - } - - @Override - public void onCommitSummary(CommitSummary summary) { - handler.onCommitSummary(summary); - } - - @Override - public void onRollbackSummary(RollbackSummary summary) { - handler.onRollbackSummary(summary); - } - - @Override - public void onResetSummary(ResetSummary summary) { - handler.onResetSummary(summary); - } - - @Override - public void onRouteSummary(RouteSummary summary) { - handler.onRouteSummary(summary); - } - - @Override - public void onLogoffSummary(LogoffSummary summary) { - handler.onLogoffSummary(summary); - } - - @Override - public void onLogonSummary(LogonSummary summary) { - handler.onLogonSummary(summary); - } - - @Override - public void onTelemetrySummary(TelemetrySummary summary) { - handler.onTelemetrySummary(summary); - } - - @Override - public void onIgnored() { - handler.onIgnored(); - } - - @Override - public void onComplete() { - handler.onComplete(); - } - }); - } - - @Override - public CompletionStage forceClose(String reason) { - return delegate.forceClose(reason); - } - - @Override - public CompletionStage close() { - return delegate.close(); - } - - @Override - public BoltConnectionState state() { - return delegate.state(); - } - - @Override - public CompletionStage authData() { - return delegate.authData(); - } - - @Override - public String serverAgent() { - return delegate.serverAgent(); - } - - @Override - public BoltServerAddress serverAddress() { - return delegate.serverAddress(); - } - - @Override - public BoltProtocolVersion protocolVersion() { - return delegate.protocolVersion(); - } - - @Override - public boolean telemetrySupported() { - return delegate.telemetrySupported(); + return delegate.flush(new ErrorMappingResponseHandler(handler, this::mapSecurityError)); } private Throwable mapSecurityError(Throwable throwable) { diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/BoltConnectionWithCloseTracking.java b/driver/src/main/java/org/neo4j/driver/internal/async/BoltConnectionWithCloseTracking.java new file mode 100644 index 000000000..837125e29 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/async/BoltConnectionWithCloseTracking.java @@ -0,0 +1,39 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.async; + +import java.util.concurrent.CompletionStage; +import java.util.concurrent.atomic.AtomicBoolean; +import org.neo4j.driver.internal.bolt.api.BoltConnection; + +final class BoltConnectionWithCloseTracking extends DelegatingBoltConnection { + private final AtomicBoolean open = new AtomicBoolean(true); + + BoltConnectionWithCloseTracking(BoltConnection delegate) { + super(delegate); + } + + @Override + public CompletionStage close() { + open.set(false); + return delegate.close(); + } + + public boolean isOpen() { + return open.get(); + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/DelegatingBoltConnection.java b/driver/src/main/java/org/neo4j/driver/internal/async/DelegatingBoltConnection.java new file mode 100644 index 000000000..c1e1136a5 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/async/DelegatingBoltConnection.java @@ -0,0 +1,197 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.async; + +import java.time.Duration; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CompletionStage; +import org.neo4j.driver.Value; +import org.neo4j.driver.internal.bolt.api.AccessMode; +import org.neo4j.driver.internal.bolt.api.AuthData; +import org.neo4j.driver.internal.bolt.api.BoltConnection; +import org.neo4j.driver.internal.bolt.api.BoltConnectionState; +import org.neo4j.driver.internal.bolt.api.BoltProtocolVersion; +import org.neo4j.driver.internal.bolt.api.BoltServerAddress; +import org.neo4j.driver.internal.bolt.api.DatabaseName; +import org.neo4j.driver.internal.bolt.api.NotificationConfig; +import org.neo4j.driver.internal.bolt.api.ResponseHandler; +import org.neo4j.driver.internal.bolt.api.TelemetryApi; +import org.neo4j.driver.internal.bolt.api.TransactionType; + +public abstract class DelegatingBoltConnection implements BoltConnection { + protected final BoltConnection delegate; + + protected DelegatingBoltConnection(BoltConnection delegate) { + this.delegate = Objects.requireNonNull(delegate); + } + + @Override + public CompletionStage onLoop() { + return delegate.onLoop().thenApply(ignored -> this); + } + + @Override + public CompletionStage route( + DatabaseName databaseName, String impersonatedUser, Set bookmarks) { + return delegate.route(databaseName, impersonatedUser, bookmarks).thenApply(ignored -> this); + } + + @Override + public CompletionStage beginTransaction( + DatabaseName databaseName, + AccessMode accessMode, + String impersonatedUser, + Set bookmarks, + TransactionType transactionType, + Duration txTimeout, + Map txMetadata, + String txType, + NotificationConfig notificationConfig) { + return delegate.beginTransaction( + databaseName, + accessMode, + impersonatedUser, + bookmarks, + transactionType, + txTimeout, + txMetadata, + txType, + notificationConfig) + .thenApply(ignored -> this); + } + + @Override + public CompletionStage runInAutoCommitTransaction( + DatabaseName databaseName, + AccessMode accessMode, + String impersonatedUser, + Set bookmarks, + String query, + Map parameters, + Duration txTimeout, + Map txMetadata, + NotificationConfig notificationConfig) { + return delegate.runInAutoCommitTransaction( + databaseName, + accessMode, + impersonatedUser, + bookmarks, + query, + parameters, + txTimeout, + txMetadata, + notificationConfig) + .thenApply(ignored -> this); + } + + @Override + public CompletionStage run(String query, Map parameters) { + return delegate.run(query, parameters).thenApply(ignored -> this); + } + + @Override + public CompletionStage pull(long qid, long request) { + return delegate.pull(qid, request).thenApply(ignored -> this); + } + + @Override + public CompletionStage discard(long qid, long number) { + return delegate.discard(qid, number).thenApply(ignored -> this); + } + + @Override + public CompletionStage commit() { + return delegate.commit().thenApply(ignored -> this); + } + + @Override + public CompletionStage rollback() { + return delegate.rollback().thenApply(ignored -> this); + } + + @Override + public CompletionStage reset() { + return delegate.reset().thenApply(ignored -> this); + } + + @Override + public CompletionStage logoff() { + return delegate.logoff().thenApply(ignored -> this); + } + + @Override + public CompletionStage logon(Map authMap) { + return delegate.logon(authMap).thenApply(ignored -> this); + } + + @Override + public CompletionStage telemetry(TelemetryApi telemetryApi) { + return delegate.telemetry(telemetryApi).thenApply(ignored -> this); + } + + @Override + public CompletionStage clear() { + return delegate.clear().thenApply(ignored -> this); + } + + @Override + public CompletionStage flush(ResponseHandler handler) { + return delegate.flush(handler); + } + + @Override + public CompletionStage forceClose(String reason) { + return delegate.forceClose(reason); + } + + @Override + public CompletionStage close() { + return delegate.close(); + } + + @Override + public BoltConnectionState state() { + return delegate.state(); + } + + @Override + public CompletionStage authData() { + return delegate.authData(); + } + + @Override + public String serverAgent() { + return delegate.serverAgent(); + } + + @Override + public BoltServerAddress serverAddress() { + return delegate.serverAddress(); + } + + @Override + public BoltProtocolVersion protocolVersion() { + return delegate.protocolVersion(); + } + + @Override + public boolean telemetrySupported() { + return delegate.telemetrySupported(); + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/DelegatingResponseHandler.java b/driver/src/main/java/org/neo4j/driver/internal/async/DelegatingResponseHandler.java new file mode 100644 index 000000000..f1a0ce940 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/async/DelegatingResponseHandler.java @@ -0,0 +1,115 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.async; + +import java.util.Objects; +import org.neo4j.driver.Value; +import org.neo4j.driver.internal.bolt.api.ResponseHandler; +import org.neo4j.driver.internal.bolt.api.summary.BeginSummary; +import org.neo4j.driver.internal.bolt.api.summary.CommitSummary; +import org.neo4j.driver.internal.bolt.api.summary.DiscardSummary; +import org.neo4j.driver.internal.bolt.api.summary.LogoffSummary; +import org.neo4j.driver.internal.bolt.api.summary.LogonSummary; +import org.neo4j.driver.internal.bolt.api.summary.PullSummary; +import org.neo4j.driver.internal.bolt.api.summary.ResetSummary; +import org.neo4j.driver.internal.bolt.api.summary.RollbackSummary; +import org.neo4j.driver.internal.bolt.api.summary.RouteSummary; +import org.neo4j.driver.internal.bolt.api.summary.RunSummary; +import org.neo4j.driver.internal.bolt.api.summary.TelemetrySummary; + +abstract class DelegatingResponseHandler implements ResponseHandler { + protected final ResponseHandler delegate; + + DelegatingResponseHandler(ResponseHandler delegate) { + this.delegate = Objects.requireNonNull(delegate); + } + + @Override + public void onError(Throwable throwable) { + delegate.onError(throwable); + } + + @Override + public void onBeginSummary(BeginSummary summary) { + delegate.onBeginSummary(summary); + } + + @Override + public void onRunSummary(RunSummary summary) { + delegate.onRunSummary(summary); + } + + @Override + public void onRecord(Value[] fields) { + delegate.onRecord(fields); + } + + @Override + public void onPullSummary(PullSummary summary) { + delegate.onPullSummary(summary); + } + + @Override + public void onDiscardSummary(DiscardSummary summary) { + delegate.onDiscardSummary(summary); + } + + @Override + public void onCommitSummary(CommitSummary summary) { + delegate.onCommitSummary(summary); + } + + @Override + public void onRollbackSummary(RollbackSummary summary) { + delegate.onRollbackSummary(summary); + } + + @Override + public void onResetSummary(ResetSummary summary) { + delegate.onResetSummary(summary); + } + + @Override + public void onRouteSummary(RouteSummary summary) { + delegate.onRouteSummary(summary); + } + + @Override + public void onLogoffSummary(LogoffSummary summary) { + delegate.onLogoffSummary(summary); + } + + @Override + public void onLogonSummary(LogonSummary summary) { + delegate.onLogonSummary(summary); + } + + @Override + public void onTelemetrySummary(TelemetrySummary summary) { + delegate.onTelemetrySummary(summary); + } + + @Override + public void onIgnored() { + delegate.onIgnored(); + } + + @Override + public void onComplete() { + delegate.onComplete(); + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/ErrorMappingResponseHandler.java b/driver/src/main/java/org/neo4j/driver/internal/async/ErrorMappingResponseHandler.java new file mode 100644 index 000000000..ce65e5064 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/async/ErrorMappingResponseHandler.java @@ -0,0 +1,35 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.async; + +import java.util.Objects; +import java.util.function.Function; +import org.neo4j.driver.internal.bolt.api.ResponseHandler; + +final class ErrorMappingResponseHandler extends DelegatingResponseHandler { + private final Function errorMapper; + + ErrorMappingResponseHandler(ResponseHandler delegate, Function errorMapper) { + super(delegate); + this.errorMapper = Objects.requireNonNull(errorMapper); + } + + @Override + public void onError(Throwable throwable) { + delegate.onError(errorMapper.apply(throwable)); + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkSession.java b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkSession.java index 03cf5d41d..09db8fbc0 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkSession.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkSession.java @@ -20,7 +20,6 @@ import static org.neo4j.driver.internal.util.Futures.completedWithNull; import static org.neo4j.driver.internal.util.Futures.completionExceptionCause; -import java.time.Duration; import java.util.Collections; import java.util.HashSet; import java.util.Map; @@ -56,19 +55,15 @@ import org.neo4j.driver.internal.DatabaseBookmark; import org.neo4j.driver.internal.FailableCursor; import org.neo4j.driver.internal.NotificationConfigMapper; -import org.neo4j.driver.internal.bolt.api.AuthData; import org.neo4j.driver.internal.bolt.api.BoltConnection; import org.neo4j.driver.internal.bolt.api.BoltConnectionProvider; -import org.neo4j.driver.internal.bolt.api.BoltConnectionState; import org.neo4j.driver.internal.bolt.api.BoltProtocolVersion; -import org.neo4j.driver.internal.bolt.api.BoltServerAddress; import org.neo4j.driver.internal.bolt.api.DatabaseName; import org.neo4j.driver.internal.bolt.api.DatabaseNameUtil; import org.neo4j.driver.internal.bolt.api.GqlStatusError; import org.neo4j.driver.internal.bolt.api.NotificationConfig; import org.neo4j.driver.internal.bolt.api.ResponseHandler; import org.neo4j.driver.internal.bolt.api.TelemetryApi; -import org.neo4j.driver.internal.bolt.api.TransactionType; import org.neo4j.driver.internal.bolt.api.exception.MinVersionAcquisitionException; import org.neo4j.driver.internal.bolt.api.summary.RunSummary; import org.neo4j.driver.internal.cursor.DisposableResultCursorImpl; @@ -153,7 +148,7 @@ public CompletionStage runAsync(Query query, TransactionConfig con var apiTelemetryWork = new ApiTelemetryWork(TelemetryApi.AUTO_COMMIT_TRANSACTION); apiTelemetryWork.setEnabled(!telemetryDisabled); var resultCursor = new ResultCursorImpl( - connection, query, fetchSize, null, this::handleNewBookmark, true, () -> null, null, null); + connection, query, fetchSize, this::handleNewBookmark, true, null, null); var cursorStage = apiTelemetryWork .pipelineTelemetryIfEnabled(connection) .thenCompose(conn -> conn.runInAutoCommitTransaction( @@ -320,7 +315,7 @@ public CompletionStage resetAsync() { }) .thenCompose(ignore -> connectionStage) .thenCompose(connection -> { - if (connection != null && !connection.closed.get()) { + if (connection != null && connection.isOpen()) { var future = new CompletableFuture(); return connection .reset() @@ -400,7 +395,7 @@ protected CompletionStage currentConnectionIsOpen() { && // no acquisition error connection != null && // some connection has actually been acquired - !connection.closed.get()); // and it's still open + connection.isOpen()); // and it's still open } private org.neo4j.driver.internal.bolt.api.AccessMode asBoltAccessMode(AccessMode mode) { @@ -607,163 +602,6 @@ private void assertDatabaseNameFutureIsDone() { } } - private static class BoltConnectionWithCloseTracking implements BoltConnection { - private final BoltConnection connection; - private final AtomicBoolean closed = new AtomicBoolean(false); - - private BoltConnectionWithCloseTracking(BoltConnection connection) { - this.connection = connection; - } - - @Override - public CompletionStage route( - DatabaseName databaseName, String impersonatedUser, Set bookmarks) { - return connection.route(databaseName, impersonatedUser, bookmarks); - } - - @Override - public CompletionStage beginTransaction( - DatabaseName databaseName, - org.neo4j.driver.internal.bolt.api.AccessMode accessMode, - String impersonatedUser, - Set bookmarks, - TransactionType transactionType, - Duration txTimeout, - Map txMetadata, - String txType, - NotificationConfig notificationConfig) { - return connection.beginTransaction( - databaseName, - accessMode, - impersonatedUser, - bookmarks, - transactionType, - txTimeout, - txMetadata, - txType, - notificationConfig); - } - - @Override - public CompletionStage runInAutoCommitTransaction( - DatabaseName databaseName, - org.neo4j.driver.internal.bolt.api.AccessMode accessMode, - String impersonatedUser, - Set bookmarks, - String query, - Map parameters, - Duration txTimeout, - Map txMetadata, - NotificationConfig notificationConfig) { - return connection.runInAutoCommitTransaction( - databaseName, - accessMode, - impersonatedUser, - bookmarks, - query, - parameters, - txTimeout, - txMetadata, - notificationConfig); - } - - @Override - public CompletionStage run(String query, Map parameters) { - return connection.run(query, parameters); - } - - @Override - public CompletionStage pull(long qid, long request) { - return connection.pull(qid, request); - } - - @Override - public CompletionStage discard(long qid, long number) { - return connection.discard(qid, number); - } - - @Override - public CompletionStage commit() { - return connection.commit(); - } - - @Override - public CompletionStage rollback() { - return connection.rollback(); - } - - @Override - public CompletionStage reset() { - return connection.reset(); - } - - @Override - public CompletionStage logoff() { - return connection.logoff(); - } - - @Override - public CompletionStage logon(Map authMap) { - return connection.logon(authMap); - } - - @Override - public CompletionStage telemetry(TelemetryApi telemetryApi) { - return connection.telemetry(telemetryApi); - } - - @Override - public CompletionStage clear() { - return connection.clear(); - } - - @Override - public CompletionStage flush(ResponseHandler handler) { - return connection.flush(handler); - } - - @Override - public CompletionStage forceClose(String reason) { - return connection.forceClose(reason); - } - - @Override - public CompletionStage close() { - closed.set(true); - return connection.close(); - } - - @Override - public BoltConnectionState state() { - return connection.state(); - } - - @Override - public CompletionStage authData() { - return connection.authData(); - } - - @Override - public String serverAgent() { - return connection.serverAgent(); - } - - @Override - public BoltServerAddress serverAddress() { - return connection.serverAddress(); - } - - @Override - public BoltProtocolVersion protocolVersion() { - return connection.protocolVersion(); - } - - @Override - public boolean telemetrySupported() { - return connection.telemetrySupported(); - } - } - /** * The {@link NetworkSessionConnectionContext#mode} can be mutable for a session connection context */ @@ -833,21 +671,15 @@ public RunRxResponseHandler( this.runFailed = runFailed; } - @SuppressWarnings("DuplicatedCode") @Override public void onError(Throwable throwable) { - if (throwable instanceof CompletionException) { - throwable = throwable.getCause(); - } + throwable = Futures.completionExceptionCause(throwable); if (error == null) { error = throwable; } else { if (error instanceof Neo4jException && !(throwable instanceof Neo4jException)) { // higher order error has occurred - throwable.addSuppressed(error); error = throwable; - } else { - error.addSuppressed(throwable); } } } @@ -868,16 +700,8 @@ public void onComplete() { if (error != null) { runFailed.set(true); } - cursorFuture.complete(new RxResultCursorImpl( - connection, - query, - runSummary, - error, - bookmarkConsumer, - (ignored) -> {}, - true, - () -> null, - logging)); + cursorFuture.complete( + new RxResultCursorImpl(connection, query, runSummary, error, bookmarkConsumer, true, logging)); } else { var message = ignoredCount > 0 ? "Run exchange contains ignored messages." diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/TerminationAwareBoltConnection.java b/driver/src/main/java/org/neo4j/driver/internal/async/TerminationAwareBoltConnection.java index f41813677..20d1556c4 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/TerminationAwareBoltConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/TerminationAwareBoltConnection.java @@ -16,208 +16,113 @@ */ package org.neo4j.driver.internal.async; -import java.time.Duration; -import java.util.Map; import java.util.Objects; -import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; -import org.neo4j.driver.Value; -import org.neo4j.driver.internal.bolt.api.AccessMode; -import org.neo4j.driver.internal.bolt.api.AuthData; +import java.util.function.Consumer; +import org.neo4j.driver.Logger; +import org.neo4j.driver.Logging; import org.neo4j.driver.internal.bolt.api.BoltConnection; -import org.neo4j.driver.internal.bolt.api.BoltConnectionState; -import org.neo4j.driver.internal.bolt.api.BoltProtocolVersion; -import org.neo4j.driver.internal.bolt.api.BoltServerAddress; -import org.neo4j.driver.internal.bolt.api.DatabaseName; -import org.neo4j.driver.internal.bolt.api.NotificationConfig; import org.neo4j.driver.internal.bolt.api.ResponseHandler; -import org.neo4j.driver.internal.bolt.api.TelemetryApi; -import org.neo4j.driver.internal.bolt.api.TransactionType; +import org.neo4j.driver.internal.util.Futures; -public class TerminationAwareBoltConnection implements BoltConnection { - private final BoltConnection delegate; +final class TerminationAwareBoltConnection extends DelegatingBoltConnection { + private final Logging logging; + private final Logger log; private final TerminationAwareStateLockingExecutor executor; - - public TerminationAwareBoltConnection(BoltConnection delegate, TerminationAwareStateLockingExecutor executor) { - this.delegate = Objects.requireNonNull(delegate); + private final Consumer throwableConsumer; + + public TerminationAwareBoltConnection( + Logging logging, + BoltConnection delegate, + TerminationAwareStateLockingExecutor executor, + Consumer throwableConsumer) { + super(delegate); + this.logging = Objects.requireNonNull(logging); + this.log = logging.getLog(getClass()); this.executor = Objects.requireNonNull(executor); + this.throwableConsumer = Objects.requireNonNull(throwableConsumer); } public CompletionStage clearAndReset() { var future = new CompletableFuture(); var thisVal = this; - delegate.clear() - .thenCompose(BoltConnection::reset) - .thenCompose(connection -> connection.flush(new ResponseHandler() { - @Override - public void onError(Throwable throwable) { - future.completeExceptionally(throwable); - } - @Override - public void onComplete() { - future.complete(thisVal); - } - })) - .whenComplete((result, throwable) -> { + delegate.onLoop() + .thenCompose(connection -> executor.execute(ignored -> connection + .clear() + .thenCompose(BoltConnection::reset) + .thenCompose(conn -> conn.flush(new ResponseHandler() { + Throwable throwable = null; + + @Override + public void onError(Throwable throwable) { + log.error("Unexpected error occurred while resetting connection", throwable); + throwableConsumer.accept(throwable); + this.throwable = throwable; + } + + @Override + public void onComplete() { + if (throwable != null) { + future.completeExceptionally(throwable); + } else { + future.complete(thisVal); + } + } + })))) + .whenComplete((ignored, throwable) -> { if (throwable != null) { + throwableConsumer.accept(throwable); future.completeExceptionally(throwable); } }); - return future; - } - - @Override - public boolean telemetrySupported() { - return delegate.telemetrySupported(); - } - - @Override - public BoltProtocolVersion protocolVersion() { - return delegate.protocolVersion(); - } - - @Override - public BoltServerAddress serverAddress() { - return delegate.serverAddress(); - } - - @Override - public String serverAgent() { - return delegate.serverAgent(); - } - @Override - public CompletionStage authData() { - return delegate.authData(); - } - - @Override - public BoltConnectionState state() { - return delegate.state(); - } - - @Override - public CompletionStage close() { - return delegate.close(); - } - - @Override - public CompletionStage forceClose(String reason) { - return delegate.forceClose(reason); + return future; } @Override public CompletionStage flush(ResponseHandler handler) { - return executor.execute(causeOfTermination -> { - if (causeOfTermination == null) { - return delegate.flush(handler); - } else { - return CompletableFuture.failedStage(causeOfTermination); - } - }); - } - - @Override - public CompletionStage telemetry(TelemetryApi telemetryApi) { - return delegate.telemetry(telemetryApi); - } - - @Override - public CompletionStage clear() { - return delegate.clear(); - } - - @Override - public CompletionStage logon(Map authMap) { - return delegate.logon(authMap); - } - - @Override - public CompletionStage logoff() { - return delegate.logoff(); - } - - @Override - public CompletionStage reset() { - return delegate.reset(); - } - - @Override - public CompletionStage rollback() { - return delegate.rollback(); - } - - @Override - public CompletionStage commit() { - return delegate.commit(); - } - - @Override - public CompletionStage discard(long qid, long number) { - return delegate.discard(qid, number); - } - - @Override - public CompletionStage pull(long qid, long request) { - return delegate.pull(qid, request); - } - - @Override - public CompletionStage run(String query, Map parameters) { - return delegate.run(query, parameters); - } - - @Override - public CompletionStage runInAutoCommitTransaction( - DatabaseName databaseName, - AccessMode accessMode, - String impersonatedUser, - Set bookmarks, - String query, - Map parameters, - Duration txTimeout, - Map txMetadata, - NotificationConfig notificationConfig) { - return delegate.runInAutoCommitTransaction( - databaseName, - accessMode, - impersonatedUser, - bookmarks, - query, - parameters, - txTimeout, - txMetadata, - notificationConfig); - } - - @Override - public CompletionStage beginTransaction( - DatabaseName databaseName, - AccessMode accessMode, - String impersonatedUser, - Set bookmarks, - TransactionType transactionType, - Duration txTimeout, - Map txMetadata, - String txType, - NotificationConfig notificationConfig) { - return delegate.beginTransaction( - databaseName, - accessMode, - impersonatedUser, - bookmarks, - transactionType, - txTimeout, - txMetadata, - txType, - notificationConfig); - } - - @Override - public CompletionStage route( - DatabaseName databaseName, String impersonatedUser, Set bookmarks) { - return delegate.route(databaseName, impersonatedUser, bookmarks); + return delegate.onLoop() + .thenCompose(connection -> executor.execute(causeOfTermination -> { + if (causeOfTermination == null) { + log.trace("This connection is active, will flush"); + var terminationAwareResponseHandler = + new TerminationAwareResponseHandler(logging, handler, executor, throwableConsumer); + return delegate.flush(terminationAwareResponseHandler).handle((ignored, flushThrowable) -> { + flushThrowable = Futures.completionExceptionCause(flushThrowable); + if (flushThrowable != null) { + if (log.isTraceEnabled()) { + log.error("The flush has failed", flushThrowable); + } + var flushThrowableRef = flushThrowable; + flushThrowable = executor.execute(existingThrowable -> { + if (existingThrowable != null) { + log.trace( + "The flush has failed, but there is an existing %s", existingThrowable); + return existingThrowable; + } else { + throwableConsumer.accept(flushThrowableRef); + return flushThrowableRef; + } + }); + // rethrow + if (flushThrowable instanceof RuntimeException runtimeException) { + throw runtimeException; + } else { + throw new CompletionException(flushThrowable); + } + } else { + return ignored; + } + }); + } else { + // there is an existing error + return connection + .clear() + .thenCompose(ignored -> CompletableFuture.failedStage(causeOfTermination)); + } + })); } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/TerminationAwareResponseHandler.java b/driver/src/main/java/org/neo4j/driver/internal/async/TerminationAwareResponseHandler.java new file mode 100644 index 000000000..2e6de3b31 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/async/TerminationAwareResponseHandler.java @@ -0,0 +1,61 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [https://neo4j.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.async; + +import java.util.Objects; +import java.util.function.Consumer; +import java.util.function.Function; +import org.neo4j.driver.Logger; +import org.neo4j.driver.Logging; +import org.neo4j.driver.internal.bolt.api.ResponseHandler; +import org.neo4j.driver.internal.util.Futures; + +final class TerminationAwareResponseHandler extends DelegatingResponseHandler { + private final Logger log; + private final TerminationAwareStateLockingExecutor executor; + private final Consumer throwableConsumer; + + TerminationAwareResponseHandler( + Logging logging, + ResponseHandler delegate, + TerminationAwareStateLockingExecutor executor, + Consumer throwableConsumer) { + super(delegate); + this.log = logging.getLog(getClass()); + this.executor = Objects.requireNonNull(executor); + this.throwableConsumer = Objects.requireNonNull(throwableConsumer); + } + + @Override + public void onError(Throwable throwable) { + throwableConsumer.accept(Futures.completionExceptionCause(throwable)); + super.onError(throwable); + } + + @Override + public void onComplete() { + var throwable = executor.execute(Function.identity()); + if (throwable != null) { + log.trace( + "Reporting an existing %s error to delegate", + throwable.getClass().getCanonicalName()); + delegate.onError(throwable); + } + log.trace("Completing delegate"); + delegate.onComplete(); + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java b/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java index 17e57072f..8dceb8565 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java @@ -27,14 +27,12 @@ import java.util.EnumSet; import java.util.Set; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; -import java.util.function.Supplier; import java.util.stream.Collectors; import org.neo4j.driver.Bookmark; import org.neo4j.driver.Logging; @@ -155,7 +153,7 @@ protected UnmanagedTransaction( ApiTelemetryWork apiTelemetryWork, Logging logging) { this.logging = logging; - this.connection = new TerminationAwareBoltConnection(connection, this); + this.connection = new TerminationAwareBoltConnection(logging, connection, this, this::markTerminated); this.databaseName = databaseName; this.accessMode = accessMode; this.impersonatedUser = impersonatedUser; @@ -186,15 +184,14 @@ public CompletionStage beginAsync( notificationConfig)) .thenCompose(connection -> { if (flush) { - var responseHandler = new BeginResponseHandler( - apiTelemetryWork, () -> executeWithLock(lock, () -> causeOfTermination)); + var responseHandler = new BeginResponseHandler(apiTelemetryWork); connection .flush(responseHandler) .thenCompose(ignored -> responseHandler.summaryFuture) .whenComplete((summary, throwable) -> { if (throwable != null) { connection.close().whenComplete((ignored, closeThrowable) -> { - if (closeThrowable != null) { + if (closeThrowable != null && throwable != closeThrowable) { throwable.addSuppressed(closeThrowable); } beginFuture.completeExceptionally(throwable); @@ -230,15 +227,7 @@ public CompletionStage runAsync(Query query) { ensureCanRunQueries(); var parameters = query.parameters().asMap(Values::value); var resultCursor = new ResultCursorImpl( - connection, - query, - fetchSize, - this::markTerminated, - (bookmark) -> {}, - false, - () -> executeWithLock(lock, () -> causeOfTermination), - beginFuture, - apiTelemetryWork); + connection, query, fetchSize, (bookmark) -> {}, false, beginFuture, apiTelemetryWork); var flushStage = connection .run(query.text(), parameters) .thenCompose(ignored -> connection.pull(-1, fetchSize)) @@ -255,15 +244,7 @@ public CompletionStage runAsync(Query query) { public CompletionStage runRx(Query query) { ensureCanRunQueries(); var parameters = query.parameters().asMap(Values::value); - var responseHandler = new RunRxResponseHandler( - logging, - apiTelemetryWork, - () -> executeWithLock(lock, () -> causeOfTermination), - this::markTerminated, - beginFuture, - this, - connection, - query); + var responseHandler = new RunRxResponseHandler(logging, apiTelemetryWork, beginFuture, connection, query); var flushStage = connection.run(query.text(), parameters).thenCompose(ignored2 -> connection.flush(responseHandler)); return beginFuture.thenCompose(ignored -> { @@ -278,15 +259,16 @@ public boolean isOpen() { } public void markTerminated(Throwable cause) { + var throwable = Futures.completionExceptionCause(cause); executeWithLock(lock, () -> { if (state == State.TERMINATED) { - if (cause != null) { - addSuppressedWhenNotCaptured(causeOfTermination, cause); + if (throwable != null) { + addSuppressedWhenNotCaptured(causeOfTermination, throwable); } } else { state = State.TERMINATED; - causeOfTermination = cause != null - ? cause + causeOfTermination = throwable != null + ? throwable : new TransactionTerminatedException( GqlStatusError.UNKNOWN.getStatus(), GqlStatusError.UNKNOWN.getStatusDescription(EXPLICITLY_TERMINATED_MSG), @@ -429,19 +411,16 @@ private CompletionStage doCommitAsync(Throwable cursorFailure) { .ifPresent(bookmarkConsumer); commitSummary.complete(summary); } else { - throwable = executeWithLock(lock, () -> causeOfTermination); - if (throwable == null) { - var message = summaries.ignored() > 0 - ? "Commit exchange contains ignored messages" - : "Unexpected state during commit"; - throwable = new ClientException( - GqlStatusError.UNKNOWN.getStatus(), - GqlStatusError.UNKNOWN.getStatusDescription(message), - "N/A", - message, - GqlStatusError.DIAGNOSTIC_RECORD, - null); - } + var message = summaries.ignored() > 0 + ? "Commit exchange contains ignored messages" + : "Unexpected state during commit"; + throwable = new ClientException( + GqlStatusError.UNKNOWN.getStatus(), + GqlStatusError.UNKNOWN.getStatusDescription(message), + "N/A", + message, + GqlStatusError.DIAGNOSTIC_RECORD, + null); commitSummary.completeExceptionally(throwable); } } @@ -468,19 +447,16 @@ private CompletionStage doRollbackAsync() { if (summary != null) { rollbackFuture.complete(null); } else { - throwable = executeWithLock(lock, () -> causeOfTermination); - if (throwable == null) { - var message = summaries.ignored() > 0 - ? "Rollback exchange contains ignored messages" - : "Unexpected state during rollback"; - throwable = new ClientException( - GqlStatusError.UNKNOWN.getStatus(), - GqlStatusError.UNKNOWN.getStatusDescription(message), - "N/A", - message, - GqlStatusError.DIAGNOSTIC_RECORD, - null); - } + var message = summaries.ignored() > 0 + ? "Rollback exchange contains ignored messages" + : "Unexpected state during rollback"; + throwable = new ClientException( + GqlStatusError.UNKNOWN.getStatus(), + GqlStatusError.UNKNOWN.getStatusDescription(message), + "N/A", + message, + GqlStatusError.DIAGNOSTIC_RECORD, + null); rollbackFuture.completeExceptionally(throwable); } } @@ -603,31 +579,23 @@ private CompletionStage closeAsync(boolean commit, boolean completeWithNul private static class BeginResponseHandler implements ResponseHandler { final CompletableFuture summaryFuture = new CompletableFuture<>(); private final ApiTelemetryWork apiTelemetryWork; - private final Supplier termSupplier; private Throwable error; private BeginSummary beginSummary; private int ignoredCount; - private BeginResponseHandler(ApiTelemetryWork apiTelemetryWork, Supplier termSupplier) { + private BeginResponseHandler(ApiTelemetryWork apiTelemetryWork) { this.apiTelemetryWork = apiTelemetryWork; - this.termSupplier = termSupplier; } - @SuppressWarnings("DuplicatedCode") @Override public void onError(Throwable throwable) { - if (throwable instanceof CompletionException) { - throwable = throwable.getCause(); - } + throwable = Futures.completionExceptionCause(throwable); if (error == null) { error = throwable; } else { if (error instanceof Neo4jException && !(throwable instanceof Neo4jException)) { // higher order error has occurred - throwable.addSuppressed(error); error = throwable; - } else { - error.addSuppressed(throwable); } } } @@ -655,19 +623,16 @@ public void onComplete() { if (beginSummary != null) { summaryFuture.complete(null); } else { - var throwable = termSupplier.get(); - if (throwable == null) { - var message = ignoredCount > 0 - ? "Begin exchange contains ignored messages" - : "Unexpected state during begin"; - throwable = new ClientException( - GqlStatusError.UNKNOWN.getStatus(), - GqlStatusError.UNKNOWN.getStatusDescription(message), - "N/A", - message, - GqlStatusError.DIAGNOSTIC_RECORD, - null); - } + var message = ignoredCount > 0 + ? "Begin exchange contains ignored messages" + : "Unexpected state during begin"; + var throwable = new ClientException( + GqlStatusError.UNKNOWN.getStatus(), + GqlStatusError.UNKNOWN.getStatusDescription(message), + "N/A", + message, + GqlStatusError.DIAGNOSTIC_RECORD, + null); summaryFuture.completeExceptionally(throwable); } } @@ -678,10 +643,7 @@ private static class RunRxResponseHandler implements ResponseHandler { final CompletableFuture cursorFuture = new CompletableFuture<>(); private final Logging logging; private final ApiTelemetryWork apiTelemetryWork; - private final Supplier termSupplier; - private final Consumer markTerminated; private final CompletableFuture beginFuture; - private final UnmanagedTransaction transaction; private final BoltConnection connection; private final Query query; private Throwable error; @@ -691,37 +653,25 @@ private static class RunRxResponseHandler implements ResponseHandler { private RunRxResponseHandler( Logging logging, ApiTelemetryWork apiTelemetryWork, - Supplier termSupplier, - Consumer markTerminated, CompletableFuture beginFuture, - UnmanagedTransaction transaction, BoltConnection connection, Query query) { this.logging = logging; this.apiTelemetryWork = apiTelemetryWork; - this.termSupplier = termSupplier; - this.markTerminated = markTerminated; this.beginFuture = beginFuture; - this.transaction = transaction; this.connection = connection; this.query = query; } - @SuppressWarnings("DuplicatedCode") @Override public void onError(Throwable throwable) { - if (throwable instanceof CompletionException) { - throwable = throwable.getCause(); - } + throwable = Futures.completionExceptionCause(throwable); if (error == null) { error = throwable; } else { if (error instanceof Neo4jException && !(throwable instanceof Neo4jException)) { // higher order error has occurred - throwable.addSuppressed(error); error = throwable; - } else { - error.addSuppressed(throwable); } } } @@ -744,47 +694,24 @@ public void onIgnored() { @Override public void onComplete() { if (error != null) { - if (beginFuture.completeExceptionally(error)) { - markTerminated.accept(error); - } else { - markTerminated.accept(error); - cursorFuture.complete(new RxResultCursorImpl( - connection, - query, - null, - error, - bookmark -> {}, - transaction::markTerminated, - false, - termSupplier, - logging)); + if (!beginFuture.completeExceptionally(error)) { + cursorFuture.complete( + new RxResultCursorImpl(connection, query, null, error, bookmark -> {}, false, logging)); } } else { if (runSummary != null) { cursorFuture.complete(new RxResultCursorImpl( - connection, - query, - runSummary, - null, - bookmark -> {}, - transaction::markTerminated, - false, - termSupplier, - logging)); + connection, query, runSummary, null, bookmark -> {}, false, logging)); } else { - var throwable = termSupplier.get(); - if (throwable == null) { - var message = ignoredCount > 0 - ? "Run exchange contains ignored messages" - : "Unexpected state during run"; - throwable = new ClientException( - GqlStatusError.UNKNOWN.getStatus(), - GqlStatusError.UNKNOWN.getStatusDescription(message), - "N/A", - message, - GqlStatusError.DIAGNOSTIC_RECORD, - null); - } + var message = + ignoredCount > 0 ? "Run exchange contains ignored messages" : "Unexpected state during run"; + var throwable = new ClientException( + GqlStatusError.UNKNOWN.getStatus(), + GqlStatusError.UNKNOWN.getStatusDescription(message), + "N/A", + message, + GqlStatusError.DIAGNOSTIC_RECORD, + null); if (!beginFuture.completeExceptionally(throwable)) { cursorFuture.completeExceptionally(throwable); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/bolt/api/BasicResponseHandler.java b/driver/src/main/java/org/neo4j/driver/internal/bolt/api/BasicResponseHandler.java index 8e77d0344..7160ea63e 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/bolt/api/BasicResponseHandler.java +++ b/driver/src/main/java/org/neo4j/driver/internal/bolt/api/BasicResponseHandler.java @@ -57,7 +57,6 @@ public CompletionStage summaries() { return summariesFuture; } - @SuppressWarnings("DuplicatedCode") @Override public void onError(Throwable throwable) { if (throwable instanceof CompletionException) { @@ -68,10 +67,7 @@ public void onError(Throwable throwable) { } else { if (error instanceof Neo4jException && !(throwable instanceof Neo4jException)) { // higher order error has occurred - throwable.addSuppressed(error); error = throwable; - } else { - error.addSuppressed(throwable); } } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/bolt/api/BoltConnection.java b/driver/src/main/java/org/neo4j/driver/internal/bolt/api/BoltConnection.java index fa99eba36..03c595953 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/bolt/api/BoltConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/bolt/api/BoltConnection.java @@ -23,6 +23,8 @@ import org.neo4j.driver.Value; public interface BoltConnection { + CompletionStage onLoop(); + CompletionStage route(DatabaseName databaseName, String impersonatedUser, Set bookmarks); CompletionStage beginTransaction( diff --git a/driver/src/main/java/org/neo4j/driver/internal/bolt/basicimpl/BoltConnectionImpl.java b/driver/src/main/java/org/neo4j/driver/internal/bolt/basicimpl/BoltConnectionImpl.java index cfa99a2c1..15052e09f 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/bolt/basicimpl/BoltConnectionImpl.java +++ b/driver/src/main/java/org/neo4j/driver/internal/bolt/basicimpl/BoltConnectionImpl.java @@ -114,6 +114,11 @@ public BoltConnectionImpl( this.log = this.logging.getLog(getClass()); } + @Override + public CompletionStage onLoop() { + return executeInEventLoop(() -> {}).thenApply(ignored -> this); + } + @Override public CompletionStage route( DatabaseName databaseName, String impersonatedUser, Set bookmarks) { @@ -499,14 +504,14 @@ public boolean telemetrySupported() { } private CompletionStage executeInEventLoop(Runnable runnable) { - var executeStage = new CompletableFuture(); + var executeFuture = new CompletableFuture(); Runnable stageCompletingRunnable = () -> { try { runnable.run(); } catch (Throwable throwable) { - executeStage.completeExceptionally(throwable); + executeFuture.completeExceptionally(throwable); } - executeStage.complete(null); + executeFuture.complete(null); }; if (eventLoop.inEventLoop()) { stageCompletingRunnable.run(); @@ -514,10 +519,10 @@ private CompletionStage executeInEventLoop(Runnable runnable) { try { eventLoop.execute(stageCompletingRunnable); } catch (Throwable throwable) { - executeStage.completeExceptionally(throwable); + executeFuture.completeExceptionally(throwable); } } - return executeStage; + return executeFuture; } private void updateState(Throwable throwable) { diff --git a/driver/src/main/java/org/neo4j/driver/internal/bolt/pooledimpl/PooledBoltConnection.java b/driver/src/main/java/org/neo4j/driver/internal/bolt/pooledimpl/PooledBoltConnection.java index 5395759fb..0b8884f05 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/bolt/pooledimpl/PooledBoltConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/bolt/pooledimpl/PooledBoltConnection.java @@ -67,6 +67,11 @@ public PooledBoltConnection( this.purgeRunnable = Objects.requireNonNull(purgeRunnable); } + @Override + public CompletionStage onLoop() { + return delegate.onLoop(); + } + @Override public CompletionStage route( DatabaseName databaseName, String impersonatedUser, Set bookmarks) { diff --git a/driver/src/main/java/org/neo4j/driver/internal/bolt/routedimpl/RoutedBoltConnection.java b/driver/src/main/java/org/neo4j/driver/internal/bolt/routedimpl/RoutedBoltConnection.java index b080c1de7..1c33772a1 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/bolt/routedimpl/RoutedBoltConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/bolt/routedimpl/RoutedBoltConnection.java @@ -71,6 +71,11 @@ public RoutedBoltConnection( this.provider = Objects.requireNonNull(provider); } + @Override + public CompletionStage onLoop() { + return delegate.onLoop(); + } + @Override public CompletionStage route( DatabaseName databaseName, String impersonatedUser, Set bookmarks) { diff --git a/driver/src/main/java/org/neo4j/driver/internal/cursor/ResultCursorImpl.java b/driver/src/main/java/org/neo4j/driver/internal/cursor/ResultCursorImpl.java index 01bbde97a..a4eb94af7 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cursor/ResultCursorImpl.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cursor/ResultCursorImpl.java @@ -28,7 +28,6 @@ import java.util.concurrent.CompletionStage; import java.util.function.Consumer; import java.util.function.Function; -import java.util.function.Supplier; import org.neo4j.driver.Bookmark; import org.neo4j.driver.Query; import org.neo4j.driver.Record; @@ -69,9 +68,7 @@ public class ResultCursorImpl extends AbstractRecordStateResponseHandler private final Queue records = new ArrayDeque<>(); private final Query query; private final long fetchSize; - private final Consumer throwableConsumer; private final Consumer bookmarkConsumer; - private final Supplier termSupplier; private final boolean closeOnSummary; private final boolean legacyNotifications; private final CompletableFuture resultCursorFuture = new CompletableFuture<>(); @@ -104,10 +101,8 @@ public ResultCursorImpl( BoltConnection boltConnection, Query query, long fetchSize, - Consumer throwableConsumer, Consumer bookmarkConsumer, boolean closeOnSummary, - Supplier termSupplier, CompletableFuture beginFuture, ApiTelemetryWork apiTelemetryWork) { this.boltConnection = Objects.requireNonNull(boltConnection); @@ -115,11 +110,9 @@ public ResultCursorImpl( updateRecordState(RecordState.REQUESTED); this.query = Objects.requireNonNull(query); this.fetchSize = fetchSize; - this.throwableConsumer = throwableConsumer; this.bookmarkConsumer = Objects.requireNonNull(bookmarkConsumer); this.closeOnSummary = closeOnSummary; this.state = State.STREAMING; - this.termSupplier = termSupplier; this.beginFuture = beginFuture; this.apiTelemetryWork = apiTelemetryWork; } @@ -149,36 +142,28 @@ public synchronized CompletionStage consumeAsync() { CompletionStage summaryFt = switch (state) { case READY -> { - var term = termSupplier.get(); - if (term == null) { - apiCallInProgress = true; - summaryFuture = new CompletableFuture<>(); - var future = summaryFuture; - state = State.DISCARDING; - boltConnection - .discard(runSummary.queryId(), -1) - .thenCompose(conn -> conn.flush(this)) - .whenComplete((ignored, throwable) -> { - var error = Futures.completionExceptionCause(throwable); - CompletableFuture summaryFuture; - if (error != null) { - synchronized (this) { - state = State.FAILED; - errorExposed = true; - summaryFuture = this.summaryFuture; - this.summaryFuture = null; - apiCallInProgress = false; - } - summaryFuture.completeExceptionally(error); + apiCallInProgress = true; + summaryFuture = new CompletableFuture<>(); + var future = summaryFuture; + state = State.DISCARDING; + boltConnection + .discard(runSummary.queryId(), -1) + .thenCompose(conn -> conn.flush(this)) + .whenComplete((ignored, throwable) -> { + var error = Futures.completionExceptionCause(throwable); + CompletableFuture summaryFuture; + if (error != null) { + synchronized (this) { + state = State.FAILED; + errorExposed = true; + summaryFuture = this.summaryFuture; + this.summaryFuture = null; + apiCallInProgress = false; } - }); - yield future; - } else { - this.error = term; - this.state = State.FAILED; - this.errorExposed = true; - yield CompletableFuture.failedStage(error); - } + summaryFuture.completeExceptionally(error); + } + }); + yield future; } case STREAMING -> { apiCallInProgress = true; @@ -236,37 +221,29 @@ var record = records.poll(); // buffer is empty return switch (state) { case READY -> { - var term = termSupplier.get(); - if (term == null) { - apiCallInProgress = true; - recordFuture = new CompletableFuture<>(); - var result = recordFuture; - state = State.STREAMING; - updateRecordState(RecordState.NO_RECORD); - boltConnection - .pull(runSummary.queryId(), fetchSize) - .thenCompose(conn -> conn.flush(this)) - .whenComplete((ignored, throwable) -> { - var error = Futures.completionExceptionCause(throwable); - CompletableFuture recordFuture; - if (error != null) { - synchronized (this) { - state = State.FAILED; - errorExposed = true; - recordFuture = this.recordFuture; - this.recordFuture = null; - apiCallInProgress = false; - } - recordFuture.completeExceptionally(error); + apiCallInProgress = true; + recordFuture = new CompletableFuture<>(); + var result = recordFuture; + state = State.STREAMING; + updateRecordState(RecordState.NO_RECORD); + boltConnection + .pull(runSummary.queryId(), fetchSize) + .thenCompose(conn -> conn.flush(this)) + .whenComplete((ignored, throwable) -> { + var error = Futures.completionExceptionCause(throwable); + CompletableFuture recordFuture; + if (error != null) { + synchronized (this) { + state = State.FAILED; + errorExposed = true; + recordFuture = this.recordFuture; + this.recordFuture = null; + apiCallInProgress = false; } - }); - yield result; - } else { - this.error = term; - this.state = State.FAILED; - this.errorExposed = true; - yield CompletableFuture.failedStage(error); - } + recordFuture.completeExceptionally(error); + } + }); + yield result; } case STREAMING -> { apiCallInProgress = true; @@ -309,37 +286,29 @@ var record = records.peek(); // buffer is empty return switch (state) { case READY -> { - var term = termSupplier.get(); - if (term == null) { - apiCallInProgress = true; - peekFuture = new CompletableFuture<>(); - var future = peekFuture; - state = State.STREAMING; - updateRecordState(RecordState.NO_RECORD); - boltConnection - .pull(runSummary.queryId(), fetchSize) - .thenCompose(conn -> conn.flush(this)) - .whenComplete((ignored, throwable) -> { - var error = Futures.completionExceptionCause(throwable); - if (error != null) { - CompletableFuture peekFuture; - synchronized (this) { - state = State.FAILED; - errorExposed = true; - recordFuture = this.peekFuture; - this.peekFuture = null; - apiCallInProgress = false; - } - recordFuture.completeExceptionally(error); + apiCallInProgress = true; + peekFuture = new CompletableFuture<>(); + var future = peekFuture; + state = State.STREAMING; + updateRecordState(RecordState.NO_RECORD); + boltConnection + .pull(runSummary.queryId(), fetchSize) + .thenCompose(conn -> conn.flush(this)) + .whenComplete((ignored, throwable) -> { + var error = Futures.completionExceptionCause(throwable); + if (error != null) { + CompletableFuture peekFuture; + synchronized (this) { + state = State.FAILED; + errorExposed = true; + recordFuture = this.peekFuture; + this.peekFuture = null; + apiCallInProgress = false; } - }); - yield future; - } else { - this.error = term; - this.state = State.FAILED; - this.errorExposed = true; - yield CompletableFuture.failedStage(error); - } + recordFuture.completeExceptionally(error); + } + }); + yield future; } case STREAMING -> { apiCallInProgress = true; @@ -385,54 +354,46 @@ public synchronized CompletionStage singleAsync() { return switch (state) { case READY -> { if (records.isEmpty()) { - var term = termSupplier.get(); - if (term == null) { - apiCallInProgress = true; - recordFuture = new CompletableFuture<>(); - secondRecordFuture = new CompletableFuture<>(); - var singleFuture = recordFuture.thenCompose(firstRecord -> { - if (firstRecord == null) { + apiCallInProgress = true; + recordFuture = new CompletableFuture<>(); + secondRecordFuture = new CompletableFuture<>(); + var singleFuture = recordFuture.thenCompose(firstRecord -> { + if (firstRecord == null) { + throw new NoSuchRecordException( + "Cannot retrieve a single record, because this result is empty."); + } + return secondRecordFuture.thenApply(secondRecord -> { + if (secondRecord) { throw new NoSuchRecordException( - "Cannot retrieve a single record, because this result is empty."); + "Expected a result with a single record, but this result contains at least one more. Ensure your query returns only one record."); } - return secondRecordFuture.thenApply(secondRecord -> { - if (secondRecord) { - throw new NoSuchRecordException( - "Expected a result with a single record, but this result contains at least one more. Ensure your query returns only one record."); - } - return firstRecord; - }); + return firstRecord; }); - state = State.STREAMING; - updateRecordState(RecordState.NO_RECORD); - boltConnection - .pull(runSummary.queryId(), fetchSize) - .thenCompose(conn -> conn.flush(this)) - .whenComplete((ignored, throwable) -> { - var error = Futures.completionExceptionCause(throwable); - if (error != null) { - CompletableFuture recordFuture; - CompletableFuture secondRecordFuture; - synchronized (this) { - state = State.FAILED; - errorExposed = true; - recordFuture = this.recordFuture; - this.recordFuture = null; - secondRecordFuture = this.secondRecordFuture; - this.secondRecordFuture = null; - apiCallInProgress = false; - } - recordFuture.completeExceptionally(error); - secondRecordFuture.completeExceptionally(error); + }); + state = State.STREAMING; + updateRecordState(RecordState.NO_RECORD); + boltConnection + .pull(runSummary.queryId(), fetchSize) + .thenCompose(conn -> conn.flush(this)) + .whenComplete((ignored, throwable) -> { + var error = Futures.completionExceptionCause(throwable); + if (error != null) { + CompletableFuture recordFuture; + CompletableFuture secondRecordFuture; + synchronized (this) { + state = State.FAILED; + errorExposed = true; + recordFuture = this.recordFuture; + this.recordFuture = null; + secondRecordFuture = this.secondRecordFuture; + this.secondRecordFuture = null; + apiCallInProgress = false; } - }); - yield singleFuture; - } else { - this.error = term; - this.state = State.FAILED; - this.errorExposed = true; - yield CompletableFuture.failedStage(error); - } + recordFuture.completeExceptionally(error); + secondRecordFuture.completeExceptionally(error); + } + }); + yield singleFuture; } else { // records is not empty and the state is READY, meaning the result is not exhausted yield CompletableFuture.failedStage( @@ -535,37 +496,29 @@ public synchronized CompletionStage> listAsync() { } return switch (state) { case READY -> { - var term = termSupplier.get(); - if (term == null) { - apiCallInProgress = true; - recordsFuture = new CompletableFuture<>(); - var future = recordsFuture; - state = State.STREAMING; - updateRecordState(RecordState.NO_RECORD); - boltConnection - .pull(runSummary.queryId(), -1) - .thenCompose(conn -> conn.flush(this)) - .whenComplete((ignored, throwable) -> { - var error = Futures.completionExceptionCause(throwable); - CompletableFuture> recordsFuture; - if (error != null) { - synchronized (this) { - state = State.FAILED; - errorExposed = true; - recordsFuture = this.recordsFuture; - this.recordsFuture = null; - apiCallInProgress = false; - } - recordsFuture.completeExceptionally(error); + apiCallInProgress = true; + recordsFuture = new CompletableFuture<>(); + var future = recordsFuture; + state = State.STREAMING; + updateRecordState(RecordState.NO_RECORD); + boltConnection + .pull(runSummary.queryId(), -1) + .thenCompose(conn -> conn.flush(this)) + .whenComplete((ignored, throwable) -> { + var error = Futures.completionExceptionCause(throwable); + CompletableFuture> recordsFuture; + if (error != null) { + synchronized (this) { + state = State.FAILED; + errorExposed = true; + recordsFuture = this.recordsFuture; + this.recordsFuture = null; + apiCallInProgress = false; } - }); - yield future; - } else { - this.error = term; - this.state = State.FAILED; - this.errorExposed = true; - yield CompletableFuture.failedStage(error); - } + recordsFuture.completeExceptionally(error); + } + }); + yield future; } case STREAMING -> { apiCallInProgress = true; @@ -676,6 +629,7 @@ var record = new InternalRecord(runSummary.keys(), fields); } } + @SuppressWarnings("DuplicatedCode") @Override public synchronized void onError(Throwable throwable) { throwable = Futures.completionExceptionCause(throwable); @@ -685,23 +639,16 @@ public synchronized void onError(Throwable throwable) { if (throwable == IGNORED_ERROR) { return; } - if (error instanceof Neo4jException && !(throwable instanceof Neo4jException)) { + if (error == IGNORED_ERROR || (error instanceof Neo4jException && !(throwable instanceof Neo4jException))) { // higher order error has occurred - throwable.addSuppressed(error); error = throwable; - } else { - error.addSuppressed(throwable); } } } @Override public void onIgnored() { - var throwable = termSupplier.get(); - if (throwable == null) { - throwable = IGNORED_ERROR; - } - onError(throwable); + onError(IGNORED_ERROR); } @SuppressWarnings("DuplicatedCode") @@ -811,67 +758,47 @@ public void onPullSummary(PullSummary summary) { CompletableFuture secondRecordFuture = null; synchronized (this) { if (this.peekFuture != null) { - var term = termSupplier.get(); - if (term == null) { - // peek is pending, keep streaming - state = State.STREAMING; - updateRecordState(RecordState.NO_RECORD); - boltConnection - .pull(runSummary.queryId(), fetchSize) - .thenCompose(conn -> conn.flush(this)) - .whenComplete((ignored, throwable) -> { - var error = Futures.completionExceptionCause(throwable); - if (error != null) { - CompletableFuture peekFuture; - synchronized (this) { - state = State.FAILED; - errorExposed = true; - peekFuture = this.peekFuture; - this.peekFuture = null; - apiCallInProgress = false; - } - peekFuture.completeExceptionally(error); + // peek is pending, keep streaming + state = State.STREAMING; + updateRecordState(RecordState.NO_RECORD); + boltConnection + .pull(runSummary.queryId(), fetchSize) + .thenCompose(conn -> conn.flush(this)) + .whenComplete((ignored, throwable) -> { + var error = Futures.completionExceptionCause(throwable); + if (error != null) { + CompletableFuture peekFuture; + synchronized (this) { + state = State.FAILED; + errorExposed = true; + peekFuture = this.peekFuture; + this.peekFuture = null; + apiCallInProgress = false; } - }); - } else { - this.error = term; - this.state = State.FAILED; - this.errorExposed = true; - var peekFuture = this.peekFuture; - this.peekFuture = null; - peekFuture.completeExceptionally(error); - } + peekFuture.completeExceptionally(error); + } + }); } else if (this.recordFuture != null) { - var term = termSupplier.get(); - if (term == null) { - // next is pending, keep streaming - state = State.STREAMING; - updateRecordState(RecordState.NO_RECORD); - boltConnection - .pull(runSummary.queryId(), fetchSize) - .thenCompose(conn -> conn.flush(this)) - .whenComplete((ignored, throwable) -> { - var error = Futures.completionExceptionCause(throwable); - if (error != null) { - CompletableFuture recordFuture; - synchronized (this) { - state = State.FAILED; - errorExposed = true; - recordFuture = this.recordFuture; - this.recordFuture = null; - apiCallInProgress = false; - } - recordFuture.completeExceptionally(error); + // next is pending, keep streaming + state = State.STREAMING; + updateRecordState(RecordState.NO_RECORD); + boltConnection + .pull(runSummary.queryId(), fetchSize) + .thenCompose(conn -> conn.flush(this)) + .whenComplete((ignored, throwable) -> { + var error = Futures.completionExceptionCause(throwable); + if (error != null) { + CompletableFuture recordFuture; + synchronized (this) { + state = State.FAILED; + errorExposed = true; + recordFuture = this.recordFuture; + this.recordFuture = null; + apiCallInProgress = false; } - }); - } else { - this.error = term; - this.state = State.FAILED; - this.errorExposed = true; - var recordFuture = this.recordFuture; - this.recordFuture = null; - recordFuture.completeExceptionally(error); - } + recordFuture.completeExceptionally(error); + } + }); } else { secondRecordFuture = this.secondRecordFuture; this.secondRecordFuture = null; @@ -882,66 +809,46 @@ public void onPullSummary(PullSummary summary) { state = State.READY; } else { if (this.recordsFuture != null) { - var term = termSupplier.get(); - if (term == null) { - // list is pending, stream all - state = State.STREAMING; - updateRecordState(RecordState.NO_RECORD); - boltConnection - .pull(runSummary.queryId(), -1) - .thenCompose(conn -> conn.flush(this)) - .whenComplete((ignored, throwable) -> { - var error = Futures.completionExceptionCause(throwable); - if (error != null) { - CompletableFuture> recordsFuture; - synchronized (this) { - state = State.FAILED; - errorExposed = true; - recordsFuture = this.recordsFuture; - this.recordsFuture = null; - apiCallInProgress = false; - } - recordsFuture.completeExceptionally(error); + // list is pending, stream all + state = State.STREAMING; + updateRecordState(RecordState.NO_RECORD); + boltConnection + .pull(runSummary.queryId(), -1) + .thenCompose(conn -> conn.flush(this)) + .whenComplete((ignored, throwable) -> { + var error = Futures.completionExceptionCause(throwable); + if (error != null) { + CompletableFuture> recordsFuture; + synchronized (this) { + state = State.FAILED; + errorExposed = true; + recordsFuture = this.recordsFuture; + this.recordsFuture = null; + apiCallInProgress = false; } - }); - } else { - this.error = term; - this.state = State.FAILED; - this.errorExposed = true; - var recordsFuture = this.recordsFuture; - this.recordsFuture = null; - recordsFuture.completeExceptionally(error); - } + recordsFuture.completeExceptionally(error); + } + }); } else if (this.summaryFuture != null) { - var term = termSupplier.get(); - if (term == null) { - // consume is pending, discard all - state = State.DISCARDING; - boltConnection - .discard(runSummary.queryId(), -1) - .thenCompose(conn -> conn.flush(this)) - .whenComplete((ignored, throwable) -> { - var error = Futures.completionExceptionCause(throwable); - CompletableFuture summaryFuture; - if (error != null) { - synchronized (this) { - state = State.FAILED; - errorExposed = true; - summaryFuture = this.summaryFuture; - this.summaryFuture = null; - apiCallInProgress = false; - } - summaryFuture.completeExceptionally(error); + // consume is pending, discard all + state = State.DISCARDING; + boltConnection + .discard(runSummary.queryId(), -1) + .thenCompose(conn -> conn.flush(this)) + .whenComplete((ignored, throwable) -> { + var error = Futures.completionExceptionCause(throwable); + CompletableFuture summaryFuture; + if (error != null) { + synchronized (this) { + state = State.FAILED; + errorExposed = true; + summaryFuture = this.summaryFuture; + this.summaryFuture = null; + apiCallInProgress = false; } - }); - } else { - this.error = term; - this.state = State.FAILED; - this.errorExposed = true; - var summaryFuture = this.recordsFuture; - this.summaryFuture = null; - summaryFuture.completeExceptionally(error); - } + summaryFuture.completeExceptionally(error); + } + }); } else { state = State.READY; } @@ -1057,9 +964,6 @@ public void onPullSummary(PullSummary summary) { } } - if (throwableConsumer != null && error != null) { - throwableConsumer.accept(error); - } if (closeOnSummary) { var errorSnapshot = error; var recordFutureSnapshot = recordFuture; @@ -1140,12 +1044,9 @@ public void onComplete() { if (beginFuture != null) { if (!beginFuture.isDone()) { // not exposed yet, fail - if (throwableConsumer != null) { - throwableConsumer.accept(throwable); - } if (closeOnSummary) { boltConnection.close().whenComplete((ignored, closeThrowable) -> { - if (closeThrowable != null) { + if (closeThrowable != null && throwable != closeThrowable) { throwable.addSuppressed(closeThrowable); } beginFuture.completeExceptionally(throwable); @@ -1173,12 +1074,9 @@ public void onComplete() { if (!resultCursorFuture.isDone()) { // not exposed yet, fail - if (throwableConsumer != null) { - throwableConsumer.accept(throwable); - } if (closeOnSummary) { finisher = () -> boltConnection.close().whenComplete((ignored, closeThrowable) -> { - if (closeThrowable != null) { + if (closeThrowable != null && throwable != closeThrowable) { throwable.addSuppressed(closeThrowable); } resultCursorFuture.completeExceptionally(throwable); @@ -1225,16 +1123,13 @@ public void onComplete() { } } } - if (throwableConsumer != null) { - throwableConsumer.accept(throwable); - } var recordFutureSnapshot = recordFuture; var secondRecordFutureSnapshot = secondRecordFuture; var recordsFutureSnapshot = recordsFuture; var summaryFutureSnapshot = summaryFuture; if (closeOnSummary) { finisher = () -> boltConnection.close().whenComplete((ignored, closeThrowable) -> { - if (closeThrowable != null) { + if (closeThrowable != null && throwable != closeThrowable) { throwable.addSuppressed(closeThrowable); } if (peekFuture != null) { @@ -1300,65 +1195,41 @@ public CompletionStage pullAllFailureAsync() { } return switch (state) { case READY -> { - var term = termSupplier.get(); - if (term == null) { - apiCallInProgress = true; - summaryFuture = new CompletableFuture<>(); - state = State.STREAMING; - updateRecordState(RecordState.NO_RECORD); - boltConnection - .pull(runSummary.queryId(), -1) - .thenCompose(conn -> conn.flush(this)) - .whenComplete((ignored, throwable) -> { - var error = Futures.completionExceptionCause(throwable); - CompletableFuture summaryFuture; - if (error != null) { - synchronized (this) { - state = State.FAILED; - errorExposed = true; - summaryFuture = this.summaryFuture; - this.summaryFuture = null; - apiCallInProgress = false; - } - summaryFuture.completeExceptionally(error); + apiCallInProgress = true; + summaryFuture = new CompletableFuture<>(); + state = State.STREAMING; + updateRecordState(RecordState.NO_RECORD); + boltConnection + .pull(runSummary.queryId(), -1) + .thenCompose(conn -> conn.flush(this)) + .whenComplete((ignored, throwable) -> { + var error = Futures.completionExceptionCause(throwable); + CompletableFuture summaryFuture; + if (error != null) { + synchronized (this) { + state = State.FAILED; + errorExposed = true; + summaryFuture = this.summaryFuture; + this.summaryFuture = null; + apiCallInProgress = false; } - }); - yield summaryFuture.handle((ignored, throwable) -> throwable); - } else { - this.error = term; - this.state = State.FAILED; - this.errorExposed = true; - yield CompletableFuture.failedStage(error); - } + summaryFuture.completeExceptionally(error); + } + }); + yield summaryFuture.handle((ignored, throwable) -> throwable); } case STREAMING -> { - var term = termSupplier.get(); - if (term == null) { - apiCallInProgress = true; - // no pending request should be in place - recordsFuture = new CompletableFuture<>(); - keepRecords = true; - yield recordsFuture.handle((ignored, throwable) -> throwable); - } else { - this.error = term; - this.state = State.FAILED; - this.errorExposed = true; - yield CompletableFuture.failedStage(error); - } + apiCallInProgress = true; + // no pending request should be in place + recordsFuture = new CompletableFuture<>(); + keepRecords = true; + yield recordsFuture.handle((ignored, throwable) -> throwable); } case DISCARDING -> { - var term = termSupplier.get(); - if (term == null) { - apiCallInProgress = true; - // no pending request should be in place - summaryFuture = new CompletableFuture<>(); - yield summaryFuture.handle((ignored, throwable) -> throwable); - } else { - this.error = term; - this.state = State.FAILED; - this.errorExposed = true; - yield CompletableFuture.failedStage(error); - } + apiCallInProgress = true; + // no pending request should be in place + summaryFuture = new CompletableFuture<>(); + yield summaryFuture.handle((ignored, throwable) -> throwable); } case FAILED -> stageExposingError(null).handle((ignored, throwable) -> throwable); case SUCCEEDED -> CompletableFuture.completedStage(null); diff --git a/driver/src/main/java/org/neo4j/driver/internal/cursor/RxResultCursorImpl.java b/driver/src/main/java/org/neo4j/driver/internal/cursor/RxResultCursorImpl.java index 2474c3f40..9061562bb 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cursor/RxResultCursorImpl.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cursor/RxResultCursorImpl.java @@ -17,6 +17,7 @@ package org.neo4j.driver.internal.cursor; import static org.neo4j.driver.internal.types.InternalTypeSystem.TYPE_SYSTEM; +import static org.neo4j.driver.internal.util.ErrorUtil.newResultConsumedError; import java.util.Collections; import java.util.List; @@ -27,7 +28,6 @@ import java.util.concurrent.CompletionStage; import java.util.function.BiConsumer; import java.util.function.Consumer; -import java.util.function.Supplier; import org.neo4j.driver.Bookmark; import org.neo4j.driver.Logger; import org.neo4j.driver.Logging; @@ -48,6 +48,7 @@ import org.neo4j.driver.internal.bolt.api.summary.RunSummary; import org.neo4j.driver.internal.util.Futures; import org.neo4j.driver.internal.util.MetadataExtractor; +import org.neo4j.driver.summary.GqlStatusObject; import org.neo4j.driver.summary.ResultSummary; public class RxResultCursorImpl extends AbstractRecordStateResponseHandler implements RxResultCursor, ResponseHandler { @@ -77,36 +78,26 @@ public long resultAvailableAfter() { return -1; } }; - private final Logger log; private final BoltConnection boltConnection; private final Query query; private final RunSummary runSummary; private final Throwable runError; private final Consumer bookmarkConsumer; - private final Consumer throwableConsumer; - private final Supplier interruptSupplier; private final boolean closeOnSummary; - private final CompletableFuture summaryFuture = new CompletableFuture<>(); private final CompletableFuture consumedFuture = new CompletableFuture<>(); private final boolean legacyNotifications; - - private State state; + private State state = State.READY; private boolean discardPending; private boolean runErrorExposed; private boolean summaryExposed; - // subscription private BiConsumer recordConsumer; private long outstandingDemand; - private boolean recordConsumerFinished; - private boolean recordConsumerHadRequests; - private PullSummary pullSummary; private DiscardSummary discardSummary; private Throwable error; - private boolean interrupted; private enum State { READY, @@ -122,36 +113,22 @@ public RxResultCursorImpl( RunSummary runSummary, Throwable runError, Consumer bookmarkConsumer, - Consumer throwableConsumer, boolean closeOnSummary, - Supplier interruptSupplier, Logging logging) { this.boltConnection = boltConnection; this.legacyNotifications = new BoltProtocolVersion(5, 5).compareTo(boltConnection.protocolVersion()) > 0; this.query = query; - if (runError == null) { - this.runSummary = runSummary; - this.state = State.READY; - } else { - this.runSummary = EMPTY_RUN_SUMMARY; - this.state = State.FAILED; - this.summaryFuture.completeExceptionally(runError); - } + this.runSummary = runError == null ? runSummary : EMPTY_RUN_SUMMARY; this.runError = runError; this.bookmarkConsumer = bookmarkConsumer; this.closeOnSummary = closeOnSummary; - this.throwableConsumer = throwableConsumer; - this.interruptSupplier = interruptSupplier; this.log = logging.getLog(getClass()); - - var runErrorName = runError == null ? "null" : runError.getClass().getCanonicalName(); - log.trace("[%d] New instance (runError=%s)", hashCode(), runErrorName); + log.trace("[%d] New instance (runError=%s)", hashCode(), throwableName(runError)); } @Override public synchronized Throwable getRunError() { - var name = runError == null ? "null" : runError.getClass().getCanonicalName(); - log.trace("[%d] Run error explicitly retrieved (value=%s)", hashCode(), name); + log.trace("[%d] Run error explicitly retrieved (value=%s)", hashCode(), throwableName(runError)); runErrorExposed = true; return runError; } @@ -167,40 +144,24 @@ public CompletionStage consumed() { } @Override - public synchronized boolean isDone() { - return switch (state) { - case DISCARDING, STREAMING, READY -> false; - case FAILED -> runError == null || runErrorExposed; - case SUCCEEDED -> true; - }; + public boolean isDone() { + return summaryFuture.isDone(); } @Override public void installRecordConsumer(BiConsumer recordConsumer) { Objects.requireNonNull(recordConsumer); + if (summaryExposed) { + throw newResultConsumedError(); + } var runnable = NOOP_RUNNABLE; synchronized (this) { if (this.recordConsumer == null) { - this.recordConsumer = (record, throwable) -> { - var recordHash = record == null ? "null" : record.hashCode(); - var throwableName = - throwable == null ? "null" : throwable.getClass().getCanonicalName(); - try { - recordConsumer.accept(record, throwable); - log.trace( - "[%d] Record consumer notified with (record=%s, throwable=%s)", - hashCode(), recordHash, throwableName); - } catch (Throwable unexpectedThrowable) { - log.error( - String.format( - "[%d] Record consumer threw an error when notified with (record=%s, throwable=%s), this will be ignored", - hashCode(), recordHash, throwableName), - unexpectedThrowable); - } - }; + this.recordConsumer = safeRecordConsumer(recordConsumer); log.trace("[%d] Record consumer installed", hashCode()); - if (runError != null && !runErrorExposed) { - runnable = setupRecordConsumerErrorNotificationRunnable(runError, true); + if (runError != null) { + handleError(runError); + runnable = this::onComplete; } } else { log.warn("[%d] Only one record consumer is supported, this request will be ignored", hashCode()); @@ -214,42 +175,30 @@ public void request(long n) { if (n > 0) { var runnable = NOOP_RUNNABLE; synchronized (this) { - if (recordConsumerFinished) { - log.trace( - "[%d] Tried requesting more records after record consumer is finished, this request will be ignored", - hashCode()); - return; - } - recordConsumerHadRequests = true; updateRecordState(RecordState.NO_RECORD); log.trace("[%d] %d records requested in %s state", hashCode(), n, state); switch (state) { - case READY -> runnable = executeIfNotInterrupted(() -> { + case READY -> { var request = appendDemand(n); state = State.STREAMING; - return () -> boltConnection + runnable = () -> boltConnection .pull(runSummary.queryId(), request) .thenCompose(conn -> conn.flush(this)) .whenComplete((ignored, throwable) -> { throwable = Futures.completionExceptionCause(throwable); if (throwable != null) { - handleError(throwable, false); + handleError(throwable); onComplete(); } }); - }); + } case STREAMING -> appendDemand(n); - case FAILED -> runnable = runError != null - ? setupRecordConsumerErrorNotificationRunnable(runError, true) - : error != null - ? setupRecordConsumerErrorNotificationRunnable(error, false) - : NOOP_RUNNABLE; - case DISCARDING, SUCCEEDED -> {} + case FAILED, DISCARDING, SUCCEEDED -> {} } } runnable.run(); } else { - log.warn("[%d] %d records requested, negative amounts will be ignored", hashCode(), n); + log.warn("[%d] %d records requested, negative amounts are ignored", hashCode(), n); } } @@ -259,7 +208,7 @@ public void cancel() { synchronized (this) { log.trace("[%d] Cancellation requested in %s state", hashCode(), state); switch (state) { - case READY -> runnable = executeIfNotInterrupted(this::setupDiscardRunnable); + case READY -> runnable = setupDiscardRunnable(); case STREAMING -> discardPending = true; case DISCARDING, FAILED, SUCCEEDED -> {} } @@ -278,7 +227,14 @@ public CompletionStage summaryAsync() { summaryExposed = true; switch (state) { case SUCCEEDED, FAILED, DISCARDING -> {} - case READY -> runnable = executeIfNotInterrupted(this::setupDiscardRunnable); + case READY -> { + if (runError != null && recordConsumer == null) { + handleError(runError); + runnable = this::onComplete; + } else { + runnable = setupDiscardRunnable(); + } + } case STREAMING -> discardPending = true; } } @@ -337,15 +293,9 @@ public void onComplete() { Runnable runnable; synchronized (this) { log.trace("[%d] onComplete", hashCode()); - var throwable = interruptSupplier.get(); - if (throwable != null) { - handleError(throwable, true); - } else { - throwable = error; - } - if (throwable != null) { - runnable = setupCompletionRunnableWithError(throwable); + if (error != null) { + runnable = setupCompletionRunnableWithError(error); } else if (pullSummary != null) { runnable = setupCompletionRunnableWithPullSummary(); } else if (discardSummary != null) { @@ -359,20 +309,14 @@ public void onComplete() { @Override public synchronized void onError(Throwable throwable) { - if (log.isTraceEnabled()) { - log.error(String.format("[%d] onError", hashCode()), throwable); - } - handleError(throwable, false); + log.trace("[%d] onError", hashCode()); + handleError(throwable); } @Override public synchronized void onIgnored() { log.trace("[%d] onIgnored", hashCode()); - var throwable = interruptSupplier.get(); - if (throwable == null) { - throwable = IGNORED_ERROR; - } - onError(throwable); + handleError(IGNORED_ERROR); } @Override @@ -411,7 +355,12 @@ public synchronized CompletionStage discardAllFailureAsync() { @Override public synchronized CompletionStage pullAllFailureAsync() { log.trace("[%d] Pull all failure requested", hashCode()); - if (recordConsumer != null && !isDone()) { + var unfinishedState = + switch (state) { + case READY, STREAMING, DISCARDING -> true; + case FAILED, SUCCEEDED -> false; + }; + if (recordConsumer != null && unfinishedState) { return CompletableFuture.completedFuture( new TransactionNestingException( "You cannot run another query or begin a new transaction in the same session before you've fully consumed the previous run result.")); @@ -453,49 +402,14 @@ private synchronized Runnable setupDiscardRunnable() { .whenComplete((ignored, throwable) -> { throwable = Futures.completionExceptionCause(throwable); if (throwable != null) { - handleError(throwable, false); + handleError(throwable); onComplete(); } }); } - private synchronized Runnable executeIfNotInterrupted(Supplier runnableSupplier) { - var runnable = NOOP_RUNNABLE; - var throwable = interruptSupplier.get(); - if (throwable == null) { - runnable = runnableSupplier.get(); - } else { - log.trace("[%d] Interrupt signal detected upon handling request", hashCode()); - handleError(throwable, true); - runnable = this::onComplete; - } - return runnable; - } - - private synchronized Runnable setupRecordConsumerErrorNotificationRunnable(Throwable throwable, boolean runError) { - Runnable runnable; - if (recordConsumer != null) { - if (!recordConsumerFinished) { - if (runError) { - this.runErrorExposed = true; - } - recordConsumerFinished = true; - var recordConsumerRef = recordConsumer; - recordConsumer = NOOP_CONSUMER; - runnable = () -> recordConsumerRef.accept(null, throwable); - } else { - runnable = () -> - log.trace("[%d] Record consumer will not be notified as it has been finished", hashCode()); - } - } else { - runnable = () -> - log.trace("[%d] Record consumer will not be notified as it has not been installed", hashCode()); - } - return runnable; - } - private synchronized Runnable setupCompletionRunnableWithPullSummary() { - log.trace("[%d] Setting up completion with pull summary", hashCode()); + log.trace("[%d] Setting up completion with pull summary (hasMore=%b)", hashCode(), pullSummary.hasMore()); var runnable = NOOP_RUNNABLE; if (pullSummary.hasMore()) { pullSummary = null; @@ -508,7 +422,7 @@ private synchronized Runnable setupCompletionRunnableWithPullSummary() { .whenComplete((ignored, flushThrowable) -> { var error = Futures.completionExceptionCause(flushThrowable); if (error != null) { - handleError(error, false); + handleError(error); onComplete(); } }); @@ -522,7 +436,7 @@ private synchronized Runnable setupCompletionRunnableWithPullSummary() { .whenComplete((ignored, flushThrowable) -> { var error = Futures.completionExceptionCause(flushThrowable); if (error != null) { - handleError(error, false); + handleError(error); onComplete(); } }); @@ -541,67 +455,31 @@ private synchronized Runnable setupCompletionRunnableWithSummaryMetadata(Map { - bookmarkOpt.ifPresent(bookmarkConsumer); - var closeStage = closeOnSummary ? boltConnection.close() : CompletableFuture.completedStage(null); - closeStage.whenComplete((ignored, closeThrowable) -> { - var error = Futures.completionExceptionCause(closeThrowable); - if (error != null) { - if (log.isTraceEnabled()) { - log.error( - String.format( - "[%d] Failed to close connection while publishing summary", hashCode()), - error); - } - } - if (recordConsumerFinished) { - log.trace("[%d] Won't publish summary because recordConsumer is finished", hashCode()); - } else { - if (recordConsumerRef != null) { - if (recordConsumerHadRequests) { - recordConsumerRef.accept(null, null); - } else { - log.trace( - "[%d] Record consumer will not be notified as it had no requests", hashCode()); - } - } else { - log.trace( - "[%d] Record consumer will not be notified as it has not been installed", - hashCode()); - } - } - completeSummaryFuture(resultSummaryRef, null); - }); - }; + bookmarkOpt.ifPresent(bookmarkConsumer); + var completeRunnable = setupSummaryAndRecordCompletionRunnable(resultSummary, null); + runnable = () -> closeBoltConnection(completeRunnable); } else { runnable = this::onComplete; } return runnable; } - private ResultSummary resultSummary(Map metadata) { + private ResultSummary resultSummary(Map metadata, GqlStatusObject gqlStatusObject) { return METADATA_EXTRACTOR.extractSummary( query, boltConnection, runSummary.resultAvailableAfter(), metadata, legacyNotifications, - generateGqlStatusObject(runSummary.keys())); + gqlStatusObject); } @SuppressWarnings("DuplicatedCode") @@ -618,81 +496,74 @@ private static Optional databaseBookmark(Map me } private synchronized Runnable setupCompletionRunnableWithError(Throwable throwable) { - log.trace( - "[%d] Setting up completion with error %s", - hashCode(), throwable.getClass().getCanonicalName()); - var recordConsumerPresent = this.recordConsumer != null; - var recordConsumerFinished = this.recordConsumerFinished; - var recordConsumerErrorNotificationRunnable = setupRecordConsumerErrorNotificationRunnable(throwable, false); - var interrupted = this.interrupted; - return () -> { - ResultSummary summary = null; - try { - summary = resultSummary(Collections.emptyMap()); - } catch (Throwable summaryThrowable) { - if (!interrupted) { - throwable.addSuppressed(summaryThrowable); - } - } - - if (summary != null && recordConsumerPresent && !recordConsumerFinished) { - var summaryRef = summary; - closeBoltConnection(throwable, interrupted, () -> { - // notify recordConsumer when possible - recordConsumerErrorNotificationRunnable.run(); - completeSummaryFuture(summaryRef, null); - }); - } else { - closeBoltConnection(throwable, interrupted, () -> completeSummaryFuture(null, throwable)); - } - }; + log.trace("[%d] Setting up completion with error %s", hashCode(), throwableName(throwable)); + ResultSummary summary = null; + try { + summary = resultSummary(Collections.emptyMap(), null); + } catch (Throwable summaryThrowable) { + log.error(String.format("[%d] Failed to parse summary", hashCode()), summaryThrowable); + } + var completeRunnable = setupSummaryAndRecordCompletionRunnable(summary, throwable); + return () -> closeBoltConnection(completeRunnable); } - private void closeBoltConnection(Throwable throwable, boolean interrupted, Runnable runnable) { + private void closeBoltConnection(Runnable runnable) { var closeStage = closeOnSummary ? boltConnection.close() : CompletableFuture.completedStage(null); closeStage.whenComplete((ignored, closeThrowable) -> { - var error = Futures.completionExceptionCause(closeThrowable); - if (!interrupted) { - if (error != null) { - throwable.addSuppressed(error); - } - throwableConsumer.accept(throwable); + if (log.isTraceEnabled() && closeThrowable != null) { + log.error( + String.format("[%d] Failed to close connection", hashCode()), + Futures.completionExceptionCause(closeThrowable)); } runnable.run(); }); } - private synchronized void handleError(Throwable throwable, boolean interrupted) { + @SuppressWarnings("DuplicatedCode") + private synchronized void handleError(Throwable throwable) { + if (log.isTraceEnabled()) { + log.error(String.format("[%d] handleError", hashCode()), throwable); + } state = State.FAILED; throwable = Futures.completionExceptionCause(throwable); if (error == null) { error = throwable; - this.interrupted = interrupted; } else { - if (!this.interrupted) { - if (throwable == IGNORED_ERROR) { - return; - } - if (interrupted) { - error = throwable; - this.interrupted = true; + if (throwable == IGNORED_ERROR) { + return; + } + if (error == IGNORED_ERROR || (error instanceof Neo4jException && !(throwable instanceof Neo4jException))) { + error = throwable; + } + } + } + + private synchronized Runnable setupSummaryAndRecordCompletionRunnable(ResultSummary summary, Throwable throwable) { + var recordConsumerRef = recordConsumer; + this.recordConsumer = NOOP_CONSUMER; + + return () -> { + if (throwable != null) { + if (recordConsumerRef != null && recordConsumerRef != NOOP_CONSUMER) { + completeSummaryFuture(summary, null); + recordConsumerRef.accept(null, throwable); } else { - if (error instanceof Neo4jException && !(throwable instanceof Neo4jException)) { - // higher order error has occurred - if (error != IGNORED_ERROR) { - throwable.addSuppressed(error); - } - error = throwable; - } else { - error.addSuppressed(throwable); - } + completeSummaryFuture(null, throwable); + } + } else { + completeSummaryFuture(summary, null); + if (recordConsumerRef != null) { + recordConsumerRef.accept(null, null); } } - } + }; } private void completeSummaryFuture(ResultSummary summary, Throwable throwable) { throwable = Futures.completionExceptionCause(throwable); + log.trace( + "[%d] Completing summary future (summary=%s, throwable=%s)", + hashCode(), hash(summary), throwableName(throwable)); if (throwable != null) { consumedFuture.completeExceptionally(throwable); summaryFuture.completeExceptionally(throwable); @@ -701,4 +572,29 @@ private void completeSummaryFuture(ResultSummary summary, Throwable throwable) { summaryFuture.complete(summary); } } + + private BiConsumer safeRecordConsumer(BiConsumer recordConsumer) { + return (record, throwable) -> { + try { + recordConsumer.accept(record, throwable); + log.trace( + "[%d] Record consumer notified with (record=%s, throwable=%s)", + hashCode(), hash(record), throwableName(throwable)); + } catch (Throwable unexpectedThrowable) { + log.error( + String.format( + "[%d] Record consumer threw an error when notified with (record=%s, throwable=%s), this will be ignored", + hashCode(), hash(record), throwableName(throwable)), + unexpectedThrowable); + } + }; + } + + private String hash(Object object) { + return object == null ? "null" : String.valueOf(object.hashCode()); + } + + private String throwableName(Throwable throwable) { + return throwable == null ? "null" : throwable.getClass().getCanonicalName(); + } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/InternalResultTest.java b/driver/src/test/java/org/neo4j/driver/internal/InternalResultTest.java index c2fef1fc8..d379b41e2 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/InternalResultTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/InternalResultTest.java @@ -345,8 +345,7 @@ private Result createResult(int numberOfRecords) { when(connection.protocolVersion()).thenReturn(new BoltProtocolVersion(4, 3)); when(connection.serverAgent()).thenReturn("Neo4j/4.2.5"); - var resultCursor = new ResultCursorImpl( - connection, query, -1, ignored -> {}, ignored -> {}, false, () -> null, null, null); + var resultCursor = new ResultCursorImpl(connection, query, -1, ignored -> {}, false, null, null); var runSummary = mock(RunSummary.class); given(runSummary.keys()).willReturn(asList("k1", "k2")); resultCursor.onRunSummary(runSummary); diff --git a/driver/src/test/java/org/neo4j/driver/internal/InternalTransactionTest.java b/driver/src/test/java/org/neo4j/driver/internal/InternalTransactionTest.java index 737786e6f..93ed98b4d 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/InternalTransactionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/InternalTransactionTest.java @@ -68,6 +68,7 @@ class InternalTransactionTest { void setUp() { connection = connectionMock(new BoltProtocolVersion(4, 0)); var connectionProvider = mock(BoltConnectionProvider.class); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connectionProvider.connect(any(), any(), any(), any(), any(), any(), any(), any(), any())) .willReturn(CompletableFuture.completedFuture(connection)); given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/InternalAsyncSessionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/InternalAsyncSessionTest.java index cd506618b..04fa888e7 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/InternalAsyncSessionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/InternalAsyncSessionTest.java @@ -97,6 +97,7 @@ class InternalAsyncSessionTest { @BeforeEach void setUp() { connection = connectionMock(new BoltProtocolVersion(4, 0)); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.close()).willReturn(completedFuture(null)); connectionProvider = mock(BoltConnectionProvider.class); given(connectionProvider.connect(any(), any(), any(), any(), any(), any(), any(), any(), any())) diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/InternalAsyncTransactionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/InternalAsyncTransactionTest.java index e92f0605d..90edd0867 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/InternalAsyncTransactionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/InternalAsyncTransactionTest.java @@ -73,6 +73,7 @@ class InternalAsyncTransactionTest { @BeforeEach void setUp() { connection = connectionMock(BoltProtocolV4.INSTANCE.version()); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); var connectionProvider = mock(BoltConnectionProvider.class); given(connectionProvider.connect(any(), any(), any(), any(), any(), any(), any(), any(), any())) .willAnswer((Answer>) invocation -> { diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/LeakLoggingNetworkSessionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/LeakLoggingNetworkSessionTest.java index 6b461c343..13adc087d 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/LeakLoggingNetworkSessionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/LeakLoggingNetworkSessionTest.java @@ -93,6 +93,7 @@ void logsMessageWithStacktraceDuringFinalizationIfLeaked(TestInfo testInfo) thro var log = mock(Logger.class); when(logging.getLog(any(Class.class))).thenReturn(log); var connection = TestUtil.connectionMock(); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) .willReturn(completedFuture(connection)); setupConnectionAnswers(connection, List.of(handler -> { diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/NetworkSessionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/NetworkSessionTest.java index d9e9e1e8b..57b5af52e 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/NetworkSessionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/NetworkSessionTest.java @@ -88,6 +88,7 @@ class NetworkSessionTest { @BeforeEach void setUp() { connection = connectionMock(new BoltProtocolVersion(5, 4)); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.close()).willReturn(completedFuture(null)); connectionProvider = mock(BoltConnectionProvider.class); given(connectionProvider.connect(any(), any(), any(), any(), any(), any(), any(), any(), any())) @@ -308,6 +309,7 @@ void updatesBookmarkWhenTxIsClosed() { @Test void releasesConnectionWhenTxIsClosed() { + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) .willReturn(completedFuture(connection)); given(connection.run(any(), any())).willAnswer((Answer>) @@ -529,9 +531,11 @@ void shouldRunAfterBeginTxFailureOnBookmark() { void shouldBeginTxAfterBeginTxFailureOnBookmark() { var error = new RuntimeException("Hi"); var connection1 = connectionMock(new BoltProtocolVersion(5, 0)); + given(connection1.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection1.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) .willReturn(CompletableFuture.failedStage(error)); var connection2 = connectionMock(new BoltProtocolVersion(5, 0)); + given(connection2.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection2.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) .willReturn(CompletableFuture.completedStage(connection2)); setupConnectionAnswers(connection2, List.of(handler -> { diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java index 482f9aaf4..0f996e7f5 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java @@ -82,6 +82,7 @@ class UnmanagedTransactionTest { void shouldFlushOnRunAsync() { // Given var connection = connectionMock(new BoltProtocolVersion(5, 0)); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) .willReturn(completedFuture(connection)); given(connection.run(any(), any())).willReturn(CompletableFuture.completedStage(connection)); @@ -111,6 +112,7 @@ void shouldFlushOnRunAsync() { void shouldFlushOnRunRx() { // Given var connection = connectionMock(new BoltProtocolVersion(5, 0)); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) .willReturn(completedFuture(connection)); given(connection.run(any(), any())).willReturn(CompletableFuture.completedStage(connection)); @@ -139,6 +141,7 @@ void shouldFlushOnRunRx() { void shouldRollbackOnImplicitFailure() { // Given var connection = connectionMock(); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) .willReturn(completedFuture(connection)); given(connection.rollback()).willReturn(CompletableFuture.completedStage(connection)); @@ -169,6 +172,7 @@ void shouldRollbackOnImplicitFailure() { @Test void shouldBeginTransaction() { var connection = connectionMock(); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) .willReturn(completedFuture(connection)); setupConnectionAnswers(connection, List.of(handler -> { @@ -185,6 +189,7 @@ void shouldBeginTransaction() { @Test void shouldBeOpenAfterConstruction() { var connection = connectionMock(); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) .willReturn(completedFuture(connection)); setupConnectionAnswers(connection, List.of(handler -> { @@ -200,6 +205,7 @@ void shouldBeOpenAfterConstruction() { @Test void shouldBeClosedWhenMarkedAsTerminated() { var connection = connectionMock(); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) .willReturn(completedFuture(connection)); setupConnectionAnswers(connection, List.of(handler -> { @@ -216,6 +222,7 @@ void shouldBeClosedWhenMarkedAsTerminated() { @Test void shouldBeClosedWhenMarkedTerminatedAndClosed() { var connection = connectionMock(); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) .willReturn(completedFuture(connection)); setupConnectionAnswers(connection, List.of(handler -> { @@ -235,6 +242,7 @@ void shouldBeClosedWhenMarkedTerminatedAndClosed() { void shouldReleaseConnectionWhenBeginFails() { var error = new RuntimeException("Wrong bookmark!"); var connection = connectionMock(); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) .willReturn(CompletableFuture.completedStage(connection)); setupConnectionAnswers(connection, List.of(handler -> { @@ -266,6 +274,7 @@ void shouldReleaseConnectionWhenBeginFails() { @Test void shouldNotReleaseConnectionWhenBeginSucceeds() { var connection = connectionMock(); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) .willReturn(CompletableFuture.completedStage(connection)); setupConnectionAnswers(connection, List.of(handler -> { @@ -423,6 +432,7 @@ void shouldReleaseConnectionWhenTerminatedAndRolledBack() { @Test void shouldReleaseConnectionWhenClose() { var connection = connectionMock(); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.rollback()).willReturn(CompletableFuture.completedStage(connection)); setupConnectionAnswers(connection, List.of(handler -> { handler.onRollbackSummary(mock(RollbackSummary.class)); @@ -450,6 +460,7 @@ void shouldReleaseConnectionWhenClose() { void shouldReleaseConnectionOnConnectionAuthorizationExpiredExceptionFailure() { var exception = new AuthorizationExpiredException("code", "message"); var connection = connectionMock(); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) .willReturn(CompletableFuture.completedStage(connection)); setupConnectionAnswers(connection, List.of(handler -> { @@ -481,6 +492,7 @@ void shouldReleaseConnectionOnConnectionAuthorizationExpiredExceptionFailure() { @Test void shouldReleaseConnectionOnConnectionReadTimeoutExceptionFailure() { var connection = connectionMock(); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) .willReturn(CompletableFuture.completedStage(connection)); setupConnectionAnswers(connection, List.of(handler -> { @@ -523,6 +535,7 @@ private static Stream similarTransactionCompletingActionArgs() { void shouldReturnExistingStageOnSimilarCompletingAction( boolean protocolCommit, String initialAction, String similarAction) { var connection = connectionMock(); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.commit()).willReturn(CompletableFuture.completedStage(connection)); given(connection.rollback()).willReturn(CompletableFuture.completedStage(connection)); given(connection.flush(any())).willReturn(CompletableFuture.completedStage(null)); @@ -573,6 +586,7 @@ void shouldReturnFailingStageOnConflictingCompletingAction( String conflictingAction, String expectedErrorMsg) { var connection = connectionMock(); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.commit()).willReturn(CompletableFuture.completedStage(connection)); given(connection.rollback()).willReturn(CompletableFuture.completedStage(connection)); if (protocolActionCompleted) { @@ -636,6 +650,7 @@ private static Stream closingNotActionTransactionArgs() { void shouldReturnCompletedWithNullStageOnClosingInactiveTransactionExceptCommittingAborted( boolean protocolCommit, int expectedProtocolInvocations, String originalAction, Boolean commitOnClose) { var connection = connectionMock(); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.commit()).willReturn(CompletableFuture.completedStage(connection)); given(connection.rollback()).willReturn(CompletableFuture.completedStage(connection)); setupConnectionAnswers(connection, List.of(handler -> { @@ -677,6 +692,7 @@ void shouldReturnCompletedWithNullStageOnClosingInactiveTransactionExceptCommitt void shouldTerminateOnTerminateAsync() { // Given var connection = connectionMock(new BoltProtocolVersion(4, 0)); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) .willReturn(CompletableFuture.completedStage(connection)); given(connection.clear()).willReturn(CompletableFuture.completedStage(connection)); @@ -706,6 +722,7 @@ void shouldTerminateOnTerminateAsync() { void shouldServeTheSameStageOnTerminateAsync() { // Given var connection = connectionMock(new BoltProtocolVersion(4, 0)); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) .willReturn(CompletableFuture.completedStage(connection)); given(connection.clear()).willReturn(CompletableFuture.completedStage(connection)); @@ -735,6 +752,7 @@ void shouldServeTheSameStageOnTerminateAsync() { void shouldHandleTerminationWhenAlreadyTerminated() throws ExecutionException, InterruptedException { // Given var connection = connectionMock(new BoltProtocolVersion(4, 0)); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) .willReturn(CompletableFuture.completedStage(connection)); given(connection.run(any(), any())).willReturn(CompletableFuture.completedStage(connection)); @@ -771,6 +789,7 @@ void shouldHandleTerminationWhenAlreadyTerminated() throws ExecutionException, I void shouldThrowOnRunningNewQueriesWhenTransactionIsClosing(TransactionClosingTestParams testParams) { // Given var connection = connectionMock(); + given(connection.onLoop()).willReturn(CompletableFuture.completedStage(connection)); given(connection.beginTransaction(any(), any(), any(), any(), any(), any(), any(), any(), any())) .willReturn(CompletableFuture.completedStage(connection)); given(connection.commit()).willReturn(CompletableFuture.completedStage(connection)); diff --git a/driver/src/test/java/org/neo4j/driver/internal/cursor/ResultCursorImplTest.java b/driver/src/test/java/org/neo4j/driver/internal/cursor/ResultCursorImplTest.java index 4b96dce37..b765d82a0 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cursor/ResultCursorImplTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cursor/ResultCursorImplTest.java @@ -77,16 +77,7 @@ class ResultCursorImplTest { void beforeEach() { openMocks(this); given(connection.protocolVersion()).willReturn(new BoltProtocolVersion(5, 5)); - cursor = new ResultCursorImpl( - connection, - query, - fetchSize, - throwableConsumer, - bookmarkConsumer, - closeOnSummary, - termSupplier, - null, - null); + cursor = new ResultCursorImpl(connection, query, fetchSize, bookmarkConsumer, closeOnSummary, null, null); cursor.onRunSummary(runSummary); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/cursor/RxResultCursorImplTest.java b/driver/src/test/java/org/neo4j/driver/internal/cursor/RxResultCursorImplTest.java index b74bd87c9..edc710c3e 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cursor/RxResultCursorImplTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cursor/RxResultCursorImplTest.java @@ -17,17 +17,21 @@ package org.neo4j.driver.internal.cursor; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.then; import static org.mockito.Mockito.mock; import static org.mockito.MockitoAnnotations.openMocks; import java.util.List; +import java.util.concurrent.CompletionException; import java.util.function.BiConsumer; import java.util.function.Consumer; -import java.util.function.Supplier; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mock; import org.neo4j.driver.Logging; import org.neo4j.driver.Query; @@ -35,6 +39,7 @@ import org.neo4j.driver.internal.DatabaseBookmark; import org.neo4j.driver.internal.bolt.api.BoltConnection; import org.neo4j.driver.internal.bolt.api.BoltProtocolVersion; +import org.neo4j.driver.internal.bolt.api.BoltServerAddress; import org.neo4j.driver.internal.bolt.api.summary.RunSummary; class RxResultCursorImplTest { @@ -50,12 +55,6 @@ class RxResultCursorImplTest { @Mock Consumer bookmarkConsumer; - @Mock - Consumer throwableConsumer; - - @Mock - Supplier termSupplier; - @BeforeEach @SuppressWarnings("resource") void beforeEach() { @@ -63,54 +62,44 @@ void beforeEach() { given(connection.protocolVersion()).willReturn(new BoltProtocolVersion(5, 5)); } - @Test - void shouldNotifyRecordConsumerOfRunError() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void shouldNotifyRecordConsumerOfRunError(boolean getRunError) { // given var runError = mock(Throwable.class); - var cursor = new RxResultCursorImpl( - connection, - query, - null, - runError, - bookmarkConsumer, - throwableConsumer, - false, - termSupplier, - Logging.none()); + given(connection.serverAddress()).willReturn(new BoltServerAddress("localhost")); + var cursor = new RxResultCursorImpl(connection, query, null, runError, bookmarkConsumer, false, Logging.none()); + if (getRunError) { + assertEquals(runError, cursor.getRunError()); + } @SuppressWarnings("unchecked") BiConsumer recordConsumer = mock(BiConsumer.class); - cursor.installRecordConsumer(recordConsumer); // when - cursor.request(1); + cursor.installRecordConsumer(recordConsumer); // then then(recordConsumer).should().accept(null, runError); + assertNotNull(cursor.summaryAsync().toCompletableFuture().join()); } - @Test - void shouldNotNotifyRecordConsumerOfRunErrorWhenRunErrorIsRequested() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void shouldReturnSummaryWithRunError(boolean getRunError) { // given var runError = mock(Throwable.class); - var cursor = new RxResultCursorImpl( - connection, - query, - runSummary, - runError, - bookmarkConsumer, - throwableConsumer, - false, - termSupplier, - Logging.none()); - @SuppressWarnings("unchecked") - BiConsumer recordConsumer = mock(BiConsumer.class); - assertEquals(runError, cursor.getRunError()); + given(connection.serverAddress()).willReturn(new BoltServerAddress("localhost")); + var cursor = new RxResultCursorImpl(connection, query, null, runError, bookmarkConsumer, false, Logging.none()); + if (getRunError) { + assertEquals(runError, cursor.getRunError()); + } // when - cursor.installRecordConsumer(recordConsumer); + var summary = cursor.summaryAsync().toCompletableFuture(); // then - then(recordConsumer).shouldHaveNoInteractions(); + assertEquals( + runError, assertThrows(CompletionException.class, summary::join).getCause()); } @Test @@ -118,16 +107,8 @@ void shouldReturnKeys() { // given var keys = List.of("a", "b"); given(runSummary.keys()).willReturn(keys); - var cursor = new RxResultCursorImpl( - connection, - query, - runSummary, - null, - bookmarkConsumer, - throwableConsumer, - false, - termSupplier, - Logging.none()); + var cursor = + new RxResultCursorImpl(connection, query, runSummary, null, bookmarkConsumer, false, Logging.none()); // when & then assertEquals(keys, cursor.keys()); diff --git a/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxResultTest.java b/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxResultTest.java index a5cdd4368..ccfd541b9 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxResultTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxResultTest.java @@ -257,15 +257,7 @@ private InternalRxResult newRxResult(BoltConnection boltConnection) { private InternalRxResult newRxResult(BoltConnection boltConnection, RunSummary runSummary) { RxResultCursor cursor = new RxResultCursorImpl( - boltConnection, - mock(), - runSummary, - null, - databaseBookmark -> {}, - throwable -> {}, - false, - () -> null, - Logging.none()); + boltConnection, mock(), runSummary, null, databaseBookmark -> {}, false, Logging.none()); return newRxResult(cursor); }