diff --git a/picker_wrapper.go b/picker_wrapper.go index f9625496c403..45baa2ae13da 100644 --- a/picker_wrapper.go +++ b/picker_wrapper.go @@ -120,6 +120,14 @@ func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer. bp.mu.Unlock() select { case <-ctx.Done(): + if connectionErr := bp.connectionError(); connectionErr != nil { + switch ctx.Err() { + case context.DeadlineExceeded: + return nil, nil, status.Errorf(codes.DeadlineExceeded, "latest connection error: %v", connectionErr) + case context.Canceled: + return nil, nil, status.Errorf(codes.Canceled, "latest connection error: %v", connectionErr) + } + } return nil, nil, ctx.Err() case <-ch: } diff --git a/test/end2end_test.go b/test/end2end_test.go index 6d81b2d22f15..8513d6653fb3 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -6642,6 +6642,29 @@ func (s) TestFailFastRPCErrorOnBadCertificates(t *testing.T) { te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg) } +func (s) TestWaitForReadyRPCErrorOnBadCertificates(t *testing.T) { + te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: "round_robin"}) + te.startServer(&testServer{security: te.e.security}) + defer te.tearDown() + + opts := []grpc.DialOption{grpc.WithTransportCredentials(clientAlwaysFailCred{})} + dctx, dcancel := context.WithTimeout(context.Background(), 10*time.Second) + defer dcancel() + cc, err := grpc.DialContext(dctx, te.srvAddr, opts...) + if err != nil { + t.Fatalf("Dial(_) = %v, want %v", err, nil) + } + defer cc.Close() + + tc := testpb.NewTestServiceClient(cc) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + if _, err = tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { + return + } + te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg) +} + func (s) TestRPCTimeout(t *testing.T) { for _, e := range listTestEnv() { testRPCTimeout(t, e)