Skip to content

Commit

Permalink
Make sure grouped flags are added to the command flag set (#1180)
Browse files Browse the repository at this point in the history
## Changes
Make sure grouped flags are added to the command flag set

## Tests
Added regression tests
  • Loading branch information
andrewnester authored Feb 7, 2024
1 parent 0b5fdcc commit de363fa
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 11 deletions.
12 changes: 8 additions & 4 deletions bundle/run/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@ type Options struct {
}

func (o *Options) Define(cmd *cobra.Command) {
wrappedCmd := cmdgroup.NewCommandWithGroupFlag(cmd)
jobGroup := wrappedCmd.AddFlagGroup("Job")
jobGroup := cmdgroup.NewFlagGroup("Job")
o.Job.DefineJobOptions(jobGroup.FlagSet())

jobTaskGroup := wrappedCmd.AddFlagGroup("Job Task")
jobTaskGroup := cmdgroup.NewFlagGroup("Job Task")
jobTaskGroup.SetDescription(`Note: please prefer use of job-level parameters (--param) over task-level parameters.
For more information, see https://docs.databricks.com/en/workflows/jobs/create-run-jobs.html#pass-parameters-to-a-databricks-job-task`)
o.Job.DefineTaskOptions(jobTaskGroup.FlagSet())

pipelineGroup := wrappedCmd.AddFlagGroup("Pipeline")
pipelineGroup := cmdgroup.NewFlagGroup("Pipeline")
o.Pipeline.Define(pipelineGroup.FlagSet())

wrappedCmd := cmdgroup.NewCommandWithGroupFlag(cmd)
wrappedCmd.AddFlagGroup(jobGroup)
wrappedCmd.AddFlagGroup(jobTaskGroup)
wrappedCmd.AddFlagGroup(pipelineGroup)
}
31 changes: 28 additions & 3 deletions libs/cmdgroup/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,24 @@ func (c *CommandWithGroupFlag) FlagGroups() []*FlagGroup {
return c.flagGroups
}

func (c *CommandWithGroupFlag) NonGroupedFlags() *pflag.FlagSet {
nonGrouped := pflag.NewFlagSet("non-grouped", pflag.ContinueOnError)
c.cmd.LocalFlags().VisitAll(func(f *pflag.Flag) {
for _, fg := range c.flagGroups {
if fg.Has(f) {
return
}
}
nonGrouped.AddFlag(f)
})

return nonGrouped
}

func (c *CommandWithGroupFlag) HasNonGroupedFlags() bool {
return c.NonGroupedFlags().HasFlags()
}

func NewCommandWithGroupFlag(cmd *cobra.Command) *CommandWithGroupFlag {
cmdWithFlagGroups := &CommandWithGroupFlag{cmd: cmd, flagGroups: make([]*FlagGroup, 0)}
cmd.SetUsageFunc(func(c *cobra.Command) error {
Expand All @@ -36,10 +54,9 @@ func NewCommandWithGroupFlag(cmd *cobra.Command) *CommandWithGroupFlag {
return cmdWithFlagGroups
}

func (c *CommandWithGroupFlag) AddFlagGroup(name string) *FlagGroup {
fg := &FlagGroup{name: name, flagSet: pflag.NewFlagSet(name, pflag.ContinueOnError)}
func (c *CommandWithGroupFlag) AddFlagGroup(fg *FlagGroup) {
c.flagGroups = append(c.flagGroups, fg)
return fg
c.cmd.Flags().AddFlagSet(fg.FlagSet())
}

type FlagGroup struct {
Expand All @@ -48,6 +65,10 @@ type FlagGroup struct {
flagSet *pflag.FlagSet
}

func NewFlagGroup(name string) *FlagGroup {
return &FlagGroup{name: name, flagSet: pflag.NewFlagSet(name, pflag.ContinueOnError)}
}

func (c *FlagGroup) Name() string {
return c.name
}
Expand All @@ -64,6 +85,10 @@ func (c *FlagGroup) FlagSet() *pflag.FlagSet {
return c.flagSet
}

func (c *FlagGroup) Has(f *pflag.Flag) bool {
return c.flagSet.Lookup(f.Name) != nil
}

var templateFuncs = template.FuncMap{
"trim": strings.TrimSpace,
"trimRightSpace": trimRightSpace,
Expand Down
12 changes: 10 additions & 2 deletions libs/cmdgroup/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@ func TestCommandFlagGrouping(t *testing.T) {
}

wrappedCmd := NewCommandWithGroupFlag(cmd)
jobGroup := wrappedCmd.AddFlagGroup("Job")
jobGroup := NewFlagGroup("Job")
fs := jobGroup.FlagSet()
fs.String("job-name", "", "Name of the job")
fs.String("job-type", "", "Type of the job")
wrappedCmd.AddFlagGroup(jobGroup)

pipelineGroup := wrappedCmd.AddFlagGroup("Pipeline")
pipelineGroup := NewFlagGroup("Pipeline")
fs = pipelineGroup.FlagSet()
fs.String("pipeline-name", "", "Name of the pipeline")
fs.String("pipeline-type", "", "Type of the pipeline")
wrappedCmd.AddFlagGroup(pipelineGroup)

cmd.Flags().BoolP("bool", "b", false, "Bool flag")

Expand All @@ -48,4 +50,10 @@ Pipeline Flags:
Flags:
-b, --bool Bool flag`
require.Equal(t, expected, buf.String())

require.NotNil(t, cmd.Flags().Lookup("job-name"))
require.NotNil(t, cmd.Flags().Lookup("job-type"))
require.NotNil(t, cmd.Flags().Lookup("pipeline-name"))
require.NotNil(t, cmd.Flags().Lookup("pipeline-type"))
require.NotNil(t, cmd.Flags().Lookup("bool"))
}
4 changes: 2 additions & 2 deletions libs/cmdgroup/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ const usageTemplate = `Usage:{{if .Command.Runnable}}
{{.Description}}{{end}}
{{.FlagSet.FlagUsages | trimTrailingWhitespaces}}
{{end}}
{{if .Command.HasAvailableLocalFlags}}Flags:
{{.Command.LocalFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .Command.HasAvailableInheritedFlags}}
{{if .HasNonGroupedFlags}}Flags:
{{.NonGroupedFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .Command.HasAvailableInheritedFlags}}
Global Flags:
{{.Command.InheritedFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}`

0 comments on commit de363fa

Please sign in to comment.