Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

testing: Avoid using context.Background #3949

Merged
merged 7 commits into from
Nov 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 47 additions & 35 deletions balancer/grpclb/grpclb_test.go

Large diffs are not rendered by default.

13 changes: 9 additions & 4 deletions balancer_switching_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ func checkPickFirst(cc *ClientConn, servers []*server) error {
err error
)
connected := false
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for i := 0; i < 5000; i++ {
if err = cc.Invoke(context.Background(), "/foo/bar", &req, &reply); errorDesc(err) == servers[0].port {
if err = cc.Invoke(ctx, "/foo/bar", &req, &reply); errorDesc(err) == servers[0].port {
if connected {
// connected is set to false if peer is not server[0]. So if
// connected is true here, this is the second time we saw
Expand All @@ -100,9 +102,10 @@ func checkPickFirst(cc *ClientConn, servers []*server) error {
if !connected {
return fmt.Errorf("pickfirst is not in effect after 5 second, EmptyCall() = _, %v, want _, %v", err, servers[0].port)
}

// The following RPCs should all succeed with the first server.
for i := 0; i < 3; i++ {
err = cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
err = cc.Invoke(ctx, "/foo/bar", &req, &reply)
if errorDesc(err) != servers[0].port {
return fmt.Errorf("index %d: want peer %v, got peer %v", i, servers[0].port, err)
}
Expand All @@ -117,14 +120,16 @@ func checkRoundRobin(cc *ClientConn, servers []*server) error {
err error
)

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// Make sure connections to all servers are up.
for i := 0; i < 2; i++ {
// Do this check twice, otherwise the first RPC's transport may still be
// picked by the closing pickfirst balancer, and the test becomes flaky.
for _, s := range servers {
var up bool
for i := 0; i < 5000; i++ {
if err = cc.Invoke(context.Background(), "/foo/bar", &req, &reply); errorDesc(err) == s.port {
if err = cc.Invoke(ctx, "/foo/bar", &req, &reply); errorDesc(err) == s.port {
up = true
break
}
Expand All @@ -138,7 +143,7 @@ func checkRoundRobin(cc *ClientConn, servers []*server) error {

serverCount := len(servers)
for i := 0; i < 3*serverCount; i++ {
err = cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
err = cc.Invoke(ctx, "/foo/bar", &req, &reply)
if errorDesc(err) != servers[i%serverCount].port {
return fmt.Errorf("index %d: want peer %v, got peer %v", i, servers[i%serverCount].port, err)
}
Expand Down
6 changes: 4 additions & 2 deletions benchmark/primitives/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"time"
)

const defaultTestTimeout = 10 * time.Second

func BenchmarkCancelContextErrNoErr(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
for i := 0; i < b.N; i++ {
Expand Down Expand Up @@ -72,7 +74,7 @@ func BenchmarkCancelContextChannelGotErr(b *testing.B) {
}

func BenchmarkTimerContextErrNoErr(b *testing.B) {
ctx, cancel := context.WithTimeout(context.Background(), 24*time.Hour)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
for i := 0; i < b.N; i++ {
if err := ctx.Err(); err != nil {
b.Fatal("error")
Expand All @@ -92,7 +94,7 @@ func BenchmarkTimerContextErrGotErr(b *testing.B) {
}

func BenchmarkTimerContextChannelNoErr(b *testing.B) {
ctx, cancel := context.WithTimeout(context.Background(), 24*time.Hour)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
for i := 0; i < b.N; i++ {
select {
case <-ctx.Done():
Expand Down
26 changes: 19 additions & 7 deletions call_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ var (
canceled = 0
)

const defaultTestTimeout = 10 * time.Second

type testCodec struct {
}

Expand Down Expand Up @@ -237,7 +239,8 @@ func (s) TestUnaryClientInterceptor(t *testing.T) {
}()

var reply string
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
Expand Down Expand Up @@ -305,7 +308,8 @@ func (s) TestChainUnaryClientInterceptor(t *testing.T) {
}()

var reply string
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse+"321" {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
Expand Down Expand Up @@ -346,7 +350,8 @@ func (s) TestChainOnBaseUnaryClientInterceptor(t *testing.T) {
}()

var reply string
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
Expand Down Expand Up @@ -407,7 +412,8 @@ func (s) TestChainStreamClientInterceptor(t *testing.T) {
server.stop()
}()

ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
_, err := cc.NewStream(parentCtx, &StreamDesc{}, "/foo/bar")
if err != nil {
Expand All @@ -418,7 +424,9 @@ func (s) TestChainStreamClientInterceptor(t *testing.T) {
func (s) TestInvoke(t *testing.T) {
server, cc := setUp(t, 0, math.MaxUint32)
var reply string
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := cc.Invoke(ctx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
}
cc.Close()
Expand All @@ -429,7 +437,9 @@ func (s) TestInvokeLargeErr(t *testing.T) {
server, cc := setUp(t, 0, math.MaxUint32)
var reply string
req := "hello"
err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
err := cc.Invoke(ctx, "/foo/bar", &req, &reply)
if _, ok := status.FromError(err); !ok {
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
}
Expand All @@ -445,7 +455,9 @@ func (s) TestInvokeErrorSpecialChars(t *testing.T) {
server, cc := setUp(t, 0, math.MaxUint32)
var reply string
req := "weird error"
err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
err := cc.Invoke(ctx, "/foo/bar", &req, &reply)
if _, ok := status.FromError(err); !ok {
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
}
Expand Down
4 changes: 3 additions & 1 deletion channelz/service/service_sktopt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,10 @@ func (s) TestGetSocketOptions(t *testing.T) {
ids[i] = channelz.RegisterNormalSocket(s, svrID, strconv.Itoa(i))
defer channelz.RemoveEntry(ids[i])
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for i, s := range ss {
resp, _ := svr.GetSocket(context.Background(), &channelzpb.GetSocketRequest{SocketId: ids[i]})
resp, _ := svr.GetSocket(ctx, &channelzpb.GetSocketRequest{SocketId: ids[i]})
metrics := resp.GetSocket()
if !reflect.DeepEqual(metrics.GetRef(), &channelzpb.SocketRef{SocketId: ids[i], Name: strconv.Itoa(i)}) || !reflect.DeepEqual(socketProtoToStruct(metrics), s) {
t.Fatalf("resp.GetSocket() want: metrics.GetRef() = %#v and %#v, got: metrics.GetRef() = %#v and %#v", &channelzpb.SocketRef{SocketId: ids[i], Name: strconv.Itoa(i)}, s, metrics.GetRef(), socketProtoToStruct(metrics))
Expand Down
38 changes: 27 additions & 11 deletions channelz/service/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ var protoToSocketOpt protoToSocketOptFunc
// TODO: Go1.7 is no longer supported - does this need a change?
var emptyTime time.Time

const defaultTestTimeout = 10 * time.Second

type dummyChannel struct {
state connectivity.State
target string
Expand Down Expand Up @@ -327,7 +329,9 @@ func (s) TestGetTopChannels(t *testing.T) {
defer channelz.RemoveEntry(id)
}
s := newCZServer()
resp, _ := s.GetTopChannels(context.Background(), &channelzpb.GetTopChannelsRequest{StartChannelId: 0})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
resp, _ := s.GetTopChannels(ctx, &channelzpb.GetTopChannelsRequest{StartChannelId: 0})
if !resp.GetEnd() {
t.Fatalf("resp.GetEnd() want true, got %v", resp.GetEnd())
}
Expand All @@ -340,7 +344,7 @@ func (s) TestGetTopChannels(t *testing.T) {
id := channelz.RegisterChannel(tcs[0], 0, "")
defer channelz.RemoveEntry(id)
}
resp, _ = s.GetTopChannels(context.Background(), &channelzpb.GetTopChannelsRequest{StartChannelId: 0})
resp, _ = s.GetTopChannels(ctx, &channelzpb.GetTopChannelsRequest{StartChannelId: 0})
if resp.GetEnd() {
t.Fatalf("resp.GetEnd() want false, got %v", resp.GetEnd())
}
Expand Down Expand Up @@ -374,7 +378,9 @@ func (s) TestGetServers(t *testing.T) {
defer channelz.RemoveEntry(id)
}
svr := newCZServer()
resp, _ := svr.GetServers(context.Background(), &channelzpb.GetServersRequest{StartServerId: 0})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
resp, _ := svr.GetServers(ctx, &channelzpb.GetServersRequest{StartServerId: 0})
if !resp.GetEnd() {
t.Fatalf("resp.GetEnd() want true, got %v", resp.GetEnd())
}
Expand All @@ -387,7 +393,7 @@ func (s) TestGetServers(t *testing.T) {
id := channelz.RegisterServer(ss[0], "")
defer channelz.RemoveEntry(id)
}
resp, _ = svr.GetServers(context.Background(), &channelzpb.GetServersRequest{StartServerId: 0})
resp, _ = svr.GetServers(ctx, &channelzpb.GetServersRequest{StartServerId: 0})
if resp.GetEnd() {
t.Fatalf("resp.GetEnd() want false, got %v", resp.GetEnd())
}
Expand All @@ -407,7 +413,9 @@ func (s) TestGetServerSockets(t *testing.T) {
defer channelz.RemoveEntry(id)
}
svr := newCZServer()
resp, _ := svr.GetServerSockets(context.Background(), &channelzpb.GetServerSocketsRequest{ServerId: svrID, StartSocketId: 0})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
resp, _ := svr.GetServerSockets(ctx, &channelzpb.GetServerSocketsRequest{ServerId: svrID, StartSocketId: 0})
if !resp.GetEnd() {
t.Fatalf("resp.GetEnd() want: true, got: %v", resp.GetEnd())
}
Expand All @@ -424,7 +432,7 @@ func (s) TestGetServerSockets(t *testing.T) {
id := channelz.RegisterNormalSocket(&dummySocket{}, svrID, "")
defer channelz.RemoveEntry(id)
}
resp, _ = svr.GetServerSockets(context.Background(), &channelzpb.GetServerSocketsRequest{ServerId: svrID, StartSocketId: 0})
resp, _ = svr.GetServerSockets(ctx, &channelzpb.GetServerSocketsRequest{ServerId: svrID, StartSocketId: 0})
if resp.GetEnd() {
t.Fatalf("resp.GetEnd() want false, got %v", resp.GetEnd())
}
Expand All @@ -446,9 +454,11 @@ func (s) TestGetServerSocketsNonZeroStartID(t *testing.T) {
defer channelz.RemoveEntry(id)
}
svr := newCZServer()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// Make GetServerSockets with startID = ids[1]+1, so socket-1 won't be
// included in the response.
resp, _ := svr.GetServerSockets(context.Background(), &channelzpb.GetServerSocketsRequest{ServerId: svrID, StartSocketId: ids[1] + 1})
resp, _ := svr.GetServerSockets(ctx, &channelzpb.GetServerSocketsRequest{ServerId: svrID, StartSocketId: ids[1] + 1})
if !resp.GetEnd() {
t.Fatalf("resp.GetEnd() want: true, got: %v", resp.GetEnd())
}
Expand Down Expand Up @@ -512,7 +522,9 @@ func (s) TestGetChannel(t *testing.T) {
defer channelz.RemoveEntry(id)
}
svr := newCZServer()
resp, _ := svr.GetChannel(context.Background(), &channelzpb.GetChannelRequest{ChannelId: ids[0]})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
resp, _ := svr.GetChannel(ctx, &channelzpb.GetChannelRequest{ChannelId: ids[0]})
metrics := resp.GetChannel()
subChans := metrics.GetSubchannelRef()
if len(subChans) != 1 || subChans[0].GetName() != refNames[2] || subChans[0].GetSubchannelId() != ids[2] {
Expand Down Expand Up @@ -552,7 +564,7 @@ func (s) TestGetChannel(t *testing.T) {
}
}
}
resp, _ = svr.GetChannel(context.Background(), &channelzpb.GetChannelRequest{ChannelId: ids[1]})
resp, _ = svr.GetChannel(ctx, &channelzpb.GetChannelRequest{ChannelId: ids[1]})
metrics = resp.GetChannel()
nestedChans = metrics.GetChannelRef()
if len(nestedChans) != 1 || nestedChans[0].GetName() != refNames[3] || nestedChans[0].GetChannelId() != ids[3] {
Expand Down Expand Up @@ -598,7 +610,9 @@ func (s) TestGetSubChannel(t *testing.T) {
defer channelz.RemoveEntry(id)
}
svr := newCZServer()
resp, _ := svr.GetSubchannel(context.Background(), &channelzpb.GetSubchannelRequest{SubchannelId: ids[1]})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
resp, _ := svr.GetSubchannel(ctx, &channelzpb.GetSubchannelRequest{SubchannelId: ids[1]})
metrics := resp.GetSubchannel()
want := map[int64]string{
ids[2]: refNames[2],
Expand Down Expand Up @@ -719,8 +733,10 @@ func (s) TestGetSocket(t *testing.T) {
ids[i] = channelz.RegisterNormalSocket(s, svrID, strconv.Itoa(i))
defer channelz.RemoveEntry(ids[i])
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for i, s := range ss {
resp, _ := svr.GetSocket(context.Background(), &channelzpb.GetSocketRequest{SocketId: ids[i]})
resp, _ := svr.GetSocket(ctx, &channelzpb.GetSocketRequest{SocketId: ids[i]})
metrics := resp.GetSocket()
if !reflect.DeepEqual(metrics.GetRef(), &channelzpb.SocketRef{SocketId: ids[i], Name: strconv.Itoa(i)}) || !reflect.DeepEqual(socketProtoToStruct(metrics), s) {
t.Fatalf("resp.GetSocket() want: metrics.GetRef() = %#v and %#v, got: metrics.GetRef() = %#v and %#v", &channelzpb.SocketRef{SocketId: ids[i], Name: strconv.Itoa(i)}, s, metrics.GetRef(), socketProtoToStruct(metrics))
Expand Down
19 changes: 16 additions & 3 deletions credentials/alts/internal/handshaker/handshaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ var (
}
)

const defaultTestTimeout = 10 * time.Second

// testRPCStream mimics a altspb.HandshakerService_DoHandshakeClient object.
type testRPCStream struct {
grpc.ClientStream
Expand Down Expand Up @@ -133,6 +135,10 @@ func (s) TestClientHandshake(t *testing.T) {
} {
errc := make(chan error)
stat.Reset()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

for i := 0; i < testCase.numberOfHandshakes; i++ {
stream := &testRPCStream{
t: t,
Expand All @@ -155,7 +161,7 @@ func (s) TestClientHandshake(t *testing.T) {
side: core.ClientSide,
}
go func() {
_, context, err := chs.ClientHandshake(context.Background())
_, context, err := chs.ClientHandshake(ctx)
if err == nil && context == nil {
panic("expected non-nil ALTS context")
}
Expand Down Expand Up @@ -188,6 +194,10 @@ func (s) TestServerHandshake(t *testing.T) {
} {
errc := make(chan error)
stat.Reset()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

for i := 0; i < testCase.numberOfHandshakes; i++ {
stream := &testRPCStream{
t: t,
Expand All @@ -207,7 +217,7 @@ func (s) TestServerHandshake(t *testing.T) {
side: core.ServerSide,
}
go func() {
_, context, err := shs.ServerHandshake(context.Background())
_, context, err := shs.ServerHandshake(ctx)
if err == nil && context == nil {
panic("expected non-nil ALTS context")
}
Expand Down Expand Up @@ -258,7 +268,10 @@ func (s) TestPeerNotResponding(t *testing.T) {
},
side: core.ClientSide,
}
_, context, err := chs.ClientHandshake(context.Background())

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
_, context, err := chs.ClientHandshake(ctx)
chs.Close()
if context != nil {
t.Error("expected non-nil ALTS context")
Expand Down
9 changes: 7 additions & 2 deletions credentials/alts/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"os"
"strings"
"testing"
"time"

"google.golang.org/grpc/codes"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
Expand All @@ -37,6 +38,8 @@ const (
testServiceAccount1 = "service_account1"
testServiceAccount2 = "service_account2"
testServiceAccount3 = "service_account3"

defaultTestTimeout = 10 * time.Second
)

func setupManufacturerReader(testOS string, reader func() (io.Reader, error)) func() {
Expand Down Expand Up @@ -101,7 +104,8 @@ func (s) TestIsRunningOnGCPNoProductNameFile(t *testing.T) {
}

func (s) TestAuthInfoFromContext(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
altsAuthInfo := &fakeALTSAuthInfo{}
p := &peer.Peer{
AuthInfo: altsAuthInfo,
Expand Down Expand Up @@ -158,7 +162,8 @@ func (s) TestAuthInfoFromPeer(t *testing.T) {
}

func (s) TestClientAuthorizationCheck(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
altsAuthInfo := &fakeALTSAuthInfo{testServiceAccount1}
p := &peer.Peer{
AuthInfo: altsAuthInfo,
Expand Down
Loading