@@ -30,12 +30,22 @@ public protocol LifecycleTask {
30
30
var shutdownIfNotStarted : Bool { get }
31
31
func start( _ callback: @escaping ( Error ? ) -> Void )
32
32
func shutdown( _ callback: @escaping ( Error ? ) -> Void )
33
+ var logStart : Bool { get }
34
+ var logShutdown : Bool { get }
33
35
}
34
36
35
37
extension LifecycleTask {
36
38
public var shutdownIfNotStarted : Bool {
37
39
return false
38
40
}
41
+
42
+ public var logStart : Bool {
43
+ return true
44
+ }
45
+
46
+ public var logShutdown : Bool {
47
+ return true
48
+ }
39
49
}
40
50
41
51
// MARK: - LifecycleHandler
@@ -317,9 +327,14 @@ public struct ServiceLifecycle {
317
327
self . log ( " intercepted signal: \( signal) " )
318
328
self . shutdown ( )
319
329
} , cancelAfterTrap: true )
320
- self . underlying. shutdownGroup. notify ( queue: . global( ) ) {
321
- signalSource. cancel ( )
322
- }
330
+ // register cleanup as the last task
331
+ self . registerShutdown ( label: " \( signal) shutdown hook cleanup " , . sync {
332
+ // cancel if not already canceled by the trap
333
+ if !signalSource. isCancelled {
334
+ signalSource. cancel ( )
335
+ ServiceLifecycle . removeTrap ( signal: signal)
336
+ }
337
+ } )
323
338
}
324
339
}
325
340
@@ -343,22 +358,34 @@ extension ServiceLifecycle {
343
358
public static func trap( signal sig: Signal , handler: @escaping ( Signal ) -> Void , on queue: DispatchQueue = . global( ) , cancelAfterTrap: Bool = false ) -> DispatchSourceSignal {
344
359
// on linux, we can call singal() once per process
345
360
self . trappedLock. withLockVoid {
346
- if !trapped. contains ( sig. rawValue) {
361
+ if !self . trapped. contains ( sig. rawValue) {
347
362
signal ( sig. rawValue, SIG_IGN)
348
- trapped. insert ( sig. rawValue)
363
+ self . trapped. insert ( sig. rawValue)
349
364
}
350
365
}
351
366
let signalSource = DispatchSource . makeSignalSource ( signal: sig. rawValue, queue: queue)
352
367
signalSource. setEventHandler {
368
+ // run handler first
369
+ handler ( sig)
370
+ // then cancel trap if so requested
353
371
if cancelAfterTrap {
354
372
signalSource. cancel ( )
373
+ self . removeTrap ( signal: sig)
355
374
}
356
- handler ( sig)
357
375
}
358
376
signalSource. resume ( )
359
377
return signalSource
360
378
}
361
379
380
+ public static func removeTrap( signal sig: Signal ) {
381
+ self . trappedLock. withLockVoid {
382
+ if self . trapped. contains ( sig. rawValue) {
383
+ signal ( sig. rawValue, SIG_DFL)
384
+ self . trapped. remove ( sig. rawValue)
385
+ }
386
+ }
387
+ }
388
+
362
389
/// A system signal
363
390
public struct Signal : Equatable , CustomStringConvertible {
364
391
internal var rawValue : CInt
@@ -433,7 +460,7 @@ struct ShutdownError: Error {
433
460
public class ComponentLifecycle : LifecycleTask {
434
461
public let label : String
435
462
fileprivate let logger : Logger
436
- internal let shutdownGroup = DispatchGroup ( )
463
+ fileprivate let shutdownGroup = DispatchGroup ( )
437
464
438
465
private var state = State . idle ( [ ] )
439
466
private let stateLock = Lock ( )
@@ -596,13 +623,15 @@ public class ComponentLifecycle: LifecycleTask {
596
623
597
624
private func startTask( on queue: DispatchQueue , tasks: [ LifecycleTask ] , index: Int , callback: @escaping ( [ LifecycleTask ] , Error ? ) -> Void ) {
598
625
// async barrier
599
- let start = { ( callback) -> Void in queue. async { tasks [ index] . start ( callback) } }
600
- let callback = { ( index, error) -> Void in queue. async { callback ( index, error) } }
626
+ let start = { callback in queue. async { tasks [ index] . start ( callback) } }
627
+ let callback = { index, error in queue. async { callback ( index, error) } }
601
628
602
629
if index >= tasks. count {
603
630
return callback ( tasks, nil )
604
631
}
605
- self . logger. info ( " starting tasks [ \( tasks [ index] . label) ] " )
632
+ if tasks [ index] . logStart {
633
+ self . logger. info ( " starting tasks [ \( tasks [ index] . label) ] " )
634
+ }
606
635
let startTime = DispatchTime . now ( )
607
636
start { error in
608
637
Timer ( label: " \( self . label) . \( tasks [ index] . label) .lifecycle.start " ) . recordNanoseconds ( DispatchTime . now ( ) . uptimeNanoseconds - startTime. uptimeNanoseconds)
@@ -642,14 +671,16 @@ public class ComponentLifecycle: LifecycleTask {
642
671
643
672
private func shutdownTask( on queue: DispatchQueue , tasks: [ LifecycleTask ] , index: Int , errors: [ String : Error ] ? , callback: @escaping ( [ String : Error ] ? ) -> Void ) {
644
673
// async barrier
645
- let shutdown = { ( callback) -> Void in queue. async { tasks [ index] . shutdown ( callback) } }
646
- let callback = { ( errors) -> Void in queue. async { callback ( errors) } }
674
+ let shutdown = { callback in queue. async { tasks [ index] . shutdown ( callback) } }
675
+ let callback = { errors in queue. async { callback ( errors) } }
647
676
648
677
if index >= tasks. count {
649
678
return callback ( errors)
650
679
}
651
680
652
- self . logger. info ( " stopping tasks [ \( tasks [ index] . label) ] " )
681
+ if tasks [ index] . logShutdown {
682
+ self . logger. info ( " stopping tasks [ \( tasks [ index] . label) ] " )
683
+ }
653
684
let startTime = DispatchTime . now ( )
654
685
shutdown { error in
655
686
Timer ( label: " \( self . label) . \( tasks [ index] . label) .lifecycle.shutdown " ) . recordNanoseconds ( DispatchTime . now ( ) . uptimeNanoseconds - startTime. uptimeNanoseconds)
@@ -739,12 +770,16 @@ internal struct _LifecycleTask: LifecycleTask {
739
770
let shutdownIfNotStarted : Bool
740
771
let start : LifecycleHandler
741
772
let shutdown : LifecycleHandler
773
+ let logStart : Bool
774
+ let logShutdown : Bool
742
775
743
776
init ( label: String , shutdownIfNotStarted: Bool ? = nil , start: LifecycleHandler , shutdown: LifecycleHandler ) {
744
777
self . label = label
745
778
self . shutdownIfNotStarted = shutdownIfNotStarted ?? start. noop
746
779
self . start = start
747
780
self . shutdown = shutdown
781
+ self . logStart = !start. noop
782
+ self . logShutdown = !shutdown. noop
748
783
}
749
784
750
785
func start( _ callback: @escaping ( Error ? ) -> Void ) {
0 commit comments