Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 11 additions & 43 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -636,12 +636,6 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
return err
}

if err := cmd.checkPersistentRequiredFlags(); err != nil {
cmd.isInError = true
_ = ShowSubcommandHelp(cmd)
return err
}

if len(cmd.Arguments) > 0 {
rargs := cmd.Args().Slice()
tracef("calling argparse with %[1]v", rargs)
Expand Down Expand Up @@ -768,8 +762,8 @@ func (cmd *Command) parseFlags(args Args) (Args, error) {
for _, fl := range pCmd.Flags {
flNames := fl.Names()

pfl, ok := fl.(PersistentFlag)
if !ok || !pfl.IsPersistent() {
pfl, ok := fl.(LocalFlag)
if !ok || pfl.IsLocal() {
tracef("skipping non-persistent flag %[1]q (cmd=%[2]q)", flNames, cmd.Name)
continue
}
Expand Down Expand Up @@ -881,12 +875,12 @@ func (cmd *Command) appendFlag(fl Flag) {
}
}

// VisiblePersistentFlags returns a slice of [PersistentFlag] with Persistent=true and Hidden=false.
// VisiblePersistentFlags returns a slice of [LocalFlag] with Persistent=true and Hidden=false.
func (cmd *Command) VisiblePersistentFlags() []Flag {
var flags []Flag
for _, fl := range cmd.Root().Flags {
pfl, ok := fl.(PersistentFlag)
if !ok || !pfl.IsPersistent() {
pfl, ok := fl.(LocalFlag)
if !ok || pfl.IsLocal() {
continue
}
flags = append(flags, fl)
Expand Down Expand Up @@ -994,48 +988,22 @@ func (cmd *Command) checkRequiredFlag(f Flag) (bool, string) {
}

func (cmd *Command) checkAllRequiredFlags() requiredFlagsErr {
if cmd.parent != nil {
if err := cmd.parent.checkRequiredFlags(); err != nil {
for pCmd := cmd; pCmd != nil; pCmd = pCmd.parent {
if err := pCmd.checkRequiredFlags(); err != nil {
return err
}
}
return cmd.checkRequiredFlags()
}

func (cmd *Command) checkRequiredFlags() requiredFlagsErr {
tracef("checking for required flags (cmd=%[1]q)", cmd.Name)

missingFlags := []string{}

for _, f := range cmd.Flags {
if pf, ok := f.(PersistentFlag); !ok || !pf.IsPersistent() {
if ok, name := cmd.checkRequiredFlag(f); !ok {
missingFlags = append(missingFlags, name)
}
}
}

if len(missingFlags) != 0 {
tracef("found missing required flags %[1]q (cmd=%[2]q)", missingFlags, cmd.Name)

return &errRequiredFlags{missingFlags: missingFlags}
}

tracef("all required flags set (cmd=%[1]q)", cmd.Name)

return nil
}

func (cmd *Command) checkPersistentRequiredFlags() requiredFlagsErr {
func (cmd *Command) checkRequiredFlags() requiredFlagsErr {
tracef("checking for required flags (cmd=%[1]q)", cmd.Name)

missingFlags := []string{}

for _, f := range cmd.appliedFlags {
if pf, ok := f.(PersistentFlag); ok && pf.IsPersistent() {
if ok, name := cmd.checkRequiredFlag(f); !ok {
missingFlags = append(missingFlags, name)
}
if ok, name := cmd.checkRequiredFlag(f); !ok {
missingFlags = append(missingFlags, name)
}
}

Expand Down Expand Up @@ -1233,7 +1201,7 @@ func (cmd *Command) runFlagActions(ctx context.Context) error {
if !fl.IsSet() {
continue
}
if pf, ok := fl.(PersistentFlag); ok && pf.IsPersistent() {
if pf, ok := fl.(LocalFlag); ok && !pf.IsLocal() {
continue
}
}
Expand Down
66 changes: 29 additions & 37 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2803,7 +2803,6 @@ func TestPersistentFlag(t *testing.T) {
Flags: []Flag{
&StringFlag{
Name: "persistentCommandFlag",
Persistent: true,
Destination: &appFlag,
Action: func(context.Context, *Command, string) error {
persistentFlagActionCount++
Expand All @@ -2812,22 +2811,18 @@ func TestPersistentFlag(t *testing.T) {
},
&IntSliceFlag{
Name: "persistentCommandSliceFlag",
Persistent: true,
Destination: &persistentCommandSliceInt,
},
&FloatSliceFlag{
Name: "persistentCommandFloatSliceFlag",
Persistent: true,
Value: []float64{11.3, 12.5},
Name: "persistentCommandFloatSliceFlag",
Value: []float64{11.3, 12.5},
},
&IntFlag{
Name: "persistentCommandOverrideFlag",
Persistent: true,
Destination: &appOverrideInt,
},
&StringFlag{
Name: "persistentRequiredCommandFlag",
Persistent: true,
Required: true,
Destination: &appRequiredFlag,
},
Expand All @@ -2839,16 +2834,17 @@ func TestPersistentFlag(t *testing.T) {
&IntFlag{
Name: "cmdFlag",
Destination: &topInt,
Local: true,
},
&IntFlag{
Name: "cmdPersistentFlag",
Persistent: true,
Destination: &topPersistentInt,
},
&IntFlag{
Name: "paof",
Aliases: []string{"persistentCommandOverrideFlag"},
Destination: &appOverrideCmdInt,
Local: true,
},
},
Commands: []*Command{
Expand All @@ -2858,6 +2854,7 @@ func TestPersistentFlag(t *testing.T) {
&IntFlag{
Name: "cmdFlag",
Destination: &subCommandInt,
Local: true,
},
},
Action: func(_ context.Context, cmd *Command) error {
Expand Down Expand Up @@ -2914,8 +2911,7 @@ func TestPersistentFlagIsSet(t *testing.T) {
Name: "root",
Flags: []Flag{
&StringFlag{
Name: "result",
Persistent: true,
Name: "result",
},
},
Commands: []*Command{
Expand Down Expand Up @@ -3016,9 +3012,8 @@ func TestRequiredPersistentFlag(t *testing.T) {
Name: "root",
Flags: []Flag{
&StringFlag{
Name: "result",
Persistent: true,
Required: true,
Name: "result",
Required: true,
},
},
Commands: []*Command{
Expand Down Expand Up @@ -3418,10 +3413,10 @@ func TestCommand_IsSet_fromEnv(t *testing.T) {

cmd := &Command{
Flags: []Flag{
&FloatFlag{Name: "timeout", Aliases: []string{"t"}, Sources: EnvVars("APP_TIMEOUT_SECONDS")},
&StringFlag{Name: "password", Aliases: []string{"p"}, Sources: EnvVars("APP_PASSWORD")},
&FloatFlag{Name: "unparsable", Aliases: []string{"u"}, Sources: EnvVars("APP_UNPARSABLE")},
&FloatFlag{Name: "no-env-var", Aliases: []string{"n"}},
&FloatFlag{Name: "timeout", Aliases: []string{"t"}, Local: true, Sources: EnvVars("APP_TIMEOUT_SECONDS")},
&StringFlag{Name: "password", Aliases: []string{"p"}, Local: true, Sources: EnvVars("APP_PASSWORD")},
&FloatFlag{Name: "unparsable", Aliases: []string{"u"}, Local: true, Sources: EnvVars("APP_UNPARSABLE")},
&FloatFlag{Name: "no-env-var", Aliases: []string{"n"}, Local: true},
},
Action: func(_ context.Context, cmd *Command) error {
timeoutIsSet = cmd.IsSet("timeout")
Expand Down Expand Up @@ -3772,18 +3767,15 @@ func TestCheckRequiredFlags(t *testing.T) {
_ = os.Setenv(test.envVarInput[0], test.envVarInput[1])
}

set := flag.NewFlagSet("test", 0)
for _, flags := range test.flags {
_ = flags.Apply(set)
}
_ = set.Parse(test.parseInput)

cmd := &Command{
Flags: test.flags,
flagSet: set,
Name: "foo",
Flags: test.flags,
}
args := []string{"foo"}
args = append(args, test.parseInput...)
_ = cmd.Run(context.Background(), args)

err := cmd.checkRequiredFlags()
err := cmd.checkAllRequiredFlags()

// assertions
if test.expectedAnError {
Expand Down Expand Up @@ -4041,7 +4033,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": false,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": "",
"aliases": [
"sub-fl",
Expand All @@ -4062,7 +4054,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": false,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": false,
"aliases": [
"s"
Expand Down Expand Up @@ -4103,7 +4095,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": false,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": "",
"aliases": [
"fl",
Expand All @@ -4124,7 +4116,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": false,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": false,
"aliases": [
"b"
Expand Down Expand Up @@ -4283,7 +4275,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": false,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": false,
"aliases": [
"s"
Expand Down Expand Up @@ -4324,7 +4316,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": false,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": "",
"aliases": [
"fl",
Expand All @@ -4345,7 +4337,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": false,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": false,
"aliases": [
"b"
Expand Down Expand Up @@ -4386,7 +4378,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": false,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": "value",
"aliases": [
"s"
Expand All @@ -4406,7 +4398,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": false,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": "",
"aliases": [
"fl",
Expand All @@ -4427,7 +4419,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": false,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": false,
"aliases": [
"b"
Expand All @@ -4447,7 +4439,7 @@ func TestJSONExportCommand(t *testing.T) {
"required": false,
"hidden": true,
"hideDefault": false,
"persistent": false,
"local": false,
"defaultValue": false,
"aliases": null,
"takesFileArg": false,
Expand Down
10 changes: 6 additions & 4 deletions flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ var VersionFlag Flag = &BoolFlag{
Aliases: []string{"v"},
Usage: "print the version",
HideDefault: true,
Local: true,
}

// HelpFlag prints the help for all commands and subcommands.
Expand All @@ -48,6 +49,7 @@ var HelpFlag Flag = &BoolFlag{
Aliases: []string{"h"},
Usage: "show help",
HideDefault: true,
Local: true,
}

// FlagStringer converts a flag definition to a string. This is used by help
Expand Down Expand Up @@ -172,10 +174,10 @@ type CategorizableFlag interface {
SetCategory(string)
}

// PersistentFlag is an interface to enable detection of flags which are persistent
// through subcommands
type PersistentFlag interface {
IsPersistent() bool
// LocalFlag is an interface to enable detection of flags which are local
// to current command
type LocalFlag interface {
IsLocal() bool
}

// IsDefaultVisible returns true if the flag is not hidden, otherwise false
Expand Down
2 changes: 1 addition & 1 deletion flag_bool_with_inverse.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (parent *BoolWithInverseFlag) initialize() {
Usage: child.Usage,
Required: child.Required,
Hidden: child.Hidden,
Persistent: child.Persistent,
Local: child.Local,
Value: child.Value,
Destination: parent.negDest,
TakesFile: child.TakesFile,
Expand Down
1 change: 1 addition & 0 deletions flag_bool_with_inverse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ func TestBoolWithInverseEnvVars(t *testing.T) {
BoolFlag: &BoolFlag{
Name: "env",
Sources: EnvVars("ENV"),
Local: true,
},
}
}
Expand Down
Loading