diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine+Demand.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine+Demand.swift index d20c94d02..90578bc87 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine+Demand.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine+Demand.swift @@ -92,10 +92,21 @@ extension HTTPRequestStateMachine { return buffer } + // For all the following cases, please note: + // Normally these code paths should never be hit. However there is one way to trigger + // this: + // + // If the connection to a server is closed, NIO will forward all outstanding + // `channelRead`s without waiting for a next `context.read` call. After all + // `channelRead`s are delivered, we will also see a `channelReadComplete` call. After + // this has happened, we know that we will get a channelInactive or further + // `channelReads`. If the request ever gets to an `.end` all buffered data will be + // forwarded to the user. + case .waitingForRead, .waitingForDemand, .waitingForReadOrDemand: - preconditionFailure("How can we receive a body part, after a channelReadComplete, but no read has been forwarded yet. Invalid state: \(self.state)") + return nil case .modifying: preconditionFailure("Invalid state: \(self.state)") diff --git a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests+XCTest.swift index b0c57569b..7781d0820 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests+XCTest.swift @@ -29,6 +29,7 @@ extension HTTP1ClientChannelHandlerTests { ("testWriteBackpressure", testWriteBackpressure), ("testClientHandlerCancelsRequestIfWeWantToShutdown", testClientHandlerCancelsRequestIfWeWantToShutdown), ("testIdleReadTimeout", testIdleReadTimeout), + ("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift index 025c18faf..30a0a287c 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift @@ -286,6 +286,64 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { XCTAssertEqual($0 as? HTTPClientError, .readTimeout) } } + + func testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand() { + let embedded = EmbeddedChannel() + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/")) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + )) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + testUtils.connection.executeRequest(requestBag) + + XCTAssertNoThrow(try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + }) + XCTAssertNoThrow(try embedded.receiveEnd()) + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "50")])) + + XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 0) + embedded.read() + XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 1) + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.head(responseHead))) + + // not sending anything after the head should lead to request fail and connection close + embedded.pipeline.fireChannelReadComplete() + embedded.pipeline.read() + XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 2) + + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.body(ByteBuffer(string: "foo bar")))) + embedded.pipeline.fireChannelReadComplete() + // We miss a `embedded.pipeline.read()` here by purpose. + XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 2) + + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.body(ByteBuffer(string: "last bytes")))) + embedded.pipeline.fireChannelReadComplete() + embedded.pipeline.fireChannelInactive() + + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .remoteConnectionClosed) + } + } } class TestBackpressureWriter { diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift index e3b5c72ac..dacaf1a67 100644 --- a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift @@ -56,6 +56,10 @@ extension HTTPRequestStateMachineTests { ("testFailHTTP1RequestWithoutContentLengthWithNIOSSLErrorUncleanShutdown", testFailHTTP1RequestWithoutContentLengthWithNIOSSLErrorUncleanShutdown), ("testFailHTTP1RequestWithoutContentLengthWithNIOSSLErrorUncleanShutdownButIgnoreIt", testFailHTTP1RequestWithoutContentLengthWithNIOSSLErrorUncleanShutdownButIgnoreIt), ("testFailHTTP1RequestWithContentLengthWithNIOSSLErrorUncleanShutdownButIgnoreIt", testFailHTTP1RequestWithContentLengthWithNIOSSLErrorUncleanShutdownButIgnoreIt), + ("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand), + ("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForRead", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForRead), + ("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemand", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemand), + ("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemandMultipleTimes", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemandMultipleTimes), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift index 2e3eccd77..74c538de0 100644 --- a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift @@ -557,6 +557,101 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.errorHappened(HTTPParserError.invalidEOFState), .failRequest(HTTPParserError.invalidEOFState, .close)) XCTAssertEqual(state.channelInactive(), .wait) } + + func testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand() { + var state = HTTPRequestStateMachine(isChannelWritable: true, ignoreUncleanSSLShutdown: false) + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .none) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"]) + let body = ByteBuffer(string: "foo bar") + XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.read(), .read) + XCTAssertEqual(state.channelRead(.body(body)), .wait) + XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts([body])) + XCTAssertEqual(state.read(), .wait) + + XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: " baz lightyear"))), .wait) + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.channelInactive(), .failRequest(HTTPClientError.remoteConnectionClosed, .none)) + } + + func testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForRead() { + var state = HTTPRequestStateMachine(isChannelWritable: true, ignoreUncleanSSLShutdown: false) + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .none) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"]) + let body = ByteBuffer(string: "foo bar") + XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.read(), .read) + XCTAssertEqual(state.channelRead(.body(body)), .wait) + XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts([body])) + XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) + + XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: " baz lightyear"))), .wait) + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.channelInactive(), .failRequest(HTTPClientError.remoteConnectionClosed, .none)) + } + + func testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemand() { + var state = HTTPRequestStateMachine(isChannelWritable: true, ignoreUncleanSSLShutdown: false) + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .none) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"]) + let body = ByteBuffer(string: "foo bar") + XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.read(), .read) + XCTAssertEqual(state.channelRead(.body(body)), .wait) + XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts([body])) + + XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: " baz lightyear"))), .wait) + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.channelInactive(), .failRequest(HTTPClientError.remoteConnectionClosed, .none)) + } + + func testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemandMultipleTimes() { + var state = HTTPRequestStateMachine(isChannelWritable: true, ignoreUncleanSSLShutdown: false) + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .none) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"]) + let body = ByteBuffer(string: "foo bar") + XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.read(), .read) + XCTAssertEqual(state.channelRead(.body(body)), .wait) + XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts([body])) + + let part1 = ByteBuffer(string: "baz lightyear") + XCTAssertEqual(state.channelRead(.body(part1)), .wait) + XCTAssertEqual(state.channelReadComplete(), .wait) + + let part2 = ByteBuffer(string: "nearly last") + XCTAssertEqual(state.channelRead(.body(part2)), .wait) + XCTAssertEqual(state.channelReadComplete(), .wait) + + let part3 = ByteBuffer(string: "final message") + XCTAssertEqual(state.channelRead(.body(part3)), .wait) + XCTAssertEqual(state.channelReadComplete(), .wait) + + XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, [part1, part2, part3])) + XCTAssertEqual(state.channelReadComplete(), .wait) + + XCTAssertEqual(state.channelInactive(), .wait) + } } extension HTTPRequestStateMachine.Action: Equatable {