From 0781d21c2ff1666a7d8859e6b4a294335badc554 Mon Sep 17 00:00:00 2001 From: Arik Kfir Date: Fri, 14 Jun 2024 15:13:56 +0300 Subject: [PATCH] bug(flags): required flags with default value still failed validation When a required flag had a default value, execution still failed, saying that the flag has not been provided. This change fixes that behavior by ensuring that default values are taken under consideration when checking whether a flag is missing or not. NOTE: environment variables will still override default values. --- execute_test.go | 31 +++++++++++++++++++++++++++++++ flag_set.go | 23 ++++++++++++++++------- 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/execute_test.go b/execute_test.go index 073de16..92e9a20 100644 --- a/execute_test.go +++ b/execute_test.go @@ -221,4 +221,35 @@ Flags: With(t).Verify(rootPostRunHook.providedExitCode).Will(EqualTo(exitCode)).OrFail() }) + t.Run("missing required flags fail execution", func(t *testing.T) { + type ActionWithRequiredFlag struct { + TrackingAction + MyFlag string `required:"true"` + } + action := &ActionWithRequiredFlag{} + ctx := context.Background() + root := MustNew("cmd", "desc", "long desc", action, nil, nil) + + b := &bytes.Buffer{} + With(t).Verify(Execute(ctx, b, root, nil, nil)).Will(EqualTo(ExitCodeMisconfiguration)).OrFail() + With(t).Verify(action.TrackingAction.callTime).Will(BeNil()).OrFail() + With(t).Verify(b.String()).Will(EqualTo("required flag is missing: --my-flag\nUsage: cmd [--help] --my-flag=VALUE\n")).OrFail() + }) + + t.Run("required flags with default value do not fail execution", func(t *testing.T) { + type ActionWithRequiredFlag struct { + TrackingAction + MyFlag string `required:"true"` + } + action := &ActionWithRequiredFlag{ + MyFlag: "abc", + } + ctx := context.Background() + root := MustNew("cmd", "desc", "long desc", action, nil, nil) + + b := &bytes.Buffer{} + With(t).Verify(Execute(ctx, b, root, nil, nil)).Will(EqualTo(ExitCodeSuccess)).OrFail() + With(t).Verify(action.TrackingAction.callTime).Will(Not(BeNil())).OrFail() + With(t).Verify(b.String()).Will(BeEmpty()).OrFail() + }) } diff --git a/flag_set.go b/flag_set.go index f868d60..b1dc393 100644 --- a/flag_set.go +++ b/flag_set.go @@ -363,20 +363,29 @@ func (fs *flagSet) apply(envVars map[string]string, args []string) error { // Iterate flags and define them in the stdlib FlagSet for _, mfd := range mergedFlagDefs { - // Set the value to the flag's corresponding environment variable, if one was given - if v, found := envVars[*mfd.EnvVarName]; found { - if err := mfd.setValue(v); err != nil { - return err - } - } - // By definition, for the same name - all flags have the same "HasValue" value, so it should be safe to just // take it from the first one if mfd.HasValue { + + // Set the field's default value so it's marked as "applied" (and thus the "required" validation will ignore it) + if mfd.DefaultValue != "" { + if err := mfd.setValue(mfd.DefaultValue); err != nil { + return fmt.Errorf("failed applying default value for flag '%s': %w", mfd.Name, err) + } + } stdFs.Func(mfd.Name, "", func(v string) error { return mfd.setValue(v) }) + } else { stdFs.BoolFunc(mfd.Name, "", func(string) error { return mfd.setValue("true") }) } + + // Set the value to the flag's corresponding environment variable, if one was given + // Important this is done here, so it overrides the default value set earlier + if v, found := envVars[*mfd.EnvVarName]; found { + if err := mfd.setValue(v); err != nil { + return err + } + } } // Parse the given arguments, which will result in all CLI flags being set