Skip to content

Commit

Permalink
Change abstraction point for transport protocol
Browse files Browse the repository at this point in the history
The previous implementation had a transport switch point in
InboundPipeline when the bytes were initially pulled off the wire. There
was no implementation for any other protocol as the `canHandleBytes`
method was hardcoded to return true. I believe this is the wrong point
to switch on the protocol. This change makes NativeInboundBytesHandler
protocol agnostic beyond the header. With this change, a complete
message is parsed from the stream of bytes, with the header schema being
unchanged from what exists today. The protocol switch point will now be
at `InboundHandler::inboundMessage`. The header will indicate what
protocol was used to serialize the the non-header bytes of the message
and then invoke the appropriate handler based on that field.

Signed-off-by: Andrew Ross <andrross@amazon.com>
  • Loading branch information
andrross committed Aug 28, 2024
1 parent df86871 commit f6d5ce9
Show file tree
Hide file tree
Showing 15 changed files with 352 additions and 366 deletions.
10 changes: 9 additions & 1 deletion server/src/main/java/org/opensearch/transport/Header.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public class Header {

private static final String RESPONSE_NAME = "NO_ACTION_NAME_FOR_RESPONSES";

private final TransportProtocol protocol;
private final int networkMessageSize;
private final Version version;
private final long requestId;
Expand All @@ -64,13 +65,18 @@ public class Header {
Tuple<Map<String, String>, Map<String, Set<String>>> headers;
Set<String> features;

Header(int networkMessageSize, long requestId, byte status, Version version) {
Header(TransportProtocol protocol, int networkMessageSize, long requestId, byte status, Version version) {
this.protocol = protocol;
this.networkMessageSize = networkMessageSize;
this.version = version;
this.requestId = requestId;
this.status = status;
}

TransportProtocol getTransportProtocol() {
return protocol;
}

public int getNetworkMessageSize() {
return networkMessageSize;
}
Expand Down Expand Up @@ -142,6 +148,8 @@ void finishParsingHeader(StreamInput input) throws IOException {
@Override
public String toString() {
return "Header{"
+ protocol
+ "}{"
+ networkMessageSize
+ "}{"
+ version
Expand Down
137 changes: 126 additions & 11 deletions server/src/main/java/org/opensearch/transport/InboundBytesHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,139 @@
package org.opensearch.transport;

import org.opensearch.common.bytes.ReleasableBytesReference;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.core.common.bytes.CompositeBytesReference;

import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.function.BiConsumer;

/**
* Interface for handling inbound bytes. Can be implemented by different transport protocols.
* Handler for inbound bytes, using {@link InboundDecoder} to decode headers
* and {@link InboundAggregator} to assemble complete messages to forward to
* the given message handler to parse the message payload.
*/
public interface InboundBytesHandler extends Closeable {
class InboundBytesHandler {

public void doHandleBytes(
TcpChannel channel,
ReleasableBytesReference reference,
BiConsumer<TcpChannel, ProtocolInboundMessage> messageHandler
) throws IOException;
private static final ThreadLocal<ArrayList<Object>> fragmentList = ThreadLocal.withInitial(ArrayList::new);

public boolean canHandleBytes(ReleasableBytesReference reference);
private final ArrayDeque<ReleasableBytesReference> pending;
private final InboundDecoder decoder;
private final InboundAggregator aggregator;
private final StatsTracker statsTracker;
private boolean isClosed = false;

InboundBytesHandler(
ArrayDeque<ReleasableBytesReference> pending,
InboundDecoder decoder,
InboundAggregator aggregator,
StatsTracker statsTracker
) {
this.pending = pending;
this.decoder = decoder;
this.aggregator = aggregator;
this.statsTracker = statsTracker;
}

public void close() {
isClosed = true;
}

public void doHandleBytes(TcpChannel channel, ReleasableBytesReference reference, BiConsumer<TcpChannel, InboundMessage> messageHandler)
throws IOException {
final ArrayList<Object> fragments = fragmentList.get();
boolean continueHandling = true;

while (continueHandling && isClosed == false) {
boolean continueDecoding = true;
while (continueDecoding && pending.isEmpty() == false) {
try (ReleasableBytesReference toDecode = getPendingBytes()) {
final int bytesDecoded = decoder.decode(toDecode, fragments::add);
if (bytesDecoded != 0) {
releasePendingBytes(bytesDecoded);
if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) {
continueDecoding = false;
}
} else {
continueDecoding = false;
}
}
}

if (fragments.isEmpty()) {
continueHandling = false;
} else {
try {
forwardFragments(channel, fragments, messageHandler);
} finally {
for (Object fragment : fragments) {
if (fragment instanceof ReleasableBytesReference) {
((ReleasableBytesReference) fragment).close();
}
}
fragments.clear();
}
}
}
}

private ReleasableBytesReference getPendingBytes() {
if (pending.size() == 1) {
return pending.peekFirst().retain();
} else {
final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()];
int index = 0;
for (ReleasableBytesReference pendingReference : pending) {
bytesReferences[index] = pendingReference.retain();
++index;
}
final Releasable releasable = () -> Releasables.closeWhileHandlingException(bytesReferences);
return new ReleasableBytesReference(CompositeBytesReference.of(bytesReferences), releasable);
}
}

private void releasePendingBytes(int bytesConsumed) {
int bytesToRelease = bytesConsumed;
while (bytesToRelease != 0) {
try (ReleasableBytesReference reference = pending.pollFirst()) {
assert reference != null;
if (bytesToRelease < reference.length()) {
pending.addFirst(reference.retainedSlice(bytesToRelease, reference.length() - bytesToRelease));
bytesToRelease -= bytesToRelease;
} else {
bytesToRelease -= reference.length();
}
}
}
}

private boolean endOfMessage(Object fragment) {
return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception;
}

private void forwardFragments(TcpChannel channel, ArrayList<Object> fragments, BiConsumer<TcpChannel, InboundMessage> messageHandler)
throws IOException {
for (Object fragment : fragments) {
if (fragment instanceof Header) {
assert aggregator.isAggregating() == false;
aggregator.headerReceived((Header) fragment);
} else if (fragment == InboundDecoder.PING) {
assert aggregator.isAggregating() == false;
messageHandler.accept(channel, InboundMessage.PING);
} else if (fragment == InboundDecoder.END_CONTENT) {
assert aggregator.isAggregating();
try (InboundMessage aggregated = aggregator.finishAggregation()) {
statsTracker.markMessageReceived();
messageHandler.accept(channel, aggregated);
}
} else {
assert aggregator.isAggregating();
assert fragment instanceof ReleasableBytesReference;
aggregator.aggregate((ReleasableBytesReference) fragment);
}
}
}

@Override
void close();
}
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,12 @@ private int headerBytesToRead(BytesReference reference) {
// exposed for use in tests
static Header readHeader(Version version, int networkMessageSize, BytesReference bytesReference) throws IOException {
try (StreamInput streamInput = bytesReference.streamInput()) {
streamInput.skip(TcpHeader.BYTES_REQUIRED_FOR_MESSAGE_SIZE);
TransportProtocol protocol = TransportProtocol.fromBytes(streamInput.readByte(), streamInput.readByte());
streamInput.skip(TcpHeader.MESSAGE_LENGTH_SIZE);
long requestId = streamInput.readLong();
byte status = streamInput.readByte();
Version remoteVersion = Version.fromId(streamInput.readInt());
Header header = new Header(networkMessageSize, requestId, status, remoteVersion);
Header header = new Header(protocol, networkMessageSize, requestId, status, remoteVersion);
final IllegalStateException invalidVersion = ensureVersionCompatibility(remoteVersion, version, header.isHandshake());
if (invalidVersion != null) {
throw invalidVersion;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;

import java.io.IOException;
import java.util.Map;
Expand All @@ -56,7 +55,7 @@ public class InboundHandler {

private volatile long slowLogThresholdMs = Long.MAX_VALUE;

private final Map<String, ProtocolMessageHandler> protocolMessageHandlers;
private final Map<TransportProtocol, ProtocolMessageHandler> protocolMessageHandlers;

InboundHandler(
String nodeName,
Expand All @@ -75,7 +74,7 @@ public class InboundHandler {
) {
this.threadPool = threadPool;
this.protocolMessageHandlers = Map.of(
NativeInboundMessage.NATIVE_PROTOCOL,
TransportProtocol.NATIVE,
new NativeMessageHandler(
nodeName,
version,
Expand Down Expand Up @@ -107,16 +106,16 @@ void setSlowLogThreshold(TimeValue slowLogThreshold) {
this.slowLogThresholdMs = slowLogThreshold.getMillis();
}

void inboundMessage(TcpChannel channel, ProtocolInboundMessage message) throws Exception {
void inboundMessage(TcpChannel channel, InboundMessage message) throws Exception {
final long startTime = threadPool.relativeTimeInMillis();
channel.getChannelStats().markAccessed(startTime);
messageReceivedFromPipeline(channel, message, startTime);
}

private void messageReceivedFromPipeline(TcpChannel channel, ProtocolInboundMessage message, long startTime) throws IOException {
ProtocolMessageHandler protocolMessageHandler = protocolMessageHandlers.get(message.getProtocol());
private void messageReceivedFromPipeline(TcpChannel channel, InboundMessage message, long startTime) throws IOException {
ProtocolMessageHandler protocolMessageHandler = protocolMessageHandlers.get(message.getTransportProtocol());
if (protocolMessageHandler == null) {
throw new IllegalStateException("No protocol message handler found for protocol: " + message.getProtocol());
throw new IllegalStateException("No protocol message handler found for protocol: " + message.getTransportProtocol());

Check warning on line 118 in server/src/main/java/org/opensearch/transport/InboundHandler.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/transport/InboundHandler.java#L118

Added line #L118 was not covered by tests
}
protocolMessageHandler.messageReceived(channel, message, startTime, slowLogThresholdMs, messageListener);
}
Expand Down
89 changes: 65 additions & 24 deletions server/src/main/java/org/opensearch/transport/InboundMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,77 +32,118 @@

package org.opensearch.transport;

import org.opensearch.common.annotation.DeprecatedApi;
import org.opensearch.common.annotation.PublicApi;
import org.opensearch.common.bytes.ReleasableBytesReference;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.util.io.IOUtils;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;

import java.io.IOException;

/**
* Inbound data as a message
* This api is deprecated, please use {@link org.opensearch.transport.nativeprotocol.NativeInboundMessage} instead.
* @opensearch.api
*/
@DeprecatedApi(since = "2.14.0")
@PublicApi(since = "1.0.0")
public class InboundMessage implements Releasable, ProtocolInboundMessage {

private final NativeInboundMessage nativeInboundMessage;
static final InboundMessage PING = new InboundMessage(null, null, null, true, null);

protected final Header header;
protected final ReleasableBytesReference content;
protected final Exception exception;
protected final boolean isPing;
private Releasable breakerRelease;
private StreamInput streamInput;

public InboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) {
this.nativeInboundMessage = new NativeInboundMessage(header, content, breakerRelease);
this(header, content, null, false, breakerRelease);
}

public InboundMessage(Header header, Exception exception) {
this.nativeInboundMessage = new NativeInboundMessage(header, exception);
this(header, null, exception, false, null);
}

public InboundMessage(Header header, boolean isPing) {
this.nativeInboundMessage = new NativeInboundMessage(header, isPing);
this(header, null, null, isPing, null);
}

Check warning on line 69 in server/src/main/java/org/opensearch/transport/InboundMessage.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/transport/InboundMessage.java#L68-L69

Added lines #L68 - L69 were not covered by tests

private InboundMessage(
Header header,
ReleasableBytesReference content,
Exception exception,
boolean isPing,
Releasable breakerRelease
) {
this.header = header;
this.content = content;
this.exception = exception;
this.isPing = isPing;
this.breakerRelease = breakerRelease;
}

TransportProtocol getTransportProtocol() {
if (isPing) {
return TransportProtocol.NATIVE;
}
return header.getTransportProtocol();
}

public String getProtocol() {
return header.getTransportProtocol().toString();

Check warning on line 93 in server/src/main/java/org/opensearch/transport/InboundMessage.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/transport/InboundMessage.java#L93

Added line #L93 was not covered by tests
}

public Header getHeader() {
return this.nativeInboundMessage.getHeader();
return header;
}

public int getContentLength() {
return this.nativeInboundMessage.getContentLength();
if (content == null) {
return 0;

Check warning on line 102 in server/src/main/java/org/opensearch/transport/InboundMessage.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/transport/InboundMessage.java#L102

Added line #L102 was not covered by tests
} else {
return content.length();
}
}

public Exception getException() {
return this.nativeInboundMessage.getException();
return exception;
}

public boolean isPing() {
return this.nativeInboundMessage.isPing();
return isPing;
}

public boolean isShortCircuit() {
return this.nativeInboundMessage.getException() != null;
return exception != null;
}

public Releasable takeBreakerReleaseControl() {
return this.nativeInboundMessage.takeBreakerReleaseControl();
final Releasable toReturn = breakerRelease;
breakerRelease = null;
if (toReturn != null) {
return toReturn;
} else {
return () -> {};
}
}

public StreamInput openOrGetStreamInput() throws IOException {
return this.nativeInboundMessage.openOrGetStreamInput();
assert isPing == false && content != null;
if (streamInput == null) {
streamInput = content.streamInput();
streamInput.setVersion(header.getVersion());
}
return streamInput;
}

@Override
public void close() {
this.nativeInboundMessage.close();
IOUtils.closeWhileHandlingException(streamInput);
Releasables.closeWhileHandlingException(content, breakerRelease);
}

@Override
public String toString() {
return this.nativeInboundMessage.toString();
}

@Override
public String getProtocol() {
return this.nativeInboundMessage.getProtocol();
return "InboundMessage{" + header + "}";
}

}
Loading

0 comments on commit f6d5ce9

Please sign in to comment.