Skip to content

Fix bodyLengthMissmatch error handling #490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -190,23 +190,42 @@ struct HTTPRequestStateMachine {
}

mutating func errorHappened(_ error: Error) -> Action {
if let error = error as? NIOSSLError,
error == .uncleanShutdown,
let action = self.handleNIOSSLUncleanShutdownError() {
return action
}
switch self.state {
case .initialized:
preconditionFailure("After the state machine has been initialized, start must be called immediately. Thus this state is unreachable")
case .waitForChannelToBecomeWritable:
// the request failed, before it was sent onto the wire.
self.state = .failed(error)
return .failRequest(error, .none)
case .running:
self.state = .failed(error)
return .failRequest(error, .close)

case .finished, .failed:
// ignore error
return .wait

case .modifying:
preconditionFailure("Invalid state: \(self.state)")
}
}

private mutating func handleNIOSSLUncleanShutdownError() -> Action? {
switch self.state {
case .running(.streaming, .waitingForHead),
.running(.endSent, .waitingForHead) where error as? NIOSSLError == .uncleanShutdown:
.running(.endSent, .waitingForHead):
// if we received a NIOSSL.uncleanShutdown before we got an answer we should handle
// this like a normal connection close. We will receive a call to channelInactive after
// this error.
return .wait

case .running(.streaming, .receivingBody(let responseHead, _)),
.running(.endSent, .receivingBody(let responseHead, _)) where error as? NIOSSLError == .uncleanShutdown:
.running(.endSent, .receivingBody(let responseHead, _)):
// This code is only reachable for request and responses, which we expect to have a body.
// We depend on logic from the HTTPResponseDecoder here. The decoder will emit an
// HTTPResponsePart.end right after the HTTPResponsePart.head, for every request with a
Expand All @@ -226,19 +245,11 @@ struct HTTPRequestStateMachine {
// If the response is EOF terminated, we need to rely on a clean tls shutdown to be sure
// we have received all necessary bytes. For this reason we forward the uncleanShutdown
// error to the user.
self.state = .failed(error)
return .failRequest(error, .close)

case .running:
self.state = .failed(error)
return .failRequest(error, .close)

case .finished, .failed:
// ignore error
return .wait
self.state = .failed(NIOSSLError.uncleanShutdown)
return .failRequest(NIOSSLError.uncleanShutdown, .close)

case .modifying:
preconditionFailure("Invalid state: \(self.state)")
case .waitForChannelToBecomeWritable, .running, .finished, .failed, .initialized, .modifying:
return nil
}
}

Expand Down Expand Up @@ -270,7 +281,7 @@ struct HTTPRequestStateMachine {

if let expected = expectedBodyLength, sentBodyBytes + part.readableBytes > expected {
let error = HTTPClientError.bodyLengthMismatch

self.state = .failed(error)
return .failRequest(error, .close)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ extension HTTPRequestStateMachineTests {
("testCanReadHTTP1_0ResponseWithBody", testCanReadHTTP1_0ResponseWithBody),
("testFailHTTP1_0RequestThatIsStillUploading", testFailHTTP1_0RequestThatIsStillUploading),
("testFailHTTP1RequestWithoutContentLengthWithNIOSSLErrorUncleanShutdown", testFailHTTP1RequestWithoutContentLengthWithNIOSSLErrorUncleanShutdown),
("testNIOSSLErrorUncleanShutdownShouldBeTreatedAsRemoteConnectionCloseWhileInWaitingForHeadState", testNIOSSLErrorUncleanShutdownShouldBeTreatedAsRemoteConnectionCloseWhileInWaitingForHeadState),
("testArbitraryErrorShouldBeTreatedAsARequestFailureWhileInWaitingForHeadState", testArbitraryErrorShouldBeTreatedAsARequestFailureWhileInWaitingForHeadState),
("testFailHTTP1RequestWithContentLengthWithNIOSSLErrorUncleanShutdownButIgnoreIt", testFailHTTP1RequestWithContentLengthWithNIOSSLErrorUncleanShutdownButIgnoreIt),
("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand),
("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForRead", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForRead),
Expand Down
85 changes: 59 additions & 26 deletions Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,10 @@ class HTTPRequestStateMachineTests: XCTestCase {
let part1 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3]))
XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0))

let failAction = state.requestStreamPartReceived(part1)
guard case .failRequest(let error, .close) = failAction else {
return XCTFail("Unexpected action: \(failAction)")
}
state.requestStreamPartReceived(part1).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close)

XCTAssertEqual(error as? HTTPClientError, .bodyLengthMismatch)
// if another error happens the new one is ignored
XCTAssertEqual(state.errorHappened(HTTPClientError.remoteConnectionClosed), .wait)
}

func testPOSTContentLengthIsTooShort() {
Expand All @@ -92,12 +90,7 @@ class HTTPRequestStateMachineTests: XCTestCase {
let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3]))
XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0))

