Skip to content

Commit

Permalink
advancedTLS: Rename get root certs related pieces (#7207)
Browse files Browse the repository at this point in the history
  • Loading branch information
gtcooke94 authored May 8, 2024
1 parent f591e3b commit c76f686
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 100 deletions.
111 changes: 72 additions & 39 deletions security/advancedtls/advancedtls.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,31 +87,52 @@ type PostHandshakeVerificationFunc func(params *HandshakeVerificationInfo) (*Pos
// Deprecated: use PostHandshakeVerificationFunc instead.
type CustomVerificationFunc = PostHandshakeVerificationFunc

// GetRootCAsParams contains the parameters available to users when
// implementing GetRootCAs.
type GetRootCAsParams struct {
RawConn net.Conn
// ConnectionInfo contains the parameters available to users when
// implementing GetRootCertificates.
type ConnectionInfo struct {
// RawConn is the raw net.Conn representing a connection.
RawConn net.Conn
// RawCerts is the byte representation of the presented peer cert chain.
RawCerts [][]byte
}

// GetRootCAsResults contains the results of GetRootCAs.
// GetRootCAsParams contains the parameters available to users when
// implementing GetRootCAs.
//
// Deprecated: use ConnectionInfo instead.
type GetRootCAsParams = ConnectionInfo

// RootCertificates is the result of GetRootCertificates.
// If users want to reload the root trust certificate, it is required to return
// the proper TrustCerts in GetRootCAs.
type GetRootCAsResults struct {
type RootCertificates struct {
// TrustCerts is the pool of trusted certificates.
TrustCerts *x509.CertPool
}

// GetRootCAsResults contains the results of GetRootCAs.
// If users want to reload the root trust certificate, it is required to return
// the proper TrustCerts in GetRootCAs.
//
// Deprecated: use RootCertificates instead.
type GetRootCAsResults = RootCertificates

// RootCertificateOptions contains options to obtain root trust certificates
// for both the client and the server.
// At most one option could be set. If none of them are set, we
// use the system default trust certificates.
type RootCertificateOptions struct {
// If RootCertificates is set, it will be used every time when verifying
// the peer certificates, without performing root certificate reloading.
RootCertificates *x509.CertPool
// If RootCACerts is set, it will be used every time when verifying
// the peer certificates, without performing root certificate reloading.
//
// Deprecated: use RootCertificates instead.
RootCACerts *x509.CertPool
// If GetRootCertificates is set, it will be invoked to obtain root certs for
// every new connection.
GetRootCertificates func(params *GetRootCAsParams) (*GetRootCAsResults, error)
GetRootCertificates func(params *ConnectionInfo) (*RootCertificates, error)
// If RootProvider is set, we will use the root certs from the Provider's
// KeyMaterial() call in the new connections. The Provider must have initial
// credentials if specified. Otherwise, KeyMaterial() will block forever.
Expand Down Expand Up @@ -277,6 +298,12 @@ func (o *Options) clientConfig() (*tls.Config, error) {
if o.MaxTLSVersion == 0 {
o.MaxTLSVersion = o.MaxVersion
}
// TODO(gtcooke94) RootCACerts is deprecated, eventually remove this block.
// This will ensure that users still explicitly setting RootCACerts will get
// the setting int the right place.
if o.RootOptions.RootCACerts != nil {
o.RootOptions.RootCertificates = o.RootOptions.RootCACerts
}
if o.VerificationType == SkipVerification && o.AdditionalPeerVerification == nil {
return nil, fmt.Errorf("client needs to provide custom verification mechanism if choose to skip default verification")
}
Expand Down Expand Up @@ -312,19 +339,19 @@ func (o *Options) clientConfig() (*tls.Config, error) {
}
// Propagate root-certificate-related fields in tls.Config.
switch {
case o.RootOptions.RootCACerts != nil:
config.RootCAs = o.RootOptions.RootCACerts
case o.RootOptions.RootCertificates != nil:
config.RootCAs = o.RootOptions.RootCertificates
case o.RootOptions.GetRootCertificates != nil:
// In cases when users provide GetRootCertificates callback, since this
// callback is not contained in tls.Config, we have nothing to set here.
// We will invoke the callback in ClientHandshake.
case o.RootOptions.RootProvider != nil:
o.RootOptions.GetRootCertificates = func(*GetRootCAsParams) (*GetRootCAsResults, error) {
o.RootOptions.GetRootCertificates = func(*ConnectionInfo) (*RootCertificates, error) {
km, err := o.RootOptions.RootProvider.KeyMaterial(context.Background())
if err != nil {
return nil, err
}
return &GetRootCAsResults{TrustCerts: km.Roots}, nil
return &RootCertificates{TrustCerts: km.Roots}, nil
}
default:
// No root certificate options specified by user. Use the certificates
Expand Down Expand Up @@ -381,6 +408,12 @@ func (o *Options) serverConfig() (*tls.Config, error) {
if o.MaxTLSVersion == 0 {
o.MaxTLSVersion = o.MaxVersion
}
// TODO(gtcooke94) RootCACerts is deprecated, eventually remove this block.
// This will ensure that users still explicitly setting RootCACerts will get
// the setting int the right place.
if o.RootOptions.RootCACerts != nil {
o.RootOptions.RootCertificates = o.RootOptions.RootCACerts
}
if o.RequireClientCert && o.VerificationType == SkipVerification && o.AdditionalPeerVerification == nil {
return nil, fmt.Errorf("server needs to provide custom verification mechanism if choose to skip default verification, but require client certificate(s)")
}
Expand Down Expand Up @@ -420,19 +453,19 @@ func (o *Options) serverConfig() (*tls.Config, error) {
}
// Propagate root-certificate-related fields in tls.Config.
switch {
case o.RootOptions.RootCACerts != nil:
config.ClientCAs = o.RootOptions.RootCACerts
case o.RootOptions.RootCertificates != nil:
config.ClientCAs = o.RootOptions.RootCertificates
case o.RootOptions.GetRootCertificates != nil:
// In cases when users provide GetRootCertificates callback, since this
// callback is not contained in tls.Config, we have nothing to set here.
// We will invoke the callback in ServerHandshake.
case o.RootOptions.RootProvider != nil:
o.RootOptions.GetRootCertificates = func(*GetRootCAsParams) (*GetRootCAsResults, error) {
o.RootOptions.GetRootCertificates = func(*ConnectionInfo) (*RootCertificates, error) {
km, err := o.RootOptions.RootProvider.KeyMaterial(context.Background())
if err != nil {
return nil, err
}
return &GetRootCAsResults{TrustCerts: km.Roots}, nil
return &RootCertificates{TrustCerts: km.Roots}, nil
}
default:
// No root certificate options specified by user. Use the certificates
Expand Down Expand Up @@ -477,12 +510,12 @@ func (o *Options) serverConfig() (*tls.Config, error) {
// advancedTLSCreds is the credentials required for authenticating a connection
// using TLS.
type advancedTLSCreds struct {
config *tls.Config
verifyFunc PostHandshakeVerificationFunc
getRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error)
isClient bool
revocationOptions *RevocationOptions
verificationType VerificationType
config *tls.Config
verifyFunc PostHandshakeVerificationFunc
getRootCertificates func(params *ConnectionInfo) (*RootCertificates, error)
isClient bool
revocationOptions *RevocationOptions
verificationType VerificationType
}

func (c advancedTLSCreds) Info() credentials.ProtocolInfo {
Expand Down Expand Up @@ -548,10 +581,10 @@ func (c *advancedTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credenti

func (c *advancedTLSCreds) Clone() credentials.TransportCredentials {
return &advancedTLSCreds{
config: credinternal.CloneTLSConfig(c.config),
verifyFunc: c.verifyFunc,
getRootCAs: c.getRootCAs,
isClient: c.isClient,
config: credinternal.CloneTLSConfig(c.config),
verifyFunc: c.verifyFunc,
getRootCertificates: c.getRootCertificates,
isClient: c.isClient,
}
}

Expand Down Expand Up @@ -588,8 +621,8 @@ func buildVerifyFunc(c *advancedTLSCreds,
rootCAs = c.config.ClientCAs
}
// Reload root CA certs.
if rootCAs == nil && c.getRootCAs != nil {
results, err := c.getRootCAs(&GetRootCAsParams{
if rootCAs == nil && c.getRootCertificates != nil {
results, err := c.getRootCertificates(&ConnectionInfo{
RawConn: rawConn,
RawCerts: rawCerts,
})
Expand Down Expand Up @@ -661,12 +694,12 @@ func NewClientCreds(o *Options) (credentials.TransportCredentials, error) {
return nil, err
}
tc := &advancedTLSCreds{
config: conf,
isClient: true,
getRootCAs: o.RootOptions.GetRootCertificates,
verifyFunc: o.AdditionalPeerVerification,
revocationOptions: o.RevocationOptions,
verificationType: o.VerificationType,
config: conf,
isClient: true,
getRootCertificates: o.RootOptions.GetRootCertificates,
verifyFunc: o.AdditionalPeerVerification,
revocationOptions: o.RevocationOptions,
verificationType: o.VerificationType,
}
tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos)
return tc, nil
Expand All @@ -680,12 +713,12 @@ func NewServerCreds(o *Options) (credentials.TransportCredentials, error) {
return nil, err
}
tc := &advancedTLSCreds{
config: conf,
isClient: false,
getRootCAs: o.RootOptions.GetRootCertificates,
verifyFunc: o.AdditionalPeerVerification,
revocationOptions: o.RevocationOptions,
verificationType: o.VerificationType,
config: conf,
isClient: false,
getRootCertificates: o.RootOptions.GetRootCertificates,
verifyFunc: o.AdditionalPeerVerification,
revocationOptions: o.RevocationOptions,
verificationType: o.VerificationType,
}
tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos)
return tc, nil
Expand Down
30 changes: 15 additions & 15 deletions security/advancedtls/advancedtls_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,13 @@ func (s) TestEnd2End(t *testing.T) {
clientCert []tls.Certificate
clientGetCert func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
clientRoot *x509.CertPool
clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
clientGetRoot func(params *ConnectionInfo) (*RootCertificates, error)
clientVerifyFunc PostHandshakeVerificationFunc
clientVerificationType VerificationType
serverCert []tls.Certificate
serverGetCert func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
serverRoot *x509.CertPool
serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
serverGetRoot func(params *ConnectionInfo) (*RootCertificates, error)
serverVerifyFunc PostHandshakeVerificationFunc
serverVerificationType VerificationType
}{
Expand Down Expand Up @@ -180,12 +180,12 @@ func (s) TestEnd2End(t *testing.T) {
},
clientVerificationType: CertVerification,
serverCert: []tls.Certificate{cs.ServerCert1},
serverGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
serverGetRoot: func(params *ConnectionInfo) (*RootCertificates, error) {
switch stage.read() {
case 0, 1:
return &GetRootCAsResults{TrustCerts: cs.ServerTrust1}, nil
return &RootCertificates{TrustCerts: cs.ServerTrust1}, nil
default:
return &GetRootCAsResults{TrustCerts: cs.ServerTrust2}, nil
return &RootCertificates{TrustCerts: cs.ServerTrust2}, nil
}
},
serverVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) {
Expand All @@ -208,12 +208,12 @@ func (s) TestEnd2End(t *testing.T) {
{
desc: "test the reloading feature for server identity callback and client trust callback",
clientCert: []tls.Certificate{cs.ClientCert1},
clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
clientGetRoot: func(params *ConnectionInfo) (*RootCertificates, error) {
switch stage.read() {
case 0, 1:
return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil
return &RootCertificates{TrustCerts: cs.ClientTrust1}, nil
default:
return &GetRootCAsResults{TrustCerts: cs.ClientTrust2}, nil
return &RootCertificates{TrustCerts: cs.ClientTrust2}, nil
}
},
clientVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) {
Expand Down Expand Up @@ -250,12 +250,12 @@ func (s) TestEnd2End(t *testing.T) {
{
desc: "test client custom verification",
clientCert: []tls.Certificate{cs.ClientCert1},
clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
clientGetRoot: func(params *ConnectionInfo) (*RootCertificates, error) {
switch stage.read() {
case 0:
return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil
return &RootCertificates{TrustCerts: cs.ClientTrust1}, nil
default:
return &GetRootCAsResults{TrustCerts: cs.ClientTrust2}, nil
return &RootCertificates{TrustCerts: cs.ClientTrust2}, nil
}
},
clientVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) {
Expand Down Expand Up @@ -342,7 +342,7 @@ func (s) TestEnd2End(t *testing.T) {
GetIdentityCertificatesForServer: test.serverGetCert,
},
RootOptions: RootCertificateOptions{
RootCACerts: test.serverRoot,
RootCertificates: test.serverRoot,
GetRootCertificates: test.serverGetRoot,
},
RequireClientCert: true,
Expand Down Expand Up @@ -370,7 +370,7 @@ func (s) TestEnd2End(t *testing.T) {
},
AdditionalPeerVerification: test.clientVerifyFunc,
RootOptions: RootCertificateOptions{
RootCACerts: test.clientRoot,
RootCertificates: test.clientRoot,
GetRootCertificates: test.clientGetRoot,
},
VerificationType: test.clientVerificationType,
Expand Down Expand Up @@ -787,7 +787,7 @@ func (s) TestDefaultHostNameCheck(t *testing.T) {
go s.Serve(lis)
clientOptions := &Options{
RootOptions: RootCertificateOptions{
RootCACerts: test.clientRoot,
RootCertificates: test.clientRoot,
},
VerificationType: test.clientVerificationType,
}
Expand Down Expand Up @@ -927,7 +927,7 @@ func (s) TestTLSVersions(t *testing.T) {
go s.Serve(lis)
clientOptions := &Options{
RootOptions: RootCertificateOptions{
RootCACerts: cs.ClientTrust1,
RootCertificates: cs.ClientTrust1,
},
VerificationType: CertAndHostVerification,
MinTLSVersion: test.clientMinVersion,
Expand Down
Loading

0 comments on commit c76f686

Please sign in to comment.