Skip to content

Commit

Permalink
all/tests: unmock ClientCall and ServerCall
Browse files Browse the repository at this point in the history
  • Loading branch information
carl-mastrangelo committed Aug 30, 2016
1 parent 3bf8d94 commit 48c6b3d
Show file tree
Hide file tree
Showing 8 changed files with 406 additions and 131 deletions.
53 changes: 47 additions & 6 deletions auth/src/test/java/io/grpc/auth/ClientAuthInterceptorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@

package io.grpc.auth;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.isA;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -102,8 +103,7 @@ public class ClientAuthInterceptorTest {
@Mock
Channel channel;

@Mock
ClientCall<String, Integer> call;
ClientCallRecorder call = new ClientCallRecorder();

ClientAuthInterceptor interceptor;

Expand All @@ -130,7 +130,8 @@ public void testCopyCredentialToHeaders() throws IOException {
interceptor.interceptCall(descriptor, CallOptions.DEFAULT, channel);
Metadata headers = new Metadata();
interceptedCall.start(listener, headers);
verify(call).start(listener, headers);
assertEquals(listener, call.responseListener);
assertEquals(headers, call.headers);

Iterable<String> authorization = headers.getAll(AUTHORIZATION);
Assert.assertArrayEquals(new String[]{"token1", "token2"},
Expand All @@ -150,7 +151,8 @@ public void testCredentialsThrows() throws IOException {
ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
Mockito.verify(listener).onClose(statusCaptor.capture(), isA(Metadata.class));
Assert.assertNull(headers.getAll(AUTHORIZATION));
Mockito.verify(call, never()).start(listener, headers);
assertNull(call.responseListener);
assertNull(call.headers);
Assert.assertEquals(Status.Code.UNAUTHENTICATED, statusCaptor.getValue().getCode());
Assert.assertNotNull(statusCaptor.getValue().getCause());
}
Expand All @@ -169,7 +171,8 @@ public AccessToken refreshAccessToken() throws IOException {
interceptor.interceptCall(descriptor, CallOptions.DEFAULT, channel);
Metadata headers = new Metadata();
interceptedCall.start(listener, headers);
verify(call).start(listener, headers);
assertEquals(listener, call.responseListener);
assertEquals(headers, call.headers);
Iterable<String> authorization = headers.getAll(AUTHORIZATION);
Assert.assertArrayEquals(new String[]{"Bearer allyourbase"},
Iterables.toArray(authorization, String.class));
Expand All @@ -191,4 +194,42 @@ public void verifyServiceUri() throws IOException {
verify(credentials).getRequestMetadata(URI.create("https://example.com:123/a.service"));
interceptedCall.cancel("Cancel for test", null);
}

private static final class ClientCallRecorder extends ClientCall<String, Integer> {
private ClientCall.Listener<Integer> responseListener;
private Metadata headers;
private int numMessages;
private String cancelMessage;
private Throwable cancelCause;
private boolean halfClosed;
private String sentMessage;

@Override
public void start(ClientCall.Listener<Integer> responseListener, Metadata headers) {
this.responseListener = responseListener;
this.headers = headers;
}

@Override
public void request(int numMessages) {
this.numMessages = numMessages;
}

@Override
public void cancel(String message, Throwable cause) {
this.cancelMessage = message;
this.cancelCause = cause;
}

@Override
public void halfClose() {
halfClosed = true;
}

@Override
public void sendMessage(String message) {
sentMessage = message;
}

}
}
112 changes: 77 additions & 35 deletions core/src/test/java/io/grpc/ClientInterceptorsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,16 @@

package io.grpc;

import static com.google.common.truth.Truth.assertThat;
import static java.util.concurrent.TimeUnit.NANOSECONDS;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.isA;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
Expand All @@ -61,8 +60,6 @@
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -75,8 +72,7 @@ public class ClientInterceptorsTest {
@Mock
private Channel channel;

@Mock
private ClientCall<String, Integer> call;
private BaseClientCall call = new BaseClientCall();

@Mock
private MethodDescriptor<String, Integer> method;
Expand All @@ -89,18 +85,6 @@ public class ClientInterceptorsTest {
when(channel.newCall(
Mockito.<MethodDescriptor<String, Integer>>any(), any(CallOptions.class)))
.thenReturn(call);

// Emulate the precondition checks in ChannelImpl.CallImpl
Answer<Void> checkStartCalled = new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) {
verify(call).start(Mockito.<ClientCall.Listener<Integer>>any(), Mockito.<Metadata>any());
return null;
}
};
doAnswer(checkStartCalled).when(call).request(anyInt());
doAnswer(checkStartCalled).when(call).halfClose();
doAnswer(checkStartCalled).when(call).sendMessage(Mockito.<String>any());
}

@Test(expected = NullPointerException.class)
Expand Down Expand Up @@ -290,11 +274,10 @@ public void start(ClientCall.Listener<RespT> responseListener, Metadata headers)
ClientCall<String, Integer> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT);
// start() on the intercepted call will eventually reach the call created by the real channel
interceptedCall.start(listener, new Metadata());
ArgumentCaptor<Metadata> captor = ArgumentCaptor.forClass(Metadata.class);
// The headers passed to the real channel call will contain the information inserted by the
// interceptor.
verify(call).start(same(listener), captor.capture());
assertEquals("abcd", captor.getValue().get(credKey));
assertSame(listener, call.listener);
assertEquals("abcd", call.headers.get(credKey));
}

