Skip to content

Commit

Permalink
Machine ID: Provide existing identity when rejoining as part of renew…
Browse files Browse the repository at this point in the history
…als (#43732)

* Allow existing bot identity to be used for token based renewals

* Add test for register with auth client

* Fix tests

* Fix test

* Fix partial assignment to interface

* Improve loggging

* Fix spelling

* Fix missing import
  • Loading branch information
strideynet authored Jul 23, 2024
1 parent 55e4524 commit 73793e7
Show file tree
Hide file tree
Showing 12 changed files with 296 additions and 141 deletions.
52 changes: 52 additions & 0 deletions lib/auth/join/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,10 @@ type RegisterParams struct {
// ID is identity ID
ID state.IdentityID
// AuthServers is a list of auth servers to dial
// Ignored if AuthClient is provided.
AuthServers []utils.NetAddr
// ProxyServer is a proxy server to dial
// Ignored if AuthClient is provided.
ProxyServer utils.NetAddr
// AdditionalPrincipals is a list of additional principals to dial
AdditionalPrincipals []string
Expand All @@ -86,12 +88,16 @@ type RegisterParams struct {
// PublicSSHKey is a server's public SSH key to sign
PublicSSHKey []byte
// CipherSuites is a list of cipher suites to use for TLS client connection
// Ignored if AuthClient is provided.
CipherSuites []uint16
// CAPins are the SKPI hashes of the CAs used to verify the Auth Server.
// Ignored if AuthClient is provided.
CAPins []string
// CAPath is the path to the CA file.
// Ignored if AuthClient is provided.
CAPath string
// GetHostCredentials is a client that can fetch host credentials.
// Ignored if AuthClient is provided.
GetHostCredentials HostCredentials
// Clock specifies the time provider. Will be used to override the time anchor
// for TLS certificate verification.
Expand All @@ -105,8 +111,10 @@ type RegisterParams struct {
// AzureParams is the parameters specific to the azure join method.
AzureParams AzureParams
// CircuitBreakerConfig defines how the circuit breaker should behave.
// Ignored if AuthClient is provided.
CircuitBreakerConfig breaker.Config
// FIPS means FedRAMP/FIPS 140-2 compliant configuration was requested.
// Ignored if AuthClient is provided.
FIPS bool
// IDToken is a token retrieved from a workload identity provider for
// certain join types e.g GitHub, Google.
Expand All @@ -116,7 +124,13 @@ type RegisterParams struct {
// It should not be specified for non-bot registrations.
Expires *time.Time
// Insecure trusts the certificates from the Auth Server or Proxy during registration without verification.
// Ignored if AuthClient is provided.
Insecure bool
// AuthClient allows an existing client with a connection to the auth
// server to be used for the registration process. If specified, then the
// Register method will not attempt to dial, and many other parameters
// may be ignored.
AuthClient AuthJoinClient
}

func (r *RegisterParams) checkAndSetDefaults() error {
Expand All @@ -132,6 +146,11 @@ func (r *RegisterParams) checkAndSetDefaults() error {
}

func (r *RegisterParams) verifyAuthOrProxyAddress() error {
// If AuthClient is provided we do not need addresses to dial with.
if r.AuthClient != nil {
return nil
}

haveAuthServers := len(r.AuthServers) > 0
haveProxyServer := !r.ProxyServer.IsEmpty()

Expand Down Expand Up @@ -210,6 +229,19 @@ func Register(ctx context.Context, params RegisterParams) (certs *proto.Certs, e
}
}

// If an explicit AuthClient has been provided, we want to go straight to
// using that rather than trying both proxy and auth dialing.
if params.AuthClient != nil {
log.Info("Attempting registration with existing auth client.")
certs, err := registerThroughAuthClient(ctx, token, params, params.AuthClient)
if err != nil {
log.WithError(err).Error("Registration with existing auth client failed.")
return nil, trace.Wrap(err)
}
log.Info("Successfully registered with existing auth client.")
return certs, nil
}

type registerMethod struct {
call func(ctx context.Context, token string, params RegisterParams) (*proto.Certs, error)
desc string
Expand Down Expand Up @@ -372,6 +404,26 @@ func registerThroughAuth(
}
defer client.Close()

certs, err = registerThroughAuthClient(ctx, token, params, client)
if err != nil {
return nil, trace.Wrap(err)
}
return certs, nil
}

// AuthJoinClient is a client that allows access to the Auth Servers join
// service and RegisterUsingToken method for the purposes of joining.
type AuthJoinClient interface {
joinServiceClient
RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error)
}

func registerThroughAuthClient(
ctx context.Context,
token string,
params RegisterParams,
client AuthJoinClient,
) (certs *proto.Certs, err error) {
switch params.JoinMethod {
// IAM and Azure methods use unique gRPC endpoints
case types.JoinMethodIAM:
Expand Down
48 changes: 48 additions & 0 deletions lib/auth/join/join_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport/api/client/proto"
headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1"
machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1"
"github.com/gravitational/teleport/api/types"
Expand Down Expand Up @@ -279,3 +282,48 @@ func newBotToken(t *testing.T, tokenName, botName string, role types.SystemRole,
require.NoError(t, err, "could not create bot token")
return token
}

type authJoinClientMock struct {
AuthJoinClient
registerUsingToken func(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error)
}

func (a *authJoinClientMock) RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error) {
return a.registerUsingToken(ctx, req)
}

// TestRegisterWithAuthClient is a unit test to validate joining using a
// auth client supplied via RegisterParams
func TestRegisterWithAuthClient(t *testing.T) {
ctx := context.Background()
expectedCerts := &proto.Certs{
SSH: []byte("ssh-cert"),
}
expectedToken := "test-token"
expectedRole := types.RoleBot
called := false
m := &authJoinClientMock{
registerUsingToken: func(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error) {
assert.Empty(t, cmp.Diff(
req,
&types.RegisterUsingTokenRequest{
Token: expectedToken,
Role: expectedRole,
},
))
called = true
return expectedCerts, nil
},
}

gotCerts, gotErr := Register(ctx, RegisterParams{
Token: expectedToken,
ID: state.IdentityID{
Role: expectedRole,
},
AuthClient: m,
})
require.NoError(t, gotErr)
assert.True(t, called)
assert.Equal(t, expectedCerts, gotCerts)
}
5 changes: 5 additions & 0 deletions lib/tbot/bot/destination.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ type Destination interface {
// as YAML including the type header.
MarshalYAML() (interface{}, error)

// IsPersistent indicates whether this destination is persistent.
// This is true for most production destinations, but will be false for
// Nop or Memory destinations.
IsPersistent() bool

// Stringer so that Destination's implements fmt.Stringer which allows for
// better logging.
fmt.Stringer
Expand Down
6 changes: 0 additions & 6 deletions lib/tbot/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,6 @@ func (conf *OnboardingConfig) HasToken() bool {
return conf.TokenValue != ""
}

// RenewableJoinMethod indicates that certificate renewal should be used with
// this join method rather than rejoining each time.
func (conf *OnboardingConfig) RenewableJoinMethod() bool {
return conf.JoinMethod == types.JoinMethodToken
}

// SetToken stores the value for --token or auth_token in the config
//
// In the case of the token value pointing to a file, this allows us to
Expand Down
4 changes: 4 additions & 0 deletions lib/tbot/config/destination_directory.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,7 @@ func (dm *DestinationDirectory) MarshalYAML() (interface{}, error) {
type raw DestinationDirectory
return withTypeHeader((*raw)(dm), DestinationDirectoryType)
}

func (dd *DestinationDirectory) IsPersistent() bool {
return true
}
4 changes: 4 additions & 0 deletions lib/tbot/config/destination_kubernetes_secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,7 @@ func (dks *DestinationKubernetesSecret) MarshalYAML() (interface{}, error) {
type raw DestinationKubernetesSecret
return withTypeHeader((*raw)(dks), DestinationKubernetesSecretType)
}

func (dks *DestinationKubernetesSecret) IsPersistent() bool {
return true
}
4 changes: 4 additions & 0 deletions lib/tbot/config/destination_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,7 @@ func (dm *DestinationMemory) MarshalYAML() (interface{}, error) {
type raw DestinationMemory
return withTypeHeader((*raw)(dm), DestinationMemoryType)
}

func (dm *DestinationMemory) IsPersistent() bool {
return false
}
4 changes: 4 additions & 0 deletions lib/tbot/config/destination_nop.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,7 @@ func (dm *DestinationNop) MarshalYAML() (interface{}, error) {
type raw DestinationNop
return withTypeHeader((*raw)(dm), DestinationNopType)
}

func (dm *DestinationNop) IsPersistent() bool {
return false
}
31 changes: 23 additions & 8 deletions lib/tbot/identity/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ type Identity struct {
// ClusterName is a name of host's cluster determined from the
// x509 certificate.
ClusterName string
// TLSIdentity is the parsed TLS identity based on the X509 certificate.
TLSIdentity *tlsca.Identity
}

// LoadIdentityParams contains parameters beyond proto.Certs needed to load a
Expand Down Expand Up @@ -171,7 +173,7 @@ func ReadIdentityFromStore(params *LoadIdentityParams, certs *proto.Certs) (*Ide
return nil, trace.Wrap(err, "parsing ssh identity")
}

clusterName, x509Cert, tlsCert, tlsCAPool, err := ParseTLSIdentity(
clusterName, tlsIdent, x509Cert, tlsCert, tlsCAPool, err := ParseTLSIdentity(
params.PrivateKeyBytes, certs.TLS, certs.TLSCACerts,
)
if err != nil {
Expand All @@ -196,41 +198,54 @@ func ReadIdentityFromStore(params *LoadIdentityParams, certs *proto.Certs) (*Ide
X509Cert: x509Cert,
TLSCert: tlsCert,
TLSCAPool: tlsCAPool,
TLSIdentity: tlsIdent,
}, nil
}

// ParseTLSIdentity reads TLS identity from key pair
func ParseTLSIdentity(
keyBytes []byte, certBytes []byte, caCertsBytes [][]byte,
) (clusterName string, x509Cert *x509.Certificate, tlsCert *tls.Certificate, certPool *x509.CertPool, err error) {
) (
clusterName string,
tlsIdentity *tlsca.Identity,
x509Cert *x509.Certificate,
tlsCert *tls.Certificate,
certPool *x509.CertPool,
err error,
) {
x509Cert, err = tlsca.ParseCertificatePEM(certBytes)
if err != nil {
return "", nil, nil, nil, trace.Wrap(err, "parsing certificate")
return "", nil, nil, nil, nil, trace.Wrap(err, "parsing certificate")
}

if len(x509Cert.Issuer.Organization) == 0 {
return "", nil, nil, nil, trace.BadParameter("certificate missing CA organization")
return "", nil, nil, nil, nil, trace.BadParameter("certificate missing CA organization")
}
clusterName = x509Cert.Issuer.Organization[0]
if clusterName == "" {
return "", nil, nil, nil, trace.BadParameter("certificate missing cluster name")
return "", nil, nil, nil, nil, trace.BadParameter("certificate missing cluster name")
}

certPool = x509.NewCertPool()
for j := range caCertsBytes {
parsedCert, err := tlsca.ParseCertificatePEM(caCertsBytes[j])
if err != nil {
return "", nil, nil, nil, trace.Wrap(err, "parsing CA certificate")
return "", nil, nil, nil, nil, trace.Wrap(err, "parsing CA certificate")
}
certPool.AddCert(parsedCert)
}

cert, err := keys.X509KeyPair(certBytes, keyBytes)
if err != nil {
return "", nil, nil, nil, trace.Wrap(err, "parse private key")
return "", nil, nil, nil, nil, trace.Wrap(err, "parse private key")
}

return clusterName, x509Cert, &cert, certPool, nil
tlsIdent, err := tlsca.FromSubject(x509Cert.Subject, x509Cert.NotAfter)
if err != nil {
return "", nil, nil, nil, nil, trace.Wrap(err, "parse tls identity")
}

return clusterName, tlsIdent, x509Cert, &cert, certPool, nil
}

// parseSSHIdentity reads identity from initialized keypair
Expand Down
Loading

0 comments on commit 73793e7

Please sign in to comment.