From 519ca3b105a2c13e1c3df54ed2ff4bbfd931df61 Mon Sep 17 00:00:00 2001 From: Fabian Fett <fabianfett@apple.com> Date: Mon, 28 Jun 2021 15:59:34 +0200 Subject: [PATCH 1/9] Add `HTTPRequestStateMachine` --- .../HTTPRequestStateMachine.swift | 467 ++++++++++++++++++ .../HTTPRequestStateMachineTests.swift | 137 +++++ 2 files changed, 604 insertions(+) create mode 100644 Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift create mode 100644 Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift new file mode 100644 index 000000000..a5c1b8fb1 --- /dev/null +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift @@ -0,0 +1,467 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIO +import NIOHTTP1 + +struct HTTPRequestStateMachine { + fileprivate enum State { + case initialized + case running(RequestState, ResponseState) + case finished + + case failed(Error) + } + + fileprivate enum RequestState { + enum ExpectedBody { + case length(Int) + case stream + } + + enum ProducerControlState: Equatable { + case producing + case paused + } + + case verifyRequest + case streaming(expectedBodyLength: Int?, sentBodyBytes: Int, producer: ProducerControlState) + case endSent + } + + fileprivate enum ResponseState { + enum StreamControlState { + case downstreamHasDemand + case readEventPending + case waiting + } + + case initialized + case receivingBody(StreamControlState) + case endReceived + } + + enum Action { + enum AfterHeadContinueWith { + case sendEnd + case startBodyStream + } + + case verifyRequest + + case sendRequestHead(HTTPRequestHead, startBody: Bool, startReadTimeoutTimer: TimeAmount?) + case sendBodyPart(IOData) + case sendRequestEnd(startReadTimeoutTimer: TimeAmount?) + + case pauseRequestBodyStream + case resumeRequestBodyStream + + case forwardResponseHead(HTTPResponseHead) + case forwardResponseBodyPart(ByteBuffer, resetReadTimeoutTimer: TimeAmount?) + case forwardResponseEnd(readPending: Bool, clearReadTimeoutTimer: Bool) + + case failRequest(Error, closeStream: Bool) + + case read + case wait + } + + private var state: State = .initialized + + private var isChannelWritable: Bool + private let idleReadTimeout: TimeAmount? + + init(isChannelWritable: Bool, idleReadTimeout: TimeAmount?) { + self.isChannelWritable = isChannelWritable + self.idleReadTimeout = idleReadTimeout + } + + mutating func writabilityChanged(writable: Bool) -> Action { + self.isChannelWritable = writable + + switch self.state { + case .initialized, + .finished, + .failed: + return .wait + + case .running(.verifyRequest, _), .running(.endSent, _): + return .wait + + case .running(.streaming(let expectedBody, let sentBodyBytes, producer: .paused), let responseState): + if writable { + let requestState: RequestState = .streaming( + expectedBodyLength: expectedBody, + sentBodyBytes: sentBodyBytes, + producer: .producing + ) + + self.state = .running(requestState, responseState) + return .resumeRequestBodyStream + } else { + // no state change needed + return .wait + } + + case .running(.streaming(let expectedBodyLength, let sentBodyBytes, producer: .producing), let responseState): + if !writable { + let requestState: RequestState = .streaming( + expectedBodyLength: expectedBodyLength, + sentBodyBytes: sentBodyBytes, + producer: .paused + ) + self.state = .running(requestState, responseState) + return .pauseRequestBodyStream + } else { + // no state change needed + return .wait + } + } + } + + mutating func readEventCaught() -> Action { + return .read + } + + mutating func errorHappened(_ error: Error) -> Action { + switch self.state { + case .initialized: + preconditionFailure("After the state machine has been initialized, start must be called immidiatly. Thus this state is unreachable") + case .running: + self.state = .failed(error) + return .failRequest(error, closeStream: true) + case .finished, .failed: + preconditionFailure("If the request is finished or failed, we expect the connection state machine to remove the request immidiatly from its state. Thus this state is unreachable.") + } + } + + mutating func start() -> Action { + guard case .initialized = self.state else { + preconditionFailure("Invalid state") + } + + self.state = .running(.verifyRequest, .initialized) + return .verifyRequest + } + + mutating func requestVerified(_ head: HTTPRequestHead) -> Action { + guard case .running(.verifyRequest, .initialized) = self.state else { + preconditionFailure("Invalid state") + } + + guard self.isChannelWritable else { + preconditionFailure("Unimplemented. Wait with starting the request here!") + } + + if let value = head.headers.first(name: "content-length"), let length = Int(value), length > 0 { + self.state = .running(.streaming(expectedBodyLength: length, sentBodyBytes: 0, producer: .producing), .initialized) + return .sendRequestHead(head, startBody: true, startReadTimeoutTimer: nil) + } else if head.headers.contains(name: "transfer-encoding") { + self.state = .running(.streaming(expectedBodyLength: nil, sentBodyBytes: 0, producer: .producing), .initialized) + return .sendRequestHead(head, startBody: true, startReadTimeoutTimer: nil) + } else { + self.state = .running(.endSent, .initialized) + return .sendRequestHead(head, startBody: false, startReadTimeoutTimer: self.idleReadTimeout) + } + } + + mutating func requestVerificationFailed(_ error: Error) -> Action { + guard case .running(.verifyRequest, .initialized) = self.state else { + preconditionFailure("Invalid state") + } + + self.state = .failed(error) + return .failRequest(error, closeStream: false) + } + + mutating func requestStreamPartReceived(_ part: IOData) -> Action { + switch self.state { + case .initialized: + preconditionFailure("Invalid state: \(self.state)") + + case .running(.verifyRequest, _), + .running(.endSent, _): + preconditionFailure("Invalid state: \(self.state)") + + case .running(.streaming(let expectedBodyLength, var sentBodyBytes, let producerState), let responseState): + // More streamed data is accepted, even though the producer should stop. However + // there might be thread syncronisations situations in which the producer might not + // be aware that it needs to stop yet. + + if let expected = expectedBodyLength { + if sentBodyBytes + part.readableBytes > expected { + let error = HTTPClientError.bodyLengthMismatch + + switch responseState { + case .initialized, .receivingBody: + self.state = .failed(error) + case .endReceived: + #warning("TODO: This needs to be fixed. @Cory: What does this mean here?") + preconditionFailure("Unimplemented") + } + + return .failRequest(error, closeStream: true) + } + } + + sentBodyBytes += part.readableBytes + + let requestState: RequestState = .streaming( + expectedBodyLength: expectedBodyLength, + sentBodyBytes: sentBodyBytes, + producer: producerState + ) + + self.state = .running(requestState, responseState) + + return .sendBodyPart(part) + + case .failed: + return .wait + + case .finished: + // a request may be finished, before we send all parts. We may still receive something + // here because of a thread race + return .wait + } + } + + mutating func requestStreamFinished() -> Action { + switch self.state { + case .initialized: + preconditionFailure("Invalid state") + case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), let responseState): + if let expected = expectedBodyLength, expected != sentBodyBytes { + let error = HTTPClientError.bodyLengthMismatch + + switch responseState { + case .initialized, .receivingBody: + self.state = .failed(error) + case .endReceived: + #warning("TODO: This needs to be fixed. @Cory: What does this mean here?") + preconditionFailure("Unimplemented") + } + + return .failRequest(error, closeStream: true) + } + + self.state = .running(.endSent, responseState) + return .sendRequestEnd(startReadTimeoutTimer: self.idleReadTimeout) + + case .running(.verifyRequest, _), + .running(.endSent, _): + preconditionFailure("Invalid state") + + case .finished: + return .wait + + case .failed: + return .wait + } + } + + mutating func requestCancelled() -> Action { + switch self.state { + case .initialized, .running: + let error = HTTPClientError.cancelled + self.state = .failed(error) + return .failRequest(error, closeStream: true) + case .finished: + return .wait + case .failed: + return .wait + } + } + + mutating func channelInactive() -> Action { + switch self.state { + case .initialized, .running: + let error = HTTPClientError.remoteConnectionClosed + self.state = .failed(error) + return .failRequest(error, closeStream: false) + case .finished: + return .wait + case .failed: + // don't overwrite error + return .wait + } + } + + // MARK: - Response + + mutating func receivedHTTPResponseHead(_ head: HTTPResponseHead) -> Action { + switch self.state { + case .initialized: + preconditionFailure("How can we receive a response head before sending a request head ourselves") + + case .running(let requestState, .initialized): + switch requestState { + case .verifyRequest: + preconditionFailure("How can we receive a response head before sending a request head ourselves") + case .streaming, .endSent: + break + } + + self.state = .running(requestState, .receivingBody(.waiting)) + return .forwardResponseHead(head) + + case .running(_, .receivingBody), .running(_, .endReceived), .finished: + preconditionFailure("How can we sucessfully finish the request, before having received a head") + case .failed: + return .wait + } + } + + mutating func receivedHTTPResponseBodyPart(_ body: ByteBuffer) -> Action { + switch self.state { + case .initialized: + preconditionFailure("How can we receive a response head before sending a request head ourselves") + + case .running(_, .initialized): + preconditionFailure("How can we receive a response body, if we haven't a received a head") + + case .running(let requestState, .receivingBody(let streamState)): + switch streamState { + case .waiting, .readEventPending: + break + case .downstreamHasDemand: + self.state = .running(requestState, .receivingBody(.waiting)) + } + + return .forwardResponseBodyPart(body, resetReadTimeoutTimer: self.idleReadTimeout) + + case .running(_, .endReceived), .finished: + preconditionFailure("How can we sucessfully finish the request, before having received a head") + case .failed: + return .wait + } + } + + mutating func receivedHTTPResponseEnd() -> Action { + switch self.state { + case .initialized: + preconditionFailure("How can we receive a response head before sending a request head ourselves") + + case .running(_, .initialized): + preconditionFailure("How can we receive a response body, if we haven't a received a head") + + case .running(.streaming, .receivingBody(let streamState)): + preconditionFailure("Unimplemented") + #warning("@Fabian: We received response end, before sending our own request's end.") + + case .running(.endSent, .receivingBody(let streamState)): + let readPending: Bool + switch streamState { + case .readEventPending: + readPending = true + case .downstreamHasDemand, .waiting: + readPending = false + } + + self.state = .finished + return .forwardResponseEnd(readPending: readPending, clearReadTimeoutTimer: self.idleReadTimeout != nil) + + case .running(.verifyRequest, .receivingBody), + .running(_, .endReceived), .finished: + preconditionFailure("invalid state") + case .failed: + return .wait + } + } + + mutating func forwardMoreBodyParts() -> Action { + guard case .running(let requestState, .receivingBody(let streamControl)) = self.state else { + preconditionFailure("Invalid state") + } + + switch streamControl { + case .waiting: + self.state = .running(requestState, .receivingBody(.downstreamHasDemand)) + return .wait + case .readEventPending: + self.state = .running(requestState, .receivingBody(.waiting)) + return .read + case .downstreamHasDemand: + // We have received a request for more data before. Normally we only expect one request + // for more data, but a race can come into play here. + return .wait + } + } + + mutating func idleReadTimeoutTriggered() -> Action { + guard case .running(.endSent, let responseState) = self.state else { + preconditionFailure("We only schedule idle read timeouts after we have sent the complete request") + } + + if case .endReceived = responseState { + preconditionFailure("Invalid state: If we have received everything, we must not schedule further timeout timers") + } + + let error = HTTPClientError.readTimeout + self.state = .failed(error) + return .failRequest(error, closeStream: true) + } +} + +extension HTTPRequestStateMachine: CustomStringConvertible { + var description: String { + switch self.state { + case .initialized: + return "HTTPRequestStateMachine(.initialized, isWritable: \(self.isChannelWritable))" + case .running(let requestState, let responseState): + return "HTTPRequestStateMachine(.running(request: \(requestState), response: \(responseState)), isWritable: \(self.isChannelWritable))" + case .finished: + return "HTTPRequestStateMachine(.finished, isWritable: \(self.isChannelWritable))" + case .failed(let error): + return "HTTPRequestStateMachine(.failed(\(error)), isWritable: \(self.isChannelWritable))" + } + } +} + +extension HTTPRequestStateMachine.RequestState: CustomStringConvertible { + var description: String { + switch self { + case .verifyRequest: + return ".verifyRequest" + case .streaming(expectedBodyLength: let expected, let sent, producer: let producer): + return ".sendingHead(sent: \(expected != nil ? String(expected!) : "-"), sent: \(sent), producer: \(producer)" + case .endSent: + return ".endSent" + } + } +} + +extension HTTPRequestStateMachine.RequestState.ProducerControlState { + var description: String { + switch self { + case .paused: + return ".paused" + case .producing: + return ".producing" + } + } +} + +extension HTTPRequestStateMachine.ResponseState: CustomStringConvertible { + var description: String { + switch self { + case .initialized: + return ".initialized" + case .receivingBody(let streamState): + return ".receivingBody(streamState: \(streamState))" + case .endReceived: + return ".endReceived" + } + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift new file mode 100644 index 000000000..e06ac91d0 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift @@ -0,0 +1,137 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@testable import AsyncHTTPClient +import NIO +import NIOHTTP1 +import XCTest + +class HTTPRequestStateMachineTests: XCTestCase { + func testSimpleGETRequest() { + var state = HTTPRequestStateMachine(isChannelWritable: true, idleReadTimeout: nil) + XCTAssertEqual(state.start(), .verifyRequest) + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + XCTAssertEqual(state.requestVerified(requestHead), .sendRequestHead(requestHead, startBody: false, startReadTimeoutTimer: nil)) + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) + XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead)) + let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) + XCTAssertEqual(state.receivedHTTPResponseBodyPart(responseBody), .forwardResponseBodyPart(responseBody, resetReadTimeoutTimer: nil)) + XCTAssertEqual(state.receivedHTTPResponseEnd(), .forwardResponseEnd(readPending: false, clearReadTimeoutTimer: false)) + } + + func testPOSTRequestWithWriterBackpressure() { + var state = HTTPRequestStateMachine(isChannelWritable: true, idleReadTimeout: nil) + XCTAssertEqual(state.start(), .verifyRequest) + let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "4")])) + XCTAssertEqual(state.requestVerified(requestHead), .sendRequestHead(requestHead, startBody: true, startReadTimeoutTimer: nil)) + let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0])) + let part1 = IOData.byteBuffer(ByteBuffer(bytes: [1])) + let part2 = IOData.byteBuffer(ByteBuffer(bytes: [2])) + let part3 = IOData.byteBuffer(ByteBuffer(bytes: [3])) + XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) + + // oh the channel reports... we should slow down producing... + XCTAssertEqual(state.writabilityChanged(writable: false), .pauseRequestBodyStream) + + // but we issued a .produceMoreRequestBodyData before... Thus, we must accept more produced + // data + XCTAssertEqual(state.requestStreamPartReceived(part2), .sendBodyPart(part2)) + // however when we have put the data on the channel, we should not issue further + // .produceMoreRequestBodyData events + + // once we receive a writable event again, we can allow the producer to produce more data + XCTAssertEqual(state.writabilityChanged(writable: true), .resumeRequestBodyStream) + XCTAssertEqual(state.requestStreamPartReceived(part3), .sendBodyPart(part3)) + XCTAssertEqual(state.requestStreamFinished(), .sendRequestEnd(startReadTimeoutTimer: nil)) + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) + XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead)) + let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) + XCTAssertEqual(state.receivedHTTPResponseBodyPart(responseBody), .forwardResponseBodyPart(responseBody, resetReadTimeoutTimer: nil)) + XCTAssertEqual(state.receivedHTTPResponseEnd(), .forwardResponseEnd(readPending: false, clearReadTimeoutTimer: false)) + } + + func testPOSTContentLengthIsTooLong() { + var state = HTTPRequestStateMachine(isChannelWritable: true, idleReadTimeout: nil) + XCTAssertEqual(state.start(), .verifyRequest) + let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "4")])) + XCTAssertEqual(state.requestVerified(requestHead), .sendRequestHead(requestHead, startBody: true, startReadTimeoutTimer: nil)) + let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) + let part1 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) + XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + + let failAction = state.requestStreamPartReceived(part1) + guard case .failRequest(let error, closeStream: true) = failAction else { + return XCTFail("Unexpected action: \(failAction)") + } + + XCTAssertEqual(error as? HTTPClientError, .bodyLengthMismatch) + } + + func testPOSTContentLengthIsTooShort() { + var state = HTTPRequestStateMachine(isChannelWritable: true, idleReadTimeout: nil) + XCTAssertEqual(state.start(), .verifyRequest) + let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "8")])) + XCTAssertEqual(state.requestVerified(requestHead), .sendRequestHead(requestHead, startBody: true, startReadTimeoutTimer: nil)) + let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) + XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + + let failAction = state.requestStreamFinished() + guard case .failRequest(let error, closeStream: true) = failAction else { + return XCTFail("Unexpected action: \(failAction)") + } + + XCTAssertEqual(error as? HTTPClientError, .bodyLengthMismatch) + } +} + +extension HTTPRequestStateMachine.Action: Equatable { + public static func == (lhs: HTTPRequestStateMachine.Action, rhs: HTTPRequestStateMachine.Action) -> Bool { + switch (lhs, rhs) { + case (.verifyRequest, .verifyRequest): + return true + + case (.sendRequestHead(let lhsHead, let lhsStartBody, let lhsIdleReadTimeout), .sendRequestHead(let rhsHead, let rhsStartBody, let rhsIdleReadTimeout)): + return lhsHead == rhsHead && lhsStartBody == rhsStartBody && lhsIdleReadTimeout == rhsIdleReadTimeout + case (.sendBodyPart(let lhsData), .sendBodyPart(let rhsData)): + return lhsData == rhsData + case (.sendRequestEnd, .sendRequestEnd): + return true + + case (.pauseRequestBodyStream, .pauseRequestBodyStream): + return true + case (.resumeRequestBodyStream, .resumeRequestBodyStream): + return true + + case (.forwardResponseHead(let lhsHead), .forwardResponseHead(let rhsHead)): + return lhsHead == rhsHead + case (.forwardResponseBodyPart(let lhsData, let lhsIdleReadTimeout), .forwardResponseBodyPart(let rhsData, let rhsIdleReadTimeout)): + return lhsIdleReadTimeout == rhsIdleReadTimeout && lhsData == rhsData + case (.forwardResponseEnd(readPending: let lhsPending), .forwardResponseEnd(readPending: let rhsPending)): + return lhsPending == rhsPending + + case (.failRequest(_, closeStream: let lhsClose), .failRequest(_, closeStream: let rhsClose)): + return lhsClose == rhsClose + + case (.read, .read): + return true + case (.wait, .wait): + return true + default: + return false + } + } +} From bcc60210abf1e7a19726226dc38827e928542ae2 Mon Sep 17 00:00:00 2001 From: Fabian Fett <fabianfett@mac.com> Date: Wed, 30 Jun 2021 09:17:53 +0200 Subject: [PATCH 2/9] Apply suggestions from code review Co-authored-by: Cory Benfield <lukasa@apple.com> --- .../ConnectionPool/HTTPRequestStateMachine.swift | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift index a5c1b8fb1..30f2ff345 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift @@ -53,7 +53,7 @@ struct HTTPRequestStateMachine { } enum Action { - enum AfterHeadContinueWith { + enum NextMessageToSend { case sendEnd case startBodyStream } @@ -137,12 +137,12 @@ struct HTTPRequestStateMachine { mutating func errorHappened(_ error: Error) -> Action { switch self.state { case .initialized: - preconditionFailure("After the state machine has been initialized, start must be called immidiatly. Thus this state is unreachable") + preconditionFailure("After the state machine has been initialized, start must be called immediately. Thus this state is unreachable") case .running: self.state = .failed(error) return .failRequest(error, closeStream: true) case .finished, .failed: - preconditionFailure("If the request is finished or failed, we expect the connection state machine to remove the request immidiatly from its state. Thus this state is unreachable.") + preconditionFailure("If the request is finished or failed, we expect the connection state machine to remove the request immediately from its state. Thus this state is unreachable.") } } @@ -329,7 +329,7 @@ struct HTTPRequestStateMachine { preconditionFailure("How can we receive a response head before sending a request head ourselves") case .running(_, .initialized): - preconditionFailure("How can we receive a response body, if we haven't a received a head") + preconditionFailure("How can we receive a response body, if we haven't received a head") case .running(let requestState, .receivingBody(let streamState)): switch streamState { From 7e5ad2636f42c3a358fe4276e6d0f40da623b16c Mon Sep 17 00:00:00 2001 From: Fabian Fett <fabianfett@apple.com> Date: Wed, 30 Jun 2021 15:05:36 +0200 Subject: [PATCH 3/9] Code review --- .../HTTPRequestStateMachine.swift | 518 ++++++++++++------ .../HTTPRequestStateMachineTests+XCTest.swift | 34 ++ .../HTTPRequestStateMachineTests.swift | 24 +- Tests/LinuxMain.swift | 1 + 4 files changed, 385 insertions(+), 192 deletions(-) create mode 100644 Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift index 30f2ff345..be76c1781 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift @@ -17,45 +17,75 @@ import NIOHTTP1 struct HTTPRequestStateMachine { fileprivate enum State { + /// The initial state machine state. The only valid mutation is `start()`. The state will + /// transition to `.verifyingRequest` case initialized + /// During this state the request's soundness is checked, before sending it out on the wire. + /// Valid transitions are: + /// - .waitForChannelToBecomeWritable (if the Channel is not writable) + /// - .running(.streaming, .initialized) (if the Channel is writable and a request body is expected) + /// - .running(.endSent, .initialized) (if the Channel is writable and no request body is expected) + /// - .failed (if an error was found in the request soundness check) + case verifyingRequest + /// Waiting for the channel to be writable. Valid transitions are: + /// - .running(.streaming, .initialized) (once the Channel is writable again and if a request body is expected) + /// - .running(.endSent, .initialized) (once the Channel is writable again and no request body is expected) + /// - .failed (if a connection error occurred) + case waitForChannelToBecomeWritable(HTTPRequestHead) + /// A request is on the wire. Valid transitions are: + /// - .finished + /// - .failed case running(RequestState, ResponseState) + /// The request has completed successfully case finished - + /// The request has failed case failed(Error) } + /// A sub state for a running request. More specifically for sending a request body. fileprivate enum RequestState { - enum ExpectedBody { - case length(Int) - case stream - } - + /// A sub state for sending a request body. Stores whether a producer should produce more + /// bytes or should pause. enum ProducerControlState: Equatable { + /// The request body producer should produce more body bytes. The channel is writable, case producing + /// The request body producer should pause producing more bytes. The channel is not writable, case paused } - case verifyRequest + /// The request is streaming its request body. `expectedBodyLength` has a value, if the request header contained + /// a `"content-length"` header field. It is the request header contained a `"transfer-encoding" = "chunked"` + /// header field. case streaming(expectedBodyLength: Int?, sentBodyBytes: Int, producer: ProducerControlState) + /// The request has sent its request body and end. case endSent } fileprivate enum ResponseState { - enum StreamControlState { + /// A sub state for receiving a response. Stores whether the consumer has either signaled demand for more data or + /// is busy consuming the so far forwarded bytes + enum ConsumerControlState { + case downstreamIsConsuming(readPending: Bool) case downstreamHasDemand - case readEventPending - case waiting } - case initialized - case receivingBody(StreamControlState) + /// A response head has not been received yet. + case waitingForHead + /// A response head has been received and we are ready to consume more data of the wire + case receivingBody(HTTPResponseHead, ConsumerControlState) + /// A response end has been received and we are ready to consume more data of the wire case endReceived } enum Action { - enum NextMessageToSend { - case sendEnd - case startBodyStream + /// A action to execute, when we consider a request "done". + enum FinalStreamAction { + /// close the connection + case close + /// trigger a read event + case read + /// do nothing + case none } case verifyRequest @@ -67,11 +97,11 @@ struct HTTPRequestStateMachine { case pauseRequestBodyStream case resumeRequestBodyStream - case forwardResponseHead(HTTPResponseHead) + case forwardResponseHead(HTTPResponseHead, pauseRequestBodyStream: Bool) case forwardResponseBodyPart(ByteBuffer, resetReadTimeoutTimer: TimeAmount?) - case forwardResponseEnd(readPending: Bool, clearReadTimeoutTimer: Bool) - case failRequest(Error, closeStream: Bool) + case failRequest(Error, FinalStreamAction, clearReadTimeoutTimer: Bool) + case succeedRequest(FinalStreamAction, clearReadTimeoutTimer: Bool) case read case wait @@ -87,131 +117,168 @@ struct HTTPRequestStateMachine { self.idleReadTimeout = idleReadTimeout } + mutating func start() -> Action { + guard case .initialized = self.state else { + preconditionFailure("`start()` must be called first, and exactly once. Invalid state: \(self.state)") + } + self.state = .verifyingRequest + return .verifyRequest + } + mutating func writabilityChanged(writable: Bool) -> Action { - self.isChannelWritable = writable + if writable { + return self.channelIsWritable() + } else { + return self.channelIsNotWritable() + } + } + + private mutating func channelIsWritable() -> Action { + self.isChannelWritable = true switch self.state { case .initialized, + .verifyingRequest, + .running(.streaming(_, _, producer: .producing), _), + .running(.endSent, _), .finished, .failed: return .wait - case .running(.verifyRequest, _), .running(.endSent, _): + case .waitForChannelToBecomeWritable(let head): + return self.startSendingRequestHead(head) + + case .running(.streaming(_, _, producer: .paused), .receivingBody(let head, _)) where head.status.code >= 300: + // If we are receiving a response with a status of >= 300, we should not send out + // further request body parts. The remote already signaled with status >= 300 that it + // won't be interested. Let's save some bandwidth. return .wait case .running(.streaming(let expectedBody, let sentBodyBytes, producer: .paused), let responseState): - if writable { - let requestState: RequestState = .streaming( - expectedBodyLength: expectedBody, - sentBodyBytes: sentBodyBytes, - producer: .producing - ) + let requestState: RequestState = .streaming( + expectedBodyLength: expectedBody, + sentBodyBytes: sentBodyBytes, + producer: .producing + ) - self.state = .running(requestState, responseState) - return .resumeRequestBodyStream - } else { - // no state change needed - return .wait - } + self.state = .running(requestState, responseState) + return .resumeRequestBodyStream + } + } + + private mutating func channelIsNotWritable() -> Action { + self.isChannelWritable = false + + switch self.state { + case .initialized, + .verifyingRequest, + .waitForChannelToBecomeWritable, + .running(.streaming(_, _, producer: .paused), _), + .running(.endSent, _), + .finished, + .failed: + return .wait case .running(.streaming(let expectedBodyLength, let sentBodyBytes, producer: .producing), let responseState): - if !writable { - let requestState: RequestState = .streaming( - expectedBodyLength: expectedBodyLength, - sentBodyBytes: sentBodyBytes, - producer: .paused - ) - self.state = .running(requestState, responseState) - return .pauseRequestBodyStream - } else { - // no state change needed - return .wait - } + let requestState: RequestState = .streaming( + expectedBodyLength: expectedBodyLength, + sentBodyBytes: sentBodyBytes, + producer: .paused + ) + self.state = .running(requestState, responseState) + return .pauseRequestBodyStream } } mutating func readEventCaught() -> Action { - return .read + switch self.state { + case .initialized, + .verifyingRequest, + .waitForChannelToBecomeWritable, + .running(_, .waitingForHead), + .running(_, .endReceived), + .finished, + .failed: + // If we are not in the middle of streaming the response body, we always want to get + // more data... + return .read + case .running(_, .receivingBody(_, .downstreamIsConsuming(readPending: true))): + preconditionFailure("It should not be possible to receive two reads after each other, if the first one hasn't been forwarded.") + case .running(let requestState, .receivingBody(let responseHead, .downstreamIsConsuming(readPending: false))): + self.state = .running(requestState, .receivingBody(responseHead, .downstreamIsConsuming(readPending: true))) + return .wait + case .running(let requestState, .receivingBody(let responseHead, .downstreamHasDemand)): + self.state = .running(requestState, .receivingBody(responseHead, .downstreamHasDemand)) + return .read + } } mutating func errorHappened(_ error: Error) -> Action { switch self.state { case .initialized: preconditionFailure("After the state machine has been initialized, start must be called immediately. Thus this state is unreachable") + case .verifyingRequest, .waitForChannelToBecomeWritable: + // the request failed, before it was send onto the wire. + self.state = .failed(error) + return .failRequest(error, .none, clearReadTimeoutTimer: false) case .running: self.state = .failed(error) - return .failRequest(error, closeStream: true) + return .failRequest(error, .close, clearReadTimeoutTimer: false) case .finished, .failed: preconditionFailure("If the request is finished or failed, we expect the connection state machine to remove the request immediately from its state. Thus this state is unreachable.") } } - mutating func start() -> Action { - guard case .initialized = self.state else { - preconditionFailure("Invalid state") - } - - self.state = .running(.verifyRequest, .initialized) - return .verifyRequest - } - mutating func requestVerified(_ head: HTTPRequestHead) -> Action { - guard case .running(.verifyRequest, .initialized) = self.state else { - preconditionFailure("Invalid state") + guard case .verifyingRequest = self.state else { + preconditionFailure("Invalid state: \(self.state)") } guard self.isChannelWritable else { - preconditionFailure("Unimplemented. Wait with starting the request here!") - } - - if let value = head.headers.first(name: "content-length"), let length = Int(value), length > 0 { - self.state = .running(.streaming(expectedBodyLength: length, sentBodyBytes: 0, producer: .producing), .initialized) - return .sendRequestHead(head, startBody: true, startReadTimeoutTimer: nil) - } else if head.headers.contains(name: "transfer-encoding") { - self.state = .running(.streaming(expectedBodyLength: nil, sentBodyBytes: 0, producer: .producing), .initialized) - return .sendRequestHead(head, startBody: true, startReadTimeoutTimer: nil) - } else { - self.state = .running(.endSent, .initialized) - return .sendRequestHead(head, startBody: false, startReadTimeoutTimer: self.idleReadTimeout) - } - } - - mutating func requestVerificationFailed(_ error: Error) -> Action { - guard case .running(.verifyRequest, .initialized) = self.state else { - preconditionFailure("Invalid state") + self.state = .waitForChannelToBecomeWritable(head) + return .wait } - self.state = .failed(error) - return .failRequest(error, closeStream: false) + return self.startSendingRequestHead(head) } mutating func requestStreamPartReceived(_ part: IOData) -> Action { switch self.state { - case .initialized: - preconditionFailure("Invalid state: \(self.state)") - - case .running(.verifyRequest, _), + case .initialized, + .waitForChannelToBecomeWritable, + .verifyingRequest, .running(.endSent, _): - preconditionFailure("Invalid state: \(self.state)") + preconditionFailure("We must be in the request streaming phase, if we receive further body parts. Invalid state: \(self.state)") + + case .running(.streaming(_, _, let producerState), .receivingBody(let head, _)) where head.status.code >= 300: + // If we have already received a response head with status >= 300, we won't send out any + // further request body bytes. Since the remote signaled with status >= 300, that it + // won't be interested. We expect that the producer has been informed to pause + // producing. + assert(producerState == .paused) + return .wait case .running(.streaming(let expectedBodyLength, var sentBodyBytes, let producerState), let responseState): - // More streamed data is accepted, even though the producer should stop. However - // there might be thread syncronisations situations in which the producer might not - // be aware that it needs to stop yet. + // We don't check the producer state here: + // + // No matter if the `producerState` is either `.producing` or `.paused` any bytes we + // receive shall be forwarded to the Channel right away. As long as we have not received + // a response with status >= 300. + // + // More streamed data is accepted, even though the producer may have been asked to + // pause. The reason for this is as follows: There might be thread synchronization + // situations in which the producer might not have received the plea to pause yet. if let expected = expectedBodyLength { if sentBodyBytes + part.readableBytes > expected { let error = HTTPClientError.bodyLengthMismatch - switch responseState { - case .initialized, .receivingBody: - self.state = .failed(error) - case .endReceived: - #warning("TODO: This needs to be fixed. @Cory: What does this mean here?") - preconditionFailure("Unimplemented") + var clearReadTimeoutTimer = false + if case .receivingBody = responseState, self.idleReadTimeout != nil { + clearReadTimeoutTimer = true } - return .failRequest(error, closeStream: true) + return .failRequest(error, .close, clearReadTimeoutTimer: clearReadTimeoutTimer) } } @@ -231,40 +298,59 @@ struct HTTPRequestStateMachine { return .wait case .finished: - // a request may be finished, before we send all parts. We may still receive something - // here because of a thread race + // A request may be finished, before we have send all parts. This might be the case if + // the server responded with an HTTP status code that is equal or larger to 300 + // (Redirection, Client Error or Server Error). In those cases we pause the request body + // stream as soon as we have received the response head and we succeed the request as + // when response end is received. This may mean, that we succeed a request, even though + // we have not sent all it's body parts. + + // We may still receive something, here because of potential race conditions with the + // producing thread. return .wait } } mutating func requestStreamFinished() -> Action { switch self.state { - case .initialized: - preconditionFailure("Invalid state") - case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), let responseState): + case .initialized, + .verifyingRequest, + .waitForChannelToBecomeWritable, + .running(.endSent, _), + .finished: + preconditionFailure("A request body stream end is only expected if we are in state request streaming. Invalid state: \(self.state)") + + case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), .waitingForHead): if let expected = expectedBodyLength, expected != sentBodyBytes { let error = HTTPClientError.bodyLengthMismatch + self.state = .failed(error) + return .failRequest(error, .close, clearReadTimeoutTimer: false) + } - switch responseState { - case .initialized, .receivingBody: - self.state = .failed(error) - case .endReceived: - #warning("TODO: This needs to be fixed. @Cory: What does this mean here?") - preconditionFailure("Unimplemented") - } + self.state = .running(.endSent, .waitingForHead) + return .sendRequestEnd(startReadTimeoutTimer: self.idleReadTimeout) - return .failRequest(error, closeStream: true) + case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), .receivingBody(let head, let streamState)): + assert(head.status.code < 300) + + if let expected = expectedBodyLength, expected != sentBodyBytes { + let error = HTTPClientError.bodyLengthMismatch + self.state = .failed(error) + return .failRequest(error, .close, clearReadTimeoutTimer: self.idleReadTimeout != nil) } - self.state = .running(.endSent, responseState) + self.state = .running(.endSent, .receivingBody(head, streamState)) return .sendRequestEnd(startReadTimeoutTimer: self.idleReadTimeout) - case .running(.verifyRequest, _), - .running(.endSent, _): - preconditionFailure("Invalid state") + case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), .endReceived): + if let expected = expectedBodyLength, expected != sentBodyBytes { + let error = HTTPClientError.bodyLengthMismatch + self.state = .failed(error) + return .failRequest(error, .close, clearReadTimeoutTimer: false) + } - case .finished: - return .wait + self.state = .finished + return .succeedRequest(.none, clearReadTimeoutTimer: false) case .failed: return .wait @@ -273,10 +359,19 @@ struct HTTPRequestStateMachine { mutating func requestCancelled() -> Action { switch self.state { - case .initialized, .running: + case .initialized, .verifyingRequest, .waitForChannelToBecomeWritable: + let error = HTTPClientError.cancelled + self.state = .failed(error) + return .failRequest(error, .none, clearReadTimeoutTimer: false) + case .running(_, let responseState): let error = HTTPClientError.cancelled self.state = .failed(error) - return .failRequest(error, closeStream: true) + + var clearReadTimeoutTimer = false + if case .receivingBody = responseState, self.idleReadTimeout != nil { + clearReadTimeoutTimer = true + } + return .failRequest(error, .close, clearReadTimeoutTimer: clearReadTimeoutTimer) case .finished: return .wait case .failed: @@ -286,10 +381,19 @@ struct HTTPRequestStateMachine { mutating func channelInactive() -> Action { switch self.state { - case .initialized, .running: + case .initialized, .verifyingRequest, .waitForChannelToBecomeWritable: let error = HTTPClientError.remoteConnectionClosed self.state = .failed(error) - return .failRequest(error, closeStream: false) + return .failRequest(error, .none, clearReadTimeoutTimer: false) + case .running(_, let responseState): + let error = HTTPClientError.remoteConnectionClosed + self.state = .failed(error) + + var clearReadTimeoutTimer = false + if case .receivingBody = responseState, self.idleReadTimeout != nil { + clearReadTimeoutTimer = true + } + return .failRequest(error, .none, clearReadTimeoutTimer: clearReadTimeoutTimer) case .finished: return .wait case .failed: @@ -302,22 +406,42 @@ struct HTTPRequestStateMachine { mutating func receivedHTTPResponseHead(_ head: HTTPResponseHead) -> Action { switch self.state { - case .initialized: + case .initialized, .verifyingRequest, .waitForChannelToBecomeWritable: preconditionFailure("How can we receive a response head before sending a request head ourselves") - case .running(let requestState, .initialized): - switch requestState { - case .verifyRequest: - preconditionFailure("How can we receive a response head before sending a request head ourselves") - case .streaming, .endSent: - break + case .running(.streaming(let expectedBodyLength, let sentBodyBytes, producer: .paused), .waitingForHead): + self.state = .running( + .streaming(expectedBodyLength: expectedBodyLength, sentBodyBytes: sentBodyBytes, producer: .producing), + .receivingBody(head, .downstreamIsConsuming(readPending: false)) + ) + return .forwardResponseHead(head, pauseRequestBodyStream: false) + + case .running(.streaming(let expectedBodyLength, let sentBodyBytes, producer: .producing), .waitingForHead): + guard head.status.code >= 200 else { + // we ignore any leading 1xx headers... No state change needed. + return .wait + } + + if head.status.code >= 300 { + self.state = .running( + .streaming(expectedBodyLength: expectedBodyLength, sentBodyBytes: sentBodyBytes, producer: .paused), + .receivingBody(head, .downstreamIsConsuming(readPending: false)) + ) + return .forwardResponseHead(head, pauseRequestBodyStream: true) + } else { + self.state = .running( + .streaming(expectedBodyLength: expectedBodyLength, sentBodyBytes: sentBodyBytes, producer: .producing), + .receivingBody(head, .downstreamIsConsuming(readPending: false)) + ) + return .forwardResponseHead(head, pauseRequestBodyStream: false) } - self.state = .running(requestState, .receivingBody(.waiting)) - return .forwardResponseHead(head) + case .running(.endSent, .waitingForHead): + self.state = .running(.endSent, .receivingBody(head, .downstreamIsConsuming(readPending: false))) + return .forwardResponseHead(head, pauseRequestBodyStream: false) case .running(_, .receivingBody), .running(_, .endReceived), .finished: - preconditionFailure("How can we sucessfully finish the request, before having received a head") + preconditionFailure("How can we successfully finish the request, before having received a head. Invalid state: \(self.state)") case .failed: return .wait } @@ -325,24 +449,23 @@ struct HTTPRequestStateMachine { mutating func receivedHTTPResponseBodyPart(_ body: ByteBuffer) -> Action { switch self.state { - case .initialized: - preconditionFailure("How can we receive a response head before sending a request head ourselves") + case .initialized, .verifyingRequest, .waitForChannelToBecomeWritable: + preconditionFailure("How can we receive a response head before sending a request head ourselves. Invalid state: \(self.state)") - case .running(_, .initialized): - preconditionFailure("How can we receive a response body, if we haven't received a head") + case .running(_, .waitingForHead): + preconditionFailure("How can we receive a response body, if we haven't received a head. Invalid state: \(self.state)") - case .running(let requestState, .receivingBody(let streamState)): - switch streamState { - case .waiting, .readEventPending: - break - case .downstreamHasDemand: - self.state = .running(requestState, .receivingBody(.waiting)) - } + case .running(let requestState, .receivingBody(let head, .downstreamHasDemand)): + self.state = .running(requestState, .receivingBody(head, .downstreamIsConsuming(readPending: false))) + return .forwardResponseBodyPart(body, resetReadTimeoutTimer: self.idleReadTimeout) + case .running(_, .receivingBody(_, .downstreamIsConsuming)): + // the state doesn't need to be changed. we are already in the correct state. + // just forward the data. return .forwardResponseBodyPart(body, resetReadTimeoutTimer: self.idleReadTimeout) case .running(_, .endReceived), .finished: - preconditionFailure("How can we sucessfully finish the request, before having received a head") + preconditionFailure("How can we successfully finish the request, before having received a head. Invalid state: \(self.state)") case .failed: return .wait } @@ -350,67 +473,100 @@ struct HTTPRequestStateMachine { mutating func receivedHTTPResponseEnd() -> Action { switch self.state { - case .initialized: - preconditionFailure("How can we receive a response head before sending a request head ourselves") + case .initialized, .verifyingRequest, .waitForChannelToBecomeWritable: + preconditionFailure("How can we receive a response head before sending a request head ourselves. Invalid state: \(self.state)") - case .running(_, .initialized): - preconditionFailure("How can we receive a response body, if we haven't a received a head") + case .running(_, .waitingForHead): + preconditionFailure("How can we receive a response end, if we haven't a received a head. Invalid state: \(self.state)") + + case .running(.streaming(let expectedBodyLength, let sentBodyBytes, let producerState), .receivingBody(let head, _)) where head.status.code < 300: + self.state = .running( + .streaming(expectedBodyLength: expectedBodyLength, sentBodyBytes: sentBodyBytes, producer: producerState), + .endReceived + ) + return .wait - case .running(.streaming, .receivingBody(let streamState)): - preconditionFailure("Unimplemented") - #warning("@Fabian: We received response end, before sending our own request's end.") + case .running(.streaming(_, _, let producerState), .receivingBody(let head, _)): + assert(head.status.code >= 300) + assert(producerState == .paused, "Expected to have paused the request body stream, when the head was received. Invalid state: \(self.state)") + self.state = .finished + return .succeedRequest(.close, clearReadTimeoutTimer: self.idleReadTimeout != nil) - case .running(.endSent, .receivingBody(let streamState)): - let readPending: Bool + case .running(.endSent, .receivingBody(_, let streamState)): + let finalAction: Action.FinalStreamAction switch streamState { - case .readEventPending: - readPending = true - case .downstreamHasDemand, .waiting: - readPending = false + case .downstreamIsConsuming(readPending: true): + finalAction = .read + case .downstreamIsConsuming(readPending: false), .downstreamHasDemand: + finalAction = .none } self.state = .finished - return .forwardResponseEnd(readPending: readPending, clearReadTimeoutTimer: self.idleReadTimeout != nil) + return .succeedRequest(finalAction, clearReadTimeoutTimer: self.idleReadTimeout != nil) - case .running(.verifyRequest, .receivingBody), - .running(_, .endReceived), .finished: - preconditionFailure("invalid state") + case .running(_, .endReceived), .finished: + preconditionFailure("How can we receive a response end, if another one was already received. Invalid state: \(self.state)") case .failed: return .wait } } mutating func forwardMoreBodyParts() -> Action { - guard case .running(let requestState, .receivingBody(let streamControl)) = self.state else { - preconditionFailure("Invalid state") - } - - switch streamControl { - case .waiting: - self.state = .running(requestState, .receivingBody(.downstreamHasDemand)) + switch self.state { + case .initialized, + .verifyingRequest, + .running(_, .waitingForHead), + .waitForChannelToBecomeWritable: + preconditionFailure("The response is expected to only ask for more data after the response head was forwarded") + case .running(let requestState, .receivingBody(let head, .downstreamIsConsuming(readPending: false))): + self.state = .running(requestState, .receivingBody(head, .downstreamHasDemand)) return .wait - case .readEventPending: - self.state = .running(requestState, .receivingBody(.waiting)) + case .running(let requestState, .receivingBody(let head, .downstreamIsConsuming(readPending: true))): + self.state = .running(requestState, .receivingBody(head, .downstreamIsConsuming(readPending: false))) return .read - case .downstreamHasDemand: + case .running(_, .receivingBody(_, .downstreamHasDemand)): // We have received a request for more data before. Normally we only expect one request // for more data, but a race can come into play here. return .wait + case .running(_, .endReceived), + .finished, + .failed: + return .wait } } mutating func idleReadTimeoutTriggered() -> Action { - guard case .running(.endSent, let responseState) = self.state else { - preconditionFailure("We only schedule idle read timeouts after we have sent the complete request") - } + switch self.state { + case .initialized, + .verifyingRequest, + .waitForChannelToBecomeWritable, + .running(.streaming, _): + preconditionFailure("We only schedule idle read timeouts after we have sent the complete request. Invalid state: \(self.state)") - if case .endReceived = responseState { - preconditionFailure("Invalid state: If we have received everything, we must not schedule further timeout timers") + case .running(.endSent, .waitingForHead), .running(.endSent, .receivingBody): + let error = HTTPClientError.readTimeout + self.state = .failed(error) + return .failRequest(error, .close, clearReadTimeoutTimer: false) + + case .running(.endSent, .endReceived): + preconditionFailure("Invalid state. This state should be: .finished") + + case .finished, .failed: + return .wait } + } - let error = HTTPClientError.readTimeout - self.state = .failed(error) - return .failRequest(error, closeStream: true) + private mutating func startSendingRequestHead(_ head: HTTPRequestHead) -> Action { + if let value = head.headers.first(name: "content-length"), let length = Int(value), length > 0 { + self.state = .running(.streaming(expectedBodyLength: length, sentBodyBytes: 0, producer: .producing), .waitingForHead) + return .sendRequestHead(head, startBody: true, startReadTimeoutTimer: nil) + } else if head.headers.contains(name: "transfer-encoding") { + self.state = .running(.streaming(expectedBodyLength: nil, sentBodyBytes: 0, producer: .producing), .waitingForHead) + return .sendRequestHead(head, startBody: true, startReadTimeoutTimer: nil) + } else { + self.state = .running(.endSent, .waitingForHead) + return .sendRequestHead(head, startBody: false, startReadTimeoutTimer: self.idleReadTimeout) + } } } @@ -419,6 +575,10 @@ extension HTTPRequestStateMachine: CustomStringConvertible { switch self.state { case .initialized: return "HTTPRequestStateMachine(.initialized, isWritable: \(self.isChannelWritable))" + case .verifyingRequest: + return "HTTPRequestStateMachine(.verifyingRequest, isWritable: \(self.isChannelWritable))" + case .waitForChannelToBecomeWritable: + return "HTTPRequestStateMachine(.waitForChannelToBecomeWritable, isWritable: \(self.isChannelWritable))" case .running(let requestState, let responseState): return "HTTPRequestStateMachine(.running(request: \(requestState), response: \(responseState)), isWritable: \(self.isChannelWritable))" case .finished: @@ -432,17 +592,15 @@ extension HTTPRequestStateMachine: CustomStringConvertible { extension HTTPRequestStateMachine.RequestState: CustomStringConvertible { var description: String { switch self { - case .verifyRequest: - return ".verifyRequest" case .streaming(expectedBodyLength: let expected, let sent, producer: let producer): - return ".sendingHead(sent: \(expected != nil ? String(expected!) : "-"), sent: \(sent), producer: \(producer)" + return ".streaming(sent: \(expected != nil ? String(expected!) : "-"), sent: \(sent), producer: \(producer)" case .endSent: return ".endSent" } } } -extension HTTPRequestStateMachine.RequestState.ProducerControlState { +extension HTTPRequestStateMachine.RequestState.ProducerControlState: CustomStringConvertible { var description: String { switch self { case .paused: @@ -456,10 +614,10 @@ extension HTTPRequestStateMachine.RequestState.ProducerControlState { extension HTTPRequestStateMachine.ResponseState: CustomStringConvertible { var description: String { switch self { - case .initialized: - return ".initialized" - case .receivingBody(let streamState): - return ".receivingBody(streamState: \(streamState))" + case .waitingForHead: + return ".waitingForHead" + case .receivingBody(let head, let streamState): + return ".receivingBody(\(head), streamState: \(streamState))" case .endReceived: return ".endReceived" } diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift new file mode 100644 index 000000000..92da07447 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// +// HTTPRequestStateMachineTests+XCTest.swift +// +import XCTest + +/// +/// NOTE: This file was generated by generate_linux_tests.rb +/// +/// Do NOT edit this file directly as it will be regenerated automatically when needed. +/// + +extension HTTPRequestStateMachineTests { + static var allTests: [(String, (HTTPRequestStateMachineTests) -> () throws -> Void)] { + return [ + ("testSimpleGETRequest", testSimpleGETRequest), + ("testPOSTRequestWithWriterBackpressure", testPOSTRequestWithWriterBackpressure), + ("testPOSTContentLengthIsTooLong", testPOSTContentLengthIsTooLong), + ("testPOSTContentLengthIsTooShort", testPOSTContentLengthIsTooShort), + ] + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift index e06ac91d0..2af8abcf3 100644 --- a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift @@ -25,10 +25,10 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.requestVerified(requestHead), .sendRequestHead(requestHead, startBody: false, startReadTimeoutTimer: nil)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead)) + XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.receivedHTTPResponseBodyPart(responseBody), .forwardResponseBodyPart(responseBody, resetReadTimeoutTimer: nil)) - XCTAssertEqual(state.receivedHTTPResponseEnd(), .forwardResponseEnd(readPending: false, clearReadTimeoutTimer: false)) + XCTAssertEqual(state.receivedHTTPResponseEnd(), .succeedRequest(.none, clearReadTimeoutTimer: false)) } func testPOSTRequestWithWriterBackpressure() { @@ -58,10 +58,10 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.requestStreamFinished(), .sendRequestEnd(startReadTimeoutTimer: nil)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead)) + XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.receivedHTTPResponseBodyPart(responseBody), .forwardResponseBodyPart(responseBody, resetReadTimeoutTimer: nil)) - XCTAssertEqual(state.receivedHTTPResponseEnd(), .forwardResponseEnd(readPending: false, clearReadTimeoutTimer: false)) + XCTAssertEqual(state.receivedHTTPResponseEnd(), .succeedRequest(.none, clearReadTimeoutTimer: false)) } func testPOSTContentLengthIsTooLong() { @@ -74,7 +74,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) let failAction = state.requestStreamPartReceived(part1) - guard case .failRequest(let error, closeStream: true) = failAction else { + guard case .failRequest(let error, .close, clearReadTimeoutTimer: false) = failAction else { return XCTFail("Unexpected action: \(failAction)") } @@ -90,7 +90,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) let failAction = state.requestStreamFinished() - guard case .failRequest(let error, closeStream: true) = failAction else { + guard case .failRequest(let error, .close, clearReadTimeoutTimer: false) = failAction else { return XCTFail("Unexpected action: \(failAction)") } @@ -116,15 +116,15 @@ extension HTTPRequestStateMachine.Action: Equatable { case (.resumeRequestBodyStream, .resumeRequestBodyStream): return true - case (.forwardResponseHead(let lhsHead), .forwardResponseHead(let rhsHead)): - return lhsHead == rhsHead + case (.forwardResponseHead(let lhsHead, let lhsPauseRequestStream), .forwardResponseHead(let rhsHead, let rhsPauseRequestStream)): + return lhsHead == rhsHead && lhsPauseRequestStream == rhsPauseRequestStream case (.forwardResponseBodyPart(let lhsData, let lhsIdleReadTimeout), .forwardResponseBodyPart(let rhsData, let rhsIdleReadTimeout)): return lhsIdleReadTimeout == rhsIdleReadTimeout && lhsData == rhsData - case (.forwardResponseEnd(readPending: let lhsPending), .forwardResponseEnd(readPending: let rhsPending)): - return lhsPending == rhsPending - case (.failRequest(_, closeStream: let lhsClose), .failRequest(_, closeStream: let rhsClose)): - return lhsClose == rhsClose + case (.succeedRequest(let lhsFinalAction, let lhsClearReadTimeoutTimer), .succeedRequest(let rhsFinalAction, let rhsClearReadTimeoutTimer)): + return lhsFinalAction == rhsFinalAction && lhsClearReadTimeoutTimer == rhsClearReadTimeoutTimer + case (.failRequest(_, let lhsFinalAction, let lhsClearReadTimeoutTimer), .failRequest(_, let rhsFinalAction, let rhsClearReadTimeoutTimer)): + return lhsFinalAction == rhsFinalAction && lhsClearReadTimeoutTimer == rhsClearReadTimeoutTimer case (.read, .read): return true diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index 54a2c80e0..83ea08033 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -35,6 +35,7 @@ import XCTest testCase(HTTPClientSOCKSTests.allTests), testCase(HTTPClientTests.allTests), testCase(HTTPConnectionPool_FactoryTests.allTests), + testCase(HTTPRequestStateMachineTests.allTests), testCase(LRUCacheTests.allTests), testCase(RequestValidationTests.allTests), testCase(SOCKSEventsHandlerTests.allTests), From 49bd91147f673bf0832c911b1db345374687bfce Mon Sep 17 00:00:00 2001 From: Fabian Fett <fabianfett@mac.com> Date: Thu, 1 Jul 2021 10:50:35 +0200 Subject: [PATCH 4/9] Apply suggestions from code review Co-authored-by: Cory Benfield <lukasa@apple.com> Co-authored-by: George Barnett <gbarnett@apple.com> --- .../ConnectionPool/HTTPRequestStateMachine.swift | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift index be76c1781..bee122c7e 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift @@ -47,9 +47,9 @@ struct HTTPRequestStateMachine { /// A sub state for sending a request body. Stores whether a producer should produce more /// bytes or should pause. enum ProducerControlState: Equatable { - /// The request body producer should produce more body bytes. The channel is writable, + /// The request body producer should produce more body bytes. The channel is writable. case producing - /// The request body producer should pause producing more bytes. The channel is not writable, + /// The request body producer should pause producing more bytes. The channel is not writable. case paused } @@ -71,7 +71,7 @@ struct HTTPRequestStateMachine { /// A response head has not been received yet. case waitingForHead - /// A response head has been received and we are ready to consume more data of the wire + /// A response head has been received and we are ready to consume more data off the wire case receivingBody(HTTPResponseHead, ConsumerControlState) /// A response end has been received and we are ready to consume more data of the wire case endReceived @@ -218,7 +218,7 @@ struct HTTPRequestStateMachine { case .initialized: preconditionFailure("After the state machine has been initialized, start must be called immediately. Thus this state is unreachable") case .verifyingRequest, .waitForChannelToBecomeWritable: - // the request failed, before it was send onto the wire. + // the request failed, before it was sent onto the wire. self.state = .failed(error) return .failRequest(error, .none, clearReadTimeoutTimer: false) case .running: From eba6df3c9b8142dd979d93fcddbf931c5416d325 Mon Sep 17 00:00:00 2001 From: Fabian Fett <fabianfett@apple.com> Date: Thu, 1 Jul 2021 15:13:04 +0200 Subject: [PATCH 5/9] Code review --- .../HTTPRequestStateMachine.swift | 98 ++++++++----------- .../HTTPRequestStateMachineTests.swift | 42 ++++---- 2 files changed, 61 insertions(+), 79 deletions(-) diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift index bee122c7e..d6dad26d3 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift @@ -54,8 +54,8 @@ struct HTTPRequestStateMachine { } /// The request is streaming its request body. `expectedBodyLength` has a value, if the request header contained - /// a `"content-length"` header field. It is the request header contained a `"transfer-encoding" = "chunked"` - /// header field. + /// a `"content-length"` header field. If the request header contained a `"transfer-encoding" = "chunked"` + /// header field, the `expectedBodyLength` is `nil`. case streaming(expectedBodyLength: Int?, sentBodyBytes: Int, producer: ProducerControlState) /// The request has sent its request body and end. case endSent @@ -73,35 +73,36 @@ struct HTTPRequestStateMachine { case waitingForHead /// A response head has been received and we are ready to consume more data off the wire case receivingBody(HTTPResponseHead, ConsumerControlState) - /// A response end has been received and we are ready to consume more data of the wire + /// A response end has been received. We don't expect more bytes from the wire. case endReceived } enum Action { /// A action to execute, when we consider a request "done". enum FinalStreamAction { - /// close the connection + /// Close the connection case close - /// trigger a read event + /// Trigger a read event case read - /// do nothing + /// Do nothing. This is action is used, if the request failed, before we the request head was written onto the wire. + /// This might happen if the request is cancelled, or the request failed the soundness check. case none } case verifyRequest - case sendRequestHead(HTTPRequestHead, startBody: Bool, startReadTimeoutTimer: TimeAmount?) + case sendRequestHead(HTTPRequestHead, startBody: Bool) case sendBodyPart(IOData) - case sendRequestEnd(startReadTimeoutTimer: TimeAmount?) + case sendRequestEnd case pauseRequestBodyStream case resumeRequestBodyStream case forwardResponseHead(HTTPResponseHead, pauseRequestBodyStream: Bool) - case forwardResponseBodyPart(ByteBuffer, resetReadTimeoutTimer: TimeAmount?) + case forwardResponseBodyPart(ByteBuffer) - case failRequest(Error, FinalStreamAction, clearReadTimeoutTimer: Bool) - case succeedRequest(FinalStreamAction, clearReadTimeoutTimer: Bool) + case failRequest(Error, FinalStreamAction) + case succeedRequest(FinalStreamAction) case read case wait @@ -207,8 +208,10 @@ struct HTTPRequestStateMachine { case .running(let requestState, .receivingBody(let responseHead, .downstreamIsConsuming(readPending: false))): self.state = .running(requestState, .receivingBody(responseHead, .downstreamIsConsuming(readPending: true))) return .wait - case .running(let requestState, .receivingBody(let responseHead, .downstreamHasDemand)): - self.state = .running(requestState, .receivingBody(responseHead, .downstreamHasDemand)) + case .running(_, .receivingBody(_, .downstreamHasDemand)): + // The consumer has signaled a demand for more response body bytes. If a `read` is + // caught, we pass it on right away. The state machines does not transition into another + // state. return .read } } @@ -220,10 +223,10 @@ struct HTTPRequestStateMachine { case .verifyingRequest, .waitForChannelToBecomeWritable: // the request failed, before it was sent onto the wire. self.state = .failed(error) - return .failRequest(error, .none, clearReadTimeoutTimer: false) + return .failRequest(error, .none) case .running: self.state = .failed(error) - return .failRequest(error, .close, clearReadTimeoutTimer: false) + return .failRequest(error, .close) case .finished, .failed: preconditionFailure("If the request is finished or failed, we expect the connection state machine to remove the request immediately from its state. Thus this state is unreachable.") } @@ -269,17 +272,10 @@ struct HTTPRequestStateMachine { // pause. The reason for this is as follows: There might be thread synchronization // situations in which the producer might not have received the plea to pause yet. - if let expected = expectedBodyLength { - if sentBodyBytes + part.readableBytes > expected { - let error = HTTPClientError.bodyLengthMismatch - - var clearReadTimeoutTimer = false - if case .receivingBody = responseState, self.idleReadTimeout != nil { - clearReadTimeoutTimer = true - } + if let expected = expectedBodyLength, sentBodyBytes + part.readableBytes > expected { + let error = HTTPClientError.bodyLengthMismatch - return .failRequest(error, .close, clearReadTimeoutTimer: clearReadTimeoutTimer) - } + return .failRequest(error, .close) } sentBodyBytes += part.readableBytes @@ -324,11 +320,11 @@ struct HTTPRequestStateMachine { if let expected = expectedBodyLength, expected != sentBodyBytes { let error = HTTPClientError.bodyLengthMismatch self.state = .failed(error) - return .failRequest(error, .close, clearReadTimeoutTimer: false) + return .failRequest(error, .close) } self.state = .running(.endSent, .waitingForHead) - return .sendRequestEnd(startReadTimeoutTimer: self.idleReadTimeout) + return .sendRequestEnd case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), .receivingBody(let head, let streamState)): assert(head.status.code < 300) @@ -336,21 +332,21 @@ struct HTTPRequestStateMachine { if let expected = expectedBodyLength, expected != sentBodyBytes { let error = HTTPClientError.bodyLengthMismatch self.state = .failed(error) - return .failRequest(error, .close, clearReadTimeoutTimer: self.idleReadTimeout != nil) + return .failRequest(error, .close) } self.state = .running(.endSent, .receivingBody(head, streamState)) - return .sendRequestEnd(startReadTimeoutTimer: self.idleReadTimeout) + return .sendRequestEnd case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), .endReceived): if let expected = expectedBodyLength, expected != sentBodyBytes { let error = HTTPClientError.bodyLengthMismatch self.state = .failed(error) - return .failRequest(error, .close, clearReadTimeoutTimer: false) + return .failRequest(error, .close) } self.state = .finished - return .succeedRequest(.none, clearReadTimeoutTimer: false) + return .succeedRequest(.none) case .failed: return .wait @@ -362,16 +358,11 @@ struct HTTPRequestStateMachine { case .initialized, .verifyingRequest, .waitForChannelToBecomeWritable: let error = HTTPClientError.cancelled self.state = .failed(error) - return .failRequest(error, .none, clearReadTimeoutTimer: false) - case .running(_, let responseState): + return .failRequest(error, .none) + case .running: let error = HTTPClientError.cancelled self.state = .failed(error) - - var clearReadTimeoutTimer = false - if case .receivingBody = responseState, self.idleReadTimeout != nil { - clearReadTimeoutTimer = true - } - return .failRequest(error, .close, clearReadTimeoutTimer: clearReadTimeoutTimer) + return .failRequest(error, .close) case .finished: return .wait case .failed: @@ -381,19 +372,10 @@ struct HTTPRequestStateMachine { mutating func channelInactive() -> Action { switch self.state { - case .initialized, .verifyingRequest, .waitForChannelToBecomeWritable: - let error = HTTPClientError.remoteConnectionClosed - self.state = .failed(error) - return .failRequest(error, .none, clearReadTimeoutTimer: false) - case .running(_, let responseState): + case .initialized, .verifyingRequest, .waitForChannelToBecomeWritable, .running: let error = HTTPClientError.remoteConnectionClosed self.state = .failed(error) - - var clearReadTimeoutTimer = false - if case .receivingBody = responseState, self.idleReadTimeout != nil { - clearReadTimeoutTimer = true - } - return .failRequest(error, .none, clearReadTimeoutTimer: clearReadTimeoutTimer) + return .failRequest(error, .none) case .finished: return .wait case .failed: @@ -457,12 +439,12 @@ struct HTTPRequestStateMachine { case .running(let requestState, .receivingBody(let head, .downstreamHasDemand)): self.state = .running(requestState, .receivingBody(head, .downstreamIsConsuming(readPending: false))) - return .forwardResponseBodyPart(body, resetReadTimeoutTimer: self.idleReadTimeout) + return .forwardResponseBodyPart(body) case .running(_, .receivingBody(_, .downstreamIsConsuming)): // the state doesn't need to be changed. we are already in the correct state. // just forward the data. - return .forwardResponseBodyPart(body, resetReadTimeoutTimer: self.idleReadTimeout) + return .forwardResponseBodyPart(body) case .running(_, .endReceived), .finished: preconditionFailure("How can we successfully finish the request, before having received a head. Invalid state: \(self.state)") @@ -490,7 +472,7 @@ struct HTTPRequestStateMachine { assert(head.status.code >= 300) assert(producerState == .paused, "Expected to have paused the request body stream, when the head was received. Invalid state: \(self.state)") self.state = .finished - return .succeedRequest(.close, clearReadTimeoutTimer: self.idleReadTimeout != nil) + return .succeedRequest(.close) case .running(.endSent, .receivingBody(_, let streamState)): let finalAction: Action.FinalStreamAction @@ -502,7 +484,7 @@ struct HTTPRequestStateMachine { } self.state = .finished - return .succeedRequest(finalAction, clearReadTimeoutTimer: self.idleReadTimeout != nil) + return .succeedRequest(finalAction) case .running(_, .endReceived), .finished: preconditionFailure("How can we receive a response end, if another one was already received. Invalid state: \(self.state)") @@ -546,7 +528,7 @@ struct HTTPRequestStateMachine { case .running(.endSent, .waitingForHead), .running(.endSent, .receivingBody): let error = HTTPClientError.readTimeout self.state = .failed(error) - return .failRequest(error, .close, clearReadTimeoutTimer: false) + return .failRequest(error, .close) case .running(.endSent, .endReceived): preconditionFailure("Invalid state. This state should be: .finished") @@ -559,13 +541,13 @@ struct HTTPRequestStateMachine { private mutating func startSendingRequestHead(_ head: HTTPRequestHead) -> Action { if let value = head.headers.first(name: "content-length"), let length = Int(value), length > 0 { self.state = .running(.streaming(expectedBodyLength: length, sentBodyBytes: 0, producer: .producing), .waitingForHead) - return .sendRequestHead(head, startBody: true, startReadTimeoutTimer: nil) + return .sendRequestHead(head, startBody: true) } else if head.headers.contains(name: "transfer-encoding") { self.state = .running(.streaming(expectedBodyLength: nil, sentBodyBytes: 0, producer: .producing), .waitingForHead) - return .sendRequestHead(head, startBody: true, startReadTimeoutTimer: nil) + return .sendRequestHead(head, startBody: true) } else { self.state = .running(.endSent, .waitingForHead) - return .sendRequestHead(head, startBody: false, startReadTimeoutTimer: self.idleReadTimeout) + return .sendRequestHead(head, startBody: false) } } } diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift index 2af8abcf3..74619c479 100644 --- a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift @@ -22,20 +22,20 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true, idleReadTimeout: nil) XCTAssertEqual(state.start(), .verifyRequest) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") - XCTAssertEqual(state.requestVerified(requestHead), .sendRequestHead(requestHead, startBody: false, startReadTimeoutTimer: nil)) + XCTAssertEqual(state.requestVerified(requestHead), .sendRequestHead(requestHead, startBody: false)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) - XCTAssertEqual(state.receivedHTTPResponseBodyPart(responseBody), .forwardResponseBodyPart(responseBody, resetReadTimeoutTimer: nil)) - XCTAssertEqual(state.receivedHTTPResponseEnd(), .succeedRequest(.none, clearReadTimeoutTimer: false)) + XCTAssertEqual(state.receivedHTTPResponseBodyPart(responseBody), .forwardResponseBodyPart(responseBody)) + XCTAssertEqual(state.receivedHTTPResponseEnd(), .succeedRequest(.none)) } func testPOSTRequestWithWriterBackpressure() { var state = HTTPRequestStateMachine(isChannelWritable: true, idleReadTimeout: nil) XCTAssertEqual(state.start(), .verifyRequest) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "4")])) - XCTAssertEqual(state.requestVerified(requestHead), .sendRequestHead(requestHead, startBody: true, startReadTimeoutTimer: nil)) + XCTAssertEqual(state.requestVerified(requestHead), .sendRequestHead(requestHead, startBody: true)) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [1])) let part2 = IOData.byteBuffer(ByteBuffer(bytes: [2])) @@ -55,26 +55,26 @@ class HTTPRequestStateMachineTests: XCTestCase { // once we receive a writable event again, we can allow the producer to produce more data XCTAssertEqual(state.writabilityChanged(writable: true), .resumeRequestBodyStream) XCTAssertEqual(state.requestStreamPartReceived(part3), .sendBodyPart(part3)) - XCTAssertEqual(state.requestStreamFinished(), .sendRequestEnd(startReadTimeoutTimer: nil)) + XCTAssertEqual(state.requestStreamFinished(), .sendRequestEnd) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) - XCTAssertEqual(state.receivedHTTPResponseBodyPart(responseBody), .forwardResponseBodyPart(responseBody, resetReadTimeoutTimer: nil)) - XCTAssertEqual(state.receivedHTTPResponseEnd(), .succeedRequest(.none, clearReadTimeoutTimer: false)) + XCTAssertEqual(state.receivedHTTPResponseBodyPart(responseBody), .forwardResponseBodyPart(responseBody)) + XCTAssertEqual(state.receivedHTTPResponseEnd(), .succeedRequest(.none)) } func testPOSTContentLengthIsTooLong() { var state = HTTPRequestStateMachine(isChannelWritable: true, idleReadTimeout: nil) XCTAssertEqual(state.start(), .verifyRequest) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "4")])) - XCTAssertEqual(state.requestVerified(requestHead), .sendRequestHead(requestHead, startBody: true, startReadTimeoutTimer: nil)) + XCTAssertEqual(state.requestVerified(requestHead), .sendRequestHead(requestHead, startBody: true)) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) let failAction = state.requestStreamPartReceived(part1) - guard case .failRequest(let error, .close, clearReadTimeoutTimer: false) = failAction else { + guard case .failRequest(let error, .close) = failAction else { return XCTFail("Unexpected action: \(failAction)") } @@ -85,12 +85,12 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true, idleReadTimeout: nil) XCTAssertEqual(state.start(), .verifyRequest) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "8")])) - XCTAssertEqual(state.requestVerified(requestHead), .sendRequestHead(requestHead, startBody: true, startReadTimeoutTimer: nil)) + XCTAssertEqual(state.requestVerified(requestHead), .sendRequestHead(requestHead, startBody: true)) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) let failAction = state.requestStreamFinished() - guard case .failRequest(let error, .close, clearReadTimeoutTimer: false) = failAction else { + guard case .failRequest(let error, .close) = failAction else { return XCTFail("Unexpected action: \(failAction)") } @@ -104,8 +104,8 @@ extension HTTPRequestStateMachine.Action: Equatable { case (.verifyRequest, .verifyRequest): return true - case (.sendRequestHead(let lhsHead, let lhsStartBody, let lhsIdleReadTimeout), .sendRequestHead(let rhsHead, let rhsStartBody, let rhsIdleReadTimeout)): - return lhsHead == rhsHead && lhsStartBody == rhsStartBody && lhsIdleReadTimeout == rhsIdleReadTimeout + case (.sendRequestHead(let lhsHead, let lhsStartBody), .sendRequestHead(let rhsHead, let rhsStartBody)): + return lhsHead == rhsHead && lhsStartBody == rhsStartBody case (.sendBodyPart(let lhsData), .sendBodyPart(let rhsData)): return lhsData == rhsData case (.sendRequestEnd, .sendRequestEnd): @@ -116,15 +116,15 @@ extension HTTPRequestStateMachine.Action: Equatable { case (.resumeRequestBodyStream, .resumeRequestBodyStream): return true - case (.forwardResponseHead(let lhsHead, let lhsPauseRequestStream), .forwardResponseHead(let rhsHead, let rhsPauseRequestStream)): - return lhsHead == rhsHead && lhsPauseRequestStream == rhsPauseRequestStream - case (.forwardResponseBodyPart(let lhsData, let lhsIdleReadTimeout), .forwardResponseBodyPart(let rhsData, let rhsIdleReadTimeout)): - return lhsIdleReadTimeout == rhsIdleReadTimeout && lhsData == rhsData + case (.forwardResponseHead(let lhsHead, let lhsPauseRequestBodyStream), .forwardResponseHead(let rhsHead, let rhsPauseRequestBodyStream)): + return lhsHead == rhsHead && lhsPauseRequestBodyStream == rhsPauseRequestBodyStream + case (.forwardResponseBodyPart(let lhsData), .forwardResponseBodyPart(let rhsData)): + return lhsData == rhsData - case (.succeedRequest(let lhsFinalAction, let lhsClearReadTimeoutTimer), .succeedRequest(let rhsFinalAction, let rhsClearReadTimeoutTimer)): - return lhsFinalAction == rhsFinalAction && lhsClearReadTimeoutTimer == rhsClearReadTimeoutTimer - case (.failRequest(_, let lhsFinalAction, let lhsClearReadTimeoutTimer), .failRequest(_, let rhsFinalAction, let rhsClearReadTimeoutTimer)): - return lhsFinalAction == rhsFinalAction && lhsClearReadTimeoutTimer == rhsClearReadTimeoutTimer + case (.succeedRequest(let lhsFinalAction), .succeedRequest(let rhsFinalAction)): + return lhsFinalAction == rhsFinalAction + case (.failRequest(_, let lhsFinalAction), .failRequest(_, let rhsFinalAction)): + return lhsFinalAction == rhsFinalAction case (.read, .read): return true From 3eff1de4c5e07f75215509266cd8753519ab17e5 Mon Sep 17 00:00:00 2001 From: Fabian Fett <fabianfett@apple.com> Date: Fri, 2 Jul 2021 11:57:47 +0200 Subject: [PATCH 6/9] Remove `verifyRequest` step --- .../HTTPRequestStateMachine.swift | 89 +++++++------------ .../RequestFramingMetadata.swift | 24 +++++ .../HTTPRequestStateMachineTests.swift | 19 ++-- 3 files changed, 66 insertions(+), 66 deletions(-) create mode 100644 Sources/AsyncHTTPClient/ConnectionPool/RequestFramingMetadata.swift diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift index d6dad26d3..95f1a69b2 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift @@ -18,23 +18,19 @@ import NIOHTTP1 struct HTTPRequestStateMachine { fileprivate enum State { /// The initial state machine state. The only valid mutation is `start()`. The state will - /// transition to `.verifyingRequest` + /// transitions to: + /// - `.waitForChannelToBecomeWritable` + /// - `.running(.streaming, .initialized)` (if the Channel is writable and if a request body is expected) + /// - `.running(.endSent, .initialized)` (if the Channel is writable and no request body is expected) case initialized - /// During this state the request's soundness is checked, before sending it out on the wire. - /// Valid transitions are: - /// - .waitForChannelToBecomeWritable (if the Channel is not writable) - /// - .running(.streaming, .initialized) (if the Channel is writable and a request body is expected) - /// - .running(.endSent, .initialized) (if the Channel is writable and no request body is expected) - /// - .failed (if an error was found in the request soundness check) - case verifyingRequest /// Waiting for the channel to be writable. Valid transitions are: - /// - .running(.streaming, .initialized) (once the Channel is writable again and if a request body is expected) - /// - .running(.endSent, .initialized) (once the Channel is writable again and no request body is expected) - /// - .failed (if a connection error occurred) - case waitForChannelToBecomeWritable(HTTPRequestHead) + /// - `.running(.streaming, .initialized)` (once the Channel is writable again and if a request body is expected) + /// - `.running(.endSent, .initialized)` (once the Channel is writable again and no request body is expected) + /// - `.failed` (if a connection error occurred) + case waitForChannelToBecomeWritable(HTTPRequestHead, RequestFramingMetadata) /// A request is on the wire. Valid transitions are: - /// - .finished - /// - .failed + /// - `.finished` + /// - `.failed` case running(RequestState, ResponseState) /// The request has completed successfully case finished @@ -89,8 +85,6 @@ struct HTTPRequestStateMachine { case none } - case verifyRequest - case sendRequestHead(HTTPRequestHead, startBody: Bool) case sendBodyPart(IOData) case sendRequestEnd @@ -118,12 +112,17 @@ struct HTTPRequestStateMachine { self.idleReadTimeout = idleReadTimeout } - mutating func start() -> Action { + mutating func startRequest(head: HTTPRequestHead, metadata: RequestFramingMetadata) -> Action { guard case .initialized = self.state else { preconditionFailure("`start()` must be called first, and exactly once. Invalid state: \(self.state)") } - self.state = .verifyingRequest - return .verifyRequest + + guard self.isChannelWritable else { + self.state = .waitForChannelToBecomeWritable(head, metadata) + return .wait + } + + return self.startSendingRequest(head: head, metadata: metadata) } mutating func writabilityChanged(writable: Bool) -> Action { @@ -139,15 +138,14 @@ struct HTTPRequestStateMachine { switch self.state { case .initialized, - .verifyingRequest, .running(.streaming(_, _, producer: .producing), _), .running(.endSent, _), .finished, .failed: return .wait - case .waitForChannelToBecomeWritable(let head): - return self.startSendingRequestHead(head) + case .waitForChannelToBecomeWritable(let head, let metadata): + return self.startSendingRequest(head: head, metadata: metadata) case .running(.streaming(_, _, producer: .paused), .receivingBody(let head, _)) where head.status.code >= 300: // If we are receiving a response with a status of >= 300, we should not send out @@ -172,7 +170,6 @@ struct HTTPRequestStateMachine { switch self.state { case .initialized, - .verifyingRequest, .waitForChannelToBecomeWritable, .running(.streaming(_, _, producer: .paused), _), .running(.endSent, _), @@ -194,7 +191,6 @@ struct HTTPRequestStateMachine { mutating func readEventCaught() -> Action { switch self.state { case .initialized, - .verifyingRequest, .waitForChannelToBecomeWritable, .running(_, .waitingForHead), .running(_, .endReceived), @@ -220,7 +216,7 @@ struct HTTPRequestStateMachine { switch self.state { case .initialized: preconditionFailure("After the state machine has been initialized, start must be called immediately. Thus this state is unreachable") - case .verifyingRequest, .waitForChannelToBecomeWritable: + case .waitForChannelToBecomeWritable: // the request failed, before it was sent onto the wire. self.state = .failed(error) return .failRequest(error, .none) @@ -232,24 +228,10 @@ struct HTTPRequestStateMachine { } } - mutating func requestVerified(_ head: HTTPRequestHead) -> Action { - guard case .verifyingRequest = self.state else { - preconditionFailure("Invalid state: \(self.state)") - } - - guard self.isChannelWritable else { - self.state = .waitForChannelToBecomeWritable(head) - return .wait - } - - return self.startSendingRequestHead(head) - } - mutating func requestStreamPartReceived(_ part: IOData) -> Action { switch self.state { case .initialized, .waitForChannelToBecomeWritable, - .verifyingRequest, .running(.endSent, _): preconditionFailure("We must be in the request streaming phase, if we receive further body parts. Invalid state: \(self.state)") @@ -310,7 +292,6 @@ struct HTTPRequestStateMachine { mutating func requestStreamFinished() -> Action { switch self.state { case .initialized, - .verifyingRequest, .waitForChannelToBecomeWritable, .running(.endSent, _), .finished: @@ -355,7 +336,7 @@ struct HTTPRequestStateMachine { mutating func requestCancelled() -> Action { switch self.state { - case .initialized, .verifyingRequest, .waitForChannelToBecomeWritable: + case .initialized, .waitForChannelToBecomeWritable: let error = HTTPClientError.cancelled self.state = .failed(error) return .failRequest(error, .none) @@ -372,7 +353,7 @@ struct HTTPRequestStateMachine { mutating func channelInactive() -> Action { switch self.state { - case .initialized, .verifyingRequest, .waitForChannelToBecomeWritable, .running: + case .initialized, .waitForChannelToBecomeWritable, .running: let error = HTTPClientError.remoteConnectionClosed self.state = .failed(error) return .failRequest(error, .none) @@ -388,7 +369,7 @@ struct HTTPRequestStateMachine { mutating func receivedHTTPResponseHead(_ head: HTTPResponseHead) -> Action { switch self.state { - case .initialized, .verifyingRequest, .waitForChannelToBecomeWritable: + case .initialized, .waitForChannelToBecomeWritable: preconditionFailure("How can we receive a response head before sending a request head ourselves") case .running(.streaming(let expectedBodyLength, let sentBodyBytes, producer: .paused), .waitingForHead): @@ -431,7 +412,7 @@ struct HTTPRequestStateMachine { mutating func receivedHTTPResponseBodyPart(_ body: ByteBuffer) -> Action { switch self.state { - case .initialized, .verifyingRequest, .waitForChannelToBecomeWritable: + case .initialized, .waitForChannelToBecomeWritable: preconditionFailure("How can we receive a response head before sending a request head ourselves. Invalid state: \(self.state)") case .running(_, .waitingForHead): @@ -455,7 +436,7 @@ struct HTTPRequestStateMachine { mutating func receivedHTTPResponseEnd() -> Action { switch self.state { - case .initialized, .verifyingRequest, .waitForChannelToBecomeWritable: + case .initialized, .waitForChannelToBecomeWritable: preconditionFailure("How can we receive a response head before sending a request head ourselves. Invalid state: \(self.state)") case .running(_, .waitingForHead): @@ -496,7 +477,6 @@ struct HTTPRequestStateMachine { mutating func forwardMoreBodyParts() -> Action { switch self.state { case .initialized, - .verifyingRequest, .running(_, .waitingForHead), .waitForChannelToBecomeWritable: preconditionFailure("The response is expected to only ask for more data after the response head was forwarded") @@ -520,7 +500,6 @@ struct HTTPRequestStateMachine { mutating func idleReadTimeoutTriggered() -> Action { switch self.state { case .initialized, - .verifyingRequest, .waitForChannelToBecomeWritable, .running(.streaming, _): preconditionFailure("We only schedule idle read timeouts after we have sent the complete request. Invalid state: \(self.state)") @@ -538,14 +517,16 @@ struct HTTPRequestStateMachine { } } - private mutating func startSendingRequestHead(_ head: HTTPRequestHead) -> Action { - if let value = head.headers.first(name: "content-length"), let length = Int(value), length > 0 { - self.state = .running(.streaming(expectedBodyLength: length, sentBodyBytes: 0, producer: .producing), .waitingForHead) - return .sendRequestHead(head, startBody: true) - } else if head.headers.contains(name: "transfer-encoding") { + private mutating func startSendingRequest(head: HTTPRequestHead, metadata: RequestFramingMetadata) -> Action { + switch metadata.body { + case .stream: self.state = .running(.streaming(expectedBodyLength: nil, sentBodyBytes: 0, producer: .producing), .waitingForHead) return .sendRequestHead(head, startBody: true) - } else { + case .fixedSize(let length) where length > 0: + self.state = .running(.streaming(expectedBodyLength: length, sentBodyBytes: 0, producer: .producing), .waitingForHead) + return .sendRequestHead(head, startBody: true) + case .none, .fixedSize: + // fallback if fixed size is 0 self.state = .running(.endSent, .waitingForHead) return .sendRequestHead(head, startBody: false) } @@ -557,8 +538,6 @@ extension HTTPRequestStateMachine: CustomStringConvertible { switch self.state { case .initialized: return "HTTPRequestStateMachine(.initialized, isWritable: \(self.isChannelWritable))" - case .verifyingRequest: - return "HTTPRequestStateMachine(.verifyingRequest, isWritable: \(self.isChannelWritable))" case .waitForChannelToBecomeWritable: return "HTTPRequestStateMachine(.waitForChannelToBecomeWritable, isWritable: \(self.isChannelWritable))" case .running(let requestState, let responseState): diff --git a/Sources/AsyncHTTPClient/ConnectionPool/RequestFramingMetadata.swift b/Sources/AsyncHTTPClient/ConnectionPool/RequestFramingMetadata.swift new file mode 100644 index 000000000..741f06c2e --- /dev/null +++ b/Sources/AsyncHTTPClient/ConnectionPool/RequestFramingMetadata.swift @@ -0,0 +1,24 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +struct RequestFramingMetadata { + enum Body { + case none + case stream + case fixedSize(Int) + } + + var connectionClose: Bool + var body: Body +} diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift index 74619c479..a7e90ac2c 100644 --- a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift @@ -20,9 +20,9 @@ import XCTest class HTTPRequestStateMachineTests: XCTestCase { func testSimpleGETRequest() { var state = HTTPRequestStateMachine(isChannelWritable: true, idleReadTimeout: nil) - XCTAssertEqual(state.start(), .verifyRequest) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") - XCTAssertEqual(state.requestVerified(requestHead), .sendRequestHead(requestHead, startBody: false)) + let metadata = RequestFramingMetadata(connectionClose: false, body: .none) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) @@ -33,9 +33,9 @@ class HTTPRequestStateMachineTests: XCTestCase { func testPOSTRequestWithWriterBackpressure() { var state = HTTPRequestStateMachine(isChannelWritable: true, idleReadTimeout: nil) - XCTAssertEqual(state.start(), .verifyRequest) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "4")])) - XCTAssertEqual(state.requestVerified(requestHead), .sendRequestHead(requestHead, startBody: true)) + let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [1])) let part2 = IOData.byteBuffer(ByteBuffer(bytes: [2])) @@ -66,9 +66,9 @@ class HTTPRequestStateMachineTests: XCTestCase { func testPOSTContentLengthIsTooLong() { var state = HTTPRequestStateMachine(isChannelWritable: true, idleReadTimeout: nil) - XCTAssertEqual(state.start(), .verifyRequest) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "4")])) - XCTAssertEqual(state.requestVerified(requestHead), .sendRequestHead(requestHead, startBody: true)) + let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) @@ -83,9 +83,9 @@ class HTTPRequestStateMachineTests: XCTestCase { func testPOSTContentLengthIsTooShort() { var state = HTTPRequestStateMachine(isChannelWritable: true, idleReadTimeout: nil) - XCTAssertEqual(state.start(), .verifyRequest) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "8")])) - XCTAssertEqual(state.requestVerified(requestHead), .sendRequestHead(requestHead, startBody: true)) + let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(8)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) @@ -101,9 +101,6 @@ class HTTPRequestStateMachineTests: XCTestCase { extension HTTPRequestStateMachine.Action: Equatable { public static func == (lhs: HTTPRequestStateMachine.Action, rhs: HTTPRequestStateMachine.Action) -> Bool { switch (lhs, rhs) { - case (.verifyRequest, .verifyRequest): - return true - case (.sendRequestHead(let lhsHead, let lhsStartBody), .sendRequestHead(let rhsHead, let rhsStartBody)): return lhsHead == rhsHead && lhsStartBody == rhsStartBody case (.sendBodyPart(let lhsData), .sendBodyPart(let rhsData)): From 650d14b3c501be0901df23fe49f40f7a35233d37 Mon Sep 17 00:00:00 2001 From: Fabian Fett <fabianfett@apple.com> Date: Fri, 2 Jul 2021 12:22:39 +0200 Subject: [PATCH 7/9] Code review --- .../HTTPRequestStateMachine.swift | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift index 95f1a69b2..31a236b2f 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift @@ -42,7 +42,7 @@ struct HTTPRequestStateMachine { fileprivate enum RequestState { /// A sub state for sending a request body. Stores whether a producer should produce more /// bytes or should pause. - enum ProducerControlState: Equatable { + enum ProducerControlState: String { /// The request body producer should produce more body bytes. The channel is writable. case producing /// The request body producer should pause producing more bytes. The channel is not writable. @@ -200,7 +200,9 @@ struct HTTPRequestStateMachine { // more data... return .read case .running(_, .receivingBody(_, .downstreamIsConsuming(readPending: true))): - preconditionFailure("It should not be possible to receive two reads after each other, if the first one hasn't been forwarded.") + // We have caught another `read` event already. We don't need to change the state and + // we should continue to wait for the consumer to call `forwardMoreBodyParts` + return .wait case .running(let requestState, .receivingBody(let responseHead, .downstreamIsConsuming(readPending: false))): self.state = .running(requestState, .receivingBody(responseHead, .downstreamIsConsuming(readPending: true))) return .wait @@ -561,17 +563,6 @@ extension HTTPRequestStateMachine.RequestState: CustomStringConvertible { } } -extension HTTPRequestStateMachine.RequestState.ProducerControlState: CustomStringConvertible { - var description: String { - switch self { - case .paused: - return ".paused" - case .producing: - return ".producing" - } - } -} - extension HTTPRequestStateMachine.ResponseState: CustomStringConvertible { var description: String { switch self { From 69539a0062acb28dd586f6030a7f4cdbc283352f Mon Sep 17 00:00:00 2001 From: Fabian Fett <fabianfett@apple.com> Date: Fri, 2 Jul 2021 14:29:50 +0200 Subject: [PATCH 8/9] Final read events --- .../HTTPRequestStateMachine.swift | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift index 31a236b2f..1b0b2014a 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift @@ -61,7 +61,13 @@ struct HTTPRequestStateMachine { /// A sub state for receiving a response. Stores whether the consumer has either signaled demand for more data or /// is busy consuming the so far forwarded bytes enum ConsumerControlState { + /// the state machine is in this state once it has passed down a request head or body part. If a read event + /// occurs while in this state, the readPending flag will be set true. If the consumer signals more demand + /// by invoking `forwardMoreBodyParts`, the state machine will forward the read event. case downstreamIsConsuming(readPending: Bool) + /// the state machine is in this state once the consumer has signaled more demand by invoking + /// `forwardMoreBodyParts`. If a read event occurs in this state the read event will be forwarded + /// immediately. case downstreamHasDemand } @@ -457,17 +463,20 @@ struct HTTPRequestStateMachine { self.state = .finished return .succeedRequest(.close) - case .running(.endSent, .receivingBody(_, let streamState)): - let finalAction: Action.FinalStreamAction - switch streamState { - case .downstreamIsConsuming(readPending: true): - finalAction = .read - case .downstreamIsConsuming(readPending: false), .downstreamHasDemand: - finalAction = .none - } + case .running(.endSent, .receivingBody(_, .downstreamIsConsuming(readPending: true))): + // If we have a received a read event before, we must ensure that the read event + // eventually gets onto the channel pipeline again. The end of the request gives + // us an opportunity for this clean up task. + // It is very unlikely that we can see this in the real world. If we have swallowed + // a read event we don't expect to receive further data from the channel incl. + // response ends. + self.state = .finished + return .succeedRequest(.read) + case .running(.endSent, .receivingBody(_, .downstreamIsConsuming(readPending: false))), + .running(.endSent, .receivingBody(_, .downstreamHasDemand)): self.state = .finished - return .succeedRequest(finalAction) + return .succeedRequest(.none) case .running(_, .endReceived), .finished: preconditionFailure("How can we receive a response end, if another one was already received. Invalid state: \(self.state)") From 660ad204ff09c87bfe4050e58b764c650e24003a Mon Sep 17 00:00:00 2001 From: Fabian Fett <fabianfett@apple.com> Date: Fri, 2 Jul 2021 17:44:32 +0200 Subject: [PATCH 9/9] tests --- .../HTTPRequestStateMachine.swift | 58 +++- .../HTTPRequestStateMachineTests+XCTest.swift | 18 ++ .../HTTPRequestStateMachineTests.swift | 305 +++++++++++++++++- 3 files changed, 357 insertions(+), 24 deletions(-) diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift index 1b0b2014a..0e30e0fd1 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift @@ -93,7 +93,9 @@ struct HTTPRequestStateMachine { case sendRequestHead(HTTPRequestHead, startBody: Bool) case sendBodyPart(IOData) - case sendRequestEnd + /// If the server has replied, with a status of 200...300 before all data was sent, a request is considered succeeded, + /// as soon as we wrote the request end onto the wire. In this case the succeedRequest property is set. + case sendRequestEnd(succeedRequest: FinalStreamAction?) case pauseRequestBodyStream case resumeRequestBodyStream @@ -111,11 +113,9 @@ struct HTTPRequestStateMachine { private var state: State = .initialized private var isChannelWritable: Bool - private let idleReadTimeout: TimeAmount? - init(isChannelWritable: Bool, idleReadTimeout: TimeAmount?) { + init(isChannelWritable: Bool) { self.isChannelWritable = isChannelWritable - self.idleReadTimeout = idleReadTimeout } mutating func startRequest(head: HTTPRequestHead, metadata: RequestFramingMetadata) -> Action { @@ -301,8 +301,7 @@ struct HTTPRequestStateMachine { switch self.state { case .initialized, .waitForChannelToBecomeWritable, - .running(.endSent, _), - .finished: + .running(.endSent, _): preconditionFailure("A request body stream end is only expected if we are in state request streaming. Invalid state: \(self.state)") case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), .waitingForHead): @@ -313,7 +312,7 @@ struct HTTPRequestStateMachine { } self.state = .running(.endSent, .waitingForHead) - return .sendRequestEnd + return .sendRequestEnd(succeedRequest: nil) case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), .receivingBody(let head, let streamState)): assert(head.status.code < 300) @@ -325,7 +324,7 @@ struct HTTPRequestStateMachine { } self.state = .running(.endSent, .receivingBody(head, streamState)) - return .sendRequestEnd + return .sendRequestEnd(succeedRequest: nil) case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), .endReceived): if let expected = expectedBodyLength, expected != sentBodyBytes { @@ -335,10 +334,22 @@ struct HTTPRequestStateMachine { } self.state = .finished - return .succeedRequest(.none) + return .sendRequestEnd(succeedRequest: .some(.none)) case .failed: return .wait + + case .finished: + // A request may be finished, before we have send all parts. This might be the case if + // the server responded with an HTTP status code that is equal or larger to 300 + // (Redirection, Client Error or Server Error). In those cases we pause the request body + // stream as soon as we have received the response head and we succeed the request as + // when response end is received. This may mean, that we succeed a request, even though + // we have not sent all it's body parts. + + // We may still receive something, here because of potential race conditions with the + // producing thread. + return .wait } } @@ -376,23 +387,23 @@ struct HTTPRequestStateMachine { // MARK: - Response mutating func receivedHTTPResponseHead(_ head: HTTPResponseHead) -> Action { + guard head.status.code >= 200 else { + // we ignore any leading 1xx headers... No state change needed. + return .wait + } + switch self.state { case .initialized, .waitForChannelToBecomeWritable: preconditionFailure("How can we receive a response head before sending a request head ourselves") case .running(.streaming(let expectedBodyLength, let sentBodyBytes, producer: .paused), .waitingForHead): self.state = .running( - .streaming(expectedBodyLength: expectedBodyLength, sentBodyBytes: sentBodyBytes, producer: .producing), + .streaming(expectedBodyLength: expectedBodyLength, sentBodyBytes: sentBodyBytes, producer: .paused), .receivingBody(head, .downstreamIsConsuming(readPending: false)) ) return .forwardResponseHead(head, pauseRequestBodyStream: false) case .running(.streaming(let expectedBodyLength, let sentBodyBytes, producer: .producing), .waitingForHead): - guard head.status.code >= 200 else { - // we ignore any leading 1xx headers... No state change needed. - return .wait - } - if head.status.code >= 300 { self.state = .running( .streaming(expectedBodyLength: expectedBodyLength, sentBodyBytes: sentBodyBytes, producer: .paused), @@ -450,12 +461,25 @@ struct HTTPRequestStateMachine { case .running(_, .waitingForHead): preconditionFailure("How can we receive a response end, if we haven't a received a head. Invalid state: \(self.state)") - case .running(.streaming(let expectedBodyLength, let sentBodyBytes, let producerState), .receivingBody(let head, _)) where head.status.code < 300: + case .running(.streaming(let expectedBodyLength, let sentBodyBytes, let producerState), .receivingBody(let head, let consumerState)) where head.status.code < 300: self.state = .running( .streaming(expectedBodyLength: expectedBodyLength, sentBodyBytes: sentBodyBytes, producer: producerState), .endReceived ) - return .wait + + switch consumerState { + case .downstreamHasDemand, .downstreamIsConsuming(readPending: false): + return .wait + case .downstreamIsConsuming(readPending: true): + // If we have a received a read event before, we must ensure that the read event + // eventually gets onto the channel pipeline again. The end of the request gives + // us an opportunity for this clean up task. + // It is very unlikely that we can see this in the real world. If we have swallowed + // a read event we don't expect to receive further data from the channel incl. + // response ends. + + return .read + } case .running(.streaming(_, _, let producerState), .receivingBody(let head, _)): assert(head.status.code >= 300) diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift index 92da07447..95b6acdb8 100644 --- a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift @@ -29,6 +29,24 @@ extension HTTPRequestStateMachineTests { ("testPOSTRequestWithWriterBackpressure", testPOSTRequestWithWriterBackpressure), ("testPOSTContentLengthIsTooLong", testPOSTContentLengthIsTooLong), ("testPOSTContentLengthIsTooShort", testPOSTContentLengthIsTooShort), + ("testRequestBodyStreamIsCancelledIfServerRespondsWith301", testRequestBodyStreamIsCancelledIfServerRespondsWith301), + ("testRequestBodyStreamIsCancelledIfServerRespondsWith301WhileWriteBackpressure", testRequestBodyStreamIsCancelledIfServerRespondsWith301WhileWriteBackpressure), + ("testRequestBodyStreamIsContinuedIfServerRespondsWith200", testRequestBodyStreamIsContinuedIfServerRespondsWith200), + ("testRequestBodyStreamIsContinuedIfServerSendHeadWithStatus200", testRequestBodyStreamIsContinuedIfServerSendHeadWithStatus200), + ("testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerRespondedWith200", testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerRespondedWith200), + ("testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerSendHeadWithStatus200", testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerSendHeadWithStatus200), + ("testRequestIsNotSendUntilChannelIsWritable", testRequestIsNotSendUntilChannelIsWritable), + ("testResponseReadingWithBackpressure", testResponseReadingWithBackpressure), + ("testResponseReadingWithBackpressureEndOfResponseSetsCaughtReadEventFree", testResponseReadingWithBackpressureEndOfResponseSetsCaughtReadEventFree), + ("testCancellingARequestInStateInitializedKeepsTheConnectionAlive", testCancellingARequestInStateInitializedKeepsTheConnectionAlive), + ("testCancellingARequestBeforeBeingSendKeepsTheConnectionAlive", testCancellingARequestBeforeBeingSendKeepsTheConnectionAlive), + ("testCancellingARequestThatIsSent", testCancellingARequestThatIsSent), + ("testRemoteSuddenlyClosesTheConnection", testRemoteSuddenlyClosesTheConnection), + ("testReadTimeoutLeadsToFailureWithEverythingAfterBeingIgnore", testReadTimeoutLeadsToFailureWithEverythingAfterBeingIgnore), + ("testResponseWithStatus1XXAreIgnored", testResponseWithStatus1XXAreIgnored), + ("testReadTimeoutThatFiresToLateIsIgnored", testReadTimeoutThatFiresToLateIsIgnored), + ("testCancellationThatIsInvokedToLateIsIgnored", testCancellationThatIsInvokedToLateIsIgnored), + ("testErrorWhileRunningARequestClosesTheStream", testErrorWhileRunningARequestClosesTheStream), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift index a7e90ac2c..1c0bf48cf 100644 --- a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift @@ -15,11 +15,12 @@ @testable import AsyncHTTPClient import NIO import NIOHTTP1 +import NIOSSL import XCTest class HTTPRequestStateMachineTests: XCTestCase { func testSimpleGETRequest() { - var state = HTTPRequestStateMachine(isChannelWritable: true, idleReadTimeout: nil) + var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .none) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) @@ -32,7 +33,7 @@ class HTTPRequestStateMachineTests: XCTestCase { } func testPOSTRequestWithWriterBackpressure() { - var state = HTTPRequestStateMachine(isChannelWritable: true, idleReadTimeout: nil) + var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "4")])) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) @@ -55,7 +56,7 @@ class HTTPRequestStateMachineTests: XCTestCase { // once we receive a writable event again, we can allow the producer to produce more data XCTAssertEqual(state.writabilityChanged(writable: true), .resumeRequestBodyStream) XCTAssertEqual(state.requestStreamPartReceived(part3), .sendBodyPart(part3)) - XCTAssertEqual(state.requestStreamFinished(), .sendRequestEnd) + XCTAssertEqual(state.requestStreamFinished(), .sendRequestEnd(succeedRequest: nil)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) @@ -65,7 +66,7 @@ class HTTPRequestStateMachineTests: XCTestCase { } func testPOSTContentLengthIsTooLong() { - var state = HTTPRequestStateMachine(isChannelWritable: true, idleReadTimeout: nil) + var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "4")])) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) @@ -82,7 +83,7 @@ class HTTPRequestStateMachineTests: XCTestCase { } func testPOSTContentLengthIsTooShort() { - var state = HTTPRequestStateMachine(isChannelWritable: true, idleReadTimeout: nil) + var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "8")])) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(8)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) @@ -96,6 +97,296 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(error as? HTTPClientError, .bodyLengthMismatch) } + + func testRequestBodyStreamIsCancelledIfServerRespondsWith301() { + var state = HTTPRequestStateMachine(isChannelWritable: true) + let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + let part = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) + XCTAssertEqual(state.requestStreamPartReceived(part), .sendBodyPart(part)) + + // response is comming before having send all data + let responseHead = HTTPResponseHead(version: .http1_1, status: .movedPermanently) + XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead, pauseRequestBodyStream: true)) + XCTAssertEqual(state.writabilityChanged(writable: false), .wait) + XCTAssertEqual(state.writabilityChanged(writable: true), .wait) + XCTAssertEqual(state.requestStreamPartReceived(part), .wait, + "Expected to drop all stream data after having received a response head, with status >= 300") + + XCTAssertEqual(state.receivedHTTPResponseEnd(), .succeedRequest(.close)) + + XCTAssertEqual(state.requestStreamPartReceived(part), .wait, + "Expected to drop all stream data after having received a response head, with status >= 300") + + XCTAssertEqual(state.requestStreamFinished(), .wait, + "Expected to drop all stream data after having received a response head, with status >= 300") + } + + func testRequestBodyStreamIsCancelledIfServerRespondsWith301WhileWriteBackpressure() { + var state = HTTPRequestStateMachine(isChannelWritable: true) + let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + let part = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) + XCTAssertEqual(state.requestStreamPartReceived(part), .sendBodyPart(part)) + XCTAssertEqual(state.writabilityChanged(writable: false), .pauseRequestBodyStream) + + // response is comming before having send all data + let responseHead = HTTPResponseHead(version: .http1_1, status: .movedPermanently) + XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual(state.writabilityChanged(writable: true), .wait) + XCTAssertEqual(state.requestStreamPartReceived(part), .wait, + "Expected to drop all stream data after having received a response head, with status >= 300") + + XCTAssertEqual(state.receivedHTTPResponseEnd(), .succeedRequest(.close)) + + XCTAssertEqual(state.requestStreamPartReceived(part), .wait, + "Expected to drop all stream data after having received a response head, with status >= 300") + + XCTAssertEqual(state.requestStreamFinished(), .wait, + "Expected to drop all stream data after having received a response head, with status >= 300") + } + + func testRequestBodyStreamIsContinuedIfServerRespondsWith200() { + var state = HTTPRequestStateMachine(isChannelWritable: true) + let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) + XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + + // response is coming before having send all data + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) + XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual(state.receivedHTTPResponseEnd(), .wait) + + let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) + XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) + let part2 = IOData.byteBuffer(ByteBuffer(bytes: 8...11)) + XCTAssertEqual(state.requestStreamPartReceived(part2), .sendBodyPart(part2)) + XCTAssertEqual(state.requestStreamFinished(), .sendRequestEnd(succeedRequest: .some(.none))) + } + + func testRequestBodyStreamIsContinuedIfServerSendHeadWithStatus200() { + var state = HTTPRequestStateMachine(isChannelWritable: true) + let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) + XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + + // response is coming before having send all data + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) + XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + + let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) + XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) + let part2 = IOData.byteBuffer(ByteBuffer(bytes: 8...11)) + XCTAssertEqual(state.requestStreamPartReceived(part2), .sendBodyPart(part2)) + XCTAssertEqual(state.requestStreamFinished(), .sendRequestEnd(succeedRequest: nil)) + + XCTAssertEqual(state.receivedHTTPResponseEnd(), .succeedRequest(.none)) + } + + func testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerRespondedWith200() { + var state = HTTPRequestStateMachine(isChannelWritable: true) + let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) + XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + + // response is comming before having send all data + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) + XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual(state.receivedHTTPResponseEnd(), .wait) + + let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) + XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) + XCTAssertEqual(state.requestStreamFinished(), .failRequest(HTTPClientError.bodyLengthMismatch, .close)) + XCTAssertEqual(state.channelInactive(), .wait) + } + + func testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerSendHeadWithStatus200() { + var state = HTTPRequestStateMachine(isChannelWritable: true) + let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) + XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + + // response is comming before having send all data + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) + XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + + let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) + XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) + XCTAssertEqual(state.requestStreamFinished(), .failRequest(HTTPClientError.bodyLengthMismatch, .close)) + XCTAssertEqual(state.receivedHTTPResponseEnd(), .wait) + } + + func testRequestIsNotSendUntilChannelIsWritable() { + var state = HTTPRequestStateMachine(isChannelWritable: false) + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .none) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, startBody: false)) + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) + XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) + XCTAssertEqual(state.receivedHTTPResponseBodyPart(responseBody), .forwardResponseBodyPart(responseBody)) + XCTAssertEqual(state.receivedHTTPResponseEnd(), .succeedRequest(.none)) + XCTAssertEqual(state.channelInactive(), .wait) + } + + func testResponseReadingWithBackpressure() { + var state = HTTPRequestStateMachine(isChannelWritable: true) + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .none) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + let part0 = ByteBuffer(bytes: 0...3) + let part1 = ByteBuffer(bytes: 4...7) + let part2 = ByteBuffer(bytes: 8...11) + XCTAssertEqual(state.receivedHTTPResponseBodyPart(part0), .forwardResponseBodyPart(part0)) + XCTAssertEqual(state.receivedHTTPResponseBodyPart(part1), .forwardResponseBodyPart(part1)) + XCTAssertEqual(state.readEventCaught(), .wait) + XCTAssertEqual(state.readEventCaught(), .wait, "Expected to be able to consume a second read event") + XCTAssertEqual(state.forwardMoreBodyParts(), .read) + XCTAssertEqual(state.receivedHTTPResponseBodyPart(part2), .forwardResponseBodyPart(part2)) + XCTAssertEqual(state.forwardMoreBodyParts(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + XCTAssertEqual(state.receivedHTTPResponseEnd(), .succeedRequest(.none)) + } + + func testResponseReadingWithBackpressureEndOfResponseSetsCaughtReadEventFree() { + var state = HTTPRequestStateMachine(isChannelWritable: true) + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .none) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + let part0 = ByteBuffer(bytes: 0...3) + let part1 = ByteBuffer(bytes: 4...7) + let part2 = ByteBuffer(bytes: 8...11) + XCTAssertEqual(state.receivedHTTPResponseBodyPart(part0), .forwardResponseBodyPart(part0)) + XCTAssertEqual(state.readEventCaught(), .wait) + XCTAssertEqual(state.forwardMoreBodyParts(), .read) + XCTAssertEqual(state.receivedHTTPResponseBodyPart(part1), .forwardResponseBodyPart(part1)) + XCTAssertEqual(state.forwardMoreBodyParts(), .wait) + XCTAssertEqual(state.forwardMoreBodyParts(), .wait, "Calling forward more bytes twice is okay") + XCTAssertEqual(state.readEventCaught(), .read) + XCTAssertEqual(state.receivedHTTPResponseBodyPart(part2), .forwardResponseBodyPart(part2)) + XCTAssertEqual(state.readEventCaught(), .wait) + XCTAssertEqual(state.receivedHTTPResponseEnd(), .succeedRequest(.read)) + } + + func testCancellingARequestInStateInitializedKeepsTheConnectionAlive() { + var state = HTTPRequestStateMachine(isChannelWritable: false) + XCTAssertEqual(state.requestCancelled(), .failRequest(HTTPClientError.cancelled, .none)) + } + + func testCancellingARequestBeforeBeingSendKeepsTheConnectionAlive() { + var state = HTTPRequestStateMachine(isChannelWritable: false) + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .none) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .wait) + XCTAssertEqual(state.requestCancelled(), .failRequest(HTTPClientError.cancelled, .none)) + } + + func testCancellingARequestThatIsSent() { + var state = HTTPRequestStateMachine(isChannelWritable: true) + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .none) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.requestCancelled(), .failRequest(HTTPClientError.cancelled, .close)) + } + + func testRemoteSuddenlyClosesTheConnection() { + var state = HTTPRequestStateMachine(isChannelWritable: true) + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/", headers: .init([("content-length", "4")])) + let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual(state.requestCancelled(), .failRequest(HTTPClientError.remoteConnectionClosed, .close)) + XCTAssertEqual(state.requestStreamPartReceived(.byteBuffer(.init(bytes: 1...3))), .wait) + } + + func testReadTimeoutLeadsToFailureWithEverythingAfterBeingIgnore() { + var state = HTTPRequestStateMachine(isChannelWritable: true) + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .none) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + let part0 = ByteBuffer(bytes: 0...3) + XCTAssertEqual(state.receivedHTTPResponseBodyPart(part0), .forwardResponseBodyPart(part0)) + XCTAssertEqual(state.idleReadTimeoutTriggered(), .failRequest(HTTPClientError.readTimeout, .close)) + XCTAssertEqual(state.receivedHTTPResponseBodyPart(ByteBuffer(bytes: 4...7)), .wait) + XCTAssertEqual(state.receivedHTTPResponseBodyPart(ByteBuffer(bytes: 8...11)), .wait) + XCTAssertEqual(state.forwardMoreBodyParts(), .wait) + XCTAssertEqual(state.receivedHTTPResponseEnd(), .wait) + } + + func testResponseWithStatus1XXAreIgnored() { + var state = HTTPRequestStateMachine(isChannelWritable: true) + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .none) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + + let continueHead = HTTPResponseHead(version: .http1_1, status: .continue) + XCTAssertEqual(state.receivedHTTPResponseHead(continueHead), .wait) + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) + XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual(state.receivedHTTPResponseEnd(), .succeedRequest(.none)) + } + + func testReadTimeoutThatFiresToLateIsIgnored() { + var state = HTTPRequestStateMachine(isChannelWritable: true) + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .none) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + + let continueHead = HTTPResponseHead(version: .http1_1, status: .continue) + XCTAssertEqual(state.receivedHTTPResponseHead(continueHead), .wait) + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) + XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual(state.receivedHTTPResponseEnd(), .succeedRequest(.none)) + XCTAssertEqual(state.idleReadTimeoutTriggered(), .wait, "A read timeout that fires to late must be ignored") + } + + func testCancellationThatIsInvokedToLateIsIgnored() { + var state = HTTPRequestStateMachine(isChannelWritable: true) + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .none) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + + let continueHead = HTTPResponseHead(version: .http1_1, status: .continue) + XCTAssertEqual(state.receivedHTTPResponseHead(continueHead), .wait) + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) + XCTAssertEqual(state.receivedHTTPResponseHead(responseHead), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual(state.receivedHTTPResponseEnd(), .succeedRequest(.none)) + XCTAssertEqual(state.requestCancelled(), .wait, "A cancellation that happens to late is ignored") + } + + func testErrorWhileRunningARequestClosesTheStream() { + var state = HTTPRequestStateMachine(isChannelWritable: true) + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .none) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + + XCTAssertEqual(state.errorHappened(NIOSSLError.uncleanShutdown), .failRequest(NIOSSLError.uncleanShutdown, .close)) + XCTAssertEqual(state.requestCancelled(), .wait, "A cancellation that happens to late is ignored") + } } extension HTTPRequestStateMachine.Action: Equatable { @@ -105,8 +396,8 @@ extension HTTPRequestStateMachine.Action: Equatable { return lhsHead == rhsHead && lhsStartBody == rhsStartBody case (.sendBodyPart(let lhsData), .sendBodyPart(let rhsData)): return lhsData == rhsData - case (.sendRequestEnd, .sendRequestEnd): - return true + case (.sendRequestEnd(let lhsFinalAction), .sendRequestEnd(let rhsFinalAction)): + return lhsFinalAction == rhsFinalAction case (.pauseRequestBodyStream, .pauseRequestBodyStream): return true