Skip to content

Commit

Permalink
Introduce BEGIN message pipelining in ExecutableQuery
Browse files Browse the repository at this point in the history
This is a communication optimization.
  • Loading branch information
injectives committed Aug 22, 2023
1 parent c3b10e3 commit 0c92a88
Show file tree
Hide file tree
Showing 22 changed files with 183 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import java.util.Map;
import java.util.stream.Collector;
import org.neo4j.driver.AccessMode;
import org.neo4j.driver.Driver;
import org.neo4j.driver.ExecutableQuery;
import org.neo4j.driver.Query;
Expand All @@ -30,6 +31,7 @@
import org.neo4j.driver.RoutingControl;
import org.neo4j.driver.SessionConfig;
import org.neo4j.driver.TransactionCallback;
import org.neo4j.driver.TransactionConfig;

public class InternalExecutableQuery implements ExecutableQuery {
private final Driver driver;
Expand Down Expand Up @@ -67,7 +69,7 @@ public <A, R, T> T execute(Collector<Record, A, R> recordCollector, ResultFinish
var supplier = recordCollector.supplier();
var accumulator = recordCollector.accumulator();
var finisher = recordCollector.finisher();
try (var session = driver.session(sessionConfigBuilder.build())) {
try (var session = (InternalSession) driver.session(sessionConfigBuilder.build())) {
TransactionCallback<T> txCallback = tx -> {
var result = tx.run(query);
var container = supplier.get();
Expand All @@ -78,9 +80,8 @@ public <A, R, T> T execute(Collector<Record, A, R> recordCollector, ResultFinish
var summary = result.consume();
return resultFinisher.finish(result.keys(), finishedValue, summary);
};
return config.routing().equals(RoutingControl.READ)
? session.executeRead(txCallback)
: session.executeWrite(txCallback);
var accessMode = config.routing().equals(RoutingControl.WRITE) ? AccessMode.WRITE : AccessMode.READ;
return session.execute(accessMode, txCallback, TransactionConfig.empty(), false);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ public <T> T readTransaction(TransactionWork<T> work) {
@Override
@Deprecated
public <T> T readTransaction(TransactionWork<T> work, TransactionConfig config) {
return transaction(AccessMode.READ, work, config);
return transaction(AccessMode.READ, work, config, true);
}

@Override
public <T> T executeRead(TransactionCallback<T> callback, TransactionConfig config) {
return readTransaction(tx -> callback.execute(new DelegatingTransactionContext(tx)), config);
return execute(AccessMode.READ, callback, config, true);
}

@Override
Expand All @@ -124,12 +124,12 @@ public <T> T writeTransaction(TransactionWork<T> work) {
@Override
@Deprecated
public <T> T writeTransaction(TransactionWork<T> work, TransactionConfig config) {
return transaction(AccessMode.WRITE, work, config);
return transaction(AccessMode.WRITE, work, config, true);
}

@Override
public <T> T executeWrite(TransactionCallback<T> callback, TransactionConfig config) {
return writeTransaction(tx -> callback.execute(new DelegatingTransactionContext(tx)), config);
return execute(AccessMode.WRITE, callback, config, true);
}

@Override
Expand All @@ -151,14 +151,21 @@ public void reset() {
() -> terminateConnectionOnThreadInterrupt("Thread interrupted while resetting the session"));
}

<T> T execute(AccessMode accessMode, TransactionCallback<T> callback, TransactionConfig config, boolean flush) {
return transaction(accessMode, tx -> callback.execute(new DelegatingTransactionContext(tx)), config, flush);
}

private <T> T transaction(
AccessMode mode, @SuppressWarnings("deprecation") TransactionWork<T> work, TransactionConfig config) {
AccessMode mode,
@SuppressWarnings("deprecation") TransactionWork<T> work,
TransactionConfig config,
boolean flush) {
// use different code path compared to async so that work is executed in the caller thread
// caller thread will also be the one who sleeps between retries;
// it is unsafe to execute retries in the event loop threads because this can cause a deadlock
// event loop thread will bock and wait for itself to read some data
return session.retryLogic().retry(() -> {
try (var tx = beginTransaction(mode, config)) {
try (var tx = beginTransaction(mode, config, flush)) {

var result = work.execute(tx);
if (result instanceof Result) {
Expand All @@ -175,9 +182,9 @@ private <T> T transaction(
});
}

private Transaction beginTransaction(AccessMode mode, TransactionConfig config) {
private Transaction beginTransaction(AccessMode mode, TransactionConfig config, boolean flush) {
var tx = Futures.blockingGet(
session.beginTransactionAsync(mode, config),
session.beginTransactionAsync(mode, config, null, flush),
() -> terminateConnectionOnThreadInterrupt("Thread interrupted while starting a transaction"));
return new InternalTransaction(tx);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,19 +127,19 @@ public CompletionStage<RxResultCursor> runRx(
}

public CompletionStage<UnmanagedTransaction> beginTransactionAsync(TransactionConfig config) {
return beginTransactionAsync(mode, config, null);
return beginTransactionAsync(mode, config, null, true);
}

public CompletionStage<UnmanagedTransaction> beginTransactionAsync(TransactionConfig config, String txType) {
return this.beginTransactionAsync(mode, config, txType);
return this.beginTransactionAsync(mode, config, txType, true);
}

public CompletionStage<UnmanagedTransaction> beginTransactionAsync(AccessMode mode, TransactionConfig config) {
return beginTransactionAsync(mode, config, null);
return beginTransactionAsync(mode, config, null, true);
}

public CompletionStage<UnmanagedTransaction> beginTransactionAsync(
AccessMode mode, TransactionConfig config, String txType) {
AccessMode mode, TransactionConfig config, String txType, boolean flush) {
ensureSessionIsOpen();

// create a chain that acquires connection and starts a transaction
Expand All @@ -150,7 +150,7 @@ public CompletionStage<UnmanagedTransaction> beginTransactionAsync(
.thenCompose(connection -> {
var tx = new UnmanagedTransaction(
connection, this::handleNewBookmark, fetchSize, notificationConfig, logging);
return tx.beginAsync(determineBookmarks(true), config, txType);
return tx.beginAsync(determineBookmarks(true), config, txType, flush);
});

// update the reference to the only known transaction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ private enum State {
private Throwable causeOfTermination;
private CompletionStage<Void> terminationStage;
private final NotificationConfig notificationConfig;
private final CompletableFuture<UnmanagedTransaction> beginFuture = new CompletableFuture<>();
private final Logging logging;

public UnmanagedTransaction(
Expand Down Expand Up @@ -128,9 +129,10 @@ protected UnmanagedTransaction(
connection.bindTerminationAwareStateLockingExecutor(this);
}

// flush = false is only supported for async mode with a single subsequent run
public CompletionStage<UnmanagedTransaction> beginAsync(
Set<Bookmark> initialBookmarks, TransactionConfig config, String txType) {
return protocol.beginTransaction(connection, initialBookmarks, config, txType, notificationConfig, logging)
Set<Bookmark> initialBookmarks, TransactionConfig config, String txType, boolean flush) {
protocol.beginTransaction(connection, initialBookmarks, config, txType, notificationConfig, logging, flush)
.handle((ignore, beginError) -> {
if (beginError != null) {
if (beginError instanceof AuthorizationExpiredException) {
Expand All @@ -143,7 +145,9 @@ public CompletionStage<UnmanagedTransaction> beginAsync(
throw asCompletionException(beginError);
}
return this;
});
})
.whenComplete(futureCompletingConsumer(beginFuture));
return flush ? beginFuture : CompletableFuture.completedFuture(this);
}

public CompletionStage<Void> closeAsync() {
Expand All @@ -167,9 +171,9 @@ public CompletionStage<ResultCursor> runAsync(Query query) {
var cursorStage = protocol.runInUnmanagedTransaction(connection, query, this, fetchSize)
.asyncResult();
resultCursors.add(cursorStage);
return cursorStage
return beginFuture.thenCompose(ignored -> cursorStage
.thenCompose(AsyncResultCursor::mapSuccessfulRunCompletionAsync)
.thenApply(Function.identity());
.thenApply(Function.identity()));
}

public CompletionStage<RxResultCursor> runRx(Query query) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,22 @@
package org.neo4j.driver.internal.handlers;

import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setConnectionId;
import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setConnectionReadTimeout;
import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setServerAgent;
import static org.neo4j.driver.internal.util.MetadataExtractor.extractServer;

import io.netty.channel.Channel;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.function.Supplier;
import org.neo4j.driver.Value;
import org.neo4j.driver.internal.spi.ResponseHandler;

public class HelloV51ResponseHandler implements ResponseHandler {
private static final String CONNECTION_ID_METADATA_KEY = "connection_id";
public static final String CONFIGURATION_HINTS_KEY = "hints";
public static final String CONNECTION_RECEIVE_TIMEOUT_SECONDS_KEY = "connection.recv_timeout_seconds";

private final Channel channel;
private final CompletableFuture<Void> helloFuture;
Expand All @@ -48,6 +53,8 @@ public void onSuccess(Map<String, Value> metadata) {
var connectionId = extractConnectionId(metadata);
setConnectionId(channel, connectionId);

processConfigurationHints(metadata);

helloFuture.complete(null);
} catch (Throwable error) {
onFailure(error);
Expand All @@ -65,6 +72,16 @@ public void onRecord(Value[] fields) {
throw new UnsupportedOperationException();
}

private void processConfigurationHints(Map<String, Value> metadata) {
var configurationHints = metadata.get(CONFIGURATION_HINTS_KEY);
if (configurationHints != null) {
getFromSupplierOrEmptyOnException(() -> configurationHints
.get(CONNECTION_RECEIVE_TIMEOUT_SECONDS_KEY)
.asLong())
.ifPresent(timeout -> setConnectionReadTimeout(channel, timeout));
}
}

private static String extractConnectionId(Map<String, Value> metadata) {
var value = metadata.get(CONNECTION_ID_METADATA_KEY);
if (value == null || value.isNull()) {
Expand All @@ -73,4 +90,12 @@ private static String extractConnectionId(Map<String, Value> metadata) {
}
return value.asString();
}

private static <T> Optional<T> getFromSupplierOrEmptyOnException(Supplier<T> supplier) {
try {
return Optional.of(supplier.get());
} catch (Exception e) {
return Optional.empty();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ void initializeChannel(
* @param txType the Kernel transaction type
* @param notificationConfig the notification configuration
* @param logging the driver logging
* @param flush defines whether to flush the message to the connection
* @return a completion stage completed when transaction is started or completed exceptionally when there was a failure.
*/
CompletionStage<Void> beginTransaction(
Expand All @@ -103,7 +104,8 @@ CompletionStage<Void> beginTransaction(
TransactionConfig config,
String txType,
NotificationConfig notificationConfig,
Logging logging);
Logging logging,
boolean flush);

/**
* Commit the unmanaged transaction.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ public CompletionStage<Void> beginTransaction(
TransactionConfig config,
String txType,
NotificationConfig notificationConfig,
Logging logging) {
Logging logging,
boolean flush) {
var exception = verifyNotificationConfigSupported(notificationConfig);
if (exception != null) {
return CompletableFuture.failedStage(exception);
Expand All @@ -158,7 +159,12 @@ public CompletionStage<Void> beginTransaction(
txType,
notificationConfig,
logging);
connection.writeAndFlush(beginMessage, new BeginTxResponseHandler(beginTxFuture));
var handler = new BeginTxResponseHandler(beginTxFuture);
if (flush) {
connection.writeAndFlush(beginMessage, handler);
} else {
connection.write(beginMessage, handler);
}
return beginTxFuture;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.then;
import static org.mockito.Mockito.mock;
Expand All @@ -29,13 +30,13 @@
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collector;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor;
import org.neo4j.driver.AccessMode;
import org.neo4j.driver.BookmarkManager;
import org.neo4j.driver.Driver;
import org.neo4j.driver.ExecutableQuery;
Expand All @@ -44,9 +45,9 @@
import org.neo4j.driver.Record;
import org.neo4j.driver.Result;
import org.neo4j.driver.RoutingControl;
import org.neo4j.driver.Session;
import org.neo4j.driver.SessionConfig;
import org.neo4j.driver.TransactionCallback;
import org.neo4j.driver.TransactionConfig;
import org.neo4j.driver.TransactionContext;
import org.neo4j.driver.summary.ResultSummary;

Expand Down Expand Up @@ -126,15 +127,15 @@ void shouldExecuteAndReturnResult(RoutingControl routingControl) {
var driver = mock(Driver.class);
var bookmarkManager = mock(BookmarkManager.class);
given(driver.executableQueryBookmarkManager()).willReturn(bookmarkManager);
var session = mock(Session.class);
var session = mock(InternalSession.class);
given(driver.session(any(SessionConfig.class))).willReturn(session);
var txContext = mock(TransactionContext.class);
BiFunction<Session, TransactionCallback<Object>, Object> executeMethod =
routingControl.equals(RoutingControl.READ) ? Session::executeRead : Session::executeWrite;
given(executeMethod.apply(session, any())).willAnswer(answer -> {
TransactionCallback<?> txCallback = answer.getArgument(0);
return txCallback.execute(txContext);
});
var accessMode = routingControl.equals(RoutingControl.WRITE) ? AccessMode.WRITE : AccessMode.READ;
given(session.execute(eq(accessMode), any(), eq(TransactionConfig.empty()), eq(false)))
.willAnswer(answer -> {
TransactionCallback<?> txCallback = answer.getArgument(1);
return txCallback.execute(txContext);
});
var result = mock(Result.class);
given(txContext.run(any(Query.class))).willReturn(result);
var keys = List.of("key");
Expand Down Expand Up @@ -180,7 +181,7 @@ var record = mock(Record.class);
.withBookmarkManager(bookmarkManager)
.build();
assertEquals(expectedSessionConfig, sessionConfig);
executeMethod.apply(then(session).should(), any(TransactionCallback.class));
then(session).should().execute(eq(accessMode), any(), eq(TransactionConfig.empty()), eq(false));
then(txContext).should().run(query.withParameters(params));
then(result).should(times(2)).hasNext();
then(result).should().next();
Expand Down
Loading

0 comments on commit 0c92a88

Please sign in to comment.