diff --git a/ecs-cli/modules/cli/local/project/project.go b/ecs-cli/modules/cli/local/project/project.go index 909fe3102..f7740cce2 100644 --- a/ecs-cli/modules/cli/local/project/project.go +++ b/ecs-cli/modules/cli/local/project/project.go @@ -163,40 +163,52 @@ func (p *localProject) Convert() error { func (p *localProject) Write() error { // Will error if the file already exists, otherwise create - outputFileName := p.context.String(flags.LocalOutputFlag) - if outputFileName == "" { - outputFileName = LocalOutDefaultFileName + p.localOutFileName = LocalOutDefaultFileName + if fileName := p.context.String(flags.LocalOutputFlag); fileName != "" { + p.localOutFileName = fileName } - p.localOutFileName = outputFileName - out, err := os.OpenFile(outputFileName, os.O_WRONLY|os.O_CREATE|os.O_EXCL, LocalOutFileMode) - defer out.Close() + return p.writeFile() +} - data := p.localBytes +func (p *localProject) writeFile() error { + out, err := openFile(p.localOutFileName) + defer out.Close() + // File already exists if err != nil { - fmt.Printf("%s file already exists. Do you want to write over this file? [y/N]\n", outputFileName) + return p.overwriteFile() + } - reader := bufio.NewReader(os.Stdin) - input, err := reader.ReadString('\n') - if err != nil { - return fmt.Errorf("Error reading input: %s", err.Error()) - } + _, err = out.Write(p.localBytes) - formattedInput := strings.ToLower(strings.TrimSpace(input)) + return err +} - if formattedInput != "yes" && formattedInput != "y" { - return fmt.Errorf("Aborted writing compose file. To retry, rename or move %s", outputFileName) // TODO add force flag - } +// Facilitates test mocking +var openFile func(filename string) (*os.File, error) = func(filename string) (*os.File, error) { + return os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_EXCL, LocalOutFileMode) +} - // Overwrite local compose file - err = ioutil.WriteFile(outputFileName, data, LocalOutFileMode) - return err +func (p *localProject) overwriteFile() error { + filename := p.localOutFileName + + fmt.Printf("%s file already exists. Do you want to write over this file? [y/N]\n", filename) + + reader := bufio.NewReader(os.Stdin) + stdin, err := reader.ReadString('\n') + if err != nil { + return fmt.Errorf("Error reading stdin: %s", err.Error()) } - _, err = out.Write(data) + input := strings.ToLower(strings.TrimSpace(stdin)) - return err + if input != "yes" && input != "y" { + return fmt.Errorf("Aborted writing compose file. To retry, rename or move %s", filename) // TODO add force flag + } + + // Overwrite local compose file + return ioutil.WriteFile(filename, p.localBytes, LocalOutFileMode) } // Get secret value stored in AWS Secrets Manager diff --git a/ecs-cli/modules/cli/local/project/project_test.go b/ecs-cli/modules/cli/local/project/project_test.go new file mode 100644 index 000000000..8aa59585a --- /dev/null +++ b/ecs-cli/modules/cli/local/project/project_test.go @@ -0,0 +1,78 @@ +// Copyright 2015-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +// Package localproject defines LocalProject interface and implements them on localProject + +package localproject + +import ( + "flag" + "io/ioutil" + "os" + "testing" + + "github.com/aws/amazon-ecs-cli/ecs-cli/modules/commands/flags" + "github.com/stretchr/testify/assert" + "github.com/urfave/cli" +) + +func TestWrite(t *testing.T) { + // GIVEN + flagSet := flag.NewFlagSet("ecs-cli", 0) // No flags specified + context := cli.NewContext(nil, flagSet, nil) + project := New(context) + + oldOpenFile := openFile + openFile = func(filename string) (*os.File, error) { + tmpfile, err := ioutil.TempFile("", filename) + assert.NoError(t, err, "Unexpected error in creating temp compose file") + defer os.Remove(tmpfile.Name()) + + return tmpfile, nil + } + defer func() { openFile = oldOpenFile }() + + // WHEN + err := project.Write() + + // THEN + assert.NoError(t, err, "Unexpected error in writing local compose file") + assert.Equal(t, LocalOutDefaultFileName, project.LocalOutFileName()) +} + +func TestWrite_WithOutputFlag(t *testing.T) { + // GIVEN + expectedOutputFile := "foo.yml" + flagSet := flag.NewFlagSet("ecs-cli", 0) + flagSet.String(flags.LocalOutputFlag, expectedOutputFile, "") + context := cli.NewContext(nil, flagSet, nil) + project := New(context) + + oldOpenFile := openFile + openFile = func(filename string) (*os.File, error) { + tmpfile, err := ioutil.TempFile("", filename) + assert.NoError(t, err, "Unexpected error in creating temp compose file") + defer os.Remove(tmpfile.Name()) + + return tmpfile, nil + } + + defer func() { openFile = oldOpenFile }() + + // WHEN + err := project.Write() + + // THEN + assert.NoError(t, err, "Unexpected error in writing local compose file") + assert.Equal(t, expectedOutputFile, project.LocalOutFileName()) +}