Skip to content

Commit 3d48ccc

Browse files
committed
Prevent TaskHandler state change after .endOrError
Motivation: Right now if task handler encounters an error, it changes state to `.endOrError`. We gate on that state to make sure that we do not process errors in the pipeline twice. Unfortunately, that state can be reset when we upload body or receive response parts. Modifications: Adds state validation before state is updated to a new value Adds a test Result: Fixes swift-server#297
1 parent 8c99c41 commit 3d48ccc

File tree

2 files changed

+65
-13
lines changed

2 files changed

+65
-13
lines changed

Diff for: Sources/AsyncHTTPClient/HTTPHandler.swift

+21-13
Original file line numberDiff line numberDiff line change
@@ -840,17 +840,27 @@ extension TaskHandler: ChannelDuplexHandler {
840840
self.writeBody(request: request, context: context)
841841
}.flatMap {
842842
context.eventLoop.assertInEventLoop()
843-
self.state = .bodySent
844-
if let expectedBodyLength = self.expectedBodyLength, expectedBodyLength != self.actualBodyLength {
845-
let error = HTTPClientError.bodyLengthMismatch
846-
self.errorCaught(context: context, error: error)
847-
return context.eventLoop.makeFailedFuture(error)
843+
switch self.state {
844+
case .endOrError:
845+
return context.eventLoop.makeSucceededFuture(())
846+
default:
847+
self.state = .bodySent
848+
if let expectedBodyLength = self.expectedBodyLength, expectedBodyLength != self.actualBodyLength {
849+
let error = HTTPClientError.bodyLengthMismatch
850+
self.errorCaught(context: context, error: error)
851+
return context.eventLoop.makeFailedFuture(error)
852+
}
853+
return context.writeAndFlush(self.wrapOutboundOut(.end(nil)))
848854
}
849-
return context.writeAndFlush(self.wrapOutboundOut(.end(nil)))
850855
}.map {
851856
context.eventLoop.assertInEventLoop()
852-
self.state = .sent
853-
self.callOutToDelegateFireAndForget(self.delegate.didSendRequest)
857+
switch self.state {
858+
case .endOrError:
859+
break
860+
default:
861+
self.state = .sent
862+
self.callOutToDelegateFireAndForget(self.delegate.didSendRequest)
863+
}
854864
}.flatMapErrorThrowing { error in
855865
context.eventLoop.assertInEventLoop()
856866
self.errorCaught(context: context, error: error)
@@ -926,7 +936,7 @@ extension TaskHandler: ChannelDuplexHandler {
926936
case .head(let head):
927937
switch self.state {
928938
case .endOrError:
929-
preconditionFailure("unexpected state on .head")
939+
break
930940
default:
931941
if !head.isKeepAlive {
932942
self.closing = true
@@ -945,10 +955,8 @@ extension TaskHandler: ChannelDuplexHandler {
945955
}
946956
case .body(let body):
947957
switch self.state {
948-
case .redirected:
958+
case .redirected, .endOrError:
949959
break
950-
case .endOrError:
951-
preconditionFailure("unexpected state on .body")
952960
default:
953961
self.state = .body
954962
self.mayRead = false
@@ -960,7 +968,7 @@ extension TaskHandler: ChannelDuplexHandler {
960968
case .end:
961969
switch self.state {
962970
case .endOrError:
963-
preconditionFailure("unexpected state on .end")
971+
break
964972
case .redirected(let head, let redirectURL):
965973
self.state = .endOrError
966974
self.task.releaseAssociatedConnection(delegateType: Delegate.self, closing: self.closing).whenSuccess {

Diff for: Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift

+44
Original file line numberDiff line numberDiff line change
@@ -1119,4 +1119,48 @@ class HTTPClientInternalTests: XCTestCase {
11191119

11201120
XCTAssertEqual(delegate.count, 1)
11211121
}
1122+
1123+
func testTaskHandlerStateChangeAfterError() throws {
1124+
let channel = EmbeddedChannel()
1125+
let task = Task<Void>(eventLoop: channel.eventLoop, logger: HTTPClient.loggingDisabled)
1126+
1127+
let handler = TaskHandler(task: task,
1128+
kind: .host,
1129+
delegate: TestHTTPDelegate(),
1130+
redirectHandler: nil,
1131+
ignoreUncleanSSLShutdown: false,
1132+
logger: HTTPClient.loggingDisabled)
1133+
1134+
try channel.pipeline.addHandler(handler).wait()
1135+
1136+
var request = try Request(url: "http://localhost:8080/get")
1137+
request.headers.add(name: "X-Test-Header", value: "X-Test-Value")
1138+
request.body = .stream(length: 4) { writer in
1139+
writer.write(.byteBuffer(channel.allocator.buffer(string: "1234"))).map {
1140+
handler.state = .endOrError
1141+
}
1142+
}
1143+
1144+
XCTAssertNoThrow(try channel.writeOutbound(request))
1145+
1146+
try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok)))
1147+
XCTAssertTrue(handler.state.isEndOrError)
1148+
1149+
try channel.writeInbound(HTTPClientResponsePart.body(channel.allocator.buffer(string: "1234")))
1150+
XCTAssertTrue(handler.state.isEndOrError)
1151+
1152+
try channel.writeInbound(HTTPClientResponsePart.end(nil))
1153+
XCTAssertTrue(handler.state.isEndOrError)
1154+
}
1155+
}
1156+
1157+
extension TaskHandler.State {
1158+
var isEndOrError: Bool {
1159+
switch self {
1160+
case .endOrError:
1161+
return true
1162+
default:
1163+
return false
1164+
}
1165+
}
11221166
}

0 commit comments

Comments
 (0)