Skip to content

add response decompression support #86

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Oct 23, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
b0f15a6
add response decompression support
artemredkin Aug 18, 2019
35fa5e0
Merge branch 'master' into support_response_decompression
artemredkin Aug 22, 2019
cc18c6d
Merge branch 'master' into support_response_decompression
artemredkin Sep 16, 2019
88b674a
review fix: add decompression limit
artemredkin Sep 17, 2019
a51fe72
make limit configurable
artemredkin Sep 17, 2019
6d96880
fix missing linux tests
artemredkin Sep 17, 2019
b577f79
fix formatting
artemredkin Sep 17, 2019
18438fe
Merge branch 'master' into support_response_decompression
artemredkin Sep 17, 2019
0782d81
formatting fix after merge
artemredkin Sep 17, 2019
8e8c30d
add docker dependency for zlib
artemredkin Sep 17, 2019
69ec369
review fixes: unset all pointers after use and make inflate methods i…
artemredkin Sep 18, 2019
2891b7d
review fix: re-factor to not use a callback
artemredkin Sep 18, 2019
dbcfd44
review fixes: throw instead of precondition
artemredkin Sep 18, 2019
5d59a66
fix formatting
artemredkin Sep 18, 2019
2f1959f
review fix: flatten compression settings
artemredkin Sep 20, 2019
8a95c2a
Merge branch 'master' into support_response_decompression
artemredkin Sep 20, 2019
0e08fdb
Merge branch 'master' into support_response_decompression
artemredkin Sep 23, 2019
dd83963
use new decompression support from nio-extras
artemredkin Sep 30, 2019
f1848d2
Merge branch 'master' into support_response_decompression
artemredkin Oct 8, 2019
495b27a
remove unused types
artemredkin Oct 8, 2019
4c4ac19
rewrite backpressure test
artemredkin Oct 14, 2019
807761e
rewrite backpressure test
artemredkin Oct 14, 2019
f2c8b09
Merge branch 'support_response_decompression' of github.com:swift-ser…
artemredkin Oct 14, 2019
646d3c7
use real version
artemredkin Oct 16, 2019
1ed16d4
remove commented code
artemredkin Oct 16, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ let package = Package(
dependencies: [
.package(url: "https://github.com/apple/swift-nio.git", from: "2.0.0"),
.package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.0.0"),
.package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.0.0"),
],
targets: [
.target(
name: "AsyncHTTPClient",
dependencies: ["NIO", "NIOHTTP1", "NIOSSL", "NIOConcurrencyHelpers"]
dependencies: ["NIO", "NIOHTTP1", "NIOSSL", "NIOConcurrencyHelpers", "NIOHTTPCompression"]
),
.testTarget(
name: "AsyncHTTPClientTests",
Expand Down
32 changes: 22 additions & 10 deletions Sources/AsyncHTTPClient/HTTPClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ public class HTTPClient {
case .some:
return channel.pipeline.addProxyHandler(for: request, decoder: decoder, encoder: encoder, tlsConfiguration: self.configuration.tlsConfiguration)
}
}.flatMap {
if self.configuration.decompression {
return channel.pipeline.addHandler(HTTPResponseDecompressor())
} else {
return channel.eventLoop.makeSucceededFuture(())
}
}.flatMap {
if let timeout = self.resolve(timeout: self.configuration.timeout.read, deadline: deadline) {
return channel.pipeline.addHandler(IdleStateHandler(readTimeout: timeout))
Expand Down Expand Up @@ -276,30 +282,36 @@ public class HTTPClient {
public var timeout: Timeout
/// Upstream proxy, defaults to no proxy.
public var proxy: Proxy?
/// Enables automatic body decompression. Supported algorithms are gzip and deflate.
public var decompression: Bool
/// Ignore TLS unclean shutdown error, defaults to `false`.
public var ignoreUncleanSSLShutdown: Bool

public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil) {
self.init(tlsConfiguration: tlsConfiguration, followRedirects: followRedirects, timeout: timeout, proxy: proxy, ignoreUncleanSSLShutdown: false)
}

public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil, ignoreUncleanSSLShutdown: Bool = false) {
public init(tlsConfiguration: TLSConfiguration? = nil,
followRedirects: Bool = false,
timeout: Timeout = Timeout(),
proxy: Proxy? = nil,
ignoreUncleanSSLShutdown: Bool = false,
decompression: Bool = false) {
self.tlsConfiguration = tlsConfiguration
self.followRedirects = followRedirects
self.timeout = timeout
self.proxy = proxy
self.decompression = decompression
self.ignoreUncleanSSLShutdown = ignoreUncleanSSLShutdown
}

public init(certificateVerification: CertificateVerification, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil) {
self.init(certificateVerification: certificateVerification, followRedirects: followRedirects, timeout: timeout, proxy: proxy, ignoreUncleanSSLShutdown: false)
}

public init(certificateVerification: CertificateVerification, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil, ignoreUncleanSSLShutdown: Bool = false) {
public init(certificateVerification: CertificateVerification,
followRedirects: Bool = false,
timeout: Timeout = Timeout(),
proxy: Proxy? = nil,
ignoreUncleanSSLShutdown: Bool = false,
decompression: Bool = false) {
self.tlsConfiguration = TLSConfiguration.forClient(certificateVerification: certificateVerification)
self.followRedirects = followRedirects
self.timeout = timeout
self.proxy = proxy
self.decompression = decompression
self.ignoreUncleanSSLShutdown = ignoreUncleanSSLShutdown
}
}
Expand Down
140 changes: 140 additions & 0 deletions Sources/AsyncHTTPClient/HTTPDecompressor.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the AsyncHTTPClient open source project
//
// Copyright (c) 2018-2019 Swift Server Working Group and the AsyncHTTPClient project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import CNIOExtrasZlib
import NIO
import NIOHTTP1

private enum CompressionAlgorithm: String {
case gzip
case deflate
}

extension z_stream {
public mutating func inflatePart(input: inout ByteBuffer, allocator: ByteBufferAllocator, consumer: (ByteBuffer) -> Void) {
input.readWithUnsafeMutableReadableBytes { dataPtr in
let typedPtr = dataPtr.baseAddress!.assumingMemoryBound(to: UInt8.self)
let typedDataPtr = UnsafeMutableBufferPointer(start: typedPtr, count: dataPtr.count)

self.avail_in = UInt32(typedDataPtr.count)
self.next_in = typedDataPtr.baseAddress!

repeat {
var buffer = allocator.buffer(capacity: 16384)
self.inflatePart(to: &buffer)
consumer(buffer)
} while self.avail_out == 0

return Int(self.avail_in)
}
}

public mutating func inflatePart(to buffer: inout ByteBuffer) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API must not be public, as it is extremely unsafe. Please make it private.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, thanks!

buffer.writeWithUnsafeMutableBytes { outputPtr in
let typedOutputPtr = UnsafeMutableBufferPointer(start: outputPtr.baseAddress!.assumingMemoryBound(to: UInt8.self), count: outputPtr.count)

self.avail_out = UInt32(typedOutputPtr.count)
self.next_out = typedOutputPtr.baseAddress!

let rc = inflate(&self, Z_NO_FLUSH)
precondition(rc == Z_OK || rc == Z_STREAM_END, "decompression failed: \(rc)")

return typedOutputPtr.count - Int(self.avail_out)
}
}
}

final class HTTPResponseDecompressor: ChannelDuplexHandler, RemovableChannelHandler {
typealias InboundIn = HTTPClientResponsePart
typealias InboundOut = HTTPClientResponsePart
typealias OutboundIn = HTTPClientRequestPart
typealias OutboundOut = HTTPClientRequestPart

private enum State {
case empty
case compressed(CompressionAlgorithm, Int)
}

private var state = State.empty
private var stream = z_stream()

func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let request = self.unwrapOutboundIn(data)
switch request {
case .head(var head):
if head.headers.contains(name: "Accept-Encoding") {
context.write(data, promise: promise)
} else {
head.headers.replaceOrAdd(name: "Accept-Encoding", value: "deflate, gzip")
context.write(self.wrapOutboundOut(.head(head)), promise: promise)
}
default:
context.write(data, promise: promise)
}
}

public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
switch self.unwrapInboundIn(data) {
case .head(let head):
let algorithm: CompressionAlgorithm?
let contentType = head.headers[canonicalForm: "Content-Encoding"].first?.lowercased()
if contentType == "gzip" {
algorithm = .gzip
} else if contentType == "deflate" {
algorithm = .deflate
} else {
algorithm = nil
}

let length = head.headers[canonicalForm: "Content-Length"].first.flatMap { Int($0) }

if let algorithm = algorithm, let length = length {
self.initializeDecoder(encoding: algorithm, length: length)
}

context.fireChannelRead(data)
case .body(var part):
switch self.state {
case .compressed:
self.stream.inflatePart(input: &part, allocator: context.channel.allocator) { output in
context.fireChannelRead(self.wrapInboundOut(.body(output)))
}
default:
context.fireChannelRead(data)
}
case .end:
deflateEnd(&self.stream)
context.fireChannelRead(data)
}
}

private func initializeDecoder(encoding: CompressionAlgorithm, length: Int) {
self.state = .compressed(encoding, length)

self.stream.zalloc = nil
self.stream.zfree = nil
self.stream.opaque = nil

let window: Int32
switch encoding {
case .gzip:
window = 15 + 16
default:
window = 15
}

let rc = CNIOExtrasZlib_inflateInit2(&self.stream, window)
precondition(rc == Z_OK, "Unexpected return from zlib init: \(rc)")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably throw.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, thanks!

}
}
2 changes: 1 addition & 1 deletion Sources/AsyncHTTPClient/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ internal class TaskHandler<T: HTTPClientResponseDelegate>: ChannelInboundHandler
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let response = unwrapInboundIn(data)
let response = self.unwrapInboundIn(data)
switch response {
case .head(let head):
if let redirectURL = redirectHandler?.redirectTarget(status: head.status, headers: head.headers) {
Expand Down
36 changes: 23 additions & 13 deletions Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import AsyncHTTPClient
import Foundation
import NIO
import NIOHTTP1
import NIOHTTPCompression
import NIOSSL

class TestHTTPDelegate: HTTPClientResponseDelegate {
Expand Down Expand Up @@ -103,26 +104,35 @@ internal class HttpBin {
return channel.pipeline.addHandler(try! NIOSSLServerHandler(context: context), position: .first)
}

init(ssl: Bool = false, simulateProxy: HTTPProxySimulator.Option? = nil, channelPromise: EventLoopPromise<Channel>? = nil) {
init(ssl: Bool = false, compress: Bool = false, simulateProxy: HTTPProxySimulator.Option? = nil, channelPromise: EventLoopPromise<Channel>? = nil) {
self.serverChannel = try! ServerBootstrap(group: self.group)
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1)
.childChannelInitializer { channel in
channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true).flatMap {
if let simulateProxy = simulateProxy {
return channel.pipeline.addHandler(HTTPProxySimulator(option: simulateProxy), position: .first)
} else {
return channel.eventLoop.makeSucceededFuture(())
channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true)
.flatMap {
if compress {
return channel.pipeline.addHandler(HTTPResponseCompressor())
} else {
return channel.eventLoop.makeSucceededFuture(())
}
}
}.flatMap {
if ssl {
return HttpBin.configureTLS(channel: channel).flatMap {
channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise))
.flatMap {
if let simulateProxy = simulateProxy {
return channel.pipeline.addHandler(HTTPProxySimulator(option: simulateProxy), position: .first)
} else {
return channel.eventLoop.makeSucceededFuture(())
}
}
.flatMap {
if ssl {
return HttpBin.configureTLS(channel: channel).flatMap {
channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise))
}
} else {
return channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise))
}
} else {
return channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise))
}
}
}.bind(host: "127.0.0.1", port: 0).wait()
}

