Skip to content

Commit 15ee7ee

Browse files
tanner0101artemredkin
authored andcommitted
add proxy support (swift-server#1)
1 parent 0cf27fc commit 15ee7ee

File tree

6 files changed

+307
-33
lines changed

6 files changed

+307
-33
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
.build
22
Package.resolved
33
*.xcodeproj
4+
DerivedData

Diff for: Sources/NIOHTTPClient/HTTPClientProxyHandler.swift

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the SwiftNIOHTTPClient open source project
4+
//
5+
// Copyright (c) 2018-2019 Swift Server Working Group and the SwiftNIOHTTPClient project authors
6+
// Licensed under Apache License v2.0
7+
//
8+
// See LICENSE.txt for license information
9+
// See CONTRIBUTORS.txt for the list of SwiftNIOHTTPClient project authors
10+
//
11+
// SPDX-License-Identifier: Apache-2.0
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
import NIO
16+
import NIOHTTP1
17+
18+
/// Specifies the remote address of an HTTP proxy.
19+
///
20+
/// Adding an `HTTPClientProxy` to your client's `HTTPClientConfiguration`
21+
/// will cause requests to be passed through the specified proxy using the
22+
/// HTTP `CONNECT` method.
23+
///
24+
/// If a `TLSConfiguration` is used in conjunction with `HTTPClientProxy`,
25+
/// TLS will be established _after_ successful proxy, between your client
26+
/// and the destination server.
27+
public extension HTTPClient {
28+
struct Proxy {
29+
internal let host: String
30+
internal let port: Int
31+
32+
public static func server(host: String, port: Int) -> Proxy {
33+
return .init(host: host, port: port)
34+
}
35+
}
36+
}
37+
38+
internal final class HTTPClientProxyHandler: ChannelDuplexHandler, RemovableChannelHandler {
39+
typealias InboundIn = HTTPClientResponsePart
40+
typealias OutboundIn = HTTPClientRequestPart
41+
typealias OutboundOut = HTTPClientRequestPart
42+
43+
enum WriteItem {
44+
case write(NIOAny, EventLoopPromise<Void>?)
45+
case flush
46+
}
47+
48+
enum ReadState {
49+
case awaitingResponse
50+
case connecting
51+
case connected
52+
}
53+
54+
private let host: String
55+
private let port: Int
56+
private var onConnect: (Channel) -> EventLoopFuture<Void>
57+
private var writeBuffer: CircularBuffer<WriteItem>
58+
private var readBuffer: CircularBuffer<NIOAny>
59+
private var readState: ReadState
60+
61+
init(host: String, port: Int, onConnect: @escaping (Channel) -> EventLoopFuture<Void>) {
62+
self.host = host
63+
self.port = port
64+
self.onConnect = onConnect
65+
self.writeBuffer = .init()
66+
self.readBuffer = .init()
67+
self.readState = .awaitingResponse
68+
}
69+
70+
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
71+
switch self.readState {
72+
case .awaitingResponse:
73+
let res = self.unwrapInboundIn(data)
74+
switch res {
75+
case .head(let head):
76+
switch head.status.code {
77+
case 200..<300:
78+
// Any 2xx (Successful) response indicates that the sender (and all
79+
// inbound proxies) will switch to tunnel mode immediately after the
80+
// blank line that concludes the successful response's header section
81+
break
82+
default:
83+
// Any response other than a successful response
84+
// indicates that the tunnel has not yet been formed and that the
85+
// connection remains governed by HTTP.
86+
context.fireErrorCaught(HTTPClientError.invalidProxyResponse)
87+
}
88+
case .end:
89+
self.readState = .connecting
90+
_ = self.handleConnect(context: context)
91+
case .body:
92+
break
93+
}
94+
case .connecting:
95+
self.readBuffer.append(data)
96+
case .connected:
97+
context.fireChannelRead(data)
98+
}
99+
}
100+
101+
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
102+
self.writeBuffer.append(.write(data, promise))
103+
}
104+
105+
func flush(context: ChannelHandlerContext) {
106+
self.writeBuffer.append(.flush)
107+
}
108+
109+
func channelActive(context: ChannelHandlerContext) {
110+
self.sendConnect(context: context)
111+
context.fireChannelActive()
112+
}
113+
114+
// MARK: Private
115+
116+
private func handleConnect(context: ChannelHandlerContext) -> EventLoopFuture<Void> {
117+
return self.onConnect(context.channel).flatMap {
118+
self.readState = .connected
119+
120+
// forward any buffered reads
121+
while !self.readBuffer.isEmpty {
122+
context.fireChannelRead(self.readBuffer.removeFirst())
123+
}
124+
125+
// calls to context.write may be re-entrant
126+
while !self.writeBuffer.isEmpty {
127+
switch self.writeBuffer.removeFirst() {
128+
case .flush:
129+
context.flush()
130+
case .write(let data, let promise):
131+
context.write(data, promise: promise)
132+
}
133+
}
134+
return context.pipeline.removeHandler(self)
135+
}
136+
}
137+
138+
private func sendConnect(context: ChannelHandlerContext) {
139+
var head = HTTPRequestHead(
140+
version: .init(major: 1, minor: 1),
141+
method: .CONNECT,
142+
uri: "\(self.host):\(self.port)"
143+
)
144+
head.headers.add(name: "proxy-connection", value: "keep-alive")
145+
context.write(self.wrapOutboundOut(.head(head)), promise: nil)
146+
context.write(self.wrapOutboundOut(.end(nil)), promise: nil)
147+
context.flush()
148+
}
149+
}

