Skip to content

add response decompression support #86

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Oct 23, 2019
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
b0f15a6
add response decompression support
artemredkin Aug 18, 2019
35fa5e0
Merge branch 'master' into support_response_decompression
artemredkin Aug 22, 2019
cc18c6d
Merge branch 'master' into support_response_decompression
artemredkin Sep 16, 2019
88b674a
review fix: add decompression limit
artemredkin Sep 17, 2019
a51fe72
make limit configurable
artemredkin Sep 17, 2019
6d96880
fix missing linux tests
artemredkin Sep 17, 2019
b577f79
fix formatting
artemredkin Sep 17, 2019
18438fe
Merge branch 'master' into support_response_decompression
artemredkin Sep 17, 2019
0782d81
formatting fix after merge
artemredkin Sep 17, 2019
8e8c30d
add docker dependency for zlib
artemredkin Sep 17, 2019
69ec369
review fixes: unset all pointers after use and make inflate methods i…
artemredkin Sep 18, 2019
2891b7d
review fix: re-factor to not use a callback
artemredkin Sep 18, 2019
dbcfd44
review fixes: throw instead of precondition
artemredkin Sep 18, 2019
5d59a66
fix formatting
artemredkin Sep 18, 2019
2f1959f
review fix: flatten compression settings
artemredkin Sep 20, 2019
8a95c2a
Merge branch 'master' into support_response_decompression
artemredkin Sep 20, 2019
0e08fdb
Merge branch 'master' into support_response_decompression
artemredkin Sep 23, 2019
dd83963
use new decompression support from nio-extras
artemredkin Sep 30, 2019
f1848d2
Merge branch 'master' into support_response_decompression
artemredkin Oct 8, 2019
495b27a
remove unused types
artemredkin Oct 8, 2019
4c4ac19
rewrite backpressure test
artemredkin Oct 14, 2019
807761e
rewrite backpressure test
artemredkin Oct 14, 2019
f2c8b09
Merge branch 'support_response_decompression' of github.com:swift-ser…
artemredkin Oct 14, 2019
646d3c7
use real version
artemredkin Oct 16, 2019
1ed16d4
remove commented code
artemredkin Oct 16, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ let package = Package(
dependencies: [
.package(url: "https://github.com/apple/swift-nio.git", from: "2.0.0"),
.package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.0.0"),
.package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.0.0"),
],
targets: [
.target(
name: "AsyncHTTPClient",
dependencies: ["NIO", "NIOHTTP1", "NIOSSL", "NIOConcurrencyHelpers"]
dependencies: ["NIO", "NIOHTTP1", "NIOSSL", "NIOConcurrencyHelpers", "NIOHTTPCompression"]
),
.testTarget(
name: "AsyncHTTPClientTests",
Expand Down
73 changes: 62 additions & 11 deletions Sources/AsyncHTTPClient/HTTPClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,13 @@ public class HTTPClient {
case .some(let proxy):
return channel.pipeline.addProxyHandler(for: request, decoder: decoder, encoder: encoder, tlsConfiguration: self.configuration.tlsConfiguration, proxy: proxy)
}
}.flatMap {
switch self.configuration.decompression {
case .disabled:
return channel.eventLoop.makeSucceededFuture(())
case .enabled(let limit):
return channel.pipeline.addHandler(HTTPResponseDecompressor(limit: limit))
}
}.flatMap {
if let timeout = self.resolve(timeout: self.configuration.timeout.read, deadline: deadline) {
return channel.pipeline.addHandler(IdleStateHandler(readTimeout: timeout))
Expand Down Expand Up @@ -310,31 +317,37 @@ public class HTTPClient {
public var timeout: Timeout
/// Upstream proxy, defaults to no proxy.
public var proxy: Proxy?
/// Enables automatic body decompression. Supported algorithms are gzip and deflate.
public var decompression: Decompression
/// Ignore TLS unclean shutdown error, defaults to `false`.
public var ignoreUncleanSSLShutdown: Bool

public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil) {
self.init(tlsConfiguration: tlsConfiguration, followRedirects: followRedirects, timeout: timeout, proxy: proxy, ignoreUncleanSSLShutdown: false)
}

public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil, ignoreUncleanSSLShutdown: Bool = false) {
public init(tlsConfiguration: TLSConfiguration? = nil,
followRedirects: Bool = false,
timeout: Timeout = Timeout(),
proxy: Proxy? = nil,
ignoreUncleanSSLShutdown: Bool = false,
decompression: Decompression = .disabled) {
self.tlsConfiguration = tlsConfiguration
self.followRedirects = followRedirects
self.timeout = timeout
self.proxy = proxy
self.ignoreUncleanSSLShutdown = ignoreUncleanSSLShutdown
self.decompression = decompression
}

public init(certificateVerification: CertificateVerification, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil) {
self.init(certificateVerification: certificateVerification, followRedirects: followRedirects, timeout: timeout, proxy: proxy, ignoreUncleanSSLShutdown: false)
}

public init(certificateVerification: CertificateVerification, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil, ignoreUncleanSSLShutdown: Bool = false) {
public init(certificateVerification: CertificateVerification,
followRedirects: Bool = false,
timeout: Timeout = Timeout(),
proxy: Proxy? = nil,
ignoreUncleanSSLShutdown: Bool = false,
decompression: Decompression = .disabled) {
self.tlsConfiguration = TLSConfiguration.forClient(certificateVerification: certificateVerification)
self.followRedirects = followRedirects
self.timeout = timeout
self.proxy = proxy
self.ignoreUncleanSSLShutdown = ignoreUncleanSSLShutdown
self.decompression = decompression
}
}

Expand Down Expand Up @@ -368,6 +381,35 @@ public class HTTPClient {
return EventLoopPreference(.prefers(eventLoop))
}
}

/// Specifies decompression settings.
public enum Decompression {
/// Decompression is disabled.
case disabled
/// Decompression is enabled.
case enabled(limit: DecompressionLimit)
}

/// Specifies how to limit decompression inflation.
public enum DecompressionLimit {
/// No limit will be set.
case none
/// Limit will be set on the request body size.
case size(Int)
/// Limit will be set on a ratio between compressed body size and decompressed result.
case ratio(Int)

func exceeded(compressed: Int, decompressed: Int) -> Bool {
switch self {
case .none:
return false
case .size(let allowed):
return compressed > allowed
case .ratio(let ratio):
return decompressed > compressed * ratio
}
}
}
}

extension HTTPClient.Configuration {
Expand Down Expand Up @@ -437,6 +479,9 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
case invalidProxyResponse
case contentLengthMissing
case proxyAuthenticationRequired
case decompressionLimit
case decompressionInitialization(Int32)
case decompression(Int32)
}

private var code: Code
Expand Down Expand Up @@ -473,6 +518,12 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
public static let invalidProxyResponse = HTTPClientError(code: .invalidProxyResponse)
/// Request does not contain `Content-Length` header.
public static let contentLengthMissing = HTTPClientError(code: .contentLengthMissing)
/// Proxy Authentication Required
/// Proxy Authentication Required.
public static let proxyAuthenticationRequired = HTTPClientError(code: .proxyAuthenticationRequired)
/// Decompression limit reached.
public static let decompressionLimit = HTTPClientError(code: .decompressionLimit)
/// Decompression initialization failed.
public static func decompressionInitialization(_ code: Int32) -> HTTPClientError { return HTTPClientError(code: .decompressionInitialization(code)) }
/// Decompression failed.
public static func decompression(_ code: Int32) -> HTTPClientError { return HTTPClientError(code: .decompression(code)) }
}
171 changes: 171 additions & 0 deletions Sources/AsyncHTTPClient/HTTPDecompressor.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the AsyncHTTPClient open source project
//
// Copyright (c) 2018-2019 Swift Server Working Group and the AsyncHTTPClient project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import CNIOExtrasZlib
import NIO
import NIOHTTP1

private enum CompressionAlgorithm: String {
case gzip
case deflate
}

extension z_stream {
mutating func inflatePart(input: inout ByteBuffer, output: inout ByteBuffer) throws {
try input.readWithUnsafeMutableReadableBytes { (dataPtr: UnsafeMutableRawBufferPointer) -> Int in
let typedPtr = dataPtr.baseAddress!.assumingMemoryBound(to: UInt8.self)
let typedDataPtr = UnsafeMutableBufferPointer(start: typedPtr, count: dataPtr.count)

self.avail_in = UInt32(typedDataPtr.count)
self.next_in = typedDataPtr.baseAddress!

defer {
self.avail_in = 0
self.next_in = nil
self.avail_out = 0
self.next_out = nil
}

try self.inflatePart(to: &output)

return typedDataPtr.count - Int(self.avail_in)
}
}

private mutating func inflatePart(to buffer: inout ByteBuffer) throws {
try buffer.writeWithUnsafeMutableBytes { outputPtr in
let typedOutputPtr = UnsafeMutableBufferPointer(start: outputPtr.baseAddress!.assumingMemoryBound(to: UInt8.self), count: outputPtr.count)

self.avail_out = UInt32(typedOutputPtr.count)
self.next_out = typedOutputPtr.baseAddress!

let rc = inflate(&self, Z_NO_FLUSH)
guard rc == Z_OK || rc == Z_STREAM_END else {
throw HTTPClientError.decompression(rc)
}

return typedOutputPtr.count - Int(self.avail_out)
}
}
}

final class HTTPResponseDecompressor: ChannelDuplexHandler, RemovableChannelHandler {
typealias InboundIn = HTTPClientResponsePart
typealias InboundOut = HTTPClientResponsePart
typealias OutboundIn = HTTPClientRequestPart
typealias OutboundOut = HTTPClientRequestPart

private enum State {
case empty
case compressed(CompressionAlgorithm, Int)
}

private let limit: HTTPClient.DecompressionLimit
private var state = State.empty
private var stream = z_stream()
private var inflated = 0

init(limit: HTTPClient.DecompressionLimit) {
self.limit = limit
}

func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let request = self.unwrapOutboundIn(data)
switch request {
case .head(var head):
if head.headers.contains(name: "Accept-Encoding") {
context.write(data, promise: promise)
} else {
head.headers.replaceOrAdd(name: "Accept-Encoding", value: "deflate, gzip")
context.write(self.wrapOutboundOut(.head(head)), promise: promise)
}
default:
context.write(data, promise: promise)
}
}

public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
switch self.unwrapInboundIn(data) {
case .head(let head):
let algorithm: CompressionAlgorithm?
let contentType = head.headers[canonicalForm: "Content-Encoding"].first?.lowercased()
if contentType == "gzip" {
algorithm = .gzip
} else if contentType == "deflate" {
algorithm = .deflate
} else {
algorithm = nil
}

let length = head.headers[canonicalForm: "Content-Length"].first.flatMap { Int($0) }

if let algorithm = algorithm, let length = length {
do {
try self.initializeDecoder(encoding: algorithm, length: length)
} catch {
context.fireErrorCaught(error)
}
}

context.fireChannelRead(data)
case .body(var part):
switch self.state {
case .compressed(_, let originalLength):
while part.readableBytes > 0 {
do {
var buffer = context.channel.allocator.buffer(capacity: 16384)
try self.stream.inflatePart(input: &part, output: &buffer)
self.inflated += buffer.readableBytes

if self.limit.exceeded(compressed: originalLength, decompressed: self.inflated) {
context.fireErrorCaught(HTTPClientError.decompressionLimit)
return
}

context.fireChannelRead(self.wrapInboundOut(.body(buffer)))
} catch {
context.fireErrorCaught(error)
return
}
}
default:
context.fireChannelRead(data)
}
case .end:
deflateEnd(&self.stream)
context.fireChannelRead(data)
}
}

