Skip to content

Commit 191c4ba

Browse files
artemredkinweissi
authored andcommitted
add response decompression support (#86)
fixes #44
1 parent d84a1da commit 191c4ba

File tree

8 files changed

+177
-44
lines changed

8 files changed

+177
-44
lines changed

Diff for: Package.swift

+2-1
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@ let package = Package(
2323
dependencies: [
2424
.package(url: "https://github.com/apple/swift-nio.git", from: "2.8.0"),
2525
.package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.0.0"),
26+
.package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.3.0"),
2627
],
2728
targets: [
2829
.target(
2930
name: "AsyncHTTPClient",
30-
dependencies: ["NIO", "NIOHTTP1", "NIOSSL", "NIOConcurrencyHelpers"]
31+
dependencies: ["NIO", "NIOHTTP1", "NIOSSL", "NIOConcurrencyHelpers", "NIOHTTPCompression"]
3132
),
3233
.testTarget(
3334
name: "AsyncHTTPClientTests",

Diff for: Sources/AsyncHTTPClient/HTTPClient.swift

+33-11
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import Foundation
1616
import NIO
1717
import NIOConcurrencyHelpers
1818
import NIOHTTP1
19+
import NIOHTTPCompression
1920
import NIOSSL
2021

2122
/// HTTPClient class provides API for request execution.
@@ -252,6 +253,13 @@ public class HTTPClient {
252253
case .some(let proxy):
253254
return channel.pipeline.addProxyHandler(for: request, decoder: decoder, encoder: encoder, tlsConfiguration: self.configuration.tlsConfiguration, proxy: proxy)
254255
}
256+
}.flatMap {
257+
switch self.configuration.decompression {
258+
case .disabled:
259+
return channel.eventLoop.makeSucceededFuture(())
260+
case .enabled(let limit):
261+
return channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: limit))
262+
}
255263
}.flatMap {
256264
if let timeout = self.resolve(timeout: self.configuration.timeout.read, deadline: deadline) {
257265
return channel.pipeline.addHandler(IdleStateHandler(readTimeout: timeout))
@@ -322,31 +330,37 @@ public class HTTPClient {
322330
public var timeout: Timeout
323331
/// Upstream proxy, defaults to no proxy.
324332
public var proxy: Proxy?
333+
/// Enables automatic body decompression. Supported algorithms are gzip and deflate.
334+
public var decompression: Decompression
325335
/// Ignore TLS unclean shutdown error, defaults to `false`.
326336
public var ignoreUncleanSSLShutdown: Bool
327337

328-
public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil) {
329-
self.init(tlsConfiguration: tlsConfiguration, followRedirects: followRedirects, timeout: timeout, proxy: proxy, ignoreUncleanSSLShutdown: false)
330-
}
331-
332-
public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil, ignoreUncleanSSLShutdown: Bool = false) {
338+
public init(tlsConfiguration: TLSConfiguration? = nil,
339+
followRedirects: Bool = false,
340+
timeout: Timeout = Timeout(),
341+
proxy: Proxy? = nil,
342+
ignoreUncleanSSLShutdown: Bool = false,
343+
decompression: Decompression = .disabled) {
333344
self.tlsConfiguration = tlsConfiguration
334345
self.followRedirects = followRedirects
335346
self.timeout = timeout
336347
self.proxy = proxy
337348
self.ignoreUncleanSSLShutdown = ignoreUncleanSSLShutdown
349+
self.decompression = decompression
338350
}
339351

340-
public init(certificateVerification: CertificateVerification, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil) {
341-
self.init(certificateVerification: certificateVerification, followRedirects: followRedirects, timeout: timeout, proxy: proxy, ignoreUncleanSSLShutdown: false)
342-
}
343-
344-
public init(certificateVerification: CertificateVerification, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil, ignoreUncleanSSLShutdown: Bool = false) {
352+
public init(certificateVerification: CertificateVerification,
353+
followRedirects: Bool = false,
354+
timeout: Timeout = Timeout(),
355+
proxy: Proxy? = nil,
356+
ignoreUncleanSSLShutdown: Bool = false,
357+
decompression: Decompression = .disabled) {
345358
self.tlsConfiguration = TLSConfiguration.forClient(certificateVerification: certificateVerification)
346359
self.followRedirects = followRedirects
347360
self.timeout = timeout
348361
self.proxy = proxy
349362
self.ignoreUncleanSSLShutdown = ignoreUncleanSSLShutdown
363+
self.decompression = decompression
350364
}
351365
}
352366

@@ -403,6 +417,14 @@ public class HTTPClient {
403417
return EventLoopPreference(.delegateAndChannel(on: eventLoop))
404418
}
405419
}
420+
421+
/// Specifies decompression settings.
422+
public enum Decompression {
423+
/// Decompression is disabled.
424+
case disabled
425+
/// Decompression is enabled.
426+
case enabled(limit: NIOHTTPDecompression.DecompressionLimit)
427+
}
406428
}
407429

408430
extension HTTPClient.Configuration {
@@ -508,6 +530,6 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
508530
public static let invalidProxyResponse = HTTPClientError(code: .invalidProxyResponse)
509531
/// Request does not contain `Content-Length` header.
510532
public static let contentLengthMissing = HTTPClientError(code: .contentLengthMissing)
511-
/// Proxy Authentication Required
533+
/// Proxy Authentication Required.
512534
public static let proxyAuthenticationRequired = HTTPClientError(code: .proxyAuthenticationRequired)
513535
}

Diff for: Sources/AsyncHTTPClient/HTTPHandler.swift

+2-2
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ extension HTTPClient {
447447
public let eventLoop: EventLoop
448448

449449
let promise: EventLoopPromise<Response>
450-
private var channel: Channel?
450+
var channel: Channel?
451451
private var cancelled: Bool
452452
private let lock: Lock
453453

@@ -677,7 +677,7 @@ extension TaskHandler: ChannelDuplexHandler {
677677
}
678678

679679
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
680-
let response = unwrapInboundIn(data)
680+
let response = self.unwrapInboundIn(data)
681681
switch response {
682682
case .head(let head):
683683
if let redirectURL = redirectHandler?.redirectTarget(status: head.status, headers: head.headers) {

Diff for: Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift

+39-11
Original file line numberDiff line numberDiff line change
@@ -155,17 +155,29 @@ class HTTPClientInternalTests: XCTestCase {
155155
XCTAssertThrowsError(try httpClient.post(url: "http://localhost:\(httpBin.port)/post", body: body).wait())
156156
}
157157

158+
// In order to test backpressure we need to make sure that reads will not happen
159+
// until the backpressure promise is succeeded. Since we cannot guarantee when
160+
// messages will be delivered to a client pipeline and we need this test to be
161+
// fast (no waiting for arbitrary amounts of time), we do the following.
162+
// First, we enforce NIO to send us only 1 byte at a time. Then we send a message
163+
// of 4 bytes. This will guarantee that if we see first byte of the message, other
164+
// bytes a ready to be read as well. This will allow us to test if subsequent reads
165+
// are waiting for backpressure promise.
158166
func testUploadStreamingBackpressure() throws {
159167
class BackpressureTestDelegate: HTTPClientResponseDelegate {
160168
typealias Response = Void
161169

162170
var _reads = 0
163171
let lock: Lock
164-
let promise: EventLoopPromise<Void>
172+
let backpressurePromise: EventLoopPromise<Void>
173+
let optionsApplied: EventLoopPromise<Void>
174+
let messageReceived: EventLoopPromise<Void>
165175

166-
init(promise: EventLoopPromise<Void>) {
176+
init(eventLoop: EventLoop) {
167177
self.lock = Lock()
168-
self.promise = promise
178+
self.backpressurePromise = eventLoop.makePromise()
179+
self.optionsApplied = eventLoop.makePromise()
180+
self.messageReceived = eventLoop.makePromise()
169181
}
170182

171183
var reads: Int {
@@ -174,18 +186,30 @@ class HTTPClientInternalTests: XCTestCase {
174186
}
175187
}
176188

189+
func didReceiveHead(task: HTTPClient.Task<Void>, _ head: HTTPResponseHead) -> EventLoopFuture<Void> {
190+
// This is to force NIO to send only 1 byte at a time.
191+
let future = task.channel!.setOption(ChannelOptions.maxMessagesPerRead, value: 1).flatMap {
192+
task.channel!.setOption(ChannelOptions.recvAllocator, value: FixedSizeRecvByteBufferAllocator(capacity: 1))
193+
}
194+
future.cascade(to: self.optionsApplied)
195+
return future
196+
}
197+
177198
func didReceiveBodyPart(task: HTTPClient.Task<Response>, _ buffer: ByteBuffer) -> EventLoopFuture<Void> {
199+
// We count a number of reads received.
178200
self.lock.withLockVoid {
179201
self._reads += 1
180202
}
181-
return self.promise.futureResult
203+
// We need to notify the test when first byte of the message is arrived.
204+
self.messageReceived.succeed(())
205+
return self.backpressurePromise.futureResult
182206
}
183207

184208
func didFinishRequest(task: HTTPClient.Task<Response>) throws {}
185209
}
186210

187211
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew)
188-
let promise: EventLoopPromise<Channel> = httpClient.eventLoopGroup.next().makePromise()
212+
let promise = httpClient.eventLoopGroup.next().makePromise(of: Channel.self)
189213
let httpBin = HTTPBin(channelPromise: promise)
190214

191215
defer {
@@ -194,25 +218,29 @@ class HTTPClientInternalTests: XCTestCase {
194218
}
195219

196220
let request = try Request(url: "http://localhost:\(httpBin.port)/custom")
197-
let delegate = BackpressureTestDelegate(promise: httpClient.eventLoopGroup.next().makePromise())
221+
let delegate = BackpressureTestDelegate(eventLoop: httpClient.eventLoopGroup.next())
198222
let future = httpClient.execute(request: request, delegate: delegate).futureResult
199223

200224
let channel = try promise.futureResult.wait()
225+
// We need to wait for channel options that limit NIO to sending only one byte at a time.
226+
try delegate.optionsApplied.futureResult.wait()
201227

202-
// Send 3 parts, but only one should be received until the future is complete
228+
// Send 4 bytes, but only one should be received until the backpressure promise is succeeded.
203229
let buffer = ByteBuffer.of(string: "1234")
204230
try channel.writeAndFlush(HTTPServerResponsePart.body(.byteBuffer(buffer))).wait()
205-
try channel.writeAndFlush(HTTPServerResponsePart.body(.byteBuffer(buffer))).wait()
206-
try channel.writeAndFlush(HTTPServerResponsePart.body(.byteBuffer(buffer))).wait()
207231

232+
// Now we wait until message is delivered to client channel pipeline
233+
try delegate.messageReceived.futureResult.wait()
208234
XCTAssertEqual(delegate.reads, 1)
209235

210-
delegate.promise.succeed(())
236+
// Succeed the backpressure promise.
237+
delegate.backpressurePromise.succeed(())
211238

212239
try channel.writeAndFlush(HTTPServerResponsePart.end(nil)).wait()
213240
try future.wait()
214241

215-
XCTAssertEqual(delegate.reads, 3)
242+
// At this point all other bytes should be delivered.
243+
XCTAssertEqual(delegate.reads, 4)
216244
}
217245

218246
func testRequestURITrailingSlash() throws {

Diff for: Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift

+35-18
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import Foundation
1717
import NIO
1818
import NIOConcurrencyHelpers
1919
import NIOHTTP1
20+
import NIOHTTPCompression
2021
import NIOSSL
2122

2223
class TestHTTPDelegate: HTTPClientResponseDelegate {
@@ -111,30 +112,40 @@ internal final class HTTPBin {
111112
return channel.pipeline.addHandler(try! NIOSSLServerHandler(context: context), position: .first)
112113
}
113114

114-
init(ssl: Bool = false, simulateProxy: HTTPProxySimulator.Option? = nil, channelPromise: EventLoopPromise<Channel>? = nil) {
115+
init(ssl: Bool = false, compress: Bool = false, simulateProxy: HTTPProxySimulator.Option? = nil, channelPromise: EventLoopPromise<Channel>? = nil) {
115116
self.serverChannel = try! ServerBootstrap(group: self.group)
116117
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
117118
.childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1)
118119
.childChannelInitializer { channel in
119-
channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true).flatMap {
120-
if let simulateProxy = simulateProxy {
121-
let responseEncoder = HTTPResponseEncoder()
122-
let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes))
123-
124-
return channel.pipeline.addHandlers([responseEncoder, requestDecoder, HTTPProxySimulator(option: simulateProxy, encoder: responseEncoder, decoder: requestDecoder)], position: .first)
125-
} else {
126-
return channel.eventLoop.makeSucceededFuture(())
120+
channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true)
121+
.flatMap {
122+
if compress {
123+
return channel.pipeline.addHandler(HTTPResponseCompressor())
124+
} else {
125+
return channel.eventLoop.makeSucceededFuture(())
126+
}
127127
}
128-
}.flatMap {
129-
if ssl {
130-
return HTTPBin.configureTLS(channel: channel).flatMap {
131-
channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise))
128+
.flatMap {
129+
if let simulateProxy = simulateProxy {
130+
let responseEncoder = HTTPResponseEncoder()
131+
let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes))
132+
133+
return channel.pipeline.addHandlers([responseEncoder, requestDecoder, HTTPProxySimulator(option: simulateProxy, encoder: responseEncoder, decoder: requestDecoder)], position: .first)
134+
} else {
135+
return channel.eventLoop.makeSucceededFuture(())
132136
}
133-
} else {
134-
return channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise))
135137
}
136-
}
137-
}.bind(host: "127.0.0.1", port: 0).wait()
138+
.flatMap {
139+
if ssl {
140+
return HTTPBin.configureTLS(channel: channel).flatMap {
141+
channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise))
142+
}
143+
} else {
144+
return channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise))
145+
}
146+
}
147+
}
148+
.bind(host: "127.0.0.1", port: 0).wait()
138149
}
139150

