Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions cli/command/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -327,13 +326,8 @@ func (cli *DockerCli) getInitTimeout() time.Duration {

func (cli *DockerCli) initializeFromClient() {
ctx := context.Background()
if !strings.HasPrefix(cli.dockerEndpoint.Host, "ssh://") {
// @FIXME context.WithTimeout doesn't work with connhelper / ssh connections
// time="2020-04-10T10:16:26Z" level=warning msg="commandConn.CloseWrite: commandconn: failed to wait: signal: killed"
var cancel func()
ctx, cancel = context.WithTimeout(ctx, cli.getInitTimeout())
defer cancel()
}
ctx, cancel := context.WithTimeout(ctx, cli.getInitTimeout())
defer cancel()

ping, err := cli.client.Ping(ctx)
if err != nil {
Expand Down
208 changes: 106 additions & 102 deletions cli/connhelper/commandconn/commandconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"runtime"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"

Expand Down Expand Up @@ -64,100 +65,86 @@ func New(_ context.Context, cmd string, args ...string) (net.Conn, error) {

// commandConn implements net.Conn
type commandConn struct {
cmd *exec.Cmd
cmdExited bool
cmdWaitErr error
cmdMutex sync.Mutex
stdin io.WriteCloser
stdout io.ReadCloser
stderrMu sync.Mutex
stderr bytes.Buffer
stdioClosedMu sync.Mutex // for stdinClosed and stdoutClosed
stdinClosed bool
stdoutClosed bool
localAddr net.Addr
remoteAddr net.Addr
cmdMutex sync.Mutex // for cmd, cmdWaitErr
cmd *exec.Cmd
cmdWaitErr error
cmdExited atomic.Bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to moby/moby#45966 (comment), this means we should update our go 1.18 line in vendor.mod to at least 1.19 😅

# github.com/docker/cli/cli/connhelper/commandconn
cli/connhelper/commandconn/commandconn.go:71:22: undefined: atomic.Bool
cli/connhelper/commandconn/commandconn.go:76:22: undefined: atomic.Bool
cli/connhelper/commandconn/commandconn.go:77:22: undefined: atomic.Bool
cli/connhelper/commandconn/commandconn.go:78:22: undefined: atomic.Bool

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @tianon, are you interested in contributing to open source and opening a PR? 😉

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naw, open source isn't really my thing 😂

(totally fair callout, I'll see what I can do)

stdin io.WriteCloser
stdout io.ReadCloser
stderrMu sync.Mutex // for stderr
stderr bytes.Buffer
stdinClosed atomic.Bool
stdoutClosed atomic.Bool
closing atomic.Bool
localAddr net.Addr
remoteAddr net.Addr
}

// killIfStdioClosed kills the cmd if both stdin and stdout are closed.
func (c *commandConn) killIfStdioClosed() error {
c.stdioClosedMu.Lock()
stdioClosed := c.stdoutClosed && c.stdinClosed
c.stdioClosedMu.Unlock()
if !stdioClosed {
return nil
// kill terminates the process. On Windows it kills the process directly,
// whereas on other platforms, a SIGTERM is sent, before forcefully terminating
// the process after 3 seconds.
func (c *commandConn) kill() {
if c.cmdExited.Load() {
return
}
return c.kill()
}

// killAndWait tries sending SIGTERM to the process before sending SIGKILL.
func killAndWait(cmd *exec.Cmd) error {
c.cmdMutex.Lock()
var werr error
if runtime.GOOS != "windows" {
werrCh := make(chan error)
go func() { werrCh <- cmd.Wait() }()
cmd.Process.Signal(syscall.SIGTERM)
go func() { werrCh <- c.cmd.Wait() }()
_ = c.cmd.Process.Signal(syscall.SIGTERM)
select {
case werr = <-werrCh:
case <-time.After(3 * time.Second):
cmd.Process.Kill()
_ = c.cmd.Process.Kill()
werr = <-werrCh
}
} else {
cmd.Process.Kill()
werr = cmd.Wait()
_ = c.cmd.Process.Kill()
werr = c.cmd.Wait()
}
return werr
c.cmdWaitErr = werr
c.cmdMutex.Unlock()
c.cmdExited.Store(true)
}

// kill returns nil if the command terminated, regardless to the exit status.
func (c *commandConn) kill() error {
var werr error
c.cmdMutex.Lock()
if c.cmdExited {
werr = c.cmdWaitErr
} else {
werr = killAndWait(c.cmd)
c.cmdWaitErr = werr
c.cmdExited = true
}
c.cmdMutex.Unlock()
if werr == nil {
return nil
}
wExitErr, ok := werr.(*exec.ExitError)
if ok {
if wExitErr.ProcessState.Exited() {
return nil
}
// handleEOF handles io.EOF errors while reading or writing from the underlying
// command pipes.
//
// When we've received an EOF we expect that the command will
// be terminated soon. As such, we call Wait() on the command
// and return EOF or the error depending on whether the command
// exited with an error.
//
// If Wait() does not return within 10s, an error is returned
func (c *commandConn) handleEOF(err error) error {
if err != io.EOF {
return err
}
return errors.Wrapf(werr, "commandconn: failed to wait")
}

func (c *commandConn) onEOF(eof error) error {
// when we got EOF, the command is going to be terminated
var werr error
c.cmdMutex.Lock()
if c.cmdExited {
defer c.cmdMutex.Unlock()

var werr error
if c.cmdExited.Load() {
werr = c.cmdWaitErr
} else {
werrCh := make(chan error)
go func() { werrCh <- c.cmd.Wait() }()
select {
case werr = <-werrCh:
c.cmdWaitErr = werr
c.cmdExited = true
c.cmdExited.Store(true)
case <-time.After(10 * time.Second):
c.cmdMutex.Unlock()
c.stderrMu.Lock()
stderr := c.stderr.String()
c.stderrMu.Unlock()
return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, eof, stderr)
return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, err, stderr)
}
}
c.cmdMutex.Unlock()

if werr == nil {
return eof
return err
}
c.stderrMu.Lock()
stderr := c.stderr.String()
Expand All @@ -166,71 +153,88 @@ func (c *commandConn) onEOF(eof error) error {
}

func ignorableCloseError(err error) bool {
errS := err.Error()
ss := []string{
os.ErrClosed.Error(),
return strings.Contains(err.Error(), os.ErrClosed.Error())
}

func (c *commandConn) Read(p []byte) (int, error) {
n, err := c.stdout.Read(p)
// check after the call to Read, since
// it is blocking, and while waiting on it
// Close might get called
if c.closing.Load() {
// If we're currently closing the connection
// we don't want to call onEOF, but we do want
// to return an io.EOF
return 0, io.EOF
}
for _, s := range ss {
if strings.Contains(errS, s) {
return true
}

return n, c.handleEOF(err)
}

func (c *commandConn) Write(p []byte) (int, error) {
n, err := c.stdin.Write(p)
// check after the call to Write, since
// it is blocking, and while waiting on it
// Close might get called
if c.closing.Load() {
// If we're currently closing the connection
// we don't want to call onEOF, but we do want
// to return an io.EOF
return 0, io.EOF
}
return false

return n, c.handleEOF(err)
}

// CloseRead allows commandConn to implement halfCloser
func (c *commandConn) CloseRead() error {
// NOTE: maybe already closed here
if err := c.stdout.Close(); err != nil && !ignorableCloseError(err) {
logrus.Warnf("commandConn.CloseRead: %v", err)
return err
}
c.stdioClosedMu.Lock()
c.stdoutClosed = true
c.stdioClosedMu.Unlock()
if err := c.killIfStdioClosed(); err != nil {
logrus.Warnf("commandConn.CloseRead: %v", err)
}
return nil
}
c.stdoutClosed.Store(true)

func (c *commandConn) Read(p []byte) (int, error) {
n, err := c.stdout.Read(p)
if err == io.EOF {
err = c.onEOF(err)
if c.stdinClosed.Load() {
c.kill()
}
return n, err

return nil
}

// CloseWrite allows commandConn to implement halfCloser
func (c *commandConn) CloseWrite() error {
// NOTE: maybe already closed here
if err := c.stdin.Close(); err != nil && !ignorableCloseError(err) {
logrus.Warnf("commandConn.CloseWrite: %v", err)
}
c.stdioClosedMu.Lock()
c.stdinClosed = true
c.stdioClosedMu.Unlock()
if err := c.killIfStdioClosed(); err != nil {
logrus.Warnf("commandConn.CloseWrite: %v", err)
return err
}
return nil
}
c.stdinClosed.Store(true)

func (c *commandConn) Write(p []byte) (int, error) {
n, err := c.stdin.Write(p)
if err == io.EOF {
err = c.onEOF(err)
if c.stdoutClosed.Load() {
c.kill()
}
return n, err
return nil
}

// Close is the net.Conn func that gets called
// by the transport when a dial is cancelled
// due to it's context timing out. Any blocked
// Read or Write calls will be unblocked and
// return errors. It will block until the underlying
// command has terminated.
func (c *commandConn) Close() error {
var err error
if err = c.CloseRead(); err != nil {
c.closing.Store(true)
defer c.closing.Store(false)

if err := c.CloseRead(); err != nil {
logrus.Warnf("commandConn.Close: CloseRead: %v", err)
return err
}
if err = c.CloseWrite(); err != nil {
if err := c.CloseWrite(); err != nil {
logrus.Warnf("commandConn.Close: CloseWrite: %v", err)
return err
}
return err

return nil
}

func (c *commandConn) LocalAddr() net.Addr {
Expand Down
Loading