Skip to content

Commit

Permalink
Merge pull request #7 from coder/completion
Browse files Browse the repository at this point in the history
Add auto-completion
  • Loading branch information
ammario authored Aug 1, 2024
2 parents 6e88789 + c365495 commit 91966a2
Show file tree
Hide file tree
Showing 13 changed files with 855 additions and 93 deletions.
106 changes: 101 additions & 5 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ type Command struct {
Middleware MiddlewareFunc
Handler HandlerFunc
HelpHandler HandlerFunc
// CompletionHandler is called when the command is run in completion
// mode. If nil, only the default completion handler is used.
//
// Flag and option parsing is best-effort in this mode, so even if an Option
// is "required" it may not be set.
CompletionHandler CompletionHandlerFunc
}

// AddSubcommands adds the given subcommands, setting their
Expand Down Expand Up @@ -193,15 +199,22 @@ type Invocation struct {
ctx context.Context
Command *Command
parsedFlags *pflag.FlagSet
Args []string

// Args is reduced into the remaining arguments after parsing flags
// during Run.
Args []string

// Environ is a list of environment variables. Use EnvsWithPrefix to parse
// os.Environ.
Environ Environ
Stdout io.Writer
Stderr io.Writer
Stdin io.Reader
Logger slog.Logger
Net Net

// Deprecated
Logger slog.Logger
// Deprecated
Net Net

// testing
signalNotifyContext func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc)
Expand Down Expand Up @@ -282,6 +295,17 @@ func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet {
return fs2
}

func (inv *Invocation) CurWords() (prev string, cur string) {
if len(inv.Args) == 1 {
cur = inv.Args[0]
prev = ""
} else {
cur = inv.Args[len(inv.Args)-1]
prev = inv.Args[len(inv.Args)-2]
}
return
}

// run recursively executes the command and its children.
// allArgs is wired through the stack so that global flags can be accepted
// anywhere in the command invocation.
Expand Down Expand Up @@ -378,8 +402,19 @@ func (inv *Invocation) run(state *runState) error {
}
}

// Outputted completions are not filtered based on the word under the cursor, as every shell we support does this already.
// We only look at the current word to figure out handler to run, or what directory to inspect.
if inv.IsCompletionMode() {
for _, e := range inv.complete() {
fmt.Fprintln(inv.Stdout, e)
}
return nil
}

ignoreFlagParseErrors := inv.Command.RawArgs

// Flag parse errors are irrelevant for raw args commands.
if !inv.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
if !ignoreFlagParseErrors && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
return xerrors.Errorf(
"parsing flags (%v) for %q: %w",
state.allArgs,
Expand All @@ -401,7 +436,7 @@ func (inv *Invocation) run(state *runState) error {
}
}
// Don't error for missing flags if `--help` was supplied.
if len(missing) > 0 && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
if len(missing) > 0 && !inv.IsCompletionMode() && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
return xerrors.Errorf("Missing values for the required flags: %s", strings.Join(missing, ", "))
}

Expand Down Expand Up @@ -558,6 +593,65 @@ func (inv *Invocation) with(fn func(*Invocation)) *Invocation {
return &i2
}

func (inv *Invocation) complete() []string {
prev, cur := inv.CurWords()

// If the current word is a flag
if strings.HasPrefix(cur, "--") {
flagParts := strings.Split(cur, "=")
flagName := flagParts[0][2:]
// If it's an equals flag
if len(flagParts) == 2 {
if out := inv.completeFlag(flagName); out != nil {
for i, o := range out {
out[i] = fmt.Sprintf("--%s=%s", flagName, o)
}
return out
}
} else if out := inv.Command.Options.ByFlag(flagName); out != nil {
// If the current word is a valid flag, auto-complete it so the
// shell moves the cursor
return []string{cur}
}
}
// If the previous word is a flag, then we're writing it's value
// and we should check it's handler
if strings.HasPrefix(prev, "--") {
word := prev[2:]
if out := inv.completeFlag(word); out != nil {
return out
}
}
// If the current word is the command, move the shell cursor
if inv.Command.Name() == cur {
return []string{inv.Command.Name()}
}
var completions []string

if inv.Command.CompletionHandler != nil {
completions = append(completions, inv.Command.CompletionHandler(inv)...)
}

completions = append(completions, DefaultCompletionHandler(inv)...)

return completions
}

func (inv *Invocation) completeFlag(word string) []string {
opt := inv.Command.Options.ByFlag(word)
if opt == nil {
return nil
}
if opt.CompletionHandler != nil {
return opt.CompletionHandler(inv)
}
val, ok := opt.Value.(*Enum)
if ok {
return val.Choices
}
return nil
}

// MiddlewareFunc returns the next handler in the chain,
// or nil if there are no more.
type MiddlewareFunc func(next HandlerFunc) HandlerFunc
Expand Down Expand Up @@ -642,3 +736,5 @@ func RequireRangeArgs(start, end int) MiddlewareFunc {

// HandlerFunc handles an Invocation of a command.
type HandlerFunc func(i *Invocation) error

type CompletionHandlerFunc func(i *Invocation) []string
216 changes: 132 additions & 84 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"golang.org/x/xerrors"

serpent "github.com/coder/serpent"
"github.com/coder/serpent/completion"
)

// ioBufs is the standard input, output, and error for a command.
Expand All @@ -30,100 +31,147 @@ func fakeIO(i *serpent.Invocation) *ioBufs {
return &b
}

