Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Account ID and random VPC in RandomResourceCollection #8

Merged
merged 3 commits into from
Mar 28, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions aws/account.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package aws
import (
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/aws/session"
"strings"
"errors"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Order of imports.


// Get the Account ID for the currently logged in IAM User.
func GetAccountId() (string, error) {
svc := iam.New(session.New())
user, err := svc.GetUser(&iam.GetUserInput{})
if err != nil {
return "", err
}

return extractAccountIdFromArn(*user.User.Arn)
}

// An IAM arn is of the format arn:aws:iam::123456789012:user/test. The account id is the number after arn:aws:iam::,
// so we split on a colon and return the 5th item.
func extractAccountIdFromArn(arn string) (string, error) {
arnParts := strings.Split(arn, ":")

if len(arnParts) < 5 {
return "", errors.New("Unrecognized format for IAM ARN: " + arn)
}

return arnParts[4], nil
}
29 changes: 29 additions & 0 deletions aws/account_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package aws

import "testing"

func TestExtractAccountIdFromValidArn(t *testing.T) {
t.Parallel()

expectedAccountId := "123456789012"
arn := "arn:aws:iam::" + expectedAccountId + ":user/test"

actualAccountId, err := extractAccountIdFromArn(arn)
if err != nil {
t.Fatalf("Unexpected error while extracting account id from arn %s: %s", arn, err)
}

if actualAccountId != expectedAccountId {
t.Fatalf("Did not get expected account id. Expected: %s. Actual: %s.", expectedAccountId, actualAccountId)
}
}

func TestExtractAccountIdFromInvalidArn(t *testing.T) {
t.Parallel()

_, err := extractAccountIdFromArn("invalidArn")
if err == nil {
t.Fatalf("Expected an error when extracting an account id from an invalid ARN, but got nil")
}
}

75 changes: 75 additions & 0 deletions aws/vpc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package aws
import (
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/aws"
"errors"
"github.com/gruntwork-io/terratest/util"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

order of imports.


var VpcIdFilterName = "vpc-id"

type Vpc struct {
Id string // The ID of the VPC
Name string // The name of the VPC
SubnetIds []string // A list of subnet ids in the VPC
}

func GetRandomVpc(awsRegion string) (Vpc, error) {
vpc := Vpc{}

svc := ec2.New(session.New(), aws.NewConfig().WithRegion(awsRegion))
vpcs, err := svc.DescribeVpcs(&ec2.DescribeVpcsInput{})
if err != nil {
return vpc, err
}

numVpcs := len(vpcs.Vpcs)
if numVpcs == 0 {
return vpc, errors.New("No VPCs found in region " + awsRegion)
}

randomIndex := util.Random(0, numVpcs)
randomVpc := vpcs.Vpcs[randomIndex]

vpc.Id = *randomVpc.VpcId
vpc.Name = FindVpcName(randomVpc)

vpc.SubnetIds, err = GetSubnetIdsForVpc(vpc.Id, awsRegion)
if err != nil {
return vpc, err
}

return vpc, nil
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate why we need this function? Is the idea that you don't want to assume a default VPC exists in each region?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't particularly care which VPC we use, as long as it exists and has at least a couple subnets. I went with a "random" VPC since it seemed consistent with everything else in the RandomResourceCollection, but I suppose the function could be changed to GetDefaultVpc too, assuming it's easy enough to use Go's search filters to find the default VPC in each region. I don't see a very compelling argument in one direction or the other, but let me know if you think changing to a default VPC would offer an advantage.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, I just figured out exactly why I shouldn't be using a random VPC. Our tests create and destroy VPCs all the time, so it's possible that one test will accidentally try to use a VPC that another test is trying to destroy. Whoops :)

https://www.pivotaltracker.com/story/show/116758059

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #13

func FindVpcName(vpc *ec2.Vpc) string {
for _, tag := range vpc.Tags {
if *tag.Key == "Name" {
return *tag.Value
}
}

if *vpc.IsDefault {
return "Default"
}

return ""
}

func GetSubnetIdsForVpc(vpcId string, awsRegion string) ([]string, error) {
subnets := []string{}

svc := ec2.New(session.New(), aws.NewConfig().WithRegion(awsRegion))

vpcIdFilter := ec2.Filter{Name: &VpcIdFilterName, Values: []*string{&vpcId}}
subnetOutput, err := svc.DescribeSubnets(&ec2.DescribeSubnetsInput{Filters: []*ec2.Filter{&vpcIdFilter}})
if err != nil {
return subnets, err
}

for _, subnet := range subnetOutput.Subnets {
subnets = append(subnets, *subnet.SubnetId)
}
return subnets, nil
}
15 changes: 12 additions & 3 deletions rand_resources.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ package terratest

import (
"fmt"

"github.com/gruntwork-io/terratest/aws"
"github.com/gruntwork-io/terratest/util"
"strings"
"strings"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Order of imports.


// A RandomResourceCollection is simply a typed holder for random resources we need as we do a Terraform run.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is from my original comment, but I think a clearer comment would be:

// A RandomResourceCollection contains various resources we need as we do a Terraform run.
// Some of these resources are dynamically generated (e.g. KeyPair) and others are randomly selected (e.g. AwsRegion).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expand All @@ -15,6 +14,7 @@ type RandomResourceCollection struct {
AwsRegion string // The AWS Region
KeyPair *Ec2Keypair // The EC2 KeyPair created in AWS
AmiId string // A random AMI ID valid for the AwsRegion
AccountId string // The AWS account ID
}

// Represents an EC2 KeyPair created in AWS
Expand Down Expand Up @@ -62,6 +62,11 @@ func CreateRandomResourceCollection(ro *RandomResourceCollectionOpts) (*RandomRe

r.KeyPair = ec2KeyPair

r.AccountId, err = aws.GetAccountId()
if err != nil {
return r, fmt.Errorf("Failed to get AWS Account Id: %s\n", err.Error())
}

return r, nil
}

Expand Down Expand Up @@ -92,4 +97,8 @@ func (r *RandomResourceCollection) FetchAwsAvailabilityZonesAsString() string {

func (r *RandomResourceCollection) GetRandomPrivateCidrBlock(prefix int) string {
return util.GetRandomPrivateCidrBlock(prefix)
}
}

func (r *RandomResourceCollection) GetRandomVpc() (aws.Vpc, error) {
return aws.GetRandomVpc(r.AwsRegion)
}
57 changes: 57 additions & 0 deletions rand_resources_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,61 @@ func TestGetRandomPrivateCidrBlock(t *testing.T) {
if actualPrefix != expPrefix {
t.Fatalf("Expected: %s, but received: %s", expPrefix, actualPrefix)
}
}

func TestAllParametersSet(t *testing.T) {
t.Parallel()

ro := NewRandomResourceCollectionOptions()
rand, err := CreateRandomResourceCollection(ro)
if err != nil {
t.Fatalf("Failed to create RandomResourceCollection: %s", err.Error())
}

if len(rand.AccountId) == 0 {
t.Fatalf("CreateRandomResourceCollection has an empty AccountId: %s", rand)
}

if len(rand.AmiId) == 0 {
t.Fatalf("CreateRandomResourceCollection has an empty AMI ID: %s", rand)
}

if len(rand.AwsRegion) == 0 {
t.Fatalf("CreateRandomResourceCollection has an empty region: %s", rand)
}

if len(rand.UniqueId) == 0 {
t.Fatalf("CreateRandomResourceCollection has an empty Unique Id: %s", rand)
}

if rand.KeyPair == nil {
t.Fatalf("CreateRandomResourceCollection has a nil Key Pair: %s", rand)
}
}

func TestGetRandomVpc(t *testing.T) {
t.Parallel()

ro := NewRandomResourceCollectionOptions()
rand, err := CreateRandomResourceCollection(ro)
if err != nil {
t.Fatalf("Failed to create RandomResourceCollection: %s", err.Error())
}

vpc, err := rand.GetRandomVpc()
if err != nil {
t.Fatalf("Failed to get random VPC: %s", err.Error())
}

if vpc.Id == "" {
t.Fatalf("GetRandomVpc returned a VPC without an ID: %s", vpc)
}

if vpc.Name == "" {
t.Fatalf("GetRandomVpc returned a VPC without a name: %s", vpc)
}

if len(vpc.SubnetIds) == 0 {
t.Fatalf("GetRandomVpc returned a VPC with no subnets: %s", vpc)
}
}