From 6fa20bdc188115d9b4f427966f29ab3ca1f3fd8c Mon Sep 17 00:00:00 2001 From: George Barnett Date: Tue, 8 Oct 2024 15:41:54 +0100 Subject: [PATCH 1/3] Add an RPC cancellation handler Motivation: As a service author it's useful to know if the RPC has been cancelled (because it's timed out, the remote peer closed it, the connection dropped etc). For cases where the stream has already closed this can be surfaced by a read or write failing. However, for cases like server-streaming RPCs where there are no reads and writes can be infrequent it's useful to have a more explicit signal. Modifications: - Add a `ServerCancellationManager`, this is internal per-stream storage for registering cancellation handlers and storing whether the RPC has been cancelled. - Add the `RPCCancellationHandle` nested within the `ServerContext`. This holds an instance of the manager and provides higher level APIs allowing users to check if the RPC has been cancellation and to wait until the RPC has been cancelled. - Add a top-level `withRPCCancellationHandler` which registers a callback with the manager. - Add a top-level `withServerContextRPCCancellationHandle` for creating and binding the task local manager. This is intended for use by transport implementations rather than users. - Update the in-process transport to cancel RPCs when shutting down gracefully. - Update the server executor to cancel RPCs when the timeout fires. Result: Users can watch for cancellation using `withRPCCancellationHandler`. --- .../Internal/ServerCancellationManager.swift | 254 ++++++++++++++++++ .../Server/Internal/ServerRPCExecutor.swift | 48 ++-- .../ServerContext+RPCCancellationHandle.swift | 115 ++++++++ .../GRPCCore/Call/Server/ServerContext.swift | 11 +- .../GRPCCore/Internal/Result+Catching.swift | 4 +- .../InProcessTransport+Server.swift | 61 ++++- .../ServerCancellationManagerTests.swift | 91 +++++++ .../ServerRPCExecutorTestHarness.swift | 56 ++-- .../Internal/ServerRPCExecutorTests.swift | 22 +- .../Call/Server/ServerContextTests.swift | 62 +++++ .../InProcessTransportTests.swift | 125 +++++++++ 11 files changed, 776 insertions(+), 73 deletions(-) create mode 100644 Sources/GRPCCore/Call/Server/Internal/ServerCancellationManager.swift create mode 100644 Sources/GRPCCore/Call/Server/ServerContext+RPCCancellationHandle.swift create mode 100644 Tests/GRPCCoreTests/Call/Server/Internal/ServerCancellationManagerTests.swift create mode 100644 Tests/GRPCCoreTests/Call/Server/ServerContextTests.swift create mode 100644 Tests/GRPCInProcessTransportTests/InProcessTransportTests.swift diff --git a/Sources/GRPCCore/Call/Server/Internal/ServerCancellationManager.swift b/Sources/GRPCCore/Call/Server/Internal/ServerCancellationManager.swift new file mode 100644 index 000000000..471f9d007 --- /dev/null +++ b/Sources/GRPCCore/Call/Server/Internal/ServerCancellationManager.swift @@ -0,0 +1,254 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +private import Synchronization + +/// Stores cancellation state for an RPC on the server . +package final class ServerCancellationManager: Sendable { + private let state: Mutex + + package init() { + self.state = Mutex(State()) + } + + /// Returns whether the RPC has been marked as cancelled. + package var isRPCCancelled: Bool { + self.state.withLock { + return $0.isRPCCancelled + } + } + + /// Marks the RPC as cancelled, potentially running any cancellation handlers. + package func cancelRPC() { + switch self.state.withLock({ $0.cancelRPC() }) { + case .executeAndResume(let onCancelHandlers, let onCancelWaiters): + for handler in onCancelHandlers { + handler.handler() + } + + for onCancelWaiter in onCancelWaiters { + switch onCancelWaiter { + case .taskCancelled: + () + case .waiting(_, let continuation): + continuation.resume(returning: .rpc) + } + } + + case .doNothing: + () + } + } + + /// Adds a handler which is invoked when the RPC is cancelled. + /// + /// - Returns: The ID of the handler, if it was added, or `nil` if the RPC is already cancelled. + package func addRPCCancelledHandler(_ handler: @Sendable @escaping () -> Void) -> UInt64? { + return self.state.withLock { state -> UInt64? in + state.addRPCCancelledHandler(handler) + } + } + + /// Removes a handler by its ID. + package func removeRPCCancelledHandler(withID id: UInt64) { + self.state.withLock { state in + state.removeRPCCancelledHandler(withID: id) + } + } + + /// Suspends until the RPC is cancelled or the `Task` is cancelled. + package func suspendUntilRPCIsCancelled() async throws(CancellationError) { + let id = self.state.withLock { $0.nextID() } + + let source = await withTaskCancellationHandler { + await withCheckedContinuation { continuation in + let onAddWaiter = self.state.withLock { + $0.addRPCIsCancelledWaiter(continuation: continuation, withID: id) + } + + switch onAddWaiter { + case .doNothing: + () + case .complete(let continuation, let result): + continuation.resume(returning: result) + } + } + } onCancel: { + switch self.state.withLock({ $0.cancelRPCCancellationWaiter(withID: id) }) { + case .resume(let continuation, let result): + continuation.resume(returning: result) + case .doNothing: + () + } + } + + switch source { + case .rpc: + () + case .task: + throw CancellationError() + } + } +} + +extension ServerCancellationManager { + enum CancellationSource { + case rpc + case task + } + + struct Handler: Sendable { + var id: UInt64 + var handler: @Sendable () -> Void + } + + enum Waiter: Sendable { + case waiting(UInt64, CheckedContinuation) + case taskCancelled(UInt64) + + var id: UInt64 { + switch self { + case .waiting(let id, _): + return id + case .taskCancelled(let id): + return id + } + } + } + + struct State { + private var handlers: [Handler] + private var waiters: [Waiter] + private var _nextID: UInt64 + var isRPCCancelled: Bool + + mutating func nextID() -> UInt64 { + let id = self._nextID + self._nextID &+= 1 + return id + } + + init() { + self.handlers = [] + self.waiters = [] + self._nextID = 0 + self.isRPCCancelled = false + } + + mutating func cancelRPC() -> OnCancelRPC { + let onCancel: OnCancelRPC + + if self.isRPCCancelled { + onCancel = .doNothing + } else { + self.isRPCCancelled = true + onCancel = .executeAndResume(self.handlers, self.waiters) + self.handlers = [] + self.waiters = [] + } + + return onCancel + } + + mutating func addRPCCancelledHandler(_ handler: @Sendable @escaping () -> Void) -> UInt64? { + if self.isRPCCancelled { + handler() + return nil + } else { + let id = self.nextID() + self.handlers.append(.init(id: id, handler: handler)) + return id + } + } + + mutating func removeRPCCancelledHandler(withID id: UInt64) { + if let index = self.handlers.firstIndex(where: { $0.id == id }) { + self.handlers.remove(at: index) + } + } + + enum OnCancelRPC { + case executeAndResume([Handler], [Waiter]) + case doNothing + } + + enum OnAddWaiter { + case complete(CheckedContinuation, CancellationSource) + case doNothing + } + + mutating func addRPCIsCancelledWaiter( + continuation: CheckedContinuation, + withID id: UInt64 + ) -> OnAddWaiter { + let onAddWaiter: OnAddWaiter + + if self.isRPCCancelled { + onAddWaiter = .complete(continuation, .rpc) + } else if let index = self.waiters.firstIndex(where: { $0.id == id }) { + switch self.waiters[index] { + case .taskCancelled: + onAddWaiter = .complete(continuation, .task) + case .waiting: + // There's already a continuation enqueued. + fatalError("Inconsistent state") + } + } else { + self.waiters.append(.waiting(id, continuation)) + onAddWaiter = .doNothing + } + + return onAddWaiter + } + + enum OnCancelRPCCancellationWaiter { + case resume(CheckedContinuation, CancellationSource) + case doNothing + } + + mutating func cancelRPCCancellationWaiter(withID id: UInt64) -> OnCancelRPCCancellationWaiter { + let onCancelWaiter: OnCancelRPCCancellationWaiter + + if let index = self.waiters.firstIndex(where: { $0.id == id }) { + let waiter = self.waiters.removeWithoutMaintainingOrder(at: index) + switch waiter { + case .taskCancelled: + onCancelWaiter = .doNothing + case .waiting(_, let continuation): + onCancelWaiter = .resume(continuation, .task) + } + } else { + self.waiters.append(.taskCancelled(id)) + onCancelWaiter = .doNothing + } + + return onCancelWaiter + } + } +} + +extension Array { + fileprivate mutating func removeWithoutMaintainingOrder(at index: Int) -> Element { + let lastElementIndex = self.index(before: self.endIndex) + + if index == lastElementIndex { + return self.remove(at: index) + } else { + self.swapAt(index, lastElementIndex) + return self.removeLast() + } + } +} diff --git a/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift b/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift index f85cbe318..d9a35da51 100644 --- a/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift +++ b/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift @@ -119,43 +119,29 @@ struct ServerRPCExecutor { _ context: ServerContext ) async throws -> StreamingServerResponse ) async { - await withTaskGroup(of: ServerExecutorTask.self) { group in + await withTaskGroup(of: Void.self) { group in group.addTask { - let result = await Result { + do { try await Task.sleep(for: timeout, clock: .continuous) + context.cancellation.cancel() + } catch { + () // Only cancel the RPC if the timeout completes. } - return .timedOut(result) } - group.addTask { - await Self._processRPC( - context: context, - metadata: metadata, - inbound: inbound, - outbound: outbound, - deserializer: deserializer, - serializer: serializer, - interceptors: interceptors, - handler: handler - ) - return .executed - } - - while let next = await group.next() { - switch next { - case .timedOut(.success): - // Timeout expired; cancel the work. - group.cancelAll() - - case .timedOut(.failure): - // Timeout failed (because it was cancelled). Wait for more tasks to finish. - () + await Self._processRPC( + context: context, + metadata: metadata, + inbound: inbound, + outbound: outbound, + deserializer: deserializer, + serializer: serializer, + interceptors: interceptors, + handler: handler + ) - case .executed: - // The work finished. Cancel any remaining tasks. - group.cancelAll() - } - } + // Cancel the timeout + group.cancelAll() } } diff --git a/Sources/GRPCCore/Call/Server/ServerContext+RPCCancellationHandle.swift b/Sources/GRPCCore/Call/Server/ServerContext+RPCCancellationHandle.swift new file mode 100644 index 000000000..90f609df9 --- /dev/null +++ b/Sources/GRPCCore/Call/Server/ServerContext+RPCCancellationHandle.swift @@ -0,0 +1,115 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import Synchronization + +extension ServerContext { + @TaskLocal + internal static var rpcCancellation: RPCCancellationHandle? + + /// A handle for the cancellation status of the RPC. + public struct RPCCancellationHandle: Sendable { + internal let manager: ServerCancellationManager + + /// Create a cancellation handle. + /// + /// To create an instance of this handle appropriately bound to a `Task` + /// use ``withServerContextRPCCancellationHandle(_:)``. + public init() { + self.manager = ServerCancellationManager() + } + + /// Returns whether the RPC has been cancelled. + public var isCancelled: Bool { + self.manager.isRPCCancelled + } + + /// Waits until the RPC has been cancelled. + /// + /// Throws a `CancellationError` if the `Task` is cancelled. + /// + /// You can also be notified when an RPC is cancelled by using + /// ``withRPCCancellationHandler(operation:onCancelRPC:)``. + public var cancelled: Void { + get async throws { + try await self.manager.suspendUntilRPCIsCancelled() + } + } + + /// Signal that the RPC should be cancelled. + /// + /// This is idempotent: calling it more than once has no effect. + public func cancel() { + self.manager.cancelRPC() + } + } +} + +/// Execute an operation with an RPC cancellation handler that's immediately invoked +/// if the RPC is canceled. +/// +/// RPCs can be cancelled for a number of reasons including: +/// 1. The RPC was taking too long to process and a timeout passed. +/// 2. The remote peer closed the underlying stream, either because they were no longer +/// interested in the result or due to a broken connection. +/// 3. The server began shutting down. +/// +/// - Important: This only applies to RPCs on the server. +/// - Parameters: +/// - operation: The operation to execute. +/// - handler: The handler which is invoked when the RPC is cancelled. +/// - Throws: Any error thrown by the `operation` closure. +/// - Returns: The result of the `operation` closure. +public func withRPCCancellationHandler( + operation: () async throws(Failure) -> Result, + onCancelRPC handler: @Sendable @escaping () -> Void +) async throws(Failure) -> Result { + guard let manager = ServerContext.rpcCancellation?.manager, + let id = manager.addRPCCancelledHandler(handler) + else { + return try await operation() + } + + defer { + manager.removeRPCCancelledHandler(withID: id) + } + + return try await operation() +} + +/// Provides scoped access to a server RPC cancellation handle. +/// +/// The cancellation handle should be passed to a ``ServerContext`` and last +/// the duration of the RPC. This function is intended for use when implementing +/// a ``ServerTransport``. +/// +/// If you want to be notified about RPCs being cancelled +/// use ``withRPCCancellationHandler(operation:onCancelRPC:)``. +/// +/// - Parameter operation: The operation to execute with the handle. +public func withServerContextRPCCancellationHandle( + _ operation: (ServerContext.RPCCancellationHandle) async throws(Failure) -> Success +) async throws(Failure) -> Success { + let handle = ServerContext.RPCCancellationHandle() + let result = await ServerContext.$rpcCancellation.withValue(handle) { + // Wrap up the outcome in a result as 'withValue' doesn't support typed throws. + return await Swift.Result { () async throws(Failure) -> Success in + return try await operation(handle) + } + } + + return try result.get() +} diff --git a/Sources/GRPCCore/Call/Server/ServerContext.swift b/Sources/GRPCCore/Call/Server/ServerContext.swift index a11f09acb..4d8613f93 100644 --- a/Sources/GRPCCore/Call/Server/ServerContext.swift +++ b/Sources/GRPCCore/Call/Server/ServerContext.swift @@ -19,8 +19,17 @@ public struct ServerContext: Sendable { /// A description of the method being called. public var descriptor: MethodDescriptor + /// A handle for checking the cancellation status of an RPC. + public var cancellation: RPCCancellationHandle + /// Create a new server context. - public init(descriptor: MethodDescriptor) { + /// + /// - Parameters: + /// - descriptor: A description of the method being called. + /// - cancellation: A cancellation handle. You can create a cancellation handle + /// using ``withServerContextRPCCancellationHandle(_:)``. + public init(descriptor: MethodDescriptor, cancellation: RPCCancellationHandle) { self.descriptor = descriptor + self.cancellation = cancellation } } diff --git a/Sources/GRPCCore/Internal/Result+Catching.swift b/Sources/GRPCCore/Internal/Result+Catching.swift index bf2393752..8f9cbe59c 100644 --- a/Sources/GRPCCore/Internal/Result+Catching.swift +++ b/Sources/GRPCCore/Internal/Result+Catching.swift @@ -14,12 +14,12 @@ * limitations under the License. */ -extension Result where Failure == any Error { +extension Result { /// Like `Result(catching:)`, but `async`. /// /// - Parameter body: An `async` closure to catch the result of. @inlinable - init(catching body: () async throws -> Success) async { + init(catching body: () async throws(Failure) -> Success) async { do { self = .success(try await body()) } catch { diff --git a/Sources/GRPCInProcessTransport/InProcessTransport+Server.swift b/Sources/GRPCInProcessTransport/InProcessTransport+Server.swift index 66e32d06f..c5f0a55fc 100644 --- a/Sources/GRPCInProcessTransport/InProcessTransport+Server.swift +++ b/Sources/GRPCInProcessTransport/InProcessTransport+Server.swift @@ -15,6 +15,7 @@ */ public import GRPCCore +private import Synchronization extension InProcessTransport { /// An in-process implementation of a ``ServerTransport``. @@ -27,16 +28,54 @@ extension InProcessTransport { /// To stop listening to new requests, call ``beginGracefulShutdown()``. /// /// - SeeAlso: ``ClientTransport`` - public struct Server: ServerTransport, Sendable { + public final class Server: ServerTransport, Sendable { public typealias Inbound = RPCAsyncSequence public typealias Outbound = RPCWriter.Closable private let newStreams: AsyncStream> private let newStreamsContinuation: AsyncStream>.Continuation + private struct State: Sendable { + private var id: UInt64 + private var handles: [UInt64: ServerContext.RPCCancellationHandle] + private var isShutdown: Bool + + private mutating func nextID() -> UInt64 { + let id = self.id + self.id &+= 1 + return id + } + + init() { + self.id = 0 + self.handles = [:] + self.isShutdown = false + } + + mutating func addHandle(_ handle: ServerContext.RPCCancellationHandle) -> (UInt64, Bool) { + let handleID = self.nextID() + self.handles[handleID] = handle + return (handleID, self.isShutdown) + } + + mutating func removeHandle(withID id: UInt64) { + self.handles.removeValue(forKey: id) + } + + mutating func beginShutdown() -> [ServerContext.RPCCancellationHandle] { + self.isShutdown = true + let values = Array(self.handles.values) + self.handles.removeAll() + return values + } + } + + private let handles: Mutex + /// Creates a new instance of ``Server``. public init() { (self.newStreams, self.newStreamsContinuation) = AsyncStream.makeStream() + self.handles = Mutex(State()) } /// Publish a new ``RPCStream``, which will be returned by the transport's ``events`` @@ -64,8 +103,21 @@ extension InProcessTransport { await withDiscardingTaskGroup { group in for await stream in self.newStreams { group.addTask { - let context = ServerContext(descriptor: stream.descriptor) - await streamHandler(stream, context) + await withServerContextRPCCancellationHandle { handle in + let (id, isShutdown) = self.handles.withLock({ $0.addHandle(handle) }) + defer { + self.handles.withLock { $0.removeHandle(withID: id) } + } + + // This happens if `beginGracefulShutdown` is called after the stream is added to + // new streams but before it's dequeued. + if isShutdown { + handle.cancel() + } + + let context = ServerContext(descriptor: stream.descriptor, cancellation: handle) + await streamHandler(stream, context) + } } } } @@ -76,6 +128,9 @@ extension InProcessTransport { /// - SeeAlso: ``ServerTransport`` public func beginGracefulShutdown() { self.newStreamsContinuation.finish() + for handle in self.handles.withLock({ $0.beginShutdown() }) { + handle.cancel() + } } } } diff --git a/Tests/GRPCCoreTests/Call/Server/Internal/ServerCancellationManagerTests.swift b/Tests/GRPCCoreTests/Call/Server/Internal/ServerCancellationManagerTests.swift new file mode 100644 index 000000000..45c851de8 --- /dev/null +++ b/Tests/GRPCCoreTests/Call/Server/Internal/ServerCancellationManagerTests.swift @@ -0,0 +1,91 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import GRPCCore +import Testing + +@Suite +struct ServerCancellationManagerTests { + @Test("Isn't cancelled after init") + func isNotCancelled() { + let manager = ServerCancellationManager() + #expect(!manager.isRPCCancelled) + } + + @Test("Is cancelled") + func isCancelled() { + let manager = ServerCancellationManager() + manager.cancelRPC() + #expect(manager.isRPCCancelled) + } + + @Test("Cancellation handler runs") + func addCancellationHandler() async throws { + let manager = ServerCancellationManager() + let signal = AsyncStream.makeStream(of: Void.self) + + let id = manager.addRPCCancelledHandler { + signal.continuation.finish() + } + + #expect(id != nil) + manager.cancelRPC() + let events: [Void] = await signal.stream.reduce(into: []) { $0.append($1) } + #expect(events.isEmpty) + } + + @Test("Cancellation handler runs immediately when already cancelled") + func addCancellationHandlerAfterCancelled() async throws { + let manager = ServerCancellationManager() + let signal = AsyncStream.makeStream(of: Void.self) + manager.cancelRPC() + + let id = manager.addRPCCancelledHandler { + signal.continuation.finish() + } + + #expect(id == nil) + let events: [Void] = await signal.stream.reduce(into: []) { $0.append($1) } + #expect(events.isEmpty) + } + + @Test("Remove cancellation handler") + func removeCancellationHandler() async throws { + let manager = ServerCancellationManager() + let signal = AsyncStream.makeStream(of: Void.self) + + let id = manager.addRPCCancelledHandler { + Issue.record("Unexpected cancellation") + } + + #expect(id != nil) + manager.removeRPCCancelledHandler(withID: id!) + manager.cancelRPC() + } + + @Test("Wait for cancellation") + func waitForCancellation() async throws { + let manager = ServerCancellationManager() + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await manager.suspendUntilRPCIsCancelled() + } + + manager.cancelRPC() + try await group.waitForAll() + } + } +} diff --git a/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTestSupport/ServerRPCExecutorTestHarness.swift b/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTestSupport/ServerRPCExecutorTestHarness.swift index 6cca2d4d2..e645c5c20 100644 --- a/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTestSupport/ServerRPCExecutorTestHarness.swift +++ b/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTestSupport/ServerRPCExecutorTestHarness.swift @@ -20,24 +20,29 @@ import XCTest struct ServerRPCExecutorTestHarness { struct ServerHandler: Sendable { let fn: - @Sendable (StreamingServerRequest) async throws -> StreamingServerResponse + @Sendable ( + _ request: StreamingServerRequest, + _ context: ServerContext + ) async throws -> StreamingServerResponse init( _ fn: @escaping @Sendable ( - StreamingServerRequest + _ request: StreamingServerRequest, + _ context: ServerContext ) async throws -> StreamingServerResponse ) { self.fn = fn } func handle( - _ request: StreamingServerRequest + _ request: StreamingServerRequest, + _ context: ServerContext ) async throws -> StreamingServerResponse { - try await self.fn(request) + try await self.fn(request, context) } static func throwing(_ error: any Error) -> Self { - return Self { _ in throw error } + return Self { _, _ in throw error } } } @@ -51,7 +56,8 @@ struct ServerRPCExecutorTestHarness { deserializer: some MessageDeserializer, serializer: some MessageSerializer, handler: @escaping @Sendable ( - StreamingServerRequest + StreamingServerRequest, + ServerContext ) async throws -> StreamingServerResponse, producer: @escaping @Sendable ( RPCWriter.Closable @@ -93,21 +99,27 @@ struct ServerRPCExecutorTestHarness { } group.addTask { - let context = ServerContext(descriptor: MethodDescriptor(service: "foo", method: "bar")) - await ServerRPCExecutor.execute( - context: context, - stream: RPCStream( - descriptor: context.descriptor, - inbound: RPCAsyncSequence(wrapping: input.stream), - outbound: RPCWriter.Closable(wrapping: output.continuation) - ), - deserializer: deserializer, - serializer: serializer, - interceptors: self.interceptors, - handler: { stream, context in - try await handler.handle(stream) - } - ) + await withServerContextRPCCancellationHandle { cancellation in + let context = ServerContext( + descriptor: MethodDescriptor(service: "foo", method: "bar"), + cancellation: cancellation + ) + + await ServerRPCExecutor.execute( + context: context, + stream: RPCStream( + descriptor: context.descriptor, + inbound: RPCAsyncSequence(wrapping: input.stream), + outbound: RPCWriter.Closable(wrapping: output.continuation) + ), + deserializer: deserializer, + serializer: serializer, + interceptors: self.interceptors, + handler: { stream, context in + try await handler.handle(stream, context) + } + ) + } } try await group.waitForAll() @@ -135,7 +147,7 @@ struct ServerRPCExecutorTestHarness { extension ServerRPCExecutorTestHarness.ServerHandler where Input == Output { static var echo: Self { - return Self { request in + return Self { request, context in return StreamingServerResponse(metadata: request.metadata) { writer in try await writer.write(contentsOf: request.messages) return [:] diff --git a/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift b/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift index f047955da..0533fe26b 100644 --- a/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift +++ b/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift @@ -83,7 +83,7 @@ final class ServerRPCExecutorTests: XCTestCase { try await harness.execute( deserializer: JSONDeserializer(), serializer: JSONSerializer() - ) { request in + ) { request, _ in let messages = try await request.messages.collect() XCTAssertEqual(messages, ["hello"]) return StreamingServerResponse(metadata: request.metadata) { writer in @@ -112,7 +112,7 @@ final class ServerRPCExecutorTests: XCTestCase { try await harness.execute( deserializer: JSONDeserializer(), serializer: JSONSerializer() - ) { request in + ) { request, _ in let messages = try await request.messages.collect() XCTAssertEqual(messages, ["hello", "world"]) return StreamingServerResponse(metadata: request.metadata) { writer in @@ -144,7 +144,7 @@ final class ServerRPCExecutorTests: XCTestCase { try await harness.execute( deserializer: IdentityDeserializer(), serializer: IdentitySerializer() - ) { request in + ) { request, _ in return StreamingServerResponse(metadata: request.metadata) { _ in return ["bar": "baz"] } @@ -235,15 +235,9 @@ final class ServerRPCExecutorTests: XCTestCase { try await harness.execute( deserializer: IdentityDeserializer(), serializer: IdentitySerializer() - ) { request in - do { - try await Task.sleep(until: .now.advanced(by: .seconds(180)), clock: .continuous) - } catch is CancellationError { - throw RPCError(code: .cancelled, message: "Sleep was cancelled") - } - - XCTFail("Server handler should've been cancelled by timeout.") - return StreamingServerResponse(error: RPCError(code: .failedPrecondition, message: "")) + ) { request, context in + try await context.cancellation.cancelled + throw RPCError(code: .cancelled, message: "Cancelled from server handler") } producer: { inbound in try await inbound.write(.metadata(["grpc-timeout": "1000n"])) await inbound.finish() @@ -251,7 +245,7 @@ final class ServerRPCExecutorTests: XCTestCase { let part = try await outbound.collect().first XCTAssertStatus(part) { status, _ in XCTAssertEqual(status.code, .cancelled) - XCTAssertEqual(status.message, "Sleep was cancelled") + XCTAssertEqual(status.message, "Cancelled from server handler") } } } @@ -268,7 +262,7 @@ final class ServerRPCExecutorTests: XCTestCase { try await harness.execute( deserializer: IdentityDeserializer(), serializer: IdentitySerializer() - ) { request in + ) { request, _ in XCTFail("Unexpected request") return StreamingServerResponse( of: [UInt8].self, diff --git a/Tests/GRPCCoreTests/Call/Server/ServerContextTests.swift b/Tests/GRPCCoreTests/Call/Server/ServerContextTests.swift new file mode 100644 index 000000000..c524519ff --- /dev/null +++ b/Tests/GRPCCoreTests/Call/Server/ServerContextTests.swift @@ -0,0 +1,62 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import GRPCCore +import Testing + +@Suite("ServerContext") +struct ServerContextTests { + @Suite("CancellationHandle") + struct CancellationHandle { + @Test("Is cancelled") + func isCancelled() async throws { + await withServerContextRPCCancellationHandle { handle in + #expect(!handle.isCancelled) + handle.cancel() + #expect(handle.isCancelled) + } + } + + @Test("Wait for cancellation") + func waitForCancellation() async throws { + await withServerContextRPCCancellationHandle { handle in + await withTaskGroup(of: Void.self) { group in + group.addTask { + try? await handle.cancelled + } + handle.cancel() + await group.waitForAll() + } + } + } + + @Test("Binds task local") + func bindsTaskLocal() async throws { + await withServerContextRPCCancellationHandle { handle in + let signal = AsyncStream.makeStream(of: Void.self) + + await withRPCCancellationHandler { + handle.cancel() + for await _ in signal.stream {} + } onCancelRPC: { + // If the task local wasn't bound, this wouldn't run. + signal.continuation.finish() + } + } + + } + } +} diff --git a/Tests/GRPCInProcessTransportTests/InProcessTransportTests.swift b/Tests/GRPCInProcessTransportTests/InProcessTransportTests.swift new file mode 100644 index 000000000..3396b259f --- /dev/null +++ b/Tests/GRPCInProcessTransportTests/InProcessTransportTests.swift @@ -0,0 +1,125 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import GRPCCore +import GRPCInProcessTransport +import Testing + +@Suite("InProcess transport") +struct InProcessTransportTests { + private static let cancellationModes = ["await-cancelled", "with-cancellation-handler"] + + private func withTestServerAndClient( + execute: (GRPCServer, GRPCClient) async throws -> Void + ) async throws { + try await withThrowingDiscardingTaskGroup { group in + let inProcess = InProcessTransport() + + let server = GRPCServer(transport: inProcess.server, services: [TestService()]) + group.addTask { + try await server.serve() + } + + let client = GRPCClient(transport: inProcess.client) + group.addTask { + try await client.run() + } + + try await execute(server, client) + } + } + + @Test("RPC cancelled by graceful shutdown", arguments: Self.cancellationModes) + func cancelledByGracefulShutdown(mode: String) async throws { + try await self.withTestServerAndClient { server, client in + try await client.serverStreaming( + request: ClientRequest(message: mode), + descriptor: .testCancellation, + serializer: UTF8Serializer(), + deserializer: UTF8Deserializer(), + options: .defaults + ) { response in + // Got initial metadata, begin shutdown to cancel the RPC. + server.beginGracefulShutdown() + + // Now wait for the response. + let messages = try await response.messages.reduce(into: []) { $0.append($1) } + #expect(messages == ["isCancelled=true"]) + } + + // Finally, shutdown the client so its run() method returns. + client.beginGracefulShutdown() + } + } +} + +private struct TestService: RegistrableRPCService { + func cancellation( + request: ServerRequest, + context: ServerContext + ) async throws -> StreamingServerResponse { + switch request.message { + case "await-cancelled": + return StreamingServerResponse { body in + try await context.cancellation.cancelled + try await body.write("isCancelled=\(context.cancellation.isCancelled)") + return [:] + } + + case "with-cancellation-handler": + let signal = AsyncStream.makeStream(of: Void.self) + return StreamingServerResponse { body in + try await withRPCCancellationHandler { + for await _ in signal.stream {} + try await body.write("isCancelled=\(context.cancellation.isCancelled)") + return [:] + } onCancelRPC: { + signal.continuation.finish() + } + } + + default: + throw RPCError(code: .invalidArgument, message: "Invalid argument '\(request.message)'") + } + } + + func registerMethods(with router: inout RPCRouter) { + router.registerHandler( + forMethod: .testCancellation, + deserializer: UTF8Deserializer(), + serializer: UTF8Serializer(), + handler: { + try await self.cancellation(request: ServerRequest(stream: $0), context: $1) + } + ) + } +} + +extension MethodDescriptor { + fileprivate static let testCancellation = Self(service: "test", method: "cancellation") +} + +private struct UTF8Serializer: MessageSerializer { + func serialize(_ message: String) throws -> [UInt8] { + Array(message.utf8) + } +} + +private struct UTF8Deserializer: MessageDeserializer { + func deserialize(_ serializedMessageBytes: [UInt8]) throws -> String { + String(decoding: serializedMessageBytes, as: UTF8.self) + } +} From 5d58b1f55d2c7c2ad2bbd0e355b6437c01601130 Mon Sep 17 00:00:00 2001 From: George Barnett Date: Wed, 9 Oct 2024 12:47:09 +0100 Subject: [PATCH 2/3] Update doc --- .../Call/Server/ServerContext+RPCCancellationHandle.swift | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Sources/GRPCCore/Call/Server/ServerContext+RPCCancellationHandle.swift b/Sources/GRPCCore/Call/Server/ServerContext+RPCCancellationHandle.swift index 90f609df9..5e0f63367 100644 --- a/Sources/GRPCCore/Call/Server/ServerContext+RPCCancellationHandle.swift +++ b/Sources/GRPCCore/Call/Server/ServerContext+RPCCancellationHandle.swift @@ -93,8 +93,10 @@ public func withRPCCancellationHandler( /// Provides scoped access to a server RPC cancellation handle. /// /// The cancellation handle should be passed to a ``ServerContext`` and last -/// the duration of the RPC. This function is intended for use when implementing -/// a ``ServerTransport``. +/// the duration of the RPC. +/// +/// - Important: This function is intended for use when implementing +/// a ``ServerTransport``. /// /// If you want to be notified about RPCs being cancelled /// use ``withRPCCancellationHandler(operation:onCancelRPC:)``. From df15ef81d3d25858da6215b3c7848eba0e61aecf Mon Sep 17 00:00:00 2001 From: George Barnett Date: Wed, 9 Oct 2024 12:47:56 +0100 Subject: [PATCH 3/3] Improve name --- .../InProcessTransport+Server.swift | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Sources/GRPCInProcessTransport/InProcessTransport+Server.swift b/Sources/GRPCInProcessTransport/InProcessTransport+Server.swift index c5f0a55fc..02b132ac8 100644 --- a/Sources/GRPCInProcessTransport/InProcessTransport+Server.swift +++ b/Sources/GRPCInProcessTransport/InProcessTransport+Server.swift @@ -36,18 +36,18 @@ extension InProcessTransport { private let newStreamsContinuation: AsyncStream>.Continuation private struct State: Sendable { - private var id: UInt64 + private var _nextID: UInt64 private var handles: [UInt64: ServerContext.RPCCancellationHandle] private var isShutdown: Bool private mutating func nextID() -> UInt64 { - let id = self.id - self.id &+= 1 + let id = self._nextID + self._nextID &+= 1 return id } init() { - self.id = 0 + self._nextID = 0 self.handles = [:] self.isShutdown = false }