Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: update IAM join method to use aws-sdk-go-v2 #47044

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 22 additions & 21 deletions integration/ec2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ import (

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -61,16 +60,16 @@ func newSilentLogger() utils.Logger {
return logger
}

func newNodeConfig(t *testing.T, authAddr utils.NetAddr, tokenName string, joinMethod types.JoinMethod) *servicecfg.Config {
func newNodeConfig(t *testing.T, tokenName string, joinMethod types.JoinMethod) *servicecfg.Config {
config := servicecfg.MakeDefaultConfig()
config.Version = defaults.TeleportConfigVersionV3
config.SetToken(tokenName)
config.JoinMethod = joinMethod
config.SSH.Enabled = true
config.SSH.Addr.Addr = helpers.NewListener(t, service.ListenerNodeSSH, &config.FileDescriptors)
config.Auth.Enabled = false
config.Proxy.Enabled = false
config.DataDir = t.TempDir()
config.SetAuthServerAddress(authAddr)
config.Log = newSilentLogger()
config.CircuitBreakerConfig = breaker.NoopBreakerConfig()
config.InstanceMetadataClient = cloudimds.NewDisabledIMDSClient()
Expand All @@ -79,7 +78,7 @@ func newNodeConfig(t *testing.T, authAddr utils.NetAddr, tokenName string, joinM

func newProxyConfig(t *testing.T, authAddr utils.NetAddr, tokenName string, joinMethod types.JoinMethod) *servicecfg.Config {
config := servicecfg.MakeDefaultConfig()
config.Version = defaults.TeleportConfigVersionV2
config.Version = defaults.TeleportConfigVersionV3
config.SetToken(tokenName)
config.JoinMethod = joinMethod
config.SSH.Enabled = false
Expand Down Expand Up @@ -109,6 +108,7 @@ func newAuthConfig(t *testing.T, clock clockwork.Clock) *servicecfg.Config {
}

config := servicecfg.MakeDefaultConfig()
config.Version = defaults.TeleportConfigVersionV3
config.DataDir = t.TempDir()
config.Auth.ListenAddr.Addr = helpers.NewListener(t, service.ListenerAuth, &config.FileDescriptors)
config.Auth.ClusterName, err = services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{
Expand Down Expand Up @@ -140,13 +140,11 @@ func getIID(ctx context.Context, t *testing.T) imds.InstanceIdentityDocument {
return output.InstanceIdentityDocument
}

func getCallerIdentity(t *testing.T) *sts.GetCallerIdentityOutput {
sess, err := session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
})
func getCallerIdentity(ctx context.Context, t *testing.T) *sts.GetCallerIdentityOutput {
cfg, err := config.LoadDefaultConfig(ctx)
require.NoError(t, err)
stsService := sts.New(sess)
output, err := stsService.GetCallerIdentity(nil /*input*/)
stsClient := sts.NewFromConfig(cfg)
output, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
require.NoError(t, err)
return output
}
Expand Down Expand Up @@ -201,7 +199,8 @@ func TestEC2NodeJoin(t *testing.T) {
require.Empty(t, nodes)

// create and start the node
nodeConfig := newNodeConfig(t, authConfig.Auth.ListenAddr, tokenName, types.JoinMethodEC2)
nodeConfig := newNodeConfig(t, tokenName, types.JoinMethodEC2)
nodeConfig.SetAuthServerAddress(authConfig.Auth.ListenAddr)
nodeSvc, err := service.NewTeleport(nodeConfig)
require.NoError(t, err)
require.NoError(t, nodeSvc.Start())
Expand All @@ -214,7 +213,7 @@ func TestEC2NodeJoin(t *testing.T) {
require.Eventually(t, func() bool {
nodes, _ := authServer.GetNodes(ctx, apidefaults.Namespace)
return len(nodes) > 0
}, time.Minute, time.Second, "waiting for node to join cluster")
}, 10*time.Second, 50*time.Millisecond, "waiting for node to join cluster")
}

// TestIAMNodeJoin is an integration test which asserts that the IAM method for
Expand All @@ -225,6 +224,7 @@ func TestIAMNodeJoin(t *testing.T) {
if os.Getenv("TELEPORT_TEST_EC2") == "" {
t.Skipf("Skipping TestIAMNodeJoin because TELEPORT_TEST_EC2 is not set")
}
ctx := context.Background()

// create and start the auth server
authConfig := newAuthConfig(t, nil /*clock*/)
Expand All @@ -236,7 +236,7 @@ func TestIAMNodeJoin(t *testing.T) {
authServer := authSvc.GetAuthServer()

// fetch the caller identity to find the AWS account and create the token
id := getCallerIdentity(t)
id := getCallerIdentity(ctx, t)

tokenName := "test_token"
token, err := types.NewProvisionTokenFromSpec(
Expand All @@ -253,7 +253,7 @@ func TestIAMNodeJoin(t *testing.T) {
})
require.NoError(t, err)

err = authServer.UpsertToken(context.Background(), token)
err = authServer.UpsertToken(ctx, token)
require.NoError(t, err)

// sanity check there are no proxies to start with
Expand All @@ -274,31 +274,32 @@ func TestIAMNodeJoin(t *testing.T) {
proxies, err := authServer.GetProxies()
assert.NoError(t, err)
assert.NotEmpty(t, proxies)
}, time.Minute, time.Second, "waiting for proxy to join cluster")
}, 10*time.Second, 50*time.Millisecond, "waiting for proxy to join cluster")
// InsecureDevMode needed for node to trust proxy
wasInsecureDevMode := lib.IsInsecureDevMode()
t.Cleanup(func() { lib.SetInsecureDevMode(wasInsecureDevMode) })
lib.SetInsecureDevMode(true)

// sanity check there are no nodes to start with
nodes, err := authServer.GetNodes(context.Background(), apidefaults.Namespace)
nodes, err := authServer.GetNodes(ctx, apidefaults.Namespace)
require.NoError(t, err)
require.Empty(t, nodes)

// create and start a node, with use the IAM method to join in IoT mode by
// create and start a node, will use the IAM method to join in IoT mode by
// connecting to the proxy
nodeConfig := newNodeConfig(t, proxyConfig.Proxy.WebAddr, tokenName, types.JoinMethodIAM)
nodeConfig := newNodeConfig(t, tokenName, types.JoinMethodIAM)
nodeConfig.ProxyServer = proxyConfig.Proxy.WebAddr
nodeSvc, err := service.NewTeleport(nodeConfig)
require.NoError(t, err)
require.NoError(t, nodeSvc.Start())
t.Cleanup(func() { require.NoError(t, nodeSvc.Close()) })

// the node should eventually join the cluster and heartbeat
require.EventuallyWithT(t, func(t *assert.CollectT) {
nodes, err := authServer.GetNodes(context.Background(), apidefaults.Namespace)
nodes, err := authServer.GetNodes(ctx, apidefaults.Namespace)
assert.NoError(t, err)
assert.NotEmpty(t, nodes)
}, time.Minute, time.Second, "waiting for node to join cluster")
}, 10*time.Second, 50*time.Millisecond, "waiting for node to join cluster")
}

