Skip to content

Commit 5d9b784

Browse files
authored
Fixes bi-directional streaming (#344)
Motivation: When we stream request body, current implementation expects that body will finish streaming _before_ we start to receive response body parts. This is not correct, reponse body parts can start to arrive before we finish sending the request. Modifications: - Simplifies state machine, we only case about request being fully sent to prevent sending body parts after .end, but response state machine is mostly ignored and correct flow will be handled by NIOHTTP1 pipeline - Adds HTTPEchoHandler, that replies to each response body part - Adds bi-directional streaming test Result: Closes #327
1 parent ba845ee commit 5d9b784

File tree

6 files changed

+214
-34
lines changed

6 files changed

+214
-34
lines changed

Diff for: Sources/AsyncHTTPClient/HTTPHandler.swift

+47-17
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,9 @@ extension HTTPClient {
634634
self.promise.fail(error)
635635
connection.channel.close(promise: nil)
636636
}
637+
} else {
638+
// this is used in tests where we don't want to bootstrap the whole connection pool
639+
self.promise.fail(error)
637640
}
638641
}
639642

@@ -665,11 +668,11 @@ internal struct TaskCancelEvent {}
665668
internal class TaskHandler<Delegate: HTTPClientResponseDelegate>: RemovableChannelHandler {
666669
enum State {
667670
case idle
668-
case bodySent
669-
case sent
670-
case head
671+
case sendingBodyWaitingResponseHead
672+
case sendingBodyResponseHeadReceived
673+
case bodySentWaitingResponseHead
674+
case bodySentResponseHeadReceived
671675
case redirected(HTTPResponseHead, URL)
672-
case body
673676
case endOrError
674677
}
675678

@@ -794,7 +797,8 @@ extension TaskHandler: ChannelDuplexHandler {
794797
typealias OutboundOut = HTTPClientRequestPart
795798

796799
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
797-
self.state = .idle
800+
self.state = .sendingBodyWaitingResponseHead
801+
798802
let request = self.unwrapOutboundIn(data)
799803

800804
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1),
@@ -840,23 +844,37 @@ extension TaskHandler: ChannelDuplexHandler {
840844
self.writeBody(request: request, context: context)
841845
}.flatMap {
842846
context.eventLoop.assertInEventLoop()
843-
if case .endOrError = self.state {
847+
switch self.state {
848+
case .idle:
849+
// since this code path is called from `write` and write sets state to sendingBody
850+
preconditionFailure("should not happen")
851+
case .sendingBodyWaitingResponseHead:
852+
self.state = .bodySentWaitingResponseHead
853+
case .sendingBodyResponseHeadReceived:
854+
self.state = .bodySentResponseHeadReceived
855+
case .bodySentWaitingResponseHead, .bodySentResponseHeadReceived:
856+
preconditionFailure("should not happen, state is \(self.state)")
857+
case .redirected:
858+
break
859+
case .endOrError:
860+
// If the state is .endOrError, it means that request was failed and there is nothing to do here:
861+
// we cannot write .end since channel is most likely closed, and we should not fail the future,
862+
// since the task would already be failed, no need to fail the writer too.
844863
return context.eventLoop.makeSucceededFuture(())
845864
}
846865

847-
self.state = .bodySent
848866
if let expectedBodyLength = self.expectedBodyLength, expectedBodyLength != self.actualBodyLength {
849867
let error = HTTPClientError.bodyLengthMismatch
850868
return context.eventLoop.makeFailedFuture(error)
851869
}
852870
return context.writeAndFlush(self.wrapOutboundOut(.end(nil)))
853871
}.map {
854872
context.eventLoop.assertInEventLoop()
873+
855874
if case .endOrError = self.state {
856875
return
857876
}
858877

859-
self.state = .sent
860878
self.callOutToDelegateFireAndForget(self.delegate.didSendRequest)
861879
}.flatMapErrorThrowing { error in
862880
context.eventLoop.assertInEventLoop()
@@ -903,6 +921,9 @@ extension TaskHandler: ChannelDuplexHandler {
903921
private func writeBodyPart(context: ChannelHandlerContext, part: IOData, promise: EventLoopPromise<Void>) {
904922
switch self.state {
905923
case .idle:
924+
// this function is called on the codepath starting with write, so it cannot be in state .idle
925+
preconditionFailure("should not happen")
926+
case .sendingBodyWaitingResponseHead, .sendingBodyResponseHeadReceived, .redirected:
906927
if let limit = self.expectedBodyLength, self.actualBodyLength + part.readableBytes > limit {
907928
let error = HTTPClientError.bodyLengthMismatch
908929
self.errorCaught(context: context, error: error)
@@ -911,7 +932,7 @@ extension TaskHandler: ChannelDuplexHandler {
911932
}
912933
self.actualBodyLength += part.readableBytes
913934
context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise)
914-
default:
935+
case .bodySentWaitingResponseHead, .bodySentResponseHeadReceived, .endOrError:
915936
let error = HTTPClientError.writeAfterRequestSent
916937
self.errorCaught(context: context, error: error)
917938
promise.fail(error)
@@ -931,7 +952,18 @@ extension TaskHandler: ChannelDuplexHandler {
931952
let response = self.unwrapInboundIn(data)
932953
switch response {
933954
case .head(let head):
934-
if case .endOrError = self.state {
955+
switch self.state {
956+
case .idle:
957+
// should be prevented by NIO HTTP1 pipeline, see testHTTPResponseHeadBeforeRequestHead
958+
preconditionFailure("should not happen")
959+
case .sendingBodyWaitingResponseHead:
960+
self.state = .sendingBodyResponseHeadReceived
961+
case .bodySentWaitingResponseHead:
962+
self.state = .bodySentResponseHeadReceived
963+
case .sendingBodyResponseHeadReceived, .bodySentResponseHeadReceived, .redirected:
964+
// should be prevented by NIO HTTP1 pipeline, aee testHTTPResponseDoubleHead
965+
preconditionFailure("should not happen")
966+
case .endOrError:
935967
return
936968
}
937969

@@ -942,7 +974,6 @@ extension TaskHandler: ChannelDuplexHandler {
942974
if let redirectURL = self.redirectHandler?.redirectTarget(status: head.status, headers: head.headers) {
943975
self.state = .redirected(head, redirectURL)
944976
} else {
945-
self.state = .head
946977
self.mayRead = false
947978
self.callOutToDelegate(value: head, channelEventLoop: context.eventLoop, self.delegate.didReceiveHead)
948979
.whenComplete { result in
@@ -954,7 +985,6 @@ extension TaskHandler: ChannelDuplexHandler {
954985
case .redirected, .endOrError:
955986
break
956987
default:
957-
self.state = .body
958988
self.mayRead = false
959989
self.callOutToDelegate(value: body, channelEventLoop: context.eventLoop, self.delegate.didReceiveBodyPart)
960990
.whenComplete { result in
@@ -1009,10 +1039,10 @@ extension TaskHandler: ChannelDuplexHandler {
10091039

10101040
func channelInactive(context: ChannelHandlerContext) {
10111041
switch self.state {
1042+
case .idle, .sendingBodyWaitingResponseHead, .sendingBodyResponseHeadReceived, .bodySentWaitingResponseHead, .bodySentResponseHeadReceived, .redirected:
1043+
self.errorCaught(context: context, error: HTTPClientError.remoteConnectionClosed)
10121044
case .endOrError:
10131045
break
1014-
case .body, .head, .idle, .redirected, .sent, .bodySent:
1015-
self.errorCaught(context: context, error: HTTPClientError.remoteConnectionClosed)
10161046
}
10171047
context.fireChannelInactive()
10181048
}
@@ -1025,8 +1055,8 @@ extension TaskHandler: ChannelDuplexHandler {
10251055
/// Some HTTP Servers can 'forget' to respond with CloseNotify when client is closing connection,
10261056
/// this could lead to incomplete SSL shutdown. But since request is already processed, we can ignore this error.
10271057
break
1028-
case .head where self.ignoreUncleanSSLShutdown,
1029-
.body where self.ignoreUncleanSSLShutdown:
1058+
case .sendingBodyResponseHeadReceived where self.ignoreUncleanSSLShutdown,
1059+
.bodySentResponseHeadReceived where self.ignoreUncleanSSLShutdown:
10301060
/// We can also ignore this error like `.end`.
10311061
break
10321062
default:
@@ -1035,7 +1065,7 @@ extension TaskHandler: ChannelDuplexHandler {
10351065
}
10361066
default:
10371067
switch self.state {
1038-
case .idle, .bodySent, .sent, .head, .redirected, .body:
1068+
case .idle, .sendingBodyWaitingResponseHead, .sendingBodyResponseHeadReceived, .bodySentWaitingResponseHead, .bodySentResponseHeadReceived, .redirected:
10391069
self.state = .endOrError
10401070
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
10411071
case .endOrError:

Diff for: Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift

+3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ extension HTTPClientInternalTests {
2929
("testBadHTTPRequest", testBadHTTPRequest),
3030
("testHostPort", testHostPort),
3131
("testHTTPPartsHandlerMultiBody", testHTTPPartsHandlerMultiBody),
32+
("testHTTPResponseHeadBeforeRequestHead", testHTTPResponseHeadBeforeRequestHead),
33+
("testHTTPResponseDoubleHead", testHTTPResponseDoubleHead),
34+
("testRequestFinishesAfterRedirectIfServerRespondsBeforeClientFinishes", testRequestFinishesAfterRedirectIfServerRespondsBeforeClientFinishes),
3235
("testProxyStreaming", testProxyStreaming),
3336
("testProxyStreamingFailure", testProxyStreamingFailure),
3437
("testUploadStreamingBackpressure", testUploadStreamingBackpressure),

Diff for: Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift

+95-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ class HTTPClientInternalTests: XCTestCase {
145145

146146
try channel.pipeline.addHandler(handler).wait()
147147

148-
handler.state = .sent
148+
handler.state = .bodySentWaitingResponseHead
149149
var body = channel.allocator.buffer(capacity: 4)
150150
body.writeStaticString("1234")
151151

@@ -161,6 +161,100 @@ class HTTPClientInternalTests: XCTestCase {
161161
}
162162
}
163163

164+
func testHTTPResponseHeadBeforeRequestHead() throws {
165+
let channel = EmbeddedChannel()
166+
XCTAssertNoThrow(try channel.connect(to: try SocketAddress(unixDomainSocketPath: "/fake")).wait())
167+
168+
let delegate = TestHTTPDelegate()
169+
let task = Task<Void>(eventLoop: channel.eventLoop, logger: HTTPClient.loggingDisabled)
170+
let handler = TaskHandler(task: task,
171+
kind: .host,
172+
delegate: delegate,
173+
redirectHandler: nil,
174+
ignoreUncleanSSLShutdown: false,
175+
logger: HTTPClient.loggingDisabled)
176+
177+
XCTAssertNoThrow(try channel.pipeline.addHTTPClientHandlers().wait())
178+
XCTAssertNoThrow(try channel.pipeline.addHandler(handler).wait())
179+
180+
XCTAssertNoThrow(try channel.writeInbound(ByteBuffer(string: "HTTP/1.0 200 OK\r\n\r\n")))
181+
182+
XCTAssertThrowsError(try task.futureResult.wait()) { error in
183+
XCTAssertEqual(error as? NIOHTTPDecoderError, NIOHTTPDecoderError.unsolicitedResponse)
184+
}
185+
}
186+
187+
func testHTTPResponseDoubleHead() throws {
188+
let channel = EmbeddedChannel()
189+
XCTAssertNoThrow(try channel.connect(to: try SocketAddress(unixDomainSocketPath: "/fake")).wait())
190+
191+
let delegate = TestHTTPDelegate()
192+
let task = Task<Void>(eventLoop: channel.eventLoop, logger: HTTPClient.loggingDisabled)
193+
let handler = TaskHandler(task: task,
194+
kind: .host,
195+
delegate: delegate,
196+
redirectHandler: nil,
197+
ignoreUncleanSSLShutdown: false,
198+
logger: HTTPClient.loggingDisabled)
199+
200+
XCTAssertNoThrow(try channel.pipeline.addHTTPClientHandlers().wait())
201+
XCTAssertNoThrow(try channel.pipeline.addHandler(handler).wait())
202+
203+
let request = try HTTPClient.Request(url: "http://localhost/get")
204+
XCTAssertNoThrow(try channel.writeOutbound(request))
205+
206+
XCTAssertNoThrow(try channel.writeInbound(ByteBuffer(string: "HTTP/1.0 200 OK\r\nHTTP/1.0 200 OK\r\n\r\n")))
207+
208+
XCTAssertThrowsError(try task.futureResult.wait()) { error in
209+
XCTAssertEqual((error as? HTTPParserError)?.debugDescription, "invalid character in header")
210+
}
211+
}
212+
213+
func testRequestFinishesAfterRedirectIfServerRespondsBeforeClientFinishes() throws {
214+
let channel = EmbeddedChannel()
215+
216+
var request = try Request(url: "http://localhost:8080/get")
217+
// This promise is needed to force task handler to process incoming redirecting head before finishing sending the request
218+
let promise = channel.eventLoop.makePromise(of: Void.self)
219+
request.body = .stream(length: 6) { writer in
220+
promise.futureResult.flatMap {
221+
writer.write(.byteBuffer(ByteBuffer(string: "helllo")))
222+
}
223+
}
224+
225+
let task = Task<Void>(eventLoop: channel.eventLoop, logger: HTTPClient.loggingDisabled)
226+
let redirecter = RedirectHandler<Void>(request: request) { _ in
227+
task
228+
}
229+
230+
let handler = TaskHandler(task: task,
231+
kind: .host,
232+
delegate: TestHTTPDelegate(),
233+
redirectHandler: redirecter,
234+
ignoreUncleanSSLShutdown: false,
235+
logger: HTTPClient.loggingDisabled)
236+
237+
XCTAssertNoThrow(try channel.pipeline.addHTTPClientHandlers().wait())
238+
XCTAssertNoThrow(try channel.pipeline.addHandler(handler).wait())
239+
240+
let future = channel.write(request)
241+
channel.flush()
242+
243+
XCTAssertNoThrow(XCTAssertNotNil(try channel.readOutbound(as: ByteBuffer.self))) // expecting to read head from the client
244+
// sending redirect before client finishesh processing the request
245+
XCTAssertNoThrow(try channel.writeInbound(ByteBuffer(string: "HTTP/1.1 302 Found\r\nLocation: /follow\r\n\r\n")))
246+
channel.flush()
247+
248+
promise.succeed(())
249+
250+
// we expect client to fully send us all bytes
251+
XCTAssertEqual(try channel.readOutbound(as: ByteBuffer.self), ByteBuffer(string: "helllo"))
252+
XCTAssertEqual(try channel.readOutbound(as: ByteBuffer.self), ByteBuffer(string: ""))
253+
XCTAssertNoThrow(try channel.writeInbound(ByteBuffer(string: "\r\n")))
254+
255+
XCTAssertNoThrow(try future.wait())
256+
}
257+
164258
func testProxyStreaming() throws {
165259
let httpBin = HTTPBin()
166260
let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup))

Diff for: Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift

+24
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,30 @@ struct CollectEverythingLogHandler: LogHandler {
896896
}
897897
}
898898

899+
class HTTPEchoHandler: ChannelInboundHandler {
900+
typealias InboundIn = HTTPServerRequestPart
901+
typealias OutboundOut = HTTPServerResponsePart
902+
903+
var promises: CircularBuffer<EventLoopPromise<Void>> = CircularBuffer()
904+
905+
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
906+
let request = self.unwrapInboundIn(data)
907+
switch request {
908+
case .head:
909+
context.writeAndFlush(self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), status: .ok))), promise: nil)
910+
case .body(let bytes):
911+
context.writeAndFlush(self.wrapOutboundOut(.body(.byteBuffer(bytes)))).whenSuccess {
912+
if let promise = self.promises.popFirst() {
913+
promise.succeed(())
914+
}
915+
}
916+
case .end:
917+
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
918+
context.close(promise: nil)
919+
}
920+
}
921+
}
922+
899923
private let cert = """
900924
-----BEGIN CERTIFICATE-----
901925
MIICmDCCAYACCQCPC8JDqMh1zzANBgkqhkiG9w0BAQsFADANMQswCQYDVQQGEwJ1

Diff for: Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift

+1
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ extension HTTPClientTests {
128128
("testSSLHandshakeErrorPropagation", testSSLHandshakeErrorPropagation),
129129
("testSSLHandshakeErrorPropagationDelayedClose", testSSLHandshakeErrorPropagationDelayedClose),
130130
("testWeCloseConnectionsWhenConnectionCloseSetByServer", testWeCloseConnectionsWhenConnectionCloseSetByServer),
131+
("testBiDirectionalStreaming", testBiDirectionalStreaming),
131132
]
132133
}
133134
}

0 commit comments

Comments
 (0)