Skip to content

Commit 244bc00

Browse files
committed
shutdown() should cancel the signal handlers installed by start()
motivation: allow easier testing of shutdown hooks changes: * introduce ServiceLifecycle.removeTrap which removes a trap * call ServiceLifecycle.removeTrap when setting up the shutdown hook * make the shutdown hook cleanup into a lifecycle task to ensure correct ordering * add tests * improve logging rdar://89552798
1 parent e63be9e commit 244bc00

File tree

5 files changed

+96
-17
lines changed

5 files changed

+96
-17
lines changed

Sources/Lifecycle/Lifecycle.swift

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,22 @@ public protocol LifecycleTask {
3030
var shutdownIfNotStarted: Bool { get }
3131
func start(_ callback: @escaping (Error?) -> Void)
3232
func shutdown(_ callback: @escaping (Error?) -> Void)
33+
var logStart: Bool { get }
34+
var logShutdown: Bool { get }
3335
}
3436

3537
extension LifecycleTask {
3638
public var shutdownIfNotStarted: Bool {
3739
return false
3840
}
41+
42+
public var logStart: Bool {
43+
return true
44+
}
45+
46+
public var logShutdown: Bool {
47+
return true
48+
}
3949
}
4050

4151
// MARK: - LifecycleHandler
@@ -317,9 +327,14 @@ public struct ServiceLifecycle {
317327
self.log("intercepted signal: \(signal)")
318328
self.shutdown()
319329
}, cancelAfterTrap: true)
320-
self.underlying.shutdownGroup.notify(queue: .global()) {
321-
signalSource.cancel()
322-
}
330+
// register cleanup as the last task
331+
self.registerShutdown(label: "\(signal) shutdown hook cleanup", .sync {
332+
// cancel if not already canceled by the trap
333+
if !signalSource.isCancelled {
334+
signalSource.cancel()
335+
ServiceLifecycle.removeTrap(signal: signal)
336+
}
337+
})
323338
}
324339
}
325340