140151
func shutdown() throws {
@@ -295,7 +306,7 @@ internal final class HttpBinHandler: ChannelInboundHandler {
295306
context.close(promise: nil)
296307
return
297308
case "/custom":
298-
context.write(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok))), promise: nil)
309+
context.writeAndFlush(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok))), promise: nil)
299310
return
300311
case "/events/10/1": // TODO: parse path
301312
context.write(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok))), promise: nil)
@@ -461,6 +472,12 @@ extension ByteBuffer {
461472
buffer.writeString(string)
462473
return buffer
463474
}
475+
476+
public static func of(bytes: [UInt8]) -> ByteBuffer {
477+
var buffer = ByteBufferAllocator().buffer(capacity: bytes.count)
478+
buffer.writeBytes(bytes)
479+
return buffer
480+
}
464481
}
465482

466483
private let cert = """

Diff for: Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift

+2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ extension HTTPClientTests {
5757
("testWrongContentLengthForSSLUncleanShutdown", testWrongContentLengthForSSLUncleanShutdown),
5858
("testWrongContentLengthWithIgnoreErrorForSSLUncleanShutdown", testWrongContentLengthWithIgnoreErrorForSSLUncleanShutdown),
5959
("testEventLoopArgument", testEventLoopArgument),
60+
("testDecompression", testDecompression),
61+
("testDecompressionLimit", testDecompressionLimit),
6062
]
6163
}
6264
}

Diff for: Tests/AsyncHTTPClientTests/HTTPClientTests.swift

+63
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import AsyncHTTPClient
1616
import NIO
1717
import NIOFoundationCompat
1818
import NIOHTTP1
19+
import NIOHTTPCompression
1920
import NIOSSL
2021
import XCTest
2122

@@ -563,4 +564,66 @@ class HTTPClientTests: XCTestCase {
563564
response = try httpClient.execute(request: request, delegate: delegate, eventLoop: .delegate(on: eventLoop)).wait()
564565
XCTAssertEqual(true, response)
565566
}
567+
568+
func testDecompression() throws {
569+
let httpBin = HTTPBin(compress: true)
570+
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, configuration: .init(decompression: .enabled(limit: .none)))
571+
defer {
572+
XCTAssertNoThrow(try httpClient.syncShutdown())
573+
XCTAssertNoThrow(try httpBin.shutdown())
574+
}
575+
576+
var body = ""
577+
for _ in 1...1000 {
578+
body += "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."
579+
}
580+
581+
for algorithm in [nil, "gzip", "deflate"] {
582+
var request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/post", method: .POST)
583+
request.body = .string(body)
584+
if let algorithm = algorithm {
585+
request.headers.add(name: "Accept-Encoding", value: algorithm)
586+
}
587+
588+
let response = try httpClient.execute(request: request).wait()
589+
let bytes = response.body!.getData(at: 0, length: response.body!.readableBytes)!
590+
let data = try JSONDecoder().decode(RequestInfo.self, from: bytes)
591+
592+
XCTAssertEqual(.ok, response.status)
593+
XCTAssertGreaterThan(body.count, response.headers["Content-Length"].first.flatMap { Int($0) }!)
594+
if let algorithm = algorithm {
595+
XCTAssertEqual(algorithm, response.headers["Content-Encoding"].first)
596+
} else {
597+
XCTAssertEqual("deflate", response.headers["Content-Encoding"].first)
598+
}
599+
XCTAssertEqual(body, data.data)
600+
}
601+
}
602+
603+
func testDecompressionLimit() throws {
604+
let httpBin = HTTPBin(compress: true)
605+
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, configuration: .init(decompression: .enabled(limit: .ratio(10))))
606+
defer {
607+
XCTAssertNoThrow(try httpClient.syncShutdown())
608+
XCTAssertNoThrow(try httpBin.shutdown())
609+
}
610+
611+
var request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/post", method: .POST)
612+
request.body = .byteBuffer(ByteBuffer.of(bytes: [120, 156, 75, 76, 28, 5, 200, 0, 0, 248, 66, 103, 17]))
613+
request.headers.add(name: "Accept-Encoding", value: "deflate")
614+
615+
do {
616+
_ = try httpClient.execute(request: request).wait()
617+
} catch let error as NIOHTTPDecompression.DecompressionError {
618+
switch error {
619+
case .limit:
620+
// ok
621+
break
622+
default:
623+
XCTFail("Unexptected error: \(error)")
624+
}
625+
} catch {
626+
XCTFail("Unexptected error: \(error)")
627+
}
628+
}
566629
}

0 commit comments

Comments
 (0)