diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1.1/HTTP1ClientChannelHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1.1/HTTP1ClientChannelHandler.swift new file mode 100644 index 000000000..0fa5c0be8 --- /dev/null +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1.1/HTTP1ClientChannelHandler.swift @@ -0,0 +1,470 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIO +import NIOHTTP1 + +final class HTTP1ClientChannelHandler: ChannelDuplexHandler { + typealias OutboundIn = HTTPExecutableRequest + typealias OutboundOut = HTTPClientRequestPart + typealias InboundIn = HTTPClientResponsePart + + private var state: HTTP1ConnectionStateMachine = .init() { + didSet { + self.eventLoop.assertInEventLoop() + } + } + + /// while we are in a channel pipeline, this context can be used. + private var channelContext: ChannelHandlerContext? + + /// the currently executing request + private var request: HTTPExecutableRequest? { + didSet { + if let request = request { + var requestLogger = request.logger + requestLogger[metadataKey: "ahc-connection-id"] = "\(self.connection.id)" + self.logger = requestLogger + } else { + self.logger = self.backgroundLogger + } + } + } + + private var idleReadTimeoutStateMachine: IdleReadStateMachine? + private var idleReadTimeoutTimer: Scheduled? + + private let backgroundLogger: Logger + private var logger: Logger + + let connection: HTTP1Connection + let eventLoop: EventLoop + + init(connection: HTTP1Connection, eventLoop: EventLoop, logger: Logger) { + self.connection = connection + self.eventLoop = eventLoop + self.backgroundLogger = logger + self.logger = self.backgroundLogger + } + + func handlerAdded(context: ChannelHandlerContext) { + self.channelContext = context + + if context.channel.isActive { + let action = self.state.channelActive(isWritable: context.channel.isWritable) + self.run(action, context: context) + } + } + + func handlerRemoved(context: ChannelHandlerContext) { + self.channelContext = nil + } + + // MARK: Channel Inbound Handler + + func channelActive(context: ChannelHandlerContext) { + self.logger.trace("Channel active", metadata: [ + "ahc-channel-writable": "\(context.channel.isWritable)", + ]) + + let action = self.state.channelActive(isWritable: context.channel.isWritable) + self.run(action, context: context) + } + + func channelInactive(context: ChannelHandlerContext) { + self.logger.trace("Channel inactive") + + let action = self.state.channelInactive() + self.run(action, context: context) + } + + func channelWritabilityChanged(context: ChannelHandlerContext) { + self.logger.trace("Channel writability changed", metadata: [ + "ahc-channel-writable": "\(context.channel.isWritable)", + ]) + + let action = self.state.writabilityChanged(writable: context.channel.isWritable) + self.run(action, context: context) + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let httpPart = unwrapInboundIn(data) + + self.logger.trace("HTTP response part received", metadata: [ + "ahc-http-part": "\(httpPart)", + ]) + + if let timeoutAction = self.idleReadTimeoutStateMachine?.channelRead(httpPart) { + self.runTimeoutAction(timeoutAction, context: context) + } + + let action = self.state.channelRead(httpPart) + self.run(action, context: context) + } + + func channelReadComplete(context: ChannelHandlerContext) { + self.logger.trace("Read complete caught") + + let action = self.state.channelReadComplete() + self.run(action, context: context) + } + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + assert(self.request == nil, "Only write to the ChannelHandler if you are sure, it is idle!") + let req = self.unwrapOutboundIn(data) + self.request = req + + self.logger.trace("New request to execute") + + if let idleReadTimeout = self.request?.idleReadTimeout { + self.idleReadTimeoutStateMachine = .init(timeAmount: idleReadTimeout) + } + + req.willExecuteRequest(self) + + let action = self.state.runNewRequest(head: req.requestHead, metadata: req.requestFramingMetadata) + self.run(action, context: context) + } + + func read(context: ChannelHandlerContext) { + self.logger.trace("Read event caught") + + let action = self.state.read() + self.run(action, context: context) + } + + func errorCaught(context: ChannelHandlerContext, error: Error) { + self.logger.trace("Error caught", metadata: [ + "error": "\(error)", + ]) + + let action = self.state.errorHappened(error) + self.run(action, context: context) + } + + func triggerUserOutboundEvent(context: ChannelHandlerContext, event: Any, promise: EventLoopPromise?) { + switch event { + case HTTPConnectionEvent.cancelRequest: + self.logger.trace("User outbound event triggered: Cancel request for connection close") + let action = self.state.requestCancelled(closeConnection: true) + self.run(action, context: context) + default: + context.fireUserInboundEventTriggered(event) + } + } + + // MARK: - Private Methods - + + // MARK: Run Actions + + private func run(_ action: HTTP1ConnectionStateMachine.Action, context: ChannelHandlerContext) { + switch action { + case .sendRequestHead(let head, startBody: let startBody): + if startBody { + context.write(self.wrapOutboundOut(.head(head)), promise: nil) + context.flush() + + self.request!.requestHeadSent() + self.request!.resumeRequestBodyStream() + } else { + context.write(self.wrapOutboundOut(.head(head)), promise: nil) + context.write(self.wrapOutboundOut(.end(nil)), promise: nil) + context.flush() + + self.request!.requestHeadSent() + + if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(timeoutAction, context: context) + } + } + + case .sendBodyPart(let part): + context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: nil) + + case .sendRequestEnd: + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + + if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(timeoutAction, context: context) + } + + case .pauseRequestBodyStream: + self.request!.pauseRequestBodyStream() + + case .resumeRequestBodyStream: + self.request!.resumeRequestBodyStream() + + case .fireChannelActive: + context.fireChannelActive() + + case .fireChannelInactive: + context.fireChannelInactive() + + case .fireChannelError(let error, let close): + context.fireErrorCaught(error) + if close { + context.close(promise: nil) + } + + case .read: + context.read() + + case .close: + context.close(promise: nil) + + case .wait: + break + + case .forwardResponseHead(let head, let pauseRequestBodyStream): + self.request!.receiveResponseHead(head) + if pauseRequestBodyStream { + self.request!.pauseRequestBodyStream() + } + + case .forwardResponseBodyParts(let buffer): + self.request!.receiveResponseBodyParts(buffer) + + case .succeedRequest(let finalAction, let buffer): + // The order here is very important... + // We first nil our own task property! `taskCompleted` will potentially lead to + // situations in which we get a new request right away. We should finish the task + // after the connection was notified, that we finished. A + // `HTTPClient.shutdown(requiresCleanShutdown: true)` will fail if we do it the + // other way around. + + let oldRequest = self.request! + self.request = nil + self.idleReadTimeoutStateMachine = nil + + switch finalAction { + case .close: + context.close(promise: nil) + case .sendRequestEnd: + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + case .informConnectionIsIdle: + self.connection.taskCompleted() + case .none: + break + } + + oldRequest.succeedRequest(buffer) + + case .failRequest(let error, let finalAction): + // see comment in the `succeedRequest` case. + let oldRequest = self.request! + self.request = nil + self.idleReadTimeoutStateMachine = nil + + switch finalAction { + case .close: + context.close(promise: nil) + case .sendRequestEnd: + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + case .informConnectionIsIdle: + self.connection.taskCompleted() + case .none: + break + } + + oldRequest.fail(error) + } + } + + private func runTimeoutAction(_ action: IdleReadStateMachine.Action, context: ChannelHandlerContext) { + switch action { + case .startIdleReadTimeoutTimer(let timeAmount): + assert(self.idleReadTimeoutTimer == nil, "Expected there is no timeout timer so far.") + + self.idleReadTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + let action = self.state.idleReadTimeoutTriggered() + self.run(action, context: context) + } + + case .resetIdleReadTimeoutTimer(let timeAmount): + if let oldTimer = self.idleReadTimeoutTimer { + oldTimer.cancel() + } + + self.idleReadTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + let action = self.state.idleReadTimeoutTriggered() + self.run(action, context: context) + } + + case .clearIdleReadTimeoutTimer: + if let oldTimer = self.idleReadTimeoutTimer { + self.idleReadTimeoutTimer = nil + oldTimer.cancel() + } + + case .none: + break + } + } + + // MARK: Private HTTPRequestExecutor + + private func writeRequestBodyPart0(_ data: IOData, request: HTTPExecutableRequest) { + guard self.request === request, let context = self.channelContext else { + // Because the HTTPExecutableRequest may run in a different thread to our eventLoop, + // calls from the HTTPExecutableRequest to our ChannelHandler may arrive here after + // the request has been popped by the state machine or the ChannelHandler has been + // removed from the Channel pipeline. This is a normal threading issue, noone has + // screwed up. + return + } + + let action = self.state.requestStreamPartReceived(data) + self.run(action, context: context) + } + + private func finishRequestBodyStream0(_ request: HTTPExecutableRequest) { + guard self.request === request, let context = self.channelContext else { + // See code comment in `writeRequestBodyPart0` + return + } + + let action = self.state.requestStreamFinished() + self.run(action, context: context) + } + + private func demandResponseBodyStream0(_ request: HTTPExecutableRequest) { + guard self.request === request, let context = self.channelContext else { + // See code comment in `writeRequestBodyPart0` + return + } + + self.logger.trace("Downstream requests more response body data") + + let action = self.state.demandMoreResponseBodyParts() + self.run(action, context: context) + } + + private func cancelRequest0(_ request: HTTPExecutableRequest) { + guard self.request === request, let context = self.channelContext else { + // See code comment in `writeRequestBodyPart0` + return + } + + self.logger.trace("Request was cancelled") + + let action = self.state.requestCancelled(closeConnection: true) + self.run(action, context: context) + } +} + +extension HTTP1ClientChannelHandler: HTTPRequestExecutor { + func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest) { + if self.eventLoop.inEventLoop { + self.writeRequestBodyPart0(data, request: request) + } else { + self.eventLoop.execute { + self.writeRequestBodyPart0(data, request: request) + } + } + } + + func finishRequestBodyStream(_ request: HTTPExecutableRequest) { + if self.eventLoop.inEventLoop { + self.finishRequestBodyStream0(request) + } else { + self.eventLoop.execute { + self.finishRequestBodyStream0(request) + } + } + } + + func demandResponseBodyStream(_ request: HTTPExecutableRequest) { + if self.eventLoop.inEventLoop { + self.demandResponseBodyStream0(request) + } else { + self.eventLoop.execute { + self.demandResponseBodyStream0(request) + } + } + } + + func cancelRequest(_ request: HTTPExecutableRequest) { + if self.eventLoop.inEventLoop { + self.cancelRequest0(request) + } else { + self.eventLoop.execute { + self.cancelRequest0(request) + } + } + } +} + +struct IdleReadStateMachine { + enum Action { + case startIdleReadTimeoutTimer(TimeAmount) + case resetIdleReadTimeoutTimer(TimeAmount) + case clearIdleReadTimeoutTimer + case none + } + + enum State { + case waitingForRequestEnd + case waitingForMoreResponseData + case responseEndReceived + } + + private var state: State = .waitingForRequestEnd + private let timeAmount: TimeAmount + + init(timeAmount: TimeAmount) { + self.timeAmount = timeAmount + } + + mutating func requestEndSent() -> Action { + switch self.state { + case .waitingForRequestEnd: + self.state = .waitingForMoreResponseData + return .startIdleReadTimeoutTimer(self.timeAmount) + + case .waitingForMoreResponseData: + preconditionFailure("Invalid state. Waiting for response data must start after request head was sent") + + case .responseEndReceived: + // the response end was received, before we send the request head. Idle timeout timer + // must never be started. + return .none + } + } + + mutating func channelRead(_ part: HTTPClientResponsePart) -> Action { + switch self.state { + case .waitingForRequestEnd: + switch part { + case .head, .body: + return .none + case .end: + self.state = .responseEndReceived + return .none + } + + case .waitingForMoreResponseData: + switch part { + case .head, .body: + return .resetIdleReadTimeoutTimer(self.timeAmount) + case .end: + self.state = .responseEndReceived + return .clearIdleReadTimeoutTimer + } + + case .responseEndReceived: + preconditionFailure("How can we receive more data, if we already received the response end?") + } + } +} diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1.1/HTTP1Connection.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1.1/HTTP1Connection.swift new file mode 100644 index 000000000..575747c17 --- /dev/null +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1.1/HTTP1Connection.swift @@ -0,0 +1,132 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIO +import NIOHTTP1 +import NIOHTTPCompression + +protocol HTTP1ConnectionDelegate { + func http1ConnectionReleased(_: HTTP1Connection) + func http1ConnectionClosed(_: HTTP1Connection) +} + +final class HTTP1Connection { + let channel: Channel + + /// the connection's delegate, that will be informed about connection close and connection release + /// (ready to run next request). + let delegate: HTTP1ConnectionDelegate + + enum State { + case initialized + case active + case closed + } + + private var state: State = .initialized + + let id: HTTPConnectionPool.Connection.ID + + init(channel: Channel, + connectionID: HTTPConnectionPool.Connection.ID, + delegate: HTTP1ConnectionDelegate) { + self.channel = channel + self.id = connectionID + self.delegate = delegate + } + + deinit { + guard case .closed = self.state else { + preconditionFailure("Connection must be closed, before we can deinit it") + } + } + + static func start( + channel: Channel, + connectionID: HTTPConnectionPool.Connection.ID, + delegate: HTTP1ConnectionDelegate, + configuration: HTTPClient.Configuration, + logger: Logger + ) throws -> HTTP1Connection { + let connection = HTTP1Connection(channel: channel, connectionID: connectionID, delegate: delegate) + try connection.start(configuration: configuration, logger: logger) + return connection + } + + func execute(request: HTTPExecutableRequest) { + if self.channel.eventLoop.inEventLoop { + self.execute0(request: request) + } else { + self.channel.eventLoop.execute { + self.execute0(request: request) + } + } + } + + func cancel() { + self.channel.triggerUserOutboundEvent(HTTPConnectionEvent.cancelRequest, promise: nil) + } + + func close() -> EventLoopFuture { + return self.channel.close() + } + + func taskCompleted() { + self.delegate.http1ConnectionReleased(self) + } + + private func execute0(request: HTTPExecutableRequest) { + guard self.channel.isActive else { + return request.fail(ChannelError.ioOnClosedChannel) + } + + self.channel.write(request, promise: nil) + } + + private func start(configuration: HTTPClient.Configuration, logger: Logger) throws { + self.channel.eventLoop.assertInEventLoop() + + guard case .initialized = self.state else { + preconditionFailure("Connection must be initialized, to start it") + } + + self.state = .active + self.channel.closeFuture.whenComplete { _ in + self.state = .closed + self.delegate.http1ConnectionClosed(self) + } + + do { + let sync = self.channel.pipeline.syncOperations + try sync.addHTTPClientHandlers() + + if case .enabled(let limit) = configuration.decompression { + let decompressHandler = NIOHTTPResponseDecompressor(limit: limit) + try sync.addHandler(decompressHandler) + } + + let channelHandler = HTTP1ClientChannelHandler( + connection: self, + eventLoop: channel.eventLoop, + logger: logger + ) + + try sync.addHandler(channelHandler) + } catch { + self.channel.close(mode: .all, promise: nil) + throw error + } + } +} diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionEvent.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionEvent.swift new file mode 100644 index 000000000..4e6d563e6 --- /dev/null +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionEvent.swift @@ -0,0 +1,17 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +enum HTTPConnectionEvent { + case cancelRequest +} diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift index ae9ec7547..3dec4ea33 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift @@ -120,7 +120,7 @@ import NIOHTTP1 /// - The executor may have received an error in thread A that it passes along to the request. /// After having passed on the error, the executor considers the request done and releases /// the request's reference. -/// - The request may issue a call to `writeRequestBodyPart(_: IOData, task: HTTPExecutingRequest)` +/// - The request may issue a call to `writeRequestBodyPart(_: IOData, task: HTTPExecutableRequest)` /// on thread B in the same moment the request error above occurred. For this reason it may /// happen that the executor receives, the invocation of `writeRequestBodyPart` after it has /// failed the request. @@ -187,6 +187,9 @@ protocol HTTPRequestExecutor { } protocol HTTPExecutableRequest: AnyObject { + /// The request's logger + var logger: Logger { get } + /// The request's head. /// /// The HTTP request head, that shall be sent. The HTTPRequestExecutor **will not** run any validation diff --git a/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift b/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift index 08ff5a539..34b42f8fb 100644 --- a/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift +++ b/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift @@ -497,9 +497,9 @@ extension RequestBag.StateMachine { case .executing(let executor, let requestState, .buffering(_, next: .eof)): self.state = .executing(executor, requestState, .buffering(.init(), next: .error(error))) return .cancelExecutor(executor) - case .executing(let executor, let requestState, .buffering(_, next: .askExecutorForMore)): - self.state = .executing(executor, requestState, .buffering(.init(), next: .error(error))) - return .cancelExecutor(executor) + case .executing(let executor, _, .buffering(_, next: .askExecutorForMore)): + self.state = .finished(error: error) + return .failTask(nil, executor) case .executing(let executor, _, .buffering(_, next: .error(_))): // this would override another error, let's keep the first one return .cancelExecutor(executor) diff --git a/Tests/AsyncHTTPClientTests/EmbeddedChannel+HTTPConvenience.swift b/Tests/AsyncHTTPClientTests/EmbeddedChannel+HTTPConvenience.swift new file mode 100644 index 000000000..ab9bcdbeb --- /dev/null +++ b/Tests/AsyncHTTPClientTests/EmbeddedChannel+HTTPConvenience.swift @@ -0,0 +1,114 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@testable import AsyncHTTPClient +import Logging +import NIO +import NIOHTTP1 + +extension EmbeddedChannel { + public func receiveHeadAndVerify(_ verify: (HTTPRequestHead) throws -> Void = { _ in }) throws { + let part = try self.readOutbound(as: HTTPClientRequestPart.self) + switch part { + case .head(let head): + try verify(head) + case .body, .end: + throw HTTP1EmbeddedChannelError(reason: "Expected .head but got '\(part!)'") + case .none: + throw HTTP1EmbeddedChannelError(reason: "Nothing in buffer") + } + } + + public func receiveBodyAndVerify(_ verify: (IOData) throws -> Void = { _ in }) throws { + let part = try self.readOutbound(as: HTTPClientRequestPart.self) + switch part { + case .body(let iodata): + try verify(iodata) + case .head, .end: + throw HTTP1EmbeddedChannelError(reason: "Expected .head but got '\(part!)'") + case .none: + throw HTTP1EmbeddedChannelError(reason: "Nothing in buffer") + } + } + + public func receiveEnd() throws { + let part = try self.readOutbound(as: HTTPClientRequestPart.self) + switch part { + case .end: + break + case .head, .body: + throw HTTP1EmbeddedChannelError(reason: "Expected .head but got '\(part!)'") + case .none: + throw HTTP1EmbeddedChannelError(reason: "Nothing in buffer") + } + } +} + +struct HTTP1TestTools { + let connection: HTTP1Connection + let connectionDelegate: MockConnectionDelegate + let readEventHandler: ReadEventHitHandler + let logger: Logger +} + +extension EmbeddedChannel { + func setupHTTP1Connection() throws -> HTTP1TestTools { + let logger = Logger(label: "test") + let readEventHandler = ReadEventHitHandler() + + try self.pipeline.syncOperations.addHandler(readEventHandler) + try self.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait() + + let connectionDelegate = MockConnectionDelegate() + let connection = try HTTP1Connection.start( + channel: self, + connectionID: 1, + delegate: connectionDelegate, + configuration: .init(), + logger: logger + ) + + // remove HTTP client encoder and decoder + + let decoder = try self.pipeline.syncOperations.handler(type: ByteToMessageHandler.self) + let encoder = try self.pipeline.syncOperations.handler(type: HTTPRequestEncoder.self) + + let removeDecoderFuture = self.pipeline.removeHandler(decoder) + let removeEncoderFuture = self.pipeline.removeHandler(encoder) + + self.embeddedEventLoop.run() + + try removeDecoderFuture.wait() + try removeEncoderFuture.wait() + + return .init( + connection: connection, + connectionDelegate: connectionDelegate, + readEventHandler: readEventHandler, + logger: logger + ) + } +} + +public struct HTTP1EmbeddedChannelError: Error, Hashable, CustomStringConvertible { + public var reason: String + + public init(reason: String) { + self.reason = reason + } + + public var description: String { + return self.reason + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests+XCTest.swift new file mode 100644 index 000000000..b0c57569b --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests+XCTest.swift @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// +// HTTP1ClientChannelHandlerTests+XCTest.swift +// +import XCTest + +/// +/// NOTE: This file was generated by generate_linux_tests.rb +/// +/// Do NOT edit this file directly as it will be regenerated automatically when needed. +/// + +extension HTTP1ClientChannelHandlerTests { + static var allTests: [(String, (HTTP1ClientChannelHandlerTests) -> () throws -> Void)] { + return [ + ("testResponseBackpressure", testResponseBackpressure), + ("testWriteBackpressure", testWriteBackpressure), + ("testClientHandlerCancelsRequestIfWeWantToShutdown", testClientHandlerCancelsRequestIfWeWantToShutdown), + ("testIdleReadTimeout", testIdleReadTimeout), + ] + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift new file mode 100644 index 000000000..77457a2ad --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift @@ -0,0 +1,481 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@testable import AsyncHTTPClient +import Logging +import NIO +import NIOHTTP1 +import XCTest + +class HTTP1ClientChannelHandlerTests: XCTestCase { + func testResponseBackpressure() { + let embedded = EmbeddedChannel() + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/")) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + idleReadTimeout: nil, + delegate: delegate + )) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + testUtils.connection.execute(request: requestBag) + + XCTAssertNoThrow(try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + }) + XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .end(nil)) + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + + XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 0) + embedded.read() + XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 1) + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.head(responseHead))) + + let part0 = ByteBuffer(bytes: 0...3) + let part1 = ByteBuffer(bytes: 4...7) + let part2 = ByteBuffer(bytes: 8...11) + + // part 0. Demand first, read second + XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 1) + let part0Future = delegate.next() + XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 1) + embedded.read() + XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 2) + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.body(part0))) + XCTAssertEqual(try part0Future.wait(), part0) + + // part 1. read first, demand second + + XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 2) + embedded.read() + XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 2) + let part1Future = delegate.next() + XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 3) + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.body(part1))) + XCTAssertEqual(try part1Future.wait(), part1) + + // part 2. Demand first, read second + XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 3) + let part2Future = delegate.next() + XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 3) + embedded.read() + XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 4) + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.body(part2))) + XCTAssertEqual(try part2Future.wait(), part2) + + // end. read first, demand second + XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 4) + embedded.read() + XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 4) + let endFuture = delegate.next() + XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 5) + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionReleased, 0) + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.end(nil))) + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionReleased, 1) + XCTAssertEqual(try endFuture.wait(), .none) + + XCTAssertNoThrow(try requestBag.task.futureResult.wait()) + } + + func testWriteBackpressure() { + class TestWriter { + let eventLoop: EventLoop + + let parts: Int + + var finishFuture: EventLoopFuture { self.finishPromise.futureResult } + private let finishPromise: EventLoopPromise + private(set) var written: Int = 0 + + private var channelIsWritable: Bool = false + + init(eventLoop: EventLoop, parts: Int) { + self.eventLoop = eventLoop + self.parts = parts + + self.finishPromise = eventLoop.makePromise(of: Void.self) + } + + func start(writer: HTTPClient.Body.StreamWriter) -> EventLoopFuture { + func recursive() { + XCTAssert(self.eventLoop.inEventLoop) + XCTAssert(self.channelIsWritable) + if self.written == self.parts { + self.finishPromise.succeed(()) + } else { + self.eventLoop.execute { + let future = writer.write(.byteBuffer(.init(bytes: [0, 1]))) + self.written += 1 + future.whenComplete { result in + switch result { + case .success: + recursive() + case .failure(let error): + XCTFail("Unexpected error: \(error)") + } + } + } + } + } + + recursive() + + return self.finishFuture + } + + func writabilityChanged(_ newValue: Bool) { + self.channelIsWritable = newValue + } + } + + let embedded = EmbeddedChannel() + let testWriter = TestWriter(eventLoop: embedded.eventLoop, parts: 50) + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: .stream(length: 100) { writer in + testWriter.start(writer: writer) + })) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + idleReadTimeout: .milliseconds(200), + delegate: delegate + )) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + // the handler only writes once the channel is writable + embedded.isWritable = false + testWriter.writabilityChanged(false) + embedded.pipeline.fireChannelWritabilityChanged() + testUtils.connection.execute(request: requestBag) + + XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .none) + + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + + XCTAssertNoThrow(try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .POST) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + XCTAssertEqual($0.headers.first(name: "content-length"), "100") + }) + + // the next body write will be executed once we tick the el. before we make the channel + // unwritable + + for index in 0..<50 { + embedded.isWritable = false + testWriter.writabilityChanged(false) + embedded.pipeline.fireChannelWritabilityChanged() + + XCTAssertEqual(testWriter.written, index) + + embedded.embeddedEventLoop.run() + + XCTAssertNoThrow(try embedded.receiveBodyAndVerify { + XCTAssertEqual($0.readableBytes, 2) + }) + + XCTAssertEqual(testWriter.written, index + 1) + + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + } + + embedded.embeddedEventLoop.run() + XCTAssertNoThrow(try embedded.receiveEnd()) + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.head(responseHead))) + embedded.read() + + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionClosed, 0) + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionReleased, 0) + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.end(nil))) + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionClosed, 0) + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionReleased, 1) + + XCTAssertNoThrow(try requestBag.task.futureResult.wait()) + } + + func testClientHandlerCancelsRequestIfWeWantToShutdown() { + let embedded = EmbeddedChannel() + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/")) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + idleReadTimeout: .milliseconds(200), + delegate: delegate + )) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + testUtils.connection.execute(request: requestBag) + + XCTAssertNoThrow(try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + }) + XCTAssertNoThrow(try embedded.receiveEnd()) + + XCTAssertTrue(embedded.isActive) + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionClosed, 0) + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionReleased, 0) + testUtils.connection.cancel() + XCTAssertFalse(embedded.isActive) + embedded.embeddedEventLoop.run() + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionClosed, 1) + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionReleased, 0) + + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .cancelled) + } + } + + func testIdleReadTimeout() { + let embedded = EmbeddedChannel() + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/")) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + idleReadTimeout: .milliseconds(200), + delegate: delegate + )) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + testUtils.connection.execute(request: requestBag) + + XCTAssertNoThrow(try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + }) + XCTAssertNoThrow(try embedded.receiveEnd()) + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + + XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 0) + embedded.read() + XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 1) + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.head(responseHead))) + + // not sending anything after the head should lead to request fail and connection close + + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionClosed, 0) + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionReleased, 0) + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(250)) + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionClosed, 1) + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionReleased, 0) + + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .readTimeout) + } + } +} + +class ResponseBackpressureDelegate: HTTPClientResponseDelegate { + typealias Response = Void + + enum State { + case consuming(EventLoopPromise) + case waitingForRemote(CircularBuffer>) + case buffering((ByteBuffer?, EventLoopPromise)?) + case done + } + + let eventLoop: EventLoop + private var state: State = .buffering(nil) + + init(eventLoop: EventLoop) { + self.eventLoop = eventLoop + + self.state = .consuming(self.eventLoop.makePromise(of: Void.self)) + } + + func next() -> EventLoopFuture { + switch self.state { + case .consuming(let backpressurePromise): + var promiseBuffer = CircularBuffer>() + let newPromise = self.eventLoop.makePromise(of: ByteBuffer?.self) + promiseBuffer.append(newPromise) + self.state = .waitingForRemote(promiseBuffer) + backpressurePromise.succeed(()) + return newPromise.futureResult + + case .waitingForRemote(var promiseBuffer): + assert(!promiseBuffer.isEmpty, "assert expected to be waiting if we have at least one promise in the buffer") + let promise = self.eventLoop.makePromise(of: ByteBuffer?.self) + promiseBuffer.append(promise) + self.state = .waitingForRemote(promiseBuffer) + return promise.futureResult + + case .buffering(.none): + var promiseBuffer = CircularBuffer>() + let promise = self.eventLoop.makePromise(of: ByteBuffer?.self) + promiseBuffer.append(promise) + self.state = .waitingForRemote(promiseBuffer) + return promise.futureResult + + case .buffering(.some((let buffer, let promise))): + self.state = .buffering(nil) + promise.succeed(()) + return self.eventLoop.makeSucceededFuture(buffer) + + case .done: + return self.eventLoop.makeSucceededFuture(.none) + } + } + + func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { + switch self.state { + case .consuming(let backpressurePromise): + return backpressurePromise.futureResult + + case .waitingForRemote: + return self.eventLoop.makeSucceededVoidFuture() + + case .buffering, .done: + preconditionFailure("State must be either waitingForRemote or initialized") + } + } + + func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { + switch self.state { + case .waitingForRemote(var promiseBuffer): + assert(!promiseBuffer.isEmpty, "assert expected to be waiting if we have at least one promise in the buffer") + let promise = promiseBuffer.removeFirst() + if promiseBuffer.isEmpty { + let newBackpressurePromise = self.eventLoop.makePromise(of: Void.self) + self.state = .consuming(newBackpressurePromise) + promise.succeed(buffer) + return newBackpressurePromise.futureResult + } else { + self.state = .waitingForRemote(promiseBuffer) + promise.succeed(buffer) + return self.eventLoop.makeSucceededVoidFuture() + } + + case .buffering(.none): + let promise = self.eventLoop.makePromise(of: Void.self) + self.state = .buffering((buffer, promise)) + return promise.futureResult + + case .buffering(.some): + preconditionFailure("Did receive response part should not be called, before the previous promise was succeeded.") + + case .done, .consuming: + preconditionFailure("Invalid state: \(self.state)") + } + } + + func didFinishRequest(task: HTTPClient.Task) throws { + switch self.state { + case .waitingForRemote(let promiseBuffer): + promiseBuffer.forEach { + $0.succeed(.none) + } + self.state = .done + + case .buffering(.none): + self.state = .done + + case .done, .consuming: + preconditionFailure("Invalid state: \(self.state)") + + case .buffering(.some): + preconditionFailure("Did receive response part should not be called, before the previous promise was succeeded.") + } + } +} + +class ReadEventHitHandler: ChannelOutboundHandler { + public typealias OutboundIn = NIOAny + + private(set) var readHitCounter = 0 + + public init() {} + + public func read(context: ChannelHandlerContext) { + self.readHitCounter += 1 + context.read() + } +} + +class MockConnectionDelegate: HTTP1ConnectionDelegate { + private(set) var hitConnectionReleased = 0 + private(set) var hitConnectionClosed = 0 + + init() {} + + func http1ConnectionReleased(_: HTTP1Connection) { + self.hitConnectionReleased += 1 + } + + func http1ConnectionClosed(_: HTTP1Connection) { + self.hitConnectionClosed += 1 + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests+XCTest.swift new file mode 100644 index 000000000..8ff56e3e4 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests+XCTest.swift @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// +// HTTP1ConnectionTests+XCTest.swift +// +import XCTest + +/// +/// NOTE: This file was generated by generate_linux_tests.rb +/// +/// Do NOT edit this file directly as it will be regenerated automatically when needed. +/// + +extension HTTP1ConnectionTests { + static var allTests: [(String, (HTTP1ConnectionTests) -> () throws -> Void)] { + return [ + ("testCreateNewConnectionWithDecompression", testCreateNewConnectionWithDecompression), + ("testCreateNewConnectionWithoutDecompression", testCreateNewConnectionWithoutDecompression), + ("testCreateNewConnectionFailureClosedIO", testCreateNewConnectionFailureClosedIO), + ("testGETRequest", testGETRequest), + ("testConnectionClosesOnCloseHeader", testConnectionClosesOnCloseHeader), + ] + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift new file mode 100644 index 000000000..2b8be194a --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift @@ -0,0 +1,302 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@testable import AsyncHTTPClient +import Logging +import NIO +import NIOHTTP1 +import NIOHTTPCompression +import NIOTestUtils +import XCTest + +class HTTP1ConnectionTests: XCTestCase { + func testCreateNewConnectionWithDecompression() { + let embedded = EmbeddedChannel() + let logger = Logger(label: "test.http1.connection") + + XCTAssertNoThrow(try embedded.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 3000)).wait()) + + var connection: HTTP1Connection? + XCTAssertNoThrow(connection = try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: MockHTTP1ConnectionDelegate(), + configuration: .init(decompression: .enabled(limit: .ratio(4))), + logger: logger + )) + + XCTAssertNotNil(try embedded.pipeline.syncOperations.handler(type: HTTPRequestEncoder.self)) + XCTAssertNotNil(try embedded.pipeline.syncOperations.handler(type: ByteToMessageHandler.self)) + XCTAssertNotNil(try embedded.pipeline.syncOperations.handler(type: NIOHTTPResponseDecompressor.self)) + + XCTAssertNoThrow(try connection?.close().wait()) + embedded.embeddedEventLoop.run() + XCTAssert(!embedded.isActive) + } + + func testCreateNewConnectionWithoutDecompression() { + let embedded = EmbeddedChannel() + let logger = Logger(label: "test.http1.connection") + + XCTAssertNoThrow(try embedded.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 3000)).wait()) + + XCTAssertNoThrow(try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: MockHTTP1ConnectionDelegate(), + configuration: .init(decompression: .disabled), + logger: logger + )) + + XCTAssertNotNil(try embedded.pipeline.syncOperations.handler(type: HTTPRequestEncoder.self)) + XCTAssertNotNil(try embedded.pipeline.syncOperations.handler(type: ByteToMessageHandler.self)) + XCTAssertThrowsError(try embedded.pipeline.syncOperations.handler(type: NIOHTTPResponseDecompressor.self)) { error in + XCTAssertEqual(error as? ChannelPipelineError, .notFound) + } + } + + func testCreateNewConnectionFailureClosedIO() { + let embedded = EmbeddedChannel() + + XCTAssertNoThrow(try embedded.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 3000)).wait()) + XCTAssertNoThrow(try embedded.close().wait()) + // to really destroy the channel we need to tick once + embedded.embeddedEventLoop.run() + let logger = Logger(label: "test.http1.connection") + + XCTAssertThrowsError(try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: MockHTTP1ConnectionDelegate(), + configuration: .init(), + logger: logger + )) + } + + func testGETRequest() { + let elg = MultiThreadedEventLoopGroup(numberOfThreads: 2) + let clientEL = elg.next() + let serverEL = elg.next() + defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } + let server = NIOHTTP1TestServer(group: serverEL) + defer { XCTAssertNoThrow(try server.stop()) } + + let logger = Logger(label: "test") + let delegate = MockHTTP1ConnectionDelegate() + delegate.closePromise = clientEL.makePromise(of: Void.self) + + let connection = try! ClientBootstrap(group: clientEL) + .connect(to: .init(ipAddress: "127.0.0.1", port: server.serverPort)) + .flatMapThrowing { + try HTTP1Connection.start( + channel: $0, + connectionID: 0, + delegate: delegate, + configuration: .init(decompression: .disabled), + logger: logger + ) + } + .wait() + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( + url: "http://localhost/hello/swift", + method: .POST, + body: .stream(length: 4) { writer -> EventLoopFuture in + func recursive(count: UInt8, promise: EventLoopPromise) { + guard count < 4 else { + return promise.succeed(()) + } + + writer.write(.byteBuffer(ByteBuffer(bytes: [count]))).whenComplete { result in + switch result { + case .failure(let error): + XCTFail("Unexpected error: \(error)") + case .success: + recursive(count: count + 1, promise: promise) + } + } + } + + let promise = clientEL.makePromise(of: Void.self) + recursive(count: 0, promise: promise) + return promise.futureResult + } + )) + + guard let request = maybeRequest else { + return XCTFail("Expected to have a connection and a request") + } + + let task = HTTPClient.Task(eventLoop: clientEL, logger: logger) + + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: clientEL), + task: task, + redirectHandler: nil, + connectionDeadline: .now() + .seconds(60), + idleReadTimeout: nil, + delegate: ResponseAccumulator(request: request) + )) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } + connection.execute(request: requestBag) + + XCTAssertNoThrow(try server.receiveHeadAndVerify { head in + XCTAssertEqual(head.method, .POST) + XCTAssertEqual(head.uri, "/hello/swift") + XCTAssertEqual(head.headers["content-length"].first, "4") + }) + + var received: UInt8 = 0 + while received < 4 { + XCTAssertNoThrow(try server.receiveBodyAndVerify { body in + var body = body + while let read = body.readInteger(as: UInt8.self) { + XCTAssertEqual(received, read) + received += 1 + } + }) + } + XCTAssertEqual(received, 4) + XCTAssertNoThrow(try server.receiveEnd()) + + XCTAssertNoThrow(try server.writeOutbound(.head(.init(version: .http1_1, status: .ok)))) + XCTAssertNoThrow(try server.writeOutbound(.body(.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3]))))) + XCTAssertNoThrow(try server.writeOutbound(.end(nil))) + + var response: HTTPClient.Response? + XCTAssertNoThrow(response = try task.futureResult.wait()) + + XCTAssertEqual(response?.body, ByteBuffer(bytes: [0, 1, 2, 3])) + + // connection is closed + XCTAssertNoThrow(try XCTUnwrap(delegate.closePromise).futureResult.wait()) + } + + func testConnectionClosesOnCloseHeader() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + let eventLoop = eventLoopGroup.next() + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + + let closeOnRequest = (30...100).randomElement()! + let httpBin = HTTPBin(handlerFactory: { _ in SuddenlySendsCloseHeaderChannel(closeOnRequest: closeOnRequest) }) + + var maybeChannel: Channel? + + XCTAssertNoThrow(maybeChannel = try ClientBootstrap(group: eventLoop).connect(host: "localhost", port: httpBin.port).wait()) + let connectionDelegate = MockConnectionDelegate() + let logger = Logger(label: "test") + var maybeConnection: HTTP1Connection? + XCTAssertNoThrow(maybeConnection = try eventLoop.submit { try HTTP1Connection.start( + channel: XCTUnwrap(maybeChannel), + connectionID: 0, + delegate: connectionDelegate, + configuration: .init(), + logger: logger + ) }.wait()) + guard let connection = maybeConnection else { return XCTFail("Expected to have a connection here") } + + var counter = 0 + while true { + counter += 1 + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/")) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: eventLoopGroup.next()), + task: .init(eventLoop: eventLoopGroup.next(), logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + idleReadTimeout: nil, + delegate: delegate + )) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + connection.execute(request: requestBag) + + var response: HTTPClient.Response? + if counter <= closeOnRequest { + XCTAssertNoThrow(response = try requestBag.task.futureResult.wait()) + XCTAssertEqual(response?.status, .ok) + + if response?.headers.first(name: "connection") == "close" { + XCTAssertEqual(closeOnRequest, counter) + XCTAssertEqual(maybeChannel?.isActive, false) + } + } else { + // io on close channel leads to error + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? ChannelError, .ioOnClosedChannel) + } + + break // the loop + } + } + } +} + +class MockHTTP1ConnectionDelegate: HTTP1ConnectionDelegate { + var releasePromise: EventLoopPromise? + var closePromise: EventLoopPromise? + + func http1ConnectionReleased(_: HTTP1Connection) { + self.releasePromise?.succeed(()) + } + + func http1ConnectionClosed(_: HTTP1Connection) { + self.closePromise?.succeed(()) + } +} + +class SuddenlySendsCloseHeaderChannel: ChannelInboundHandler { + typealias InboundIn = HTTPServerRequestPart + typealias OutboundOut = HTTPServerResponsePart + + var counter = 1 + let closeOnRequest: Int + + init(closeOnRequest: Int) { + self.closeOnRequest = closeOnRequest + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + switch self.unwrapInboundIn(data) { + case .head(let head): + XCTAssertLessThanOrEqual(self.counter, self.closeOnRequest) + XCTAssertTrue(head.headers.contains(name: "host")) + XCTAssertEqual(head.method, .GET) + case .body: + break + case .end: + if self.closeOnRequest == self.counter { + context.write(self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok, headers: ["connection": "close"]))), promise: nil) + context.write(self.wrapOutboundOut(.end(nil)), promise: nil) + context.flush() + self.counter += 1 + } else { + context.write(self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok))), promise: nil) + context.write(self.wrapOutboundOut(.end(nil)), promise: nil) + context.flush() + self.counter += 1 + } + } + } +} diff --git a/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift b/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift index 1c069da43..308c8dd07 100644 --- a/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift @@ -30,6 +30,7 @@ extension RequestBagTests { ("testCancelFailsTaskBeforeRequestIsSent", testCancelFailsTaskBeforeRequestIsSent), ("testCancelFailsTaskAfterRequestIsSent", testCancelFailsTaskAfterRequestIsSent), ("testCancelFailsTaskWhenTaskIsQueued", testCancelFailsTaskWhenTaskIsQueued), + ("testFailsTaskWhenTaskIsWaitingForMoreFromServer", testFailsTaskWhenTaskIsWaitingForMoreFromServer), ("testHTTPUploadIsCancelledEvenThoughRequestSucceeds", testHTTPUploadIsCancelledEvenThoughRequestSucceeds), ] } diff --git a/Tests/AsyncHTTPClientTests/RequestBagTests.swift b/Tests/AsyncHTTPClientTests/RequestBagTests.swift index 10ae49527..320b93cfd 100644 --- a/Tests/AsyncHTTPClientTests/RequestBagTests.swift +++ b/Tests/AsyncHTTPClientTests/RequestBagTests.swift @@ -306,6 +306,40 @@ final class RequestBagTests: XCTestCase { } } + func testFailsTaskWhenTaskIsWaitingForMoreFromServer() { + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org")) + guard let request = maybeRequest else { return XCTFail("Expected to have a request") } + + let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + idleReadTimeout: nil, + delegate: delegate + )) + guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } + + let executor = MockRequestExecutor() + bag.willExecuteRequest(executor) + bag.requestHeadSent() + bag.receiveResponseHead(.init(version: .http1_1, status: .ok)) + XCTAssertEqual(executor.isCancelled, false) + bag.fail(HTTPClientError.readTimeout) + XCTAssertEqual(executor.isCancelled, true) + XCTAssertThrowsError(try bag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .readTimeout) + } + } + func testHTTPUploadIsCancelledEvenThoughRequestSucceeds() { let embeddedEventLoop = EmbeddedEventLoop() defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index a7c7ad1c7..6aac03817 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -28,7 +28,9 @@ import XCTest XCTMain([ testCase(ConnectionPoolTests.allTests), testCase(ConnectionTests.allTests), + testCase(HTTP1ClientChannelHandlerTests.allTests), testCase(HTTP1ConnectionStateMachineTests.allTests), + testCase(HTTP1ConnectionTests.allTests), testCase(HTTP1ProxyConnectHandlerTests.allTests), testCase(HTTPClientCookieTests.allTests), testCase(HTTPClientInternalTests.allTests),