diff --git a/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java b/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java index 44454fbf7f1..66c3479bfc0 100644 --- a/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java +++ b/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java @@ -47,13 +47,19 @@ import io.grpc.internal.ManagedClientTransport; import io.grpc.internal.ObjectPool; import io.grpc.internal.StreamListener; +import io.grpc.internal.StreamListener.MessageProducer; import io.grpc.protobuf.lite.ProtoLiteUtils; import io.grpc.stub.ServerCalls; +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayDeque; +import java.util.Deque; import java.util.concurrent.BlockingQueue; import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ScheduledExecutorService; import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -96,8 +102,6 @@ public final class BinderClientTransportTest { private final TestTransportListener transportListener = new TestTransportListener(); private final TestStreamListener streamListener = new TestStreamListener(); - private int serverCallsCompleted; - @Before public void setUp() throws Exception { ServerCallHandler callHandler = @@ -105,17 +109,15 @@ public void setUp() throws Exception { (req, respObserver) -> { respObserver.onNext(req); respObserver.onCompleted(); - serverCallsCompleted += 1; }); ServerCallHandler streamingCallHandler = - ServerCalls.asyncUnaryCall( + ServerCalls.asyncServerStreamingCall( (req, respObserver) -> { for (int i = 0; i < 100; i++) { respObserver.onNext(req); } respObserver.onCompleted(); - serverCallsCompleted += 1; }); ServerServiceDefinition serviceDef = @@ -191,9 +193,7 @@ public void testRequestWhileStreamIsWaitingOnCall_b154088869() throws Exception stream.halfClose(); stream.request(3); - streamListener.awaitMessages(); - streamListener.messageProducer.next(); - streamListener.messageProducer.next(); + streamListener.readAndDiscardMessages(2); // Without the fix, this loops forever. stream.request(2); @@ -231,17 +231,12 @@ public void testBadTransactionStreamThroughput_b163053382() throws Exception { stream.halfClose(); stream.request(1000); - // Wait until we receive the first message. - streamListener.awaitMessages(); - // Wait until the server actually provides all messages and completes the call. - awaitServerCallsCompleted(1); - - // Now we should be able to receive all messages on a single message producer. - assertThat(streamListener.drainMessages()).isEqualTo(100); + // We should eventually see all messages despite receiving no more transactions from the server. + streamListener.readAndDiscardMessages(100); } @Test - public void testMessageProducerClosedAfterStream_b169313545() { + public void testMessageProducerClosedAfterStream_b169313545() throws Exception { transport = new BinderClientTransportBuilder().build(); startAndAwaitReady(transport, transportListener); ClientStream stream = @@ -278,16 +273,6 @@ public void testNewStreamBeforeTransportReadyFails() throws InterruptedException transportListener.awaitReady(); } - private synchronized void awaitServerCallsCompleted(int calls) { - while (serverCallsCompleted < calls) { - try { - wait(100); - } catch (InterruptedException inte) { - throw new AssertionError("Interrupted waiting for servercalls"); - } - } - } - private static void startAndAwaitReady( BinderTransport.BinderClientTransport transport, TestTransportListener transportListener) { transport.start(transportListener).run(); @@ -295,7 +280,9 @@ private static void startAndAwaitReady( } private static final class TestTransportListener implements ManagedClientTransport.Listener { - public boolean ready; + @GuardedBy("this") + private boolean ready; + public boolean inUse; @Nullable public Status shutdownStatus; public boolean terminated; @@ -313,13 +300,13 @@ public void transportTerminated() { @Override public synchronized void transportReady() { ready = true; - notify(); + notifyAll(); } public synchronized void awaitReady() { while (!ready) { try { - wait(100); + wait(); } catch (InterruptedException inte) { throw new AssertionError("Interrupted waiting for ready"); } @@ -334,22 +321,45 @@ public void transportInUse(boolean inUse) { private static final class TestStreamListener implements ClientStreamListener { - public StreamListener.MessageProducer messageProducer; public boolean ready; public Metadata headers; - @Nullable public Status closedStatus; + + @GuardedBy("this") + private final Deque messageProducers = new ArrayDeque<>(); + + @GuardedBy("this") + @Nullable + private Status closedStatus; @Override - public void messagesAvailable(StreamListener.MessageProducer messageProducer) { - this.messageProducer = messageProducer; + public synchronized void messagesAvailable(StreamListener.MessageProducer messageProducer) { + messageProducers.add(messageProducer); + notifyAll(); } - public synchronized void awaitMessages() { - while (messageProducer == null) { - try { - wait(100); - } catch (InterruptedException inte) { - throw new AssertionError("Interrupted waiting for messages"); + /** Blocks until at least one MessageProducer has been provided for reading. */ + public synchronized void awaitMessages() throws InterruptedException { + while (messageProducers.isEmpty()) { + wait(); + } + } + + /** Blocks until {@code n} messages can be produced (and discarded). */ + public synchronized void readAndDiscardMessages(int n) + throws InterruptedException, IOException { + while (n > 0) { + while (closedStatus == null && messageProducers.isEmpty()) { + wait(); + } + if (closedStatus != null) { + throw closedStatus.withDescription("premature close").asRuntimeException(); + } + try (InputStream message = messageProducers.peek().next()) { + if (message == null) { + messageProducers.remove(); + continue; + } + n -= 1; } } } @@ -357,7 +367,7 @@ public synchronized void awaitMessages() { public synchronized Status awaitClose() { while (closedStatus == null) { try { - wait(100); + wait(); } catch (InterruptedException inte) { throw new AssertionError("Interrupted waiting for close"); } @@ -365,10 +375,17 @@ public synchronized Status awaitClose() { return closedStatus; } - public int drainMessages() { + /** Discards any messages available on the stream without reading them. Does not block. */ + public synchronized int drainMessages() throws IOException { int n = 0; - while (messageProducer.next() != null) { - n += 1; + while (!messageProducers.isEmpty()) { + try (InputStream message = messageProducers.peek().next()) { + if (message == null) { + messageProducers.remove(); + continue; + } + n += 1; + } } return n; } @@ -384,8 +401,9 @@ public void headersRead(Metadata headers) { } @Override - public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) { + public synchronized void closed(Status status, RpcProgress rpcProgress, Metadata trailers) { this.closedStatus = status; + notifyAll(); } }