Diff for: Sources/NIOHTTPClient/SwiftNIOHTTP.swift

+65-26
Original file line numberDiff line numberDiff line change
@@ -124,24 +124,33 @@ public class HTTPClient {
124124
var bootstrap = ClientBootstrap(group: group)
125125
.channelOption(ChannelOptions.socket(SocketOptionLevel(IPPROTO_TCP), TCP_NODELAY), value: 1)
126126
.channelInitializer { channel in
127-
channel.pipeline.addHTTPClientHandlers().flatMap {
128-
self.configureSSL(channel: channel, useTLS: request.useTLS, hostname: request.host)
129-
}.flatMap {
130-
if let readTimeout = timeout.read {
131-
return channel.pipeline.addHandler(IdleStateHandler(readTimeout: readTimeout))
132-
} else {
133-
return channel.eventLoop.makeSucceededFuture(())
127+
let encoder = HTTPRequestEncoder()
128+
let decoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes))
129+
return channel.pipeline.addHandlers([encoder, decoder], position: .first).flatMap {
130+
switch self.configuration.proxy {
131+
case .none:
132+
return channel.pipeline.addSSLHandlerIfNeeded(for: request, tlsConfiguration: self.configuration.tlsConfiguration)
133+
case .some:
134+
return channel.pipeline.addProxyHandler(for: request, decoder: decoder, encoder: encoder, tlsConfiguration: self.configuration.tlsConfiguration)
135+
}
136+
}.flatMap {
137+
if let readTimeout = timeout.read {
138+
return channel.pipeline.addHandler(IdleStateHandler(readTimeout: readTimeout))
139+
} else {
140+
return channel.eventLoop.makeSucceededFuture(())
141+
}
142+
}.flatMap {
143+
let taskHandler = TaskHandler(task: task, delegate: delegate, promise: promise, redirectHandler: redirectHandler)
144+
return channel.pipeline.addHandler(taskHandler)
134145
}
135-
}.flatMap {
136-
channel.pipeline.addHandler(TaskHandler(task: task, delegate: delegate, promise: promise, redirectHandler: redirectHandler))
137-
}
138146
}
139147

140148
if let connectTimeout = timeout.connect {
141149
bootstrap = bootstrap.connectTimeout(connectTimeout)
142150
}
143-
144-
bootstrap.connect(host: request.host, port: request.port)
151+
152+
let address = self.resolveAddress(request: request, proxy: self.configuration.proxy)
153+
bootstrap.connect(host: address.host, port: address.port)
145154
.map { channel in
146155
task.setChannel(channel)
147156
}
@@ -155,36 +164,33 @@ public class HTTPClient {
155164
return task
156165
}
157166

