From b56e0679d12e0d6bf0e1504ec1c29d2bdc124899 Mon Sep 17 00:00:00 2001 From: Jonathan Flat <50605158+jrflat@users.noreply.github.com> Date: Thu, 6 Jun 2024 08:24:51 -0700 Subject: [PATCH] Add async URLSession methods (#4970) * Add `data(from:delegate:)` method. * Add async URLSession methods --------- Co-authored-by: ichiho --- .../DataURLProtocol.swift | 3 +- .../URLSession/FTP/FTPURLProtocol.swift | 2 +- .../URLSession/HTTP/HTTPURLProtocol.swift | 6 +- .../URLSession/NativeProtocol.swift | 84 +++--- .../URLSession/TaskRegistry.swift | 4 + .../URLSession/URLSession.swift | 241 ++++++++++++++++++ .../URLSession/URLSessionDelegate.swift | 44 +++- .../URLSession/URLSessionTask.swift | 74 +++++- Tests/Foundation/Tests/TestURLSession.swift | 96 ++++++- 9 files changed, 495 insertions(+), 59 deletions(-) diff --git a/Sources/FoundationNetworking/DataURLProtocol.swift b/Sources/FoundationNetworking/DataURLProtocol.swift index 014f34b558..c4783b2dc4 100644 --- a/Sources/FoundationNetworking/DataURLProtocol.swift +++ b/Sources/FoundationNetworking/DataURLProtocol.swift @@ -91,8 +91,7 @@ internal class _DataURLProtocol: URLProtocol { urlClient.urlProtocolDidFinishLoading(self) } else { let error = NSError(domain: NSURLErrorDomain, code: NSURLErrorBadURL) - if let session = self.task?.session as? URLSession, let delegate = session.delegate as? URLSessionTaskDelegate, - let task = self.task { + if let task = self.task, let session = task.actualSession, let delegate = task.delegate { delegate.urlSession(session, task: task, didCompleteWithError: error) } } diff --git a/Sources/FoundationNetworking/URLSession/FTP/FTPURLProtocol.swift b/Sources/FoundationNetworking/URLSession/FTP/FTPURLProtocol.swift index 55583fd2b8..932600cbe2 100644 --- a/Sources/FoundationNetworking/URLSession/FTP/FTPURLProtocol.swift +++ b/Sources/FoundationNetworking/URLSession/FTP/FTPURLProtocol.swift @@ -119,7 +119,7 @@ internal extension _FTPURLProtocol { switch session.behaviour(for: self.task!) { case .noDelegate: break - case .taskDelegate: + case .taskDelegate, .dataCompletionHandlerWithTaskDelegate, .downloadCompletionHandlerWithTaskDelegate: self.client?.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed) case .dataCompletionHandler: break diff --git a/Sources/FoundationNetworking/URLSession/HTTP/HTTPURLProtocol.swift b/Sources/FoundationNetworking/URLSession/HTTP/HTTPURLProtocol.swift index abf6623435..c0722fb040 100644 --- a/Sources/FoundationNetworking/URLSession/HTTP/HTTPURLProtocol.swift +++ b/Sources/FoundationNetworking/URLSession/HTTP/HTTPURLProtocol.swift @@ -475,7 +475,7 @@ internal class _HTTPURLProtocol: _NativeProtocol { guard let session = task?.session as? URLSession else { fatalError() } - if let delegate = session.delegate as? URLSessionTaskDelegate { + if let delegate = task?.delegate { // At this point we need to change the internal state to note // that we're waiting for the delegate to call the completion // handler. Then we'll call the delegate callback @@ -524,7 +524,9 @@ internal class _HTTPURLProtocol: _NativeProtocol { switch session.behaviour(for: self.task!) { case .noDelegate: break - case .taskDelegate: + case .taskDelegate, + .dataCompletionHandlerWithTaskDelegate, + .downloadCompletionHandlerWithTaskDelegate: //TODO: There's a problem with libcurl / with how we're using it. // We're currently unable to pause the transfer / the easy handle: // https://curl.haxx.se/mail/lib-2016-03/0222.html diff --git a/Sources/FoundationNetworking/URLSession/NativeProtocol.swift b/Sources/FoundationNetworking/URLSession/NativeProtocol.swift index 53a195f5a8..95da2e9bdb 100644 --- a/Sources/FoundationNetworking/URLSession/NativeProtocol.swift +++ b/Sources/FoundationNetworking/URLSession/NativeProtocol.swift @@ -129,43 +129,59 @@ internal class _NativeProtocol: URLProtocol, _EasyHandleDelegate { } fileprivate func notifyDelegate(aboutReceivedData data: Data) { - guard let t = self.task else { + guard let task = self.task, let session = task.session as? URLSession else { fatalError("Cannot notify") } - if case .taskDelegate(let delegate) = t.session.behaviour(for: self.task!), - let dataDelegate = delegate as? URLSessionDataDelegate, - let task = self.task as? URLSessionDataTask { - // Forward to the delegate: - guard let s = self.task?.session as? URLSession else { - fatalError() - } - s.delegateQueue.addOperation { - dataDelegate.urlSession(s, dataTask: task, didReceive: data) - } - } else if case .taskDelegate(let delegate) = t.session.behaviour(for: self.task!), - let downloadDelegate = delegate as? URLSessionDownloadDelegate, - let task = self.task as? URLSessionDownloadTask { - guard let s = self.task?.session as? URLSession else { - fatalError() - } - let fileHandle = try! FileHandle(forWritingTo: self.tempFileURL) - _ = fileHandle.seekToEndOfFile() - fileHandle.write(data) - task.countOfBytesReceived += Int64(data.count) - s.delegateQueue.addOperation { - downloadDelegate.urlSession(s, downloadTask: task, didWriteData: Int64(data.count), totalBytesWritten: task.countOfBytesReceived, - totalBytesExpectedToWrite: task.countOfBytesExpectedToReceive) + switch task.session.behaviour(for: task) { + case .taskDelegate(let delegate), + .dataCompletionHandlerWithTaskDelegate(_, let delegate), + .downloadCompletionHandlerWithTaskDelegate(_, let delegate): + if let dataDelegate = delegate as? URLSessionDataDelegate, + let dataTask = task as? URLSessionDataTask { + session.delegateQueue.addOperation { + dataDelegate.urlSession(session, dataTask: dataTask, didReceive: data) + } + } else if let downloadDelegate = delegate as? URLSessionDownloadDelegate, + let downloadTask = task as? URLSessionDownloadTask { + let fileHandle = try! FileHandle(forWritingTo: self.tempFileURL) + _ = fileHandle.seekToEndOfFile() + fileHandle.write(data) + task.countOfBytesReceived += Int64(data.count) + session.delegateQueue.addOperation { + downloadDelegate.urlSession( + session, + downloadTask: downloadTask, + didWriteData: Int64(data.count), + totalBytesWritten: task.countOfBytesReceived, + totalBytesExpectedToWrite: task.countOfBytesExpectedToReceive + ) + } } + default: + break } } fileprivate func notifyDelegate(aboutUploadedData count: Int64) { - guard let task = self.task, let session = task.session as? URLSession, - case .taskDelegate(let delegate) = session.behaviour(for: task) else { return } - task.countOfBytesSent += count - session.delegateQueue.addOperation { - delegate.urlSession(session, task: task, didSendBodyData: count, - totalBytesSent: task.countOfBytesSent, totalBytesExpectedToSend: task.countOfBytesExpectedToSend) + guard let task = self.task, let session = task.session as? URLSession else { + return + } + switch session.behaviour(for: task) { + case .taskDelegate(let delegate), + .dataCompletionHandlerWithTaskDelegate(_, let delegate), + .downloadCompletionHandlerWithTaskDelegate(_, let delegate): + task.countOfBytesSent += count + session.delegateQueue.addOperation { + delegate.urlSession( + session, + task: task, + didSendBodyData: count, + totalBytesSent: task.countOfBytesSent, + totalBytesExpectedToSend: task.countOfBytesExpectedToSend + ) + } + default: + break } } @@ -284,7 +300,7 @@ internal class _NativeProtocol: URLProtocol, _EasyHandleDelegate { var currentInputStream: InputStream? - if let delegate = session.delegate as? URLSessionTaskDelegate { + if let delegate = task?.delegate { let dispatchGroup = DispatchGroup() dispatchGroup.enter() @@ -338,11 +354,13 @@ internal class _NativeProtocol: URLProtocol, _EasyHandleDelegate { // Data will be forwarded to the delegate as we receive it, we don't // need to do anything about it. return .ignore - case .dataCompletionHandler: + case .dataCompletionHandler, + .dataCompletionHandlerWithTaskDelegate: // Data needs to be concatenated in-memory such that we can pass it // to the completion handler upon completion. return .inMemory(nil) - case .downloadCompletionHandler: + case .downloadCompletionHandler, + .downloadCompletionHandlerWithTaskDelegate: // Data needs to be written to a file (i.e. a download task). let fileHandle = try! FileHandle(forWritingTo: self.tempFileURL) return .toFile(self.tempFileURL, fileHandle) diff --git a/Sources/FoundationNetworking/URLSession/TaskRegistry.swift b/Sources/FoundationNetworking/URLSession/TaskRegistry.swift index 9066a4a9cc..3e958891dd 100644 --- a/Sources/FoundationNetworking/URLSession/TaskRegistry.swift +++ b/Sources/FoundationNetworking/URLSession/TaskRegistry.swift @@ -45,8 +45,12 @@ extension URLSession { case callDelegate /// Default action for all events, except for completion. case dataCompletionHandler(DataTaskCompletion) + /// Default action for all asynchronous events. + case dataCompletionHandlerWithTaskDelegate(DataTaskCompletion, URLSessionTaskDelegate?) /// Default action for all events, except for completion. case downloadCompletionHandler(DownloadTaskCompletion) + /// Default action for all asynchronous events. + case downloadCompletionHandlerWithTaskDelegate(DownloadTaskCompletion, URLSessionTaskDelegate?) } fileprivate var tasks: [Int: URLSessionTask] = [:] diff --git a/Sources/FoundationNetworking/URLSession/URLSession.swift b/Sources/FoundationNetworking/URLSession/URLSession.swift index 2dcb000a22..d01fc316c7 100644 --- a/Sources/FoundationNetworking/URLSession/URLSession.swift +++ b/Sources/FoundationNetworking/URLSession/URLSession.swift @@ -648,15 +648,31 @@ internal extension URLSession { /// Default action for all events, except for completion. /// - SeeAlso: URLSession.TaskRegistry.Behaviour.dataCompletionHandler case dataCompletionHandler(URLSession._TaskRegistry.DataTaskCompletion) + /// Default action for all asynchronous events. + /// - SeeAlso: URLsession.TaskRegistry.Behaviour.dataCompletionHandlerWithTaskDelegate + case dataCompletionHandlerWithTaskDelegate(URLSession._TaskRegistry.DataTaskCompletion, URLSessionTaskDelegate) /// Default action for all events, except for completion. /// - SeeAlso: URLSession.TaskRegistry.Behaviour.downloadCompletionHandler case downloadCompletionHandler(URLSession._TaskRegistry.DownloadTaskCompletion) + /// Default action for all asynchronous events. + /// - SeeAlso: URLsession.TaskRegistry.Behaviour.downloadCompletionHandlerWithTaskDelegate + case downloadCompletionHandlerWithTaskDelegate(URLSession._TaskRegistry.DownloadTaskCompletion, URLSessionTaskDelegate) } func behaviour(for task: URLSessionTask) -> _TaskBehaviour { switch taskRegistry.behaviour(for: task) { case .dataCompletionHandler(let c): return .dataCompletionHandler(c) + case .dataCompletionHandlerWithTaskDelegate(let c, let d): + guard let d else { + return .dataCompletionHandler(c) + } + return .dataCompletionHandlerWithTaskDelegate(c, d) case .downloadCompletionHandler(let c): return .downloadCompletionHandler(c) + case .downloadCompletionHandlerWithTaskDelegate(let c, let d): + guard let d else { + return .downloadCompletionHandler(c) + } + return .downloadCompletionHandlerWithTaskDelegate(c, d) case .callDelegate: guard let d = delegate as? URLSessionTaskDelegate else { return .noDelegate @@ -666,6 +682,231 @@ internal extension URLSession { } } +fileprivate struct Lock: @unchecked Sendable { + let stateLock: ManagedBuffer + init(initialState: State) { + stateLock = .create(minimumCapacity: 1) { buffer in + buffer.withUnsafeMutablePointerToElements { lock in + lock.initialize(to: .init()) + } + return initialState + } + } + + func withLock(_ body: @Sendable (inout State) throws -> R) rethrows -> R where R : Sendable { + return try stateLock.withUnsafeMutablePointers { header, lock in + lock.pointee.lock() + defer { + lock.pointee.unlock() + } + return try body(&header.pointee) + } + } +} + +fileprivate extension URLSession { + final class CancelState: Sendable { + struct State { + var isCancelled: Bool + var task: URLSessionTask? + } + let lock: Lock + init() { + lock = Lock(initialState: State(isCancelled: false, task: nil)) + } + + func cancel() { + let task = lock.withLock { state in + state.isCancelled = true + let result = state.task + state.task = nil + return result + } + task?.cancel() + } + + func activate(task: URLSessionTask) { + let taskUsed = lock.withLock { state in + if state.task != nil { + fatalError("Cannot activate twice") + } + if state.isCancelled { + return false + } else { + state.isCancelled = false + state.task = task + return true + } + } + + if !taskUsed { + task.cancel() + } + } + } +} + +@available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) +extension URLSession { + /// Convenience method to load data using a URLRequest, creates and resumes a URLSessionDataTask internally. + /// + /// - Parameter request: The URLRequest for which to load data. + /// - Parameter delegate: Task-specific delegate. + /// - Returns: Data and response. + public func data(for request: URLRequest, delegate: URLSessionTaskDelegate? = nil) async throws -> (Data, URLResponse) { + let cancelState = CancelState() + return try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { continuation in + let completionHandler: URLSession._TaskRegistry.DataTaskCompletion = { data, response, error in + if let error = error { + continuation.resume(throwing: error) + } else { + continuation.resume(returning: (data!, response!)) + } + } + let task = dataTask(with: _Request(request), behaviour: .dataCompletionHandlerWithTaskDelegate(completionHandler, delegate)) + task._callCompletionHandlerInline = true + task.resume() + cancelState.activate(task: task) + } + } onCancel: { + cancelState.cancel() + } + } + + /// Convenience method to load data using a URL, creates and resumes a URLSessionDataTask internally. + /// + /// - Parameter url: The URL for which to load data. + /// - Parameter delegate: Task-specific delegate. + /// - Returns: Data and response. + public func data(from url: URL, delegate: URLSessionTaskDelegate? = nil) async throws -> (Data, URLResponse) { + let cancelState = CancelState() + return try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { continuation in + let completionHandler: URLSession._TaskRegistry.DataTaskCompletion = { data, response, error in + if let error = error { + continuation.resume(throwing: error) + } else { + continuation.resume(returning: (data!, response!)) + } + } + let task = dataTask(with: _Request(url), behaviour: .dataCompletionHandlerWithTaskDelegate(completionHandler, delegate)) + task._callCompletionHandlerInline = true + task.resume() + cancelState.activate(task: task) + } + } onCancel: { + cancelState.cancel() + } + } + + /// Convenience method to upload data using a URLRequest, creates and resumes a URLSessionUploadTask internally. + /// + /// - Parameter request: The URLRequest for which to upload data. + /// - Parameter fileURL: File to upload. + /// - Parameter delegate: Task-specific delegate. + /// - Returns: Data and response. + public func upload(for request: URLRequest, fromFile fileURL: URL, delegate: URLSessionTaskDelegate? = nil) async throws -> (Data, URLResponse) { + let cancelState = CancelState() + return try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { continuation in + let completionHandler: URLSession._TaskRegistry.DataTaskCompletion = { data, response, error in + if let error = error { + continuation.resume(throwing: error) + } else { + continuation.resume(returning: (data!, response!)) + } + } + let task = uploadTask(with: _Request(request), body: .file(fileURL), behaviour: .dataCompletionHandlerWithTaskDelegate(completionHandler, delegate)) + task._callCompletionHandlerInline = true + task.resume() + cancelState.activate(task: task) + } + } onCancel: { + cancelState.cancel() + } + } + + /// Convenience method to upload data using a URLRequest, creates and resumes a URLSessionUploadTask internally. + /// + /// - Parameter request: The URLRequest for which to upload data. + /// - Parameter bodyData: Data to upload. + /// - Parameter delegate: Task-specific delegate. + /// - Returns: Data and response. + public func upload(for request: URLRequest, from bodyData: Data, delegate: URLSessionTaskDelegate? = nil) async throws -> (Data, URLResponse) { + let cancelState = CancelState() + return try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { continuation in + let completionHandler: URLSession._TaskRegistry.DataTaskCompletion = { data, response, error in + if let error = error { + continuation.resume(throwing: error) + } else { + continuation.resume(returning: (data!, response!)) + } + } + let task = uploadTask(with: _Request(request), body: .data(createDispatchData(bodyData)), behaviour: .dataCompletionHandlerWithTaskDelegate(completionHandler, delegate)) + task._callCompletionHandlerInline = true + task.resume() + cancelState.activate(task: task) + } + } onCancel: { + cancelState.cancel() + } + } + + /// Convenience method to download using a URLRequest, creates and resumes a URLSessionDownloadTask internally. + /// + /// - Parameter request: The URLRequest for which to download. + /// - Parameter delegate: Task-specific delegate. + /// - Returns: Downloaded file URL and response. The file will not be removed automatically. + public func download(for request: URLRequest, delegate: URLSessionTaskDelegate? = nil) async throws -> (URL, URLResponse) { + let cancelState = CancelState() + return try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { continuation in + let completionHandler: URLSession._TaskRegistry.DownloadTaskCompletion = { location, response, error in + if let error = error { + continuation.resume(throwing: error) + } else { + continuation.resume(returning: (location!, response!)) + } + } + let task = downloadTask(with: _Request(request), behavior: .downloadCompletionHandlerWithTaskDelegate(completionHandler, delegate)) + task._callCompletionHandlerInline = true + task.resume() + cancelState.activate(task: task) + } + } onCancel: { + cancelState.cancel() + } + } + + /// Convenience method to download using a URL, creates and resumes a URLSessionDownloadTask internally. + /// + /// - Parameter url: The URL for which to download. + /// - Parameter delegate: Task-specific delegate. + /// - Returns: Downloaded file URL and response. The file will not be removed automatically. + public func download(from url: URL, delegate: URLSessionTaskDelegate? = nil) async throws -> (URL, URLResponse) { + let cancelState = CancelState() + return try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { continuation in + let completionHandler: URLSession._TaskRegistry.DownloadTaskCompletion = { location, response, error in + if let error = error { + continuation.resume(throwing: error) + } else { + continuation.resume(returning: (location!, response!)) + } + } + let task = downloadTask(with: _Request(url), behavior: .downloadCompletionHandlerWithTaskDelegate(completionHandler, delegate)) + task._callCompletionHandlerInline = true + task.resume() + cancelState.activate(task: task) + } + } onCancel: { + cancelState.cancel() + } + } +} + internal protocol URLSessionProtocol: AnyObject { func add(handle: _EasyHandle) diff --git a/Sources/FoundationNetworking/URLSession/URLSessionDelegate.swift b/Sources/FoundationNetworking/URLSession/URLSessionDelegate.swift index 4cb7d41351..bd061f55dc 100644 --- a/Sources/FoundationNetworking/URLSession/URLSessionDelegate.swift +++ b/Sources/FoundationNetworking/URLSession/URLSessionDelegate.swift @@ -134,24 +134,54 @@ public protocol URLSessionTaskDelegate : URLSessionDelegate { extension URLSessionTaskDelegate { public func urlSession(_ session: URLSession, task: URLSessionTask, willPerformHTTPRedirection response: HTTPURLResponse, newRequest request: URLRequest, completionHandler: @escaping (URLRequest?) -> Void) { - completionHandler(request) + // If the task's delegate does not implement this function, check if the session's delegate does + if self === task.delegate, let sessionDelegate = session.delegate as? URLSessionTaskDelegate, self !== sessionDelegate { + sessionDelegate.urlSession(session, task: task, willPerformHTTPRedirection: response, newRequest: request, completionHandler: completionHandler) + } else { + // Default handling + completionHandler(request) + } } public func urlSession(_ session: URLSession, task: URLSessionTask, didReceive challenge: URLAuthenticationChallenge, completionHandler: @escaping (URLSession.AuthChallengeDisposition, URLCredential?) -> Void) { - completionHandler(.performDefaultHandling, nil) + if self === task.delegate, let sessionDelegate = session.delegate as? URLSessionTaskDelegate, self !== sessionDelegate { + sessionDelegate.urlSession(session, task: task, didReceive: challenge, completionHandler: completionHandler) + } else { + completionHandler(.performDefaultHandling, nil) + } } public func urlSession(_ session: URLSession, task: URLSessionTask, needNewBodyStream completionHandler: @escaping (InputStream?) -> Void) { - completionHandler(nil) + if self === task.delegate, let sessionDelegate = session.delegate as? URLSessionTaskDelegate, self !== sessionDelegate { + sessionDelegate.urlSession(session, task: task, needNewBodyStream: completionHandler) + } else { + completionHandler(nil) + } } - public func urlSession(_ session: URLSession, task: URLSessionTask, didSendBodyData bytesSent: Int64, totalBytesSent: Int64, totalBytesExpectedToSend: Int64) { } + public func urlSession(_ session: URLSession, task: URLSessionTask, didSendBodyData bytesSent: Int64, totalBytesSent: Int64, totalBytesExpectedToSend: Int64) { + if self === task.delegate, let sessionDelegate = session.delegate as? URLSessionTaskDelegate, self !== sessionDelegate { + sessionDelegate.urlSession(session, task: task, didSendBodyData: bytesSent, totalBytesSent: totalBytesSent, totalBytesExpectedToSend: totalBytesExpectedToSend) + } + } - public func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { } + public func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { + if self === task.delegate, let sessionDelegate = session.delegate as? URLSessionTaskDelegate, self !== sessionDelegate { + sessionDelegate.urlSession(session, task: task, didCompleteWithError: error) + } + } - public func urlSession(_ session: URLSession, task: URLSessionTask, willBeginDelayedRequest request: URLRequest, completionHandler: @escaping (URLSession.DelayedRequestDisposition, URLRequest?) -> Void) { } + public func urlSession(_ session: URLSession, task: URLSessionTask, willBeginDelayedRequest request: URLRequest, completionHandler: @escaping (URLSession.DelayedRequestDisposition, URLRequest?) -> Void) { + if self === task.delegate, let sessionDelegate = session.delegate as? URLSessionTaskDelegate, self !== sessionDelegate { + sessionDelegate.urlSession(session, task: task, willBeginDelayedRequest: request, completionHandler: completionHandler) + } + } - public func urlSession(_ session: URLSession, task: URLSessionTask, didFinishCollecting metrics: URLSessionTaskMetrics) { } + public func urlSession(_ session: URLSession, task: URLSessionTask, didFinishCollecting metrics: URLSessionTaskMetrics) { + if self === task.delegate, let sessionDelegate = session.delegate as? URLSessionTaskDelegate, self !== sessionDelegate { + sessionDelegate.urlSession(session, task: task, didFinishCollecting: metrics) + } + } } /* diff --git a/Sources/FoundationNetworking/URLSession/URLSessionTask.swift b/Sources/FoundationNetworking/URLSession/URLSessionTask.swift index 6a342c6ad2..3771945750 100644 --- a/Sources/FoundationNetworking/URLSession/URLSessionTask.swift +++ b/Sources/FoundationNetworking/URLSession/URLSessionTask.swift @@ -104,6 +104,22 @@ open class URLSessionTask : NSObject, NSCopying { internal var actualSession: URLSession? { return session as? URLSession } internal var session: URLSessionProtocol! //change to nil when task completes + private var _taskDelegate: URLSessionTaskDelegate? + open var delegate: URLSessionTaskDelegate? { + get { + if let _taskDelegate { return _taskDelegate } + return self.actualSession?.delegate as? URLSessionTaskDelegate + } + set { + guard !self.hasTriggeredResume else { + fatalError("Cannot set task delegate after resumption") + } + _taskDelegate = newValue + } + } + + internal var _callCompletionHandlerInline = false + fileprivate enum ProtocolState { case toBeCreated case awaitingCacheReply(Bag<(URLProtocol?) -> Void>) @@ -211,7 +227,7 @@ open class URLSessionTask : NSObject, NSCopying { return } - if let session = actualSession, let delegate = session.delegate as? URLSessionTaskDelegate { + if let session = actualSession, let delegate = self.delegate { delegate.urlSession(session, task: self) { (stream) in if let stream = stream { completion(.stream(stream)) @@ -1044,7 +1060,9 @@ extension _ProtocolClient : URLProtocolClient { } switch session.behaviour(for: task) { - case .taskDelegate(let delegate): + case .taskDelegate(let delegate), + .dataCompletionHandlerWithTaskDelegate(_, let delegate), + .downloadCompletionHandlerWithTaskDelegate(_, let delegate): if let dataDelegate = delegate as? URLSessionDataDelegate, let dataTask = task as? URLSessionDataTask { session.delegateQueue.addOperation { @@ -1119,7 +1137,7 @@ extension _ProtocolClient : URLProtocolClient { let cacheable = CachedURLResponse(response: response, data: Data(data.joined()), storagePolicy: cachePolicy) let protocolAllows = (urlProtocol as? _NativeProtocol)?.canCache(cacheable) ?? false if protocolAllows { - if let delegate = task.session.delegate as? URLSessionDataDelegate { + if let delegate = task.delegate as? URLSessionDataDelegate { delegate.urlSession(task.session as! URLSession, dataTask: task, willCacheResponse: cacheable) { (actualCacheable) in if let actualCacheable = actualCacheable { cache.storeCachedResponse(actualCacheable, for: task) @@ -1157,8 +1175,9 @@ extension _ProtocolClient : URLProtocolClient { session.workQueue.async { session.taskRegistry.remove(task) } - case .dataCompletionHandler(let completion): - session.delegateQueue.addOperation { + case .dataCompletionHandler(let completion), + .dataCompletionHandlerWithTaskDelegate(let completion, _): + let dataCompletion = { guard task.state != .completed else { return } completion(urlProtocol.properties[URLProtocol._PropertyKey.responseData] as? Data ?? Data(), task.response, nil) task.state = .completed @@ -1166,8 +1185,16 @@ extension _ProtocolClient : URLProtocolClient { session.taskRegistry.remove(task) } } - case .downloadCompletionHandler(let completion): - session.delegateQueue.addOperation { + if task._callCompletionHandlerInline { + dataCompletion() + } else { + session.delegateQueue.addOperation { + dataCompletion() + } + } + case .downloadCompletionHandler(let completion), + .downloadCompletionHandlerWithTaskDelegate(let completion, _): + let downloadCompletion = { guard task.state != .completed else { return } completion(urlProtocol.properties[URLProtocol._PropertyKey.temporaryFileURL] as? URL, task.response, nil) task.state = .completed @@ -1175,6 +1202,13 @@ extension _ProtocolClient : URLProtocolClient { session.taskRegistry.remove(task) } } + if task._callCompletionHandlerInline { + downloadCompletion() + } else { + session.delegateQueue.addOperation { + downloadCompletion() + } + } } task._invalidateProtocol() } @@ -1224,7 +1258,7 @@ extension _ProtocolClient : URLProtocolClient { } } - if let delegate = session.delegate as? URLSessionTaskDelegate { + if let delegate = task.delegate { session.delegateQueue.addOperation { delegate.urlSession(session, task: task, didReceive: challenge) { disposition, credential in @@ -1297,8 +1331,9 @@ extension _ProtocolClient : URLProtocolClient { session.workQueue.async { session.taskRegistry.remove(task) } - case .dataCompletionHandler(let completion): - session.delegateQueue.addOperation { + case .dataCompletionHandler(let completion), + .dataCompletionHandlerWithTaskDelegate(let completion, _): + let dataCompletion = { guard task.state != .completed else { return } completion(nil, nil, error) task.state = .completed @@ -1306,8 +1341,16 @@ extension _ProtocolClient : URLProtocolClient { session.taskRegistry.remove(task) } } - case .downloadCompletionHandler(let completion): - session.delegateQueue.addOperation { + if task._callCompletionHandlerInline { + dataCompletion() + } else { + session.delegateQueue.addOperation { + dataCompletion() + } + } + case .downloadCompletionHandler(let completion), + .downloadCompletionHandlerWithTaskDelegate(let completion, _): + let downloadCompletion = { guard task.state != .completed else { return } completion(nil, nil, error) task.state = .completed @@ -1315,6 +1358,13 @@ extension _ProtocolClient : URLProtocolClient { session.taskRegistry.remove(task) } } + if task._callCompletionHandlerInline { + downloadCompletion() + } else { + session.delegateQueue.addOperation { + downloadCompletion() + } + } } task._invalidateProtocol() } diff --git a/Tests/Foundation/Tests/TestURLSession.swift b/Tests/Foundation/Tests/TestURLSession.swift index 8c04855589..688f5db3b1 100644 --- a/Tests/Foundation/Tests/TestURLSession.swift +++ b/Tests/Foundation/Tests/TestURLSession.swift @@ -92,7 +92,42 @@ class TestURLSession: LoopbackServerTest { task.resume() waitForExpectations(timeout: 12) } - + + func test_asyncDataFromURL() async throws { + guard #available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) else { return } + let urlString = "http://127.0.0.1:\(TestURLSession.serverPort)/UK" + let (data, response) = try await URLSession.shared.data(from: URL(string: urlString)!, delegate: nil) + guard let httpResponse = response as? HTTPURLResponse else { + XCTFail("Did not get response") + return + } + XCTAssertEqual(200, httpResponse.statusCode, "HTTP response code is not 200") + let result = String(data: data, encoding: .utf8) ?? "" + XCTAssertEqual("London", result, "Did not receive expected value") + } + + func test_asyncDataFromURLWithDelegate() async throws { + guard #available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) else { return } + class CapitalDataTaskDelegate: NSObject, URLSessionDataDelegate { + var capital: String = "unknown" + public func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) { + capital = String(data: data, encoding: .utf8)! + } + } + let delegate = CapitalDataTaskDelegate() + + let urlString = "http://127.0.0.1:\(TestURLSession.serverPort)/UK" + let (data, response) = try await URLSession.shared.data(from: URL(string: urlString)!, delegate: delegate) + guard let httpResponse = response as? HTTPURLResponse else { + XCTFail("Did not get response") + return + } + XCTAssertEqual(200, httpResponse.statusCode, "HTTP response code is not 200") + let result = String(data: data, encoding: .utf8) ?? "" + XCTAssertEqual("London", result, "Did not receive expected value") + XCTAssertEqual("London", delegate.capital) + } + func test_dataTaskWithHttpInputStream() throws { let urlString = "http://127.0.0.1:\(TestURLSession.serverPort)/jsonBody" let url = try XCTUnwrap(URL(string: urlString)) @@ -266,6 +301,44 @@ class TestURLSession: LoopbackServerTest { waitForExpectations(timeout: 12) } + func test_asyncDownloadFromURL() async throws { + guard #available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) else { return } + let urlString = "http://127.0.0.1:\(TestURLSession.serverPort)/country.txt" + let (location, response) = try await URLSession.shared.download(from: URL(string: urlString)!) + guard let httpResponse = response as? HTTPURLResponse else { + XCTFail("Did not get response") + return + } + XCTAssertEqual(200, httpResponse.statusCode, "HTTP response code is not 200") + XCTAssertNotNil(location, "Download location was nil") + } + + func test_asyncDownloadFromURLWithDelegate() async throws { + guard #available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) else { return } + class AsyncDownloadDelegate : NSObject, URLSessionDownloadDelegate { + func urlSession(_ session: URLSession, downloadTask: URLSessionDownloadTask, didFinishDownloadingTo location: URL) { + XCTFail("Should not be called for async downloads") + } + + var totalBytesWritten = Int64(0) + public func urlSession(_ session: URLSession, downloadTask: URLSessionDownloadTask, didWriteData bytesWritten: Int64, + totalBytesWritten: Int64, totalBytesExpectedToWrite: Int64) -> Void { + self.totalBytesWritten = totalBytesWritten + } + } + let delegate = AsyncDownloadDelegate() + + let urlString = "http://127.0.0.1:\(TestURLSession.serverPort)/country.txt" + let (location, response) = try await URLSession.shared.download(from: URL(string: urlString)!, delegate: delegate) + guard let httpResponse = response as? HTTPURLResponse else { + XCTFail("Did not get response") + return + } + XCTAssertEqual(200, httpResponse.statusCode, "HTTP response code is not 200") + XCTAssertNotNil(location, "Download location was nil") + XCTAssertTrue(delegate.totalBytesWritten > 0) + } + func test_gzippedDownloadTask() { let urlString = "http://127.0.0.1:\(TestURLSession.serverPort)/gzipped-response" let url = URL(string: urlString)! @@ -1611,6 +1684,21 @@ class TestURLSession: LoopbackServerTest { XCTAssertNil(session.delegate) } + func test_sessionDelegateCalledIfTaskDelegateDoesNotImplement() throws { + let expectation = XCTestExpectation(description: "task finished") + let delegate = SessionDelegate(with: expectation) + let session = URLSession(configuration: .default, delegate: delegate, delegateQueue: nil) + + class EmptyTaskDelegate: NSObject, URLSessionTaskDelegate { } + let url = URL(string: "http://127.0.0.1:\(TestURLSession.serverPort)/country.txt")! + let request = URLRequest(url: url) + let task = session.dataTask(with: request) + task.delegate = EmptyTaskDelegate() + task.resume() + + wait(for: [expectation], timeout: 5) + } + func test_getAllTasks() throws { let expect = expectation(description: "Tasks URLSession.getAllTasks") @@ -2088,6 +2176,7 @@ class TestURLSession: LoopbackServerTest { ("test_checkErrorTypeAfterInvalidateAndCancel", test_checkErrorTypeAfterInvalidateAndCancel), ("test_taskCountAfterInvalidateAndCancel", test_taskCountAfterInvalidateAndCancel), ("test_sessionDelegateAfterInvalidateAndCancel", test_sessionDelegateAfterInvalidateAndCancel), + ("test_sessionDelegateCalledIfTaskDelegateDoesNotImplement", test_sessionDelegateCalledIfTaskDelegateDoesNotImplement), /* ⚠️ */ ("test_getAllTasks", testExpectedToFail(test_getAllTasks, "This test causes later ones to crash")), /* ⚠️ */ ("test_getTasksWithCompletion", testExpectedToFail(test_getTasksWithCompletion, "Flaky tests")), /* ⚠️ */ ("test_invalidResumeDataForDownloadTask", @@ -2100,6 +2189,10 @@ class TestURLSession: LoopbackServerTest { ] if #available(macOS 12.0, *) { retVal.append(contentsOf: [ + ("test_asyncDataFromURL", asyncTest(test_asyncDataFromURL)), + ("test_asyncDataFromURLWithDelegate", asyncTest(test_asyncDataFromURLWithDelegate)), + ("test_asyncDownloadFromURL", asyncTest(test_asyncDownloadFromURL)), + ("test_asyncDownloadFromURLWithDelegate", asyncTest(test_asyncDownloadFromURLWithDelegate)), ("test_webSocket", asyncTest(test_webSocket)), ("test_webSocketSpecificProtocol", asyncTest(test_webSocketSpecificProtocol)), ("test_webSocketAbruptClose", asyncTest(test_webSocketAbruptClose)), @@ -2296,7 +2389,6 @@ extension SessionDelegate: URLSessionDataDelegate { } } - class DataTask : NSObject { let syncQ = dispatchQueueMake("org.swift.TestFoundation.TestURLSession.DataTask.syncQ") let dataTaskExpectation: XCTestExpectation!