From 1dd9ddcf8475829e7f20e006b6040d41b04bbaa0 Mon Sep 17 00:00:00 2001 From: Hamza El-Saawy Date: Tue, 8 Feb 2022 18:37:21 -0500 Subject: [PATCH 1/2] IO relay bugs Updating cmd to wait on all relayIO operations, and using CloseStdin to signal closure of the process. `hcs.Process` would attempt to send `ModifyProcessRequest` to close stdin even after process exits. Was also inconsistent about returning error on already closed IO handles. Changed `gcs` Process to mirror `hcs.Process` and `jobcontainer.Process` `.CloseStd*` now return errors if the process was already closed, `.Close) zeros out the `gc` connection, and `.CloseStdin` both performs a `CloseWrite` and closes the channel. Signed-off-by: Hamza El-Saawy --- internal/cmd/cmd.go | 145 +++++++++++++++++++++++++++++++-------- internal/cmd/cmd_test.go | 129 +++++++++++++++++++++++++++++++++- internal/cmd/diag.go | 6 ++ internal/cmd/io.go | 30 +++++++- internal/gcs/process.go | 137 +++++++++++++++++++++++++++++------- internal/hcs/process.go | 88 +++++++++++++++--------- 6 files changed, 448 insertions(+), 87 deletions(-) diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index d7228619eb..741855fb86 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -5,19 +5,23 @@ package cmd import ( "bytes" "context" + "errors" "fmt" "io" "strings" "sync/atomic" "time" - "github.com/Microsoft/hcsshim/internal/cow" - hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" - "github.com/Microsoft/hcsshim/internal/log" specs "github.com/opencontainers/runtime-spec/specs-go" "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" "golang.org/x/sys/windows" + + "github.com/Microsoft/hcsshim/internal/cow" + "github.com/Microsoft/hcsshim/internal/hcs" + hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/logfields" ) // CmdProcessRequest stores information on command requests made through this package. @@ -62,9 +66,13 @@ type Cmd struct { // ExitState is filled out after Wait() (or Run() or Output()) completes. ExitState *ExitState + // afterExitFuns are run after the process exits, but before wait returns and + // IO is waited on + afterExitFuns []func(context.Context) error + iogrp errgroup.Group stdinErr atomic.Value - allDoneCh chan struct{} + allDoneCh chan struct{} // closed after Wait finishes } // ExitState contains whether a process has exited and with which exit code. @@ -136,6 +144,11 @@ func CommandContext(ctx context.Context, host cow.ProcessHost, name string, arg // Start starts a command. The caller must ensure that if Start succeeds, // Wait is eventually called to clean up resources. func (c *Cmd) Start() error { + ctx := context.Background() + if c.Context != nil { + ctx = c.Context + } + c.allDoneCh = make(chan struct{}) var x interface{} if !c.Host.IsOCI() { @@ -184,10 +197,10 @@ func (c *Cmd) Start() error { } x = lpp } - if c.Context != nil && c.Context.Err() != nil { - return c.Context.Err() + if err := ctx.Err(); err != nil { + return err } - p, err := c.Host.CreateProcess(context.TODO(), x) + p, err := c.Host.CreateProcess(ctx, x) if err != nil { return err } @@ -200,18 +213,29 @@ func (c *Cmd) Start() error { stdin, stdout, stderr := p.Stdio() if c.Stdin != nil { // Do not make stdin part of the error group because there is no way for - // us or the caller to reliably unblock the c.Stdin read when the - // process exits. + // us to reliably unblock the c.Stdin read when the process exits. + // Even if `stdin` is closed, the runtime can block indefinitely on reading + // c.Stdin, so the only reliable way to unblock this is with: + // c.Stdin.CloseWrite() (if it implements it) or c.Stdin.Close(). + // However, we are only passed on the Reader end of Stdin, and closing the + // upstream c.Stdin breaks with the functionality that os.exec.Cmd implements. go func() { - _, err := relayIO(stdin, c.Stdin, c.Log, "stdin") + _, err := relayIO(stdin, c.Stdin, c.Log, "stdin", isIOChannelClosedErr) // Report the stdin copy error. If the process has exited, then the // caller may never see it, but if the error was due to a failure in // stdin read, then it is likely the process is still running. if err != nil { - c.stdinErr.Store(err) + select { + case <-c.allDoneCh: + // Wait has returned, err will be ignored. + // relayIO will log the error, so no need for logging here + default: + c.stdinErr.Store(err) + } } - // Notify the process that there is no more input. - if err := p.CloseStdin(context.TODO()); err != nil && c.Log != nil { + // Notify the process that there is no more input, in the case that + // c.Stdin closed while the process is running + if err := p.CloseStdin(ctx); err != nil && !isIOChannelClosedErr(err) && c.Log != nil { c.Log.WithError(err).Warn("failed to close Cmd stdin") } }() @@ -219,9 +243,11 @@ func (c *Cmd) Start() error { if c.Stdout != nil { c.iogrp.Go(func() error { - _, err := relayIO(c.Stdout, stdout, c.Log, "stdout") - if err := p.CloseStdout(context.TODO()); err != nil { - c.Log.WithError(err).Warn("failed to close Cmd stdout") + _, err := relayIO(c.Stdout, stdout, c.Log, "stdout", nil /*skipErr*/) + // Notify the process that upstream IO closed its std out, if the process + // is still running + if cerr := p.CloseStdout(ctx); cerr != nil && !isIOChannelClosedErr(cerr) && c.Log != nil { + c.Log.WithError(cerr).Warn("failed to close Cmd stdout") } return err }) @@ -229,26 +255,29 @@ func (c *Cmd) Start() error { if c.Stderr != nil { c.iogrp.Go(func() error { - _, err := relayIO(c.Stderr, stderr, c.Log, "stderr") - if err := p.CloseStderr(context.TODO()); err != nil { - c.Log.WithError(err).Warn("failed to close Cmd stderr") + _, err := relayIO(c.Stderr, stderr, c.Log, "stderr", nil /*skipErr*/) + // Notify the process that upstream IO closed its std err, if the process + // is still running + if cerr := p.CloseStderr(ctx); cerr != nil && !isIOChannelClosedErr(cerr) && c.Log != nil { + c.Log.WithError(cerr).Warn("failed to close Cmd stderr") } return err }) } + // if ctx is `Background()`, then don't bother launching this, since ctx will + // never be cancelled if c.Context != nil { go func() { select { case <-c.Context.Done(): // Process.Kill (via Process.Signal) will not send an RPC if the // provided context in is cancelled (bridge.AsyncRPC will end early) - ctx := c.Context - if ctx == nil { - ctx = context.Background() + ctx := context.Background() + if c.Context != nil { + ctx = log.Copy(context.Background(), c.Context) } - kctx := log.Copy(context.Background(), ctx) - _, _ = c.Process.Kill(kctx) + _, _ = c.Process.Kill(ctx) case <-c.allDoneCh: } }() @@ -260,6 +289,11 @@ func (c *Cmd) Start() error { // process. It can only be called once. It returns an ExitError if the command // runs and returns a non-zero exit code. func (c *Cmd) Wait() error { + ctx := context.Background() + if c.Context != nil { + ctx = c.Context + } + waitErr := c.Process.Wait() if waitErr != nil && c.Log != nil { c.Log.WithError(waitErr).Warn("process wait failed") @@ -270,8 +304,14 @@ func (c *Cmd) Wait() error { state.exited = true state.code = code } + + err := c.afterExit(ctx) + if err != nil { + log.G(ctx).WithError(err).Warn("error when running after exit functions") + } + // Terminate the IO if the copy does not complete in the requested time. - if c.CopyAfterExitTimeout != 0 { + if c.CopyAfterExitTimeout > 0 { go func() { t := time.NewTimer(c.CopyAfterExitTimeout) defer t.Stop() @@ -281,17 +321,42 @@ func (c *Cmd) Wait() error { // Close the process to cancel any reads to stdout or stderr. c.Process.Close() if c.Log != nil { - c.Log.Warn("timed out waiting for stdio relay") + c.Log. + WithField(logfields.Timeout, c.CopyAfterExitTimeout.String()). + Warn("timed out waiting for stdio relay") } } }() } ioErr := c.iogrp.Wait() - if ioErr == nil { - ioErr, _ = c.stdinErr.Load().(error) + if inErr, _ := c.stdinErr.Load().(error); inErr != nil { + if ioErr == nil { + ioErr = inErr + } else { + // cannot wrap two errors at once, so one will be wrapped via `%v` + ioErr = fmt.Errorf("multiple IO copy errors: %v; %w", inErr, ioErr) + } } + if err == nil { + err = ioErr + } else { + // wrap (prioritize) IO errors over `afterExit` errors + err = fmt.Errorf("io error: %w; other error: %v", ioErr, err) + } + // close the channel first, to prevent `Process.Kill` being called mid-`Process.Close` close(c.allDoneCh) - c.Process.Close() + // Process could have been closed by IO timeout handling + if cerr := c.Process.Close(); cerr != nil && !errors.Is(cerr, hcs.ErrAlreadyClosed) { + if c.Log != nil { + c.Log.WithError(cerr).Warn("error closing the process") + } + if err == nil { + err = cerr + } else { + err = fmt.Errorf("error closing process: %w; other error: %v", cerr, err) + } + } + c.ExitState = state if exitErr != nil { return exitErr @@ -299,7 +364,7 @@ func (c *Cmd) Wait() error { if state.exited && state.code != 0 { return &ExitError{state} } - return ioErr + return err } // Run is equivalent to Start followed by Wait. @@ -319,3 +384,23 @@ func (c *Cmd) Output() ([]byte, error) { err := c.Run() return b.Bytes(), err } + +func (c *Cmd) afterExit(ctx context.Context) (err error) { + for _, f := range c.afterExitFuns { + if ferr := f(ctx); err != nil { + err = ferr + log.G(ctx).WithError(err).Warn("error running function after process exit") + } + } + return err +} + +// RegisterAfterExitFun registers a function to be run after the process exits, but +// before the IO copy operations are waited on and the process is closed +func (c *Cmd) RegisterAfterExitFun(f func(context.Context) error) { + // TODO: find a better way to check if the process is started + if c.allDoneCh != nil { + return + } + c.afterExitFuns = append(c.afterExitFuns, f) +} diff --git a/internal/cmd/cmd_test.go b/internal/cmd/cmd_test.go index c87fd07793..fe16e816c2 100644 --- a/internal/cmd/cmd_test.go +++ b/internal/cmd/cmd_test.go @@ -22,6 +22,8 @@ import ( type localProcessHost struct { } +var _ cow.ProcessHost = &localProcessHost{} + type localProcess struct { p *os.Process state *os.ProcessState @@ -29,6 +31,8 @@ type localProcess struct { stdin, stdout, stderr *os.File } +var _ cow.Process = &localProcess{} + func (h *localProcessHost) OS() string { return "windows" } @@ -171,15 +175,17 @@ func TestCmdOutput(t *testing.T) { func TestCmdContext(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) defer cancel() + cmd := CommandContext(ctx, &localProcessHost{}, "cmd", "/c", "pause") r, w := io.Pipe() cmd.Stdin = r + cmd.RegisterAfterExitFun(func(_ context.Context) error { return w.Close() }) + err := cmd.Start() if err != nil { t.Fatal(err) } - _ = cmd.Process.Wait() - w.Close() + err = cmd.Wait() if e, ok := err.(*ExitError); !ok || e.ExitCode() != 1 || ctx.Err() == nil { t.Fatal(err) @@ -213,6 +219,72 @@ func TestCmdStdinBlocked(t *testing.T) { } } +func TestCmdAfterExitFun(t *testing.T) { + cmd := Command(&localProcessHost{}, "cmd", "/c") + + c := make(chan struct{}) + cmd.RegisterAfterExitFun(func(_ context.Context) error { + close(c) + time.Sleep(50 * time.Millisecond) + return nil + }) + + err := cmd.Start() + if err != nil { + t.Fatal(err) + } + + // call cmd.Wait to make sure after funs are called + done := make(chan error) + go func() { + done <- cmd.Wait() + close(done) + }() + + err = cmd.Process.Wait() + if err != nil { + t.Fatal(err) + } + + select { + case <-c: + // if they both finish at the same time, it is undefined which case is chosen... + case <-cmd.allDoneCh: + t.Fatalf("after exit did not finish before cmd.Wait returned") + } + + // still check for errors during cmd.Wait(): + err = <-done + if err != nil { + t.Fatalf("cmd failed: %v", err) + } +} + +func TestCmdAfterExitFunRegistration(t *testing.T) { + cmd := Command(&localProcessHost{}, "cmd", "/c", "echo", "hello") + + l := len(cmd.afterExitFuns) + cmd.RegisterAfterExitFun(func(_ context.Context) error { + return nil + }) + if len(cmd.afterExitFuns) != l+1 { + t.Fatalf("function registration failed") + } + + err := cmd.Start() + if err != nil { + t.Fatalf("cmd Run failed: %v", err) + } + + cmd.RegisterAfterExitFun(func(_ context.Context) error { + return errors.New("this error should never be raised") + }) + + if len(cmd.afterExitFuns) != l+1 { + t.Fatalf("function should not have been registered") + } +} + type stuckIoProcessHost struct { cow.ProcessHost } @@ -256,3 +328,56 @@ func TestCmdStuckIo(t *testing.T) { t.Fatal(err) } } + +// check that io Copy will wait indefintely if pipes are not closed +func TestCmdStuckStdoutNotClosed(t *testing.T) { + cmd := Command(&stuckIoProcessHost{&localProcessHost{}}, "cmd", "/c") + r, w := io.Pipe() + defer r.Close() + cmd.Stdout = w + + done := make(chan error) + go func() { + done <- cmd.Run() + close(done) + }() + + tr := time.NewTimer(250 * time.Millisecond) // give the cmd a chance to finish running + defer tr.Stop() + select { + case err := <-done: + if err != nil { + t.Fatalf("cmd run failed: %v", err) + } + t.Fatal("command should have blocked indefinitely") + case <-tr.C: + } +} + +func TestCmdStuckStdoutClosed(t *testing.T) { + cmd := Command(&stuckIoProcessHost{&localProcessHost{}}, "cmd", "/c") + r, w := io.Pipe() + defer r.Close() + cmd.Stdout = w + cmd.RegisterAfterExitFun(func(ctx context.Context) error { + p := cmd.Process.(*stuckIoProcess) + return p.stdout.Close() + }) + + done := make(chan error) + go func() { + done <- cmd.Run() + close(done) + }() + + tr := time.NewTimer(250 * time.Millisecond) + defer tr.Stop() + select { + case err := <-done: + if err != io.ErrClosedPipe { + t.Fatalf("cmd run failed: %v", err) + } + case <-tr.C: + t.Fatal("command did not exit") + } +} diff --git a/internal/cmd/diag.go b/internal/cmd/diag.go index e397bb85ee..71391becdd 100644 --- a/internal/cmd/diag.go +++ b/internal/cmd/diag.go @@ -33,6 +33,12 @@ func ExecInUvm(ctx context.Context, vm *uvm.UtilityVM, req *CmdProcessRequest) ( cmd.Stdin = np.Stdin() cmd.Stdout = np.Stdout() cmd.Stderr = np.Stderr() + if cmd.Stdin != nil { + cmd.RegisterAfterExitFun(func(ctx context.Context) error { + np.CloseStdin(ctx) + return nil + }) + } cmd.Log = log.G(ctx).WithField(logfields.UVMID, vm.ID()) err = cmd.Run() return cmd.ExitState.ExitCode(), err diff --git a/internal/cmd/io.go b/internal/cmd/io.go index 75ddd1f355..08aed757b9 100644 --- a/internal/cmd/io.go +++ b/internal/cmd/io.go @@ -5,11 +5,16 @@ package cmd import ( "context" "io" + "net" "net/url" + "os" "time" + "github.com/Microsoft/go-winio" "github.com/pkg/errors" "github.com/sirupsen/logrus" + + "github.com/Microsoft/hcsshim/internal/hcs" ) // UpstreamIO is an interface describing the IO to connect to above the shim. @@ -63,8 +68,28 @@ func NewUpstreamIO(ctx context.Context, id, stdout, stderr, stdin string, termin return NewBinaryIO(ctx, id, u) } +// check if the error is from the file or channel already being closed +func isIOChannelClosedErr(err error) bool { + for _, e := range []error{ + os.ErrClosed, + net.ErrClosed, + io.ErrClosedPipe, + winio.ErrFileClosed, + hcs.ErrAlreadyClosed, + } { + if errors.Is(err, e) { + return true + } + } + return false +} + // relayIO is a glorified io.Copy that also logs when the copy has completed. -func relayIO(w io.Writer, r io.Reader, log *logrus.Entry, name string) (int64, error) { +// +// Use skipErr to ignore and not log errors that it returns true for: passing in +// `isIOChannelClosedErr` will ignore errors raised by reading from or writing to +// a closed Reader or Writer, respectively. It can be `nil` +func relayIO(w io.Writer, r io.Reader, log *logrus.Entry, name string, skipErr func(error) bool) (int64, error) { n, err := io.Copy(w, r) if log != nil { lvl := logrus.DebugLevel @@ -72,6 +97,9 @@ func relayIO(w io.Writer, r io.Reader, log *logrus.Entry, name string) (int64, e "file": name, "bytes": n, }) + if err != nil && skipErr != nil && skipErr(err) { + err = nil + } if err != nil { lvl = logrus.ErrorLevel log = log.WithError(err) diff --git a/internal/gcs/process.go b/internal/gcs/process.go index fab6af75c7..c0d5ae1934 100644 --- a/internal/gcs/process.go +++ b/internal/gcs/process.go @@ -11,12 +11,14 @@ import ( "sync" "github.com/Microsoft/go-winio" + "github.com/sirupsen/logrus" + "go.opencensus.io/trace" + "github.com/Microsoft/hcsshim/internal/cow" + "github.com/Microsoft/hcsshim/internal/hcs" "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/logfields" "github.com/Microsoft/hcsshim/internal/oc" - "github.com/sirupsen/logrus" - "go.opencensus.io/trace" ) const ( @@ -25,14 +27,16 @@ const ( // Process represents a process in a container or container host. type Process struct { - gc *GuestConnection - cid string - id uint32 - waitCall *rpc - waitResp containerWaitForProcessResponse + gcLock sync.RWMutex + gc *GuestConnection + + cid string + id uint32 + waitCall *rpc + waitResp containerWaitForProcessResponse + + stdioLock sync.Mutex stdin, stdout, stderr *ioChannel - stdinCloseWriteOnce sync.Once - stdinCloseWriteErr error } var _ cow.Process = &Process{} @@ -123,26 +127,56 @@ func (gc *GuestConnection) exec(ctx context.Context, cid string, params interfac // Close releases resources associated with the process and closes the // associated standard IO streams. -func (p *Process) Close() error { +func (p *Process) Close() (err error) { ctx, span := oc.StartSpan(context.Background(), "gcs::Process::Close") defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() span.AddAttributes( trace.StringAttribute("cid", p.cid), trace.Int64Attribute("pid", int64(p.id))) - if err := p.stdin.Close(); err != nil { - log.G(ctx).WithError(err).Warn("close stdin failed") + p.gcLock.RLock() + if p.gc == nil { + p.gcLock.RUnlock() + return hcs.ErrAlreadyClosed } - if err := p.stdout.Close(); err != nil { - log.G(ctx).WithError(err).Warn("close stdout failed") + p.gcLock.RUnlock() + + p.stdioLock.Lock() + if p.stdin != nil { + if err = p.stdin.Close(); err != nil { + log.G(ctx).WithError(err).Warn("close stdin failed") + } + p.stdin = nil } - if err := p.stderr.Close(); err != nil { - log.G(ctx).WithError(err).Warn("close stderr failed") + if p.stdout != nil { + if serr := p.stdout.Close(); serr != nil { + log.G(ctx).WithError(serr).Warn("close stdout failed") + if err == nil { + err = serr + } + } + p.stdout = nil } + if p.stderr != nil { + if serr := p.stderr.Close(); serr != nil { + log.G(ctx).WithError(serr).Warn("close stderr failed") + if err == nil { + err = serr + } + } + p.stderr = nil + } + p.stdioLock.Unlock() + + p.gcLock.Lock() + defer p.gcLock.Unlock() + p.gc = nil + return nil } -// CloseStdin causes the process to read EOF on its stdin stream. +// CloseStdin causes the process to read EOF on its stdin stream, and then closes stdin. func (p *Process) CloseStdin(ctx context.Context) (err error) { ctx, span := oc.StartSpan(ctx, "gcs::Process::CloseStdin") //nolint:ineffassign,staticcheck defer span.End() @@ -151,10 +185,27 @@ func (p *Process) CloseStdin(ctx context.Context) (err error) { trace.StringAttribute("cid", p.cid), trace.Int64Attribute("pid", int64(p.id))) - p.stdinCloseWriteOnce.Do(func() { - p.stdinCloseWriteErr = p.stdin.CloseWrite() - }) - return p.stdinCloseWriteErr + p.gcLock.RLock() + if p.gc == nil { + p.gcLock.RUnlock() + return hcs.ErrAlreadyClosed + } + p.gcLock.RUnlock() + + p.stdioLock.Lock() + defer p.stdioLock.Unlock() + if p.stdin != nil { + // First close the channel for writing, sending an EOF to readers, then + // close the file. + if err = p.stdin.CloseWrite(); err != nil { + return err + } + if err = p.stdin.Close(); err == nil { + return err + } + p.stdin = nil + } + return err } func (p *Process) CloseStdout(ctx context.Context) (err error) { @@ -165,7 +216,20 @@ func (p *Process) CloseStdout(ctx context.Context) (err error) { trace.StringAttribute("cid", p.cid), trace.Int64Attribute("pid", int64(p.id))) - return p.stdout.Close() + p.gcLock.RLock() + if p.gc == nil { + p.gcLock.RUnlock() + return hcs.ErrAlreadyClosed + } + p.gcLock.RUnlock() + + p.stdioLock.Lock() + defer p.stdioLock.Unlock() + if p.stdout != nil { + err = p.stdout.Close() + p.stdout = nil + } + return err } func (p *Process) CloseStderr(ctx context.Context) (err error) { @@ -176,7 +240,20 @@ func (p *Process) CloseStderr(ctx context.Context) (err error) { trace.StringAttribute("cid", p.cid), trace.Int64Attribute("pid", int64(p.id))) - return p.stderr.Close() + p.gcLock.RLock() + if p.gc == nil { + p.gcLock.RUnlock() + return hcs.ErrAlreadyClosed + } + p.gcLock.RUnlock() + + p.stdioLock.Lock() + defer p.stdioLock.Unlock() + if p.stderr != nil { + err = p.stderr.Close() + p.stderr = nil + } + return err } // ExitCode returns the process's exit code, or an error if the process is still @@ -220,6 +297,12 @@ func (p *Process) ResizeConsole(ctx context.Context, width, height uint16) (err trace.StringAttribute("cid", p.cid), trace.Int64Attribute("pid", int64(p.id))) + p.gcLock.RLock() + defer p.gcLock.RUnlock() + if p.gc == nil { + return hcs.ErrAlreadyClosed + } + req := containerResizeConsole{ requestBase: makeRequest(ctx, p.cid), ProcessID: p.id, @@ -239,6 +322,12 @@ func (p *Process) Signal(ctx context.Context, options interface{}) (_ bool, err trace.StringAttribute("cid", p.cid), trace.Int64Attribute("pid", int64(p.id))) + p.gcLock.RLock() + defer p.gcLock.RUnlock() + if p.gc == nil { + return false, hcs.ErrAlreadyClosed + } + req := containerSignalProcess{ requestBase: makeRequest(ctx, p.cid), ProcessID: p.id, @@ -267,6 +356,8 @@ func (p *Process) Signal(ctx context.Context, options interface{}) (_ bool, err // Stdio returns the standard IO streams associated with the container. They // will be closed when Close is called. func (p *Process) Stdio() (stdin io.Writer, stdout, stderr io.Reader) { + p.stdioLock.Lock() + defer p.stdioLock.Unlock() return p.stdin, p.stdout, p.stderr } diff --git a/internal/hcs/process.go b/internal/hcs/process.go index 8a3c18437a..e3a10618e5 100644 --- a/internal/hcs/process.go +++ b/internal/hcs/process.go @@ -12,10 +12,12 @@ import ( "syscall" "time" + "go.opencensus.io/trace" + + "github.com/Microsoft/hcsshim/internal/cow" "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/oc" "github.com/Microsoft/hcsshim/internal/vmcompute" - "go.opencensus.io/trace" ) // ContainerError is an error encountered in HCS @@ -38,6 +40,8 @@ type Process struct { waitError error } +var _ cow.Process = &Process{} + func newProcess(process vmcompute.HcsProcess, processID int, computeSystem *System) *Process { return &Process{ handle: process, @@ -309,6 +313,16 @@ func (process *Process) ExitCode() (int, error) { } } +// Exited returns if the process has exited +func (process *Process) exited() bool { + select { + case <-process.waitBlock: + return true + default: + return false + } +} + // StdioLegacy returns the stdin, stdout, and stderr pipes, respectively. Closing // these pipes does not close the underlying pipes. Once returned, these pipes // are the responsibility of the caller to close. @@ -352,55 +366,66 @@ func (process *Process) StdioLegacy() (_ io.WriteCloser, _ io.ReadCloser, _ io.R } // Stdio returns the stdin, stdout, and stderr pipes, respectively. -// To close them, close the process handle. +// To close them, close the process handle, or use the `CloseStd*` functions. func (process *Process) Stdio() (stdin io.Writer, stdout, stderr io.Reader) { - process.stdioLock.Lock() - defer process.stdioLock.Unlock() return process.stdin, process.stdout, process.stderr } // CloseStdin closes the write side of the stdin pipe so that the process is // notified on the read side that there is no more data in stdin. -func (process *Process) CloseStdin(ctx context.Context) error { +func (process *Process) CloseStdin(ctx context.Context) (err error) { + operation := "hcs::Process::CloseStdin" + ctx, span := trace.StartSpan(ctx, operation) //nolint:ineffassign,staticcheck + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + span.AddAttributes( + trace.StringAttribute("cid", process.SystemID()), + trace.Int64Attribute("pid", int64(process.processID))) + process.handleLock.RLock() defer process.handleLock.RUnlock() - operation := "hcs::Process::CloseStdin" - if process.handle == 0 { return makeProcessError(process, operation, ErrAlreadyClosed, nil) } - modifyRequest := processModifyRequest{ - Operation: modifyCloseHandle, - CloseHandle: &closeHandle{ - Handle: stdIn, - }, + process.stdioLock.Lock() + defer process.stdioLock.Unlock() + if process.stdin == nil { + return nil } - modifyRequestb, err := json.Marshal(modifyRequest) - if err != nil { - return err - } + // The HcsModifyProcess request to close stdin will fail if the process has already + // exited or been closed before. + if !process.exited() { + modifyRequest := processModifyRequest{ + Operation: modifyCloseHandle, + CloseHandle: &closeHandle{ + Handle: stdIn, + }, + } - resultJSON, err := vmcompute.HcsModifyProcess(ctx, process.handle, string(modifyRequestb)) - events := processHcsResult(ctx, resultJSON) - if err != nil { - return makeProcessError(process, operation, err, events) - } + modifyRequestB, err := json.Marshal(modifyRequest) + if err != nil { + return err + } - process.stdioLock.Lock() - if process.stdin != nil { - process.stdin.Close() - process.stdin = nil + resultJSON, err := vmcompute.HcsModifyProcess(ctx, process.handle, string(modifyRequestB)) + events := processHcsResult(ctx, resultJSON) + if err != nil { + return makeProcessError(process, operation, err, events) + } } - process.stdioLock.Unlock() - return nil + err = process.stdin.Close() + process.stdin = nil + + return err } func (process *Process) CloseStdout(ctx context.Context) (err error) { - ctx, span := oc.StartSpan(ctx, "hcs::Process::CloseStdout") //nolint:ineffassign,staticcheck + operation := "hcs::Process::CloseStdout" + ctx, span := oc.StartSpan(ctx, operation) //nolint:ineffassign,staticcheck defer span.End() defer func() { oc.SetSpanStatus(span, err) }() span.AddAttributes( @@ -411,7 +436,7 @@ func (process *Process) CloseStdout(ctx context.Context) (err error) { defer process.handleLock.Unlock() if process.handle == 0 { - return nil + return makeProcessError(process, operation, ErrAlreadyClosed, nil) } process.stdioLock.Lock() @@ -424,7 +449,8 @@ func (process *Process) CloseStdout(ctx context.Context) (err error) { } func (process *Process) CloseStderr(ctx context.Context) (err error) { - ctx, span := oc.StartSpan(ctx, "hcs::Process::CloseStderr") //nolint:ineffassign,staticcheck + operation := "hcs::Process::CloseStderr" + ctx, span := oc.StartSpan(ctx, operation) //nolint:ineffassign,staticcheck defer span.End() defer func() { oc.SetSpanStatus(span, err) }() span.AddAttributes( @@ -435,7 +461,7 @@ func (process *Process) CloseStderr(ctx context.Context) (err error) { defer process.handleLock.Unlock() if process.handle == 0 { - return nil + return makeProcessError(process, operation, ErrAlreadyClosed, nil) } process.stdioLock.Lock() From cb6f10b1d1461f3816aaa08b573276ba22057301 Mon Sep 17 00:00:00 2001 From: Hamza El-Saawy Date: Tue, 12 Apr 2022 10:39:12 -0400 Subject: [PATCH 2/2] PR: simplify callbacks, go1.13 errors, assignment Bug with how stdin streams were closed for processes fixed. removed ineffectual assignment to ctx in c.Wait() Signed-off-by: Hamza El-Saawy --- internal/cmd/cmd.go | 66 ++++++++++------------- internal/cmd/cmd_test.go | 111 ++++---------------------------------- internal/cmd/diag.go | 12 ++--- internal/cmd/io_binary.go | 2 +- internal/gcs/process.go | 40 ++++++-------- internal/hcs/process.go | 25 +++++---- 6 files changed, 76 insertions(+), 180 deletions(-) diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index 741855fb86..fa66c5e581 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -47,6 +47,10 @@ type Cmd struct { Stdout io.Writer Stderr io.Writer + // CloseStdIn attempts to cast `Stdin` to call `CloseRead() error` or `Close() error` + // on the upstream StdIn IO stream after the process ends but before `Wait` completes. + CloseStdIn bool + // Log provides a logrus entry to use in logging IO copying status. Log *logrus.Entry @@ -66,10 +70,6 @@ type Cmd struct { // ExitState is filled out after Wait() (or Run() or Output()) completes. ExitState *ExitState - // afterExitFuns are run after the process exits, but before wait returns and - // IO is waited on - afterExitFuns []func(context.Context) error - iogrp errgroup.Group stdinErr atomic.Value allDoneCh chan struct{} // closed after Wait finishes @@ -144,9 +144,12 @@ func CommandContext(ctx context.Context, host cow.ProcessHost, name string, arg // Start starts a command. The caller must ensure that if Start succeeds, // Wait is eventually called to clean up resources. func (c *Cmd) Start() error { - ctx := context.Background() - if c.Context != nil { - ctx = c.Context + ctx := c.Context + if ctx == nil { + ctx = context.Background() + } + if c.Log == nil { + c.Log = log.L.Dup() } c.allDoneCh = make(chan struct{}) @@ -288,11 +291,8 @@ func (c *Cmd) Start() error { // Wait waits for a command and its IO to complete and closes the underlying // process. It can only be called once. It returns an ExitError if the command // runs and returns a non-zero exit code. -func (c *Cmd) Wait() error { - ctx := context.Background() - if c.Context != nil { - ctx = c.Context - } +func (c *Cmd) Wait() (err error) { + // c.Context and c.Log should have been properly initialized in c.Start() waitErr := c.Process.Wait() if waitErr != nil && c.Log != nil { @@ -305,9 +305,21 @@ func (c *Cmd) Wait() error { state.code = code } - err := c.afterExit(ctx) - if err != nil { - log.G(ctx).WithError(err).Warn("error when running after exit functions") + if c.Stdin != nil && c.CloseStdIn { + // try to close the stdin to end the `relayIO`/`io.Copy` go routine + if cstdin, ok := c.Stdin.(interface{ CloseRead() error }); ok { + if stdinErr := cstdin.CloseRead(); !isIOChannelClosedErr(stdinErr) { + err = stdinErr + } + } else if cstdin, ok := c.Stdin.(io.Closer); ok { + if stdinErr := cstdin.Close(); !isIOChannelClosedErr(stdinErr) { + err = stdinErr + } + } + + if err != nil { + c.Log.WithError(err).Warn("could not close upstram Stdin after process finished") + } } // Terminate the IO if the copy does not complete in the requested time. @@ -321,13 +333,13 @@ func (c *Cmd) Wait() error { // Close the process to cancel any reads to stdout or stderr. c.Process.Close() if c.Log != nil { - c.Log. - WithField(logfields.Timeout, c.CopyAfterExitTimeout.String()). + c.Log.WithField(logfields.Timeout, c.CopyAfterExitTimeout.String()). Warn("timed out waiting for stdio relay") } } }() } + ioErr := c.iogrp.Wait() if inErr, _ := c.stdinErr.Load().(error); inErr != nil { if ioErr == nil { @@ -384,23 +396,3 @@ func (c *Cmd) Output() ([]byte, error) { err := c.Run() return b.Bytes(), err } - -func (c *Cmd) afterExit(ctx context.Context) (err error) { - for _, f := range c.afterExitFuns { - if ferr := f(ctx); err != nil { - err = ferr - log.G(ctx).WithError(err).Warn("error running function after process exit") - } - } - return err -} - -// RegisterAfterExitFun registers a function to be run after the process exits, but -// before the IO copy operations are waited on and the process is closed -func (c *Cmd) RegisterAfterExitFun(f func(context.Context) error) { - // TODO: find a better way to check if the process is started - if c.allDoneCh != nil { - return - } - c.afterExitFuns = append(c.afterExitFuns, f) -} diff --git a/internal/cmd/cmd_test.go b/internal/cmd/cmd_test.go index fe16e816c2..c176a4ee28 100644 --- a/internal/cmd/cmd_test.go +++ b/internal/cmd/cmd_test.go @@ -156,8 +156,9 @@ func (p *localProcess) Wait() error { func TestCmdExitCode(t *testing.T) { cmd := Command(&localProcessHost{}, "cmd", "/c", "exit", "/b", "64") err := cmd.Run() - if e, ok := err.(*ExitError); !ok || e.ExitCode() != 64 { - t.Fatal("expected exit code 64, got ", err) + eerr := &ExitError{} + if !errors.As(err, &eerr) || eerr.ExitCode() != 64 { + t.Fatalf("expected exit code 64, got %v", err) } } @@ -172,14 +173,15 @@ func TestCmdOutput(t *testing.T) { } } -func TestCmdContext(t *testing.T) { +func TestCmdContextCloseStdIn(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) defer cancel() cmd := CommandContext(ctx, &localProcessHost{}, "cmd", "/c", "pause") r, w := io.Pipe() + defer w.Close() cmd.Stdin = r - cmd.RegisterAfterExitFun(func(_ context.Context) error { return w.Close() }) + cmd.CloseStdIn = true err := cmd.Start() if err != nil { @@ -187,8 +189,9 @@ func TestCmdContext(t *testing.T) { } err = cmd.Wait() - if e, ok := err.(*ExitError); !ok || e.ExitCode() != 1 || ctx.Err() == nil { - t.Fatal(err) + eerr := &ExitError{} + if !errors.As(err, &eerr) || eerr.ExitCode() != 1 || ctx.Err() == nil { + t.Fatalf("expected context timeout or exit code 1, got: %v", err) } } @@ -219,72 +222,6 @@ func TestCmdStdinBlocked(t *testing.T) { } } -func TestCmdAfterExitFun(t *testing.T) { - cmd := Command(&localProcessHost{}, "cmd", "/c") - - c := make(chan struct{}) - cmd.RegisterAfterExitFun(func(_ context.Context) error { - close(c) - time.Sleep(50 * time.Millisecond) - return nil - }) - - err := cmd.Start() - if err != nil { - t.Fatal(err) - } - - // call cmd.Wait to make sure after funs are called - done := make(chan error) - go func() { - done <- cmd.Wait() - close(done) - }() - - err = cmd.Process.Wait() - if err != nil { - t.Fatal(err) - } - - select { - case <-c: - // if they both finish at the same time, it is undefined which case is chosen... - case <-cmd.allDoneCh: - t.Fatalf("after exit did not finish before cmd.Wait returned") - } - - // still check for errors during cmd.Wait(): - err = <-done - if err != nil { - t.Fatalf("cmd failed: %v", err) - } -} - -func TestCmdAfterExitFunRegistration(t *testing.T) { - cmd := Command(&localProcessHost{}, "cmd", "/c", "echo", "hello") - - l := len(cmd.afterExitFuns) - cmd.RegisterAfterExitFun(func(_ context.Context) error { - return nil - }) - if len(cmd.afterExitFuns) != l+1 { - t.Fatalf("function registration failed") - } - - err := cmd.Start() - if err != nil { - t.Fatalf("cmd Run failed: %v", err) - } - - cmd.RegisterAfterExitFun(func(_ context.Context) error { - return errors.New("this error should never be raised") - }) - - if len(cmd.afterExitFuns) != l+1 { - t.Fatalf("function should not have been registered") - } -} - type stuckIoProcessHost struct { cow.ProcessHost } @@ -324,7 +261,7 @@ func TestCmdStuckIo(t *testing.T) { cmd := Command(&stuckIoProcessHost{&localProcessHost{}}, "cmd", "/c", "echo", "hello") cmd.CopyAfterExitTimeout = time.Millisecond * 200 _, err := cmd.Output() - if err != io.ErrClosedPipe { + if !errors.Is(err, io.ErrClosedPipe) { t.Fatal(err) } } @@ -353,31 +290,3 @@ func TestCmdStuckStdoutNotClosed(t *testing.T) { case <-tr.C: } } - -func TestCmdStuckStdoutClosed(t *testing.T) { - cmd := Command(&stuckIoProcessHost{&localProcessHost{}}, "cmd", "/c") - r, w := io.Pipe() - defer r.Close() - cmd.Stdout = w - cmd.RegisterAfterExitFun(func(ctx context.Context) error { - p := cmd.Process.(*stuckIoProcess) - return p.stdout.Close() - }) - - done := make(chan error) - go func() { - done <- cmd.Run() - close(done) - }() - - tr := time.NewTimer(250 * time.Millisecond) - defer tr.Stop() - select { - case err := <-done: - if err != io.ErrClosedPipe { - t.Fatalf("cmd run failed: %v", err) - } - case <-tr.C: - t.Fatal("command did not exit") - } -} diff --git a/internal/cmd/diag.go b/internal/cmd/diag.go index 71391becdd..4de6b1a8d1 100644 --- a/internal/cmd/diag.go +++ b/internal/cmd/diag.go @@ -31,14 +31,11 @@ func ExecInUvm(ctx context.Context, vm *uvm.UtilityVM, req *CmdProcessRequest) ( } cmd.Spec.Terminal = req.Terminal cmd.Stdin = np.Stdin() - cmd.Stdout = np.Stdout() - cmd.Stderr = np.Stderr() if cmd.Stdin != nil { - cmd.RegisterAfterExitFun(func(ctx context.Context) error { - np.CloseStdin(ctx) - return nil - }) + cmd.CloseStdIn = true } + cmd.Stdout = np.Stdout() + cmd.Stderr = np.Stderr() cmd.Log = log.G(ctx).WithField(logfields.UVMID, vm.ID()) err = cmd.Run() return cmd.ExitState.ExitCode(), err @@ -65,7 +62,8 @@ func ExecInShimHost(ctx context.Context, req *CmdProcessRequest) (int, error) { cmd.Stderr = np.Stderr() err = cmd.Run() if err != nil { - if exiterr, ok := err.(*exec.ExitError); ok { + exiterr := &ExitError{} + if errors.As(err, &exiterr) { return exiterr.ExitCode(), exiterr } return -1, err diff --git a/internal/cmd/io_binary.go b/internal/cmd/io_binary.go index 989a53c93c..0115f7e075 100644 --- a/internal/cmd/io_binary.go +++ b/internal/cmd/io_binary.go @@ -89,7 +89,7 @@ func NewBinaryIO(ctx context.Context, id string, uri *url.URL) (_ UpstreamIO, er // Wait for logging driver to signal to the wait pipe that it's ready to consume IO go func() { b := make([]byte, 1) - if _, err := waitPipe.Read(b); err != nil && err != io.EOF { + if _, err := waitPipe.Read(b); err != nil && !errors.Is(err, io.EOF) { errCh <- err return } diff --git a/internal/gcs/process.go b/internal/gcs/process.go index c0d5ae1934..87ecfb8727 100644 --- a/internal/gcs/process.go +++ b/internal/gcs/process.go @@ -135,12 +135,9 @@ func (p *Process) Close() (err error) { trace.StringAttribute("cid", p.cid), trace.Int64Attribute("pid", int64(p.id))) - p.gcLock.RLock() - if p.gc == nil { - p.gcLock.RUnlock() + if p.gcClosed() { return hcs.ErrAlreadyClosed } - p.gcLock.RUnlock() p.stdioLock.Lock() if p.stdin != nil { @@ -185,23 +182,18 @@ func (p *Process) CloseStdin(ctx context.Context) (err error) { trace.StringAttribute("cid", p.cid), trace.Int64Attribute("pid", int64(p.id))) - p.gcLock.RLock() - if p.gc == nil { - p.gcLock.RUnlock() + if p.gcClosed() { return hcs.ErrAlreadyClosed } - p.gcLock.RUnlock() p.stdioLock.Lock() defer p.stdioLock.Unlock() if p.stdin != nil { - // First close the channel for writing, sending an EOF to readers, then - // close the file. - if err = p.stdin.CloseWrite(); err != nil { - return err - } - if err = p.stdin.Close(); err == nil { - return err + // First (try to) close the channel for writing, sending an EOF to readers, then + // close the file. If CloseWrite fails, close the stream regardless. + err = p.stdin.CloseWrite() + if cerr := p.stdin.Close(); cerr != nil { + err = cerr } p.stdin = nil } @@ -216,12 +208,9 @@ func (p *Process) CloseStdout(ctx context.Context) (err error) { trace.StringAttribute("cid", p.cid), trace.Int64Attribute("pid", int64(p.id))) - p.gcLock.RLock() - if p.gc == nil { - p.gcLock.RUnlock() + if p.gcClosed() { return hcs.ErrAlreadyClosed } - p.gcLock.RUnlock() p.stdioLock.Lock() defer p.stdioLock.Unlock() @@ -240,12 +229,9 @@ func (p *Process) CloseStderr(ctx context.Context) (err error) { trace.StringAttribute("cid", p.cid), trace.Int64Attribute("pid", int64(p.id))) - p.gcLock.RLock() - if p.gc == nil { - p.gcLock.RUnlock() + if p.gcClosed() { return hcs.ErrAlreadyClosed } - p.gcLock.RUnlock() p.stdioLock.Lock() defer p.stdioLock.Unlock() @@ -382,3 +368,11 @@ func (p *Process) waitBackground() { log.G(ctx).WithField("exitCode", ec).Debug("process exited") oc.SetSpanStatus(span, err) } + +// gcClosed checks if the guest connection has been set to `nil`. +// Must be able to acquire `gcLock` for reading +func (p *Process) gcClosed() bool { + p.gcLock.RLock() + defer p.gcLock.RUnlock() + return p.gc == nil +} diff --git a/internal/hcs/process.go b/internal/hcs/process.go index e3a10618e5..d4ac2e7361 100644 --- a/internal/hcs/process.go +++ b/internal/hcs/process.go @@ -368,6 +368,8 @@ func (process *Process) StdioLegacy() (_ io.WriteCloser, _ io.ReadCloser, _ io.R // Stdio returns the stdin, stdout, and stderr pipes, respectively. // To close them, close the process handle, or use the `CloseStd*` functions. func (process *Process) Stdio() (stdin io.Writer, stdout, stderr io.Reader) { + process.stdioLock.Lock() + defer process.stdioLock.Unlock() return process.stdin, process.stdout, process.stderr } @@ -375,7 +377,7 @@ func (process *Process) Stdio() (stdin io.Writer, stdout, stderr io.Reader) { // notified on the read side that there is no more data in stdin. func (process *Process) CloseStdin(ctx context.Context) (err error) { operation := "hcs::Process::CloseStdin" - ctx, span := trace.StartSpan(ctx, operation) //nolint:ineffassign,staticcheck + ctx, span := trace.StartSpan(ctx, operation) defer span.End() defer func() { oc.SetSpanStatus(span, err) }() span.AddAttributes( @@ -405,19 +407,20 @@ func (process *Process) CloseStdin(ctx context.Context) (err error) { }, } - modifyRequestB, err := json.Marshal(modifyRequest) - if err != nil { - return err - } - - resultJSON, err := vmcompute.HcsModifyProcess(ctx, process.handle, string(modifyRequestB)) - events := processHcsResult(ctx, resultJSON) - if err != nil { - return makeProcessError(process, operation, err, events) + var b []byte + b, err = json.Marshal(modifyRequest) + // don't return on errors, and still try to close the stream from the host + if err == nil { + var resultJSON string + resultJSON, err = vmcompute.HcsModifyProcess(ctx, process.handle, string(b)) + events := processHcsResult(ctx, resultJSON) + if err != nil { + err = makeProcessError(process, operation, err, events) + } } } - err = process.stdin.Close() + process.stdin.Close() process.stdin = nil return err