Skip to content

More checks for task cancellation and tests #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import Foundation
task = dataTask(with: urlRequest)
}
return try await withTaskCancellationHandler {
try Task.checkCancellation()
let delegate = BidirectionalStreamingURLSessionDelegate(
requestBody: requestBody,
requestStreamBufferSize: requestStreamBufferSize,
Expand All @@ -47,8 +48,10 @@ import Foundation
length: .init(from: response),
iterationBehavior: .single
)
try Task.checkCancellation()
return (try HTTPResponse(response), responseBody)
} onCancel: {
debug("Concurrency task cancelled, cancelling URLSession task.")
task.cancel()
}
}
Expand Down
50 changes: 35 additions & 15 deletions Sources/OpenAPIURLSession/URLSessionTransport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import class Foundation.FileHandle
#if canImport(FoundationNetworking)
@preconcurrency import struct FoundationNetworking.URLRequest
import class FoundationNetworking.URLSession
import class FoundationNetworking.URLSessionTask
import class FoundationNetworking.URLResponse
import class FoundationNetworking.HTTPURLResponse
#endif
Expand Down Expand Up @@ -243,31 +244,50 @@ extension URLSession {
func bufferedRequest(for request: HTTPRequest, baseURL: URL, requestBody: HTTPBody?) async throws -> (
HTTPResponse, HTTPBody?
) {
try Task.checkCancellation()
var urlRequest = try URLRequest(request, baseURL: baseURL)
if let requestBody { urlRequest.httpBody = try await Data(collecting: requestBody, upTo: .max) }
try Task.checkCancellation()

/// Use `dataTask(with:completionHandler:)` here because `data(for:[delegate:]) async` is only available on
/// Darwin platforms newer than our minimum deployment target, and not at all on Linux.
let (response, maybeResponseBodyData): (URLResponse, Data?) = try await withCheckedThrowingContinuation {
continuation in
let task = self.dataTask(with: urlRequest) { [urlRequest] data, response, error in
if let error {
continuation.resume(throwing: error)
return
let taskBox: LockedValueBox<URLSessionTask?> = .init(nil)
return try await withTaskCancellationHandler {
let (response, maybeResponseBodyData): (URLResponse, Data?) = try await withCheckedThrowingContinuation {
continuation in
let task = self.dataTask(with: urlRequest) { [urlRequest] data, response, error in
if let error {
continuation.resume(throwing: error)
return
}
guard let response else {
continuation.resume(throwing: URLSessionTransportError.noResponse(url: urlRequest.url))
return
}
continuation.resume(with: .success((response, data)))
}
guard let response else {
continuation.resume(throwing: URLSessionTransportError.noResponse(url: urlRequest.url))
return
// Swift concurrency task cancelled here.
taskBox.withLockedValue { boxedTask in
guard task.state == .suspended else {
debug("URLSession task cannot be resumed, probably because it was cancelled by onCancel.")
return
}
task.resume()
boxedTask = task
}
continuation.resume(with: .success((response, data)))
}
task.resume()
}

let maybeResponseBody = maybeResponseBodyData.map { data in
HTTPBody(data, length: HTTPBody.Length(from: response), iterationBehavior: .multiple)
let maybeResponseBody = maybeResponseBodyData.map { data in
HTTPBody(data, length: HTTPBody.Length(from: response), iterationBehavior: .multiple)
}
return (try HTTPResponse(response), maybeResponseBody)
} onCancel: {
taskBox.withLockedValue { boxedTask in
debug("Concurrency task cancelled, cancelling URLSession task.")
boxedTask?.cancel()
boxedTask = nil
}
}
return (try HTTPResponse(response), maybeResponseBody)
}
}

Expand Down
240 changes: 240 additions & 0 deletions Tests/OpenAPIURLSessionTests/TaskCancellationTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftOpenAPIGenerator open source project
//
// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
#if canImport(Darwin)

import Foundation
import HTTPTypes
import NIO
import OpenAPIRuntime
import XCTest
@testable import OpenAPIURLSession

enum CancellationPoint: CaseIterable {
case beforeSendingHead
case beforeSendingRequestBody
case partwayThroughSendingRequestBody
case beforeConsumingResponseBody
case partwayThroughConsumingResponseBody
case afterConsumingResponseBody
}

func testTaskCancelled(_ cancellationPoint: CancellationPoint, transport: URLSessionTransport) async throws {
let requestPath = "/hello/world"
let requestBodyElements = ["Hello,", "world!"]
let requestBodySequence = MockAsyncSequence(elementsToVend: requestBodyElements, gatingProduction: true)
let requestBody = HTTPBody(
requestBodySequence,
length: .known(Int64(requestBodyElements.joined().lengthOfBytes(using: .utf8))),
iterationBehavior: .single
)

let responseBodyMessage = "Hey!"

let taskShouldCancel = XCTestExpectation(description: "Concurrency task cancelled")
let taskCancelled = XCTestExpectation(description: "Concurrency task cancelled")

try await withThrowingTaskGroup(of: Void.self) { group in
let serverPort = try await AsyncTestHTTP1Server.start(connectionTaskGroup: &group) { connectionChannel in
try await connectionChannel.executeThenClose { inbound, outbound in
var requestPartIterator = inbound.makeAsyncIterator()
var accumulatedBody = ByteBuffer()
while let requestPart = try await requestPartIterator.next() {
switch requestPart {
case .head(let head):
XCTAssertEqual(head.uri, requestPath)
XCTAssertEqual(head.method, .POST)
case .body(let buffer): accumulatedBody.writeImmutableBuffer(buffer)
case .end:
switch cancellationPoint {
case .beforeConsumingResponseBody, .partwayThroughConsumingResponseBody,
.afterConsumingResponseBody:
XCTAssertEqual(
String(decoding: accumulatedBody.readableBytesView, as: UTF8.self),
requestBodyElements.joined()
)
case .beforeSendingHead, .beforeSendingRequestBody, .partwayThroughSendingRequestBody: break
}
try await outbound.write(.head(.init(version: .http1_1, status: .ok)))
try await outbound.write(.body(ByteBuffer(string: responseBodyMessage)))
try await outbound.write(.end(nil))
}
}
}
}
debug("Server running on 127.0.0.1:\(serverPort)")

let task = Task {
if case .beforeSendingHead = cancellationPoint {
taskShouldCancel.fulfill()
await fulfillment(of: [taskCancelled])
}
debug("Client starting request")
async let (asyncResponse, asyncResponseBody) = try await transport.send(
HTTPRequest(method: .post, scheme: nil, authority: nil, path: requestPath),
body: requestBody,
baseURL: URL(string: "http://127.0.0.1:\(serverPort)")!,
operationID: "unused"
)

if case .beforeSendingRequestBody = cancellationPoint {
taskShouldCancel.fulfill()
await fulfillment(of: [taskCancelled])
}

requestBodySequence.openGate(for: 1)

if case .partwayThroughSendingRequestBody = cancellationPoint {
taskShouldCancel.fulfill()
await fulfillment(of: [taskCancelled])
}

requestBodySequence.openGate()

let (response, maybeResponseBody) = try await (asyncResponse, asyncResponseBody)

debug("Client received response head: \(response)")
XCTAssertEqual(response.status, .ok)
let responseBody = try XCTUnwrap(maybeResponseBody)

if case .beforeConsumingResponseBody = cancellationPoint {
taskShouldCancel.fulfill()
await fulfillment(of: [taskCancelled])
}

var iterator = responseBody.makeAsyncIterator()

_ = try await iterator.next()

if case .partwayThroughConsumingResponseBody = cancellationPoint {
taskShouldCancel.fulfill()
await fulfillment(of: [taskCancelled])
}

while try await iterator.next() != nil {

}

if case .afterConsumingResponseBody = cancellationPoint {
taskShouldCancel.fulfill()
await fulfillment(of: [taskCancelled])
}

}

await fulfillment(of: [taskShouldCancel])
task.cancel()
taskCancelled.fulfill()

switch transport.configuration.implementation {
case .buffering:
switch cancellationPoint {
case .beforeSendingHead, .beforeSendingRequestBody, .partwayThroughSendingRequestBody:
await XCTAssertThrowsError(try await task.value) { error in XCTAssertTrue(error is CancellationError) }
case .beforeConsumingResponseBody, .partwayThroughConsumingResponseBody, .afterConsumingResponseBody:
try await task.value
}
case .streaming:
switch cancellationPoint {
case .beforeSendingHead:
await XCTAssertThrowsError(try await task.value) { error in XCTAssertTrue(error is CancellationError) }
case .beforeSendingRequestBody, .partwayThroughSendingRequestBody:
await XCTAssertThrowsError(try await task.value) { error in
guard let urlError = error as? URLError else {
XCTFail()
return
}
XCTAssertEqual(urlError.code, .cancelled)
}
case .beforeConsumingResponseBody, .partwayThroughConsumingResponseBody, .afterConsumingResponseBody:
try await task.value
}
}

group.cancelAll()
}

}

func fulfillment(
of expectations: [XCTestExpectation],
timeout seconds: TimeInterval = .infinity,
enforceOrder enforceOrderOfFulfillment: Bool = false,
file: StaticString = #file,
line: UInt = #line
) async {
guard
case .completed = await XCTWaiter.fulfillment(
of: expectations,
timeout: seconds,
enforceOrder: enforceOrderOfFulfillment
)
else {
XCTFail("Expectation was not fulfilled", file: file, line: line)
return
}
}

extension URLSessionTransportBufferedTests {
func testCancellation_beforeSendingHead() async throws {
try await testTaskCancelled(.beforeSendingHead, transport: transport)
}

func testCancellation_beforeSendingRequestBody() async throws {
try await testTaskCancelled(.beforeSendingRequestBody, transport: transport)
}

func testCancellation_partwayThroughSendingRequestBody() async throws {
try await testTaskCancelled(.partwayThroughSendingRequestBody, transport: transport)
}

func testCancellation_beforeConsumingResponseBody() async throws {
try await testTaskCancelled(.beforeConsumingResponseBody, transport: transport)
}

func testCancellation_partwayThroughConsumingResponseBody() async throws {
try await testTaskCancelled(.partwayThroughConsumingResponseBody, transport: transport)
}

func testCancellation_afterConsumingResponseBody() async throws {
try await testTaskCancelled(.afterConsumingResponseBody, transport: transport)
}
}

extension URLSessionTransportStreamingTests {
func testCancellation_beforeSendingHead() async throws {
try await testTaskCancelled(.beforeSendingHead, transport: transport)
}

func testCancellation_beforeSendingRequestBody() async throws {
try await testTaskCancelled(.beforeSendingRequestBody, transport: transport)
}

func testCancellation_partwayThroughSendingRequestBody() async throws {
try await testTaskCancelled(.partwayThroughSendingRequestBody, transport: transport)
}

func testCancellation_beforeConsumingResponseBody() async throws {
try await testTaskCancelled(.beforeConsumingResponseBody, transport: transport)
}

func testCancellation_partwayThroughConsumingResponseBody() async throws {
try await testTaskCancelled(.partwayThroughConsumingResponseBody, transport: transport)
}

func testCancellation_afterConsumingResponseBody() async throws {
try await testTaskCancelled(.afterConsumingResponseBody, transport: transport)
}
}

#endif // canImport(Darwin)
8 changes: 4 additions & 4 deletions Tests/OpenAPIURLSessionTests/URLSessionTransportTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class URLSessionTransportConverterTests: XCTestCase {

// swift-format-ignore: AllPublicDeclarationsHaveDocumentation
class URLSessionTransportBufferedTests: XCTestCase {
var transport: (any ClientTransport)!
var transport: URLSessionTransport!

static override func setUp() { OpenAPIURLSession.debugLoggingEnabled = false }

Expand All @@ -66,7 +66,7 @@ class URLSessionTransportBufferedTests: XCTestCase {

func testBasicGet() async throws { try await testHTTPBasicGet(transport: transport) }

func testBasicPost() async throws { try await testHTTPBasicGet(transport: transport) }
func testBasicPost() async throws { try await testHTTPBasicPost(transport: transport) }

#if canImport(Darwin) // Only passes on Darwin because Linux doesn't replay the request body on 307.
func testHTTPRedirect_multipleIterationBehavior_succeeds() async throws {
Expand All @@ -89,7 +89,7 @@ class URLSessionTransportBufferedTests: XCTestCase {

// swift-format-ignore: AllPublicDeclarationsHaveDocumentation
class URLSessionTransportStreamingTests: XCTestCase {
var transport: (any ClientTransport)!
var transport: URLSessionTransport!

static override func setUp() { OpenAPIURLSession.debugLoggingEnabled = false }

Expand All @@ -107,7 +107,7 @@ class URLSessionTransportStreamingTests: XCTestCase {

func testBasicGet() async throws { try await testHTTPBasicGet(transport: transport) }

func testBasicPost() async throws { try await testHTTPBasicGet(transport: transport) }
func testBasicPost() async throws { try await testHTTPBasicPost(transport: transport) }

#if canImport(Darwin) // Only passes on Darwin because Linux doesn't replay the request body on 307.
func testHTTPRedirect_multipleIterationBehavior_succeeds() async throws {
Expand Down