Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deserialize responses on the handling thread-pool #91367

Merged
merged 6 commits into from
Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/91367.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 91367
summary: Deserialize responses on the handling thread-pool
area: Network
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public InboundMessage finishAggregation() throws IOException {
checkBreaker(aggregated.getHeader(), aggregated.getContentLength(), breakerControl);
}
if (isShortCircuited()) {
aggregated.close();
aggregated.decRef();
success = true;
return new InboundMessage(aggregated.getHeader(), aggregationException);
} else {
Expand All @@ -130,7 +130,7 @@ public InboundMessage finishAggregation() throws IOException {
} finally {
resetCurrentAggregation();
if (success == false) {
aggregated.close();
aggregated.decRef();
}
}
}
Expand Down
110 changes: 71 additions & 39 deletions server/src/main/java/org/elasticsearch/transport/InboundHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import org.elasticsearch.common.network.HandlingTimeTracker;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.threadpool.ThreadPool;

Expand Down Expand Up @@ -133,34 +135,17 @@ private void messageReceived(TcpChannel channel, InboundMessage message, long st
}
// ignore if its null, the service logs it
if (responseHandler != null) {
final StreamInput streamInput;
if (message.getContentLength() > 0 || header.getVersion().equals(Version.CURRENT) == false) {
streamInput = namedWriteableStream(message.openOrGetStreamInput());
final StreamInput streamInput = namedWriteableStream(message.openOrGetStreamInput());
assertRemoteVersion(streamInput, header.getVersion());
if (header.isError()) {
handlerResponseError(streamInput, responseHandler);
handlerResponseError(streamInput, message, responseHandler);
} else {
handleResponse(remoteAddress, streamInput, responseHandler);
}
// Check the entire message has been read
final int nextByte = streamInput.read();
// calling read() is useful to make sure the message is fully read, even if there is an EOS marker
if (nextByte != -1) {
final IllegalStateException exception = new IllegalStateException(
"Message not fully read (response) for requestId ["
+ requestId
+ "], handler ["
+ responseHandler
+ "], error ["
+ header.isError()
+ "]; resetting"
);
assert ignoreDeserializationErrors : exception;
throw exception;
handleResponse(remoteAddress, streamInput, responseHandler, message);
}
} else {
assert header.isError() == false;
handleResponse(remoteAddress, EMPTY_STREAM_INPUT, responseHandler);
handleResponse(remoteAddress, EMPTY_STREAM_INPUT, responseHandler, message);
}
}
}
Expand Down Expand Up @@ -189,6 +174,26 @@ private void messageReceived(TcpChannel channel, InboundMessage message, long st
}
}

private void verifyResponseReadFully(Header header, TransportResponseHandler<?> responseHandler, StreamInput streamInput)
throws IOException {
// Check the entire message has been read
final int nextByte = streamInput.read();
// calling read() is useful to make sure the message is fully read, even if there is an EOS marker
if (nextByte != -1) {
final IllegalStateException exception = new IllegalStateException(
"Message not fully read (response) for requestId ["
+ header.getRequestId()
+ "], handler ["
+ responseHandler
+ "], error ["
+ header.isError()
+ "]; resetting"
);
assert ignoreDeserializationErrors : exception;
throw exception;
}
}

