Skip to content

Commit

Permalink
Merge pull request #24 from gruntwork-io/ssh
Browse files Browse the repository at this point in the history
Use Go libs for SSH. Add retry helpers.
  • Loading branch information
brikis98 authored Jul 6, 2016
2 parents be37015 + 8dd3eb5 commit 658be89
Show file tree
Hide file tree
Showing 9 changed files with 669 additions and 81 deletions.
32 changes: 12 additions & 20 deletions http/http_helper.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package http_helper

import (
"errors"
"time"
"net/http"
"io/ioutil"
"strings"
"strconv"
"log"
"github.com/gruntwork-io/terratest/util"
"fmt"
)

func HttpGet(url string, logger *log.Logger) (int, string, error) {
Expand All @@ -29,28 +29,20 @@ func HttpGet(url string, logger *log.Logger) (int, string, error) {
}

func HttpGetWithRetry(url string, expectedBody string, retries int, sleepBetweenRetries time.Duration, logger *log.Logger) error {
for i := 0; i < retries; i++ {
_, err := util.DoWithRetry(fmt.Sprintf("HTTP GET to URL %s", url), retries, sleepBetweenRetries, logger, func() (string, error) {
status, body, err := HttpGet(url, logger)

if err == nil && status == 200 {
logger.Println("Got 200 OK from URL", url)
if body == expectedBody {
logger.Println("Got expected body from URL", url, ":", body)
return nil
} else {
logger.Println("Did not get expected body from URL", url, ". Expected:", expectedBody, ". Got:", body, ".")
}
}

if err != nil {
logger.Println("Got an error after making an HTTP get to URL", url, ":", err)
return "", err
} else if status != 200 {
logger.Println("Got a non-200 response from URL", url, ":", status)
return "", fmt.Errorf("Expected a 200 response but got %d", status)
} else if body != expectedBody {
return "", fmt.Errorf("Got a 200 response, but did not get expected body. Expected: %s. Got: %s.", expectedBody, body)
} else {
logger.Printf("Got 200 a response from URL %s and expected body: %s\n", url, body)
return body, nil
}
})

logger.Println("Will retry in", sleepBetweenRetries)
time.Sleep(sleepBetweenRetries)
}

return errors.New("Did not get a 200 OK from URL " + url + " after " + strconv.Itoa(retries) + " retries.")
return err
}
94 changes: 94 additions & 0 deletions ssh/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package ssh

import (
"golang.org/x/crypto/ssh"
"fmt"
"io"
"net"
"reflect"
"log"
)

type SshConnectionOptions struct {
Username string
Address string
Port int
AuthMethods []ssh.AuthMethod
Command string
JumpHost *SshConnectionOptions
}

func (options *SshConnectionOptions) ConnectionString() string {
return fmt.Sprintf("%s:%d", options.Address, options.Port)
}

// A container object for all resources created by an SSH session. The reason we need this is so that we can do a
// single defer in a top-level method that calls the Cleanup method to go through and ensure all of these resources are
// released and cleaned up.
type SshSession struct {
Options *SshConnectionOptions
Client *ssh.Client
Session *ssh.Session
JumpHost *JumpHostSession
}

func (sshSession *SshSession) Cleanup(logger *log.Logger) {
if sshSession == nil {
return
}


// Closing the session may result in an EOF error if it's already closed (e.g. due to hitting CTRL + D), so
// don't report those errors, as there is nothing actually wrong in that case.
Close(sshSession.Session, logger, io.EOF.Error())
Close(sshSession.Client, logger)
sshSession.JumpHost.Cleanup(logger)
}

type JumpHostSession struct {
JumpHostClient *ssh.Client
HostVirtualConnection net.Conn
HostConnection ssh.Conn
}

func (jumpHost *JumpHostSession) Cleanup(logger *log.Logger) {
if jumpHost == nil {
return
}

// Closing a connection may result in an EOF error if it's already closed (e.g. due to hitting CTRL + D), so
// don't report those errors, as there is nothing actually wrong in that case.
Close(jumpHost.HostConnection, logger, io.EOF.Error())
Close(jumpHost.HostVirtualConnection, logger, io.EOF.Error())
Close(jumpHost.JumpHostClient, logger)
}

type Closeable interface {
Close() error
}

func Close(closeable Closeable, logger *log.Logger, ignoreErrors ...string) {
if interfaceIsNil(closeable) {
return
}

if err := closeable.Close(); err != nil && !contains(ignoreErrors, err.Error()) {
logger.Printf("Error closing %s: %s", closeable, err.Error())
}
}

func contains(haystack []string, needle string) bool {
for _, hay := range haystack {
if hay == needle {
return true
}
}
return false
}

// Go is a shitty language. Checking an interface directly against nil does not work, and if you don't know the exact
// types the interface may be ahead of time, the only way to know if you're dealing with nil is to use reflection.
// http://stackoverflow.com/questions/13476349/check-for-nil-and-nil-interface-in-go
func interfaceIsNil(i interface{}) bool {
return i == nil || reflect.ValueOf(i).IsNil()
}
183 changes: 122 additions & 61 deletions ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,7 @@ package ssh
import (
"github.com/gruntwork-io/terratest"
"log"
"github.com/gruntwork-io/terratest/shell"
"errors"
"strconv"
"io/ioutil"
"os"
"fmt"
"github.com/gruntwork-io/terratest/util"
"golang.org/x/crypto/ssh"
)

type Host struct {
Expand All @@ -26,90 +20,157 @@ func CheckSshConnection(host Host, logger *log.Logger) error {

// Check that you can connect via SSH to the given host and run the given command. Returns the stdout/stderr.
func CheckSshCommand(host Host, command string, logger *log.Logger) (string, error) {
keyPairWithUniqueName := createKeyPairCopyWithUniqueName(*host.SshKeyPair)

defer cleanupKeyPairFile(keyPairWithUniqueName, logger)
writeKeyPairFile(keyPairWithUniqueName, logger)

output, sshErr := shell.RunCommandAndGetOutput(shell.Command{Command: "ssh", Args: []string{"-i", keyPairWithUniqueName.Name, "-o", "UserKnownHostsFile=/dev/null", "-o", "StrictHostKeyChecking=no", host.SshUserName + "@" + host.Hostname, command}}, logger)

exitCode, err := shell.GetExitCodeForRunCommandError(sshErr)

authMethods, err := createAuthMethodsForHost(host)
if err != nil {
return output, err
return "", err
}

hostOptions := SshConnectionOptions{
Username: host.SshUserName,
Address: host.Hostname,
Port: 22,
Command: command,
AuthMethods: authMethods,
}

if exitCode != 0 {
return output, errors.New("SSH exited with a non-zero exit code: " + strconv.Itoa(exitCode))
sshSession := &SshSession{
Options: &hostOptions,
JumpHost: &JumpHostSession{},
}

return output, nil
defer sshSession.Cleanup(logger)

return runSshCommand(sshSession)
}

// CheckPrivateSshConnection attempts to connect to privateHost (which is not addressable from the Internet) via a separate
// publicHost (which is addressable from the Internet) and then executes "command" on privateHost and returns its output.
// It is useful for checking that it's possible to SSH from a Bastion Host to a private instance.
func CheckPrivateSshConnection(publicHost Host, privateHost Host, command string, logger *log.Logger) (string, error) {
publicKeyPairWithUniqueName := createKeyPairCopyWithUniqueName(*publicHost.SshKeyPair)
privateKeyPairWithUniqueName := createKeyPairCopyWithUniqueName(*privateHost.SshKeyPair)

defer cleanupKeyPairFile(publicKeyPairWithUniqueName, logger)
writeKeyPairFile(publicKeyPairWithUniqueName, logger)
jumpHostAuthMethods, err := createAuthMethodsForHost(publicHost)
if err != nil {
return "", err
}

defer cleanupKeyPairFile(privateKeyPairWithUniqueName, logger)
writeKeyPairFile(privateKeyPairWithUniqueName, logger)
jumpHostOptions := SshConnectionOptions{
Username: publicHost.SshUserName,
Address: publicHost.Hostname,
Port: 22,
AuthMethods: jumpHostAuthMethods,
}

// We need the SSH key to be available when we SSH from the Bastion Host to the Private Host.
// We cannot guarantee ssh-agent will be in the test environment, so we use scp to copy the key to the bastion host file system.
// Start by setting permissions on the key to 0600. These permissions (read/write for file owner only) are required by ssh to access the key.
chmodErr := shell.RunCommand(shell.Command{Command: "chmod", Args: []string{"0600", privateKeyPairWithUniqueName.Name}}, logger)
exitCode, err := shell.GetExitCodeForRunCommandError(chmodErr)
hostAuthMethods, err := createAuthMethodsForHost(privateHost)
if err != nil {
return "", err
}
if exitCode != 0 {
return "", errors.New("Attempt to set permissions on local key file exited with a non-zero exit code: " + strconv.Itoa(exitCode))

hostOptions := SshConnectionOptions{
Username: privateHost.SshUserName,
Address: privateHost.Hostname,
Port: 22,
Command: command,
AuthMethods: hostAuthMethods,
JumpHost: &jumpHostOptions,
}

sshSession := &SshSession{
Options: &hostOptions,
JumpHost: &JumpHostSession{},
}

defer sshSession.Cleanup(logger)

return runSshCommand(sshSession)
}

func runSshCommand(sshSession *SshSession) (string, error) {
if err := setupSshClient(sshSession); err != nil {
return "", err
}

if err := setupSshSession(sshSession); err != nil {
return "", err
}

// Upload the key to the bastion host
sshErr := shell.RunCommand(shell.Command{Command: "scp", Args: []string{"-p", "-i", publicKeyPairWithUniqueName.Name, "-o", "UserKnownHostsFile=/dev/null", "-o", "StrictHostKeyChecking=no", privateKeyPairWithUniqueName.Name, publicHost.SshUserName + "@" + publicHost.Hostname + ":key.pem"}}, logger)
exitCode, err = shell.GetExitCodeForRunCommandError(sshErr)
bytes, err := sshSession.Session.Output(sshSession.Options.Command)
if err != nil {
return "", err
}
if exitCode != 0 {
return "", errors.New("Attempt to SSH and write key file exited with a non-zero exit code: " + strconv.Itoa(exitCode))

return string(bytes), nil
}

func setupSshClient(sshSession *SshSession) error {
if sshSession.Options.JumpHost == nil {
return fillSshClientForHost(sshSession)
} else {
return fillSshClientForJumpHost(sshSession)
}
}

func fillSshClientForHost(sshSession *SshSession) error {
client, err := createSshClient(sshSession.Options)

// Now connect directly to the privateHost
output, sshErr := shell.RunCommandAndGetOutput(shell.Command{Command: "ssh", Args: []string{"-i", publicKeyPairWithUniqueName.Name, "-o", "UserKnownHostsFile=/dev/null", "-o", "StrictHostKeyChecking=no", publicHost.SshUserName + "@" + publicHost.Hostname, "ssh -i key.pem -o StrictHostKeyChecking=no", privateHost.SshUserName + "@" + privateHost.Hostname, command}}, logger)
exitCode, err = shell.GetExitCodeForRunCommandError(sshErr)
if err != nil {
return output, err
return err
}
if exitCode != 0 {
return output, errors.New("Attempt to SSH to private host exited with a non-zero exit code: " + strconv.Itoa(exitCode))

sshSession.Client = client
return nil
}

func fillSshClientForJumpHost(sshSession *SshSession) error {
jumpHostClient, err := createSshClient(sshSession.Options.JumpHost)
if err != nil {
return err
}
sshSession.JumpHost.JumpHostClient = jumpHostClient

hostVirtualConn, err := jumpHostClient.Dial("tcp", sshSession.Options.ConnectionString())
if err != nil {
return err
}
sshSession.JumpHost.HostVirtualConnection = hostVirtualConn

return output, nil
hostConn, hostIncomingChannels, hostIncomingRequests, err := ssh.NewClientConn(hostVirtualConn, sshSession.Options.ConnectionString(), createSshClientConfig(sshSession.Options))
if err != nil {
return err
}
sshSession.JumpHost.HostConnection = hostConn

sshSession.Client = ssh.NewClient(hostConn, hostIncomingChannels, hostIncomingRequests)
return nil
}

func setupSshSession(sshSession *SshSession) error {
session, err := sshSession.Client.NewSession()
if err != nil {
return err
}

sshSession.Session = session
return nil
}

func writeKeyPairFile(keyPair terratest.Ec2Keypair, logger *log.Logger) error {
logger.Println("Creating test-time Key Pair file", keyPair.Name)
return ioutil.WriteFile(keyPair.Name, []byte(keyPair.PrivateKey), 0400)
func createSshClient(options *SshConnectionOptions) (*ssh.Client, error) {
sshClientConfig := createSshClientConfig(options)
return ssh.Dial("tcp", options.ConnectionString(), sshClientConfig)
}

func cleanupKeyPairFile(keyPair terratest.Ec2Keypair, logger *log.Logger) error {
logger.Println("Cleaning up test-time Key Pair file", keyPair.Name)
return os.Remove(keyPair.Name)
func createSshClientConfig(hostOptions *SshConnectionOptions) *ssh.ClientConfig {
clientConfig := &ssh.ClientConfig{
User: hostOptions.Username,
Auth: hostOptions.AuthMethods,
}
clientConfig.SetDefaults()
return clientConfig
}

// Testing SSH connectivity involves writing and deleting Key Pair files on disk. Since there might be multiple SSH
// checks happening in parallel, we use this function to give the Key Pair file a unique name, and thereby avoid the
// files overwriting each other.
func createKeyPairCopyWithUniqueName(keyPair terratest.Ec2Keypair) terratest.Ec2Keypair {
// This automatically creates a shallow copy in Go
keyPairWithUniqueName := keyPair
keyPairWithUniqueName.Name = fmt.Sprintf("%s-%s", keyPairWithUniqueName.Name, util.UniqueId())
return keyPairWithUniqueName
}
func createAuthMethodsForHost(host Host) ([]ssh.AuthMethod, error) {
signer, err := ssh.ParsePrivateKey([]byte(host.SshKeyPair.PrivateKey))
if err != nil {
return []ssh.AuthMethod{}, err
}

return []ssh.AuthMethod{ssh.PublicKeys(signer)}, nil
}
Loading

0 comments on commit 658be89

Please sign in to comment.