-
Notifications
You must be signed in to change notification settings - Fork 1
Add execution timeout support #128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
59bf1f1
5fc1974
3162c4f
e09b577
ee69332
4a4c118
f5a45fe
c73d6f7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -14,22 +14,26 @@ import ( | |||||||||||||||||
| "io" | ||||||||||||||||||
| "os" | ||||||||||||||||||
| "strings" | ||||||||||||||||||
| "time" | ||||||||||||||||||
|
|
||||||||||||||||||
| "github.com/DataDog/rshell/interp" | ||||||||||||||||||
| "github.com/spf13/cobra" | ||||||||||||||||||
| "mvdan.cc/sh/v3/syntax" | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| const exitCodeTimeout = 124 | ||||||||||||||||||
|
|
||||||||||||||||||
| func main() { | ||||||||||||||||||
| os.Exit(run(os.Args[1:], os.Stdin, os.Stdout, os.Stderr)) | ||||||||||||||||||
| os.Exit(run(context.Background(), os.Args[1:], os.Stdin, os.Stdout, os.Stderr)) | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| func run(args []string, stdin io.Reader, stdout, stderr io.Writer) int { | ||||||||||||||||||
| func run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) int { | ||||||||||||||||||
| var ( | ||||||||||||||||||
| command string | ||||||||||||||||||
| allowedPaths string | ||||||||||||||||||
| allowedCommands string | ||||||||||||||||||
| allowAllCmds bool | ||||||||||||||||||
| timeout time.Duration | ||||||||||||||||||
| procPath string | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -49,6 +53,17 @@ func run(args []string, stdin io.Reader, stdout, stderr io.Writer) int { | |||||||||||||||||
| return fmt.Errorf("cannot use -c with file arguments") | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| if timeout < 0 { | ||||||||||||||||||
| return fmt.Errorf("--timeout must be >= 0") | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| runCtx := cmd.Context() | ||||||||||||||||||
| if timeout > 0 { | ||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd suggest combining these
Suggested change
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this suggestion above doesn't account for timeout == 0 case: |
||||||||||||||||||
| var cancel context.CancelFunc | ||||||||||||||||||
| runCtx, cancel = context.WithTimeout(runCtx, timeout) | ||||||||||||||||||
| defer cancel() | ||||||||||||||||||
|
Comment on lines
+61
to
+64
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The timeout implementation only derives a child context with Useful? React with 👍 / 👎. |
||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| var paths []string | ||||||||||||||||||
| if allowedPaths != "" { | ||||||||||||||||||
| paths = strings.Split(allowedPaths, ",") | ||||||||||||||||||
|
|
@@ -67,35 +82,40 @@ func run(args []string, stdin io.Reader, stdout, stderr io.Writer) int { | |||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| if commandSet { | ||||||||||||||||||
| return execute(cmd.Context(), command, "", execOpts, stdin, stdout, stderr) | ||||||||||||||||||
| return execute(runCtx, command, "", execOpts, stdin, stdout, stderr) | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| if len(args) > 0 { | ||||||||||||||||||
| // Read stdin once so each execute() call gets its own | ||||||||||||||||||
| // reader, avoiding a data race on the shared io.Reader. | ||||||||||||||||||
| stdinData, err := io.ReadAll(stdin) | ||||||||||||||||||
| stdinData, err := readAllContext(runCtx, stdin) | ||||||||||||||||||
| if err != nil { | ||||||||||||||||||
| return fmt.Errorf("reading stdin: %w", err) | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| for _, file := range args { | ||||||||||||||||||
| data, err := os.ReadFile(file) | ||||||||||||||||||
| f, err := os.Open(file) | ||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These direct
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note that this is cmd/rshell/main.go building ./rshell CLI , only for dev/testing |
||||||||||||||||||
| if err != nil { | ||||||||||||||||||
| return fmt.Errorf("reading %s: %w", file, err) | ||||||||||||||||||
| } | ||||||||||||||||||
| data, err := readAllContext(runCtx, f) | ||||||||||||||||||
| f.Close() | ||||||||||||||||||
| if err != nil { | ||||||||||||||||||
| return fmt.Errorf("reading %s: %w", file, err) | ||||||||||||||||||
| } | ||||||||||||||||||
| if err := execute(cmd.Context(), string(data), file, execOpts, bytes.NewReader(stdinData), stdout, stderr); err != nil { | ||||||||||||||||||
| if err := execute(runCtx, string(data), file, execOpts, bytes.NewReader(stdinData), stdout, stderr); err != nil { | ||||||||||||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When multiple script files are passed and the timeout fires between file N and N+1, the next file is still read from disk before the already-expired context is checked (inside Consider an explicit check before each read:
Suggested change
|
||||||||||||||||||
| return err | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
| return nil | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| // No -c and no file args: read from stdin. | ||||||||||||||||||
| stdinData, err := io.ReadAll(stdin) | ||||||||||||||||||
| stdinData, err := readAllContext(runCtx, stdin) | ||||||||||||||||||
| if err != nil { | ||||||||||||||||||
| return fmt.Errorf("reading stdin: %w", err) | ||||||||||||||||||
| } | ||||||||||||||||||
| return execute(cmd.Context(), string(stdinData), "", execOpts, strings.NewReader(""), stdout, stderr) | ||||||||||||||||||
| return execute(runCtx, string(stdinData), "", execOpts, strings.NewReader(""), stdout, stderr) | ||||||||||||||||||
| }, | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -109,19 +129,55 @@ func run(args []string, stdin io.Reader, stdout, stderr io.Writer) int { | |||||||||||||||||
| cmd.Flags().StringVarP(&allowedPaths, "allowed-paths", "p", "", "comma-separated list of directories the shell is allowed to access") | ||||||||||||||||||
| cmd.Flags().StringVar(&allowedCommands, "allowed-commands", "", "comma-separated list of namespaced commands (e.g. rshell:cat,rshell:find)") | ||||||||||||||||||
| cmd.Flags().BoolVar(&allowAllCmds, "allow-all-commands", false, "allow execution of all commands (builtins and external)") | ||||||||||||||||||
| cmd.Flags().DurationVar(&timeout, "timeout", 0, "maximum execution time for the entire shell run (e.g. 100ms, 5s, 1m)") | ||||||||||||||||||
| cmd.Flags().StringVar(&procPath, "proc-path", "", "path to the proc filesystem used by ps (default \"/proc\")") | ||||||||||||||||||
|
|
||||||||||||||||||
| if err := cmd.Execute(); err != nil { | ||||||||||||||||||
| if err := cmd.ExecuteContext(ctx); err != nil { | ||||||||||||||||||
| var status interp.ExitStatus | ||||||||||||||||||
| if errors.As(err, &status) { | ||||||||||||||||||
| return int(status) | ||||||||||||||||||
| } | ||||||||||||||||||
| if errors.Is(err, context.DeadlineExceeded) { | ||||||||||||||||||
|
AlexandreYang marked this conversation as resolved.
|
||||||||||||||||||
| if timeout > 0 { | ||||||||||||||||||
| fmt.Fprintf(stderr, "error: execution timed out after %s\n", timeout) | ||||||||||||||||||
| } else { | ||||||||||||||||||
| fmt.Fprintln(stderr, "error: execution timed out") | ||||||||||||||||||
| } | ||||||||||||||||||
| return exitCodeTimeout | ||||||||||||||||||
| } | ||||||||||||||||||
| if errors.Is(err, context.Canceled) { | ||||||||||||||||||
| fmt.Fprintln(stderr, "error: execution canceled") | ||||||||||||||||||
| return exitCodeTimeout | ||||||||||||||||||
| } | ||||||||||||||||||
| fmt.Fprintf(stderr, "error: %v\n", err) | ||||||||||||||||||
| return 1 | ||||||||||||||||||
| } | ||||||||||||||||||
| return 0 | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| // readAllContext reads all bytes from r, but returns ctx.Err() immediately if | ||||||||||||||||||
| // the context is cancelled or its deadline expires before the read completes. | ||||||||||||||||||
| // It spawns a goroutine to perform the read; the goroutine may outlive this | ||||||||||||||||||
| // call if the underlying reader blocks (e.g. stdin from a pipe), but it will | ||||||||||||||||||
| // be reclaimed when the process exits. | ||||||||||||||||||
| func readAllContext(ctx context.Context, r io.Reader) ([]byte, error) { | ||||||||||||||||||
| type result struct { | ||||||||||||||||||
| data []byte | ||||||||||||||||||
| err error | ||||||||||||||||||
| } | ||||||||||||||||||
| ch := make(chan result, 1) | ||||||||||||||||||
| go func() { | ||||||||||||||||||
| data, err := io.ReadAll(r) | ||||||||||||||||||
| ch <- result{data, err} | ||||||||||||||||||
| }() | ||||||||||||||||||
| select { | ||||||||||||||||||
| case <-ctx.Done(): | ||||||||||||||||||
| return nil, ctx.Err() | ||||||||||||||||||
| case res := <-ch: | ||||||||||||||||||
| return res.data, res.err | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| // rejectLongCommand scans raw CLI args for "--command" or "--command=..." and | ||||||||||||||||||
| // returns an error if found. The flag is registered with a long name so that | ||||||||||||||||||
| // cobra/pflag help formatting works correctly, but only the -c shorthand is | ||||||||||||||||||
|
|
||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,8 @@ package main | |
|
|
||
| import ( | ||
| "bytes" | ||
| "context" | ||
| "io" | ||
| "os" | ||
| "path/filepath" | ||
| "runtime" | ||
|
|
@@ -18,9 +20,14 @@ import ( | |
| ) | ||
|
|
||
| func runCLI(t *testing.T, args ...string) (exitCode int, stdout, stderr string) { | ||
| t.Helper() | ||
| return runCLIContext(t, context.Background(), args...) | ||
| } | ||
|
|
||
| func runCLIContext(t *testing.T, ctx context.Context, args ...string) (exitCode int, stdout, stderr string) { | ||
| t.Helper() | ||
| var out, errBuf bytes.Buffer | ||
| code := run(args, strings.NewReader(""), &out, &errBuf) | ||
| code := run(ctx, args, strings.NewReader(""), &out, &errBuf) | ||
|
AlexandreYang marked this conversation as resolved.
|
||
| return code, out.String(), errBuf.String() | ||
| } | ||
|
|
||
|
|
@@ -39,7 +46,7 @@ func TestShortFlag(t *testing.T) { | |
| func runCLIWithStdin(t *testing.T, stdin string, args ...string) (exitCode int, stdout, stderr string) { | ||
| t.Helper() | ||
| var out, errBuf bytes.Buffer | ||
| code := run(args, strings.NewReader(stdin), &out, &errBuf) | ||
| code := run(context.Background(), args, strings.NewReader(stdin), &out, &errBuf) | ||
| return code, out.String(), errBuf.String() | ||
| } | ||
|
|
||
|
|
@@ -141,6 +148,7 @@ func TestHelp(t *testing.T) { | |
| assert.Contains(t, stdout, "--allowed-paths") | ||
| assert.Contains(t, stdout, "--allowed-commands") | ||
| assert.Contains(t, stdout, "--allow-all-commands") | ||
| assert.Contains(t, stdout, "--timeout") | ||
| assert.NotContains(t, stdout, "--command", "-c/--command should be hidden from help") | ||
| } | ||
|
|
||
|
|
@@ -243,6 +251,31 @@ func TestCommandLongFormRejected(t *testing.T) { | |
| assert.Contains(t, stderr, "unknown flag: --command") | ||
| } | ||
|
|
||
| func TestTimeoutFlagTimesOutExecution(t *testing.T) { | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is a known limitation: the CLI cannot inject a slow execHandler, and |
||
| // Feed a blocking stdin with no -c flag so the timeout fires while readAllContext | ||
| // is waiting for EOF. 50ms is well above Windows' ~15ms timer resolution. | ||
| pr, pw := io.Pipe() | ||
| defer pw.Close() | ||
| var out, errBuf bytes.Buffer | ||
| code := run(context.Background(), []string{"--timeout", "50ms"}, pr, &out, &errBuf) | ||
| assert.Equal(t, exitCodeTimeout, code) | ||
| assert.Contains(t, errBuf.String(), "execution timed out") | ||
| } | ||
|
|
||
| func TestCanceledContextExitsWithTimeoutCode(t *testing.T) { | ||
| ctx, cancel := context.WithCancel(context.Background()) | ||
| cancel() // cancel before execution starts | ||
| code, _, stderr := runCLIContext(t, ctx, "--allow-all-commands", "-c", `echo hello`) | ||
| assert.Equal(t, exitCodeTimeout, code) | ||
| assert.Contains(t, stderr, "execution canceled") | ||
| } | ||
|
|
||
| func TestTimeoutFlagRejectsNegative(t *testing.T) { | ||
| code, _, stderr := runCLI(t, "--timeout", "-1s", "-c", `echo hello`) | ||
| assert.Equal(t, 1, code) | ||
| assert.Contains(t, stderr, "--timeout must be >= 0") | ||
| } | ||
|
|
||
| func TestProcPathFlagInHelp(t *testing.T) { | ||
| code, stdout, _ := runCLI(t, "--help") | ||
| assert.Equal(t, 0, code) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.