diff --git a/allowedsymbols/symbols_interp.go b/allowedsymbols/symbols_interp.go index 5960185b..810e9bdf 100644 --- a/allowedsymbols/symbols_interp.go +++ b/allowedsymbols/symbols_interp.go @@ -24,6 +24,7 @@ var interpAllowedSymbols = []string{ "context.WithTimeout", // 🟢 derives a context with a deadline; needed for execution timeout support. "context.WithValue", // 🟢 derives a context carrying a key-value pair; pure function. "errors.As", // 🟢 error type assertion; pure function, no I/O. + "errors.New", // 🟢 creates a sentinel error value; pure function, no I/O. "fmt.Errorf", // 🟢 formatted error creation; pure function, no I/O. "fmt.Fprintf", // 🟠 formatted write to an io.Writer; delegates to Write, no filesystem access. "fmt.Fprintln", // 🟠 writes to an io.Writer with newline; delegates to Write, no filesystem access. diff --git a/interp/api.go b/interp/api.go index ce817918..94e10489 100644 --- a/interp/api.go +++ b/interp/api.go @@ -421,6 +421,14 @@ func (r *Runner) Reset() { r.didReset = true } +// ErrOutputLimitExceeded is returned by Run when a script produces more stdout +// than maxStdoutBytes. Partial output up to the limit is still delivered to the +// caller's writer. Use errors.Is to check for this condition. +var ErrOutputLimitExceeded = errors.New(fmt.Sprintf( + "stdout limit exceeded: script produced more than %d MiB of output", + maxStdoutBytes/(1024*1024), +)) + // ExitStatus is a non-zero status code resulting from running a shell node. type ExitStatus uint8 @@ -458,6 +466,15 @@ func (r *Runner) Run(ctx context.Context, node syntax.Node) (retErr error) { return r.exit.err } } + // Wrap stdout with a cap. Bytes beyond maxStdoutBytes are silently + // discarded so that builtins never see a write error mid-execution, but + // the exceeded flag is set so Run() can surface a well-defined error to + // the caller after the script finishes. Restore r.stdout on return so + // that repeated Run() calls without Reset() do not double-wrap the writer. + prevStdout := r.stdout + stdoutCap := &limitWriter{w: prevStdout, limit: maxStdoutBytes} + r.stdout = stdoutCap + defer func() { r.stdout = prevStdout }() r.startTime = time.Now() r.globReadDirCount = &atomic.Int64{} r.fillExpandConfig(ctx) @@ -478,10 +495,15 @@ func (r *Runner) Run(ctx context.Context, node syntax.Node) (retErr error) { default: return fmt.Errorf("node can only be File, Stmt, or Command: %T", node) } - // Return the first of: a fatal error, a non-fatal handler error, or the exit code. + // Return the first of: a fatal/handler error, stdout cap exceeded, or the exit code. + // Fatal errors take precedence over ErrOutputLimitExceeded so that cancellation + // and handler failures are not masked when the cap is also hit. if err := r.exit.err; err != nil { return err } + if stdoutCap.isExceeded() { + return ErrOutputLimitExceeded + } if code := r.exit.code; code != 0 { return ExitStatus(code) } diff --git a/interp/runner_expand.go b/interp/runner_expand.go index 2307a72d..42252e19 100644 --- a/interp/runner_expand.go +++ b/interp/runner_expand.go @@ -14,6 +14,7 @@ import ( "io/fs" "os" "strings" + "sync" "mvdan.cc/sh/v3/expand" "mvdan.cc/sh/v3/syntax" @@ -49,6 +50,12 @@ func (r *Runner) updateExpandOpts() { // commands that produce unbounded output. const maxCmdSubstOutput = 1 << 20 // 1 MiB +// maxStdoutBytes is the maximum number of bytes a script can write to stdout +// before further output is silently discarded. This caps total script output +// to prevent memory exhaustion from runaway commands (e.g. infinite loops +// writing to stdout). +const maxStdoutBytes = 10 * 1024 * 1024 // 10 MiB + // MaxGlobReadDirCalls is the maximum number of ReadDirForGlob invocations // allowed per Run() call. This prevents memory exhaustion from scripts that // trigger an excessive number of glob expansions (e.g. millions of unquoted @@ -130,14 +137,27 @@ func catShortcutArg(stmt *syntax.Stmt) *syntax.Word { } // limitWriter wraps a writer and stops writing after limit bytes. +// When the limit is exceeded, exceeded is set to true and further writes +// are silently discarded so that callers do not see spurious short-write +// errors mid-execution. The exceeded flag can be checked after execution +// via isExceeded to surface the event as an error. +// +// limitWriter is safe for concurrent use: the mutex serialises writes so +// that the byte counter and exceeded flag are always consistent, even when +// background goroutines write to the same writer concurrently. type limitWriter struct { - w io.Writer - limit int64 - n int64 + mu sync.Mutex + w io.Writer + limit int64 + n int64 + exceeded bool } func (lw *limitWriter) Write(p []byte) (int, error) { + lw.mu.Lock() + defer lw.mu.Unlock() if lw.n >= lw.limit { + lw.exceeded = true return len(p), nil // silently discard excess } remaining := lw.limit - lw.n @@ -146,6 +166,7 @@ func (lw *limitWriter) Write(p []byte) (int, error) { return int(remaining), err } lw.n = lw.limit + lw.exceeded = true return len(p), nil // report all bytes consumed to avoid short-write errors } n, err := lw.w.Write(p) @@ -153,6 +174,14 @@ func (lw *limitWriter) Write(p []byte) (int, error) { return n, err } +// isExceeded reports whether any write has exceeded the byte limit. +// It is safe to call concurrently with Write. +func (lw *limitWriter) isExceeded() bool { + lw.mu.Lock() + defer lw.mu.Unlock() + return lw.exceeded +} + func (r *Runner) expandErr(err error) { if err == nil { return diff --git a/interp/tests/cmdsubst_hardening_test.go b/interp/tests/cmdsubst_hardening_test.go index 8392dd83..6108727b 100644 --- a/interp/tests/cmdsubst_hardening_test.go +++ b/interp/tests/cmdsubst_hardening_test.go @@ -6,7 +6,9 @@ package tests_test import ( + "bytes" "context" + "errors" "fmt" "os" "path/filepath" @@ -16,10 +18,139 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "mvdan.cc/sh/v3/syntax" + + "github.com/DataDog/rshell/internal/interpoption" + "github.com/DataDog/rshell/interp" ) // --- Memory limits: output capping --- +// TestGlobalStdoutCapReturnsError verifies that Run returns ErrOutputLimitExceeded +// when a script exceeds the 10 MiB stdout cap. The script runs to completion +// but partial output (up to the limit) is still delivered, and the caller +// receives a well-defined error rather than a silent truncation. +func TestGlobalStdoutCapReturnsError(t *testing.T) { + dir := t.TempDir() + + // Create a file of exactly 1 MiB. + content := strings.Repeat("A", 1<<20) + require.NoError(t, os.WriteFile(filepath.Join(dir, "mb.txt"), []byte(content), 0644)) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // cat the file 11 times — produces 11 MiB, exceeding the 10 MiB cap. + script := `for i in 1 2 3 4 5 6 7 8 9 10 11; do cat mb.txt; done` + var outBuf bytes.Buffer + runner, err := interp.New( + interp.StdIO(nil, &outBuf, nil), + interp.AllowedPaths([]string{dir}), + interpoption.AllowAllCommands().(interp.RunnerOption), + ) + require.NoError(t, err) + defer runner.Close() + runner.Dir = dir + + prog, err := syntax.NewParser().Parse(strings.NewReader(script), "test") + require.NoError(t, err) + + runErr := runner.Run(ctx, prog) + assert.ErrorIs(t, runErr, interp.ErrOutputLimitExceeded, + "Run must return ErrOutputLimitExceeded when stdout cap is exceeded") + // Output up to the cap must still be delivered. + assert.LessOrEqual(t, outBuf.Len(), 10*1024*1024, + "stdout must not exceed 10 MiB; got %d bytes", outBuf.Len()) + assert.Greater(t, outBuf.Len(), 0, "expected non-empty stdout before cap") +} + +// TestGlobalStdoutCapMultipleRuns verifies that repeated Run() calls on the +// same Runner without Reset() do not double-wrap the stdout writer. The first +// call must not leave r.stdout pointing at the limitWriter, so the second call +// starts with a fresh 10 MiB budget rather than inheriting the first call's +// byte counter. +func TestGlobalStdoutCapMultipleRuns(t *testing.T) { + dir := t.TempDir() + content := strings.Repeat("A", 1<<20) // 1 MiB + require.NoError(t, os.WriteFile(filepath.Join(dir, "mb.txt"), []byte(content), 0644)) + + var outBuf bytes.Buffer + runner, err := interp.New( + interp.StdIO(nil, &outBuf, nil), + interp.AllowedPaths([]string{dir}), + interpoption.AllowAllCommands().(interp.RunnerOption), + ) + require.NoError(t, err) + defer runner.Close() + runner.Dir = dir + + parse := func(script string) *syntax.File { + t.Helper() + prog, parseErr := syntax.NewParser().Parse(strings.NewReader(script), "test") + require.NoError(t, parseErr) + return prog + } + + ctx := context.Background() + + // First call: write 9 MiB — just under the cap. Must succeed. + outBuf.Reset() + ctx1, cancel1 := context.WithTimeout(ctx, 10*time.Second) + defer cancel1() + err = runner.Run(ctx1, parse(`for i in 1 2 3 4 5 6 7 8 9; do cat mb.txt; done`)) + assert.NoError(t, err, "first run (9 MiB) must not exceed cap") + assert.Equal(t, 9<<20, outBuf.Len(), "first run must deliver exactly 9 MiB") + + // Second call: write another 9 MiB. If r.stdout was not restored, the + // wrapped limitWriter from call 1 already has 9 MiB counted and would + // silently drop all output here — returning no error. A fresh budget means + // this call also succeeds with 9 MiB delivered. + outBuf.Reset() + ctx2, cancel2 := context.WithTimeout(ctx, 10*time.Second) + defer cancel2() + err = runner.Run(ctx2, parse(`for i in 1 2 3 4 5 6 7 8 9; do cat mb.txt; done`)) + assert.NoError(t, err, "second run (9 MiB) must not exceed cap") + assert.Equal(t, 9<<20, outBuf.Len(), "second run must deliver exactly 9 MiB (fresh cap)") +} + +// TestGlobalStdoutCapPrecedenceOverExitCode verifies that ErrOutputLimitExceeded +// takes precedence over a non-zero exit code when both occur in the same Run() +// call. Fatal handler errors (r.exit.err) still take precedence over the cap +// per the ordering in Run(), but that path is not easily triggerable from a +// script-level test. +func TestGlobalStdoutCapPrecedenceOverExitCode(t *testing.T) { + dir := t.TempDir() + content := strings.Repeat("A", 1<<20) // 1 MiB + require.NoError(t, os.WriteFile(filepath.Join(dir, "mb.txt"), []byte(content), 0644)) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + var outBuf bytes.Buffer + runner, err := interp.New( + interp.StdIO(nil, &outBuf, nil), + interp.AllowedPaths([]string{dir}), + interpoption.AllowAllCommands().(interp.RunnerOption), + ) + require.NoError(t, err) + defer runner.Close() + runner.Dir = dir + + // Exceed the cap then exit non-zero. ErrOutputLimitExceeded must be returned, + // not ExitStatus(1). + prog, parseErr := syntax.NewParser().Parse(strings.NewReader( + `for i in 1 2 3 4 5 6 7 8 9 10 11; do cat mb.txt; done; exit 1`, + ), "test") + require.NoError(t, parseErr) + + runErr := runner.Run(ctx, prog) + assert.ErrorIs(t, runErr, interp.ErrOutputLimitExceeded, + "ErrOutputLimitExceeded must take precedence over a non-zero exit code") + var es interp.ExitStatus + assert.False(t, errors.As(runErr, &es), + "must not return ExitStatus when stdout cap was exceeded") +} + func TestCmdSubstOutputCapped(t *testing.T) { // Generate output exceeding 1 MiB inside command substitution. // The output should be truncated, not cause OOM.