From 8fbeb16470949229a1defecd801f3760a370109e Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 24 Jun 2021 13:54:30 +0200 Subject: [PATCH 1/3] Add h2 support to HTTPBin --- Package.swift | 2 + .../HTTPClient+SOCKSTests.swift | 2 +- .../HTTPClientInternalTests.swift | 63 ++- .../HTTPClientNIOTSTests.swift | 6 +- .../HTTPClientTestUtils.swift | 371 +++++++++++++----- .../HTTPClientTests.swift | 46 +-- 6 files changed, 348 insertions(+), 142 deletions(-) diff --git a/Package.swift b/Package.swift index 28d229473..3f4a7d186 100644 --- a/Package.swift +++ b/Package.swift @@ -23,6 +23,7 @@ let package = Package( dependencies: [ .package(url: "https://github.com/apple/swift-nio.git", from: "2.30.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.14.0"), + .package(url: "https://github.com/apple/swift-nio-http2.git", from: "1.18.0"), .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.10.0"), .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.11.0"), .package(url: "https://github.com/apple/swift-log.git", from: "1.4.0"), @@ -48,6 +49,7 @@ let package = Package( .product(name: "NIO", package: "swift-nio"), .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), .product(name: "NIOSSL", package: "swift-nio-ssl"), + .product(name: "NIOHTTP2", package: "swift-nio-http2"), "AsyncHTTPClient", .product(name: "NIOFoundationCompat", package: "swift-nio"), .product(name: "NIOTestUtils", package: "swift-nio"), diff --git a/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests.swift b/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests.swift index 2c65a7f3b..2de0d6e4a 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests.swift @@ -23,7 +23,7 @@ class HTTPClientSOCKSTests: XCTestCase { var clientGroup: EventLoopGroup! var serverGroup: EventLoopGroup! - var defaultHTTPBin: HTTPBin! + var defaultHTTPBin: HTTPBin! var defaultClient: HTTPClient! var backgroundLogStore: CollectEverythingLogHandler.LogStore! diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift index 89648c726..9f170217c 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift @@ -368,11 +368,49 @@ class HTTPClientInternalTests: XCTestCase { func didFinishRequest(task: HTTPClient.Task) throws {} } + final class WriteAfterFutureSucceedsHandler: ChannelInboundHandler { + typealias InboundIn = HTTPServerRequestPart + typealias OutboundOut = HTTPServerResponsePart + + let bodyFuture: EventLoopFuture + let endFuture: EventLoopFuture + + init(bodyFuture: EventLoopFuture, endFuture: EventLoopFuture) { + self.bodyFuture = bodyFuture + self.endFuture = endFuture + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + switch self.unwrapInboundIn(data) { + case .head: + let head = HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok) + context.writeAndFlush(wrapOutboundOut(.head(head)), promise: nil) + case .body: + // ignore + break + case .end: + self.bodyFuture.hop(to: context.eventLoop).whenSuccess { + let buffer = context.channel.allocator.buffer(string: "1234") + context.writeAndFlush(self.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: nil) + } + + self.endFuture.hop(to: context.eventLoop).whenSuccess { + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + } + } + } + } + // cannot test with NIOTS as `maxMessagesPerRead` is not supported let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) let httpClient = HTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup)) - let promise = httpClient.eventLoopGroup.next().makePromise(of: Channel.self) - let httpBin = HTTPBin(channelPromise: promise) + let delegate = BackpressureTestDelegate(eventLoop: httpClient.eventLoopGroup.next()) + let httpBin = HTTPBin { _ in + WriteAfterFutureSucceedsHandler( + bodyFuture: delegate.optionsApplied.futureResult, + endFuture: delegate.backpressurePromise.futureResult + ) + } defer { XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) @@ -381,17 +419,14 @@ class HTTPClientInternalTests: XCTestCase { } let request = try Request(url: "http://localhost:\(httpBin.port)/custom") - let delegate = BackpressureTestDelegate(eventLoop: httpClient.eventLoopGroup.next()) - let future = httpClient.execute(request: request, delegate: delegate).futureResult - let channel = try promise.futureResult.wait() + let requestFuture = httpClient.execute(request: request, delegate: delegate).futureResult // We need to wait for channel options that limit NIO to sending only one byte at a time. try delegate.optionsApplied.futureResult.wait() + // // Send 4 bytes, but only one should be received until the backpressure promise is succeeded. - let buffer = channel.allocator.buffer(string: "1234") - try channel.writeAndFlush(HTTPServerResponsePart.body(.byteBuffer(buffer))).wait() // Now we wait until message is delivered to client channel pipeline try delegate.messageReceived.futureResult.wait() @@ -399,9 +434,7 @@ class HTTPClientInternalTests: XCTestCase { // Succeed the backpressure promise. delegate.backpressurePromise.succeed(()) - - try channel.writeAndFlush(HTTPServerResponsePart.end(nil)).wait() - try future.wait() + try requestFuture.wait() // At this point all other bytes should be delivered. XCTAssertEqual(delegate.reads, 4) @@ -602,7 +635,7 @@ class HTTPClientInternalTests: XCTestCase { } func testResponseConnectionCloseGet() throws { - let httpBin = HTTPBin(ssl: false) + let httpBin = HTTPBin(.http1_1()) let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: HTTPClient.Configuration(certificateVerification: .none)) defer { @@ -756,14 +789,14 @@ class HTTPClientInternalTests: XCTestCase { struct NoChannelError: Error {} let client = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup)) - var maybeServersAndChannels: [(HTTPBin, Channel)]? + var maybeServersAndChannels: [(HTTPBin, Channel)]? XCTAssertNoThrow(maybeServersAndChannels = try (0..<10).map { _ in let web = HTTPBin() defer { XCTAssertNoThrow(try web.shutdown()) } - let req = try! HTTPClient.Request(url: "http://localhost:\(web.serverChannel.localAddress!.port!)/get", + let req = try! HTTPClient.Request(url: "http://localhost:\(web.port)/get", method: .GET, body: nil) var maybeConnection: Connection? @@ -847,7 +880,7 @@ class HTTPClientInternalTests: XCTestCase { XCTAssertNoThrow(try client.syncShutdown()) } - let req = try! HTTPClient.Request(url: "http://localhost:\(web.serverChannel.localAddress!.port!)/get", + let req = try! HTTPClient.Request(url: "http://localhost:\(web.port)/get", method: .GET, body: nil) @@ -1083,7 +1116,7 @@ class HTTPClientInternalTests: XCTestCase { let el1 = elg.next() let el2 = elg.next() - let httpBin = HTTPBin(refusesConnections: true) + let httpBin = HTTPBin(.refuse) let client = HTTPClient(eventLoopGroupProvider: .shared(elg)) defer { diff --git a/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests.swift index 8edb24b75..47ecd7a40 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests.swift @@ -52,7 +52,7 @@ class HTTPClientNIOTSTests: XCTestCase { func testTLSFailError() { guard isTestingNIOTS() else { return } - let httpBin = HTTPBin(ssl: true) + let httpBin = HTTPBin(.http1_1(ssl: true)) let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup)) defer { XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) @@ -76,7 +76,7 @@ class HTTPClientNIOTSTests: XCTestCase { func testConnectionFailError() { guard isTestingNIOTS() else { return } - let httpBin = HTTPBin(ssl: true) + let httpBin = HTTPBin(.http1_1(ssl: true)) let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: .init(timeout: .init(connect: .milliseconds(100), read: .milliseconds(100)))) @@ -96,7 +96,7 @@ class HTTPClientNIOTSTests: XCTestCase { func testTLSVersionError() { guard isTestingNIOTS() else { return } #if canImport(Network) - let httpBin = HTTPBin(ssl: true) + let httpBin = HTTPBin(.http1_1(ssl: true)) var tlsConfig = TLSConfiguration.makeClientConfiguration() tlsConfig.certificateVerification = .none tlsConfig.minimumTLSVersion = .tlsv11 diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index 7a264e33f..edd064e4b 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -18,8 +18,10 @@ import Logging import NIO import NIOConcurrencyHelpers import NIOHTTP1 +import NIOHTTP2 import NIOHTTPCompression import NIOSSL +import NIOTLS import NIOTransportServices import XCTest @@ -264,22 +266,50 @@ enum TemporaryFileHelpers { } } -internal final class HTTPBin { - let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) - let serverChannel: Channel - let isShutdown: NIOAtomic = .makeAtomic(value: false) - var connections: NIOAtomic - var connectionCount: NIOAtomic = .makeAtomic(value: 0) - private let activeConnCounterHandler: CountActiveConnectionsHandler - var activeConnections: Int { - return self.activeConnCounterHandler.currentlyActiveConnections - } +enum TestTLS { + static let certificate = try! NIOSSLCertificate(bytes: Array(cert.utf8), format: .pem) + static let privateKey = try! NIOSSLPrivateKey(bytes: Array(key.utf8), format: .pem) +} +internal final class HTTPBin where + RequestHandler.InboundIn == HTTPServerRequestPart, + RequestHandler.OutboundOut == HTTPServerResponsePart { enum BindTarget { case unixDomainSocket(String) case localhostIPv4RandomPort } + enum Mode { + // refuses all connections + case refuse + // supports http1.1 connections only, which can be either plain text or encrypted + case http1_1(ssl: Bool = false, compress: Bool = false) + // supports http1.1 and http2 connections which must be always encrypted + case http2(compress: Bool) + + // supports request decompression and http response compression + var compress: Bool { + switch self { + case .refuse: + return false + case .http1_1(ssl: _, compress: let compress), .http2(compress: let compress): + return compress + } + } + } + + enum Proxy { + case none + case simulate(authorization: String?) + } + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + + private let activeConnCounterHandler: CountActiveConnectionsHandler + var activeConnections: Int { + return self.activeConnCounterHandler.currentlyActiveConnections + } + var port: Int { return Int(self.serverChannel.localAddress!.port!) } @@ -288,21 +318,22 @@ internal final class HTTPBin { return self.serverChannel.localAddress! } - static func configureTLS(channel: Channel) -> EventLoopFuture { - let configuration = TLSConfiguration.makeServerConfiguration(certificateChain: [.certificate(try! NIOSSLCertificate(bytes: Array(cert.utf8), format: .pem))], - privateKey: .privateKey(try! NIOSSLPrivateKey(bytes: Array(key.utf8), format: .pem))) - let context = try! NIOSSLContext(configuration: configuration) - return channel.pipeline.addHandler(NIOSSLServerHandler(context: context), position: .first) - } + private let mode: Mode + private let sslContext: NIOSSLContext? + private var serverChannel: Channel! + private let isShutdown: NIOAtomic = .makeAtomic(value: false) + private let handlerFactory: (Int) -> (RequestHandler) + + init( + _ mode: Mode = .http1_1(ssl: false, compress: false), + proxy: Proxy = .none, + bindTarget: BindTarget = .localhostIPv4RandomPort, + handlerFactory: @escaping (Int) -> (RequestHandler) + ) { + self.mode = mode + self.sslContext = HTTPBin.sslContext(for: mode) + self.handlerFactory = handlerFactory - init(ssl: Bool = false, - compress: Bool = false, - bindTarget: BindTarget = .localhostIPv4RandomPort, - simulateProxy: HTTPProxySimulator.Option? = nil, - channelPromise: EventLoopPromise? = nil, - connectionDelay: TimeAmount = .seconds(0), - maxChannelAge: TimeAmount? = nil, - refusesConnections: Bool = false) { let socketAddress: SocketAddress switch bindTarget { case .localhostIPv4RandomPort: @@ -314,47 +345,197 @@ internal final class HTTPBin { let activeConnCounterHandler = CountActiveConnectionsHandler() self.activeConnCounterHandler = activeConnCounterHandler - let connections = NIOAtomic.makeAtomic(value: 0) - self.connections = connections + let connectionIDAtomic = NIOAtomic.makeAtomic(value: 0) self.serverChannel = try! ServerBootstrap(group: self.group) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .serverChannelInitializer { channel in channel.pipeline.addHandler(activeConnCounterHandler) }.childChannelInitializer { channel in - guard !refusesConnections else { - return channel.eventLoop.makeFailedFuture(HTTPBinError.refusedConnection) - } - return channel.eventLoop.scheduleTask(in: connectionDelay) {}.futureResult.flatMap { - channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true).flatMap { - if compress { - return channel.pipeline.addHandler(HTTPResponseCompressor()) - } else { - return channel.eventLoop.makeSucceededFuture(()) - } + do { + let connectionID = connectionIDAtomic.add(1) + + if case .refuse = mode { + throw HTTPBinError.refusedConnection } - .flatMap { - if let simulateProxy = simulateProxy { - let responseEncoder = HTTPResponseEncoder() - let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) - - return channel.pipeline.addHandlers([responseEncoder, requestDecoder, HTTPProxySimulator(option: simulateProxy, encoder: responseEncoder, decoder: requestDecoder)], position: .first) - } else { - return channel.eventLoop.makeSucceededFuture(()) - } - }.flatMap { - if ssl { - return HTTPBin.configureTLS(channel: channel).flatMap { - channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise, maxChannelAge: maxChannelAge, connectionId: connections.add(1))) - } - } else { - return channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise, maxChannelAge: maxChannelAge, connectionId: connections.add(1))) - } + + // if we need to simulate a proxy, we need to add those handlers first + if case .simulate(authorization: let expectedAuthorization) = proxy { + try self.syncAddHTTPProxyHandlers( + to: channel, + connectionID: connectionID, + expectedAuthroization: expectedAuthorization + ) + return channel.eventLoop.makeSucceededVoidFuture() } + + // if a connection has been established, we need to negotiate TLS before + // anything else. Depending on the negotiation, the HTTPHandlers will be added. + if let sslContext = self.sslContext { + try self.addTLSHandlerAndUpgrader( + to: channel, + sslContext: sslContext, + connectionID: connectionID + ) + return channel.eventLoop.makeSucceededVoidFuture() + } + + // if neither HTTP Proxy nor TLS are wanted, we can add HTTP1 handlers directly + try self.syncAddHTTP1Handlers(to: channel, connectionID: connectionID) + return channel.eventLoop.makeSucceededVoidFuture() + } catch { + return channel.eventLoop.makeFailedFuture(error) } }.bind(to: socketAddress).wait() } + private func syncAddHTTPProxyHandlers( + to channel: Channel, + connectionID: Int, + expectedAuthroization: String? + ) throws { + let sync = channel.pipeline.syncOperations + let promise = channel.eventLoop.makePromise(of: Void.self) + + let responseEncoder = HTTPResponseEncoder() + let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) + let proxySimulator = HTTPProxySimulator(promise: promise, expectedAuhorization: expectedAuthroization) + + try sync.addHandler(responseEncoder) + try sync.addHandler(requestDecoder) + try sync.addHandler(proxySimulator) + + promise.futureResult.flatMap { _ in + channel.pipeline.removeHandler(proxySimulator) + }.flatMap { _ in + channel.pipeline.removeHandler(responseEncoder) + }.flatMap { _ in + channel.pipeline.removeHandler(requestDecoder) + }.whenComplete { result in + switch result { + case .failure: + channel.close(mode: .all, promise: nil) + case .success: + self.httpProxyEstablished(channel, connectionID: connectionID) + } + } + } + + func syncAddHTTP1Handlers(to channel: Channel, connectionID: Int) throws { + let sync = channel.pipeline.syncOperations + try sync.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true) + + if self.mode.compress { + try sync.addHandler(HTTPResponseCompressor()) + } + + try sync.addHandler(self.handlerFactory(connectionID)) + } + + private func httpProxyEstablished(_ channel: Channel, connectionID: Int) { + do { + // if a connection has been established, we need to negotiate TLS before + // anything else. Depending on the negotiation, the HTTPHandlers will be added. + if let sslContext = self.sslContext { + try self.addTLSHandlerAndUpgrader( + to: channel, + sslContext: sslContext, + connectionID: connectionID + ) + return + } + + // if neither HTTP Proxy nor TLS are wanted, we can add HTTP1 handlers directly + try self.syncAddHTTP1Handlers(to: channel, connectionID: connectionID) + } catch { + // in case of an while modifying the pipeline we should close the connection + channel.close(mode: .all, promise: nil) + } + } + + private static func tlsConfiguration(for mode: Mode) -> TLSConfiguration? { + var configuration: TLSConfiguration? + + switch mode { + case .refuse, .http1_1(ssl: false, compress: _): + break + case .http2: + configuration = .makeServerConfiguration( + certificateChain: [.certificate(TestTLS.certificate)], + privateKey: .privateKey(TestTLS.privateKey) + ) + configuration!.applicationProtocols = NIOHTTP2SupportedALPNProtocols + case .http1_1(ssl: true, compress: _): + configuration = .makeServerConfiguration( + certificateChain: [.certificate(TestTLS.certificate)], + privateKey: .privateKey(TestTLS.privateKey) + ) + } + + return configuration + } + + private static func sslContext(for mode: Mode) -> NIOSSLContext? { + if let tlsConfiguration = self.tlsConfiguration(for: mode) { + return try! NIOSSLContext(configuration: tlsConfiguration) + } + return nil + } + + private func addTLSHandlerAndUpgrader(to channel: Channel, sslContext: NIOSSLContext, connectionID: Int) throws { + let sslHandler = NIOSSLServerHandler(context: sslContext) + + // copy pasted from NIOHTTP2 + let alpnHandler = ApplicationProtocolNegotiationHandler { result in + do { + switch result { + case .negotiated("h2"): + // Successful upgrade to HTTP/2. Let the user configure the pipeline. + let http2Handler = NIOHTTP2Handler( + mode: .server, + initialSettings: NIOHTTP2.nioDefaultSettings + ) + let multiplexer = HTTP2StreamMultiplexer( + mode: .server, + channel: channel, + inboundStreamInitializer: { channel in + do { + let sync = channel.pipeline.syncOperations + + try sync.addHandler(HTTP2FramePayloadToHTTP1ServerCodec()) + try sync.addHandler(self.handlerFactory(connectionID)) + + return channel.eventLoop.makeSucceededVoidFuture() + } catch { + return channel.eventLoop.makeFailedFuture(error) + } + } + ) + + let sync = channel.pipeline.syncOperations + + try sync.addHandler(http2Handler) + try sync.addHandler(multiplexer) + case .negotiated("http/1.1"), .fallback: + // Explicit or implicit HTTP/1.1 choice. + try self.syncAddHTTP1Handlers(to: channel, connectionID: connectionID) + case .negotiated: + // We negotiated something that isn't HTTP/1.1. This is a bad scene, and is a good indication + // of a user configuration error. We're going to close the connection directly. + channel.close(mode: .all, promise: nil) + throw NIOHTTP2Errors.invalidALPNToken() + } + + return channel.eventLoop.makeSucceededVoidFuture() + } catch { + return channel.eventLoop.makeFailedFuture(error) + } + } + + try channel.pipeline.syncOperations.addHandler(alpnHandler) + try channel.pipeline.syncOperations.addHandler(sslHandler, position: .before(alpnHandler)) + } + func shutdown() throws { self.isShutdown.store(true) try self.group.syncShutdownGracefully() @@ -365,8 +546,19 @@ internal final class HTTPBin { } } +extension HTTPBin where RequestHandler == HTTPBinHandler { + convenience init( + _ mode: Mode = .http1_1(ssl: false, compress: false), + proxy: Proxy = .none, + bindTarget: BindTarget = .localhostIPv4RandomPort + ) { + self.init(mode, proxy: proxy, bindTarget: bindTarget) { HTTPBinHandler(connectionID: $0) } + } +} + enum HTTPBinError: Error { case refusedConnection + case invalidProxyRequest } final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { @@ -374,21 +566,16 @@ final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { typealias InboundOut = HTTPServerResponsePart typealias OutboundOut = HTTPServerResponsePart - enum Option { - case plaintext - case tls - } + // the promise to succeed, once the proxy connection is setup + let promise: EventLoopPromise + let expectedAuhorization: String? - let option: Option - let encoder: HTTPResponseEncoder - let decoder: ByteToMessageHandler var head: HTTPResponseHead - init(option: Option, encoder: HTTPResponseEncoder, decoder: ByteToMessageHandler) { - self.option = option - self.encoder = encoder - self.decoder = decoder - self.head = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .ok, headers: .init([("Content-Length", "0"), ("Connection", "close")])) + init(promise: EventLoopPromise, expectedAuhorization: String?) { + self.promise = promise + self.expectedAuhorization = expectedAuhorization + self.head = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .ok, headers: .init([("Content-Length", "0")])) } func channelRead(context: ChannelHandlerContext, data: NIOAny) { @@ -396,27 +583,28 @@ final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { switch request { case .head(let head): guard head.method == .CONNECT else { - fatalError("Expected a CONNECT request") + self.head.status = .badRequest + return } - if head.headers.contains(name: "proxy-authorization") { - if head.headers["proxy-authorization"].first != "Basic YWxhZGRpbjpvcGVuc2VzYW1l" { + + if let expectedAuhorization = self.expectedAuhorization { + guard let authorization = head.headers["proxy-authorization"].first, + expectedAuhorization == authorization else { self.head.status = .proxyAuthenticationRequired + return } } + case .body: () case .end: + let okay = self.head.status != .ok context.write(self.wrapOutboundOut(.head(self.head)), promise: nil) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) - - context.channel.pipeline.removeHandler(self, promise: nil) - context.channel.pipeline.removeHandler(self.decoder, promise: nil) - context.channel.pipeline.removeHandler(self.encoder, promise: nil) - - switch self.option { - case .tls: - _ = HTTPBin.configureTLS(channel: context.channel) - case .plaintext: break + if okay { + self.promise.fail(HTTPBinError.invalidProxyRequest) + } else { + self.promise.succeed(()) } } } @@ -447,37 +635,21 @@ internal struct RequestInfo: Codable { var connectionNumber: Int } -internal final class HttpBinHandler: ChannelInboundHandler { +internal final class HTTPBinHandler: ChannelInboundHandler { typealias InboundIn = HTTPServerRequestPart typealias OutboundOut = HTTPServerResponsePart - let channelPromise: EventLoopPromise? var resps = CircularBuffer() var responseHeaders = HTTPHeaders() var delay: TimeAmount = .seconds(0) let creationDate = Date() - let maxChannelAge: TimeAmount? var shouldClose = false var isServingRequest = false - let connectionId: Int + let connectionID: Int var requestId: Int = 0 - init(channelPromise: EventLoopPromise? = nil, maxChannelAge: TimeAmount? = nil, connectionId: Int) { - self.channelPromise = channelPromise - self.maxChannelAge = maxChannelAge - self.connectionId = connectionId - } - - func handlerAdded(context: ChannelHandlerContext) { - if let maxChannelAge = self.maxChannelAge { - context.eventLoop.scheduleTask(in: maxChannelAge) { - if !self.isServingRequest { - context.close(promise: nil) - } else { - self.shouldClose = true - } - } - } + init(connectionID: Int) { + self.connectionID = connectionID } func parseAndSetOptions(from head: HTTPRequestHead) { @@ -675,7 +847,6 @@ internal final class HttpBinHandler: ChannelInboundHandler { response.add(body) self.resps.prepend(response) case .end: - self.channelPromise?.succeed(context.channel) if self.resps.isEmpty { return } @@ -685,14 +856,14 @@ internal final class HttpBinHandler: ChannelInboundHandler { if let body = response.body { let requestInfo = RequestInfo(data: String(buffer: body), requestNumber: self.requestId, - connectionNumber: self.connectionId) + connectionNumber: self.connectionID) let responseBody = try! JSONEncoder().encodeAsByteBuffer(requestInfo, allocator: context.channel.allocator) context.write(wrapOutboundOut(.body(.byteBuffer(responseBody))), promise: nil) } else { let requestInfo = RequestInfo(data: "", requestNumber: self.requestId, - connectionNumber: self.connectionId) + connectionNumber: self.connectionID) let responseBody = try! JSONEncoder().encodeAsByteBuffer(requestInfo, allocator: context.channel.allocator) context.write(wrapOutboundOut(.body(.byteBuffer(responseBody))), promise: nil) diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 49cda659c..e72fa3f67 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -32,7 +32,7 @@ class HTTPClientTests: XCTestCase { var clientGroup: EventLoopGroup! var serverGroup: EventLoopGroup! - var defaultHTTPBin: HTTPBin! + var defaultHTTPBin: HTTPBin! var defaultClient: HTTPClient! var backgroundLogStore: CollectEverythingLogHandler.LogStore! @@ -249,7 +249,7 @@ class HTTPClientTests: XCTestCase { func testConvenienceExecuteMethodsOverSecureSocket() throws { XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let localSocketPathHTTPBin = HTTPBin(ssl: true, bindTarget: .unixDomainSocket(path)) + let localSocketPathHTTPBin = HTTPBin(.http1_1(ssl: true, compress: false), bindTarget: .unixDomainSocket(path)) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: HTTPClient.Configuration(certificateVerification: .none)) defer { @@ -288,7 +288,7 @@ class HTTPClientTests: XCTestCase { } func testGetHttps() throws { - let localHTTPBin = HTTPBin(ssl: true) + let localHTTPBin = HTTPBin(.http1_1(ssl: true)) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: HTTPClient.Configuration(certificateVerification: .none)) defer { @@ -301,7 +301,7 @@ class HTTPClientTests: XCTestCase { } func testGetHttpsWithIP() throws { - let localHTTPBin = HTTPBin(ssl: true) + let localHTTPBin = HTTPBin(.http1_1(ssl: true)) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: HTTPClient.Configuration(certificateVerification: .none)) defer { @@ -320,7 +320,7 @@ class HTTPClientTests: XCTestCase { defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } - let localHTTPBin = HTTPBin(ssl: true) + let localHTTPBin = HTTPBin(.http1_1(ssl: true)) let localClient = HTTPClient(eventLoopGroupProvider: .shared(group), configuration: HTTPClient.Configuration(certificateVerification: .none)) defer { @@ -333,7 +333,7 @@ class HTTPClientTests: XCTestCase { } func testPostHttps() throws { - let localHTTPBin = HTTPBin(ssl: true) + let localHTTPBin = HTTPBin(.http1_1(ssl: true)) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: HTTPClient.Configuration(certificateVerification: .none)) defer { @@ -352,7 +352,7 @@ class HTTPClientTests: XCTestCase { } func testHttpRedirect() throws { - let httpsBin = HTTPBin(ssl: true) + let httpsBin = HTTPBin(.http1_1(ssl: true)) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 10, allowCycles: true))) @@ -370,7 +370,7 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { httpSocketPath in XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { httpsSocketPath in let socketHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(httpSocketPath)) - let socketHTTPSBin = HTTPBin(ssl: true, bindTarget: .unixDomainSocket(httpsSocketPath)) + let socketHTTPSBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(httpsSocketPath)) defer { XCTAssertNoThrow(try socketHTTPBin.shutdown()) XCTAssertNoThrow(try socketHTTPSBin.shutdown()) @@ -647,7 +647,7 @@ class HTTPClientTests: XCTestCase { } func testProxyPlaintext() throws { - let localHTTPBin = HTTPBin(simulateProxy: .plaintext) + let localHTTPBin = HTTPBin(proxy: .simulate(authorization: nil)) let localClient = HTTPClient( eventLoopGroupProvider: .shared(self.clientGroup), configuration: .init(proxy: .server(host: "localhost", port: localHTTPBin.port)) @@ -661,7 +661,7 @@ class HTTPClientTests: XCTestCase { } func testProxyTLS() throws { - let localHTTPBin = HTTPBin(simulateProxy: .tls) + let localHTTPBin = HTTPBin(.http1_1(ssl: true), proxy: .simulate(authorization: nil)) let localClient = HTTPClient( eventLoopGroupProvider: .shared(self.clientGroup), configuration: .init( @@ -678,7 +678,7 @@ class HTTPClientTests: XCTestCase { } func testProxyPlaintextWithCorrectlyAuthorization() throws { - let localHTTPBin = HTTPBin(simulateProxy: .plaintext) + let localHTTPBin = HTTPBin(proxy: .simulate(authorization: "Basic YWxhZGRpbjpvcGVuc2VzYW1l")) let localClient = HTTPClient( eventLoopGroupProvider: .shared(self.clientGroup), configuration: .init(proxy: .server(host: "localhost", port: localHTTPBin.port, authorization: .basic(username: "aladdin", password: "opensesame"))) @@ -692,7 +692,7 @@ class HTTPClientTests: XCTestCase { } func testProxyPlaintextWithIncorrectlyAuthorization() throws { - let localHTTPBin = HTTPBin(simulateProxy: .plaintext) + let localHTTPBin = HTTPBin(proxy: .simulate(authorization: "Basic YWxhZGRpbjpvcGVuc2VzYW1l")) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: .init(proxy: .server(host: "localhost", port: localHTTPBin.port, @@ -924,7 +924,7 @@ class HTTPClientTests: XCTestCase { } func testDecompression() throws { - let localHTTPBin = HTTPBin(compress: true) + let localHTTPBin = HTTPBin(.http1_1(compress: true)) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: .init(decompression: .enabled(limit: .none))) @@ -961,7 +961,7 @@ class HTTPClientTests: XCTestCase { } func testDecompressionLimit() throws { - let localHTTPBin = HTTPBin(compress: true) + let localHTTPBin = HTTPBin(.http1_1(compress: true)) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: .init(decompression: .enabled(limit: .ratio(1)))) defer { @@ -982,7 +982,7 @@ class HTTPClientTests: XCTestCase { } func testLoopDetectionRedirectLimit() throws { - let localHTTPBin = HTTPBin(ssl: true) + let localHTTPBin = HTTPBin(.http1_1(ssl: true)) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 5, allowCycles: false))) @@ -997,7 +997,7 @@ class HTTPClientTests: XCTestCase { } func testCountRedirectLimit() throws { - let localHTTPBin = HTTPBin(ssl: true) + let localHTTPBin = HTTPBin(.http1_1(ssl: true)) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 10, allowCycles: true))) @@ -1196,7 +1196,7 @@ class HTTPClientTests: XCTestCase { } func testStressGetHttps() throws { - let localHTTPBin = HTTPBin(ssl: true) + let localHTTPBin = HTTPBin(.http1_1(ssl: true)) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: HTTPClient.Configuration(certificateVerification: .none)) defer { @@ -1252,7 +1252,7 @@ class HTTPClientTests: XCTestCase { } func testFailingConnectionIsReleased() { - let localHTTPBin = HTTPBin(refusesConnections: true) + let localHTTPBin = HTTPBin(.refuse) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup)) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -1683,7 +1683,7 @@ class HTTPClientTests: XCTestCase { func testHTTPSPlusUNIX() { // Here, we're testing a URL where the UNIX domain socket is encoded as the host name XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let localHTTPBin = HTTPBin(ssl: true, bindTarget: .unixDomainSocket(path)) + let localHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(path)) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: HTTPClient.Configuration(certificateVerification: .none)) defer { @@ -2318,7 +2318,7 @@ class HTTPClientTests: XCTestCase { }) backgroundLogger.logLevel = .trace - let localSocketPathHTTPBin = HTTPBin(ssl: true, bindTarget: .unixDomainSocket(path)) + let localSocketPathHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(path)) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: HTTPClient.Configuration(certificateVerification: .none), backgroundActivityLogger: backgroundLogger) @@ -2417,7 +2417,7 @@ class HTTPClientTests: XCTestCase { }) backgroundLogger.logLevel = .trace - let localSocketPathHTTPBin = HTTPBin(ssl: true, bindTarget: .unixDomainSocket(path)) + let localSocketPathHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(path)) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: HTTPClient.Configuration(certificateVerification: .none), backgroundActivityLogger: backgroundLogger) @@ -2862,7 +2862,7 @@ class HTTPClientTests: XCTestCase { tlsConfig.minimumTLSVersion = .tlsv13 tlsConfig.maximumTLSVersion = .tlsv12 tlsConfig.certificateVerification = .none - let localHTTPBin = HTTPBin(ssl: true) + let localHTTPBin = HTTPBin(.http1_1(ssl: true)) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: HTTPClient.Configuration(tlsConfiguration: tlsConfig)) defer { @@ -2943,7 +2943,7 @@ class HTTPClientTests: XCTestCase { timeout: .init(), ignoreUncleanSSLShutdown: false, decompression: .disabled) - let localHTTPBin = HTTPBin(ssl: true) + let localHTTPBin = HTTPBin(.http1_1(ssl: true)) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: configuration) let decoder = JSONDecoder() From c4c7212f03740b341de1c2290caef64d5851c839 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Thu, 24 Jun 2021 20:21:03 +0200 Subject: [PATCH 2/3] Code review --- Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index edd064e4b..5e75d5cc6 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -342,15 +342,14 @@ internal final class HTTPBin where socketAddress = try! SocketAddress(unixDomainSocketPath: path) } - let activeConnCounterHandler = CountActiveConnectionsHandler() - self.activeConnCounterHandler = activeConnCounterHandler + self.activeConnCounterHandler = CountActiveConnectionsHandler() let connectionIDAtomic = NIOAtomic.makeAtomic(value: 0) self.serverChannel = try! ServerBootstrap(group: self.group) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .serverChannelInitializer { channel in - channel.pipeline.addHandler(activeConnCounterHandler) + channel.pipeline.addHandler(self.activeConnCounterHandler) }.childChannelInitializer { channel in do { let connectionID = connectionIDAtomic.add(1) @@ -364,7 +363,7 @@ internal final class HTTPBin where try self.syncAddHTTPProxyHandlers( to: channel, connectionID: connectionID, - expectedAuthroization: expectedAuthorization + expectedAuthorization: expectedAuthorization ) return channel.eventLoop.makeSucceededVoidFuture() } @@ -392,14 +391,14 @@ internal final class HTTPBin where private func syncAddHTTPProxyHandlers( to channel: Channel, connectionID: Int, - expectedAuthroization: String? + expectedAuthorization: String? ) throws { let sync = channel.pipeline.syncOperations let promise = channel.eventLoop.makePromise(of: Void.self) let responseEncoder = HTTPResponseEncoder() let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) - let proxySimulator = HTTPProxySimulator(promise: promise, expectedAuhorization: expectedAuthroization) + let proxySimulator = HTTPProxySimulator(promise: promise, expectedAuhorization: expectedAuthorization) try sync.addHandler(responseEncoder) try sync.addHandler(requestDecoder) From 3bf080c6f9dfe328306816ef87786377d33b5b3f Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Fri, 25 Jun 2021 11:44:18 +0200 Subject: [PATCH 3/3] Update Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift Co-authored-by: David Evans --- Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift | 1 - 1 file changed, 1 deletion(-) diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift index 9f170217c..807f111f0 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift @@ -425,7 +425,6 @@ class HTTPClientInternalTests: XCTestCase { // We need to wait for channel options that limit NIO to sending only one byte at a time. try delegate.optionsApplied.futureResult.wait() - // // Send 4 bytes, but only one should be received until the backpressure promise is succeeded. // Now we wait until message is delivered to client channel pipeline