diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1.1/HTTP1ConnectionStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1.1/HTTP1ConnectionStateMachine.swift index 012544441..32cbe500c 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1.1/HTTP1ConnectionStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1.1/HTTP1ConnectionStateMachine.swift @@ -261,7 +261,9 @@ struct HTTP1ConnectionStateMachine { let action = requestStateMachine.channelRead(part) if case .head(let head) = part, close == false { - close = !head.isKeepAlive + // since the HTTPClient does not support protocol switching, we must close any + // connection that has received a status `.switchingProtocols` + close = !head.isKeepAlive || head.status == .switchingProtocols } state = .inRequest(requestStateMachine, close: close) return state.modify(with: action) diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift index b7d05e415..655123053 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift @@ -450,8 +450,10 @@ struct HTTPRequestStateMachine { } private mutating func receivedHTTPResponseHead(_ head: HTTPResponseHead) -> Action { - guard head.status.code >= 200 else { - // we ignore any leading 1xx headers... No state change needed. + guard head.status.code >= 200 || head.status == .switchingProtocols else { + // We ignore any leading 1xx headers except for 101 (switching protocols). The + // HTTP1ConnectionStateMachine ensures the connection close for 101 after the `.end` is + // received. return .wait } @@ -527,7 +529,13 @@ struct HTTPRequestStateMachine { preconditionFailure("How can we receive a response head before sending a request head ourselves. Invalid state: \(self.state)") case .running(_, .waitingForHead): - preconditionFailure("How can we receive a response end, if we haven't a received a head. Invalid state: \(self.state)") + // If we receive a http response header with a status code of 1xx, we ignore the header + // except for 101, which we consume. + // If the remote closes the connection after sending a 1xx (not 101) response head, we + // will receive a response end from the parser. We need to protect against this case. + let error = HTTPClientError.httpEndReceivedAfterHeadWith1xx + self.state = .failed(error) + return .failRequest(error, .close) case .running(.streaming(let expectedBodyLength, let sentBodyBytes, let producerState), .receivingBody(let head, var responseStreamState)) where head.status.code < 300: diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index 0edac6b16..2c31392ac 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -901,6 +901,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { case requestStreamCancelled case getConnectionFromPoolTimeout case deadlineExceeded + case httpEndReceivedAfterHeadWith1xx } private var code: Code @@ -983,4 +984,6 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { /// - A connection could not be created within the timout period. /// - Tasks are not processed fast enough on the existing connections, to process all waiters in time public static let getConnectionFromPoolTimeout = HTTPClientError(code: .getConnectionFromPoolTimeout) + + public static let httpEndReceivedAfterHeadWith1xx = HTTPClientError(code: .httpEndReceivedAfterHeadWith1xx) } diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests+XCTest.swift index 17e45387a..75bc6c017 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests+XCTest.swift @@ -40,6 +40,8 @@ extension HTTP1ConnectionStateMachineTests { ("testChannelReadsAreIgnoredIfConnectionIsClosing", testChannelReadsAreIgnoredIfConnectionIsClosing), ("testRequestIsCancelledWhileWaitingForWritable", testRequestIsCancelledWhileWaitingForWritable), ("testConnectionIsClosedIfErrorHappensWhileInRequest", testConnectionIsClosedIfErrorHappensWhileInRequest), + ("testConnectionIsClosedAfterSwitchingProtocols", testConnectionIsClosedAfterSwitchingProtocols), + ("testWeDontCrashAfterEarlyHintsAndConnectionClose", testWeDontCrashAfterEarlyHintsAndConnectionClose), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift index d3e6a37a4..2f5117e91 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift @@ -243,6 +243,30 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let decompressionError = NIOHTTPDecompression.DecompressionError.limit XCTAssertEqual(state.errorHappened(decompressionError), .failRequest(decompressionError, .close)) } + + func testConnectionIsClosedAfterSwitchingProtocols() { + var state = HTTP1ConnectionStateMachine() + XCTAssertEqual(state.channelActive(isWritable: true), .fireChannelActive) + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .none) + let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata, ignoreUncleanSSLShutdown: false) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + let responseHead = HTTPResponseHead(version: .http1_1, status: .switchingProtocols) + XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, [])) + } + + func testWeDontCrashAfterEarlyHintsAndConnectionClose() { + var state = HTTP1ConnectionStateMachine() + XCTAssertEqual(state.channelActive(isWritable: true), .fireChannelActive) + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .none) + let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata, ignoreUncleanSSLShutdown: false) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + let responseHead = HTTPResponseHead(version: .http1_1, status: .init(statusCode: 103, reasonPhrase: "Early Hints")) + XCTAssertEqual(state.channelRead(.head(responseHead)), .wait) + XCTAssertEqual(state.channelRead(.end(nil)), .failRequest(HTTPClientError.httpEndReceivedAfterHeadWith1xx, .close)) + } } extension HTTP1ConnectionStateMachine.Action: Equatable { diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests+XCTest.swift index 35c510e4f..b979690d3 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests+XCTest.swift @@ -32,6 +32,8 @@ extension HTTP1ConnectionTests { ("testConnectionClosesOnCloseHeader", testConnectionClosesOnCloseHeader), ("testConnectionClosesOnRandomlyAppearingCloseHeader", testConnectionClosesOnRandomlyAppearingCloseHeader), ("testConnectionClosesAfterTheRequestWithoutHavingSentAnCloseHeader", testConnectionClosesAfterTheRequestWithoutHavingSentAnCloseHeader), + ("testConnectionIsClosedAfterSwitchingProtocols", testConnectionIsClosedAfterSwitchingProtocols), + ("testConnectionDoesntCrashAfterConnectionCloseAndEarlyHints", testConnectionDoesntCrashAfterConnectionCloseAndEarlyHints), ("testDownloadStreamingBackpressure", testDownloadStreamingBackpressure), ] } diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift index 6fc8f3e94..825a65d89 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift @@ -365,6 +365,131 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertEqual(httpBin.activeConnections, 0) } + func testConnectionIsClosedAfterSwitchingProtocols() { + let embedded = EmbeddedChannel() + let logger = Logger(label: "test.http1.connection") + + XCTAssertNoThrow(try embedded.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 3000)).wait()) + + var maybeConnection: HTTP1Connection? + let connectionDelegate = MockConnectionDelegate() + XCTAssertNoThrow(maybeConnection = try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: connectionDelegate, + configuration: .init(decompression: .enabled(limit: .ratio(4))), + logger: logger + )) + guard let connection = maybeConnection else { return XCTFail("Expected to have a connection at this point.") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://swift.org/")) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + )) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + connection.executeRequest(requestBag) + + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end + + let responseString = """ + HTTP/1.1 101 Switching Protocols\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Accept: xAMUK7/Il9bLRFJrikq6mm8CNZI=\r\n\ + Connection: upgrade\r\n\ + date: Mon, 27 Sep 2021 17:53:14 GMT\r\n\ + \r\n\ + \r\nfoo bar baz + """ + + XCTAssertTrue(embedded.isActive) + XCTAssertEqual(connectionDelegate.hitConnectionClosed, 0) + XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) + XCTAssertNoThrow(try embedded.writeInbound(ByteBuffer(string: responseString))) + XCTAssertFalse(embedded.isActive) + (embedded.eventLoop as! EmbeddedEventLoop).run() // tick once to run futures. + XCTAssertEqual(connectionDelegate.hitConnectionClosed, 1) + XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) + + var response: HTTPClient.Response? + XCTAssertNoThrow(response = try requestBag.task.futureResult.wait()) + XCTAssertEqual(response?.status, .switchingProtocols) + XCTAssertEqual(response?.headers.count, 4) + XCTAssertEqual(response?.body, nil) + } + + func testConnectionDoesntCrashAfterConnectionCloseAndEarlyHints() { + let embedded = EmbeddedChannel() + let logger = Logger(label: "test.http1.connection") + + XCTAssertNoThrow(try embedded.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 3000)).wait()) + + var maybeConnection: HTTP1Connection? + let connectionDelegate = MockConnectionDelegate() + XCTAssertNoThrow(maybeConnection = try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: connectionDelegate, + configuration: .init(decompression: .enabled(limit: .ratio(4))), + logger: logger + )) + guard let connection = maybeConnection else { return XCTFail("Expected to have a connection at this point.") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://swift.org/")) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + )) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + connection.executeRequest(requestBag) + + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end + + let responseString = """ + HTTP/1.1 103 Early Hints\r\n\ + date: Mon, 27 Sep 2021 17:53:14 GMT\r\n\ + \r\n\ + \r\n + """ + + XCTAssertTrue(embedded.isActive) + XCTAssertEqual(connectionDelegate.hitConnectionClosed, 0) + XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) + XCTAssertNoThrow(try embedded.writeInbound(ByteBuffer(string: responseString))) + XCTAssertFalse(embedded.isActive) + (embedded.eventLoop as! EmbeddedEventLoop).run() // tick once to run futures. + XCTAssertEqual(connectionDelegate.hitConnectionClosed, 1) + XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) + + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .httpEndReceivedAfterHeadWith1xx) + } + } + // In order to test backpressure we need to make sure that reads will not happen // until the backpressure promise is succeeded. Since we cannot guarantee when // messages will be delivered to a client pipeline and we need this test to be