Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 87 additions & 19 deletions cli-plugins/socket/socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,104 @@ import (
"io"
"net"
"os"
"runtime"
"sync"
)

// EnvKey represents the well-known environment variable used to pass the plugin being
// executed the socket name it should listen on to coordinate with the host CLI.
const EnvKey = "DOCKER_CLI_PLUGIN_SOCKET"

// SetupConn sets up a Unix socket listener, establishes a goroutine to handle connections
// and update the conn pointer, and returns the listener for the socket (which the caller
// is responsible for closing when it's no longer needed).
func SetupConn(conn **net.UnixConn) (*net.UnixListener, error) {
listener, err := listen("docker_cli_" + randomID())
// NewPluginServer creates a plugin server that listens on a new Unix domain socket.
// `h` is called for each new connection to the socket in a goroutine.
func NewPluginServer(h func(net.Conn)) (*PluginServer, error) {
l, err := listen("docker_cli_" + randomID())
if err != nil {
return nil, err
}

accept(listener, conn)
if h == nil {
h = func(net.Conn) {}
}

pl := &PluginServer{
l: l,
h: h,
}

go func() {
defer pl.Close()
for {
err := pl.accept()
if err != nil {
return
}
}
}()

return pl, nil
}

type PluginServer struct {
mu sync.Mutex
conns []net.Conn
l *net.UnixListener
h func(net.Conn)
closed bool
}

func (pl *PluginServer) accept() error {
conn, err := pl.l.Accept()
if err != nil {
return err
}

pl.mu.Lock()
defer pl.mu.Unlock()

if pl.closed {
// handle potential race condition between Close and Accept
conn.Close()
return errors.New("plugin server is closed")
}

return listener, nil
pl.conns = append(pl.conns, conn)

go pl.h(conn)
return nil
}

func (pl *PluginServer) Addr() net.Addr {
return pl.l.Addr()
}

// Close ensures that the server is no longer accepting new connections and closes all existing connections.
// Existing connections will receive [io.EOF].
func (pl *PluginServer) Close() error {
// Remove the listener socket, if it exists on the filesystem.
unlink(pl.l)

// Close connections first to ensure the connections get io.EOF instead of a connection reset.
pl.closeAllConns()

// Try to ensure that any active connections have a chance to receive io.EOF
runtime.Gosched()

return pl.l.Close()
}

func (pl *PluginServer) closeAllConns() {
pl.mu.Lock()
defer pl.mu.Unlock()

// Prevent new connections from being accepted
pl.closed = true

for _, conn := range pl.conns {
conn.Close()
}

pl.conns = nil
}

func randomID() string {
Expand All @@ -35,18 +115,6 @@ func randomID() string {
return hex.EncodeToString(b)
}

func accept(listener *net.UnixListener, conn **net.UnixConn) {
go func() {
for {
// ignore error here, if we failed to accept a connection,
// conn is nil and we fallback to previous behavior
*conn, _ = listener.AcceptUnix()
// perform any platform-specific actions on accept (e.g. unlink non-abstract sockets)
onAccept(*conn, listener)
}
}()
}

// ConnectAndWait connects to the socket passed via well-known env var,
// if present, and attempts to read from it until it receives an EOF, at which
// point cb is called.
Expand Down
20 changes: 20 additions & 0 deletions cli-plugins/socket/socket_abstract.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//go:build windows || linux

package socket

import (
"net"
)

func listen(socketname string) (*net.UnixListener, error) {
// Create an abstract socket -- this socket can be opened by name, but is
// not present in the filesystem.
return net.ListenUnix("unix", &net.UnixAddr{
Name: "@" + socketname,
Net: "unix",
})
}

func unlink(listener *net.UnixListener) {
// Do nothing; the socket is not present in the filesystem.
}
19 changes: 0 additions & 19 deletions cli-plugins/socket/socket_darwin.go

This file was deleted.

25 changes: 25 additions & 0 deletions cli-plugins/socket/socket_noabstract.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//go:build !windows && !linux

package socket

import (
"net"
"os"
"path/filepath"
"syscall"
)

func listen(socketname string) (*net.UnixListener, error) {
// Because abstract sockets are unavailable, we create a socket in the
// system temporary directory instead.
return net.ListenUnix("unix", &net.UnixAddr{
Name: filepath.Join(os.TempDir(), socketname),
Net: "unix",
})
}

func unlink(listener *net.UnixListener) {
// unlink(2) is best effort here; if it fails, we may 'leak' a socket
// into the filesystem, but this is unlikely and overall harmless.
_ = syscall.Unlink(listener.Addr().String())
}
20 changes: 0 additions & 20 deletions cli-plugins/socket/socket_nodarwin.go

This file was deleted.

19 changes: 0 additions & 19 deletions cli-plugins/socket/socket_openbsd.go

This file was deleted.

Loading