Expand Down
1 change: 1 addition & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ extension HTTPClientTests {
("testNoResponseWithIgnoreErrorForSSLUncleanShutdown", testNoResponseWithIgnoreErrorForSSLUncleanShutdown),
("testWrongContentLengthForSSLUncleanShutdown", testWrongContentLengthForSSLUncleanShutdown),
("testWrongContentLengthWithIgnoreErrorForSSLUncleanShutdown", testWrongContentLengthWithIgnoreErrorForSSLUncleanShutdown),
("testDecompression", testDecompression),
]
}
}
35 changes: 35 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -484,4 +484,39 @@ class HTTPClientTests: XCTestCase {
}
}
}

func testDecompression() throws {
let httpBin = HttpBin(compress: true)
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, configuration: .init(decompression: true))
defer {
try! httpClient.syncShutdown()
httpBin.shutdown()
}

var body = ""
for _ in 1 ... 1000 {
body += "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."
}

for algorithm in [nil, "gzip", "deflate"] {
var request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/post", method: .POST)
request.body = .string(body)
if let algorithm = algorithm {
request.headers.add(name: "Accept-Encoding", value: algorithm)
}

let response = try httpClient.execute(request: request).wait()
let bytes = response.body!.getData(at: 0, length: response.body!.readableBytes)!
let data = try JSONDecoder().decode(RequestInfo.self, from: bytes)

XCTAssertEqual(.ok, response.status)
XCTAssertGreaterThan(body.count, response.headers["Content-Length"].first.flatMap { Int($0) }!)
if let algorithm = algorithm {
XCTAssertEqual(algorithm, response.headers["Content-Encoding"].first)
} else {
XCTAssertEqual("deflate", response.headers["Content-Encoding"].first)
}
XCTAssertEqual(body, data.data)
}
}
}