Skip to content

Commit f69b68f

Browse files
authored
fail if user tries writing bytes after request is sent (#270)
1 parent 2e6a64a commit f69b68f

File tree

6 files changed

+156
-18
lines changed

6 files changed

+156
-18
lines changed

Diff for: Sources/AsyncHTTPClient/HTTPClient.swift

+3
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
977977
case traceRequestWithBody
978978
case invalidHeaderFieldNames([String])
979979
case bodyLengthMismatch
980+
case writeAfterRequestSent
980981
case incompatibleHeaders
981982
}
982983

@@ -1030,6 +1031,8 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
10301031
public static func invalidHeaderFieldNames(_ names: [String]) -> HTTPClientError { return HTTPClientError(code: .invalidHeaderFieldNames(names)) }
10311032
/// Body length is not equal to `Content-Length`.
10321033
public static let bodyLengthMismatch = HTTPClientError(code: .bodyLengthMismatch)
1034+
/// Body part was written after request was fully sent.
1035+
public static let writeAfterRequestSent = HTTPClientError(code: .writeAfterRequestSent)
10331036
/// Incompatible headers specified, for example `Transfer-Encoding` and `Content-Length`.
10341037
public static let incompatibleHeaders = HTTPClientError(code: .incompatibleHeaders)
10351038
}

Diff for: Sources/AsyncHTTPClient/HTTPHandler.swift

+25-5
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,7 @@ internal struct TaskCancelEvent {}
665665
internal class TaskHandler<Delegate: HTTPClientResponseDelegate>: RemovableChannelHandler {
666666
enum State {
667667
case idle
668+
case bodySent
668669
case sent
669670
case head
670671
case redirected(HTTPResponseHead, URL)
@@ -839,6 +840,7 @@ extension TaskHandler: ChannelDuplexHandler {
839840
}.flatMap {
840841
self.writeBody(request: request, context: context)
841842
}.flatMap {
843+
self.state = .bodySent
842844
context.eventLoop.assertInEventLoop()
843845
if let expectedBodyLength = self.expectedBodyLength, expectedBodyLength != self.actualBodyLength {
844846
self.state = .endOrError
@@ -876,12 +878,10 @@ extension TaskHandler: ChannelDuplexHandler {
876878
let promise = self.task.eventLoop.makePromise(of: Void.self)
877879
// All writes have to be switched to the channel EL if channel and task ELs differ
878880
if channel.eventLoop.inEventLoop {
879-
self.actualBodyLength += part.readableBytes
880-
context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise)
881+
self.writeBodyPart(context: context, part: part, promise: promise)
881882
} else {
882883
channel.eventLoop.execute {
883-
self.actualBodyLength += part.readableBytes
884-
context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise)
884+
self.writeBodyPart(context: context, part: part, promise: promise)
885885
}
886886
}
887887

@@ -901,6 +901,26 @@ extension TaskHandler: ChannelDuplexHandler {
901901
}
902902
}
903903

904+
private func writeBodyPart(context: ChannelHandlerContext, part: IOData, promise: EventLoopPromise<Void>) {
905+
switch self.state {
906+
case .idle:
907+
if let limit = self.expectedBodyLength, self.actualBodyLength + part.readableBytes > limit {
908+
let error = HTTPClientError.bodyLengthMismatch
909+
self.state = .endOrError
910+
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
911+
promise.fail(error)
912+
return
913+
}
914+
self.actualBodyLength += part.readableBytes
915+
context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise)
916+
default:
917+
let error = HTTPClientError.writeAfterRequestSent
918+
self.state = .endOrError
919+
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
920+
promise.fail(error)
921+
}
922+
}
923+
904924
public func read(context: ChannelHandlerContext) {
905925
if self.mayRead {
906926
self.pendingRead = false
@@ -993,7 +1013,7 @@ extension TaskHandler: ChannelDuplexHandler {
9931013
switch self.state {
9941014
case .endOrError:
9951015
break
996-
case .body, .head, .idle, .redirected, .sent:
1016+
case .body, .head, .idle, .redirected, .sent, .bodySent:
9971017
self.state = .endOrError
9981018
let error = HTTPClientError.remoteConnectionClosed
9991019
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)

Diff for: Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ extension HTTPClientInternalTests {
4545
("testTaskPromiseBoundToEL", testTaskPromiseBoundToEL),
4646
("testConnectErrorCalloutOnCorrectEL", testConnectErrorCalloutOnCorrectEL),
4747
("testInternalRequestURI", testInternalRequestURI),
48+
("testBodyPartStreamStateChangedBeforeNotification", testBodyPartStreamStateChangedBeforeNotification),
4849
]
4950
}
5051
}

Diff for: Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift

+64-9
Original file line numberDiff line numberDiff line change
@@ -415,20 +415,21 @@ class HTTPClientInternalTests: XCTestCase {
415415
}
416416

417417
let group = getDefaultEventLoopGroup(numberOfThreads: 3)
418+
let serverGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
418419
defer {
419420
XCTAssertNoThrow(try group.syncShutdownGracefully())
421+
XCTAssertNoThrow(try serverGroup.syncShutdownGracefully())
420422
}
421423

422424
let channelEL = group.next()
423425
let delegateEL = group.next()
424426
let randoEL = group.next()
425427

426428
let httpClient = HTTPClient(eventLoopGroupProvider: .shared(group))
427-
let promise: EventLoopPromise<Channel> = httpClient.eventLoopGroup.next().makePromise()
428-
let httpBin = HTTPBin(channelPromise: promise)
429+
let server = NIOHTTP1TestServer(group: serverGroup)
429430
defer {
431+
XCTAssertNoThrow(try server.stop())
430432
XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true))
431-
XCTAssertNoThrow(try httpBin.shutdown())
432433
}
433434

