diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1.1/HTTP1ConnectionStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1.1/HTTP1ConnectionStateMachine.swift index b5787f690..88107bda0 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1.1/HTTP1ConnectionStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1.1/HTTP1ConnectionStateMachine.swift @@ -117,10 +117,10 @@ struct HTTP1ConnectionStateMachine { self.state = .closed return .fireChannelError(error, closeConnection: false) - case .inRequest(var requestStateMachine, close: _): + case .inRequest(var requestStateMachine, close: let close): return self.avoidingStateMachineCoW { state -> Action in let action = requestStateMachine.errorHappened(error) - state = .closed + state = .inRequest(requestStateMachine, close: close) return state.modify(with: action) } diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests+XCTest.swift index 70fa91854..9abc3741a 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests+XCTest.swift @@ -37,6 +37,7 @@ extension HTTP1ConnectionStateMachineTests { ("testReadsAreForwardedIfConnectionIsClosing", testReadsAreForwardedIfConnectionIsClosing), ("testChannelReadsAreIgnoredIfConnectionIsClosing", testChannelReadsAreIgnoredIfConnectionIsClosing), ("testRequestIsCancelledWhileWaitingForWritable", testRequestIsCancelledWhileWaitingForWritable), + ("testConnectionIsClosedIfErrorHappensWhileInRequest", testConnectionIsClosedIfErrorHappensWhileInRequest), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift index 76c82813f..3ca42249b 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift @@ -15,6 +15,7 @@ @testable import AsyncHTTPClient import NIOCore import NIOHTTP1 +import NIOHTTPCompression import XCTest class HTTP1ConnectionStateMachineTests: XCTestCase { @@ -22,7 +23,7 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { var state = HTTP1ConnectionStateMachine() XCTAssertEqual(state.channelActive(isWritable: false), .fireChannelActive) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "4")])) + let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: ["content-length": "4"]) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) XCTAssertEqual(state.runNewRequest(head: requestHead, metadata: metadata), .wait) XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, startBody: true)) @@ -64,7 +65,7 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .none) XCTAssertEqual(state.runNewRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["content-length": "12"]) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) let part0 = ByteBuffer(bytes: 0...3) let part1 = ByteBuffer(bytes: 4...7) @@ -141,7 +142,7 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { var state = HTTP1ConnectionStateMachine() XCTAssertEqual(state.channelActive(isWritable: false), .fireChannelActive) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "4")])) + let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: ["content-length": "4"]) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) XCTAssertEqual(state.runNewRequest(head: requestHead, metadata: metadata), .wait) XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, startBody: true)) @@ -185,11 +186,25 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { func testRequestIsCancelledWhileWaitingForWritable() { var state = HTTP1ConnectionStateMachine() XCTAssertEqual(state.channelActive(isWritable: false), .fireChannelActive) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "4")])) + let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: ["content-length": "4"]) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) XCTAssertEqual(state.runNewRequest(head: requestHead, metadata: metadata), .wait) XCTAssertEqual(state.requestCancelled(closeConnection: false), .failRequest(HTTPClientError.cancelled, .informConnectionIsIdle)) } + + func testConnectionIsClosedIfErrorHappensWhileInRequest() { + var state = HTTP1ConnectionStateMachine() + XCTAssertEqual(state.channelActive(isWritable: true), .fireChannelActive) + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .none) + XCTAssertEqual(state.runNewRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) + XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: "Hello world!\n"))), .wait) + XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: "Foo Bar!\n"))), .wait) + let decompressionError = NIOHTTPDecompression.DecompressionError.limit + XCTAssertEqual(state.errorHappened(decompressionError), .failRequest(decompressionError, .close)) + } } extension HTTP1ConnectionStateMachine.Action: Equatable {