Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 75 additions & 44 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"flag"
"fmt"
"io"
"os"
"reflect"
"regexp"
"slices"
Expand All @@ -18,10 +17,7 @@ import (
var Version = "0.0.0-unknown"

var (
tokenRE = regexp.MustCompile(`^([^=]+)=(.*)$`)
builtinConfig = &BuiltinConfig{
Help: false,
}
tokenRE = regexp.MustCompile(`^([^=]+)=(.*)$`)
)

func New(parent *Command, spec Spec) *Command {
Expand All @@ -33,12 +29,14 @@ func New(parent *Command, spec Spec) *Command {
ShortDescription: spec.ShortDescription,
LongDescription: spec.LongDescription,
Config: spec.Config,
PreSubCommandRun: spec.OnSubCommandRun,
Run: spec.Run,
Parent: parent,
builtinConfig: &BuiltinConfig{Help: false},
parent: parent,
createdByNewCommand: true,
}
if cmd.Parent != nil {
cmd.Parent.subCommands = append(cmd.Parent.subCommands, cmd)
if cmd.parent != nil {
cmd.parent.subCommands = append(cmd.parent.subCommands, cmd)
}
return cmd
}
Expand All @@ -48,17 +46,20 @@ type Spec struct {
ShortDescription string
LongDescription string
Config any
OnSubCommandRun func(ctx context.Context, config any, usagePrinter UsagePrinter) error
Run func(ctx context.Context, config any, usagePrinter UsagePrinter) error
}

type Command struct {
Name string
ShortDescription string
LongDescription string
Parent *Command
subCommands []*Command
Config any
PreSubCommandRun func(ctx context.Context, config any, usagePrinter UsagePrinter) error
Run func(ctx context.Context, config any, usagePrinter UsagePrinter) error
builtinConfig any
parent *Command
subCommands []*Command
createdByNewCommand bool
envVarsMapping map[string]reflect.Value
flagSet *flag.FlagSet
Expand Down Expand Up @@ -91,12 +92,12 @@ func (c *Command) initializeFlagSet() error {

// Create a flag set
name := c.Name
for parent := c.Parent; parent != nil; parent = parent.Parent {
for parent := c.parent; parent != nil; parent = parent.parent {
name = parent.Name + " " + name
}
c.flagSet = flag.NewFlagSet(name, flag.ContinueOnError)
c.flagSet.SetOutput(io.Discard)
if err := c.initializeFlagSetFromStruct(reflect.ValueOf(builtinConfig).Elem()); err != nil {
if err := c.initializeFlagSetFromStruct(reflect.ValueOf(c.builtinConfig).Elem()); err != nil {
return fmt.Errorf("failed to process builtin configuration fields: %w", err)
}

Expand Down Expand Up @@ -261,19 +262,17 @@ func (c *Command) applyEnvironmentVariables(envVars map[string]string) error {
return nil
}

func (c *Command) configure(envVars map[string]string, args []string) error {
func (c *Command) applyCLIArguments(args []string) error {

// Apply environment variables first
if err := c.applyEnvironmentVariables(envVars); err != nil {
return fmt.Errorf("failed to apply environment variables: %w", err)
}

// Override with CLI arguments
// Update config with CLI arguments
if err := c.flagSet.Parse(args); err != nil {
return fmt.Errorf("failed to apply CLI arguments: %w", err)
}

// Ensure all required flags have been provided via either CLI or via environment variables
return nil
}

func (c *Command) validateRequiredFlagsWereProvided(envVars map[string]string) error {
var missingRequiredFlags []string
copy(missingRequiredFlags, c.requiredFlags)
c.flagSet.Visit(func(f *flag.Flag) {
Expand All @@ -289,20 +288,40 @@ func (c *Command) configure(envVars map[string]string, args []string) error {
})
}
}
if len(missingRequiredFlags) > 0 {
return fmt.Errorf("these required flags have not set via either CLI nor environment variables: %v", missingRequiredFlags)
}
return nil
}

func (c *Command) configure(envVars map[string]string, args []string) error {

// Initialize the flagSet for the chosen command
if err := c.initializeFlagSet(); err != nil {
panic(fmt.Sprintf("failed to initialize flag set for command '%s': %v", c.Name, err))
}

// Apply environment variables first
if err := c.applyEnvironmentVariables(envVars); err != nil {
return fmt.Errorf("failed to apply environment variables: %w", err)
}

// Override with CLI arguments
if err := c.flagSet.Parse(args); err != nil {
return fmt.Errorf("failed to apply CLI arguments: %w", err)
}

// Apply positional arguments
if c.positionalArgsTarget != nil {
*c.positionalArgsTarget = c.flagSet.Args()
}
if len(missingRequiredFlags) > 0 {
return fmt.Errorf("these required flags have not set via either CLI nor environment variables: %v", missingRequiredFlags)
}

return nil
}

func (c *Command) printCommandUsage(w io.Writer, short bool) {
cmdChain := c.Name
for cmd := c.Parent; cmd != nil; cmd = cmd.Parent {
for cmd := c.parent; cmd != nil; cmd = cmd.parent {
cmdChain = cmd.Name + " " + cmdChain
}

Expand Down Expand Up @@ -392,10 +411,9 @@ func (c *Command) printCommandUsage(w io.Writer, short bool) {
}
}

//goland:noinspection GoUnusedExportedFunction
func Execute(root *Command, args []string, envVars map[string]string) {
func Execute(ctx context.Context, w io.Writer, root *Command, args []string, envVars map[string]string) (exitCode int) {
if !root.createdByNewCommand {
panic("illegal root command was specified - was it created by 'command.New(...)'?")
panic("invalid root command given, indicating it may not have been created by 'command.New(...)'")
}

// Iterate CLI args, separate them to flags & positional args, but also infer the command to execute from the given
Expand All @@ -411,31 +429,44 @@ func Execute(root *Command, args []string, envVars map[string]string) {
// positional args: [something, sub3, a, b, c]: no "cmd1", "sub1" and "sub2" as they are commands in the hierarchy
cmd, flagArgs, positionalArgs := inferCommandFlagsAndPositionals(root, args)

// Initialize the flagSet for the chosen command
if err := cmd.initializeFlagSet(); err != nil {
panic(fmt.Sprintf("failed to initialize flag set for command '%s': %v", cmd.Name, err))
// Build the command chain from top-to-bottom (so index 0 is the root)
commandChain := []*Command{cmd}
parent := cmd.parent
for parent != nil {
commandChain = append([]*Command{parent}, commandChain...)
parent = parent.parent
}

// Parse the arguments as returned in the parsing step
if err := cmd.configure(envVars, append(flagArgs, positionalArgs...)); err != nil {
cmd.PrintShortUsage(os.Stderr)
os.Exit(1)
} else if cmd.flagSet.Lookup("help").Value.String() == "true" {
cmd.PrintFullUsage(os.Stderr)
os.Exit(0)
// Configure commands up the chain, in order to invoke their "PreSubCommandRun" function
for _, current := range commandChain {
if err := current.configure(envVars, append(flagArgs, positionalArgs...)); err != nil {
current.PrintShortUsage(w)
return 1
}

if err := current.PreSubCommandRun(ctx, current.Config, current); err != nil {
_, _ = fmt.Fprintln(w, err.Error())
return 1
}
}

// If "--help" was provided, show usage and exit immediately
if cmd.flagSet.Lookup("help").Value.String() == "true" {
cmd.PrintFullUsage(w)
return 0
}

// If command has no "Run" function, it's an intermediate probably - just print its usage and exit successfully
if cmd.Run == nil {
cmd.PrintFullUsage(os.Stderr)
cmd.PrintFullUsage(w)
return 0
}

// Run the command with a fresh context
ctx, cancel := context.WithCancel(SetupSignalHandler())
defer cancel()
// Run the command
if err := cmd.Run(ctx, cmd.Config, cmd); err != nil {
cancel() // os.Exit might not invoke the deferred cancel call
_, _ = fmt.Fprintln(os.Stderr, err.Error())
os.Exit(1)
_, _ = fmt.Fprintln(w, err.Error())
return 1
}

return 0
}
Loading