Skip to content

Commit

Permalink
Issue #107: add basic zsh completion (command hierarchy only) (#497)
Browse files Browse the repository at this point in the history
Add basic zsh completion (command hierarchy only)

Partially fixes #107
See PR #497
  • Loading branch information
bpicode authored and anthonyfok committed Jul 30, 2017
1 parent 9e024b6 commit 1bdc55b
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 0 deletions.
114 changes: 114 additions & 0 deletions zsh_completions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package cobra

import (
"bytes"
"fmt"
"io"
"strings"
)

// GenZshCompletion generates a zsh completion file and writes to the passed writer.
func (cmd *Command) GenZshCompletion(w io.Writer) error {
buf := new(bytes.Buffer)

writeHeader(buf, cmd)
maxDepth := maxDepth(cmd)
writeLevelMapping(buf, maxDepth)
writeLevelCases(buf, maxDepth, cmd)

_, err := buf.WriteTo(w)
return err
}

func writeHeader(w io.Writer, cmd *Command) {
fmt.Fprintf(w, "#compdef %s\n\n", cmd.Name())
}

func maxDepth(c *Command) int {
if len(c.Commands()) == 0 {
return 0
}
maxDepthSub := 0
for _, s := range c.Commands() {
subDepth := maxDepth(s)
if subDepth > maxDepthSub {
maxDepthSub = subDepth
}
}
return 1 + maxDepthSub
}

func writeLevelMapping(w io.Writer, numLevels int) {
fmt.Fprintln(w, `_arguments \`)
for i := 1; i <= numLevels; i++ {
fmt.Fprintf(w, ` '%d: :->level%d' \`, i, i)
fmt.Fprintln(w)
}
fmt.Fprintf(w, ` '%d: :%s'`, numLevels+1, "_files")
fmt.Fprintln(w)
}

func writeLevelCases(w io.Writer, maxDepth int, root *Command) {
fmt.Fprintln(w, "case $state in")
defer fmt.Fprintln(w, "esac")

for i := 1; i <= maxDepth; i++ {
fmt.Fprintf(w, " level%d)\n", i)
writeLevel(w, root, i)
fmt.Fprintln(w, " ;;")
}
fmt.Fprintln(w, " *)")
fmt.Fprintln(w, " _arguments '*: :_files'")
fmt.Fprintln(w, " ;;")
}

func writeLevel(w io.Writer, root *Command, i int) {
fmt.Fprintf(w, " case $words[%d] in\n", i)
defer fmt.Fprintln(w, " esac")

commands := filterByLevel(root, i)
byParent := groupByParent(commands)

for p, c := range byParent {
names := names(c)
fmt.Fprintf(w, " %s)\n", p)
fmt.Fprintf(w, " _arguments '%d: :(%s)'\n", i, strings.Join(names, " "))
fmt.Fprintln(w, " ;;")
}
fmt.Fprintln(w, " *)")
fmt.Fprintln(w, " _arguments '*: :_files'")
fmt.Fprintln(w, " ;;")

}

func filterByLevel(c *Command, l int) []*Command {
cs := make([]*Command, 0)
if l == 0 {
cs = append(cs, c)
return cs
}
for _, s := range c.Commands() {
cs = append(cs, filterByLevel(s, l-1)...)
}
return cs
}

func groupByParent(commands []*Command) map[string][]*Command {
m := make(map[string][]*Command)
for _, c := range commands {
parent := c.Parent()
if parent == nil {
continue
}
m[parent.Name()] = append(m[parent.Name()], c)
}
return m
}

func names(commands []*Command) []string {
ns := make([]string, len(commands))
for i, c := range commands {
ns[i] = c.Name()
}
return ns
}
88 changes: 88 additions & 0 deletions zsh_completions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package cobra

import (
"bytes"
"strings"
"testing"
)

func TestZshCompletion(t *testing.T) {
tcs := []struct {
name string
root *Command
expectedExpressions []string
}{
{
name: "trivial",
root: &Command{Use: "trivialapp"},
expectedExpressions: []string{"#compdef trivial"},
},
{
name: "linear",
root: func() *Command {
r := &Command{Use: "linear"}

sub1 := &Command{Use: "sub1"}
r.AddCommand(sub1)

sub2 := &Command{Use: "sub2"}
sub1.AddCommand(sub2)

sub3 := &Command{Use: "sub3"}
sub2.AddCommand(sub3)
return r
}(),
expectedExpressions: []string{"sub1", "sub2", "sub3"},
},
{
name: "flat",
root: func() *Command {
r := &Command{Use: "flat"}
r.AddCommand(&Command{Use: "c1"})
r.AddCommand(&Command{Use: "c2"})
return r
}(),
expectedExpressions: []string{"(c1 c2)"},
},
{
name: "tree",
root: func() *Command {
r := &Command{Use: "tree"}

sub1 := &Command{Use: "sub1"}
r.AddCommand(sub1)

sub11 := &Command{Use: "sub11"}
sub12 := &Command{Use: "sub12"}

sub1.AddCommand(sub11)
sub1.AddCommand(sub12)

sub2 := &Command{Use: "sub2"}
r.AddCommand(sub2)

sub21 := &Command{Use: "sub21"}
sub22 := &Command{Use: "sub22"}

sub2.AddCommand(sub21)
sub2.AddCommand(sub22)

return r
}(),
expectedExpressions: []string{"(sub11 sub12)", "(sub21 sub22)"},
},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
buf := new(bytes.Buffer)
tc.root.GenZshCompletion(buf)
completion := buf.String()
for _, expectedExpression := range tc.expectedExpressions {
if !strings.Contains(completion, expectedExpression) {
t.Errorf("expected completion to contain '%v' somewhere; got '%v'", expectedExpression, completion)
}
}
})
}
}

0 comments on commit 1bdc55b

Please sign in to comment.