Skip to content

Commit

Permalink
[core] Use SyncContext for InProcessTransport listener callbacks to a…
Browse files Browse the repository at this point in the history
…void deadlocks (Fixes bug grpc#3084)

Also support unary calls returning null values
  • Loading branch information
larry-safran committed Jun 27, 2022
1 parent 33fbb9d commit 6c82a0f
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 93 deletions.
209 changes: 126 additions & 83 deletions core/src/main/java/io/grpc/inprocess/InProcessTransport.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import io.grpc.SecurityLevel;
import io.grpc.ServerStreamTracer;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.internal.ClientStream;
import io.grpc.internal.ClientStreamListener;
import io.grpc.internal.ClientStreamListener.RpcProgress;
Expand Down Expand Up @@ -407,8 +408,8 @@ private void streamClosed() {

private class InProcessServerStream implements ServerStream {
final StatsTraceContext statsTraceCtx;
@GuardedBy("this")
private ClientStreamListener clientStreamListener;
private final SynchronizationContext syncContext;
@GuardedBy("this")
private int clientRequested;
@GuardedBy("this")
Expand All @@ -426,7 +427,14 @@ private class InProcessServerStream implements ServerStream {

InProcessServerStream(MethodDescriptor<?, ?> method, Metadata headers) {
statsTraceCtx = StatsTraceContext.newServerContext(
serverStreamTracerFactories, method.getFullMethodName(), headers);
serverStreamTracerFactories, method.getFullMethodName(), headers);

syncContext = new SynchronizationContext(new Thread.UncaughtExceptionHandler() {
@Override
public void uncaughtException(Thread t, Throwable e) {
throw new RuntimeException(e);
}
});
}

private synchronized void setListener(ClientStreamListener listener) {
Expand All @@ -442,7 +450,7 @@ public void setListener(ServerStreamListener serverStreamListener) {
public void request(int numMessages) {
boolean onReady = clientStream.serverRequested(numMessages);
if (onReady) {
synchronized (this) {
synchronized (this) { // TODO How should this be handled (probably just remove)
if (!closed) {
clientStreamListener.onReady();
}
Expand All @@ -451,57 +459,69 @@ public void request(int numMessages) {
}

// This method is the only reason we have to synchronize field accesses.
// Using a SynchronizationContext to avoid possibility of deadlock in direct executors.
/**
* Client requested more messages.
*
* @return whether onReady should be called on the server
*/
private synchronized boolean clientRequested(int numMessages) {
if (closed) {
return false;
}
boolean previouslyReady = clientRequested > 0;
clientRequested += numMessages;
while (clientRequested > 0 && !clientReceiveQueue.isEmpty()) {
clientRequested--;
clientStreamListener.messagesAvailable(clientReceiveQueue.poll());
}
// Attempt being reentrant-safe
if (closed) {
return false;
}
if (clientReceiveQueue.isEmpty() && clientNotifyStatus != null) {
closed = true;
clientStream.statsTraceCtx.clientInboundTrailers(clientNotifyTrailers);
clientStream.statsTraceCtx.streamClosed(clientNotifyStatus);
clientStreamListener.closed(
private boolean clientRequested(int numMessages) {
boolean previouslyReady;
boolean nowReady;
synchronized (this) {
if (closed) {
return false;
}

previouslyReady = clientRequested > 0;
clientRequested += numMessages;
while (clientRequested > 0 && !clientReceiveQueue.isEmpty()) {
clientRequested--;
StreamListener.MessageProducer producer = clientReceiveQueue.poll();
syncContext.executeLater(() -> clientStreamListener.messagesAvailable(producer));
}

if (clientReceiveQueue.isEmpty() && clientNotifyStatus != null) {
closed = true;
clientStream.statsTraceCtx.clientInboundTrailers(clientNotifyTrailers);
clientStream.statsTraceCtx.streamClosed(clientNotifyStatus);
clientStreamListener.closed(
clientNotifyStatus, RpcProgress.PROCESSED, clientNotifyTrailers);
}

nowReady = clientRequested > 0;
}
boolean nowReady = clientRequested > 0;

syncContext.drain();
return !previouslyReady && nowReady;
}

private void clientCancelled(Status status) {
internalCancel(status);
}

// Using syncContext to avoid possibility of deadlock
@Override
public synchronized void writeMessage(InputStream message) {
if (closed) {
return;
}
statsTraceCtx.outboundMessage(outboundSeqNo);
statsTraceCtx.outboundMessageSent(outboundSeqNo, -1, -1);
clientStream.statsTraceCtx.inboundMessage(outboundSeqNo);
clientStream.statsTraceCtx.inboundMessageRead(outboundSeqNo, -1, -1);
outboundSeqNo++;
StreamListener.MessageProducer producer = new SingleMessageProducer(message);
if (clientRequested > 0) {
clientRequested--;
clientStreamListener.messagesAvailable(producer);
} else {
clientReceiveQueue.add(producer);
public void writeMessage(InputStream message) {
synchronized (this) {
if (closed) {
return;
}
statsTraceCtx.outboundMessage(outboundSeqNo);
statsTraceCtx.outboundMessageSent(outboundSeqNo, -1, -1);
clientStream.statsTraceCtx.inboundMessage(outboundSeqNo);
clientStream.statsTraceCtx.inboundMessageRead(outboundSeqNo, -1, -1);
outboundSeqNo++;
StreamListener.MessageProducer producer = new SingleMessageProducer(message);
if (clientRequested > 0) {
clientRequested--;
syncContext.executeLater(() -> clientStreamListener.messagesAvailable(producer));
} else {
clientReceiveQueue.add(producer);
}
}

syncContext.drain();
}

@Override
Expand Down Expand Up @@ -662,8 +682,8 @@ public int streamId() {
private class InProcessClientStream implements ClientStream {
final StatsTraceContext statsTraceCtx;
final CallOptions callOptions;
@GuardedBy("this")
private ServerStreamListener serverStreamListener;
private final SynchronizationContext syncContext;
@GuardedBy("this")
private int serverRequested;
@GuardedBy("this")
Expand All @@ -681,6 +701,15 @@ private class InProcessClientStream implements ClientStream {
CallOptions callOptions, StatsTraceContext statsTraceContext) {
this.callOptions = callOptions;
statsTraceCtx = statsTraceContext;


syncContext = new SynchronizationContext(new Thread.UncaughtExceptionHandler() {
@Override
public void uncaughtException(Thread t, Throwable e) {
throw new RuntimeException(e);
}
});

}

private synchronized void setListener(ServerStreamListener listener) {
Expand All @@ -705,21 +734,29 @@ public void request(int numMessages) {
*
* @return whether onReady should be called on the server
*/
private synchronized boolean serverRequested(int numMessages) {
if (closed) {
return false;
}
boolean previouslyReady = serverRequested > 0;
serverRequested += numMessages;
while (serverRequested > 0 && !serverReceiveQueue.isEmpty()) {
serverRequested--;
serverStreamListener.messagesAvailable(serverReceiveQueue.poll());
}
if (serverReceiveQueue.isEmpty() && serverNotifyHalfClose) {
serverNotifyHalfClose = false;
serverStreamListener.halfClosed();
private boolean serverRequested(int numMessages) {
boolean previouslyReady;
boolean nowReady;
synchronized (this) {
if (closed) {
return false;
}
previouslyReady = serverRequested > 0;
serverRequested += numMessages;

while (serverRequested > 0 && !serverReceiveQueue.isEmpty()) {
serverRequested--;
StreamListener.MessageProducer producer = serverReceiveQueue.poll();
syncContext.executeLater(() -> serverStreamListener.messagesAvailable(producer));
}

if (serverReceiveQueue.isEmpty() && serverNotifyHalfClose) {
serverNotifyHalfClose = false;
serverStreamListener.halfClosed();
}
nowReady = serverRequested > 0;
}
boolean nowReady = serverRequested > 0;
syncContext.drain();
return !previouslyReady && nowReady;
}

Expand All @@ -728,22 +765,25 @@ private void serverClosed(Status serverListenerStatus, Status serverTracerStatus
}

@Override
public synchronized void writeMessage(InputStream message) {
if (closed) {
return;
}
statsTraceCtx.outboundMessage(outboundSeqNo);
statsTraceCtx.outboundMessageSent(outboundSeqNo, -1, -1);
serverStream.statsTraceCtx.inboundMessage(outboundSeqNo);
serverStream.statsTraceCtx.inboundMessageRead(outboundSeqNo, -1, -1);
outboundSeqNo++;
StreamListener.MessageProducer producer = new SingleMessageProducer(message);
if (serverRequested > 0) {
serverRequested--;
serverStreamListener.messagesAvailable(producer);
} else {
serverReceiveQueue.add(producer);
public void writeMessage(InputStream message) {
synchronized (this) {
if (closed) {
return;
}
statsTraceCtx.outboundMessage(outboundSeqNo);
statsTraceCtx.outboundMessageSent(outboundSeqNo, -1, -1);
serverStream.statsTraceCtx.inboundMessage(outboundSeqNo);
serverStream.statsTraceCtx.inboundMessageRead(outboundSeqNo, -1, -1);
outboundSeqNo++;
StreamListener.MessageProducer producer = new SingleMessageProducer(message);
if (serverRequested > 0) {
serverRequested--;
syncContext.executeLater(() -> serverStreamListener.messagesAvailable(producer));
} else {
serverReceiveQueue.add(producer);
}
}
syncContext.drain();
}

@Override
Expand All @@ -768,26 +808,29 @@ public void cancel(Status reason) {
streamClosed();
}

private synchronized boolean internalCancel(
private boolean internalCancel(
Status serverListenerStatus, Status serverTracerStatus) {
if (closed) {
return false;
}
closed = true;
synchronized(this) {
if (closed) {
return false;
}
closed = true;

StreamListener.MessageProducer producer;
while ((producer = serverReceiveQueue.poll()) != null) {
InputStream message;
while ((message = producer.next()) != null) {
try {
message.close();
} catch (Throwable t) {
log.log(Level.WARNING, "Exception closing stream", t);
StreamListener.MessageProducer producer;
while ((producer = serverReceiveQueue.poll()) != null) {
InputStream message;
while ((message = producer.next()) != null) {
try {
message.close();
} catch (Throwable t) {
log.log(Level.WARNING, "Exception closing stream", t);
}
}
}
serverStream.statsTraceCtx.streamClosed(serverTracerStatus);
syncContext.executeLater(() -> serverStreamListener.closed(serverListenerStatus));
}
serverStream.statsTraceCtx.streamClosed(serverTracerStatus);
serverStreamListener.closed(serverListenerStatus);
syncContext.drain();
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
* Not intended to provide a high code coverage or to test every major usecase.
*
* directExecutor() makes it easier to have deterministic tests.
* However, if your implementation uses another thread and uses streaming it is better to use
* the default executor, to avoid hitting bug #3084.
*
* <p>For more unit test examples see {@link io.grpc.examples.routeguide.RouteGuideClientTest} and
* {@link io.grpc.examples.routeguide.RouteGuideServerTest}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
* Not intended to provide a high code coverage or to test every major usecase.
*
* directExecutor() makes it easier to have deterministic tests.
* However, if your implementation uses another thread and uses streaming it is better to use
* the default executor, to avoid hitting bug #3084.
*
* <p>For more unit test examples see {@link io.grpc.examples.routeguide.RouteGuideClientTest} and
* {@link io.grpc.examples.routeguide.RouteGuideServerTest}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@
* Not intended to provide a high code coverage or to test every major usecase.
*
* directExecutor() makes it easier to have deterministic tests.
* However, if your implementation uses another thread and uses streaming it is better to use
* the default executor, to avoid hitting bug #3084.
*
* <p>For basic unit test examples see {@link io.grpc.examples.helloworld.HelloWorldClientTest} and
* {@link io.grpc.examples.helloworld.HelloWorldServerTest}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@
* Not intended to provide a high code coverage or to test every major usecase.
*
* directExecutor() makes it easier to have deterministic tests.
* However, if your implementation uses another thread and uses streaming it is better to use
* the default executor, to avoid hitting bug #3084.
*
* <p>For basic unit test examples see {@link io.grpc.examples.helloworld.HelloWorldClientTest} and
* {@link io.grpc.examples.helloworld.HelloWorldServerTest}.
Expand Down
6 changes: 4 additions & 2 deletions stub/src/main/java/io/grpc/stub/ClientCalls.java
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ void onStart() {
private static final class UnaryStreamToFuture<RespT> extends StartableListener<RespT> {
private final GrpcFuture<RespT> responseFuture;
private RespT value;
private boolean isValueReceived = false;

// Non private to avoid synthetic class
UnaryStreamToFuture(GrpcFuture<RespT> responseFuture) {
Expand All @@ -521,17 +522,18 @@ public void onHeaders(Metadata headers) {

@Override
public void onMessage(RespT value) {
if (this.value != null) {
if (this.isValueReceived) {
throw Status.INTERNAL.withDescription("More than one value received for unary call")
.asRuntimeException();
}
this.value = value;
this.isValueReceived = true;
}

@Override
public void onClose(Status status, Metadata trailers) {
if (status.isOk()) {
if (value == null) {
if (!isValueReceived) {
// No value received so mark the future as an error
responseFuture.setException(
Status.INTERNAL.withDescription("No value received for unary call")
Expand Down

0 comments on commit 6c82a0f

Please sign in to comment.