diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutingRequest.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutingRequest.swift new file mode 100644 index 000000000..5c4f0f6ab --- /dev/null +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutingRequest.swift @@ -0,0 +1,244 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIO +import NIOHTTP1 + +/// # Protocol Overview +/// +/// To support different public request APIs we abstract the actual request implementations behind +/// protocols. During the lifetime of a request, a request must conform to different protocols +/// depending on which state it is in. +/// +/// Generally there are two main states in a request's lifetime: +/// +/// 1. **The request is scheduled to be run.** +/// In this state the HTTP client tries to acquire a connection for the request, and the request +/// may need to wait for a connection +/// 2. **The request is executing.** +/// In this state the request was written to a NIO channel. A NIO channel handler (abstracted +/// by the `HTTPRequestExecutor` protocol) writes the request's bytes onto the wire and +/// dispatches the http response bytes back to the response. +/// +/// +/// ## Request is scheduled +/// +/// When the `HTTPClient` shall send an HTTP request, it will use its `HTTPConnectionPool.Manager` to +/// determine the `HTTPConnectionPool` to run the request on. After a `HTTPConnectionPool` has been +/// found for the request, the request will be executed on this connection pool. Since the HTTP +/// request implements the `HTTPScheduledRequest` protocol, the HTTP connection pool can communicate +/// with the request. The `HTTPConnectionPool` implements the `HTTPRequestScheduler` protocol. +/// +/// 1. The `HTTPConnectionPool` tries to find an idle connection for the request based on its +/// `eventLoopPreference`. +/// +/// 2. If an idle connection is available to the request, the request will be passed to the +/// connection right away. In this case the `HTTPConnectionPool` will only use the +/// `HTTPScheduledRequest`'s `eventLoopPreference` property. No other methods will be called. +/// +/// 3. If no idle connection is available to the request, the request will be queued for execution: +/// - The `HTTPConnectionPool` will inform the request that it is queued for execution by +/// calling: `requestWasQueued(_: HTTPRequestScheduler)`. The request must store a reference +/// to the `HTTPRequestScheduler`. The request must call `cancelRequest(self)` on the +/// scheduler, if the request was cancelled, while waiting for execution. +/// - The `HTTPConnectionPool` will create a connection deadline based on the +/// `HTTPScheduledRequest`'s `connectionDeadline` property. If a connection to execute the +/// request on, was not found before this deadline the request will be failed. +/// - The HTTPConnectionPool will call `fail(_: Error)` on the `HTTPScheduledRequest` to +/// inform the request about having overrun the `connectionDeadline`. +/// +/// +/// ## Request is executing +/// +/// After the `HTTPConnectionPool` has identified a connection for the request to be run on, it will +/// execute the request on this connection. (Implementation detail: This happens by writing the +/// `HTTPExecutingRequest` to a `NIO.Channel`. We expect the last handler in the `ChannelPipeline` +/// to have an `OutboundIn` type of `HTTPExecutingRequest`. Further we expect that the handler +/// also conforms to the protocol `HTTPRequestExecutor` to allow communication of the request with +/// the executor/`ChannelHandler`). +/// +/// The request execution will work as follows: +/// +/// 1. The request executor will call `willExecuteRequest(_: HTTPRequestExecutor)` on the +/// request. The request is expected to keep a reference to the `HTTPRequestExecutor` that was +/// passed to the request for further communication. +/// 2. The request sending is started by the executor accessing the `HTTPExecutingRequest`'s +/// property `requestHead: HTTPRequestHead`. Based on the `requestHead` the executor can +/// determine if the request has a body (Is a "content-length" or "transfer-encoding" +/// header present?). +/// 3. The executor will write the request's header into the Channel. If no body is present, the +/// executor will also write a request end into the Channel. After this the executor will call +/// `requestHeadSent(_: HTTPRequestHead)` +/// 4. If the request has a body the request executor will, ask the request for body data, by +/// calling `startRequestBodyStream()`. The request is expected to call +/// `writeRequestBodyPart(_: IOData, task: HTTPExecutingRequest)` on the executor with body +/// data. +/// - The executor can signal backpressure to the request by calling +/// `pauseRequestBodyStream()`. In this case the request is expected to stop calling +/// `writeRequestBodyPart(_: IOData, task: HTTPExecutingRequest)`. However because of race +/// conditions the executor is prepared to process more data, even though it asked the +/// request to pause. +/// - Once the executor is able to send more data, it will notify the request by calling +/// `resumeRequestBodyStream()` on the request. +/// - The request shall call `finishRequestBodyStream()` on the executor to signal that the +/// request body was sent. +/// 5. Once the executor receives a http response from the Channel, it will forward the http +/// response head to the `HTTPExecutingRequest` by calling `receiveResponseHead` on it. +/// - The executor will forward all the response body parts it receives in a single read to +/// the `HTTPExecutingRequest` without any buffering by calling +/// `receiveResponseBodyPart(_ buffer: ByteBuffer)` right away. It is the task's job to +/// buffer the responses for user consumption. +/// - Once the executor has finished a read, it will not schedule another read, until the +/// request calls `demandResponseBodyStream(task: HTTPExecutingRequest)` on the executor. +/// - Once the executor has received the response's end, it will forward this message by +/// calling `receiveResponseEnd()` on the `HTTPExecutingRequest`. +/// 6. If a channel error occurs during the execution of the request, or if the channel becomes +/// inactive the executor will notify the request by calling `fail(_ error: Error)` on it. +/// 7. If the request is cancelled, while it is executing on the executor, it must call +/// `cancelRequest(task: HTTPExecutingRequest)` on the executor. +/// +/// +/// ## Further notes +/// +/// - These protocols makes no guarantees about thread safety at all. It is implementations job to +/// ensure thread safety. +/// - However all calls to the `HTTPRequestScheduler` and `HTTPRequestExecutor` require that the +/// invoking request is passed along. This helps the scheduler and executor in race conditions. +/// Example: +/// - The executor may have received an error in thread A that it passes along to the request. +/// After having passed on the error, the executor considers the request done and releases +/// the request's reference. +/// - The request may issue a call to `writeRequestBodyPart(_: IOData, task: HTTPExecutingRequest)` +/// on thread B in the same moment the request error above occurred. For this reason it may +/// happen that the executor receives, the invocation of `writeRequestBodyPart` after it has +/// failed the request. +/// Passing along the requests reference helps the executor and scheduler verify its internal +/// state. + +/// A handle to the request scheduler. +/// +/// Use this handle to cancel the request, while it is waiting for a free connection, to execute the request. +/// This protocol is only intended to be implemented by the `HTTPConnectionPool`. +protocol HTTPRequestScheduler { + /// Informs the task queuer that a request has been cancelled. + func cancelRequest(_: HTTPScheduledRequest) +} + +/// An abstraction over a request that we want to send. A request may need to communicate with its request +/// queuer and executor. The client's methods will be called synchronously on an `EventLoop` by the +/// executor. For this reason it is very important that the implementation of these functions never blocks. +protocol HTTPScheduledRequest: AnyObject { + /// The task's logger + var logger: Logger { get } + + /// A connection to run this task on needs to be found before this deadline! + var connectionDeadline: NIODeadline { get } + + /// The task's `EventLoop` preference + var eventLoopPreference: HTTPClient.EventLoopPreference { get } + + /// Informs the task, that it was queued for execution + /// + /// This happens if all available connections are currently in use + func requestWasQueued(_: HTTPRequestScheduler) + + /// Fails the queued request, with an error. + func fail(_ error: Error) +} + +/// A handle to the request executor. +/// +/// This protocol is implemented by the `HTTP1ClientChannelHandler`. +protocol HTTPRequestExecutor { + /// Writes a body part into the channel pipeline + /// + /// This method may be **called on any thread**. The executor needs to ensure thread safety. + func writeRequestBodyPart(_: IOData, request: HTTPExecutingRequest) + + /// Signals that the request body stream has finished + /// + /// This method may be **called on any thread**. The executor needs to ensure thread safety. + func finishRequestBodyStream(_ task: HTTPExecutingRequest) + + /// Signals that more bytes from response body stream can be consumed. + /// + /// The request executor will call `receiveResponseBodyPart(_ buffer: ByteBuffer)` with more data after + /// this call. + /// + /// This method may be **called on any thread**. The executor needs to ensure thread safety. + func demandResponseBodyStream(_ task: HTTPExecutingRequest) + + /// Signals that the request has been cancelled. + /// + /// This method may be **called on any thread**. The executor needs to ensure thread safety. + func cancelRequest(_ task: HTTPExecutingRequest) +} + +protocol HTTPExecutingRequest: AnyObject { + /// The request's head. + /// + /// Based on the content of the request head the task executor will call `startRequestBodyStream` + /// after `requestHeadSent` was called. + var requestHead: HTTPRequestHead { get } + + /// The maximal `TimeAmount` that is allowed to pass between `channelRead`s from the Channel. + var idleReadTimeout: TimeAmount? { get } + + /// Will be called by the ChannelHandler to indicate that the request is going to be sent. + /// + /// This will be called on the Channel's EventLoop. Do **not block** during your execution! If the + /// request is cancelled after the `willExecuteRequest` method was called. The executing + /// request must call `executor.cancel()` to stop request execution. + func willExecuteRequest(_: HTTPRequestExecutor) + + /// Will be called by the ChannelHandler to indicate that the request head has been sent. + /// + /// This will be called on the Channel's EventLoop. Do **not block** during your execution! + func requestHeadSent() + + /// Start or resume request body streaming + /// + /// This will be called on the Channel's EventLoop. Do **not block** during your execution! + func resumeRequestBodyStream() + + /// Pause request streaming + /// + /// This will be called on the Channel's EventLoop. Do **not block** during your execution! + func pauseRequestBodyStream() + + /// Receive a response head. + /// + /// Please note that `receiveResponseHead` and `receiveResponseBodyPart` may + /// be called in quick succession. It is the task's job to buffer those events for the user. Once all + /// buffered data has been consumed the task must call `executor.demandResponseBodyStream` + /// to ask for more data. + func receiveResponseHead(_ head: HTTPResponseHead) + + /// Receive response body stream parts. + /// + /// Please note that `receiveResponseHead` and `receiveResponseBodyPart` may + /// be called in quick succession. It is the task's job to buffer those events for the user. Once all + /// buffered data has been consumed the task must call `executor.demandResponseBodyStream` + /// to ask for more data. + func receiveResponseBodyParts(_ buffer: CircularBuffer) + + /// Succeeds the executing request. The executor will not call any further methods on the request after this method. + /// + /// - Parameter buffer: The remaining response body parts, that were received before the request end + func succeedRequest(_ buffer: CircularBuffer?) + + /// Fails the executing request, with an error. + func fail(_ error: Error) +} diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index 9400930b9..6b1c6ec31 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -923,6 +923,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { case httpProxyHandshakeTimeout case tlsHandshakeTimeout case serverOfferedUnsupportedApplicationProtocol(String) + case requestStreamCancelled } private var code: Code @@ -991,4 +992,8 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { public static func serverOfferedUnsupportedApplicationProtocol(_ proto: String) -> HTTPClientError { return HTTPClientError(code: .serverOfferedUnsupportedApplicationProtocol(proto)) } + + /// The remote server responded with a status code >= 300, before the full request was sent. The request stream + /// was therefore cancelled + public static let requestStreamCancelled = HTTPClientError(code: .requestStreamCancelled) } diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index a9c1a9e22..1ae8cdd60 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -624,6 +624,10 @@ extension URL { } } +protocol HTTPClientTaskDelegate { + func cancel() +} + extension HTTPClient { /// Response execution context. Will be created by the library and could be used for obtaining /// `EventLoopFuture` of the execution or cancellation of the execution. diff --git a/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift b/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift new file mode 100644 index 000000000..08ff5a539 --- /dev/null +++ b/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift @@ -0,0 +1,525 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import struct Foundation.URL +import NIO +import NIOHTTP1 + +extension RequestBag { + struct StateMachine { + fileprivate enum State { + case initialized + case queued(HTTPRequestScheduler) + case executing(HTTPRequestExecutor, RequestStreamState, ResponseStreamState) + case finished(error: Error?) + case redirected(HTTPResponseHead, URL) + case modifying + } + + fileprivate enum RequestStreamState { + case initialized + case producing + case paused(EventLoopPromise?) + case finished + } + + fileprivate enum ResponseStreamState { + enum Next { + case askExecutorForMore + case error(Error) + case eof + } + + case initialized + case buffering(CircularBuffer, next: Next) + case waitingForRemote + } + + private var state: State = .initialized + private let redirectHandler: RedirectHandler? + + init(redirectHandler: RedirectHandler?) { + self.redirectHandler = redirectHandler + } + } +} + +extension RequestBag.StateMachine { + mutating func requestWasQueued(_ scheduler: HTTPRequestScheduler) { + guard case .initialized = self.state else { + // There might be a race between `requestWasQueued` and `willExecuteRequest`: + // + // If the request is created and passed to the HTTPClient on thread A, it will move into + // the connection pool lock in thread A. If no connection is available, thread A will + // add the request to the waiters and leave the connection pool lock. + // `requestWasQueued` will be called outside the connection pool lock on thread A. + // However if thread B has a connection that becomes available and thread B enters the + // connection pool lock directly after thread A, the request will be immediately + // scheduled for execution on thread B. After the thread B has left the lock it will + // call `willExecuteRequest` directly after. + // + // Having an order in the connection pool lock, does not guarantee an order in calling: + // `requestWasQueued` and `willExecuteRequest`. + // + // For this reason we must check the state here... If we are not `.initialized`, we are + // already executing. + return + } + + self.state = .queued(scheduler) + } + + mutating func willExecuteRequest(_ executor: HTTPRequestExecutor) -> Bool { + switch self.state { + case .initialized, .queued: + self.state = .executing(executor, .initialized, .initialized) + return true + case .finished(error: .some): + return false + case .executing, .redirected, .finished(error: .none), .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + enum ResumeProducingAction { + case startWriter + case succeedBackpressurePromise(EventLoopPromise?) + case none + } + + mutating func resumeRequestBodyStream() -> ResumeProducingAction { + switch self.state { + case .initialized, .queued: + preconditionFailure("A request stream can only be resumed, if the request was started") + + case .executing(let executor, .initialized, .initialized): + self.state = .executing(executor, .producing, .initialized) + return .startWriter + + case .executing(_, .producing, _): + preconditionFailure("Expected that resume is only called when if we were paused before") + + case .executing(let executor, .paused(let promise), let responseState): + self.state = .executing(executor, .producing, responseState) + return .succeedBackpressurePromise(promise) + + case .executing(_, .finished, _): + // the channels writability changed to writable after we have forwarded all the + // request bytes. Can be ignored. + return .none + + case .executing(_, .initialized, .buffering), .executing(_, .initialized, .waitingForRemote): + preconditionFailure("Invalid states: Response can not be received before request") + + case .redirected: + // if we are redirected, we should cancel our request body stream anyway + return .none + + case .finished: + preconditionFailure("Invalid state") + + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func pauseRequestBodyStream() { + switch self.state { + case .initialized, .queued: + preconditionFailure("A request stream can only be paused, if the request was started") + case .executing(let executor, let requestState, let responseState): + switch requestState { + case .initialized: + preconditionFailure("Request stream must be started before it can be paused") + case .producing: + self.state = .executing(executor, .paused(nil), responseState) + case .paused: + preconditionFailure("Expected that pause is only called when if we were producing before") + case .finished: + // the channels writability changed to not writable after we have forwarded the + // last bytes from our side. + break + } + case .redirected: + // if we are redirected, we should cancel our request body stream anyway + break + case .finished: + // the request is already finished nothing further to do + break + case .modifying: + preconditionFailure("Invalid state") + } + } + + enum WriteAction { + case write(IOData, HTTPRequestExecutor, EventLoopFuture) + + case failTask(Error) + case failFuture(Error) + } + + mutating func writeNextRequestPart(_ part: IOData, taskEventLoop: EventLoop) -> WriteAction { + switch self.state { + case .initialized, .queued: + preconditionFailure("Invalid state: \(self.state)") + case .executing(let executor, let requestState, let responseState): + switch requestState { + case .initialized: + preconditionFailure("Request stream must be started before it can be paused") + case .producing: + return .write(part, executor, taskEventLoop.makeSucceededFuture(())) + + case .paused(.none): + // backpressure is signaled to the writer using unfulfilled futures. if there + // is no existing, unfulfilled promise, let's create a new one + let promise = taskEventLoop.makePromise(of: Void.self) + self.state = .executing(executor, .paused(promise), responseState) + return .write(part, executor, promise.futureResult) + + case .paused(.some(let promise)): + // backpressure is signaled to the writer using unfulfilled futures. if an + // unfulfilled promise already exist, let's reuse the promise + return .write(part, executor, promise.futureResult) + + case .finished: + let error = HTTPClientError.writeAfterRequestSent + self.state = .finished(error: error) + return .failTask(error) + } + case .redirected: + // if we are redirected we can cancel the upload stream + return .failFuture(HTTPClientError.cancelled) + case .finished(error: .some(let error)): + return .failFuture(error) + case .finished(error: .none): + return .failFuture(HTTPClientError.requestStreamCancelled) + case .modifying: + preconditionFailure("Invalid state") + } + } + + enum FinishAction { + case forwardStreamFinished(HTTPRequestExecutor, EventLoopPromise?) + case forwardStreamFailureAndFailTask(HTTPRequestExecutor, Error, EventLoopPromise?) + case none + } + + mutating func finishRequestBodyStream(_ result: Result) -> FinishAction { + switch self.state { + case .initialized, .queued: + preconditionFailure("Invalid state: \(self.state)") + case .executing(let executor, let requestState, let responseState): + switch requestState { + case .initialized: + preconditionFailure("Request stream must be started before it can be finished") + case .producing: + switch result { + case .success: + self.state = .executing(executor, .finished, responseState) + return .forwardStreamFinished(executor, nil) + case .failure(let error): + self.state = .finished(error: error) + return .forwardStreamFailureAndFailTask(executor, error, nil) + } + + case .paused(let promise): + switch result { + case .success: + self.state = .executing(executor, .finished, responseState) + return .forwardStreamFinished(executor, promise) + case .failure(let error): + self.state = .finished(error: error) + return .forwardStreamFailureAndFailTask(executor, error, promise) + } + + case .finished: + preconditionFailure("How can a finished request stream, be finished again?") + } + case .redirected: + return .none + case .finished(error: _): + return .none + case .modifying: + preconditionFailure("Invalid state") + } + } + + /// The response head has been received. + /// + /// - Parameter head: The response' head + /// - Returns: Whether the response should be forwarded to the delegate. Will be `false` if the request follows a redirect. + mutating func receiveResponseHead(_ head: HTTPResponseHead) -> Bool { + switch self.state { + case .initialized, .queued: + preconditionFailure("How can we receive a response, if the request hasn't started yet.") + case .executing(let executor, let requestState, let responseState): + guard case .initialized = responseState else { + preconditionFailure("If we receive a response, we must not have received something else before") + } + + if let redirectURL = self.redirectHandler?.redirectTarget(status: head.status, headers: head.headers) { + self.state = .redirected(head, redirectURL) + return false + } else { + self.state = .executing(executor, requestState, .buffering(.init(), next: .askExecutorForMore)) + return true + } + case .redirected: + preconditionFailure("This state can only be reached after we have received a HTTP head") + case .finished(error: .some): + return false + case .finished(error: .none): + preconditionFailure("How can the request be finished without error, before receiving response head?") + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func receiveResponseBodyParts(_ buffer: CircularBuffer) -> ByteBuffer? { + switch self.state { + case .initialized, .queued: + preconditionFailure("How can we receive a response body part, if the request hasn't started yet.") + case .executing(_, _, .initialized): + preconditionFailure("If we receive a response body, we must have received a head before") + + case .executing(let executor, let requestState, .buffering(var currentBuffer, next: let next)): + guard case .askExecutorForMore = next else { + preconditionFailure("If we have received an error or eof before, why did we get another body part? Next: \(next)") + } + + self.state = .modifying + if currentBuffer.isEmpty { + currentBuffer = buffer + } else { + currentBuffer.append(contentsOf: buffer) + } + self.state = .executing(executor, requestState, .buffering(currentBuffer, next: next)) + return nil + case .executing(let executor, let requestState, .waitingForRemote): + var buffer = buffer + let first = buffer.removeFirst() + self.state = .executing(executor, requestState, .buffering(buffer, next: .askExecutorForMore)) + return first + case .redirected: + // ignore body + return nil + case .finished(error: .some): + return nil + case .finished(error: .none): + preconditionFailure("How can the request be finished without error, before receiving response head?") + case .modifying: + preconditionFailure("Invalid state") + } + } + + enum ReceiveResponseEndAction { + case consume(ByteBuffer) + case redirect(RedirectHandler, HTTPResponseHead, URL) + case succeedRequest + case none + } + + mutating func succeedRequest(_ newChunks: CircularBuffer?) -> ReceiveResponseEndAction { + switch self.state { + case .initialized, .queued: + preconditionFailure("How can we receive a response body part, if the request hasn't started yet.") + case .executing(_, _, .initialized): + preconditionFailure("If we receive a response body, we must have received a head before") + + case .executing(let executor, let requestState, .buffering(var buffer, next: let next)): + guard case .askExecutorForMore = next else { + preconditionFailure("If we have received an error or eof before, why did we get another body part? Next: \(next)") + } + + if buffer.isEmpty, newChunks == nil || newChunks!.isEmpty { + self.state = .finished(error: nil) + return .succeedRequest + } else if buffer.isEmpty, let newChunks = newChunks { + buffer = newChunks + } else if let newChunks = newChunks { + buffer.append(contentsOf: newChunks) + } + + self.state = .executing(executor, requestState, .buffering(buffer, next: .eof)) + return .none + + case .executing(let executor, let requestState, .waitingForRemote): + guard var newChunks = newChunks, !newChunks.isEmpty else { + self.state = .finished(error: nil) + return .succeedRequest + } + + let first = newChunks.removeFirst() + self.state = .executing(executor, requestState, .buffering(newChunks, next: .eof)) + return .consume(first) + + case .redirected(let head, let redirectURL): + self.state = .finished(error: nil) + return .redirect(self.redirectHandler!, head, redirectURL) + + case .finished(error: .some): + return .none + + case .finished(error: .none): + preconditionFailure("How can the request be finished without error, before receiving response head?") + case .modifying: + preconditionFailure("Invalid state") + } + } + + enum ConsumeAction { + case requestMoreFromExecutor(HTTPRequestExecutor) + case consume(ByteBuffer) + case finishStream + case failTask(Error, executorToCancel: HTTPRequestExecutor?) + case doNothing + } + + mutating func consumeMoreBodyData(resultOfPreviousConsume result: Result) -> ConsumeAction { + switch result { + case .success: + return self.consumeMoreBodyData() + case .failure(let error): + return self.failWithConsumptionError(error) + } + } + + private mutating func failWithConsumptionError(_ error: Error) -> ConsumeAction { + switch self.state { + case .initialized, .queued: + preconditionFailure("Invalid state") + case .executing(_, _, .initialized): + preconditionFailure("Invalid state: Must have received response head, before this method is called for the first time") + + case .executing(_, _, .buffering(_, next: .error(let connectionError))): + // if an error was received from the connection, we fail the task with the one + // from the connection, since it happened first. + self.state = .finished(error: connectionError) + return .failTask(connectionError, executorToCancel: nil) + + case .executing(let executor, _, .buffering(_, _)): + self.state = .finished(error: error) + return .failTask(error, executorToCancel: executor) + + case .executing(_, _, .waitingForRemote): + preconditionFailure("Invalid state... We just returned from a consumption function. We can't already be waiting") + + case .redirected: + preconditionFailure("Invalid state... Redirect don't call out to delegate functions. Thus we should never land here.") + + case .finished(error: .some): + // don't overwrite existing errors + return .doNothing + + case .finished(error: .none): + preconditionFailure("Invalid state... If no error occured, this must not be called, after the request was finished") + + case .modifying: + preconditionFailure() + } + } + + private mutating func consumeMoreBodyData() -> ConsumeAction { + switch self.state { + case .initialized, .queued: + preconditionFailure("Invalid state") + case .executing(_, _, .initialized): + preconditionFailure("Invalid state: Must have received response head, before this method is called for the first time") + case .executing(let executor, let requestState, .buffering(var buffer, next: .askExecutorForMore)): + self.state = .modifying + + if let byteBuffer = buffer.popFirst() { + self.state = .executing(executor, requestState, .buffering(buffer, next: .askExecutorForMore)) + return .consume(byteBuffer) + } + + // buffer is empty, wait for more + self.state = .executing(executor, requestState, .waitingForRemote) + return .requestMoreFromExecutor(executor) + + case .executing(let executor, let requestState, .buffering(var buffer, next: .eof)): + self.state = .modifying + + if let byteBuffer = buffer.popFirst() { + self.state = .executing(executor, requestState, .buffering(buffer, next: .eof)) + return .consume(byteBuffer) + } + + self.state = .finished(error: nil) + return .finishStream + + case .executing(_, _, .buffering(_, next: .error(let error))): + self.state = .finished(error: error) + return .failTask(error, executorToCancel: nil) + + case .executing(_, _, .waitingForRemote): + preconditionFailure("Invalid state... We just returned from a consumption function. We can't already be waiting") + + case .redirected: + return .doNothing + + case .finished(error: .some): + return .doNothing + + case .finished(error: .none): + preconditionFailure("Invalid state... If no error occured, this must not be called, after the request was finished") + + case .modifying: + preconditionFailure() + } + } + + enum FailAction { + case failTask(HTTPRequestScheduler?, HTTPRequestExecutor?) + case cancelExecutor(HTTPRequestExecutor) + case none + } + + mutating func fail(_ error: Error) -> FailAction { + switch self.state { + case .initialized: + self.state = .finished(error: error) + return .failTask(nil, nil) + case .queued(let queuer): + self.state = .finished(error: error) + return .failTask(queuer, nil) + case .executing(let executor, let requestState, .buffering(_, next: .eof)): + self.state = .executing(executor, requestState, .buffering(.init(), next: .error(error))) + return .cancelExecutor(executor) + case .executing(let executor, let requestState, .buffering(_, next: .askExecutorForMore)): + self.state = .executing(executor, requestState, .buffering(.init(), next: .error(error))) + return .cancelExecutor(executor) + case .executing(let executor, _, .buffering(_, next: .error(_))): + // this would override another error, let's keep the first one + return .cancelExecutor(executor) + case .executing(let executor, _, .initialized): + self.state = .finished(error: error) + return .failTask(nil, executor) + case .executing(let executor, _, .waitingForRemote): + self.state = .finished(error: error) + return .failTask(nil, executor) + case .redirected: + self.state = .finished(error: error) + return .failTask(nil, nil) + case .finished(.none): + // An error occurred after the request has finished. Ignore... + return .none + case .finished(.some(_)): + // this might happen, if the stream consumer has failed... let's just drop the data + return .none + case .modifying: + preconditionFailure("Invalid state") + } + } +} diff --git a/Sources/AsyncHTTPClient/RequestBag.swift b/Sources/AsyncHTTPClient/RequestBag.swift new file mode 100644 index 000000000..222a1472b --- /dev/null +++ b/Sources/AsyncHTTPClient/RequestBag.swift @@ -0,0 +1,421 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import struct Foundation.URL +import Logging +import NIO +import NIOConcurrencyHelpers +import NIOHTTP1 + +final class RequestBag { + let task: HTTPClient.Task + var eventLoop: EventLoop { + self.task.eventLoop + } + + private let delegate: Delegate + private let request: HTTPClient.Request + + // the request state is synchronized on the task eventLoop + private var state: StateMachine + + // MARK: HTTPClientTask properties + + var logger: Logger { + self.task.logger + } + + let connectionDeadline: NIODeadline + + let idleReadTimeout: TimeAmount? + + let requestHead: HTTPRequestHead + + let eventLoopPreference: HTTPClient.EventLoopPreference + + init(request: HTTPClient.Request, + eventLoopPreference: HTTPClient.EventLoopPreference, + task: HTTPClient.Task, + redirectHandler: RedirectHandler?, + connectionDeadline: NIODeadline, + idleReadTimeout: TimeAmount?, + delegate: Delegate) { + self.eventLoopPreference = eventLoopPreference + self.task = task + self.state = .init(redirectHandler: redirectHandler) + self.request = request + self.connectionDeadline = connectionDeadline + self.idleReadTimeout = idleReadTimeout + self.delegate = delegate + + self.requestHead = HTTPRequestHead( + version: .http1_1, + method: request.method, + uri: request.uri, + headers: request.headers + ) + + // TODO: comment in once we switch to using the Request bag in AHC +// self.task.taskDelegate = self +// self.task.futureResult.whenComplete { _ in +// self.task.taskDelegate = nil +// } + } + + private func requestWasQueued0(_ scheduler: HTTPRequestScheduler) { + self.task.eventLoop.assertInEventLoop() + self.state.requestWasQueued(scheduler) + } + + // MARK: - Request - + + private func willExecuteRequest0(_ executor: HTTPRequestExecutor) { + self.task.eventLoop.assertInEventLoop() + if !self.state.willExecuteRequest(executor) { + return executor.cancelRequest(self) + } + } + + private func requestHeadSent0() { + self.task.eventLoop.assertInEventLoop() + + self.delegate.didSendRequestHead(task: self.task, self.requestHead) + + if self.request.body == nil { + self.delegate.didSendRequest(task: self.task) + } + } + + private func resumeRequestBodyStream0() { + self.task.eventLoop.assertInEventLoop() + + let produceAction = self.state.resumeRequestBodyStream() + + switch produceAction { + case .startWriter: + guard let body = self.request.body else { + preconditionFailure("Expected to have a body, if the `HTTPRequestStateMachine` resume a request stream") + } + + let writer = HTTPClient.Body.StreamWriter { + self.writeNextRequestPart($0) + } + + body.stream(writer).whenComplete { + self.finishRequestBodyStream($0) + } + + case .succeedBackpressurePromise(let promise): + promise?.succeed(()) + + case .none: + break + } + } + + private func pauseRequestBodyStream0() { + self.task.eventLoop.assertInEventLoop() + + self.state.pauseRequestBodyStream() + } + + private func writeNextRequestPart(_ part: IOData) -> EventLoopFuture { + if self.eventLoop.inEventLoop { + return self.writeNextRequestPart0(part) + } else { + return self.eventLoop.flatSubmit { + self.writeNextRequestPart0(part) + } + } + } + + private func writeNextRequestPart0(_ part: IOData) -> EventLoopFuture { + self.task.eventLoop.assertInEventLoop() + + let action = self.state.writeNextRequestPart(part, taskEventLoop: self.task.eventLoop) + + switch action { + case .failTask(let error): + self.delegate.didReceiveError(task: self.task, error) + self.task.fail(with: error, delegateType: Delegate.self) + return self.task.eventLoop.makeFailedFuture(error) + + case .failFuture(let error): + return self.task.eventLoop.makeFailedFuture(error) + + case .write(let part, let writer, let future): + writer.writeRequestBodyPart(part, request: self) + self.delegate.didSendRequestPart(task: self.task, part) + return future + } + } + + private func finishRequestBodyStream(_ result: Result) { + self.task.eventLoop.assertInEventLoop() + + let action = self.state.finishRequestBodyStream(result) + + switch action { + case .none: + break + case .forwardStreamFinished(let writer, let promise): + writer.finishRequestBodyStream(self) + promise?.succeed(()) + + self.delegate.didSendRequest(task: self.task) + + case .forwardStreamFailureAndFailTask(let writer, let error, let promise): + writer.cancelRequest(self) + promise?.fail(error) + self.failTask0(error) + } + } + + // MARK: Request delegate calls + + func failTask0(_ error: Error) { + self.task.eventLoop.assertInEventLoop() + + self.delegate.didReceiveError(task: self.task, error) + self.task.promise.fail(error) + } + + // MARK: - Response - + + private func receiveResponseHead0(_ head: HTTPResponseHead) { + self.task.eventLoop.assertInEventLoop() + + // runs most likely on channel eventLoop + let forwardToDelegate = self.state.receiveResponseHead(head) + + guard forwardToDelegate else { return } + + self.delegate.didReceiveHead(task: self.task, head) + .hop(to: self.task.eventLoop) + .whenComplete { result in + // After the head received, let's start to consume body data + self.consumeMoreBodyData0(resultOfPreviousConsume: result) + } + } + + private func receiveResponseBodyParts0(_ buffer: CircularBuffer) { + self.task.eventLoop.assertInEventLoop() + + let maybeForwardBuffer = self.state.receiveResponseBodyParts(buffer) + + guard let forwardBuffer = maybeForwardBuffer else { + return + } + + self.delegate.didReceiveBodyPart(task: self.task, forwardBuffer) + .hop(to: self.task.eventLoop) + .whenComplete { result in + // on task el + self.consumeMoreBodyData0(resultOfPreviousConsume: result) + } + } + + private func succeedRequest0(_ buffer: CircularBuffer?) { + self.task.eventLoop.assertInEventLoop() + let action = self.state.succeedRequest(buffer) + + switch action { + case .none: + break + case .consume(let buffer): + self.delegate.didReceiveBodyPart(task: self.task, buffer) + .hop(to: self.task.eventLoop) + .whenComplete { + switch $0 { + case .success: + self.consumeMoreBodyData0(resultOfPreviousConsume: $0) + case .failure(let error): + // if in the response stream consumption an error has occurred, we need to + // cancel the running request and fail the task. + self.fail(error) + } + } + + case .succeedRequest: + do { + let response = try self.delegate.didFinishRequest(task: task) + self.task.promise.succeed(response) + } catch { + self.task.promise.fail(error) + } + + case .redirect(let handler, let head, let newURL): + handler.redirect(status: head.status, to: newURL, promise: self.task.promise) + } + } + + private func consumeMoreBodyData0(resultOfPreviousConsume result: Result) { + self.task.eventLoop.assertInEventLoop() + + let consumptionAction = self.state.consumeMoreBodyData(resultOfPreviousConsume: result) + + switch consumptionAction { + case .consume(let byteBuffer): + self.delegate.didReceiveBodyPart(task: self.task, byteBuffer) + .hop(to: self.task.eventLoop) + .whenComplete { + switch $0 { + case .success: + self.consumeMoreBodyData0(resultOfPreviousConsume: $0) + case .failure(let error): + self.fail(error) + } + } + + case .doNothing: + break + case .finishStream: + do { + let response = try self.delegate.didFinishRequest(task: task) + self.task.promise.succeed(response) + } catch { + self.task.promise.fail(error) + } + + case .failTask(let error, let executor): + executor?.cancelRequest(self) + self.failTask0(error) + case .requestMoreFromExecutor(let executor): + executor.demandResponseBodyStream(self) + } + } + + private func fail0(_ error: Error) { + self.task.eventLoop.assertInEventLoop() + + let action = self.state.fail(error) + + switch action { + case .failTask(let scheduler, let executor): + scheduler?.cancelRequest(self) + executor?.cancelRequest(self) + self.failTask0(error) + case .cancelExecutor(let executor): + executor.cancelRequest(self) + case .none: + break + } + } +} + +extension RequestBag: HTTPScheduledRequest { + func requestWasQueued(_ scheduler: HTTPRequestScheduler) { + if self.task.eventLoop.inEventLoop { + self.requestWasQueued0(scheduler) + } else { + self.task.eventLoop.execute { + self.requestWasQueued0(scheduler) + } + } + } + + func fail(_ error: Error) { + if self.task.eventLoop.inEventLoop { + self.fail0(error) + } else { + self.task.eventLoop.execute { + self.fail0(error) + } + } + } +} + +extension RequestBag: HTTPExecutingRequest { + func willExecuteRequest(_ executor: HTTPRequestExecutor) { + if self.task.eventLoop.inEventLoop { + self.willExecuteRequest0(executor) + } else { + self.task.eventLoop.execute { + self.willExecuteRequest0(executor) + } + } + } + + func requestHeadSent() { + if self.task.eventLoop.inEventLoop { + self.requestHeadSent0() + } else { + self.task.eventLoop.execute { + self.requestHeadSent0() + } + } + } + + func resumeRequestBodyStream() { + if self.task.eventLoop.inEventLoop { + self.resumeRequestBodyStream0() + } else { + self.task.eventLoop.execute { + self.resumeRequestBodyStream0() + } + } + } + + func pauseRequestBodyStream() { + if self.task.eventLoop.inEventLoop { + self.pauseRequestBodyStream0() + } else { + self.task.eventLoop.execute { + self.pauseRequestBodyStream0() + } + } + } + + func receiveResponseHead(_ head: HTTPResponseHead) { + if self.task.eventLoop.inEventLoop { + self.receiveResponseHead0(head) + } else { + self.task.eventLoop.execute { + self.receiveResponseHead0(head) + } + } + } + + func receiveResponseBodyParts(_ buffer: CircularBuffer) { + if self.task.eventLoop.inEventLoop { + self.receiveResponseBodyParts0(buffer) + } else { + self.task.eventLoop.execute { + self.receiveResponseBodyParts0(buffer) + } + } + } + + func succeedRequest(_ buffer: CircularBuffer?) { + if self.task.eventLoop.inEventLoop { + self.succeedRequest0(buffer) + } else { + self.task.eventLoop.execute { + self.succeedRequest0(buffer) + } + } + } +} + +extension RequestBag: HTTPClientTaskDelegate { + func cancel() { + if self.task.eventLoop.inEventLoop { + self.fail0(HTTPClientError.cancelled) + } else { + self.task.eventLoop.execute { + self.fail0(HTTPClientError.cancelled) + } + } + } +} diff --git a/Sources/AsyncHTTPClient/RequestValidation.swift b/Sources/AsyncHTTPClient/RequestValidation.swift index a9c1678f4..4c8fd9d21 100644 --- a/Sources/AsyncHTTPClient/RequestValidation.swift +++ b/Sources/AsyncHTTPClient/RequestValidation.swift @@ -59,7 +59,7 @@ extension HTTPHeaders { throw HTTPClientError.traceRequestWithBody } - guard (encodings.filter { $0 == "chunked" }.count <= 1) else { + guard (encodings.lazy.filter { $0 == "chunked" }.count <= 1) else { throw HTTPClientError.chunkedSpecifiedMultipleTimes } diff --git a/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift b/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift new file mode 100644 index 000000000..1c069da43 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// +// RequestBagTests+XCTest.swift +// +import XCTest + +/// +/// NOTE: This file was generated by generate_linux_tests.rb +/// +/// Do NOT edit this file directly as it will be regenerated automatically when needed. +/// + +extension RequestBagTests { + static var allTests: [(String, (RequestBagTests) -> () throws -> Void)] { + return [ + ("testWriteBackpressureWorks", testWriteBackpressureWorks), + ("testTaskIsFailedIfWritingFails", testTaskIsFailedIfWritingFails), + ("testCancelFailsTaskBeforeRequestIsSent", testCancelFailsTaskBeforeRequestIsSent), + ("testCancelFailsTaskAfterRequestIsSent", testCancelFailsTaskAfterRequestIsSent), + ("testCancelFailsTaskWhenTaskIsQueued", testCancelFailsTaskWhenTaskIsQueued), + ("testHTTPUploadIsCancelledEvenThoughRequestSucceeds", testHTTPUploadIsCancelledEvenThoughRequestSucceeds), + ] + } +} diff --git a/Tests/AsyncHTTPClientTests/RequestBagTests.swift b/Tests/AsyncHTTPClientTests/RequestBagTests.swift new file mode 100644 index 000000000..818c989e7 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/RequestBagTests.swift @@ -0,0 +1,483 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@testable import AsyncHTTPClient +import Logging +import NIO +import NIOHTTP1 +import XCTest + +final class RequestBagTests: XCTestCase { + func testWriteBackpressureWorks() { + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let logger = Logger(label: "test") + + var writtenBytes = 0 + var writes = 0 + let bytesToSent = (3000...10000).randomElement()! + var streamIsAllowedToWrite = false + + let writeDonePromise = embeddedEventLoop.makePromise(of: Void.self) + let requestBody: HTTPClient.Body = .stream(length: bytesToSent) { writer -> EventLoopFuture in + func write(donePromise: EventLoopPromise) { + XCTAssertTrue(streamIsAllowedToWrite) + guard writtenBytes < bytesToSent else { + return donePromise.succeed(()) + } + let byteCount = min(bytesToSent - writtenBytes, 100) + let buffer = ByteBuffer(bytes: [UInt8](repeating: 1, count: byteCount)) + writes += 1 + writer.write(.byteBuffer(buffer)).whenSuccess { _ in + writtenBytes += 100 + write(donePromise: donePromise) + } + } + + write(donePromise: writeDonePromise) + + return writeDonePromise.futureResult + } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org", method: .POST, body: requestBody)) + guard let request = maybeRequest else { return XCTFail("Expected to have a request") } + + let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) + let bag = RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + idleReadTimeout: nil, + delegate: delegate + ) + XCTAssert(bag.task.eventLoop === embeddedEventLoop) + + let executor = MockRequestExecutor(pauseRequestBodyPartStreamAfterASingleWrite: true) + + bag.willExecuteRequest(executor) + + XCTAssertEqual(delegate.hitDidSendRequestHead, 0) + bag.requestHeadSent() + XCTAssertEqual(delegate.hitDidSendRequestHead, 1) + streamIsAllowedToWrite = true + bag.resumeRequestBodyStream() + streamIsAllowedToWrite = false + + // after starting the body stream we should have received two writes + var eof = false + var receivedBytes = 0 + while !eof { + switch executor.nextBodyPart() { + case .body(.byteBuffer(let bytes)): + XCTAssertEqual(delegate.hitDidSendRequestPart, writes) + receivedBytes += bytes.readableBytes + case .body(.fileRegion(_)): + return XCTFail("We never send a file region. Something is really broken") + case .endOfStream: + XCTAssertEqual(delegate.hitDidSendRequest, 1) + eof = true + case .none: + // this should produce maximum two parts + streamIsAllowedToWrite = true + bag.resumeRequestBodyStream() + streamIsAllowedToWrite = false + XCTAssertLessThanOrEqual(executor.requestBodyParts.count, 2) + XCTAssertEqual(delegate.hitDidSendRequestPart, writes) + } + } + + XCTAssertEqual(receivedBytes, bytesToSent, "We have sent all request bytes...") + + XCTAssertNil(delegate.receivedHead, "Expected not to have a response head, before `receiveResponseHead`") + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: .init([ + ("Transfer-Encoding", "chunked"), + ])) + bag.receiveResponseHead(responseHead) + XCTAssertEqual(responseHead, delegate.receivedHead) + XCTAssertNoThrow(try XCTUnwrap(delegate.backpressurePromise).succeed(())) + XCTAssertTrue(executor.signalledDemandForResponseBody) + executor.resetDemandSignal() + + // we will receive 20 chunks with each 10 byteBuffers and 32 bytes + let bodyPart = ByteBuffer(bytes: 0..<32) + for i in 0..<20 { + let chunk = CircularBuffer(repeating: bodyPart, count: 10) + XCTAssertEqual(delegate.hitDidReceiveBodyPart, i * 10) // 0 + bag.receiveResponseBodyParts(chunk) + + // consume the 10 buffers + for j in 0..<10 { + XCTAssertEqual(delegate.hitDidReceiveBodyPart, i * 10 + j + 1) + XCTAssertEqual(delegate.lastBodyPart, bodyPart) + XCTAssertNoThrow(try XCTUnwrap(delegate.backpressurePromise).succeed(())) + + if j < 9 { + XCTAssertFalse(executor.signalledDemandForResponseBody) + } else { + XCTAssertTrue(executor.signalledDemandForResponseBody) + } + } + + executor.resetDemandSignal() + } + + XCTAssertEqual(delegate.hitDidReceiveResponse, 0) + bag.succeedRequest(nil) + XCTAssertEqual(delegate.hitDidReceiveResponse, 1) + + XCTAssertNoThrow(try bag.task.futureResult.wait(), "The request has succeeded") + } + + func testTaskIsFailedIfWritingFails() { + struct TestError: Error, Equatable {} + + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let logger = Logger(label: "test") + + let requestBody: HTTPClient.Body = .stream(length: 12) { writer -> EventLoopFuture in + + writer.write(.byteBuffer(ByteBuffer(bytes: 0...3))).flatMap { _ -> EventLoopFuture in + embeddedEventLoop.makeFailedFuture(TestError()) + } + } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org", method: .POST, body: requestBody)) + guard let request = maybeRequest else { return XCTFail("Expected to have a request") } + + let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) + let bag = RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + idleReadTimeout: nil, + delegate: delegate + ) + XCTAssert(bag.task.eventLoop === embeddedEventLoop) + + let executor = MockRequestExecutor() + + bag.willExecuteRequest(executor) + + XCTAssertEqual(delegate.hitDidSendRequestHead, 0) + bag.requestHeadSent() + XCTAssertEqual(delegate.hitDidSendRequestHead, 1) + XCTAssertEqual(delegate.hitDidSendRequestPart, 0) + bag.resumeRequestBodyStream() + XCTAssertEqual(delegate.hitDidSendRequestPart, 1) + XCTAssertEqual(delegate.hitDidReceiveError, 1) + XCTAssertEqual(delegate.lastError as? TestError, TestError()) + + XCTAssertTrue(executor.isCancelled) + + XCTAssertThrowsError(try bag.task.futureResult.wait()) { + XCTAssertEqual($0 as? TestError, TestError()) + } + } + + func testCancelFailsTaskBeforeRequestIsSent() { + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org")) + guard let request = maybeRequest else { return XCTFail("Expected to have a request") } + + let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) + let bag = RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + idleReadTimeout: nil, + delegate: delegate + ) + XCTAssert(bag.eventLoop === embeddedEventLoop) + + let executor = MockRequestExecutor() + bag.cancel() + + bag.willExecuteRequest(executor) + XCTAssertTrue(executor.isCancelled, "The request bag, should call cancel immediately on the executor") + XCTAssertThrowsError(try bag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .cancelled) + } + } + + func testCancelFailsTaskAfterRequestIsSent() { + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org")) + guard let request = maybeRequest else { return XCTFail("Expected to have a request") } + + let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) + let bag = RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + idleReadTimeout: nil, + delegate: delegate + ) + XCTAssert(bag.eventLoop === embeddedEventLoop) + + let executor = MockRequestExecutor() + + bag.willExecuteRequest(executor) + XCTAssertFalse(executor.isCancelled) + + XCTAssertEqual(delegate.hitDidSendRequestHead, 0) + XCTAssertEqual(delegate.hitDidSendRequest, 0) + bag.requestHeadSent() + XCTAssertEqual(delegate.hitDidSendRequestHead, 1) + XCTAssertEqual(delegate.hitDidSendRequest, 1) + + bag.cancel() + XCTAssertTrue(executor.isCancelled, "The request bag, should call cancel immediately on the executor") + + XCTAssertThrowsError(try bag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .cancelled) + } + } + + func testCancelFailsTaskWhenTaskIsQueued() { + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org")) + guard let request = maybeRequest else { return XCTFail("Expected to have a request") } + + let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) + let bag = RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + idleReadTimeout: nil, + delegate: delegate + ) + + let queuer = MockTaskQueuer() + bag.requestWasQueued(queuer) + + XCTAssertEqual(queuer.hitCancelCount, 0) + bag.cancel() + XCTAssertEqual(queuer.hitCancelCount, 1) + + XCTAssertThrowsError(try bag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .cancelled) + } + } + + func testHTTPUploadIsCancelledEvenThoughRequestSucceeds() { + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + let writeSecondPartPromise = embeddedEventLoop.makePromise(of: Void.self) + + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( + url: "https://swift.org", + method: .POST, + headers: ["content-length": "12"], + body: .stream(length: 12) { writer -> EventLoopFuture in + var firstWriteSuccess = false + return writer.write(.byteBuffer(.init(bytes: 0...3))).flatMap { _ in + firstWriteSuccess = true + + return writeSecondPartPromise.futureResult + }.flatMap { + return writer.write(.byteBuffer(.init(bytes: 4...7))) + }.always { result in + XCTAssertTrue(firstWriteSuccess) + + guard case .failure(let error) = result else { + return XCTFail("Expected the second write to fail") + } + XCTAssertEqual(error as? HTTPClientError, .requestStreamCancelled) + } + } + )) + guard let request = maybeRequest else { return XCTFail("Expected to have a request") } + + let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) + let bag = RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + idleReadTimeout: nil, + delegate: delegate + ) + + let executor = MockRequestExecutor() + bag.willExecuteRequest(executor) + + XCTAssertEqual(delegate.hitDidSendRequestHead, 0) + XCTAssertEqual(delegate.hitDidSendRequest, 0) + bag.requestHeadSent() + XCTAssertEqual(delegate.hitDidSendRequestHead, 1) + XCTAssertEqual(delegate.hitDidSendRequest, 0) + + bag.resumeRequestBodyStream() + XCTAssertEqual(executor.nextBodyPart(), .body(.byteBuffer(.init(bytes: 0...3)))) + // receive a 301 response immediately. + bag.receiveResponseHead(.init(version: .http1_1, status: .movedPermanently)) + bag.succeedRequest(.init()) + + // if we now write our second part of the response this should fail the backpressure promise + writeSecondPartPromise.succeed(()) + + XCTAssertEqual(delegate.receivedHead?.status, .movedPermanently) + XCTAssertNoThrow(try bag.task.futureResult.wait()) + } +} + +class MockRequestExecutor: HTTPRequestExecutor { + enum RequestParts: Equatable { + case body(IOData) + case endOfStream + } + + let pauseRequestBodyPartStreamAfterASingleWrite: Bool + + private(set) var requestBodyParts = CircularBuffer() + private(set) var isCancelled: Bool = false + private(set) var signalledDemandForResponseBody: Bool = false + + init(pauseRequestBodyPartStreamAfterASingleWrite: Bool = false) { + self.pauseRequestBodyPartStreamAfterASingleWrite = pauseRequestBodyPartStreamAfterASingleWrite + } + + func nextBodyPart() -> RequestParts? { + guard !self.requestBodyParts.isEmpty else { return nil } + return self.requestBodyParts.removeFirst() + } + + func resetDemandSignal() { + self.signalledDemandForResponseBody = false + } + + // this should always be called twice. When we receive the first call, the next call to produce + // data is already scheduled. If we call pause here, once, after the second call new subsequent + // calls should not be scheduled. + func writeRequestBodyPart(_ part: IOData, request: HTTPExecutingRequest) { + if self.requestBodyParts.isEmpty, self.pauseRequestBodyPartStreamAfterASingleWrite { + request.pauseRequestBodyStream() + } + self.requestBodyParts.append(.body(part)) + } + + func finishRequestBodyStream(_: HTTPExecutingRequest) { + self.requestBodyParts.append(.endOfStream) + } + + func demandResponseBodyStream(_: HTTPExecutingRequest) { + self.signalledDemandForResponseBody = true + } + + func cancelRequest(_: HTTPExecutingRequest) { + self.isCancelled = true + } +} + +class UploadCountingDelegate: HTTPClientResponseDelegate { + typealias Response = Void + + let eventLoop: EventLoop + + private(set) var hitDidSendRequestHead = 0 + private(set) var hitDidSendRequestPart = 0 + private(set) var hitDidSendRequest = 0 + private(set) var hitDidReceiveResponse = 0 + private(set) var hitDidReceiveBodyPart = 0 + private(set) var hitDidReceiveError = 0 + + private(set) var receivedHead: HTTPResponseHead? + private(set) var lastBodyPart: ByteBuffer? + private(set) var backpressurePromise: EventLoopPromise? + private(set) var lastError: Error? + + init(eventLoop: EventLoop) { + self.eventLoop = eventLoop + } + + func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) { + self.hitDidSendRequestHead += 1 + } + + func didSendRequestPart(task: HTTPClient.Task, _ part: IOData) { + self.hitDidSendRequestPart += 1 + } + + func didSendRequest(task: HTTPClient.Task) { + self.hitDidSendRequest += 1 + } + + func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { + self.receivedHead = head + return self.createBackpressurePromise() + } + + func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { + assert(self.backpressurePromise == nil) + self.hitDidReceiveBodyPart += 1 + self.lastBodyPart = buffer + return self.createBackpressurePromise() + } + + func didFinishRequest(task: HTTPClient.Task) throws { + self.hitDidReceiveResponse += 1 + } + + func didReceiveError(task: HTTPClient.Task, _ error: Error) { + self.hitDidReceiveError += 1 + self.lastError = error + } + + private func createBackpressurePromise() -> EventLoopFuture { + assert(self.backpressurePromise == nil) + self.backpressurePromise = self.eventLoop.makePromise(of: Void.self) + return self.backpressurePromise!.futureResult.always { _ in + self.backpressurePromise = nil + } + } +} + +class MockTaskQueuer: HTTPRequestScheduler { + private(set) var hitCancelCount = 0 + + init() {} + + func cancelRequest(_: HTTPScheduledRequest) { + self.hitCancelCount += 1 + } +} diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index 83ea08033..cf5987f61 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -37,6 +37,7 @@ import XCTest testCase(HTTPConnectionPool_FactoryTests.allTests), testCase(HTTPRequestStateMachineTests.allTests), testCase(LRUCacheTests.allTests), + testCase(RequestBagTests.allTests), testCase(RequestValidationTests.allTests), testCase(SOCKSEventsHandlerTests.allTests), testCase(SSLContextCacheTests.allTests),