Skip to content

Commit 16a427e

Browse files
committed
shutdown() should cancel the signal handlers installed by start()
motivation: allow easier testing of shutdown hooks changes: * introduce ServiceLifecycle.removeTrap which removes a trap * call ServiceLifecycle.removeTrap when setting up the shutdown hook * make the shutdown hook cleanup into a lifecycle task to ensure correct ordering * add tests * improve logging rdar://89552798
1 parent e63be9e commit 16a427e

File tree

7 files changed

+133
-51
lines changed

7 files changed

+133
-51
lines changed

Sources/Lifecycle/Lifecycle.swift

Lines changed: 71 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,22 @@ public protocol LifecycleTask {
3030
var shutdownIfNotStarted: Bool { get }
3131
func start(_ callback: @escaping (Error?) -> Void)
3232
func shutdown(_ callback: @escaping (Error?) -> Void)
33+
var logStart: Bool { get }
34+
var logShutdown: Bool { get }
3335
}
3436

35-
extension LifecycleTask {
36-
public var shutdownIfNotStarted: Bool {
37+
public extension LifecycleTask {
38+
var shutdownIfNotStarted: Bool {
3739
return false
3840
}
41+
42+
var logStart: Bool {
43+
return true
44+
}
45+
46+
var logShutdown: Bool {
47+
return true
48+
}
3949
}
4050

4151
// MARK: - LifecycleHandler
@@ -97,8 +107,8 @@ public struct LifecycleHandler {
97107

98108
#if canImport(_Concurrency) && compiler(>=5.5.2)
99109
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
100-
extension LifecycleHandler {
101-
public init(_ handler: @escaping () async throws -> Void) {
110+
public extension LifecycleHandler {
111+
init(_ handler: @escaping () async throws -> Void) {
102112
self = LifecycleHandler { callback in
103113
Task {
104114
do {
@@ -111,7 +121,7 @@ extension LifecycleHandler {
111121
}
112122
}
113123

114-
public static func async(_ handler: @escaping () async throws -> Void) -> LifecycleHandler {
124+
static func async(_ handler: @escaping () async throws -> Void) -> LifecycleHandler {
115125
return LifecycleHandler(handler)
116126
}
117127
}
@@ -161,8 +171,8 @@ public struct LifecycleStartHandler<State> {
161171

162172
#if canImport(_Concurrency) && compiler(>=5.5.2)
163173
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
164-
extension LifecycleStartHandler {
165-
public init(_ handler: @escaping () async throws -> State) {
174+
public extension LifecycleStartHandler {
175+
init(_ handler: @escaping () async throws -> State) {
166176
self = LifecycleStartHandler { callback in
167177
Task {
168178
do {
@@ -175,7 +185,7 @@ extension LifecycleStartHandler {
175185
}
176186
}
177187

178-
public static func async(_ handler: @escaping () async throws -> State) -> LifecycleStartHandler {
188+
static func async(_ handler: @escaping () async throws -> State) -> LifecycleStartHandler {
179189
return LifecycleStartHandler(handler)
180190
}
181191
}
@@ -223,8 +233,8 @@ public struct LifecycleShutdownHandler<State> {
223233

224234
#if canImport(_Concurrency) && compiler(>=5.5.2)
225235
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
226-
extension LifecycleShutdownHandler {
227-
public init(_ handler: @escaping (State) async throws -> Void) {
236+
public extension LifecycleShutdownHandler {
237+
init(_ handler: @escaping (State) async throws -> Void) {
228238
self = LifecycleShutdownHandler { state, callback in
229239
Task {
230240
do {
@@ -237,7 +247,7 @@ extension LifecycleShutdownHandler {
237247
}
238248
}
239249

240-
public static func async(_ handler: @escaping (State) async throws -> Void) -> LifecycleShutdownHandler {
250+
static func async(_ handler: @escaping (State) async throws -> Void) -> LifecycleShutdownHandler {
241251
return LifecycleShutdownHandler(handler)
242252
}
243253
}
@@ -317,9 +327,14 @@ public struct ServiceLifecycle {
317327
self.log("intercepted signal: \(signal)")
318328
self.shutdown()
319329
}, cancelAfterTrap: true)
320-
self.underlying.shutdownGroup.notify(queue: .global()) {
321-
signalSource.cancel()
322-
}
330+
// register cleanup as the last task
331+
self.registerShutdown(label: "\(signal) shutdown hook cleanup", .sync {
332+
// cancel if not already canceled by the trap
333+
if !signalSource.isCancelled {
334+
signalSource.cancel()
335+
ServiceLifecycle.removeTrap(signal: signal)
336+
}
337+
})
323338
}
324339
}
325340

@@ -328,7 +343,7 @@ public struct ServiceLifecycle {
328343
}
329344
}
330345

331-
extension ServiceLifecycle {
346+
public extension ServiceLifecycle {
332347
private static var trapped: Set<Int32> = []
333348
private static let trappedLock = Lock()
334349

@@ -340,27 +355,39 @@ extension ServiceLifecycle {
340355
/// - on: DispatchQueue to run the signal handler on (default global dispatch queue)
341356
/// - cancelAfterTrap: Defaults to false, which means the signal handler can be run multiple times. If true, the DispatchSignalSource will be cancelled after being trapped once.
342357
/// - returns: a `DispatchSourceSignal` for the given trap. The source must be cancelled by the caller.
343-
public static func trap(signal sig: Signal, handler: @escaping (Signal) -> Void, on queue: DispatchQueue = .global(), cancelAfterTrap: Bool = false) -> DispatchSourceSignal {
358+
static func trap(signal sig: Signal, handler: @escaping (Signal) -> Void, on queue: DispatchQueue = .global(), cancelAfterTrap: Bool = false) -> DispatchSourceSignal {
344359
// on linux, we can call singal() once per process
345360
self.trappedLock.withLockVoid {
346-
if !trapped.contains(sig.rawValue) {
361+
if !self.trapped.contains(sig.rawValue) {
347362
signal(sig.rawValue, SIG_IGN)
348-
trapped.insert(sig.rawValue)
363+
self.trapped.insert(sig.rawValue)
349364
}
350365
}
351366
let signalSource = DispatchSource.makeSignalSource(signal: sig.rawValue, queue: queue)
352367
signalSource.setEventHandler {
368+
// run handler first
369+
handler(sig)
370+
// then cancel trap if so requested
353371
if cancelAfterTrap {
354372
signalSource.cancel()
373+
self.removeTrap(signal: sig)
355374
}
356-
handler(sig)
357375
}
358376
signalSource.resume()
359377
return signalSource
360378
}
361379

380+
static func removeTrap(signal sig: Signal) {
381+
self.trappedLock.withLockVoid {
382+
if self.trapped.contains(sig.rawValue) {
383+
signal(sig.rawValue, SIG_DFL)
384+
self.trapped.remove(sig.rawValue)
385+
}
386+
}
387+
}
388+
362389
/// A system signal
363-
public struct Signal: Equatable, CustomStringConvertible {
390+
struct Signal: Equatable, CustomStringConvertible {
364391
internal var rawValue: CInt
365392

366393
public static let TERM = Signal(rawValue: SIGTERM)
@@ -395,9 +422,9 @@ extension ServiceLifecycle: LifecycleTasksContainer {
395422
}
396423
}
397424

398-
extension ServiceLifecycle {
425+
public extension ServiceLifecycle {
399426
/// `ServiceLifecycle` configuration options.
400-
public struct Configuration {
427+
struct Configuration {
401428
/// Defines the `label` for the lifeycle and its Logger
402429
public var label: String
403430
/// Defines the `Logger` to log with.
@@ -413,7 +440,8 @@ extension ServiceLifecycle {
413440
logger: Logger? = nil,
414441
callbackQueue: DispatchQueue = .global(),
415442
shutdownSignal: [Signal]? = [.TERM, .INT],
416-
installBacktrace: Bool = true) {
443+
installBacktrace: Bool = true)
444+
{
417445
self.label = label
418446
self.logger = logger ?? Logger(label: label)
419447
self.callbackQueue = callbackQueue
@@ -433,7 +461,7 @@ struct ShutdownError: Error {
433461
public class ComponentLifecycle: LifecycleTask {
434462
public let label: String
435463
fileprivate let logger: Logger
436-
internal let shutdownGroup = DispatchGroup()
464+
fileprivate let shutdownGroup = DispatchGroup()
437465

438466
private var state = State.idle([])
439467
private let stateLock = Lock()
@@ -596,13 +624,15 @@ public class ComponentLifecycle: LifecycleTask {
596624

597625
private func startTask(on queue: DispatchQueue, tasks: [LifecycleTask], index: Int, callback: @escaping ([LifecycleTask], Error?) -> Void) {
598626
// async barrier
599-
let start = { (callback) -> Void in queue.async { tasks[index].start(callback) } }
600-
let callback = { (index, error) -> Void in queue.async { callback(index, error) } }
627+
let start = { callback in queue.async { tasks[index].start(callback) } }
628+
let callback = { index, error in queue.async { callback(index, error) } }
601629

602630
if index >= tasks.count {
603631
return callback(tasks, nil)
604632
}
605-
self.logger.info("starting tasks [\(tasks[index].label)]")
633+
if tasks[index].logStart {
634+
self.logger.info("starting tasks [\(tasks[index].label)]")
635+
}
606636
let startTime = DispatchTime.now()
607637
start { error in
608638
Timer(label: "\(self.label).\(tasks[index].label).lifecycle.start").recordNanoseconds(DispatchTime.now().uptimeNanoseconds - startTime.uptimeNanoseconds)
@@ -642,14 +672,16 @@ public class ComponentLifecycle: LifecycleTask {
642672

643673
private func shutdownTask(on queue: DispatchQueue, tasks: [LifecycleTask], index: Int, errors: [String: Error]?, callback: @escaping ([String: Error]?) -> Void) {
644674
// async barrier
645-
let shutdown = { (callback) -> Void in queue.async { tasks[index].shutdown(callback) } }
646-
let callback = { (errors) -> Void in queue.async { callback(errors) } }
675+
let shutdown = { callback in queue.async { tasks[index].shutdown(callback) } }
676+
let callback = { errors in queue.async { callback(errors) } }
647677

648678
if index >= tasks.count {
649679
return callback(errors)
650680
}
651681

652-
self.logger.info("stopping tasks [\(tasks[index].label)]")
682+
if tasks[index].logShutdown {
683+
self.logger.info("stopping tasks [\(tasks[index].label)]")
684+
}
653685
let startTime = DispatchTime.now()
654686
shutdown { error in
655687
Timer(label: "\(self.label).\(tasks[index].label).lifecycle.shutdown").recordNanoseconds(DispatchTime.now().uptimeNanoseconds - startTime.uptimeNanoseconds)
@@ -694,12 +726,12 @@ public protocol LifecycleTasksContainer {
694726
func register(_ tasks: [LifecycleTask])
695727
}
696728

697-
extension LifecycleTasksContainer {
729+
public extension LifecycleTasksContainer {
698730
/// Adds a `LifecycleTask` to a `LifecycleTasks` collection.
699731
///
700732
/// - parameters:
701733
/// - tasks: one or more `LifecycleTask`.
702-
public func register(_ tasks: LifecycleTask ...) {
734+
func register(_ tasks: LifecycleTask ...) {
703735
self.register(tasks)
704736
}
705737

@@ -709,7 +741,7 @@ extension LifecycleTasksContainer {
709741
/// - label: label of the item, useful for debugging.
710742
/// - start: `Handler` to perform the startup.
711743
/// - shutdown: `Handler` to perform the shutdown.
712-
public func register(label: String, start: LifecycleHandler, shutdown: LifecycleHandler, shutdownIfNotStarted: Bool? = nil) {
744+
func register(label: String, start: LifecycleHandler, shutdown: LifecycleHandler, shutdownIfNotStarted: Bool? = nil) {
713745
self.register(_LifecycleTask(label: label, shutdownIfNotStarted: shutdownIfNotStarted, start: start, shutdown: shutdown))
714746
}
715747

@@ -718,7 +750,7 @@ extension LifecycleTasksContainer {
718750
/// - parameters:
719751
/// - label: label of the item, useful for debugging.
720752
/// - handler: `Handler` to perform the shutdown.
721-
public func registerShutdown(label: String, _ handler: LifecycleHandler) {
753+
func registerShutdown(label: String, _ handler: LifecycleHandler) {
722754
self.register(label: label, start: .none, shutdown: handler)
723755
}
724756

@@ -728,7 +760,7 @@ extension LifecycleTasksContainer {
728760
/// - label: label of the item, useful for debugging.
729761
/// - start: `LifecycleStartHandler` to perform the startup and return the state.
730762
/// - shutdown: `LifecycleShutdownHandler` to perform the shutdown given the state.
731-
public func registerStateful<State>(label: String, start: LifecycleStartHandler<State>, shutdown: LifecycleShutdownHandler<State>) {
763+
func registerStateful<State>(label: String, start: LifecycleStartHandler<State>, shutdown: LifecycleShutdownHandler<State>) {
732764
self.register(StatefulLifecycleTask(label: label, start: start, shutdown: shutdown))
733765
}
734766
}
@@ -739,12 +771,16 @@ internal struct _LifecycleTask: LifecycleTask {
739771
let shutdownIfNotStarted: Bool
740772
let start: LifecycleHandler
741773
let shutdown: LifecycleHandler
774+
let logStart: Bool
775+
let logShutdown: Bool
742776

743777
init(label: String, shutdownIfNotStarted: Bool? = nil, start: LifecycleHandler, shutdown: LifecycleHandler) {
744778
self.label = label
745779
self.shutdownIfNotStarted = shutdownIfNotStarted ?? start.noop
746780
self.start = start
747781
self.shutdown = shutdown
782+
self.logStart = !start.noop
783+
self.logShutdown = !shutdown.noop
748784
}
749785

750786
func start(_ callback: @escaping (Error?) -> Void) {

Sources/Lifecycle/Locks.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ extension Lock {
8181
/// - Parameter body: The block to execute while holding the lock.
8282
/// - Returns: The value returned by the block.
8383
@inlinable
84-
internal func withLock<T>(_ body: () throws -> T) rethrows -> T {
84+
func withLock<T>(_ body: () throws -> T) rethrows -> T {
8585
self.lock()
8686
defer {
8787
self.unlock()
@@ -91,7 +91,7 @@ extension Lock {
9191

9292
// specialise Void return (for performance)
9393
@inlinable
94-
internal func withLockVoid(_ body: () throws -> Void) rethrows {
94+
func withLockVoid(_ body: () throws -> Void) rethrows {
9595
try self.withLock(body)
9696
}
9797
}

Sources/LifecycleNIOCompat/Bridge.swift

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
import Lifecycle
1616
import NIO
1717

18-
extension LifecycleHandler {
18+
public extension LifecycleHandler {
1919
/// Asynchronous `LifecycleHandler` based on an `EventLoopFuture`.
2020
///
2121
/// - parameters:
2222
/// - future: function returning the underlying `EventLoopFuture`
23-
public static func eventLoopFuture(_ future: @escaping () -> EventLoopFuture<Void>) -> LifecycleHandler {
23+
static func eventLoopFuture(_ future: @escaping () -> EventLoopFuture<Void>) -> LifecycleHandler {
2424
return LifecycleHandler { callback in
2525
future().whenComplete { result in
2626
switch result {
@@ -34,13 +34,13 @@ extension LifecycleHandler {
3434
}
3535
}
3636

37-
extension LifecycleHandler {
37+
public extension LifecycleHandler {
3838
/// `LifecycleHandler` that cancels a `RepeatedTask`.
3939
///
4040
/// - parameters:
4141
/// - task: `RepeatedTask` to be cancelled
4242
/// - on: `EventLoop` to use for cancelling the task
43-
public static func cancelRepeatedTask(_ task: RepeatedTask, on eventLoop: EventLoop) -> LifecycleHandler {
43+
static func cancelRepeatedTask(_ task: RepeatedTask, on eventLoop: EventLoop) -> LifecycleHandler {
4444
return self.eventLoopFuture {
4545
let promise = eventLoop.makePromise(of: Void.self)
4646
task.cancel(promise: promise)
@@ -49,12 +49,12 @@ extension LifecycleHandler {
4949
}
5050
}
5151

52-
extension LifecycleStartHandler {
52+
public extension LifecycleStartHandler {
5353
/// Asynchronous `LifecycleStartHandler` based on an `EventLoopFuture`.
5454
///
5555
/// - parameters:
5656
/// - future: function returning the underlying `EventLoopFuture`
57-
public static func eventLoopFuture(_ future: @escaping () -> EventLoopFuture<State>) -> LifecycleStartHandler {
57+
static func eventLoopFuture(_ future: @escaping () -> EventLoopFuture<State>) -> LifecycleStartHandler {
5858
return LifecycleStartHandler { callback in
5959
future().whenComplete { result in
6060
callback(result)
@@ -63,12 +63,12 @@ extension LifecycleStartHandler {
6363
}
6464
}
6565

66-
extension LifecycleShutdownHandler {
66+
public extension LifecycleShutdownHandler {
6767
/// Asynchronous `LifecycleShutdownHandler` based on an `EventLoopFuture`.
6868
///
6969
/// - parameters:
7070
/// - future: function returning the underlying `EventLoopFuture`
71-
public static func eventLoopFuture(_ future: @escaping (State) -> EventLoopFuture<Void>) -> LifecycleShutdownHandler {
71+
static func eventLoopFuture(_ future: @escaping (State) -> EventLoopFuture<Void>) -> LifecycleShutdownHandler {
7272
return LifecycleShutdownHandler { state, callback in
7373
future(state).whenComplete { result in
7474
switch result {
@@ -82,13 +82,13 @@ extension LifecycleShutdownHandler {
8282
}
8383
}
8484

85-
extension ComponentLifecycle {
85+
public extension ComponentLifecycle {
8686
/// Starts the provided `LifecycleItem` array.
8787
/// Startup is performed in the order of items provided.
8888
///
8989
/// - parameters:
9090
/// - eventLoop: The `eventLoop` which is used to generate the `EventLoopFuture` that is returned. After the start the future is fulfilled:
91-
public func start(on eventLoop: EventLoop) -> EventLoopFuture<Void> {
91+
func start(on eventLoop: EventLoop) -> EventLoopFuture<Void> {
9292
let promise = eventLoop.makePromise(of: Void.self)
9393
self.start { error in
9494
if let error = error {

Tests/LifecycleTests/ComponentLifecycleTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ final class ComponentLifecycleTests: XCTestCase {
5353
dispatchPrecondition(condition: .onQueue(.global()))
5454
XCTAssertTrue(startCalls.contains(id))
5555
stopCalls.append(id)
56-
})
56+
})
5757
}
5858
lifecycle.register(items)
5959

@@ -92,7 +92,7 @@ final class ComponentLifecycleTests: XCTestCase {
9292
dispatchPrecondition(condition: .onQueue(testQueue))
9393
XCTAssertTrue(startCalls.contains(id))
9494
stopCalls.append(id)
95-
})
95+
})
9696
}
9797
lifecycle.register(items)
9898

0 commit comments

Comments
 (0)