diff --git a/http/http_helper.go b/http/http_helper.go index ffd594447..46f511f46 100644 --- a/http/http_helper.go +++ b/http/http_helper.go @@ -1,13 +1,13 @@ package http_helper import ( - "errors" "time" "net/http" "io/ioutil" "strings" - "strconv" "log" + "github.com/gruntwork-io/terratest/util" + "fmt" ) func HttpGet(url string, logger *log.Logger) (int, string, error) { @@ -29,28 +29,20 @@ func HttpGet(url string, logger *log.Logger) (int, string, error) { } func HttpGetWithRetry(url string, expectedBody string, retries int, sleepBetweenRetries time.Duration, logger *log.Logger) error { - for i := 0; i < retries; i++ { + _, err := util.DoWithRetry(fmt.Sprintf("HTTP GET to URL %s", url), retries, sleepBetweenRetries, logger, func() (string, error) { status, body, err := HttpGet(url, logger) - if err == nil && status == 200 { - logger.Println("Got 200 OK from URL", url) - if body == expectedBody { - logger.Println("Got expected body from URL", url, ":", body) - return nil - } else { - logger.Println("Did not get expected body from URL", url, ". Expected:", expectedBody, ". Got:", body, ".") - } - } - if err != nil { - logger.Println("Got an error after making an HTTP get to URL", url, ":", err) + return "", err } else if status != 200 { - logger.Println("Got a non-200 response from URL", url, ":", status) + return "", fmt.Errorf("Expected a 200 response but got %d", status) + } else if body != expectedBody { + return "", fmt.Errorf("Got a 200 response, but did not get expected body. Expected: %s. Got: %s.", expectedBody, body) + } else { + logger.Printf("Got 200 a response from URL %s and expected body: %s\n", url, body) + return body, nil } + }) - logger.Println("Will retry in", sleepBetweenRetries) - time.Sleep(sleepBetweenRetries) - } - - return errors.New("Did not get a 200 OK from URL " + url + " after " + strconv.Itoa(retries) + " retries.") + return err } diff --git a/ssh/session.go b/ssh/session.go new file mode 100644 index 000000000..cea223f97 --- /dev/null +++ b/ssh/session.go @@ -0,0 +1,94 @@ +package ssh + +import ( + "golang.org/x/crypto/ssh" + "fmt" + "io" + "net" + "reflect" + "log" +) + +type SshConnectionOptions struct { + Username string + Address string + Port int + AuthMethods []ssh.AuthMethod + Command string + JumpHost *SshConnectionOptions +} + +func (options *SshConnectionOptions) ConnectionString() string { + return fmt.Sprintf("%s:%d", options.Address, options.Port) +} + +// A container object for all resources created by an SSH session. The reason we need this is so that we can do a +// single defer in a top-level method that calls the Cleanup method to go through and ensure all of these resources are +// released and cleaned up. +type SshSession struct { + Options *SshConnectionOptions + Client *ssh.Client + Session *ssh.Session + JumpHost *JumpHostSession +} + +func (sshSession *SshSession) Cleanup(logger *log.Logger) { + if sshSession == nil { + return + } + + + // Closing the session may result in an EOF error if it's already closed (e.g. due to hitting CTRL + D), so + // don't report those errors, as there is nothing actually wrong in that case. + Close(sshSession.Session, logger, io.EOF.Error()) + Close(sshSession.Client, logger) + sshSession.JumpHost.Cleanup(logger) +} + +type JumpHostSession struct { + JumpHostClient *ssh.Client + HostVirtualConnection net.Conn + HostConnection ssh.Conn +} + +func (jumpHost *JumpHostSession) Cleanup(logger *log.Logger) { + if jumpHost == nil { + return + } + + // Closing a connection may result in an EOF error if it's already closed (e.g. due to hitting CTRL + D), so + // don't report those errors, as there is nothing actually wrong in that case. + Close(jumpHost.HostConnection, logger, io.EOF.Error()) + Close(jumpHost.HostVirtualConnection, logger, io.EOF.Error()) + Close(jumpHost.JumpHostClient, logger) +} + +type Closeable interface { + Close() error +} + +func Close(closeable Closeable, logger *log.Logger, ignoreErrors ...string) { + if interfaceIsNil(closeable) { + return + } + + if err := closeable.Close(); err != nil && !contains(ignoreErrors, err.Error()) { + logger.Printf("Error closing %s: %s", closeable, err.Error()) + } +} + +func contains(haystack []string, needle string) bool { + for _, hay := range haystack { + if hay == needle { + return true + } + } + return false +} + +// Go is a shitty language. Checking an interface directly against nil does not work, and if you don't know the exact +// types the interface may be ahead of time, the only way to know if you're dealing with nil is to use reflection. +// http://stackoverflow.com/questions/13476349/check-for-nil-and-nil-interface-in-go +func interfaceIsNil(i interface{}) bool { + return i == nil || reflect.ValueOf(i).IsNil() +} diff --git a/ssh/ssh.go b/ssh/ssh.go index a405998f4..d14603479 100644 --- a/ssh/ssh.go +++ b/ssh/ssh.go @@ -3,13 +3,7 @@ package ssh import ( "github.com/gruntwork-io/terratest" "log" - "github.com/gruntwork-io/terratest/shell" - "errors" - "strconv" - "io/ioutil" - "os" - "fmt" - "github.com/gruntwork-io/terratest/util" + "golang.org/x/crypto/ssh" ) type Host struct { @@ -26,90 +20,157 @@ func CheckSshConnection(host Host, logger *log.Logger) error { // Check that you can connect via SSH to the given host and run the given command. Returns the stdout/stderr. func CheckSshCommand(host Host, command string, logger *log.Logger) (string, error) { - keyPairWithUniqueName := createKeyPairCopyWithUniqueName(*host.SshKeyPair) - - defer cleanupKeyPairFile(keyPairWithUniqueName, logger) - writeKeyPairFile(keyPairWithUniqueName, logger) - - output, sshErr := shell.RunCommandAndGetOutput(shell.Command{Command: "ssh", Args: []string{"-i", keyPairWithUniqueName.Name, "-o", "UserKnownHostsFile=/dev/null", "-o", "StrictHostKeyChecking=no", host.SshUserName + "@" + host.Hostname, command}}, logger) - - exitCode, err := shell.GetExitCodeForRunCommandError(sshErr) - + authMethods, err := createAuthMethodsForHost(host) if err != nil { - return output, err + return "", err + } + + hostOptions := SshConnectionOptions{ + Username: host.SshUserName, + Address: host.Hostname, + Port: 22, + Command: command, + AuthMethods: authMethods, } - if exitCode != 0 { - return output, errors.New("SSH exited with a non-zero exit code: " + strconv.Itoa(exitCode)) + sshSession := &SshSession{ + Options: &hostOptions, + JumpHost: &JumpHostSession{}, } - return output, nil + defer sshSession.Cleanup(logger) + + return runSshCommand(sshSession) } // CheckPrivateSshConnection attempts to connect to privateHost (which is not addressable from the Internet) via a separate // publicHost (which is addressable from the Internet) and then executes "command" on privateHost and returns its output. // It is useful for checking that it's possible to SSH from a Bastion Host to a private instance. func CheckPrivateSshConnection(publicHost Host, privateHost Host, command string, logger *log.Logger) (string, error) { - publicKeyPairWithUniqueName := createKeyPairCopyWithUniqueName(*publicHost.SshKeyPair) - privateKeyPairWithUniqueName := createKeyPairCopyWithUniqueName(*privateHost.SshKeyPair) - - defer cleanupKeyPairFile(publicKeyPairWithUniqueName, logger) - writeKeyPairFile(publicKeyPairWithUniqueName, logger) + jumpHostAuthMethods, err := createAuthMethodsForHost(publicHost) + if err != nil { + return "", err + } - defer cleanupKeyPairFile(privateKeyPairWithUniqueName, logger) - writeKeyPairFile(privateKeyPairWithUniqueName, logger) + jumpHostOptions := SshConnectionOptions{ + Username: publicHost.SshUserName, + Address: publicHost.Hostname, + Port: 22, + AuthMethods: jumpHostAuthMethods, + } - // We need the SSH key to be available when we SSH from the Bastion Host to the Private Host. - // We cannot guarantee ssh-agent will be in the test environment, so we use scp to copy the key to the bastion host file system. - // Start by setting permissions on the key to 0600. These permissions (read/write for file owner only) are required by ssh to access the key. - chmodErr := shell.RunCommand(shell.Command{Command: "chmod", Args: []string{"0600", privateKeyPairWithUniqueName.Name}}, logger) - exitCode, err := shell.GetExitCodeForRunCommandError(chmodErr) + hostAuthMethods, err := createAuthMethodsForHost(privateHost) if err != nil { return "", err } - if exitCode != 0 { - return "", errors.New("Attempt to set permissions on local key file exited with a non-zero exit code: " + strconv.Itoa(exitCode)) + + hostOptions := SshConnectionOptions{ + Username: privateHost.SshUserName, + Address: privateHost.Hostname, + Port: 22, + Command: command, + AuthMethods: hostAuthMethods, + JumpHost: &jumpHostOptions, + } + + sshSession := &SshSession{ + Options: &hostOptions, + JumpHost: &JumpHostSession{}, + } + + defer sshSession.Cleanup(logger) + + return runSshCommand(sshSession) +} + +func runSshCommand(sshSession *SshSession) (string, error) { + if err := setupSshClient(sshSession); err != nil { + return "", err + } + + if err := setupSshSession(sshSession); err != nil { + return "", err } - // Upload the key to the bastion host - sshErr := shell.RunCommand(shell.Command{Command: "scp", Args: []string{"-p", "-i", publicKeyPairWithUniqueName.Name, "-o", "UserKnownHostsFile=/dev/null", "-o", "StrictHostKeyChecking=no", privateKeyPairWithUniqueName.Name, publicHost.SshUserName + "@" + publicHost.Hostname + ":key.pem"}}, logger) - exitCode, err = shell.GetExitCodeForRunCommandError(sshErr) + bytes, err := sshSession.Session.Output(sshSession.Options.Command) if err != nil { return "", err } - if exitCode != 0 { - return "", errors.New("Attempt to SSH and write key file exited with a non-zero exit code: " + strconv.Itoa(exitCode)) + + return string(bytes), nil +} + +func setupSshClient(sshSession *SshSession) error { + if sshSession.Options.JumpHost == nil { + return fillSshClientForHost(sshSession) + } else { + return fillSshClientForJumpHost(sshSession) } +} + +func fillSshClientForHost(sshSession *SshSession) error { + client, err := createSshClient(sshSession.Options) - // Now connect directly to the privateHost - output, sshErr := shell.RunCommandAndGetOutput(shell.Command{Command: "ssh", Args: []string{"-i", publicKeyPairWithUniqueName.Name, "-o", "UserKnownHostsFile=/dev/null", "-o", "StrictHostKeyChecking=no", publicHost.SshUserName + "@" + publicHost.Hostname, "ssh -i key.pem -o StrictHostKeyChecking=no", privateHost.SshUserName + "@" + privateHost.Hostname, command}}, logger) - exitCode, err = shell.GetExitCodeForRunCommandError(sshErr) if err != nil { - return output, err + return err } - if exitCode != 0 { - return output, errors.New("Attempt to SSH to private host exited with a non-zero exit code: " + strconv.Itoa(exitCode)) + + sshSession.Client = client + return nil +} + +func fillSshClientForJumpHost(sshSession *SshSession) error { + jumpHostClient, err := createSshClient(sshSession.Options.JumpHost) + if err != nil { + return err + } + sshSession.JumpHost.JumpHostClient = jumpHostClient + + hostVirtualConn, err := jumpHostClient.Dial("tcp", sshSession.Options.ConnectionString()) + if err != nil { + return err } + sshSession.JumpHost.HostVirtualConnection = hostVirtualConn - return output, nil + hostConn, hostIncomingChannels, hostIncomingRequests, err := ssh.NewClientConn(hostVirtualConn, sshSession.Options.ConnectionString(), createSshClientConfig(sshSession.Options)) + if err != nil { + return err + } + sshSession.JumpHost.HostConnection = hostConn + + sshSession.Client = ssh.NewClient(hostConn, hostIncomingChannels, hostIncomingRequests) + return nil +} + +func setupSshSession(sshSession *SshSession) error { + session, err := sshSession.Client.NewSession() + if err != nil { + return err + } + + sshSession.Session = session + return nil } -func writeKeyPairFile(keyPair terratest.Ec2Keypair, logger *log.Logger) error { - logger.Println("Creating test-time Key Pair file", keyPair.Name) - return ioutil.WriteFile(keyPair.Name, []byte(keyPair.PrivateKey), 0400) +func createSshClient(options *SshConnectionOptions) (*ssh.Client, error) { + sshClientConfig := createSshClientConfig(options) + return ssh.Dial("tcp", options.ConnectionString(), sshClientConfig) } -func cleanupKeyPairFile(keyPair terratest.Ec2Keypair, logger *log.Logger) error { - logger.Println("Cleaning up test-time Key Pair file", keyPair.Name) - return os.Remove(keyPair.Name) +func createSshClientConfig(hostOptions *SshConnectionOptions) *ssh.ClientConfig { + clientConfig := &ssh.ClientConfig{ + User: hostOptions.Username, + Auth: hostOptions.AuthMethods, + } + clientConfig.SetDefaults() + return clientConfig } -// Testing SSH connectivity involves writing and deleting Key Pair files on disk. Since there might be multiple SSH -// checks happening in parallel, we use this function to give the Key Pair file a unique name, and thereby avoid the -// files overwriting each other. -func createKeyPairCopyWithUniqueName(keyPair terratest.Ec2Keypair) terratest.Ec2Keypair { - // This automatically creates a shallow copy in Go - keyPairWithUniqueName := keyPair - keyPairWithUniqueName.Name = fmt.Sprintf("%s-%s", keyPairWithUniqueName.Name, util.UniqueId()) - return keyPairWithUniqueName -} \ No newline at end of file +func createAuthMethodsForHost(host Host) ([]ssh.AuthMethod, error) { + signer, err := ssh.ParsePrivateKey([]byte(host.SshKeyPair.PrivateKey)) + if err != nil { + return []ssh.AuthMethod{}, err + } + + return []ssh.AuthMethod{ssh.PublicKeys(signer)}, nil +} diff --git a/ssh/ssh_test.go b/ssh/ssh_test.go new file mode 100644 index 000000000..0736aed15 --- /dev/null +++ b/ssh/ssh_test.go @@ -0,0 +1,147 @@ +package ssh + +import ( + "testing" + "github.com/gruntwork-io/terratest" + "fmt" + terralog "github.com/gruntwork-io/terratest/log" + "strings" + "log" + "github.com/gruntwork-io/terratest/util" + "time" +) + +const TERRAFORM_OUTPUT_PUBLIC_IP = "example_public_ip" +const TERRAFORM_OUTPUT_PRIVATE_IP = "example_private_ip" +const EXPECTED_TEXT_FROM_SSH = "Hello World" + +func TestSsh(t *testing.T) { + t.Parallel() + + randomResourceCollection := createBaseRandomResourceCollection(t) + terratestOptions := createTerratestOptions("TestSsh", "../test-fixtures/ssh-test", randomResourceCollection, t) + defer terratest.Destroy(terratestOptions, randomResourceCollection) + + logger := terralog.NewLogger(terratestOptions.TestName) + + if _, err := terratest.Apply(terratestOptions); err != nil { + t.Fatalf("Failed to apply templates: %s\n", err.Error()) + } + + if err := testSshToPublicHost(terratestOptions, randomResourceCollection, logger); err != nil { + t.Fatalf("Failed to SSH to public host: %s\n", err.Error()) + } + + if err := testSshToPrivateHost(terratestOptions, randomResourceCollection, logger); err != nil { + t.Fatalf("Failed to SSH to private host: %s\n", err.Error()) + } +} + +// As of 6/9/16, these AWS regions do not support t2.nano instances +var REGIONS_WITHOUT_T2_NANO = []string{ + "ap-southeast-2", +} + +func createBaseRandomResourceCollection(t *testing.T) *terratest.RandomResourceCollection { + resourceCollectionOptions := terratest.NewRandomResourceCollectionOptions() + resourceCollectionOptions.ForbiddenRegions = REGIONS_WITHOUT_T2_NANO + + randomResourceCollection, err := terratest.CreateRandomResourceCollection(resourceCollectionOptions) + if err != nil { + t.Fatalf("Failed to create random resource collection: %s\n", err.Error()) + } + + return randomResourceCollection +} + +func createTerratestOptions(testName string, templatePath string, randomResourceCollection *terratest.RandomResourceCollection, t *testing.T) *terratest.TerratestOptions { + terratestOptions := terratest.NewTerratestOptions() + + terratestOptions.UniqueId = randomResourceCollection.UniqueId + terratestOptions.TemplatePath = templatePath + terratestOptions.TestName = testName + + vpc, err := randomResourceCollection.GetDefaultVpc() + if err != nil { + t.Fatalf("Failed to get default VPC: %s\n", err.Error()) + } + + terratestOptions.Vars = map[string]string { + "aws_region": randomResourceCollection.AwsRegion, + "ami": randomResourceCollection.AmiId, + "keypair_name": randomResourceCollection.KeyPair.Name, + "vpc_id": vpc.Id, + "name_prefix": fmt.Sprintf("ssh-test-%s", randomResourceCollection.UniqueId), + } + + return terratestOptions +} + +func testSshToPublicHost(terratestOptions *terratest.TerratestOptions, resourceCollection *terratest.RandomResourceCollection, logger *log.Logger) error { + ip, err := terratest.Output(terratestOptions, TERRAFORM_OUTPUT_PUBLIC_IP) + if err != nil { + return err + } + + host := Host { + Hostname: ip, + SshUserName: "ubuntu", + SshKeyPair: resourceCollection.KeyPair, + } + + _, err = util.DoWithRetry(fmt.Sprintf("SSH to %s", TERRAFORM_OUTPUT_PUBLIC_IP), 10, 30 * time.Second, logger, func() (string, error) { + output, err := CheckSshCommand(host, fmt.Sprintf("echo '%s'", EXPECTED_TEXT_FROM_SSH), logger) + + if err != nil { + return "", err + } + if ! strings.Contains(output, EXPECTED_TEXT_FROM_SSH) { + return "", fmt.Errorf("Expected output to contain '%s' but got %s", EXPECTED_TEXT_FROM_SSH, output) + } + + logger.Printf("Got expected output after SSHing to %s: %s", TERRAFORM_OUTPUT_PUBLIC_IP, EXPECTED_TEXT_FROM_SSH) + return output, nil + }) + + return err +} + +func testSshToPrivateHost(terratestOptions *terratest.TerratestOptions, resourceCollection *terratest.RandomResourceCollection, logger *log.Logger) error { + publicIp, err := terratest.Output(terratestOptions, TERRAFORM_OUTPUT_PUBLIC_IP) + if err != nil { + return err + } + + privateIp, err := terratest.Output(terratestOptions, TERRAFORM_OUTPUT_PRIVATE_IP) + if err != nil { + return err + } + + publicHost := Host { + Hostname: publicIp, + SshUserName: "ubuntu", + SshKeyPair: resourceCollection.KeyPair, + } + + privateHost := Host { + Hostname: privateIp, + SshUserName: "ubuntu", + SshKeyPair: resourceCollection.KeyPair, + } + + _, err = util.DoWithRetry(fmt.Sprintf("SSH to %s via %s", TERRAFORM_OUTPUT_PRIVATE_IP, TERRAFORM_OUTPUT_PUBLIC_IP), 10, 30 * time.Second, logger, func() (string, error) { + output, err := CheckPrivateSshConnection(publicHost, privateHost, fmt.Sprintf("echo '%s'", EXPECTED_TEXT_FROM_SSH), logger) + + if err != nil { + return "", err + } + if ! strings.Contains(output, EXPECTED_TEXT_FROM_SSH) { + return "", fmt.Errorf("Expected output to contain '%s' but got %s", EXPECTED_TEXT_FROM_SSH, output) + } + + logger.Printf("Got expected output after SSHing to %s via %s: %s", TERRAFORM_OUTPUT_PRIVATE_IP, TERRAFORM_OUTPUT_PUBLIC_IP, EXPECTED_TEXT_FROM_SSH) + return output, nil + }) + + return err +} \ No newline at end of file diff --git a/test-fixtures/ssh-test/main.tf b/test-fixtures/ssh-test/main.tf new file mode 100644 index 000000000..ecd757cf4 --- /dev/null +++ b/test-fixtures/ssh-test/main.tf @@ -0,0 +1,71 @@ +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# CREATE TWO EC2 INSTANCES FOR TESTING SSH CONNECTIVITY +# These templates deploy two EC2 instances, one with a public IP and one with only a private IP. These can be used to +# test SSH connectivity. +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +# --------------------------------------------------------------------------------------------------------------------- +# CONFIGURE OUR AWS CONNECTION +# --------------------------------------------------------------------------------------------------------------------- + +provider "aws" { + region = "${var.aws_region}" +} + +# --------------------------------------------------------------------------------------------------------------------- +# CREATE AN AWS INSTANCE WITH A PUBLIC IP +# --------------------------------------------------------------------------------------------------------------------- + +resource "aws_instance" "example_public" { + ami = "${var.ami}" + instance_type = "t2.nano" + key_name = "${var.keypair_name}" + vpc_security_group_ids = ["${aws_security_group.example.id}"] + associate_public_ip_address = true + tags { + Name = "${var.name_prefix}-public" + } +} + +# --------------------------------------------------------------------------------------------------------------------- +# CREATE AN AWS INSTANCE WITH ONLY A PRIVATE IP +# --------------------------------------------------------------------------------------------------------------------- + +resource "aws_instance" "example_private" { + ami = "${var.ami}" + instance_type = "t2.nano" + key_name = "${var.keypair_name}" + vpc_security_group_ids = ["${aws_security_group.example.id}"] + associate_public_ip_address = false + tags { + Name = "${var.name_prefix}-private" + } +} + +# --------------------------------------------------------------------------------------------------------------------- +# CREATE A SECURITY GROUP THAT ALLOWS SSH ACCESS TO THE EC2 INSTANCES +# --------------------------------------------------------------------------------------------------------------------- + +resource "aws_security_group" "example" { + vpc_id = "${var.vpc_id}" + + # Outbound Everything + egress { + from_port = 0 + to_port = 0 + protocol = "-1" + cidr_blocks = ["0.0.0.0/0"] + } + + # Inbound SSH from anywhere + ingress { + from_port = 22 + to_port = 22 + protocol = "tcp" + cidr_blocks = ["0.0.0.0/0"] + } + + tags { + Name = "${var.name_prefix}-example" + } +} \ No newline at end of file diff --git a/test-fixtures/ssh-test/outputs.tf b/test-fixtures/ssh-test/outputs.tf new file mode 100644 index 000000000..8e942edc4 --- /dev/null +++ b/test-fixtures/ssh-test/outputs.tf @@ -0,0 +1,7 @@ +output "example_public_ip" { + value = "${aws_instance.example_public.public_ip}" +} + +output "example_private_ip" { + value = "${aws_instance.example_private.private_ip}" +} \ No newline at end of file diff --git a/test-fixtures/ssh-test/vars.tf b/test-fixtures/ssh-test/vars.tf new file mode 100644 index 000000000..104013eb8 --- /dev/null +++ b/test-fixtures/ssh-test/vars.tf @@ -0,0 +1,40 @@ +# --------------------------------------------------------------------------------------------------------------------- +# ENVIRONMENT VARIABLES +# Define these secrets as environment variables +# --------------------------------------------------------------------------------------------------------------------- + +# AWS_ACCESS_KEY_ID +# AWS_SECRET_ACCESS_KEY + +# --------------------------------------------------------------------------------------------------------------------- +# MODULE PARAMETERS +# These variables are expected to be passed in by the operator +# --------------------------------------------------------------------------------------------------------------------- + +variable "aws_region" { + description = "The AWS region in which all resources will be created" +} + +variable "ami" { + description = "The ID of the AMI to run on each instance in this example" + # Ubuntu Server 14.04 LTS (HVM), SSD Volume Type in us-east-1 + default = "ami-fce3c696" +} + +variable "keypair_name" { + description = "The name of the Key Pair that can be used to SSH to each instance in this example" +} + +variable "vpc_id" { + description = "The ID of the VPC in which to run these instances" +} + +# --------------------------------------------------------------------------------------------------------------------- +# DEFINE CONSTANTS +# Generally, these values won't need to be changed. +# --------------------------------------------------------------------------------------------------------------------- + +variable "name_prefix" { + description = "The prefix to use for the names of all resources in these templates" + default = "ssh-test" +} diff --git a/util/retry.go b/util/retry.go new file mode 100644 index 000000000..4c02b15d9 --- /dev/null +++ b/util/retry.go @@ -0,0 +1,62 @@ +package util + +import ( + "fmt" + "time" + "github.com/gruntwork-io/terratest/parallel" + "log" +) + +// Run the specified action and wait up to the specified timeout for it to complete. Return the output of the action if +// it completes on time or an error otherwise. +func DoWithTimeout(actionDescription string, timeout time.Duration, action func() (string, error)) (string, error) { + resultChannel := make(chan parallel.TestResult, 1) + + go func() { + out, err := action() + resultChannel <- parallel.TestResult{Description: actionDescription, Value: out, Err: err} + }() + + select { + case result := <-resultChannel: + return result.Value, result.Err + case <-time.After(timeout): + return "", TimeoutExceeded{Description: actionDescription, Timeout: timeout} + } +} + +// Run the specified action. If it returns a value, return that value. If it returns an error, sleep for +// sleepBetweenRetries and try again, up to a maximum of maxRetries retries. +func DoWithRetry(actionDescription string, maxRetries int, sleepBetweenRetries time.Duration, logger *log.Logger, action func() (string, error)) (string, error) { + for i := 0; i < maxRetries; i++ { + output, err := action() + if err == nil { + return output, nil + } + + logger.Printf("%s returned an error: %s. Sleeping for %s and will try again.", actionDescription, err.Error(), sleepBetweenRetries) + time.Sleep(sleepBetweenRetries) + } + + return "", MaxRetriesExceeded{Description: actionDescription, MaxRetries: maxRetries} +} + +// Custom error types + +type TimeoutExceeded struct { + Description string + Timeout time.Duration +} + +func (err TimeoutExceeded) Error() string { + return fmt.Sprintf("'%s' did not complete before timeout of %s", err.Description, err.Timeout) +} + +type MaxRetriesExceeded struct { + Description string + MaxRetries int +} + +func (err MaxRetriesExceeded) Error() string { + return fmt.Sprintf("'%s' unsuccessful after %d retries", err.Description, err.MaxRetries) +} \ No newline at end of file diff --git a/util/retry_test.go b/util/retry_test.go new file mode 100644 index 000000000..c78cd17fa --- /dev/null +++ b/util/retry_test.go @@ -0,0 +1,114 @@ +package util + +import ( + "testing" + "time" + terralog "github.com/gruntwork-io/terratest/log" + "fmt" +) + +func TestDoWithRetry(t *testing.T) { + t.Parallel() + + expectedOutput := "expected" + expectedError := fmt.Errorf("expected error") + + actionAlwaysReturnsExpected := func() (string, error) { return expectedOutput, nil } + actionAlwaysReturnsError := func() (string, error) { return "", expectedError } + + createActionThatReturnsExpectedAfterFiveRetries := func() func() (string, error) { + count := 0 + return func() (string, error) { + count++ + if count > 5 { + return expectedOutput, nil + } else { + return "", expectedError + } + } + } + + testCases := []struct { + description string + maxRetries int + expectedError error + action func() (string, error) + }{ + {"Return value on first try", 10, nil, actionAlwaysReturnsExpected}, + {"Return error on all retries", 10, MaxRetriesExceeded{Description: "Return error on all retries", MaxRetries: 10}, actionAlwaysReturnsError}, + {"Return value after 5 retries", 10, nil, createActionThatReturnsExpectedAfterFiveRetries()}, + {"Return value after 5 retries, but only do 4 retries", 4, MaxRetriesExceeded{Description: "Return value after 5 retries, but only do 4 retries", MaxRetries: 4}, createActionThatReturnsExpectedAfterFiveRetries()}, + } + + logger := terralog.NewLogger("TestDoWithRetry") + + for _, testCase := range testCases { + actualOutput, err := DoWithRetry(testCase.description, testCase.maxRetries, 1 * time.Millisecond, logger, testCase.action) + if testCase.expectedError != nil { + if err != testCase.expectedError { + t.Fatalf("Expected error '%v' for test case '%s' but got '%v'", testCase.description, testCase.expectedError, err) + } + } else { + if err != nil { + t.Fatalf("Did not expect an error for test case '%s' but got: %s", testCase.description, err.Error()) + } + if actualOutput != expectedOutput { + t.Fatalf("Expected output '%s' but got '%s'", expectedOutput, actualOutput) + } + } + } +} + +func TestDoWithTimeout(t *testing.T) { + t.Parallel() + + expectedOutput := "expected" + expectedError := fmt.Errorf("expected error") + + actionReturnsValueImmediately := func() (string, error) { return expectedOutput, nil } + actionReturnsErrorImmediately := func() (string, error) { return "", expectedError} + + createActionThatReturnsValueAfterDelay := func(delay time.Duration) func() (string, error) { + return func() (string, error) { + time.Sleep(delay) + return expectedOutput, nil + } + } + + createActionThatReturnsErrorAfterDelay := func(delay time.Duration) func() (string, error) { + return func() (string, error) { + time.Sleep(delay) + return "", expectedError + } + } + + testCases := []struct { + description string + timeout time.Duration + expectedError error + action func() (string, error) + }{ + {"Returns value immediately", 5 * time.Second, nil, actionReturnsValueImmediately}, + {"Returns error immediately", 5 * time.Second, expectedError, actionReturnsErrorImmediately}, + {"Returns value after 2 seconds", 5 * time.Second, nil, createActionThatReturnsValueAfterDelay(2 * time.Second)}, + {"Returns error after 2 seconds", 5 * time.Second, expectedError, createActionThatReturnsErrorAfterDelay(2 * time.Second)}, + {"Returns value after timeout exceeded", 5 * time.Second, TimeoutExceeded{Description: "Returns value after timeout exceeded", Timeout: 5 * time.Second}, createActionThatReturnsValueAfterDelay(10 * time.Second)}, + {"Returns error after timeout exceeded", 5 * time.Second, TimeoutExceeded{Description: "Returns error after timeout exceeded", Timeout: 5 * time.Second}, createActionThatReturnsErrorAfterDelay(10 * time.Second)}, + } + + for _, testCase := range testCases { + actualOutput, err := DoWithTimeout(testCase.description, testCase.timeout, testCase.action) + if testCase.expectedError != nil { + if err != testCase.expectedError { + t.Fatalf("Expected error '%v' for test case '%s' but got '%v'", testCase.description, testCase.expectedError, err) + } + } else { + if err != nil { + t.Fatalf("Did not expect an error for test case '%s' but got: %s", testCase.description, err.Error()) + } + if actualOutput != expectedOutput { + t.Fatalf("Expected output '%s' but got '%s'", expectedOutput, actualOutput) + } + } + } +}