Skip to content

Commit

Permalink
credentials/alts: fix defer in TestDial (#7301)
Browse files Browse the repository at this point in the history
  • Loading branch information
aranjans authored Jun 13, 2024
1 parent e37c6e8 commit 24a6b48
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 23 deletions.
4 changes: 1 addition & 3 deletions credentials/alts/internal/handshaker/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ var (
// to a corresponding connection to a hypervisor handshaker service
// instance.
hsConnMap = make(map[string]*grpc.ClientConn)
// hsDialer will be reassigned in tests.
hsDialer = grpc.Dial
)

// Dial dials the handshake service in the hypervisor. If a connection has
Expand All @@ -50,7 +48,7 @@ func Dial(hsAddress string) (*grpc.ClientConn, error) {
// Create a new connection to the handshaker service. Note that
// this connection stays open until the application is closed.
var err error
hsConn, err = hsDialer(hsAddress, grpc.WithTransportCredentials(insecure.NewCredentials()))
hsConn, err = grpc.Dial(hsAddress, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
Expand Down
37 changes: 17 additions & 20 deletions credentials/alts/internal/handshaker/service/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,33 @@ package service
import (
"testing"

grpc "google.golang.org/grpc"
"google.golang.org/grpc/internal/grpctest"
)

type s struct {
grpctest.Tester
}

func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}

const (
testAddress1 = "some_address_1"
testAddress2 = "some_address_2"
)

func TestDial(t *testing.T) {
defer func() func() {
temp := hsDialer
hsDialer = func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
return &grpc.ClientConn{}, nil
}
return func() {
hsDialer = temp
}
}()

// TestDial verifies the behaviour of alts handshake when there are multiple Dials.
// If a connection has already been established, this function returns it.
// Otherwise, a new connection is created.
func (s) TestDial(t *testing.T) {
// First call to Dial, it should create a connection to the server running
// at the given address.
conn1, err := Dial(testAddress1)
if err != nil {
t.Fatalf("first call to Dial(%v) failed: %v", testAddress1, err)
}
if conn1 == nil {
t.Fatalf("first call to Dial(%v)=(nil, _), want not nil", testAddress1)
}
defer conn1.Close()
if got, want := hsConnMap[testAddress1], conn1; got != want {
t.Fatalf("hsConnMap[%v]=%v, want %v", testAddress1, got, want)
}
Expand All @@ -58,22 +57,20 @@ func TestDial(t *testing.T) {
if err != nil {
t.Fatalf("second call to Dial(%v) failed: %v", testAddress1, err)
}
defer conn2.Close()
if got, want := conn2, conn1; got != want {
t.Fatalf("second call to Dial(%v)=(%v, _), want (%v,. _)", testAddress1, got, want)
}
if got, want := hsConnMap[testAddress1], conn1; got != want {
t.Fatalf("hsConnMap[%v]=%v, want %v", testAddress1, got, want)
}

// Third call to Dial using a different address should create a new
// connection.
// Third call to Dial using a different address should create a new connection.
conn3, err := Dial(testAddress2)
if err != nil {
t.Fatalf("third call to Dial(%v) failed: %v", testAddress2, err)
}
if conn3 == nil {
t.Fatalf("third call to Dial(%v)=(nil, _), want not nil", testAddress2)
}
defer conn3.Close()
if got, want := hsConnMap[testAddress2], conn3; got != want {
t.Fatalf("hsConnMap[%v]=%v, want %v", testAddress2, got, want)
}
Expand Down

0 comments on commit 24a6b48

Please sign in to comment.