diff --git a/Sources/AsyncHTTPClient/ConnectionPool.swift b/Sources/AsyncHTTPClient/ConnectionPool.swift index d17612c38..ebede078b 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool.swift @@ -383,32 +383,33 @@ final class ConnectionPool { } return channel.flatMap { channel -> EventLoopFuture in - channel.pipeline.addSSLHandlerIfNeeded(for: self.key, tlsConfiguration: self.configuration.tlsConfiguration, handshakePromise: handshakePromise).flatMap { + channel.pipeline.addSSLHandlerIfNeeded(for: self.key, tlsConfiguration: self.configuration.tlsConfiguration, handshakePromise: handshakePromise) + return handshakePromise.futureResult.flatMap { channel.pipeline.addHTTPClientHandlers(leftOverBytesStrategy: .forwardBytes) }.map { let connection = Connection(key: self.key, channel: channel, parentPool: self.parentPool) connection.isLeased = true return connection } - }.flatMap { connection in - handshakePromise.futureResult.map { - self.configureCloseCallback(of: connection) - return connection - }.flatMapError { error in - connection.closePromise.succeed(()) - let action = self.parentPool.connectionProvidersLock.withLock { - self.stateLock.withLock { - self.state.failedConnectionAction() - } - } - switch action { - case .makeConnectionAndComplete(let el, let promise): - self.makeConnection(on: el).cascade(to: promise) - case .none: - break + }.map { connection in + self.configureCloseCallback(of: connection) + return connection + }.flatMapError { error in + // This promise may not have been completed if we reach this + // so we fail it to avoid any leak + handshakePromise.fail(error) + let action = self.parentPool.connectionProvidersLock.withLock { + self.stateLock.withLock { + self.state.failedConnectionAction() } - return self.eventLoop.makeFailedFuture(error) } + switch action { + case .makeConnectionAndComplete(let el, let promise): + self.makeConnection(on: el).cascade(to: promise) + case .none: + break + } + return self.eventLoop.makeFailedFuture(error) } } diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index da6ca3952..0ecd49a05 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -625,10 +625,10 @@ extension ChannelPipeline { return addHandlers([encoder, decoder, handler]) } - func addSSLHandlerIfNeeded(for key: ConnectionPool.Key, tlsConfiguration: TLSConfiguration?, handshakePromise: EventLoopPromise) -> EventLoopFuture { + func addSSLHandlerIfNeeded(for key: ConnectionPool.Key, tlsConfiguration: TLSConfiguration?, handshakePromise: EventLoopPromise) { guard key.scheme == .https else { handshakePromise.succeed(()) - return self.eventLoop.makeSucceededFuture(()) + return } do { @@ -638,10 +638,9 @@ extension ChannelPipeline { try NIOSSLClientHandler(context: context, serverHostname: key.host.isIPAddress ? nil : key.host), TLSEventsHandler(completionPromise: handshakePromise), ] - - return self.addHandlers(handlers) + self.addHandlers(handlers).cascadeFailure(to: handshakePromise) } catch { - return self.eventLoop.makeFailedFuture(error) + handshakePromise.fail(error) } } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index 07cca50d7..9a53e36ce 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -95,6 +95,7 @@ extension HTTPClientTests { ("testWeRecoverFromServerThatClosesTheConnectionOnUs", testWeRecoverFromServerThatClosesTheConnectionOnUs), ("testPoolClosesIdleConnections", testPoolClosesIdleConnections), ("testRacePoolIdleConnectionsAndGet", testRacePoolIdleConnectionsAndGet), + ("testAvoidLeakingTLSHandshakeCompletionPromise", testAvoidLeakingTLSHandshakeCompletionPromise), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 049ce78d4..bba05a63a 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -1656,4 +1656,21 @@ class HTTPClientTests: XCTestCase { Thread.sleep(forTimeInterval: 0.01 + .random(in: -0.05...0.05)) } } + + func testAvoidLeakingTLSHandshakeCompletionPromise() { + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) + let httpBin = HTTPBin() + let port = httpBin.port + XCTAssertNoThrow(try httpBin.shutdown()) + defer { + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) + } + + XCTAssertThrowsError(try httpClient.get(url: "http://localhost:\(port)").wait()) { error in + guard error is NIOConnectionError else { + XCTFail("Unexpected error: \(error)") + return + } + } + } }