diff --git a/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift b/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift index a67bfbe37..9ad35e5a9 100644 --- a/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift +++ b/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift @@ -120,43 +120,30 @@ struct ServerRPCExecutor { _ context: ServerContext ) async throws -> ServerResponse.Stream ) async { - await withTaskGroup(of: ServerExecutorTask.self) { group in + await withTaskGroup(of: Void.self) { group in group.addTask { - let result = await Result { + do { try await Task.sleep(for: timeout, clock: .continuous) + // Cancel the RPC if the timeout passes. + context.streamState.events.continuation.yield(.rpcCancelled) + } catch { + () // Sleep was cancelled, the RPC completed. } - return .timedOut(result) } - group.addTask { - await Self._processRPC( - context: context, - metadata: metadata, - inbound: inbound, - outbound: outbound, - deserializer: deserializer, - serializer: serializer, - interceptors: interceptors, - handler: handler - ) - return .executed - } - - while let next = await group.next() { - switch next { - case .timedOut(.success): - // Timeout expired; cancel the work. - group.cancelAll() - - case .timedOut(.failure): - // Timeout failed (because it was cancelled). Wait for more tasks to finish. - () + await Self._processRPC( + context: context, + metadata: metadata, + inbound: inbound, + outbound: outbound, + deserializer: deserializer, + serializer: serializer, + interceptors: interceptors, + handler: handler + ) - case .executed: - // The work finished. Cancel any remaining tasks. - group.cancelAll() - } - } + // Cancel the timeout, if it's still running. + group.cancelAll() } } diff --git a/Sources/GRPCCore/Call/Server/RPCRouter.swift b/Sources/GRPCCore/Call/Server/RPCRouter.swift index bc2f58fef..490133743 100644 --- a/Sources/GRPCCore/Call/Server/RPCRouter.swift +++ b/Sources/GRPCCore/Call/Server/RPCRouter.swift @@ -155,7 +155,7 @@ extension RPCRouter { context: ServerContext, interceptors: [any ServerInterceptor] ) async { - if let handler = self.handlers[stream.descriptor] { + if let handler = self.handlers[context.descriptor] { await handler.handle(stream: stream, context: context, interceptors: interceptors) } else { // If this throws then the stream must be closed which we can't do anything about, so ignore diff --git a/Sources/GRPCCore/Call/Server/ServerContext.swift b/Sources/GRPCCore/Call/Server/ServerContext.swift index a11f09acb..c3fb57fdf 100644 --- a/Sources/GRPCCore/Call/Server/ServerContext.swift +++ b/Sources/GRPCCore/Call/Server/ServerContext.swift @@ -15,12 +15,17 @@ */ /// Additional information about an RPC handled by a server. +@available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *) public struct ServerContext: Sendable { /// A description of the method being called. public var descriptor: MethodDescriptor + /// The state of the server stream. + public var streamState: ServerStreamState + /// Create a new server context. - public init(descriptor: MethodDescriptor) { + public init(descriptor: MethodDescriptor, streamState: ServerStreamState) { self.descriptor = descriptor + self.streamState = streamState } } diff --git a/Sources/GRPCCore/Call/Server/ServerStreamEvent.swift b/Sources/GRPCCore/Call/Server/ServerStreamEvent.swift new file mode 100644 index 000000000..c9defe709 --- /dev/null +++ b/Sources/GRPCCore/Call/Server/ServerStreamEvent.swift @@ -0,0 +1,47 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/// An out-of-band event which can happen to the underlying stream +/// which an RPC is executing on. +public struct ServerStreamEvent: Hashable, Sendable { + internal enum Value: Hashable, Sendable { + case rpcCancelled + } + + internal var value: Value + + private init(_ value: Value) { + self.value = value + } + + /// The RPC was cancelled and the service should stop processing it. + /// + /// RPCs can be cancelled for a number of reasons including, but not limited to: + /// - it took too long to complete + /// - the client closed the underlying stream + /// - the stream closed unexpectedly (due to a network failure, for example) + /// - the server initiated a graceful shutdown + /// + /// You should stop processing the RPC and cleanup any associated state if you + /// receive this event. + public static let rpcCancelled = Self(.rpcCancelled) +} + +extension ServerStreamEvent: CustomStringConvertible { + public var description: String { + String(describing: self.value) + } +} diff --git a/Sources/GRPCCore/Call/Server/ServerStreamState.swift b/Sources/GRPCCore/Call/Server/ServerStreamState.swift new file mode 100644 index 000000000..43ff5d884 --- /dev/null +++ b/Sources/GRPCCore/Call/Server/ServerStreamState.swift @@ -0,0 +1,213 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import Synchronization + +@available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *) +public struct ServerStreamState: Sendable { + /// Returns whether the RPC has been cancelled. + /// + /// - SeeAlso: ``ServerStreamEvent/rpcCancelled``. + public var isRPCCancelled: Bool { + self.events.isRPCCancelled + } + + /// Events which can happen to the underlying stream the RPC is being run on. + public let events: Events + + private init(events: Events) { + self.events = events + } + + public static func makeState() -> (streamState: Self, eventContinuation: Events.Continuation) { + let events = Events() + return (ServerStreamState(events: events), eventContinuation: events.continuation) + } +} + +@available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *) +extension ServerStreamState { + /// An `AsyncSequence` of events which can happen to the stream. + /// + /// Each event will be delivered at most once. + /// + /// - Note: This sequence supports _multiple_ concurrent iterators. + public struct Events { + private let storage: Storage + @usableFromInline + internal let continuation: Continuation + } +} + +@available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *) +extension ServerStreamState.Events: AsyncSequence, Sendable { + public typealias Element = ServerStreamEvent + public typealias Failure = Never + + init() { + self.storage = Storage() + self.continuation = Continuation(storage: self.storage) + } + + fileprivate var isRPCCancelled: Bool { + self.storage.eventSet(contains: .rpcCancelled) + } + + public func makeAsyncIterator() -> AsyncIterator { + let streamEvents = AsyncStream.makeStream(of: ServerStreamEvent.self) + self.storage.registerContinuation(streamEvents.continuation) + return AsyncIterator(iterator: streamEvents.stream.makeAsyncIterator()) + } + + public struct AsyncIterator: AsyncIteratorProtocol { + private var iterator: AsyncStream.AsyncIterator + + fileprivate init(iterator: AsyncStream.AsyncIterator) { + self.iterator = iterator + } + + public mutating func next() async throws(Never) -> ServerStreamEvent? { + await self.next(isolation: nil) + } + + public mutating func next( + isolation actor: isolated (any Actor)? + ) async throws(Never) -> ServerStreamEvent? { + return await self.iterator.next(isolation: actor) + } + } + + public struct Continuation: Sendable { + private let storage: Storage + + init(storage: Storage) { + self.storage = storage + } + + /// Yield an event to the stream. + /// + /// - Important: Events are only delivered once. If the event has already been yielded + /// then attempting to yield it again is a no-op. + /// - Parameter event: The event to yield. + public func yield(_ event: ServerStreamEvent) { + self.storage.yield(event) + } + + /// Indicate that no more events will be delivered. + public func finish() { + self.storage.finish() + } + } +} + +@available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *) +extension ServerStreamState.Events { + final class Storage: Sendable { + private let state: Mutex + + init() { + self.state = Mutex(State()) + } + + func eventSet(contains event: ServerStreamEvent) -> Bool { + self.state.withLock { $0.eventSet.contains(EventSet(event)) } + } + + func registerContinuation(_ continuation: AsyncStream.Continuation) { + self.state.withLock { + let events = $0.registerContinuation(continuation) + + if events.contains(.rpcCancelled) { + continuation.yield(.rpcCancelled) + } + + if events.contains(.finished) { + continuation.finish() + } + } + } + + func yield(_ event: ServerStreamEvent) { + self.state.withLock { + for continuation in $0.publishStreamEvent(event) { + continuation.yield(event) + } + } + } + + func finish() { + self.state.withLock { + for continuation in $0.finish() { + continuation.finish() + } + } + } + } + + private struct EventSet: OptionSet, Hashable, Sendable { + var rawValue: UInt8 + + init(rawValue: UInt8) { + self.rawValue = rawValue + } + + init(_ event: ServerStreamEvent) { + switch event.value { + case .rpcCancelled: + self = .rpcCancelled + } + } + + static let finished = EventSet(rawValue: 1 << 0) + static let rpcCancelled = EventSet(rawValue: 1 << 1) + } + + private struct State: Sendable { + private(set) var eventSet: EventSet + private var continuations: [AsyncStream.Continuation] + + init() { + self.eventSet = EventSet() + self.continuations = [] + } + + mutating func registerContinuation( + _ continuation: AsyncStream.Continuation + ) -> EventSet { + if !self.eventSet.contains(.finished) { + self.continuations.append(continuation) + } + + return self.eventSet + } + + mutating func publishStreamEvent( + _ streamEvent: ServerStreamEvent + ) -> [AsyncStream.Continuation] { + if self.eventSet.contains(.finished) { + return [] + } else { + let (inserted, _) = self.eventSet.insert(EventSet(streamEvent)) + return inserted ? self.continuations : [] + } + } + + mutating func finish() -> [AsyncStream.Continuation] { + let (inserted, _) = self.eventSet.insert(.finished) + return inserted ? self.continuations : [] + } + } +} diff --git a/Sources/GRPCInProcessTransport/InProcessServerTransport.swift b/Sources/GRPCInProcessTransport/InProcessServerTransport.swift index 2bb2ed57d..9bd848d58 100644 --- a/Sources/GRPCInProcessTransport/InProcessServerTransport.swift +++ b/Sources/GRPCInProcessTransport/InProcessServerTransport.swift @@ -15,6 +15,7 @@ */ public import GRPCCore +private import Synchronization /// An in-process implementation of a ``ServerTransport``. /// @@ -27,16 +28,47 @@ public import GRPCCore /// /// - SeeAlso: ``ClientTransport`` @available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *) -public struct InProcessServerTransport: ServerTransport, Sendable { +public final class InProcessServerTransport: ServerTransport, Sendable { public typealias Inbound = RPCAsyncSequence public typealias Outbound = RPCWriter.Closable private let newStreams: AsyncStream> private let newStreamsContinuation: AsyncStream>.Continuation + private let streamCounter: Atomic + private let state: Mutex + + private struct State: Sendable { + private var streamStateContinuations: [Int: ServerStreamState.Events.Continuation] + private var isShuttingDown: Bool + + init() { + self.streamStateContinuations = [:] + self.isShuttingDown = false + } + + mutating func beginGracefulShutdown() -> [ServerStreamState.Events.Continuation] { + self.isShuttingDown = true + return Array(self.streamStateContinuations.values) + } + + mutating func registerContinuation( + _ continuation: ServerStreamState.Events.Continuation, + id: Int + ) -> Bool { + self.streamStateContinuations[id] = continuation + return self.isShuttingDown + } + + mutating func deregisterContinuation(id: Int) -> ServerStreamState.Events.Continuation? { + return self.streamStateContinuations.removeValue(forKey: id) + } + } /// Creates a new instance of ``InProcessServerTransport``. public init() { (self.newStreams, self.newStreamsContinuation) = AsyncStream.makeStream() + self.streamCounter = Atomic(0) + self.state = Mutex(State()) } /// Publish a new ``RPCStream``, which will be returned by the transport's ``events`` @@ -64,7 +96,21 @@ public struct InProcessServerTransport: ServerTransport, Sendable { await withDiscardingTaskGroup { group in for await stream in self.newStreams { group.addTask { - let context = ServerContext(descriptor: stream.descriptor) + let (streamState, continuation) = ServerStreamState.makeState() + let (id, isShuttingDown) = self.registerStreamStateContinuation(continuation) + + // This can happen if the stream was accepted but not dequeued + // before 'beginGracefulShutdown' was called. Let the RPC run. + if isShuttingDown { + continuation.yield(.rpcCancelled) + } + + defer { + let continuation = self.deregisterStreamStateContinuation(id) + continuation?.finish() + } + + let context = ServerContext(descriptor: stream.descriptor, streamState: streamState) await streamHandler(stream, context) } } @@ -76,5 +122,28 @@ public struct InProcessServerTransport: ServerTransport, Sendable { /// - SeeAlso: ``ServerTransport`` public func beginGracefulShutdown() { self.newStreamsContinuation.finish() + let continuations = self.state.withLock { $0.beginGracefulShutdown() } + for continuation in continuations { + continuation.yield(.rpcCancelled) + } + } +} + +@available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *) +extension InProcessServerTransport { + private func registerStreamStateContinuation( + _ continuation: ServerStreamState.Events.Continuation + ) -> (Int, isShuttingDown: Bool) { + let (id, _) = self.streamCounter.add(1, ordering: .relaxed) + let isShuttingDown = self.state.withLock { + $0.registerContinuation(continuation, id: id) + } + return (id, isShuttingDown) + } + + private func deregisterStreamStateContinuation( + _ id: Int + ) -> ServerStreamState.Events.Continuation? { + self.state.withLock { $0.deregisterContinuation(id: id) } } } diff --git a/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTestSupport/ServerRPCExecutorTestHarness.swift b/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTestSupport/ServerRPCExecutorTestHarness.swift index 8d7e0a543..d69004354 100644 --- a/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTestSupport/ServerRPCExecutorTestHarness.swift +++ b/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTestSupport/ServerRPCExecutorTestHarness.swift @@ -20,24 +20,30 @@ import XCTest @available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *) struct ServerRPCExecutorTestHarness { struct ServerHandler: Sendable { - let fn: @Sendable (ServerRequest.Stream) async throws -> ServerResponse.Stream + let fn: + @Sendable ( + _ request: ServerRequest.Stream, + _ context: ServerContext + ) async throws -> ServerResponse.Stream init( _ fn: @escaping @Sendable ( - ServerRequest.Stream + ServerRequest.Stream, + ServerContext ) async throws -> ServerResponse.Stream ) { self.fn = fn } func handle( - _ request: ServerRequest.Stream + _ request: ServerRequest.Stream, + _ context: ServerContext ) async throws -> ServerResponse.Stream { - try await self.fn(request) + try await self.fn(request, context) } static func throwing(_ error: any Error) -> Self { - return Self { _ in throw error } + return Self { _, _ in throw error } } } @@ -51,7 +57,8 @@ struct ServerRPCExecutorTestHarness { deserializer: some MessageDeserializer, serializer: some MessageSerializer, handler: @escaping @Sendable ( - ServerRequest.Stream + ServerRequest.Stream, + ServerContext ) async throws -> ServerResponse.Stream, producer: @escaping @Sendable ( RPCWriter.Closable @@ -93,7 +100,12 @@ struct ServerRPCExecutorTestHarness { } group.addTask { - let context = ServerContext(descriptor: MethodDescriptor(service: "foo", method: "bar")) + let (streamState, _) = ServerStreamState.makeState() + let context = ServerContext( + descriptor: MethodDescriptor(service: "foo", method: "bar"), + streamState: streamState + ) + await ServerRPCExecutor.execute( context: context, stream: RPCStream( @@ -105,7 +117,7 @@ struct ServerRPCExecutorTestHarness { serializer: serializer, interceptors: self.interceptors, handler: { stream, context in - try await handler.handle(stream) + try await handler.handle(stream, context) } ) } @@ -136,7 +148,7 @@ struct ServerRPCExecutorTestHarness { @available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *) extension ServerRPCExecutorTestHarness.ServerHandler where Input == Output { static var echo: Self { - return Self { request in + return Self { request, _ in return ServerResponse.Stream(metadata: request.metadata) { writer in try await writer.write(contentsOf: request.messages) return [:] diff --git a/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift b/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift index 5d2aa0029..bb64e0d3d 100644 --- a/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift +++ b/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift @@ -84,7 +84,7 @@ final class ServerRPCExecutorTests: XCTestCase { try await harness.execute( deserializer: JSONDeserializer(), serializer: JSONSerializer() - ) { request in + ) { request, _ in let messages = try await request.messages.collect() XCTAssertEqual(messages, ["hello"]) return ServerResponse.Stream(metadata: request.metadata) { writer in @@ -113,7 +113,7 @@ final class ServerRPCExecutorTests: XCTestCase { try await harness.execute( deserializer: JSONDeserializer(), serializer: JSONSerializer() - ) { request in + ) { request, _ in let messages = try await request.messages.collect() XCTAssertEqual(messages, ["hello", "world"]) return ServerResponse.Stream(metadata: request.metadata) { writer in @@ -145,7 +145,7 @@ final class ServerRPCExecutorTests: XCTestCase { try await harness.execute( deserializer: IdentityDeserializer(), serializer: IdentitySerializer() - ) { request in + ) { request, _ in return ServerResponse.Stream(metadata: request.metadata) { _ in return ["bar": "baz"] } @@ -236,11 +236,16 @@ final class ServerRPCExecutorTests: XCTestCase { try await harness.execute( deserializer: IdentityDeserializer(), serializer: IdentitySerializer() - ) { request in - do { - try await Task.sleep(until: .now.advanced(by: .seconds(180)), clock: .continuous) - } catch is CancellationError { - throw RPCError(code: .cancelled, message: "Sleep was cancelled") + ) { request, context in + for await event in context.streamState.events { + switch event { + case .rpcCancelled: + return ServerResponse.Stream( + error: RPCError(code: .cancelled, message: "received 'rpcCancelled' event") + ) + default: + continue + } } XCTFail("Server handler should've been cancelled by timeout.") @@ -252,7 +257,7 @@ final class ServerRPCExecutorTests: XCTestCase { let part = try await outbound.collect().first XCTAssertStatus(part) { status, _ in XCTAssertEqual(status.code, .cancelled) - XCTAssertEqual(status.message, "Sleep was cancelled") + XCTAssertEqual(status.message, "received 'rpcCancelled' event") } } } @@ -269,7 +274,7 @@ final class ServerRPCExecutorTests: XCTestCase { try await harness.execute( deserializer: IdentityDeserializer(), serializer: IdentitySerializer() - ) { request in + ) { request, _ in XCTFail("Unexpected request") return ServerResponse.Stream( of: [UInt8].self, diff --git a/Tests/GRPCCoreTests/Call/Server/ServerStreamStateTests.swift b/Tests/GRPCCoreTests/Call/Server/ServerStreamStateTests.swift new file mode 100644 index 000000000..9c0ca0256 --- /dev/null +++ b/Tests/GRPCCoreTests/Call/Server/ServerStreamStateTests.swift @@ -0,0 +1,70 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import GRPCCore +import Testing + +@Suite("ServerStreamState") +struct ServerStreamStateTests { + @Test("Does nothing on init") + @available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *) + func nothingOnInit() { + let (state, _) = ServerStreamState.makeState() + #expect(!state.isRPCCancelled) + } + + @Test("Multiple iterators are allowed") + @available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *) + func multipleIteratorsAreAllowed() async { + let (state, continuation) = ServerStreamState.makeState() + await withTaskGroup(of: [ServerStreamEvent].self) { group in + for _ in 0 ..< 100 { + group.addTask { + await state.events.reduce(into: []) { $0.append($1) } + } + } + + continuation.yield(.rpcCancelled) + continuation.finish() + + for await events in group { + #expect(events == [.rpcCancelled]) + } + } + } + + @Test("State is set after event is yielded") + @available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *) + func stateSetAfterYield() async { + let (state, continuation) = ServerStreamState.makeState() + #expect(!state.isRPCCancelled) + continuation.yield(.rpcCancelled) + #expect(state.isRPCCancelled) + } + + @Test("Events are only delivered once") + @available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *) + func eventsAreOnlyDeliveredOnce() async { + let (state, continuation) = ServerStreamState.makeState() + continuation.yield(.rpcCancelled) + continuation.yield(.rpcCancelled) + continuation.yield(.rpcCancelled) + continuation.finish() + + let events = await state.events.reduce(into: []) { $0.append($1) } + #expect(events == [.rpcCancelled]) + } +} diff --git a/Tests/GRPCCoreTests/Test Utilities/Transport/ThrowingTransport.swift b/Tests/GRPCCoreTests/Test Utilities/Transport/ThrowingTransport.swift index 804be7d52..f953e7bcb 100644 --- a/Tests/GRPCCoreTests/Test Utilities/Transport/ThrowingTransport.swift +++ b/Tests/GRPCCoreTests/Test Utilities/Transport/ThrowingTransport.swift @@ -51,7 +51,7 @@ struct ThrowOnStreamCreationTransport: ClientTransport { } } -@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +@available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *) struct ThrowOnRunServerTransport: ServerTransport { func listen( streamHandler: ( @@ -70,7 +70,7 @@ struct ThrowOnRunServerTransport: ServerTransport { } } -@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +@available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *) struct ThrowOnSignalServerTransport: ServerTransport { let signal: AsyncStream diff --git a/Tests/GRPCInProcessTransportTests/InProcessClientTransportTests.swift b/Tests/GRPCInProcessTransportTests/InProcessClientTransportTests.swift index d33b2774d..a110fd7ea 100644 --- a/Tests/GRPCInProcessTransportTests/InProcessClientTransportTests.swift +++ b/Tests/GRPCInProcessTransportTests/InProcessClientTransportTests.swift @@ -284,6 +284,65 @@ final class InProcessClientTransportTests: XCTestCase { } } + func testRPCIsCancelledIsPropagated() async throws { + let inProcess = InProcessTransport.makePair() + let shouldBeginShutdown = AsyncStream.makeStream(of: Void.self) + + var router = RPCRouter() + router.registerHandler( + forMethod: MethodDescriptor(service: "foo", method: "bar"), + deserializer: PassthroughDeserializer(), + serializer: PassthroughSerializer() + ) { stream, context in + shouldBeginShutdown.continuation.finish() + + for await event in context.streamState.events { + switch event { + case .rpcCancelled: + return ServerResponse.Stream(error: RPCError(code: .cancelled, message: "")) + default: + continue + } + } + return ServerResponse.Stream(error: RPCError(code: .failedPrecondition, message: "")) + } + + let server = GRPCServer(transport: inProcess.server, router: router) + let client = GRPCClient(transport: inProcess.client) + + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await server.serve() + } + + group.addTask { + for await _ in shouldBeginShutdown.stream {} + server.beginGracefulShutdown() + } + + group.addTask { + try await client.run() + } + + try await client.unary( + request: ClientRequest.Single(message: [UInt8]()), + descriptor: MethodDescriptor(service: "foo", method: "bar"), + serializer: PassthroughSerializer(), + deserializer: PassthroughDeserializer(), + options: .defaults + ) { response in + switch response.accepted { + case .success: + XCTFail("Expected error") + case .failure(let error): + XCTAssertEqual(error.code, .cancelled) + } + } + + client.beginGracefulShutdown() + } + } + func makeClient( server: InProcessServerTransport = InProcessServerTransport() ) -> InProcessClientTransport { @@ -310,3 +369,15 @@ final class InProcessClientTransportTests: XCTestCase { ) } } + +struct PassthroughSerializer: MessageSerializer { + func serialize(_ message: [UInt8]) throws -> [UInt8] { + return message + } +} + +struct PassthroughDeserializer: MessageDeserializer { + func deserialize(_ messageBytes: [UInt8]) throws -> [UInt8] { + return messageBytes + } +}