Skip to content

Unconditionally insert TLSEventsHandler #349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 18, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 19 additions & 23 deletions Sources/AsyncHTTPClient/HTTPClient.swift
Original file line number Diff line number Diff line change
@@ -900,27 +900,25 @@ extension ChannelPipeline {
try sync.addHandler(handler)
}

func syncAddSSLHandlerIfNeeded(for key: ConnectionPool.Key, tlsConfiguration: TLSConfiguration?, addSSLClient: Bool, handshakePromise: EventLoopPromise<Void>) {
guard key.scheme.requiresTLS else {
handshakePromise.succeed(())
return
}
func syncAddLateSSLHandlerIfNeeded(for key: ConnectionPool.Key, tlsConfiguration: TLSConfiguration?, handshakePromise: EventLoopPromise<Void>) {
precondition(key.scheme.requiresTLS)

do {
let synchronousPipelineView = self.syncOperations

// We add the TLSEventsHandler first so that it's always in the pipeline before any other TLS handler we add.
// If we're here, we must not have one in the channel already.
assert((try? synchronousPipelineView.context(name: TLSEventsHandler.handlerName)) == nil)
let eventsHandler = TLSEventsHandler(completionPromise: handshakePromise)
try synchronousPipelineView.addHandler(eventsHandler)

if addSSLClient {
let tlsConfiguration = tlsConfiguration ?? TLSConfiguration.forClient()
let context = try NIOSSLContext(configuration: tlsConfiguration)
try synchronousPipelineView.addHandler(
try NIOSSLClientHandler(context: context, serverHostname: (key.host.isIPAddress || key.host.isEmpty) ? nil : key.host),
position: .before(eventsHandler)
)
}
try synchronousPipelineView.addHandler(eventsHandler, name: TLSEventsHandler.handlerName)

// Then we add the SSL handler.
let tlsConfiguration = tlsConfiguration ?? TLSConfiguration.forClient()
let context = try NIOSSLContext(configuration: tlsConfiguration)
try synchronousPipelineView.addHandler(
try NIOSSLClientHandler(context: context, serverHostname: (key.host.isIPAddress || key.host.isEmpty) ? nil : key.host),
position: .before(eventsHandler)
)
} catch {
handshakePromise.fail(error)
}
@@ -930,7 +928,9 @@ extension ChannelPipeline {
class TLSEventsHandler: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = NIOAny

var completionPromise: EventLoopPromise<Void>?
static let handlerName: String = "AsyncHTTPClient.HTTPClient.TLSEventsHandler"

var completionPromise: EventLoopPromise<Void>

init(completionPromise: EventLoopPromise<Void>) {
self.completionPromise = completionPromise
@@ -940,9 +940,7 @@ class TLSEventsHandler: ChannelInboundHandler, RemovableChannelHandler {
if let tlsEvent = event as? TLSUserEvent {
switch tlsEvent {
case .handshakeCompleted:
self.completionPromise?.succeed(())
self.completionPromise = nil
context.pipeline.removeHandler(self, promise: nil)
self.completionPromise.succeed(())
case .shutdownCompleted:
break
}
@@ -951,15 +949,13 @@ class TLSEventsHandler: ChannelInboundHandler, RemovableChannelHandler {
}

func errorCaught(context: ChannelHandlerContext, error: Error) {
self.completionPromise?.fail(error)
self.completionPromise = nil
context.pipeline.removeHandler(self, promise: nil)
self.completionPromise.fail(error)
context.fireErrorCaught(error)
}

func handlerRemoved(context: ChannelHandlerContext) {
struct NoResult: Error {}
self.completionPromise?.fail(NoResult())
self.completionPromise.fail(NoResult())
}
}

33 changes: 28 additions & 5 deletions Sources/AsyncHTTPClient/Utils.swift
Original file line number Diff line number Diff line change
@@ -131,6 +131,11 @@ extension NIOClientTCPBootstrap {
do {
if let proxy = configuration.proxy {
try channel.pipeline.syncAddProxyHandler(host: host, port: port, authorization: proxy.authorization)
} else if requiresTLS {
// We only add the handshake verifier if we need TLS and we're not going through a proxy. If we're going
// through a proxy we add it later.
let completionPromise = channel.eventLoop.makePromise(of: Void.self)
try channel.pipeline.syncOperations.addHandler(TLSEventsHandler(completionPromise: completionPromise), name: TLSEventsHandler.handlerName)
}
return channel.eventLoop.makeSucceededVoidFuture()
} catch {
@@ -162,14 +167,32 @@ extension NIOClientTCPBootstrap {
}

return channel.flatMap { channel in
let requiresSSLHandler = configuration.proxy != nil && key.scheme.requiresTLS
let handshakePromise = channel.eventLoop.makePromise(of: Void.self)

channel.pipeline.syncAddSSLHandlerIfNeeded(for: key, tlsConfiguration: configuration.tlsConfiguration, addSSLClient: requiresSSLHandler, handshakePromise: handshakePromise)
let requiresTLS = key.scheme.requiresTLS
let requiresLateSSLHandler = configuration.proxy != nil && requiresTLS
let handshakeFuture: EventLoopFuture<Void>

if requiresLateSSLHandler {
let handshakePromise = channel.eventLoop.makePromise(of: Void.self)
channel.pipeline.syncAddLateSSLHandlerIfNeeded(for: key, tlsConfiguration: configuration.tlsConfiguration, handshakePromise: handshakePromise)
handshakeFuture = handshakePromise.futureResult
} else if requiresTLS {
do {
handshakeFuture = try channel.pipeline.syncOperations.handler(type: TLSEventsHandler.self).completionPromise.futureResult
} catch {
return channel.eventLoop.makeFailedFuture(error)
}
} else {
handshakeFuture = channel.eventLoop.makeSucceededVoidFuture()
}

return handshakePromise.futureResult.flatMapThrowing {
return handshakeFuture.flatMapThrowing {
let syncOperations = channel.pipeline.syncOperations

// If we got here and we had a TLSEventsHandler in the pipeline, we can remove it ow.
if requiresTLS {
channel.pipeline.removeHandler(name: TLSEventsHandler.handlerName, promise: nil)
}

try syncOperations.addHTTPClientHandlers(leftOverBytesStrategy: .forwardBytes)

#if canImport(Network)
1 change: 1 addition & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift
Original file line number Diff line number Diff line change
@@ -129,6 +129,7 @@ extension HTTPClientTests {
("testSSLHandshakeErrorPropagationDelayedClose", testSSLHandshakeErrorPropagationDelayedClose),
("testWeCloseConnectionsWhenConnectionCloseSetByServer", testWeCloseConnectionsWhenConnectionCloseSetByServer),
("testBiDirectionalStreaming", testBiDirectionalStreaming),
("testSynchronousHandshakeErrorReporting", testSynchronousHandshakeErrorReporting),
]
}
}
23 changes: 23 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTests.swift
Original file line number Diff line number Diff line change
@@ -2821,4 +2821,27 @@ class HTTPClientTests: XCTestCase {

XCTAssertNoThrow(try future.wait())
}

func testSynchronousHandshakeErrorReporting() throws {
// This only affects cases where we use NIOSSL.
guard !isTestingNIOTS() else { return }

// We use a specially crafted client that has no cipher suites to offer. To do this we ask
// only for cipher suites incompatible with our TLS version.
let tlsConfig = TLSConfiguration.forClient(minimumTLSVersion: .tlsv13, maximumTLSVersion: .tlsv12, certificateVerification: .none)
let localHTTPBin = HTTPBin(ssl: true)
let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup),
configuration: HTTPClient.Configuration(tlsConfiguration: tlsConfig))
defer {
XCTAssertNoThrow(try localClient.syncShutdown())
XCTAssertNoThrow(try localHTTPBin.shutdown())
}

XCTAssertThrowsError(try localClient.get(url: "https://localhost:\(localHTTPBin.port)/").wait()) { error in
guard let clientError = error as? NIOSSLError, case NIOSSLError.handshakeFailed = clientError else {
XCTFail("Unexpected error: \(error)")
return
}
}
}
}