Skip to content

Commit

Permalink
fix PerRPCCredentials w/RequireTransportSecurity and security levels
Browse files Browse the repository at this point in the history
  • Loading branch information
yihuazhang committed Nov 3, 2020
1 parent 15ae9fc commit 21ae702
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 26 deletions.
4 changes: 2 additions & 2 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ type CommonAuthInfo struct {
}

// GetCommonAuthInfo returns the pointer to CommonAuthInfo struct.
func (c *CommonAuthInfo) GetCommonAuthInfo() *CommonAuthInfo {
func (c CommonAuthInfo) GetCommonAuthInfo() CommonAuthInfo {
return c
}

Expand Down Expand Up @@ -231,7 +231,7 @@ func ClientHandshakeInfoFromContext(ctx context.Context) ClientHandshakeInfo {
// This API is experimental.
func CheckSecurityLevel(ctx context.Context, level SecurityLevel) error {
type internalInfo interface {
GetCommonAuthInfo() *CommonAuthInfo
GetCommonAuthInfo() CommonAuthInfo
}
ri, _ := RequestInfoFromContext(ctx)
if ri.AuthInfo == nil {
Expand Down
12 changes: 6 additions & 6 deletions credentials/insecure/insecure.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ func NewCredentials() credentials.TransportCredentials {
type insecureTC struct{}

func (insecureTC) ClientHandshake(ctx context.Context, _ string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return conn, Info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
return conn, info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
}

func (insecureTC) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return conn, Info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
return conn, info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
}

func (insecureTC) Info() credentials.ProtocolInfo {
Expand All @@ -62,13 +62,13 @@ func (insecureTC) OverrideServerName(string) error {
return nil
}

// Info contains the auth information for an insecure connection.
// info contains the auth information for an insecure connection.
// It implements the AuthInfo interface.
type Info struct {
type info struct {
credentials.CommonAuthInfo
}

// AuthType returns the type of Info as a string.
func (Info) AuthType() string {
// AuthType returns the type of info as a string.
func (info) AuthType() string {
return "insecure"
}
12 changes: 6 additions & 6 deletions credentials/local/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ import (
"google.golang.org/grpc/credentials"
)

// Info contains the auth information for a local connection.
// info contains the auth information for a local connection.
// It implements the AuthInfo interface.
type Info struct {
type info struct {
credentials.CommonAuthInfo
}

// AuthType returns the type of Info as a string.
func (Info) AuthType() string {
// AuthType returns the type of info as a string.
func (info) AuthType() string {
return "local"
}

Expand Down Expand Up @@ -79,15 +79,15 @@ func (*localTC) ClientHandshake(ctx context.Context, authority string, conn net.
if err != nil {
return nil, nil, err
}
return conn, Info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
return conn, info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
}

func (*localTC) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
secLevel, err := getSecurityLevel(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
if err != nil {
return nil, nil, err
}
return conn, Info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
return conn, info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
}

// NewCredentials returns a local credential implementing credentials.TransportCredentials.
Expand Down
18 changes: 14 additions & 4 deletions credentials/local/local_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,20 @@ func serverAndClientHandshake(lis net.Listener) (credentials.SecurityLevel, erro
if serverHandleResult.err != nil {
return credentials.Invalid, fmt.Errorf("Error at server-side: %v", serverHandleResult.err)
}
clientLocal, _ := clientAuthInfo.(Info)
serverLocal, _ := serverHandleResult.authInfo.(Info)
clientSecLevel := clientLocal.CommonAuthInfo.SecurityLevel
serverSecLevel := serverLocal.CommonAuthInfo.SecurityLevel
var clientSecLevel, serverSecLevel credentials.SecurityLevel
type internalInfo interface {
GetCommonAuthInfo() credentials.CommonAuthInfo
}
if info, ok := clientAuthInfo.(internalInfo); ok {
clientSecLevel = info.GetCommonAuthInfo().SecurityLevel
} else {
return credentials.Invalid, fmt.Errorf("Error at client-side: client's AuthInfo does not implement GetCommonAuthInfo().")
}
if info, ok := (serverHandleResult.authInfo).(internalInfo); ok {
serverSecLevel = info.GetCommonAuthInfo().SecurityLevel
} else {
return credentials.Invalid, fmt.Errorf("Error at server-side: server's AuthInfo does not implement GetCommonAuthInfo().")
}
if clientSecLevel != serverSecLevel {
return credentials.Invalid, fmt.Errorf("client's AuthInfo contains %s but server's AuthInfo contains %s", clientSecLevel.String(), serverSecLevel.String())
}
Expand Down
19 changes: 17 additions & 2 deletions internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,22 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context)
connectCtx = contextWithHandshakeInfo(connectCtx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
conn, authInfo, err = transportCreds.ClientHandshake(connectCtx, addr.ServerName, conn)
type internalInfo interface {
GetCommonAuthInfo() credentials.CommonAuthInfo
}
if err != nil {
return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err)
}
for _, cd := range perRPCCreds {
if cd.RequireTransportSecurity() {
if ci, ok := authInfo.(internalInfo); ok {
secLevel := ci.GetCommonAuthInfo().SecurityLevel
if secLevel != credentials.Invalid && secLevel < credentials.PrivacyAndIntegrity {
return nil, connectionErrorf(true, nil, "transport: cannot send secure credentials on an insecure connection")
}
}
}
}
isSecure = true
if transportCreds.Info().SecurityProtocol == "tls" {
scheme = "https"
Expand Down Expand Up @@ -557,8 +570,10 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call
// Note: if these credentials are provided both via dial options and call
// options, then both sets of credentials will be applied.
if callCreds := callHdr.Creds; callCreds != nil {
if !t.isSecure && callCreds.RequireTransportSecurity() {
return nil, status.Error(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection")
if callCreds.RequireTransportSecurity() {
if !t.isSecure || credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity) != nil {
return nil, status.Error(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection")
}
}
data, err := callCreds.GetRequestMetadata(ctx, audience)
if err != nil {
Expand Down
97 changes: 93 additions & 4 deletions test/insecure_creds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package test
import (
"context"
"net"
"strings"
"testing"
"time"

Expand All @@ -30,11 +31,23 @@ import (
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"

testpb "google.golang.org/grpc/test/grpc_testing"
)

const defaultTestTimeout = 5 * time.Second

// testLegacyPerRPCCredentials is a PerRPCCredentials that has yet incorporated security level.
type testLegacyPerRPCCredentials struct{}

func (cr testLegacyPerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return nil, nil
}

func (cr testLegacyPerRPCCredentials) RequireTransportSecurity() bool {
return true
}

// TestInsecureCreds tests the use of insecure creds on the server and client
// side, and verifies that expect security level and auth info are returned.
// Also verifies that this credential can interop with existing `WithInsecure`
Expand Down Expand Up @@ -73,11 +86,16 @@ func (s) TestInsecureCreds(t *testing.T) {
return nil, status.Error(codes.DataLoss, "Failed to get peer from ctx")
}
// Check security level.
info := pr.AuthInfo.(insecure.Info)
if at := info.AuthType(); at != "insecure" {
return nil, status.Errorf(codes.Unauthenticated, "Wrong AuthType: got %q, want insecure", at)
var secLevel credentials.SecurityLevel
type internalInfo interface {
GetCommonAuthInfo() credentials.CommonAuthInfo
}
if secLevel := info.CommonAuthInfo.SecurityLevel; secLevel != credentials.NoSecurity {
if info, ok := pr.AuthInfo.(internalInfo); ok {
secLevel = info.GetCommonAuthInfo().SecurityLevel
} else {
return nil, status.Errorf(codes.Unauthenticated, "peer.AuthInfo does not implement GetCommonAuthInfo()")
}
if secLevel != credentials.NoSecurity {
return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel, credentials.NoSecurity)
}
return &testpb.Empty{}, nil
Expand Down Expand Up @@ -122,3 +140,74 @@ func (s) TestInsecureCreds(t *testing.T) {
})
}
}

func (s) TestInsecureCredsWithPerRPCCredentials(t *testing.T) {
tests := []struct {
desc string
perRPCCredsViaDialOptions bool
perRPCCredsViaCallOptions bool
wantErr string
}{
{
desc: "send PerRPCCredentials via DialOptions",
perRPCCredsViaDialOptions: true,
perRPCCredsViaCallOptions: false,
wantErr: "context deadline exceeded",
},
{
desc: "send PerRPCCredentials via CallOptions",
perRPCCredsViaDialOptions: false,
perRPCCredsViaCallOptions: true,
wantErr: "transport: cannot send secure credentials on an insecure connection",
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
ss := &stubServer{
emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
},
}

sOpts := []grpc.ServerOption{}
sOpts = append(sOpts, grpc.Creds(insecure.NewCredentials()))
s := grpc.NewServer(sOpts...)
defer s.Stop()

testpb.RegisterTestServiceServer(s, ss)

lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("net.Listen(tcp, localhost:0) failed: %v", err)
}

go s.Serve(lis)

addr := lis.Addr().String()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
cOpts := []grpc.DialOption{grpc.WithBlock()}
cOpts = append(cOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))

if test.perRPCCredsViaDialOptions {
cOpts = append(cOpts, grpc.WithPerRPCCredentials(testLegacyPerRPCCredentials{}))
if _, err := grpc.DialContext(ctx, addr, cOpts...); !strings.Contains(err.Error(), test.wantErr) {
t.Fatalf("InsecureCredsWithPerRPCCredentials/send_PerRPCCredentials_via_DialOptions = %v; want %s", err, test.wantErr)
}
}

if test.perRPCCredsViaCallOptions {
cc, err := grpc.DialContext(ctx, addr, cOpts...)
if err != nil {
t.Fatalf("grpc.Dial(%q) failed: %v", addr, err)
}
defer cc.Close()

c := testpb.NewTestServiceClient(cc)
if _, err = c.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testLegacyPerRPCCredentials{})); !strings.Contains(err.Error(), test.wantErr) {
t.Fatalf("InsecureCredsWithPerRPCCredentials/send_PerRPCCredentials_via_CallOptions = %v; want %s", err, test.wantErr)
}
}
})
}
}
11 changes: 9 additions & 2 deletions test/local_creds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,16 @@ func testLocalCredsE2ESucceed(network, address string) error {
if !ok {
return nil, status.Error(codes.DataLoss, "Failed to get peer from ctx")
}
type internalInfo interface {
GetCommonAuthInfo() credentials.CommonAuthInfo
}
var secLevel credentials.SecurityLevel
if info, ok := (pr.AuthInfo).(internalInfo); ok {
secLevel = info.GetCommonAuthInfo().SecurityLevel
} else {
return nil, status.Errorf(codes.Unauthenticated, "peer.AuthInfo does not implement GetCommonAuthInfo()")
}
// Check security level
info := pr.AuthInfo.(local.Info)
secLevel := info.CommonAuthInfo.SecurityLevel
switch network {
case "unix":
if secLevel != credentials.PrivacyAndIntegrity {
Expand Down

0 comments on commit 21ae702

Please sign in to comment.