diff --git a/shellwords.go b/shellwords.go index 576b792..fc29fa4 100644 --- a/shellwords.go +++ b/shellwords.go @@ -1,10 +1,12 @@ package shellwords import ( + "bytes" "errors" "os" "regexp" "strings" + "unicode" ) var ( @@ -27,13 +29,72 @@ func replaceEnv(getenv func(string) string, s string) string { getenv = os.Getenv } - return envRe.ReplaceAllStringFunc(s, func(s string) string { - s = s[1:] - if s[0] == '{' { - s = s[1 : len(s)-1] + var buf bytes.Buffer + rs := []rune(s) + for i := 0; i < len(rs); i++ { + r := rs[i] + if r == '\\' { + i++ + if i == len(rs) { + break + } + buf.WriteRune(rs[i]) + continue + } else if r == '$' { + i++ + if i == len(rs) { + buf.WriteRune(r) + break + } + if rs[i] == 0x7b { + i++ + p := i + for ; i < len(rs); i++ { + r = rs[i] + if r == '\\' { + i++ + if i == len(rs) { + return s + } + continue + } + if r == 0x7d || (!unicode.IsLetter(r) && r != '_' && !unicode.IsDigit(r)) { + break + } + } + if r != 0x7d { + return s + } + if i > p { + buf.WriteString(getenv(s[p:i])) + } + } else { + p := i + for ; i < len(rs); i++ { + r := rs[i] + if r == '\\' { + i++ + if i == len(rs) { + return s + } + continue + } + if !unicode.IsLetter(r) && r != '_' && !unicode.IsDigit(r) { + break + } + } + if i > p { + buf.WriteString(getenv(s[p:i])) + i-- + } else { + buf.WriteString(s[p:]) + } + } + } else { + buf.WriteRune(r) } - return getenv(s) - }) + } + return buf.String() } type Parser struct { @@ -56,6 +117,14 @@ func NewParser() *Parser { } } +type argType int + +const ( + argNo argType = iota + argSingle + argQuoted +) + func (p *Parser) Parse(line string) ([]string, error) { args := []string{} buf := "" @@ -63,7 +132,7 @@ func (p *Parser) Parse(line string) ([]string, error) { backtick := "" pos := -1 - got := false + got := argNo i := -1 loop: @@ -72,7 +141,7 @@ loop: if escaped { buf += string(r) escaped = false - got = true + got = argSingle continue } @@ -89,21 +158,25 @@ loop: if singleQuoted || doubleQuoted || backQuote || dollarQuote { buf += string(r) backtick += string(r) - } else if got { + } else if got != argNo { if p.ParseEnv { - parser := &Parser{ParseEnv: false, ParseBacktick: false, Position: 0, Dir: p.Dir} - strs, err := parser.Parse(replaceEnv(p.Getenv, buf)) - if err != nil { - return nil, err - } - for _, str := range strs { - args = append(args, str) + if got == argSingle { + parser := &Parser{ParseEnv: false, ParseBacktick: false, Position: 0, Dir: p.Dir} + strs, err := parser.Parse(replaceEnv(p.Getenv, buf)) + if err != nil { + return nil, err + } + for _, str := range strs { + args = append(args, str) + } + } else { + args = append(args, replaceEnv(p.Getenv, buf)) } } else { args = append(args, buf) } buf = "" - got = false + got = argNo } continue } @@ -156,7 +229,7 @@ loop: case '"': if !singleQuoted && !dollarQuote { if doubleQuoted { - got = true + got = argQuoted } doubleQuoted = !doubleQuoted continue @@ -164,7 +237,7 @@ loop: case '\'': if !doubleQuoted && !dollarQuote { if singleQuoted { - got = true + got = argSingle } singleQuoted = !singleQuoted continue @@ -174,7 +247,7 @@ loop: if r == '>' && len(buf) > 0 { if c := buf[0]; '0' <= c && c <= '9' { i -= 1 - got = false + got = argNo } } pos = i @@ -182,22 +255,26 @@ loop: } } - got = true + got = argSingle buf += string(r) if backQuote || dollarQuote { backtick += string(r) } } - if got { + if got != argNo { if p.ParseEnv { - parser := &Parser{ParseEnv: false, ParseBacktick: false, Position: 0, Dir: p.Dir} - strs, err := parser.Parse(replaceEnv(p.Getenv, buf)) - if err != nil { - return nil, err - } - for _, str := range strs { - args = append(args, str) + if got == argSingle { + parser := &Parser{ParseEnv: false, ParseBacktick: false, Position: 0, Dir: p.Dir} + strs, err := parser.Parse(replaceEnv(p.Getenv, buf)) + if err != nil { + return nil, err + } + for _, str := range strs { + args = append(args, str) + } + } else { + args = append(args, replaceEnv(p.Getenv, buf)) } } else { args = append(args, buf) diff --git a/shellwords_test.go b/shellwords_test.go index cfe818c..b32a493 100644 --- a/shellwords_test.go +++ b/shellwords_test.go @@ -288,9 +288,9 @@ func TestEnvArgumentsFail(t *testing.T) { t.Fatal("Should be an error") } os.Setenv("FOO", "bar `") - _, err = parser.Parse("$FOO ") + result, err := parser.Parse("$FOO ") if err == nil { - t.Fatal("Should be an error") + t.Fatal("Should be an error: ", result) } } @@ -300,20 +300,20 @@ func TestDupEnv(t *testing.T) { parser := NewParser() parser.ParseEnv = true - args, err := parser.Parse("echo $$FOO$") + args, err := parser.Parse("echo $FOO$") if err != nil { t.Fatal(err) } - expected := []string{"echo", "$bar$"} + expected := []string{"echo", "bar$"} if !reflect.DeepEqual(args, expected) { t.Fatalf("Expected %#v, but %#v:", expected, args) } - args, err = parser.Parse("echo $${FOO_BAR}$") + args, err = parser.Parse("echo ${FOO_BAR}$") if err != nil { t.Fatal(err) } - expected = []string{"echo", "$baz$"} + expected = []string{"echo", "baz$"} if !reflect.DeepEqual(args, expected) { t.Fatalf("Expected %#v, but %#v:", expected, args) } @@ -383,3 +383,27 @@ func TestBackquoteInFlag(t *testing.T) { t.Fatalf("Expected %#v, but %#v:", expected, args) } } + +func TestEnvInQuoted(t *testing.T) { + os.Setenv("FOO", "bar") + + parser := NewParser() + parser.ParseEnv = true + args, err := parser.Parse(`ssh 127.0.0.1 "echo $FOO"`) + if err != nil { + panic(err) + } + expected := []string{"ssh", "127.0.0.1", "echo bar"} + if !reflect.DeepEqual(args, expected) { + t.Fatalf("Expected %#v, but %#v:", expected, args) + } + + args, err = parser.Parse(`ssh 127.0.0.1 "echo \\$FOO"`) + if err != nil { + panic(err) + } + expected = []string{"ssh", "127.0.0.1", "echo $FOO"} + if !reflect.DeepEqual(args, expected) { + t.Fatalf("Expected %#v, but %#v:", expected, args) + } +}