diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index bb6779a91..b02396b75 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -604,7 +604,13 @@ internal class TaskHandler: RemovableChann var state: State = .idle var pendingRead = false var mayRead = true - var closing = false + var closing = false { + didSet { + assert(self.closing || !oldValue, + "BUG in AsyncHTTPClient: TaskHandler.closing went from true (no conn reuse) to true (do reuse).") + } + } + let kind: HTTPClient.Request.Kind init(task: HTTPClient.Task, @@ -736,6 +742,14 @@ extension TaskHandler: ChannelDuplexHandler { head.headers = headers + if head.headers[canonicalForm: "connection"].map({ $0.lowercased() }).contains("close") { + self.closing = true + } + // This assert can go away when (if ever!) the above `if` correctly handles other HTTP versions. For example + // in HTTP/1.0, we need to treat the absence of a 'connection: keep-alive' as a close too. + assert(head.version == HTTPVersion(major: 1, minor: 1), + "Sending a request in HTTP version \(head.version) which is unsupported by the above `if`") + context.write(wrapOutboundOut(.head(head))).map { self.callOutToDelegateFireAndForget(value: head, self.delegate.didSendRequestHead) }.flatMap { diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift index d52eea258..98085cf14 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift @@ -139,11 +139,10 @@ class HTTPClientInternalTests: XCTestCase { } let upload = try! httpClient.post(url: "http://localhost:\(httpBin.port)/post", body: body).wait() - let bytes = upload.body.flatMap { $0.getData(at: 0, length: $0.readableBytes) } - let data = try! JSONDecoder().decode(RequestInfo.self, from: bytes!) + let data = upload.body.flatMap { try? JSONDecoder().decode(RequestInfo.self, from: $0) } XCTAssertEqual(.ok, upload.status) - XCTAssertEqual("id: 0id: 1id: 2id: 3id: 4id: 5id: 6id: 7id: 8id: 9", data.data) + XCTAssertEqual("id: 0id: 1id: 2id: 3id: 4id: 5id: 6id: 7id: 8id: 9", data?.data) } func testProxyStreamingFailure() throws { @@ -466,7 +465,9 @@ class HTTPClientInternalTests: XCTestCase { XCTAssertNoThrow(try httpBin.shutdown()) } - let req = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get", method: .GET, headers: ["Connection": "close"], body: nil) + let req = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get", + method: .GET, + headers: ["X-Send-Back-Header-Connection": "close"], body: nil) _ = try! httpClient.execute(request: req).wait() let el = httpClient.eventLoopGroup.next() try! el.scheduleTask(in: .milliseconds(500)) { diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index 06c2cc81c..d1e109844 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -339,7 +339,7 @@ final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { } internal struct HTTPResponseBuilder { - let head: HTTPResponseHead + var head: HTTPResponseHead var body: ByteBuffer? init(_ version: HTTPVersion = HTTPVersion(major: 1, minor: 1), status: HTTPResponseStatus, headers: HTTPHeaders = HTTPHeaders()) { @@ -357,8 +357,13 @@ internal struct HTTPResponseBuilder { } } +let globalRequestCounter = NIOAtomic.makeAtomic(value: 0) +let globalConnectionCounter = NIOAtomic.makeAtomic(value: 0) + internal struct RequestInfo: Codable { - let data: String + var data: String + var requestNumber: Int + var connectionNumber: Int } internal final class HttpBinHandler: ChannelInboundHandler { @@ -367,16 +372,19 @@ internal final class HttpBinHandler: ChannelInboundHandler { let channelPromise: EventLoopPromise? var resps = CircularBuffer() - var closeAfterResponse = false + var responseHeaders = HTTPHeaders() var delay: TimeAmount = .seconds(0) let creationDate = Date() let maxChannelAge: TimeAmount? var shouldClose = false var isServingRequest = false + let myConnectionNumber: Int + var currentRequestNumber: Int = -1 init(channelPromise: EventLoopPromise? = nil, maxChannelAge: TimeAmount? = nil) { self.channelPromise = channelPromise self.maxChannelAge = maxChannelAge + self.myConnectionNumber = globalConnectionCounter.add(1) } func handlerAdded(context: ChannelHandlerContext) { @@ -402,10 +410,12 @@ internal final class HttpBinHandler: ChannelInboundHandler { self.delay = .nanoseconds(0) } - if let connection = head.headers["Connection"].first { - self.closeAfterResponse = (connection == "close") - } else { - self.closeAfterResponse = false + for header in head.headers { + let needle = "x-send-back-header-" + if header.name.lowercased().starts(with: needle) { + self.responseHeaders.add(name: String(header.name.dropFirst(needle.count)), + value: header.value) + } } } @@ -413,16 +423,18 @@ internal final class HttpBinHandler: ChannelInboundHandler { self.isServingRequest = true switch self.unwrapInboundIn(data) { case .head(let req): + self.responseHeaders = HTTPHeaders() + self.currentRequestNumber = globalRequestCounter.add(1) self.parseAndSetOptions(from: req) let urlComponents = URLComponents(string: req.uri)! switch urlComponents.percentEncodedPath { case "/": - var headers = HTTPHeaders() + var headers = self.responseHeaders headers.add(name: "X-Is-This-Slash", value: "Yes") self.resps.append(HTTPResponseBuilder(status: .ok, headers: headers)) return case "/echo-uri": - var headers = HTTPHeaders() + var headers = self.responseHeaders headers.add(name: "X-Calling-URI", value: req.uri) self.resps.append(HTTPResponseBuilder(status: .ok, headers: headers)) return @@ -436,6 +448,13 @@ internal final class HttpBinHandler: ChannelInboundHandler { } self.resps.append(HTTPResponseBuilder(status: .ok)) return + case "/stats": + var body = context.channel.allocator.buffer(capacity: 1) + body.writeString("Just some stats mate.") + var builder = HTTPResponseBuilder(status: .ok) + builder.add(body) + + self.resps.append(builder) case "/post": if req.method != .POST { self.resps.append(HTTPResponseBuilder(status: .methodNotAllowed)) @@ -444,29 +463,29 @@ internal final class HttpBinHandler: ChannelInboundHandler { self.resps.append(HTTPResponseBuilder(status: .ok)) return case "/redirect/302": - var headers = HTTPHeaders() + var headers = self.responseHeaders headers.add(name: "location", value: "/ok") self.resps.append(HTTPResponseBuilder(status: .found, headers: headers)) return case "/redirect/https": let port = self.value(for: "port", from: urlComponents.query!) - var headers = HTTPHeaders() + var headers = self.responseHeaders headers.add(name: "Location", value: "https://localhost:\(port)/ok") self.resps.append(HTTPResponseBuilder(status: .found, headers: headers)) return case "/redirect/loopback": let port = self.value(for: "port", from: urlComponents.query!) - var headers = HTTPHeaders() + var headers = self.responseHeaders headers.add(name: "Location", value: "http://127.0.0.1:\(port)/echohostheader") self.resps.append(HTTPResponseBuilder(status: .found, headers: headers)) return case "/redirect/infinite1": - var headers = HTTPHeaders() + var headers = self.responseHeaders headers.add(name: "Location", value: "/redirect/infinite2") self.resps.append(HTTPResponseBuilder(status: .found, headers: headers)) return case "/redirect/infinite2": - var headers = HTTPHeaders() + var headers = self.responseHeaders headers.add(name: "Location", value: "/redirect/infinite1") self.resps.append(HTTPResponseBuilder(status: .found, headers: headers)) return @@ -528,15 +547,15 @@ internal final class HttpBinHandler: ChannelInboundHandler { if self.resps.isEmpty { return } - let response = self.resps.removeFirst() + var response = self.resps.removeFirst() + response.head.headers.add(contentsOf: self.responseHeaders) context.write(wrapOutboundOut(.head(response.head)), promise: nil) - if var body = response.body { - let data = body.readData(length: body.readableBytes)! - let serialized = try! JSONEncoder().encode(RequestInfo(data: String(decoding: data, - as: Unicode.UTF8.self))) - - var responseBody = context.channel.allocator.buffer(capacity: serialized.count) - responseBody.writeBytes(serialized) + if let body = response.body { + let requestInfo = RequestInfo(data: String(buffer: body), + requestNumber: self.currentRequestNumber, + connectionNumber: self.myConnectionNumber) + let responseBody = try! JSONEncoder().encodeAsByteBuffer(requestInfo, + allocator: context.channel.allocator) context.write(wrapOutboundOut(.body(.byteBuffer(responseBody))), promise: nil) } context.eventLoop.scheduleTask(in: self.delay) { @@ -549,7 +568,8 @@ internal final class HttpBinHandler: ChannelInboundHandler { self.isServingRequest = false switch result { case .success: - if self.closeAfterResponse || self.shouldClose { + if self.responseHeaders[canonicalForm: "X-Close-Connection"].contains("true") || + self.shouldClose { context.close(promise: nil) } case .failure(let error): diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index 16726a7cf..f13a4ca5e 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -99,6 +99,10 @@ extension HTTPClientTests { ("testValidationErrorsAreSurfaced", testValidationErrorsAreSurfaced), ("testUploadsReallyStream", testUploadsReallyStream), ("testUploadStreamingCallinToleratedFromOtsideEL", testUploadStreamingCallinToleratedFromOtsideEL), + ("testWeHandleUsSendingACloseHeaderCorrectly", testWeHandleUsSendingACloseHeaderCorrectly), + ("testWeHandleUsReceivingACloseHeaderCorrectly", testWeHandleUsReceivingACloseHeaderCorrectly), + ("testWeHandleUsSendingACloseHeaderAmongstOtherConnectionHeadersCorrectly", testWeHandleUsSendingACloseHeaderAmongstOtherConnectionHeadersCorrectly), + ("testWeHandleUsReceivingACloseHeaderAmongstOtherConnectionHeadersCorrectly", testWeHandleUsReceivingACloseHeaderAmongstOtherConnectionHeadersCorrectly), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index a2fc6da7b..ffb37bf77 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -186,7 +186,7 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(response.status, .ok) } - func testHttpHostRedirect() throws { + func testHttpHostRedirect() { let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 10, allowCycles: true))) @@ -194,18 +194,15 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try localClient.syncShutdown()) } - let response = try localClient.get(url: self.defaultHTTPBinURLPrefix + "redirect/loopback?port=\(self.defaultHTTPBin.port)").wait() - guard var body = response.body else { - XCTFail("The target page should have a body containing the value of the Host header") + let url = self.defaultHTTPBinURLPrefix + "redirect/loopback?port=\(self.defaultHTTPBin.port)" + var maybeResponse: HTTPClient.Response? + XCTAssertNoThrow(maybeResponse = try localClient.get(url: url).wait()) + guard let response = maybeResponse, let body = response.body else { + XCTFail("request failed") return } - guard let responseData = body.readData(length: body.readableBytes) else { - XCTFail("Read data shouldn't be nil since we passed body.readableBytes to body.readData") - return - } - let decoder = JSONDecoder() - let hostName = try decoder.decode([String: String].self, from: responseData)["data"] - XCTAssert(hostName == "127.0.0.1") + let hostName = try? JSONDecoder().decode(RequestInfo.self, from: body).data + XCTAssertEqual("127.0.0.1", hostName) } func testPercentEncoded() throws { @@ -1637,4 +1634,126 @@ class HTTPClientTests: XCTestCase { }) XCTAssertNoThrow(try self.defaultClient.execute(request: request).wait()) } + + func testWeHandleUsSendingACloseHeaderCorrectly() { + guard let req1 = try? Request(url: self.defaultHTTPBinURLPrefix + "stats", + method: .GET, + headers: ["connection": "close"]), + let statsBytes1 = try? self.defaultClient.execute(request: req1).wait().body, + let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) else { + XCTFail("request 1 didn't work") + return + } + guard let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, + let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) else { + XCTFail("request 2 didn't work") + return + } + guard let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, + let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) else { + XCTFail("request 3 didn't work") + return + } + + // req 1 and 2 cannot share the same connection (close header) + XCTAssertEqual(stats1.connectionNumber + 1, stats2.connectionNumber) + XCTAssertEqual(stats1.requestNumber + 1, stats2.requestNumber) + + // req 2 and 3 should share the same connection (keep-alive is default) + XCTAssertEqual(stats2.requestNumber + 1, stats3.requestNumber) + XCTAssertEqual(stats2.connectionNumber, stats3.connectionNumber) + } + + func testWeHandleUsReceivingACloseHeaderCorrectly() { + guard let req1 = try? Request(url: self.defaultHTTPBinURLPrefix + "stats", + method: .GET, + headers: ["X-Send-Back-Header-Connection": "close"]), + let statsBytes1 = try? self.defaultClient.execute(request: req1).wait().body, + let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) else { + XCTFail("request 1 didn't work") + return + } + guard let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, + let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) else { + XCTFail("request 2 didn't work") + return + } + guard let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, + let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) else { + XCTFail("request 3 didn't work") + return + } + + // req 1 and 2 cannot share the same connection (close header) + XCTAssertEqual(stats1.connectionNumber + 1, stats2.connectionNumber) + XCTAssertEqual(stats1.requestNumber + 1, stats2.requestNumber) + + // req 2 and 3 should share the same connection (keep-alive is default) + XCTAssertEqual(stats2.requestNumber + 1, stats3.requestNumber) + XCTAssertEqual(stats2.connectionNumber, stats3.connectionNumber) + } + + func testWeHandleUsSendingACloseHeaderAmongstOtherConnectionHeadersCorrectly() { + for closeHeader in [("connection", "close"), ("CoNneCTION", "ClOSe")] { + guard let req1 = try? Request(url: self.defaultHTTPBinURLPrefix + "stats", + method: .GET, + headers: ["X-Send-Back-Header-\(closeHeader.0)": + "foo,\(closeHeader.1),bar"]), + let statsBytes1 = try? self.defaultClient.execute(request: req1).wait().body, + let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) else { + XCTFail("request 1 didn't work") + return + } + guard let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, + let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) else { + XCTFail("request 2 didn't work") + return + } + guard let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, + let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) else { + XCTFail("request 3 didn't work") + return + } + + // req 1 and 2 cannot share the same connection (close header) + XCTAssertEqual(stats1.connectionNumber + 1, stats2.connectionNumber) + XCTAssertEqual(stats1.requestNumber + 1, stats2.requestNumber) + + // req 2 and 3 should share the same connection (keep-alive is default) + XCTAssertEqual(stats2.requestNumber + 1, stats3.requestNumber) + XCTAssertEqual(stats2.connectionNumber, stats3.connectionNumber) + } + } + + func testWeHandleUsReceivingACloseHeaderAmongstOtherConnectionHeadersCorrectly() { + for closeHeader in [("connection", "close"), ("CoNneCTION", "ClOSe")] { + guard let req1 = try? Request(url: self.defaultHTTPBinURLPrefix + "stats", + method: .GET, + headers: ["X-Send-Back-Header-\(closeHeader.0)": + "foo,\(closeHeader.1),bar"]), + let statsBytes1 = try? self.defaultClient.execute(request: req1).wait().body, + let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) else { + XCTFail("request 1 didn't work") + return + } + guard let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, + let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) else { + XCTFail("request 2 didn't work") + return + } + guard let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, + let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) else { + XCTFail("request 3 didn't work") + return + } + + // req 1 and 2 cannot share the same connection (close header) + XCTAssertEqual(stats1.connectionNumber + 1, stats2.connectionNumber) + XCTAssertEqual(stats1.requestNumber + 1, stats2.requestNumber) + + // req 2 and 3 should share the same connection (keep-alive is default) + XCTAssertEqual(stats2.requestNumber + 1, stats3.requestNumber) + XCTAssertEqual(stats2.connectionNumber, stats3.connectionNumber) + } + } }