Skip to content

Commit b3aa1e3

Browse files
committed
login1: Add RebootWithContext method
Existing Reboot() method does not allow using context not inspecting D-Bus call errors, which makes it difficult to debug and use. This commit adds new RebootWithContext() method which addresses those shortcomings. Closes #387 Signed-off-by: Mateusz Gozdek <[email protected]>
1 parent 04c09ee commit b3aa1e3

File tree

2 files changed

+242
-0
lines changed

2 files changed

+242
-0
lines changed

login1/dbus.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package login1
1717

1818
import (
19+
"context"
1920
"fmt"
2021
"os"
2122
"strconv"
@@ -59,6 +60,7 @@ type connectionManager interface {
5960
type Caller interface {
6061
// TODO: This method should eventually be removed, as it provides no context support.
6162
Call(method string, flags dbus.Flags, args ...interface{}) *dbus.Call
63+
CallWithContext(ctx context.Context, method string, flags dbus.Flags, args ...interface{}) *dbus.Call
6264
}
6365

6466
// New establishes a connection to the system bus and authenticates.
@@ -347,6 +349,15 @@ func (c *Conn) Reboot(askForAuth bool) {
347349
c.object.Call(dbusInterface+".Reboot", 0, askForAuth)
348350
}
349351

352+
// Reboot asks logind for a reboot using given context, optionally asking for auth.
353+
func (c *Conn) RebootWithContext(ctx context.Context, askForAuth bool) error {
354+
if call := c.object.CallWithContext(ctx, dbusInterface+".Reboot", 0, askForAuth); call.Err != nil {
355+
return fmt.Errorf("calling reboot: %w", call.Err)
356+
}
357+
358+
return nil
359+
}
360+
350361
// Inhibit takes inhibition lock in logind.
351362
func (c *Conn) Inhibit(what, who, why, mode string) (*os.File, error) {
352363
var fd dbus.UnixFD

login1/dbus_test.go

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
package login1_test
1616

1717
import (
18+
"context"
19+
"errors"
1820
"fmt"
1921
"os/user"
2022
"regexp"
@@ -142,6 +144,168 @@ func Test_Creating_new_connection_with_custom_connection(t *testing.T) {
142144
})
143145
}
144146

147+
//nolint:funlen // Many subtests.
148+
func Test_Rebooting_with_context(t *testing.T) {
149+
t.Parallel()
150+
151+
t.Run("calls_login1_reboot_method_on_manager_interface", func(t *testing.T) {
152+
t.Parallel()
153+
154+
rebootCalled := false
155+
156+
askForReboot := false
157+
158+
connectionWithContextCheck := &mockConnection{
159+
ObjectF: func(string, dbus.ObjectPath) dbus.BusObject {
160+
return &mockObject{
161+
CallWithContextF: func(ctx context.Context, method string, flags dbus.Flags, args ...interface{}) *dbus.Call {
162+
rebootCalled = true
163+
164+
expectedMethodName := "org.freedesktop.login1.Manager.Reboot"
165+
166+
if method != expectedMethodName {
167+
t.Fatalf("Expected method %q being called, got %q", expectedMethodName, method)
168+
}
169+
170+
if len(args) != 1 {
171+
t.Fatalf("Expected one argument to call, got %q", args)
172+
}
173+
174+
askedForReboot, ok := args[0].(bool)
175+
if !ok {
176+
t.Fatalf("Expected first argument to be of type %T, got %T", askForReboot, args[0])
177+
}
178+
179+
if askForReboot != askedForReboot {
180+
t.Fatalf("Expected argument to be %t, got %t", askForReboot, askedForReboot)
181+
}
182+
183+
return &dbus.Call{}
184+
},
185+
}
186+
},
187+
}
188+
189+
testConn, err := login1.NewWithConnection(connectionWithContextCheck)
190+
if err != nil {
191+
t.Fatalf("Unexpected error creating connection: %v", err)
192+
}
193+
194+
if err := testConn.RebootWithContext(context.Background(), askForReboot); err != nil {
195+
t.Fatalf("Unexpected error rebooting: %v", err)
196+
}
197+
198+
if !rebootCalled {
199+
t.Fatalf("Expected reboot method call on given D-Bus connection")
200+
}
201+
})
202+
203+
t.Run("asks_for_auth_when_requested", func(t *testing.T) {
204+
t.Parallel()
205+
206+
rebootCalled := false
207+
208+
askForReboot := true
209+
210+
connectionWithContextCheck := &mockConnection{
211+
ObjectF: func(string, dbus.ObjectPath) dbus.BusObject {
212+
return &mockObject{
213+
CallWithContextF: func(ctx context.Context, method string, flags dbus.Flags, args ...interface{}) *dbus.Call {
214+
rebootCalled = true
215+
216+
if len(args) != 1 {
217+
t.Fatalf("Expected one argument to call, got %q", args)
218+
}
219+
220+
askedForReboot, ok := args[0].(bool)
221+
if !ok {
222+
t.Fatalf("Expected first argument to be of type %T, got %T", askForReboot, args[0])
223+
}
224+
225+
if askForReboot != askedForReboot {
226+
t.Fatalf("Expected argument to be %t, got %t", askForReboot, askedForReboot)
227+
}
228+
229+
return &dbus.Call{}
230+
},
231+
}
232+
},
233+
}
234+
235+
testConn, err := login1.NewWithConnection(connectionWithContextCheck)
236+
if err != nil {
237+
t.Fatalf("Unexpected error creating connection: %v", err)
238+
}
239+
240+
if err := testConn.RebootWithContext(context.Background(), askForReboot); err != nil {
241+
t.Fatalf("Unexpected error rebooting: %v", err)
242+
}
243+
244+
if !rebootCalled {
245+
t.Fatalf("Expected reboot method call on given D-Bus connection")
246+
}
247+
})
248+
249+
t.Run("use_given_context_for_D-Bus_call", func(t *testing.T) {
250+
t.Parallel()
251+
252+
testKey := struct{}{}
253+
expectedValue := "bar"
254+
255+
ctx := context.WithValue(context.Background(), testKey, expectedValue)
256+
257+
connectionWithContextCheck := &mockConnection{
258+
ObjectF: func(string, dbus.ObjectPath) dbus.BusObject {
259+
return &mockObject{
260+
CallWithContextF: func(ctx context.Context, method string, flags dbus.Flags, args ...interface{}) *dbus.Call {
261+
if val := ctx.Value(testKey); val != expectedValue {
262+
t.Fatalf("Got unexpected context on call")
263+
}
264+
265+
return &dbus.Call{}
266+
},
267+
}
268+
},
269+
}
270+
271+
testConn, err := login1.NewWithConnection(connectionWithContextCheck)
272+
if err != nil {
273+
t.Fatalf("Unexpected error creating connection: %v", err)
274+
}
275+
276+
if err := testConn.RebootWithContext(ctx, false); err != nil {
277+
t.Fatalf("Unexpected error rebooting: %v", err)
278+
}
279+
})
280+
281+
t.Run("returns_error_when_D-Bus_call_fails", func(t *testing.T) {
282+
t.Parallel()
283+
284+
expectedError := fmt.Errorf("reboot error")
285+
286+
connectionWithFailingObjectCall := &mockConnection{
287+
ObjectF: func(string, dbus.ObjectPath) dbus.BusObject {
288+
return &mockObject{
289+
CallWithContextF: func(ctx context.Context, method string, flags dbus.Flags, args ...interface{}) *dbus.Call {
290+
return &dbus.Call{
291+
Err: expectedError,
292+
}
293+
},
294+
}
295+
},
296+
}
297+
298+
testConn, err := login1.NewWithConnection(connectionWithFailingObjectCall)
299+
if err != nil {
300+
t.Fatalf("Unexpected error creating connection: %v", err)
301+
}
302+
303+
if err := testConn.RebootWithContext(context.Background(), false); !errors.Is(err, expectedError) {
304+
t.Fatalf("Unexpected error rebooting: %v", err)
305+
}
306+
})
307+
}
308+
145309
// mockConnection is a test helper for mocking dbus.Conn.
146310
type mockConnection struct {
147311
ObjectF func(string, dbus.ObjectPath) dbus.BusObject
@@ -178,3 +342,70 @@ func (m *mockConnection) Close() error {
178342
func (m *mockConnection) BusObject() dbus.BusObject {
179343
return nil
180344
}
345+
346+
// mockObject is a mock of dbus.BusObject.
347+
type mockObject struct {
348+
CallWithContextF func(context.Context, string, dbus.Flags, ...interface{}) *dbus.Call
349+
CallF func(string, dbus.Flags, ...interface{}) *dbus.Call
350+
}
351+
352+
// mockObject must implement dbus.BusObject to be usable for other packages in tests, though not
353+
// all methods must actually be mockable. See https://github.com/dbus/dbus/issues/252 for details.
354+
var _ dbus.BusObject = &mockObject{}
355+
356+
// CallWithContext ...
357+
//
358+
//nolint:lll // Upstream signature, can't do much with that.
359+
func (m *mockObject) CallWithContext(ctx context.Context, method string, flags dbus.Flags, args ...interface{}) *dbus.Call {
360+
if m.CallWithContextF == nil {
361+
return &dbus.Call{}
362+
}
363+
364+
return m.CallWithContextF(ctx, method, flags, args...)
365+
}
366+
367+
// Call ...
368+
func (m *mockObject) Call(method string, flags dbus.Flags, args ...interface{}) *dbus.Call {
369+
if m.CallF == nil {
370+
return &dbus.Call{}
371+
}
372+
373+
return m.CallF(method, flags, args...)
374+
}
375+
376+
// Go ...
377+
func (m *mockObject) Go(method string, flags dbus.Flags, ch chan *dbus.Call, args ...interface{}) *dbus.Call {
378+
return &dbus.Call{}
379+
}
380+
381+
// GoWithContext ...
382+
//
383+
//nolint:lll // Upstream signature, can't do much with that.
384+
func (m *mockObject) GoWithContext(ctx context.Context, method string, flags dbus.Flags, ch chan *dbus.Call, args ...interface{}) *dbus.Call {
385+
return &dbus.Call{}
386+
}
387+
388+
// AddMatchSignal ...
389+
func (m *mockObject) AddMatchSignal(iface, member string, options ...dbus.MatchOption) *dbus.Call {
390+
return &dbus.Call{}
391+
}
392+
393+
// RemoveMatchSignal ...
394+
func (m *mockObject) RemoveMatchSignal(iface, member string, options ...dbus.MatchOption) *dbus.Call {
395+
return &dbus.Call{}
396+
}
397+
398+
// GetProperty ...
399+
func (m *mockObject) GetProperty(p string) (dbus.Variant, error) { return dbus.Variant{}, nil }
400+
401+
// StoreProperty ...
402+
func (m *mockObject) StoreProperty(p string, value interface{}) error { return nil }
403+
404+
// SetProperty ...
405+
func (m *mockObject) SetProperty(p string, v interface{}) error { return nil }
406+
407+
// Destination ...
408+
func (m *mockObject) Destination() string { return "" }
409+
410+
// Path ...
411+
func (m *mockObject) Path() dbus.ObjectPath { return "" }

0 commit comments

Comments
 (0)