diff --git a/Package.swift b/Package.swift index 3f4a7d186..4bc57f85d 100644 --- a/Package.swift +++ b/Package.swift @@ -34,6 +34,7 @@ let package = Package( dependencies: [ .product(name: "NIO", package: "swift-nio"), .product(name: "NIOHTTP1", package: "swift-nio"), + .product(name: "NIOHTTP2", package: "swift-nio-http2"), .product(name: "NIOSSL", package: "swift-nio-ssl"), .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), .product(name: "NIOHTTPCompression", package: "swift-nio-extras"), diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1.1/HTTP1ClientChannelHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1.1/HTTP1ClientChannelHandler.swift index 0fa5c0be8..d6e64697a 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1.1/HTTP1ClientChannelHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1.1/HTTP1ClientChannelHandler.swift @@ -33,12 +33,17 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { /// the currently executing request private var request: HTTPExecutableRequest? { didSet { - if let request = request { - var requestLogger = request.logger + if let newRequest = self.request { + var requestLogger = newRequest.logger requestLogger[metadataKey: "ahc-connection-id"] = "\(self.connection.id)" self.logger = requestLogger + + if let idleReadTimeout = newRequest.idleReadTimeout { + self.idleReadTimeoutStateMachine = .init(timeAmount: idleReadTimeout) + } } else { self.logger = self.backgroundLogger + self.idleReadTimeoutStateMachine = nil } } } @@ -100,7 +105,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { } func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let httpPart = unwrapInboundIn(data) + let httpPart = self.unwrapInboundIn(data) self.logger.trace("HTTP response part received", metadata: [ "ahc-http-part": "\(httpPart)", @@ -121,6 +126,17 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.run(action, context: context) } + func errorCaught(context: ChannelHandlerContext, error: Error) { + self.logger.trace("Error caught", metadata: [ + "error": "\(error)", + ]) + + let action = self.state.errorHappened(error) + self.run(action, context: context) + } + + // MARK: Channel Outbound Handler + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { assert(self.request == nil, "Only write to the ChannelHandler if you are sure, it is idle!") let req = self.unwrapOutboundIn(data) @@ -145,15 +161,6 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.run(action, context: context) } - func errorCaught(context: ChannelHandlerContext, error: Error) { - self.logger.trace("Error caught", metadata: [ - "error": "\(error)", - ]) - - let action = self.state.errorHappened(error) - self.run(action, context: context) - } - func triggerUserOutboundEvent(context: ChannelHandlerContext, event: Any, promise: EventLoopPromise?) { switch event { case HTTPConnectionEvent.cancelRequest: @@ -246,7 +253,6 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { let oldRequest = self.request! self.request = nil - self.idleReadTimeoutStateMachine = nil switch finalAction { case .close: @@ -265,7 +271,6 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { // see comment in the `succeedRequest` case. let oldRequest = self.request! self.request = nil - self.idleReadTimeoutStateMachine = nil switch finalAction { case .close: diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift new file mode 100644 index 000000000..74de17bae --- /dev/null +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift @@ -0,0 +1,339 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIO +import NIOHTTP1 +import NIOHTTP2 + +final class HTTP2ClientRequestHandler: ChannelDuplexHandler { + typealias OutboundIn = HTTPExecutableRequest + typealias OutboundOut = HTTPClientRequestPart + typealias InboundIn = HTTPClientResponsePart + + private let eventLoop: EventLoop + + private var state: HTTPRequestStateMachine = .init(isChannelWritable: false) { + willSet { + self.eventLoop.assertInEventLoop() + } + } + + /// while we are in a channel pipeline, this context can be used. + private var channelContext: ChannelHandlerContext? + + private var request: HTTPExecutableRequest? { + didSet { + if let newRequest = self.request, let idleReadTimeout = newRequest.idleReadTimeout { + self.idleReadTimeoutStateMachine = .init(timeAmount: idleReadTimeout) + } else { + self.idleReadTimeoutStateMachine = nil + } + } + } + + private var idleReadTimeoutStateMachine: IdleReadStateMachine? + private var idleReadTimeoutTimer: Scheduled? + + init(eventLoop: EventLoop) { + self.eventLoop = eventLoop + } + + func handlerAdded(context: ChannelHandlerContext) { + assert(context.eventLoop === self.eventLoop, + "The handler must be added to a channel that runs on the eventLoop it was initialized with.") + self.channelContext = context + + let isWritable = context.channel.isActive && context.channel.isWritable + let action = self.state.writabilityChanged(writable: isWritable) + self.run(action, context: context) + } + + func handlerRemoved(context: ChannelHandlerContext) { + self.channelContext = nil + } + + // MARK: Channel Inbound Handler + + func channelActive(context: ChannelHandlerContext) { + let action = self.state.writabilityChanged(writable: context.channel.isWritable) + self.run(action, context: context) + } + + func channelInactive(context: ChannelHandlerContext) { + let action = self.state.channelInactive() + self.run(action, context: context) + } + + func channelWritabilityChanged(context: ChannelHandlerContext) { + let action = self.state.writabilityChanged(writable: context.channel.isWritable) + self.run(action, context: context) + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let httpPart = self.unwrapInboundIn(data) + + if let timeoutAction = self.idleReadTimeoutStateMachine?.channelRead(httpPart) { + self.runTimeoutAction(timeoutAction, context: context) + } + + let action = self.state.channelRead(httpPart) + self.run(action, context: context) + } + + func channelReadComplete(context: ChannelHandlerContext) { + let action = self.state.channelReadComplete() + self.run(action, context: context) + } + + func errorCaught(context: ChannelHandlerContext, error: Error) { + let action = self.state.errorHappened(error) + self.run(action, context: context) + } + + // MARK: Channel Outbound Handler + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let request = self.unwrapOutboundIn(data) + // The `HTTPRequestStateMachine` ensures that a `HTTP2ClientRequestHandler` only handles + // a single request. + self.request = request + + request.willExecuteRequest(self) + + let action = self.state.startRequest( + head: request.requestHead, + metadata: request.requestFramingMetadata + ) + self.run(action, context: context) + } + + func read(context: ChannelHandlerContext) { + let action = self.state.read() + self.run(action, context: context) + } + + func triggerUserOutboundEvent(context: ChannelHandlerContext, event: Any, promise: EventLoopPromise?) { + switch event { + case HTTPConnectionEvent.cancelRequest: + let action = self.state.requestCancelled() + self.run(action, context: context) + default: + context.fireUserInboundEventTriggered(event) + } + } + + // MARK: - Private Methods - + + // MARK: Run Actions + + private func run(_ action: HTTPRequestStateMachine.Action, context: ChannelHandlerContext) { + // NOTE: We can bang the request in the following actions, since the `HTTPRequestStateMachine` + // ensures, that actions that require a request are only called, if the request is + // still present. The request is only nilled as a response to a state machine action + // (.failRequest or .succeedRequest). + + switch action { + case .sendRequestHead(let head, let startBody): + if startBody { + context.writeAndFlush(self.wrapOutboundOut(.head(head)), promise: nil) + self.request!.requestHeadSent() + self.request!.resumeRequestBodyStream() + } else { + context.write(self.wrapOutboundOut(.head(head)), promise: nil) + context.write(self.wrapOutboundOut(.end(nil)), promise: nil) + context.flush() + + self.request!.requestHeadSent() + + if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(timeoutAction, context: context) + } + } + + case .pauseRequestBodyStream: + self.request!.pauseRequestBodyStream() + + case .sendBodyPart(let data): + context.writeAndFlush(self.wrapOutboundOut(.body(data)), promise: nil) + + case .sendRequestEnd: + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + + if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(timeoutAction, context: context) + } + + case .read: + context.read() + + case .wait: + break + + case .resumeRequestBodyStream: + self.request!.resumeRequestBodyStream() + + case .forwardResponseHead(let head, pauseRequestBodyStream: let pauseRequestBodyStream): + self.request!.receiveResponseHead(head) + if pauseRequestBodyStream { + self.request!.pauseRequestBodyStream() + } + + case .forwardResponseBodyParts(let parts): + self.request!.receiveResponseBodyParts(parts) + + case .failRequest(let error, let finalAction): + self.request!.fail(error) + self.request = nil + self.runFinalAction(finalAction, context: context) + + case .succeedRequest(let finalAction, let finalParts): + self.request!.succeedRequest(finalParts) + self.request = nil + self.runFinalAction(finalAction, context: context) + } + } + + private func runFinalAction(_ action: HTTPRequestStateMachine.Action.FinalStreamAction, context: ChannelHandlerContext) { + switch action { + case .close: + context.close(promise: nil) + + case .sendRequestEnd: + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + + case .none: + break + } + } + + private func runTimeoutAction(_ action: IdleReadStateMachine.Action, context: ChannelHandlerContext) { + switch action { + case .startIdleReadTimeoutTimer(let timeAmount): + assert(self.idleReadTimeoutTimer == nil, "Expected there is no timeout timer so far.") + + self.idleReadTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + let action = self.state.idleReadTimeoutTriggered() + self.run(action, context: context) + } + + case .resetIdleReadTimeoutTimer(let timeAmount): + if let oldTimer = self.idleReadTimeoutTimer { + oldTimer.cancel() + } + + self.idleReadTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + let action = self.state.idleReadTimeoutTriggered() + self.run(action, context: context) + } + + case .clearIdleReadTimeoutTimer: + if let oldTimer = self.idleReadTimeoutTimer { + self.idleReadTimeoutTimer = nil + oldTimer.cancel() + } + + case .none: + break + } + } + + // MARK: Private HTTPRequestExecutor + + private func writeRequestBodyPart0(_ data: IOData, request: HTTPExecutableRequest) { + guard self.request === request, let context = self.channelContext else { + // Because the HTTPExecutingRequest may run in a different thread to our eventLoop, + // calls from the HTTPExecutingRequest to our ChannelHandler may arrive here after + // the request has been popped by the state machine or the ChannelHandler has been + // removed from the Channel pipeline. This is a normal threading issue, noone has + // screwed up. + return + } + + let action = self.state.requestStreamPartReceived(data) + self.run(action, context: context) + } + + private func finishRequestBodyStream0(_ request: HTTPExecutableRequest) { + guard self.request === request, let context = self.channelContext else { + // See code comment in `writeRequestBodyPart0` + return + } + + let action = self.state.requestStreamFinished() + self.run(action, context: context) + } + + private func demandResponseBodyStream0(_ request: HTTPExecutableRequest) { + guard self.request === request, let context = self.channelContext else { + // See code comment in `writeRequestBodyPart0` + return + } + + let action = self.state.demandMoreResponseBodyParts() + self.run(action, context: context) + } + + private func cancelRequest0(_ request: HTTPExecutableRequest) { + guard self.request === request, let context = self.channelContext else { + // See code comment in `writeRequestBodyPart0` + return + } + + let action = self.state.requestCancelled() + self.run(action, context: context) + } +} + +extension HTTP2ClientRequestHandler: HTTPRequestExecutor { + func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest) { + if self.eventLoop.inEventLoop { + self.writeRequestBodyPart0(data, request: request) + } else { + self.eventLoop.execute { + self.writeRequestBodyPart0(data, request: request) + } + } + } + + func finishRequestBodyStream(_ request: HTTPExecutableRequest) { + if self.eventLoop.inEventLoop { + self.finishRequestBodyStream0(request) + } else { + self.eventLoop.execute { + self.finishRequestBodyStream0(request) + } + } + } + + func demandResponseBodyStream(_ request: HTTPExecutableRequest) { + if self.eventLoop.inEventLoop { + self.demandResponseBodyStream0(request) + } else { + self.eventLoop.execute { + self.demandResponseBodyStream0(request) + } + } + } + + func cancelRequest(_ request: HTTPExecutableRequest) { + if self.eventLoop.inEventLoop { + self.cancelRequest0(request) + } else { + self.eventLoop.execute { + self.cancelRequest0(request) + } + } + } +} diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift new file mode 100644 index 000000000..c348ab8c1 --- /dev/null +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift @@ -0,0 +1,301 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIO +import NIOHTTP2 + +protocol HTTP2ConnectionDelegate { + func http2Connection(_: HTTP2Connection, newMaxStreamSetting: Int) + func http2ConnectionStreamClosed(_: HTTP2Connection, availableStreams: Int) + func http2ConnectionGoAwayReceived(_: HTTP2Connection) + func http2ConnectionClosed(_: HTTP2Connection) +} + +struct HTTP2PushNotSupportedError: Error {} + +struct HTTP2ReceivedGoAwayBeforeSettingsError: Error {} + +final class HTTP2Connection { + let channel: Channel + let multiplexer: HTTP2StreamMultiplexer + let logger: Logger + + /// the connection pool that created the connection + let delegate: HTTP2ConnectionDelegate + + enum State { + case initialized + case starting(EventLoopPromise) + case active(maxStreams: Int) + case closing + case closed + } + + /// A structure to store a http/2 stream channel in a set. + private struct ChannelBox: Hashable { + struct ID: Hashable { + private let id: ObjectIdentifier + + init(_ channel: Channel) { + self.id = ObjectIdentifier(channel) + } + } + + let channel: Channel + + var id: ID { + ID(self.channel) + } + + init(_ channel: Channel) { + self.channel = channel + } + + static func == (lhs: Self, rhs: Self) -> Bool { + lhs.id == rhs.id + } + + func hash(into hasher: inout Hasher) { + hasher.combine(self.id) + } + } + + private var state: State + + /// We use this channel set to remember, which open streams we need to inform that + /// we want to close the connection. The channels shall than cancel their currently running + /// request. + private var openStreams = Set() + let id: HTTPConnectionPool.Connection.ID + + var closeFuture: EventLoopFuture { + self.channel.closeFuture + } + + init(channel: Channel, + connectionID: HTTPConnectionPool.Connection.ID, + delegate: HTTP2ConnectionDelegate, + logger: Logger) { + self.channel = channel + self.id = connectionID + self.logger = logger + self.multiplexer = HTTP2StreamMultiplexer( + mode: .client, + channel: channel, + targetWindowSize: 8 * 1024 * 1024, // 8mb + outboundBufferSizeHighWatermark: 8196, + outboundBufferSizeLowWatermark: 4092, + inboundStreamInitializer: { (channel) -> EventLoopFuture in + channel.eventLoop.makeFailedFuture(HTTP2PushNotSupportedError()) + } + ) + self.delegate = delegate + self.state = .initialized + } + + deinit { + guard case .closed = self.state else { + preconditionFailure("Connection must be closed, before we can deinit it") + } + } + + static func start( + channel: Channel, + connectionID: HTTPConnectionPool.Connection.ID, + delegate: HTTP2ConnectionDelegate, + configuration: HTTPClient.Configuration, + logger: Logger + ) -> EventLoopFuture { + let connection = HTTP2Connection(channel: channel, connectionID: connectionID, delegate: delegate, logger: logger) + return connection.start().map { _ in connection } + } + + func executeRequest(_ request: HTTPExecutableRequest) { + if self.channel.eventLoop.inEventLoop { + self.executeRequest0(request) + } else { + self.channel.eventLoop.execute { + self.executeRequest0(request) + } + } + } + + /// shuts down the connection by cancelling all running tasks and closing the connection once + /// all child streams/channels are closed. + func shutdown() { + if self.channel.eventLoop.inEventLoop { + self.shutdown0() + } else { + self.channel.eventLoop.execute { + self.shutdown0() + } + } + } + + func close() -> EventLoopFuture { + self.channel.close() + } + + private func start() -> EventLoopFuture { + self.channel.eventLoop.assertInEventLoop() + + let readyToAcceptConnectionsPromise = self.channel.eventLoop.makePromise(of: Void.self) + + self.state = .starting(readyToAcceptConnectionsPromise) + self.channel.closeFuture.whenComplete { _ in + self.state = .closed + self.delegate.http2ConnectionClosed(self) + } + + do { + // We create and add the http handlers ourselves here, since we need to inject an + // `HTTP2IdleHandler` between the `NIOHTTP2Handler` and the `HTTP2StreamMultiplexer`. + // The purpose of the `HTTP2IdleHandler` is to count open streams in the multiplexer. + // We use the HTTP2IdleHandler's information to notify our delegate, whether more work + // can be scheduled on this connection. + let sync = self.channel.pipeline.syncOperations + + let http2Handler = NIOHTTP2Handler(mode: .client, initialSettings: nioDefaultSettings) + let idleHandler = HTTP2IdleHandler(delegate: self, logger: self.logger) + + try sync.addHandler(http2Handler, position: .last) + try sync.addHandler(idleHandler, position: .last) + try sync.addHandler(self.multiplexer, position: .last) + } catch { + self.channel.close(mode: .all, promise: nil) + readyToAcceptConnectionsPromise.fail(error) + } + + return readyToAcceptConnectionsPromise.futureResult + } + + private func executeRequest0(_ request: HTTPExecutableRequest) { + self.channel.eventLoop.assertInEventLoop() + + switch self.state { + case .initialized, .starting: + preconditionFailure("Invalid state: \(self.state). Sending requests is not allowed before we are started.") + + case .active: + let createStreamChannelPromise = self.channel.eventLoop.makePromise(of: Channel.self) + self.multiplexer.createStreamChannel(promise: createStreamChannelPromise) { channel -> EventLoopFuture in + do { + // the connection may have been asked to shutdown while we created the child. in + // this + // channel. + guard case .active = self.state else { + throw HTTPClientError.cancelled + } + + // We only support http/2 over an https connection – using the Application-Layer + // Protocol Negotiation (ALPN). For this reason it is safe to fix this to `.https`. + let translate = HTTP2FramePayloadToHTTP1ClientCodec(httpProtocol: .https) + let handler = HTTP2ClientRequestHandler(eventLoop: channel.eventLoop) + + try channel.pipeline.syncOperations.addHandler(translate) + try channel.pipeline.syncOperations.addHandler(handler) + + // We must add the new channel to the list of open channels BEFORE we write the + // request to it. In case of an error, we are sure that the channel was added + // before. + let box = ChannelBox(channel) + self.openStreams.insert(box) + self.channel.closeFuture.whenComplete { _ in + self.openStreams.remove(box) + } + + channel.write(request, promise: nil) + return channel.eventLoop.makeSucceededVoidFuture() + } catch { + return channel.eventLoop.makeFailedFuture(error) + } + } + + createStreamChannelPromise.futureResult.whenFailure { error in + request.fail(error) + } + + case .closing, .closed: + // Because of race conditions requests might reach this point, even though the + // connection is already closing + return request.fail(HTTPClientError.cancelled) + } + } + + private func shutdown0() { + self.channel.eventLoop.assertInEventLoop() + + self.state = .closing + + // inform all open streams, that the currently running request should be cancelled. + self.openStreams.forEach { box in + box.channel.triggerUserOutboundEvent(HTTPConnectionEvent.cancelRequest, promise: nil) + } + + // inform the idle connection handler, that connection should be closed, once all streams + // are closed. + self.channel.triggerUserOutboundEvent(HTTPConnectionEvent.closeConnection, promise: nil) + } +} + +extension HTTP2Connection: HTTP2IdleHandlerDelegate { + func http2SettingsReceived(maxStreams: Int) { + self.channel.eventLoop.assertInEventLoop() + + switch self.state { + case .initialized: + preconditionFailure("Invalid state: \(self.state)") + + case .starting(let promise): + self.state = .active(maxStreams: maxStreams) + promise.succeed(()) + + case .active: + self.state = .active(maxStreams: maxStreams) + self.delegate.http2Connection(self, newMaxStreamSetting: maxStreams) + + case .closing, .closed: + // ignore. we only wait for all connections to be closed anyway. + break + } + } + + func http2GoAwayReceived() { + self.channel.eventLoop.assertInEventLoop() + + switch self.state { + case .initialized: + preconditionFailure("Invalid state: \(self.state)") + + case .starting(let promise): + self.state = .closing + promise.fail(HTTP2ReceivedGoAwayBeforeSettingsError()) + + case .active: + self.state = .closing + self.delegate.http2ConnectionGoAwayReceived(self) + + case .closing, .closed: + // we are already closing. Nothing new + break + } + } + + func http2StreamClosed(availableStreams: Int) { + self.channel.eventLoop.assertInEventLoop() + + self.delegate.http2ConnectionStreamClosed(self, availableStreams: availableStreams) + } +} diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2IdleHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2IdleHandler.swift new file mode 100644 index 000000000..d0e6d8ab2 --- /dev/null +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2IdleHandler.swift @@ -0,0 +1,277 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIO +import NIOHTTP2 + +protocol HTTP2IdleHandlerDelegate { + func http2SettingsReceived(maxStreams: Int) + + func http2GoAwayReceived() + + func http2StreamClosed(availableStreams: Int) +} + +// This is a `ChannelDuplexHandler` since we need to intercept outgoing user events. It is generic +// over its delegate to allow for specialization. +final class HTTP2IdleHandler: ChannelDuplexHandler { + typealias InboundIn = HTTP2Frame + typealias InboundOut = HTTP2Frame + typealias OutboundIn = HTTP2Frame + typealias OutboundOut = HTTP2Frame + + let logger: Logger + let delegate: Delegate + + private var state: StateMachine = .init() + + init(delegate: Delegate, logger: Logger) { + self.delegate = delegate + self.logger = logger + } + + func handlerAdded(context: ChannelHandlerContext) { + if context.channel.isActive { + self.state.channelActive() + } + } + + func channelActive(context: ChannelHandlerContext) { + self.state.channelActive() + context.fireChannelActive() + } + + func channelInactive(context: ChannelHandlerContext) { + self.state.channelInactive() + context.fireChannelInactive() + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let frame = self.unwrapInboundIn(data) + + switch frame.payload { + case .goAway: + let action = self.state.goAwayReceived() + self.run(action, context: context) + + case .settings(.settings(let settings)): + let action = self.state.settingsReceived(settings) + self.run(action, context: context) + + default: + // We're not interested in other events. + break + } + + context.fireChannelRead(data) + } + + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + // We intercept calls between the `NIOHTTP2ChannelHandler` and the `HTTP2StreamMultiplexer` + // to learn, how many open streams we have. + switch event { + case is StreamClosedEvent: + let action = self.state.streamClosed() + self.run(action, context: context) + + case is NIOHTTP2StreamCreatedEvent: + let action = self.state.streamCreated() + self.run(action, context: context) + + default: + // We're not interested in other events. + break + } + + context.fireUserInboundEventTriggered(event) + } + + func triggerUserOutboundEvent(context: ChannelHandlerContext, event: Any, promise: EventLoopPromise?) { + switch event { + case HTTPConnectionEvent.closeConnection: + let action = self.state.closeEventReceived() + self.run(action, context: context) + + default: + context.triggerUserOutboundEvent(event, promise: promise) + } + } + + private func run(_ action: StateMachine.Action, context: ChannelHandlerContext) { + switch action { + case .nothing: + break + + case .notifyConnectionNewMaxStreamsSettings(let maxStreams): + self.delegate.http2SettingsReceived(maxStreams: maxStreams) + + case .notifyConnectionStreamClosed(let currentlyAvailable): + self.delegate.http2StreamClosed(availableStreams: currentlyAvailable) + + case .notifyConnectionGoAwayReceived: + self.delegate.http2GoAwayReceived() + + case .close: + context.close(mode: .all, promise: nil) + } + } +} + +extension HTTP2IdleHandler { + struct StateMachine { + enum Action { + case notifyConnectionNewMaxStreamsSettings(Int) + case notifyConnectionGoAwayReceived(close: Bool) + case notifyConnectionStreamClosed(currentlyAvailable: Int) + case nothing + case close + } + + enum State { + case initialized + case connected + case active(openStreams: Int, maxStreams: Int) + case closing(openStreams: Int, maxStreams: Int) + case closed + } + + var state: State = .initialized + + mutating func channelActive() { + switch self.state { + case .initialized: + self.state = .connected + + case .connected, .active, .closing, .closed: + break + } + } + + mutating func channelInactive() { + switch self.state { + case .initialized, .connected, .active, .closing, .closed: + self.state = .closed + } + } + + mutating func settingsReceived(_ settings: HTTP2Settings) -> Action { + switch self.state { + case .initialized, .closed: + preconditionFailure("Invalid state: \(self.state)") + + case .connected: + // a settings frame might have multiple entries for `maxConcurrentStreams`. We are + // only interested in the last value! If no `maxConcurrentStreams` is set, we assume + // the http/2 default of 100. + let maxStreams = settings.last(where: { $0.parameter == .maxConcurrentStreams })?.value ?? 100 + self.state = .active(openStreams: 0, maxStreams: maxStreams) + return .notifyConnectionNewMaxStreamsSettings(maxStreams) + + case .active(openStreams: let openStreams, maxStreams: let maxStreams): + if let newMaxStreams = settings.last(where: { $0.parameter == .maxConcurrentStreams })?.value, newMaxStreams != maxStreams { + self.state = .active(openStreams: openStreams, maxStreams: newMaxStreams) + return .notifyConnectionNewMaxStreamsSettings(newMaxStreams) + } + return .nothing + + case .closing: + return .nothing + } + } + + mutating func goAwayReceived() -> Action { + switch self.state { + case .initialized, .closed: + preconditionFailure("Invalid state: \(self.state)") + + case .connected: + self.state = .closing(openStreams: 0, maxStreams: 0) + return .notifyConnectionGoAwayReceived(close: true) + + case .active(let openStreams, let maxStreams): + self.state = .closing(openStreams: openStreams, maxStreams: maxStreams) + return .notifyConnectionGoAwayReceived(close: openStreams == 0) + + case .closing: + return .notifyConnectionGoAwayReceived(close: false) + } + } + + mutating func closeEventReceived() -> Action { + switch self.state { + case .initialized: + preconditionFailure("Invalid state: \(self.state)") + + case .connected: + self.state = .closing(openStreams: 0, maxStreams: 0) + return .close + + case .active(let openStreams, let maxStreams): + if openStreams == 0 { + self.state = .closed + return .close + } + + self.state = .closing(openStreams: openStreams, maxStreams: maxStreams) + return .nothing + + case .closed, .closing: + return .nothing + } + } + + mutating func streamCreated() -> Action { + switch self.state { + case .active(var openStreams, let maxStreams): + openStreams += 1 + self.state = .active(openStreams: openStreams, maxStreams: maxStreams) + return .nothing + + case .closing(var openStreams, let maxStreams): + // A stream might be opened, while we are closing because of race conditions. For + // this reason, we should handle this case. + openStreams += 1 + self.state = .closing(openStreams: openStreams, maxStreams: maxStreams) + return .nothing + + case .initialized, .connected, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + + mutating func streamClosed() -> Action { + switch self.state { + case .active(var openStreams, let maxStreams): + openStreams -= 1 + assert(openStreams >= 0) + self.state = .active(openStreams: openStreams, maxStreams: maxStreams) + return .notifyConnectionStreamClosed(currentlyAvailable: maxStreams - openStreams) + + case .closing(var openStreams, let maxStreams): + openStreams -= 1 + assert(openStreams >= 0) + if openStreams == 0 { + self.state = .closed + return .close + } + self.state = .closing(openStreams: openStreams, maxStreams: maxStreams) + return .nothing + + case .initialized, .connected, .closed: + preconditionFailure("Invalid state: \(self.state)") + } + } + } +} diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionEvent.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionEvent.swift index 4e6d563e6..4bae049ac 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionEvent.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionEvent.swift @@ -14,4 +14,5 @@ enum HTTPConnectionEvent { case cancelRequest + case closeConnection } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift index 2f19f9224..9949c5aae 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift @@ -30,19 +30,79 @@ extension HTTPConnectionPool { let tlsConfiguration: TLSConfiguration let sslContextCache: SSLContextCache + // This property can be removed once we enable true http/2 support + let allowHTTP2Connections: Bool + init(key: ConnectionPool.Key, tlsConfiguration: TLSConfiguration?, clientConfiguration: HTTPClient.Configuration, - sslContextCache: SSLContextCache) { + sslContextCache: SSLContextCache, + allowHTTP2Connections: Bool = false) { self.key = key self.clientConfiguration = clientConfiguration self.sslContextCache = sslContextCache self.tlsConfiguration = tlsConfiguration ?? clientConfiguration.tlsConfiguration ?? .makeClientConfiguration() + self.allowHTTP2Connections = allowHTTP2Connections } } } +protocol HTTPConnectionRequester { + func http1ConnectionCreated(_: HTTP1Connection) + func http2ConnectionCreated(_: HTTP2Connection, maximumStreams: Int) + func failedToCreateHTTPConnection(_: HTTPConnectionPool.Connection.ID, error: Error) +} + extension HTTPConnectionPool.ConnectionFactory { + func makeConnection( + for requester: Requester, + connectionID: HTTPConnectionPool.Connection.ID, + http1ConnectionDelegate: HTTP1ConnectionDelegate, + http2ConnectionDelegate: HTTP2ConnectionDelegate, + deadline: NIODeadline, + eventLoop: EventLoop, + logger: Logger + ) { + var logger = logger + logger[metadataKey: "ahc-connection"] = "\(connectionID)" + + self.makeChannel(connectionID: connectionID, deadline: deadline, eventLoop: eventLoop, logger: logger).whenComplete { result in + switch result { + case .success(.http1_1(let channel)): + do { + let connection = try HTTP1Connection.start( + channel: channel, + connectionID: connectionID, + delegate: http1ConnectionDelegate, + configuration: self.clientConfiguration, + logger: logger + ) + requester.http1ConnectionCreated(connection) + } catch { + requester.failedToCreateHTTPConnection(connectionID, error: error) + } + case .success(.http2(let channel)): + HTTP2Connection.start( + channel: channel, + connectionID: connectionID, + delegate: http2ConnectionDelegate, + configuration: self.clientConfiguration, + logger: logger + ).whenComplete { result in + switch result { + case .success(let connection): + requester.http2ConnectionCreated(connection, maximumStreams: 0) + case .failure(let error): + requester.failedToCreateHTTPConnection(connectionID, error: error) + } + } + + case .failure(let error): + requester.failedToCreateHTTPConnection(connectionID, error: error) + } + } + } + enum NegotiatedProtocol { case http1_1(Channel) case http2(Channel) @@ -243,7 +303,14 @@ extension HTTPConnectionPool.ConnectionFactory { case .https: var tlsConfig = self.tlsConfiguration // since we can support h2, we need to advertise this in alpn - tlsConfig.applicationProtocols = ["http/1.1" /* , "h2" */ ] + if self.allowHTTP2Connections { + // "ProtocolNameList" contains the list of protocols advertised by the + // client, in descending order of preference. + // https://datatracker.ietf.org/doc/html/rfc7301#section-3.1 + tlsConfig.applicationProtocols = ["h2", "http/1.1"] + } else { + tlsConfig.applicationProtocols = ["http/1.1"] + } let tlsEventHandler = TLSEventsHandler(deadline: deadline) let sslContextFuture = self.sslContextCache.sslContext( @@ -341,7 +408,14 @@ extension HTTPConnectionPool.ConnectionFactory { -> EventLoopFuture { // since we can support h2, we need to advertise this in alpn var tlsConfig = self.tlsConfiguration - tlsConfig.applicationProtocols = ["http/1.1" /* , "h2" */ ] + if self.allowHTTP2Connections { + // "ProtocolNameList" contains the list of protocols advertised by the + // client, in descending order of preference. + // https://datatracker.ietf.org/doc/html/rfc7301#section-3.1 + tlsConfig.applicationProtocols = ["h2", "http/1.1"] + } else { + tlsConfig.applicationProtocols = ["http/1.1"] + } #if canImport(Network) if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) { diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift index 424f42d1a..fd0b9a846 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift @@ -200,7 +200,8 @@ struct HTTPRequestStateMachine { self.state = .failed(error) return .failRequest(error, .close) case .finished, .failed: - preconditionFailure("If the request is finished or failed, we expect the connection state machine to remove the request immediately from its state. Thus this state is unreachable.") + // ignore error + return .wait case .modifying: preconditionFailure("Invalid state: \(self.state)") } diff --git a/Tests/AsyncHTTPClientTests/EmbeddedChannel+HTTPConvenience.swift b/Tests/AsyncHTTPClientTests/EmbeddedChannel+HTTPConvenience.swift index ab9bcdbeb..03fec2d4f 100644 --- a/Tests/AsyncHTTPClientTests/EmbeddedChannel+HTTPConvenience.swift +++ b/Tests/AsyncHTTPClientTests/EmbeddedChannel+HTTPConvenience.swift @@ -16,6 +16,7 @@ import Logging import NIO import NIOHTTP1 +import NIOHTTP2 extension EmbeddedChannel { public func receiveHeadAndVerify(_ verify: (HTTPRequestHead) throws -> Void = { _ in }) throws { diff --git a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift index 77457a2ad..027d4bdfa 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift @@ -105,58 +105,8 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { } func testWriteBackpressure() { - class TestWriter { - let eventLoop: EventLoop - - let parts: Int - - var finishFuture: EventLoopFuture { self.finishPromise.futureResult } - private let finishPromise: EventLoopPromise - private(set) var written: Int = 0 - - private var channelIsWritable: Bool = false - - init(eventLoop: EventLoop, parts: Int) { - self.eventLoop = eventLoop - self.parts = parts - - self.finishPromise = eventLoop.makePromise(of: Void.self) - } - - func start(writer: HTTPClient.Body.StreamWriter) -> EventLoopFuture { - func recursive() { - XCTAssert(self.eventLoop.inEventLoop) - XCTAssert(self.channelIsWritable) - if self.written == self.parts { - self.finishPromise.succeed(()) - } else { - self.eventLoop.execute { - let future = writer.write(.byteBuffer(.init(bytes: [0, 1]))) - self.written += 1 - future.whenComplete { result in - switch result { - case .success: - recursive() - case .failure(let error): - XCTFail("Unexpected error: \(error)") - } - } - } - } - } - - recursive() - - return self.finishFuture - } - - func writabilityChanged(_ newValue: Bool) { - self.channelIsWritable = newValue - } - } - let embedded = EmbeddedChannel() - let testWriter = TestWriter(eventLoop: embedded.eventLoop, parts: 50) + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 50) var maybeTestUtils: HTTP1TestTools? XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } @@ -337,6 +287,56 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { } } +class TestBackpressureWriter { + let eventLoop: EventLoop + + let parts: Int + + var finishFuture: EventLoopFuture { self.finishPromise.futureResult } + private let finishPromise: EventLoopPromise + private(set) var written: Int = 0 + + private var channelIsWritable: Bool = false + + init(eventLoop: EventLoop, parts: Int) { + self.eventLoop = eventLoop + self.parts = parts + + self.finishPromise = eventLoop.makePromise(of: Void.self) + } + + func start(writer: HTTPClient.Body.StreamWriter) -> EventLoopFuture { + func recursive() { + XCTAssert(self.eventLoop.inEventLoop) + XCTAssert(self.channelIsWritable) + if self.written == self.parts { + self.finishPromise.succeed(()) + } else { + self.eventLoop.execute { + let future = writer.write(.byteBuffer(.init(bytes: [0, 1]))) + self.written += 1 + future.whenComplete { result in + switch result { + case .success: + recursive() + case .failure(let error): + XCTFail("Unexpected error: \(error)") + } + } + } + } + } + + recursive() + + return self.finishFuture + } + + func writabilityChanged(_ newValue: Bool) { + self.channelIsWritable = newValue + } +} + class ResponseBackpressureDelegate: HTTPClientResponseDelegate { typealias Response = Void diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests+XCTest.swift new file mode 100644 index 000000000..53be47f01 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests+XCTest.swift @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// +// HTTP2ClientRequestHandlerTests+XCTest.swift +// +import XCTest + +/// +/// NOTE: This file was generated by generate_linux_tests.rb +/// +/// Do NOT edit this file directly as it will be regenerated automatically when needed. +/// + +extension HTTP2ClientRequestHandlerTests { + static var allTests: [(String, (HTTP2ClientRequestHandlerTests) -> () throws -> Void)] { + return [ + ("testResponseBackpressure", testResponseBackpressure), + ("testWriteBackpressure", testWriteBackpressure), + ("testIdleReadTimeout", testIdleReadTimeout), + ] + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift new file mode 100644 index 000000000..dbdf4fd97 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift @@ -0,0 +1,235 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@testable import AsyncHTTPClient +import Logging +import NIO +import NIOHTTP1 +import XCTest + +class HTTP2ClientRequestHandlerTests: XCTestCase { + func testResponseBackpressure() { + let embedded = EmbeddedChannel() + let readEventHandler = ReadEventHitHandler() + let requestHandler = HTTP2ClientRequestHandler(eventLoop: embedded.eventLoop) + let logger = Logger(label: "test") + + XCTAssertNoThrow(try embedded.pipeline.syncOperations.addHandlers([readEventHandler, requestHandler])) + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/")) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + idleReadTimeout: nil, + delegate: delegate + )) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.write(requestBag, promise: nil) + XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) + + XCTAssertNoThrow(try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + }) + XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .end(nil)) + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + + XCTAssertEqual(readEventHandler.readHitCounter, 0) + embedded.read() + XCTAssertEqual(readEventHandler.readHitCounter, 1) + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.head(responseHead))) + + let part0 = ByteBuffer(bytes: 0...3) + let part1 = ByteBuffer(bytes: 4...7) + let part2 = ByteBuffer(bytes: 8...11) + + // part 0. Demand first, read second + XCTAssertEqual(readEventHandler.readHitCounter, 1) + let part0Future = delegate.next() + XCTAssertEqual(readEventHandler.readHitCounter, 1) + embedded.read() + XCTAssertEqual(readEventHandler.readHitCounter, 2) + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.body(part0))) + XCTAssertEqual(try part0Future.wait(), part0) + + // part 1. read first, demand second + + XCTAssertEqual(readEventHandler.readHitCounter, 2) + embedded.read() + XCTAssertEqual(readEventHandler.readHitCounter, 2) + let part1Future = delegate.next() + XCTAssertEqual(readEventHandler.readHitCounter, 3) + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.body(part1))) + XCTAssertEqual(try part1Future.wait(), part1) + + // part 2. Demand first, read second + XCTAssertEqual(readEventHandler.readHitCounter, 3) + let part2Future = delegate.next() + XCTAssertEqual(readEventHandler.readHitCounter, 3) + embedded.read() + XCTAssertEqual(readEventHandler.readHitCounter, 4) + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.body(part2))) + XCTAssertEqual(try part2Future.wait(), part2) + + // end. read first, demand second + XCTAssertEqual(readEventHandler.readHitCounter, 4) + embedded.read() + XCTAssertEqual(readEventHandler.readHitCounter, 4) + let endFuture = delegate.next() + XCTAssertEqual(readEventHandler.readHitCounter, 5) + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.end(nil))) + XCTAssertEqual(try endFuture.wait(), .none) + + XCTAssertNoThrow(try requestBag.task.futureResult.wait()) + } + + func testWriteBackpressure() { + let embedded = EmbeddedChannel() + let requestHandler = HTTP2ClientRequestHandler(eventLoop: embedded.eventLoop) + XCTAssertNoThrow(try embedded.pipeline.syncOperations.addHandler(requestHandler)) + let logger = Logger(label: "test") + + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 50) + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: .stream(length: 100) { writer in + testWriter.start(writer: writer) + })) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + idleReadTimeout: .milliseconds(200), + delegate: delegate + )) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = false + XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) + embedded.write(requestBag, promise: nil) + + // the handler only writes once the channel is writable + XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .none) + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + + XCTAssertNoThrow(try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .POST) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + XCTAssertEqual($0.headers.first(name: "content-length"), "100") + }) + + // the next body write will be executed once we tick the el. before we make the channel + // unwritable + + for index in 0..<50 { + embedded.isWritable = false + testWriter.writabilityChanged(false) + embedded.pipeline.fireChannelWritabilityChanged() + + XCTAssertEqual(testWriter.written, index) + + embedded.embeddedEventLoop.run() + + XCTAssertNoThrow(try embedded.receiveBodyAndVerify { + XCTAssertEqual($0.readableBytes, 2) + }) + + XCTAssertEqual(testWriter.written, index + 1) + + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + } + + embedded.embeddedEventLoop.run() + XCTAssertNoThrow(try embedded.receiveEnd()) + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.head(responseHead))) + embedded.read() + + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.end(nil))) + XCTAssertNoThrow(try requestBag.task.futureResult.wait()) + } + + func testIdleReadTimeout() { + let embedded = EmbeddedChannel() + let readEventHandler = ReadEventHitHandler() + let requestHandler = HTTP2ClientRequestHandler(eventLoop: embedded.eventLoop) + XCTAssertNoThrow(try embedded.pipeline.syncOperations.addHandlers([readEventHandler, requestHandler])) + XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/")) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + idleReadTimeout: .milliseconds(200), + delegate: delegate + )) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.write(requestBag, promise: nil) + + XCTAssertNoThrow(try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + }) + XCTAssertNoThrow(try embedded.receiveEnd()) + + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + + XCTAssertEqual(readEventHandler.readHitCounter, 0) + embedded.read() + XCTAssertEqual(readEventHandler.readHitCounter, 1) + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.head(responseHead))) + + // not sending anything after the head should lead to request fail and connection close + + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(250)) + + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .readTimeout) + } + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests+XCTest.swift new file mode 100644 index 000000000..9f9582d9f --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests+XCTest.swift @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// +// HTTP2ConnectionTests+XCTest.swift +// +import XCTest + +/// +/// NOTE: This file was generated by generate_linux_tests.rb +/// +/// Do NOT edit this file directly as it will be regenerated automatically when needed. +/// + +extension HTTP2ConnectionTests { + static var allTests: [(String, (HTTP2ConnectionTests) -> () throws -> Void)] { + return [ + ("testCreateNewConnectionFailureClosedIO", testCreateNewConnectionFailureClosedIO), + ("testSimpleGetRequest", testSimpleGetRequest), + ("testEveryDoneRequestLeadsToAStreamAvailableCall", testEveryDoneRequestLeadsToAStreamAvailableCall), + ("testCancelAllRunningRequests", testCancelAllRunningRequests), + ] + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift b/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift new file mode 100644 index 000000000..4ef47ce92 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift @@ -0,0 +1,481 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@testable import AsyncHTTPClient +import Logging +import NIO +import NIOConcurrencyHelpers +import NIOHTTP1 +import NIOSSL +import NIOTestUtils +import XCTest + +class HTTP2ConnectionTests: XCTestCase { + func testCreateNewConnectionFailureClosedIO() { + let embedded = EmbeddedChannel() + + XCTAssertNoThrow(try embedded.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 3000)).wait()) + XCTAssertNoThrow(try embedded.close().wait()) + // to really destroy the channel we need to tick once + embedded.embeddedEventLoop.run() + let logger = Logger(label: "test.http2.connection") + + XCTAssertThrowsError(try HTTP2Connection.start( + channel: embedded, + connectionID: 0, + delegate: TestHTTP2ConnectionDelegate(), + configuration: .init(), + logger: logger + ).wait()) + } + + func testSimpleGetRequest() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + let eventLoop = eventLoopGroup.next() + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + + let httpBin = HTTPBin(.http2(compress: false)) + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let connectionCreator = TestConnectionCreator() + let delegate = TestHTTP2ConnectionDelegate() + var maybeHTTP2Connection: HTTP2Connection? + XCTAssertNoThrow(maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + ) + ) + guard let http2Connection = maybeHTTP2Connection else { + return XCTFail("Expected to have an HTTP2 connection here.") + } + + var maybeRequest: HTTPClient.Request? + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + idleReadTimeout: nil, + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + )) + guard let requestBag = maybeRequestBag else { + return XCTFail("Expected to have a request bag at this point") + } + + http2Connection.executeRequest(requestBag) + + XCTAssertEqual(delegate.hitStreamClosed, 0) + var maybeResponse: HTTPClient.Response? + XCTAssertNoThrow(maybeResponse = try requestBag.task.futureResult.wait()) + XCTAssertEqual(maybeResponse?.status, .ok) + XCTAssertEqual(maybeResponse?.version, .http2) + XCTAssertEqual(delegate.hitStreamClosed, 1) + } + + func testEveryDoneRequestLeadsToAStreamAvailableCall() { + class NeverRespondChannelHandler: ChannelInboundHandler { + typealias InboundIn = HTTPServerRequestPart + typealias OutboundOut = HTTPServerResponsePart + + init() {} + + func channelRead(context: ChannelHandlerContext, data: NIOAny) {} + } + + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + let eventLoop = eventLoopGroup.next() + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + + let httpBin = HTTPBin(.http2(compress: false)) + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let connectionCreator = TestConnectionCreator() + let delegate = TestHTTP2ConnectionDelegate() + var maybeHTTP2Connection: HTTP2Connection? + XCTAssertNoThrow(maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + )) + guard let http2Connection = maybeHTTP2Connection else { + return XCTFail("Expected to have an HTTP2 connection here.") + } + defer { XCTAssertNoThrow(try http2Connection.close().wait()) } + + var futures = [EventLoopFuture]() + + XCTAssertEqual(delegate.hitStreamClosed, 0) + + for _ in 0..<100 { + var maybeRequest: HTTPClient.Request? + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + idleReadTimeout: nil, + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + )) + guard let requestBag = maybeRequestBag else { + return XCTFail("Expected to have a request bag at this point") + } + + http2Connection.executeRequest(requestBag) + + futures.append(requestBag.task.futureResult) + } + + for future in futures { + XCTAssertNoThrow(try future.wait()) + } + + XCTAssertEqual(delegate.hitStreamClosed, 100) + XCTAssertTrue(http2Connection.channel.isActive) + } + + func testCancelAllRunningRequests() { + class NeverRespondChannelHandler: ChannelInboundHandler { + typealias InboundIn = HTTPServerRequestPart + typealias OutboundOut = HTTPServerResponsePart + + init() {} + + func channelRead(context: ChannelHandlerContext, data: NIOAny) {} + } + + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + let eventLoop = eventLoopGroup.next() + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + + let httpBin = HTTPBin(.http2(compress: false), handlerFactory: { _ in NeverRespondChannelHandler() }) + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let connectionCreator = TestConnectionCreator() + let delegate = TestHTTP2ConnectionDelegate() + var maybeHTTP2Connection: HTTP2Connection? + XCTAssertNoThrow(maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + ) + ) + guard let http2Connection = maybeHTTP2Connection else { + return XCTFail("Expected to have an HTTP2 connection here.") + } + + var futures = [EventLoopFuture]() + + for _ in 0..<100 { + var maybeRequest: HTTPClient.Request? + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + idleReadTimeout: nil, + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + )) + guard let requestBag = maybeRequestBag else { + return XCTFail("Expected to have a request bag at this point") + } + + http2Connection.executeRequest(requestBag) + + XCTAssertEqual(delegate.hitStreamClosed, 0) + + futures.append(requestBag.task.futureResult) + } + + http2Connection.shutdown() + + for future in futures { + XCTAssertThrowsError(try future.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .cancelled) + } + } + + XCTAssertNoThrow(try http2Connection.closeFuture.wait()) + } +} + +class TestConnectionCreator { + enum Error: Swift.Error { + case alreadyCreatingAnotherConnection + case wantedHTTP2ConnectionButGotHTTP1 + case wantedHTTP1ConnectionButGotHTTP2 + } + + enum State { + case idle + case waitingForHTTP1Connection(EventLoopPromise) + case waitingForHTTP2Connection(EventLoopPromise) + } + + private var state: State = .idle + private let lock = Lock() + + init() {} + + func createHTTP1Connection( + to port: Int, + delegate: HTTP1ConnectionDelegate, + connectionID: HTTPConnectionPool.Connection.ID = 0, + on eventLoop: EventLoop, + logger: Logger = .init(label: "test") + ) throws -> HTTP1Connection { + let request = try! HTTPClient.Request(url: "https://localhost:\(port)") + + var tlsConfiguration = TLSConfiguration.makeClientConfiguration() + tlsConfiguration.certificateVerification = .none + let factory = HTTPConnectionPool.ConnectionFactory( + key: .init(request), + tlsConfiguration: tlsConfiguration, + clientConfiguration: .init(), + sslContextCache: .init(), + allowHTTP2Connections: true + ) + + let promise = try self.lock.withLock { () -> EventLoopPromise in + guard case .idle = self.state else { + throw Error.alreadyCreatingAnotherConnection + } + + let promise = eventLoop.makePromise(of: HTTP1Connection.self) + self.state = .waitingForHTTP1Connection(promise) + return promise + } + + factory.makeConnection( + for: self, + connectionID: connectionID, + http1ConnectionDelegate: delegate, + http2ConnectionDelegate: EmptyHTTP2ConnectionDelegate(), + deadline: .now() + .seconds(2), + eventLoop: eventLoop, + logger: logger + ) + + return try promise.futureResult.wait() + } + + func createHTTP2Connection( + to port: Int, + delegate: HTTP2ConnectionDelegate, + connectionID: HTTPConnectionPool.Connection.ID = 0, + on eventLoop: EventLoop, + logger: Logger = .init(label: "test") + ) throws -> HTTP2Connection { + let request = try! HTTPClient.Request(url: "https://localhost:\(port)") + + var tlsConfiguration = TLSConfiguration.makeClientConfiguration() + tlsConfiguration.certificateVerification = .none + let factory = HTTPConnectionPool.ConnectionFactory( + key: .init(request), + tlsConfiguration: tlsConfiguration, + clientConfiguration: .init(), + sslContextCache: .init(), + allowHTTP2Connections: true + ) + + let promise = try self.lock.withLock { () -> EventLoopPromise in + guard case .idle = self.state else { + throw Error.alreadyCreatingAnotherConnection + } + + let promise = eventLoop.makePromise(of: HTTP2Connection.self) + self.state = .waitingForHTTP2Connection(promise) + return promise + } + + factory.makeConnection( + for: self, + connectionID: connectionID, + http1ConnectionDelegate: EmptyHTTP1ConnectionDelegate(), + http2ConnectionDelegate: delegate, + deadline: .now() + .seconds(2), + eventLoop: eventLoop, + logger: logger + ) + + return try promise.futureResult.wait() + } +} + +extension TestConnectionCreator: HTTPConnectionRequester { + enum EitherPromiseWrapper { + case succeed(EventLoopPromise, SucceedType) + case fail(EventLoopPromise, Error) + + func complete() { + switch self { + case .succeed(let promise, let success): + promise.succeed(success) + case .fail(let promise, let error): + promise.fail(error) + } + } + } + + func http1ConnectionCreated(_ connection: HTTP1Connection) { + let wrapper = self.lock.withLock { + () -> (EitherPromiseWrapper) in + + switch self.state { + case .waitingForHTTP1Connection(let promise): + return .succeed(promise, connection) + + case .waitingForHTTP2Connection(let promise): + return .fail(promise, Error.wantedHTTP2ConnectionButGotHTTP1) + + case .idle: + preconditionFailure("Invalid state") + } + } + wrapper.complete() + } + + func http2ConnectionCreated(_ connection: HTTP2Connection, maximumStreams: Int) { + let wrapper = self.lock.withLock { + () -> (EitherPromiseWrapper) in + + switch self.state { + case .waitingForHTTP1Connection(let promise): + return .fail(promise, Error.wantedHTTP1ConnectionButGotHTTP2) + + case .waitingForHTTP2Connection(let promise): + return .succeed(promise, connection) + + case .idle: + preconditionFailure("Invalid state") + } + } + wrapper.complete() + } + + enum FailPromiseWrapper { + case type1(EventLoopPromise) + case type2(EventLoopPromise) + + func fail(_ error: Swift.Error) { + switch self { + case .type1(let eventLoopPromise): + eventLoopPromise.fail(error) + case .type2(let eventLoopPromise): + eventLoopPromise.fail(error) + } + } + } + + func failedToCreateHTTPConnection(_: HTTPConnectionPool.Connection.ID, error: Swift.Error) { + let wrapper = self.lock.withLock { + () -> (FailPromiseWrapper) in + + switch self.state { + case .waitingForHTTP1Connection(let promise): + return .type1(promise) + + case .waitingForHTTP2Connection(let promise): + return .type2(promise) + + case .idle: + preconditionFailure("Invalid state") + } + } + wrapper.fail(error) + } +} + +class TestHTTP2ConnectionDelegate: HTTP2ConnectionDelegate { + var hitStreamClosed: Int { + self.lock.withLock { self._hitStreamClosed } + } + + var hitGoAwayReceived: Int { + self.lock.withLock { self._hitGoAwayReceived } + } + + var hitConnectionClosed: Int { + self.lock.withLock { self._hitConnectionClosed } + } + + var maxStreamSetting: Int { + self.lock.withLock { self._maxStreamSetting } + } + + private let lock = Lock() + private var _hitStreamClosed: Int = 0 + private var _hitGoAwayReceived: Int = 0 + private var _hitConnectionClosed: Int = 0 + private var _maxStreamSetting: Int = 100 + + init() {} + + func http2Connection(_: HTTP2Connection, newMaxStreamSetting: Int) {} + + func http2ConnectionStreamClosed(_: HTTP2Connection, availableStreams: Int) { + self.lock.withLockVoid { + self._hitStreamClosed += 1 + } + } + + func http2ConnectionGoAwayReceived(_: HTTP2Connection) { + self.lock.withLockVoid { + self._hitGoAwayReceived += 1 + } + } + + func http2ConnectionClosed(_: HTTP2Connection) { + self.lock.withLockVoid { + self._hitConnectionClosed += 1 + } + } +} + +final class EmptyHTTP2ConnectionDelegate: HTTP2ConnectionDelegate { + func http2Connection(_: HTTP2Connection, newMaxStreamSetting: Int) { + preconditionFailure("Unimplemented") + } + + func http2ConnectionStreamClosed(_: HTTP2Connection, availableStreams: Int) { + preconditionFailure("Unimplemented") + } + + func http2ConnectionGoAwayReceived(_: HTTP2Connection) { + preconditionFailure("Unimplemented") + } + + func http2ConnectionClosed(_: HTTP2Connection) { + preconditionFailure("Unimplemented") + } +} + +final class EmptyHTTP1ConnectionDelegate: HTTP1ConnectionDelegate { + func http1ConnectionReleased(_: HTTP1Connection) { + preconditionFailure("Unimplemented") + } + + func http1ConnectionClosed(_: HTTP1Connection) { + preconditionFailure("Unimplemented") + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests+XCTest.swift new file mode 100644 index 000000000..5c7021e23 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests+XCTest.swift @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// +// HTTP2IdleHandlerTests+XCTest.swift +// +import XCTest + +/// +/// NOTE: This file was generated by generate_linux_tests.rb +/// +/// Do NOT edit this file directly as it will be regenerated automatically when needed. +/// + +extension HTTP2IdleHandlerTests { + static var allTests: [(String, (HTTP2IdleHandlerTests) -> () throws -> Void)] { + return [ + ("testReceiveSettingsWithMaxConcurrentStreamSetting", testReceiveSettingsWithMaxConcurrentStreamSetting), + ("testReceiveSettingsWithoutMaxConcurrentStreamSetting", testReceiveSettingsWithoutMaxConcurrentStreamSetting), + ("testEmptySettingsDontOverwriteMaxConcurrentStreamSetting", testEmptySettingsDontOverwriteMaxConcurrentStreamSetting), + ("testOverwriteMaxConcurrentStreamSetting", testOverwriteMaxConcurrentStreamSetting), + ("testGoAwayReceivedBeforeSettings", testGoAwayReceivedBeforeSettings), + ("testGoAwayReceivedAfterSettings", testGoAwayReceivedAfterSettings), + ("testCloseEventBeforeFirstSettings", testCloseEventBeforeFirstSettings), + ("testCloseEventWhileNoOpenStreams", testCloseEventWhileNoOpenStreams), + ("testCloseEventWhileThereAreOpenStreams", testCloseEventWhileThereAreOpenStreams), + ("testGoAwayWhileThereAreOpenStreams", testGoAwayWhileThereAreOpenStreams), + ] + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests.swift new file mode 100644 index 000000000..3d5197a69 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests.swift @@ -0,0 +1,246 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@testable import AsyncHTTPClient +import Logging +import NIO +import NIOHTTP2 +import XCTest + +class HTTP2IdleHandlerTests: XCTestCase { + func testReceiveSettingsWithMaxConcurrentStreamSetting() { + let delegate = MockHTTP2IdleHandlerDelegate() + let idleHandler = HTTP2IdleHandler(delegate: delegate, logger: Logger(label: "test")) + let embedded = EmbeddedChannel(handlers: [idleHandler]) + XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) + + let settingsFrame = HTTP2Frame(streamID: 0, payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)]))) + XCTAssertEqual(delegate.maxStreams, nil) + XCTAssertNoThrow(try embedded.writeInbound(settingsFrame)) + XCTAssertEqual(delegate.maxStreams, 10) + } + + func testReceiveSettingsWithoutMaxConcurrentStreamSetting() { + let delegate = MockHTTP2IdleHandlerDelegate() + let idleHandler = HTTP2IdleHandler(delegate: delegate, logger: Logger(label: "test")) + let embedded = EmbeddedChannel(handlers: [idleHandler]) + XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) + + let settingsFrame = HTTP2Frame(streamID: 0, payload: .settings(.settings([]))) + XCTAssertEqual(delegate.maxStreams, nil) + XCTAssertNoThrow(try embedded.writeInbound(settingsFrame)) + XCTAssertEqual(delegate.maxStreams, 100, "Expected to assume 100 maxConcurrentConnection, if no setting was present") + } + + func testEmptySettingsDontOverwriteMaxConcurrentStreamSetting() { + let delegate = MockHTTP2IdleHandlerDelegate() + let idleHandler = HTTP2IdleHandler(delegate: delegate, logger: Logger(label: "test")) + let embedded = EmbeddedChannel(handlers: [idleHandler]) + XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) + + let settingsFrame = HTTP2Frame(streamID: 0, payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)]))) + XCTAssertEqual(delegate.maxStreams, nil) + XCTAssertNoThrow(try embedded.writeInbound(settingsFrame)) + XCTAssertEqual(delegate.maxStreams, 10) + + let emptySettings = HTTP2Frame(streamID: 0, payload: .settings(.settings([]))) + XCTAssertNoThrow(try embedded.writeInbound(emptySettings)) + XCTAssertEqual(delegate.maxStreams, 10) + } + + func testOverwriteMaxConcurrentStreamSetting() { + let delegate = MockHTTP2IdleHandlerDelegate() + let idleHandler = HTTP2IdleHandler(delegate: delegate, logger: Logger(label: "test")) + let embedded = EmbeddedChannel(handlers: [idleHandler]) + XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) + + let settingsFrame = HTTP2Frame(streamID: 0, payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)]))) + XCTAssertEqual(delegate.maxStreams, nil) + XCTAssertNoThrow(try embedded.writeInbound(settingsFrame)) + XCTAssertEqual(delegate.maxStreams, 10) + + let emptySettings = HTTP2Frame(streamID: 0, payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 20)]))) + XCTAssertNoThrow(try embedded.writeInbound(emptySettings)) + XCTAssertEqual(delegate.maxStreams, 20) + } + + func testGoAwayReceivedBeforeSettings() { + let delegate = MockHTTP2IdleHandlerDelegate() + let idleHandler = HTTP2IdleHandler(delegate: delegate, logger: Logger(label: "test")) + let embedded = EmbeddedChannel(handlers: [idleHandler]) + XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) + + let randomStreamID = HTTP2StreamID((0..() + + for i in 0..<(1...100).randomElement()! { + let streamID = HTTP2StreamID(i) + let event = NIOHTTP2StreamCreatedEvent(streamID: streamID, localInitialWindowSize: nil, remoteInitialWindowSize: nil) + embedded.pipeline.fireUserInboundEventTriggered(event) + openStreams.insert(streamID) + } + + embedded.pipeline.triggerUserOutboundEvent(HTTPConnectionEvent.closeConnection, promise: nil) + XCTAssertTrue(embedded.isActive) + + while let streamID = openStreams.randomElement() { + openStreams.remove(streamID) + + let event = StreamClosedEvent(streamID: streamID, reason: nil) + XCTAssertTrue(embedded.isActive) + embedded.pipeline.fireUserInboundEventTriggered(event) + if openStreams.isEmpty { + XCTAssertFalse(embedded.isActive) + } else { + XCTAssertTrue(embedded.isActive) + } + } + } + + func testGoAwayWhileThereAreOpenStreams() { + let delegate = MockHTTP2IdleHandlerDelegate() + let idleHandler = HTTP2IdleHandler(delegate: delegate, logger: Logger(label: "test")) + let embedded = EmbeddedChannel(handlers: [idleHandler]) + XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) + + let settingsFrame = HTTP2Frame(streamID: 0, payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)]))) + XCTAssertEqual(delegate.maxStreams, nil) + XCTAssertNoThrow(try embedded.writeInbound(settingsFrame)) + XCTAssertEqual(delegate.maxStreams, 10) + + var openStreams = Set() + + for i in 0..<(1...100).randomElement()! { + let streamID = HTTP2StreamID(i) + let event = NIOHTTP2StreamCreatedEvent(streamID: streamID, localInitialWindowSize: nil, remoteInitialWindowSize: nil) + embedded.pipeline.fireUserInboundEventTriggered(event) + openStreams.insert(streamID) + } + + let goAwayStreamID = HTTP2StreamID(openStreams.count) + let goAwayFrame = HTTP2Frame(streamID: goAwayStreamID, payload: .goAway(lastStreamID: 0, errorCode: .http11Required, opaqueData: nil)) + XCTAssertEqual(delegate.goAwayReceived, false) + XCTAssertNoThrow(try embedded.writeInbound(goAwayFrame)) + XCTAssertEqual(delegate.goAwayReceived, true) + XCTAssertEqual(delegate.maxStreams, 10) + + while let streamID = openStreams.randomElement() { + openStreams.remove(streamID) + + let event = StreamClosedEvent(streamID: streamID, reason: nil) + XCTAssertTrue(embedded.isActive) + embedded.pipeline.fireUserInboundEventTriggered(event) + if openStreams.isEmpty { + XCTAssertFalse(embedded.isActive) + } else { + XCTAssertTrue(embedded.isActive) + } + } + } +} + +class MockHTTP2IdleHandlerDelegate: HTTP2IdleHandlerDelegate { + private(set) var maxStreams: Int? + private(set) var goAwayReceived: Bool = false + + private(set) var streamClosedHitCount: Int = 0 + + func http2SettingsReceived(maxStreams: Int) { + self.maxStreams = maxStreams + } + + func http2GoAwayReceived() { + self.goAwayReceived = true + } + + func http2StreamClosed(availableStreams: Int) { + self.streamClosedHitCount += 1 + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift index cc8b22a0b..7fef14658 100644 --- a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift @@ -42,6 +42,7 @@ extension HTTPRequestStateMachineTests { ("testResponseReadingWithBackpressureEndOfResponseAllowsReadEventsToTriggerDirectly", testResponseReadingWithBackpressureEndOfResponseAllowsReadEventsToTriggerDirectly), ("testCancellingARequestInStateInitializedKeepsTheConnectionAlive", testCancellingARequestInStateInitializedKeepsTheConnectionAlive), ("testCancellingARequestBeforeBeingSendKeepsTheConnectionAlive", testCancellingARequestBeforeBeingSendKeepsTheConnectionAlive), + ("testConnectionBecomesWritableBeforeFirstRequest", testConnectionBecomesWritableBeforeFirstRequest), ("testCancellingARequestThatIsSent", testCancellingARequestThatIsSent), ("testRemoteSuddenlyClosesTheConnection", testRemoteSuddenlyClosesTheConnection), ("testReadTimeoutLeadsToFailureWithEverythingAfterBeingIgnored", testReadTimeoutLeadsToFailureWithEverythingAfterBeingIgnored), diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift index 8614d9767..7ad4ecd99 100644 --- a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift @@ -350,6 +350,24 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.requestCancelled(), .failRequest(HTTPClientError.cancelled, .none)) } + func testConnectionBecomesWritableBeforeFirstRequest() { + var state = HTTPRequestStateMachine(isChannelWritable: false) + XCTAssertEqual(state.writabilityChanged(writable: true), .wait) + + // --- sending request + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .none) + XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + + // --- receiving response + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["content-length": "4"]) + XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) + XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) + XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init([responseBody]))) + XCTAssertEqual(state.channelReadComplete(), .wait) + } + func testCancellingARequestThatIsSent() { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index 6aac03817..b31e0d29d 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -32,6 +32,9 @@ import XCTest testCase(HTTP1ConnectionStateMachineTests.allTests), testCase(HTTP1ConnectionTests.allTests), testCase(HTTP1ProxyConnectHandlerTests.allTests), + testCase(HTTP2ClientRequestHandlerTests.allTests), + testCase(HTTP2ConnectionTests.allTests), + testCase(HTTP2IdleHandlerTests.allTests), testCase(HTTPClientCookieTests.allTests), testCase(HTTPClientInternalTests.allTests), testCase(HTTPClientNIOTSTests.allTests),