Skip to content

Fixes bi-directional streaming #344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 47 additions & 17 deletions Sources/AsyncHTTPClient/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,9 @@ extension HTTPClient {
self.promise.fail(error)
connection.channel.close(promise: nil)
}
} else {
// this is used in tests where we don't want to bootstrap the whole connection pool
self.promise.fail(error)
}
}

Expand Down Expand Up @@ -665,11 +668,11 @@ internal struct TaskCancelEvent {}
internal class TaskHandler<Delegate: HTTPClientResponseDelegate>: RemovableChannelHandler {
enum State {
case idle
case bodySent
case sent
case head
case sendingBodyWaitingResponseHead
case sendingBodyResponseHeadReceived
case bodySentWaitingResponseHead
case bodySentResponseHeadReceived
case redirected(HTTPResponseHead, URL)
case body
case endOrError
}

Expand Down Expand Up @@ -794,7 +797,8 @@ extension TaskHandler: ChannelDuplexHandler {
typealias OutboundOut = HTTPClientRequestPart

func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
self.state = .idle
self.state = .sendingBodyWaitingResponseHead

let request = self.unwrapOutboundIn(data)

var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1),
Expand Down Expand Up @@ -840,23 +844,37 @@ extension TaskHandler: ChannelDuplexHandler {
self.writeBody(request: request, context: context)
}.flatMap {
context.eventLoop.assertInEventLoop()
if case .endOrError = self.state {
switch self.state {
case .idle:
// since this code path is called from `write` and write sets state to sendingBody
preconditionFailure("should not happen")
case .sendingBodyWaitingResponseHead:
self.state = .bodySentWaitingResponseHead
case .sendingBodyResponseHeadReceived:
self.state = .bodySentResponseHeadReceived
case .bodySentWaitingResponseHead, .bodySentResponseHeadReceived:
preconditionFailure("should not happen, state is \(self.state)")
case .redirected:
break
case .endOrError:
// If the state is .endOrError, it means that request was failed and there is nothing to do here:
// we cannot write .end since channel is most likely closed, and we should not fail the future,
// since the task would already be failed, no need to fail the writer too.
return context.eventLoop.makeSucceededFuture(())
}

self.state = .bodySent
if let expectedBodyLength = self.expectedBodyLength, expectedBodyLength != self.actualBodyLength {
let error = HTTPClientError.bodyLengthMismatch
return context.eventLoop.makeFailedFuture(error)
}
return context.writeAndFlush(self.wrapOutboundOut(.end(nil)))
}.map {
context.eventLoop.assertInEventLoop()

if case .endOrError = self.state {
return
}

self.state = .sent
self.callOutToDelegateFireAndForget(self.delegate.didSendRequest)
}.flatMapErrorThrowing { error in
context.eventLoop.assertInEventLoop()
Expand Down Expand Up @@ -903,6 +921,9 @@ extension TaskHandler: ChannelDuplexHandler {
private func writeBodyPart(context: ChannelHandlerContext, part: IOData, promise: EventLoopPromise<Void>) {
switch self.state {
case .idle:
// this function is called on the codepath starting with write, so it cannot be in state .idle
preconditionFailure("should not happen")
case .sendingBodyWaitingResponseHead, .sendingBodyResponseHeadReceived, .redirected:
if let limit = self.expectedBodyLength, self.actualBodyLength + part.readableBytes > limit {
let error = HTTPClientError.bodyLengthMismatch
self.errorCaught(context: context, error: error)
Expand All @@ -911,7 +932,7 @@ extension TaskHandler: ChannelDuplexHandler {
}
self.actualBodyLength += part.readableBytes
context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise)
default:
case .bodySentWaitingResponseHead, .bodySentResponseHeadReceived, .endOrError:
let error = HTTPClientError.writeAfterRequestSent
self.errorCaught(context: context, error: error)
promise.fail(error)
Expand All @@ -931,7 +952,18 @@ extension TaskHandler: ChannelDuplexHandler {
let response = self.unwrapInboundIn(data)
switch response {
case .head(let head):
if case .endOrError = self.state {
switch self.state {
case .idle:
// should be prevented by NIO HTTP1 pipeline, see testHTTPResponseHeadBeforeRequestHead
preconditionFailure("should not happen")
case .sendingBodyWaitingResponseHead:
self.state = .sendingBodyResponseHeadReceived
case .bodySentWaitingResponseHead:
self.state = .bodySentResponseHeadReceived
case .sendingBodyResponseHeadReceived, .bodySentResponseHeadReceived, .redirected:
// should be prevented by NIO HTTP1 pipeline, aee testHTTPResponseDoubleHead
preconditionFailure("should not happen")
case .endOrError:
return
}

Expand All @@ -942,7 +974,6 @@ extension TaskHandler: ChannelDuplexHandler {
if let redirectURL = self.redirectHandler?.redirectTarget(status: head.status, headers: head.headers) {
self.state = .redirected(head, redirectURL)
} else {
self.state = .head
self.mayRead = false
self.callOutToDelegate(value: head, channelEventLoop: context.eventLoop, self.delegate.didReceiveHead)
.whenComplete { result in
Expand All @@ -954,7 +985,6 @@ extension TaskHandler: ChannelDuplexHandler {
case .redirected, .endOrError:
break
default:
self.state = .body
self.mayRead = false
self.callOutToDelegate(value: body, channelEventLoop: context.eventLoop, self.delegate.didReceiveBodyPart)
.whenComplete { result in
Expand Down Expand Up @@ -1009,10 +1039,10 @@ extension TaskHandler: ChannelDuplexHandler {

func channelInactive(context: ChannelHandlerContext) {
switch self.state {
case .idle, .sendingBodyWaitingResponseHead, .sendingBodyResponseHeadReceived, .bodySentWaitingResponseHead, .bodySentResponseHeadReceived, .redirected:
self.errorCaught(context: context, error: HTTPClientError.remoteConnectionClosed)
case .endOrError:
break
case .body, .head, .idle, .redirected, .sent, .bodySent:
self.errorCaught(context: context, error: HTTPClientError.remoteConnectionClosed)
}
context.fireChannelInactive()
}
Expand All @@ -1025,8 +1055,8 @@ extension TaskHandler: ChannelDuplexHandler {
/// Some HTTP Servers can 'forget' to respond with CloseNotify when client is closing connection,
/// this could lead to incomplete SSL shutdown. But since request is already processed, we can ignore this error.
break
case .head where self.ignoreUncleanSSLShutdown,
.body where self.ignoreUncleanSSLShutdown:
case .sendingBodyResponseHeadReceived where self.ignoreUncleanSSLShutdown,
.bodySentResponseHeadReceived where self.ignoreUncleanSSLShutdown:
/// We can also ignore this error like `.end`.
break
default:
Expand All @@ -1035,7 +1065,7 @@ extension TaskHandler: ChannelDuplexHandler {
}
default:
switch self.state {
case .idle, .bodySent, .sent, .head, .redirected, .body:
case .idle, .sendingBodyWaitingResponseHead, .sendingBodyResponseHeadReceived, .bodySentWaitingResponseHead, .bodySentResponseHeadReceived, .redirected:
self.state = .endOrError
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
case .endOrError:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ extension HTTPClientInternalTests {
("testBadHTTPRequest", testBadHTTPRequest),
("testHostPort", testHostPort),
("testHTTPPartsHandlerMultiBody", testHTTPPartsHandlerMultiBody),
("testHTTPResponseHeadBeforeRequestHead", testHTTPResponseHeadBeforeRequestHead),
("testHTTPResponseDoubleHead", testHTTPResponseDoubleHead),
("testRequestFinishesAfterRedirectIfServerRespondsBeforeClientFinishes", testRequestFinishesAfterRedirectIfServerRespondsBeforeClientFinishes),
("testProxyStreaming", testProxyStreaming),
("testProxyStreamingFailure", testProxyStreamingFailure),
("testUploadStreamingBackpressure", testUploadStreamingBackpressure),
Expand Down
96 changes: 95 additions & 1 deletion Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class HTTPClientInternalTests: XCTestCase {

try channel.pipeline.addHandler(handler).wait()

handler.state = .sent
handler.state = .bodySentWaitingResponseHead
var body = channel.allocator.buffer(capacity: 4)
body.writeStaticString("1234")

Expand All @@ -161,6 +161,100 @@ class HTTPClientInternalTests: XCTestCase {
}
}

func testHTTPResponseHeadBeforeRequestHead() throws {
let channel = EmbeddedChannel()
XCTAssertNoThrow(try channel.connect(to: try SocketAddress(unixDomainSocketPath: "/fake")).wait())

let delegate = TestHTTPDelegate()
let task = Task<Void>(eventLoop: channel.eventLoop, logger: HTTPClient.loggingDisabled)
let handler = TaskHandler(task: task,
kind: .host,
delegate: delegate,
redirectHandler: nil,
ignoreUncleanSSLShutdown: false,
logger: HTTPClient.loggingDisabled)

XCTAssertNoThrow(try channel.pipeline.addHTTPClientHandlers().wait())
XCTAssertNoThrow(try channel.pipeline.addHandler(handler).wait())

XCTAssertNoThrow(try channel.writeInbound(ByteBuffer(string: "HTTP/1.0 200 OK\r\n\r\n")))

XCTAssertThrowsError(try task.futureResult.wait()) { error in
XCTAssertEqual(error as? NIOHTTPDecoderError, NIOHTTPDecoderError.unsolicitedResponse)
}
}

func testHTTPResponseDoubleHead() throws {
let channel = EmbeddedChannel()
XCTAssertNoThrow(try channel.connect(to: try SocketAddress(unixDomainSocketPath: "/fake")).wait())

let delegate = TestHTTPDelegate()
let task = Task<Void>(eventLoop: channel.eventLoop, logger: HTTPClient.loggingDisabled)
let handler = TaskHandler(task: task,
kind: .host,
delegate: delegate,
redirectHandler: nil,
ignoreUncleanSSLShutdown: false,
logger: HTTPClient.loggingDisabled)

XCTAssertNoThrow(try channel.pipeline.addHTTPClientHandlers().wait())
XCTAssertNoThrow(try channel.pipeline.addHandler(handler).wait())

let request = try HTTPClient.Request(url: "http://localhost/get")
XCTAssertNoThrow(try channel.writeOutbound(request))

XCTAssertNoThrow(try channel.writeInbound(ByteBuffer(string: "HTTP/1.0 200 OK\r\nHTTP/1.0 200 OK\r\n\r\n")))

XCTAssertThrowsError(try task.futureResult.wait()) { error in
XCTAssertEqual((error as? HTTPParserError)?.debugDescription, "invalid character in header")
}
}

func testRequestFinishesAfterRedirectIfServerRespondsBeforeClientFinishes() throws {
let channel = EmbeddedChannel()

var request = try Request(url: "http://localhost:8080/get")
// This promise is needed to force task handler to process incoming redirecting head before finishing sending the request
let promise = channel.eventLoop.makePromise(of: Void.self)
request.body = .stream(length: 6) { writer in
promise.futureResult.flatMap {
writer.write(.byteBuffer(ByteBuffer(string: "helllo")))
}
}

let task = Task<Void>(eventLoop: channel.eventLoop, logger: HTTPClient.loggingDisabled)
let redirecter = RedirectHandler<Void>(request: request) { _ in
task
}

let handler = TaskHandler(task: task,
kind: .host,
delegate: TestHTTPDelegate(),
redirectHandler: redirecter,
ignoreUncleanSSLShutdown: false,
logger: HTTPClient.loggingDisabled)

XCTAssertNoThrow(try channel.pipeline.addHTTPClientHandlers().wait())
XCTAssertNoThrow(try channel.pipeline.addHandler(handler).wait())

let future = channel.write(request)
channel.flush()

XCTAssertNoThrow(XCTAssertNotNil(try channel.readOutbound(as: ByteBuffer.self))) // expecting to read head from the client
// sending redirect before client finishesh processing the request
XCTAssertNoThrow(try channel.writeInbound(ByteBuffer(string: "HTTP/1.1 302 Found\r\nLocation: /follow\r\n\r\n")))
channel.flush()

promise.succeed(())

// we expect client to fully send us all bytes
XCTAssertEqual(try channel.readOutbound(as: ByteBuffer.self), ByteBuffer(string: "helllo"))
XCTAssertEqual(try channel.readOutbound(as: ByteBuffer.self), ByteBuffer(string: ""))
XCTAssertNoThrow(try channel.writeInbound(ByteBuffer(string: "\r\n")))

XCTAssertNoThrow(try future.wait())
}

func testProxyStreaming() throws {
let httpBin = HTTPBin()
let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup))
Expand Down
24 changes: 24 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,30 @@ struct CollectEverythingLogHandler: LogHandler {
}
}

class HTTPEchoHandler: ChannelInboundHandler {
typealias InboundIn = HTTPServerRequestPart
typealias OutboundOut = HTTPServerResponsePart

var promises: CircularBuffer<EventLoopPromise<Void>> = CircularBuffer()

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let request = self.unwrapInboundIn(data)
switch request {
case .head:
context.writeAndFlush(self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), status: .ok))), promise: nil)
case .body(let bytes):
context.writeAndFlush(self.wrapOutboundOut(.body(.byteBuffer(bytes)))).whenSuccess {
if let promise = self.promises.popFirst() {
promise.succeed(())
}
}
case .end:
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
context.close(promise: nil)
}
}
}

private let cert = """
-----BEGIN CERTIFICATE-----
MIICmDCCAYACCQCPC8JDqMh1zzANBgkqhkiG9w0BAQsFADANMQswCQYDVQQGEwJ1
Expand Down
1 change: 1 addition & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ extension HTTPClientTests {
("testSSLHandshakeErrorPropagation", testSSLHandshakeErrorPropagation),
("testSSLHandshakeErrorPropagationDelayedClose", testSSLHandshakeErrorPropagationDelayedClose),
("testWeCloseConnectionsWhenConnectionCloseSetByServer", testWeCloseConnectionsWhenConnectionCloseSetByServer),
("testBiDirectionalStreaming", testBiDirectionalStreaming),
]
}
}
Loading