private func initializeDecoder(encoding: CompressionAlgorithm, length: Int) throws {
self.state = .compressed(encoding, length)

self.stream.zalloc = nil
self.stream.zfree = nil
self.stream.opaque = nil

let window: Int32
switch encoding {
case .gzip:
window = 15 + 16
default:
window = 15
}

let rc = CNIOExtrasZlib_inflateInit2(&self.stream, window)
guard rc == Z_OK else {
throw HTTPClientError.decompressionInitialization(rc)
}
}
}
2 changes: 1 addition & 1 deletion Sources/AsyncHTTPClient/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ internal class TaskHandler<Delegate: HTTPClientResponseDelegate>: ChannelInbound
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let response = unwrapInboundIn(data)
let response = self.unwrapInboundIn(data)
switch response {
case .head(let head):
if let redirectURL = redirectHandler?.redirectTarget(status: head.status, headers: head.headers) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ extension HTTPClientInternalTests {
("testProxyStreamingFailure", testProxyStreamingFailure),
("testUploadStreamingBackpressure", testUploadStreamingBackpressure),
("testRequestURITrailingSlash", testRequestURITrailingSlash),
("testDecompressionNoLimit", testDecompressionNoLimit),
("testDecompressionLimitRatio", testDecompressionLimitRatio),
("testDecompressionLimitSize", testDecompressionLimitSize),
]
}
}
45 changes: 45 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -228,4 +228,49 @@ class HTTPClientInternalTests: XCTestCase {
let request11 = try Request(url: "https://someserver.com/some%20path")
XCTAssertEqual(request11.url.uri, "/some%20path")
}