@@ -343,22 +358,34 @@ extension ServiceLifecycle {
343358
public static func trap(signal sig: Signal, handler: @escaping (Signal) -> Void, on queue: DispatchQueue = .global(), cancelAfterTrap: Bool = false) -> DispatchSourceSignal {
344359
// on linux, we can call singal() once per process
345360
self.trappedLock.withLockVoid {
346-
if !trapped.contains(sig.rawValue) {
361+
if !self.trapped.contains(sig.rawValue) {
347362
signal(sig.rawValue, SIG_IGN)
348-
trapped.insert(sig.rawValue)
363+
self.trapped.insert(sig.rawValue)
349364
}
350365
}
351366
let signalSource = DispatchSource.makeSignalSource(signal: sig.rawValue, queue: queue)
352367
signalSource.setEventHandler {
368+
// run handler first
369+
handler(sig)
370+
// then cancel trap if so requested
353371
if cancelAfterTrap {
354372
signalSource.cancel()
373+
self.removeTrap(signal: sig)
355374
}
356-
handler(sig)
357375
}
358376
signalSource.resume()
359377
return signalSource
360378
}
361379

380+
public static func removeTrap(signal sig: Signal) {
381+
self.trappedLock.withLockVoid {
382+
if self.trapped.contains(sig.rawValue) {
383+
signal(sig.rawValue, SIG_DFL)
384+
self.trapped.remove(sig.rawValue)
385+
}
386+
}
387+
}
388+
362389
/// A system signal
363390
public struct Signal: Equatable, CustomStringConvertible {
364391
internal var rawValue: CInt
@@ -433,7 +460,7 @@ struct ShutdownError: Error {
433460
public class ComponentLifecycle: LifecycleTask {
434461
public let label: String
435462
fileprivate let logger: Logger
436-
internal let shutdownGroup = DispatchGroup()
463+
fileprivate let shutdownGroup = DispatchGroup()
437464

438465
private var state = State.idle([])
439466
private let stateLock = Lock()
@@ -596,13 +623,15 @@ public class ComponentLifecycle: LifecycleTask {
596623

597624
private func startTask(on queue: DispatchQueue, tasks: [LifecycleTask], index: Int, callback: @escaping ([LifecycleTask], Error?) -> Void) {
598625
// async barrier
599-
let start = { (callback) -> Void in queue.async { tasks[index].start(callback) } }
600-
let callback = { (index, error) -> Void in queue.async { callback(index, error) } }
626+
let start = { callback in queue.async { tasks[index].start(callback) } }
627+
let callback = { index, error in queue.async { callback(index, error) } }
601628

602629
if index >= tasks.count {
603630
return callback(tasks, nil)
604631
}
605-
self.logger.info("starting tasks [\(tasks[index].label)]")
632+
if tasks[index].logStart {
633+
self.logger.info("starting tasks [\(tasks[index].label)]")
634+
}
606635
let startTime = DispatchTime.now()
607636
start { error in
608637
Timer(label: "\(self.label).\(tasks[index].label).lifecycle.start").recordNanoseconds(DispatchTime.now().uptimeNanoseconds - startTime.uptimeNanoseconds)
@@ -642,14 +671,16 @@ public class ComponentLifecycle: LifecycleTask {
642671

643672
private func shutdownTask(on queue: DispatchQueue, tasks: [LifecycleTask], index: Int, errors: [String: Error]?, callback: @escaping ([String: Error]?) -> Void) {
644673
// async barrier
645-
let shutdown = { (callback) -> Void in queue.async { tasks[index].shutdown(callback) } }
646-
let callback = { (errors) -> Void in queue.async { callback(errors) } }
674+
let shutdown = { callback in queue.async { tasks[index].shutdown(callback) } }
675+
let callback = { errors in queue.async { callback(errors) } }
647676

648677
if index >= tasks.count {
649678
return callback(errors)
650679
}
651680

652-
self.logger.info("stopping tasks [\(tasks[index].label)]")
681+
if tasks[index].logShutdown {
682+
self.logger.info("stopping tasks [\(tasks[index].label)]")
683+
}
653684
let startTime = DispatchTime.now()
654685
shutdown { error in
655686
Timer(label: "\(self.label).\(tasks[index].label).lifecycle.shutdown").recordNanoseconds(DispatchTime.now().uptimeNanoseconds - startTime.uptimeNanoseconds)
@@ -739,12 +770,16 @@ internal struct _LifecycleTask: LifecycleTask {
739770
let shutdownIfNotStarted: Bool
740771
let start: LifecycleHandler
741772
let shutdown: LifecycleHandler
773+
let logStart: Bool
774+
let logShutdown: Bool
742775

743776
init(label: String, shutdownIfNotStarted: Bool? = nil, start: LifecycleHandler, shutdown: LifecycleHandler) {
744777
self.label = label
745778
self.shutdownIfNotStarted = shutdownIfNotStarted ?? start.noop
746779
self.start = start
747780
self.shutdown = shutdown
781+
self.logStart = !start.noop
782+
self.logShutdown = !shutdown.noop
748783
}
749784

750785
func start(_ callback: @escaping (Error?) -> Void) {

Sources/Lifecycle/Locks.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ extension Lock {
8181
/// - Parameter body: The block to execute while holding the lock.
8282
/// - Returns: The value returned by the block.
8383
@inlinable
84-
internal func withLock<T>(_ body: () throws -> T) rethrows -> T {
84+
func withLock<T>(_ body: () throws -> T) rethrows -> T {
8585
self.lock()
8686
defer {
8787
self.unlock()
@@ -91,7 +91,7 @@ extension Lock {
9191

9292
// specialise Void return (for performance)
9393
@inlinable
94-
internal func withLockVoid(_ body: () throws -> Void) rethrows {
94+
func withLockVoid(_ body: () throws -> Void) rethrows {
9595
try self.withLock(body)
9696
}
9797
}

Tests/LifecycleTests/ComponentLifecycleTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ final class ComponentLifecycleTests: XCTestCase {
5353
dispatchPrecondition(condition: .onQueue(.global()))
5454
XCTAssertTrue(startCalls.contains(id))
5555
stopCalls.append(id)
56-
})
56+
})
5757
}
5858
lifecycle.register(items)
5959

@@ -92,7 +92,7 @@ final class ComponentLifecycleTests: XCTestCase {
9292
dispatchPrecondition(condition: .onQueue(testQueue))
9393
XCTAssertTrue(startCalls.contains(id))
9494
stopCalls.append(id)
95-
})
95+
})
9696
}
9797
lifecycle.register(items)
9898

Tests/LifecycleTests/ServiceLifecycleTests+XCTest.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ extension ServiceLifecycleTests {
3535
("testSignalDescription", testSignalDescription),
3636
("testBacktracesInstalledOnce", testBacktracesInstalledOnce),
3737
("testRepeatShutdown", testRepeatShutdown),
38+
("testShutdownCancelSignal", testShutdownCancelSignal),
3839
]
3940
}
4041
}

Tests/LifecycleTests/ServiceLifecycleTests.swift

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,4 +271,47 @@ final class ServiceLifecycleTests: XCTestCase {
271271

272272
XCTAssertEqual(attempts, count)
273273
}
274+
275+
func testShutdownCancelSignal() {
276+
if ProcessInfo.processInfo.environment["SKIP_SIGNAL_TEST"].flatMap(Bool.init) ?? false {
277+
print("skipping testRepeatShutdown")
278+
return
279+
}
280+
281+
struct Service {
282+
static let signal = ServiceLifecycle.Signal.ALRM
283+
284+
let lifecycle: ServiceLifecycle
285+
286+
init() {
287+
self.lifecycle = ServiceLifecycle(configuration: .init(shutdownSignal: [Service.signal]))
288+
self.lifecycle.register(GoodItem())
289+
}
290+
}
291+
292+
let service = Service()
293+
service.lifecycle.start { error in
294+
XCTAssertNil(error, "not expecting error")
295+
kill(getpid(), Service.signal.rawValue)
296+
}
297+
service.lifecycle.wait()
298+
299+
var count = 0
300+
let sync = DispatchGroup()
301+
sync.enter()
302+
let signalSource = ServiceLifecycle.trap(signal: Service.signal, handler: { _ in
303+
count = count + 1 // not thread safe but fine for this purpose
304+
sync.leave()
305+
}, cancelAfterTrap: false)
306+
307+
// since we are removing the hook added by lifecycle on shutdown,
308+
// this will fail unless a new hook is set up as done above
309+
kill(getpid(), Service.signal.rawValue)
310+
311+
XCTAssertEqual(.success, sync.wait(timeout: .now() + 2))
312+
XCTAssertEqual(count, 1)
313+
314+
signalSource.cancel()
315+
ServiceLifecycle.removeTrap(signal: Service.signal)
316+
}
274317
}

0 commit comments

Comments
 (0)