From 24a6b48bc8185bafb256c774d9b02534cff63893 Mon Sep 17 00:00:00 2001 From: Abhishek Ranjan <159750762+aranjans@users.noreply.github.com> Date: Thu, 13 Jun 2024 22:01:01 +0530 Subject: [PATCH] credentials/alts: fix defer in TestDial (#7301) --- .../internal/handshaker/service/service.go | 4 +- .../handshaker/service/service_test.go | 37 +++++++++---------- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/credentials/alts/internal/handshaker/service/service.go b/credentials/alts/internal/handshaker/service/service.go index e1cdafb980cd..b3af03590729 100644 --- a/credentials/alts/internal/handshaker/service/service.go +++ b/credentials/alts/internal/handshaker/service/service.go @@ -34,8 +34,6 @@ var ( // to a corresponding connection to a hypervisor handshaker service // instance. hsConnMap = make(map[string]*grpc.ClientConn) - // hsDialer will be reassigned in tests. - hsDialer = grpc.Dial ) // Dial dials the handshake service in the hypervisor. If a connection has @@ -50,7 +48,7 @@ func Dial(hsAddress string) (*grpc.ClientConn, error) { // Create a new connection to the handshaker service. Note that // this connection stays open until the application is closed. var err error - hsConn, err = hsDialer(hsAddress, grpc.WithTransportCredentials(insecure.NewCredentials())) + hsConn, err = grpc.Dial(hsAddress, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err } diff --git a/credentials/alts/internal/handshaker/service/service_test.go b/credentials/alts/internal/handshaker/service/service_test.go index 28b4af757206..74bfb7e9f7df 100644 --- a/credentials/alts/internal/handshaker/service/service_test.go +++ b/credentials/alts/internal/handshaker/service/service_test.go @@ -21,34 +21,33 @@ package service import ( "testing" - grpc "google.golang.org/grpc" + "google.golang.org/grpc/internal/grpctest" ) +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + const ( testAddress1 = "some_address_1" testAddress2 = "some_address_2" ) -func TestDial(t *testing.T) { - defer func() func() { - temp := hsDialer - hsDialer = func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { - return &grpc.ClientConn{}, nil - } - return func() { - hsDialer = temp - } - }() - +// TestDial verifies the behaviour of alts handshake when there are multiple Dials. +// If a connection has already been established, this function returns it. +// Otherwise, a new connection is created. +func (s) TestDial(t *testing.T) { // First call to Dial, it should create a connection to the server running // at the given address. conn1, err := Dial(testAddress1) if err != nil { t.Fatalf("first call to Dial(%v) failed: %v", testAddress1, err) } - if conn1 == nil { - t.Fatalf("first call to Dial(%v)=(nil, _), want not nil", testAddress1) - } + defer conn1.Close() if got, want := hsConnMap[testAddress1], conn1; got != want { t.Fatalf("hsConnMap[%v]=%v, want %v", testAddress1, got, want) } @@ -58,6 +57,7 @@ func TestDial(t *testing.T) { if err != nil { t.Fatalf("second call to Dial(%v) failed: %v", testAddress1, err) } + defer conn2.Close() if got, want := conn2, conn1; got != want { t.Fatalf("second call to Dial(%v)=(%v, _), want (%v,. _)", testAddress1, got, want) } @@ -65,15 +65,12 @@ func TestDial(t *testing.T) { t.Fatalf("hsConnMap[%v]=%v, want %v", testAddress1, got, want) } - // Third call to Dial using a different address should create a new - // connection. + // Third call to Dial using a different address should create a new connection. conn3, err := Dial(testAddress2) if err != nil { t.Fatalf("third call to Dial(%v) failed: %v", testAddress2, err) } - if conn3 == nil { - t.Fatalf("third call to Dial(%v)=(nil, _), want not nil", testAddress2) - } + defer conn3.Close() if got, want := hsConnMap[testAddress2], conn3; got != want { t.Fatalf("hsConnMap[%v]=%v, want %v", testAddress2, got, want) }