diff --git a/aws/vpc.go b/aws/vpc.go index f95930a4c..6e2489f17 100644 --- a/aws/vpc.go +++ b/aws/vpc.go @@ -1,10 +1,11 @@ package aws + import ( "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/aws" "errors" - "github.com/gruntwork-io/terratest/util" + "strconv" ) var VpcIdFilterName = "vpc-id" @@ -15,25 +16,28 @@ type Vpc struct { Subnets []ec2.Subnet // A list of subnets in the VPC } -func GetRandomVpc(awsRegion string) (Vpc, error) { +var IS_DEFAULT_FILTER_NAME = "isDefault" +var IS_DEFAULT_FILTER_VALUE = "true" + +func GetDefaultVpc(awsRegion string) (Vpc, error) { vpc := Vpc{} svc := ec2.New(session.New(), aws.NewConfig().WithRegion(awsRegion)) - vpcs, err := svc.DescribeVpcs(&ec2.DescribeVpcsInput{}) + defaultVpcFilter := ec2.Filter{Name: &IS_DEFAULT_FILTER_NAME, Values: []*string{&IS_DEFAULT_FILTER_VALUE}} + vpcs, err := svc.DescribeVpcs(&ec2.DescribeVpcsInput{Filters: []*ec2.Filter{&defaultVpcFilter}}) if err != nil { return vpc, err } numVpcs := len(vpcs.Vpcs) - if numVpcs == 0 { - return vpc, errors.New("No VPCs found in region " + awsRegion) + if numVpcs != 1 { + return vpc, errors.New("Expected to find one default VPC in region " + awsRegion + " but found " + strconv.Itoa(numVpcs)) } - randomIndex := util.Random(0, numVpcs) - randomVpc := vpcs.Vpcs[randomIndex] + defaultVpc := vpcs.Vpcs[0] - vpc.Id = *randomVpc.VpcId - vpc.Name = FindVpcName(randomVpc) + vpc.Id = *defaultVpc.VpcId + vpc.Name = FindVpcName(defaultVpc) vpc.Subnets, err = GetSubnetsForVpc(vpc.Id, awsRegion) return vpc, err diff --git a/rand_resources.go b/rand_resources.go index e07cb59ca..fb0dede1c 100644 --- a/rand_resources.go +++ b/rand_resources.go @@ -100,6 +100,6 @@ func (r *RandomResourceCollection) GetRandomPrivateCidrBlock(prefix int) string return util.GetRandomPrivateCidrBlock(prefix) } -func (r *RandomResourceCollection) GetRandomVpc() (aws.Vpc, error) { - return aws.GetRandomVpc(r.AwsRegion) +func (r *RandomResourceCollection) GetDefaultVpc() (aws.Vpc, error) { + return aws.GetDefaultVpc(r.AwsRegion) } diff --git a/rand_resources_test.go b/rand_resources_test.go index d174cef14..14ee00ca3 100644 --- a/rand_resources_test.go +++ b/rand_resources_test.go @@ -121,7 +121,7 @@ func TestAllParametersSet(t *testing.T) { } } -func TestGetRandomVpc(t *testing.T) { +func TestGetDefaultVpc(t *testing.T) { t.Parallel() ro := NewRandomResourceCollectionOptions() @@ -130,7 +130,7 @@ func TestGetRandomVpc(t *testing.T) { t.Fatalf("Failed to create RandomResourceCollection: %s", err.Error()) } - vpc, err := rand.GetRandomVpc() + vpc, err := rand.GetDefaultVpc() if err != nil { t.Fatalf("Failed to get random VPC: %s", err.Error()) }