diff --git a/cmd/src/cmd.go b/cmd/src/cmd.go index 95c16e6e0e..3a92fb497c 100644 --- a/cmd/src/cmd.go +++ b/cmd/src/cmd.go @@ -42,86 +42,147 @@ func (c *command) matches(name string) bool { // commander represents a top-level command with subcommands. type commander []*command -// run runs the command. +// Run the command func (c commander) run(flagSet *flag.FlagSet, cmdName, usageText string, args []string) { - // Parse flags. - flagSet.Usage = func() { - _, _ = fmt.Fprint(flag.CommandLine.Output(), usageText) + + // NOTE: This function is quite brittle + // Especially with printing helper text at all 3 different levels of depth + + // Check if --help args are anywhere in the command + // If yes, then remove it from the list of args at this point, + // then append it to the deepest command / subcommand, later, + // to avoid outputting usage text for a commander when a subcommand is specified + filteredArgs := make([]string, 0, len(args)) + helpRequested := false + + helpFlags := []string{ + "help", + "-help", + "--help", + "-h", + "--h", } - if !flagSet.Parsed() { - _ = flagSet.Parse(args) + + for _, arg := range args { + if slices.Contains(helpFlags, arg) { + helpRequested = true + } else { + filteredArgs = append(filteredArgs, arg) + } } - // Print usage if the command is "help". - if flagSet.Arg(0) == "help" || flagSet.NArg() == 0 { - flagSet.SetOutput(os.Stdout) - flagSet.Usage() - os.Exit(0) + // Define the usage function for the commander + flagSet.Usage = func() { + _, _ = fmt.Fprint(flag.CommandLine.Output(), usageText) } - // Configure default usage funcs for commands. - for _, cmd := range c { - cmd := cmd - if cmd.usageFunc != nil { - cmd.flagSet.Usage = cmd.usageFunc - continue - } - cmd.flagSet.Usage = func() { - _, _ = fmt.Fprintf(flag.CommandLine.Output(), "Usage of '%s %s':\n", cmdName, cmd.flagSet.Name()) - cmd.flagSet.PrintDefaults() - } + // Parse the commander's flags, if not already parsed + if !flagSet.Parsed() { + _ = flagSet.Parse(filteredArgs) } - // Find the subcommand to execute. + // Find the subcommand to execute + // This assumes the subcommand is the first arg in the flagSet, + // i.e. any global args have been removed from the flagSet name := flagSet.Arg(0) + + // Loop through the list of all registered subcommands for _, cmd := range c { + + // If the first arg is not this registered commmand in the loop, try the next registered command if !cmd.matches(name) { continue } + // If the first arg is this registered commmand in the loop, then try and run it, then exit - // Read global configuration now. + // Set up the usage function for this subcommand + if cmd.usageFunc != nil { + // If the subcommand has a usageFunc defined, then use it + cmd.flagSet.Usage = cmd.usageFunc + } else { + // If the subcommand does not have a usageFunc defined, + // then define a simple default one, + // using the list of flags defined in the subcommand, and their description strings + cmd.flagSet.Usage = func() { + _, _ = fmt.Fprintf(flag.CommandLine.Output(), "Usage of '%s %s':\n", cmdName, cmd.flagSet.Name()) + cmd.flagSet.PrintDefaults() + } + } + + // Read global configuration var err error cfg, err = readConfig() if err != nil { log.Fatal("reading config: ", err) } - // Print help to stdout if requested - if slices.IndexFunc(args, func(s string) bool { - return s == "--help" - }) >= 0 { - cmd.flagSet.SetOutput(os.Stdout) - flag.CommandLine.SetOutput(os.Stdout) - cmd.flagSet.Usage() - os.Exit(0) + // Get the remainder of the args, excluding the first arg / this command name + args := flagSet.Args()[1:] + + // Set output to stdout, for usage / helper text printed for the --help flag (flag package defaults to stderr) + cmd.flagSet.SetOutput(os.Stdout) + flag.CommandLine.SetOutput(os.Stdout) + + // If the --help arg was provided, re-add it here for the lowest command to parse and action + if helpRequested { + args = append(args, "-h") } - // Parse subcommand flags. - args := flagSet.Args()[1:] + // Parse the subcommand's args, on its behalf, to test and ensure flag.ExitOnError is set + // just in case any future authors of subcommands forget to set flag.ExitOnError if err := cmd.flagSet.Parse(args); err != nil { fmt.Printf("Error parsing subcommand flags: %s\n", err) panic(fmt.Sprintf("all registered commands should use flag.ExitOnError: error: %s", err)) } - // Execute the subcommand. - if err := cmd.handler(flagSet.Args()[1:]); err != nil { + // Execute the subcommand + // Handle any errors returned + if err := cmd.handler(args); err != nil { + + // If the returned error is of type UsageError if _, ok := err.(*cmderrors.UsageError); ok { + // then print the error and usage helper text, both to stderr log.Printf("error: %s\n\n", err) cmd.flagSet.SetOutput(os.Stderr) flag.CommandLine.SetOutput(os.Stderr) cmd.flagSet.Usage() os.Exit(2) } + + // If the returned error is of type ExitCodeError if e, ok := err.(*cmderrors.ExitCodeError); ok { + // Then log the error and exit with the exit code if e.HasError() { log.Println(e) } os.Exit(e.Code()) } + + // For all other types of errors, log them as fatal, and exit log.Fatal(err) } + + // If no error was returned, then just exit the application cleanly os.Exit(0) } - log.Printf("%s: unknown subcommand %q", cmdName, name) - log.Fatalf("Run '%s help' for usage.", cmdName) + + // To make it after the big loop, that means name didn't match any registered commands + if name != "" { + log.Printf("%s: unknown command %q", cmdName, name) + flagSet.Usage() + os.Exit(2) + } + + // Special case to handle --help usage text for src command + if helpRequested { + // Set output to stdout, for usage / helper text printed for the --help flag (flag package defaults to stderr) + flagSet.SetOutput(os.Stdout) + flagSet.Usage() + os.Exit(0) + } + + // Special case to handle src command with no args + flagSet.Usage() + os.Exit(2) + } diff --git a/cmd/src/cmd_test.go b/cmd/src/cmd_test.go new file mode 100644 index 0000000000..f62323798e --- /dev/null +++ b/cmd/src/cmd_test.go @@ -0,0 +1,346 @@ +package main + +import ( + "bytes" + "flag" + "fmt" + "os" + "os/exec" + "testing" + + "github.com/sourcegraph/src-cli/internal/cmderrors" +) + +func TestCommand_Matches(t *testing.T) { + tests := []struct { + name string + cmd *command + input string + expected bool + }{ + { + name: "matches command name", + cmd: &command{ + flagSet: flag.NewFlagSet("test", flag.ExitOnError), + }, + input: "test", + expected: true, + }, + { + name: "matches alias", + cmd: &command{ + flagSet: flag.NewFlagSet("test", flag.ExitOnError), + aliases: []string{"t", "tst"}, + }, + input: "t", + expected: true, + }, + { + name: "matches second alias", + cmd: &command{ + flagSet: flag.NewFlagSet("test", flag.ExitOnError), + aliases: []string{"t", "tst"}, + }, + input: "tst", + expected: true, + }, + { + name: "no match", + cmd: &command{ + flagSet: flag.NewFlagSet("test", flag.ExitOnError), + aliases: []string{"t"}, + }, + input: "other", + expected: false, + }, + { + name: "empty string no match", + cmd: &command{ + flagSet: flag.NewFlagSet("test", flag.ExitOnError), + }, + input: "", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.cmd.matches(tt.input) + if result != tt.expected { + t.Errorf("matches(%q) = %v, want %v", tt.input, result, tt.expected) + } + }) + } +} + +func TestCommander_Run_ErrorHandling(t *testing.T) { + tests := []struct { + name string + handlerError error + expectedExit int + description string + }{ + { + name: "usage error", + handlerError: cmderrors.Usage("invalid usage"), + expectedExit: 2, + description: "should exit with code 2 for usage errors", + }, + { + name: "exit code error without message", + handlerError: cmderrors.ExitCode(42, nil), + expectedExit: 42, + description: "should exit with custom exit code", + }, + { + name: "exit code error with message", + handlerError: cmderrors.ExitCode(1, cmderrors.Usage("command failed")), + expectedExit: 1, + description: "should exit with custom exit code and log error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Logf("Test case: %s", tt.description) + }) + } +} + +func TestCommander_Run_UnknownCommand(t *testing.T) { + if os.Getenv("TEST_SUBPROCESS") == "1" { + testHomeDir = os.Getenv("TEST_TEMP_DIR") + cmdr := commander{ + &command{ + flagSet: flag.NewFlagSet("version", flag.ContinueOnError), + handler: func(args []string) error { return nil }, + }, + } + flagSet := flag.NewFlagSet("test", flag.ContinueOnError) + cmdr.run(flagSet, "src", "usage text", []string{"beans"}) + return + } + + tempDir := t.TempDir() + cmd := exec.Command(os.Args[0], "-test.run=^TestCommander_Run_UnknownCommand$") + cmd.Env = append(os.Environ(), "TEST_SUBPROCESS=1", "TEST_TEMP_DIR="+tempDir) + var stderr bytes.Buffer + cmd.Stderr = &stderr + err := cmd.Run() + + if err == nil { + t.Fatal("expected command to exit with non-zero code") + } + + if e, ok := err.(*exec.ExitError); ok { + if e.ExitCode() != 2 { + t.Errorf("expected exit code 2 for unknown command, got %d\nstderr: %s", e.ExitCode(), stderr.String()) + } + } else { + t.Errorf("unexpected error type: %v", err) + } +} + +func TestCommander_Run_HelpFlag(t *testing.T) { + if os.Getenv("TEST_SUBPROCESS") == "1" { + testHomeDir = os.Getenv("TEST_TEMP_DIR") + arg := os.Getenv("TEST_ARG") + cmdr := commander{} + flagSet := flag.NewFlagSet("test", flag.ContinueOnError) + cmdr.run(flagSet, "src", "usage text", []string{arg}) + return + } + + tests := []struct { + name string + arg string + contains string + expectedExit int + }{ + { + name: "help flag at root", + arg: "--help", + contains: "usage text", + expectedExit: 0, + }, + { + name: "-h flag at root", + arg: "-h", + contains: "usage text", + expectedExit: 0, + }, + { + name: "help command at root", + arg: "help", + contains: "usage text", + expectedExit: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + cmd := exec.Command(os.Args[0], "-test.run=^TestCommander_Run_HelpFlag$") + cmd.Env = append(os.Environ(), "TEST_SUBPROCESS=1", "TEST_TEMP_DIR="+tempDir, "TEST_ARG="+tt.arg) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + err := cmd.Run() + + output := stdout.String() + stderr.String() + + if tt.expectedExit == 0 && err != nil { + t.Errorf("expected success, got error: %v\noutput: %s", err, output) + } else if tt.expectedExit != 0 { + if err == nil { + t.Errorf("expected exit code %d, got success", tt.expectedExit) + } else if e, ok := err.(*exec.ExitError); ok && e.ExitCode() != tt.expectedExit { + t.Errorf("expected exit code %d, got %d\noutput: %s", tt.expectedExit, e.ExitCode(), output) + } + } + + if !bytes.Contains([]byte(output), []byte(tt.contains)) { + t.Errorf("expected output to contain %q, got:\n%s", tt.contains, output) + } + }) + } +} + +func TestCommander_Run_NestedHelpFlags(t *testing.T) { + if os.Getenv("TEST_SUBPROCESS") == "1" { + testHomeDir = os.Getenv("TEST_TEMP_DIR") + + uploadFlagSet := flag.NewFlagSet("upload", flag.ExitOnError) + uploadCmd := &command{ + flagSet: uploadFlagSet, + handler: func(args []string) error { return nil }, + usageFunc: func() { + fmt.Fprint(flag.CommandLine.Output(), "upload usage text") + }, + } + + snapshotCommands := commander{uploadCmd} + + snapshotFlagSet := flag.NewFlagSet("snapshot", flag.ExitOnError) + snapshotCmd := &command{ + flagSet: snapshotFlagSet, + handler: func(args []string) error { + snapshotCommands.run(snapshotFlagSet, "src snapshot", "snapshot usage text", args) + return nil + }, + usageFunc: func() { + fmt.Fprint(flag.CommandLine.Output(), "snapshot usage text") + }, + } + + cmdr := commander{snapshotCmd} + flagSet := flag.NewFlagSet("test", flag.ExitOnError) + args := []string{"snapshot", "upload", "--h"} + cmdr.run(flagSet, "src", "root usage", args) + return + } + + tempDir := t.TempDir() + cmd := exec.Command(os.Args[0], "-test.run=^TestCommander_Run_NestedHelpFlags$") + cmd.Env = append(os.Environ(), "TEST_SUBPROCESS=1", "TEST_TEMP_DIR="+tempDir) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + err := cmd.Run() + + output := stdout.String() + stderr.String() + + if err != nil { + t.Errorf("expected success, got error: %v\noutput: %s", err, output) + } + + if !bytes.Contains([]byte(output), []byte("upload usage text")) { + t.Errorf("expected output to contain 'upload usage text', got:\n%s", output) + } + + if bytes.Contains([]byte(output), []byte("snapshot usage text")) { + t.Errorf("expected output NOT to contain 'snapshot usage text', got:\n%s", output) + } +} + +func TestCommander_Run_InvalidSubcommand(t *testing.T) { + if os.Getenv("TEST_SUBPROCESS") == "1" { + testHomeDir = os.Getenv("TEST_TEMP_DIR") + arg := os.Getenv("TEST_ARG") + cmdr := commander{ + &command{ + flagSet: flag.NewFlagSet("version", flag.ContinueOnError), + handler: func(args []string) error { return nil }, + }, + } + flagSet := flag.NewFlagSet("test", flag.ContinueOnError) + cmdr.run(flagSet, "src", "root usage", []string{arg}) + return + } + + tests := []struct { + name string + arg string + expectedExit int + }{ + {"invalid root command", "beans", 2}, + {"invalid root with help", "foobar", 2}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + cmd := exec.Command(os.Args[0], "-test.run=^TestCommander_Run_InvalidSubcommand$") + cmd.Env = append(os.Environ(), "TEST_SUBPROCESS=1", "TEST_TEMP_DIR="+tempDir, "TEST_ARG="+tt.arg) + var stderr bytes.Buffer + cmd.Stderr = &stderr + err := cmd.Run() + + if err == nil { + t.Fatalf("expected exit code %d, got success", tt.expectedExit) + } + + if e, ok := err.(*exec.ExitError); ok { + if e.ExitCode() != tt.expectedExit { + t.Errorf("expected exit code %d, got %d\nstderr: %s", tt.expectedExit, e.ExitCode(), stderr.String()) + } + } else { + t.Errorf("unexpected error type: %v", err) + } + }) + } +} + +func TestCommander_Run_MissingRequiredArgs(t *testing.T) { + if os.Getenv("TEST_SUBPROCESS") == "1" { + testHomeDir = os.Getenv("TEST_TEMP_DIR") + cmdr := commander{ + &command{ + flagSet: flag.NewFlagSet("version", flag.ContinueOnError), + handler: func(args []string) error { return nil }, + }, + } + flagSet := flag.NewFlagSet("test", flag.ContinueOnError) + cmdr.run(flagSet, "src", "root usage", []string{}) + return + } + + tempDir := t.TempDir() + cmd := exec.Command(os.Args[0], "-test.run=^TestCommander_Run_MissingRequiredArgs$") + cmd.Env = append(os.Environ(), "TEST_SUBPROCESS=1", "TEST_TEMP_DIR="+tempDir) + var stderr bytes.Buffer + cmd.Stderr = &stderr + err := cmd.Run() + + if err == nil { + t.Fatal("expected exit code 2, got success") + } + + if e, ok := err.(*exec.ExitError); ok { + if e.ExitCode() != 2 { + t.Errorf("expected exit code 2, got %d\nstderr: %s", e.ExitCode(), stderr.String()) + } + } else { + t.Errorf("unexpected error type: %v", err) + } +}