Skip to content

Commit

Permalink
Merge pull request #1977 from dearchap/default_persistence
Browse files Browse the repository at this point in the history
Fix: Make flags persistent by default
  • Loading branch information
dearchap authored Oct 11, 2024
2 parents 20ef97b + af97b3f commit a288be2
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 119 deletions.
54 changes: 11 additions & 43 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -636,12 +636,6 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
return err
}

if err := cmd.checkPersistentRequiredFlags(); err != nil {
cmd.isInError = true
_ = ShowSubcommandHelp(cmd)
return err
}

if len(cmd.Arguments) > 0 {
rargs := cmd.Args().Slice()
tracef("calling argparse with %[1]v", rargs)
Expand Down Expand Up @@ -768,8 +762,8 @@ func (cmd *Command) parseFlags(args Args) (Args, error) {
for _, fl := range pCmd.Flags {
flNames := fl.Names()

pfl, ok := fl.(PersistentFlag)
if !ok || !pfl.IsPersistent() {
pfl, ok := fl.(LocalFlag)
if !ok || pfl.IsLocal() {
tracef("skipping non-persistent flag %[1]q (cmd=%[2]q)", flNames, cmd.Name)
continue
}
Expand Down Expand Up @@ -881,12 +875,12 @@ func (cmd *Command) appendFlag(fl Flag) {
}
}

// VisiblePersistentFlags returns a slice of [PersistentFlag] with Persistent=true and Hidden=false.
// VisiblePersistentFlags returns a slice of [LocalFlag] with Persistent=true and Hidden=false.
func (cmd *Command) VisiblePersistentFlags() []Flag {
var flags []Flag
for _, fl := range cmd.Root().Flags {
pfl, ok := fl.(PersistentFlag)
if !ok || !pfl.IsPersistent() {
pfl, ok := fl.(LocalFlag)
if !ok || pfl.IsLocal() {
continue
}
flags = append(flags, fl)
Expand Down Expand Up @@ -994,48 +988,22 @@ func (cmd *Command) checkRequiredFlag(f Flag) (bool, string) {
}

func (cmd *Command) checkAllRequiredFlags() requiredFlagsErr {
if cmd.parent != nil {
if err := cmd.parent.checkRequiredFlags(); err != nil {
for pCmd := cmd; pCmd != nil; pCmd = pCmd.parent {
if err := pCmd.checkRequiredFlags(); err != nil {
return err
}
}
return cmd.checkRequiredFlags()
}

func (cmd *Command) checkRequiredFlags() requiredFlagsErr {
tracef("checking for required flags (cmd=%[1]q)", cmd.Name)

missingFlags := []string{}

for _, f := range cmd.Flags {
if pf, ok := f.(PersistentFlag); !ok || !pf.IsPersistent() {
if ok, name := cmd.checkRequiredFlag(f); !ok {
missingFlags = append(missingFlags, name)
}
}
}

if len(missingFlags) != 0 {
tracef("found missing required flags %[1]q (cmd=%[2]q)", missingFlags, cmd.Name)

return &errRequiredFlags{missingFlags: missingFlags}
}

tracef("all required flags set (cmd=%[1]q)", cmd.Name)

return nil
}

func (cmd *Command) checkPersistentRequiredFlags() requiredFlagsErr {
func (cmd *Command) checkRequiredFlags() requiredFlagsErr {
tracef("checking for required flags (cmd=%[1]q)", cmd.Name)

missingFlags := []string{}

for _, f := range cmd.appliedFlags {
if pf, ok := f.(PersistentFlag); ok && pf.IsPersistent() {
if ok, name := cmd.checkRequiredFlag(f); !ok {
missingFlags = append(missingFlags, name)
}
if ok, name := cmd.checkRequiredFlag(f); !ok {
missingFlags = append(missingFlags, name)
}
}

Expand Down Expand Up @@ -1233,7 +1201,7 @@ func (cmd *Command) runFlagActions(ctx context.Context) error {
if !fl.IsSet() {
continue
}
if pf, ok := fl.(PersistentFlag); ok && pf.IsPersistent() {
if pf, ok := fl.(LocalFlag); ok && !pf.IsLocal() {
continue
}
}
Expand Down
66 changes: 29 additions & 37 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2803,7 +2803,6 @@ func TestPersistentFlag(t *testing.T) {
Flags: []Flag{
&StringFlag{
Name: "persistentCommandFlag",
Persistent: true,
Destination: &appFlag,
Action: func(context.Context, *Command, string) error {
persistentFlagActionCount++
Expand All @@ -2812,22 +2811,18 @@ func TestPersistentFlag(t *testing.T) {
},
&IntSliceFlag{
Name: "persistentCommandSliceFlag",
Persistent: true,
Destination: &persistentCommandSliceInt,
},
&FloatSliceFlag{
Name: "persistentCommandFloatSliceFlag",
Persistent: true,
Value: []float64{11.3, 12.5},
Name: "persistentCommandFloatSliceFlag",
Value: []float64{11.3, 12.5},
},
&IntFlag{
Name: "persistentCommandOverrideFlag",
Persistent: true,
Destination: &appOverrideInt,
},
&StringFlag{
Name: "persistentRequiredCommandFlag",
Persistent: true,
Required: true,
Destination: &appRequiredFlag,
},
Expand All @@ -2839,16 +2834,17 @@ func TestPersistentFlag(t *testing.T) {
&IntFlag{
Name: "cmdFlag",
Destination: &topInt,
Local: true,
},
&IntFlag{
Name: "cmdPersistentFlag",
Persistent: true,
Destination: &topPersistentInt,
},
&IntFlag{
Name: "paof",
Aliases: []string{"persistentCommandOverrideFlag"},
Destination: &appOverrideCmdInt,
Local: true,
},
},
Commands: []*Command{
Expand All @@ -2858,6 +2854,7 @@ func TestPersistentFlag(t *testing.T) {
&IntFlag{
Name: "cmdFlag",
Destination: &subCommandInt,
Local: true,
},
},
Action: func(_ context.Context, cmd *Command) error {
Expand Down Expand Up @@ -2914,8 +2911,7 @@ func TestPersistentFlagIsSet(t *testing.T) {
Name: "root",
Flags: []Flag{
&StringFlag{
Name: "result",
Persistent: true,
Name: "result",
},
},
Commands: []*Command{
Expand Down Expand Up @@ -3016,9 +3012,8 @@ func TestRequiredPersistentFlag(t *testing.T) {
Name: "root",
Flags: []Flag{
&StringFlag{
Name: "result",
Persistent: true,
Required: true,
Name: "result",
Required: true,
},
},
Commands: []*Command{
Expand Down Expand Up @@ -3418,10 +3413,10 @@ func TestCommand_IsSet_fromEnv(t *testing.T) {

cmd := &Command{
Flags: []Flag{
&FloatFlag{Name: "timeout", Aliases: []string{"t"}, Sources: EnvVars("APP_TIMEOUT_SECONDS")},
&StringFlag{Name: "password", Aliases: []string{"p"}, Sources: EnvVars("APP_PASSWORD")},
&FloatFlag{Name: "unparsable", Aliases: []string{"u"}, Sources: EnvVars("APP_UNPARSABLE")},
&FloatFlag{Name: "no-env-var", Aliases: []string{"n"}},
&FloatFlag{Name: "timeout", Aliases: []string{"t"}, Local: true, Sources: EnvVars("APP_TIMEOUT_SECONDS")},
&StringFlag{Name: "password", Aliases: []string{"p"}, Local: true, Sources: EnvVars("APP_PASSWORD")},
&FloatFlag{Name: "unparsable", Aliases: []string{"u"}, Local: true, Sources: EnvVars("APP_UNPARSABLE")},
&FloatFlag{Name: "no-env-var", Aliases: []string{"n"}, Local: true},
},
Action: func(_ context.Context, cmd *Command) error {
timeoutIsSet = cmd.IsSet("timeout")
Expand Down Expand Up @@ -3772,18 +3767,15 @@ func TestCheckRequiredFlags(t *testing.T) {
_ = os.Setenv(test.envVarInput[0], test.envVarInput[1])
}

set := flag.NewFlagSet("test", 0)
for _, flags := range test.flags {
_ = flags.Apply(set)
}
_ = set.Parse(test.parseInput)

cmd := &Command{
Flags: test.flags,
flagSet: set,
Name: "foo",
Flags: test.flags,
}
args := []string{"foo"}
args = append(args, test.parseInput...)
_ = cmd.Run(context.Background(), args)

err := cmd.checkRequiredFlags()
err := cmd.checkAllRequiredFlags()

// assertions
if test.expectedAnError {
Expand Down Expand Up @@ -4041,7 +4033,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": false,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": "",
"aliases": [
"sub-fl",
Expand All @@ -4062,7 +4054,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": false,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": false,
"aliases": [
"s"
Expand Down Expand Up @@ -4103,7 +4095,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": false,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": "",
"aliases": [
"fl",
Expand All @@ -4124,7 +4116,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": false,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": false,
"aliases": [
"b"
Expand Down Expand Up @@ -4283,7 +4275,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": false,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": false,
"aliases": [
"s"
Expand Down Expand Up @@ -4324,7 +4316,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": false,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": "",
"aliases": [
"fl",
Expand All @@ -4345,7 +4337,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": false,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": false,
"aliases": [
"b"
Expand Down Expand Up @@ -4386,7 +4378,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": false,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": "value",
"aliases": [
"s"
Expand All @@ -4406,7 +4398,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": false,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": "",
"aliases": [
"fl",
Expand All @@ -4427,7 +4419,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": false,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": false,
"aliases": [
"b"
Expand All @@ -4447,7 +4439,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": true,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": false,
"aliases": null,
"takesFileArg": false,
Expand Down
10 changes: 6 additions & 4 deletions flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ var VersionFlag Flag = &BoolFlag{
Aliases: []string{"v"},
Usage: "print the version",
HideDefault: true,
Local: true,
}

// HelpFlag prints the help for all commands and subcommands.
Expand All @@ -48,6 +49,7 @@ var HelpFlag Flag = &BoolFlag{
Aliases: []string{"h"},
Usage: "show help",
HideDefault: true,
Local: true,
}

// FlagStringer converts a flag definition to a string. This is used by help
Expand Down Expand Up @@ -172,10 +174,10 @@ type CategorizableFlag interface {
SetCategory(string)
}

// PersistentFlag is an interface to enable detection of flags which are persistent
// through subcommands
type PersistentFlag interface {
IsPersistent() bool
// LocalFlag is an interface to enable detection of flags which are local
// to current command
type LocalFlag interface {
IsLocal() bool
}

// IsDefaultVisible returns true if the flag is not hidden, otherwise false
Expand Down
2 changes: 1 addition & 1 deletion flag_bool_with_inverse.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (parent *BoolWithInverseFlag) initialize() {
Usage: child.Usage,
Required: child.Required,
Hidden: child.Hidden,
Persistent: child.Persistent,
Local: child.Local,
Value: child.Value,
Destination: parent.negDest,
TakesFile: child.TakesFile,
Expand Down
1 change: 1 addition & 0 deletions flag_bool_with_inverse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ func TestBoolWithInverseEnvVars(t *testing.T) {
BoolFlag: &BoolFlag{
Name: "env",
Sources: EnvVars("ENV"),
Local: true,
},
}
}
Expand Down
Loading

0 comments on commit a288be2

Please sign in to comment.