Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions cli-plugins/socket/socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,22 @@ import (
// executed the socket name it should listen on to coordinate with the host CLI.
const EnvKey = "DOCKER_CLI_PLUGIN_SOCKET"

// SetupConn sets up a Unix socket listener, establishes a goroutine to handle connections
// and update the conn pointer, and returns the listener for the socket (which the caller
// is responsible for closing when it's no longer needed).
func SetupConn(conn **net.UnixConn) (*net.UnixListener, error) {
// SetupConn sets up a Unix socket listener, establishes a goroutine to handle connections.
// A listener is returned, along with a connection channel that will receive the established
// connection. The channel *may* return a nil connection and should be checked once received.
// The caller is responsible for closing the listener when it's no longer needed.
func SetupConn() (*net.UnixListener, <-chan *net.UnixConn, error) {
listener, err := listen("docker_cli_" + randomID())
if err != nil {
return nil, err
return nil, nil, err
}

accept(listener, conn)
// accept starts a background goroutine
// to accept a new connection
// once accepted, the connChan will be updated.
connChan := accept(listener)

return listener, nil
return listener, connChan, nil
}

func randomID() string {
Expand All @@ -35,16 +39,25 @@ func randomID() string {
return hex.EncodeToString(b)
}

func accept(listener *net.UnixListener, conn **net.UnixConn) {
// accept creates a new Unix socket connection
// and sends it to the *net.UnixConn channel
func accept(listener *net.UnixListener) <-chan *net.UnixConn {
connChan := make(chan *net.UnixConn, 1)

go func() {
for {
// ignore error here, if we failed to accept a connection,
// conn is nil and we fallback to previous behavior
*conn, _ = listener.AcceptUnix()
// perform any platform-specific actions on accept (e.g. unlink non-abstract sockets)
onAccept(*conn, listener)
}
// close the channel to signal we won't accept any more connections
defer close(connChan)
// this is a blocking call and will wait
// until a new connection is accepted
// or until the timout is reached
conn, _ := listener.AcceptUnix()

// perform any platform-specific actions on accept (e.g. unlink non-abstract sockets)
onAccept(listener)
connChan <- conn
}()

return connChan
}

// ConnectAndWait connects to the socket passed via well-known env var,
Expand Down
2 changes: 1 addition & 1 deletion cli-plugins/socket/socket_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ func listen(socketname string) (*net.UnixListener, error) {
})
}

func onAccept(conn *net.UnixConn, listener *net.UnixListener) {
func onAccept(listener *net.UnixListener) {
syscall.Unlink(listener.Addr().String())
}
2 changes: 1 addition & 1 deletion cli-plugins/socket/socket_nodarwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func listen(socketname string) (*net.UnixListener, error) {
})
}

func onAccept(conn *net.UnixConn, listener *net.UnixListener) {
func onAccept(listener *net.UnixListener) {
// do nothing
// while on darwin and OpenBSD we would unlink here;
// on non-darwin the socket is abstract and not present on the filesystem
Expand Down
2 changes: 1 addition & 1 deletion cli-plugins/socket/socket_openbsd.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ func listen(socketname string) (*net.UnixListener, error) {
})
}

func onAccept(conn *net.UnixConn, listener *net.UnixListener) {
func onAccept(listener *net.UnixListener) {
syscall.Unlink(listener.Addr().String())
}
53 changes: 19 additions & 34 deletions cli-plugins/socket/socket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ import (

func TestSetupConn(t *testing.T) {
t.Run("updates conn when connected", func(t *testing.T) {
var conn *net.UnixConn
listener, err := SetupConn(&conn)
listener, conn, err := SetupConn()
assert.NilError(t, err)
assert.Check(t, listener != nil, "returned nil listener but no error")
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
Expand All @@ -25,29 +24,11 @@ func TestSetupConn(t *testing.T) {
_, err = net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to dial returned listener")

pollConnNotNil(t, &conn)
})

t.Run("allows reconnects", func(t *testing.T) {
var conn *net.UnixConn
listener, err := SetupConn(&conn)
assert.NilError(t, err)
assert.Check(t, listener != nil, "returned nil listener but no error")
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
assert.NilError(t, err, "failed to resolve listener address")

otherConn, err := net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to dial returned listener")

otherConn.Close()

_, err = net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to redial listener")
pollConnNotNil(t, conn)
})

t.Run("does not leak sockets to local directory", func(t *testing.T) {
var conn *net.UnixConn
listener, err := SetupConn(&conn)
listener, _, err := SetupConn()
assert.NilError(t, err)
assert.Check(t, listener != nil, "returned nil listener but no error")
checkDirNoPluginSocket(t)
Expand Down Expand Up @@ -78,17 +59,17 @@ func checkDirNoPluginSocket(t *testing.T) {

func TestConnectAndWait(t *testing.T) {
t.Run("calls cancel func on EOF", func(t *testing.T) {
var conn *net.UnixConn
listener, err := SetupConn(&conn)
listener, connChan, err := SetupConn()
assert.NilError(t, err, "failed to setup listener")

done := make(chan struct{})
t.Setenv(EnvKey, listener.Addr().String())
cancelFunc := func() {
done <- struct{}{}
}

ConnectAndWait(cancelFunc)
pollConnNotNil(t, &conn)
conn := pollConnNotNil(t, connChan)
conn.Close()

select {
Expand All @@ -101,17 +82,17 @@ func TestConnectAndWait(t *testing.T) {
// TODO: this test cannot be executed with `t.Parallel()`, due to
// relying on goroutine numbers to ensure correct behaviour
t.Run("connect goroutine exits after EOF", func(t *testing.T) {
var conn *net.UnixConn
listener, err := SetupConn(&conn)
listener, connChan, err := SetupConn()
assert.NilError(t, err, "failed to setup listener")
t.Setenv(EnvKey, listener.Addr().String())
numGoroutines := runtime.NumGoroutine()

ConnectAndWait(func() {})
assert.Equal(t, runtime.NumGoroutine(), numGoroutines+1)

pollConnNotNil(t, &conn)
conn := pollConnNotNil(t, connChan)
conn.Close()

poll.WaitOn(t, func(t poll.LogT) poll.Result {
if runtime.NumGoroutine() > numGoroutines+1 {
return poll.Continue("waiting for connect goroutine to exit")
Expand All @@ -121,13 +102,17 @@ func TestConnectAndWait(t *testing.T) {
})
}

func pollConnNotNil(t *testing.T, conn **net.UnixConn) {
func pollConnNotNil(t *testing.T, conn <-chan *net.UnixConn) *net.UnixConn {
t.Helper()

poll.WaitOn(t, func(t poll.LogT) poll.Result {
if *conn == nil {
return poll.Continue("waiting for conn to not be nil")
select {
case c := <-conn:
if c == nil {
t.Fatal("conn is nil")
}
return poll.Success()
}, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond))
return c
case <-time.After(10 * time.Millisecond):
t.Fatal("timeout waiting for conn to be set")
}
return nil
}
18 changes: 11 additions & 7 deletions cmd/docker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"fmt"
"net"
"os"
"os/exec"
"os/signal"
Expand Down Expand Up @@ -222,8 +221,7 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string,
}

// Establish the plugin socket, adding it to the environment under a well-known key if successful.
var conn *net.UnixConn
listener, err := socket.SetupConn(&conn)
listener, connChan, err := socket.SetupConn()
if err == nil {
envs = append(envs, socket.EnvKey+"="+listener.Addr().String())
defer listener.Close()
Expand All @@ -247,11 +245,17 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string,
// receive signals due to sharing a pgid with the parent CLI
continue
}
if conn != nil {
if err := conn.Close(); err != nil {
_, _ = fmt.Fprintf(dockerCli.Err(), "failed to signal plugin to close: %v\n", err)
select {
// connChan will close itself once we receive the connection
// thus further loops will not block on connChan
case c := <-connChan:
if c != nil {
if err := c.Close(); err != nil {
_, _ = fmt.Fprintf(dockerCli.Err(), "failed to signal plugin to close: %v\n", err)
}
}
conn = nil
default:
// fallthrough and continue with the loop.
}
retries++
if retries >= exitLimit {
Expand Down