434435
let body: HTTPClient.Body = .stream(length: 8) { writer in
@@ -439,21 +440,26 @@ class HTTPClientInternalTests: XCTestCase {
439440
}
440441
}
441442

442-
let request = try Request(url: "http://127.0.0.1:\(httpBin.port)/custom",
443+
let request = try Request(url: "http://127.0.0.1:\(server.serverPort)/custom",
443444
body: body)
444445
let delegate = Delegate(expectedEventLoop: delegateEL, randomOtherEventLoop: randoEL)
445446
let future = httpClient.execute(request: request,
446447
delegate: delegate,
447448
eventLoop: .init(.testOnly_exact(channelOn: channelEL,
448449
delegateOn: delegateEL))).futureResult
449450

450-
let channel = try promise.futureResult.wait()
451+
XCTAssertNoThrow(try server.readInbound()) // .head
452+
XCTAssertNoThrow(try server.readInbound()) // .body
453+
XCTAssertNoThrow(try server.readInbound()) // .end
451454

452455
// Send 3 parts, but only one should be received until the future is complete
453-
let buffer = channel.allocator.buffer(string: "1234")
454-
try channel.writeAndFlush(HTTPServerResponsePart.body(.byteBuffer(buffer))).wait()
456+
XCTAssertNoThrow(try server.writeOutbound(.head(.init(version: .init(major: 1, minor: 1),
457+
status: .ok,
458+
headers: HTTPHeaders([("Transfer-Encoding", "chunked")])))))
459+
let buffer = ByteBuffer(string: "1234")
460+
XCTAssertNoThrow(try server.writeOutbound(.body(.byteBuffer(buffer))))
461+
XCTAssertNoThrow(try server.writeOutbound(.end(nil)))
455462

456-
try channel.writeAndFlush(HTTPServerResponsePart.end(nil)).wait()
457463
let (receivedMessages, sentMessages) = try future.wait()
458464
XCTAssertEqual(2, receivedMessages.count)
459465
XCTAssertEqual(4, sentMessages.count)
@@ -488,7 +494,7 @@ class HTTPClientInternalTests: XCTestCase {
488494

489495
switch receivedMessages.dropFirst(0).first {
490496
case .some(.head(let head)):
491-
XCTAssertEqual(["transfer-encoding": "chunked"], head.headers)
497+
XCTAssertEqual(head.headers["transfer-encoding"].first, "chunked")
492498
default:
493499
XCTFail("wrong message")
494500
}
@@ -1025,4 +1031,53 @@ class HTTPClientInternalTests: XCTestCase {
10251031
XCTAssertEqual(request5.socketPath, "/tmp/file")
10261032
XCTAssertEqual(request5.uri, "/file/path")
10271033
}
1034+
1035+
func testBodyPartStreamStateChangedBeforeNotification() throws {
1036+
class StateValidationDelegate: HTTPClientResponseDelegate {
1037+
typealias Response = Void
1038+
1039+
var handler: TaskHandler<StateValidationDelegate>!
1040+
var triggered = false
1041+
1042+
func didReceiveError(task: HTTPClient.Task<Response>, _ error: Error) {
1043+
self.triggered = true
1044+
switch self.handler.state {
1045+
case .endOrError:
1046+
// expected
1047+
break
1048+
default:
1049+
XCTFail("unexpected state: \(self.handler.state)")
1050+
}
1051+
}
1052+
1053+
func didFinishRequest(task: HTTPClient.Task<Void>) throws {}
1054+
}
1055+
1056+
let channel = EmbeddedChannel()
1057+
XCTAssertNoThrow(try channel.connect(to: try SocketAddress(unixDomainSocketPath: "/fake")).wait())
1058+
1059+
let task = Task<Void>(eventLoop: channel.eventLoop, logger: HTTPClient.loggingDisabled)
1060+
1061+
let delegate = StateValidationDelegate()
1062+
let handler = TaskHandler(task: task,
1063+
kind: .host,
1064+
delegate: delegate,
1065+
redirectHandler: nil,
1066+
ignoreUncleanSSLShutdown: false,
1067+
logger: HTTPClient.loggingDisabled)
1068+
1069+
delegate.handler = handler
1070+
try channel.pipeline.addHandler(handler).wait()
1071+
1072+
var request = try Request(url: "http://localhost:8080/post")
1073+
request.body = .stream(length: 1) { writer in
1074+
writer.write(.byteBuffer(ByteBuffer(string: "1234")))
1075+
}
1076+
1077+
XCTAssertThrowsError(try channel.writeOutbound(request))
1078+
XCTAssertTrue(delegate.triggered)
1079+
1080+
XCTAssertNoThrow(try channel.readOutbound(as: HTTPClientRequestPart.self)) // .head
1081+
XCTAssertNoThrow(XCTAssertTrue(try channel.finish().isClean))
1082+
}
10281083
}

Diff for: Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift

+2
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ extension HTTPClientTests {
120120
("testDelegateCallinsTolerateRandomEL", testDelegateCallinsTolerateRandomEL),
121121
("testContentLengthTooLongFails", testContentLengthTooLongFails),
122122
("testContentLengthTooShortFails", testContentLengthTooShortFails),
123+
("testBodyUploadAfterEndFails", testBodyUploadAfterEndFails),
124+
("testNoBytesSentOverBodyLimit", testNoBytesSentOverBodyLimit),
123125
]
124126
}
125127
}

