Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Machine ID: Provide existing identity when rejoining as part of renewals #43732

Merged
merged 9 commits into from
Jul 23, 2024
52 changes: 52 additions & 0 deletions lib/auth/join/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,10 @@ type RegisterParams struct {
// ID is identity ID
ID state.IdentityID
// AuthServers is a list of auth servers to dial
// Ignored if AuthClient is provided.
AuthServers []utils.NetAddr
// ProxyServer is a proxy server to dial
// Ignored if AuthClient is provided.
ProxyServer utils.NetAddr
// AdditionalPrincipals is a list of additional principals to dial
AdditionalPrincipals []string
Expand All @@ -86,12 +88,16 @@ type RegisterParams struct {
// PublicSSHKey is a server's public SSH key to sign
PublicSSHKey []byte
// CipherSuites is a list of cipher suites to use for TLS client connection
// Ignored if AuthClient is provided.
CipherSuites []uint16
// CAPins are the SKPI hashes of the CAs used to verify the Auth Server.
// Ignored if AuthClient is provided.
CAPins []string
// CAPath is the path to the CA file.
// Ignored if AuthClient is provided.
CAPath string
// GetHostCredentials is a client that can fetch host credentials.
// Ignored if AuthClient is provided.
GetHostCredentials HostCredentials
// Clock specifies the time provider. Will be used to override the time anchor
// for TLS certificate verification.
Expand All @@ -105,8 +111,10 @@ type RegisterParams struct {
// AzureParams is the parameters specific to the azure join method.
AzureParams AzureParams
// CircuitBreakerConfig defines how the circuit breaker should behave.
// Ignored if AuthClient is provided.
CircuitBreakerConfig breaker.Config
// FIPS means FedRAMP/FIPS 140-2 compliant configuration was requested.
// Ignored if AuthClient is provided.
FIPS bool
// IDToken is a token retrieved from a workload identity provider for
// certain join types e.g GitHub, Google.
Expand All @@ -116,7 +124,13 @@ type RegisterParams struct {
// It should not be specified for non-bot registrations.
Expires *time.Time
// Insecure trusts the certificates from the Auth Server or Proxy during registration without verification.
// Ignored if AuthClient is provided.
Insecure bool
// AuthClient allows an existing client with a connection to the auth
// server to be used for the registration process. If specified, then the
// Register method will not attempt to dial, and many other parameters
// may be ignored.
AuthClient AuthJoinClient
}

func (r *RegisterParams) checkAndSetDefaults() error {
Expand All @@ -132,6 +146,11 @@ func (r *RegisterParams) checkAndSetDefaults() error {
}

func (r *RegisterParams) verifyAuthOrProxyAddress() error {
// If AuthClient is provided we do not need addresses to dial with.
if r.AuthClient != nil {
return nil
}

haveAuthServers := len(r.AuthServers) > 0
haveProxyServer := !r.ProxyServer.IsEmpty()

Expand Down Expand Up @@ -210,6 +229,19 @@ func Register(ctx context.Context, params RegisterParams) (certs *proto.Certs, e
}
}

// If an explicit AuthClient has been provided, we want to go straight to
// using that rather than trying both proxy and auth dialing.
if params.AuthClient != nil {
log.Info("Attempting registration with existing auth client.")
certs, err := registerThroughAuthClient(ctx, token, params, params.AuthClient)
if err != nil {
log.WithError(err).Error("Registration with existing auth client failed.")
return nil, trace.Wrap(err)
}
log.Info("Successfully registered with existing auth client.")
return certs, nil
}

type registerMethod struct {
call func(ctx context.Context, token string, params RegisterParams) (*proto.Certs, error)
desc string
Expand Down Expand Up @@ -372,6 +404,26 @@ func registerThroughAuth(
}
defer client.Close()

certs, err = registerThroughAuthClient(ctx, token, params, client)
if err != nil {
return nil, trace.Wrap(err)
}
return certs, nil
}

// AuthJoinClient is a client that allows access to the Auth Servers join
// service and RegisterUsingToken method for the purposes of joining.
type AuthJoinClient interface {
joinServiceClient
RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error)
}

func registerThroughAuthClient(
ctx context.Context,
token string,
params RegisterParams,
client AuthJoinClient,
) (certs *proto.Certs, err error) {
switch params.JoinMethod {
// IAM and Azure methods use unique gRPC endpoints
case types.JoinMethodIAM:
Expand Down
48 changes: 48 additions & 0 deletions lib/auth/join/join_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport/api/client/proto"
headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1"
machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1"
"github.com/gravitational/teleport/api/types"
Expand Down Expand Up @@ -279,3 +282,48 @@ func newBotToken(t *testing.T, tokenName, botName string, role types.SystemRole,
require.NoError(t, err, "could not create bot token")
return token
}

type authJoinClientMock struct {
AuthJoinClient
registerUsingToken func(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error)
}

func (a *authJoinClientMock) RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error) {
return a.registerUsingToken(ctx, req)
}

// TestRegisterWithAuthClient is a unit test to validate joining using a
// auth client supplied via RegisterParams
func TestRegisterWithAuthClient(t *testing.T) {
ctx := context.Background()
expectedCerts := &proto.Certs{
SSH: []byte("ssh-cert"),
}
expectedToken := "test-token"
expectedRole := types.RoleBot
called := false
m := &authJoinClientMock{
registerUsingToken: func(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error) {
assert.Empty(t, cmp.Diff(
req,
&types.RegisterUsingTokenRequest{
Token: expectedToken,
Role: expectedRole,
},
))
called = true
return expectedCerts, nil
},
}

gotCerts, gotErr := Register(ctx, RegisterParams{
Token: expectedToken,
ID: state.IdentityID{
Role: expectedRole,
},
AuthClient: m,
})
require.NoError(t, gotErr)
assert.True(t, called)
assert.Equal(t, expectedCerts, gotCerts)
}
5 changes: 5 additions & 0 deletions lib/tbot/bot/destination.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ type Destination interface {
// as YAML including the type header.
MarshalYAML() (interface{}, error)

// IsPersistent indicates whether this destination is persistent.
// This is true for most production destinations, but will be false for
// Nop or Memory destinations.
IsPersistent() bool

// Stringer so that Destination's implements fmt.Stringer which allows for
// better logging.
fmt.Stringer
Expand Down
6 changes: 0 additions & 6 deletions lib/tbot/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,6 @@ func (conf *OnboardingConfig) HasToken() bool {
return conf.TokenValue != ""
}

// RenewableJoinMethod indicates that certificate renewal should be used with
// this join method rather than rejoining each time.
func (conf *OnboardingConfig) RenewableJoinMethod() bool {
return conf.JoinMethod == types.JoinMethodToken
}

// SetToken stores the value for --token or auth_token in the config
//
// In the case of the token value pointing to a file, this allows us to
Expand Down
4 changes: 4 additions & 0 deletions lib/tbot/config/destination_directory.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,7 @@ func (dm *DestinationDirectory) MarshalYAML() (interface{}, error) {
type raw DestinationDirectory
return withTypeHeader((*raw)(dm), DestinationDirectoryType)
}

func (dd *DestinationDirectory) IsPersistent() bool {
return true
}
4 changes: 4 additions & 0 deletions lib/tbot/config/destination_kubernetes_secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,7 @@ func (dks *DestinationKubernetesSecret) MarshalYAML() (interface{}, error) {
type raw DestinationKubernetesSecret
return withTypeHeader((*raw)(dks), DestinationKubernetesSecretType)
}

func (dks *DestinationKubernetesSecret) IsPersistent() bool {
return true
}
4 changes: 4 additions & 0 deletions lib/tbot/config/destination_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,7 @@ func (dm *DestinationMemory) MarshalYAML() (interface{}, error) {
type raw DestinationMemory
return withTypeHeader((*raw)(dm), DestinationMemoryType)
}

func (dm *DestinationMemory) IsPersistent() bool {
return false
}
4 changes: 4 additions & 0 deletions lib/tbot/config/destination_nop.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,7 @@ func (dm *DestinationNop) MarshalYAML() (interface{}, error) {
type raw DestinationNop
return withTypeHeader((*raw)(dm), DestinationNopType)
}

func (dm *DestinationNop) IsPersistent() bool {
return false
}
31 changes: 23 additions & 8 deletions lib/tbot/identity/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ type Identity struct {
// ClusterName is a name of host's cluster determined from the
// x509 certificate.
ClusterName string
// TLSIdentity is the parsed TLS identity based on the X509 certificate.
TLSIdentity *tlsca.Identity
}

// LoadIdentityParams contains parameters beyond proto.Certs needed to load a
Expand Down Expand Up @@ -171,7 +173,7 @@ func ReadIdentityFromStore(params *LoadIdentityParams, certs *proto.Certs) (*Ide
return nil, trace.Wrap(err, "parsing ssh identity")
}

clusterName, x509Cert, tlsCert, tlsCAPool, err := ParseTLSIdentity(
clusterName, tlsIdent, x509Cert, tlsCert, tlsCAPool, err := ParseTLSIdentity(
params.PrivateKeyBytes, certs.TLS, certs.TLSCACerts,
)
if err != nil {
Expand All @@ -196,41 +198,54 @@ func ReadIdentityFromStore(params *LoadIdentityParams, certs *proto.Certs) (*Ide
X509Cert: x509Cert,
TLSCert: tlsCert,
TLSCAPool: tlsCAPool,
TLSIdentity: tlsIdent,
}, nil
}

// ParseTLSIdentity reads TLS identity from key pair
func ParseTLSIdentity(
keyBytes []byte, certBytes []byte, caCertsBytes [][]byte,
) (clusterName string, x509Cert *x509.Certificate, tlsCert *tls.Certificate, certPool *x509.CertPool, err error) {
) (
clusterName string,
tlsIdentity *tlsca.Identity,
x509Cert *x509.Certificate,
tlsCert *tls.Certificate,
certPool *x509.CertPool,
err error,
) {
x509Cert, err = tlsca.ParseCertificatePEM(certBytes)
if err != nil {
return "", nil, nil, nil, trace.Wrap(err, "parsing certificate")
return "", nil, nil, nil, nil, trace.Wrap(err, "parsing certificate")
}

if len(x509Cert.Issuer.Organization) == 0 {
return "", nil, nil, nil, trace.BadParameter("certificate missing CA organization")
return "", nil, nil, nil, nil, trace.BadParameter("certificate missing CA organization")
}
clusterName = x509Cert.Issuer.Organization[0]
if clusterName == "" {
return "", nil, nil, nil, trace.BadParameter("certificate missing cluster name")
return "", nil, nil, nil, nil, trace.BadParameter("certificate missing cluster name")
}

certPool = x509.NewCertPool()
for j := range caCertsBytes {
parsedCert, err := tlsca.ParseCertificatePEM(caCertsBytes[j])
if err != nil {
return "", nil, nil, nil, trace.Wrap(err, "parsing CA certificate")
return "", nil, nil, nil, nil, trace.Wrap(err, "parsing CA certificate")
}
certPool.AddCert(parsedCert)
}

cert, err := keys.X509KeyPair(certBytes, keyBytes)
if err != nil {
return "", nil, nil, nil, trace.Wrap(err, "parse private key")
return "", nil, nil, nil, nil, trace.Wrap(err, "parse private key")
}

return clusterName, x509Cert, &cert, certPool, nil
tlsIdent, err := tlsca.FromSubject(x509Cert.Subject, x509Cert.NotAfter)
if err != nil {
return "", nil, nil, nil, nil, trace.Wrap(err, "parse tls identity")
}

return clusterName, tlsIdent, x509Cert, &cert, certPool, nil
}

// parseSSHIdentity reads identity from initialized keypair
Expand Down
Loading
Loading