From 3f064329dc5a32dc1b7181d7c6c3f722846f77bd Mon Sep 17 00:00:00 2001 From: Si Beaumont Date: Tue, 16 Jan 2024 08:10:44 +0000 Subject: [PATCH] Tolerate both CancellationError and URLError in cancellation tests When cancelling a Swift concurrency task during a streaming request then the error returned might be `URLError` with `.cancelled` code or `CancellationError`. The tests tried to be smart and expect just one of these depending on which stage of the request we were at, but there are still some races, and this test fails very rarely because a `CancellationError` was thrown instead of a `URLError`. This patch updates the test to tolerate both kinds of error at this stage of the request. Now the tests do not fail when run repeatedly. --- .../NIOAsyncHTTP1TestServer.swift | 2 +- .../TaskCancellationTests.swift | 8 +++--- .../MockAsyncSequence.swift | 25 +++++++++++-------- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/Tests/OpenAPIURLSessionTests/NIOAsyncHTTP1TestServer.swift b/Tests/OpenAPIURLSessionTests/NIOAsyncHTTP1TestServer.swift index 9b4f3ae..20fb663 100644 --- a/Tests/OpenAPIURLSessionTests/NIOAsyncHTTP1TestServer.swift +++ b/Tests/OpenAPIURLSessionTests/NIOAsyncHTTP1TestServer.swift @@ -59,7 +59,7 @@ final class AsyncTestHTTP1Server { for try await connectionChannel in inbound { group.addTask { do { - debug("Sevrer handling new connection") + debug("Server handling new connection") try await connectionHandler(connectionChannel) debug("Server done handling connection") } catch { debug("Server error handling connection: \(error)") } diff --git a/Tests/OpenAPIURLSessionTests/TaskCancellationTests.swift b/Tests/OpenAPIURLSessionTests/TaskCancellationTests.swift index a10d3f9..a5f4c86 100644 --- a/Tests/OpenAPIURLSessionTests/TaskCancellationTests.swift +++ b/Tests/OpenAPIURLSessionTests/TaskCancellationTests.swift @@ -150,11 +150,11 @@ func testTaskCancelled(_ cancellationPoint: CancellationPoint, transport: URLSes await XCTAssertThrowsError(try await task.value) { error in XCTAssertTrue(error is CancellationError) } case .beforeSendingRequestBody, .partwayThroughSendingRequestBody: await XCTAssertThrowsError(try await task.value) { error in - guard let urlError = error as? URLError else { - XCTFail() - return + switch error { + case is CancellationError: break + case is URLError: XCTAssertEqual((error as! URLError).code, .cancelled) + default: XCTFail("Unexpected error: \(error)") } - XCTAssertEqual(urlError.code, .cancelled) } case .beforeConsumingResponseBody, .partwayThroughConsumingResponseBody, .afterConsumingResponseBody: try await task.value diff --git a/Tests/OpenAPIURLSessionTests/URLSessionBidirectionalStreamingTests/MockAsyncSequence.swift b/Tests/OpenAPIURLSessionTests/URLSessionBidirectionalStreamingTests/MockAsyncSequence.swift index 4bbfc92..7f413ac 100644 --- a/Tests/OpenAPIURLSessionTests/URLSessionBidirectionalStreamingTests/MockAsyncSequence.swift +++ b/Tests/OpenAPIURLSessionTests/URLSessionBidirectionalStreamingTests/MockAsyncSequence.swift @@ -20,44 +20,47 @@ struct MockAsyncSequence: AsyncSequence, Sendable where Element: Sendab var elementsToVend: [Element] private let _elementsVended: LockedValueBox<[Element]> var elementsVended: [Element] { _elementsVended.withValue { $0 } } - private let semaphore: DispatchSemaphore? + private let gateOpeningsStream: AsyncStream + private let gateOpeningsContinuation: AsyncStream.Continuation init(elementsToVend: [Element], gatingProduction: Bool) { self.elementsToVend = elementsToVend self._elementsVended = LockedValueBox([]) - self.semaphore = gatingProduction ? DispatchSemaphore(value: 0) : nil + (self.gateOpeningsStream, self.gateOpeningsContinuation) = AsyncStream.makeStream(of: Void.self) + if !gatingProduction { openGate() } } - func openGate(for count: Int) { for _ in 0.. AsyncIterator { - AsyncIterator(elementsToVend: elementsToVend[...], semaphore: semaphore, elementsVended: _elementsVended) + AsyncIterator( + elementsToVend: elementsToVend[...], + gateOpenings: gateOpeningsStream.makeAsyncIterator(), + elementsVended: _elementsVended + ) } final class AsyncIterator: AsyncIteratorProtocol { var elementsToVend: ArraySlice - var semaphore: DispatchSemaphore? + var gateOpenings: AsyncStream.Iterator var elementsVended: LockedValueBox<[Element]> init( elementsToVend: ArraySlice, - semaphore: DispatchSemaphore?, + gateOpenings: AsyncStream.Iterator, elementsVended: LockedValueBox<[Element]> ) { self.elementsToVend = elementsToVend - self.semaphore = semaphore + self.gateOpenings = gateOpenings self.elementsVended = elementsVended } func next() async throws -> Element? { - await withCheckedContinuation { continuation in - semaphore?.wait() - continuation.resume() - } + guard await gateOpenings.next() != nil else { throw CancellationError() } guard let element = elementsToVend.popFirst() else { return nil } elementsVended.withValue { $0.append(element) } return element