Skip to content

add proxy support #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.build
Package.resolved
*.xcodeproj
DerivedData
149 changes: 149 additions & 0 deletions Sources/NIOHTTPClient/HTTPClientProxyHandler.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIOHTTPClient open source project
//
// Copyright (c) 2018-2019 Swift Server Working Group and the SwiftNIOHTTPClient project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIOHTTPClient project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import NIO
import NIOHTTP1

/// Specifies the remote address of an HTTP proxy.
///
/// Adding an `HTTPClientProxy` to your client's `HTTPClientConfiguration`
/// will cause requests to be passed through the specified proxy using the
/// HTTP `CONNECT` method.
///
/// If a `TLSConfiguration` is used in conjunction with `HTTPClientProxy`,
/// TLS will be established _after_ successful proxy, between your client
/// and the destination server.
public extension HTTPClient {
struct Proxy {
internal let host: String
internal let port: Int

public static func server(host: String, port: Int) -> Proxy {
return .init(host: host, port: port)
}
}
}

internal final class HTTPClientProxyHandler: ChannelDuplexHandler, RemovableChannelHandler {
typealias InboundIn = HTTPClientResponsePart
typealias OutboundIn = HTTPClientRequestPart
typealias OutboundOut = HTTPClientRequestPart

enum WriteItem {
case write(NIOAny, EventLoopPromise<Void>?)
case flush
}

enum ReadState {
case awaitingResponse
case connecting
case connected
}

private let host: String
private let port: Int
private var onConnect: (Channel) -> EventLoopFuture<Void>
private var writeBuffer: CircularBuffer<WriteItem>
private var readBuffer: CircularBuffer<NIOAny>
private var readState: ReadState

init(host: String, port: Int, onConnect: @escaping (Channel) -> EventLoopFuture<Void>) {
self.host = host
self.port = port
self.onConnect = onConnect
self.writeBuffer = .init()
self.readBuffer = .init()
self.readState = .awaitingResponse
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
switch self.readState {
case .awaitingResponse:
let res = self.unwrapInboundIn(data)
switch res {
case .head(let head):
switch head.status.code {
case 200..<300:
// Any 2xx (Successful) response indicates that the sender (and all
// inbound proxies) will switch to tunnel mode immediately after the
// blank line that concludes the successful response's header section
break
default:
// Any response other than a successful response
// indicates that the tunnel has not yet been formed and that the
// connection remains governed by HTTP.
context.fireErrorCaught(HTTPClientError.invalidProxyResponse)
}
case .end:
self.readState = .connecting
_ = self.handleConnect(context: context)
case .body:
break
}
case .connecting:
self.readBuffer.append(data)
case .connected:
context.fireChannelRead(data)
}
}

func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
self.writeBuffer.append(.write(data, promise))
}

func flush(context: ChannelHandlerContext) {
self.writeBuffer.append(.flush)
}

func channelActive(context: ChannelHandlerContext) {
self.sendConnect(context: context)
context.fireChannelActive()
}

// MARK: Private

private func handleConnect(context: ChannelHandlerContext) -> EventLoopFuture<Void> {
return self.onConnect(context.channel).flatMap {
self.readState = .connected

// forward any buffered reads
while !self.readBuffer.isEmpty {
context.fireChannelRead(self.readBuffer.removeFirst())
}

// calls to context.write may be re-entrant
while !self.writeBuffer.isEmpty {
switch self.writeBuffer.removeFirst() {
case .flush:
context.flush()
case .write(let data, let promise):
context.write(data, promise: promise)
}
}
return context.pipeline.removeHandler(self)
}
}

