Skip to content

Commit 121e34f

Browse files
committed
Reimplement Task.sleep(nanoseconds:) without the raciness.
The prior implementation of `Task.sleep()` effectively had two different atomic words to capture the state, which could lead to cases where cancelling before a sleep operation started would fail to throw `CancellationError`. Reimplement the logic for the cancellable sleep with a more traditional lock-free approach by packing all of the state information into a single word, where we always load, figure out what to do, then compare-and-swap.
1 parent 1a024e9 commit 121e34f

File tree

1 file changed

+218
-68
lines changed

1 file changed

+218
-68
lines changed

stdlib/public/Concurrency/TaskSleep.swift

+218-68
Original file line numberDiff line numberDiff line change
@@ -32,45 +32,164 @@ extension Task where Success == Never, Failure == Never {
3232
/// sleep(nanoseconds:).
3333
private typealias SleepContinuation = UnsafeContinuation<(), Error>
3434

35+
/// Describes the state of a sleep() operation.
36+
private enum SleepState {
37+
/// The sleep continuation has not yet begun.
38+
case notStarted
39+
40+
// The sleep continuation has been created and is available here.
41+
case activeContinuation(SleepContinuation)
42+
43+
/// The sleep has finished.
44+
case finished
45+
46+
/// The sleep was cancelled.
47+
case cancelled
48+
49+
/// The sleep was cancelled before it even got started.
50+
case cancelledBeforeStarted
51+
52+
/// Decode sleep state from the word of storage.
53+
init(word: Builtin.Word) {
54+
switch UInt(word) & 0x03 {
55+
case 0:
56+
let continuationBits = UInt(word) & ~0x03
57+
if continuationBits == 0 {
58+
self = .notStarted
59+
} else {
60+
let continuation = unsafeBitCast(
61+
continuationBits, to: SleepContinuation.self)
62+
self = .activeContinuation(continuation)
63+
}
64+
65+
case 1:
66+
self = .finished
67+
68+
case 2:
69+
self = .cancelled
70+
71+
case 3:
72+
self = .cancelledBeforeStarted
73+
74+
default:
75+
fatalError("Bitmask failure")
76+
}
77+
}
78+
79+
/// Decode sleep state by loading from the given pointer
80+
init(loading wordPtr: UnsafeMutablePointer<Builtin.Word>) {
81+
self.init(word: Builtin.atomicload_seqcst_Word(wordPtr._rawValue))
82+
}
83+
84+
/// Encode sleep state into a word of storage.
85+
var word: UInt {
86+
switch self {
87+
case .notStarted:
88+
return 0
89+
90+
case .activeContinuation(let continuation):
91+
let continuationBits = unsafeBitCast(continuation, to: UInt.self)
92+
return continuationBits
93+
94+
case .finished:
95+
return 1
96+
97+
case .cancelled:
98+
return 2
99+
100+
case .cancelledBeforeStarted:
101+
return 3
102+
}
103+
}
104+
}
105+
35106
/// Called when the sleep(nanoseconds:) operation woke up without being
36107
/// cancelled.
37108
private static func onSleepWake(
38-
_ wordPtr: UnsafeMutablePointer<Builtin.Word>,
39-
_ continuation: UnsafeContinuation<(), Error>
109+
_ wordPtr: UnsafeMutablePointer<Builtin.Word>
40110
) {
41-
// Indicate that we've finished by putting a "1" into the flag word.
42-
let (_, won) = Builtin.cmpxchg_seqcst_seqcst_Word(
43-
wordPtr._rawValue,
44-
UInt(0)._builtinWordValue,
45-
UInt(1)._builtinWordValue)
46-
47-
if Bool(_builtinBooleanLiteral: won) {
48-
// The sleep finished, invoke the continuation.
49-
continuation.resume()
50-
} else {
51-
// The task was cancelled first, which means the continuation was
52-
// called by the cancellation handler. We need to deallocate up the flag
53-
// word, because it was left over for this task to complete.
54-
wordPtr.deallocate()
111+
while true {
112+
let state = SleepState(loading: wordPtr)
113+
switch state {
114+
case .notStarted:
115+
fatalError("Cannot wake before we even started")
116+
117+
case .activeContinuation(let continuation):
118+
// We have an active continuation, so try to transition to the
119+
// "finished" state.
120+
let (_, won) = Builtin.cmpxchg_seqcst_seqcst_Word(
121+
wordPtr._rawValue,
122+
state.word._builtinWordValue,
123+
SleepState.finished.word._builtinWordValue)
124+
if Bool(_builtinBooleanLiteral: won) {
125+
// The sleep finished, so invoke the continuation: we're done.
126+
continuation.resume()
127+
return
128+
}
129+
130+
// Try again!
131+
continue
132+
133+
case .finished:
134+
fatalError("Already finished normally, can't do that again")
135+
136+
case .cancelled:
137+
// The task was cancelled, which means the continuation was
138+
// called by the cancellation handler. We need to deallocate the flag
139+
// word, because it was left over for this task to complete.
140+
wordPtr.deallocate()
141+
return
142+
143+
case .cancelledBeforeStarted:
144+
// Nothing to do;
145+
return
146+
}
55147
}
56148
}
57149

58150
/// Called when the sleep(nanoseconds:) operation has been cancelled before
59151
/// the sleep completed.
60152
private static func onSleepCancel(
61-
_ wordPtr: UnsafeMutablePointer<Builtin.Word>,
62-
_ continuation: UnsafeContinuation<(), Error>
153+
_ wordPtr: UnsafeMutablePointer<Builtin.Word>
63154
) {
64-
// Indicate that we've finished by putting a "2" into the flag word.
65-
let (_, won) = Builtin.cmpxchg_seqcst_seqcst_Word(
66-
wordPtr._rawValue,
67-
UInt(0)._builtinWordValue,
68-
UInt(2)._builtinWordValue)
69-
70-
if Bool(_builtinBooleanLiteral: won) {
71-
// We recorded the task cancellation before the sleep finished, so
72-
// invoke the continuation with a the cancellation error.
73-
continuation.resume(throwing: _Concurrency.CancellationError())
155+
while true {
156+
let state = SleepState(loading: wordPtr)
157+
switch state {
158+
case .notStarted:
159+
// We haven't started yet, so try to transition to the cancelled-before
160+
// started state.
161+
let (_, won) = Builtin.cmpxchg_seqcst_seqcst_Word(
162+
wordPtr._rawValue,
163+
state.word._builtinWordValue,
164+
SleepState.cancelledBeforeStarted.word._builtinWordValue)
165+
if Bool(_builtinBooleanLiteral: won) {
166+
return
167+
}
168+
169+
// Try again!
170+
continue
171+
172+
case .activeContinuation(let continuation):
173+
// We have an active continuation, so try to transition to the
174+
// "cancelled" state.
175+
let (_, won) = Builtin.cmpxchg_seqcst_seqcst_Word(
176+
wordPtr._rawValue,
177+
state.word._builtinWordValue,
178+
SleepState.cancelled.word._builtinWordValue)
179+
if Bool(_builtinBooleanLiteral: won) {
180+
// We recorded the task cancellation before the sleep finished, so
181+
// invoke the continuation with the cancellation error.
182+
continuation.resume(throwing: _Concurrency.CancellationError())
183+
return
184+
}
185+
186+
// Try again!
187+
continue
188+
189+
case .finished, .cancelled, .cancelledBeforeStarted:
190+
// The operation already finished, so there is nothing more to do.
191+
return
192+
}
74193
}
75194
}
76195

@@ -80,64 +199,95 @@ extension Task where Success == Never, Failure == Never {
80199
///
81200
/// This function does _not_ block the underlying thread.
82201
public static func sleep(nanoseconds duration: UInt64) async throws {
83-
// If the task was already cancelled, go ahead and throw now.
84-
try checkCancellation()
85-
86-
// Allocate storage for the flag word and continuation.
87-
let wordPtr = UnsafeMutablePointer<Builtin.Word>.allocate(capacity: 2)
202+
// Allocate storage for the storage word.
203+
let wordPtr = UnsafeMutablePointer<Builtin.Word>.allocate(capacity: 1)
88204

89-
// Initialize the flag word to 0, which means the continuation has not
90-
// executed.
205+
// Initialize the flag word to "not started", which means the continuation
206+
// has neither been created nor completed.
91207
Builtin.atomicstore_seqcst_Word(
92-
wordPtr._rawValue, UInt(0)._builtinWordValue)
93-
94-
// A pointer to the storage continuation. Also initialize it to zero, to
95-
// indicate that there is no continuation.
96-
let continuationPtr = wordPtr + 1
97-
Builtin.atomicstore_seqcst_Word(
98-
continuationPtr._rawValue, UInt(0)._builtinWordValue)
208+
wordPtr._rawValue, SleepState.notStarted.word._builtinWordValue)
99209

100210
do {
101211
// Install a cancellation handler to resume the continuation by
102212
// throwing CancellationError.
103213
try await withTaskCancellationHandler {
104214
let _: () = try await withUnsafeThrowingContinuation { continuation in
105-
// Stash the continuation so the cancellation handler can see it.
106-
Builtin.atomicstore_seqcst_Word(
107-
continuationPtr._rawValue,
108-
unsafeBitCast(continuation, to: Builtin.Word.self))
109-
110-
// Create a task that resumes the continuation normally if it
111-
// finishes first. Enqueue it directly with the delay, so it fires
112-
// when we're done sleeping.
113-
let sleepTaskFlags = taskCreateFlags(
114-
priority: nil, isChildTask: false, copyTaskLocals: false,
115-
inheritContext: false, enqueueJob: false,
116-
addPendingGroupTaskUnconditionally: false)
117-
let (sleepTask, _) = Builtin.createAsyncTask(sleepTaskFlags) {
118-
onSleepWake(wordPtr, continuation)
215+
while true {
216+
let state = SleepState(loading: wordPtr)
217+
switch state {
218+
case .notStarted:
219+
// The word that describes the active continuation state.
220+
let continuationWord =
221+
SleepState.activeContinuation(continuation).word
222+
223+
// Try to swap in the continuation word.
224+
let (_, won) = Builtin.cmpxchg_seqcst_seqcst_Word(
225+
wordPtr._rawValue,
226+
state.word._builtinWordValue,
227+
continuationWord._builtinWordValue)
228+
if !Bool(_builtinBooleanLiteral: won) {
229+
// Keep trying!
230+
continue
231+
}
232+
233+
// Create a task that resumes the continuation normally if it
234+
// finishes first. Enqueue it directly with the delay, so it fires
235+
// when we're done sleeping.
236+
let sleepTaskFlags = taskCreateFlags(
237+
priority: nil, isChildTask: false, copyTaskLocals: false,
238+
inheritContext: false, enqueueJob: false,
239+
addPendingGroupTaskUnconditionally: false)
240+
let (sleepTask, _) = Builtin.createAsyncTask(sleepTaskFlags) {
241+
onSleepWake(wordPtr)
242+
}
243+
_enqueueJobGlobalWithDelay(
244+
duration, Builtin.convertTaskToJob(sleepTask))
245+
return
246+
247+
case .activeContinuation, .finished:
248+
fatalError("Impossible to have multiple active continuations")
249+
250+
case .cancelled:
251+
fatalError("Impossible to have cancelled before we began")
252+
253+
case .cancelledBeforeStarted:
254+
// Finish the continuation normally. We'll throw later, after
255+
// we clean up.
256+
continuation.resume()
257+
return
119258
}
120-
_enqueueJobGlobalWithDelay(
121-
duration, Builtin.convertTaskToJob(sleepTask))
122259
}
123-
} onCancel: {
124-
let continuationWord = continuationPtr.pointee
125-
if UInt(continuationWord) != 0 {
126-
// Try to cancel, which will resume the continuation by throwing a
127-
// CancellationError if the continuation hasn't already been resumed.
128-
continuationPtr.withMemoryRebound(
129-
to: SleepContinuation.self, capacity: 1) {
130-
onSleepCancel(wordPtr, $0.pointee)
131-
}
132260
}
261+
} onCancel: {
262+
onSleepCancel(wordPtr)
263+
}
264+
265+
// Determine whether we got cancelled before we even started.
266+
let cancelledBeforeStarted: Bool
267+
switch SleepState(loading: wordPtr) {
268+
case .notStarted, .activeContinuation, .cancelled:
269+
fatalError("Invalid state for non-cancelled sleep task")
270+
271+
case .cancelledBeforeStarted:
272+
cancelledBeforeStarted = true
273+
274+
case .finished:
275+
cancelledBeforeStarted = false
133276
}
134277

135278
// We got here without being cancelled, so deallocate the storage for
136279
// the flag word and continuation.
137280
wordPtr.deallocate()
281+
282+
// If we got cancelled before we even started, through the cancellation
283+
// error now.
284+
if cancelledBeforeStarted {
285+
throw _Concurrency.CancellationError()
286+
}
138287
} catch {
139288
// The task was cancelled; propagate the error. The "on wake" task is
140-
// responsible for deallocating the flag word.
289+
// responsible for deallocating the flag word and continuation, if it's
290+
// still running.
141291
throw error
142292
}
143293
}

0 commit comments

Comments
 (0)