diff --git a/Sources/AsyncHTTPClient/Connection.swift b/Sources/AsyncHTTPClient/Connection.swift index 3923603ee..6542ca20a 100644 --- a/Sources/AsyncHTTPClient/Connection.swift +++ b/Sources/AsyncHTTPClient/Connection.swift @@ -72,6 +72,10 @@ extension Connection { func close() -> EventLoopFuture { return self.channel.close() } + + func close(promise: EventLoopPromise?) { + return self.channel.close(promise: promise) + } } /// Methods of Connection which are used in ConnectionsState extracted as protocol diff --git a/Sources/AsyncHTTPClient/ConnectionPool.swift b/Sources/AsyncHTTPClient/ConnectionPool.swift index 1ccd82ee9..e7c68aba4 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool.swift @@ -352,6 +352,14 @@ class HTTP1ConnectionProvider { self.state.release(connection: connection, closing: closing) } + // We close defensively here: we may have failed to actually close on other codepaths, + // or we may be expecting the server to close. In either case, we want our FD back, so + // we close now to cover our backs. We don't care about the result: if the channel is + // _already_ closed, that's fine by us. + if closing { + connection.close(promise: nil) + } + switch action { case .none: break diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift index 803824a0c..24daf6c94 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift @@ -812,7 +812,7 @@ class HTTPClientInternalTests: XCTestCase { XCTAssert(connection !== connection2) try! connection2.channel.eventLoop.submit { - connection2.release(closing: true, logger: HTTPClient.loggingDisabled) + connection2.release(closing: false, logger: HTTPClient.loggingDisabled) }.wait() XCTAssertTrue(connection2.channel.isActive) } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index 956d67cdd..9962c7209 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -772,6 +772,53 @@ internal final class HttpBinForSSLUncleanShutdownHandler: ChannelInboundHandler } } +internal final class CloseWithoutClosingServerHandler: ChannelInboundHandler { + typealias InboundIn = HTTPServerRequestPart + typealias OutboundOut = HTTPServerResponsePart + + private var callback: (() -> Void)? + private var onClosePromise: EventLoopPromise? + + init(_ callback: @escaping () -> Void) { + self.callback = callback + } + + func handlerAdded(context: ChannelHandlerContext) { + self.onClosePromise = context.eventLoop.makePromise() + self.onClosePromise!.futureResult.whenSuccess(self.callback!) + self.callback = nil + } + + func handlerRemoved(context: ChannelHandlerContext) { + assert(self.onClosePromise == nil) + } + + func channelInactive(context: ChannelHandlerContext) { + if let onClosePromise = self.onClosePromise { + self.onClosePromise = nil + onClosePromise.succeed(()) + } + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + guard case .end = self.unwrapInboundIn(data) else { + return + } + + // We're gonna send a response back here, with Connection: close, but we will + // not close the connection. This reproduces #324. + let headers = HTTPHeaders([ + ("Host", "CloseWithoutClosingServerHandler"), + ("Content-Length", "0"), + ("Connection", "close"), + ]) + let head = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .ok, headers: headers) + + context.write(self.wrapOutboundOut(.head(head)), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + } +} + struct EventLoopFutureTimeoutError: Error {} extension EventLoopFuture { diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index 655f31792..f54433b1d 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -126,6 +126,7 @@ extension HTTPClientTests { ("testNoBytesSentOverBodyLimit", testNoBytesSentOverBodyLimit), ("testDoubleError", testDoubleError), ("testSSLHandshakeErrorPropagation", testSSLHandshakeErrorPropagation), + ("testWeCloseConnectionsWhenConnectionCloseSetByServer", testWeCloseConnectionsWhenConnectionCloseSetByServer), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 94f649aab..9f59e2c8d 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -2712,4 +2712,30 @@ class HTTPClientTests: XCTestCase { #endif } } + + func testWeCloseConnectionsWhenConnectionCloseSetByServer() throws { + let group = DispatchGroup() + group.enter() + + let server = try ServerBootstrap(group: self.serverGroup) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .childChannelInitializer { channel in + channel.pipeline.configureHTTPServerPipeline().flatMap { + channel.pipeline.addHandler(CloseWithoutClosingServerHandler(group.leave)) + } + } + .bind(host: "localhost", port: 0) + .wait() + + defer { + server.close(promise: nil) + } + + // Simple request, should go great. + XCTAssertNoThrow(try self.defaultClient.get(url: "http://localhost:\(server.localAddress!.port!)/").wait()) + + // Shouldn't need more than 100ms of waiting to see the close. + let result = group.wait(timeout: DispatchTime.now() + DispatchTimeInterval.milliseconds(100)) + XCTAssertEqual(result, .success, "we never closed the connection!") + } }