diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift index 2d6ad5dd8..380f7386e 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift @@ -190,6 +190,11 @@ struct HTTPRequestStateMachine { } mutating func errorHappened(_ error: Error) -> Action { + if let error = error as? NIOSSLError, + error == .uncleanShutdown, + let action = self.handleNIOSSLUncleanShutdownError() { + return action + } switch self.state { case .initialized: preconditionFailure("After the state machine has been initialized, start must be called immediately. Thus this state is unreachable") @@ -197,16 +202,30 @@ struct HTTPRequestStateMachine { // the request failed, before it was sent onto the wire. self.state = .failed(error) return .failRequest(error, .none) + case .running: + self.state = .failed(error) + return .failRequest(error, .close) + + case .finished, .failed: + // ignore error + return .wait + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + private mutating func handleNIOSSLUncleanShutdownError() -> Action? { + switch self.state { case .running(.streaming, .waitingForHead), - .running(.endSent, .waitingForHead) where error as? NIOSSLError == .uncleanShutdown: + .running(.endSent, .waitingForHead): // if we received a NIOSSL.uncleanShutdown before we got an answer we should handle // this like a normal connection close. We will receive a call to channelInactive after // this error. return .wait case .running(.streaming, .receivingBody(let responseHead, _)), - .running(.endSent, .receivingBody(let responseHead, _)) where error as? NIOSSLError == .uncleanShutdown: + .running(.endSent, .receivingBody(let responseHead, _)): // This code is only reachable for request and responses, which we expect to have a body. // We depend on logic from the HTTPResponseDecoder here. The decoder will emit an // HTTPResponsePart.end right after the HTTPResponsePart.head, for every request with a @@ -226,19 +245,11 @@ struct HTTPRequestStateMachine { // If the response is EOF terminated, we need to rely on a clean tls shutdown to be sure // we have received all necessary bytes. For this reason we forward the uncleanShutdown // error to the user. - self.state = .failed(error) - return .failRequest(error, .close) - - case .running: - self.state = .failed(error) - return .failRequest(error, .close) - - case .finished, .failed: - // ignore error - return .wait + self.state = .failed(NIOSSLError.uncleanShutdown) + return .failRequest(NIOSSLError.uncleanShutdown, .close) - case .modifying: - preconditionFailure("Invalid state: \(self.state)") + case .waitForChannelToBecomeWritable, .running, .finished, .failed, .initialized, .modifying: + return nil } } @@ -270,7 +281,7 @@ struct HTTPRequestStateMachine { if let expected = expectedBodyLength, sentBodyBytes + part.readableBytes > expected { let error = HTTPClientError.bodyLengthMismatch - + self.state = .failed(error) return .failRequest(error, .close) } diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift index 3dd6c7b30..b54865fd8 100644 --- a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift @@ -54,6 +54,8 @@ extension HTTPRequestStateMachineTests { ("testCanReadHTTP1_0ResponseWithBody", testCanReadHTTP1_0ResponseWithBody), ("testFailHTTP1_0RequestThatIsStillUploading", testFailHTTP1_0RequestThatIsStillUploading), ("testFailHTTP1RequestWithoutContentLengthWithNIOSSLErrorUncleanShutdown", testFailHTTP1RequestWithoutContentLengthWithNIOSSLErrorUncleanShutdown), + ("testNIOSSLErrorUncleanShutdownShouldBeTreatedAsRemoteConnectionCloseWhileInWaitingForHeadState", testNIOSSLErrorUncleanShutdownShouldBeTreatedAsRemoteConnectionCloseWhileInWaitingForHeadState), + ("testArbitraryErrorShouldBeTreatedAsARequestFailureWhileInWaitingForHeadState", testArbitraryErrorShouldBeTreatedAsARequestFailureWhileInWaitingForHeadState), ("testFailHTTP1RequestWithContentLengthWithNIOSSLErrorUncleanShutdownButIgnoreIt", testFailHTTP1RequestWithContentLengthWithNIOSSLErrorUncleanShutdownButIgnoreIt), ("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand), ("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForRead", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForRead), diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift index 3e92dc87f..ab55345c9 100644 --- a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift @@ -76,12 +76,10 @@ class HTTPRequestStateMachineTests: XCTestCase { let part1 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) - let failAction = state.requestStreamPartReceived(part1) - guard case .failRequest(let error, .close) = failAction else { - return XCTFail("Unexpected action: \(failAction)") - } + state.requestStreamPartReceived(part1).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close) - XCTAssertEqual(error as? HTTPClientError, .bodyLengthMismatch) + // if another error happens the new one is ignored + XCTAssertEqual(state.errorHappened(HTTPClientError.remoteConnectionClosed), .wait) } func testPOSTContentLengthIsTooShort() { @@ -92,12 +90,7 @@ class HTTPRequestStateMachineTests: XCTestCase { let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) - let failAction = state.requestStreamFinished() - guard case .failRequest(let error, .close) = failAction else { - return XCTFail("Unexpected action: \(failAction)") - } - - XCTAssertEqual(error as? HTTPClientError, .bodyLengthMismatch) + state.requestStreamFinished().assertFailRequest(HTTPClientError.bodyLengthMismatch, .close) } func testRequestBodyStreamIsCancelledIfServerRespondsWith301() { @@ -206,7 +199,7 @@ class HTTPRequestStateMachineTests: XCTestCase { let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) - XCTAssertEqual(state.requestStreamFinished(), .failRequest(HTTPClientError.bodyLengthMismatch, .close)) + state.requestStreamFinished().assertFailRequest(HTTPClientError.bodyLengthMismatch, .close) XCTAssertEqual(state.channelInactive(), .wait) } @@ -224,7 +217,7 @@ class HTTPRequestStateMachineTests: XCTestCase { let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) - XCTAssertEqual(state.requestStreamFinished(), .failRequest(HTTPClientError.bodyLengthMismatch, .close)) + state.requestStreamFinished().assertFailRequest(HTTPClientError.bodyLengthMismatch, .close) XCTAssertEqual(state.channelRead(.end(nil)), .wait) } @@ -249,7 +242,7 @@ class HTTPRequestStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .wait) - XCTAssertEqual(state.channelInactive(), .failRequest(HTTPClientError.remoteConnectionClosed, .none)) + state.channelInactive().assertFailRequest(HTTPClientError.remoteConnectionClosed, .none) } func testResponseReadingWithBackpressure() { @@ -339,7 +332,7 @@ class HTTPRequestStateMachineTests: XCTestCase { func testCancellingARequestInStateInitializedKeepsTheConnectionAlive() { var state = HTTPRequestStateMachine(isChannelWritable: false) - XCTAssertEqual(state.requestCancelled(), .failRequest(HTTPClientError.cancelled, .none)) + state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .none) } func testCancellingARequestBeforeBeingSendKeepsTheConnectionAlive() { @@ -347,7 +340,7 @@ class HTTPRequestStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .wait) - XCTAssertEqual(state.requestCancelled(), .failRequest(HTTPClientError.cancelled, .none)) + state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .none) } func testConnectionBecomesWritableBeforeFirstRequest() { @@ -373,7 +366,7 @@ class HTTPRequestStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) - XCTAssertEqual(state.requestCancelled(), .failRequest(HTTPClientError.cancelled, .close)) + state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .close) } func testRemoteSuddenlyClosesTheConnection() { @@ -381,7 +374,7 @@ class HTTPRequestStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/", headers: .init([("content-length", "4")])) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) - XCTAssertEqual(state.requestCancelled(), .failRequest(HTTPClientError.remoteConnectionClosed, .close)) + state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .close) XCTAssertEqual(state.requestStreamPartReceived(.byteBuffer(.init(bytes: 1...3))), .wait) } @@ -395,7 +388,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) let part0 = ByteBuffer(bytes: 0...3) XCTAssertEqual(state.channelRead(.body(part0)), .wait) - XCTAssertEqual(state.idleReadTimeoutTriggered(), .failRequest(HTTPClientError.readTimeout, .close)) + state.idleReadTimeoutTriggered().assertFailRequest(HTTPClientError.readTimeout, .close) XCTAssertEqual(state.channelRead(.body(ByteBuffer(bytes: 4...7))), .wait) XCTAssertEqual(state.channelRead(.body(ByteBuffer(bytes: 8...11))), .wait) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) @@ -448,7 +441,7 @@ class HTTPRequestStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) - XCTAssertEqual(state.errorHappened(HTTPParserError.invalidChunkSize), .failRequest(HTTPParserError.invalidChunkSize, .close)) + state.errorHappened(HTTPParserError.invalidChunkSize).assertFailRequest(HTTPParserError.invalidChunkSize, .close) XCTAssertEqual(state.requestCancelled(), .wait, "A cancellation that happens to late is ignored") } @@ -502,7 +495,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.read(), .read) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.channelRead(.body(body)), .wait) - XCTAssertEqual(state.channelRead(.end(nil)), .failRequest(HTTPClientError.remoteConnectionClosed, .close)) + state.channelRead(.end(nil)).assertFailRequest(HTTPClientError.remoteConnectionClosed, .close) XCTAssertEqual(state.channelInactive(), .wait) } @@ -517,11 +510,32 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelRead(.body(body)), .wait) - XCTAssertEqual(state.errorHappened(NIOSSLError.uncleanShutdown), .failRequest(NIOSSLError.uncleanShutdown, .close)) + state.errorHappened(NIOSSLError.uncleanShutdown).assertFailRequest(NIOSSLError.uncleanShutdown, .close) XCTAssertEqual(state.channelRead(.end(nil)), .wait) XCTAssertEqual(state.channelInactive(), .wait) } + func testNIOSSLErrorUncleanShutdownShouldBeTreatedAsRemoteConnectionCloseWhileInWaitingForHeadState() { + var state = HTTPRequestStateMachine(isChannelWritable: true) + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + + XCTAssertEqual(state.errorHappened(NIOSSLError.uncleanShutdown), .wait) + state.channelInactive().assertFailRequest(HTTPClientError.remoteConnectionClosed, .none) + } + + func testArbitraryErrorShouldBeTreatedAsARequestFailureWhileInWaitingForHeadState() { + struct ArbitraryError: Error, Equatable {} + var state = HTTPRequestStateMachine(isChannelWritable: true) + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + + state.errorHappened(ArbitraryError()).assertFailRequest(ArbitraryError(), .close) + XCTAssertEqual(state.channelInactive(), .wait) + } + func testFailHTTP1RequestWithContentLengthWithNIOSSLErrorUncleanShutdownButIgnoreIt() { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") @@ -536,7 +550,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.channelRead(.body(body)), .wait) XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts([body])) XCTAssertEqual(state.errorHappened(NIOSSLError.uncleanShutdown), .wait) - XCTAssertEqual(state.errorHappened(HTTPParserError.invalidEOFState), .failRequest(HTTPParserError.invalidEOFState, .close)) + state.errorHappened(HTTPParserError.invalidEOFState).assertFailRequest(HTTPParserError.invalidEOFState, .close) XCTAssertEqual(state.channelInactive(), .wait) } @@ -558,7 +572,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: " baz lightyear"))), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) - XCTAssertEqual(state.channelInactive(), .failRequest(HTTPClientError.remoteConnectionClosed, .none)) + state.channelInactive().assertFailRequest(HTTPClientError.remoteConnectionClosed, .none) } func testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForRead() { @@ -579,7 +593,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: " baz lightyear"))), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) - XCTAssertEqual(state.channelInactive(), .failRequest(HTTPClientError.remoteConnectionClosed, .none)) + state.channelInactive().assertFailRequest(HTTPClientError.remoteConnectionClosed, .none) } func testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemand() { @@ -599,7 +613,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: " baz lightyear"))), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) - XCTAssertEqual(state.channelInactive(), .failRequest(HTTPClientError.remoteConnectionClosed, .none)) + state.channelInactive().assertFailRequest(HTTPClientError.remoteConnectionClosed, .none) } func testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemandMultipleTimes() { @@ -676,3 +690,22 @@ extension HTTPRequestStateMachine.Action: Equatable { } } } + +extension HTTPRequestStateMachine.Action { + fileprivate func assertFailRequest( + _ expectedError: Error, + _ expectedFinalStreamAction: HTTPRequestStateMachine.Action.FinalStreamAction, + file: StaticString = #file, + line: UInt = #line + ) where Error: Swift.Error & Equatable { + guard case .failRequest(let actualError, let actualFinalStreamAction) = self else { + return XCTFail("expected .failRequest(\(expectedError), \(expectedFinalStreamAction)) but got \(self)", file: file, line: line) + } + if let actualError = actualError as? Error { + XCTAssertEqual(actualError, expectedError, file: file, line: line) + } else { + XCTFail("\(actualError) is not equal to \(expectedError)", file: file, line: line) + } + XCTAssertEqual(actualFinalStreamAction, expectedFinalStreamAction, file: file, line: line) + } +}