diff --git a/Sources/OpenAPIURLSession/URLSessionBidirectionalStreaming/URLSession+Extensions.swift b/Sources/OpenAPIURLSession/URLSessionBidirectionalStreaming/URLSession+Extensions.swift index 007b9f2..4cd091f 100644 --- a/Sources/OpenAPIURLSession/URLSessionBidirectionalStreaming/URLSession+Extensions.swift +++ b/Sources/OpenAPIURLSession/URLSessionBidirectionalStreaming/URLSession+Extensions.swift @@ -32,6 +32,7 @@ import Foundation task = dataTask(with: urlRequest) } return try await withTaskCancellationHandler { + try Task.checkCancellation() let delegate = BidirectionalStreamingURLSessionDelegate( requestBody: requestBody, requestStreamBufferSize: requestStreamBufferSize, @@ -47,8 +48,10 @@ import Foundation length: .init(from: response), iterationBehavior: .single ) + try Task.checkCancellation() return (try HTTPResponse(response), responseBody) } onCancel: { + debug("Concurrency task cancelled, cancelling URLSession task.") task.cancel() } } diff --git a/Sources/OpenAPIURLSession/URLSessionTransport.swift b/Sources/OpenAPIURLSession/URLSessionTransport.swift index 0364f29..582a352 100644 --- a/Sources/OpenAPIURLSession/URLSessionTransport.swift +++ b/Sources/OpenAPIURLSession/URLSessionTransport.swift @@ -24,6 +24,7 @@ import class Foundation.FileHandle #if canImport(FoundationNetworking) @preconcurrency import struct FoundationNetworking.URLRequest import class FoundationNetworking.URLSession +import class FoundationNetworking.URLSessionTask import class FoundationNetworking.URLResponse import class FoundationNetworking.HTTPURLResponse #endif @@ -243,31 +244,50 @@ extension URLSession { func bufferedRequest(for request: HTTPRequest, baseURL: URL, requestBody: HTTPBody?) async throws -> ( HTTPResponse, HTTPBody? ) { + try Task.checkCancellation() var urlRequest = try URLRequest(request, baseURL: baseURL) if let requestBody { urlRequest.httpBody = try await Data(collecting: requestBody, upTo: .max) } + try Task.checkCancellation() /// Use `dataTask(with:completionHandler:)` here because `data(for:[delegate:]) async` is only available on /// Darwin platforms newer than our minimum deployment target, and not at all on Linux. - let (response, maybeResponseBodyData): (URLResponse, Data?) = try await withCheckedThrowingContinuation { - continuation in - let task = self.dataTask(with: urlRequest) { [urlRequest] data, response, error in - if let error { - continuation.resume(throwing: error) - return + let taskBox: LockedValueBox = .init(nil) + return try await withTaskCancellationHandler { + let (response, maybeResponseBodyData): (URLResponse, Data?) = try await withCheckedThrowingContinuation { + continuation in + let task = self.dataTask(with: urlRequest) { [urlRequest] data, response, error in + if let error { + continuation.resume(throwing: error) + return + } + guard let response else { + continuation.resume(throwing: URLSessionTransportError.noResponse(url: urlRequest.url)) + return + } + continuation.resume(with: .success((response, data))) } - guard let response else { - continuation.resume(throwing: URLSessionTransportError.noResponse(url: urlRequest.url)) - return + // Swift concurrency task cancelled here. + taskBox.withLockedValue { boxedTask in + guard task.state == .suspended else { + debug("URLSession task cannot be resumed, probably because it was cancelled by onCancel.") + return + } + task.resume() + boxedTask = task } - continuation.resume(with: .success((response, data))) } - task.resume() - } - let maybeResponseBody = maybeResponseBodyData.map { data in - HTTPBody(data, length: HTTPBody.Length(from: response), iterationBehavior: .multiple) + let maybeResponseBody = maybeResponseBodyData.map { data in + HTTPBody(data, length: HTTPBody.Length(from: response), iterationBehavior: .multiple) + } + return (try HTTPResponse(response), maybeResponseBody) + } onCancel: { + taskBox.withLockedValue { boxedTask in + debug("Concurrency task cancelled, cancelling URLSession task.") + boxedTask?.cancel() + boxedTask = nil + } } - return (try HTTPResponse(response), maybeResponseBody) } } diff --git a/Tests/OpenAPIURLSessionTests/TaskCancellationTests.swift b/Tests/OpenAPIURLSessionTests/TaskCancellationTests.swift new file mode 100644 index 0000000..a10d3f9 --- /dev/null +++ b/Tests/OpenAPIURLSessionTests/TaskCancellationTests.swift @@ -0,0 +1,240 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftOpenAPIGenerator open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if canImport(Darwin) + +import Foundation +import HTTPTypes +import NIO +import OpenAPIRuntime +import XCTest +@testable import OpenAPIURLSession + +enum CancellationPoint: CaseIterable { + case beforeSendingHead + case beforeSendingRequestBody + case partwayThroughSendingRequestBody + case beforeConsumingResponseBody + case partwayThroughConsumingResponseBody + case afterConsumingResponseBody +} + +func testTaskCancelled(_ cancellationPoint: CancellationPoint, transport: URLSessionTransport) async throws { + let requestPath = "/hello/world" + let requestBodyElements = ["Hello,", "world!"] + let requestBodySequence = MockAsyncSequence(elementsToVend: requestBodyElements, gatingProduction: true) + let requestBody = HTTPBody( + requestBodySequence, + length: .known(Int64(requestBodyElements.joined().lengthOfBytes(using: .utf8))), + iterationBehavior: .single + ) + + let responseBodyMessage = "Hey!" + + let taskShouldCancel = XCTestExpectation(description: "Concurrency task cancelled") + let taskCancelled = XCTestExpectation(description: "Concurrency task cancelled") + + try await withThrowingTaskGroup(of: Void.self) { group in + let serverPort = try await AsyncTestHTTP1Server.start(connectionTaskGroup: &group) { connectionChannel in + try await connectionChannel.executeThenClose { inbound, outbound in + var requestPartIterator = inbound.makeAsyncIterator() + var accumulatedBody = ByteBuffer() + while let requestPart = try await requestPartIterator.next() { + switch requestPart { + case .head(let head): + XCTAssertEqual(head.uri, requestPath) + XCTAssertEqual(head.method, .POST) + case .body(let buffer): accumulatedBody.writeImmutableBuffer(buffer) + case .end: + switch cancellationPoint { + case .beforeConsumingResponseBody, .partwayThroughConsumingResponseBody, + .afterConsumingResponseBody: + XCTAssertEqual( + String(decoding: accumulatedBody.readableBytesView, as: UTF8.self), + requestBodyElements.joined() + ) + case .beforeSendingHead, .beforeSendingRequestBody, .partwayThroughSendingRequestBody: break + } + try await outbound.write(.head(.init(version: .http1_1, status: .ok))) + try await outbound.write(.body(ByteBuffer(string: responseBodyMessage))) + try await outbound.write(.end(nil)) + } + } + } + } + debug("Server running on 127.0.0.1:\(serverPort)") + + let task = Task { + if case .beforeSendingHead = cancellationPoint { + taskShouldCancel.fulfill() + await fulfillment(of: [taskCancelled]) + } + debug("Client starting request") + async let (asyncResponse, asyncResponseBody) = try await transport.send( + HTTPRequest(method: .post, scheme: nil, authority: nil, path: requestPath), + body: requestBody, + baseURL: URL(string: "http://127.0.0.1:\(serverPort)")!, + operationID: "unused" + ) + + if case .beforeSendingRequestBody = cancellationPoint { + taskShouldCancel.fulfill() + await fulfillment(of: [taskCancelled]) + } + + requestBodySequence.openGate(for: 1) + + if case .partwayThroughSendingRequestBody = cancellationPoint { + taskShouldCancel.fulfill() + await fulfillment(of: [taskCancelled]) + } + + requestBodySequence.openGate() + + let (response, maybeResponseBody) = try await (asyncResponse, asyncResponseBody) + + debug("Client received response head: \(response)") + XCTAssertEqual(response.status, .ok) + let responseBody = try XCTUnwrap(maybeResponseBody) + + if case .beforeConsumingResponseBody = cancellationPoint { + taskShouldCancel.fulfill() + await fulfillment(of: [taskCancelled]) + } + + var iterator = responseBody.makeAsyncIterator() + + _ = try await iterator.next() + + if case .partwayThroughConsumingResponseBody = cancellationPoint { + taskShouldCancel.fulfill() + await fulfillment(of: [taskCancelled]) + } + + while try await iterator.next() != nil { + + } + + if case .afterConsumingResponseBody = cancellationPoint { + taskShouldCancel.fulfill() + await fulfillment(of: [taskCancelled]) + } + + } + + await fulfillment(of: [taskShouldCancel]) + task.cancel() + taskCancelled.fulfill() + + switch transport.configuration.implementation { + case .buffering: + switch cancellationPoint { + case .beforeSendingHead, .beforeSendingRequestBody, .partwayThroughSendingRequestBody: + await XCTAssertThrowsError(try await task.value) { error in XCTAssertTrue(error is CancellationError) } + case .beforeConsumingResponseBody, .partwayThroughConsumingResponseBody, .afterConsumingResponseBody: + try await task.value + } + case .streaming: + switch cancellationPoint { + case .beforeSendingHead: + await XCTAssertThrowsError(try await task.value) { error in XCTAssertTrue(error is CancellationError) } + case .beforeSendingRequestBody, .partwayThroughSendingRequestBody: + await XCTAssertThrowsError(try await task.value) { error in + guard let urlError = error as? URLError else { + XCTFail() + return + } + XCTAssertEqual(urlError.code, .cancelled) + } + case .beforeConsumingResponseBody, .partwayThroughConsumingResponseBody, .afterConsumingResponseBody: + try await task.value + } + } + + group.cancelAll() + } + +} + +func fulfillment( + of expectations: [XCTestExpectation], + timeout seconds: TimeInterval = .infinity, + enforceOrder enforceOrderOfFulfillment: Bool = false, + file: StaticString = #file, + line: UInt = #line +) async { + guard + case .completed = await XCTWaiter.fulfillment( + of: expectations, + timeout: seconds, + enforceOrder: enforceOrderOfFulfillment + ) + else { + XCTFail("Expectation was not fulfilled", file: file, line: line) + return + } +} + +extension URLSessionTransportBufferedTests { + func testCancellation_beforeSendingHead() async throws { + try await testTaskCancelled(.beforeSendingHead, transport: transport) + } + + func testCancellation_beforeSendingRequestBody() async throws { + try await testTaskCancelled(.beforeSendingRequestBody, transport: transport) + } + + func testCancellation_partwayThroughSendingRequestBody() async throws { + try await testTaskCancelled(.partwayThroughSendingRequestBody, transport: transport) + } + + func testCancellation_beforeConsumingResponseBody() async throws { + try await testTaskCancelled(.beforeConsumingResponseBody, transport: transport) + } + + func testCancellation_partwayThroughConsumingResponseBody() async throws { + try await testTaskCancelled(.partwayThroughConsumingResponseBody, transport: transport) + } + + func testCancellation_afterConsumingResponseBody() async throws { + try await testTaskCancelled(.afterConsumingResponseBody, transport: transport) + } +} + +extension URLSessionTransportStreamingTests { + func testCancellation_beforeSendingHead() async throws { + try await testTaskCancelled(.beforeSendingHead, transport: transport) + } + + func testCancellation_beforeSendingRequestBody() async throws { + try await testTaskCancelled(.beforeSendingRequestBody, transport: transport) + } + + func testCancellation_partwayThroughSendingRequestBody() async throws { + try await testTaskCancelled(.partwayThroughSendingRequestBody, transport: transport) + } + + func testCancellation_beforeConsumingResponseBody() async throws { + try await testTaskCancelled(.beforeConsumingResponseBody, transport: transport) + } + + func testCancellation_partwayThroughConsumingResponseBody() async throws { + try await testTaskCancelled(.partwayThroughConsumingResponseBody, transport: transport) + } + + func testCancellation_afterConsumingResponseBody() async throws { + try await testTaskCancelled(.afterConsumingResponseBody, transport: transport) + } +} + +#endif // canImport(Darwin) diff --git a/Tests/OpenAPIURLSessionTests/URLSessionTransportTests.swift b/Tests/OpenAPIURLSessionTests/URLSessionTransportTests.swift index 122c5e9..5903527 100644 --- a/Tests/OpenAPIURLSessionTests/URLSessionTransportTests.swift +++ b/Tests/OpenAPIURLSessionTests/URLSessionTransportTests.swift @@ -56,7 +56,7 @@ class URLSessionTransportConverterTests: XCTestCase { // swift-format-ignore: AllPublicDeclarationsHaveDocumentation class URLSessionTransportBufferedTests: XCTestCase { - var transport: (any ClientTransport)! + var transport: URLSessionTransport! static override func setUp() { OpenAPIURLSession.debugLoggingEnabled = false } @@ -66,7 +66,7 @@ class URLSessionTransportBufferedTests: XCTestCase { func testBasicGet() async throws { try await testHTTPBasicGet(transport: transport) } - func testBasicPost() async throws { try await testHTTPBasicGet(transport: transport) } + func testBasicPost() async throws { try await testHTTPBasicPost(transport: transport) } #if canImport(Darwin) // Only passes on Darwin because Linux doesn't replay the request body on 307. func testHTTPRedirect_multipleIterationBehavior_succeeds() async throws { @@ -89,7 +89,7 @@ class URLSessionTransportBufferedTests: XCTestCase { // swift-format-ignore: AllPublicDeclarationsHaveDocumentation class URLSessionTransportStreamingTests: XCTestCase { - var transport: (any ClientTransport)! + var transport: URLSessionTransport! static override func setUp() { OpenAPIURLSession.debugLoggingEnabled = false } @@ -107,7 +107,7 @@ class URLSessionTransportStreamingTests: XCTestCase { func testBasicGet() async throws { try await testHTTPBasicGet(transport: transport) } - func testBasicPost() async throws { try await testHTTPBasicGet(transport: transport) } + func testBasicPost() async throws { try await testHTTPBasicPost(transport: transport) } #if canImport(Darwin) // Only passes on Darwin because Linux doesn't replay the request body on 307. func testHTTPRedirect_multipleIterationBehavior_succeeds() async throws {