Skip to content

Commit f070ede

Browse files
committed
shutdown() should cancel the signal handlers installed by start()
motivation: allow easier testing of shutdown hooks changes: * introduce ServiceLifecycle.removeTrap which removes a trap * call ServiceLifecycle.removeTrap when setting up the shutdown hook * make the shutdown hook cleanup into a lifecycle task to ensure correct ordering * add tests * improve logging rdar://89552798
1 parent e63be9e commit f070ede

File tree

6 files changed

+102
-20
lines changed

6 files changed

+102
-20
lines changed

Sources/Lifecycle/Lifecycle.swift

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,22 @@ public protocol LifecycleTask {
3030
var shutdownIfNotStarted: Bool { get }
3131
func start(_ callback: @escaping (Error?) -> Void)
3232
func shutdown(_ callback: @escaping (Error?) -> Void)
33+
var logStart: Bool { get }
34+
var logShutdown: Bool { get }
3335
}
3436

3537
extension LifecycleTask {
3638
public var shutdownIfNotStarted: Bool {
3739
return false
3840
}
41+
42+
public var logStart: Bool {
43+
return true
44+
}
45+
46+
public var logShutdown: Bool {
47+
return true
48+
}
3949
}
4050

4151
// MARK: - LifecycleHandler
@@ -317,9 +327,14 @@ public struct ServiceLifecycle {
317327
self.log("intercepted signal: \(signal)")
318328
self.shutdown()
319329
}, cancelAfterTrap: true)
320-
self.underlying.shutdownGroup.notify(queue: .global()) {
321-
signalSource.cancel()
322-
}
330+
// register cleanup as the last task
331+
self.registerShutdown(label: "\(signal) shutdown hook cleanup", .sync {
332+
// cancel if not already canceled by the trap
333+
if !signalSource.isCancelled {
334+
signalSource.cancel()
335+
ServiceLifecycle.removeTrap(signal: signal)
336+
}
337+
})
323338
}
324339
}
325340