private func sendConnect(context: ChannelHandlerContext) {
var head = HTTPRequestHead(
version: .init(major: 1, minor: 1),
method: .CONNECT,
uri: "\(self.host):\(self.port)"
)
head.headers.add(name: "proxy-connection", value: "keep-alive")
context.write(self.wrapOutboundOut(.head(head)), promise: nil)
context.write(self.wrapOutboundOut(.end(nil)), promise: nil)
context.flush()
}
}
91 changes: 65 additions & 26 deletions Sources/NIOHTTPClient/SwiftNIOHTTP.swift
Original file line number Diff line number Diff line change
Expand Up @@ -124,24 +124,33 @@ public class HTTPClient {
var bootstrap = ClientBootstrap(group: group)
.channelOption(ChannelOptions.socket(SocketOptionLevel(IPPROTO_TCP), TCP_NODELAY), value: 1)
.channelInitializer { channel in
channel.pipeline.addHTTPClientHandlers().flatMap {
self.configureSSL(channel: channel, useTLS: request.useTLS, hostname: request.host)
}.flatMap {
if let readTimeout = timeout.read {
return channel.pipeline.addHandler(IdleStateHandler(readTimeout: readTimeout))
} else {
return channel.eventLoop.makeSucceededFuture(())
let encoder = HTTPRequestEncoder()
let decoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes))
return channel.pipeline.addHandlers([encoder, decoder], position: .first).flatMap {
switch self.configuration.proxy {
case .none:
return channel.pipeline.addSSLHandlerIfNeeded(for: request, tlsConfiguration: self.configuration.tlsConfiguration)
case .some:
return channel.pipeline.addProxyHandler(for: request, decoder: decoder, encoder: encoder, tlsConfiguration: self.configuration.tlsConfiguration)
}
}.flatMap {
if let readTimeout = timeout.read {
return channel.pipeline.addHandler(IdleStateHandler(readTimeout: readTimeout))
} else {
return channel.eventLoop.makeSucceededFuture(())
}
}.flatMap {
let taskHandler = TaskHandler(task: task, delegate: delegate, promise: promise, redirectHandler: redirectHandler)
return channel.pipeline.addHandler(taskHandler)
}
}.flatMap {
channel.pipeline.addHandler(TaskHandler(task: task, delegate: delegate, promise: promise, redirectHandler: redirectHandler))
}
}

if let connectTimeout = timeout.connect {
bootstrap = bootstrap.connectTimeout(connectTimeout)
}

bootstrap.connect(host: request.host, port: request.port)

let address = self.resolveAddress(request: request, proxy: self.configuration.proxy)
bootstrap.connect(host: address.host, port: address.port)
.map { channel in
task.setChannel(channel)
}
Expand All @@ -155,36 +164,33 @@ public class HTTPClient {
return task
}

private func configureSSL(channel: Channel, useTLS: Bool, hostname: String) -> EventLoopFuture<Void> {
if useTLS {
do {
let tlsConfiguration = self.configuration.tlsConfiguration ?? TLSConfiguration.forClient()
let context = try NIOSSLContext(configuration: tlsConfiguration)
return channel.pipeline.addHandler(try NIOSSLClientHandler(context: context, serverHostname: hostname),
position: .first)
} catch {
return channel.eventLoop.makeFailedFuture(error)
}
} else {
return channel.eventLoop.makeSucceededFuture(())
private func resolveAddress(request: Request, proxy: Proxy?) -> (host: String, port: Int) {
switch self.configuration.proxy {
case .none:
return (request.host, request.port)
case .some(let proxy):
return (proxy.host, proxy.port)
}
}

public struct Configuration {
public var tlsConfiguration: TLSConfiguration?
public var followRedirects: Bool
public var timeout: Timeout
public var proxy: Proxy?

public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: Bool = false, timeout: Timeout = Timeout()) {
public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil) {
self.tlsConfiguration = tlsConfiguration
self.followRedirects = followRedirects
self.timeout = timeout
self.proxy = proxy
}

public init(certificateVerification: CertificateVerification, followRedirects: Bool = false, timeout: Timeout = Timeout()) {
public init(certificateVerification: CertificateVerification, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil) {
self.tlsConfiguration = TLSConfiguration.forClient(certificateVerification: certificateVerification)
self.followRedirects = followRedirects
self.timeout = timeout
self.proxy = proxy
}
}

Expand All @@ -199,6 +205,37 @@ public class HTTPClient {
}
}

private extension ChannelPipeline {
func addProxyHandler(for request: HTTPClient.Request, decoder: ByteToMessageHandler<HTTPResponseDecoder>, encoder: HTTPRequestEncoder, tlsConfiguration: TLSConfiguration?) -> EventLoopFuture<Void> {
let handler = HTTPClientProxyHandler(host: request.host, port: request.port, onConnect: { channel in
return channel.pipeline.removeHandler(decoder).flatMap {
return channel.pipeline.addHandler(
ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)),
position: .after(encoder)
)
}.flatMap {
return channel.pipeline.addSSLHandlerIfNeeded(for: request, tlsConfiguration: tlsConfiguration)
}
})
return self.addHandler(handler)
}

