Skip to content

check body length #255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 16, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions Sources/AsyncHTTPClient/HTTPClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
case uncleanShutdown
case traceRequestWithBody
case invalidHeaderFieldNames([String])
case bodyLengthMismatch
}

private var code: Code
Expand Down Expand Up @@ -969,10 +970,12 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
public static let redirectLimitReached = HTTPClientError(code: .redirectLimitReached)
/// Redirect Cycle detected.
public static let redirectCycleDetected = HTTPClientError(code: .redirectCycleDetected)
/// Unclean shutdown
/// Unclean shutdown.
public static let uncleanShutdown = HTTPClientError(code: .uncleanShutdown)
/// A body was sent in a request with method TRACE
/// A body was sent in a request with method TRACE.
public static let traceRequestWithBody = HTTPClientError(code: .traceRequestWithBody)
/// Header field names contain invalid characters
/// Header field names contain invalid characters.
public static func invalidHeaderFieldNames(_ names: [String]) -> HTTPClientError { return HTTPClientError(code: .invalidHeaderFieldNames(names)) }
/// Body length is not equal to `Content-Length`.
public static let bodyLengthMismatch = HTTPClientError(code: .bodyLengthMismatch)
}
43 changes: 29 additions & 14 deletions Sources/AsyncHTTPClient/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ internal class TaskHandler<Delegate: HTTPClientResponseDelegate>: RemovableChann
case head
case redirected(HTTPResponseHead, URL)
case body
case end
case endOrError
}

let task: HTTPClient.Task<Delegate.Response>
Expand All @@ -651,6 +651,8 @@ internal class TaskHandler<Delegate: HTTPClientResponseDelegate>: RemovableChann
let logger: Logger // We are okay to store the logger here because a TaskHandler is just for one request.

