diff --git a/README.md b/README.md index 5a68ac5d..47b688fe 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ import ( "context" "os" "strings" + "time" "github.com/DataDog/rshell/interp" "mvdan.cc/sh/v3/syntax" @@ -33,6 +34,7 @@ func main() { runner, _ := interp.New( interp.StdIO(nil, os.Stdout, os.Stderr), interp.AllowedCommands([]string{"rshell:echo"}), + interp.MaxExecutionTime(5*time.Second), ) defer runner.Close() @@ -40,6 +42,12 @@ func main() { } ``` +CLI usage also supports a whole-run timeout: + +```bash +rshell --allow-all-commands --timeout 5s -c 'echo "hello from rshell"' +``` + ## Security Model Every access path is default-deny: diff --git a/SHELL_FEATURES.md b/SHELL_FEATURES.md index 07492524..d3540988 100644 --- a/SHELL_FEATURES.md +++ b/SHELL_FEATURES.md @@ -98,6 +98,7 @@ Blocked features are rejected before execution with exit code 2. - ✅ AllowedCommands — restricts which commands (builtins or external) may be executed; commands require the `rshell:` namespace prefix (e.g. `rshell:cat`); if not set, no commands are allowed - ✅ AllowAllCommands — permits any command (testing convenience) - ✅ AllowedPaths filesystem sandboxing — restricts all file access to specified directories +- ✅ Whole-run execution timeout — callers can bound a `Run()` call via `context.Context`, `interp.MaxExecutionTime`, or the CLI `--timeout` flag; the deadline applies to the entire script, not each individual command - ✅ ProcPath — overrides the proc filesystem path used by `ps` (default `/proc`; Linux-only; useful for testing/container environments) - ❌ External commands — blocked by default; requires an ExecHandler to be configured and the binary to be within AllowedPaths - ❌ Background execution: `cmd &` diff --git a/allowedsymbols/symbols_interp.go b/allowedsymbols/symbols_interp.go index d86d235c..b688a9a0 100644 --- a/allowedsymbols/symbols_interp.go +++ b/allowedsymbols/symbols_interp.go @@ -18,7 +18,9 @@ package allowedsymbols // The permanently banned packages (reflect, unsafe) apply here too. var interpAllowedSymbols = []string{ "bytes.Buffer", // 🟢 in-memory byte buffer; pure data structure, no I/O. + "context.CancelFunc", // 🟢 function type returned by WithTimeout/WithCancel; pure function type, no side effects. "context.Context", // 🟢 deadline/cancellation plumbing; pure interface, no side effects. + "context.WithTimeout", // 🟢 derives a context with a deadline; needed for execution timeout support. "context.WithValue", // 🟢 derives a context carrying a key-value pair; pure function. "errors.As", // 🟢 error type assertion; pure function, no I/O. "fmt.Errorf", // 🟢 formatted error creation; pure function, no I/O. @@ -58,6 +60,7 @@ var interpAllowedSymbols = []string{ "sync.Mutex", // 🟢 mutual exclusion lock; concurrency primitive, no I/O. "sync.Once", // 🟢 ensures a function runs exactly once; concurrency primitive, no I/O. "sync.WaitGroup", // 🟢 waits for goroutines to finish; concurrency primitive, no I/O. + "time.Duration", // 🟢 numeric duration type; pure type, no side effects. "time.Now", // 🟠 returns current time; read-only, no mutation. "time.Time", // 🟢 time value type; pure data, no side effects. diff --git a/cmd/rshell/main.go b/cmd/rshell/main.go index 838bc36d..3e20f1d1 100644 --- a/cmd/rshell/main.go +++ b/cmd/rshell/main.go @@ -14,22 +14,26 @@ import ( "io" "os" "strings" + "time" "github.com/DataDog/rshell/interp" "github.com/spf13/cobra" "mvdan.cc/sh/v3/syntax" ) +const exitCodeTimeout = 124 + func main() { - os.Exit(run(os.Args[1:], os.Stdin, os.Stdout, os.Stderr)) + os.Exit(run(context.Background(), os.Args[1:], os.Stdin, os.Stdout, os.Stderr)) } -func run(args []string, stdin io.Reader, stdout, stderr io.Writer) int { +func run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) int { var ( command string allowedPaths string allowedCommands string allowAllCmds bool + timeout time.Duration procPath string ) @@ -49,6 +53,17 @@ func run(args []string, stdin io.Reader, stdout, stderr io.Writer) int { return fmt.Errorf("cannot use -c with file arguments") } + if timeout < 0 { + return fmt.Errorf("--timeout must be >= 0") + } + + runCtx := cmd.Context() + if timeout > 0 { + var cancel context.CancelFunc + runCtx, cancel = context.WithTimeout(runCtx, timeout) + defer cancel() + } + var paths []string if allowedPaths != "" { paths = strings.Split(allowedPaths, ",") @@ -67,23 +82,28 @@ func run(args []string, stdin io.Reader, stdout, stderr io.Writer) int { } if commandSet { - return execute(cmd.Context(), command, "", execOpts, stdin, stdout, stderr) + return execute(runCtx, command, "", execOpts, stdin, stdout, stderr) } if len(args) > 0 { // Read stdin once so each execute() call gets its own // reader, avoiding a data race on the shared io.Reader. - stdinData, err := io.ReadAll(stdin) + stdinData, err := readAllContext(runCtx, stdin) if err != nil { return fmt.Errorf("reading stdin: %w", err) } for _, file := range args { - data, err := os.ReadFile(file) + f, err := os.Open(file) + if err != nil { + return fmt.Errorf("reading %s: %w", file, err) + } + data, err := readAllContext(runCtx, f) + f.Close() if err != nil { return fmt.Errorf("reading %s: %w", file, err) } - if err := execute(cmd.Context(), string(data), file, execOpts, bytes.NewReader(stdinData), stdout, stderr); err != nil { + if err := execute(runCtx, string(data), file, execOpts, bytes.NewReader(stdinData), stdout, stderr); err != nil { return err } } @@ -91,11 +111,11 @@ func run(args []string, stdin io.Reader, stdout, stderr io.Writer) int { } // No -c and no file args: read from stdin. - stdinData, err := io.ReadAll(stdin) + stdinData, err := readAllContext(runCtx, stdin) if err != nil { return fmt.Errorf("reading stdin: %w", err) } - return execute(cmd.Context(), string(stdinData), "", execOpts, strings.NewReader(""), stdout, stderr) + return execute(runCtx, string(stdinData), "", execOpts, strings.NewReader(""), stdout, stderr) }, } @@ -109,19 +129,55 @@ func run(args []string, stdin io.Reader, stdout, stderr io.Writer) int { cmd.Flags().StringVarP(&allowedPaths, "allowed-paths", "p", "", "comma-separated list of directories the shell is allowed to access") cmd.Flags().StringVar(&allowedCommands, "allowed-commands", "", "comma-separated list of namespaced commands (e.g. rshell:cat,rshell:find)") cmd.Flags().BoolVar(&allowAllCmds, "allow-all-commands", false, "allow execution of all commands (builtins and external)") + cmd.Flags().DurationVar(&timeout, "timeout", 0, "maximum execution time for the entire shell run (e.g. 100ms, 5s, 1m)") cmd.Flags().StringVar(&procPath, "proc-path", "", "path to the proc filesystem used by ps (default \"/proc\")") - if err := cmd.Execute(); err != nil { + if err := cmd.ExecuteContext(ctx); err != nil { var status interp.ExitStatus if errors.As(err, &status) { return int(status) } + if errors.Is(err, context.DeadlineExceeded) { + if timeout > 0 { + fmt.Fprintf(stderr, "error: execution timed out after %s\n", timeout) + } else { + fmt.Fprintln(stderr, "error: execution timed out") + } + return exitCodeTimeout + } + if errors.Is(err, context.Canceled) { + fmt.Fprintln(stderr, "error: execution canceled") + return exitCodeTimeout + } fmt.Fprintf(stderr, "error: %v\n", err) return 1 } return 0 } +// readAllContext reads all bytes from r, but returns ctx.Err() immediately if +// the context is cancelled or its deadline expires before the read completes. +// It spawns a goroutine to perform the read; the goroutine may outlive this +// call if the underlying reader blocks (e.g. stdin from a pipe), but it will +// be reclaimed when the process exits. +func readAllContext(ctx context.Context, r io.Reader) ([]byte, error) { + type result struct { + data []byte + err error + } + ch := make(chan result, 1) + go func() { + data, err := io.ReadAll(r) + ch <- result{data, err} + }() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case res := <-ch: + return res.data, res.err + } +} + // rejectLongCommand scans raw CLI args for "--command" or "--command=..." and // returns an error if found. The flag is registered with a long name so that // cobra/pflag help formatting works correctly, but only the -c shorthand is diff --git a/cmd/rshell/main_test.go b/cmd/rshell/main_test.go index b4575503..fe24bf29 100644 --- a/cmd/rshell/main_test.go +++ b/cmd/rshell/main_test.go @@ -7,6 +7,8 @@ package main import ( "bytes" + "context" + "io" "os" "path/filepath" "runtime" @@ -18,9 +20,14 @@ import ( ) func runCLI(t *testing.T, args ...string) (exitCode int, stdout, stderr string) { + t.Helper() + return runCLIContext(t, context.Background(), args...) +} + +func runCLIContext(t *testing.T, ctx context.Context, args ...string) (exitCode int, stdout, stderr string) { t.Helper() var out, errBuf bytes.Buffer - code := run(args, strings.NewReader(""), &out, &errBuf) + code := run(ctx, args, strings.NewReader(""), &out, &errBuf) return code, out.String(), errBuf.String() } @@ -39,7 +46,7 @@ func TestShortFlag(t *testing.T) { func runCLIWithStdin(t *testing.T, stdin string, args ...string) (exitCode int, stdout, stderr string) { t.Helper() var out, errBuf bytes.Buffer - code := run(args, strings.NewReader(stdin), &out, &errBuf) + code := run(context.Background(), args, strings.NewReader(stdin), &out, &errBuf) return code, out.String(), errBuf.String() } @@ -141,6 +148,7 @@ func TestHelp(t *testing.T) { assert.Contains(t, stdout, "--allowed-paths") assert.Contains(t, stdout, "--allowed-commands") assert.Contains(t, stdout, "--allow-all-commands") + assert.Contains(t, stdout, "--timeout") assert.NotContains(t, stdout, "--command", "-c/--command should be hidden from help") } @@ -243,6 +251,31 @@ func TestCommandLongFormRejected(t *testing.T) { assert.Contains(t, stderr, "unknown flag: --command") } +func TestTimeoutFlagTimesOutExecution(t *testing.T) { + // Feed a blocking stdin with no -c flag so the timeout fires while readAllContext + // is waiting for EOF. 50ms is well above Windows' ~15ms timer resolution. + pr, pw := io.Pipe() + defer pw.Close() + var out, errBuf bytes.Buffer + code := run(context.Background(), []string{"--timeout", "50ms"}, pr, &out, &errBuf) + assert.Equal(t, exitCodeTimeout, code) + assert.Contains(t, errBuf.String(), "execution timed out") +} + +func TestCanceledContextExitsWithTimeoutCode(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel before execution starts + code, _, stderr := runCLIContext(t, ctx, "--allow-all-commands", "-c", `echo hello`) + assert.Equal(t, exitCodeTimeout, code) + assert.Contains(t, stderr, "execution canceled") +} + +func TestTimeoutFlagRejectsNegative(t *testing.T) { + code, _, stderr := runCLI(t, "--timeout", "-1s", "-c", `echo hello`) + assert.Equal(t, 1, code) + assert.Contains(t, stderr, "--timeout must be >= 0") +} + func TestProcPathFlagInHelp(t *testing.T) { code, stdout, _ := runCLI(t, "--help") assert.Equal(t, 0, code) diff --git a/interp/api.go b/interp/api.go index bba2242a..a7c0a546 100644 --- a/interp/api.go +++ b/interp/api.go @@ -59,6 +59,10 @@ type runnerConfig struct { // command. Intended for testing convenience. allowAllCommands bool + // maxExecutionTime bounds the duration of each Run call. Zero disables + // the limit. When non-zero, Run derives a child context with this timeout. + maxExecutionTime time.Duration + // procPath is the path to the proc filesystem used by the ps builtin. // Defaults to "/proc" when empty. procPath string @@ -292,6 +296,22 @@ func StdIO(in io.Reader, out, err io.Writer) RunnerOption { } } +// MaxExecutionTime bounds the total execution time of each [Runner.Run] call. +// +// When d is zero, no timeout is applied. Negative values are rejected. +// +// The timeout is applied per Run call rather than when the Runner is created, +// so reusing a Runner across multiple runs yields a fresh deadline each time. +func MaxExecutionTime(d time.Duration) RunnerOption { + return func(r *Runner) error { + if d < 0 { + return fmt.Errorf("MaxExecutionTime: duration must be >= 0") + } + r.maxExecutionTime = d + return nil + } +} + // Reset returns a runner to its initial state, right before the first call to // Run or Reset. // @@ -377,6 +397,11 @@ func (r *Runner) Run(ctx context.Context, node syntax.Node) (retErr error) { retErr = fmt.Errorf("internal error") } }() + if r.maxExecutionTime > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, r.maxExecutionTime) + defer cancel() + } if !r.didReset { r.Reset() if r.exit.fatalExit { diff --git a/interp/timeout_test.go b/interp/timeout_test.go new file mode 100644 index 00000000..4d49f689 --- /dev/null +++ b/interp/timeout_test.go @@ -0,0 +1,110 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2026-present Datadog, Inc. + +package interp + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTimeoutRunner(t *testing.T, opts ...RunnerOption) *Runner { + t.Helper() + allOpts := append([]RunnerOption{AllowAllCommands()}, opts...) + r, err := New(allOpts...) + require.NoError(t, err) + t.Cleanup(func() { _ = r.Close() }) + r.Reset() + return r +} + +func TestMaxExecutionTimeRejectsNegative(t *testing.T) { + _, err := New(MaxExecutionTime(-time.Second)) + require.Error(t, err) + assert.Contains(t, err.Error(), "MaxExecutionTime") +} + +func TestMaxExecutionTimeStopsRun(t *testing.T) { + r := newTimeoutRunner(t, MaxExecutionTime(20*time.Millisecond)) + r.execHandler = func(ctx context.Context, _ []string) error { + <-ctx.Done() + return ctx.Err() + } + + err := r.Run(context.Background(), parseScript(t, "slowcmd")) + require.Error(t, err) + assert.ErrorIs(t, err, context.DeadlineExceeded) +} + +func TestMaxExecutionTimeRespectsEarlierParentDeadline(t *testing.T) { + r := newTimeoutRunner(t, MaxExecutionTime(time.Second)) + var got time.Time + r.execHandler = func(ctx context.Context, _ []string) error { + var ok bool + got, ok = ctx.Deadline() + require.True(t, ok, "expected deadline on exec handler context") + return nil + } + + parent, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel() + parentDeadline, ok := parent.Deadline() + require.True(t, ok) + + err := r.Run(parent, parseScript(t, "slowcmd")) + require.NoError(t, err) + // context.WithTimeout takes the earlier of the two deadlines, so the runner's 1s + // MaxExecutionTime must not override the parent's tighter 25ms deadline. + assert.WithinDuration(t, parentDeadline, got, 5*time.Millisecond) +} + +func TestMaxExecutionTimeStopsForLoop(t *testing.T) { + // Exercises the interpreter's own ctx.Err() check inside the for-loop body + // (runner_exec.go), not just the execHandler cooperative-cancellation path. + // while/until loops are not supported, so we use a for loop with an + // execHandler that sleeps per iteration to make the loop outlast the timeout. + r := newTimeoutRunner(t, MaxExecutionTime(50*time.Millisecond)) + r.execHandler = func(ctx context.Context, _ []string) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(20 * time.Millisecond): + return nil + } + } + + // 10 iterations × 20ms each = 200ms total, well beyond the 50ms timeout. + err := r.Run(context.Background(), parseScript(t, "for x in 1 2 3 4 5 6 7 8 9 10; do cmd; done")) + require.Error(t, err) + assert.ErrorIs(t, err, context.DeadlineExceeded) +} + +func TestMaxExecutionTimeIsRefreshedPerRun(t *testing.T) { + r := newTimeoutRunner(t, MaxExecutionTime(100*time.Millisecond)) + var deadlines []time.Time + r.execHandler = func(ctx context.Context, _ []string) error { + deadline, ok := ctx.Deadline() + require.True(t, ok, "expected deadline on exec handler context") + deadlines = append(deadlines, deadline) + return nil + } + + prog := parseScript(t, "slowcmd") + + err := r.Run(context.Background(), prog) + require.NoError(t, err) + + time.Sleep(20 * time.Millisecond) + + err = r.Run(context.Background(), prog) + require.NoError(t, err) + + require.Len(t, deadlines, 2) + assert.True(t, deadlines[1].After(deadlines[0]), "expected a fresh deadline on each Run") +}