Skip to content

Commit 6fa20bd

Browse files
committed
Add an RPC cancellation handler
Motivation: As a service author it's useful to know if the RPC has been cancelled (because it's timed out, the remote peer closed it, the connection dropped etc). For cases where the stream has already closed this can be surfaced by a read or write failing. However, for cases like server-streaming RPCs where there are no reads and writes can be infrequent it's useful to have a more explicit signal. Modifications: - Add a `ServerCancellationManager`, this is internal per-stream storage for registering cancellation handlers and storing whether the RPC has been cancelled. - Add the `RPCCancellationHandle` nested within the `ServerContext`. This holds an instance of the manager and provides higher level APIs allowing users to check if the RPC has been cancellation and to wait until the RPC has been cancelled. - Add a top-level `withRPCCancellationHandler` which registers a callback with the manager. - Add a top-level `withServerContextRPCCancellationHandle` for creating and binding the task local manager. This is intended for use by transport implementations rather than users. - Update the in-process transport to cancel RPCs when shutting down gracefully. - Update the server executor to cancel RPCs when the timeout fires. Result: Users can watch for cancellation using `withRPCCancellationHandler`.
1 parent d5e0d70 commit 6fa20bd

File tree

11 files changed

+776
-73
lines changed

11 files changed

+776
-73
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
/*
2+
* Copyright 2024, gRPC Authors All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
private import Synchronization
18+
19+
/// Stores cancellation state for an RPC on the server .
20+
package final class ServerCancellationManager: Sendable {
21+
private let state: Mutex<State>
22+
23+
package init() {
24+
self.state = Mutex(State())
25+
}
26+
27+
/// Returns whether the RPC has been marked as cancelled.
28+
package var isRPCCancelled: Bool {
29+
self.state.withLock {
30+
return $0.isRPCCancelled
31+
}
32+
}
33+
34+
/// Marks the RPC as cancelled, potentially running any cancellation handlers.
35+
package func cancelRPC() {
36+
switch self.state.withLock({ $0.cancelRPC() }) {
37+
case .executeAndResume(let onCancelHandlers, let onCancelWaiters):
38+
for handler in onCancelHandlers {
39+
handler.handler()
40+
}
41+
42+
for onCancelWaiter in onCancelWaiters {
43+
switch onCancelWaiter {
44+
case .taskCancelled:
45+
()
46+
case .waiting(_, let continuation):
47+
continuation.resume(returning: .rpc)
48+
}
49+
}
50+
51+
case .doNothing:
52+
()
53+
}
54+
}
55+
56+
/// Adds a handler which is invoked when the RPC is cancelled.
57+
///
58+
/// - Returns: The ID of the handler, if it was added, or `nil` if the RPC is already cancelled.
59+
package func addRPCCancelledHandler(_ handler: @Sendable @escaping () -> Void) -> UInt64? {
60+
return self.state.withLock { state -> UInt64? in
61+
state.addRPCCancelledHandler(handler)
62+
}
63+
}
64+
65+
/// Removes a handler by its ID.
66+
package func removeRPCCancelledHandler(withID id: UInt64) {
67+
self.state.withLock { state in
68+
state.removeRPCCancelledHandler(withID: id)
69+
}
70+
}
71+
72+
/// Suspends until the RPC is cancelled or the `Task` is cancelled.
73+
package func suspendUntilRPCIsCancelled() async throws(CancellationError) {
74+
let id = self.state.withLock { $0.nextID() }
75+
76+
let source = await withTaskCancellationHandler {
77+
await withCheckedContinuation { continuation in
78+
let onAddWaiter = self.state.withLock {
79+
$0.addRPCIsCancelledWaiter(continuation: continuation, withID: id)
80+
}
81+
82+
switch onAddWaiter {
83+
case .doNothing:
84+
()
85+
case .complete(let continuation, let result):
86+
continuation.resume(returning: result)
87+
}
88+
}
89+
} onCancel: {
90+
switch self.state.withLock({ $0.cancelRPCCancellationWaiter(withID: id) }) {
91+
case .resume(let continuation, let result):
92+
continuation.resume(returning: result)
93+
case .doNothing:
94+
()
95+
}
96+
}
97+
98+
switch source {
99+
case .rpc:
100+
()
101+
case .task:
102+
throw CancellationError()
103+
}
104+
}
105+
}
106+
107+
extension ServerCancellationManager {
108+
enum CancellationSource {
109+
case rpc
110+
case task
111+
}
112+
113+
struct Handler: Sendable {
114+
var id: UInt64
115+
var handler: @Sendable () -> Void
116+
}
117+
118+
enum Waiter: Sendable {
119+
case waiting(UInt64, CheckedContinuation<CancellationSource, Never>)
120+
case taskCancelled(UInt64)
121+
122+
var id: UInt64 {
123+
switch self {
124+
case .waiting(let id, _):
125+
return id
126+
case .taskCancelled(let id):
127+
return id
128+
}
129+
}
130+
}
131+
132+
struct State {
133+
private var handlers: [Handler]
134+
private var waiters: [Waiter]
135+
private var _nextID: UInt64
136+
var isRPCCancelled: Bool
137+
138+
mutating func nextID() -> UInt64 {
139+
let id = self._nextID
140+
self._nextID &+= 1
141+
return id
142+
}
143+
144+
init() {
145+
self.handlers = []
146+
self.waiters = []
147+
self._nextID = 0
148+
self.isRPCCancelled = false
149+
}
150+
151+
mutating func cancelRPC() -> OnCancelRPC {
152+
let onCancel: OnCancelRPC
153+
154+
if self.isRPCCancelled {
155+
onCancel = .doNothing
156+
} else {
157+
self.isRPCCancelled = true
158+
onCancel = .executeAndResume(self.handlers, self.waiters)
159+
self.handlers = []
160+
self.waiters = []
161+
}
162+
163+
return onCancel
164+
}
165+
166+
mutating func addRPCCancelledHandler(_ handler: @Sendable @escaping () -> Void) -> UInt64? {
167+
if self.isRPCCancelled {
168+
handler()
169+
return nil
170+
} else {
171+
let id = self.nextID()
172+
self.handlers.append(.init(id: id, handler: handler))
173+
return id
174+
}
175+
}
176+
177+
mutating func removeRPCCancelledHandler(withID id: UInt64) {
178+
if let index = self.handlers.firstIndex(where: { $0.id == id }) {
179+
self.handlers.remove(at: index)
180+
}
181+
}
182+
183+
enum OnCancelRPC {
184+
case executeAndResume([Handler], [Waiter])
185+
case doNothing
186+
}
187+
188+
enum OnAddWaiter {
189+
case complete(CheckedContinuation<CancellationSource, Never>, CancellationSource)
190+
case doNothing
191+
}
192+
193+
mutating func addRPCIsCancelledWaiter(
194+
continuation: CheckedContinuation<CancellationSource, Never>,
195+
withID id: UInt64
196+
) -> OnAddWaiter {
197+
let onAddWaiter: OnAddWaiter
198+
199+
if self.isRPCCancelled {
200+
onAddWaiter = .complete(continuation, .rpc)
201+
} else if let index = self.waiters.firstIndex(where: { $0.id == id }) {
202+
switch self.waiters[index] {
203+
case .taskCancelled:
204+
onAddWaiter = .complete(continuation, .task)
205+
case .waiting:
206+
// There's already a continuation enqueued.
207+
fatalError("Inconsistent state")
208+
}
209+
} else {
210+
self.waiters.append(.waiting(id, continuation))
211+
onAddWaiter = .doNothing
212+
}
213+
214+
return onAddWaiter
215+
}
216+
217+
enum OnCancelRPCCancellationWaiter {
218+
case resume(CheckedContinuation<CancellationSource, Never>, CancellationSource)
219+
case doNothing
220+
}
221+
222+
mutating func cancelRPCCancellationWaiter(withID id: UInt64) -> OnCancelRPCCancellationWaiter {
223+
let onCancelWaiter: OnCancelRPCCancellationWaiter
224+
225+
if let index = self.waiters.firstIndex(where: { $0.id == id }) {
226+
let waiter = self.waiters.removeWithoutMaintainingOrder(at: index)
227+
switch waiter {
228+
case .taskCancelled:
229+
onCancelWaiter = .doNothing
230+
case .waiting(_, let continuation):
231+
onCancelWaiter = .resume(continuation, .task)
232+
}
233+
} else {
234+
self.waiters.append(.taskCancelled(id))
235+
onCancelWaiter = .doNothing
236+
}
237+
238+
return onCancelWaiter
239+
}
240+
}
241+
}
242+
243+
extension Array {
244+
fileprivate mutating func removeWithoutMaintainingOrder(at index: Int) -> Element {
245+
let lastElementIndex = self.index(before: self.endIndex)
246+
247+
if index == lastElementIndex {
248+
return self.remove(at: index)
249+
} else {
250+
self.swapAt(index, lastElementIndex)
251+
return self.removeLast()
252+
}
253+
}
254+
}

Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift

+17-31
Original file line numberDiff line numberDiff line change
@@ -119,43 +119,29 @@ struct ServerRPCExecutor {
119119
_ context: ServerContext
120120
) async throws -> StreamingServerResponse<Output>
121121
) async {
122-
await withTaskGroup(of: ServerExecutorTask.self) { group in
122+
await withTaskGroup(of: Void.self) { group in
123123
group.addTask {
124-
let result = await Result {
124+
do {
125125
try await Task.sleep(for: timeout, clock: .continuous)
126+
context.cancellation.cancel()
127+
} catch {
128+
() // Only cancel the RPC if the timeout completes.
126129
}
127-
return .timedOut(result)
128130
}
129131

130-
group.addTask {
131-
await Self._processRPC(
132-
context: context,
133-
metadata: metadata,
134-
inbound: inbound,
135-
outbound: outbound,
136-
deserializer: deserializer,
137-
serializer: serializer,
138-
interceptors: interceptors,
139-
handler: handler
140-
)
141-
return .executed
142-
}
143-
144-
while let next = await group.next() {
145-
switch next {
146-
case .timedOut(.success):
147-
// Timeout expired; cancel the work.
148-
group.cancelAll()
149-
150-
case .timedOut(.failure):
151-
// Timeout failed (because it was cancelled). Wait for more tasks to finish.
152-
()
132+
await Self._processRPC(
133+
context: context,
134+
metadata: metadata,
135+
inbound: inbound,
136+
outbound: outbound,
137+
deserializer: deserializer,
138+
serializer: serializer,
139+
interceptors: interceptors,
140+
handler: handler
141+
)
153142

154-
case .executed:
155-
// The work finished. Cancel any remaining tasks.
156-
group.cancelAll()
157-
}
158-
}
143+
// Cancel the timeout
144+
group.cancelAll()
159145
}
160146
}
161147

0 commit comments

Comments
 (0)