diff --git a/Package.swift b/Package.swift index 7078fb4d9..f24f15d28 100644 --- a/Package.swift +++ b/Package.swift @@ -23,11 +23,12 @@ let package = Package( dependencies: [ .package(url: "https://github.com/apple/swift-nio.git", from: "2.8.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.0.0"), + .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.3.0"), ], targets: [ .target( name: "AsyncHTTPClient", - dependencies: ["NIO", "NIOHTTP1", "NIOSSL", "NIOConcurrencyHelpers"] + dependencies: ["NIO", "NIOHTTP1", "NIOSSL", "NIOConcurrencyHelpers", "NIOHTTPCompression"] ), .testTarget( name: "AsyncHTTPClientTests", diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index 76f59c9ac..4e1f1dd18 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -16,6 +16,7 @@ import Foundation import NIO import NIOConcurrencyHelpers import NIOHTTP1 +import NIOHTTPCompression import NIOSSL /// HTTPClient class provides API for request execution. @@ -252,6 +253,13 @@ public class HTTPClient { case .some(let proxy): return channel.pipeline.addProxyHandler(for: request, decoder: decoder, encoder: encoder, tlsConfiguration: self.configuration.tlsConfiguration, proxy: proxy) } + }.flatMap { + switch self.configuration.decompression { + case .disabled: + return channel.eventLoop.makeSucceededFuture(()) + case .enabled(let limit): + return channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: limit)) + } }.flatMap { if let timeout = self.resolve(timeout: self.configuration.timeout.read, deadline: deadline) { return channel.pipeline.addHandler(IdleStateHandler(readTimeout: timeout)) @@ -322,31 +330,37 @@ public class HTTPClient { public var timeout: Timeout /// Upstream proxy, defaults to no proxy. public var proxy: Proxy? + /// Enables automatic body decompression. Supported algorithms are gzip and deflate. + public var decompression: Decompression /// Ignore TLS unclean shutdown error, defaults to `false`. public var ignoreUncleanSSLShutdown: Bool - public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil) { - self.init(tlsConfiguration: tlsConfiguration, followRedirects: followRedirects, timeout: timeout, proxy: proxy, ignoreUncleanSSLShutdown: false) - } - - public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil, ignoreUncleanSSLShutdown: Bool = false) { + public init(tlsConfiguration: TLSConfiguration? = nil, + followRedirects: Bool = false, + timeout: Timeout = Timeout(), + proxy: Proxy? = nil, + ignoreUncleanSSLShutdown: Bool = false, + decompression: Decompression = .disabled) { self.tlsConfiguration = tlsConfiguration self.followRedirects = followRedirects self.timeout = timeout self.proxy = proxy self.ignoreUncleanSSLShutdown = ignoreUncleanSSLShutdown + self.decompression = decompression } - public init(certificateVerification: CertificateVerification, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil) { - self.init(certificateVerification: certificateVerification, followRedirects: followRedirects, timeout: timeout, proxy: proxy, ignoreUncleanSSLShutdown: false) - } - - public init(certificateVerification: CertificateVerification, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil, ignoreUncleanSSLShutdown: Bool = false) { + public init(certificateVerification: CertificateVerification, + followRedirects: Bool = false, + timeout: Timeout = Timeout(), + proxy: Proxy? = nil, + ignoreUncleanSSLShutdown: Bool = false, + decompression: Decompression = .disabled) { self.tlsConfiguration = TLSConfiguration.forClient(certificateVerification: certificateVerification) self.followRedirects = followRedirects self.timeout = timeout self.proxy = proxy self.ignoreUncleanSSLShutdown = ignoreUncleanSSLShutdown + self.decompression = decompression } } @@ -403,6 +417,14 @@ public class HTTPClient { return EventLoopPreference(.delegateAndChannel(on: eventLoop)) } } + + /// Specifies decompression settings. + public enum Decompression { + /// Decompression is disabled. + case disabled + /// Decompression is enabled. + case enabled(limit: NIOHTTPDecompression.DecompressionLimit) + } } extension HTTPClient.Configuration { @@ -508,6 +530,6 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { public static let invalidProxyResponse = HTTPClientError(code: .invalidProxyResponse) /// Request does not contain `Content-Length` header. public static let contentLengthMissing = HTTPClientError(code: .contentLengthMissing) - /// Proxy Authentication Required + /// Proxy Authentication Required. public static let proxyAuthenticationRequired = HTTPClientError(code: .proxyAuthenticationRequired) } diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index 3913389a1..990115f92 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -447,7 +447,7 @@ extension HTTPClient { public let eventLoop: EventLoop let promise: EventLoopPromise - private var channel: Channel? + var channel: Channel? private var cancelled: Bool private let lock: Lock @@ -677,7 +677,7 @@ extension TaskHandler: ChannelDuplexHandler { } func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let response = unwrapInboundIn(data) + let response = self.unwrapInboundIn(data) switch response { case .head(let head): if let redirectURL = redirectHandler?.redirectTarget(status: head.status, headers: head.headers) { diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift index f75bbb087..40ebad75e 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift @@ -155,17 +155,29 @@ class HTTPClientInternalTests: XCTestCase { XCTAssertThrowsError(try httpClient.post(url: "http://localhost:\(httpBin.port)/post", body: body).wait()) } + // In order to test backpressure we need to make sure that reads will not happen + // until the backpressure promise is succeeded. Since we cannot guarantee when + // messages will be delivered to a client pipeline and we need this test to be + // fast (no waiting for arbitrary amounts of time), we do the following. + // First, we enforce NIO to send us only 1 byte at a time. Then we send a message + // of 4 bytes. This will guarantee that if we see first byte of the message, other + // bytes a ready to be read as well. This will allow us to test if subsequent reads + // are waiting for backpressure promise. func testUploadStreamingBackpressure() throws { class BackpressureTestDelegate: HTTPClientResponseDelegate { typealias Response = Void var _reads = 0 let lock: Lock - let promise: EventLoopPromise + let backpressurePromise: EventLoopPromise + let optionsApplied: EventLoopPromise + let messageReceived: EventLoopPromise - init(promise: EventLoopPromise) { + init(eventLoop: EventLoop) { self.lock = Lock() - self.promise = promise + self.backpressurePromise = eventLoop.makePromise() + self.optionsApplied = eventLoop.makePromise() + self.messageReceived = eventLoop.makePromise() } var reads: Int { @@ -174,18 +186,30 @@ class HTTPClientInternalTests: XCTestCase { } } + func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { + // This is to force NIO to send only 1 byte at a time. + let future = task.channel!.setOption(ChannelOptions.maxMessagesPerRead, value: 1).flatMap { + task.channel!.setOption(ChannelOptions.recvAllocator, value: FixedSizeRecvByteBufferAllocator(capacity: 1)) + } + future.cascade(to: self.optionsApplied) + return future + } + func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { + // We count a number of reads received. self.lock.withLockVoid { self._reads += 1 } - return self.promise.futureResult + // We need to notify the test when first byte of the message is arrived. + self.messageReceived.succeed(()) + return self.backpressurePromise.futureResult } func didFinishRequest(task: HTTPClient.Task) throws {} } let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) - let promise: EventLoopPromise = httpClient.eventLoopGroup.next().makePromise() + let promise = httpClient.eventLoopGroup.next().makePromise(of: Channel.self) let httpBin = HTTPBin(channelPromise: promise) defer { @@ -194,25 +218,29 @@ class HTTPClientInternalTests: XCTestCase { } let request = try Request(url: "http://localhost:\(httpBin.port)/custom") - let delegate = BackpressureTestDelegate(promise: httpClient.eventLoopGroup.next().makePromise()) + let delegate = BackpressureTestDelegate(eventLoop: httpClient.eventLoopGroup.next()) let future = httpClient.execute(request: request, delegate: delegate).futureResult let channel = try promise.futureResult.wait() + // We need to wait for channel options that limit NIO to sending only one byte at a time. + try delegate.optionsApplied.futureResult.wait() - // Send 3 parts, but only one should be received until the future is complete + // Send 4 bytes, but only one should be received until the backpressure promise is succeeded. let buffer = ByteBuffer.of(string: "1234") try channel.writeAndFlush(HTTPServerResponsePart.body(.byteBuffer(buffer))).wait() - try channel.writeAndFlush(HTTPServerResponsePart.body(.byteBuffer(buffer))).wait() - try channel.writeAndFlush(HTTPServerResponsePart.body(.byteBuffer(buffer))).wait() + // Now we wait until message is delivered to client channel pipeline + try delegate.messageReceived.futureResult.wait() XCTAssertEqual(delegate.reads, 1) - delegate.promise.succeed(()) + // Succeed the backpressure promise. + delegate.backpressurePromise.succeed(()) try channel.writeAndFlush(HTTPServerResponsePart.end(nil)).wait() try future.wait() - XCTAssertEqual(delegate.reads, 3) + // At this point all other bytes should be delivered. + XCTAssertEqual(delegate.reads, 4) } func testRequestURITrailingSlash() throws { diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index 9dd207079..c15cf6bfb 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -17,6 +17,7 @@ import Foundation import NIO import NIOConcurrencyHelpers import NIOHTTP1 +import NIOHTTPCompression import NIOSSL class TestHTTPDelegate: HTTPClientResponseDelegate { @@ -111,30 +112,40 @@ internal final class HTTPBin { return channel.pipeline.addHandler(try! NIOSSLServerHandler(context: context), position: .first) } - init(ssl: Bool = false, simulateProxy: HTTPProxySimulator.Option? = nil, channelPromise: EventLoopPromise? = nil) { + init(ssl: Bool = false, compress: Bool = false, simulateProxy: HTTPProxySimulator.Option? = nil, channelPromise: EventLoopPromise? = nil) { self.serverChannel = try! ServerBootstrap(group: self.group) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1) .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true).flatMap { - if let simulateProxy = simulateProxy { - let responseEncoder = HTTPResponseEncoder() - let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) - - return channel.pipeline.addHandlers([responseEncoder, requestDecoder, HTTPProxySimulator(option: simulateProxy, encoder: responseEncoder, decoder: requestDecoder)], position: .first) - } else { - return channel.eventLoop.makeSucceededFuture(()) + channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true) + .flatMap { + if compress { + return channel.pipeline.addHandler(HTTPResponseCompressor()) + } else { + return channel.eventLoop.makeSucceededFuture(()) + } } - }.flatMap { - if ssl { - return HTTPBin.configureTLS(channel: channel).flatMap { - channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise)) + .flatMap { + if let simulateProxy = simulateProxy { + let responseEncoder = HTTPResponseEncoder() + let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) + + return channel.pipeline.addHandlers([responseEncoder, requestDecoder, HTTPProxySimulator(option: simulateProxy, encoder: responseEncoder, decoder: requestDecoder)], position: .first) + } else { + return channel.eventLoop.makeSucceededFuture(()) } - } else { - return channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise)) } - } - }.bind(host: "127.0.0.1", port: 0).wait() + .flatMap { + if ssl { + return HTTPBin.configureTLS(channel: channel).flatMap { + channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise)) + } + } else { + return channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise)) + } + } + } + .bind(host: "127.0.0.1", port: 0).wait() } func shutdown() throws { @@ -295,7 +306,7 @@ internal final class HttpBinHandler: ChannelInboundHandler { context.close(promise: nil) return case "/custom": - context.write(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok))), promise: nil) + context.writeAndFlush(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok))), promise: nil) return case "/events/10/1": // TODO: parse path context.write(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok))), promise: nil) @@ -461,6 +472,12 @@ extension ByteBuffer { buffer.writeString(string) return buffer } + + public static func of(bytes: [UInt8]) -> ByteBuffer { + var buffer = ByteBufferAllocator().buffer(capacity: bytes.count) + buffer.writeBytes(bytes) + return buffer + } } private let cert = """ diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index a06c148f1..8979d7a3d 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -57,6 +57,8 @@ extension HTTPClientTests { ("testWrongContentLengthForSSLUncleanShutdown", testWrongContentLengthForSSLUncleanShutdown), ("testWrongContentLengthWithIgnoreErrorForSSLUncleanShutdown", testWrongContentLengthWithIgnoreErrorForSSLUncleanShutdown), ("testEventLoopArgument", testEventLoopArgument), + ("testDecompression", testDecompression), + ("testDecompressionLimit", testDecompressionLimit), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index fb8a35fe4..6fa219cda 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -16,6 +16,7 @@ import AsyncHTTPClient import NIO import NIOFoundationCompat import NIOHTTP1 +import NIOHTTPCompression import NIOSSL import XCTest @@ -563,4 +564,66 @@ class HTTPClientTests: XCTestCase { response = try httpClient.execute(request: request, delegate: delegate, eventLoop: .delegate(on: eventLoop)).wait() XCTAssertEqual(true, response) } + + func testDecompression() throws { + let httpBin = HTTPBin(compress: true) + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, configuration: .init(decompression: .enabled(limit: .none))) + defer { + XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpBin.shutdown()) + } + + var body = "" + for _ in 1...1000 { + body += "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." + } + + for algorithm in [nil, "gzip", "deflate"] { + var request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/post", method: .POST) + request.body = .string(body) + if let algorithm = algorithm { + request.headers.add(name: "Accept-Encoding", value: algorithm) + } + + let response = try httpClient.execute(request: request).wait() + let bytes = response.body!.getData(at: 0, length: response.body!.readableBytes)! + let data = try JSONDecoder().decode(RequestInfo.self, from: bytes) + + XCTAssertEqual(.ok, response.status) + XCTAssertGreaterThan(body.count, response.headers["Content-Length"].first.flatMap { Int($0) }!) + if let algorithm = algorithm { + XCTAssertEqual(algorithm, response.headers["Content-Encoding"].first) + } else { + XCTAssertEqual("deflate", response.headers["Content-Encoding"].first) + } + XCTAssertEqual(body, data.data) + } + } + + func testDecompressionLimit() throws { + let httpBin = HTTPBin(compress: true) + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, configuration: .init(decompression: .enabled(limit: .ratio(10)))) + defer { + XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpBin.shutdown()) + } + + var request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/post", method: .POST) + request.body = .byteBuffer(ByteBuffer.of(bytes: [120, 156, 75, 76, 28, 5, 200, 0, 0, 248, 66, 103, 17])) + request.headers.add(name: "Accept-Encoding", value: "deflate") + + do { + _ = try httpClient.execute(request: request).wait() + } catch let error as NIOHTTPDecompression.DecompressionError { + switch error { + case .limit: + // ok + break + default: + XCTFail("Unexptected error: \(error)") + } + } catch { + XCTFail("Unexptected error: \(error)") + } + } } diff --git a/docker/Dockerfile b/docker/Dockerfile index 4409ea799..7a15bec1e 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -13,7 +13,7 @@ ENV LANGUAGE en_US.UTF-8 # dependencies RUN apt-get update && apt-get install -y wget -RUN apt-get update && apt-get install -y lsof dnsutils netcat-openbsd net-tools # used by integration tests +RUN apt-get update && apt-get install -y lsof dnsutils netcat-openbsd net-tools libz-dev # used by integration tests # ruby and jazzy for docs generation RUN apt-get update && apt-get install -y ruby ruby-dev libsqlite3-dev