diff --git a/cli-plugins/socket/socket.go b/cli-plugins/socket/socket.go index 67ba11562e38..f46cf3ce6d1d 100644 --- a/cli-plugins/socket/socket.go +++ b/cli-plugins/socket/socket.go @@ -13,18 +13,22 @@ import ( // executed the socket name it should listen on to coordinate with the host CLI. const EnvKey = "DOCKER_CLI_PLUGIN_SOCKET" -// SetupConn sets up a Unix socket listener, establishes a goroutine to handle connections -// and update the conn pointer, and returns the listener for the socket (which the caller -// is responsible for closing when it's no longer needed). -func SetupConn(conn **net.UnixConn) (*net.UnixListener, error) { +// SetupConn sets up a Unix socket listener, establishes a goroutine to handle connections. +// A listener is returned, along with a connection channel that will receive the established +// connection. The channel *may* return a nil connection and should be checked once received. +// The caller is responsible for closing the listener when it's no longer needed. +func SetupConn() (*net.UnixListener, <-chan *net.UnixConn, error) { listener, err := listen("docker_cli_" + randomID()) if err != nil { - return nil, err + return nil, nil, err } - accept(listener, conn) + // accept starts a background goroutine + // to accept a new connection + // once accepted, the connChan will be updated. + connChan := accept(listener) - return listener, nil + return listener, connChan, nil } func randomID() string { @@ -35,16 +39,25 @@ func randomID() string { return hex.EncodeToString(b) } -func accept(listener *net.UnixListener, conn **net.UnixConn) { +// accept creates a new Unix socket connection +// and sends it to the *net.UnixConn channel +func accept(listener *net.UnixListener) <-chan *net.UnixConn { + connChan := make(chan *net.UnixConn, 1) + go func() { - for { - // ignore error here, if we failed to accept a connection, - // conn is nil and we fallback to previous behavior - *conn, _ = listener.AcceptUnix() - // perform any platform-specific actions on accept (e.g. unlink non-abstract sockets) - onAccept(*conn, listener) - } + // close the channel to signal we won't accept any more connections + defer close(connChan) + // this is a blocking call and will wait + // until a new connection is accepted + // or until the timout is reached + conn, _ := listener.AcceptUnix() + + // perform any platform-specific actions on accept (e.g. unlink non-abstract sockets) + onAccept(listener) + connChan <- conn }() + + return connChan } // ConnectAndWait connects to the socket passed via well-known env var, diff --git a/cli-plugins/socket/socket_darwin.go b/cli-plugins/socket/socket_darwin.go index 17ab6aa69e6e..2531a9c8044a 100644 --- a/cli-plugins/socket/socket_darwin.go +++ b/cli-plugins/socket/socket_darwin.go @@ -14,6 +14,6 @@ func listen(socketname string) (*net.UnixListener, error) { }) } -func onAccept(conn *net.UnixConn, listener *net.UnixListener) { +func onAccept(listener *net.UnixListener) { syscall.Unlink(listener.Addr().String()) } diff --git a/cli-plugins/socket/socket_nodarwin.go b/cli-plugins/socket/socket_nodarwin.go index aa6065ecb446..ed5d4cbcc6f0 100644 --- a/cli-plugins/socket/socket_nodarwin.go +++ b/cli-plugins/socket/socket_nodarwin.go @@ -13,7 +13,7 @@ func listen(socketname string) (*net.UnixListener, error) { }) } -func onAccept(conn *net.UnixConn, listener *net.UnixListener) { +func onAccept(listener *net.UnixListener) { // do nothing // while on darwin and OpenBSD we would unlink here; // on non-darwin the socket is abstract and not present on the filesystem diff --git a/cli-plugins/socket/socket_openbsd.go b/cli-plugins/socket/socket_openbsd.go index 17ab6aa69e6e..2531a9c8044a 100644 --- a/cli-plugins/socket/socket_openbsd.go +++ b/cli-plugins/socket/socket_openbsd.go @@ -14,6 +14,6 @@ func listen(socketname string) (*net.UnixListener, error) { }) } -func onAccept(conn *net.UnixConn, listener *net.UnixListener) { +func onAccept(listener *net.UnixListener) { syscall.Unlink(listener.Addr().String()) } diff --git a/cli-plugins/socket/socket_test.go b/cli-plugins/socket/socket_test.go index 409eb689485c..d9947a7ab881 100644 --- a/cli-plugins/socket/socket_test.go +++ b/cli-plugins/socket/socket_test.go @@ -15,8 +15,7 @@ import ( func TestSetupConn(t *testing.T) { t.Run("updates conn when connected", func(t *testing.T) { - var conn *net.UnixConn - listener, err := SetupConn(&conn) + listener, conn, err := SetupConn() assert.NilError(t, err) assert.Check(t, listener != nil, "returned nil listener but no error") addr, err := net.ResolveUnixAddr("unix", listener.Addr().String()) @@ -25,29 +24,11 @@ func TestSetupConn(t *testing.T) { _, err = net.DialUnix("unix", nil, addr) assert.NilError(t, err, "failed to dial returned listener") - pollConnNotNil(t, &conn) - }) - - t.Run("allows reconnects", func(t *testing.T) { - var conn *net.UnixConn - listener, err := SetupConn(&conn) - assert.NilError(t, err) - assert.Check(t, listener != nil, "returned nil listener but no error") - addr, err := net.ResolveUnixAddr("unix", listener.Addr().String()) - assert.NilError(t, err, "failed to resolve listener address") - - otherConn, err := net.DialUnix("unix", nil, addr) - assert.NilError(t, err, "failed to dial returned listener") - - otherConn.Close() - - _, err = net.DialUnix("unix", nil, addr) - assert.NilError(t, err, "failed to redial listener") + pollConnNotNil(t, conn) }) t.Run("does not leak sockets to local directory", func(t *testing.T) { - var conn *net.UnixConn - listener, err := SetupConn(&conn) + listener, _, err := SetupConn() assert.NilError(t, err) assert.Check(t, listener != nil, "returned nil listener but no error") checkDirNoPluginSocket(t) @@ -78,8 +59,7 @@ func checkDirNoPluginSocket(t *testing.T) { func TestConnectAndWait(t *testing.T) { t.Run("calls cancel func on EOF", func(t *testing.T) { - var conn *net.UnixConn - listener, err := SetupConn(&conn) + listener, connChan, err := SetupConn() assert.NilError(t, err, "failed to setup listener") done := make(chan struct{}) @@ -87,8 +67,9 @@ func TestConnectAndWait(t *testing.T) { cancelFunc := func() { done <- struct{}{} } + ConnectAndWait(cancelFunc) - pollConnNotNil(t, &conn) + conn := pollConnNotNil(t, connChan) conn.Close() select { @@ -101,8 +82,7 @@ func TestConnectAndWait(t *testing.T) { // TODO: this test cannot be executed with `t.Parallel()`, due to // relying on goroutine numbers to ensure correct behaviour t.Run("connect goroutine exits after EOF", func(t *testing.T) { - var conn *net.UnixConn - listener, err := SetupConn(&conn) + listener, connChan, err := SetupConn() assert.NilError(t, err, "failed to setup listener") t.Setenv(EnvKey, listener.Addr().String()) numGoroutines := runtime.NumGoroutine() @@ -110,8 +90,9 @@ func TestConnectAndWait(t *testing.T) { ConnectAndWait(func() {}) assert.Equal(t, runtime.NumGoroutine(), numGoroutines+1) - pollConnNotNil(t, &conn) + conn := pollConnNotNil(t, connChan) conn.Close() + poll.WaitOn(t, func(t poll.LogT) poll.Result { if runtime.NumGoroutine() > numGoroutines+1 { return poll.Continue("waiting for connect goroutine to exit") @@ -121,13 +102,17 @@ func TestConnectAndWait(t *testing.T) { }) } -func pollConnNotNil(t *testing.T, conn **net.UnixConn) { +func pollConnNotNil(t *testing.T, conn <-chan *net.UnixConn) *net.UnixConn { t.Helper() - poll.WaitOn(t, func(t poll.LogT) poll.Result { - if *conn == nil { - return poll.Continue("waiting for conn to not be nil") + select { + case c := <-conn: + if c == nil { + t.Fatal("conn is nil") } - return poll.Success() - }, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond)) + return c + case <-time.After(10 * time.Millisecond): + t.Fatal("timeout waiting for conn to be set") + } + return nil } diff --git a/cmd/docker/docker.go b/cmd/docker/docker.go index cfc53a6fa170..addf674aa2c1 100644 --- a/cmd/docker/docker.go +++ b/cmd/docker/docker.go @@ -2,7 +2,6 @@ package main import ( "fmt" - "net" "os" "os/exec" "os/signal" @@ -222,8 +221,7 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string, } // Establish the plugin socket, adding it to the environment under a well-known key if successful. - var conn *net.UnixConn - listener, err := socket.SetupConn(&conn) + listener, connChan, err := socket.SetupConn() if err == nil { envs = append(envs, socket.EnvKey+"="+listener.Addr().String()) defer listener.Close() @@ -247,11 +245,17 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string, // receive signals due to sharing a pgid with the parent CLI continue } - if conn != nil { - if err := conn.Close(); err != nil { - _, _ = fmt.Fprintf(dockerCli.Err(), "failed to signal plugin to close: %v\n", err) + select { + // connChan will close itself once we receive the connection + // thus further loops will not block on connChan + case c := <-connChan: + if c != nil { + if err := c.Close(); err != nil { + _, _ = fmt.Fprintf(dockerCli.Err(), "failed to signal plugin to close: %v\n", err) + } } - conn = nil + default: + // fallthrough and continue with the loop. } retries++ if retries >= exitLimit {