Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Revert "credentials/alts: defer ALTS stream creation until handshake …" #6179

Merged
merged 3 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 36 additions & 18 deletions credentials/alts/internal/handshaker/handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,16 @@ func DefaultServerHandshakerOptions() *ServerHandshakerOptions {
// and server options (server options struct does not exist now. When
// caller can provide endpoints, it should be created.

// altsHandshaker is used to complete a ALTS handshaking between client and
// altsHandshaker is used to complete an ALTS handshake between client and
// server. This handshaker talks to the ALTS handshaker service in the metadata
// server.
type altsHandshaker struct {
// RPC stream used to access the ALTS Handshaker service.
stream altsgrpc.HandshakerService_DoHandshakeClient
// the connection to the peer.
conn net.Conn
// a virtual connection to the ALTS handshaker service.
clientConn *grpc.ClientConn
// client handshake options.
clientOpts *ClientHandshakerOptions
// server handshake options.
Expand All @@ -154,39 +156,33 @@ type altsHandshaker struct {
side core.Side
}

// NewClientHandshaker creates a ALTS handshaker for GCP which contains an RPC
// stub created using the passed conn and used to talk to the ALTS Handshaker
// NewClientHandshaker creates a core.Handshaker that performs a client-side
// ALTS handshake by acting as a proxy between the peer and the ALTS handshaker
// service in the metadata server.
func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ClientHandshakerOptions) (core.Handshaker, error) {
stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx)
if err != nil {
return nil, err
}
return &altsHandshaker{
stream: stream,
stream: nil,
conn: c,
clientConn: conn,
clientOpts: opts,
side: core.ClientSide,
}, nil
}

// NewServerHandshaker creates a ALTS handshaker for GCP which contains an RPC
// stub created using the passed conn and used to talk to the ALTS Handshaker
// NewServerHandshaker creates a core.Handshaker that performs a server-side
// ALTS handshake by acting as a proxy between the peer and the ALTS handshaker
// service in the metadata server.
func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ServerHandshakerOptions) (core.Handshaker, error) {
stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx)
if err != nil {
return nil, err
}
return &altsHandshaker{
stream: stream,
stream: nil,
conn: c,
clientConn: conn,
serverOpts: opts,
side: core.ServerSide,
}, nil
}

// ClientHandshake starts and completes a client ALTS handshaking for GCP. Once
// ClientHandshake starts and completes a client ALTS handshake for GCP. Once
// done, ClientHandshake returns a secure connection.
func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
if !acquire() {
Expand All @@ -198,6 +194,16 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent
return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client handshaker")
}

// TODO(matthewstevenson88): Change unit tests to use public APIs so
// that h.stream can unconditionally be set based on h.clientConn.
if h.stream == nil {
stream, err := altsgrpc.NewHandshakerServiceClient(h.clientConn).DoHandshake(ctx)
if err != nil {
return nil, nil, fmt.Errorf("failed to establish stream to ALTS handshaker service: %v", err)
}
h.stream = stream
}

// Create target identities from service account list.
targetIdentities := make([]*altspb.Identity, 0, len(h.clientOpts.TargetServiceAccounts))
for _, account := range h.clientOpts.TargetServiceAccounts {
Expand Down Expand Up @@ -229,7 +235,7 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent
return conn, authInfo, nil
}

// ServerHandshake starts and completes a server ALTS handshaking for GCP. Once
// ServerHandshake starts and completes a server ALTS handshake for GCP. Once
// done, ServerHandshake returns a secure connection.
func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
if !acquire() {
Expand All @@ -241,6 +247,16 @@ func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credent
return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server handshaker")
}

// TODO(matthewstevenson88): Change unit tests to use public APIs so
// that h.stream can unconditionally be set based on h.clientConn.
if h.stream == nil {
stream, err := altsgrpc.NewHandshakerServiceClient(h.clientConn).DoHandshake(ctx)
if err != nil {
return nil, nil, fmt.Errorf("failed to establish stream to ALTS handshaker service: %v", err)
}
h.stream = stream
}

p := make([]byte, frameLimit)
n, err := h.conn.Read(p)
if err != nil {
Expand Down Expand Up @@ -371,5 +387,7 @@ func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []b
// Close terminates the Handshaker. It should be called when the caller obtains
// the secure connection.
func (h *altsHandshaker) Close() {
h.stream.CloseSend()
if h.stream != nil {
h.stream.CloseSend()
}
}
66 changes: 66 additions & 0 deletions credentials/alts/internal/handshaker/handshaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
grpc "google.golang.org/grpc"
core "google.golang.org/grpc/credentials/alts/internal"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
Expand Down Expand Up @@ -283,3 +285,67 @@ func (s) TestPeerNotResponding(t *testing.T) {
t.Errorf("ClientHandshake() = %v, want %v", got, want)
}
}

func (s) TestNewClientHandshaker(t *testing.T) {
conn := testutil.NewTestConn(nil, nil)
clientConn := &grpc.ClientConn{}
opts := &ClientHandshakerOptions{}
hs, err := NewClientHandshaker(context.Background(), clientConn, conn, opts)
if err != nil {
t.Errorf("NewClientHandshaker returned unexpected error: %v", err)
}
expectedHs := &altsHandshaker{
stream: nil,
conn: conn,
clientConn: clientConn,
clientOpts: opts,
serverOpts: nil,
side: core.ClientSide,
}
cmpOpts := []cmp.Option{
cmp.AllowUnexported(altsHandshaker{}),
cmpopts.IgnoreFields(altsHandshaker{}, "conn", "clientConn"),
}
if got, want := hs.(*altsHandshaker), expectedHs; !cmp.Equal(got, want, cmpOpts...) {
t.Errorf("NewClientHandshaker() returned unexpected handshaker: got: %v, want: %v", got, want)
}
if hs.(*altsHandshaker).stream != nil {
t.Errorf("NewClientHandshaker() returned handshaker with non-nil stream")
}
if hs.(*altsHandshaker).clientConn != clientConn {
t.Errorf("NewClientHandshaker() returned handshaker with unexpected clientConn")
}
hs.Close()
}

func (s) TestNewServerHandshaker(t *testing.T) {
conn := testutil.NewTestConn(nil, nil)
clientConn := &grpc.ClientConn{}
opts := &ServerHandshakerOptions{}
hs, err := NewServerHandshaker(context.Background(), clientConn, conn, opts)
if err != nil {
t.Errorf("NewServerHandshaker returned unexpected error: %v", err)
}
expectedHs := &altsHandshaker{
stream: nil,
conn: conn,
clientConn: clientConn,
clientOpts: nil,
serverOpts: opts,
side: core.ServerSide,
}
cmpOpts := []cmp.Option{
cmp.AllowUnexported(altsHandshaker{}),
cmpopts.IgnoreFields(altsHandshaker{}, "conn", "clientConn"),
}
if got, want := hs.(*altsHandshaker), expectedHs; !cmp.Equal(got, want, cmpOpts...) {
t.Errorf("NewServerHandshaker() returned unexpected handshaker: got: %v, want: %v", got, want)
}
if hs.(*altsHandshaker).stream != nil {
t.Errorf("NewServerHandshaker() returned handshaker with non-nil stream")
}
if hs.(*altsHandshaker).clientConn != clientConn {
t.Errorf("NewServerHandshaker() returned handshaker with unexpected clientConn")
}
hs.Close()
}