diff --git a/core/src/main/java/io/grpc/internal/AbstractClientStream.java b/core/src/main/java/io/grpc/internal/AbstractClientStream.java index 4ef743bf96d..a4ebfa52d63 100644 --- a/core/src/main/java/io/grpc/internal/AbstractClientStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractClientStream.java @@ -455,10 +455,10 @@ private void closeListener( if (!listenerClosed) { listenerClosed = true; statsTraceCtx.streamClosed(status); - listener().closed(status, rpcProgress, trailers); if (getTransportTracer() != null) { getTransportTracer().reportStreamClosed(status.isOk()); } + listener().closed(status, rpcProgress, trailers); } } } diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index 1cb2a668a45..56c9c9d68d5 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -195,7 +195,10 @@ public void run() { } } if (retryFuture != null) { - retryFuture.cancel(false); + boolean cancelled = retryFuture.cancel(false); + if (cancelled) { + inFlightSubStreams.decrementAndGet(); + } } if (hedgingFuture != null) { hedgingFuture.cancel(false); diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java index 72ed8bf975b..7a5bba7add1 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java @@ -18,7 +18,10 @@ import static com.google.common.truth.Truth.assertThat; import static java.util.concurrent.TimeUnit.SECONDS; +import static org.junit.Assert.assertNotNull; +import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.verify; @@ -78,8 +81,6 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.mockito.ArgumentCaptor; -import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -103,8 +104,12 @@ public class RetryTest { @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); private final FakeClock fakeClock = new FakeClock(); - @Mock - private ClientCall.Listener mockCallListener; + private TestListener testCallListener = new TestListener(); + @SuppressWarnings("unchecked") + private ClientCall.Listener mockCallListener = + mock(ClientCall.Listener.class, delegatesTo(testCallListener)); + private java.util.concurrent.ScheduledFuture activeFuture = null; + private CountDownLatch backoffLatch = new CountDownLatch(1); private final EventLoopGroup group = new DefaultEventLoopGroup() { @SuppressWarnings("FutureReturnValueIgnored") @@ -114,7 +119,7 @@ public ScheduledFuture schedule( if (!command.getClass().getName().contains("RetryBackoffRunnable")) { return super.schedule(command, delay, unit); } - fakeClock.getScheduledExecutorService().schedule( + activeFuture = fakeClock.getScheduledExecutorService().schedule( new Runnable() { @Override public void run() { @@ -244,8 +249,10 @@ private void assertInboundWireSizeRecorded(long length) throws Exception { private void assertRpcStatusRecorded( Status.Code code, long roundtripLatencyMs, long outboundMessages) throws Exception { - MetricsRecord record = clientStatsRecorder.pollRecord(5, SECONDS); + MetricsRecord record = clientStatsRecorder.pollRecord(7, SECONDS); + assertNotNull(record); TagValue statusTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_STATUS); + assertNotNull(statusTag); assertThat(statusTag.asString()).isEqualTo(code.toString()); assertThat(record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT)) .isEqualTo(1); @@ -295,14 +302,16 @@ public void retryUntilBufferLimitExceeded() throws Exception { verify(mockCallListener, never()).onClose(any(Status.class), any(Metadata.class)); // send one more message, should exceed buffer limit call.sendMessage(message); + // let attempt fail + testCallListener.clear(); serverCall.close( Status.UNAVAILABLE.withDescription("2nd attempt failed"), new Metadata()); + fakeClock.forwardTime(1, SECONDS); + activeFuture.get(1, SECONDS); // Make sure the close is done. // no more retry - ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); - verify(mockCallListener, timeout(5000)).onClose(statusCaptor.capture(), any(Metadata.class)); - assertThat(statusCaptor.getValue().getDescription()).contains("2nd attempt failed"); + testCallListener.verifyDescription("2nd attempt failed", 5000); } @Test @@ -414,9 +423,12 @@ public void streamClosed(Status status) { call.cancel("Cancelled before commit", null); // Let the netty substream listener be closed. streamClosedLatch.countDown(); + assertNotNull("No activeFuture", activeFuture); + fakeClock.forwardTime(1, SECONDS); + activeFuture.get(1, SECONDS); // The call listener is closed. verify(mockCallListener, timeout(5000)).onClose(any(Status.class), any(Metadata.class)); - assertRpcStatusRecorded(Code.CANCELLED, 17_000, 1); + assertRpcStatusRecorded(Code.CANCELLED, 18_000, 1); assertRetryStatsRecorded(1, 0, 0); } @@ -534,4 +546,26 @@ public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata header assertRpcStatusRecorded(Code.INVALID_ARGUMENT, 0, 0); assertRetryStatsRecorded(0, 1, 0); } + + private static class TestListener extends ClientCall.Listener { + Status status = null; + private CountDownLatch closeLatch = new CountDownLatch(1); + + @Override + public void onClose(Status status, Metadata trailers) { + this.status = status; + closeLatch.countDown(); + } + + void clear() { + status = null; + closeLatch = new CountDownLatch(1); + } + + void verifyDescription(String description, long timeoutMs) throws InterruptedException { + closeLatch.await(timeoutMs, TimeUnit.MILLISECONDS); + assertNotNull(status); + assertThat(status.getDescription()).contains(description); + } + } }