@@ -343,22 +358,34 @@ extension ServiceLifecycle {
343358
public static func trap(signal sig: Signal, handler: @escaping (Signal) -> Void, on queue: DispatchQueue = .global(), cancelAfterTrap: Bool = false) -> DispatchSourceSignal {
344359
// on linux, we can call singal() once per process
345360
self.trappedLock.withLockVoid {
346-
if !trapped.contains(sig.rawValue) {
361+
if !self.trapped.contains(sig.rawValue) {
347362
signal(sig.rawValue, SIG_IGN)
348-
trapped.insert(sig.rawValue)
363+
self.trapped.insert(sig.rawValue)
349364
}
350365
}
351366
let signalSource = DispatchSource.makeSignalSource(signal: sig.rawValue, queue: queue)
352367
signalSource.setEventHandler {
368+
// run handler first
369+
handler(sig)
370+
// then cancel trap if so requested
353371
if cancelAfterTrap {
354372
signalSource.cancel()
373+
self.removeTrap(signal: sig)
355374
}
356-
handler(sig)
357375
}
358376
signalSource.resume()
359377
return signalSource
360378
}
361379

380+
public static func removeTrap(signal sig: Signal) {
381+
self.trappedLock.withLockVoid {
382+
if self.trapped.contains(sig.rawValue) {
383+
signal(sig.rawValue, SIG_DFL)
384+
self.trapped.remove(sig.rawValue)
385+
}
386+
}
387+
}
388+
362389
/// A system signal
363390
public struct Signal: Equatable, CustomStringConvertible {
364391
internal var rawValue: CInt
@@ -413,7 +440,8 @@ extension ServiceLifecycle {
413440
logger: Logger? = nil,
414441
callbackQueue: DispatchQueue = .global(),
415442
shutdownSignal: [Signal]? = [.TERM, .INT],
416-
installBacktrace: Bool = true) {
443+
installBacktrace: Bool = true)
444+
{
417445
self.label = label
418446
self.logger = logger ?? Logger(label: label)
419447
self.callbackQueue = callbackQueue
@@ -433,7 +461,7 @@ struct ShutdownError: Error {
433461
public class ComponentLifecycle: LifecycleTask {
434462
public let label: String
435463
fileprivate let logger: Logger
436-
internal let shutdownGroup = DispatchGroup()
464+
fileprivate let shutdownGroup = DispatchGroup()
437465

438466
private var state = State.idle([])
439467
private let stateLock = Lock()
@@ -596,13 +624,15 @@ public class ComponentLifecycle: LifecycleTask {
596624

597625
private func startTask(on queue: DispatchQueue, tasks: [LifecycleTask], index: Int, callback: @escaping ([LifecycleTask], Error?) -> Void) {
598626
// async barrier
599-
let start = { (callback) -> Void in queue.async { tasks[index].start(callback) } }
600-
let callback = { (index, error) -> Void in queue.async { callback(index, error) } }
627+
let start = { callback in queue.async { tasks[index].start(callback) } }
628+
let callback = { index, error in queue.async { callback(index, error) } }
601629

602630
if index >= tasks.count {
603631
return callback(tasks, nil)
604632
}
605-
self.logger.info("starting tasks [\(tasks[index].label)]")
633+
if tasks[index].logStart {
634+
self.logger.info("starting tasks [\(tasks[index].label)]")
635+
}
606636
let startTime = DispatchTime.now()
607637
start { error in
608638
Timer(label: "\(self.label).\(tasks[index].label).lifecycle.start").recordNanoseconds(DispatchTime.now().uptimeNanoseconds - startTime.uptimeNanoseconds)
@@ -642,14 +672,16 @@ public class ComponentLifecycle: LifecycleTask {
642672

643673
private func shutdownTask(on queue: DispatchQueue, tasks: [LifecycleTask], index: Int, errors: [String: Error]?, callback: @escaping ([String: Error]?) -> Void) {
644674
// async barrier
645-
let shutdown = { (callback) -> Void in queue.async { tasks[index].shutdown(callback) } }
646-
let callback = { (errors) -> Void in queue.async { callback(errors) } }
675+
let shutdown = { callback in queue.async { tasks[index].shutdown(callback) } }
676+
let callback = { errors in queue.async { callback(errors) } }
647677

648678
if index >= tasks.count {
649679
return callback(errors)
650680
}
651681

652-
self.logger.info("stopping tasks [\(tasks[index].label)]")
682+
if tasks[index].logShutdown {
683+
self.logger.info("stopping tasks [\(tasks[index].label)]")
684+
}
653685
let startTime = DispatchTime.now()
654686
shutdown { error in
655687
Timer(label: "\(self.label).\(tasks[index].label).lifecycle.shutdown").recordNanoseconds(DispatchTime.now().uptimeNanoseconds - startTime.uptimeNanoseconds)
@@ -739,12 +771,16 @@ internal struct _LifecycleTask: LifecycleTask {
739771
let shutdownIfNotStarted: Bool
740772
let start: LifecycleHandler
741773
let shutdown: LifecycleHandler
774+
let logStart: Bool
775+
let logShutdown: Bool
742776

743777
init(label: String, shutdownIfNotStarted: Bool? = nil, start: LifecycleHandler, shutdown: LifecycleHandler) {
744778
self.label = label
745779
self.shutdownIfNotStarted = shutdownIfNotStarted ?? start.noop
746780
self.start = start
747781
self.shutdown = shutdown
782+
self.logStart = !start.noop
783+
self.logShutdown = !shutdown.noop
748784
}
749785

750786
func start(_ callback: @escaping (Error?) -> Void) {

Sources/Lifecycle/Locks.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ extension Lock {
8181
/// - Parameter body: The block to execute while holding the lock.
8282
/// - Returns: The value returned by the block.
8383
@inlinable
84-
internal func withLock<T>(_ body: () throws -> T) rethrows -> T {
84+
func withLock<T>(_ body: () throws -> T) rethrows -> T {
8585
self.lock()
8686
defer {
8787
self.unlock()
@@ -91,7 +91,7 @@ extension Lock {
9191

9292
// specialise Void return (for performance)
9393
@inlinable
94-
internal func withLockVoid(_ body: () throws -> Void) rethrows {
94+
func withLockVoid(_ body: () throws -> Void) rethrows {
9595
try self.withLock(body)
9696
}
9797
}

Tests/LifecycleTests/ComponentLifecycleTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ final class ComponentLifecycleTests: XCTestCase {
5353
dispatchPrecondition(condition: .onQueue(.global()))
5454
XCTAssertTrue(startCalls.contains(id))
5555
stopCalls.append(id)
56-
})
56+
})
5757
}
5858
lifecycle.register(items)
5959

@@ -92,7 +92,7 @@ final class ComponentLifecycleTests: XCTestCase {
9292
dispatchPrecondition(condition: .onQueue(testQueue))
9393
XCTAssertTrue(startCalls.contains(id))
9494
stopCalls.append(id)
95-
})
95+
})
9696
}
9797
lifecycle.register(items)
9898

Tests/LifecycleTests/Helpers.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ class GoodItem: LifecycleTask {
2929

3030
init(id: String = UUID().uuidString,
3131
startDelay: Double = Double.random(in: 0.01 ... 0.1),
32-
shutdownDelay: Double = Double.random(in: 0.01 ... 0.1)) {
32+
shutdownDelay: Double = Double.random(in: 0.01 ... 0.1))
33+
{
3334
self.id = id
3435
self.startDelay = startDelay
3536
self.shutdownDelay = shutdownDelay
@@ -72,7 +73,8 @@ class NIOItem {
7273
init(eventLoopGroup: EventLoopGroup,
7374
id: String = UUID().uuidString,
7475
startDelay: Int64 = Int64.random(in: 10 ... 20),
75-
shutdownDelay: Int64 = Int64.random(in: 10 ... 20)) {
76+
shutdownDelay: Int64 = Int64.random(in: 10 ... 20))
77+
{
7678
self.id = id
7779
self.eventLoopGroup = eventLoopGroup
7880
self.startDelay = startDelay

Tests/LifecycleTests/ServiceLifecycleTests+XCTest.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ extension ServiceLifecycleTests {
3535
("testSignalDescription", testSignalDescription),
3636
("testBacktracesInstalledOnce", testBacktracesInstalledOnce),
3737
("testRepeatShutdown", testRepeatShutdown),
38+
("testShutdownCancelSignal", testShutdownCancelSignal),
3839
]
3940
}
4041
}

Tests/LifecycleTests/ServiceLifecycleTests.swift

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,4 +271,47 @@ final class ServiceLifecycleTests: XCTestCase {
271271

272272
XCTAssertEqual(attempts, count)
273273
}
274+
275+
func testShutdownCancelSignal() {
276+
if ProcessInfo.processInfo.environment["SKIP_SIGNAL_TEST"].flatMap(Bool.init) ?? false {
277+
print("skipping testRepeatShutdown")
278+
return
279+
}
280+
281+
struct Service {
282+
static let signal = ServiceLifecycle.Signal.ALRM
283+
284+
let lifecycle: ServiceLifecycle
285+
286+
init() {
287+
self.lifecycle = ServiceLifecycle(configuration: .init(shutdownSignal: [Service.signal]))
288+
self.lifecycle.register(GoodItem())
289+
}
290+
}
291+
292+
let service = Service()
293+
service.lifecycle.start { error in
294+
XCTAssertNil(error, "not expecting error")
295+
kill(getpid(), Service.signal.rawValue)
296+
}
297+
service.lifecycle.wait()
298+
299+
var count = 0
300+
let sync = DispatchGroup()
301+
sync.enter()
302+
let signalSource = ServiceLifecycle.trap(signal: Service.signal, handler: { _ in
303+
count = count + 1 // not thread safe but fine for this purpose
304+
sync.leave()
305+
}, cancelAfterTrap: false)
306+
307+
// since we are removing the hook added by lifecycle on shutdown,
308+
// this will fail unless a new hook is set up as done above
309+
kill(getpid(), Service.signal.rawValue)
310+
311+
XCTAssertEqual(.success, sync.wait(timeout: .now() + 2))
312+
XCTAssertEqual(count, 1)
313+
314+
signalSource.cancel()
315+
ServiceLifecycle.removeTrap(signal: Service.signal)
316+
}
274317
}

0 commit comments

Comments
 (0)