private <T extends TransportRequest> void handleRequest(TcpChannel channel, Header header, InboundMessage message) throws IOException {
final String action = header.getActionName();
final long requestId = header.getRequestId();
Expand Down Expand Up @@ -335,10 +340,49 @@ private static void sendErrorResponse(String actionName, TransportChannel transp
private <T extends TransportResponse> void handleResponse(
InetSocketAddress remoteAddress,
final StreamInput stream,
final TransportResponseHandler<T> handler
final TransportResponseHandler<T> handler,
final InboundMessage inboundMessage
) {
final String executor = handler.executor();
if (ThreadPool.Names.SAME.equals(executor)) {
// no need to provide a buffer release here, we never escape the buffer when handling directly
doHandleResponse(handler, remoteAddress, stream, inboundMessage.getHeader(), () -> {});
} else {
inboundMessage.incRef();
// release buffer once we deserialize the message, but have a fail-safe in #onAfter below in case that didn't work out
final Releasable releaseBuffer = Releasables.releaseOnce(inboundMessage::decRef);
DaveCTurner marked this conversation as resolved.
Show resolved Hide resolved
threadPool.executor(executor).execute(new ForkingResponseHandlerRunnable(handler, null) {
@Override
protected void doRun() {
doHandleResponse(handler, remoteAddress, stream, inboundMessage.getHeader(), releaseBuffer);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I don't totally understand why we need to pass the releaseBuffer mechanism into the method here.

onAfter already handles the release. I'm not totally clear why it matters if the doHandleResponse method clearly releases the thing. It's already being released in onAfter no matter what.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The motivation here was to release the buffer asap and not needlessly hold on to it until the handler is is done with the deserialized message. The onAfter was just put in place as a final fail-safe.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact we should always release it in doHandleResponse, a response handler should never be rejected (see assertions in ForkingResponseHandlerRunnable) and there's no chance an exception could prevent it either.

That said, I'm 👍 on paranoid leak prevention.

}

@Override
public void onAfter() {
Releasables.closeExpectNoException(releaseBuffer);
}
});
}
}

/**
*
* @param handler response handler
* @param remoteAddress remote address that the message was sent from
* @param stream bytes stream for reading the message
* @param header message header
* @param releaseResponseBuffer releasable that will be released once the message has been read from the {@code stream}
* @param <T> response message type
*/
private <T extends TransportResponse> void doHandleResponse(
TransportResponseHandler<T> handler,
InetSocketAddress remoteAddress,
final StreamInput stream,
final Header header,
Releasable releaseResponseBuffer
) {
final T response;
try {
try (releaseResponseBuffer) {
response = handler.read(stream);
response.remoteAddress(remoteAddress);
} catch (Exception e) {
Expand All @@ -348,24 +392,11 @@ private <T extends TransportResponse> void handleResponse(
);
logger.warn(() -> "Failed to deserialize response from [" + remoteAddress + "]", serializationException);
assert ignoreDeserializationErrors : e;
handleException(handler, serializationException);
doHandleException(handler, serializationException);
return;
}
final String executor = handler.executor();
if (ThreadPool.Names.SAME.equals(executor)) {
doHandleResponse(handler, response);
} else {
threadPool.executor(executor).execute(new ForkingResponseHandlerRunnable(handler, null) {
@Override
protected void doRun() {
doHandleResponse(handler, response);
}
});
}
}

private static <T extends TransportResponse> void doHandleResponse(TransportResponseHandler<T> handler, T response) {
try {
verifyResponseReadFully(header, handler, stream);
handler.handleResponse(response);
} catch (Exception e) {
doHandleException(handler, new ResponseHandlerFailureTransportException(e));
Expand All @@ -374,10 +405,11 @@ private static <T extends TransportResponse> void doHandleResponse(TransportResp
}
}

private void handlerResponseError(StreamInput stream, final TransportResponseHandler<?> handler) {
private void handlerResponseError(StreamInput stream, InboundMessage message, final TransportResponseHandler<?> handler) {
Exception error;
try {
error = stream.readException();
verifyResponseReadFully(message.getHeader(), handler, stream);
} catch (Exception e) {
error = new TransportSerializationException(
"Failed to deserialize exception response from stream for handler [" + handler + "]",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.core.AbstractRefCounted;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.Releasable;

import java.io.IOException;
import java.util.Objects;

public class InboundMessage implements Releasable {
public class InboundMessage extends AbstractRefCounted {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we assert that ref count is greater than 0 when openOrGetStreamInput is called?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ done


private final Header header;
private final ReleasableBytesReference content;
Expand Down Expand Up @@ -82,6 +83,7 @@ public Releasable takeBreakerReleaseControl() {

public StreamInput openOrGetStreamInput() throws IOException {
assert isPing == false && content != null;
assert hasReferences();
if (streamInput == null) {
streamInput = content.streamInput();
streamInput.setVersion(header.getVersion());
Expand All @@ -90,17 +92,17 @@ public StreamInput openOrGetStreamInput() throws IOException {
}

@Override
public void close() {
public String toString() {
return "InboundMessage{" + header + "}";
}

@Override
protected void closeInternal() {
try {
IOUtils.close(streamInput, content, breakerRelease);
} catch (Exception e) {
assert false : e;
throw new ElasticsearchException(e);
}
}

@Override
public String toString() {
return "InboundMessage{" + header + "}";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,12 @@ private void forwardFragments(TcpChannel channel, ArrayList<Object> fragments) t
messageHandler.accept(channel, PING_MESSAGE);
} else if (fragment == InboundDecoder.END_CONTENT) {
assert aggregator.isAggregating();
try (InboundMessage aggregated = aggregator.finishAggregation()) {
InboundMessage aggregated = aggregator.finishAggregation();
try {
statsTracker.markMessageReceived();
messageHandler.accept(channel, aggregated);
} finally {
aggregated.decRef();
}
} else {
assert aggregator.isAggregating();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public void testInboundAggregation() throws IOException {
for (ReleasableBytesReference reference : references) {
assertTrue(reference.hasReferences());
}
aggregated.close();
aggregated.decRef();
for (ReleasableBytesReference reference : references) {
assertFalse(reference.hasReferences());
}
Expand Down