diff --git a/login1/dbus.go b/login1/dbus.go index c7df1f18..e98124bd 100644 --- a/login1/dbus.go +++ b/login1/dbus.go @@ -36,8 +36,33 @@ const ( // Conn is a connection to systemds dbus endpoint. type Conn struct { - conn *dbus.Conn - object dbus.BusObject + conn Connection + connManager connectionManager + object Caller +} + +// Connection describes functionality required from a given D-Bus connection. +type Connection interface { + Object(string, dbus.ObjectPath) dbus.BusObject + Signal(ch chan<- *dbus.Signal) + Connected() bool + AddMatchSignalContext(ctx context.Context, options ...dbus.MatchOption) error +} + +// connectionManager explicitly wraps dependencies on established D-Bus connection. +type connectionManager interface { + Hello() error + Auth(authMethods []dbus.Auth) error + Close() error + + Connection +} + +// Caller describes required functionality from D-Bus object. +type Caller interface { + // TODO: This method should eventually be removed, as it provides no context support. + Call(method string, flags dbus.Flags, args ...any) *dbus.Call + CallWithContext(ctx context.Context, method string, flags dbus.Flags, args ...any) *dbus.Call } // New establishes a connection to the system bus and authenticates. @@ -51,14 +76,26 @@ func New() (*Conn, error) { return c, nil } +// NewWithConnection creates new login1 client using given D-Bus connection. +func NewWithConnection(connection Connection) (*Conn, error) { + if connection == nil { + return nil, errors.New("no connection given") + } + + return &Conn{ + conn: connection, + object: connection.Object(dbusDest, dbusPath), + }, nil +} + // Close closes the dbus connection func (c *Conn) Close() { if c == nil { return } - if c.conn != nil { - c.conn.Close() + if c.conn != nil && c.connManager != nil { + c.connManager.Close() } } @@ -69,7 +106,7 @@ func (c *Conn) Connected() bool { func (c *Conn) initConnection() error { var err error - c.conn, err = dbus.SystemBusPrivate() + c.connManager, err = dbus.SystemBusPrivate() if err != nil { return err } @@ -79,18 +116,19 @@ func (c *Conn) initConnection() error { // libc) methods := []dbus.Auth{dbus.AuthExternal(strconv.Itoa(os.Getuid()))} - err = c.conn.Auth(methods) + err = c.connManager.Auth(methods) if err != nil { - c.conn.Close() + c.connManager.Close() return err } - err = c.conn.Hello() + err = c.connManager.Hello() if err != nil { - c.conn.Close() + c.connManager.Close() return err } + c.conn = c.connManager c.object = c.conn.Object("org.freedesktop.login1", dbus.ObjectPath(dbusPath)) return nil @@ -354,14 +392,23 @@ func (c *Conn) Inhibit(what, who, why, mode string) (*os.File, error) { return os.NewFile(uintptr(fd), "inhibit"), nil } -// Subscribe to signals on the logind dbus -func (c *Conn) Subscribe(members ...string) chan *dbus.Signal { +// SubscribeWithContext subscribes to signals on the logind dbus. If adding match signals fails, an error is returned. +func (c *Conn) SubscribeWithContext(ctx context.Context, members ...string) (chan *dbus.Signal, error) { for _, member := range members { - c.conn.BusObject().Call("org.freedesktop.DBus.AddMatch", 0, - fmt.Sprintf("type='signal',interface='org.freedesktop.login1.Manager',member='%s'", member)) + if err := c.conn.AddMatchSignalContext(ctx, dbus.WithMatchInterface(dbusManagerInterface), dbus.WithMatchMember(member)); err != nil { + return nil, fmt.Errorf("adding match for signal %s: %w", member, err) + } } ch := make(chan *dbus.Signal, 10) c.conn.Signal(ch) + return ch, nil +} + +// Subscribe subscribes to signals on the logind dbus. If adding match signals fails, errors are ignored. +// +// Deprecated: use SubscribeWithContext instead. +func (c *Conn) Subscribe(members ...string) chan *dbus.Signal { + ch, _ := c.SubscribeWithContext(context.Background(), members...) return ch } diff --git a/login1/dbus_test.go b/login1/dbus_test.go index 2f8be4f2..c409bc26 100644 --- a/login1/dbus_test.go +++ b/login1/dbus_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package login1 +package login1_test import ( "context" @@ -21,18 +21,21 @@ import ( "regexp" "testing" "time" + + "github.com/godbus/dbus/v5" + + "github.com/coreos/go-systemd/v22/login1" ) // TestNew ensures that New() works without errors. func TestNew(t *testing.T) { - _, err := New() - if err != nil { + if _, err := login1.New(); err != nil { t.Fatal(err) } } func TestListSessions(t *testing.T) { - c, err := New() + c, err := login1.New() if err != nil { t.Fatal(err) } @@ -61,7 +64,7 @@ func TestListSessions(t *testing.T) { } func TestListUsers(t *testing.T) { - c, err := New() + c, err := login1.New() if err != nil { t.Fatal(err) } @@ -90,7 +93,7 @@ func TestListUsers(t *testing.T) { } func TestConn_GetSessionPropertiesContext(t *testing.T) { - c, err := New() + c, err := login1.New() if err != nil { t.Fatal(err) } @@ -117,7 +120,7 @@ func TestConn_GetSessionPropertiesContext(t *testing.T) { } func TestConn_GetSessionPropertyContext(t *testing.T) { - c, err := New() + c, err := login1.New() if err != nil { t.Fatal(err) } @@ -141,7 +144,7 @@ func TestConn_GetSessionPropertyContext(t *testing.T) { } func TestConn_GetUserPropertiesContext(t *testing.T) { - c, err := New() + c, err := login1.New() if err != nil { t.Fatal(err) } @@ -168,7 +171,7 @@ func TestConn_GetUserPropertiesContext(t *testing.T) { } func TestConn_GetUserPropertyContext(t *testing.T) { - c, err := New() + c, err := login1.New() if err != nil { t.Fatal(err) } @@ -190,3 +193,246 @@ func TestConn_GetUserPropertyContext(t *testing.T) { }() } } + +func Test_Creating_new_connection_with_custom_connection(t *testing.T) { + t.Parallel() + + t.Run("connects_to_global_login1_path_and_interface", func(t *testing.T) { + t.Parallel() + + objectConstructorCalled := false + + connectionWithContextCheck := &mockConnection{ + ObjectF: func(dest string, path dbus.ObjectPath) dbus.BusObject { + objectConstructorCalled = true + + expectedDest := "org.freedesktop.login1" + + if dest != expectedDest { + t.Fatalf("Expected D-Bus destination %q, got %q", expectedDest, dest) + } + + expectedPath := dbus.ObjectPath("/org/freedesktop/login1") + + if path != expectedPath { + t.Fatalf("Expected D-Bus path %q, got %q", expectedPath, path) + } + + return nil + }, + } + + if _, err := login1.NewWithConnection(connectionWithContextCheck); err != nil { + t.Fatalf("Unexpected error creating connection: %v", err) + } + + if !objectConstructorCalled { + t.Fatalf("Expected object constructor to be called") + } + }) + + t.Run("returns_error_when_no_custom_connection_is_given", func(t *testing.T) { + t.Parallel() + + testConn, err := login1.NewWithConnection(nil) + if err == nil { + t.Fatalf("Expected error creating connection with no connector") + } + + if testConn != nil { + t.Fatalf("Expected connection to be nil when New returns error") + } + }) +} + +func Test_Subscribing_to_signals(t *testing.T) { + t.Parallel() + + t.Run("subscribes_to", func(t *testing.T) { + t.Parallel() + + t.Run("login1_interface", func(t *testing.T) { + t.Parallel() + + addMatchCalled := false + + connectionWithInterfaceCheck := &mockConnection{ + AddMatchSignalContextF: func(ctx context.Context, options ...dbus.MatchOption) error { + addMatchCalled = true + if len(options) < 2 { + t.Fatalf("Expected at least 2 match options (interface and member)") + } + return nil + }, + } + + conn, err := login1.NewWithConnection(connectionWithInterfaceCheck) + if err != nil { + t.Fatalf("Unexpected error creating connection: %v", err) + } + + if _, err := conn.SubscribeWithContext(t.Context(), "SessionNew"); err != nil { + t.Fatalf("Unexpected error subscribing to signals: %v", err) + } + + if !addMatchCalled { + t.Fatalf("Expected AddMatchSignalContext to be called") + } + }) + + t.Run("for_all_given_members", func(t *testing.T) { + t.Parallel() + + callCount := 0 + + connectionWithMemberCheck := &mockConnection{ + AddMatchSignalContextF: func(ctx context.Context, options ...dbus.MatchOption) error { + callCount++ + return nil + }, + } + + conn, err := login1.NewWithConnection(connectionWithMemberCheck) + if err != nil { + t.Fatalf("Unexpected error creating connection: %v", err) + } + + expectedMembers := []string{"SessionNew", "SessionRemoved", "UserNew"} + if _, err := conn.SubscribeWithContext(t.Context(), expectedMembers...); err != nil { + t.Fatalf("Unexpected error subscribing to signals: %v", err) + } + + if callCount != len(expectedMembers) { + t.Fatalf("Expected AddMatchSignalContext to be called %d times, got %d", len(expectedMembers), callCount) + } + }) + }) + + t.Run("passes_received_signals_to_channel", func(t *testing.T) { + t.Parallel() + + signalChannelProvided := false + + connectionWithSignalCheck := &mockConnection{ + SignalF: func(ch chan<- *dbus.Signal) { + signalChannelProvided = ch != nil + // Send a test signal to verify the channel works + go func() { + ch <- &dbus.Signal{ + Sender: "org.freedesktop.login1", + Path: "/org/freedesktop/login1", + Name: "org.freedesktop.login1.Manager.SessionNew", + Body: []any{"session1", dbus.ObjectPath("/org/freedesktop/login1/session/session1")}, + } + }() + }, + } + + conn, err := login1.NewWithConnection(connectionWithSignalCheck) + if err != nil { + t.Fatalf("Unexpected error creating connection: %v", err) + } + + ch, err := conn.SubscribeWithContext(t.Context(), "SessionNew") + if err != nil { + t.Fatalf("Unexpected error subscribing to signals: %v", err) + } + + if ch == nil { + t.Fatalf("Expected signal channel to be returned") + } + + if !signalChannelProvided { + t.Fatalf("Expected signal channel to be passed to connection") + } + + // Verify we can receive signals through the channel + ctx, cancel := context.WithTimeout(t.Context(), time.Second*3) + defer cancel() + + select { + case sig := <-ch: + if sig == nil { + t.Fatalf("Received nil signal") + } + if sig.Name != "org.freedesktop.login1.Manager.SessionNew" { + t.Fatalf("Expected signal name %q, got %q", "org.freedesktop.login1.Manager.SessionNew", sig.Name) + } + case <-ctx.Done(): + t.Fatalf("Timeout waiting for signal") + } + }) + + t.Run("returns_error_when_adding_match_signal_fails", func(t *testing.T) { + t.Parallel() + + expectedError := fmt.Errorf("failed to add match") + + connectionWithError := &mockConnection{ + AddMatchSignalContextF: func(ctx context.Context, options ...dbus.MatchOption) error { + return expectedError + }, + } + + conn, err := login1.NewWithConnection(connectionWithError) + if err != nil { + t.Fatalf("Unexpected error creating connection: %v", err) + } + + _, err = conn.SubscribeWithContext(t.Context(), "SessionNew") + if err == nil { + t.Fatalf("Expected error when adding match signal fails") + } + }) +} + +// mockConnection is a test helper for mocking dbus.Conn. +type mockConnection struct { + ObjectF func(string, dbus.ObjectPath) dbus.BusObject + AddMatchSignalContextF func(context.Context, ...dbus.MatchOption) error + SignalF func(chan<- *dbus.Signal) +} + +// AddMatchSignalContext ... +func (m *mockConnection) AddMatchSignalContext(ctx context.Context, options ...dbus.MatchOption) error { + if m.AddMatchSignalContextF != nil { + return m.AddMatchSignalContextF(ctx, options...) + } + return nil +} + +// Auth ... +func (m *mockConnection) Auth(authMethods []dbus.Auth) error { + return nil +} + +// Hello ... +func (m *mockConnection) Hello() error { + return nil +} + +// Signal ... +func (m *mockConnection) Signal(ch chan<- *dbus.Signal) { + if m.SignalF != nil { + m.SignalF(ch) + } +} + +// Object ... +func (m *mockConnection) Object(dest string, path dbus.ObjectPath) dbus.BusObject { + if m.ObjectF == nil { + return nil + } + + return m.ObjectF(dest, path) +} + +// Close ... +func (m *mockConnection) Close() error { + return nil +} + +// Connected ... +func (m *mockConnection) Connected() bool { + return true +}