let failAction = state.requestStreamFinished()
guard case .failRequest(let error, .close) = failAction else {
return XCTFail("Unexpected action: \(failAction)")
}

XCTAssertEqual(error as? HTTPClientError, .bodyLengthMismatch)
state.requestStreamFinished().assertFailRequest(HTTPClientError.bodyLengthMismatch, .close)
}

func testRequestBodyStreamIsCancelledIfServerRespondsWith301() {
Expand Down Expand Up @@ -206,7 +199,7 @@ class HTTPRequestStateMachineTests: XCTestCase {

let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7))
XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1))
XCTAssertEqual(state.requestStreamFinished(), .failRequest(HTTPClientError.bodyLengthMismatch, .close))
state.requestStreamFinished().assertFailRequest(HTTPClientError.bodyLengthMismatch, .close)
XCTAssertEqual(state.channelInactive(), .wait)
}

Expand All @@ -224,7 +217,7 @@ class HTTPRequestStateMachineTests: XCTestCase {

let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7))
XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1))
XCTAssertEqual(state.requestStreamFinished(), .failRequest(HTTPClientError.bodyLengthMismatch, .close))
state.requestStreamFinished().assertFailRequest(HTTPClientError.bodyLengthMismatch, .close)
XCTAssertEqual(state.channelRead(.end(nil)), .wait)
}

Expand All @@ -249,7 +242,7 @@ class HTTPRequestStateMachineTests: XCTestCase {
let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/")
let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0))
XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .wait)
XCTAssertEqual(state.channelInactive(), .failRequest(HTTPClientError.remoteConnectionClosed, .none))
state.channelInactive().assertFailRequest(HTTPClientError.remoteConnectionClosed, .none)
}

func testResponseReadingWithBackpressure() {
Expand Down Expand Up @@ -339,15 +332,15 @@ class HTTPRequestStateMachineTests: XCTestCase {

func testCancellingARequestInStateInitializedKeepsTheConnectionAlive() {
var state = HTTPRequestStateMachine(isChannelWritable: false)
XCTAssertEqual(state.requestCancelled(), .failRequest(HTTPClientError.cancelled, .none))
state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .none)
}

func testCancellingARequestBeforeBeingSendKeepsTheConnectionAlive() {
var state = HTTPRequestStateMachine(isChannelWritable: false)
let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/")
let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0))
XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .wait)
XCTAssertEqual(state.requestCancelled(), .failRequest(HTTPClientError.cancelled, .none))
state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .none)
}

func testConnectionBecomesWritableBeforeFirstRequest() {
Expand All @@ -373,15 +366,15 @@ class HTTPRequestStateMachineTests: XCTestCase {
let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/")
let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0))
XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false))
XCTAssertEqual(state.requestCancelled(), .failRequest(HTTPClientError.cancelled, .close))
state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .close)
}

func testRemoteSuddenlyClosesTheConnection() {
var state = HTTPRequestStateMachine(isChannelWritable: true)
let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/", headers: .init([("content-length", "4")]))
let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4))
XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true))
XCTAssertEqual(state.requestCancelled(), .failRequest(HTTPClientError.remoteConnectionClosed, .close))
state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .close)
XCTAssertEqual(state.requestStreamPartReceived(.byteBuffer(.init(bytes: 1...3))), .wait)
}

Expand All @@ -395,7 +388,7 @@ class HTTPRequestStateMachineTests: XCTestCase {
XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false))
let part0 = ByteBuffer(bytes: 0...3)
XCTAssertEqual(state.channelRead(.body(part0)), .wait)
XCTAssertEqual(state.idleReadTimeoutTriggered(), .failRequest(HTTPClientError.readTimeout, .close))
state.idleReadTimeoutTriggered().assertFailRequest(HTTPClientError.readTimeout, .close)
XCTAssertEqual(state.channelRead(.body(ByteBuffer(bytes: 4...7))), .wait)
XCTAssertEqual(state.channelRead(.body(ByteBuffer(bytes: 8...11))), .wait)
XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait)
Expand Down Expand Up @@ -448,7 +441,7 @@ class HTTPRequestStateMachineTests: XCTestCase {
let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0))
XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false))

XCTAssertEqual(state.errorHappened(HTTPParserError.invalidChunkSize), .failRequest(HTTPParserError.invalidChunkSize, .close))
state.errorHappened(HTTPParserError.invalidChunkSize).assertFailRequest(HTTPParserError.invalidChunkSize, .close)
XCTAssertEqual(state.requestCancelled(), .wait, "A cancellation that happens to late is ignored")
}

Expand Down Expand Up @@ -502,7 +495,7 @@ class HTTPRequestStateMachineTests: XCTestCase {
XCTAssertEqual(state.read(), .read)
XCTAssertEqual(state.channelReadComplete(), .wait)
XCTAssertEqual(state.channelRead(.body(body)), .wait)
XCTAssertEqual(state.channelRead(.end(nil)), .failRequest(HTTPClientError.remoteConnectionClosed, .close))
state.channelRead(.end(nil)).assertFailRequest(HTTPClientError.remoteConnectionClosed, .close)
XCTAssertEqual(state.channelInactive(), .wait)
}

