Skip to content

Refactor Request Validation #391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutingRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,18 @@ protocol HTTPRequestExecutor {
protocol HTTPExecutingRequest: AnyObject {
/// The request's head.
///
/// Based on the content of the request head the task executor will call `startRequestBodyStream`
/// after `requestHeadSent` was called.
/// The HTTP request head, that shall be sent. The HTTPRequestExecutor **will not** run any validation
/// check on the request head. All necessary metadata about the request head the executor expects in
/// the ``requestFramingMetadata``.
var requestHead: HTTPRequestHead { get }

/// The request's framing metadata.
///
/// The request framing metadata that is derived from the ``requestHead``. Based on the content of the
/// request framing metadata the executor will call ``startRequestBodyStream`` after
/// ``requestHeadSent``.
var requestFramingMetadata: RequestFramingMetadata { get }

/// The maximal `TimeAmount` that is allowed to pass between `channelRead`s from the Channel.
var idleReadTimeout: TimeAmount? { get }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
//
//===----------------------------------------------------------------------===//

struct RequestFramingMetadata {
enum Body {
struct RequestFramingMetadata: Equatable {
enum Body: Equatable {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make these Hashable: Equatable by itself is rarely sensible.

case none
case stream
case fixedSize(Int)
Expand Down
62 changes: 36 additions & 26 deletions Sources/AsyncHTTPClient/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,33 @@ extension HTTPClient {
public var port: Int {
return self.url.port ?? (self.useTLS ? 443 : 80)
}

func createRequestHead() throws -> (HTTPRequestHead, RequestFramingMetadata) {
var head = HTTPRequestHead(
version: .http1_1,
method: self.method,
uri: self.uri,
headers: self.headers
)

if !head.headers.contains(name: "host") {
let port = self.port
var host = self.host
if !(port == 80 && self.scheme == "http"), !(port == 443 && self.scheme == "https") {
host += ":\(port)"
}
head.headers.add(name: "host", value: host)
}

let metadata = try head.headers.validate(method: self.method, body: self.body)

// This assert can go away when (if ever!) the above `if` correctly handles other HTTP versions. For example
// in HTTP/1.0, we need to treat the absence of a 'connection: keep-alive' as a close too.
assert(head.version == HTTPVersion(major: 1, minor: 1),
"Sending a request in HTTP version \(head.version) which is unsupported by the above `if`")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this assert is meaningfully enforcing here, is it? You (rightly) kept the original, but this one seems superfluous.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I still don't think this assertion is useful. We literally set the HTTP version to a hardcoded constant just above.


return (head, metadata)
}
}

/// Represent HTTP response.
Expand Down Expand Up @@ -877,46 +904,29 @@ extension TaskHandler: ChannelDuplexHandler {
typealias OutboundOut = HTTPClientRequestPart

func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
self.state = .sendingBodyWaitingResponseHead

let request = self.unwrapOutboundIn(data)
self.state = .sendingBodyWaitingResponseHead

var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1),
method: request.method,
uri: request.uri)
var headers = request.headers

if !request.headers.contains(name: "host") {
let port = request.port
var host = request.host
if !(port == 80 && request.scheme == "http"), !(port == 443 && request.scheme == "https") {
host += ":\(port)"
}
headers.add(name: "host", value: host)
}
let head: HTTPRequestHead
let metadata: RequestFramingMetadata

do {
try headers.validate(method: request.method, body: request.body)
(head, metadata) = try request.createRequestHead()
} catch {
self.errorCaught(context: context, error: error)
promise?.fail(error)
return
}

head.headers = headers

if head.headers[canonicalForm: "connection"].map({ $0.lowercased() }).contains("close") {
self.closing = true
}
// This assert can go away when (if ever!) the above `if` correctly handles other HTTP versions. For example
// in HTTP/1.0, we need to treat the absence of a 'connection: keep-alive' as a close too.
assert(head.version == HTTPVersion(major: 1, minor: 1),
assert(head.version == .http1_1,
"Sending a request in HTTP version \(head.version) which is unsupported by the above `if`")

let contentLengths = head.headers[canonicalForm: "content-length"]
assert(contentLengths.count <= 1)

self.expectedBodyLength = contentLengths.first.flatMap { Int($0) }
if case .fixedSize(let length) = metadata.body {
self.expectedBodyLength = length
}
self.closing = metadata.connectionClose

context.write(wrapOutboundOut(.head(head))).map {
self.callOutToDelegateFireAndForget(value: head, self.delegate.didSendRequestHead)
Expand Down
12 changes: 5 additions & 7 deletions Sources/AsyncHTTPClient/RequestBag.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
let idleReadTimeout: TimeAmount?

let requestHead: HTTPRequestHead
let requestFramingMetadata: RequestFramingMetadata

let eventLoopPreference: HTTPClient.EventLoopPreference

Expand All @@ -50,7 +51,7 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
redirectHandler: RedirectHandler<Delegate.Response>?,
connectionDeadline: NIODeadline,
idleReadTimeout: TimeAmount?,
delegate: Delegate) {
delegate: Delegate) throws {
self.eventLoopPreference = eventLoopPreference
self.task = task
self.state = .init(redirectHandler: redirectHandler)
Expand All @@ -59,12 +60,9 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
self.idleReadTimeout = idleReadTimeout
self.delegate = delegate

self.requestHead = HTTPRequestHead(
version: .http1_1,
method: request.method,
uri: request.uri,
headers: request.headers
)
let (head, metadata) = try request.createRequestHead()
self.requestHead = head
self.requestFramingMetadata = metadata

// TODO: comment in once we switch to using the Request bag in AHC
// self.task.taskDelegate = self
Expand Down
18 changes: 14 additions & 4 deletions Sources/AsyncHTTPClient/RequestValidation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@ import NIO
import NIOHTTP1

extension HTTPHeaders {
mutating func validate(method: HTTPMethod, body: HTTPClient.Body?) throws {
mutating func validate(method: HTTPMethod, body: HTTPClient.Body?) throws -> RequestFramingMetadata {
var metadata = RequestFramingMetadata(connectionClose: false, body: .none)

if self[canonicalForm: "connection"].map({ $0.lowercased() }).contains("close") {
metadata.connectionClose = true
}

// validate transfer encoding and content length (https://tools.ietf.org/html/rfc7230#section-3.3.1)
if self.contains(name: "Transfer-Encoding"), self.contains(name: "Content-Length") {
throw HTTPClientError.incompatibleHeaders
Expand All @@ -43,13 +49,13 @@ extension HTTPHeaders {
// A user agent SHOULD NOT send a Content-Length header field when the request
// message does not contain a payload body and the method semantics do not
// anticipate such a body.
return
return metadata
default:
// A user agent SHOULD send a Content-Length in a request message when
// no Transfer-Encoding is sent and the request method defines a meaning
// for an enclosed payload body.
self.add(name: "Content-Length", value: "0")
return
return metadata
}
}

Expand Down Expand Up @@ -85,14 +91,18 @@ extension HTTPHeaders {
// add headers if required
if let enc = transferEncoding {
self.add(name: "Transfer-Encoding", value: enc)
metadata.body = .stream
} else if let length = contentLength {
// A sender MUST NOT send a Content-Length header field in any message
// that contains a Transfer-Encoding header field.
self.add(name: "Content-Length", value: String(length))
metadata.body = .fixedSize(length)
}

return metadata
}

func validateFieldNames() throws {
private func validateFieldNames() throws {
let invalidFieldNames = self.compactMap { (name, _) -> String? in
let satisfy = name.utf8.allSatisfy { (char) -> Bool in
switch char {
Expand Down
39 changes: 26 additions & 13 deletions Tests/AsyncHTTPClientTests/RequestBagTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,20 @@ final class RequestBagTests: XCTestCase {
var maybeRequest: HTTPClient.Request?
XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org", method: .POST, body: requestBody))
guard let request = maybeRequest else { return XCTFail("Expected to have a request") }

let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop)
let bag = RequestBag(

var maybeRequestBag: RequestBag<UploadCountingDelegate>?
XCTAssertNoThrow(maybeRequestBag = try RequestBag(
request: request,
eventLoopPreference: .delegate(on: embeddedEventLoop),
task: .init(eventLoop: embeddedEventLoop, logger: logger),
redirectHandler: nil,
connectionDeadline: .now() + .seconds(30),
idleReadTimeout: nil,
delegate: delegate
)
))
guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") }

XCTAssert(bag.task.eventLoop === embeddedEventLoop)

let executor = MockRequestExecutor(pauseRequestBodyPartStreamAfterASingleWrite: true)
Expand Down Expand Up @@ -161,15 +164,17 @@ final class RequestBagTests: XCTestCase {
guard let request = maybeRequest else { return XCTFail("Expected to have a request") }

let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop)
let bag = RequestBag(
var maybeRequestBag: RequestBag<UploadCountingDelegate>?
XCTAssertNoThrow(maybeRequestBag = try RequestBag(
request: request,
eventLoopPreference: .delegate(on: embeddedEventLoop),
task: .init(eventLoop: embeddedEventLoop, logger: logger),
redirectHandler: nil,
connectionDeadline: .now() + .seconds(30),
idleReadTimeout: nil,
delegate: delegate
)
))
guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") }
XCTAssert(bag.task.eventLoop === embeddedEventLoop)

let executor = MockRequestExecutor()
Expand Down Expand Up @@ -202,15 +207,17 @@ final class RequestBagTests: XCTestCase {
guard let request = maybeRequest else { return XCTFail("Expected to have a request") }

let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop)
let bag = RequestBag(
var maybeRequestBag: RequestBag<UploadCountingDelegate>?
XCTAssertNoThrow(maybeRequestBag = try RequestBag(
request: request,
eventLoopPreference: .delegate(on: embeddedEventLoop),
task: .init(eventLoop: embeddedEventLoop, logger: logger),
redirectHandler: nil,
connectionDeadline: .now() + .seconds(30),
idleReadTimeout: nil,
delegate: delegate
)
))
guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") }
XCTAssert(bag.eventLoop === embeddedEventLoop)

let executor = MockRequestExecutor()
Expand All @@ -233,15 +240,17 @@ final class RequestBagTests: XCTestCase {
guard let request = maybeRequest else { return XCTFail("Expected to have a request") }

let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop)
let bag = RequestBag(
var maybeRequestBag: RequestBag<UploadCountingDelegate>?
XCTAssertNoThrow(maybeRequestBag = try RequestBag(
request: request,
eventLoopPreference: .delegate(on: embeddedEventLoop),
task: .init(eventLoop: embeddedEventLoop, logger: logger),
redirectHandler: nil,
connectionDeadline: .now() + .seconds(30),
idleReadTimeout: nil,
delegate: delegate
)
))
guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") }
XCTAssert(bag.eventLoop === embeddedEventLoop)

let executor = MockRequestExecutor()
Expand Down Expand Up @@ -273,15 +282,17 @@ final class RequestBagTests: XCTestCase {
guard let request = maybeRequest else { return XCTFail("Expected to have a request") }

let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop)
let bag = RequestBag(
var maybeRequestBag: RequestBag<UploadCountingDelegate>?
XCTAssertNoThrow(maybeRequestBag = try RequestBag(
request: request,
eventLoopPreference: .delegate(on: embeddedEventLoop),
task: .init(eventLoop: embeddedEventLoop, logger: logger),
redirectHandler: nil,
connectionDeadline: .now() + .seconds(30),
idleReadTimeout: nil,
delegate: delegate
)
))
guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") }

let queuer = MockTaskQueuer()
bag.requestWasQueued(queuer)
Expand Down Expand Up @@ -328,15 +339,17 @@ final class RequestBagTests: XCTestCase {
guard let request = maybeRequest else { return XCTFail("Expected to have a request") }

let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop)
let bag = RequestBag(
var maybeRequestBag: RequestBag<UploadCountingDelegate>?
XCTAssertNoThrow(maybeRequestBag = try RequestBag(
request: request,
eventLoopPreference: .delegate(on: embeddedEventLoop),
task: .init(eventLoop: embeddedEventLoop, logger: logger),
redirectHandler: nil,
connectionDeadline: .now() + .seconds(30),
idleReadTimeout: nil,
delegate: delegate
)
))
guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") }

let executor = MockRequestExecutor()
bag.willExecuteRequest(executor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ extension RequestValidationTests {
("testGET_HEAD_DELETE_CONNECTRequestCanHaveBody", testGET_HEAD_DELETE_CONNECTRequestCanHaveBody),
("testInvalidHeaderFieldNames", testInvalidHeaderFieldNames),
("testValidHeaderFieldNames", testValidHeaderFieldNames),
("testMetadataDetectConnectionClose", testMetadataDetectConnectionClose),
("testMetadataDefaultIsConnectionCloseIsFalse", testMetadataDefaultIsConnectionCloseIsFalse),
("testNoHeadersNoBody", testNoHeadersNoBody),
("testNoHeadersHasBody", testNoHeadersHasBody),
("testContentLengthHeaderNoBody", testContentLengthHeaderNoBody),
Expand Down
Loading