Skip to content

cpool: don't reuse connection if we sent close #225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion Sources/AsyncHTTPClient/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,13 @@ internal class TaskHandler<Delegate: HTTPClientResponseDelegate>: RemovableChann
var state: State = .idle
var pendingRead = false
var mayRead = true
var closing = false
var closing = false {
didSet {
assert(self.closing || !oldValue,
"BUG in AsyncHTTPClient: TaskHandler.closing went from true (no conn reuse) to true (do reuse).")
}
}

let kind: HTTPClient.Request.Kind

init(task: HTTPClient.Task<Delegate.Response>,
Expand Down Expand Up @@ -736,6 +742,14 @@ extension TaskHandler: ChannelDuplexHandler {

head.headers = headers

if head.headers[canonicalForm: "connection"].map({ $0.lowercased() }).contains("close") {
self.closing = true
}
// This assert can go away when (if ever!) the above `if` correctly handles other HTTP versions. For example
// in HTTP/1.0, we need to treat the absence of a 'connection: keep-alive' as a close too.
assert(head.version == HTTPVersion(major: 1, minor: 1),
"Sending a request in HTTP version \(head.version) which is unsupported by the above `if`")

context.write(wrapOutboundOut(.head(head))).map {
self.callOutToDelegateFireAndForget(value: head, self.delegate.didSendRequestHead)
}.flatMap {
Expand Down
9 changes: 5 additions & 4 deletions Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,10 @@ class HTTPClientInternalTests: XCTestCase {
}

let upload = try! httpClient.post(url: "http://localhost:\(httpBin.port)/post", body: body).wait()
let bytes = upload.body.flatMap { $0.getData(at: 0, length: $0.readableBytes) }
let data = try! JSONDecoder().decode(RequestInfo.self, from: bytes!)
let data = upload.body.flatMap { try? JSONDecoder().decode(RequestInfo.self, from: $0) }

XCTAssertEqual(.ok, upload.status)
XCTAssertEqual("id: 0id: 1id: 2id: 3id: 4id: 5id: 6id: 7id: 8id: 9", data.data)
XCTAssertEqual("id: 0id: 1id: 2id: 3id: 4id: 5id: 6id: 7id: 8id: 9", data?.data)
}

func testProxyStreamingFailure() throws {
Expand Down Expand Up @@ -466,7 +465,9 @@ class HTTPClientInternalTests: XCTestCase {
XCTAssertNoThrow(try httpBin.shutdown())
}

let req = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get", method: .GET, headers: ["Connection": "close"], body: nil)
let req = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get",
method: .GET,
headers: ["X-Send-Back-Header-Connection": "close"], body: nil)
_ = try! httpClient.execute(request: req).wait()
let el = httpClient.eventLoopGroup.next()
try! el.scheduleTask(in: .milliseconds(500)) {
Expand Down
66 changes: 43 additions & 23 deletions Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler {
}

internal struct HTTPResponseBuilder {
let head: HTTPResponseHead
var head: HTTPResponseHead
var body: ByteBuffer?

init(_ version: HTTPVersion = HTTPVersion(major: 1, minor: 1), status: HTTPResponseStatus, headers: HTTPHeaders = HTTPHeaders()) {
Expand All @@ -357,8 +357,13 @@ internal struct HTTPResponseBuilder {
}
}

let globalRequestCounter = NIOAtomic<Int>.makeAtomic(value: 0)
let globalConnectionCounter = NIOAtomic<Int>.makeAtomic(value: 0)

internal struct RequestInfo: Codable {
let data: String
var data: String
var requestNumber: Int
var connectionNumber: Int
}

internal final class HttpBinHandler: ChannelInboundHandler {
Expand All @@ -367,16 +372,19 @@ internal final class HttpBinHandler: ChannelInboundHandler {

let channelPromise: EventLoopPromise<Channel>?
var resps = CircularBuffer<HTTPResponseBuilder>()
var closeAfterResponse = false
var responseHeaders = HTTPHeaders()
var delay: TimeAmount = .seconds(0)
let creationDate = Date()
let maxChannelAge: TimeAmount?
var shouldClose = false
var isServingRequest = false
let myConnectionNumber: Int
var currentRequestNumber: Int = -1

init(channelPromise: EventLoopPromise<Channel>? = nil, maxChannelAge: TimeAmount? = nil) {
self.channelPromise = channelPromise
self.maxChannelAge = maxChannelAge
self.myConnectionNumber = globalConnectionCounter.add(1)
}

func handlerAdded(context: ChannelHandlerContext) {
Expand All @@ -402,27 +410,31 @@ internal final class HttpBinHandler: ChannelInboundHandler {
self.delay = .nanoseconds(0)
}

if let connection = head.headers["Connection"].first {
self.closeAfterResponse = (connection == "close")
} else {
self.closeAfterResponse = false
for header in head.headers {
let needle = "x-send-back-header-"
if header.name.lowercased().starts(with: needle) {
self.responseHeaders.add(name: String(header.name.dropFirst(needle.count)),
value: header.value)
}
}
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
self.isServingRequest = true
switch self.unwrapInboundIn(data) {
case .head(let req):
self.responseHeaders = HTTPHeaders()
self.currentRequestNumber = globalRequestCounter.add(1)
self.parseAndSetOptions(from: req)
let urlComponents = URLComponents(string: req.uri)!
switch urlComponents.percentEncodedPath {
case "/":
var headers = HTTPHeaders()
var headers = self.responseHeaders
headers.add(name: "X-Is-This-Slash", value: "Yes")
self.resps.append(HTTPResponseBuilder(status: .ok, headers: headers))
return
case "/echo-uri":
var headers = HTTPHeaders()
var headers = self.responseHeaders
headers.add(name: "X-Calling-URI", value: req.uri)
self.resps.append(HTTPResponseBuilder(status: .ok, headers: headers))
return
Expand All @@ -436,6 +448,13 @@ internal final class HttpBinHandler: ChannelInboundHandler {
}
self.resps.append(HTTPResponseBuilder(status: .ok))
return
case "/stats":
var body = context.channel.allocator.buffer(capacity: 1)
body.writeString("Just some stats mate.")
var builder = HTTPResponseBuilder(status: .ok)
builder.add(body)

self.resps.append(builder)
case "/post":
if req.method != .POST {
self.resps.append(HTTPResponseBuilder(status: .methodNotAllowed))
Expand All @@ -444,29 +463,29 @@ internal final class HttpBinHandler: ChannelInboundHandler {
self.resps.append(HTTPResponseBuilder(status: .ok))
return
case "/redirect/302":
var headers = HTTPHeaders()
var headers = self.responseHeaders
headers.add(name: "location", value: "/ok")
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
return
case "/redirect/https":
let port = self.value(for: "port", from: urlComponents.query!)
var headers = HTTPHeaders()
var headers = self.responseHeaders
headers.add(name: "Location", value: "https://localhost:\(port)/ok")
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
return
case "/redirect/loopback":
let port = self.value(for: "port", from: urlComponents.query!)
var headers = HTTPHeaders()
var headers = self.responseHeaders
headers.add(name: "Location", value: "http://127.0.0.1:\(port)/echohostheader")
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
return
case "/redirect/infinite1":
var headers = HTTPHeaders()
var headers = self.responseHeaders
headers.add(name: "Location", value: "/redirect/infinite2")
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
return
case "/redirect/infinite2":
var headers = HTTPHeaders()
var headers = self.responseHeaders
headers.add(name: "Location", value: "/redirect/infinite1")
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
return
Expand Down Expand Up @@ -528,15 +547,15 @@ internal final class HttpBinHandler: ChannelInboundHandler {
if self.resps.isEmpty {
return
}
let response = self.resps.removeFirst()
var response = self.resps.removeFirst()
response.head.headers.add(contentsOf: self.responseHeaders)
context.write(wrapOutboundOut(.head(response.head)), promise: nil)
if var body = response.body {
let data = body.readData(length: body.readableBytes)!
let serialized = try! JSONEncoder().encode(RequestInfo(data: String(decoding: data,
as: Unicode.UTF8.self)))

var responseBody = context.channel.allocator.buffer(capacity: serialized.count)
responseBody.writeBytes(serialized)
if let body = response.body {
let requestInfo = RequestInfo(data: String(buffer: body),
requestNumber: self.currentRequestNumber,
connectionNumber: self.myConnectionNumber)
let responseBody = try! JSONEncoder().encodeAsByteBuffer(requestInfo,
allocator: context.channel.allocator)
context.write(wrapOutboundOut(.body(.byteBuffer(responseBody))), promise: nil)
}
context.eventLoop.scheduleTask(in: self.delay) {
Expand All @@ -549,7 +568,8 @@ internal final class HttpBinHandler: ChannelInboundHandler {
self.isServingRequest = false
switch result {
case .success:
if self.closeAfterResponse || self.shouldClose {
if self.responseHeaders[canonicalForm: "X-Close-Connection"].contains("true") ||
self.shouldClose {
context.close(promise: nil)
}
case .failure(let error):
Expand Down
4 changes: 4 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ extension HTTPClientTests {
("testValidationErrorsAreSurfaced", testValidationErrorsAreSurfaced),
("testUploadsReallyStream", testUploadsReallyStream),
("testUploadStreamingCallinToleratedFromOtsideEL", testUploadStreamingCallinToleratedFromOtsideEL),
("testWeHandleUsSendingACloseHeaderCorrectly", testWeHandleUsSendingACloseHeaderCorrectly),
("testWeHandleUsReceivingACloseHeaderCorrectly", testWeHandleUsReceivingACloseHeaderCorrectly),
("testWeHandleUsSendingACloseHeaderAmongstOtherConnectionHeadersCorrectly", testWeHandleUsSendingACloseHeaderAmongstOtherConnectionHeadersCorrectly),
("testWeHandleUsReceivingACloseHeaderAmongstOtherConnectionHeadersCorrectly", testWeHandleUsReceivingACloseHeaderAmongstOtherConnectionHeadersCorrectly),
]
}
}
Loading