Expand All @@ -517,11 +510,32 @@ class HTTPRequestStateMachineTests: XCTestCase {
XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false))
XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait)
XCTAssertEqual(state.channelRead(.body(body)), .wait)
XCTAssertEqual(state.errorHappened(NIOSSLError.uncleanShutdown), .failRequest(NIOSSLError.uncleanShutdown, .close))
state.errorHappened(NIOSSLError.uncleanShutdown).assertFailRequest(NIOSSLError.uncleanShutdown, .close)
XCTAssertEqual(state.channelRead(.end(nil)), .wait)
XCTAssertEqual(state.channelInactive(), .wait)
}

func testNIOSSLErrorUncleanShutdownShouldBeTreatedAsRemoteConnectionCloseWhileInWaitingForHeadState() {
var state = HTTPRequestStateMachine(isChannelWritable: true)
let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/")
let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0))
XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false))

XCTAssertEqual(state.errorHappened(NIOSSLError.uncleanShutdown), .wait)
state.channelInactive().assertFailRequest(HTTPClientError.remoteConnectionClosed, .none)
}

func testArbitraryErrorShouldBeTreatedAsARequestFailureWhileInWaitingForHeadState() {
struct ArbitraryError: Error, Equatable {}
var state = HTTPRequestStateMachine(isChannelWritable: true)
let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/")
let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0))
XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false))

state.errorHappened(ArbitraryError()).assertFailRequest(ArbitraryError(), .close)
XCTAssertEqual(state.channelInactive(), .wait)
}

func testFailHTTP1RequestWithContentLengthWithNIOSSLErrorUncleanShutdownButIgnoreIt() {
var state = HTTPRequestStateMachine(isChannelWritable: true)
let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/")
Expand All @@ -536,7 +550,7 @@ class HTTPRequestStateMachineTests: XCTestCase {
XCTAssertEqual(state.channelRead(.body(body)), .wait)
XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts([body]))
XCTAssertEqual(state.errorHappened(NIOSSLError.uncleanShutdown), .wait)
XCTAssertEqual(state.errorHappened(HTTPParserError.invalidEOFState), .failRequest(HTTPParserError.invalidEOFState, .close))
state.errorHappened(HTTPParserError.invalidEOFState).assertFailRequest(HTTPParserError.invalidEOFState, .close)
XCTAssertEqual(state.channelInactive(), .wait)
}

Expand All @@ -558,7 +572,7 @@ class HTTPRequestStateMachineTests: XCTestCase {

XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: " baz lightyear"))), .wait)
XCTAssertEqual(state.channelReadComplete(), .wait)
XCTAssertEqual(state.channelInactive(), .failRequest(HTTPClientError.remoteConnectionClosed, .none))
state.channelInactive().assertFailRequest(HTTPClientError.remoteConnectionClosed, .none)
}

func testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForRead() {
Expand All @@ -579,7 +593,7 @@ class HTTPRequestStateMachineTests: XCTestCase {

XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: " baz lightyear"))), .wait)
XCTAssertEqual(state.channelReadComplete(), .wait)
XCTAssertEqual(state.channelInactive(), .failRequest(HTTPClientError.remoteConnectionClosed, .none))
state.channelInactive().assertFailRequest(HTTPClientError.remoteConnectionClosed, .none)
}

func testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemand() {
Expand All @@ -599,7 +613,7 @@ class HTTPRequestStateMachineTests: XCTestCase {

XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: " baz lightyear"))), .wait)
XCTAssertEqual(state.channelReadComplete(), .wait)
XCTAssertEqual(state.channelInactive(), .failRequest(HTTPClientError.remoteConnectionClosed, .none))
state.channelInactive().assertFailRequest(HTTPClientError.remoteConnectionClosed, .none)
}

func testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemandMultipleTimes() {
Expand Down Expand Up @@ -676,3 +690,22 @@ extension HTTPRequestStateMachine.Action: Equatable {
}
}
}

extension HTTPRequestStateMachine.Action {
fileprivate func assertFailRequest<Error>(
_ expectedError: Error,
_ expectedFinalStreamAction: HTTPRequestStateMachine.Action.FinalStreamAction,
file: StaticString = #file,
line: UInt = #line
) where Error: Swift.Error & Equatable {
guard case .failRequest(let actualError, let actualFinalStreamAction) = self else {
return XCTFail("expected .failRequest(\(expectedError), \(expectedFinalStreamAction)) but got \(self)", file: file, line: line)
}
if let actualError = actualError as? Error {
XCTAssertEqual(actualError, expectedError, file: file, line: line)
} else {
XCTFail("\(actualError) is not equal to \(expectedError)", file: file, line: line)
}
XCTAssertEqual(actualFinalStreamAction, expectedFinalStreamAction, file: file, line: line)
}
}