diff --git a/internal/utils/io.go b/internal/utils/io.go index 0abdc418..ccccf5f9 100644 --- a/internal/utils/io.go +++ b/internal/utils/io.go @@ -18,30 +18,30 @@ func maybeUnwrap(err error) error { return err } -// stdinFilename is the name of the file that is used in many command line utilities -// to denote input is to be read from STDIN. +// stdinFilename is the name of the file that is used in many command +// line utilities to denote input is to be read from STDIN. const stdinFilename = "-" -// stdin points to os.Stdin. +// stdin points to STDIN through os.Stdin. var stdin = os.Stdin -// ReadFile reads the file named by filename and returns the -// contents. If filename is equal to "-", it will read from +// ReadFile reads the file identified by filename and returns +// the contents. If filename is equal to "-", it will read from // STDIN. -func ReadFile(name string) (b []byte, err error) { - if name == stdinFilename { - name = "/dev/stdin" +func ReadFile(filename string) (b []byte, err error) { + if filename == stdinFilename { + filename = "/dev/stdin" b, err = io.ReadAll(stdin) } else { var contents []byte - contents, err = os.ReadFile(name) + contents, err = os.ReadFile(filename) if err != nil { - return nil, errors.Wrapf(maybeUnwrap(err), "error reading %s", name) + return nil, errors.Wrapf(maybeUnwrap(err), "error reading %q", filename) } b, err = io.ReadAll(utfbom.SkipOnly(bytes.NewReader(contents))) } if err != nil { - return nil, errors.Wrapf(maybeUnwrap(err), "error reading %s", name) + return nil, errors.Wrapf(maybeUnwrap(err), "error reading %q", filename) } return } diff --git a/internal/utils/io_test.go b/internal/utils/io_test.go index d083ce6c..53c02ddf 100644 --- a/internal/utils/io_test.go +++ b/internal/utils/io_test.go @@ -2,15 +2,18 @@ package utils import ( "fmt" + "io" "os" "path/filepath" "reflect" "testing" "github.com/pkg/errors" + "github.com/stretchr/testify/require" ) func TestReadFile(t *testing.T) { + t.Parallel() type args struct { filename string } @@ -37,7 +40,64 @@ func TestReadFile(t *testing.T) { } } +// Set content to be read from mock STDIN +func setStdinContent(t *testing.T, content string) (cleanup func()) { + f, err := os.CreateTemp("" /* dir */, "utils-read-test") + require.NoError(t, err) + _, err = f.Write([]byte(content)) + require.NoError(t, err) + _, err = f.Seek(0, io.SeekStart) + require.NoError(t, err) + old := stdin + stdin = f + + return func() { + stdin = old + require.NoError(t, f.Close()) + require.NoError(t, os.Remove(f.Name())) + } +} + +func TestReadFromStdin(t *testing.T) { + cleanup := setStdinContent(t, "input on STDIN") + t.Cleanup(func() { + cleanup() + }) + + b, err := ReadFile(stdinFilename) + require.NoError(t, err) + require.Equal(t, "input on STDIN", string(b)) +} + +// Sets STDIN to a file that is already closed, and thus fails +// to be read from. +func setFailingStdin(t *testing.T) (cleanup func()) { + f, err := os.CreateTemp("" /* dir */, "utils-read-test") + require.NoError(t, err) + err = f.Close() + require.NoError(t, err) + old := stdin + stdin = f + + return func() { + stdin = old + require.NoError(t, os.Remove(f.Name())) + } +} + +func TestReadFromStdinFails(t *testing.T) { + cleanup := setFailingStdin(t) + t.Cleanup(func() { + cleanup() + }) + + b, err := ReadFile(stdinFilename) + require.Error(t, err) + require.Empty(t, b) +} + func TestReadPasswordFromFile(t *testing.T) { + t.Parallel() type args struct { filename string } @@ -65,11 +125,20 @@ func TestReadPasswordFromFile(t *testing.T) { } } +func TestReadPasswordFromStdin(t *testing.T) { + cleanup := setStdinContent(t, "this-is-a-secret-testing-password") + t.Cleanup(func() { + cleanup() + }) + + b, err := ReadPasswordFromFile(stdinFilename) + require.NoError(t, err) + require.Equal(t, "this-is-a-secret-testing-password", string(b)) +} + func TestWriteFile(t *testing.T) { tmpDir, err := os.MkdirTemp(os.TempDir(), "go-tests") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) t.Cleanup(func() { os.RemoveAll(tmpDir) }) @@ -97,6 +166,7 @@ func TestWriteFile(t *testing.T) { } func Test_maybeUnwrap(t *testing.T) { + t.Parallel() wantErr := fmt.Errorf("the error") type args struct {