func testDecompressionNoLimit() throws {
let channel = EmbeddedChannel()
try channel.pipeline.addHandler(HTTPResponseDecompressor(limit: .none)).wait()

let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "13")])
try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)))

let body = ByteBuffer.of(bytes: [120, 156, 75, 76, 28, 5, 200, 0, 0, 248, 66, 103, 17])
XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.body(body)))
}

func testDecompressionLimitRatio() throws {
let channel = EmbeddedChannel()
try channel.pipeline.addHandler(HTTPResponseDecompressor(limit: .ratio(10))).wait()

let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "13")])
try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)))

let body = ByteBuffer.of(bytes: [120, 156, 75, 76, 28, 5, 200, 0, 0, 248, 66, 103, 17])
do {
try channel.writeInbound(HTTPClientResponsePart.body(body))
} catch let error as HTTPClientError {
XCTAssertEqual(error, .decompressionLimit)
} catch {
XCTFail("Unexptected error: \(error)")
}
}

func testDecompressionLimitSize() throws {
let channel = EmbeddedChannel()
try channel.pipeline.addHandler(HTTPResponseDecompressor(limit: .size(10))).wait()

let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "13")])
try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers)))

let body = ByteBuffer.of(bytes: [120, 156, 75, 76, 28, 5, 200, 0, 0, 248, 66, 103, 17])
do {
try channel.writeInbound(HTTPClientResponsePart.body(body))
} catch let error as HTTPClientError {
XCTAssertEqual(error, .decompressionLimit)
} catch {
XCTFail("Unexptected error: \(error)")
}
}
}
Loading