diff --git a/Sources/GRPCCore/Call/Server/Internal/ServerCancellationManager.swift b/Sources/GRPCCore/Call/Server/Internal/ServerCancellationManager.swift new file mode 100644 index 000000000..471f9d007 --- /dev/null +++ b/Sources/GRPCCore/Call/Server/Internal/ServerCancellationManager.swift @@ -0,0 +1,254 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +private import Synchronization + +/// Stores cancellation state for an RPC on the server . +package final class ServerCancellationManager: Sendable { + private let state: Mutex + + package init() { + self.state = Mutex(State()) + } + + /// Returns whether the RPC has been marked as cancelled. + package var isRPCCancelled: Bool { + self.state.withLock { + return $0.isRPCCancelled + } + } + + /// Marks the RPC as cancelled, potentially running any cancellation handlers. + package func cancelRPC() { + switch self.state.withLock({ $0.cancelRPC() }) { + case .executeAndResume(let onCancelHandlers, let onCancelWaiters): + for handler in onCancelHandlers { + handler.handler() + } + + for onCancelWaiter in onCancelWaiters { + switch onCancelWaiter { + case .taskCancelled: + () + case .waiting(_, let continuation): + continuation.resume(returning: .rpc) + } + } + + case .doNothing: + () + } + } + + /// Adds a handler which is invoked when the RPC is cancelled. + /// + /// - Returns: The ID of the handler, if it was added, or `nil` if the RPC is already cancelled. + package func addRPCCancelledHandler(_ handler: @Sendable @escaping () -> Void) -> UInt64? { + return self.state.withLock { state -> UInt64? in + state.addRPCCancelledHandler(handler) + } + } + + /// Removes a handler by its ID. + package func removeRPCCancelledHandler(withID id: UInt64) { + self.state.withLock { state in + state.removeRPCCancelledHandler(withID: id) + } + } + + /// Suspends until the RPC is cancelled or the `Task` is cancelled. + package func suspendUntilRPCIsCancelled() async throws(CancellationError) { + let id = self.state.withLock { $0.nextID() } + + let source = await withTaskCancellationHandler { + await withCheckedContinuation { continuation in + let onAddWaiter = self.state.withLock { + $0.addRPCIsCancelledWaiter(continuation: continuation, withID: id) + } + + switch onAddWaiter { + case .doNothing: + () + case .complete(let continuation, let result): + continuation.resume(returning: result) + } + } + } onCancel: { + switch self.state.withLock({ $0.cancelRPCCancellationWaiter(withID: id) }) { + case .resume(let continuation, let result): + continuation.resume(returning: result) + case .doNothing: + () + } + } + + switch source { + case .rpc: + () + case .task: + throw CancellationError() + } + } +} + +extension ServerCancellationManager { + enum CancellationSource { + case rpc + case task + } + + struct Handler: Sendable { + var id: UInt64 + var handler: @Sendable () -> Void + } + + enum Waiter: Sendable { + case waiting(UInt64, CheckedContinuation) + case taskCancelled(UInt64) + + var id: UInt64 { + switch self { + case .waiting(let id, _): + return id + case .taskCancelled(let id): + return id + } + } + } + + struct State { + private var handlers: [Handler] + private var waiters: [Waiter] + private var _nextID: UInt64 + var isRPCCancelled: Bool + + mutating func nextID() -> UInt64 { + let id = self._nextID + self._nextID &+= 1 + return id + } + + init() { + self.handlers = [] + self.waiters = [] + self._nextID = 0 + self.isRPCCancelled = false + } + + mutating func cancelRPC() -> OnCancelRPC { + let onCancel: OnCancelRPC + + if self.isRPCCancelled { + onCancel = .doNothing + } else { + self.isRPCCancelled = true + onCancel = .executeAndResume(self.handlers, self.waiters) + self.handlers = [] + self.waiters = [] + } + + return onCancel + } + + mutating func addRPCCancelledHandler(_ handler: @Sendable @escaping () -> Void) -> UInt64? { + if self.isRPCCancelled { + handler() + return nil + } else { + let id = self.nextID() + self.handlers.append(.init(id: id, handler: handler)) + return id + } + } + + mutating func removeRPCCancelledHandler(withID id: UInt64) { + if let index = self.handlers.firstIndex(where: { $0.id == id }) { + self.handlers.remove(at: index) + } + } + + enum OnCancelRPC { + case executeAndResume([Handler], [Waiter]) + case doNothing + } + + enum OnAddWaiter { + case complete(CheckedContinuation, CancellationSource) + case doNothing + } + + mutating func addRPCIsCancelledWaiter( + continuation: CheckedContinuation, + withID id: UInt64 + ) -> OnAddWaiter { + let onAddWaiter: OnAddWaiter + + if self.isRPCCancelled { + onAddWaiter = .complete(continuation, .rpc) + } else if let index = self.waiters.firstIndex(where: { $0.id == id }) { + switch self.waiters[index] { + case .taskCancelled: + onAddWaiter = .complete(continuation, .task) + case .waiting: + // There's already a continuation enqueued. + fatalError("Inconsistent state") + } + } else { + self.waiters.append(.waiting(id, continuation)) + onAddWaiter = .doNothing + } + + return onAddWaiter + } + + enum OnCancelRPCCancellationWaiter { + case resume(CheckedContinuation, CancellationSource) + case doNothing + } + + mutating func cancelRPCCancellationWaiter(withID id: UInt64) -> OnCancelRPCCancellationWaiter { + let onCancelWaiter: OnCancelRPCCancellationWaiter + + if let index = self.waiters.firstIndex(where: { $0.id == id }) { + let waiter = self.waiters.removeWithoutMaintainingOrder(at: index) + switch waiter { + case .taskCancelled: + onCancelWaiter = .doNothing + case .waiting(_, let continuation): + onCancelWaiter = .resume(continuation, .task) + } + } else { + self.waiters.append(.taskCancelled(id)) + onCancelWaiter = .doNothing + } + + return onCancelWaiter + } + } +} + +extension Array { + fileprivate mutating func removeWithoutMaintainingOrder(at index: Int) -> Element { + let lastElementIndex = self.index(before: self.endIndex) + + if index == lastElementIndex { + return self.remove(at: index) + } else { + self.swapAt(index, lastElementIndex) + return self.removeLast() + } + } +} diff --git a/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift b/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift index f85cbe318..d9a35da51 100644 --- a/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift +++ b/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift @@ -119,43 +119,29 @@ struct ServerRPCExecutor { _ context: ServerContext ) async throws -> StreamingServerResponse ) async { - await withTaskGroup(of: ServerExecutorTask.self) { group in + await withTaskGroup(of: Void.self) { group in group.addTask { - let result = await Result { + do { try await Task.sleep(for: timeout, clock: .continuous) + context.cancellation.cancel() + } catch { + () // Only cancel the RPC if the timeout completes. } - return .timedOut(result) } - group.addTask { - await Self._processRPC( - context: context, - metadata: metadata, - inbound: inbound, - outbound: outbound, - deserializer: deserializer, - serializer: serializer, - interceptors: interceptors, - handler: handler - ) - return .executed - } - - while let next = await group.next() { - switch next { - case .timedOut(.success): - // Timeout expired; cancel the work. - group.cancelAll() - - case .timedOut(.failure): - // Timeout failed (because it was cancelled). Wait for more tasks to finish. - () + await Self._processRPC( + context: context, + metadata: metadata, + inbound: inbound, + outbound: outbound, + deserializer: deserializer, + serializer: serializer, + interceptors: interceptors, + handler: handler + ) - case .executed: - // The work finished. Cancel any remaining tasks. - group.cancelAll() - } - } + // Cancel the timeout + group.cancelAll() } } diff --git a/Sources/GRPCCore/Call/Server/ServerContext+RPCCancellationHandle.swift b/Sources/GRPCCore/Call/Server/ServerContext+RPCCancellationHandle.swift new file mode 100644 index 000000000..5e0f63367 --- /dev/null +++ b/Sources/GRPCCore/Call/Server/ServerContext+RPCCancellationHandle.swift @@ -0,0 +1,117 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import Synchronization + +extension ServerContext { + @TaskLocal + internal static var rpcCancellation: RPCCancellationHandle? + + /// A handle for the cancellation status of the RPC. + public struct RPCCancellationHandle: Sendable { + internal let manager: ServerCancellationManager + + /// Create a cancellation handle. + /// + /// To create an instance of this handle appropriately bound to a `Task` + /// use ``withServerContextRPCCancellationHandle(_:)``. + public init() { + self.manager = ServerCancellationManager() + } + + /// Returns whether the RPC has been cancelled. + public var isCancelled: Bool { + self.manager.isRPCCancelled + } + + /// Waits until the RPC has been cancelled. + /// + /// Throws a `CancellationError` if the `Task` is cancelled. + /// + /// You can also be notified when an RPC is cancelled by using + /// ``withRPCCancellationHandler(operation:onCancelRPC:)``. + public var cancelled: Void { + get async throws { + try await self.manager.suspendUntilRPCIsCancelled() + } + } + + /// Signal that the RPC should be cancelled. + /// + /// This is idempotent: calling it more than once has no effect. + public func cancel() { + self.manager.cancelRPC() + } + } +} + +/// Execute an operation with an RPC cancellation handler that's immediately invoked +/// if the RPC is canceled. +/// +/// RPCs can be cancelled for a number of reasons including: +/// 1. The RPC was taking too long to process and a timeout passed. +/// 2. The remote peer closed the underlying stream, either because they were no longer +/// interested in the result or due to a broken connection. +/// 3. The server began shutting down. +/// +/// - Important: This only applies to RPCs on the server. +/// - Parameters: +/// - operation: The operation to execute. +/// - handler: The handler which is invoked when the RPC is cancelled. +/// - Throws: Any error thrown by the `operation` closure. +/// - Returns: The result of the `operation` closure. +public func withRPCCancellationHandler( + operation: () async throws(Failure) -> Result, + onCancelRPC handler: @Sendable @escaping () -> Void +) async throws(Failure) -> Result { + guard let manager = ServerContext.rpcCancellation?.manager, + let id = manager.addRPCCancelledHandler(handler) + else { + return try await operation() + } + + defer { + manager.removeRPCCancelledHandler(withID: id) + } + + return try await operation() +} + +/// Provides scoped access to a server RPC cancellation handle. +/// +/// The cancellation handle should be passed to a ``ServerContext`` and last +/// the duration of the RPC. +/// +/// - Important: This function is intended for use when implementing +/// a ``ServerTransport``. +/// +/// If you want to be notified about RPCs being cancelled +/// use ``withRPCCancellationHandler(operation:onCancelRPC:)``. +/// +/// - Parameter operation: The operation to execute with the handle. +public func withServerContextRPCCancellationHandle( + _ operation: (ServerContext.RPCCancellationHandle) async throws(Failure) -> Success +) async throws(Failure) -> Success { + let handle = ServerContext.RPCCancellationHandle() + let result = await ServerContext.$rpcCancellation.withValue(handle) { + // Wrap up the outcome in a result as 'withValue' doesn't support typed throws. + return await Swift.Result { () async throws(Failure) -> Success in + return try await operation(handle) + } + } + + return try result.get() +} diff --git a/Sources/GRPCCore/Call/Server/ServerContext.swift b/Sources/GRPCCore/Call/Server/ServerContext.swift index a11f09acb..4d8613f93 100644 --- a/Sources/GRPCCore/Call/Server/ServerContext.swift +++ b/Sources/GRPCCore/Call/Server/ServerContext.swift @@ -19,8 +19,17 @@ public struct ServerContext: Sendable { /// A description of the method being called. public var descriptor: MethodDescriptor + /// A handle for checking the cancellation status of an RPC. + public var cancellation: RPCCancellationHandle + /// Create a new server context. - public init(descriptor: MethodDescriptor) { + /// + /// - Parameters: + /// - descriptor: A description of the method being called. + /// - cancellation: A cancellation handle. You can create a cancellation handle + /// using ``withServerContextRPCCancellationHandle(_:)``. + public init(descriptor: MethodDescriptor, cancellation: RPCCancellationHandle) { self.descriptor = descriptor + self.cancellation = cancellation } } diff --git a/Sources/GRPCCore/Internal/Result+Catching.swift b/Sources/GRPCCore/Internal/Result+Catching.swift index bf2393752..8f9cbe59c 100644 --- a/Sources/GRPCCore/Internal/Result+Catching.swift +++ b/Sources/GRPCCore/Internal/Result+Catching.swift @@ -14,12 +14,12 @@ * limitations under the License. */ -extension Result where Failure == any Error { +extension Result { /// Like `Result(catching:)`, but `async`. /// /// - Parameter body: An `async` closure to catch the result of. @inlinable - init(catching body: () async throws -> Success) async { + init(catching body: () async throws(Failure) -> Success) async { do { self = .success(try await body()) } catch { diff --git a/Sources/GRPCInProcessTransport/InProcessTransport+Server.swift b/Sources/GRPCInProcessTransport/InProcessTransport+Server.swift index 66e32d06f..02b132ac8 100644 --- a/Sources/GRPCInProcessTransport/InProcessTransport+Server.swift +++ b/Sources/GRPCInProcessTransport/InProcessTransport+Server.swift @@ -15,6 +15,7 @@ */ public import GRPCCore +private import Synchronization extension InProcessTransport { /// An in-process implementation of a ``ServerTransport``. @@ -27,16 +28,54 @@ extension InProcessTransport { /// To stop listening to new requests, call ``beginGracefulShutdown()``. /// /// - SeeAlso: ``ClientTransport`` - public struct Server: ServerTransport, Sendable { + public final class Server: ServerTransport, Sendable { public typealias Inbound = RPCAsyncSequence public typealias Outbound = RPCWriter.Closable private let newStreams: AsyncStream> private let newStreamsContinuation: AsyncStream>.Continuation + private struct State: Sendable { + private var _nextID: UInt64 + private var handles: [UInt64: ServerContext.RPCCancellationHandle] + private var isShutdown: Bool + + private mutating func nextID() -> UInt64 { + let id = self._nextID + self._nextID &+= 1 + return id + } + + init() { + self._nextID = 0 + self.handles = [:] + self.isShutdown = false + } + + mutating func addHandle(_ handle: ServerContext.RPCCancellationHandle) -> (UInt64, Bool) { + let handleID = self.nextID() + self.handles[handleID] = handle + return (handleID, self.isShutdown) + } + + mutating func removeHandle(withID id: UInt64) { + self.handles.removeValue(forKey: id) + } + + mutating func beginShutdown() -> [ServerContext.RPCCancellationHandle] { + self.isShutdown = true + let values = Array(self.handles.values) + self.handles.removeAll() + return values + } + } + + private let handles: Mutex + /// Creates a new instance of ``Server``. public init() { (self.newStreams, self.newStreamsContinuation) = AsyncStream.makeStream() + self.handles = Mutex(State()) } /// Publish a new ``RPCStream``, which will be returned by the transport's ``events`` @@ -64,8 +103,21 @@ extension InProcessTransport { await withDiscardingTaskGroup { group in for await stream in self.newStreams { group.addTask { - let context = ServerContext(descriptor: stream.descriptor) - await streamHandler(stream, context) + await withServerContextRPCCancellationHandle { handle in + let (id, isShutdown) = self.handles.withLock({ $0.addHandle(handle) }) + defer { + self.handles.withLock { $0.removeHandle(withID: id) } + } + + // This happens if `beginGracefulShutdown` is called after the stream is added to + // new streams but before it's dequeued. + if isShutdown { + handle.cancel() + } + + let context = ServerContext(descriptor: stream.descriptor, cancellation: handle) + await streamHandler(stream, context) + } } } } @@ -76,6 +128,9 @@ extension InProcessTransport { /// - SeeAlso: ``ServerTransport`` public func beginGracefulShutdown() { self.newStreamsContinuation.finish() + for handle in self.handles.withLock({ $0.beginShutdown() }) { + handle.cancel() + } } } } diff --git a/Tests/GRPCCoreTests/Call/Server/Internal/ServerCancellationManagerTests.swift b/Tests/GRPCCoreTests/Call/Server/Internal/ServerCancellationManagerTests.swift new file mode 100644 index 000000000..45c851de8 --- /dev/null +++ b/Tests/GRPCCoreTests/Call/Server/Internal/ServerCancellationManagerTests.swift @@ -0,0 +1,91 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import GRPCCore +import Testing + +@Suite +struct ServerCancellationManagerTests { + @Test("Isn't cancelled after init") + func isNotCancelled() { + let manager = ServerCancellationManager() + #expect(!manager.isRPCCancelled) + } + + @Test("Is cancelled") + func isCancelled() { + let manager = ServerCancellationManager() + manager.cancelRPC() + #expect(manager.isRPCCancelled) + } + + @Test("Cancellation handler runs") + func addCancellationHandler() async throws { + let manager = ServerCancellationManager() + let signal = AsyncStream.makeStream(of: Void.self) + + let id = manager.addRPCCancelledHandler { + signal.continuation.finish() + } + + #expect(id != nil) + manager.cancelRPC() + let events: [Void] = await signal.stream.reduce(into: []) { $0.append($1) } + #expect(events.isEmpty) + } + + @Test("Cancellation handler runs immediately when already cancelled") + func addCancellationHandlerAfterCancelled() async throws { + let manager = ServerCancellationManager() + let signal = AsyncStream.makeStream(of: Void.self) + manager.cancelRPC() + + let id = manager.addRPCCancelledHandler { + signal.continuation.finish() + } + + #expect(id == nil) + let events: [Void] = await signal.stream.reduce(into: []) { $0.append($1) } + #expect(events.isEmpty) + } + + @Test("Remove cancellation handler") + func removeCancellationHandler() async throws { + let manager = ServerCancellationManager() + let signal = AsyncStream.makeStream(of: Void.self) + + let id = manager.addRPCCancelledHandler { + Issue.record("Unexpected cancellation") + } + + #expect(id != nil) + manager.removeRPCCancelledHandler(withID: id!) + manager.cancelRPC() + } + + @Test("Wait for cancellation") + func waitForCancellation() async throws { + let manager = ServerCancellationManager() + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await manager.suspendUntilRPCIsCancelled() + } + + manager.cancelRPC() + try await group.waitForAll() + } + } +} diff --git a/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTestSupport/ServerRPCExecutorTestHarness.swift b/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTestSupport/ServerRPCExecutorTestHarness.swift index 6cca2d4d2..e645c5c20 100644 --- a/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTestSupport/ServerRPCExecutorTestHarness.swift +++ b/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTestSupport/ServerRPCExecutorTestHarness.swift @@ -20,24 +20,29 @@ import XCTest struct ServerRPCExecutorTestHarness { struct ServerHandler: Sendable { let fn: - @Sendable (StreamingServerRequest) async throws -> StreamingServerResponse + @Sendable ( + _ request: StreamingServerRequest, + _ context: ServerContext + ) async throws -> StreamingServerResponse init( _ fn: @escaping @Sendable ( - StreamingServerRequest + _ request: StreamingServerRequest, + _ context: ServerContext ) async throws -> StreamingServerResponse ) { self.fn = fn } func handle( - _ request: StreamingServerRequest + _ request: StreamingServerRequest, + _ context: ServerContext ) async throws -> StreamingServerResponse { - try await self.fn(request) + try await self.fn(request, context) } static func throwing(_ error: any Error) -> Self { - return Self { _ in throw error } + return Self { _, _ in throw error } } } @@ -51,7 +56,8 @@ struct ServerRPCExecutorTestHarness { deserializer: some MessageDeserializer, serializer: some MessageSerializer, handler: @escaping @Sendable ( - StreamingServerRequest + StreamingServerRequest, + ServerContext ) async throws -> StreamingServerResponse, producer: @escaping @Sendable ( RPCWriter.Closable @@ -93,21 +99,27 @@ struct ServerRPCExecutorTestHarness { } group.addTask { - let context = ServerContext(descriptor: MethodDescriptor(service: "foo", method: "bar")) - await ServerRPCExecutor.execute( - context: context, - stream: RPCStream( - descriptor: context.descriptor, - inbound: RPCAsyncSequence(wrapping: input.stream), - outbound: RPCWriter.Closable(wrapping: output.continuation) - ), - deserializer: deserializer, - serializer: serializer, - interceptors: self.interceptors, - handler: { stream, context in - try await handler.handle(stream) - } - ) + await withServerContextRPCCancellationHandle { cancellation in + let context = ServerContext( + descriptor: MethodDescriptor(service: "foo", method: "bar"), + cancellation: cancellation + ) + + await ServerRPCExecutor.execute( + context: context, + stream: RPCStream( + descriptor: context.descriptor, + inbound: RPCAsyncSequence(wrapping: input.stream), + outbound: RPCWriter.Closable(wrapping: output.continuation) + ), + deserializer: deserializer, + serializer: serializer, + interceptors: self.interceptors, + handler: { stream, context in + try await handler.handle(stream, context) + } + ) + } } try await group.waitForAll() @@ -135,7 +147,7 @@ struct ServerRPCExecutorTestHarness { extension ServerRPCExecutorTestHarness.ServerHandler where Input == Output { static var echo: Self { - return Self { request in + return Self { request, context in return StreamingServerResponse(metadata: request.metadata) { writer in try await writer.write(contentsOf: request.messages) return [:] diff --git a/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift b/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift index f047955da..0533fe26b 100644 --- a/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift +++ b/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift @@ -83,7 +83,7 @@ final class ServerRPCExecutorTests: XCTestCase { try await harness.execute( deserializer: JSONDeserializer(), serializer: JSONSerializer() - ) { request in + ) { request, _ in let messages = try await request.messages.collect() XCTAssertEqual(messages, ["hello"]) return StreamingServerResponse(metadata: request.metadata) { writer in @@ -112,7 +112,7 @@ final class ServerRPCExecutorTests: XCTestCase { try await harness.execute( deserializer: JSONDeserializer(), serializer: JSONSerializer() - ) { request in + ) { request, _ in let messages = try await request.messages.collect() XCTAssertEqual(messages, ["hello", "world"]) return StreamingServerResponse(metadata: request.metadata) { writer in @@ -144,7 +144,7 @@ final class ServerRPCExecutorTests: XCTestCase { try await harness.execute( deserializer: IdentityDeserializer(), serializer: IdentitySerializer() - ) { request in + ) { request, _ in return StreamingServerResponse(metadata: request.metadata) { _ in return ["bar": "baz"] } @@ -235,15 +235,9 @@ final class ServerRPCExecutorTests: XCTestCase { try await harness.execute( deserializer: IdentityDeserializer(), serializer: IdentitySerializer() - ) { request in - do { - try await Task.sleep(until: .now.advanced(by: .seconds(180)), clock: .continuous) - } catch is CancellationError { - throw RPCError(code: .cancelled, message: "Sleep was cancelled") - } - - XCTFail("Server handler should've been cancelled by timeout.") - return StreamingServerResponse(error: RPCError(code: .failedPrecondition, message: "")) + ) { request, context in + try await context.cancellation.cancelled + throw RPCError(code: .cancelled, message: "Cancelled from server handler") } producer: { inbound in try await inbound.write(.metadata(["grpc-timeout": "1000n"])) await inbound.finish() @@ -251,7 +245,7 @@ final class ServerRPCExecutorTests: XCTestCase { let part = try await outbound.collect().first XCTAssertStatus(part) { status, _ in XCTAssertEqual(status.code, .cancelled) - XCTAssertEqual(status.message, "Sleep was cancelled") + XCTAssertEqual(status.message, "Cancelled from server handler") } } } @@ -268,7 +262,7 @@ final class ServerRPCExecutorTests: XCTestCase { try await harness.execute( deserializer: IdentityDeserializer(), serializer: IdentitySerializer() - ) { request in + ) { request, _ in XCTFail("Unexpected request") return StreamingServerResponse( of: [UInt8].self, diff --git a/Tests/GRPCCoreTests/Call/Server/ServerContextTests.swift b/Tests/GRPCCoreTests/Call/Server/ServerContextTests.swift new file mode 100644 index 000000000..c524519ff --- /dev/null +++ b/Tests/GRPCCoreTests/Call/Server/ServerContextTests.swift @@ -0,0 +1,62 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import GRPCCore +import Testing + +@Suite("ServerContext") +struct ServerContextTests { + @Suite("CancellationHandle") + struct CancellationHandle { + @Test("Is cancelled") + func isCancelled() async throws { + await withServerContextRPCCancellationHandle { handle in + #expect(!handle.isCancelled) + handle.cancel() + #expect(handle.isCancelled) + } + } + + @Test("Wait for cancellation") + func waitForCancellation() async throws { + await withServerContextRPCCancellationHandle { handle in + await withTaskGroup(of: Void.self) { group in + group.addTask { + try? await handle.cancelled + } + handle.cancel() + await group.waitForAll() + } + } + } + + @Test("Binds task local") + func bindsTaskLocal() async throws { + await withServerContextRPCCancellationHandle { handle in + let signal = AsyncStream.makeStream(of: Void.self) + + await withRPCCancellationHandler { + handle.cancel() + for await _ in signal.stream {} + } onCancelRPC: { + // If the task local wasn't bound, this wouldn't run. + signal.continuation.finish() + } + } + + } + } +} diff --git a/Tests/GRPCInProcessTransportTests/InProcessTransportTests.swift b/Tests/GRPCInProcessTransportTests/InProcessTransportTests.swift new file mode 100644 index 000000000..3396b259f --- /dev/null +++ b/Tests/GRPCInProcessTransportTests/InProcessTransportTests.swift @@ -0,0 +1,125 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import GRPCCore +import GRPCInProcessTransport +import Testing + +@Suite("InProcess transport") +struct InProcessTransportTests { + private static let cancellationModes = ["await-cancelled", "with-cancellation-handler"] + + private func withTestServerAndClient( + execute: (GRPCServer, GRPCClient) async throws -> Void + ) async throws { + try await withThrowingDiscardingTaskGroup { group in + let inProcess = InProcessTransport() + + let server = GRPCServer(transport: inProcess.server, services: [TestService()]) + group.addTask { + try await server.serve() + } + + let client = GRPCClient(transport: inProcess.client) + group.addTask { + try await client.run() + } + + try await execute(server, client) + } + } + + @Test("RPC cancelled by graceful shutdown", arguments: Self.cancellationModes) + func cancelledByGracefulShutdown(mode: String) async throws { + try await self.withTestServerAndClient { server, client in + try await client.serverStreaming( + request: ClientRequest(message: mode), + descriptor: .testCancellation, + serializer: UTF8Serializer(), + deserializer: UTF8Deserializer(), + options: .defaults + ) { response in + // Got initial metadata, begin shutdown to cancel the RPC. + server.beginGracefulShutdown() + + // Now wait for the response. + let messages = try await response.messages.reduce(into: []) { $0.append($1) } + #expect(messages == ["isCancelled=true"]) + } + + // Finally, shutdown the client so its run() method returns. + client.beginGracefulShutdown() + } + } +} + +private struct TestService: RegistrableRPCService { + func cancellation( + request: ServerRequest, + context: ServerContext + ) async throws -> StreamingServerResponse { + switch request.message { + case "await-cancelled": + return StreamingServerResponse { body in + try await context.cancellation.cancelled + try await body.write("isCancelled=\(context.cancellation.isCancelled)") + return [:] + } + + case "with-cancellation-handler": + let signal = AsyncStream.makeStream(of: Void.self) + return StreamingServerResponse { body in + try await withRPCCancellationHandler { + for await _ in signal.stream {} + try await body.write("isCancelled=\(context.cancellation.isCancelled)") + return [:] + } onCancelRPC: { + signal.continuation.finish() + } + } + + default: + throw RPCError(code: .invalidArgument, message: "Invalid argument '\(request.message)'") + } + } + + func registerMethods(with router: inout RPCRouter) { + router.registerHandler( + forMethod: .testCancellation, + deserializer: UTF8Deserializer(), + serializer: UTF8Serializer(), + handler: { + try await self.cancellation(request: ServerRequest(stream: $0), context: $1) + } + ) + } +} + +extension MethodDescriptor { + fileprivate static let testCancellation = Self(service: "test", method: "cancellation") +} + +private struct UTF8Serializer: MessageSerializer { + func serialize(_ message: String) throws -> [UInt8] { + Array(message.utf8) + } +} + +private struct UTF8Deserializer: MessageDeserializer { + func deserialize(_ serializedMessageBytes: [UInt8]) throws -> String { + String(decoding: serializedMessageBytes, as: UTF8.self) + } +}