@Test
Expand Down Expand Up @@ -327,12 +310,11 @@ public void onHeaders(Metadata headers) {
ClientCall<String, Integer> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT);
interceptedCall.start(listener, new Metadata());
// Capture the underlying call listener that will receive headers from the transport.
ArgumentCaptor<ClientCall.Listener<Integer>> captor = ArgumentCaptor.forClass(null);
verify(call).start(captor.capture(), Mockito.<Metadata>any());

Metadata inboundHeaders = new Metadata();
// Simulate that a headers arrives on the underlying call listener.
captor.getValue().onHeaders(inboundHeaders);
assertEquals(Arrays.asList(inboundHeaders), examinedHeaders);
call.listener.onHeaders(inboundHeaders);
assertThat(examinedHeaders).contains(inboundHeaders);
}

@Test
Expand All @@ -354,13 +336,14 @@ public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
ClientCall.Listener<Integer> listener = mock(ClientCall.Listener.class);
Metadata headers = new Metadata();
interceptedCall.start(listener, headers);
verify(call).start(same(listener), same(headers));
assertSame(listener, call.listener);
assertSame(headers, call.headers);
interceptedCall.sendMessage("request");
verify(call).sendMessage(eq("request"));
assertThat(call.messages).containsExactly("request");
interceptedCall.halfClose();
verify(call).halfClose();
assertTrue(call.halfClosed);
interceptedCall.request(1);
verify(call).request(1);
assertThat(call.requests).containsExactly(1);
}

@Test
Expand Down Expand Up @@ -392,7 +375,7 @@ protected void checkedStart(ClientCall.Listener<RespT> responseListener, Metadat
interceptedCall.sendMessage("request");
interceptedCall.halfClose();
interceptedCall.request(1);
verifyNoMoreInteractions(call);
call.done = true;
ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
verify(listener).onClose(captor.capture(), any(Metadata.class));
assertSame(error, captor.getValue().getCause());
Expand All @@ -406,7 +389,6 @@ protected void checkedStart(ClientCall.Listener<RespT> responseListener, Metadat
noop.halfClose();
noop.sendMessage(null);
assertFalse(noop.isReady());
verifyNoMoreInteractions(call);
}

@Test
Expand All @@ -432,12 +414,12 @@ public void customOptionAccessible() {
CallOptions callOptions = CallOptions.DEFAULT.withOption(customOption, "value");
ArgumentCaptor<CallOptions> passedOptions = ArgumentCaptor.forClass(CallOptions.class);
ClientInterceptor interceptor = spy(new NoopInterceptor());

Channel intercepted = ClientInterceptors.intercept(channel, interceptor);

assertSame(call, intercepted.newCall(method, callOptions));
verify(channel).newCall(same(method), same(callOptions));

verify(interceptor).interceptCall(same(method), passedOptions.capture(), isA(Channel.class));
assertSame("value", passedOptions.getValue().getOption(customOption));
}
Expand All @@ -449,4 +431,64 @@ public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT
return next.newCall(method, callOptions);
}
}

private static class BaseClientCall extends ClientCall<String, Integer> {
private boolean started;
private boolean done;
private ClientCall.Listener<Integer> listener;
private Metadata headers;
private List<Integer> requests = new ArrayList<Integer>();
private List<String> messages = new ArrayList<String>();
private boolean halfClosed;
private Throwable cancelCause;
private String cancelMessage;

@Override
public void start(ClientCall.Listener<Integer> listener, Metadata headers) {
checkNotDone();
started = true;
this.listener = listener;
this.headers = headers;
}

@Override
public void request(int numMessages) {
checkNotDone();
checkStarted();
requests.add(numMessages);
}

@Override
public void cancel(String message, Throwable cause) {
checkNotDone();
this.cancelMessage = message;
this.cancelCause = cause;
}

@Override
public void halfClose() {
checkNotDone();
checkStarted();
this.halfClosed = true;
}

@Override
public void sendMessage(String message) {
checkNotDone();
checkStarted();
messages.add(message);
}

private void checkNotDone() {
if (done) {
throw new IllegalStateException("no more methods should be called");
}
}

private void checkStarted() {
if (!started) {
throw new IllegalStateException("should have called start");
}
}
}
}
26 changes: 25 additions & 1 deletion core/src/test/java/io/grpc/ContextsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import static org.mockito.Mockito.mock;

import io.grpc.internal.FakeClock;

import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand All @@ -66,7 +67,30 @@ public class ContextsTest {
@SuppressWarnings("unchecked")
private MethodDescriptor<Object, Object> method = mock(MethodDescriptor.class);
@SuppressWarnings("unchecked")
private ServerCall<Object, Object> call = mock(ServerCall.class);
private ServerCall<Object, Object> call = new ServerCall<Object, Object>() {

@Override
public void request(int numMessages) {}

@Override
public void sendHeaders(Metadata headers) {}

@Override
public void sendMessage(Object message) {}

@Override
public void close(Status status, Metadata trailers) {}

@Override
public boolean isCancelled() {
return false;
}

@Override
public MethodDescriptor<Object, Object> getMethodDescriptor() {
return null;
}
};
private Metadata headers = new Metadata();

@Test
Expand Down
Loading

0 comments on commit 48c6b3d

Please sign in to comment.