Skip to content

Commit 9a11003

Browse files
More checks for task cancellation and tests
1 parent 144464e commit 9a11003

File tree

3 files changed

+254
-19
lines changed

3 files changed

+254
-19
lines changed

Sources/OpenAPIURLSession/URLSessionBidirectionalStreaming/URLSession+Extensions.swift

+3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import Foundation
3232
task = dataTask(with: urlRequest)
3333
}
3434
return try await withTaskCancellationHandler {
35+
try Task.checkCancellation()
3536
let delegate = BidirectionalStreamingURLSessionDelegate(
3637
requestBody: requestBody,
3738
requestStreamBufferSize: requestStreamBufferSize,
@@ -47,8 +48,10 @@ import Foundation
4748
length: .init(from: response),
4849
iterationBehavior: .single
4950
)
51+
try Task.checkCancellation()
5052
return (try HTTPResponse(response), responseBody)
5153
} onCancel: {
54+
debug("Concurrency task cancelled, cancelling URLSession task.")
5255
task.cancel()
5356
}
5457
}

Sources/OpenAPIURLSession/URLSessionTransport.swift

+34-15
Original file line numberDiff line numberDiff line change
@@ -243,31 +243,50 @@ extension URLSession {
243243
func bufferedRequest(for request: HTTPRequest, baseURL: URL, requestBody: HTTPBody?) async throws -> (
244244
HTTPResponse, HTTPBody?
245245
) {
246+
try Task.checkCancellation()
246247
var urlRequest = try URLRequest(request, baseURL: baseURL)
247248
if let requestBody { urlRequest.httpBody = try await Data(collecting: requestBody, upTo: .max) }
249+
try Task.checkCancellation()
248250

249251
/// Use `dataTask(with:completionHandler:)` here because `data(for:[delegate:]) async` is only available on
250252
/// Darwin platforms newer than our minimum deployment target, and not at all on Linux.
251-
let (response, maybeResponseBodyData): (URLResponse, Data?) = try await withCheckedThrowingContinuation {
252-
continuation in
253-
let task = self.dataTask(with: urlRequest) { [urlRequest] data, response, error in
254-
if let error {
255-
continuation.resume(throwing: error)
256-
return
253+
let taskBox: LockedValueBox<URLSessionTask?> = .init(nil)
254+
return try await withTaskCancellationHandler {
255+
let (response, maybeResponseBodyData): (URLResponse, Data?) = try await withCheckedThrowingContinuation {
256+
continuation in
257+
let task = self.dataTask(with: urlRequest) { [urlRequest] data, response, error in
258+
if let error {
259+
continuation.resume(throwing: error)
260+
return
261+
}
262+
guard let response else {
263+
continuation.resume(throwing: URLSessionTransportError.noResponse(url: urlRequest.url))
264+
return
265+
}
266+
continuation.resume(with: .success((response, data)))
257267
}
258-
guard let response else {
259-
continuation.resume(throwing: URLSessionTransportError.noResponse(url: urlRequest.url))
260-
return
268+
// Swift concurrency task cancelled here.
269+
taskBox.withLockedValue { boxedTask in
270+
guard task.state == .suspended else {
271+
debug("URLSession task cannot be resumed, probably because it was cancelled by onCancel.")
272+
return
273+
}
274+
task.resume()
275+
boxedTask = task
261276
}
262-
continuation.resume(with: .success((response, data)))
263277
}
264-
task.resume()
265-
}
266278

267-
let maybeResponseBody = maybeResponseBodyData.map { data in
268-
HTTPBody(data, length: HTTPBody.Length(from: response), iterationBehavior: .multiple)
279+
let maybeResponseBody = maybeResponseBodyData.map { data in
280+
HTTPBody(data, length: HTTPBody.Length(from: response), iterationBehavior: .multiple)
281+
}
282+
return (try HTTPResponse(response), maybeResponseBody)
283+
} onCancel: {
284+
taskBox.withLockedValue { boxedTask in
285+
debug("Concurrency task cancelled, cancelling URLSession task.")
286+
boxedTask?.cancel()
287+
boxedTask = nil
288+
}
269289
}
270-
return (try HTTPResponse(response), maybeResponseBody)
271290
}
272291
}
273292

Tests/OpenAPIURLSessionTests/URLSessionTransportTests.swift

+217-4
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class URLSessionTransportConverterTests: XCTestCase {
5656

5757
// swift-format-ignore: AllPublicDeclarationsHaveDocumentation
5858
class URLSessionTransportBufferedTests: XCTestCase {
59-
var transport: (any ClientTransport)!
59+
var transport: URLSessionTransport!
6060

6161
static override func setUp() { OpenAPIURLSession.debugLoggingEnabled = false }
6262

@@ -66,7 +66,31 @@ class URLSessionTransportBufferedTests: XCTestCase {
6666

6767
func testBasicGet() async throws { try await testHTTPBasicGet(transport: transport) }
6868

69-
func testBasicPost() async throws { try await testHTTPBasicGet(transport: transport) }
69+
func testBasicPost() async throws { try await testHTTPBasicPost(transport: transport) }
70+
71+
func testCancellation_beforeSendingHead() async throws {
72+
try await testTaskCancelled(.beforeSendingHead, transport: transport)
73+
}
74+
75+
func testCancellation_beforeSendingRequestBody() async throws {
76+
try await testTaskCancelled(.beforeSendingRequestBody, transport: transport)
77+
}
78+
79+
func testCancellation_partwayThroughSendingRequestBody() async throws {
80+
try await testTaskCancelled(.partwayThroughSendingRequestBody, transport: transport)
81+
}
82+
83+
func testCancellation_beforeConsumingResponseBody() async throws {
84+
try await testTaskCancelled(.beforeConsumingResponseBody, transport: transport)
85+
}
86+
87+
func testCancellation_partwayThroughConsumingResponseBody() async throws {
88+
try await testTaskCancelled(.partwayThroughConsumingResponseBody, transport: transport)
89+
}
90+
91+
func testCancellation_afterConsumingResponseBody() async throws {
92+
try await testTaskCancelled(.afterConsumingResponseBody, transport: transport)
93+
}
7094

7195
#if canImport(Darwin) // Only passes on Darwin because Linux doesn't replay the request body on 307.
7296
func testHTTPRedirect_multipleIterationBehavior_succeeds() async throws {
@@ -89,7 +113,7 @@ class URLSessionTransportBufferedTests: XCTestCase {
89113

90114
// swift-format-ignore: AllPublicDeclarationsHaveDocumentation
91115
class URLSessionTransportStreamingTests: XCTestCase {
92-
var transport: (any ClientTransport)!
116+
var transport: URLSessionTransport!
93117

94118
static override func setUp() { OpenAPIURLSession.debugLoggingEnabled = false }
95119

@@ -107,7 +131,31 @@ class URLSessionTransportStreamingTests: XCTestCase {
107131

108132
func testBasicGet() async throws { try await testHTTPBasicGet(transport: transport) }
109133

110-
func testBasicPost() async throws { try await testHTTPBasicGet(transport: transport) }
134+
func testBasicPost() async throws { try await testHTTPBasicPost(transport: transport) }
135+
136+
func testCancellation_beforeSendingHead() async throws {
137+
try await testTaskCancelled(.beforeSendingHead, transport: transport)
138+
}
139+
140+
func testCancellation_beforeSendingRequestBody() async throws {
141+
try await testTaskCancelled(.beforeSendingRequestBody, transport: transport)
142+
}
143+
144+
func testCancellation_partwayThroughSendingRequestBody() async throws {
145+
try await testTaskCancelled(.partwayThroughSendingRequestBody, transport: transport)
146+
}
147+
148+
func testCancellation_beforeConsumingResponseBody() async throws {
149+
try await testTaskCancelled(.beforeConsumingResponseBody, transport: transport)
150+
}
151+
152+
func testCancellation_partwayThroughConsumingResponseBody() async throws {
153+
try await testTaskCancelled(.partwayThroughConsumingResponseBody, transport: transport)
154+
}
155+
156+
func testCancellation_afterConsumingResponseBody() async throws {
157+
try await testTaskCancelled(.afterConsumingResponseBody, transport: transport)
158+
}
111159

112160
#if canImport(Darwin) // Only passes on Darwin because Linux doesn't replay the request body on 307.
113161
func testHTTPRedirect_multipleIterationBehavior_succeeds() async throws {
@@ -311,6 +359,171 @@ func testHTTPBasicPost(transport: any ClientTransport) async throws {
311359
}
312360
}
313361

362+
enum CancellationPoint: CaseIterable {
363+
case beforeSendingHead
364+
case beforeSendingRequestBody
365+
case partwayThroughSendingRequestBody
366+
case beforeConsumingResponseBody
367+
case partwayThroughConsumingResponseBody
368+
case afterConsumingResponseBody
369+
}
370+
371+
func testTaskCancelled(_ cancellationPoint: CancellationPoint, transport: URLSessionTransport) async throws {
372+
let requestPath = "/hello/world"
373+
let requestBodyElements = ["Hello,", "world!"]
374+
let requestBodySequence = MockAsyncSequence(elementsToVend: requestBodyElements, gatingProduction: true)
375+
let requestBody = HTTPBody(
376+
requestBodySequence,
377+
length: .known(Int64(requestBodyElements.joined().lengthOfBytes(using: .utf8))),
378+
iterationBehavior: .single
379+
)
380+
381+
let responseBodyMessage = "Hey!"
382+
383+
let taskShouldCancel = XCTestExpectation(description: "Concurrency task cancelled")
384+
let taskCancelled = XCTestExpectation(description: "Concurrency task cancelled")
385+
386+
try await withThrowingTaskGroup(of: Void.self) { group in
387+
let serverPort = try await AsyncTestHTTP1Server.start(connectionTaskGroup: &group) { connectionChannel in
388+
try await connectionChannel.executeThenClose { inbound, outbound in
389+
var requestPartIterator = inbound.makeAsyncIterator()
390+
var accumulatedBody = ByteBuffer()
391+
while let requestPart = try await requestPartIterator.next() {
392+
switch requestPart {
393+
case .head(let head):
394+
XCTAssertEqual(head.uri, requestPath)
395+
XCTAssertEqual(head.method, .POST)
396+
case .body(let buffer): accumulatedBody.writeImmutableBuffer(buffer)
397+
case .end:
398+
switch cancellationPoint {
399+
case .beforeConsumingResponseBody, .partwayThroughConsumingResponseBody,
400+
.afterConsumingResponseBody:
401+
XCTAssertEqual(
402+
String(decoding: accumulatedBody.readableBytesView, as: UTF8.self),
403+
requestBodyElements.joined()
404+
)
405+
case .beforeSendingHead, .beforeSendingRequestBody, .partwayThroughSendingRequestBody: break
406+
}
407+
try await outbound.write(.head(.init(version: .http1_1, status: .ok)))
408+
try await outbound.write(.body(ByteBuffer(string: responseBodyMessage)))
409+
try await outbound.write(.end(nil))
410+
}
411+
}
412+
}
413+
}
414+
debug("Server running on 127.0.0.1:\(serverPort)")
415+
416+
let task = Task {
417+
if case .beforeSendingHead = cancellationPoint {
418+
taskShouldCancel.fulfill()
419+
await fulfillment(of: [taskCancelled])
420+
}
421+
debug("Client starting request")
422+
async let (asyncResponse, asyncResponseBody) = try await transport.send(
423+
HTTPRequest(method: .post, scheme: nil, authority: nil, path: requestPath),
424+
body: requestBody,
425+
baseURL: URL(string: "http://127.0.0.1:\(serverPort)")!,
426+
operationID: "unused"
427+
)
428+
429+
if case .beforeSendingRequestBody = cancellationPoint {
430+
taskShouldCancel.fulfill()
431+
await fulfillment(of: [taskCancelled])
432+
}
433+
434+
requestBodySequence.openGate(for: 1)
435+
436+
if case .partwayThroughSendingRequestBody = cancellationPoint {
437+
taskShouldCancel.fulfill()
438+
await fulfillment(of: [taskCancelled])
439+
}
440+
441+
requestBodySequence.openGate()
442+
443+
let (response, maybeResponseBody) = try await (asyncResponse, asyncResponseBody)
444+
445+
debug("Client received response head: \(response)")
446+
XCTAssertEqual(response.status, .ok)
447+
let responseBody = try XCTUnwrap(maybeResponseBody)
448+
449+
if case .beforeConsumingResponseBody = cancellationPoint {
450+
taskShouldCancel.fulfill()
451+
await fulfillment(of: [taskCancelled])
452+
}
453+
454+
var iterator = responseBody.makeAsyncIterator()
455+
456+
_ = try await iterator.next()
457+
458+
if case .partwayThroughConsumingResponseBody = cancellationPoint {
459+
taskShouldCancel.fulfill()
460+
await fulfillment(of: [taskCancelled])
461+
}
462+
463+
while try await iterator.next() != nil {
464+
465+
}
466+
467+
if case .afterConsumingResponseBody = cancellationPoint {
468+
taskShouldCancel.fulfill()
469+
await fulfillment(of: [taskCancelled])
470+
}
471+
472+
}
473+
474+
await fulfillment(of: [taskShouldCancel])
475+
task.cancel()
476+
taskCancelled.fulfill()
477+
478+
switch transport.configuration.implementation {
479+
case .buffering:
480+
switch cancellationPoint {
481+
case .beforeSendingHead, .beforeSendingRequestBody, .partwayThroughSendingRequestBody:
482+
await XCTAssertThrowsError(try await task.value) { error in XCTAssertTrue(error is CancellationError) }
483+
case .beforeConsumingResponseBody, .partwayThroughConsumingResponseBody, .afterConsumingResponseBody:
484+
try await task.value
485+
}
486+
case .streaming:
487+
switch cancellationPoint {
488+
case .beforeSendingHead:
489+
await XCTAssertThrowsError(try await task.value) { error in XCTAssertTrue(error is CancellationError) }
490+
case .beforeSendingRequestBody, .partwayThroughSendingRequestBody:
491+
await XCTAssertThrowsError(try await task.value) { error in
492+
guard let urlError = error as? URLError else {
493+
XCTFail()
494+
return
495+
}
496+
XCTAssertEqual(urlError.code, .cancelled)
497+
}
498+
case .beforeConsumingResponseBody, .partwayThroughConsumingResponseBody, .afterConsumingResponseBody:
499+
try await task.value
500+
}
501+
}
502+
503+
group.cancelAll()
504+
}
505+
506+
}
507+
508+
func fulfillment(
509+
of expectations: [XCTestExpectation],
510+
timeout seconds: TimeInterval = .infinity,
511+
enforceOrder enforceOrderOfFulfillment: Bool = false,
512+
file: StaticString = #file,
513+
line: UInt = #line
514+
) async {
515+
guard
516+
case .completed = await XCTWaiter.fulfillment(
517+
of: expectations,
518+
timeout: seconds,
519+
enforceOrder: enforceOrderOfFulfillment
520+
)
521+
else {
522+
XCTFail("Expectation was not fulfilled", file: file, line: line)
523+
return
524+
}
525+
}
526+
314527
class URLSessionTransportDebugLoggingTests: XCTestCase {
315528
func testDebugLoggingEnabled() {
316529
let expectation = expectation(description: "message autoclosure evaluated")

0 commit comments

Comments
 (0)