158-
private func configureSSL(channel: Channel, useTLS: Bool, hostname: String) -> EventLoopFuture<Void> {
159-
if useTLS {
160-
do {
161-
let tlsConfiguration = self.configuration.tlsConfiguration ?? TLSConfiguration.forClient()
162-
let context = try NIOSSLContext(configuration: tlsConfiguration)
163-
return channel.pipeline.addHandler(try NIOSSLClientHandler(context: context, serverHostname: hostname),
164-
position: .first)
165-
} catch {
166-
return channel.eventLoop.makeFailedFuture(error)
167-
}
168-
} else {
169-
return channel.eventLoop.makeSucceededFuture(())
167+
private func resolveAddress(request: Request, proxy: Proxy?) -> (host: String, port: Int) {
168+
switch self.configuration.proxy {
169+
case .none:
170+
return (request.host, request.port)
171+
case .some(let proxy):
172+
return (proxy.host, proxy.port)
170173
}
171174
}
172175

173176
public struct Configuration {
174177
public var tlsConfiguration: TLSConfiguration?
175178
public var followRedirects: Bool
176179
public var timeout: Timeout
180+
public var proxy: Proxy?
177181

178-
public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: Bool = false, timeout: Timeout = Timeout()) {
182+
public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil) {
179183
self.tlsConfiguration = tlsConfiguration
180184
self.followRedirects = followRedirects
181185
self.timeout = timeout
186+
self.proxy = proxy
182187
}
183188

184-
public init(certificateVerification: CertificateVerification, followRedirects: Bool = false, timeout: Timeout = Timeout()) {
189+
public init(certificateVerification: CertificateVerification, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil) {
185190
self.tlsConfiguration = TLSConfiguration.forClient(certificateVerification: certificateVerification)
186191
self.followRedirects = followRedirects
187192
self.timeout = timeout
193+
self.proxy = proxy
188194
}
189195
}
190196

@@ -199,6 +205,37 @@ public class HTTPClient {
199205
}
200206
}
201207

208+
private extension ChannelPipeline {
209+
func addProxyHandler(for request: HTTPClient.Request, decoder: ByteToMessageHandler<HTTPResponseDecoder>, encoder: HTTPRequestEncoder, tlsConfiguration: TLSConfiguration?) -> EventLoopFuture<Void> {
210+
let handler = HTTPClientProxyHandler(host: request.host, port: request.port, onConnect: { channel in
211+
return channel.pipeline.removeHandler(decoder).flatMap {
212+
return channel.pipeline.addHandler(
213+
ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)),
214+
position: .after(encoder)
215+
)
216+
}.flatMap {
217+
return channel.pipeline.addSSLHandlerIfNeeded(for: request, tlsConfiguration: tlsConfiguration)
218+
}
219+
})
220+
return self.addHandler(handler)
221+
}
222+
223+
func addSSLHandlerIfNeeded(for request: HTTPClient.Request, tlsConfiguration: TLSConfiguration?) -> EventLoopFuture<Void> {
224+
guard request.useTLS else {
225+
return self.eventLoop.makeSucceededFuture(())
226+
}
227+
228+
do {
229+
let tlsConfiguration = tlsConfiguration ?? TLSConfiguration.forClient()
230+
let context = try NIOSSLContext(configuration: tlsConfiguration)
231+
return self.addHandler(try NIOSSLClientHandler(context: context, serverHostname: request.host),
232+
position: .first)
233+
} catch {
234+
return self.eventLoop.makeFailedFuture(error)
235+
}
236+
}
237+
}
238+
202239
public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
203240
private enum Code: Equatable {
204241
case invalidURL
@@ -211,6 +248,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
211248
case cancelled
212249
case identityCodingIncorrectlyPresent
213250
case chunkedSpecifiedMultipleTimes
251+
case invalidProxyResponse
214252
}
215253

216254
private var code: Code
@@ -233,4 +271,5 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
233271
public static let cancelled = HTTPClientError(code: .cancelled)
234272
public static let identityCodingIncorrectlyPresent = HTTPClientError(code: .identityCodingIncorrectlyPresent)
235273
public static let chunkedSpecifiedMultipleTimes = HTTPClientError(code: .chunkedSpecifiedMultipleTimes)
274+
public static let invalidProxyResponse = HTTPClientError(code: .invalidProxyResponse)
236275
}

Diff for: Tests/NIOHTTPClientTests/HTTPClientTestUtils.swift

