Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions execute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,4 +221,35 @@ Flags:
With(t).Verify(rootPostRunHook.providedExitCode).Will(EqualTo(exitCode)).OrFail()
})

t.Run("missing required flags fail execution", func(t *testing.T) {
type ActionWithRequiredFlag struct {
TrackingAction
MyFlag string `required:"true"`
}
action := &ActionWithRequiredFlag{}
ctx := context.Background()
root := MustNew("cmd", "desc", "long desc", action, nil, nil)

b := &bytes.Buffer{}
With(t).Verify(Execute(ctx, b, root, nil, nil)).Will(EqualTo(ExitCodeMisconfiguration)).OrFail()
With(t).Verify(action.TrackingAction.callTime).Will(BeNil()).OrFail()
With(t).Verify(b.String()).Will(EqualTo("required flag is missing: --my-flag\nUsage: cmd [--help] --my-flag=VALUE\n")).OrFail()
})

t.Run("required flags with default value do not fail execution", func(t *testing.T) {
type ActionWithRequiredFlag struct {
TrackingAction
MyFlag string `required:"true"`
}
action := &ActionWithRequiredFlag{
MyFlag: "abc",
}
ctx := context.Background()
root := MustNew("cmd", "desc", "long desc", action, nil, nil)

b := &bytes.Buffer{}
With(t).Verify(Execute(ctx, b, root, nil, nil)).Will(EqualTo(ExitCodeSuccess)).OrFail()
With(t).Verify(action.TrackingAction.callTime).Will(Not(BeNil())).OrFail()
With(t).Verify(b.String()).Will(BeEmpty()).OrFail()
})
}
23 changes: 16 additions & 7 deletions flag_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,20 +363,29 @@ func (fs *flagSet) apply(envVars map[string]string, args []string) error {
// Iterate flags and define them in the stdlib FlagSet
for _, mfd := range mergedFlagDefs {

// Set the value to the flag's corresponding environment variable, if one was given
if v, found := envVars[*mfd.EnvVarName]; found {
if err := mfd.setValue(v); err != nil {
return err
}
}

// By definition, for the same name - all flags have the same "HasValue" value, so it should be safe to just
// take it from the first one
if mfd.HasValue {

// Set the field's default value so it's marked as "applied" (and thus the "required" validation will ignore it)
if mfd.DefaultValue != "" {
if err := mfd.setValue(mfd.DefaultValue); err != nil {
return fmt.Errorf("failed applying default value for flag '%s': %w", mfd.Name, err)
}
}
stdFs.Func(mfd.Name, "", func(v string) error { return mfd.setValue(v) })

} else {
stdFs.BoolFunc(mfd.Name, "", func(string) error { return mfd.setValue("true") })
}

// Set the value to the flag's corresponding environment variable, if one was given
// Important this is done here, so it overrides the default value set earlier
if v, found := envVars[*mfd.EnvVarName]; found {
if err := mfd.setValue(v); err != nil {
return err
}
}
}

// Parse the given arguments, which will result in all CLI flags being set
Expand Down