func addSSLHandlerIfNeeded(for request: HTTPClient.Request, tlsConfiguration: TLSConfiguration?) -> EventLoopFuture<Void> {
guard request.useTLS else {
return self.eventLoop.makeSucceededFuture(())
}

do {
let tlsConfiguration = tlsConfiguration ?? TLSConfiguration.forClient()
let context = try NIOSSLContext(configuration: tlsConfiguration)
return self.addHandler(try NIOSSLClientHandler(context: context, serverHostname: request.host),
position: .first)
} catch {
return self.eventLoop.makeFailedFuture(error)
}
}
}

public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
private enum Code: Equatable {
case invalidURL
Expand All @@ -211,6 +248,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
case cancelled
case identityCodingIncorrectlyPresent
case chunkedSpecifiedMultipleTimes
case invalidProxyResponse
}

private var code: Code
Expand All @@ -233,4 +271,5 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
public static let cancelled = HTTPClientError(code: .cancelled)
public static let identityCodingIncorrectlyPresent = HTTPClientError(code: .identityCodingIncorrectlyPresent)
public static let chunkedSpecifiedMultipleTimes = HTTPClientError(code: .chunkedSpecifiedMultipleTimes)
public static let invalidProxyResponse = HTTPClientError(code: .invalidProxyResponse)
}
66 changes: 59 additions & 7 deletions Tests/NIOHTTPClientTests/HTTPClientTestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,27 @@ internal class HttpBin {
return self.serverChannel.localAddress!
}

init(ssl: Bool = false) {
self.serverChannel = try! ServerBootstrap(group: self.group)
static func configureTLS(channel: Channel) -> EventLoopFuture<Void> {
let configuration = TLSConfiguration.forServer(certificateChain: [.certificate(try! NIOSSLCertificate(buffer: cert.utf8.map(Int8.init), format: .pem))],
privateKey: .privateKey(try! NIOSSLPrivateKey(buffer: key.utf8.map(Int8.init), format: .pem)))
let context = try! NIOSSLContext(configuration: configuration)
return channel.pipeline.addHandler(try! NIOSSLServerHandler(context: context), position: .first)
}

init(ssl: Bool = false, simulateProxy: HTTPProxySimulator.Option? = nil) {
self.serverChannel = try! ServerBootstrap(group: group)
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1)
.childChannelInitializer { channel in
channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true).flatMap {
return channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true).flatMap {
if let simulateProxy = simulateProxy {
return channel.pipeline.addHandler(HTTPProxySimulator(option: simulateProxy), position: .first)
} else {
return channel.eventLoop.makeSucceededFuture(())
}
}.flatMap {
if ssl {
let configuration = TLSConfiguration.forServer(certificateChain: [.certificate(try! NIOSSLCertificate(buffer: cert.utf8.map(Int8.init), format: .pem))],
privateKey: .privateKey(try! NIOSSLPrivateKey(buffer: key.utf8.map(Int8.init), format: .pem)))
let context = try! NIOSSLContext(configuration: configuration)
return channel.pipeline.addHandler(try! NIOSSLServerHandler(context: context), position: .first).flatMap {
return HttpBin.configureTLS(channel: channel).flatMap {
channel.pipeline.addHandler(HttpBinHandler())
}
} else {
Expand All @@ -111,6 +121,48 @@ internal class HttpBin {
}
}

final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = ByteBuffer
typealias InboundOut = ByteBuffer
typealias OutboundOut = ByteBuffer

enum Option {
case plaintext
case tls
}

let option: Option

init(option: Option) {
self.option = option
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let response = """
HTTP/1.1 200 OK\r\n\
Content-Length: 0\r\n\
Connection: close\r\n\
\r\n
"""
var buffer = self.unwrapInboundIn(data)
let request = buffer.readString(length: buffer.readableBytes)!
if request.hasPrefix("CONNECT") {
var buffer = context.channel.allocator.buffer(capacity: 0)
buffer.writeString(response)
context.write(self.wrapInboundOut(buffer), promise: nil)
context.flush()
context.channel.pipeline.removeHandler(self, promise: nil)
switch self.option {
case .tls:
_ = HttpBin.configureTLS(channel: context.channel)
case .plaintext: break
}
} else {
fatalError("Expected a CONNECT request")
}
}
}

internal struct HTTPResponseBuilder {
let head: HTTPResponseHead
var body: ByteBuffer?
Expand Down
2 changes: 2 additions & 0 deletions Tests/NIOHTTPClientTests/SwiftNIOHTTPTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ extension SwiftHTTPTests {
("testRemoteClose", testRemoteClose),
("testReadTimeout", testReadTimeout),
("testCancel", testCancel),
("testProxyPlaintext", testProxyPlaintext),
("testProxyTLS", testProxyTLS),
]
}
}
Loading