+59-7
Original file line numberDiff line numberDiff line change
@@ -86,17 +86,27 @@ internal class HttpBin {
8686
return self.serverChannel.localAddress!
8787
}
8888

89-
init(ssl: Bool = false) {
90-
self.serverChannel = try! ServerBootstrap(group: self.group)
89+
static func configureTLS(channel: Channel) -> EventLoopFuture<Void> {
90+
let configuration = TLSConfiguration.forServer(certificateChain: [.certificate(try! NIOSSLCertificate(buffer: cert.utf8.map(Int8.init), format: .pem))],
91+
privateKey: .privateKey(try! NIOSSLPrivateKey(buffer: key.utf8.map(Int8.init), format: .pem)))
92+
let context = try! NIOSSLContext(configuration: configuration)
93+
return channel.pipeline.addHandler(try! NIOSSLServerHandler(context: context), position: .first)
94+
}
95+
96+
init(ssl: Bool = false, simulateProxy: HTTPProxySimulator.Option? = nil) {
97+
self.serverChannel = try! ServerBootstrap(group: group)
9198
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
9299
.childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1)
93100
.childChannelInitializer { channel in
94-
channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true).flatMap {
101+
return channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true).flatMap {
102+
if let simulateProxy = simulateProxy {
103+
return channel.pipeline.addHandler(HTTPProxySimulator(option: simulateProxy), position: .first)
104+
} else {
105+
return channel.eventLoop.makeSucceededFuture(())
106+
}
107+
}.flatMap {
95108
if ssl {
96-
let configuration = TLSConfiguration.forServer(certificateChain: [.certificate(try! NIOSSLCertificate(buffer: cert.utf8.map(Int8.init), format: .pem))],
97-
privateKey: .privateKey(try! NIOSSLPrivateKey(buffer: key.utf8.map(Int8.init), format: .pem)))
98-
let context = try! NIOSSLContext(configuration: configuration)
99-
return channel.pipeline.addHandler(try! NIOSSLServerHandler(context: context), position: .first).flatMap {
109+
return HttpBin.configureTLS(channel: channel).flatMap {
100110
channel.pipeline.addHandler(HttpBinHandler())
101111
}
102112
} else {
@@ -111,6 +121,48 @@ internal class HttpBin {
111121
}
112122
}
113123

124+
final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler {
125+
typealias InboundIn = ByteBuffer
126+
typealias InboundOut = ByteBuffer
127+
typealias OutboundOut = ByteBuffer
128+
129+
enum Option {
130+
case plaintext
131+
case tls
132+
}
133+
134+
let option: Option
135+
136+
init(option: Option) {
137+
self.option = option
138+
}
139+
140+
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
141+
let response = """
142+
HTTP/1.1 200 OK\r\n\
143+
Content-Length: 0\r\n\
144+
Connection: close\r\n\
145+
\r\n
146+
"""
147+
var buffer = self.unwrapInboundIn(data)
148+
let request = buffer.readString(length: buffer.readableBytes)!
149+
if request.hasPrefix("CONNECT") {
150+
var buffer = context.channel.allocator.buffer(capacity: 0)
151+
buffer.writeString(response)
152+
context.write(self.wrapInboundOut(buffer), promise: nil)
153+
context.flush()
154+
context.channel.pipeline.removeHandler(self, promise: nil)
155+
switch self.option {
156+
case .tls:
157+
_ = HttpBin.configureTLS(channel: context.channel)
158+
case .plaintext: break
159+
}
160+
} else {
161+
fatalError("Expected a CONNECT request")
162+
}
163+
}
164+
}
165+
114166
internal struct HTTPResponseBuilder {
115167
let head: HTTPResponseHead
116168
var body: ByteBuffer?

Diff for: Tests/NIOHTTPClientTests/SwiftNIOHTTPTests+XCTest.swift

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ extension SwiftHTTPTests {
3838
("testRemoteClose", testRemoteClose),
3939
("testReadTimeout", testReadTimeout),
4040
("testCancel", testCancel),
41+
("testProxyPlaintext", testProxyPlaintext),
42+
("testProxyTLS", testProxyTLS),
4143
]
4244
}
4345
}

0 commit comments

Comments
 (0)