func TestCommand(t *testing.T) {
t.Parallel()

cmd := func() *serpent.Command {
var (
verbose bool
lower bool
prefix string
reqBool bool
reqStr string
)
return &serpent.Command{
Use: "root [subcommand]",
Options: serpent.OptionSet{
serpent.Option{
Name: "verbose",
Flag: "verbose",
Value: serpent.BoolOf(&verbose),
},
serpent.Option{
Name: "prefix",
Flag: "prefix",
Value: serpent.StringOf(&prefix),
},
func sampleCommand(t *testing.T) *serpent.Command {
t.Helper()
var (
verbose bool
lower bool
prefix string
reqBool bool
reqStr string
reqArr []string
fileArr []string
enumStr string
)
enumChoices := []string{"foo", "bar", "qux"}
return &serpent.Command{
Use: "root [subcommand]",
Options: serpent.OptionSet{
serpent.Option{
Name: "verbose",
Flag: "verbose",
Value: serpent.BoolOf(&verbose),
},
Children: []*serpent.Command{
{
Use: "required-flag --req-bool=true --req-string=foo",
Short: "Example with required flags",
Options: serpent.OptionSet{
serpent.Option{
Name: "req-bool",
Flag: "req-bool",
Value: serpent.BoolOf(&reqBool),
Required: true,
},
serpent.Option{
Name: "req-string",
Flag: "req-string",
Value: serpent.Validate(serpent.StringOf(&reqStr), func(value *serpent.String) error {
ok := strings.Contains(value.String(), " ")
if !ok {
return xerrors.Errorf("string must contain a space")
}
return nil
}),
Required: true,
},
serpent.Option{
Name: "prefix",
Flag: "prefix",
Value: serpent.StringOf(&prefix),
},
},
Children: []*serpent.Command{
{
Use: "required-flag --req-bool=true --req-string=foo",
Short: "Example with required flags",
Options: serpent.OptionSet{
serpent.Option{
Name: "req-bool",
Flag: "req-bool",
FlagShorthand: "b",
Value: serpent.BoolOf(&reqBool),
Required: true,
},
HelpHandler: func(i *serpent.Invocation) error {
_, _ = i.Stdout.Write([]byte("help text.png"))
return nil
serpent.Option{
Name: "req-string",
Flag: "req-string",
FlagShorthand: "s",
Value: serpent.Validate(serpent.StringOf(&reqStr), func(value *serpent.String) error {
ok := strings.Contains(value.String(), " ")
if !ok {
return xerrors.Errorf("string must contain a space")
}
return nil
}),
Required: true,
},
Handler: func(i *serpent.Invocation) error {
_, _ = i.Stdout.Write([]byte(fmt.Sprintf("%s-%t", reqStr, reqBool)))
return nil
serpent.Option{
Name: "req-enum",
Flag: "req-enum",
Value: serpent.EnumOf(&enumStr, enumChoices...),
},
serpent.Option{
Name: "req-array",
Flag: "req-array",
FlagShorthand: "a",
Value: serpent.StringArrayOf(&reqArr),
},
},
{
Use: "toupper [word]",
Short: "Converts a word to upper case",
Middleware: serpent.Chain(
serpent.RequireNArgs(1),
),
Aliases: []string{"up"},
Options: serpent.OptionSet{
serpent.Option{
Name: "lower",
Flag: "lower",
Value: serpent.BoolOf(&lower),
},
HelpHandler: func(i *serpent.Invocation) error {
_, _ = i.Stdout.Write([]byte("help text.png"))
return nil
},
Handler: func(i *serpent.Invocation) error {
_, _ = i.Stdout.Write([]byte(fmt.Sprintf("%s-%t", reqStr, reqBool)))
return nil
},
},
{
Use: "toupper [word]",
Short: "Converts a word to upper case",
Middleware: serpent.Chain(
serpent.RequireNArgs(1),
),
Aliases: []string{"up"},
Options: serpent.OptionSet{
serpent.Option{
Name: "lower",
Flag: "lower",
Value: serpent.BoolOf(&lower),
},
Handler: func(i *serpent.Invocation) error {
_, _ = i.Stdout.Write([]byte(prefix))
w := i.Args[0]
if lower {
w = strings.ToLower(w)
} else {
w = strings.ToUpper(w)
}
_, _ = i.Stdout.Write(
[]byte(
w,
),
)
if verbose {
_, _ = i.Stdout.Write([]byte("!!!"))
}
return nil
},
Handler: func(i *serpent.Invocation) error {
_, _ = i.Stdout.Write([]byte(prefix))
w := i.Args[0]
if lower {
w = strings.ToLower(w)
} else {
w = strings.ToUpper(w)
}
_, _ = i.Stdout.Write(
[]byte(
w,
),
)
if verbose {
_, _ = i.Stdout.Write([]byte("!!!"))
}
return nil
},
},
{
Use: "file <file>",
Handler: func(inv *serpent.Invocation) error {
return nil
},
CompletionHandler: completion.FileHandler(func(info os.FileInfo) bool {
return true
}),
Middleware: serpent.RequireNArgs(1),
},
{
Use: "altfile",
Handler: func(inv *serpent.Invocation) error {
return nil
},
Options: serpent.OptionSet{
{
Name: "extra",
Flag: "extra",
Description: "Extra files.",
Value: serpent.StringArrayOf(&fileArr),
},
},
CompletionHandler: func(i *serpent.Invocation) []string {
return []string{"doesntexist.go"}
},
},
}
},
}
}

func TestCommand(t *testing.T) {
t.Parallel()

cmd := func() *serpent.Command { return sampleCommand(t) }

t.Run("SimpleOK", func(t *testing.T) {
t.Parallel()
Expand Down
Loading

0 comments on commit 91966a2

Please sign in to comment.