diff --git a/driver/src/main/java/org/neo4j/driver/SessionConfig.java b/driver/src/main/java/org/neo4j/driver/SessionConfig.java index f9f32174ff..9682ffe8f2 100644 --- a/driver/src/main/java/org/neo4j/driver/SessionConfig.java +++ b/driver/src/main/java/org/neo4j/driver/SessionConfig.java @@ -41,6 +41,7 @@ public class SessionConfig private final AccessMode defaultAccessMode; private final String database; private final Optional fetchSize; + private final String impersonatedUser; private SessionConfig( Builder builder ) { @@ -48,6 +49,7 @@ private SessionConfig( Builder builder ) this.defaultAccessMode = builder.defaultAccessMode; this.database = builder.database; this.fetchSize = builder.fetchSize; + this.impersonatedUser = builder.impersonatedUser; } /** @@ -116,6 +118,7 @@ public Optional database() /** * This value if set, overrides the default fetch size set on {@link Config#fetchSize()}. + * * @return an optional value of fetch size. */ public Optional fetchSize() @@ -123,6 +126,16 @@ public Optional fetchSize() return fetchSize; } + /** + * The impersonated user the session is going to use for query execution. + * + * @return an optional value of the impersonated user. + */ + public Optional impersonatedUser() + { + return Optional.ofNullable( impersonatedUser ); + } + @Override public boolean equals( Object o ) { @@ -136,20 +149,20 @@ public boolean equals( Object o ) } SessionConfig that = (SessionConfig) o; return Objects.equals( bookmarks, that.bookmarks ) && defaultAccessMode == that.defaultAccessMode && Objects.equals( database, that.database ) - && Objects.equals( fetchSize, that.fetchSize ); + && Objects.equals( fetchSize, that.fetchSize ) && Objects.equals( impersonatedUser, that.impersonatedUser ); } @Override public int hashCode() { - return Objects.hash( bookmarks, defaultAccessMode, database ); + return Objects.hash( bookmarks, defaultAccessMode, database, impersonatedUser ); } @Override public String toString() { return "SessionParameters{" + "bookmarks=" + bookmarks + ", defaultAccessMode=" + defaultAccessMode + ", database='" + database + '\'' + - ", fetchSize=" + fetchSize + '}'; + ", fetchSize=" + fetchSize + "impersonatedUser=" + impersonatedUser + '}'; } /** @@ -161,6 +174,7 @@ public static class Builder private Iterable bookmarks = null; private AccessMode defaultAccessMode = AccessMode.WRITE; private String database = null; + private String impersonatedUser = null; private Builder() { @@ -268,6 +282,31 @@ public Builder withFetchSize( long size ) return this; } + /** + * Set the impersonated user that the newly created session is going to use for query execution. + *

+ * The principal provided to the driver on creation must have the necessary permissions to impersonate and run queries as the impersonated user. + *

+ * When {@link #withDatabase(String)} is not used, the driver will discover the default database name of the impersonated user on first session usage. + * From that moment, the discovered database name will be used as the default database name for the whole lifetime of the new session. + *

+ * Compatible with 4.4+ only. You MUST have all servers running 4.4 version or above and communicating over Bolt 4.4 or above. + * + * @param impersonatedUser the user to impersonate. Provided value should not be {@code null}. + * @return this builder + */ + public Builder withImpersonatedUser( String impersonatedUser ) + { + requireNonNull( impersonatedUser, "Impersonated user should not be null." ); + if ( impersonatedUser.isEmpty() ) + { + // Empty string is an illegal user. Fail fast on client. + throw new IllegalArgumentException( String.format( "Illegal impersonated user '%s'.", impersonatedUser ) ); + } + this.impersonatedUser = impersonatedUser; + return this; + } + public SessionConfig build() { return new SessionConfig( this ); diff --git a/driver/src/main/java/org/neo4j/driver/internal/DirectConnectionProvider.java b/driver/src/main/java/org/neo4j/driver/internal/DirectConnectionProvider.java index 143abe0b0a..7c6f614fc9 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/DirectConnectionProvider.java +++ b/driver/src/main/java/org/neo4j/driver/internal/DirectConnectionProvider.java @@ -18,6 +18,7 @@ */ package org.neo4j.driver.internal; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import org.neo4j.driver.internal.async.ConnectionContext; @@ -25,12 +26,13 @@ import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.spi.ConnectionPool; import org.neo4j.driver.internal.spi.ConnectionProvider; +import org.neo4j.driver.internal.util.Futures; +import static org.neo4j.driver.internal.async.ConnectionContext.PENDING_DATABASE_NAME_EXCEPTION_SUPPLIER; import static org.neo4j.driver.internal.messaging.request.MultiDatabaseUtil.supportsMultiDatabase; /** - * Simple {@link ConnectionProvider connection provider} that obtains connections form the given pool only for - * the given address. + * Simple {@link ConnectionProvider connection provider} that obtains connections form the given pool only for the given address. */ public class DirectConnectionProvider implements ConnectionProvider { @@ -46,7 +48,12 @@ public class DirectConnectionProvider implements ConnectionProvider @Override public CompletionStage acquireConnection( ConnectionContext context ) { - return acquireConnection().thenApply( connection -> new DirectConnection( connection, context.databaseName(), context.mode() ) ); + CompletableFuture databaseNameFuture = context.databaseNameFuture(); + databaseNameFuture.complete( DatabaseNameUtil.defaultDatabase() ); + return acquireConnection().thenApply( + connection -> new DirectConnection( connection, + Futures.joinNowOrElseThrow( databaseNameFuture, PENDING_DATABASE_NAME_EXCEPTION_SUPPLIER ), + context.mode(), context.impersonatedUser() ) ); } @Override diff --git a/driver/src/main/java/org/neo4j/driver/internal/ImpersonationUtil.java b/driver/src/main/java/org/neo4j/driver/internal/ImpersonationUtil.java new file mode 100644 index 0000000000..0a944c7001 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/ImpersonationUtil.java @@ -0,0 +1,46 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal; + +import org.neo4j.driver.exceptions.ClientException; +import org.neo4j.driver.internal.messaging.v44.BoltProtocolV44; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.internal.util.ServerVersion; + +public class ImpersonationUtil +{ + public static final String IMPERSONATION_UNSUPPORTED_ERROR_MESSAGE = + "Detected connection that does not support impersonation, please make sure to have all servers running 4.4 version or above and communicating" + + " over Bolt version 4.4 or above when using impersonation feature"; + + public static Connection ensureImpersonationSupport( Connection connection, String impersonatedUser ) + { + if ( impersonatedUser != null && !supportsImpersonation( connection ) ) + { + throw new ClientException( IMPERSONATION_UNSUPPORTED_ERROR_MESSAGE ); + } + return connection; + } + + private static boolean supportsImpersonation( Connection connection ) + { + return connection.serverVersion().greaterThanOrEqual( ServerVersion.v4_4_0 ) && + connection.protocol().version().compareTo( BoltProtocolV44.VERSION ) >= 0; + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/SessionFactoryImpl.java b/driver/src/main/java/org/neo4j/driver/internal/SessionFactoryImpl.java index 2c3a2c2836..8a9e85634e 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/SessionFactoryImpl.java +++ b/driver/src/main/java/org/neo4j/driver/internal/SessionFactoryImpl.java @@ -52,7 +52,8 @@ public NetworkSession newInstance( SessionConfig sessionConfig ) { BookmarkHolder bookmarkHolder = new DefaultBookmarkHolder( InternalBookmark.from( sessionConfig.bookmarks() ) ); return createSession( connectionProvider, retryLogic, parseDatabaseName( sessionConfig ), - sessionConfig.defaultAccessMode(), bookmarkHolder, parseFetchSize( sessionConfig ), logging ); + sessionConfig.defaultAccessMode(), bookmarkHolder, parseFetchSize( sessionConfig ), + sessionConfig.impersonatedUser().orElse( null ), logging ); } private long parseFetchSize( SessionConfig sessionConfig ) @@ -98,10 +99,10 @@ public ConnectionProvider getConnectionProvider() } private NetworkSession createSession( ConnectionProvider connectionProvider, RetryLogic retryLogic, DatabaseName databaseName, AccessMode mode, - BookmarkHolder bookmarkHolder, long fetchSize, Logging logging ) + BookmarkHolder bookmarkHolder, long fetchSize, String impersonatedUser, Logging logging ) { return leakedSessionsLoggingEnabled - ? new LeakLoggingNetworkSession( connectionProvider, retryLogic, databaseName, mode, bookmarkHolder, fetchSize, logging ) - : new NetworkSession( connectionProvider, retryLogic, databaseName, mode, bookmarkHolder, fetchSize, logging ); + ? new LeakLoggingNetworkSession( connectionProvider, retryLogic, databaseName, mode, bookmarkHolder, impersonatedUser, fetchSize, logging ) + : new NetworkSession( connectionProvider, retryLogic, databaseName, mode, bookmarkHolder, impersonatedUser, fetchSize, logging ); } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/ConnectionContext.java b/driver/src/main/java/org/neo4j/driver/internal/async/ConnectionContext.java index 34e50c4243..31efe28b90 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/ConnectionContext.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/ConnectionContext.java @@ -18,6 +18,9 @@ */ package org.neo4j.driver.internal.async; +import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; + import org.neo4j.driver.AccessMode; import org.neo4j.driver.Bookmark; import org.neo4j.driver.internal.DatabaseName; @@ -28,9 +31,13 @@ */ public interface ConnectionContext { - DatabaseName databaseName(); + Supplier PENDING_DATABASE_NAME_EXCEPTION_SUPPLIER = () -> new IllegalStateException( "Pending database name encountered." ); + + CompletableFuture databaseNameFuture(); AccessMode mode(); Bookmark rediscoveryBookmark(); + + String impersonatedUser(); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/ImmutableConnectionContext.java b/driver/src/main/java/org/neo4j/driver/internal/async/ImmutableConnectionContext.java index cbbe2d5c8a..ccaa86c389 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/ImmutableConnectionContext.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/ImmutableConnectionContext.java @@ -18,6 +18,8 @@ */ package org.neo4j.driver.internal.async; +import java.util.concurrent.CompletableFuture; + import org.neo4j.driver.AccessMode; import org.neo4j.driver.Bookmark; import org.neo4j.driver.internal.DatabaseName; @@ -35,21 +37,21 @@ public class ImmutableConnectionContext implements ConnectionContext private static final ConnectionContext SINGLE_DB_CONTEXT = new ImmutableConnectionContext( defaultDatabase(), empty(), AccessMode.READ ); private static final ConnectionContext MULTI_DB_CONTEXT = new ImmutableConnectionContext( systemDatabase(), empty(), AccessMode.READ ); - private final DatabaseName databaseName; + private final CompletableFuture databaseNameFuture; private final AccessMode mode; private final Bookmark rediscoveryBookmark; public ImmutableConnectionContext( DatabaseName databaseName, Bookmark bookmark, AccessMode mode ) { - this.databaseName = databaseName; + this.databaseNameFuture = CompletableFuture.completedFuture( databaseName ); this.rediscoveryBookmark = bookmark; this.mode = mode; } @Override - public DatabaseName databaseName() + public CompletableFuture databaseNameFuture() { - return databaseName; + return databaseNameFuture; } @Override @@ -64,10 +66,15 @@ public Bookmark rediscoveryBookmark() return rediscoveryBookmark; } + @Override + public String impersonatedUser() + { + return null; + } + /** - * A simple context is used to test connectivity with a remote server/cluster. - * As long as there is a read only service, the connection shall be established successfully. - * Depending on whether multidb is supported or not, this method returns different context for routing table discovery. + * A simple context is used to test connectivity with a remote server/cluster. As long as there is a read only service, the connection shall be established + * successfully. Depending on whether multidb is supported or not, this method returns different context for routing table discovery. */ public static ConnectionContext simple( boolean supportsMultiDb ) { diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/LeakLoggingNetworkSession.java b/driver/src/main/java/org/neo4j/driver/internal/async/LeakLoggingNetworkSession.java index 73c999ab2b..952c17e243 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/LeakLoggingNetworkSession.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/LeakLoggingNetworkSession.java @@ -33,9 +33,9 @@ public class LeakLoggingNetworkSession extends NetworkSession private final String stackTrace; public LeakLoggingNetworkSession( ConnectionProvider connectionProvider, RetryLogic retryLogic, DatabaseName databaseName, AccessMode mode, - BookmarkHolder bookmarkHolder, long fetchSize, Logging logging ) + BookmarkHolder bookmarkHolder, String impersonatedUser, long fetchSize, Logging logging ) { - super( connectionProvider, retryLogic, databaseName, mode, bookmarkHolder, fetchSize, logging ); + super( connectionProvider, retryLogic, databaseName, mode, bookmarkHolder, impersonatedUser, fetchSize, logging ); this.stackTrace = captureStackTrace(); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkSession.java b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkSession.java index 1ff1aabe66..a967a7d990 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkSession.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkSession.java @@ -18,6 +18,7 @@ */ package org.neo4j.driver.internal.async; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; import java.util.concurrent.atomic.AtomicBoolean; @@ -34,6 +35,7 @@ import org.neo4j.driver.internal.BookmarkHolder; import org.neo4j.driver.internal.DatabaseName; import org.neo4j.driver.internal.FailableCursor; +import org.neo4j.driver.internal.ImpersonationUtil; import org.neo4j.driver.internal.cursor.AsyncResultCursor; import org.neo4j.driver.internal.cursor.ResultCursorFactory; import org.neo4j.driver.internal.cursor.RxResultCursor; @@ -63,14 +65,17 @@ public class NetworkSession private final AtomicBoolean open = new AtomicBoolean( true ); public NetworkSession( ConnectionProvider connectionProvider, RetryLogic retryLogic, DatabaseName databaseName, AccessMode mode, - BookmarkHolder bookmarkHolder, long fetchSize, Logging logging ) + BookmarkHolder bookmarkHolder, String impersonatedUser, long fetchSize, Logging logging ) { this.connectionProvider = connectionProvider; this.mode = mode; this.retryLogic = retryLogic; this.log = new PrefixedLogger( "[" + hashCode() + "]", logging.getLog( getClass() ) ); this.bookmarkHolder = bookmarkHolder; - this.connectionContext = new NetworkSessionConnectionContext( databaseName, bookmarkHolder.getBookmark() ); + CompletableFuture databaseNameFuture = databaseName.databaseName() + .map( ignored -> CompletableFuture.completedFuture( databaseName ) ) + .orElse( new CompletableFuture<>() ); + this.connectionContext = new NetworkSessionConnectionContext( databaseNameFuture, bookmarkHolder.getBookmark(), impersonatedUser ); this.fetchSize = fetchSize; } @@ -104,6 +109,7 @@ public CompletionStage beginTransactionAsync( AccessMode m // create a chain that acquires connection and starts a transaction CompletionStage newTransactionStage = ensureNoOpenTxBeforeStartingTx() .thenCompose( ignore -> acquireConnection( mode ) ) + .thenApply( connection -> ImpersonationUtil.ensureImpersonationSupport( connection, connection.impersonatedUser() ) ) .thenCompose( connection -> { UnmanagedTransaction tx = new UnmanagedTransaction( connection, bookmarkHolder, fetchSize ); @@ -227,6 +233,7 @@ private CompletionStage buildResultCursorFactory( Query que return ensureNoOpenTxBeforeRunningQuery() .thenCompose( ignore -> acquireConnection( mode ) ) + .thenApply( connection -> ImpersonationUtil.ensureImpersonationSupport( connection, connection.impersonatedUser() ) ) .thenCompose( connection -> { @@ -350,18 +357,20 @@ private void ensureSessionIsOpen() */ private static class NetworkSessionConnectionContext implements ConnectionContext { - private final DatabaseName databaseName; + private final CompletableFuture databaseNameFuture; private AccessMode mode; // This bookmark is only used for rediscovery. // It has to be the initial bookmark given at the creation of the session. // As only that bookmark could carry extra system bookmarks private final Bookmark rediscoveryBookmark; + private final String impersonatedUser; - private NetworkSessionConnectionContext( DatabaseName databaseName, Bookmark bookmark ) + private NetworkSessionConnectionContext( CompletableFuture databaseNameFuture, Bookmark bookmark, String impersonatedUser ) { - this.databaseName = databaseName; + this.databaseNameFuture = databaseNameFuture; this.rediscoveryBookmark = bookmark; + this.impersonatedUser = impersonatedUser; } private ConnectionContext contextWithMode( AccessMode mode ) @@ -371,9 +380,9 @@ private ConnectionContext contextWithMode( AccessMode mode ) } @Override - public DatabaseName databaseName() + public CompletableFuture databaseNameFuture() { - return databaseName; + return databaseNameFuture; } @Override @@ -387,6 +396,12 @@ public Bookmark rediscoveryBookmark() { return rediscoveryBookmark; } + + @Override + public String impersonatedUser() + { + return impersonatedUser; + } } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/connection/DirectConnection.java b/driver/src/main/java/org/neo4j/driver/internal/async/connection/DirectConnection.java index 1d74213de7..ff0f663e55 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/connection/DirectConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/connection/DirectConnection.java @@ -38,12 +38,14 @@ public class DirectConnection implements Connection private final Connection delegate; private final AccessMode mode; private final DatabaseName databaseName; + private final String impersonatedUser; - public DirectConnection( Connection delegate, DatabaseName databaseName, AccessMode mode ) + public DirectConnection( Connection delegate, DatabaseName databaseName, AccessMode mode, String impersonatedUser ) { this.delegate = delegate; this.mode = mode; this.databaseName = databaseName; + this.impersonatedUser = impersonatedUser; } public Connection connection() @@ -147,6 +149,12 @@ public DatabaseName databaseName() return this.databaseName; } + @Override + public String impersonatedUser() + { + return impersonatedUser; + } + @Override public void flush() { diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/connection/RoutingConnection.java b/driver/src/main/java/org/neo4j/driver/internal/async/connection/RoutingConnection.java index 3799224611..9735a6382e 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/connection/RoutingConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/connection/RoutingConnection.java @@ -40,12 +40,14 @@ public class RoutingConnection implements Connection private final AccessMode accessMode; private final RoutingErrorHandler errorHandler; private final DatabaseName databaseName; + private final String impersonatedUser; - public RoutingConnection( Connection delegate, DatabaseName databaseName, AccessMode accessMode, RoutingErrorHandler errorHandler ) + public RoutingConnection( Connection delegate, DatabaseName databaseName, AccessMode accessMode, String impersonatedUser, RoutingErrorHandler errorHandler ) { this.delegate = delegate; this.databaseName = databaseName; this.accessMode = accessMode; + this.impersonatedUser = impersonatedUser; this.errorHandler = errorHandler; } @@ -151,6 +153,11 @@ public DatabaseName databaseName() return this.databaseName; } + @Override + public String impersonatedUser() + { + return impersonatedUser; + } private RoutingResponseHandler newRoutingResponseHandler( ResponseHandler handler ) { diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImpl.java b/driver/src/main/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImpl.java index b1ef41bbb3..3438516ecf 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImpl.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImpl.java @@ -52,8 +52,6 @@ public class ConnectionPoolImpl implements ConnectionPool { - public static final String CONNECTION_POOL_CLOSED_ERROR_MESSAGE = "Pool closed"; - private final ChannelConnector connector; private final Bootstrap bootstrap; private final NettyChannelTracker nettyChannelTracker; diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterComposition.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterComposition.java index 7bddb70c21..4083fff4c8 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterComposition.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterComposition.java @@ -36,13 +36,15 @@ public final class ClusterComposition private final Set writers; private final Set routers; private final long expirationTimestamp; + private final String databaseName; - private ClusterComposition( long expirationTimestamp ) + private ClusterComposition( long expirationTimestamp, String databaseName ) { this.readers = new LinkedHashSet<>(); this.writers = new LinkedHashSet<>(); this.routers = new LinkedHashSet<>(); this.expirationTimestamp = expirationTimestamp; + this.databaseName = databaseName; } /** @@ -52,9 +54,10 @@ public ClusterComposition( long expirationTimestamp, Set readers, Set writers, - Set routers ) + Set routers, + String databaseName ) { - this( expirationTimestamp ); + this( expirationTimestamp, databaseName ); this.readers.addAll( readers ); this.writers.addAll( writers ); this.routers.addAll( routers ); @@ -90,6 +93,11 @@ public long expirationTimestamp() return this.expirationTimestamp; } + public String databaseName() + { + return databaseName; + } + @Override public boolean equals( Object o ) { @@ -103,6 +111,7 @@ public boolean equals( Object o ) } ClusterComposition that = (ClusterComposition) o; return expirationTimestamp == that.expirationTimestamp && + Objects.equals( databaseName, that.databaseName ) && Objects.equals( readers, that.readers ) && Objects.equals( writers, that.writers ) && Objects.equals( routers, that.routers ); @@ -111,7 +120,7 @@ public boolean equals( Object o ) @Override public int hashCode() { - return Objects.hash( readers, writers, routers, expirationTimestamp ); + return Objects.hash( readers, writers, routers, expirationTimestamp, databaseName ); } @Override @@ -122,6 +131,7 @@ public String toString() ", writers=" + writers + ", routers=" + routers + ", expirationTimestamp=" + expirationTimestamp + + ", databaseName=" + databaseName + '}'; } @@ -132,16 +142,12 @@ public static ClusterComposition parse( Record record, long now ) return null; } - final ClusterComposition result = new ClusterComposition( expirationTimestamp( now, record ) ); - record.get( "servers" ).asList( new Function() + final ClusterComposition result = new ClusterComposition( expirationTimestamp( now, record ), record.get( "db" ).asString( null ) ); + record.get( "servers" ).asList( (Function) value -> { - @Override - public Void apply( Value value ) - { - result.servers( value.get( "role" ).asString() ) - .addAll( value.get( "addresses" ).asList( OF_BoltServerAddress ) ); - return null; - } + result.servers( value.get( "role" ).asString() ) + .addAll( value.get( "addresses" ).asList( OF_BoltServerAddress ) ); + return null; } ); return result; } diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterCompositionProvider.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterCompositionProvider.java index 997858d8e7..e71f800cd5 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterCompositionProvider.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterCompositionProvider.java @@ -26,5 +26,5 @@ public interface ClusterCompositionProvider { - CompletionStage getClusterComposition( Connection connection, DatabaseName databaseName, Bookmark bookmark ); + CompletionStage getClusterComposition( Connection connection, DatabaseName databaseName, Bookmark bookmark, String impersonatedUser ); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterRoutingTable.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterRoutingTable.java index 3604a5ffc9..432c17e811 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterRoutingTable.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterRoutingTable.java @@ -152,10 +152,16 @@ public boolean preferInitialRouter() return preferInitialRouter; } + @Override + public long expirationTimestamp() + { + return expirationTimestamp; + } + @Override public synchronized String toString() { return format( "Ttl %s, currentTime %s, routers %s, writers %s, readers %s, database '%s'", - expirationTimestamp, clock.millis(), routers, writers, readers, databaseName.description() ); + expirationTimestamp, clock.millis(), routers, writers, readers, databaseName.description() ); } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/MultiDatabasesRoutingProcedureRunner.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/MultiDatabasesRoutingProcedureRunner.java index d32b9d987f..43a054a227 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/MultiDatabasesRoutingProcedureRunner.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/MultiDatabasesRoutingProcedureRunner.java @@ -67,6 +67,6 @@ Query procedureQuery(ServerVersion serverVersion, DatabaseName databaseName ) @Override DirectConnection connection( Connection connection ) { - return new DirectConnection( connection, systemDatabase(), AccessMode.READ ); + return new DirectConnection( connection, systemDatabase(), AccessMode.READ, null ); } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/Rediscovery.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/Rediscovery.java index 553efb7511..abd5ddc784 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/Rediscovery.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/Rediscovery.java @@ -36,12 +36,14 @@ public interface Rediscovery *

* Implementation must be thread safe to be called with distinct routing tables concurrently. The routing table instance may be modified. * - * @param routingTable the routing table for cluster composition lookup - * @param connectionPool the connection pool for connection acquisition - * @param bookmark the bookmark that is presented to the server + * @param routingTable the routing table for cluster composition lookup + * @param connectionPool the connection pool for connection acquisition + * @param bookmark the bookmark that is presented to the server + * @param impersonatedUser the impersonated user for cluster composition lookup, should be {@code null} for non-impersonated requests * @return cluster composition lookup result */ - CompletionStage lookupClusterComposition( RoutingTable routingTable, ConnectionPool connectionPool, Bookmark bookmark ); + CompletionStage lookupClusterComposition( RoutingTable routingTable, ConnectionPool connectionPool, Bookmark bookmark, + String impersonatedUser ); List resolve() throws UnknownHostException; } diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/RediscoveryImpl.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/RediscoveryImpl.java index 8013c560c2..06985c0379 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/RediscoveryImpl.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/RediscoveryImpl.java @@ -40,6 +40,7 @@ import org.neo4j.driver.exceptions.ServiceUnavailableException; import org.neo4j.driver.internal.BoltServerAddress; import org.neo4j.driver.internal.DomainNameResolver; +import org.neo4j.driver.internal.ImpersonationUtil; import org.neo4j.driver.internal.ResolvedBoltServerAddress; import org.neo4j.driver.internal.spi.ConnectionPool; import org.neo4j.driver.internal.util.Futures; @@ -91,20 +92,20 @@ public RediscoveryImpl( BoltServerAddress initialRouter, RoutingSettings setting */ @Override public CompletionStage lookupClusterComposition( RoutingTable routingTable, ConnectionPool connectionPool, - Bookmark bookmark ) + Bookmark bookmark, String impersonatedUser ) { CompletableFuture result = new CompletableFuture<>(); // if we failed discovery, we will chain all errors into this one. ServiceUnavailableException baseError = new ServiceUnavailableException( String.format( NO_ROUTERS_AVAILABLE, routingTable.database().description() ) ); - lookupClusterComposition( routingTable, connectionPool, 0, 0, result, bookmark, baseError ); + lookupClusterComposition( routingTable, connectionPool, 0, 0, result, bookmark, impersonatedUser, baseError ); return result; } - private void lookupClusterComposition( RoutingTable routingTable, ConnectionPool pool, - int failures, long previousDelay, CompletableFuture result, Bookmark bookmark, + private void lookupClusterComposition( RoutingTable routingTable, ConnectionPool pool, int failures, long previousDelay, + CompletableFuture result, Bookmark bookmark, String impersonatedUser, Throwable baseError ) { - lookup( routingTable, pool, bookmark, baseError ) + lookup( routingTable, pool, bookmark, impersonatedUser, baseError ) .whenComplete( ( compositionLookupResult, completionError ) -> { @@ -130,7 +131,8 @@ else if ( compositionLookupResult != null ) long nextDelay = Math.max( settings.retryTimeoutDelay(), previousDelay * 2 ); log.info( "Unable to fetch new routing table, will try again in " + nextDelay + "ms" ); eventExecutorGroup.next().schedule( - () -> lookupClusterComposition( routingTable, pool, newFailures, nextDelay, result, bookmark, baseError ), + () -> lookupClusterComposition( routingTable, pool, newFailures, nextDelay, result, bookmark, impersonatedUser, + baseError ), nextDelay, TimeUnit.MILLISECONDS ); } @@ -139,27 +141,28 @@ else if ( compositionLookupResult != null ) } private CompletionStage lookup( RoutingTable routingTable, ConnectionPool connectionPool, Bookmark bookmark, - Throwable baseError ) + String impersonatedUser, Throwable baseError ) { CompletionStage compositionStage; if ( routingTable.preferInitialRouter() ) { - compositionStage = lookupOnInitialRouterThenOnKnownRouters( routingTable, connectionPool, bookmark, baseError ); + compositionStage = lookupOnInitialRouterThenOnKnownRouters( routingTable, connectionPool, bookmark, impersonatedUser, baseError ); } else { - compositionStage = lookupOnKnownRoutersThenOnInitialRouter( routingTable, connectionPool, bookmark, baseError ); + compositionStage = lookupOnKnownRoutersThenOnInitialRouter( routingTable, connectionPool, bookmark, impersonatedUser, baseError ); } return compositionStage; } private CompletionStage lookupOnKnownRoutersThenOnInitialRouter( RoutingTable routingTable, ConnectionPool connectionPool, - Bookmark bookmark, Throwable baseError ) + Bookmark bookmark, String impersonatedUser, + Throwable baseError ) { Set seenServers = new HashSet<>(); - return lookupOnKnownRouters( routingTable, connectionPool, seenServers, bookmark, baseError ) + return lookupOnKnownRouters( routingTable, connectionPool, seenServers, bookmark, impersonatedUser, baseError ) .thenCompose( compositionLookupResult -> { @@ -168,19 +171,16 @@ private CompletionStage lookupOnKnownRoutersThen return completedFuture( compositionLookupResult ); } - return lookupOnInitialRouter( - routingTable, connectionPool, - seenServers, bookmark, - baseError ); + return lookupOnInitialRouter( routingTable, connectionPool, seenServers, bookmark, impersonatedUser, baseError ); } ); } - private CompletionStage lookupOnInitialRouterThenOnKnownRouters( RoutingTable routingTable, - ConnectionPool connectionPool, Bookmark bookmark, + private CompletionStage lookupOnInitialRouterThenOnKnownRouters( RoutingTable routingTable, ConnectionPool connectionPool, + Bookmark bookmark, String impersonatedUser, Throwable baseError ) { Set seenServers = emptySet(); - return lookupOnInitialRouter( routingTable, connectionPool, seenServers, bookmark, baseError ) + return lookupOnInitialRouter( routingTable, connectionPool, seenServers, bookmark, impersonatedUser, baseError ) .thenCompose( compositionLookupResult -> { @@ -189,16 +189,13 @@ private CompletionStage lookupOnInitialRouterThe return completedFuture( compositionLookupResult ); } - return lookupOnKnownRouters( - routingTable, connectionPool, - new HashSet<>(), bookmark, - baseError ); + return lookupOnKnownRouters( routingTable, connectionPool, new HashSet<>(), bookmark, impersonatedUser, baseError ); } ); } private CompletionStage lookupOnKnownRouters( RoutingTable routingTable, ConnectionPool connectionPool, Set seenServers, Bookmark bookmark, - Throwable baseError ) + String impersonatedUser, Throwable baseError ) { BoltServerAddress[] addresses = routingTable.routers().toArray(); @@ -215,7 +212,7 @@ private CompletionStage lookupOnKnownRouters( Ro } else { - return lookupOnRouter( address, true, routingTable, connectionPool, seenServers, bookmark, baseError ); + return lookupOnRouter( address, true, routingTable, connectionPool, seenServers, bookmark, impersonatedUser, baseError ); } } ); } @@ -224,7 +221,7 @@ private CompletionStage lookupOnKnownRouters( Ro private CompletionStage lookupOnInitialRouter( RoutingTable routingTable, ConnectionPool connectionPool, Set seenServers, Bookmark bookmark, - Throwable baseError ) + String impersonatedUser, Throwable baseError ) { List resolvedRouters; try @@ -248,15 +245,15 @@ private CompletionStage lookupOnInitialRouter( R { return completedFuture( composition ); } - return lookupOnRouter( address, false, routingTable, connectionPool, null, bookmark, baseError ); + return lookupOnRouter( address, false, routingTable, connectionPool, null, bookmark, impersonatedUser, baseError ); } ); } return result.thenApply( composition -> composition != null ? new ClusterCompositionLookupResult( composition, resolvedRouterSet ) : null ); } - private CompletionStage lookupOnRouter( BoltServerAddress routerAddress, boolean resolveAddress, - RoutingTable routingTable, ConnectionPool connectionPool, - Set seenServers, Bookmark bookmark, Throwable baseError ) + private CompletionStage lookupOnRouter( BoltServerAddress routerAddress, boolean resolveAddress, RoutingTable routingTable, + ConnectionPool connectionPool, Set seenServers, Bookmark bookmark, + String impersonatedUser, Throwable baseError ) { CompletableFuture addressFuture = CompletableFuture.completedFuture( routerAddress ); @@ -264,7 +261,8 @@ private CompletionStage lookupOnRouter( BoltServerAddress ro .thenApply( address -> resolveAddress ? resolveByDomainNameOrThrowCompletionException( address, routingTable ) : address ) .thenApply( address -> addAndReturn( seenServers, address ) ) .thenCompose( connectionPool::acquire ) - .thenCompose( connection -> provider.getClusterComposition( connection, routingTable.database(), bookmark ) ) + .thenApply( connection -> ImpersonationUtil.ensureImpersonationSupport( connection, impersonatedUser ) ) + .thenCompose( connection -> provider.getClusterComposition( connection, routingTable.database(), bookmark, impersonatedUser ) ) .handle( ( response, error ) -> { Throwable cause = Futures.completionExceptionCause( error ); @@ -282,7 +280,8 @@ private CompletionStage lookupOnRouter( BoltServerAddress ro private ClusterComposition handleRoutingProcedureError( Throwable error, RoutingTable routingTable, BoltServerAddress routerAddress, Throwable baseError ) { - if ( error instanceof SecurityException || error instanceof FatalDiscoveryException ) + if ( error instanceof SecurityException || error instanceof FatalDiscoveryException || + (error instanceof IllegalStateException && ConnectionPool.CONNECTION_POOL_CLOSED_ERROR_MESSAGE.equals( error.getMessage() )) ) { // auth error or routing error happened, terminate the discovery procedure immediately throw new CompletionException( error ); diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/RouteMessageRoutingProcedureRunner.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/RouteMessageRoutingProcedureRunner.java index 2c4a756284..095944b069 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/RouteMessageRoutingProcedureRunner.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/RouteMessageRoutingProcedureRunner.java @@ -66,12 +66,12 @@ protected RouteMessageRoutingProcedureRunner( RoutingContext routingContext, Sup } @Override - public CompletionStage run( Connection connection, DatabaseName databaseName, Bookmark bookmark ) + public CompletionStage run( Connection connection, DatabaseName databaseName, Bookmark bookmark, String impersonatedUser ) { CompletableFuture> completableFuture = createCompletableFuture.get(); - DirectConnection directConnection = toDirectConnection( connection, databaseName ); - directConnection.writeAndFlush( new RouteMessage( routingContext, bookmark, databaseName.databaseName().orElse( null ) ), + DirectConnection directConnection = toDirectConnection( connection, databaseName, impersonatedUser ); + directConnection.writeAndFlush( new RouteMessage( routingContext, bookmark, databaseName.databaseName().orElse( null ), impersonatedUser ), new RouteMessageResponseHandler( completableFuture ) ); return completableFuture .thenApply( routingTable -> new RoutingProcedureResponse( getQuery( databaseName ), singletonList( toRecord( routingTable ) ) ) ) @@ -84,9 +84,9 @@ private Record toRecord( Map routingTable ) return new InternalRecord( new ArrayList<>( routingTable.keySet() ), routingTable.values().toArray( new Value[0] ) ); } - private DirectConnection toDirectConnection( Connection connection, DatabaseName databaseName ) + private DirectConnection toDirectConnection( Connection connection, DatabaseName databaseName, String impersonatedUser ) { - return new DirectConnection( connection, databaseName, AccessMode.READ ); + return new DirectConnection( connection, databaseName, AccessMode.READ, impersonatedUser ); } private Query getQuery( DatabaseName databaseName ) diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureClusterCompositionProvider.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureClusterCompositionProvider.java index f254fa2386..84ae9577cb 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureClusterCompositionProvider.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureClusterCompositionProvider.java @@ -61,7 +61,8 @@ public RoutingProcedureClusterCompositionProvider( Clock clock, RoutingContext r } @Override - public CompletionStage getClusterComposition( Connection connection, DatabaseName databaseName, Bookmark bookmark ) + public CompletionStage getClusterComposition( Connection connection, DatabaseName databaseName, Bookmark bookmark, + String impersonatedUser ) { RoutingProcedureRunner runner; @@ -78,7 +79,7 @@ else if ( supportsMultiDatabase( connection ) ) runner = singleDatabaseRoutingProcedureRunner; } - return runner.run( connection, databaseName, bookmark ) + return runner.run( connection, databaseName, bookmark, impersonatedUser ) .thenApply( this::processRoutingResponse ); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureRunner.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureRunner.java index 91925e8b7c..c91be6fb7a 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureRunner.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureRunner.java @@ -32,10 +32,11 @@ public interface RoutingProcedureRunner /** * Run the calls to the server * - * @param connection The connection which will be used to call the server - * @param databaseName The database name - * @param bookmark The bookmark used to query the routing information + * @param connection The connection which will be used to call the server + * @param databaseName The database name + * @param bookmark The bookmark used to query the routing information + * @param impersonatedUser The impersonated user, should be {@code null} for non-impersonated requests * @return The routing table */ - CompletionStage run( Connection connection, DatabaseName databaseName, Bookmark bookmark ); + CompletionStage run( Connection connection, DatabaseName databaseName, Bookmark bookmark, String impersonatedUser ); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingTable.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingTable.java index 7fa7000bda..dbe58f22ee 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingTable.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingTable.java @@ -49,4 +49,6 @@ public interface RoutingTable void replaceRouterIfPresent( BoltServerAddress oldRouter, BoltServerAddress newRouter ); boolean preferInitialRouter(); + + long expirationTimestamp(); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingTableHandler.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingTableHandler.java index 5e3274a41a..f9b01206f4 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingTableHandler.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingTableHandler.java @@ -33,5 +33,7 @@ public interface RoutingTableHandler extends RoutingErrorHandler CompletionStage ensureRoutingTable( ConnectionContext context ); + CompletionStage updateRoutingTable( ClusterCompositionLookupResult compositionLookupResult ); + RoutingTable routingTable(); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingTableHandlerImpl.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingTableHandlerImpl.java index 62af10e3a9..5385a5cdda 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingTableHandlerImpl.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingTableHandlerImpl.java @@ -72,6 +72,7 @@ public void onWriteFailure( BoltServerAddress address ) routingTable.forgetWriter( address ); } + @Override public synchronized CompletionStage ensureRoutingTable( ConnectionContext context ) { if ( refreshRoutingTableFuture != null ) @@ -87,19 +88,19 @@ else if ( routingTable.isStaleFor( context.mode() ) ) CompletableFuture resultFuture = new CompletableFuture<>(); refreshRoutingTableFuture = resultFuture; - rediscovery.lookupClusterComposition( routingTable, connectionPool, context.rediscoveryBookmark() ) - .whenComplete( ( composition, completionError ) -> - { - Throwable error = Futures.completionExceptionCause( completionError ); - if ( error != null ) - { - clusterCompositionLookupFailed( error ); - } - else - { - freshClusterCompositionFetched( composition ); - } - } ); + rediscovery.lookupClusterComposition( routingTable, connectionPool, context.rediscoveryBookmark(), null ) + .whenComplete( ( composition, completionError ) -> + { + Throwable error = Futures.completionExceptionCause( completionError ); + if ( error != null ) + { + clusterCompositionLookupFailed( error ); + } + else + { + freshClusterCompositionFetched( composition ); + } + } ); return resultFuture; } @@ -110,6 +111,27 @@ else if ( routingTable.isStaleFor( context.mode() ) ) } } + @Override + public synchronized CompletionStage updateRoutingTable( ClusterCompositionLookupResult compositionLookupResult ) + { + if ( refreshRoutingTableFuture != null ) + { + // refresh is already happening concurrently, just use its result + return refreshRoutingTableFuture; + } + else + { + if ( compositionLookupResult.getClusterComposition().expirationTimestamp() < routingTable.expirationTimestamp() ) + { + return completedFuture( routingTable ); + } + CompletableFuture resultFuture = new CompletableFuture<>(); + refreshRoutingTableFuture = resultFuture; + freshClusterCompositionFetched( compositionLookupResult ); + return resultFuture; + } + } + private synchronized void freshClusterCompositionFetched( ClusterCompositionLookupResult compositionLookupResult ) { try diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingTableRegistryImpl.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingTableRegistryImpl.java index 45cf335306..06544a0a05 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingTableRegistryImpl.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingTableRegistryImpl.java @@ -18,44 +18,150 @@ */ package org.neo4j.driver.internal.cluster; +import java.util.HashMap; import java.util.HashSet; +import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicReference; import org.neo4j.driver.Logger; import org.neo4j.driver.Logging; import org.neo4j.driver.internal.BoltServerAddress; import org.neo4j.driver.internal.DatabaseName; +import org.neo4j.driver.internal.DatabaseNameUtil; import org.neo4j.driver.internal.async.ConnectionContext; import org.neo4j.driver.internal.spi.ConnectionPool; import org.neo4j.driver.internal.util.Clock; +import org.neo4j.driver.internal.util.Futures; + +import static org.neo4j.driver.internal.async.ConnectionContext.PENDING_DATABASE_NAME_EXCEPTION_SUPPLIER; public class RoutingTableRegistryImpl implements RoutingTableRegistry { private final ConcurrentMap routingTableHandlers; + private final Map> principalToDatabaseNameStage; private final RoutingTableHandlerFactory factory; private final Logger log; + private final Clock clock; + private final ConnectionPool connectionPool; + private final Rediscovery rediscovery; public RoutingTableRegistryImpl( ConnectionPool connectionPool, Rediscovery rediscovery, Clock clock, Logging logging, long routingTablePurgeDelayMs ) { - this( new ConcurrentHashMap<>(), new RoutingTableHandlerFactory( connectionPool, rediscovery, clock, logging, routingTablePurgeDelayMs ), logging ); + this( new ConcurrentHashMap<>(), new RoutingTableHandlerFactory( connectionPool, rediscovery, clock, logging, routingTablePurgeDelayMs ), clock, + connectionPool, rediscovery, logging ); } - RoutingTableRegistryImpl( ConcurrentMap routingTableHandlers, RoutingTableHandlerFactory factory, Logging logging ) + RoutingTableRegistryImpl( ConcurrentMap routingTableHandlers, RoutingTableHandlerFactory factory, Clock clock, + ConnectionPool connectionPool, Rediscovery rediscovery, Logging logging ) { this.factory = factory; this.routingTableHandlers = routingTableHandlers; + this.principalToDatabaseNameStage = new HashMap<>(); + this.clock = clock; + this.connectionPool = connectionPool; + this.rediscovery = rediscovery; this.log = logging.getLog( getClass() ); } @Override public CompletionStage ensureRoutingTable( ConnectionContext context ) { - RoutingTableHandler handler = getOrCreate( context.databaseName() ); - return handler.ensureRoutingTable( context ).thenApply( ignored -> handler ); + return ensureDatabaseNameIsCompleted( context ) + .thenCompose( ctxAndHandler -> + { + ConnectionContext completedContext = ctxAndHandler.getContext(); + RoutingTableHandler handler = ctxAndHandler.getHandler() != null + ? ctxAndHandler.getHandler() + : getOrCreate( Futures.joinNowOrElseThrow( completedContext.databaseNameFuture(), + PENDING_DATABASE_NAME_EXCEPTION_SUPPLIER ) ); + return handler.ensureRoutingTable( completedContext ) + .thenApply( ignored -> handler ); + } ); + } + + private CompletionStage ensureDatabaseNameIsCompleted( ConnectionContext context ) + { + CompletionStage contextAndHandlerStage; + CompletableFuture contextDatabaseNameFuture = context.databaseNameFuture(); + + if ( contextDatabaseNameFuture.isDone() ) + { + contextAndHandlerStage = CompletableFuture.completedFuture( new ConnectionContextAndHandler( context, null ) ); + } + else + { + synchronized ( this ) + { + if ( contextDatabaseNameFuture.isDone() ) + { + contextAndHandlerStage = CompletableFuture.completedFuture( new ConnectionContextAndHandler( context, null ) ); + } + else + { + String impersonatedUser = context.impersonatedUser(); + Principal principal = new Principal( impersonatedUser ); + CompletionStage databaseNameStage = principalToDatabaseNameStage.get( principal ); + AtomicReference handlerRef = new AtomicReference<>(); + + if ( databaseNameStage == null ) + { + CompletableFuture databaseNameFuture = new CompletableFuture<>(); + principalToDatabaseNameStage.put( principal, databaseNameFuture ); + databaseNameStage = databaseNameFuture; + + ClusterRoutingTable routingTable = new ClusterRoutingTable( DatabaseNameUtil.defaultDatabase(), clock ); + rediscovery.lookupClusterComposition( routingTable, connectionPool, context.rediscoveryBookmark(), impersonatedUser ) + .thenCompose( + compositionLookupResult -> + { + DatabaseName databaseName = + DatabaseNameUtil.database( compositionLookupResult.getClusterComposition().databaseName() ); + RoutingTableHandler handler = getOrCreate( databaseName ); + handlerRef.set( handler ); + return handler.updateRoutingTable( compositionLookupResult ) + .thenApply( ignored -> databaseName ); + } ) + .whenComplete( ( databaseName, throwable ) -> + { + synchronized ( this ) + { + principalToDatabaseNameStage.remove( principal ); + } + } ) + .whenComplete( ( databaseName, throwable ) -> + { + if ( throwable != null ) + { + databaseNameFuture.completeExceptionally( throwable ); + } + else + { + databaseNameFuture.complete( databaseName ); + } + } ); + } + + contextAndHandlerStage = databaseNameStage.thenApply( + databaseName -> + { + synchronized ( this ) + { + contextDatabaseNameFuture.complete( databaseName ); + } + return new ConnectionContextAndHandler( context, handlerRef.get() ); + } ); + } + } + } + + return contextAndHandlerStage; } @Override @@ -140,4 +246,57 @@ RoutingTableHandler newInstance( DatabaseName databaseName, RoutingTableRegistry return new RoutingTableHandlerImpl( routingTable, rediscovery, connectionPool, allTables, logging, routingTablePurgeDelayMs ); } } + + private static class Principal + { + private final String id; + + private Principal( String id ) + { + this.id = id; + } + + @Override + public boolean equals( Object o ) + { + if ( this == o ) + { + return true; + } + if ( o == null || getClass() != o.getClass() ) + { + return false; + } + Principal principal = (Principal) o; + return Objects.equals( id, principal.id ); + } + + @Override + public int hashCode() + { + return Objects.hash( id ); + } + } + + private static class ConnectionContextAndHandler + { + private final ConnectionContext context; + private final RoutingTableHandler handler; + + private ConnectionContextAndHandler( ConnectionContext context, RoutingTableHandler handler ) + { + this.context = context; + this.handler = handler; + } + + public ConnectionContext getContext() + { + return context; + } + + public RoutingTableHandler getHandler() + { + return handler; + } + } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/SingleDatabaseRoutingProcedureRunner.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/SingleDatabaseRoutingProcedureRunner.java index 543075a87f..3be6f7fc3b 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/SingleDatabaseRoutingProcedureRunner.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/SingleDatabaseRoutingProcedureRunner.java @@ -58,7 +58,7 @@ public SingleDatabaseRoutingProcedureRunner( RoutingContext context ) } @Override - public CompletionStage run( Connection connection, DatabaseName databaseName, Bookmark bookmark ) + public CompletionStage run( Connection connection, DatabaseName databaseName, Bookmark bookmark, String impersonatedUser ) { DirectConnection delegate = connection( connection ); Query procedure = procedureQuery( connection.serverVersion(), databaseName ); @@ -70,7 +70,7 @@ public CompletionStage run( Connection connection, Dat DirectConnection connection( Connection connection ) { - return new DirectConnection( connection, defaultDatabase(), AccessMode.WRITE ); + return new DirectConnection( connection, defaultDatabase(), AccessMode.WRITE, null ); } Query procedureQuery(ServerVersion serverVersion, DatabaseName databaseName ) diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancer.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancer.java index e8e05cca77..2b980e666d 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancer.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancer.java @@ -53,6 +53,7 @@ import static java.lang.String.format; import static java.util.Objects.requireNonNull; +import static org.neo4j.driver.internal.async.ConnectionContext.PENDING_DATABASE_NAME_EXCEPTION_SUPPLIER; import static org.neo4j.driver.internal.async.ImmutableConnectionContext.simple; import static org.neo4j.driver.internal.messaging.request.MultiDatabaseUtil.supportsMultiDatabase; import static org.neo4j.driver.internal.util.Futures.completedWithNull; @@ -104,8 +105,11 @@ private LoadBalancer( ConnectionPool connectionPool, Rediscovery rediscovery, Ro public CompletionStage acquireConnection( ConnectionContext context ) { return routingTables.ensureRoutingTable( context ) - .thenCompose( handler -> acquire( context.mode(), handler.routingTable() ) - .thenApply( connection -> new RoutingConnection( connection, context.databaseName(), context.mode(), handler ) ) ); + .thenCompose( handler -> acquire( context.mode(), handler.routingTable() ) + .thenApply( connection -> new RoutingConnection( connection, + Futures.joinNowOrElseThrow( context.databaseNameFuture(), + PENDING_DATABASE_NAME_EXCEPTION_SUPPLIER ), + context.mode(), context.impersonatedUser(), handler ) ) ); } @Override diff --git a/driver/src/main/java/org/neo4j/driver/internal/messaging/encode/RouteV44MessageEncoder.java b/driver/src/main/java/org/neo4j/driver/internal/messaging/encode/RouteV44MessageEncoder.java new file mode 100644 index 0000000000..e15d5511c0 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/messaging/encode/RouteV44MessageEncoder.java @@ -0,0 +1,63 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.messaging.encode; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; + +import org.neo4j.driver.Value; +import org.neo4j.driver.internal.messaging.Message; +import org.neo4j.driver.internal.messaging.MessageEncoder; +import org.neo4j.driver.internal.messaging.ValuePacker; +import org.neo4j.driver.internal.messaging.request.RouteMessage; + +import static org.neo4j.driver.Values.value; +import static org.neo4j.driver.internal.util.Preconditions.checkArgument; + +/** + * Encodes the ROUTE message to the stream + */ +public class RouteV44MessageEncoder implements MessageEncoder +{ + @Override + public void encode( Message message, ValuePacker packer ) throws IOException + { + checkArgument( message, RouteMessage.class ); + RouteMessage routeMessage = (RouteMessage) message; + packer.packStructHeader( 3, message.signature() ); + packer.pack( routeMessage.getRoutingContext() ); + packer.pack( routeMessage.getBookmark().isPresent() ? value( routeMessage.getBookmark().get().values() ) : value( Collections.emptyList() ) ); + + Map params; + if ( routeMessage.getImpersonatedUser() != null && routeMessage.getDatabaseName() == null ) + { + params = Collections.singletonMap( "imp_user", value( routeMessage.getImpersonatedUser() ) ); + } + else if ( routeMessage.getDatabaseName() != null ) + { + params = Collections.singletonMap( "db", value( routeMessage.getDatabaseName() ) ); + } + else + { + params = Collections.emptyMap(); + } + packer.pack( params ); + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/messaging/request/BeginMessage.java b/driver/src/main/java/org/neo4j/driver/internal/messaging/request/BeginMessage.java index 5a18b3d152..0e82cd61da 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/messaging/request/BeginMessage.java +++ b/driver/src/main/java/org/neo4j/driver/internal/messaging/request/BeginMessage.java @@ -34,14 +34,15 @@ public class BeginMessage extends MessageWithMetadata { public static final byte SIGNATURE = 0x11; - public BeginMessage( Bookmark bookmark, TransactionConfig config, DatabaseName databaseName, AccessMode mode ) + public BeginMessage( Bookmark bookmark, TransactionConfig config, DatabaseName databaseName, AccessMode mode, String impersonatedUser ) { - this( bookmark, config.timeout(), config.metadata(), mode, databaseName ); + this( bookmark, config.timeout(), config.metadata(), mode, databaseName, impersonatedUser ); } - public BeginMessage( Bookmark bookmark, Duration txTimeout, Map txMetadata, AccessMode mode, DatabaseName databaseName ) + public BeginMessage( Bookmark bookmark, Duration txTimeout, Map txMetadata, AccessMode mode, DatabaseName databaseName, + String impersonatedUser ) { - super( buildMetadata( txTimeout, txMetadata, databaseName, mode, bookmark ) ); + super( buildMetadata( txTimeout, txMetadata, databaseName, mode, bookmark, impersonatedUser ) ); } @Override diff --git a/driver/src/main/java/org/neo4j/driver/internal/messaging/request/RouteMessage.java b/driver/src/main/java/org/neo4j/driver/internal/messaging/request/RouteMessage.java index ef5404a705..440f3c81f3 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/messaging/request/RouteMessage.java +++ b/driver/src/main/java/org/neo4j/driver/internal/messaging/request/RouteMessage.java @@ -40,19 +40,22 @@ public class RouteMessage implements Message private final Map routingContext; private final Bookmark bookmark; private final String databaseName; + private final String impersonatedUser; /** * Constructor * - * @param routingContext The routing context used to define the routing table. Multi-datacenter deployments is one of its use cases. - * @param bookmark The bookmark used when getting the routing table. - * @param databaseName The name of the database to get the routing table for. + * @param routingContext The routing context used to define the routing table. Multi-datacenter deployments is one of its use cases. + * @param bookmark The bookmark used when getting the routing table. + * @param databaseName The name of the database to get the routing table for. + * @param impersonatedUser The name of the impersonated user to get the routing table for, should be {@code null} for non-impersonated requests */ - public RouteMessage( Map routingContext, Bookmark bookmark, String databaseName ) + public RouteMessage( Map routingContext, Bookmark bookmark, String databaseName, String impersonatedUser ) { this.routingContext = unmodifiableMap( routingContext ); this.bookmark = bookmark; this.databaseName = databaseName; + this.impersonatedUser = impersonatedUser; } public Map getRoutingContext() @@ -70,6 +73,11 @@ public String getDatabaseName() return databaseName; } + public String getImpersonatedUser() + { + return impersonatedUser; + } + @Override public byte signature() { @@ -79,7 +87,7 @@ public byte signature() @Override public String toString() { - return String.format( "ROUTE %s %s %s", routingContext, bookmark, databaseName ); + return String.format( "ROUTE %s %s %s %s", routingContext, bookmark, databaseName, impersonatedUser ); } @Override @@ -95,12 +103,13 @@ public boolean equals( Object o ) } RouteMessage that = (RouteMessage) o; return routingContext.equals( that.routingContext ) && - Objects.equals( databaseName, that.databaseName ); + Objects.equals( databaseName, that.databaseName ) && + Objects.equals( impersonatedUser, that.impersonatedUser ); } @Override public int hashCode() { - return Objects.hash( routingContext, databaseName ); + return Objects.hash( routingContext, databaseName, impersonatedUser ); } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/messaging/request/RunWithMetadataMessage.java b/driver/src/main/java/org/neo4j/driver/internal/messaging/request/RunWithMetadataMessage.java index d0722c2b37..91c6712634 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/messaging/request/RunWithMetadataMessage.java +++ b/driver/src/main/java/org/neo4j/driver/internal/messaging/request/RunWithMetadataMessage.java @@ -40,20 +40,20 @@ public class RunWithMetadataMessage extends MessageWithMetadata private final String query; private final Map parameters; - public static RunWithMetadataMessage autoCommitTxRunMessage(Query query, TransactionConfig config, DatabaseName databaseName, AccessMode mode, - Bookmark bookmark ) + public static RunWithMetadataMessage autoCommitTxRunMessage( Query query, TransactionConfig config, DatabaseName databaseName, AccessMode mode, + Bookmark bookmark, String impersonatedUser ) { - return autoCommitTxRunMessage(query, config.timeout(), config.metadata(), databaseName, mode, bookmark ); + return autoCommitTxRunMessage( query, config.timeout(), config.metadata(), databaseName, mode, bookmark, impersonatedUser ); } - public static RunWithMetadataMessage autoCommitTxRunMessage(Query query, Duration txTimeout, Map txMetadata, DatabaseName databaseName, - AccessMode mode, Bookmark bookmark ) + public static RunWithMetadataMessage autoCommitTxRunMessage( Query query, Duration txTimeout, Map txMetadata, DatabaseName databaseName, + AccessMode mode, Bookmark bookmark, String impersonatedUser ) { - Map metadata = buildMetadata( txTimeout, txMetadata, databaseName, mode, bookmark ); + Map metadata = buildMetadata( txTimeout, txMetadata, databaseName, mode, bookmark, impersonatedUser ); return new RunWithMetadataMessage( query.text(), query.parameters().asMap( ofValue() ), metadata ); } - public static RunWithMetadataMessage unmanagedTxRunMessage(Query query) + public static RunWithMetadataMessage unmanagedTxRunMessage( Query query ) { return new RunWithMetadataMessage( query.text(), query.parameters().asMap( ofValue() ), emptyMap() ); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/messaging/request/TransactionMetadataBuilder.java b/driver/src/main/java/org/neo4j/driver/internal/messaging/request/TransactionMetadataBuilder.java index c46e569dd2..7a4ba34b9c 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/messaging/request/TransactionMetadataBuilder.java +++ b/driver/src/main/java/org/neo4j/driver/internal/messaging/request/TransactionMetadataBuilder.java @@ -39,21 +39,25 @@ public class TransactionMetadataBuilder private static final String TX_METADATA_METADATA_KEY = "tx_metadata"; private static final String MODE_KEY = "mode"; private static final String MODE_READ_VALUE = "r"; + private static final String IMPERSONATED_USER_KEY = "imp_user"; - public static Map buildMetadata( Duration txTimeout, Map txMetadata, AccessMode mode, Bookmark bookmark ) + public static Map buildMetadata( Duration txTimeout, Map txMetadata, AccessMode mode, Bookmark bookmark, + String impersonatedUser ) { - return buildMetadata( txTimeout, txMetadata, defaultDatabase(), mode, bookmark ); + return buildMetadata( txTimeout, txMetadata, defaultDatabase(), mode, bookmark, impersonatedUser ); } - public static Map buildMetadata( Duration txTimeout, Map txMetadata, DatabaseName databaseName, AccessMode mode, Bookmark bookmark ) + public static Map buildMetadata( Duration txTimeout, Map txMetadata, DatabaseName databaseName, AccessMode mode, + Bookmark bookmark, String impersonatedUser ) { boolean bookmarksPresent = bookmark != null && !bookmark.isEmpty(); boolean txTimeoutPresent = txTimeout != null; boolean txMetadataPresent = txMetadata != null && !txMetadata.isEmpty(); boolean accessModePresent = mode == AccessMode.READ; boolean databaseNamePresent = databaseName.databaseName().isPresent(); + boolean impersonatedUserPresent = impersonatedUser != null; - if ( !bookmarksPresent && !txTimeoutPresent && !txMetadataPresent && !accessModePresent && !databaseNamePresent ) + if ( !bookmarksPresent && !txTimeoutPresent && !txMetadataPresent && !accessModePresent && !databaseNamePresent && !impersonatedUserPresent ) { return emptyMap(); } @@ -72,10 +76,14 @@ public static Map buildMetadata( Duration txTimeout, Map result.put( DATABASE_NAME_KEY, value( name ) ) ); diff --git a/driver/src/main/java/org/neo4j/driver/internal/messaging/v3/BoltProtocolV3.java b/driver/src/main/java/org/neo4j/driver/internal/messaging/v3/BoltProtocolV3.java index 92ea2ff272..1f86459c7a 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/messaging/v3/BoltProtocolV3.java +++ b/driver/src/main/java/org/neo4j/driver/internal/messaging/v3/BoltProtocolV3.java @@ -122,7 +122,7 @@ public CompletionStage beginTransaction( Connection connection, Bookmark b } CompletableFuture beginTxFuture = new CompletableFuture<>(); - BeginMessage beginMessage = new BeginMessage( bookmark, config, connection.databaseName(), connection.mode() ); + BeginMessage beginMessage = new BeginMessage( bookmark, config, connection.databaseName(), connection.mode(), connection.impersonatedUser() ); connection.writeAndFlush( beginMessage, new BeginTxResponseHandler( beginTxFuture ) ); return beginTxFuture; } @@ -149,7 +149,8 @@ public ResultCursorFactory runInAutoCommitTransaction( Connection connection, Qu { verifyDatabaseNameBeforeTransaction( connection.databaseName() ); RunWithMetadataMessage runMessage = - autoCommitTxRunMessage(query, config, connection.databaseName(), connection.mode(), bookmarkHolder.getBookmark() ); + autoCommitTxRunMessage( query, config, connection.databaseName(), connection.mode(), bookmarkHolder.getBookmark(), + connection.impersonatedUser() ); return buildResultCursorFactory( connection, query, bookmarkHolder, null, runMessage, fetchSize ); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/messaging/v44/MessageWriterV44.java b/driver/src/main/java/org/neo4j/driver/internal/messaging/v44/MessageWriterV44.java index 090dc4559d..4731da297b 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/messaging/v44/MessageWriterV44.java +++ b/driver/src/main/java/org/neo4j/driver/internal/messaging/v44/MessageWriterV44.java @@ -31,7 +31,7 @@ import org.neo4j.driver.internal.messaging.encode.PullMessageEncoder; import org.neo4j.driver.internal.messaging.encode.ResetMessageEncoder; import org.neo4j.driver.internal.messaging.encode.RollbackMessageEncoder; -import org.neo4j.driver.internal.messaging.encode.RouteMessageEncoder; +import org.neo4j.driver.internal.messaging.encode.RouteV44MessageEncoder; import org.neo4j.driver.internal.messaging.encode.RunWithMetadataMessageEncoder; import org.neo4j.driver.internal.messaging.request.BeginMessage; import org.neo4j.driver.internal.messaging.request.CommitMessage; @@ -62,7 +62,7 @@ private static Map buildEncoders() result.put( HelloMessage.SIGNATURE, new HelloMessageEncoder() ); result.put( GoodbyeMessage.SIGNATURE, new GoodbyeMessageEncoder() ); result.put( RunWithMetadataMessage.SIGNATURE, new RunWithMetadataMessageEncoder() ); - result.put( RouteMessage.SIGNATURE, new RouteMessageEncoder() ); + result.put( RouteMessage.SIGNATURE, new RouteV44MessageEncoder() ); result.put( DiscardMessage.SIGNATURE, new DiscardMessageEncoder() ); result.put( PullMessage.SIGNATURE, new PullMessageEncoder() ); diff --git a/driver/src/main/java/org/neo4j/driver/internal/spi/Connection.java b/driver/src/main/java/org/neo4j/driver/internal/spi/Connection.java index 13c2946937..f6d676c0e7 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/spi/Connection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/spi/Connection.java @@ -69,5 +69,10 @@ default DatabaseName databaseName() throw new UnsupportedOperationException( format( "%s does not support database name.", getClass() ) ); } + default String impersonatedUser() + { + throw new UnsupportedOperationException( format( "%s does not support impersonated user.", getClass() ) ); + } + void flush(); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/spi/ConnectionPool.java b/driver/src/main/java/org/neo4j/driver/internal/spi/ConnectionPool.java index 7be5cec12c..0bd91db8fe 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/spi/ConnectionPool.java +++ b/driver/src/main/java/org/neo4j/driver/internal/spi/ConnectionPool.java @@ -25,6 +25,8 @@ public interface ConnectionPool { + String CONNECTION_POOL_CLOSED_ERROR_MESSAGE = "Pool closed"; + CompletionStage acquire( BoltServerAddress address ); void retainAll( Set addressesToRetain ); diff --git a/driver/src/main/java/org/neo4j/driver/internal/util/Futures.java b/driver/src/main/java/org/neo4j/driver/internal/util/Futures.java index 24ed13c879..6d2d318fd5 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/util/Futures.java +++ b/driver/src/main/java/org/neo4j/driver/internal/util/Futures.java @@ -27,6 +27,7 @@ import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Function; +import java.util.function.Supplier; import org.neo4j.driver.internal.async.connection.EventLoopGroupFactory; @@ -158,11 +159,23 @@ public static T getNow( CompletionStage stage ) return stage.toCompletableFuture().getNow( null ); } + public static T joinNowOrElseThrow( CompletableFuture future, Supplier exceptionSupplier ) + { + if ( future.isDone() ) + { + return future.join(); + } + else + { + throw exceptionSupplier.get(); + } + } + /** * Helper method to extract cause of a {@link CompletionException}. *

- * When using {@link CompletionStage#whenComplete(BiConsumer)} and {@link CompletionStage#handle(BiFunction)} - * propagated exceptions might get wrapped in a {@link CompletionException}. + * When using {@link CompletionStage#whenComplete(BiConsumer)} and {@link CompletionStage#handle(BiFunction)} propagated exceptions might get wrapped in a + * {@link CompletionException}. * * @param error the exception to get cause for. * @return cause of the given exception if it is a {@link CompletionException}, given exception otherwise. diff --git a/driver/src/main/java/org/neo4j/driver/internal/util/ServerVersion.java b/driver/src/main/java/org/neo4j/driver/internal/util/ServerVersion.java index 4c9ba7fcd9..5232d4e76f 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/util/ServerVersion.java +++ b/driver/src/main/java/org/neo4j/driver/internal/util/ServerVersion.java @@ -28,6 +28,8 @@ import org.neo4j.driver.internal.messaging.v4.BoltProtocolV4; import org.neo4j.driver.internal.messaging.v41.BoltProtocolV41; import org.neo4j.driver.internal.messaging.v42.BoltProtocolV42; +import org.neo4j.driver.internal.messaging.v43.BoltProtocolV43; +import org.neo4j.driver.internal.messaging.v44.BoltProtocolV44; import static java.lang.Integer.compare; @@ -35,6 +37,8 @@ public class ServerVersion { public static final String NEO4J_PRODUCT = "Neo4j"; + public static final ServerVersion v4_4_0 = new ServerVersion( NEO4J_PRODUCT, 4, 4, 0 ); + public static final ServerVersion v4_3_0 = new ServerVersion( NEO4J_PRODUCT, 4, 3, 0 ); public static final ServerVersion v4_2_0 = new ServerVersion( NEO4J_PRODUCT, 4, 2, 0 ); public static final ServerVersion v4_1_0 = new ServerVersion( NEO4J_PRODUCT, 4, 1, 0 ); public static final ServerVersion v4_0_0 = new ServerVersion( NEO4J_PRODUCT, 4, 0, 0 ); @@ -194,6 +198,14 @@ else if ( BoltProtocolV41.VERSION.equals( protocolVersion ) ) { return ServerVersion.v4_2_0; } + else if ( BoltProtocolV43.VERSION.equals( protocolVersion ) ) + { + return ServerVersion.v4_3_0; + } + else if ( BoltProtocolV44.VERSION.equals( protocolVersion ) ) + { + return ServerVersion.v4_4_0; + } return ServerVersion.vInDev; } diff --git a/driver/src/test/java/org/neo4j/driver/ParametersTest.java b/driver/src/test/java/org/neo4j/driver/ParametersTest.java index ca1005ecc3..9358e7dbbc 100644 --- a/driver/src/test/java/org/neo4j/driver/ParametersTest.java +++ b/driver/src/test/java/org/neo4j/driver/ParametersTest.java @@ -110,7 +110,8 @@ private Session mockedSession() ConnectionProvider provider = mock( ConnectionProvider.class ); RetryLogic retryLogic = mock( RetryLogic.class ); NetworkSession session = - new NetworkSession( provider, retryLogic, defaultDatabase(), AccessMode.WRITE, new DefaultBookmarkHolder(), UNLIMITED_FETCH_SIZE, DEV_NULL_LOGGING ); + new NetworkSession( provider, retryLogic, defaultDatabase(), AccessMode.WRITE, new DefaultBookmarkHolder(), null, UNLIMITED_FETCH_SIZE, + DEV_NULL_LOGGING ); return new InternalSession( session ); } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/DirectConnectionProviderTest.java b/driver/src/test/java/org/neo4j/driver/internal/DirectConnectionProviderTest.java index 245c93b322..b746554e42 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/DirectConnectionProviderTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/DirectConnectionProviderTest.java @@ -22,11 +22,13 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.InOrder; import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; import org.neo4j.driver.AccessMode; +import org.neo4j.driver.internal.async.ConnectionContext; import org.neo4j.driver.internal.async.connection.DirectConnection; import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.spi.ConnectionPool; @@ -36,7 +38,9 @@ import static org.hamcrest.junit.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertSame; +import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.neo4j.driver.AccessMode.READ; @@ -131,14 +135,36 @@ void shouldObtainDatabaseNameOnConnection( String databaseName ) throws Throwabl assertEquals( databaseName, acquired.databaseName().description() ); } + @ParameterizedTest + @ValueSource( booleans = {true, false} ) + void ensuresCompletedDatabaseNameBeforeAccessingValue( boolean completed ) + { + BoltServerAddress address = BoltServerAddress.LOCAL_DEFAULT; + ConnectionPool pool = poolMock( address, mock( Connection.class ) ); + DirectConnectionProvider provider = new DirectConnectionProvider( address, pool ); + ConnectionContext context = mock( ConnectionContext.class ); + CompletableFuture databaseNameFuture = + spy( completed ? CompletableFuture.completedFuture( DatabaseNameUtil.systemDatabase() ) : new CompletableFuture<>() ); + when( context.databaseNameFuture() ).thenReturn( databaseNameFuture ); + when( context.mode() ).thenReturn( WRITE ); + + await( provider.acquireConnection( context ) ); + + InOrder inOrder = inOrder( context, databaseNameFuture ); + inOrder.verify( context ).databaseNameFuture(); + inOrder.verify( databaseNameFuture ).complete( DatabaseNameUtil.defaultDatabase() ); + inOrder.verify( databaseNameFuture ).isDone(); + inOrder.verify( databaseNameFuture ).join(); + } + @SuppressWarnings( "unchecked" ) private static ConnectionPool poolMock( BoltServerAddress address, Connection connection, - Connection... otherConnections ) + Connection... otherConnections ) { ConnectionPool pool = mock( ConnectionPool.class ); CompletableFuture[] otherConnectionFutures = Stream.of( otherConnections ) - .map( CompletableFuture::completedFuture ) - .toArray( CompletableFuture[]::new ); + .map( CompletableFuture::completedFuture ) + .toArray( CompletableFuture[]::new ); when( pool.acquire( address ) ).thenReturn( completedFuture( connection ), otherConnectionFutures ); return pool; } diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/LeakLoggingNetworkSessionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/LeakLoggingNetworkSessionTest.java index 1d9106621b..d0af1ffe8d 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/LeakLoggingNetworkSessionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/LeakLoggingNetworkSessionTest.java @@ -98,7 +98,7 @@ private static void finalize( NetworkSession session ) throws Exception private static LeakLoggingNetworkSession newSession( Logging logging, boolean openConnection ) { return new LeakLoggingNetworkSession( connectionProviderMock( openConnection ), new FixedRetryLogic( 0 ), defaultDatabase(), READ, - new DefaultBookmarkHolder(), FetchSizeUtil.UNLIMITED_FETCH_SIZE, logging ); + new DefaultBookmarkHolder(), null, FetchSizeUtil.UNLIMITED_FETCH_SIZE, logging ); } private static ConnectionProvider connectionProviderMock( boolean openConnection ) diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/connection/DecoratedConnectionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/connection/DecoratedConnectionTest.java index fee7ea90f0..3892160fa8 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/connection/DecoratedConnectionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/connection/DecoratedConnectionTest.java @@ -210,7 +210,7 @@ void shouldDelegateProtocol() @EnumSource( AccessMode.class ) void shouldReturnModeFromConstructor( AccessMode mode ) { - DirectConnection connection = new DirectConnection( mock( Connection.class ), defaultDatabase(), mode ); + DirectConnection connection = new DirectConnection( mock( Connection.class ), defaultDatabase(), mode, null ); assertEquals( mode, connection.mode() ); } @@ -226,6 +226,6 @@ void shouldReturnConnection() private static DirectConnection newConnection( Connection connection ) { - return new DirectConnection( connection, defaultDatabase(), READ ); + return new DirectConnection( connection, defaultDatabase(), READ, null ); } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/connection/DirectConnectionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/connection/DirectConnectionTest.java index 3e843dad75..4997cbdcb7 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/connection/DirectConnectionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/connection/DirectConnectionTest.java @@ -36,7 +36,7 @@ void shouldReturnServerAgent() { // given Connection connection = mock( Connection.class ); - DirectConnection directConnection = new DirectConnection( connection, defaultDatabase(), READ ); + DirectConnection directConnection = new DirectConnection( connection, defaultDatabase(), READ, null ); String agent = "Neo4j/4.2.5"; given( connection.serverAgent() ).willReturn( agent ); diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/connection/RoutingConnectionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/connection/RoutingConnectionTest.java index aac0b28b32..2aa59c69c0 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/connection/RoutingConnectionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/connection/RoutingConnectionTest.java @@ -71,7 +71,7 @@ void shouldReturnServerAgent() // given Connection connection = mock( Connection.class ); RoutingErrorHandler errorHandler = mock( RoutingErrorHandler.class ); - RoutingConnection routingConnection = new RoutingConnection( connection, defaultDatabase(), READ, errorHandler ); + RoutingConnection routingConnection = new RoutingConnection( connection, defaultDatabase(), READ, null, errorHandler ); String agent = "Neo4j/4.2.5"; given( connection.serverAgent() ).willReturn( agent ); @@ -87,7 +87,7 @@ private static void testHandlersWrappingWithSingleMessage( boolean flush ) { Connection connection = mock( Connection.class ); RoutingErrorHandler errorHandler = mock( RoutingErrorHandler.class ); - RoutingConnection routingConnection = new RoutingConnection( connection, defaultDatabase(), READ, errorHandler ); + RoutingConnection routingConnection = new RoutingConnection( connection, defaultDatabase(), READ, null, errorHandler ); if ( flush ) { @@ -116,7 +116,7 @@ private static void testHandlersWrappingWithMultipleMessages( boolean flush ) { Connection connection = mock( Connection.class ); RoutingErrorHandler errorHandler = mock( RoutingErrorHandler.class ); - RoutingConnection routingConnection = new RoutingConnection( connection, defaultDatabase(), READ, errorHandler ); + RoutingConnection routingConnection = new RoutingConnection( connection, defaultDatabase(), READ, null, errorHandler ); if ( flush ) { diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/AbstractRoutingProcedureRunnerTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/AbstractRoutingProcedureRunnerTest.java index 6a88db4ab4..0fe12d5165 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/AbstractRoutingProcedureRunnerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/AbstractRoutingProcedureRunnerTest.java @@ -48,7 +48,7 @@ void shouldReturnFailedResponseOnClientException() ClientException error = new ClientException( "Hi" ); SingleDatabaseRoutingProcedureRunner runner = singleDatabaseRoutingProcedureRunner( RoutingContext.EMPTY, failedFuture( error ) ); - RoutingProcedureResponse response = await( runner.run( connection(), defaultDatabase(), empty() ) ); + RoutingProcedureResponse response = await( runner.run( connection(), defaultDatabase(), empty(), null ) ); assertFalse( response.isSuccess() ); assertEquals( error, response.error() ); @@ -60,7 +60,7 @@ void shouldReturnFailedStageOnError() Exception error = new Exception( "Hi" ); SingleDatabaseRoutingProcedureRunner runner = singleDatabaseRoutingProcedureRunner( RoutingContext.EMPTY, failedFuture( error ) ); - Exception e = assertThrows( Exception.class, () -> await( runner.run( connection(), defaultDatabase(), empty() ) ) ); + Exception e = assertThrows( Exception.class, () -> await( runner.run( connection(), defaultDatabase(), empty(), null ) ) ); assertEquals( error, e ); } @@ -70,7 +70,7 @@ void shouldReleaseConnectionOnSuccess() SingleDatabaseRoutingProcedureRunner runner = singleDatabaseRoutingProcedureRunner( RoutingContext.EMPTY ); Connection connection = connection(); - RoutingProcedureResponse response = await( runner.run( connection, defaultDatabase(), empty() ) ); + RoutingProcedureResponse response = await( runner.run( connection, defaultDatabase(), empty(), null ) ); assertTrue( response.isSuccess() ); verify( connection ).release(); @@ -84,7 +84,7 @@ void shouldPropagateReleaseError() RuntimeException releaseError = new RuntimeException( "Release failed" ); Connection connection = connection( failedFuture( releaseError ) ); - RuntimeException e = assertThrows( RuntimeException.class, () -> await( runner.run( connection, defaultDatabase(), empty() ) ) ); + RuntimeException e = assertThrows( RuntimeException.class, () -> await( runner.run( connection, defaultDatabase(), empty(), null ) ) ); assertEquals( releaseError, e ); verify( connection ).release(); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/ClusterCompositionTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/ClusterCompositionTest.java index a89d86a763..7fe9dc13b9 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/ClusterCompositionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/ClusterCompositionTest.java @@ -222,7 +222,7 @@ void parsePreservesOrderOfRouters() private static ClusterComposition newComposition( long expirationTimestamp, Set readers, Set writers, Set routers ) { - return new ClusterComposition( expirationTimestamp, readers, writers, routers ); + return new ClusterComposition( expirationTimestamp, readers, writers, routers, null ); } private static Set addresses( BoltServerAddress... elements ) diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/MultiDatabasesRoutingProcedureRunnerTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/MultiDatabasesRoutingProcedureRunnerTest.java index 4684a41e19..7abbe76252 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/MultiDatabasesRoutingProcedureRunnerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/MultiDatabasesRoutingProcedureRunnerTest.java @@ -60,7 +60,7 @@ class MultiDatabasesRoutingProcedureRunnerTest extends AbstractRoutingProcedureR void shouldCallGetRoutingTableWithEmptyMapOnSystemDatabaseForDatabase( String db ) { TestRoutingProcedureRunner runner = new TestRoutingProcedureRunner( RoutingContext.EMPTY ); - RoutingProcedureResponse response = await( runner.run( connection(), database( db ), empty() ) ); + RoutingProcedureResponse response = await( runner.run( connection(), database( db ), empty(), null ) ); assertTrue( response.isSuccess() ); assertEquals( 1, response.records().size() ); @@ -81,7 +81,7 @@ void shouldCallGetRoutingTableWithParamOnSystemDatabaseForDatabase( String db ) RoutingContext context = new RoutingContext( uri ); TestRoutingProcedureRunner runner = new TestRoutingProcedureRunner( context ); - RoutingProcedureResponse response = await( runner.run( connection(), database( db ), empty() ) ); + RoutingProcedureResponse response = await( runner.run( connection(), database( db ), empty(), null ) ); assertTrue( response.isSuccess() ); assertEquals( 1, response.records().size() ); diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/RediscoveryTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/RediscoveryTest.java index 63eb032fd5..cd1e760849 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/RediscoveryTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/RediscoveryTest.java @@ -83,8 +83,8 @@ class RediscoveryTest @Test void shouldUseFirstRouterInTable() { - ClusterComposition expectedComposition = new ClusterComposition( 42, - asOrderedSet( B, C ), asOrderedSet( C, D ), asOrderedSet( B ) ); + ClusterComposition expectedComposition = + new ClusterComposition( 42, asOrderedSet( B, C ), asOrderedSet( C, D ), asOrderedSet( B ), null ); Map responsesByAddress = new HashMap<>(); responsesByAddress.put( B, expectedComposition ); // first -> valid cluster composition @@ -93,7 +93,7 @@ void shouldUseFirstRouterInTable() Rediscovery rediscovery = newRediscovery( A, compositionProvider, mock( ServerAddressResolver.class ) ); RoutingTable table = routingTableMock( B ); - ClusterComposition actualComposition = await( rediscovery.lookupClusterComposition( table, pool, empty() ) ).getClusterComposition(); + ClusterComposition actualComposition = await( rediscovery.lookupClusterComposition( table, pool, empty(), null ) ).getClusterComposition(); assertEquals( expectedComposition, actualComposition ); verify( table, never() ).forget( B ); @@ -102,8 +102,8 @@ void shouldUseFirstRouterInTable() @Test void shouldSkipFailingRouters() { - ClusterComposition expectedComposition = new ClusterComposition( 42, - asOrderedSet( A, B, C ), asOrderedSet( B, C, D ), asOrderedSet( A, B ) ); + ClusterComposition expectedComposition = + new ClusterComposition( 42, asOrderedSet( A, B, C ), asOrderedSet( B, C, D ), asOrderedSet( A, B ), null ); Map responsesByAddress = new HashMap<>(); responsesByAddress.put( A, new RuntimeException( "Hi!" ) ); // first -> non-fatal failure @@ -114,7 +114,7 @@ void shouldSkipFailingRouters() Rediscovery rediscovery = newRediscovery( A, compositionProvider, mock( ServerAddressResolver.class ) ); RoutingTable table = routingTableMock( A, B, C ); - ClusterComposition actualComposition = await( rediscovery.lookupClusterComposition( table, pool, empty() ) ).getClusterComposition(); + ClusterComposition actualComposition = await( rediscovery.lookupClusterComposition( table, pool, empty(), null ) ).getClusterComposition(); assertEquals( expectedComposition, actualComposition ); verify( table ).forget( A ); @@ -137,7 +137,7 @@ void shouldFailImmediatelyOnAuthError() RoutingTable table = routingTableMock( A, B, C ); AuthenticationException error = assertThrows( AuthenticationException.class, - () -> await( rediscovery.lookupClusterComposition( table, pool, empty() ) ) ); + () -> await( rediscovery.lookupClusterComposition( table, pool, empty(), null ) ) ); assertEquals( authError, error ); verify( table ).forget( A ); } @@ -146,8 +146,8 @@ void shouldFailImmediatelyOnAuthError() void shouldFallbackToInitialRouterWhenKnownRoutersFail() { BoltServerAddress initialRouter = A; - ClusterComposition expectedComposition = new ClusterComposition( 42, - asOrderedSet( C, B, A ), asOrderedSet( A, B ), asOrderedSet( D, E ) ); + ClusterComposition expectedComposition = + new ClusterComposition( 42, asOrderedSet( C, B, A ), asOrderedSet( A, B ), asOrderedSet( D, E ), null ); Map responsesByAddress = new HashMap<>(); responsesByAddress.put( B, new ServiceUnavailableException( "Hi!" ) ); // first -> non-fatal failure @@ -159,7 +159,7 @@ void shouldFallbackToInitialRouterWhenKnownRoutersFail() Rediscovery rediscovery = newRediscovery( initialRouter, compositionProvider, resolver ); RoutingTable table = routingTableMock( B, C ); - ClusterComposition actualComposition = await( rediscovery.lookupClusterComposition( table, pool, empty() ) ).getClusterComposition(); + ClusterComposition actualComposition = await( rediscovery.lookupClusterComposition( table, pool, empty(), null ) ).getClusterComposition(); assertEquals( expectedComposition, actualComposition ); verify( table ).forget( B ); @@ -169,8 +169,8 @@ void shouldFallbackToInitialRouterWhenKnownRoutersFail() @Test void shouldFailImmediatelyWhenClusterCompositionProviderReturnsFailure() { - ClusterComposition validComposition = new ClusterComposition( 42, - asOrderedSet( A ), asOrderedSet( B ), asOrderedSet( C ) ); + ClusterComposition validComposition = + new ClusterComposition( 42, asOrderedSet( A ), asOrderedSet( B ), asOrderedSet( C ), null ); ProtocolException protocolError = new ProtocolException( "Wrong record!" ); Map responsesByAddress = new HashMap<>(); @@ -186,7 +186,7 @@ void shouldFailImmediatelyWhenClusterCompositionProviderReturnsFailure() RoutingTable table = routingTableMock( B, C ); // When - ClusterComposition composition = await( rediscovery.lookupClusterComposition( table, pool, empty() ) ).getClusterComposition(); + ClusterComposition composition = await( rediscovery.lookupClusterComposition( table, pool, empty(), null ) ).getClusterComposition(); assertEquals( validComposition, composition ); ArgumentCaptor warningMessageCaptor = ArgumentCaptor.forClass( String.class ); @@ -204,8 +204,8 @@ void shouldFailImmediatelyWhenClusterCompositionProviderReturnsFailure() void shouldResolveInitialRouterAddress() { BoltServerAddress initialRouter = A; - ClusterComposition expectedComposition = new ClusterComposition( 42, - asOrderedSet( A, B ), asOrderedSet( A, B ), asOrderedSet( A, B ) ); + ClusterComposition expectedComposition = + new ClusterComposition( 42, asOrderedSet( A, B ), asOrderedSet( A, B ), asOrderedSet( A, B ), null ); Map responsesByAddress = new HashMap<>(); responsesByAddress.put( B, new ServiceUnavailableException( "Hi!" ) ); // first -> non-fatal failure @@ -219,7 +219,7 @@ void shouldResolveInitialRouterAddress() Rediscovery rediscovery = newRediscovery( initialRouter, compositionProvider, resolver ); RoutingTable table = routingTableMock( B, C ); - ClusterComposition actualComposition = await( rediscovery.lookupClusterComposition( table, pool, empty() ) ).getClusterComposition(); + ClusterComposition actualComposition = await( rediscovery.lookupClusterComposition( table, pool, empty(), null ) ).getClusterComposition(); assertEquals( expectedComposition, actualComposition ); verify( table ).forget( B ); @@ -230,8 +230,8 @@ void shouldResolveInitialRouterAddress() @Test void shouldResolveInitialRouterAddressUsingCustomResolver() { - ClusterComposition expectedComposition = new ClusterComposition( 42, - asOrderedSet( A, B, C ), asOrderedSet( A, B, C ), asOrderedSet( B, E ) ); + ClusterComposition expectedComposition = + new ClusterComposition( 42, asOrderedSet( A, B, C ), asOrderedSet( A, B, C ), asOrderedSet( B, E ), null ); ServerAddressResolver resolver = address -> { @@ -248,7 +248,7 @@ void shouldResolveInitialRouterAddressUsingCustomResolver() Rediscovery rediscovery = newRediscovery( A, compositionProvider, resolver ); RoutingTable table = routingTableMock( B, C ); - ClusterComposition actualComposition = await( rediscovery.lookupClusterComposition( table, pool, empty() ) ).getClusterComposition(); + ClusterComposition actualComposition = await( rediscovery.lookupClusterComposition( table, pool, empty(), null ) ).getClusterComposition(); assertEquals( expectedComposition, actualComposition ); verify( table ).forget( B ); @@ -258,8 +258,8 @@ void shouldResolveInitialRouterAddressUsingCustomResolver() @Test void shouldPropagateFailureWhenResolverFails() { - ClusterComposition expectedComposition = new ClusterComposition( 42, - asOrderedSet( A, B ), asOrderedSet( A, B ), asOrderedSet( A, B ) ); + ClusterComposition expectedComposition = + new ClusterComposition( 42, asOrderedSet( A, B ), asOrderedSet( A, B ), asOrderedSet( A, B ), null ); Map responsesByAddress = singletonMap( A, expectedComposition ); ClusterCompositionProvider compositionProvider = compositionProviderMock( responsesByAddress ); @@ -271,7 +271,7 @@ void shouldPropagateFailureWhenResolverFails() Rediscovery rediscovery = newRediscovery( A, compositionProvider, resolver ); RoutingTable table = routingTableMock(); - RuntimeException error = assertThrows( RuntimeException.class, () -> await( rediscovery.lookupClusterComposition( table, pool, empty() ) ) ); + RuntimeException error = assertThrows( RuntimeException.class, () -> await( rediscovery.lookupClusterComposition( table, pool, empty(), null ) ) ); assertEquals( "Resolver fails!", error.getMessage() ); verify( resolver ).resolve( A ); @@ -293,7 +293,8 @@ void shouldRecordAllErrorsWhenNoRouterRespond() Rediscovery rediscovery = newRediscovery( A, compositionProvider, mock( ServerAddressResolver.class ) ); RoutingTable table = routingTableMock( A, B, C ); - ServiceUnavailableException e = assertThrows( ServiceUnavailableException.class, () -> await( rediscovery.lookupClusterComposition( table, pool, empty() ) ) ); + ServiceUnavailableException e = + assertThrows( ServiceUnavailableException.class, () -> await( rediscovery.lookupClusterComposition( table, pool, empty(), null ) ) ); assertThat( e.getMessage(), containsString( "Could not perform discovery" ) ); assertThat( e.getSuppressed().length, equalTo( 3 ) ); assertThat( e.getSuppressed()[0].getCause(), equalTo( first ) ); @@ -305,10 +306,10 @@ void shouldRecordAllErrorsWhenNoRouterRespond() void shouldUseInitialRouterAfterDiscoveryReturnsNoWriters() { BoltServerAddress initialRouter = A; - ClusterComposition noWritersComposition = new ClusterComposition( 42, - asOrderedSet( D, E ), emptySet(), asOrderedSet( D, E ) ); - ClusterComposition validComposition = new ClusterComposition( 42, - asOrderedSet( B, A ), asOrderedSet( B, A ), asOrderedSet( B, A ) ); + ClusterComposition noWritersComposition = + new ClusterComposition( 42, asOrderedSet( D, E ), emptySet(), asOrderedSet( D, E ), null ); + ClusterComposition validComposition = + new ClusterComposition( 42, asOrderedSet( B, A ), asOrderedSet( B, A ), asOrderedSet( B, A ), null ); Map responsesByAddress = new HashMap<>(); responsesByAddress.put( initialRouter, validComposition ); // initial -> valid composition @@ -319,7 +320,7 @@ void shouldUseInitialRouterAfterDiscoveryReturnsNoWriters() RoutingTable table = new ClusterRoutingTable( defaultDatabase(), new FakeClock() ); table.update( noWritersComposition ); - ClusterComposition composition2 = await( rediscovery.lookupClusterComposition( table, pool, empty() ) ).getClusterComposition(); + ClusterComposition composition2 = await( rediscovery.lookupClusterComposition( table, pool, empty(), null ) ).getClusterComposition(); assertEquals( validComposition, composition2 ); } @@ -327,8 +328,8 @@ void shouldUseInitialRouterAfterDiscoveryReturnsNoWriters() void shouldUseInitialRouterToStartWith() { BoltServerAddress initialRouter = A; - ClusterComposition validComposition = new ClusterComposition( 42, - asOrderedSet( A ), asOrderedSet( A ), asOrderedSet( A ) ); + ClusterComposition validComposition = + new ClusterComposition( 42, asOrderedSet( A ), asOrderedSet( A ), asOrderedSet( A ), null ); Map responsesByAddress = new HashMap<>(); responsesByAddress.put( initialRouter, validComposition ); // initial -> valid composition @@ -338,7 +339,7 @@ void shouldUseInitialRouterToStartWith() Rediscovery rediscovery = newRediscovery( initialRouter, compositionProvider, resolver ); RoutingTable table = routingTableMock( true, B, C, D ); - ClusterComposition composition = await( rediscovery.lookupClusterComposition( table, pool, empty() ) ).getClusterComposition(); + ClusterComposition composition = await( rediscovery.lookupClusterComposition( table, pool, empty(), null ) ).getClusterComposition(); assertEquals( validComposition, composition ); } @@ -346,8 +347,8 @@ void shouldUseInitialRouterToStartWith() void shouldUseKnownRoutersWhenInitialRouterFails() { BoltServerAddress initialRouter = A; - ClusterComposition validComposition = new ClusterComposition( 42, - asOrderedSet( D, E ), asOrderedSet( E, D ), asOrderedSet( A, B ) ); + ClusterComposition validComposition = + new ClusterComposition( 42, asOrderedSet( D, E ), asOrderedSet( E, D ), asOrderedSet( A, B ), null ); Map responsesByAddress = new HashMap<>(); responsesByAddress.put( initialRouter, new ServiceUnavailableException( "Hi" ) ); // initial -> non-fatal error @@ -359,7 +360,7 @@ void shouldUseKnownRoutersWhenInitialRouterFails() Rediscovery rediscovery = newRediscovery( initialRouter, compositionProvider, resolver ); RoutingTable table = routingTableMock( true, D, E ); - ClusterComposition composition = await( rediscovery.lookupClusterComposition( table, pool, empty() ) ).getClusterComposition(); + ClusterComposition composition = await( rediscovery.lookupClusterComposition( table, pool, empty(), null ) ).getClusterComposition(); assertEquals( validComposition, composition ); verify( table ).forget( initialRouter ); verify( table ).forget( D ); @@ -370,8 +371,8 @@ void shouldRetryConfiguredNumberOfTimesWithDelay() { int maxRoutingFailures = 3; long retryTimeoutDelay = 15; - ClusterComposition expectedComposition = new ClusterComposition( 42, - asOrderedSet( A, C ), asOrderedSet( B, D ), asOrderedSet( A, E ) ); + ClusterComposition expectedComposition = + new ClusterComposition( 42, asOrderedSet( A, C ), asOrderedSet( B, D ), asOrderedSet( A, E ), null ); Map responsesByAddress = new HashMap<>(); responsesByAddress.put( A, new ServiceUnavailableException( "Hi!" ) ); @@ -381,7 +382,7 @@ void shouldRetryConfiguredNumberOfTimesWithDelay() ClusterCompositionProvider compositionProvider = compositionProviderMock( responsesByAddress ); ServerAddressResolver resolver = mock( ServerAddressResolver.class ); when( resolver.resolve( A ) ).thenReturn( asOrderedSet( A ) ) - .thenReturn( asOrderedSet( A ) ) + .thenReturn( asOrderedSet( A ) ) .thenReturn( asOrderedSet( E ) ); ImmediateSchedulingEventExecutor eventExecutor = new ImmediateSchedulingEventExecutor(); @@ -391,7 +392,7 @@ void shouldRetryConfiguredNumberOfTimesWithDelay() DefaultDomainNameResolver.getInstance() ); RoutingTable table = routingTableMock( A, B ); - ClusterComposition actualComposition = await( rediscovery.lookupClusterComposition( table, pool, empty() ) ).getClusterComposition(); + ClusterComposition actualComposition = await( rediscovery.lookupClusterComposition( table, pool, empty(), null ) ).getClusterComposition(); assertEquals( expectedComposition, actualComposition ); verify( table, times( maxRoutingFailures ) ).forget( A ); @@ -419,7 +420,7 @@ void shouldNotLogWhenSingleRetryAttemptFails() RoutingTable table = routingTableMock( A ); ServiceUnavailableException e = - assertThrows( ServiceUnavailableException.class, () -> await( rediscovery.lookupClusterComposition( table, pool, empty() ) ) ); + assertThrows( ServiceUnavailableException.class, () -> await( rediscovery.lookupClusterComposition( table, pool, empty(), null ) ) ); assertThat( e.getMessage(), containsString( "Could not perform discovery" ) ); // rediscovery should not log about retries and should not schedule any retries @@ -464,21 +465,22 @@ private static ClusterCompositionProvider compositionProviderMock( Map responsesByAddress ) { ClusterCompositionProvider provider = mock( ClusterCompositionProvider.class ); - when( provider.getClusterComposition( any( Connection.class ), any( DatabaseName.class ), any( InternalBookmark.class ) ) ).then( invocation -> - { - Connection connection = invocation.getArgument( 0 ); - BoltServerAddress address = connection.serverAddress(); - Object response = responsesByAddress.get( address ); - assertNotNull( response ); - if ( response instanceof Throwable ) - { - return failedFuture( (Throwable) response ); - } - else - { - return completedFuture( response ); - } - } ); + when( provider.getClusterComposition( any( Connection.class ), any( DatabaseName.class ), any( InternalBookmark.class ), any() ) ) + .then( invocation -> + { + Connection connection = invocation.getArgument( 0 ); + BoltServerAddress address = connection.serverAddress(); + Object response = responsesByAddress.get( address ); + assertNotNull( response ); + if ( response instanceof Throwable ) + { + return failedFuture( (Throwable) response ); + } + else + { + return completedFuture( response ); + } + } ); return provider; } diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/RouteMessageRoutingProcedureRunnerTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/RouteMessageRoutingProcedureRunnerTest.java index bc92e3e9c0..484394e615 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/RouteMessageRoutingProcedureRunnerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/RouteMessageRoutingProcedureRunnerTest.java @@ -79,7 +79,7 @@ void shouldRequestRoutingTableForAllValidInputScenarios( RoutingContext routingC CompletableFuture releaseConnectionFuture = CompletableFuture.completedFuture( null ); doReturn( releaseConnectionFuture ).when( connection ).release(); - RoutingProcedureResponse response = TestUtil.await( runner.run( connection, databaseName, null ) ); + RoutingProcedureResponse response = TestUtil.await( runner.run( connection, databaseName, null, null ) ); assertNotNull( response ); assertTrue( response.isSuccess() ); @@ -106,7 +106,7 @@ void shouldReturnFailureWhenSomethingHappensGettingTheRoutingTable() CompletableFuture releaseConnectionFuture = CompletableFuture.completedFuture( null ); doReturn( releaseConnectionFuture ).when( connection ).release(); - RoutingProcedureResponse response = TestUtil.await( runner.run( connection, DatabaseNameUtil.defaultDatabase(), null ) ); + RoutingProcedureResponse response = TestUtil.await( runner.run( connection, DatabaseNameUtil.defaultDatabase(), null, null ) ); assertNotNull( response ); assertFalse( response.isSuccess() ); @@ -126,7 +126,7 @@ private void verifyMessageWasWrittenAndFlushed( Connection connection, Completab .stream() .collect( Collectors.toMap( Map.Entry::getKey, entry -> Values.value( entry.getValue() ) ) ); - verify( connection ).writeAndFlush( eq( new RouteMessage( context, bookmark, databaseName.databaseName().orElse( null ) ) ), + verify( connection ).writeAndFlush( eq( new RouteMessage( context, bookmark, databaseName.databaseName().orElse( null ), null ) ), eq( new RouteMessageResponseHandler( completableFuture ) ) ); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingProcedureClusterCompositionProviderTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingProcedureClusterCompositionProviderTest.java index f21f3fd83f..d2cc5b3750 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingProcedureClusterCompositionProviderTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingProcedureClusterCompositionProviderTest.java @@ -72,12 +72,12 @@ void shouldProtocolErrorWhenNoRecord() newClusterCompositionProvider( mockedRunner, connection ); RoutingProcedureResponse noRecordsResponse = newRoutingResponse(); - when( mockedRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ) ) ) + when( mockedRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ), any() ) ) .thenReturn( completedFuture( noRecordsResponse ) ); // When & Then ProtocolException error = assertThrows( ProtocolException.class, - () -> await( provider.getClusterComposition( connection, defaultDatabase(), empty() ) ) ); + () -> await( provider.getClusterComposition( connection, defaultDatabase(), empty(), null ) ) ); assertThat( error.getMessage(), containsString( "records received '0' is too few or too many." ) ); } @@ -92,11 +92,12 @@ void shouldProtocolErrorWhenMoreThanOneRecord() Record aRecord = new InternalRecord( asList( "key1", "key2" ), new Value[]{new StringValue( "a value" )} ); RoutingProcedureResponse routingResponse = newRoutingResponse( aRecord, aRecord ); - when( mockedRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ) ) ).thenReturn( completedFuture( routingResponse ) ); + when( mockedRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ), any() ) ) + .thenReturn( completedFuture( routingResponse ) ); // When ProtocolException error = assertThrows( ProtocolException.class, - () -> await( provider.getClusterComposition( connection, defaultDatabase(), empty() ) ) ); + () -> await( provider.getClusterComposition( connection, defaultDatabase(), empty(), null ) ) ); assertThat( error.getMessage(), containsString( "records received '2' is too few or too many." ) ); } @@ -111,11 +112,12 @@ void shouldProtocolErrorWhenUnparsableRecord() Record aRecord = new InternalRecord( asList( "key1", "key2" ), new Value[]{new StringValue( "a value" )} ); RoutingProcedureResponse routingResponse = newRoutingResponse( aRecord ); - when( mockedRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ) ) ).thenReturn( completedFuture( routingResponse ) ); + when( mockedRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ), any() ) ) + .thenReturn( completedFuture( routingResponse ) ); // When ProtocolException error = assertThrows( ProtocolException.class, - () -> await( provider.getClusterComposition( connection, defaultDatabase(), empty() ) ) ); + () -> await( provider.getClusterComposition( connection, defaultDatabase(), empty(), null ) ) ); assertThat( error.getMessage(), containsString( "unparsable record received." ) ); } @@ -135,12 +137,13 @@ void shouldProtocolErrorWhenNoRouters() serverInfo( "WRITE", "one:1337" ) ) ) } ); RoutingProcedureResponse routingResponse = newRoutingResponse( record ); - when( mockedRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ) ) ).thenReturn( completedFuture( routingResponse ) ); + when( mockedRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ), any() ) ) + .thenReturn( completedFuture( routingResponse ) ); when( mockedClock.millis() ).thenReturn( 12345L ); // When ProtocolException error = assertThrows( ProtocolException.class, - () -> await( provider.getClusterComposition( connection, defaultDatabase(), empty() ) ) ); + () -> await( provider.getClusterComposition( connection, defaultDatabase(), empty(), null ) ) ); assertThat( error.getMessage(), containsString( "no router or reader found in response." ) ); } @@ -160,12 +163,13 @@ void routeMessageRoutingProcedureShouldProtocolErrorWhenNoRouters() serverInfo( "WRITE", "one:1337" ) ) ) } ); RoutingProcedureResponse routingResponse = newRoutingResponse( record ); - when( mockedRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ) ) ).thenReturn( completedFuture( routingResponse ) ); + when( mockedRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ), any() ) ) + .thenReturn( completedFuture( routingResponse ) ); when( mockedClock.millis() ).thenReturn( 12345L ); // When ProtocolException error = assertThrows( ProtocolException.class, - () -> await( provider.getClusterComposition( connection, defaultDatabase(), empty() ) ) ); + () -> await( provider.getClusterComposition( connection, defaultDatabase(), empty(), null ) ) ); assertThat( error.getMessage(), containsString( "no router or reader found in response." ) ); } @@ -185,12 +189,13 @@ void shouldProtocolErrorWhenNoReaders() serverInfo( "ROUTE", "one:1337", "two:1337" ) ) ) } ); RoutingProcedureResponse routingResponse = newRoutingResponse( record ); - when( mockedRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ) ) ).thenReturn( completedFuture( routingResponse ) ); + when( mockedRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ), any() ) ) + .thenReturn( completedFuture( routingResponse ) ); when( mockedClock.millis() ).thenReturn( 12345L ); // When ProtocolException error = assertThrows( ProtocolException.class, - () -> await( provider.getClusterComposition( connection, defaultDatabase(), empty() ) ) ); + () -> await( provider.getClusterComposition( connection, defaultDatabase(), empty(), null ) ) ); assertThat( error.getMessage(), containsString( "no router or reader found in response." ) ); } @@ -210,12 +215,13 @@ void routeMessageRoutingProcedureShouldProtocolErrorWhenNoReaders() serverInfo( "ROUTE", "one:1337", "two:1337" ) ) ) } ); RoutingProcedureResponse routingResponse = newRoutingResponse( record ); - when( mockedRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ) ) ).thenReturn( completedFuture( routingResponse ) ); + when( mockedRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ), any() ) ) + .thenReturn( completedFuture( routingResponse ) ); when( mockedClock.millis() ).thenReturn( 12345L ); // When ProtocolException error = assertThrows( ProtocolException.class, - () -> await( provider.getClusterComposition( connection, defaultDatabase(), empty() ) ) ); + () -> await( provider.getClusterComposition( connection, defaultDatabase(), empty(), null ) ) ); assertThat( error.getMessage(), containsString( "no router or reader found in response." ) ); } @@ -228,12 +234,12 @@ void shouldPropagateConnectionFailureExceptions() ClusterCompositionProvider provider = newClusterCompositionProvider( mockedRunner, connection ); - when( mockedRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ) ) ).thenReturn( failedFuture( - new ServiceUnavailableException( "Connection breaks during cypher execution" ) ) ); + when( mockedRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ), any() ) ) + .thenReturn( failedFuture( new ServiceUnavailableException( "Connection breaks during cypher execution" ) ) ); // When & Then ServiceUnavailableException e = assertThrows( ServiceUnavailableException.class, - () -> await( provider.getClusterComposition( connection, defaultDatabase(), empty() ) ) ); + () -> await( provider.getClusterComposition( connection, defaultDatabase(), empty(), null ) ) ); assertThat( e.getMessage(), containsString( "Connection breaks during cypher execution" ) ); } @@ -254,12 +260,12 @@ void shouldReturnSuccessResultWhenNoError() serverInfo( "ROUTE", "one:1337", "two:1337" ) ) ) } ); RoutingProcedureResponse routingResponse = newRoutingResponse( record ); - when( mockedRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ) ) ) + when( mockedRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ), any() ) ) .thenReturn( completedFuture( routingResponse ) ); when( mockedClock.millis() ).thenReturn( 12345L ); // When - ClusterComposition cluster = await( provider.getClusterComposition( connection, defaultDatabase(), empty() ) ); + ClusterComposition cluster = await( provider.getClusterComposition( connection, defaultDatabase(), empty(), null ) ); // Then assertEquals( 12345 + 100_000, cluster.expirationTimestamp() ); @@ -285,12 +291,12 @@ void routeMessageRoutingProcedureShouldReturnSuccessResultWhenNoError() serverInfo( "ROUTE", "one:1337", "two:1337" ) ) ) } ); RoutingProcedureResponse routingResponse = newRoutingResponse( record ); - when( mockedRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ) ) ) + when( mockedRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ), any() ) ) .thenReturn( completedFuture( routingResponse ) ); when( mockedClock.millis() ).thenReturn( 12345L ); // When - ClusterComposition cluster = await( provider.getClusterComposition( connection, defaultDatabase(), empty() ) ); + ClusterComposition cluster = await( provider.getClusterComposition( connection, defaultDatabase(), empty(), null ) ); // Then assertEquals( 12345 + 100_000, cluster.expirationTimestamp() ); @@ -306,14 +312,14 @@ void shouldReturnFailureWhenProcedureRunnerFails() Connection connection = mock( Connection.class ); RuntimeException error = new RuntimeException( "hi" ); - when( procedureRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ) ) ) + when( procedureRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ), any() ) ) .thenReturn( completedFuture( newRoutingResponse( error ) ) ); RoutingProcedureClusterCompositionProvider provider = newClusterCompositionProvider( procedureRunner, connection ); RuntimeException e = assertThrows( RuntimeException.class, - () -> await( provider.getClusterComposition( connection, defaultDatabase(), empty() ) ) ); + () -> await( provider.getClusterComposition( connection, defaultDatabase(), empty(), null ) ) ); assertEquals( error, e ); } @@ -326,10 +332,10 @@ void shouldUseMultiDBProcedureRunnerWhenConnectingWith40Server() throws Throwabl RoutingProcedureClusterCompositionProvider provider = newClusterCompositionProvider( procedureRunner, connection ); - when( procedureRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ) ) ).thenReturn( completedWithNull() ); - provider.getClusterComposition( connection, defaultDatabase(), empty() ); + when( procedureRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ), any() ) ).thenReturn( completedWithNull() ); + provider.getClusterComposition( connection, defaultDatabase(), empty(), null ); - verify( procedureRunner ).run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ) ); + verify( procedureRunner ).run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ), any() ); } @Test @@ -341,10 +347,10 @@ void shouldUseProcedureRunnerWhenConnectingWith35AndPreviousServers() throws Thr RoutingProcedureClusterCompositionProvider provider = newClusterCompositionProvider( procedureRunner, connection ); - when( procedureRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ) ) ).thenReturn( completedWithNull() ); - provider.getClusterComposition( connection, defaultDatabase(), empty() ); + when( procedureRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ), any() ) ).thenReturn( completedWithNull() ); + provider.getClusterComposition( connection, defaultDatabase(), empty(), null ); - verify( procedureRunner ).run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ) ); + verify( procedureRunner ).run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ), any() ); } @Test @@ -356,10 +362,10 @@ void shouldUseRouteMessageProcedureRunnerWhenConnectingWithProtocol43() throws T RoutingProcedureClusterCompositionProvider provider = newClusterCompositionProvider( procedureRunner, connection ); - when( procedureRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ) ) ).thenReturn( completedWithNull() ); - provider.getClusterComposition( connection, defaultDatabase(), empty() ); + when( procedureRunner.run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ), any() ) ).thenReturn( completedWithNull() ); + provider.getClusterComposition( connection, defaultDatabase(), empty(), null ); - verify( procedureRunner ).run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ) ); + verify( procedureRunner ).run( eq( connection ), any( DatabaseName.class ), any( InternalBookmark.class ), any() ); } private static Map serverInfo( String role, String... addresses ) diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingTableHandlerTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingTableHandlerTest.java index 6a56b1f276..709b547434 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingTableHandlerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingTableHandlerTest.java @@ -76,7 +76,7 @@ void shouldRemoveAddressFromRoutingTableOnConnectionFailure() { RoutingTable routingTable = new ClusterRoutingTable( defaultDatabase(), new FakeClock() ); routingTable.update( new ClusterComposition( - 42, asOrderedSet( A, B, C ), asOrderedSet( A, C, E ), asOrderedSet( B, D, F ) ) ); + 42, asOrderedSet( A, B, C ), asOrderedSet( A, C, E ), asOrderedSet( B, D, F ), null ) ); RoutingTableHandler handler = newRoutingTableHandler( routingTable, newRediscoveryMock(), newConnectionPoolMock() ); @@ -109,16 +109,16 @@ void acquireShouldUpdateRoutingTableWhenKnownRoutingTableIsStale() Set readers = new LinkedHashSet<>( asList( reader1, reader2 ) ); Set writers = new LinkedHashSet<>( singletonList( writer1 ) ); Set routers = new LinkedHashSet<>( singletonList( router1 ) ); - ClusterComposition clusterComposition = new ClusterComposition( 42, readers, writers, routers ); + ClusterComposition clusterComposition = new ClusterComposition( 42, readers, writers, routers, null ); Rediscovery rediscovery = mock( RediscoveryImpl.class ); - when( rediscovery.lookupClusterComposition( eq( routingTable ), eq( connectionPool ), any() ) ) + when( rediscovery.lookupClusterComposition( eq( routingTable ), eq( connectionPool ), any(), any() ) ) .thenReturn( completedFuture( new ClusterCompositionLookupResult( clusterComposition ) ) ); RoutingTableHandler handler = newRoutingTableHandler( routingTable, rediscovery, connectionPool ); assertNotNull( await( handler.ensureRoutingTable( simple( false ) ) ) ); - verify( rediscovery ).lookupClusterComposition( eq ( routingTable ) , eq ( connectionPool ), any() ); + verify( rediscovery ).lookupClusterComposition( eq( routingTable ), eq( connectionPool ), any(), any() ); assertArrayEquals( new BoltServerAddress[]{reader1, reader2}, routingTable.readers().toArray() ); assertArrayEquals( new BoltServerAddress[]{writer1}, routingTable.writers().toArray() ); assertArrayEquals( new BoltServerAddress[]{router1}, routingTable.routers().toArray() ); @@ -153,13 +153,13 @@ void shouldRetainAllFetchedAddressesInConnectionPoolAfterFetchingOfRoutingTable( { RoutingTable routingTable = new ClusterRoutingTable( defaultDatabase(), new FakeClock() ); routingTable.update( new ClusterComposition( - 42, asOrderedSet(), asOrderedSet( B, C ), asOrderedSet( D, E ) ) ); + 42, asOrderedSet(), asOrderedSet( B, C ), asOrderedSet( D, E ), null ) ); ConnectionPool connectionPool = newConnectionPoolMock(); Rediscovery rediscovery = newRediscoveryMock(); - when( rediscovery.lookupClusterComposition( any(), any(), any() ) ).thenReturn( completedFuture( - new ClusterCompositionLookupResult( new ClusterComposition( 42, asOrderedSet( A, B ), asOrderedSet( B, C ), asOrderedSet( A, C ) ) ) ) ); + when( rediscovery.lookupClusterComposition( any(), any(), any(), any() ) ).thenReturn( completedFuture( + new ClusterCompositionLookupResult( new ClusterComposition( 42, asOrderedSet( A, B ), asOrderedSet( B, C ), asOrderedSet( A, C ), null ) ) ) ); RoutingTableRegistry registry = new RoutingTableRegistry() { @@ -208,7 +208,7 @@ void shouldRemoveRoutingTableHandlerIfFailedToLookup() throws Throwable RoutingTable routingTable = new ClusterRoutingTable( defaultDatabase(), new FakeClock() ); Rediscovery rediscovery = newRediscoveryMock(); - when( rediscovery.lookupClusterComposition( any(), any(), any() ) ).thenReturn( Futures.failedFuture( new RuntimeException( "Bang!" ) ) ); + when( rediscovery.lookupClusterComposition( any(), any(), any(), any() ) ).thenReturn( Futures.failedFuture( new RuntimeException( "Bang!" ) ) ); ConnectionPool connectionPool = newConnectionPoolMock(); RoutingTableRegistry registry = newRoutingTableRegistryMock(); @@ -235,7 +235,7 @@ private void testRediscoveryWhenStale( AccessMode mode ) assertEquals( routingTable, actual ); verify( routingTable ).isStaleFor( mode ); - verify( rediscovery ).lookupClusterComposition( eq( routingTable ), eq( connectionPool ), any() ); + verify( rediscovery ).lookupClusterComposition( eq( routingTable ), eq( connectionPool ), any(), any() ); } private void testNoRediscoveryWhenNotStale( AccessMode staleMode, AccessMode notStaleMode ) @@ -251,7 +251,7 @@ private void testNoRediscoveryWhenNotStale( AccessMode staleMode, AccessMode not assertNotNull( await( handler.ensureRoutingTable( contextWithMode( notStaleMode ) ) ) ); verify( routingTable ).isStaleFor( notStaleMode ); - verify( rediscovery, never() ).lookupClusterComposition( eq( routingTable ), eq( connectionPool ), any() ); + verify( rediscovery, never() ).lookupClusterComposition( eq( routingTable ), eq( connectionPool ), any(), any() ); } private static RoutingTable newStaleRoutingTableMock( AccessMode mode ) @@ -277,8 +277,8 @@ private static Rediscovery newRediscoveryMock() { Rediscovery rediscovery = mock( RediscoveryImpl.class ); Set noServers = Collections.emptySet(); - ClusterComposition clusterComposition = new ClusterComposition( 1, noServers, noServers, noServers ); - when( rediscovery.lookupClusterComposition( any( RoutingTable.class ), any( ConnectionPool.class ), any( InternalBookmark.class ) ) ) + ClusterComposition clusterComposition = new ClusterComposition( 1, noServers, noServers, noServers, null ); + when( rediscovery.lookupClusterComposition( any( RoutingTable.class ), any( ConnectionPool.class ), any( InternalBookmark.class ), any() ) ) .thenReturn( completedFuture( new ClusterCompositionLookupResult( clusterComposition ) ) ); return rediscovery; } diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingTableRegistryImplTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingTableRegistryImplTest.java index a15db7caab..ffd93d309d 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingTableRegistryImplTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingTableRegistryImplTest.java @@ -136,7 +136,7 @@ void shouldReturnFreshRoutingTable( AccessMode mode ) throws Throwable ConcurrentMap map = new ConcurrentHashMap<>(); RoutingTableHandler handler = mockedRoutingTableHandler(); RoutingTableHandlerFactory factory = mockedHandlerFactory( handler ); - RoutingTableRegistryImpl routingTables = new RoutingTableRegistryImpl( map, factory, DEV_NULL_LOGGING ); + RoutingTableRegistryImpl routingTables = new RoutingTableRegistryImpl( map, factory, null, null, null, DEV_NULL_LOGGING ); ImmutableConnectionContext context = new ImmutableConnectionContext( defaultDatabase(), InternalBookmark.empty(), mode ); // When @@ -155,7 +155,7 @@ void shouldReturnServersInAllRoutingTables() throws Throwable map.put( database( "Banana" ), mockedRoutingTableHandler( B, C, D ) ); map.put( database( "Orange" ), mockedRoutingTableHandler( E, F, C ) ); RoutingTableHandlerFactory factory = mockedHandlerFactory(); - RoutingTableRegistryImpl routingTables = new RoutingTableRegistryImpl( map, factory, DEV_NULL_LOGGING ); + RoutingTableRegistryImpl routingTables = new RoutingTableRegistryImpl( map, factory, null, null, null, DEV_NULL_LOGGING ); // When Set servers = routingTables.allServers(); @@ -210,7 +210,7 @@ private RoutingTableHandler mockedRoutingTableHandler( BoltServerAddress... serv private RoutingTableRegistryImpl newRoutingTables( ConcurrentMap handlers, RoutingTableHandlerFactory factory ) { - return new RoutingTableRegistryImpl( handlers, factory, DEV_NULL_LOGGING ); + return new RoutingTableRegistryImpl( handlers, factory, null, null, null, DEV_NULL_LOGGING ); } private RoutingTableHandlerFactory mockedHandlerFactory( RoutingTableHandler handler ) diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/SingleDatabaseRoutingProcedureRunnerTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/SingleDatabaseRoutingProcedureRunnerTest.java index 4d7c0a8e1e..7bb13949ee 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/SingleDatabaseRoutingProcedureRunnerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/SingleDatabaseRoutingProcedureRunnerTest.java @@ -60,7 +60,7 @@ class SingleDatabaseRoutingProcedureRunnerTest extends AbstractRoutingProcedureR void shouldCallGetRoutingTableWithEmptyMap() { TestRoutingProcedureRunner runner = new TestRoutingProcedureRunner( RoutingContext.EMPTY ); - RoutingProcedureResponse response = await( runner.run( connection(), defaultDatabase(), empty() ) ); + RoutingProcedureResponse response = await( runner.run( connection(), defaultDatabase(), empty(), null ) ); assertTrue( response.isSuccess() ); assertEquals( 1, response.records().size() ); @@ -80,7 +80,7 @@ void shouldCallGetRoutingTableWithParam() RoutingContext context = new RoutingContext( uri ); TestRoutingProcedureRunner runner = new TestRoutingProcedureRunner( context ); - RoutingProcedureResponse response = await( runner.run( connection(), defaultDatabase(), empty() ) ); + RoutingProcedureResponse response = await( runner.run( connection(), defaultDatabase(), empty(), null ) ); assertTrue( response.isSuccess() ); assertEquals( 1, response.records().size() ); @@ -99,7 +99,7 @@ void shouldCallGetRoutingTableWithParam() void shouldErrorWhenDatabaseIsNotAbsent( String db ) throws Throwable { TestRoutingProcedureRunner runner = new TestRoutingProcedureRunner( RoutingContext.EMPTY ); - assertThrows( FatalDiscoveryException.class, () -> await( runner.run( connection(), database( db ), empty() ) ) ); + assertThrows( FatalDiscoveryException.class, () -> await( runner.run( connection(), database( db ), empty(), null ) ) ); } SingleDatabaseRoutingProcedureRunner singleDatabaseRoutingProcedureRunner( RoutingContext context ) diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancerTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancerTest.java index 516f2eec4e..ed07a78d8a 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancerTest.java @@ -20,9 +20,11 @@ import io.netty.util.concurrent.GlobalEventExecutor; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.InOrder; import java.util.Arrays; import java.util.HashSet; @@ -37,6 +39,8 @@ import org.neo4j.driver.exceptions.ServiceUnavailableException; import org.neo4j.driver.exceptions.SessionExpiredException; import org.neo4j.driver.internal.BoltServerAddress; +import org.neo4j.driver.internal.DatabaseName; +import org.neo4j.driver.internal.DatabaseNameUtil; import org.neo4j.driver.internal.async.ConnectionContext; import org.neo4j.driver.internal.async.connection.RoutingConnection; import org.neo4j.driver.internal.cluster.AddressSet; @@ -68,7 +72,9 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -205,7 +211,7 @@ void shouldTryMultipleServersAfterRediscovery() ConnectionPool connectionPool = newConnectionPoolMockWithFailures( unavailableAddresses ); RoutingTable routingTable = new ClusterRoutingTable( defaultDatabase(), new FakeClock() ); - routingTable.update( new ClusterComposition( -1, new LinkedHashSet<>( Arrays.asList( A, B ) ), emptySet(), emptySet() ) ); + routingTable.update( new ClusterComposition( -1, new LinkedHashSet<>( Arrays.asList( A, B ) ), emptySet(), emptySet(), null ) ); LoadBalancer loadBalancer = newLoadBalancer( connectionPool, routingTable ); @@ -363,6 +369,52 @@ void shouldReturnSuccessVerifyConnectivity() throws Throwable verify( routingTables ).ensureRoutingTable( any( ConnectionContext.class ) ); } + @ParameterizedTest + @ValueSource( booleans = {true, false} ) + void expectsCompetedDatabaseNameAfterRoutingTableRegistry( boolean completed ) throws Throwable + { + ConnectionPool connectionPool = newConnectionPoolMock(); + RoutingTable routingTable = mock( RoutingTable.class ); + AddressSet readerAddresses = mock( AddressSet.class ); + AddressSet writerAddresses = mock( AddressSet.class ); + when( readerAddresses.toArray() ).thenReturn( new BoltServerAddress[]{A} ); + when( writerAddresses.toArray() ).thenReturn( new BoltServerAddress[]{B} ); + when( routingTable.readers() ).thenReturn( readerAddresses ); + when( routingTable.writers() ).thenReturn( writerAddresses ); + RoutingTableRegistry routingTables = mock( RoutingTableRegistry.class ); + RoutingTableHandler handler = mock( RoutingTableHandler.class ); + when( handler.routingTable() ).thenReturn( routingTable ); + when( routingTables.ensureRoutingTable( any( ConnectionContext.class ) ) ).thenReturn( CompletableFuture.completedFuture( handler ) ); + Rediscovery rediscovery = mock( Rediscovery.class ); + LoadBalancer loadBalancer = + new LoadBalancer( connectionPool, routingTables, rediscovery, new LeastConnectedLoadBalancingStrategy( connectionPool, DEV_NULL_LOGGING ), + GlobalEventExecutor.INSTANCE, DEV_NULL_LOGGING ); + ConnectionContext context = mock( ConnectionContext.class ); + CompletableFuture databaseNameFuture = + spy( completed ? CompletableFuture.completedFuture( DatabaseNameUtil.systemDatabase() ) : new CompletableFuture<>() ); + when( context.databaseNameFuture() ).thenReturn( databaseNameFuture ); + when( context.mode() ).thenReturn( WRITE ); + + Executable action = () -> await( loadBalancer.acquireConnection( context ) ); + if ( completed ) + { + action.execute(); + } + else + { + assertThrows( IllegalStateException.class, action, ConnectionContext.PENDING_DATABASE_NAME_EXCEPTION_SUPPLIER.get().getMessage() ); + } + + InOrder inOrder = inOrder( routingTables, context, databaseNameFuture ); + inOrder.verify( routingTables ).ensureRoutingTable( context ); + inOrder.verify( context ).databaseNameFuture(); + inOrder.verify( databaseNameFuture ).isDone(); + if ( completed ) + { + inOrder.verify( databaseNameFuture ).join(); + } + } + private static ConnectionPool newConnectionPoolMock() { return newConnectionPoolMockWithFailures( emptySet() ); diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/RoutingTableAndConnectionPoolTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/RoutingTableAndConnectionPoolTest.java index a2bd57d35d..bd7eb9a2ae 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/RoutingTableAndConnectionPoolTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/RoutingTableAndConnectionPoolTest.java @@ -99,7 +99,7 @@ void shouldAddServerToRoutingTableAndConnectionPool() throws Throwable // Given ConnectionPool connectionPool = newConnectionPool(); Rediscovery rediscovery = mock( Rediscovery.class ); - when( rediscovery.lookupClusterComposition( any(), any(), any() ) ).thenReturn( clusterComposition( A ) ); + when( rediscovery.lookupClusterComposition( any(), any(), any(), any() ) ).thenReturn( clusterComposition( A ) ); RoutingTableRegistryImpl routingTables = newRoutingTables( connectionPool, rediscovery ); LoadBalancer loadBalancer = newLoadBalancer( connectionPool, routingTables ); @@ -119,7 +119,8 @@ void shouldNotAddToRoutingTableWhenFailedWithRoutingError() throws Throwable // Given ConnectionPool connectionPool = newConnectionPool(); Rediscovery rediscovery = mock( Rediscovery.class ); - when( rediscovery.lookupClusterComposition( any(), any(), any() ) ).thenReturn( Futures.failedFuture( new FatalDiscoveryException( "No database found" ) ) ); + when( rediscovery.lookupClusterComposition( any(), any(), any(), any() ) ) + .thenReturn( Futures.failedFuture( new FatalDiscoveryException( "No database found" ) ) ); RoutingTableRegistryImpl routingTables = newRoutingTables( connectionPool, rediscovery ); LoadBalancer loadBalancer = newLoadBalancer( connectionPool, routingTables ); @@ -138,7 +139,8 @@ void shouldNotAddToRoutingTableWhenFailedWithProtocolError() throws Throwable // Given ConnectionPool connectionPool = newConnectionPool(); Rediscovery rediscovery = mock( Rediscovery.class ); - when( rediscovery.lookupClusterComposition( any(), any(), any() ) ).thenReturn( Futures.failedFuture( new ProtocolException( "No database found" ) ) ); + when( rediscovery.lookupClusterComposition( any(), any(), any(), any() ) ) + .thenReturn( Futures.failedFuture( new ProtocolException( "No database found" ) ) ); RoutingTableRegistryImpl routingTables = newRoutingTables( connectionPool, rediscovery ); LoadBalancer loadBalancer = newLoadBalancer( connectionPool, routingTables ); @@ -157,7 +159,8 @@ void shouldNotAddToRoutingTableWhenFailedWithSecurityError() throws Throwable // Given ConnectionPool connectionPool = newConnectionPool(); Rediscovery rediscovery = mock( Rediscovery.class ); - when( rediscovery.lookupClusterComposition( any(), any(), any() ) ).thenReturn( Futures.failedFuture( new SecurityException( "No database found" ) ) ); + when( rediscovery.lookupClusterComposition( any(), any(), any(), any() ) ) + .thenReturn( Futures.failedFuture( new SecurityException( "No database found" ) ) ); RoutingTableRegistryImpl routingTables = newRoutingTables( connectionPool, rediscovery ); LoadBalancer loadBalancer = newLoadBalancer( connectionPool, routingTables ); @@ -176,7 +179,7 @@ void shouldNotRemoveNewlyAddedRoutingTableEvenIfItIsExpired() throws Throwable // Given ConnectionPool connectionPool = newConnectionPool(); Rediscovery rediscovery = mock( Rediscovery.class ); - when( rediscovery.lookupClusterComposition( any(), any(), any() ) ).thenReturn( expiredClusterComposition( A ) ); + when( rediscovery.lookupClusterComposition( any(), any(), any(), any() ) ).thenReturn( expiredClusterComposition( A ) ); RoutingTableRegistryImpl routingTables = newRoutingTables( connectionPool, rediscovery ); LoadBalancer loadBalancer = newLoadBalancer( connectionPool, routingTables ); @@ -199,14 +202,16 @@ void shouldRemoveExpiredRoutingTableAndServers() throws Throwable // Given ConnectionPool connectionPool = newConnectionPool(); Rediscovery rediscovery = mock( Rediscovery.class ); - when( rediscovery.lookupClusterComposition( any(), any(), any() ) ).thenReturn( expiredClusterComposition( A ) ).thenReturn( clusterComposition( B ) ); + when( rediscovery.lookupClusterComposition( any(), any(), any(), any() ) ) + .thenReturn( expiredClusterComposition( A ) ) + .thenReturn( clusterComposition( B ) ); RoutingTableRegistryImpl routingTables = newRoutingTables( connectionPool, rediscovery ); LoadBalancer loadBalancer = newLoadBalancer( connectionPool, routingTables ); // When Connection connection = await( loadBalancer.acquireConnection( contextWithDatabase( "neo4j" ) ) ); await( connection.release() ); - await( loadBalancer.acquireConnection( contextWithDatabase( "foo" ) ) ); + await( loadBalancer.acquireConnection( contextWithDatabase( "foo" ) ) ); // Then assertFalse( routingTables.contains( database( "neo4j" ) ) ); @@ -224,12 +229,14 @@ void shouldRemoveExpiredRoutingTableButNotServer() throws Throwable // Given ConnectionPool connectionPool = newConnectionPool(); Rediscovery rediscovery = mock( Rediscovery.class ); - when( rediscovery.lookupClusterComposition( any(), any(), any() ) ).thenReturn( expiredClusterComposition( A ) ).thenReturn( clusterComposition( B ) ); + when( rediscovery.lookupClusterComposition( any(), any(), any(), any() ) ) + .thenReturn( expiredClusterComposition( A ) ) + .thenReturn( clusterComposition( B ) ); RoutingTableRegistryImpl routingTables = newRoutingTables( connectionPool, rediscovery ); LoadBalancer loadBalancer = newLoadBalancer( connectionPool, routingTables ); // When - await( loadBalancer.acquireConnection( contextWithDatabase("neo4j" ) ) ); + await( loadBalancer.acquireConnection( contextWithDatabase( "neo4j" ) ) ); await( loadBalancer.acquireConnection( contextWithDatabase( "foo" ) ) ); // Then @@ -345,7 +352,7 @@ private CompletableFuture expiredClusterComposit private CompletableFuture clusterComposition( long expireAfterMs, BoltServerAddress... addresses ) { HashSet servers = new HashSet<>( Arrays.asList( addresses ) ); - ClusterComposition composition = new ClusterComposition( clock.millis() + expireAfterMs, servers, servers, servers ); + ClusterComposition composition = new ClusterComposition( clock.millis() + expireAfterMs, servers, servers, servers, null ); return CompletableFuture.completedFuture( new ClusterCompositionLookupResult( composition ) ); } @@ -353,7 +360,7 @@ private class RandomizedRediscovery implements Rediscovery { @Override public CompletionStage lookupClusterComposition( RoutingTable routingTable, ConnectionPool connectionPool, - Bookmark bookmark ) + Bookmark bookmark, String impersonatedUser ) { // when looking up a new routing table, we return a valid random routing table back Set servers = new HashSet<>(); @@ -370,7 +377,7 @@ public CompletionStage lookupClusterComposition( { servers.add( A ); } - ClusterComposition composition = new ClusterComposition( clock.millis() + 1, servers, servers, servers ); + ClusterComposition composition = new ClusterComposition( clock.millis() + 1, servers, servers, servers, null ); return CompletableFuture.completedFuture( new ClusterCompositionLookupResult( composition ) ); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/messaging/encode/BeginMessageEncoderTest.java b/driver/src/test/java/org/neo4j/driver/internal/messaging/encode/BeginMessageEncoderTest.java index 6651e8b931..f2c6355412 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/messaging/encode/BeginMessageEncoderTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/messaging/encode/BeginMessageEncoderTest.java @@ -20,12 +20,15 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.mockito.InOrder; import java.time.Duration; +import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import java.util.stream.Stream; import org.neo4j.driver.AccessMode; import org.neo4j.driver.Bookmark; @@ -48,8 +51,8 @@ class BeginMessageEncoderTest private final ValuePacker packer = mock( ValuePacker.class ); @ParameterizedTest - @EnumSource( AccessMode.class ) - void shouldEncodeBeginMessage( AccessMode mode ) throws Exception + @MethodSource( "arguments" ) + void shouldEncodeBeginMessage( AccessMode mode, String impersonatedUser ) throws Exception { Bookmark bookmark = InternalBookmark.parse( "neo4j:bookmark:v1:tx42" ); @@ -59,7 +62,7 @@ void shouldEncodeBeginMessage( AccessMode mode ) throws Exception Duration txTimeout = Duration.ofSeconds( 1 ); - encoder.encode( new BeginMessage( bookmark, txTimeout, txMetadata, mode, defaultDatabase() ), packer ); + encoder.encode( new BeginMessage( bookmark, txTimeout, txMetadata, mode, defaultDatabase(), impersonatedUser ), packer ); InOrder order = inOrder( packer ); order.verify( packer ).packStructHeader( 1, BeginMessage.SIGNATURE ); @@ -72,10 +75,20 @@ void shouldEncodeBeginMessage( AccessMode mode ) throws Exception { expectedMetadata.put( "mode", value( "r" ) ); } + if ( impersonatedUser != null ) + { + expectedMetadata.put( "imp_user", value( impersonatedUser ) ); + } order.verify( packer ).pack( expectedMetadata ); } + private static Stream arguments() + { + return Arrays.stream( AccessMode.values() ) + .flatMap( accessMode -> Stream.of( Arguments.of( accessMode, "user" ), Arguments.of( accessMode, null ) ) ); + } + @Test void shouldFailToEncodeWrongMessage() { diff --git a/driver/src/test/java/org/neo4j/driver/internal/messaging/encode/RouteMessageEncoderTest.java b/driver/src/test/java/org/neo4j/driver/internal/messaging/encode/RouteMessageEncoderTest.java index b53ce8d928..76aa007700 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/messaging/encode/RouteMessageEncoderTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/messaging/encode/RouteMessageEncoderTest.java @@ -54,7 +54,7 @@ void shouldEncodeRouteMessage(String databaseName) throws IOException { Map routingContext = getRoutingContext(); - encoder.encode( new RouteMessage( getRoutingContext(), null, databaseName ), packer ); + encoder.encode( new RouteMessage( getRoutingContext(), null, databaseName, null ), packer ); InOrder inOrder = inOrder( packer ); @@ -72,7 +72,7 @@ void shouldEncodeRouteMessageWithBookmark(String databaseName) throws IOExceptio Map routingContext = getRoutingContext(); Bookmark bookmark = InternalBookmark.parse( "somebookmark" ); - encoder.encode( new RouteMessage( getRoutingContext(), bookmark, databaseName ), packer ); + encoder.encode( new RouteMessage( getRoutingContext(), bookmark, databaseName, null ), packer ); InOrder inOrder = inOrder( packer ); diff --git a/driver/src/test/java/org/neo4j/driver/internal/messaging/encode/RunWithMetadataMessageEncoderTest.java b/driver/src/test/java/org/neo4j/driver/internal/messaging/encode/RunWithMetadataMessageEncoderTest.java index 9a8b262a55..f90da33e62 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/messaging/encode/RunWithMetadataMessageEncoderTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/messaging/encode/RunWithMetadataMessageEncoderTest.java @@ -66,7 +66,7 @@ void shouldEncodeRunWithMetadataMessage( AccessMode mode ) throws Exception Duration txTimeout = Duration.ofMillis( 42 ); Query query = new Query( "RETURN $answer", value( params ) ); - encoder.encode( autoCommitTxRunMessage(query, txTimeout, txMetadata, defaultDatabase(), mode, bookmark ), packer ); + encoder.encode( autoCommitTxRunMessage( query, txTimeout, txMetadata, defaultDatabase(), mode, bookmark, null ), packer ); InOrder order = inOrder( packer ); order.verify( packer ).packStructHeader( 3, RunWithMetadataMessage.SIGNATURE ); diff --git a/driver/src/test/java/org/neo4j/driver/internal/messaging/request/TransactionMetadataBuilderTest.java b/driver/src/test/java/org/neo4j/driver/internal/messaging/request/TransactionMetadataBuilderTest.java index c06d7d3810..e9dae116ca 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/messaging/request/TransactionMetadataBuilderTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/messaging/request/TransactionMetadataBuilderTest.java @@ -59,7 +59,7 @@ void shouldHaveCorrectMetadata( AccessMode mode ) Duration txTimeout = Duration.ofSeconds( 7 ); - Map metadata = buildMetadata( txTimeout, txMetadata, defaultDatabase(), mode, bookmark ); + Map metadata = buildMetadata( txTimeout, txMetadata, defaultDatabase(), mode, bookmark, null ); Map expectedMetadata = new HashMap<>(); expectedMetadata.put( "bookmarks", value( bookmark.values() ) ); @@ -86,7 +86,7 @@ void shouldHaveCorrectMetadataForDatabaseName( String databaseName ) Duration txTimeout = Duration.ofSeconds( 7 ); - Map metadata = buildMetadata( txTimeout, txMetadata, database( databaseName ), WRITE, bookmark ); + Map metadata = buildMetadata( txTimeout, txMetadata, database( databaseName ), WRITE, bookmark, null ); Map expectedMetadata = new HashMap<>(); expectedMetadata.put( "bookmarks", value( bookmark.values() ) ); @@ -100,7 +100,7 @@ void shouldHaveCorrectMetadataForDatabaseName( String databaseName ) @Test void shouldNotHaveMetadataForDatabaseNameWhenIsNull() { - Map metadata = buildMetadata( null, null, defaultDatabase(), WRITE, null ); + Map metadata = buildMetadata( null, null, defaultDatabase(), WRITE, null, null ); assertTrue( metadata.isEmpty() ); } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/messaging/v3/BoltProtocolV3Test.java b/driver/src/test/java/org/neo4j/driver/internal/messaging/v3/BoltProtocolV3Test.java index d08248a7b9..14f6d3bcb2 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/messaging/v3/BoltProtocolV3Test.java +++ b/driver/src/test/java/org/neo4j/driver/internal/messaging/v3/BoltProtocolV3Test.java @@ -194,7 +194,7 @@ void shouldBeginTransactionWithoutBookmark() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), TransactionConfig.empty() ); - verify( connection ).writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE ) ), + verify( connection ).writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE, null ) ), any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -208,7 +208,8 @@ void shouldBeginTransactionWithBookmarks() CompletionStage stage = protocol.beginTransaction( connection, bookmark, TransactionConfig.empty() ); verify( connection ) - .writeAndFlush( eq( new BeginMessage( bookmark, TransactionConfig.empty(), defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); + .writeAndFlush( eq( new BeginMessage( bookmark, TransactionConfig.empty(), defaultDatabase(), WRITE, null ) ), + any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -220,7 +221,8 @@ void shouldBeginTransactionWithConfig() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), txConfig ); verify( connection ) - .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); + .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE, null ) ), + any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -232,7 +234,7 @@ void shouldBeginTransactionWithBookmarksAndConfig() CompletionStage stage = protocol.beginTransaction( connection, bookmark, txConfig ); - verify( connection ).writeAndFlush( eq( new BeginMessage( bookmark, txConfig, defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); + verify( connection ).writeAndFlush( eq( new BeginMessage( bookmark, txConfig, defaultDatabase(), WRITE, null ) ), any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -499,7 +501,7 @@ private static ResponseHandlers verifyRunInvoked( Connection connection, boolean RunWithMetadataMessage expectedMessage; if ( session ) { - expectedMessage = RunWithMetadataMessage.autoCommitTxRunMessage( QUERY, config, defaultDatabase(), mode, bookmark ); + expectedMessage = RunWithMetadataMessage.autoCommitTxRunMessage( QUERY, config, defaultDatabase(), mode, bookmark, null ); } else { diff --git a/driver/src/test/java/org/neo4j/driver/internal/messaging/v3/MessageWriterV3Test.java b/driver/src/test/java/org/neo4j/driver/internal/messaging/v3/MessageWriterV3Test.java index c3e14b530b..bcfc4a843a 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/messaging/v3/MessageWriterV3Test.java +++ b/driver/src/test/java/org/neo4j/driver/internal/messaging/v3/MessageWriterV3Test.java @@ -93,15 +93,15 @@ protected Stream supportedMessages() new HelloMessage( "MyDriver/1.2.3", ((InternalAuthToken) basic( "neo4j", "neo4j" )).toMap(), Collections.emptyMap() ), GOODBYE, new BeginMessage( InternalBookmark.parse( "neo4j:bookmark:v1:tx123" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), READ, - defaultDatabase() ), + defaultDatabase(), null ), new BeginMessage( InternalBookmark.parse( "neo4j:bookmark:v1:tx123" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), WRITE, - defaultDatabase() ), + defaultDatabase(), null ), COMMIT, ROLLBACK, autoCommitTxRunMessage( new Query( "RETURN 1" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), defaultDatabase(), READ, - InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ) ), + InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ), null ), autoCommitTxRunMessage( new Query( "RETURN 1" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), defaultDatabase(), WRITE, - InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ) ), + InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ), null ), unmanagedTxRunMessage( new Query( "RETURN 1" ) ), PULL_ALL, DISCARD_ALL, @@ -109,10 +109,10 @@ protected Stream supportedMessages() // Bolt V3 messages with struct values autoCommitTxRunMessage( new Query( "RETURN $x", singletonMap( "x", value( ZonedDateTime.now() ) ) ), ofSeconds( 1 ), emptyMap(), - defaultDatabase(), READ, InternalBookmark.empty() ), + defaultDatabase(), READ, InternalBookmark.empty(), null ), autoCommitTxRunMessage( new Query( "RETURN $x", singletonMap( "x", value( ZonedDateTime.now() ) ) ), ofSeconds( 1 ), emptyMap(), - defaultDatabase(), WRITE, InternalBookmark.empty() ), - unmanagedTxRunMessage( new Query( "RETURN $x", singletonMap( "x", point( 42, 1, 2, 3 ) ) ) ) + defaultDatabase(), WRITE, InternalBookmark.empty(), null ), + unmanagedTxRunMessage( new Query( "RETURN $x", singletonMap( "x", point( 42, 1, 2, 3 ) ) ) ) ); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/messaging/v4/BoltProtocolV4Test.java b/driver/src/test/java/org/neo4j/driver/internal/messaging/v4/BoltProtocolV4Test.java index 2040e9b4b0..ac26d0a2f0 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/messaging/v4/BoltProtocolV4Test.java +++ b/driver/src/test/java/org/neo4j/driver/internal/messaging/v4/BoltProtocolV4Test.java @@ -188,7 +188,7 @@ void shouldBeginTransactionWithoutBookmark() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), TransactionConfig.empty() ); verify( connection ) - .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE ) ), + .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE, null ) ), any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -202,7 +202,8 @@ void shouldBeginTransactionWithBookmarks() CompletionStage stage = protocol.beginTransaction( connection, bookmark, TransactionConfig.empty() ); verify( connection ) - .writeAndFlush( eq( new BeginMessage( bookmark, TransactionConfig.empty(), defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); + .writeAndFlush( eq( new BeginMessage( bookmark, TransactionConfig.empty(), defaultDatabase(), WRITE, null ) ), + any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -214,7 +215,8 @@ void shouldBeginTransactionWithConfig() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), txConfig ); verify( connection ) - .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); + .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE, null ) ), + any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -226,7 +228,7 @@ void shouldBeginTransactionWithBookmarksAndConfig() CompletionStage stage = protocol.beginTransaction( connection, bookmark, txConfig ); - verify( connection ).writeAndFlush( eq( new BeginMessage( bookmark, txConfig, defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); + verify( connection ).writeAndFlush( eq( new BeginMessage( bookmark, txConfig, defaultDatabase(), WRITE, null ) ), any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -512,7 +514,7 @@ private ResponseHandler verifyTxRunInvoked( Connection connection ) private ResponseHandler verifySessionRunInvoked( Connection connection, Bookmark bookmark, TransactionConfig config, AccessMode mode, DatabaseName databaseName ) { - RunWithMetadataMessage runMessage = RunWithMetadataMessage.autoCommitTxRunMessage( QUERY, config, databaseName, mode, bookmark ); + RunWithMetadataMessage runMessage = RunWithMetadataMessage.autoCommitTxRunMessage( QUERY, config, databaseName, mode, bookmark, null ); return verifyRunInvoked( connection, runMessage ); } @@ -533,7 +535,7 @@ private ResponseHandler verifyRunInvoked( Connection connection, RunWithMetadata private void verifyBeginInvoked( Connection connection, Bookmark bookmark, TransactionConfig config, AccessMode mode, DatabaseName databaseName ) { ArgumentCaptor beginHandlerCaptor = ArgumentCaptor.forClass( ResponseHandler.class ); - BeginMessage beginMessage = new BeginMessage( bookmark, config, databaseName, mode ); + BeginMessage beginMessage = new BeginMessage( bookmark, config, databaseName, mode, null ); verify( connection ).writeAndFlush( eq( beginMessage ), beginHandlerCaptor.capture() ); assertThat( beginHandlerCaptor.getValue(), instanceOf( BeginTxResponseHandler.class ) ); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/messaging/v4/MessageWriterV4Test.java b/driver/src/test/java/org/neo4j/driver/internal/messaging/v4/MessageWriterV4Test.java index 7f7f95bd86..df7cdf858c 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/messaging/v4/MessageWriterV4Test.java +++ b/driver/src/test/java/org/neo4j/driver/internal/messaging/v4/MessageWriterV4Test.java @@ -101,25 +101,26 @@ protected Stream supportedMessages() new HelloMessage( "MyDriver/1.2.3", ((InternalAuthToken) basic( "neo4j", "neo4j" )).toMap(), Collections.emptyMap() ), GOODBYE, new BeginMessage( InternalBookmark.parse( "neo4j:bookmark:v1:tx123" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), READ, - defaultDatabase() ), + defaultDatabase(), null ), new BeginMessage( InternalBookmark.parse( "neo4j:bookmark:v1:tx123" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), WRITE, - database( "foo" ) ), + database( "foo" ), null ), COMMIT, ROLLBACK, RESET, autoCommitTxRunMessage( new Query( "RETURN 1" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), defaultDatabase(), READ, - InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ) ), + InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ), null ), autoCommitTxRunMessage( new Query( "RETURN 1" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), database( "foo" ), WRITE, - InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ) ), + InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ), null ), unmanagedTxRunMessage( new Query( "RETURN 1" ) ), // Bolt V3 messages with struct values autoCommitTxRunMessage( new Query( "RETURN $x", singletonMap( "x", value( ZonedDateTime.now() ) ) ), ofSeconds( 1 ), emptyMap(), - defaultDatabase(), READ, InternalBookmark.empty() ), - autoCommitTxRunMessage( new Query( "RETURN $x", singletonMap( "x", value( ZonedDateTime.now() ) ) ), ofSeconds( 1 ), emptyMap(), database( "foo" ), - WRITE, InternalBookmark.empty() ), - unmanagedTxRunMessage( new Query( "RETURN $x", singletonMap( "x", point( 42, 1, 2, 3 ) ) ) ) + defaultDatabase(), READ, InternalBookmark.empty(), null ), + autoCommitTxRunMessage( new Query( "RETURN $x", singletonMap( "x", value( ZonedDateTime.now() ) ) ), ofSeconds( 1 ), emptyMap(), + database( "foo" ), + WRITE, InternalBookmark.empty(), null ), + unmanagedTxRunMessage( new Query( "RETURN $x", singletonMap( "x", point( 42, 1, 2, 3 ) ) ) ) ); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/messaging/v41/BoltProtocolV41Test.java b/driver/src/test/java/org/neo4j/driver/internal/messaging/v41/BoltProtocolV41Test.java index b600967b45..f722df6b78 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/messaging/v41/BoltProtocolV41Test.java +++ b/driver/src/test/java/org/neo4j/driver/internal/messaging/v41/BoltProtocolV41Test.java @@ -193,7 +193,7 @@ void shouldBeginTransactionWithoutBookmark() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), TransactionConfig.empty() ); verify( connection ) - .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE ) ), + .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE, null ) ), any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -207,7 +207,8 @@ void shouldBeginTransactionWithBookmarks() CompletionStage stage = protocol.beginTransaction( connection, bookmark, TransactionConfig.empty() ); verify( connection ) - .writeAndFlush( eq( new BeginMessage( bookmark, TransactionConfig.empty(), defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); + .writeAndFlush( eq( new BeginMessage( bookmark, TransactionConfig.empty(), defaultDatabase(), WRITE, null ) ), + any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -219,7 +220,8 @@ void shouldBeginTransactionWithConfig() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), txConfig ); verify( connection ) - .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); + .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE, null ) ), + any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -231,7 +233,7 @@ void shouldBeginTransactionWithBookmarksAndConfig() CompletionStage stage = protocol.beginTransaction( connection, bookmark, txConfig ); - verify( connection ).writeAndFlush( eq( new BeginMessage( bookmark, txConfig, defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); + verify( connection ).writeAndFlush( eq( new BeginMessage( bookmark, txConfig, defaultDatabase(), WRITE, null ) ), any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -506,7 +508,7 @@ private ResponseHandler verifyTxRunInvoked( Connection connection ) private ResponseHandler verifySessionRunInvoked( Connection connection, Bookmark bookmark, TransactionConfig config, AccessMode mode, DatabaseName databaseName ) { - RunWithMetadataMessage runMessage = RunWithMetadataMessage.autoCommitTxRunMessage( QUERY, config, databaseName, mode, bookmark ); + RunWithMetadataMessage runMessage = RunWithMetadataMessage.autoCommitTxRunMessage( QUERY, config, databaseName, mode, bookmark, null ); return verifyRunInvoked( connection, runMessage ); } @@ -527,7 +529,7 @@ private ResponseHandler verifyRunInvoked( Connection connection, RunWithMetadata private void verifyBeginInvoked( Connection connection, Bookmark bookmark, TransactionConfig config, AccessMode mode, DatabaseName databaseName ) { ArgumentCaptor beginHandlerCaptor = ArgumentCaptor.forClass( ResponseHandler.class ); - BeginMessage beginMessage = new BeginMessage( bookmark, config, databaseName, mode ); + BeginMessage beginMessage = new BeginMessage( bookmark, config, databaseName, mode, null ); verify( connection ).writeAndFlush( eq( beginMessage ), beginHandlerCaptor.capture() ); assertThat( beginHandlerCaptor.getValue(), instanceOf( BeginTxResponseHandler.class ) ); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/messaging/v41/MessageWriterV41Test.java b/driver/src/test/java/org/neo4j/driver/internal/messaging/v41/MessageWriterV41Test.java index 1f03b5375d..dc172824cf 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/messaging/v41/MessageWriterV41Test.java +++ b/driver/src/test/java/org/neo4j/driver/internal/messaging/v41/MessageWriterV41Test.java @@ -100,25 +100,25 @@ protected Stream supportedMessages() new HelloMessage( "MyDriver/1.2.3", ((InternalAuthToken) basic( "neo4j", "neo4j" )).toMap(), Collections.emptyMap() ), GOODBYE, new BeginMessage( InternalBookmark.parse( "neo4j:bookmark:v1:tx123" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), READ, - defaultDatabase() ), + defaultDatabase(), null ), new BeginMessage( InternalBookmark.parse( "neo4j:bookmark:v1:tx123" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), WRITE, - database( "foo" ) ), + database( "foo" ), null ), COMMIT, ROLLBACK, RESET, autoCommitTxRunMessage( new Query( "RETURN 1" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), defaultDatabase(), READ, - InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ) ), + InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ), null ), autoCommitTxRunMessage( new Query( "RETURN 1" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), database( "foo" ), WRITE, - InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ) ), + InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ), null ), unmanagedTxRunMessage( new Query( "RETURN 1" ) ), // Bolt V3 messages with struct values autoCommitTxRunMessage( new Query( "RETURN $x", singletonMap( "x", value( ZonedDateTime.now() ) ) ), ofSeconds( 1 ), emptyMap(), - defaultDatabase(), READ, InternalBookmark.empty() ), + defaultDatabase(), READ, InternalBookmark.empty(), null ), autoCommitTxRunMessage( new Query( "RETURN $x", singletonMap( "x", value( ZonedDateTime.now() ) ) ), ofSeconds( 1 ), emptyMap(), database( "foo" ), - WRITE, InternalBookmark.empty() ), + WRITE, InternalBookmark.empty(), null ), unmanagedTxRunMessage( new Query( "RETURN $x", singletonMap( "x", point( 42, 1, 2, 3 ) ) ) ) ); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/messaging/v42/BoltProtocolV42Test.java b/driver/src/test/java/org/neo4j/driver/internal/messaging/v42/BoltProtocolV42Test.java index 85d8ce72e2..b163030c50 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/messaging/v42/BoltProtocolV42Test.java +++ b/driver/src/test/java/org/neo4j/driver/internal/messaging/v42/BoltProtocolV42Test.java @@ -193,7 +193,7 @@ void shouldBeginTransactionWithoutBookmark() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), TransactionConfig.empty() ); verify( connection ) - .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE ) ), + .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE, null ) ), any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -207,7 +207,8 @@ void shouldBeginTransactionWithBookmarks() CompletionStage stage = protocol.beginTransaction( connection, bookmark, TransactionConfig.empty() ); verify( connection ) - .writeAndFlush( eq( new BeginMessage( bookmark, TransactionConfig.empty(), defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); + .writeAndFlush( eq( new BeginMessage( bookmark, TransactionConfig.empty(), defaultDatabase(), WRITE, null ) ), + any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -219,7 +220,8 @@ void shouldBeginTransactionWithConfig() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), txConfig ); verify( connection ) - .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); + .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE, null ) ), + any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -231,7 +233,7 @@ void shouldBeginTransactionWithBookmarksAndConfig() CompletionStage stage = protocol.beginTransaction( connection, bookmark, txConfig ); - verify( connection ).writeAndFlush( eq( new BeginMessage( bookmark, txConfig, defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); + verify( connection ).writeAndFlush( eq( new BeginMessage( bookmark, txConfig, defaultDatabase(), WRITE, null ) ), any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -507,7 +509,7 @@ private ResponseHandler verifyTxRunInvoked( Connection connection ) private ResponseHandler verifySessionRunInvoked( Connection connection, Bookmark bookmark, TransactionConfig config, AccessMode mode, DatabaseName databaseName ) { - RunWithMetadataMessage runMessage = RunWithMetadataMessage.autoCommitTxRunMessage( QUERY, config, databaseName, mode, bookmark ); + RunWithMetadataMessage runMessage = RunWithMetadataMessage.autoCommitTxRunMessage( QUERY, config, databaseName, mode, bookmark, null ); return verifyRunInvoked( connection, runMessage ); } @@ -528,7 +530,7 @@ private ResponseHandler verifyRunInvoked( Connection connection, RunWithMetadata private void verifyBeginInvoked( Connection connection, Bookmark bookmark, TransactionConfig config, AccessMode mode, DatabaseName databaseName ) { ArgumentCaptor beginHandlerCaptor = ArgumentCaptor.forClass( ResponseHandler.class ); - BeginMessage beginMessage = new BeginMessage( bookmark, config, databaseName, mode ); + BeginMessage beginMessage = new BeginMessage( bookmark, config, databaseName, mode, null ); verify( connection ).writeAndFlush( eq( beginMessage ), beginHandlerCaptor.capture() ); assertThat( beginHandlerCaptor.getValue(), instanceOf( BeginTxResponseHandler.class ) ); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/messaging/v42/MessageWriterV42Test.java b/driver/src/test/java/org/neo4j/driver/internal/messaging/v42/MessageWriterV42Test.java index 927cb77156..4ce24f8fd6 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/messaging/v42/MessageWriterV42Test.java +++ b/driver/src/test/java/org/neo4j/driver/internal/messaging/v42/MessageWriterV42Test.java @@ -100,25 +100,25 @@ protected Stream supportedMessages() new HelloMessage( "MyDriver/1.2.3", ((InternalAuthToken) basic( "neo4j", "neo4j" )).toMap(), Collections.emptyMap() ), GOODBYE, new BeginMessage( InternalBookmark.parse( "neo4j:bookmark:v1:tx123" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), READ, - defaultDatabase() ), + defaultDatabase(), null ), new BeginMessage( InternalBookmark.parse( "neo4j:bookmark:v1:tx123" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), WRITE, - database( "foo" ) ), + database( "foo" ), null ), COMMIT, ROLLBACK, RESET, autoCommitTxRunMessage( new Query( "RETURN 1" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), defaultDatabase(), READ, - InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ) ), + InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ), null ), autoCommitTxRunMessage( new Query( "RETURN 1" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), database( "foo" ), WRITE, - InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ) ), + InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ), null ), unmanagedTxRunMessage( new Query( "RETURN 1" ) ), // Bolt V3 messages with struct values autoCommitTxRunMessage( new Query( "RETURN $x", singletonMap( "x", value( ZonedDateTime.now() ) ) ), ofSeconds( 1 ), emptyMap(), - defaultDatabase(), READ, InternalBookmark.empty() ), + defaultDatabase(), READ, InternalBookmark.empty(), null ), autoCommitTxRunMessage( new Query( "RETURN $x", singletonMap( "x", value( ZonedDateTime.now() ) ) ), ofSeconds( 1 ), emptyMap(), database( "foo" ), - WRITE, InternalBookmark.empty() ), + WRITE, InternalBookmark.empty(), null ), unmanagedTxRunMessage( new Query( "RETURN $x", singletonMap( "x", point( 42, 1, 2, 3 ) ) ) ) ); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/messaging/v43/BoltProtocolV43Test.java b/driver/src/test/java/org/neo4j/driver/internal/messaging/v43/BoltProtocolV43Test.java index 156eb02430..d8b2ee13c8 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/messaging/v43/BoltProtocolV43Test.java +++ b/driver/src/test/java/org/neo4j/driver/internal/messaging/v43/BoltProtocolV43Test.java @@ -192,7 +192,7 @@ void shouldBeginTransactionWithoutBookmark() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), TransactionConfig.empty() ); verify( connection ) - .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE ) ), + .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE, null ) ), any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -206,7 +206,8 @@ void shouldBeginTransactionWithBookmarks() CompletionStage stage = protocol.beginTransaction( connection, bookmark, TransactionConfig.empty() ); verify( connection ) - .writeAndFlush( eq( new BeginMessage( bookmark, TransactionConfig.empty(), defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); + .writeAndFlush( eq( new BeginMessage( bookmark, TransactionConfig.empty(), defaultDatabase(), WRITE, null ) ), + any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -218,7 +219,8 @@ void shouldBeginTransactionWithConfig() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), txConfig ); verify( connection ) - .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); + .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE, null ) ), + any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -230,7 +232,7 @@ void shouldBeginTransactionWithBookmarksAndConfig() CompletionStage stage = protocol.beginTransaction( connection, bookmark, txConfig ); - verify( connection ).writeAndFlush( eq( new BeginMessage( bookmark, txConfig, defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); + verify( connection ).writeAndFlush( eq( new BeginMessage( bookmark, txConfig, defaultDatabase(), WRITE, null ) ), any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -507,7 +509,7 @@ private ResponseHandler verifyTxRunInvoked( Connection connection ) private ResponseHandler verifySessionRunInvoked( Connection connection, Bookmark bookmark, TransactionConfig config, AccessMode mode, DatabaseName databaseName ) { - RunWithMetadataMessage runMessage = RunWithMetadataMessage.autoCommitTxRunMessage( QUERY, config, databaseName, mode, bookmark ); + RunWithMetadataMessage runMessage = RunWithMetadataMessage.autoCommitTxRunMessage( QUERY, config, databaseName, mode, bookmark, null ); return verifyRunInvoked( connection, runMessage ); } @@ -528,7 +530,7 @@ private ResponseHandler verifyRunInvoked( Connection connection, RunWithMetadata private void verifyBeginInvoked( Connection connection, Bookmark bookmark, TransactionConfig config, AccessMode mode, DatabaseName databaseName ) { ArgumentCaptor beginHandlerCaptor = ArgumentCaptor.forClass( ResponseHandler.class ); - BeginMessage beginMessage = new BeginMessage( bookmark, config, databaseName, mode ); + BeginMessage beginMessage = new BeginMessage( bookmark, config, databaseName, mode, null ); verify( connection ).writeAndFlush( eq( beginMessage ), beginHandlerCaptor.capture() ); assertThat( beginHandlerCaptor.getValue(), instanceOf( BeginTxResponseHandler.class ) ); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/messaging/v43/MessageWriterV43Test.java b/driver/src/test/java/org/neo4j/driver/internal/messaging/v43/MessageWriterV43Test.java index 880dcf2a57..c8eceeeecb 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/messaging/v43/MessageWriterV43Test.java +++ b/driver/src/test/java/org/neo4j/driver/internal/messaging/v43/MessageWriterV43Test.java @@ -105,25 +105,25 @@ protected Stream supportedMessages() new HelloMessage( "MyDriver/1.2.3", ((InternalAuthToken) basic( "neo4j", "neo4j" )).toMap(), Collections.emptyMap() ), GOODBYE, new BeginMessage( InternalBookmark.parse( "neo4j:bookmark:v1:tx123" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), READ, - defaultDatabase() ), + defaultDatabase(), null ), new BeginMessage( InternalBookmark.parse( "neo4j:bookmark:v1:tx123" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), WRITE, - database( "foo" ) ), + database( "foo" ), null ), COMMIT, ROLLBACK, RESET, autoCommitTxRunMessage( new Query( "RETURN 1" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), defaultDatabase(), READ, - InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ) ), + InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ), null ), autoCommitTxRunMessage( new Query( "RETURN 1" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), database( "foo" ), WRITE, - InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ) ), + InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ), null ), unmanagedTxRunMessage( new Query( "RETURN 1" ) ), // Bolt V3 messages with struct values autoCommitTxRunMessage( new Query( "RETURN $x", singletonMap( "x", value( ZonedDateTime.now() ) ) ), ofSeconds( 1 ), emptyMap(), - defaultDatabase(), READ, InternalBookmark.empty() ), + defaultDatabase(), READ, InternalBookmark.empty(), null ), autoCommitTxRunMessage( new Query( "RETURN $x", singletonMap( "x", value( ZonedDateTime.now() ) ) ), ofSeconds( 1 ), emptyMap(), database( "foo" ), - WRITE, InternalBookmark.empty() ), + WRITE, InternalBookmark.empty(), null ), unmanagedTxRunMessage( new Query( "RETURN $x", singletonMap( "x", point( 42, 1, 2, 3 ) ) ) ), // New 4.3 Messages @@ -147,6 +147,6 @@ private RouteMessage routeMessage() { Map routeContext = new HashMap<>(); routeContext.put( "someContext", Values.value( 124 ) ); - return new RouteMessage( routeContext, InternalBookmark.empty(), "dbName" ); + return new RouteMessage( routeContext, InternalBookmark.empty(), "dbName", null ); } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/messaging/v44/BoltProtocolV44Test.java b/driver/src/test/java/org/neo4j/driver/internal/messaging/v44/BoltProtocolV44Test.java index aa277c66ca..09cd33c40e 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/messaging/v44/BoltProtocolV44Test.java +++ b/driver/src/test/java/org/neo4j/driver/internal/messaging/v44/BoltProtocolV44Test.java @@ -192,7 +192,7 @@ void shouldBeginTransactionWithoutBookmark() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), TransactionConfig.empty() ); verify( connection ) - .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE ) ), + .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE, null ) ), any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -206,7 +206,8 @@ void shouldBeginTransactionWithBookmarks() CompletionStage stage = protocol.beginTransaction( connection, bookmark, TransactionConfig.empty() ); verify( connection ) - .writeAndFlush( eq( new BeginMessage( bookmark, TransactionConfig.empty(), defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); + .writeAndFlush( eq( new BeginMessage( bookmark, TransactionConfig.empty(), defaultDatabase(), WRITE, null ) ), + any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -218,7 +219,8 @@ void shouldBeginTransactionWithConfig() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), txConfig ); verify( connection ) - .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); + .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE, null ) ), + any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -230,7 +232,7 @@ void shouldBeginTransactionWithBookmarksAndConfig() CompletionStage stage = protocol.beginTransaction( connection, bookmark, txConfig ); - verify( connection ).writeAndFlush( eq( new BeginMessage( bookmark, txConfig, defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); + verify( connection ).writeAndFlush( eq( new BeginMessage( bookmark, txConfig, defaultDatabase(), WRITE, null ) ), any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -507,7 +509,7 @@ private ResponseHandler verifyTxRunInvoked( Connection connection ) private ResponseHandler verifySessionRunInvoked( Connection connection, Bookmark bookmark, TransactionConfig config, AccessMode mode, DatabaseName databaseName ) { - RunWithMetadataMessage runMessage = RunWithMetadataMessage.autoCommitTxRunMessage( QUERY, config, databaseName, mode, bookmark ); + RunWithMetadataMessage runMessage = RunWithMetadataMessage.autoCommitTxRunMessage( QUERY, config, databaseName, mode, bookmark, null ); return verifyRunInvoked( connection, runMessage ); } @@ -528,7 +530,7 @@ private ResponseHandler verifyRunInvoked( Connection connection, RunWithMetadata private void verifyBeginInvoked( Connection connection, Bookmark bookmark, TransactionConfig config, AccessMode mode, DatabaseName databaseName ) { ArgumentCaptor beginHandlerCaptor = ArgumentCaptor.forClass( ResponseHandler.class ); - BeginMessage beginMessage = new BeginMessage( bookmark, config, databaseName, mode ); + BeginMessage beginMessage = new BeginMessage( bookmark, config, databaseName, mode, null ); verify( connection ).writeAndFlush( eq( beginMessage ), beginHandlerCaptor.capture() ); assertThat( beginHandlerCaptor.getValue(), instanceOf( BeginTxResponseHandler.class ) ); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/messaging/v44/MessageWriterV44Test.java b/driver/src/test/java/org/neo4j/driver/internal/messaging/v44/MessageWriterV44Test.java index cdb827c3eb..75185916f4 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/messaging/v44/MessageWriterV44Test.java +++ b/driver/src/test/java/org/neo4j/driver/internal/messaging/v44/MessageWriterV44Test.java @@ -105,25 +105,25 @@ protected Stream supportedMessages() new HelloMessage( "MyDriver/1.2.3", ((InternalAuthToken) basic( "neo4j", "neo4j" )).toMap(), Collections.emptyMap() ), GOODBYE, new BeginMessage( InternalBookmark.parse( "neo4j:bookmark:v1:tx123" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), READ, - defaultDatabase() ), + defaultDatabase(), null ), new BeginMessage( InternalBookmark.parse( "neo4j:bookmark:v1:tx123" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), WRITE, - database( "foo" ) ), + database( "foo" ), null ), COMMIT, ROLLBACK, RESET, autoCommitTxRunMessage( new Query( "RETURN 1" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), defaultDatabase(), READ, - InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ) ), + InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ), null ), autoCommitTxRunMessage( new Query( "RETURN 1" ), ofSeconds( 5 ), singletonMap( "key", value( 42 ) ), database( "foo" ), WRITE, - InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ) ), + InternalBookmark.parse( "neo4j:bookmark:v1:tx1" ), null ), unmanagedTxRunMessage( new Query( "RETURN 1" ) ), // Bolt V3 messages with struct values autoCommitTxRunMessage( new Query( "RETURN $x", singletonMap( "x", value( ZonedDateTime.now() ) ) ), ofSeconds( 1 ), emptyMap(), - defaultDatabase(), READ, InternalBookmark.empty() ), + defaultDatabase(), READ, InternalBookmark.empty(), null ), autoCommitTxRunMessage( new Query( "RETURN $x", singletonMap( "x", value( ZonedDateTime.now() ) ) ), ofSeconds( 1 ), emptyMap(), database( "foo" ), - WRITE, InternalBookmark.empty() ), + WRITE, InternalBookmark.empty(), null ), unmanagedTxRunMessage( new Query( "RETURN $x", singletonMap( "x", point( 42, 1, 2, 3 ) ) ) ), // New 4.3 Messages @@ -147,6 +147,6 @@ private RouteMessage routeMessage() { Map routeContext = new HashMap<>(); routeContext.put( "someContext", Values.value( 124 ) ); - return new RouteMessage( routeContext, InternalBookmark.empty(), "dbName" ); + return new RouteMessage( routeContext, InternalBookmark.empty(), "dbName", null ); } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/util/ClusterCompositionUtil.java b/driver/src/test/java/org/neo4j/driver/internal/util/ClusterCompositionUtil.java index 86698f0bc6..aabbfb44bd 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/util/ClusterCompositionUtil.java +++ b/driver/src/test/java/org/neo4j/driver/internal/util/ClusterCompositionUtil.java @@ -69,6 +69,6 @@ public static ClusterComposition createClusterComposition( long expirationTimest case 1: routers.addAll( servers[0] ); } - return new ClusterComposition( expirationTimestamp, readers, writers, routers ); + return new ClusterComposition( expirationTimestamp, readers, writers, routers, null ); } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/util/ServerVersionTest.java b/driver/src/test/java/org/neo4j/driver/internal/util/ServerVersionTest.java index e4f2976fb4..d0e880d5d0 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/util/ServerVersionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/util/ServerVersionTest.java @@ -24,6 +24,8 @@ import org.neo4j.driver.internal.messaging.v4.BoltProtocolV4; import org.neo4j.driver.internal.messaging.v41.BoltProtocolV41; import org.neo4j.driver.internal.messaging.v42.BoltProtocolV42; +import org.neo4j.driver.internal.messaging.v43.BoltProtocolV43; +import org.neo4j.driver.internal.messaging.v44.BoltProtocolV44; import static java.lang.Integer.MAX_VALUE; import static org.hamcrest.MatcherAssert.assertThat; @@ -73,6 +75,8 @@ void shouldReturnCorrectServerVersionFromBoltProtocolVersion() assertEquals( ServerVersion.v4_0_0, ServerVersion.fromBoltProtocolVersion( BoltProtocolV4.VERSION ) ); assertEquals( ServerVersion.v4_1_0, ServerVersion.fromBoltProtocolVersion( BoltProtocolV41.VERSION ) ); assertEquals( ServerVersion.v4_2_0, ServerVersion.fromBoltProtocolVersion( BoltProtocolV42.VERSION ) ); + assertEquals( ServerVersion.v4_3_0, ServerVersion.fromBoltProtocolVersion( BoltProtocolV43.VERSION ) ); + assertEquals( ServerVersion.v4_4_0, ServerVersion.fromBoltProtocolVersion( BoltProtocolV44.VERSION ) ); assertEquals( ServerVersion.vInDev, ServerVersion.fromBoltProtocolVersion( new BoltProtocolVersion( MAX_VALUE, MAX_VALUE ) ) ); } } diff --git a/driver/src/test/java/org/neo4j/driver/stress/CausalClusteringIT.java b/driver/src/test/java/org/neo4j/driver/stress/CausalClusteringIT.java index d3a9027bc3..e274c0b062 100644 --- a/driver/src/test/java/org/neo4j/driver/stress/CausalClusteringIT.java +++ b/driver/src/test/java/org/neo4j/driver/stress/CausalClusteringIT.java @@ -26,6 +26,7 @@ import java.net.URI; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; @@ -39,6 +40,7 @@ import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; +import java.util.stream.Collectors; import org.neo4j.driver.AccessMode; import org.neo4j.driver.AuthToken; @@ -68,6 +70,7 @@ import org.neo4j.driver.internal.util.ServerVersion; import org.neo4j.driver.internal.util.ThrowingMessageEncoder; import org.neo4j.driver.internal.util.io.ChannelTrackingDriverFactory; +import org.neo4j.driver.net.ServerAddress; import org.neo4j.driver.summary.ResultSummary; import org.neo4j.driver.util.cc.Cluster; import org.neo4j.driver.util.cc.ClusterExtension; @@ -331,8 +334,18 @@ void shouldHandleGracefulLeaderSwitch() throws Exception { Cluster cluster = clusterRule.getCluster(); ClusterMember leader = cluster.leader(); + ServerAddress clusterAddress = ServerAddress.of( "cluster", 7687 ); + URI clusterUri = URI.create( String.format( "neo4j://%s:%d", clusterAddress.host(), clusterAddress.port() ) ); + Set coreAddresses = cluster.cores().stream() + .map( ClusterMember::getBoltAddress ) + .collect( Collectors.toSet() ); - try ( Driver driver = createDriver( leader.getRoutingUri() ) ) + Config config = Config.builder() + .withLogging( none() ) + .withResolver( address -> address.equals( clusterAddress ) ? coreAddresses : Collections.singleton( address ) ) + .build(); + + try ( Driver driver = GraphDatabase.driver( clusterUri, clusterRule.getDefaultAuthToken(), config ) ) { Session session1 = driver.session(); Transaction tx1 = session1.beginTransaction(); @@ -357,7 +370,8 @@ void shouldHandleGracefulLeaderSwitch() throws Exception return session.lastBookmark(); } ); - try ( Session session2 = driver.session( builder().withDefaultAccessMode( AccessMode.READ ).withBookmarks( bookmark ).build() ); + try ( Session session2 = driver.session( + builder().withDefaultAccessMode( AccessMode.READ ).withBookmarks( bookmark ).build() ); Transaction tx2 = session2.beginTransaction() ) { Record record = tx2.run( "MATCH (n:Person) RETURN COUNT(*) AS count" ).next(); @@ -538,20 +552,25 @@ void shouldRespectMaxConnectionPoolSizePerClusterMember() try ( Driver driver = createDriver( leader.getRoutingUri(), config ) ) { - Session writeSession1 = driver.session( builder().withDefaultAccessMode( AccessMode.WRITE ).build() ); + String database = "neo4j"; + Session writeSession1 = + driver.session( builder().withDatabase( database ).withDefaultAccessMode( AccessMode.WRITE ).build() ); writeSession1.beginTransaction(); - Session writeSession2 = driver.session( builder().withDefaultAccessMode( AccessMode.WRITE ).build() ); + Session writeSession2 = + driver.session( builder().withDatabase( database ).withDefaultAccessMode( AccessMode.WRITE ).build() ); writeSession2.beginTransaction(); // should not be possible to acquire more connections towards leader because limit is 2 - Session writeSession3 = driver.session( builder().withDefaultAccessMode( AccessMode.WRITE ).build() ); + Session writeSession3 = + driver.session( builder().withDatabase( database ).withDefaultAccessMode( AccessMode.WRITE ).build() ); ClientException e = assertThrows( ClientException.class, writeSession3::beginTransaction ); assertThat( e, is( connectionAcquisitionTimeoutError( 42 ) ) ); // should be possible to acquire new connection towards read server // it's a different machine, not leader, so different max connection pool size limit applies - Session readSession = driver.session( builder().withDefaultAccessMode( AccessMode.READ ).build() ); + Session readSession = + driver.session( builder().withDatabase( database ).withDefaultAccessMode( AccessMode.READ ).build() ); Record record = readSession.readTransaction( tx -> tx.run( "RETURN 1" ).single() ); assertEquals( 1, record.get( 0 ).asInt() ); } @@ -610,7 +629,8 @@ void shouldRediscoverWhenConnectionsToAllCoresBreak() try ( Driver driver = driverFactory.newInstance( leader.getRoutingUri(), clusterRule.getDefaultAuthToken(), RoutingSettings.DEFAULT, RetrySettings.DEFAULT, configWithoutLogging(), SecurityPlanImpl.insecure() ) ) { - try ( Session session = driver.session() ) + String database = "neo4j"; + try ( Session session = driver.session( builder().withDatabase( database ).build() ) ) { createNode( session, "Person", "name", "Vision" ); @@ -626,10 +646,10 @@ RoutingSettings.DEFAULT, RetrySettings.DEFAULT, configWithoutLogging(), Security makeAllChannelsFailToRunQueries( driverFactory, ServerVersion.version( driver ) ); // observe that connection towards writer is broken - try ( Session session = driver.session( builder().withDefaultAccessMode( AccessMode.WRITE ).build() ) ) + try ( Session session = driver.session( builder().withDatabase( database ).withDefaultAccessMode( AccessMode.WRITE ).build() ) ) { SessionExpiredException e = assertThrows( SessionExpiredException.class, - () -> runCreateNode( session, "Person", "name", "Vision" ).consume() ); + () -> runCreateNode( session, "Person", "name", "Vision" ).consume() ); assertEquals( "Disconnected", e.getCause().getMessage() ); } @@ -637,7 +657,7 @@ RoutingSettings.DEFAULT, RetrySettings.DEFAULT, configWithoutLogging(), Security int readersCount = cluster.followers().size() + cluster.readReplicas().size(); for ( int i = 0; i < readersCount; i++ ) { - try ( Session session = driver.session( builder().withDefaultAccessMode( AccessMode.READ ).build() ) ) + try ( Session session = driver.session( builder().withDatabase( database ).withDefaultAccessMode( AccessMode.READ ).build() ) ) { runCountNodes( session, "Person", "name", "Vision" ); } @@ -646,7 +666,7 @@ RoutingSettings.DEFAULT, RetrySettings.DEFAULT, configWithoutLogging(), Security } } - try ( Session session = driver.session() ) + try ( Session session = driver.session( builder().withDatabase( database ).build() ) ) { updateNode( session, "Person", "name", "Vision", "Thanos" ); assertEquals( 0, countNodes( session, "Person", "name", "Vision" ) ); diff --git a/driver/src/test/java/org/neo4j/driver/util/TestUtil.java b/driver/src/test/java/org/neo4j/driver/util/TestUtil.java index 1dd3231748..0df604d59c 100644 --- a/driver/src/test/java/org/neo4j/driver/util/TestUtil.java +++ b/driver/src/test/java/org/neo4j/driver/util/TestUtil.java @@ -316,7 +316,7 @@ public static NetworkSession newSession( ConnectionProvider connectionProvider ) public static NetworkSession newSession( ConnectionProvider connectionProvider, AccessMode mode, RetryLogic retryLogic, Bookmark bookmark ) { - return new NetworkSession( connectionProvider, retryLogic, defaultDatabase(), mode, new DefaultBookmarkHolder( bookmark ), UNLIMITED_FETCH_SIZE, + return new NetworkSession( connectionProvider, retryLogic, defaultDatabase(), mode, new DefaultBookmarkHolder( bookmark ), null, UNLIMITED_FETCH_SIZE, DEV_NULL_LOGGING ); } diff --git a/testkit-backend/src/main/java/neo4j/org/testkit/backend/channel/handler/TestkitRequestProcessorHandler.java b/testkit-backend/src/main/java/neo4j/org/testkit/backend/channel/handler/TestkitRequestProcessorHandler.java index bd5d0352ea..3d2ce940d0 100644 --- a/testkit-backend/src/main/java/neo4j/org/testkit/backend/channel/handler/TestkitRequestProcessorHandler.java +++ b/testkit-backend/src/main/java/neo4j/org/testkit/backend/channel/handler/TestkitRequestProcessorHandler.java @@ -36,7 +36,7 @@ import org.neo4j.driver.exceptions.Neo4jException; import org.neo4j.driver.exceptions.UntrustedServerException; -import org.neo4j.driver.internal.async.pool.ConnectionPoolImpl; +import org.neo4j.driver.internal.spi.ConnectionPool; public class TestkitRequestProcessorHandler extends ChannelInboundHandlerAdapter { @@ -154,7 +154,7 @@ else if ( isConnectionPoolClosedException( throwable ) || throwable instanceof U private boolean isConnectionPoolClosedException( Throwable throwable ) { return throwable instanceof IllegalStateException && throwable.getMessage() != null && - throwable.getMessage().equals( ConnectionPoolImpl.CONNECTION_POOL_CLOSED_ERROR_MESSAGE ); + throwable.getMessage().equals( ConnectionPool.CONNECTION_POOL_CLOSED_ERROR_MESSAGE ); } private void writeAndFlush( TestkitResponse response ) diff --git a/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/GetFeatures.java b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/GetFeatures.java index bf9544d4ca..a8e5dfe063 100644 --- a/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/GetFeatures.java +++ b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/GetFeatures.java @@ -42,7 +42,9 @@ public class GetFeatures implements TestkitRequest "Temporary:DriverMaxTxRetryTime", "Feature:Auth:Bearer", "Feature:Auth:Kerberos", - "Feature:Auth:Custom" + "Feature:Auth:Custom", + "Feature:Bolt:4.4", + "Feature:Impersonation" ) ); private static final Set SYNC_FEATURES = new HashSet<>( Arrays.asList( diff --git a/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/NewSession.java b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/NewSession.java index 020bfb93e3..299e456d9a 100644 --- a/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/NewSession.java +++ b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/NewSession.java @@ -79,6 +79,7 @@ protected TestkitResponse createSessionStateAndResponse( TestkitState testki .ifPresent( builder::withBookmarks ); Optional.ofNullable( data.database ).ifPresent( builder::withDatabase ); + Optional.ofNullable( data.impersonatedUser ).ifPresent( builder::withImpersonatedUser ); if ( data.getFetchSize() != 0 ) { @@ -114,6 +115,7 @@ public static class NewSessionBody private String accessMode; private List bookmarks; private String database; + private String impersonatedUser; private int fetchSize; } }