var state: State = .idle
var expectedBodyLength: Int?
var actualBodyLength: Int = 0
var pendingRead = false
var mayRead = true
var closing = false {
Expand Down Expand Up @@ -780,7 +782,7 @@ extension TaskHandler: ChannelDuplexHandler {
} catch {
promise?.fail(error)
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
self.state = .end
self.state = .endOrError
return
}

Expand All @@ -794,12 +796,23 @@ extension TaskHandler: ChannelDuplexHandler {
assert(head.version == HTTPVersion(major: 1, minor: 1),
"Sending a request in HTTP version \(head.version) which is unsupported by the above `if`")

let contentLengths = head.headers[canonicalForm: "content-length"]
assert(contentLengths.count <= 1)

self.expectedBodyLength = contentLengths.first.flatMap { Int($0) }

context.write(wrapOutboundOut(.head(head))).map {
self.callOutToDelegateFireAndForget(value: head, self.delegate.didSendRequestHead)
}.flatMap {
self.writeBody(request: request, context: context)
}.flatMap {
context.eventLoop.assertInEventLoop()
if let expectedBodyLength = self.expectedBodyLength, expectedBodyLength != self.actualBodyLength {
self.state = .endOrError
let error = HTTPClientError.bodyLengthMismatch
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
return context.eventLoop.makeFailedFuture(error)
}
return context.writeAndFlush(self.wrapOutboundOut(.end(nil)))
}.map {
context.eventLoop.assertInEventLoop()
Expand All @@ -808,10 +821,10 @@ extension TaskHandler: ChannelDuplexHandler {
}.flatMapErrorThrowing { error in
context.eventLoop.assertInEventLoop()
switch self.state {
case .end:
case .endOrError:
break
default:
self.state = .end
self.state = .endOrError
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
}
throw error
Expand All @@ -828,9 +841,11 @@ extension TaskHandler: ChannelDuplexHandler {
let promise = self.task.eventLoop.makePromise(of: Void.self)
// All writes have to be switched to the channel EL if channel and task ELs differ
if context.eventLoop.inEventLoop {
self.actualBodyLength += part.readableBytes
context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise)
} else {
context.eventLoop.execute {
self.actualBodyLength += part.readableBytes
context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise)
}
}
Expand Down Expand Up @@ -893,12 +908,12 @@ extension TaskHandler: ChannelDuplexHandler {
case .end:
switch self.state {
case .redirected(let head, let redirectURL):
self.state = .end
self.state = .endOrError
self.task.releaseAssociatedConnection(delegateType: Delegate.self, closing: self.closing).whenSuccess {
self.redirectHandler?.redirect(status: head.status, to: redirectURL, promise: self.task.promise)
}
default:
self.state = .end
self.state = .endOrError
self.callOutToDelegate(promise: self.task.promise, self.delegate.didFinishRequest)
}
}
Expand All @@ -913,14 +928,14 @@ extension TaskHandler: ChannelDuplexHandler {
context.read()
}
case .failure(let error):
self.state = .end
self.state = .endOrError
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
}
}

func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
if (event as? IdleStateHandler.IdleStateEvent) == .read {
self.state = .end
self.state = .endOrError
let error = HTTPClientError.readTimeout
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
} else {
Expand All @@ -930,7 +945,7 @@ extension TaskHandler: ChannelDuplexHandler {

func triggerUserOutboundEvent(context: ChannelHandlerContext, event: Any, promise: EventLoopPromise<Void>?) {
if (event as? TaskCancelEvent) != nil {
self.state = .end
self.state = .endOrError
let error = HTTPClientError.cancelled
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
promise?.succeed(())
Expand All @@ -941,10 +956,10 @@ extension TaskHandler: ChannelDuplexHandler {

func channelInactive(context: ChannelHandlerContext) {
switch self.state {
case .end:
case .endOrError:
break
case .body, .head, .idle, .redirected, .sent:
self.state = .end
self.state = .endOrError
let error = HTTPClientError.remoteConnectionClosed
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
}
Expand All @@ -955,7 +970,7 @@ extension TaskHandler: ChannelDuplexHandler {
switch error {
case NIOSSLError.uncleanShutdown:
switch self.state {
case .end:
case .endOrError:
/// Some HTTP Servers can 'forget' to respond with CloseNotify when client is closing connection,
/// this could lead to incomplete SSL shutdown. But since request is already processed, we can ignore this error.
break
Expand All @@ -964,11 +979,11 @@ extension TaskHandler: ChannelDuplexHandler {
/// We can also ignore this error like `.end`.
break
default:
self.state = .end
self.state = .endOrError
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
}
default:
self.state = .end
self.state = .endOrError
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
}
}
Expand Down
32 changes: 20 additions & 12 deletions Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ internal final class HTTPBin {
let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
let serverChannel: Channel
let isShutdown: NIOAtomic<Bool> = .makeAtomic(value: false)
var connections: NIOAtomic<Int>
var connectionCount: NIOAtomic<Int> = .makeAtomic(value: 0)
private let activeConnCounterHandler: CountActiveConnectionsHandler
var activeConnections: Int {
Expand Down Expand Up @@ -233,6 +234,9 @@ internal final class HTTPBin {
let activeConnCounterHandler = CountActiveConnectionsHandler()
self.activeConnCounterHandler = activeConnCounterHandler

let connections = NIOAtomic.makeAtomic(value: 0)
self.connections = connections

self.serverChannel = try! ServerBootstrap(group: self.group)
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.serverChannelInitializer { channel in
Expand Down Expand Up @@ -261,10 +265,10 @@ internal final class HTTPBin {
}.flatMap {
if ssl {
return HTTPBin.configureTLS(channel: channel).flatMap {
channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise, maxChannelAge: maxChannelAge))
channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise, maxChannelAge: maxChannelAge, connectionId: connections.add(1)))
}
} else {
return channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise, maxChannelAge: maxChannelAge))
return channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise, maxChannelAge: maxChannelAge, connectionId: connections.add(1)))
}
}
}
Expand Down Expand Up @@ -357,9 +361,6 @@ internal struct HTTPResponseBuilder {
}
}

let globalRequestCounter = NIOAtomic<Int>.makeAtomic(value: 0)
let globalConnectionCounter = NIOAtomic<Int>.makeAtomic(value: 0)

internal struct RequestInfo: Codable {
var data: String
var requestNumber: Int
Expand All @@ -378,13 +379,13 @@ internal final class HttpBinHandler: ChannelInboundHandler {
let maxChannelAge: TimeAmount?
var shouldClose = false
var isServingRequest = false
let myConnectionNumber: Int
var currentRequestNumber: Int = -1
let connectionId: Int
var requestId: Int = 0

init(channelPromise: EventLoopPromise<Channel>? = nil, maxChannelAge: TimeAmount? = nil) {
init(channelPromise: EventLoopPromise<Channel>? = nil, maxChannelAge: TimeAmount? = nil, connectionId: Int) {
self.channelPromise = channelPromise
self.maxChannelAge = maxChannelAge
self.myConnectionNumber = globalConnectionCounter.add(1)
self.connectionId = connectionId
}

func handlerAdded(context: ChannelHandlerContext) {
Expand Down Expand Up @@ -424,7 +425,7 @@ internal final class HttpBinHandler: ChannelInboundHandler {
switch self.unwrapInboundIn(data) {
case .head(let req):
self.responseHeaders = HTTPHeaders()
self.currentRequestNumber = globalRequestCounter.add(1)
self.requestId += 1
self.parseAndSetOptions(from: req)
let urlComponents = URLComponents(string: req.uri)!
switch urlComponents.percentEncodedPath {
Expand Down Expand Up @@ -552,8 +553,15 @@ internal final class HttpBinHandler: ChannelInboundHandler {
context.write(wrapOutboundOut(.head(response.head)), promise: nil)
if let body = response.body {
let requestInfo = RequestInfo(data: String(buffer: body),
requestNumber: self.currentRequestNumber,
connectionNumber: self.myConnectionNumber)
requestNumber: self.requestId,
connectionNumber: self.connectionId)
let responseBody = try! JSONEncoder().encodeAsByteBuffer(requestInfo,
allocator: context.channel.allocator)
context.write(wrapOutboundOut(.body(.byteBuffer(responseBody))), promise: nil)
} else {
let requestInfo = RequestInfo(data: "",
requestNumber: self.requestId,
connectionNumber: self.connectionId)
let responseBody = try! JSONEncoder().encodeAsByteBuffer(requestInfo,
allocator: context.channel.allocator)
context.write(wrapOutboundOut(.body(.byteBuffer(responseBody))), promise: nil)
Expand Down
2 changes: 2 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ extension HTTPClientTests {
("testAllMethodsLog", testAllMethodsLog),
("testClosingIdleConnectionsInPoolLogsInTheBackground", testClosingIdleConnectionsInPoolLogsInTheBackground),
("testDelegateCallinsTolerateRandomEL", testDelegateCallinsTolerateRandomEL),
("testContentLengthTooLongFails", testContentLengthTooLongFails),
("testContentLengthTooShortFails", testContentLengthTooShortFails),
]
}
}
51 changes: 47 additions & 4 deletions Tests/AsyncHTTPClientTests/HTTPClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1713,7 +1713,8 @@ class HTTPClientTests: XCTestCase {

// req 1 and 2 cannot share the same connection (close header)
XCTAssertEqual(stats1.connectionNumber + 1, stats2.connectionNumber)
XCTAssertEqual(stats1.requestNumber + 1, stats2.requestNumber)
XCTAssertEqual(stats1.requestNumber, 1)
XCTAssertEqual(stats2.requestNumber, 1)

// req 2 and 3 should share the same connection (keep-alive is default)
XCTAssertEqual(stats2.requestNumber + 1, stats3.requestNumber)
Expand Down Expand Up @@ -1742,7 +1743,8 @@ class HTTPClientTests: XCTestCase {

// req 1 and 2 cannot share the same connection (close header)
XCTAssertEqual(stats1.connectionNumber + 1, stats2.connectionNumber)
XCTAssertEqual(stats1.requestNumber + 1, stats2.requestNumber)
XCTAssertEqual(stats1.requestNumber, 1)
XCTAssertEqual(stats2.requestNumber, 1)

// req 2 and 3 should share the same connection (keep-alive is default)
XCTAssertEqual(stats2.requestNumber + 1, stats3.requestNumber)
Expand Down Expand Up @@ -1773,7 +1775,7 @@ class HTTPClientTests: XCTestCase {

// req 1 and 2 cannot share the same connection (close header)
XCTAssertEqual(stats1.connectionNumber + 1, stats2.connectionNumber)
XCTAssertEqual(stats1.requestNumber + 1, stats2.requestNumber)
XCTAssertEqual(stats2.requestNumber, 1)

// req 2 and 3 should share the same connection (keep-alive is default)
XCTAssertEqual(stats2.requestNumber + 1, stats3.requestNumber)
Expand Down Expand Up @@ -1805,7 +1807,7 @@ class HTTPClientTests: XCTestCase {

// req 1 and 2 cannot share the same connection (close header)
XCTAssertEqual(stats1.connectionNumber + 1, stats2.connectionNumber)
XCTAssertEqual(stats1.requestNumber + 1, stats2.requestNumber)
XCTAssertEqual(stats2.requestNumber, 1)

// req 2 and 3 should share the same connection (keep-alive is default)
XCTAssertEqual(stats2.requestNumber + 1, stats3.requestNumber)
Expand Down Expand Up @@ -2051,4 +2053,45 @@ class HTTPClientTests: XCTestCase {

XCTAssertNoThrow(try future.wait())
}

func testContentLengthTooLongFails() throws {
let url = self.defaultHTTPBinURLPrefix + "/post"
XCTAssertThrowsError(
try self.defaultClient.execute(request:
Request(url: url,
body: .stream(length: 10) { streamWriter in
let promise = self.defaultClient.eventLoopGroup.next().makePromise(of: Void.self)
DispatchQueue(label: "content-length-test").async {
streamWriter.write(.byteBuffer(ByteBuffer(string: "1"))).cascade(to: promise)
}
return promise.futureResult
})).wait()) { error in
XCTAssertEqual(error as! HTTPClientError, HTTPClientError.bodyLengthMismatch)
}
// Quickly try another request and check that it works.
var response = try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "get").wait()
let info = try response.body!.readJSONDecodable(RequestInfo.self, length: response.body!.readableBytes)
XCTAssertEqual(info!.connectionNumber, 1)
XCTAssertEqual(info!.requestNumber, 1)
}

// currently gets stuck because of #250 the server just never replies
func testContentLengthTooShortFails() throws {
let url = self.defaultHTTPBinURLPrefix + "/post"
let tooLong = "XBAD BAD BAD NOT HTTP/1.1\r\n\r\n"
XCTAssertThrowsError(
try self.defaultClient.execute(request:
Request(url: url,
body: .stream(length: 1) { streamWriter in
streamWriter.write(.byteBuffer(ByteBuffer(string: tooLong)))
})).wait()) { error in
XCTAssertEqual(error as! HTTPClientError, HTTPClientError.bodyLengthMismatch)
}
// Quickly try another request and check that it works. If we by accident wrote some extra bytes into the
// stream (and reuse the connection) that could cause problems.
var response = try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "get").wait()
let info = try response.body!.readJSONDecodable(RequestInfo.self, length: response.body!.readableBytes)
XCTAssertEqual(info!.connectionNumber, 1)
XCTAssertEqual(info!.requestNumber, 1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ extension RequestValidationTests {
("testGET_HEAD_DELETE_CONNECTRequestCanHaveBody", testGET_HEAD_DELETE_CONNECTRequestCanHaveBody),
("testInvalidHeaderFieldNames", testInvalidHeaderFieldNames),
("testValidHeaderFieldNames", testValidHeaderFieldNames),
("testMultipleContentLengthOnNilStreamLength", testMultipleContentLengthOnNilStreamLength),
]
}
}
10 changes: 10 additions & 0 deletions Tests/AsyncHTTPClientTests/RequestValidationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,14 @@ class RequestValidationTests: XCTestCase {

XCTAssertNoThrow(try headers.validate(method: .GET, body: nil))
}

func testMultipleContentLengthOnNilStreamLength() {
var headers = HTTPHeaders([("Content-Length", "1"), ("Content-Length", "2")])
var buffer = ByteBufferAllocator().buffer(capacity: 10)
buffer.writeBytes([UInt8](repeating: 12, count: 10))
let body: HTTPClient.Body = .stream { writer in
writer.write(.byteBuffer(buffer))
}
XCTAssertThrowsError(try headers.validate(method: .PUT, body: body))
}
}