Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ module github.com/DataDog/rshell
go 1.25.6

require (
github.com/spf13/cobra v1.10.2
github.com/spf13/pflag v1.0.9
github.com/stretchr/testify v1.11.1
gopkg.in/yaml.v3 v3.0.1
mvdan.cc/sh/v3 v3.12.0
Expand All @@ -12,6 +14,4 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/spf13/cobra v1.10.2 // indirect
github.com/spf13/pflag v1.0.9 // indirect
)
2 changes: 1 addition & 1 deletion interp/builtins/break/break.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
)

// Cmd is the break builtin command descriptor.
var Cmd = builtins.Command{Name: "break", Run: run}
var Cmd = builtins.Command{Name: "break", MakeFlags: builtins.NoFlags(run)}

func run(_ context.Context, callCtx *builtins.CallContext, args []string) builtins.Result {
return loopctl.LoopControl(callCtx, "break", args)
Expand Down
63 changes: 51 additions & 12 deletions interp/builtins/builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,60 @@ import (
"fmt"
"io"
"os"

"github.com/spf13/pflag"
)

// HandlerFunc is the signature for a builtin command implementation.
// FlagSet is a type alias for pflag.FlagSet. Command files receive a *FlagSet
// from the framework without needing to import pflag directly (the builtins
// package is always allowed by the import allowlist).
type FlagSet = pflag.FlagSet

// HandlerFunc is the bound handler called by the framework after flags are
// parsed. args contains only the positional (non-flag) arguments.
type HandlerFunc func(ctx context.Context, callCtx *CallContext, args []string) Result

// Command pairs a builtin name with its flag-declaring factory. MakeFlags
// registers any flags on the provided FlagSet and returns the bound handler.
// Commands that accept no flags may ignore fs via NoFlags.
type Command struct {
Name string
MakeFlags func(*FlagSet) HandlerFunc
}

// NoFlags wraps a HandlerFunc in the MakeFlags format for commands that
// declare no flags.
func NoFlags(fn HandlerFunc) func(*FlagSet) HandlerFunc {
return func(_ *FlagSet) HandlerFunc { return fn }
}

// Register adds the Command to the builtin registry. For each invocation the
// framework creates a fresh *FlagSet, passes it to MakeFlags so the command
// can register its flags, parses the raw args, writes any error to stderr
// (exit 1), and then calls the bound handler with positional args only.
//
// If MakeFlags registers no flags (e.g. via NoFlags), the framework skips
// parsing entirely and passes all raw args to the handler unchanged. This
// lets commands like echo treat flag-shaped literals (e.g. -n) correctly.
func (c Command) Register() {
name := c.Name
factory := c.MakeFlags
addToRegistry(name, func(ctx context.Context, callCtx *CallContext, args []string) Result {
fs := pflag.NewFlagSet(name, pflag.ContinueOnError)
fs.SetOutput(io.Discard) // handler formats errors itself
handler := factory(fs)
if !fs.HasFlags() {
// No flags declared: pass all args through unchanged.
return handler(ctx, callCtx, args)
}
if err := fs.Parse(args); err != nil {
callCtx.Errf("%s: %v\n", name, err)
return Result{Code: 1}
}
return handler(ctx, callCtx, fs.Args())
})
}

// CallContext provides the capabilities available to builtin commands.
// It is created by the Runner for each builtin invocation.
type CallContext struct {
Expand Down Expand Up @@ -65,18 +114,9 @@ type Result struct {
ContinueN int
}

// Command pairs a builtin name with its handler, used for explicit
// registration in the all package.
type Command struct {
Name string
Run HandlerFunc
}

var registry = map[string]HandlerFunc{}

// Register adds a builtin command to the registry.
// It panics if name is already registered, catching duplicate registrations at startup.
func Register(name string, fn HandlerFunc) {
func addToRegistry(name string, fn HandlerFunc) {
if _, exists := registry[name]; exists {
panic("builtin already registered: " + name)
}
Expand All @@ -88,4 +128,3 @@ func Lookup(name string) (HandlerFunc, bool) {
fn, ok := registry[name]
return fn, ok
}

135 changes: 63 additions & 72 deletions interp/builtins/cat/cat.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,11 @@ import (
"io"
"os"

"github.com/spf13/pflag"

"github.com/DataDog/rshell/interp/builtins"
)

// Cmd is the cat builtin command descriptor.
var Cmd = builtins.Command{Name: "cat", Run: run}
var Cmd = builtins.Command{Name: "cat", MakeFlags: registerFlags}

// MaxLineBytes is the per-line buffer cap for the line scanner. Lines
// longer than this are reported as an error instead of being buffered.
Expand All @@ -90,10 +88,7 @@ const (
lineNumWidth = 6 // GNU cat line-number field width
)

func run(ctx context.Context, callCtx *builtins.CallContext, args []string) builtins.Result {
fs := pflag.NewFlagSet("cat", pflag.ContinueOnError)
fs.SetOutput(io.Discard)

func registerFlags(fs *builtins.FlagSet) builtins.HandlerFunc {
help := fs.BoolP("help", "h", false, "print usage and exit")
number := fs.BoolP("number", "n", false, "number all output lines")
numberNonblank := fs.BoolP("number-nonblank", "b", false, "number non-blank output lines, overrides -n")
Expand All @@ -106,80 +101,76 @@ func run(ctx context.Context, callCtx *builtins.CallContext, args []string) buil
flagT := fs.BoolP("show-nonprinting-tabs", "t", false, "equivalent to -vT")
_ = fs.BoolP("unbuffered", "u", false, "ignored")

if err := fs.Parse(args); err != nil {
callCtx.Errf("cat: %v\n", err)
return builtins.Result{Code: 1}
}

if *help {
callCtx.Out("Usage: cat [OPTION]... [FILE]...\n")
callCtx.Out("Concatenate FILE(s) to standard output.\n")
callCtx.Out("With no FILE, or when FILE is -, read standard input.\n\n")
fs.SetOutput(callCtx.Stdout)
fs.PrintDefaults()
return builtins.Result{}
}

if *showAll {
*showNonprinting = true
*showEnds = true
*showTabs = true
}
if *flagE {
*showNonprinting = true
*showEnds = true
}
if *flagT {
*showNonprinting = true
*showTabs = true
}
if *numberNonblank {
*number = false
}

needsLineProcessing := *number || *numberNonblank || *squeezeBlank ||
*showEnds || *showTabs || *showNonprinting
return func(ctx context.Context, callCtx *builtins.CallContext, files []string) builtins.Result {
if *help {
callCtx.Out("Usage: cat [OPTION]... [FILE]...\n")
callCtx.Out("Concatenate FILE(s) to standard output.\n")
callCtx.Out("With no FILE, or when FILE is -, read standard input.\n\n")
fs.SetOutput(callCtx.Stdout)
fs.PrintDefaults()
return builtins.Result{}
}

files := fs.Args()
if len(files) == 0 {
files = []string{"-"}
}
if *showAll {
*showNonprinting = true
*showEnds = true
*showTabs = true
}
if *flagE {
*showNonprinting = true
*showEnds = true
}
if *flagT {
*showNonprinting = true
*showTabs = true
}
if *numberNonblank {
*number = false
}

st := &state{
number: *number,
numberNonblank: *numberNonblank,
squeezeBlank: *squeezeBlank,
showEnds: *showEnds,
showTabs: *showTabs,
showNonprinting: *showNonprinting,
lineNum: 1,
}
needsLineProcessing := *number || *numberNonblank || *squeezeBlank ||
*showEnds || *showTabs || *showNonprinting

var failed bool
for _, file := range files {
if ctx.Err() != nil {
break
if len(files) == 0 {
files = []string{"-"}
}
var err error
if needsLineProcessing {
err = catLines(ctx, callCtx, file, st)
} else {
err = catRaw(ctx, callCtx, file)

st := &state{
number: *number,
numberNonblank: *numberNonblank,
squeezeBlank: *squeezeBlank,
showEnds: *showEnds,
showTabs: *showTabs,
showNonprinting: *showNonprinting,
lineNum: 1,
}
if err != nil {
name := file
if file == "-" {
name = "standard input"

var failed bool
for _, file := range files {
if ctx.Err() != nil {
break
}
var err error
if needsLineProcessing {
err = catLines(ctx, callCtx, file, st)
} else {
err = catRaw(ctx, callCtx, file)
}
if err != nil {
name := file
if file == "-" {
name = "standard input"
}
callCtx.Errf("cat: %s: %s\n", name, callCtx.PortableErr(err))
failed = true
}
callCtx.Errf("cat: %s: %s\n", name, callCtx.PortableErr(err))
failed = true
}
}

if failed {
return builtins.Result{Code: 1}
if failed {
return builtins.Result{Code: 1}
}
return builtins.Result{}
}
return builtins.Result{}
}

type state struct {
Expand Down
2 changes: 1 addition & 1 deletion interp/builtins/continue/continue.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
)

// Cmd is the continue builtin command descriptor.
var Cmd = builtins.Command{Name: "continue", Run: run}
var Cmd = builtins.Command{Name: "continue", MakeFlags: builtins.NoFlags(run)}

func run(_ context.Context, callCtx *builtins.CallContext, args []string) builtins.Result {
return loopctl.LoopControl(callCtx, "continue", args)
Expand Down
2 changes: 1 addition & 1 deletion interp/builtins/echo/echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
)

// Cmd is the echo builtin command descriptor.
var Cmd = builtins.Command{Name: "echo", Run: run}
var Cmd = builtins.Command{Name: "echo", MakeFlags: builtins.NoFlags(run)}

func run(_ context.Context, callCtx *builtins.CallContext, args []string) builtins.Result {
for i, arg := range args {
Expand Down
2 changes: 1 addition & 1 deletion interp/builtins/exit/exit.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
)

// Cmd is the exit builtin command descriptor.
var Cmd = builtins.Command{Name: "exit", Run: run}
var Cmd = builtins.Command{Name: "exit", MakeFlags: builtins.NoFlags(run)}

func run(_ context.Context, callCtx *builtins.CallContext, args []string) builtins.Result {
var r builtins.Result
Expand Down
2 changes: 1 addition & 1 deletion interp/builtins/false/false.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
)

// Cmd is the false builtin command descriptor.
var Cmd = builtins.Command{Name: "false", Run: run}
var Cmd = builtins.Command{Name: "false", MakeFlags: builtins.NoFlags(run)}

func run(_ context.Context, _ *builtins.CallContext, _ []string) builtins.Result {
return builtins.Result{Code: 1}
Expand Down
Loading