Diff for: Tests/AsyncHTTPClientTests/HTTPClientTests.swift

+61-4
Original file line numberDiff line numberDiff line change
@@ -1524,7 +1524,7 @@ class HTTPClientTests: XCTestCase {
15241524
XCTAssertEqual(.ok, firstResponse.status)
15251525
return localClient.get(url: url) // <== interesting bit here
15261526
}
1527-
}.wait().status))
1527+
}.wait().status))
15281528
}
15291529

15301530
func testMakeSecondRequestWhilstFirstIsOngoing() {
@@ -1910,7 +1910,7 @@ class HTTPClientTests: XCTestCase {
19101910
body: .stream { streamWriter in
19111911
streamWriterPromise.succeed(streamWriter)
19121912
return sentOffAllBodyPartsPromise.futureResult
1913-
})
1913+
})
19141914
}
19151915

19161916
guard let server = makeServer(), let request = makeRequest(server: server) else {
@@ -2502,7 +2502,7 @@ class HTTPClientTests: XCTestCase {
25022502
streamWriter.write(.byteBuffer(ByteBuffer(string: "1"))).cascade(to: promise)
25032503
}
25042504
return promise.futureResult
2505-
})).wait()) { error in
2505+
})).wait()) { error in
25062506
XCTAssertEqual(error as! HTTPClientError, HTTPClientError.bodyLengthMismatch)
25072507
}
25082508
// Quickly try another request and check that it works.
@@ -2528,7 +2528,7 @@ class HTTPClientTests: XCTestCase {
25282528
Request(url: url,
25292529
body: .stream(length: 1) { streamWriter in
25302530
streamWriter.write(.byteBuffer(ByteBuffer(string: tooLong)))
2531-
})).wait()) { error in
2531+
})).wait()) { error in
25322532
XCTAssertEqual(error as! HTTPClientError, HTTPClientError.bodyLengthMismatch)
25332533
}
25342534
// Quickly try another request and check that it works. If we by accident wrote some extra bytes into the
@@ -2545,4 +2545,61 @@ class HTTPClientTests: XCTestCase {
25452545
XCTAssertEqual(info.connectionNumber, 1)
25462546
XCTAssertEqual(info.requestNumber, 1)
25472547
}
2548+
2549+
func testBodyUploadAfterEndFails() {
2550+
let url = self.defaultHTTPBinURLPrefix + "post"
2551+
2552+
func uploader(_ streamWriter: HTTPClient.Body.StreamWriter) -> EventLoopFuture<Void> {
2553+
let done = streamWriter.write(.byteBuffer(ByteBuffer(string: "X")))
2554+
done.recover { error -> Void in
2555+
XCTFail("unexpected error \(error)")
2556+
}.whenSuccess {
2557+
// This is executed when we have already sent the end of the request.
2558+
done.eventLoop.execute {
2559+
streamWriter.write(.byteBuffer(ByteBuffer(string: "BAD BAD BAD"))).whenComplete { result in
2560+
switch result {
2561+
case .success:
2562+
XCTFail("we succeeded writing bytes after the end!?")
2563+
case .failure(let error):
2564+
XCTAssertEqual(HTTPClientError.writeAfterRequestSent, error as? HTTPClientError)
2565+
}
2566+
}
2567+
}
2568+
}
2569+
return done
2570+
}
2571+
2572+
XCTAssertThrowsError(
2573+
try self.defaultClient.execute(request:
2574+
Request(url: url,
2575+
body: .stream(length: 1, uploader))).wait()) { error in
2576+
XCTAssertEqual(HTTPClientError.writeAfterRequestSent, error as? HTTPClientError)
2577+
}
2578+
2579+
// Quickly try another request and check that it works. If we by accident wrote some extra bytes into the
2580+
// stream (and reuse the connection) that could cause problems.
2581+
XCTAssertNoThrow(try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "get").wait())
2582+
}
2583+
2584+
func testNoBytesSentOverBodyLimit() throws {
2585+
let server = NIOHTTP1TestServer(group: self.serverGroup)
2586+
defer {
2587+
XCTAssertNoThrow(try server.stop())
2588+
}
2589+
2590+
let tooLong = "XBAD BAD BAD NOT HTTP/1.1\r\n\r\n"
2591+
let future = self.defaultClient.execute(
2592+
request: try Request(url: "http://localhost:\(server.serverPort)",
2593+
body: .stream(length: 1) { streamWriter in
2594+
streamWriter.write(.byteBuffer(ByteBuffer(string: tooLong)))
2595+
}))
2596+
2597+
XCTAssertNoThrow(try server.readInbound()) // .head
2598+
// this should fail if client detects that we are about to send more bytes than body limit and closes the connection
2599+
// We can test that this test actually fails if we remove limit check in `writeBodyPart` - it will send bytes, meaning that the next
2600+
// call will not throw, but the future will still throw body mismatch error
2601+
XCTAssertThrowsError(try server.readInbound()) { error in XCTAssertEqual(error as? HTTPParserError, HTTPParserError.invalidEOFState) }
2602+
2603+
XCTAssertThrowsError(try future.wait())
2604+
}
25482605
}

0 commit comments

Comments
 (0)