From f633b8d894238631fc67ce1e34b6662f1c64ac73 Mon Sep 17 00:00:00 2001 From: Nic Klaassen Date: Wed, 25 Sep 2024 10:40:44 -0700 Subject: [PATCH 1/3] chore: update IAM join method to use aws-sdk-go-v2 --- integration/ec2_test.go | 43 +++---- integration/teleterm_test.go | 6 +- lib/auth/join/iam/endpoints.go | 23 ++-- lib/auth/join/iam/iam.go | 210 +++++++++++++++++++-------------- lib/auth/join/iam/iam_test.go | 149 +++++++++++++++++++++++ lib/auth/join/join.go | 15 ++- lib/auth/join_iam.go | 24 ++-- lib/auth/register.go | 1 - 8 files changed, 324 insertions(+), 147 deletions(-) create mode 100644 lib/auth/join/iam/iam_test.go diff --git a/integration/ec2_test.go b/integration/ec2_test.go index fb85d25eb91e5..d75ce0014d9c6 100644 --- a/integration/ec2_test.go +++ b/integration/ec2_test.go @@ -28,8 +28,7 @@ import ( "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" @@ -61,8 +60,9 @@ func newSilentLogger() utils.Logger { return logger } -func newNodeConfig(t *testing.T, authAddr utils.NetAddr, tokenName string, joinMethod types.JoinMethod) *servicecfg.Config { +func newNodeConfig(t *testing.T, tokenName string, joinMethod types.JoinMethod) *servicecfg.Config { config := servicecfg.MakeDefaultConfig() + config.Version = defaults.TeleportConfigVersionV3 config.SetToken(tokenName) config.JoinMethod = joinMethod config.SSH.Enabled = true @@ -70,7 +70,6 @@ func newNodeConfig(t *testing.T, authAddr utils.NetAddr, tokenName string, joinM config.Auth.Enabled = false config.Proxy.Enabled = false config.DataDir = t.TempDir() - config.SetAuthServerAddress(authAddr) config.Log = newSilentLogger() config.CircuitBreakerConfig = breaker.NoopBreakerConfig() config.InstanceMetadataClient = cloudimds.NewDisabledIMDSClient() @@ -79,7 +78,7 @@ func newNodeConfig(t *testing.T, authAddr utils.NetAddr, tokenName string, joinM func newProxyConfig(t *testing.T, authAddr utils.NetAddr, tokenName string, joinMethod types.JoinMethod) *servicecfg.Config { config := servicecfg.MakeDefaultConfig() - config.Version = defaults.TeleportConfigVersionV2 + config.Version = defaults.TeleportConfigVersionV3 config.SetToken(tokenName) config.JoinMethod = joinMethod config.SSH.Enabled = false @@ -109,6 +108,7 @@ func newAuthConfig(t *testing.T, clock clockwork.Clock) *servicecfg.Config { } config := servicecfg.MakeDefaultConfig() + config.Version = defaults.TeleportConfigVersionV3 config.DataDir = t.TempDir() config.Auth.ListenAddr.Addr = helpers.NewListener(t, service.ListenerAuth, &config.FileDescriptors) config.Auth.ClusterName, err = services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{ @@ -140,13 +140,11 @@ func getIID(ctx context.Context, t *testing.T) imds.InstanceIdentityDocument { return output.InstanceIdentityDocument } -func getCallerIdentity(t *testing.T) *sts.GetCallerIdentityOutput { - sess, err := session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - }) +func getCallerIdentity(ctx context.Context, t *testing.T) *sts.GetCallerIdentityOutput { + cfg, err := config.LoadDefaultConfig(ctx) require.NoError(t, err) - stsService := sts.New(sess) - output, err := stsService.GetCallerIdentity(nil /*input*/) + stsClient := sts.NewFromConfig(cfg) + output, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) require.NoError(t, err) return output } @@ -201,7 +199,8 @@ func TestEC2NodeJoin(t *testing.T) { require.Empty(t, nodes) // create and start the node - nodeConfig := newNodeConfig(t, authConfig.Auth.ListenAddr, tokenName, types.JoinMethodEC2) + nodeConfig := newNodeConfig(t, tokenName, types.JoinMethodEC2) + nodeConfig.SetAuthServerAddress(authConfig.Auth.ListenAddr) nodeSvc, err := service.NewTeleport(nodeConfig) require.NoError(t, err) require.NoError(t, nodeSvc.Start()) @@ -214,7 +213,7 @@ func TestEC2NodeJoin(t *testing.T) { require.Eventually(t, func() bool { nodes, _ := authServer.GetNodes(ctx, apidefaults.Namespace) return len(nodes) > 0 - }, time.Minute, time.Second, "waiting for node to join cluster") + }, 10*time.Second, 50*time.Millisecond, "waiting for node to join cluster") } // TestIAMNodeJoin is an integration test which asserts that the IAM method for @@ -225,6 +224,7 @@ func TestIAMNodeJoin(t *testing.T) { if os.Getenv("TELEPORT_TEST_EC2") == "" { t.Skipf("Skipping TestIAMNodeJoin because TELEPORT_TEST_EC2 is not set") } + ctx := context.Background() // create and start the auth server authConfig := newAuthConfig(t, nil /*clock*/) @@ -236,7 +236,7 @@ func TestIAMNodeJoin(t *testing.T) { authServer := authSvc.GetAuthServer() // fetch the caller identity to find the AWS account and create the token - id := getCallerIdentity(t) + id := getCallerIdentity(ctx, t) tokenName := "test_token" token, err := types.NewProvisionTokenFromSpec( @@ -253,7 +253,7 @@ func TestIAMNodeJoin(t *testing.T) { }) require.NoError(t, err) - err = authServer.UpsertToken(context.Background(), token) + err = authServer.UpsertToken(ctx, token) require.NoError(t, err) // sanity check there are no proxies to start with @@ -274,20 +274,21 @@ func TestIAMNodeJoin(t *testing.T) { proxies, err := authServer.GetProxies() assert.NoError(t, err) assert.NotEmpty(t, proxies) - }, time.Minute, time.Second, "waiting for proxy to join cluster") + }, 10*time.Second, 50*time.Millisecond, "waiting for proxy to join cluster") // InsecureDevMode needed for node to trust proxy wasInsecureDevMode := lib.IsInsecureDevMode() t.Cleanup(func() { lib.SetInsecureDevMode(wasInsecureDevMode) }) lib.SetInsecureDevMode(true) // sanity check there are no nodes to start with - nodes, err := authServer.GetNodes(context.Background(), apidefaults.Namespace) + nodes, err := authServer.GetNodes(ctx, apidefaults.Namespace) require.NoError(t, err) require.Empty(t, nodes) - // create and start a node, with use the IAM method to join in IoT mode by + // create and start a node, will use the IAM method to join in IoT mode by // connecting to the proxy - nodeConfig := newNodeConfig(t, proxyConfig.Proxy.WebAddr, tokenName, types.JoinMethodIAM) + nodeConfig := newNodeConfig(t, tokenName, types.JoinMethodIAM) + nodeConfig.ProxyServer = proxyConfig.Proxy.WebAddr nodeSvc, err := service.NewTeleport(nodeConfig) require.NoError(t, err) require.NoError(t, nodeSvc.Start()) @@ -295,10 +296,10 @@ func TestIAMNodeJoin(t *testing.T) { // the node should eventually join the cluster and heartbeat require.EventuallyWithT(t, func(t *assert.CollectT) { - nodes, err := authServer.GetNodes(context.Background(), apidefaults.Namespace) + nodes, err := authServer.GetNodes(ctx, apidefaults.Namespace) assert.NoError(t, err) assert.NotEmpty(t, nodes) - }, time.Minute, time.Second, "waiting for node to join cluster") + }, 10*time.Second, 50*time.Millisecond, "waiting for node to join cluster") } type mockIMDSClient struct { diff --git a/integration/teleterm_test.go b/integration/teleterm_test.go index bdd3cacb6ad6a..17ec75121f614 100644 --- a/integration/teleterm_test.go +++ b/integration/teleterm_test.go @@ -957,7 +957,8 @@ func testWaitForConnectMyComputerNodeJoin(t *testing.T, pack *dbhelpers.Database }() // Start the new node. - nodeConfig := newNodeConfig(t, pack.Root.Cluster.Config.Auth.ListenAddr, "token", types.JoinMethodToken) + nodeConfig := newNodeConfig(t, "token", types.JoinMethodToken) + nodeConfig.SetAuthServerAddress(pack.Root.Cluster.Config.Auth.ListenAddr) nodeConfig.DataDir = filepath.Join(agentsDir, profileName, "data") nodeConfig.Log = libutils.NewLoggerForTests() nodeSvc, err := service.NewTeleport(nodeConfig) @@ -1031,7 +1032,8 @@ func testDeleteConnectMyComputerNode(t *testing.T, pack *dbhelpers.DatabasePack) require.NoError(t, err) // Start the new node. - nodeConfig := newNodeConfig(t, pack.Root.Cluster.Config.Auth.ListenAddr, "token", types.JoinMethodToken) + nodeConfig := newNodeConfig(t, "token", types.JoinMethodToken) + nodeConfig.SetAuthServerAddress(pack.Root.Cluster.Config.Auth.ListenAddr) nodeConfig.DataDir = filepath.Join(agentsDir, profileName, "data") nodeConfig.Log = libutils.NewLoggerForTests() nodeSvc, err := service.NewTeleport(nodeConfig) diff --git a/lib/auth/join/iam/endpoints.go b/lib/auth/join/iam/endpoints.go index 3434ce30228b1..bcaad06b0f1e9 100644 --- a/lib/auth/join/iam/endpoints.go +++ b/lib/auth/join/iam/endpoints.go @@ -19,10 +19,13 @@ package iam import "sync" var ( - // ValidSTSEndpoints holds a sorted list of all known valid public endpoints for - // the AWS STS service. You can generate this list by running - // $ go run github.com/nklaassen/sts-endpoints@latest --go-list - // Update aws-sdk-go in that package to learn about new endpoints. + // ValidSTSEndpoints returns a sorted list of all known valid public endpoints for + // the AWS STS service. + // + // TODO(nklaassen): find a better way to validate STS endpoints or generate + // this list and get notified when it needs to be updated. The original + // solution was https://github.com/nklaassen/sts-endpoints which is based on + // aws-sdk-go v1 which no longer gets updates for new regions. ValidSTSEndpoints = sync.OnceValue(func() []string { return []string{ "sts-fips.us-east-1.amazonaws.com", @@ -69,15 +72,7 @@ var ( } }) - GlobalSTSEndpoints = sync.OnceValue(func() []string { - return []string{ - "sts.amazonaws.com", - // This is not a real endpoint, but the SDK will select it if - // AWS_USE_FIPS_ENDPOINT is set and a region is not. - "sts-fips.aws-global.amazonaws.com", - } - }) - + // FIPSSTSEndpoints returns the set of known valid FIPS AWS STS endpoints. FIPSSTSEndpoints = sync.OnceValue(func() []string { return []string{ "sts-fips.us-east-1.amazonaws.com", @@ -89,3 +84,5 @@ var ( } }) ) + +const fipsSTSEndpointUSEast1 = "sts-fips.us-east-1.amazonaws.com" diff --git a/lib/auth/join/iam/iam.go b/lib/auth/join/iam/iam.go index 241d4ec7800c3..54413e372a960 100644 --- a/lib/auth/join/iam/iam.go +++ b/lib/auth/join/iam/iam.go @@ -19,14 +19,16 @@ package iam import ( "bytes" "context" + "errors" + "io" "log/slog" + "net/http" "slices" - "strings" - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/sts" + smithyendpoints "github.com/aws/smithy-go/endpoints" "github.com/gravitational/trace" cloudaws "github.com/gravitational/teleport/lib/cloud/imds/aws" @@ -38,39 +40,85 @@ const ( challengeHeaderKey = "x-teleport-challenge" ) -type stsIdentityRequestConfig struct { - regionalEndpointOption endpoints.STSRegionalEndpoint - fipsEndpointOption endpoints.FIPSEndpointState +type stsIdentityRequestOptions struct { + fipsEndpointOption aws.FIPSEndpointState + imdsClient imdsClient } -type stsIdentityRequestOption func(cfg *stsIdentityRequestConfig) +type stsIdentityRequestOption func(cfg *stsIdentityRequestOptions) -func WithRegionalEndpoint(useRegionalEndpoint bool) stsIdentityRequestOption { - return func(cfg *stsIdentityRequestConfig) { - if useRegionalEndpoint { - cfg.regionalEndpointOption = endpoints.RegionalSTSEndpoint +// WithFIPSEndpoint is a functional option to use a FIPS STS endpoint. In non-US +// regions, this will use the us-east-1 FIPS endpoint. +func WithFIPSEndpoint(useFIPS bool) stsIdentityRequestOption { + return func(opts *stsIdentityRequestOptions) { + if useFIPS { + opts.fipsEndpointOption = aws.FIPSEndpointStateEnabled } else { - cfg.regionalEndpointOption = endpoints.LegacySTSEndpoint + opts.fipsEndpointOption = aws.FIPSEndpointStateUnset } } } -func WithFIPSEndpoint(useFIPS bool) stsIdentityRequestOption { - return func(cfg *stsIdentityRequestConfig) { - if useFIPS { - cfg.fipsEndpointOption = endpoints.FIPSEndpointStateEnabled - } else { - cfg.fipsEndpointOption = endpoints.FIPSEndpointStateDisabled +// WithIMDSClient is a functional option to use a custom IMDS client. +func WithIMDSClient(clt imdsClient) stsIdentityRequestOption { + return func(opts *stsIdentityRequestOptions) { + opts.imdsClient = clt + } +} + +type imdsClient interface { + // IsAvailable should return true if the IMDSv2 is available, and false + // otherwise. + IsAvailable(context.Context) bool + // GetRegion should return the local region as reported by the IMDSv2. + GetRegion(context.Context) (string, error) +} + +// CreateSignedSTSIdentityRequest is called on the client side and returns an +// sts:GetCallerIdentity request signed with the local AWS credentials +func CreateSignedSTSIdentityRequest(ctx context.Context, challenge string, opts ...stsIdentityRequestOption) ([]byte, error) { + var options stsIdentityRequestOptions + for _, opt := range opts { + opt(&options) + } + + awsConfig, err := config.LoadDefaultConfig(ctx) + if err != nil { + return nil, trace.Wrap(err, "loading default AWS config") + } + + var signedRequest bytes.Buffer + stsClient := sts.NewFromConfig(awsConfig, + sts.WithEndpointResolverV2(newCustomResolver(challenge, &options)), + func(stsOpts *sts.Options) { + stsOpts.EndpointOptions.UseFIPSEndpoint = options.fipsEndpointOption + // Use a fake HTTP client to record the request. + stsOpts.HTTPClient = &httpRequestRecorder{&signedRequest} + // httpRequestRecorder intentionally records the request and returns + // an error, don't retry. + stsOpts.RetryMaxAttempts = 1 + }) + + if _, err = stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}); !errors.Is(err, errRequestRecorded) { + if err == nil { + return nil, trace.Errorf("expected to get errRequestedRecorded, got (this is a bug)") } + return nil, trace.Wrap(err, "building signed sts:GetCallerIdentity request") } + + return signedRequest.Bytes(), nil } // getEC2LocalRegion returns the AWS region this EC2 instance is running in, or // a NotFound error if the EC2 IMDS is unavailable. -func getEC2LocalRegion(ctx context.Context) (string, error) { - imdsClient, err := cloudaws.NewInstanceMetadataClient(ctx) - if err != nil { - return "", trace.Wrap(err) +func getEC2LocalRegion(ctx context.Context, opts *stsIdentityRequestOptions) (string, error) { + imdsClient := opts.imdsClient + if imdsClient == nil { + var err error + imdsClient, err = cloudaws.NewInstanceMetadataClient(ctx) + if err != nil { + return "", trace.Wrap(err) + } } if !imdsClient.IsAvailable(ctx) { @@ -81,81 +129,63 @@ func getEC2LocalRegion(ctx context.Context) (string, error) { return region, trace.Wrap(err) } -func newSTSClient(ctx context.Context, cfg *stsIdentityRequestConfig) (*sts.STS, error) { - awsConfig := awssdk.Config{ - UseFIPSEndpoint: cfg.fipsEndpointOption, - STSRegionalEndpoint: cfg.regionalEndpointOption, - } - sess, err := session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - Config: awsConfig, - }) - if err != nil { - return nil, trace.Wrap(err) +type customResolver struct { + defaultResolver sts.EndpointResolverV2 + challenge string + opts *stsIdentityRequestOptions +} + +func newCustomResolver(challenge string, opts *stsIdentityRequestOptions) *customResolver { + return &customResolver{ + defaultResolver: sts.NewDefaultEndpointResolverV2(), + challenge: challenge, + opts: opts, } +} - stsClient := sts.New(sess) - - if slices.Contains(GlobalSTSEndpoints(), strings.TrimPrefix(stsClient.Endpoint, "https://")) { - // If the caller wants to use the regional endpoint but it was not resolved - // from the environment, attempt to find the region from the EC2 IMDS - if cfg.regionalEndpointOption == endpoints.RegionalSTSEndpoint { - region, err := getEC2LocalRegion(ctx) - if err != nil { - return nil, trace.Wrap(err, "failed to resolve local AWS region from environment or IMDS") - } - stsClient = sts.New(sess, awssdk.NewConfig().WithRegion(region)) - } else { - const msg = "Attempting to use the global STS endpoint for the IAM join method. " + - "This will probably fail in non-default AWS partitions such as China or GovCloud, or if FIPS mode is enabled. " + - "Consider setting the AWS_REGION environment variable, setting the region in ~/.aws/config, or enabling the IMDSv2." - slog.InfoContext(ctx, msg) +// ResolveEndpoint implements [sts.EndpointResolverV2]. +func (r customResolver) ResolveEndpoint(ctx context.Context, params sts.EndpointParameters) (smithyendpoints.Endpoint, error) { + if aws.ToString(params.Region) == "" { + // If we don't have a region from the environment here this will fail to + // resolve. We can try to get the local region from IMDSv2 if running on EC2. + region, err := getEC2LocalRegion(ctx, r.opts) + switch { + case trace.IsNotFound(err): + params.Region = aws.String("aws-global") + params.UseGlobalEndpoint = aws.Bool(true) + case err != nil: + return smithyendpoints.Endpoint{}, trace.Wrap(err, "failed to resolve local AWS region from environment or IMDS") + default: + params.Region = aws.String(region) } } - - if cfg.fipsEndpointOption == endpoints.FIPSEndpointStateEnabled && - !slices.Contains(ValidSTSEndpoints(), strings.TrimPrefix(stsClient.Endpoint, "https://")) { - // The AWS SDK will generate invalid endpoints when attempting to - // resolve the FIPS endpoint for a region that does not have one. - // In this case, try to use the FIPS endpoint in us-east-1. This should - // work for all regions in the standard partition. In GovCloud, we should - // not hit this because all regional endpoints support FIPS. In China or - // other partitions, this will fail, and FIPS mode will not be supported. - const msg = "AWS SDK resolved invalid FIPS STS endpoint. " + - "Attempting to use the FIPS STS endpoint for us-east-1." - slog.InfoContext(ctx, msg, "resolved", stsClient.Endpoint) - stsClient = sts.New(sess, awssdk.NewConfig().WithRegion("us-east-1")) + endpoint, err := r.defaultResolver.ResolveEndpoint(ctx, params) + if err != nil { + return smithyendpoints.Endpoint{}, trace.Wrap(err) } - - return stsClient, nil + if aws.ToBool(params.UseFIPS) && !slices.Contains(FIPSSTSEndpoints(), endpoint.URI.Host) { + // The default resolver will return non-existent endpoints if FIPS was + // requested in regions outside the USA. Use the FIPS endpoint in + // us-east-1 instead. + slog.InfoContext(ctx, "The AWS SDK resolved an invalid FIPS STS endpoint, attempting to use the us-east-1 FIPS STS endpoint instead. This will fail in non-default AWS partitions.", "resolved", endpoint.URI.Host) + endpoint.URI.Host = fipsSTSEndpointUSEast1 + } + // Add challenge as a header to be signed. + endpoint.Headers.Add(challengeHeaderKey, r.challenge) + // Request JSON for simpler parsing. + endpoint.Headers.Add("Accept", "application/json") + return endpoint, nil } -// CreateSignedSTSIdentityRequest is called on the client side and returns an -// sts:GetCallerIdentity request signed with the local AWS credentials -func CreateSignedSTSIdentityRequest(ctx context.Context, challenge string, opts ...stsIdentityRequestOption) ([]byte, error) { - cfg := &stsIdentityRequestConfig{} - for _, opt := range opts { - opt(cfg) - } +type httpRequestRecorder struct { + w io.Writer +} - stsClient, err := newSTSClient(ctx, cfg) - if err != nil { - return nil, trace.Wrap(err) - } +var errRequestRecorded = errors.New("request recorded") - req, _ := stsClient.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{}) - // set challenge header - req.HTTPRequest.Header.Set(challengeHeaderKey, challenge) - // request json for simpler parsing - req.HTTPRequest.Header.Set("Accept", "application/json") - // sign the request, including headers - if err := req.Sign(); err != nil { - return nil, trace.Wrap(err) - } - // write the signed HTTP request to a buffer - var signedRequest bytes.Buffer - if err := req.HTTPRequest.Write(&signedRequest); err != nil { +func (r *httpRequestRecorder) Do(req *http.Request) (*http.Response, error) { + if err := req.Write(r.w); err != nil { return nil, trace.Wrap(err) } - return signedRequest.Bytes(), nil + return nil, errRequestRecorded } diff --git a/lib/auth/join/iam/iam_test.go b/lib/auth/join/iam/iam_test.go new file mode 100644 index 0000000000000..f27735c9baaa9 --- /dev/null +++ b/lib/auth/join/iam/iam_test.go @@ -0,0 +1,149 @@ +package iam_test + +import ( + "bufio" + "bytes" + "context" + "net/http" + "os" + "testing" + + "github.com/gravitational/teleport/lib/auth/join/iam" + "github.com/gravitational/teleport/lib/utils/aws" + "github.com/stretchr/testify/require" +) + +func TestCreateSignedSTSIdentityRequest(t *testing.T) { + ctx := context.Background() + + t.Setenv("AWS_ACCESS_KEY_ID", "FAKE_KEY_ID") + t.Setenv("AWS_SECRET_ACCESS_KEY", "FAKE_KEY") + t.Setenv("AWS_SESSION_TOKEN", "FAKE_SESSION_TOKEN") + + const challenge = "asdf12345" + + for desc, tc := range map[string]struct { + envRegion string + imdsRegion string + fips bool + expectEndpoint string + expectError string + }{ + "no region": { + expectEndpoint: "sts.amazonaws.com", + }, + "no region fips": { + fips: true, + expectEndpoint: "sts-fips.us-east-1.amazonaws.com", + }, + "us-west-2": { + envRegion: "us-west-2", + expectEndpoint: "sts.us-west-2.amazonaws.com", + }, + "us-west-2 with region from imdsv2": { + imdsRegion: "us-west-2", + expectEndpoint: "sts.us-west-2.amazonaws.com", + }, + "us-west-2 fips": { + envRegion: "us-west-2", + fips: true, + expectEndpoint: "sts-fips.us-west-2.amazonaws.com", + }, + "us-west-2 fips with region from imdsv2": { + imdsRegion: "us-west-2", + fips: true, + expectEndpoint: "sts-fips.us-west-2.amazonaws.com", + }, + "eu-central-1": { + envRegion: "eu-central-1", + expectEndpoint: "sts.eu-central-1.amazonaws.com", + }, + "eu-central-1 fips": { + envRegion: "eu-central-1", + fips: true, + // All non-US regions have no FIPS endpoint and use the FIPS + // endpoint in us-east-1. + expectEndpoint: "sts-fips.us-east-1.amazonaws.com", + }, + "ap-southeast-1": { + envRegion: "ap-southeast-1", + expectEndpoint: "sts.ap-southeast-1.amazonaws.com", + }, + "ap-southeast-1 fips": { + envRegion: "ap-southeast-1", + fips: true, + // All non-US regions have no FIPS endpoint and try to use the FIPS + // endpoint in us-east-1, but this will fail if the AWS credentials + // were issued by the AWS China partition because they will not be + // recognized by STS in the default partition. It will fail when + // Auth sends the request to AWS, but this unit test only exercizes + // the client-side request generation. + expectEndpoint: "sts-fips.us-east-1.amazonaws.com", + }, + "govcloud": { + envRegion: "us-gov-east-1", + expectEndpoint: "sts.us-gov-east-1.amazonaws.com", + }, + "govcloud fips": { + envRegion: "us-gov-east-1", + fips: true, + // All govcloud endpoints are FIPS. + expectEndpoint: "sts.us-gov-east-1.amazonaws.com", + }, + } { + t.Run(desc, func(t *testing.T) { + if len(tc.envRegion) > 0 { + t.Setenv("AWS_REGION", tc.envRegion) + } else { + // There's no t.Unsetenv so do this manually. + prev := os.Getenv("AWS_REGION") + os.Unsetenv("AWS_REGION") + t.Cleanup(func() { os.Setenv("AWS_REGION", prev) }) + } + + imdsClient := &fakeIMDSClient{} + if tc.imdsRegion != "" { + imdsClient = &fakeIMDSClient{ + available: true, + region: tc.imdsRegion, + } + } + + // Create the signed sts:GetCallerIdentity request, which is a full + // HTTP request with a body serialized into a byte slice. + req, err := iam.CreateSignedSTSIdentityRequest(ctx, challenge, + iam.WithFIPSEndpoint(tc.fips), + iam.WithIMDSClient(imdsClient)) + if tc.expectError != "" { + require.Error(t, err) + require.ErrorContains(t, err, tc.expectError) + return + } + require.NoError(t, err) + + // Parse the serialized HTTP request to check the endpoint and + // parameters were correctly included by the AWS SDK. + httpReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(req))) + require.NoError(t, err) + require.Equal(t, tc.expectEndpoint, httpReq.Host) + authHeader := httpReq.Header.Get(aws.AuthorizationHeader) + sigV4, err := aws.ParseSigV4(authHeader) + require.NoError(t, err) + require.Contains(t, sigV4.SignedHeaders, "x-teleport-challenge") + require.Equal(t, challenge, httpReq.Header.Get("x-teleport-challenge")) + }) + } +} + +type fakeIMDSClient struct { + available bool + region string +} + +func (c *fakeIMDSClient) IsAvailable(_ context.Context) bool { + return c.available +} + +func (c *fakeIMDSClient) GetRegion(_ context.Context) (string, error) { + return c.region, nil +} diff --git a/lib/auth/join/join.go b/lib/auth/join/join.go index 5ca2d19dc7470..038a8a693283f 100644 --- a/lib/auth/join/join.go +++ b/lib/auth/join/join.go @@ -426,12 +426,12 @@ func registerThroughAuth( client, err = insecureRegisterClient(params) } if err != nil { - return nil, trace.Wrap(err) + return nil, trace.Wrap(err, "building auth client") } defer client.Close() result, err = registerThroughAuthClient(ctx, token, params, client) - return result, trace.Wrap(err) + return result, trace.Wrap(err, "registering through auth client") } // AuthJoinClient is a client that allows access to the Auth Servers join @@ -450,7 +450,7 @@ func registerThroughAuthClient( ) (result *RegisterResult, err error) { hostKeys, err := generateHostKeysForAuth(ctx, client) if err != nil { - return nil, trace.Wrap(err) + return nil, trace.Wrap(err, "generating host keys") } var certs *proto.Certs @@ -468,7 +468,7 @@ func registerThroughAuthClient( certs, err = client.RegisterUsingToken(ctx, registerUsingTokenRequestForParams(token, hostKeys, params)) } if err != nil { - return nil, trace.Wrap(err) + return nil, trace.Wrap(err, "registering with %s method", params.JoinMethod) } return &RegisterResult{ Certs: certs, @@ -506,7 +506,7 @@ func insecureRegisterClient(params RegisterParams) (*authclient.Client, error) { CircuitBreakerConfig: params.CircuitBreakerConfig, }) if err != nil { - return nil, trace.Wrap(err) + return nil, trace.Wrap(err, "creating insecure auth client") } return client, nil @@ -665,10 +665,9 @@ func registerUsingIAMMethod( // create the signed sts:GetCallerIdentity request and include the challenge signedRequest, err := iam.CreateSignedSTSIdentityRequest(ctx, challenge, iam.WithFIPSEndpoint(params.FIPS), - iam.WithRegionalEndpoint(true), ) if err != nil { - return nil, trace.Wrap(err) + return nil, trace.Wrap(err, "creating signed sts:GetCallerIdentity request") } // send the register request including the challenge response @@ -679,7 +678,7 @@ func registerUsingIAMMethod( }) if err != nil { log.WithError(err).Infof("Failed to register %s using regional STS endpoint", params.ID.Role) - return nil, trace.Wrap(err) + return nil, trace.Wrap(err, "registering via IAM method streaming RPC") } log.Infof("Successfully registered %s with IAM method using regional STS endpoint", params.ID.Role) diff --git a/lib/auth/join_iam.go b/lib/auth/join_iam.go index 6668a3449b8b5..a43fbe70fd920 100644 --- a/lib/auth/join_iam.go +++ b/lib/auth/join_iam.go @@ -264,7 +264,7 @@ func (a *Server) checkIAMRequest(ctx context.Context, challenge string, req *pro tokenName := req.RegisterUsingTokenRequest.Token provisionToken, err := a.GetToken(ctx, tokenName) if err != nil { - return trace.Wrap(err) + return trace.Wrap(err, "getting token") } if provisionToken.GetJoinMethod() != types.JoinMethodIAM { return trace.AccessDenied("this token does not support the IAM join method") @@ -273,25 +273,25 @@ func (a *Server) checkIAMRequest(ctx context.Context, challenge string, req *pro // parse the incoming http request to the sts:GetCallerIdentity endpoint identityRequest, err := parseSTSRequest(req.StsIdentityRequest) if err != nil { - return trace.Wrap(err) + return trace.Wrap(err, "parsing STS request") } // validate that the host, method, and headers are correct and the expected // challenge is included in the signed portion of the request if err := validateSTSIdentityRequest(identityRequest, challenge, cfg); err != nil { - return trace.Wrap(err) + return trace.Wrap(err, "validating STS request") } // send the signed request to the public AWS API and get the node identity // from the response identity, err := executeSTSIdentityRequest(ctx, a.httpClientForAWSSTS, identityRequest) if err != nil { - return trace.Wrap(err) + return trace.Wrap(err, "executing STS request") } // check that the node identity matches an allow rule for this token if err := checkIAMAllowRules(identity, provisionToken.GetName(), provisionToken.GetAllowRules()); err != nil { - return trace.Wrap(err) + return trace.Wrap(err, "checking allow rules") } return nil @@ -357,34 +357,34 @@ func (a *Server) RegisterUsingIAMMethod( challenge, err := generateIAMChallenge() if err != nil { - return nil, trace.Wrap(err) + return nil, trace.Wrap(err, "generating IAM challenge") } req, err := challengeResponse(challenge) if err != nil { - return nil, trace.Wrap(err) + return nil, trace.Wrap(err, "getting challenge response") } joinRequest = req.RegisterUsingTokenRequest if err := req.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) + return nil, trace.Wrap(err, "validating request parameters") } // perform common token checks provisionToken, err = a.checkTokenJoinRequestCommon(ctx, req.RegisterUsingTokenRequest) if err != nil { - return nil, trace.Wrap(err) + return nil, trace.Wrap(err, "completing common token checks") } // check that the GetCallerIdentity request is valid and matches the token if err := a.checkIAMRequest(ctx, challenge, req, cfg); err != nil { - return nil, trace.Wrap(err) + return nil, trace.Wrap(err, "checking iam request") } if req.RegisterUsingTokenRequest.Role == types.RoleBot { certs, err := a.generateCertsBot(ctx, provisionToken, req.RegisterUsingTokenRequest, nil) - return certs, trace.Wrap(err) + return certs, trace.Wrap(err, "generating bot certs") } certs, err = a.generateCerts(ctx, provisionToken, req.RegisterUsingTokenRequest, nil) - return certs, trace.Wrap(err) + return certs, trace.Wrap(err, "generating certs") } diff --git a/lib/auth/register.go b/lib/auth/register.go index 565d155ecb6c0..43120e2c16428 100644 --- a/lib/auth/register.go +++ b/lib/auth/register.go @@ -36,7 +36,6 @@ import ( // within the same process as the Auth Server and as such, does not need to // use provisioning tokens. func LocalRegister(id state.IdentityID, authServer *Server, additionalPrincipals, dnsNames []string, remoteAddr string, systemRoles []types.SystemRole) (*state.Identity, error) { - // TODO(nklaassen): split SSH and TLS keys for host identities. key, err := cryptosuites.GenerateKey(context.Background(), cryptosuites.GetCurrentSuiteFromAuthPreference(authServer), cryptosuites.HostIdentity) if err != nil { return nil, trace.Wrap(err) From bdb5a268c7d8ba0aff363a7dcffd2be32e302331 Mon Sep 17 00:00:00 2001 From: Nic Klaassen Date: Tue, 1 Oct 2024 12:42:58 -0700 Subject: [PATCH 2/3] use constant Co-authored-by: Zac Bergquist --- lib/auth/join/iam/endpoints.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/auth/join/iam/endpoints.go b/lib/auth/join/iam/endpoints.go index bcaad06b0f1e9..de03ca95e597f 100644 --- a/lib/auth/join/iam/endpoints.go +++ b/lib/auth/join/iam/endpoints.go @@ -75,7 +75,7 @@ var ( // FIPSSTSEndpoints returns the set of known valid FIPS AWS STS endpoints. FIPSSTSEndpoints = sync.OnceValue(func() []string { return []string{ - "sts-fips.us-east-1.amazonaws.com", + fipsSTSEndpointUSEast1, "sts-fips.us-east-2.amazonaws.com", "sts-fips.us-west-1.amazonaws.com", "sts-fips.us-west-2.amazonaws.com", From ca662d284d55aa54e67e9b2ca8a7fec9a9760046 Mon Sep 17 00:00:00 2001 From: Nic Klaassen Date: Tue, 1 Oct 2024 16:14:51 -0700 Subject: [PATCH 3/3] fix lint --- lib/auth/join/iam/iam_test.go | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/lib/auth/join/iam/iam_test.go b/lib/auth/join/iam/iam_test.go index f27735c9baaa9..ca9dfa6ae17bd 100644 --- a/lib/auth/join/iam/iam_test.go +++ b/lib/auth/join/iam/iam_test.go @@ -1,3 +1,19 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + package iam_test import ( @@ -8,9 +24,10 @@ import ( "os" "testing" + "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/lib/auth/join/iam" "github.com/gravitational/teleport/lib/utils/aws" - "github.com/stretchr/testify/require" ) func TestCreateSignedSTSIdentityRequest(t *testing.T) { @@ -76,7 +93,7 @@ func TestCreateSignedSTSIdentityRequest(t *testing.T) { // endpoint in us-east-1, but this will fail if the AWS credentials // were issued by the AWS China partition because they will not be // recognized by STS in the default partition. It will fail when - // Auth sends the request to AWS, but this unit test only exercizes + // Auth sends the request to AWS, but this unit test only exercises // the client-side request generation. expectEndpoint: "sts-fips.us-east-1.amazonaws.com", },