Skip to content

Commit

Permalink
Change abstraction point for transport protocol (opensearch-project#1…
Browse files Browse the repository at this point in the history
…5432)

* Revert "Replacing InboundMessage with NativeInboundMessage for deprecation (opensearch-project#13126)"

This reverts commit f5c3ef9.

Signed-off-by: Andrew Ross <andrross@amazon.com>

* Change abstraction point for transport protocol

The previous implementation had a transport switch point in
InboundPipeline when the bytes were initially pulled off the wire. There
was no implementation for any other protocol as the `canHandleBytes`
method was hardcoded to return true. I believe this is the wrong point
to switch on the protocol. This change makes NativeInboundBytesHandler
protocol agnostic beyond the header. With this change, a complete
message is parsed from the stream of bytes, with the header schema being
unchanged from what exists today. The protocol switch point will now be
at `InboundHandler::inboundMessage`. The header will indicate what
protocol was used to serialize the the non-header bytes of the message
and then invoke the appropriate handler based on that field.

Signed-off-by: Andrew Ross <andrross@amazon.com>

---------

Signed-off-by: Andrew Ross <andrross@amazon.com>
  • Loading branch information
andrross authored and dk2k committed Oct 17, 2024
1 parent f3a30d5 commit b4ec9e9
Show file tree
Hide file tree
Showing 18 changed files with 467 additions and 402 deletions.
10 changes: 9 additions & 1 deletion server/src/main/java/org/opensearch/transport/Header.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public class Header {

private static final String RESPONSE_NAME = "NO_ACTION_NAME_FOR_RESPONSES";

private final TransportProtocol protocol;
private final int networkMessageSize;
private final Version version;
private final long requestId;
Expand All @@ -64,13 +65,18 @@ public class Header {
Tuple<Map<String, String>, Map<String, Set<String>>> headers;
Set<String> features;

Header(int networkMessageSize, long requestId, byte status, Version version) {
Header(TransportProtocol protocol, int networkMessageSize, long requestId, byte status, Version version) {
this.protocol = protocol;
this.networkMessageSize = networkMessageSize;
this.version = version;
this.requestId = requestId;
this.status = status;
}

TransportProtocol getTransportProtocol() {
return protocol;
}

public int getNetworkMessageSize() {
return networkMessageSize;
}
Expand Down Expand Up @@ -142,6 +148,8 @@ void finishParsingHeader(StreamInput input) throws IOException {
@Override
public String toString() {
return "Header{"
+ protocol
+ "}{"
+ networkMessageSize
+ "}{"
+ version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.bytes.CompositeBytesReference;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -114,7 +113,7 @@ public void aggregate(ReleasableBytesReference content) {
}
}

public NativeInboundMessage finishAggregation() throws IOException {
public InboundMessage finishAggregation() throws IOException {
ensureOpen();
final ReleasableBytesReference releasableContent;
if (isFirstContent()) {
Expand All @@ -128,7 +127,7 @@ public NativeInboundMessage finishAggregation() throws IOException {
}

final BreakerControl breakerControl = new BreakerControl(circuitBreaker);
final NativeInboundMessage aggregated = new NativeInboundMessage(currentHeader, releasableContent, breakerControl);
final InboundMessage aggregated = new InboundMessage(currentHeader, releasableContent, breakerControl);
boolean success = false;
try {
if (aggregated.getHeader().needsToReadVariableHeader()) {
Expand All @@ -143,7 +142,7 @@ public NativeInboundMessage finishAggregation() throws IOException {
if (isShortCircuited()) {
aggregated.close();
success = true;
return new NativeInboundMessage(aggregated.getHeader(), aggregationException);
return new InboundMessage(aggregated.getHeader(), aggregationException);
} else {
success = true;
return aggregated;
Expand Down
137 changes: 126 additions & 11 deletions server/src/main/java/org/opensearch/transport/InboundBytesHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,139 @@
package org.opensearch.transport;

import org.opensearch.common.bytes.ReleasableBytesReference;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.core.common.bytes.CompositeBytesReference;

import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.function.BiConsumer;

/**
* Interface for handling inbound bytes. Can be implemented by different transport protocols.
* Handler for inbound bytes, using {@link InboundDecoder} to decode headers
* and {@link InboundAggregator} to assemble complete messages to forward to
* the given message handler to parse the message payload.
*/
public interface InboundBytesHandler extends Closeable {
class InboundBytesHandler {

public void doHandleBytes(
TcpChannel channel,
ReleasableBytesReference reference,
BiConsumer<TcpChannel, ProtocolInboundMessage> messageHandler
) throws IOException;
private static final ThreadLocal<ArrayList<Object>> fragmentList = ThreadLocal.withInitial(ArrayList::new);

public boolean canHandleBytes(ReleasableBytesReference reference);
private final ArrayDeque<ReleasableBytesReference> pending;
private final InboundDecoder decoder;
private final InboundAggregator aggregator;
private final StatsTracker statsTracker;
private boolean isClosed = false;

InboundBytesHandler(
ArrayDeque<ReleasableBytesReference> pending,
InboundDecoder decoder,
InboundAggregator aggregator,
StatsTracker statsTracker
) {
this.pending = pending;
this.decoder = decoder;
this.aggregator = aggregator;
this.statsTracker = statsTracker;
}

public void close() {
isClosed = true;
}

public void doHandleBytes(TcpChannel channel, ReleasableBytesReference reference, BiConsumer<TcpChannel, InboundMessage> messageHandler)
throws IOException {
final ArrayList<Object> fragments = fragmentList.get();
boolean continueHandling = true;

while (continueHandling && isClosed == false) {
boolean continueDecoding = true;
while (continueDecoding && pending.isEmpty() == false) {
try (ReleasableBytesReference toDecode = getPendingBytes()) {
final int bytesDecoded = decoder.decode(toDecode, fragments::add);
if (bytesDecoded != 0) {
releasePendingBytes(bytesDecoded);
if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) {
continueDecoding = false;
}
} else {
continueDecoding = false;
}
}
}

if (fragments.isEmpty()) {
continueHandling = false;
} else {
try {
forwardFragments(channel, fragments, messageHandler);
} finally {
for (Object fragment : fragments) {
if (fragment instanceof ReleasableBytesReference) {
((ReleasableBytesReference) fragment).close();
}
}
fragments.clear();
}
}
}
}

private ReleasableBytesReference getPendingBytes() {
if (pending.size() == 1) {
return pending.peekFirst().retain();
} else {
final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()];
int index = 0;
for (ReleasableBytesReference pendingReference : pending) {
bytesReferences[index] = pendingReference.retain();
++index;
}
final Releasable releasable = () -> Releasables.closeWhileHandlingException(bytesReferences);
return new ReleasableBytesReference(CompositeBytesReference.of(bytesReferences), releasable);
}
}

private void releasePendingBytes(int bytesConsumed) {
int bytesToRelease = bytesConsumed;
while (bytesToRelease != 0) {
try (ReleasableBytesReference reference = pending.pollFirst()) {
assert reference != null;
if (bytesToRelease < reference.length()) {
pending.addFirst(reference.retainedSlice(bytesToRelease, reference.length() - bytesToRelease));
bytesToRelease -= bytesToRelease;
} else {
bytesToRelease -= reference.length();
}
}
}
}

private boolean endOfMessage(Object fragment) {
return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception;
}

private void forwardFragments(TcpChannel channel, ArrayList<Object> fragments, BiConsumer<TcpChannel, InboundMessage> messageHandler)
throws IOException {
for (Object fragment : fragments) {
if (fragment instanceof Header) {
assert aggregator.isAggregating() == false;
aggregator.headerReceived((Header) fragment);
} else if (fragment == InboundDecoder.PING) {
assert aggregator.isAggregating() == false;
messageHandler.accept(channel, InboundMessage.PING);
} else if (fragment == InboundDecoder.END_CONTENT) {
assert aggregator.isAggregating();
try (InboundMessage aggregated = aggregator.finishAggregation()) {
statsTracker.markMessageReceived();
messageHandler.accept(channel, aggregated);
}
} else {
assert aggregator.isAggregating();
assert fragment instanceof ReleasableBytesReference;
aggregator.aggregate((ReleasableBytesReference) fragment);
}
}
}

@Override
void close();
}
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,12 @@ private int headerBytesToRead(BytesReference reference) {
// exposed for use in tests
static Header readHeader(Version version, int networkMessageSize, BytesReference bytesReference) throws IOException {
try (StreamInput streamInput = bytesReference.streamInput()) {
streamInput.skip(TcpHeader.BYTES_REQUIRED_FOR_MESSAGE_SIZE);
TransportProtocol protocol = TransportProtocol.fromBytes(streamInput.readByte(), streamInput.readByte());
streamInput.skip(TcpHeader.MESSAGE_LENGTH_SIZE);
long requestId = streamInput.readLong();
byte status = streamInput.readByte();
Version remoteVersion = Version.fromId(streamInput.readInt());
Header header = new Header(networkMessageSize, requestId, status, remoteVersion);
Header header = new Header(protocol, networkMessageSize, requestId, status, remoteVersion);
final IllegalStateException invalidVersion = ensureVersionCompatibility(remoteVersion, version, header.isHandshake());
if (invalidVersion != null) {
throw invalidVersion;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;

import java.io.IOException;
import java.util.Map;
Expand All @@ -56,7 +55,7 @@ public class InboundHandler {

private volatile long slowLogThresholdMs = Long.MAX_VALUE;

private final Map<String, ProtocolMessageHandler> protocolMessageHandlers;
private final Map<TransportProtocol, ProtocolMessageHandler> protocolMessageHandlers;

InboundHandler(
String nodeName,
Expand All @@ -75,7 +74,7 @@ public class InboundHandler {
) {
this.threadPool = threadPool;
this.protocolMessageHandlers = Map.of(
NativeInboundMessage.NATIVE_PROTOCOL,
TransportProtocol.NATIVE,
new NativeMessageHandler(
nodeName,
version,
Expand Down Expand Up @@ -107,16 +106,16 @@ void setSlowLogThreshold(TimeValue slowLogThreshold) {
this.slowLogThresholdMs = slowLogThreshold.getMillis();
}

void inboundMessage(TcpChannel channel, ProtocolInboundMessage message) throws Exception {
void inboundMessage(TcpChannel channel, InboundMessage message) throws Exception {
final long startTime = threadPool.relativeTimeInMillis();
channel.getChannelStats().markAccessed(startTime);
messageReceivedFromPipeline(channel, message, startTime);
}

private void messageReceivedFromPipeline(TcpChannel channel, ProtocolInboundMessage message, long startTime) throws IOException {
ProtocolMessageHandler protocolMessageHandler = protocolMessageHandlers.get(message.getProtocol());
private void messageReceivedFromPipeline(TcpChannel channel, InboundMessage message, long startTime) throws IOException {
ProtocolMessageHandler protocolMessageHandler = protocolMessageHandlers.get(message.getTransportProtocol());
if (protocolMessageHandler == null) {
throw new IllegalStateException("No protocol message handler found for protocol: " + message.getProtocol());
throw new IllegalStateException("No protocol message handler found for protocol: " + message.getTransportProtocol());
}
protocolMessageHandler.messageReceived(channel, message, startTime, slowLogThresholdMs, messageListener);
}
Expand Down
Loading

0 comments on commit b4ec9e9

Please sign in to comment.