Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions allowedsymbols/symbols_interp.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ var interpAllowedSymbols = []string{
"context.WithTimeout", // 🟢 derives a context with a deadline; needed for execution timeout support.
"context.WithValue", // 🟢 derives a context carrying a key-value pair; pure function.
"errors.As", // 🟢 error type assertion; pure function, no I/O.
"errors.New", // 🟢 creates a sentinel error value; pure function, no I/O.
"fmt.Errorf", // 🟢 formatted error creation; pure function, no I/O.
"fmt.Fprintf", // 🟠 formatted write to an io.Writer; delegates to Write, no filesystem access.
"fmt.Fprintln", // 🟠 writes to an io.Writer with newline; delegates to Write, no filesystem access.
Expand Down
24 changes: 23 additions & 1 deletion interp/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,14 @@ func (r *Runner) Reset() {
r.didReset = true
}

// ErrOutputLimitExceeded is returned by Run when a script produces more stdout
// than maxStdoutBytes. Partial output up to the limit is still delivered to the
// caller's writer. Use errors.Is to check for this condition.
var ErrOutputLimitExceeded = errors.New(fmt.Sprintf(
"stdout limit exceeded: script produced more than %d MiB of output",
maxStdoutBytes/(1024*1024),
))

// ExitStatus is a non-zero status code resulting from running a shell node.
type ExitStatus uint8

Expand Down Expand Up @@ -458,6 +466,15 @@ func (r *Runner) Run(ctx context.Context, node syntax.Node) (retErr error) {
return r.exit.err
}
}
// Wrap stdout with a cap. Bytes beyond maxStdoutBytes are silently
// discarded so that builtins never see a write error mid-execution, but
// the exceeded flag is set so Run() can surface a well-defined error to
// the caller after the script finishes. Restore r.stdout on return so
// that repeated Run() calls without Reset() do not double-wrap the writer.
prevStdout := r.stdout
stdoutCap := &limitWriter{w: prevStdout, limit: maxStdoutBytes}
r.stdout = stdoutCap
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 — r.stdout is never restored after Run() returns.

After this line, r.stdout points to stdoutCap. The next Run() call (without Reset()) will do:

stdoutCap2 := &limitWriter{w: stdoutCap  /* still from call 1 */, limit: maxStdoutBytes}
r.stdout = stdoutCap2

The outer limitWriter has a fresh n=0, so the per-call check works. But all writes flow through stdoutCap (from call 1) whose n still reflects call 1's bytes. If call 1 wrote 9 MiB, call 2 can only write 1 MiB before silent discard — and stdoutCap2.exceeded stays false, so Run() returns nil with truncated output.

Reset() between calls does restore r.origStdout correctly (line 399). But Run() is documented to be callable multiple times incrementally without Reset().

Fix:

prevStdout := r.stdout
stdoutCap := &limitWriter{w: prevStdout, limit: maxStdoutBytes}
r.stdout = stdoutCap
defer func() { r.stdout = prevStdout }()

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 9f6b8e2. Added prevStdout := r.stdout before wrapping and a defer func() { r.stdout = prevStdout }() so the writer is always restored — including through the panic recovery path.

defer func() { r.stdout = prevStdout }()
r.startTime = time.Now()
r.globReadDirCount = &atomic.Int64{}
r.fillExpandConfig(ctx)
Expand All @@ -478,10 +495,15 @@ func (r *Runner) Run(ctx context.Context, node syntax.Node) (retErr error) {
default:
return fmt.Errorf("node can only be File, Stmt, or Command: %T", node)
}
// Return the first of: a fatal error, a non-fatal handler error, or the exit code.
// Return the first of: a fatal/handler error, stdout cap exceeded, or the exit code.
// Fatal errors take precedence over ErrOutputLimitExceeded so that cancellation
// and handler failures are not masked when the cap is also hit.
if err := r.exit.err; err != nil {
return err
}
if stdoutCap.isExceeded() {
return ErrOutputLimitExceeded
}
if code := r.exit.code; code != 0 {
return ExitStatus(code)
}
Expand Down
35 changes: 32 additions & 3 deletions interp/runner_expand.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"io/fs"
"os"
"strings"
"sync"

