Skip to content

Commit

Permalink
alts: Forward-fix of ALTS queuing of handshake requests. (#6906)
Browse files Browse the repository at this point in the history
* alts: Forward-fix of ALTS queuing of handshake requests.
  • Loading branch information
matthewstevenson88 authored Jan 11, 2024
1 parent 6ce73bf commit 953d12a
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 15 deletions.
12 changes: 9 additions & 3 deletions credentials/alts/alts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import (
)

const (
defaultTestLongTimeout = 10 * time.Second
defaultTestLongTimeout = 60 * time.Second
defaultTestShortTimeout = 10 * time.Millisecond
)

Expand Down Expand Up @@ -392,17 +392,23 @@ func establishAltsConnection(t *testing.T, handshakerAddress, serverAddress stri
ctx, cancel := context.WithTimeout(context.Background(), defaultTestLongTimeout)
defer cancel()
c := testgrpc.NewTestServiceClient(conn)
success := false
for ; ctx.Err() == nil; <-time.After(defaultTestShortTimeout) {
_, err = c.UnaryCall(ctx, &testpb.SimpleRequest{})
if err == nil {
success = true
break
}
if code := status.Code(err); code == codes.Unavailable {
// The server is not ready yet. Try again.
if code := status.Code(err); code == codes.Unavailable || code == codes.DeadlineExceeded {
// The server is not ready yet or there were too many concurrent handshakes.
// Try again.
continue
}
t.Fatalf("c.UnaryCall() failed: %v", err)
}
if !success {
t.Fatalf("c.UnaryCall() timed out after %v", defaultTestShortTimeout)
}
}

func startFakeHandshakerService(t *testing.T, wait *sync.WaitGroup) (stop func(), address string) {
Expand Down
10 changes: 4 additions & 6 deletions credentials/alts/internal/handshaker/handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ var (
// control number of concurrent created (but not closed) handshakes.
clientHandshakes = semaphore.NewWeighted(int64(envconfig.ALTSMaxConcurrentHandshakes))
serverHandshakes = semaphore.NewWeighted(int64(envconfig.ALTSMaxConcurrentHandshakes))
// errDropped occurs when maxPendingHandshakes is reached.
errDropped = errors.New("maximum number of concurrent ALTS handshakes is reached")
// errOutOfBound occurs when the handshake service returns a consumed
// bytes value larger than the buffer that was passed to it originally.
errOutOfBound = errors.New("handshaker service consumed bytes value is out-of-bound")
Expand Down Expand Up @@ -156,8 +154,8 @@ func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn,
// ClientHandshake starts and completes a client ALTS handshake for GCP. Once
// done, ClientHandshake returns a secure connection.
func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
if !clientHandshakes.TryAcquire(1) {
return nil, nil, errDropped
if err := clientHandshakes.Acquire(ctx, 1); err != nil {
return nil, nil, err
}
defer clientHandshakes.Release(1)

Expand Down Expand Up @@ -209,8 +207,8 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent
// ServerHandshake starts and completes a server ALTS handshake for GCP. Once
// done, ServerHandshake returns a secure connection.
func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
if !serverHandshakes.TryAcquire(1) {
return nil, nil, errDropped
if err := serverHandshakes.Acquire(ctx, 1); err != nil {
return nil, nil, err
}
defer serverHandshakes.Release(1)

Expand Down
12 changes: 6 additions & 6 deletions credentials/alts/internal/handshaker/handshaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,10 @@ func (s) TestClientHandshake(t *testing.T) {
}()
}

// Ensure all errors are expected.
// Ensure that there are no errors.
for i := 0; i < testCase.numberOfHandshakes; i++ {
if err := <-errc; err != nil && err != errDropped {
t.Errorf("ClientHandshake() = _, %v, want _, <nil> or %v", err, errDropped)
if err := <-errc; err != nil {
t.Errorf("ClientHandshake() = _, %v, want _, <nil>", err)
}
}

Expand Down Expand Up @@ -250,10 +250,10 @@ func (s) TestServerHandshake(t *testing.T) {
}()
}

// Ensure all errors are expected.
// Ensure that there are no errors.
for i := 0; i < testCase.numberOfHandshakes; i++ {
if err := <-errc; err != nil && err != errDropped {
t.Errorf("ServerHandshake() = _, %v, want _, <nil> or %v", err, errDropped)
if err := <-errc; err != nil {
t.Errorf("ServerHandshake() = _, %v, want _, <nil>", err)
}
}

Expand Down

0 comments on commit 953d12a

Please sign in to comment.