Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,28 @@ func (i PreRunHookFunc) PreRun(ctx context.Context) error {
}
}

type PostRunHook interface {
PostRun(context.Context, error, ExitCode) error
}

type PostRunHookFunc func(context.Context, error, ExitCode) error

func (i PostRunHookFunc) PostRun(ctx context.Context, err error, exitCode ExitCode) error {
if i != nil {
return i(ctx, err, exitCode)
} else {
return nil
}
}

// Command is a command instance, created by [New] and can be composed with more Command instances to form a CLI command
// hierarchy.
type Command struct {
name string
shortDescription string
longDescription string
preRunHooks []PreRunHook
postRunHooks []PostRunHook
action Action
flags *flagSet
parent *Command
Expand All @@ -64,8 +79,8 @@ type Command struct {
// MustNew creates a new command using [New], but will panic if it returns an error.
//
//goland:noinspection GoUnusedExportedFunction
func MustNew(name, shortDescription, longDescription string, action Action, preRunHooks []PreRunHook, subCommands ...*Command) *Command {
cmd, err := New(name, shortDescription, longDescription, action, preRunHooks, subCommands...)
func MustNew(name, shortDescription, longDescription string, action Action, preRunHooks []PreRunHook, postRunHooks []PostRunHook, subCommands ...*Command) *Command {
cmd, err := New(name, shortDescription, longDescription, action, preRunHooks, postRunHooks, subCommands...)
if err != nil {
panic(err)
}
Expand All @@ -74,7 +89,7 @@ func MustNew(name, shortDescription, longDescription string, action Action, preR

// New creates a new command with the given name, short & long descriptions, and the given executor. The executor object
// is also scanned for configuration structs via reflection.
func New(name, shortDescription, longDescription string, action Action, preRunHooks []PreRunHook, subCommands ...*Command) (*Command, error) {
func New(name, shortDescription, longDescription string, action Action, preRunHooks []PreRunHook, postRunHooks []PostRunHook, subCommands ...*Command) (*Command, error) {
if name == "" {
return nil, fmt.Errorf("%w: empty name", ErrInvalidCommand)
} else if shortDescription == "" {
Expand All @@ -88,6 +103,7 @@ func New(name, shortDescription, longDescription string, action Action, preRunHo
longDescription: longDescription,
action: action,
preRunHooks: preRunHooks,
postRunHooks: postRunHooks,
HelpConfig: &HelpConfig{},
}

Expand Down
69 changes: 37 additions & 32 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,19 @@ func TestNew(t *testing.T) {
testCases := map[string]testCase{
"empty name": {
commandFactory: func(t T, tc *testCase) (*Command, error) {
return New("", "short desc", "long desc", nil, nil)
return New("", "short desc", "long desc", nil, nil, nil)
},
expectedError: `^invalid command: empty name$`,
},
"empty short description": {
commandFactory: func(t T, tc *testCase) (*Command, error) {
return New("cmd", "", "long desc", nil, nil)
return New("cmd", "", "long desc", nil, nil, nil)
},
expectedError: `^invalid command: empty short description$`,
},
"no flags": {
commandFactory: func(t T, tc *testCase) (*Command, error) {
return New("cmd", "desc", "long desc", nil, nil)
return New("cmd", "desc", "long desc", nil, nil, nil)
},
expectedName: "cmd",
expectedShortDescription: "desc",
Expand All @@ -54,6 +54,7 @@ func TestNew(t *testing.T) {
MyFlag string `flag:"true"`
}{},
nil,
nil,
)
},
expectedFlagSet: &flagSet{
Expand Down Expand Up @@ -96,13 +97,13 @@ func TestNew(t *testing.T) {
func TestAddSubCommand(t *testing.T) {
t.Parallel()

root, err := New("root", "desc", "description", nil, nil)
root, err := New("root", "desc", "description", nil, nil, nil)
With(t).Verify(err).Will(BeNil()).OrFail()

sub1, err := New("sub1", "sub1 desc", "sub1 description", nil, nil)
sub1, err := New("sub1", "sub1 desc", "sub1 description", nil, nil, nil)
With(t).Verify(err).Will(BeNil()).OrFail()

sub2, err := New("sub2", "sub2 desc", "sub2 description", nil, nil)
sub2, err := New("sub2", "sub2 desc", "sub2 description", nil, nil, nil)
With(t).Verify(err).Will(BeNil()).OrFail()

With(t).Verify(root.AddSubCommand(sub1)).Will(BeNil()).OrFail()
Expand All @@ -123,10 +124,10 @@ func Test_inferCommandAndArgs(t *testing.T) {
testCases := map[string]testCase{
"No arguments": {
root: MustNew(
"root", "desc", "description", nil, nil,
MustNew("sub1", "sub1 desc", "sub1 description", nil, nil,
MustNew("sub2", "sub2 desc", "sub2 description", nil, nil,
MustNew("sub3", "sub3 desc", "sub3 description", nil, nil),
"root", "desc", "description", nil, nil, nil,
MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, nil,
MustNew("sub2", "sub2 desc", "sub2 description", nil, nil, nil,
MustNew("sub3", "sub3 desc", "sub3 description", nil, nil, nil),
),
),
),
Expand All @@ -137,9 +138,9 @@ func Test_inferCommandAndArgs(t *testing.T) {
},
"Flags for root command": {
root: MustNew(
"root", "desc", "description", nil, nil,
MustNew("sub1", "sub1 desc", "sub1 description", nil, nil,
MustNew("sub2", "sub2 desc", "sub2 description", nil, nil),
"root", "desc", "description", nil, nil, nil,
MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, nil,
MustNew("sub2", "sub2 desc", "sub2 description", nil, nil, nil),
),
),
args: strings.Split("-f1 -f2", " "),
Expand All @@ -149,9 +150,9 @@ func Test_inferCommandAndArgs(t *testing.T) {
},
"Flags and positionals for root command": {
root: MustNew(
"root", "desc", "description", nil, nil,
MustNew("sub1", "sub1 desc", "sub1 description", nil, nil,
MustNew("sub2", "sub2 desc", "sub2 description", nil, nil),
"root", "desc", "description", nil, nil, nil,
MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, nil,
MustNew("sub2", "sub2 desc", "sub2 description", nil, nil, nil),
),
),
args: strings.Split("-f1 a -f2 b", " "),
Expand All @@ -161,9 +162,9 @@ func Test_inferCommandAndArgs(t *testing.T) {
},
"Flags and positionals for sub1 command": {
root: MustNew(
"root", "desc", "description", nil, nil,
MustNew("sub1", "sub1 desc", "sub1 description", nil, nil,
MustNew("sub2", "sub2 desc", "sub2 description", nil, nil),
"root", "desc", "description", nil, nil, nil,
MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, nil,
MustNew("sub2", "sub2 desc", "sub2 description", nil, nil, nil),
),
),
args: strings.Split("-f1 sub1 -f2 a b", " "),
Expand All @@ -173,9 +174,9 @@ func Test_inferCommandAndArgs(t *testing.T) {
},
"Flags and positionals for sub2 command": {
root: MustNew(
"root", "desc", "description", nil, nil,
MustNew("sub1", "sub1 desc", "sub1 description", nil, nil,
MustNew("sub2", "sub2 desc", "sub2 description", nil, nil),
"root", "desc", "description", nil, nil, nil,
MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, nil,
MustNew("sub2", "sub2 desc", "sub2 description", nil, nil, nil),
),
),
args: strings.Split("-f1 sub1 -f2 a b sub2 c", " "),
Expand All @@ -200,10 +201,10 @@ func Test_getFullName(t *testing.T) {
cmd *Command
expectedFullName string
}
sub3 := MustNew("sub3", "sub3 desc", "sub3 description", nil, nil)
sub2 := MustNew("sub2", "sub2 desc", "sub2 description", nil, nil, sub3)
sub1 := MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, sub2)
root := MustNew("root", "desc", "description", nil, nil, sub1)
sub3 := MustNew("sub3", "sub3 desc", "sub3 description", nil, nil, nil)
sub2 := MustNew("sub2", "sub2 desc", "sub2 description", nil, nil, nil, sub3)
sub1 := MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, nil, sub2)
root := MustNew("root", "desc", "description", nil, nil, nil, sub1)
testCases := map[string]testCase{
"root": {
cmd: root,
Expand Down Expand Up @@ -235,10 +236,10 @@ func Test_getChain(t *testing.T) {
cmd *Command
expectedChain []string
}
sub3 := MustNew("sub3", "sub3 desc", "sub3 description", nil, nil)
sub2 := MustNew("sub2", "sub2 desc", "sub2 description", nil, nil, sub3)
sub1 := MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, sub2)
root := MustNew("root", "desc", "description", nil, nil, sub1)
sub3 := MustNew("sub3", "sub3 desc", "sub3 description", nil, nil, nil)
sub2 := MustNew("sub2", "sub2 desc", "sub2 description", nil, nil, nil, sub3)
sub1 := MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, nil, sub2)
root := MustNew("root", "desc", "description", nil, nil, nil, sub1)
testCases := map[string]testCase{
"root": {
cmd: root,
Expand Down Expand Up @@ -281,7 +282,7 @@ func TestPrintHelp(t *testing.T) {
"no flags & no positionals": {
commandFactory: func(*testCase) *Command {
ligen := loremipsum.NewWithSeed(4321)
return MustNew("cmd", ligen.Sentence(), ligen.Sentences(2), nil, nil)
return MustNew("cmd", ligen.Sentence(), ligen.Sentences(2), nil, nil, nil)
},
expectedHelpUsageOutput: `
Usage: cmd [--help]
Expand Down Expand Up @@ -324,7 +325,9 @@ Flags:
MyFlag string `desc:"flag description"`
Args []string `args:"true"`
}{},
nil)
nil,
nil,
)
},
expectedHelpUsageOutput: `
Usage: cmd [--help]
Expand Down Expand Up @@ -374,6 +377,7 @@ Flags:
Args []string `args:"true"`
}{},
nil,
nil,
MustNew(
"child1",
ligen.Sentence(),
Expand All @@ -384,6 +388,7 @@ Flags:
Args []string `args:"true"`
}{},
nil,
nil,
),
)
},
Expand Down
59 changes: 45 additions & 14 deletions execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@ const (
// Execute the correct command in the given command hierarchy (starting at "root"), configured from the given CLI args
// and environment variables. The command will be executed with the given context after all pre-RunFunc hooks have been
// successfully executed in the command hierarchy.
func Execute(ctx context.Context, w io.Writer, root *Command, args []string, envVars map[string]string) ExitCode {
func Execute(ctx context.Context, w io.Writer, root *Command, args []string, envVars map[string]string) (exitCode ExitCode) {
exitCode = ExitCodeSuccess

// We insist on getting the root command - so that we can infer correctly which command the user wanted to invoke
if root.parent != nil {
_, _ = fmt.Fprintf(w, "%s: command must be the root command", errors.ErrUnsupported)
return ExitCodeError
exitCode = ExitCodeError
return
}

// Extract the command, CLI flags, positional arguments & the command hierarchy
Expand All @@ -33,25 +37,51 @@ func Execute(ctx context.Context, w io.Writer, root *Command, args []string, env
_, _ = fmt.Fprintln(w, err)
if err := cmd.PrintUsageLine(w, getTerminalWidth()); err != nil {
_, _ = fmt.Fprintf(w, "%s\n", err)
return ExitCodeError
exitCode = ExitCodeError
return
} else {
return ExitCodeMisconfiguration
exitCode = ExitCodeMisconfiguration
return
}
} else if cmd.HelpConfig.Help {
if err := cmd.PrintHelp(w, getTerminalWidth()); err != nil {
_, _ = fmt.Fprintf(w, "%s\n", err)
return ExitCodeError
exitCode = ExitCodeMisconfiguration
return
} else {
return ExitCodeSuccess
exitCode = ExitCodeSuccess
return
}
}

// Results
var actionError error

// Ensure we invoke post-run hooks before we return
chain := cmd.getChain()
defer func() {
for i := len(chain) - 1; i >= 0; i-- {
c := chain[i]
for j := len(c.postRunHooks) - 1; j >= 0; j-- {
h := c.postRunHooks[j]
if err := h.PostRun(ctx, actionError, exitCode); err != nil {
_, _ = fmt.Fprintln(w, err)
exitCode = ExitCodeError
}
}
}
}()

// Invoke all "PreRun" hooks on the whole chain of commands (starting at the root)
for _, c := range cmd.getChain() {
for _, hook := range c.preRunHooks {
if err := hook.PreRun(ctx); err != nil {
for i := 0; i < len(chain); i++ {
c := chain[i]
for j := 0; j < len(c.preRunHooks); j++ {
h := c.preRunHooks[j]
if err := h.PreRun(ctx); err != nil {
_, _ = fmt.Fprintln(w, err)
return ExitCodeError
actionError = err
exitCode = ExitCodeError
return
}
}
}
Expand All @@ -60,15 +90,16 @@ func Execute(ctx context.Context, w io.Writer, root *Command, args []string, env
if cmd.action != nil {
if err := cmd.action.Run(ctx); err != nil {
_, _ = fmt.Fprintln(w, err)
return ExitCodeError
actionError = err
exitCode = ExitCodeError
}
} else {
// Command is not a runner - print help
if err := cmd.PrintHelp(w, getTerminalWidth()); err != nil {
_, _ = fmt.Fprintf(w, "%s\n", err)
return ExitCodeError
actionError = err
exitCode = ExitCodeError
}
}
return ExitCodeSuccess

return
}
Loading