"mvdan.cc/sh/v3/expand"
"mvdan.cc/sh/v3/syntax"
Expand Down Expand Up @@ -49,6 +50,12 @@ func (r *Runner) updateExpandOpts() {
// commands that produce unbounded output.
const maxCmdSubstOutput = 1 << 20 // 1 MiB

// maxStdoutBytes is the maximum number of bytes a script can write to stdout
// before further output is silently discarded. This caps total script output
// to prevent memory exhaustion from runaway commands (e.g. infinite loops
// writing to stdout).
const maxStdoutBytes = 10 * 1024 * 1024 // 10 MiB

// MaxGlobReadDirCalls is the maximum number of ReadDirForGlob invocations
// allowed per Run() call. This prevents memory exhaustion from scripts that
// trigger an excessive number of glob expansions (e.g. millions of unquoted
Expand Down Expand Up @@ -130,14 +137,27 @@ func catShortcutArg(stmt *syntax.Stmt) *syntax.Word {
}

// limitWriter wraps a writer and stops writing after limit bytes.
// When the limit is exceeded, exceeded is set to true and further writes
// are silently discarded so that callers do not see spurious short-write
// errors mid-execution. The exceeded flag can be checked after execution
// via isExceeded to surface the event as an error.
//
// limitWriter is safe for concurrent use: the mutex serialises writes so
// that the byte counter and exceeded flag are always consistent, even when
// background goroutines write to the same writer concurrently.
type limitWriter struct {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's add a comment stating this is not thread-safe object (it is passed by pointer today, and the thread-safety depends on the caller discipline: overwriting the stdout field

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Claude Sonnet 4.6] Fixed in 2151ff4 — made limitWriter concurrency-safe by adding a sync.Mutex that is held for the full duration of each Write call (protecting both n and exceeded). Also added an isExceeded() method so the caller in api.go reads the flag under the same lock. The doc comment now describes the guarantee rather than a caveat.

w io.Writer
limit int64
n int64
mu sync.Mutex
w io.Writer
limit int64
n int64
exceeded bool
}

func (lw *limitWriter) Write(p []byte) (int, error) {
lw.mu.Lock()
defer lw.mu.Unlock()
if lw.n >= lw.limit {
lw.exceeded = true
return len(p), nil // silently discard excess
}
remaining := lw.limit - lw.n
Expand All @@ -146,13 +166,22 @@ func (lw *limitWriter) Write(p []byte) (int, error) {
return int(remaining), err
}
lw.n = lw.limit
lw.exceeded = true
return len(p), nil // report all bytes consumed to avoid short-write errors
}
n, err := lw.w.Write(p)
lw.n += int64(n)
return n, err
}

// isExceeded reports whether any write has exceeded the byte limit.
// It is safe to call concurrently with Write.
func (lw *limitWriter) isExceeded() bool {
lw.mu.Lock()
defer lw.mu.Unlock()
return lw.exceeded
}