type mockIMDSClient struct {
Expand Down
6 changes: 4 additions & 2 deletions integration/teleterm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,8 @@ func testWaitForConnectMyComputerNodeJoin(t *testing.T, pack *dbhelpers.Database
}()

// Start the new node.
nodeConfig := newNodeConfig(t, pack.Root.Cluster.Config.Auth.ListenAddr, "token", types.JoinMethodToken)
nodeConfig := newNodeConfig(t, "token", types.JoinMethodToken)
nodeConfig.SetAuthServerAddress(pack.Root.Cluster.Config.Auth.ListenAddr)
nodeConfig.DataDir = filepath.Join(agentsDir, profileName, "data")
nodeConfig.Log = libutils.NewLoggerForTests()
nodeSvc, err := service.NewTeleport(nodeConfig)
Expand Down Expand Up @@ -1031,7 +1032,8 @@ func testDeleteConnectMyComputerNode(t *testing.T, pack *dbhelpers.DatabasePack)
require.NoError(t, err)

// Start the new node.
nodeConfig := newNodeConfig(t, pack.Root.Cluster.Config.Auth.ListenAddr, "token", types.JoinMethodToken)
nodeConfig := newNodeConfig(t, "token", types.JoinMethodToken)
nodeConfig.SetAuthServerAddress(pack.Root.Cluster.Config.Auth.ListenAddr)
nodeConfig.DataDir = filepath.Join(agentsDir, profileName, "data")
nodeConfig.Log = libutils.NewLoggerForTests()
nodeSvc, err := service.NewTeleport(nodeConfig)
Expand Down
23 changes: 10 additions & 13 deletions lib/auth/join/iam/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ package iam
import "sync"

var (
// ValidSTSEndpoints holds a sorted list of all known valid public endpoints for
// the AWS STS service. You can generate this list by running
// $ go run github.com/nklaassen/sts-endpoints@latest --go-list
// Update aws-sdk-go in that package to learn about new endpoints.
// ValidSTSEndpoints returns a sorted list of all known valid public endpoints for
// the AWS STS service.
//
// TODO(nklaassen): find a better way to validate STS endpoints or generate
// this list and get notified when it needs to be updated. The original
// solution was https://github.com/nklaassen/sts-endpoints which is based on
// aws-sdk-go v1 which no longer gets updates for new regions.
ValidSTSEndpoints = sync.OnceValue(func() []string {
return []string{
"sts-fips.us-east-1.amazonaws.com",
Expand Down Expand Up @@ -69,15 +72,7 @@ var (
}
})

GlobalSTSEndpoints = sync.OnceValue(func() []string {
return []string{
"sts.amazonaws.com",
// This is not a real endpoint, but the SDK will select it if
// AWS_USE_FIPS_ENDPOINT is set and a region is not.
"sts-fips.aws-global.amazonaws.com",
}
})

// FIPSSTSEndpoints returns the set of known valid FIPS AWS STS endpoints.
FIPSSTSEndpoints = sync.OnceValue(func() []string {
return []string{
"sts-fips.us-east-1.amazonaws.com",
nklaassen marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -89,3 +84,5 @@ var (
}
})
)

const fipsSTSEndpointUSEast1 = "sts-fips.us-east-1.amazonaws.com"
Loading
Loading