diff --git a/aws/account.go b/aws/account.go new file mode 100644 index 000000000..7ecc05766 --- /dev/null +++ b/aws/account.go @@ -0,0 +1,30 @@ +package aws +import ( + "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go/aws/session" + "strings" + "errors" +) + +// Get the Account ID for the currently logged in IAM User. +func GetAccountId() (string, error) { + svc := iam.New(session.New()) + user, err := svc.GetUser(&iam.GetUserInput{}) + if err != nil { + return "", err + } + + return extractAccountIdFromArn(*user.User.Arn) +} + +// An IAM arn is of the format arn:aws:iam::123456789012:user/test. The account id is the number after arn:aws:iam::, +// so we split on a colon and return the 5th item. +func extractAccountIdFromArn(arn string) (string, error) { + arnParts := strings.Split(arn, ":") + + if len(arnParts) < 5 { + return "", errors.New("Unrecognized format for IAM ARN: " + arn) + } + + return arnParts[4], nil +} diff --git a/aws/account_test.go b/aws/account_test.go new file mode 100644 index 000000000..c4a348c92 --- /dev/null +++ b/aws/account_test.go @@ -0,0 +1,29 @@ +package aws + +import "testing" + +func TestExtractAccountIdFromValidArn(t *testing.T) { + t.Parallel() + + expectedAccountId := "123456789012" + arn := "arn:aws:iam::" + expectedAccountId + ":user/test" + + actualAccountId, err := extractAccountIdFromArn(arn) + if err != nil { + t.Fatalf("Unexpected error while extracting account id from arn %s: %s", arn, err) + } + + if actualAccountId != expectedAccountId { + t.Fatalf("Did not get expected account id. Expected: %s. Actual: %s.", expectedAccountId, actualAccountId) + } +} + +func TestExtractAccountIdFromInvalidArn(t *testing.T) { + t.Parallel() + + _, err := extractAccountIdFromArn("invalidArn") + if err == nil { + t.Fatalf("Expected an error when extracting an account id from an invalid ARN, but got nil") + } +} + diff --git a/aws/vpc.go b/aws/vpc.go new file mode 100644 index 000000000..1c8b996aa --- /dev/null +++ b/aws/vpc.go @@ -0,0 +1,75 @@ +package aws +import ( + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/aws" + "errors" + "github.com/gruntwork-io/terratest/util" +) + +var VpcIdFilterName = "vpc-id" + +type Vpc struct { + Id string // The ID of the VPC + Name string // The name of the VPC + SubnetIds []string // A list of subnet ids in the VPC +} + +func GetRandomVpc(awsRegion string) (Vpc, error) { + vpc := Vpc{} + + svc := ec2.New(session.New(), aws.NewConfig().WithRegion(awsRegion)) + vpcs, err := svc.DescribeVpcs(&ec2.DescribeVpcsInput{}) + if err != nil { + return vpc, err + } + + numVpcs := len(vpcs.Vpcs) + if numVpcs == 0 { + return vpc, errors.New("No VPCs found in region " + awsRegion) + } + + randomIndex := util.Random(0, numVpcs) + randomVpc := vpcs.Vpcs[randomIndex] + + vpc.Id = *randomVpc.VpcId + vpc.Name = FindVpcName(randomVpc) + + vpc.SubnetIds, err = GetSubnetIdsForVpc(vpc.Id, awsRegion) + if err != nil { + return vpc, err + } + + return vpc, nil +} + +func FindVpcName(vpc *ec2.Vpc) string { + for _, tag := range vpc.Tags { + if *tag.Key == "Name" { + return *tag.Value + } + } + + if *vpc.IsDefault { + return "Default" + } + + return "" +} + +func GetSubnetIdsForVpc(vpcId string, awsRegion string) ([]string, error) { + subnets := []string{} + + svc := ec2.New(session.New(), aws.NewConfig().WithRegion(awsRegion)) + + vpcIdFilter := ec2.Filter{Name: &VpcIdFilterName, Values: []*string{&vpcId}} + subnetOutput, err := svc.DescribeSubnets(&ec2.DescribeSubnetsInput{Filters: []*ec2.Filter{&vpcIdFilter}}) + if err != nil { + return subnets, err + } + + for _, subnet := range subnetOutput.Subnets { + subnets = append(subnets, *subnet.SubnetId) + } + return subnets, nil +} \ No newline at end of file diff --git a/rand_resources.go b/rand_resources.go index 0e5c7c05c..fe9d5c41a 100644 --- a/rand_resources.go +++ b/rand_resources.go @@ -3,10 +3,9 @@ package terratest import ( "fmt" - "github.com/gruntwork-io/terratest/aws" "github.com/gruntwork-io/terratest/util" -"strings" + "strings" ) // A RandomResourceCollection is simply a typed holder for random resources we need as we do a Terraform run. @@ -15,6 +14,7 @@ type RandomResourceCollection struct { AwsRegion string // The AWS Region KeyPair *Ec2Keypair // The EC2 KeyPair created in AWS AmiId string // A random AMI ID valid for the AwsRegion + AccountId string // The AWS account ID } // Represents an EC2 KeyPair created in AWS @@ -62,6 +62,11 @@ func CreateRandomResourceCollection(ro *RandomResourceCollectionOpts) (*RandomRe r.KeyPair = ec2KeyPair + r.AccountId, err = aws.GetAccountId() + if err != nil { + return r, fmt.Errorf("Failed to get AWS Account Id: %s\n", err.Error()) + } + return r, nil } @@ -92,4 +97,8 @@ func (r *RandomResourceCollection) FetchAwsAvailabilityZonesAsString() string { func (r *RandomResourceCollection) GetRandomPrivateCidrBlock(prefix int) string { return util.GetRandomPrivateCidrBlock(prefix) -} \ No newline at end of file +} + +func (r *RandomResourceCollection) GetRandomVpc() (aws.Vpc, error) { + return aws.GetRandomVpc(r.AwsRegion) +} diff --git a/rand_resources_test.go b/rand_resources_test.go index ec0b02876..6fe44f148 100644 --- a/rand_resources_test.go +++ b/rand_resources_test.go @@ -89,4 +89,61 @@ func TestGetRandomPrivateCidrBlock(t *testing.T) { if actualPrefix != expPrefix { t.Fatalf("Expected: %s, but received: %s", expPrefix, actualPrefix) } +} + +func TestAllParametersSet(t *testing.T) { + t.Parallel() + + ro := NewRandomResourceCollectionOptions() + rand, err := CreateRandomResourceCollection(ro) + if err != nil { + t.Fatalf("Failed to create RandomResourceCollection: %s", err.Error()) + } + + if len(rand.AccountId) == 0 { + t.Fatalf("CreateRandomResourceCollection has an empty AccountId: %s", rand) + } + + if len(rand.AmiId) == 0 { + t.Fatalf("CreateRandomResourceCollection has an empty AMI ID: %s", rand) + } + + if len(rand.AwsRegion) == 0 { + t.Fatalf("CreateRandomResourceCollection has an empty region: %s", rand) + } + + if len(rand.UniqueId) == 0 { + t.Fatalf("CreateRandomResourceCollection has an empty Unique Id: %s", rand) + } + + if rand.KeyPair == nil { + t.Fatalf("CreateRandomResourceCollection has a nil Key Pair: %s", rand) + } +} + +func TestGetRandomVpc(t *testing.T) { + t.Parallel() + + ro := NewRandomResourceCollectionOptions() + rand, err := CreateRandomResourceCollection(ro) + if err != nil { + t.Fatalf("Failed to create RandomResourceCollection: %s", err.Error()) + } + + vpc, err := rand.GetRandomVpc() + if err != nil { + t.Fatalf("Failed to get random VPC: %s", err.Error()) + } + + if vpc.Id == "" { + t.Fatalf("GetRandomVpc returned a VPC without an ID: %s", vpc) + } + + if vpc.Name == "" { + t.Fatalf("GetRandomVpc returned a VPC without a name: %s", vpc) + } + + if len(vpc.SubnetIds) == 0 { + t.Fatalf("GetRandomVpc returned a VPC with no subnets: %s", vpc) + } } \ No newline at end of file