func (r *Runner) expandErr(err error) {
if err == nil {
return
Expand Down
131 changes: 131 additions & 0 deletions interp/tests/cmdsubst_hardening_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
package tests_test

import (
"bytes"
"context"
"errors"
"fmt"
"os"
"path/filepath"
Expand All @@ -16,10 +18,139 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"mvdan.cc/sh/v3/syntax"

"github.com/DataDog/rshell/internal/interpoption"
"github.com/DataDog/rshell/interp"
)

// --- Memory limits: output capping ---

// TestGlobalStdoutCapReturnsError verifies that Run returns ErrOutputLimitExceeded
// when a script exceeds the 10 MiB stdout cap. The script runs to completion
// but partial output (up to the limit) is still delivered, and the caller
// receives a well-defined error rather than a silent truncation.
func TestGlobalStdoutCapReturnsError(t *testing.T) {
dir := t.TempDir()

// Create a file of exactly 1 MiB.
content := strings.Repeat("A", 1<<20)
require.NoError(t, os.WriteFile(filepath.Join(dir, "mb.txt"), []byte(content), 0644))

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

// cat the file 11 times — produces 11 MiB, exceeding the 10 MiB cap.
script := `for i in 1 2 3 4 5 6 7 8 9 10 11; do cat mb.txt; done`
var outBuf bytes.Buffer
runner, err := interp.New(
interp.StdIO(nil, &outBuf, nil),
interp.AllowedPaths([]string{dir}),
interpoption.AllowAllCommands().(interp.RunnerOption),
)
require.NoError(t, err)
defer runner.Close()
runner.Dir = dir

prog, err := syntax.NewParser().Parse(strings.NewReader(script), "test")
require.NoError(t, err)

runErr := runner.Run(ctx, prog)
assert.ErrorIs(t, runErr, interp.ErrOutputLimitExceeded,
"Run must return ErrOutputLimitExceeded when stdout cap is exceeded")
// Output up to the cap must still be delivered.
assert.LessOrEqual(t, outBuf.Len(), 10*1024*1024,
"stdout must not exceed 10 MiB; got %d bytes", outBuf.Len())
assert.Greater(t, outBuf.Len(), 0, "expected non-empty stdout before cap")
}
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 — Missing: test for multiple Run() calls without Reset() to verify caps do not stack.

The current test only exercises a single Run() call. A second test that runs two scripts on the same runner (first writes 9 MiB, second writes 5 MiB) would catch the stdout-stacking bug and confirm the second call gets the full 10 MiB budget.

Also missing: a test for a script that both exits non-zero and exceeds the cap, to document which error takes priority.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 9f6b8e2. Added TestGlobalStdoutCapMultipleRuns (covers the double-wrap regression — two 9 MiB runs on the same runner, each must succeed with a fresh cap) and TestGlobalStdoutCapPrecedenceOverExitCode (verifies ErrOutputLimitExceeded takes precedence over a non-zero exit code).


// TestGlobalStdoutCapMultipleRuns verifies that repeated Run() calls on the
// same Runner without Reset() do not double-wrap the stdout writer. The first
// call must not leave r.stdout pointing at the limitWriter, so the second call
// starts with a fresh 10 MiB budget rather than inheriting the first call's
// byte counter.
func TestGlobalStdoutCapMultipleRuns(t *testing.T) {
dir := t.TempDir()
content := strings.Repeat("A", 1<<20) // 1 MiB
require.NoError(t, os.WriteFile(filepath.Join(dir, "mb.txt"), []byte(content), 0644))

var outBuf bytes.Buffer
runner, err := interp.New(
interp.StdIO(nil, &outBuf, nil),
interp.AllowedPaths([]string{dir}),
interpoption.AllowAllCommands().(interp.RunnerOption),
)
require.NoError(t, err)
defer runner.Close()
runner.Dir = dir

parse := func(script string) *syntax.File {
t.Helper()
prog, parseErr := syntax.NewParser().Parse(strings.NewReader(script), "test")
require.NoError(t, parseErr)
return prog
}

ctx := context.Background()

// First call: write 9 MiB — just under the cap. Must succeed.
outBuf.Reset()
ctx1, cancel1 := context.WithTimeout(ctx, 10*time.Second)
defer cancel1()
err = runner.Run(ctx1, parse(`for i in 1 2 3 4 5 6 7 8 9; do cat mb.txt; done`))
assert.NoError(t, err, "first run (9 MiB) must not exceed cap")
assert.Equal(t, 9<<20, outBuf.Len(), "first run must deliver exactly 9 MiB")

// Second call: write another 9 MiB. If r.stdout was not restored, the
// wrapped limitWriter from call 1 already has 9 MiB counted and would
// silently drop all output here — returning no error. A fresh budget means
// this call also succeeds with 9 MiB delivered.
outBuf.Reset()
ctx2, cancel2 := context.WithTimeout(ctx, 10*time.Second)
defer cancel2()
err = runner.Run(ctx2, parse(`for i in 1 2 3 4 5 6 7 8 9; do cat mb.txt; done`))
assert.NoError(t, err, "second run (9 MiB) must not exceed cap")
assert.Equal(t, 9<<20, outBuf.Len(), "second run must deliver exactly 9 MiB (fresh cap)")
}

// TestGlobalStdoutCapPrecedenceOverExitCode verifies that ErrOutputLimitExceeded
// takes precedence over a non-zero exit code when both occur in the same Run()
// call. Fatal handler errors (r.exit.err) still take precedence over the cap
// per the ordering in Run(), but that path is not easily triggerable from a
// script-level test.
func TestGlobalStdoutCapPrecedenceOverExitCode(t *testing.T) {
dir := t.TempDir()
content := strings.Repeat("A", 1<<20) // 1 MiB
require.NoError(t, os.WriteFile(filepath.Join(dir, "mb.txt"), []byte(content), 0644))

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

var outBuf bytes.Buffer
runner, err := interp.New(
interp.StdIO(nil, &outBuf, nil),
interp.AllowedPaths([]string{dir}),
interpoption.AllowAllCommands().(interp.RunnerOption),
)
require.NoError(t, err)
defer runner.Close()
runner.Dir = dir

// Exceed the cap then exit non-zero. ErrOutputLimitExceeded must be returned,
// not ExitStatus(1).
prog, parseErr := syntax.NewParser().Parse(strings.NewReader(
`for i in 1 2 3 4 5 6 7 8 9 10 11; do cat mb.txt; done; exit 1`,
), "test")
require.NoError(t, parseErr)

runErr := runner.Run(ctx, prog)
assert.ErrorIs(t, runErr, interp.ErrOutputLimitExceeded,
"ErrOutputLimitExceeded must take precedence over a non-zero exit code")
var es interp.ExitStatus
assert.False(t, errors.As(runErr, &es),
"must not return ExitStatus when stdout cap was exceeded")
}

func TestCmdSubstOutputCapped(t *testing.T) {
// Generate output exceeding 1 MiB inside command substitution.
// The output should be truncated, not cause OOM.
Expand Down
Loading