diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index de2c5e551..387183c8f 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -977,6 +977,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { case traceRequestWithBody case invalidHeaderFieldNames([String]) case bodyLengthMismatch + case writeAfterRequestSent case incompatibleHeaders } @@ -1030,6 +1031,8 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { public static func invalidHeaderFieldNames(_ names: [String]) -> HTTPClientError { return HTTPClientError(code: .invalidHeaderFieldNames(names)) } /// Body length is not equal to `Content-Length`. public static let bodyLengthMismatch = HTTPClientError(code: .bodyLengthMismatch) + /// Body part was written after request was fully sent. + public static let writeAfterRequestSent = HTTPClientError(code: .writeAfterRequestSent) /// Incompatible headers specified, for example `Transfer-Encoding` and `Content-Length`. public static let incompatibleHeaders = HTTPClientError(code: .incompatibleHeaders) } diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index 41867ac3a..b949cff21 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -665,6 +665,7 @@ internal struct TaskCancelEvent {} internal class TaskHandler: RemovableChannelHandler { enum State { case idle + case bodySent case sent case head case redirected(HTTPResponseHead, URL) @@ -839,6 +840,7 @@ extension TaskHandler: ChannelDuplexHandler { }.flatMap { self.writeBody(request: request, context: context) }.flatMap { + self.state = .bodySent context.eventLoop.assertInEventLoop() if let expectedBodyLength = self.expectedBodyLength, expectedBodyLength != self.actualBodyLength { self.state = .endOrError @@ -876,12 +878,10 @@ extension TaskHandler: ChannelDuplexHandler { let promise = self.task.eventLoop.makePromise(of: Void.self) // All writes have to be switched to the channel EL if channel and task ELs differ if channel.eventLoop.inEventLoop { - self.actualBodyLength += part.readableBytes - context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise) + self.writeBodyPart(context: context, part: part, promise: promise) } else { channel.eventLoop.execute { - self.actualBodyLength += part.readableBytes - context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise) + self.writeBodyPart(context: context, part: part, promise: promise) } } @@ -901,6 +901,26 @@ extension TaskHandler: ChannelDuplexHandler { } } + private func writeBodyPart(context: ChannelHandlerContext, part: IOData, promise: EventLoopPromise) { + switch self.state { + case .idle: + if let limit = self.expectedBodyLength, self.actualBodyLength + part.readableBytes > limit { + let error = HTTPClientError.bodyLengthMismatch + self.state = .endOrError + self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) + promise.fail(error) + return + } + self.actualBodyLength += part.readableBytes + context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise) + default: + let error = HTTPClientError.writeAfterRequestSent + self.state = .endOrError + self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) + promise.fail(error) + } + } + public func read(context: ChannelHandlerContext) { if self.mayRead { self.pendingRead = false @@ -993,7 +1013,7 @@ extension TaskHandler: ChannelDuplexHandler { switch self.state { case .endOrError: break - case .body, .head, .idle, .redirected, .sent: + case .body, .head, .idle, .redirected, .sent, .bodySent: self.state = .endOrError let error = HTTPClientError.remoteConnectionClosed self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift index 14c830e51..6177127d0 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift @@ -45,6 +45,7 @@ extension HTTPClientInternalTests { ("testTaskPromiseBoundToEL", testTaskPromiseBoundToEL), ("testConnectErrorCalloutOnCorrectEL", testConnectErrorCalloutOnCorrectEL), ("testInternalRequestURI", testInternalRequestURI), + ("testBodyPartStreamStateChangedBeforeNotification", testBodyPartStreamStateChangedBeforeNotification), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift index 34891c632..52348ba94 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift @@ -415,8 +415,10 @@ class HTTPClientInternalTests: XCTestCase { } let group = getDefaultEventLoopGroup(numberOfThreads: 3) + let serverGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) + XCTAssertNoThrow(try serverGroup.syncShutdownGracefully()) } let channelEL = group.next() @@ -424,11 +426,10 @@ class HTTPClientInternalTests: XCTestCase { let randoEL = group.next() let httpClient = HTTPClient(eventLoopGroupProvider: .shared(group)) - let promise: EventLoopPromise = httpClient.eventLoopGroup.next().makePromise() - let httpBin = HTTPBin(channelPromise: promise) + let server = NIOHTTP1TestServer(group: serverGroup) defer { + XCTAssertNoThrow(try server.stop()) XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) - XCTAssertNoThrow(try httpBin.shutdown()) } let body: HTTPClient.Body = .stream(length: 8) { writer in @@ -439,7 +440,7 @@ class HTTPClientInternalTests: XCTestCase { } } - let request = try Request(url: "http://127.0.0.1:\(httpBin.port)/custom", + let request = try Request(url: "http://127.0.0.1:\(server.serverPort)/custom", body: body) let delegate = Delegate(expectedEventLoop: delegateEL, randomOtherEventLoop: randoEL) let future = httpClient.execute(request: request, @@ -447,13 +448,18 @@ class HTTPClientInternalTests: XCTestCase { eventLoop: .init(.testOnly_exact(channelOn: channelEL, delegateOn: delegateEL))).futureResult - let channel = try promise.futureResult.wait() + XCTAssertNoThrow(try server.readInbound()) // .head + XCTAssertNoThrow(try server.readInbound()) // .body + XCTAssertNoThrow(try server.readInbound()) // .end // Send 3 parts, but only one should be received until the future is complete - let buffer = channel.allocator.buffer(string: "1234") - try channel.writeAndFlush(HTTPServerResponsePart.body(.byteBuffer(buffer))).wait() + XCTAssertNoThrow(try server.writeOutbound(.head(.init(version: .init(major: 1, minor: 1), + status: .ok, + headers: HTTPHeaders([("Transfer-Encoding", "chunked")]))))) + let buffer = ByteBuffer(string: "1234") + XCTAssertNoThrow(try server.writeOutbound(.body(.byteBuffer(buffer)))) + XCTAssertNoThrow(try server.writeOutbound(.end(nil))) - try channel.writeAndFlush(HTTPServerResponsePart.end(nil)).wait() let (receivedMessages, sentMessages) = try future.wait() XCTAssertEqual(2, receivedMessages.count) XCTAssertEqual(4, sentMessages.count) @@ -488,7 +494,7 @@ class HTTPClientInternalTests: XCTestCase { switch receivedMessages.dropFirst(0).first { case .some(.head(let head)): - XCTAssertEqual(["transfer-encoding": "chunked"], head.headers) + XCTAssertEqual(head.headers["transfer-encoding"].first, "chunked") default: XCTFail("wrong message") } @@ -1025,4 +1031,53 @@ class HTTPClientInternalTests: XCTestCase { XCTAssertEqual(request5.socketPath, "/tmp/file") XCTAssertEqual(request5.uri, "/file/path") } + + func testBodyPartStreamStateChangedBeforeNotification() throws { + class StateValidationDelegate: HTTPClientResponseDelegate { + typealias Response = Void + + var handler: TaskHandler! + var triggered = false + + func didReceiveError(task: HTTPClient.Task, _ error: Error) { + self.triggered = true + switch self.handler.state { + case .endOrError: + // expected + break + default: + XCTFail("unexpected state: \(self.handler.state)") + } + } + + func didFinishRequest(task: HTTPClient.Task) throws {} + } + + let channel = EmbeddedChannel() + XCTAssertNoThrow(try channel.connect(to: try SocketAddress(unixDomainSocketPath: "/fake")).wait()) + + let task = Task(eventLoop: channel.eventLoop, logger: HTTPClient.loggingDisabled) + + let delegate = StateValidationDelegate() + let handler = TaskHandler(task: task, + kind: .host, + delegate: delegate, + redirectHandler: nil, + ignoreUncleanSSLShutdown: false, + logger: HTTPClient.loggingDisabled) + + delegate.handler = handler + try channel.pipeline.addHandler(handler).wait() + + var request = try Request(url: "http://localhost:8080/post") + request.body = .stream(length: 1) { writer in + writer.write(.byteBuffer(ByteBuffer(string: "1234"))) + } + + XCTAssertThrowsError(try channel.writeOutbound(request)) + XCTAssertTrue(delegate.triggered) + + XCTAssertNoThrow(try channel.readOutbound(as: HTTPClientRequestPart.self)) // .head + XCTAssertNoThrow(XCTAssertTrue(try channel.finish().isClean)) + } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index 4a5acc298..eb4dd7cb5 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -120,6 +120,8 @@ extension HTTPClientTests { ("testDelegateCallinsTolerateRandomEL", testDelegateCallinsTolerateRandomEL), ("testContentLengthTooLongFails", testContentLengthTooLongFails), ("testContentLengthTooShortFails", testContentLengthTooShortFails), + ("testBodyUploadAfterEndFails", testBodyUploadAfterEndFails), + ("testNoBytesSentOverBodyLimit", testNoBytesSentOverBodyLimit), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 1127dd394..63441c08a 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -1524,7 +1524,7 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(.ok, firstResponse.status) return localClient.get(url: url) // <== interesting bit here } - }.wait().status)) + }.wait().status)) } func testMakeSecondRequestWhilstFirstIsOngoing() { @@ -1910,7 +1910,7 @@ class HTTPClientTests: XCTestCase { body: .stream { streamWriter in streamWriterPromise.succeed(streamWriter) return sentOffAllBodyPartsPromise.futureResult - }) + }) } guard let server = makeServer(), let request = makeRequest(server: server) else { @@ -2502,7 +2502,7 @@ class HTTPClientTests: XCTestCase { streamWriter.write(.byteBuffer(ByteBuffer(string: "1"))).cascade(to: promise) } return promise.futureResult - })).wait()) { error in + })).wait()) { error in XCTAssertEqual(error as! HTTPClientError, HTTPClientError.bodyLengthMismatch) } // Quickly try another request and check that it works. @@ -2528,7 +2528,7 @@ class HTTPClientTests: XCTestCase { Request(url: url, body: .stream(length: 1) { streamWriter in streamWriter.write(.byteBuffer(ByteBuffer(string: tooLong))) - })).wait()) { error in + })).wait()) { error in XCTAssertEqual(error as! HTTPClientError, HTTPClientError.bodyLengthMismatch) } // Quickly try another request and check that it works. If we by accident wrote some extra bytes into the @@ -2545,4 +2545,61 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(info.connectionNumber, 1) XCTAssertEqual(info.requestNumber, 1) } + + func testBodyUploadAfterEndFails() { + let url = self.defaultHTTPBinURLPrefix + "post" + + func uploader(_ streamWriter: HTTPClient.Body.StreamWriter) -> EventLoopFuture { + let done = streamWriter.write(.byteBuffer(ByteBuffer(string: "X"))) + done.recover { error -> Void in + XCTFail("unexpected error \(error)") + }.whenSuccess { + // This is executed when we have already sent the end of the request. + done.eventLoop.execute { + streamWriter.write(.byteBuffer(ByteBuffer(string: "BAD BAD BAD"))).whenComplete { result in + switch result { + case .success: + XCTFail("we succeeded writing bytes after the end!?") + case .failure(let error): + XCTAssertEqual(HTTPClientError.writeAfterRequestSent, error as? HTTPClientError) + } + } + } + } + return done + } + + XCTAssertThrowsError( + try self.defaultClient.execute(request: + Request(url: url, + body: .stream(length: 1, uploader))).wait()) { error in + XCTAssertEqual(HTTPClientError.writeAfterRequestSent, error as? HTTPClientError) + } + + // Quickly try another request and check that it works. If we by accident wrote some extra bytes into the + // stream (and reuse the connection) that could cause problems. + XCTAssertNoThrow(try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "get").wait()) + } + + func testNoBytesSentOverBodyLimit() throws { + let server = NIOHTTP1TestServer(group: self.serverGroup) + defer { + XCTAssertNoThrow(try server.stop()) + } + + let tooLong = "XBAD BAD BAD NOT HTTP/1.1\r\n\r\n" + let future = self.defaultClient.execute( + request: try Request(url: "http://localhost:\(server.serverPort)", + body: .stream(length: 1) { streamWriter in + streamWriter.write(.byteBuffer(ByteBuffer(string: tooLong))) + })) + + XCTAssertNoThrow(try server.readInbound()) // .head + // this should fail if client detects that we are about to send more bytes than body limit and closes the connection + // We can test that this test actually fails if we remove limit check in `writeBodyPart` - it will send bytes, meaning that the next + // call will not throw, but the future will still throw body mismatch error + XCTAssertThrowsError(try server.readInbound()) { error in XCTAssertEqual(error as? HTTPParserError, HTTPParserError.invalidEOFState) } + + XCTAssertThrowsError(try future.wait()) + } }