Skip to content

Commit 21c826f

Browse files
committed
Prevent TaskHandler state change after .endOrError
Motivation: Right now if task handler encounters an error, it changes state to `.endOrError`. We gate on that state to make sure that we do not process errors in the pipeline twice. Unfortunately, that state can be reset when we upload body or receive response parts. Modifications: Adds state validation before state is updated to a new value Adds a test Result: Fixes swift-server#297
1 parent 8c99c41 commit 21c826f

File tree

3 files changed

+56
-6
lines changed

3 files changed

+56
-6
lines changed

Diff for: Sources/AsyncHTTPClient/HTTPHandler.swift

+11-6
Original file line numberDiff line numberDiff line change
@@ -840,15 +840,22 @@ extension TaskHandler: ChannelDuplexHandler {
840840
self.writeBody(request: request, context: context)
841841
}.flatMap {
842842
context.eventLoop.assertInEventLoop()
843+
if case .endOrError = self.state {
844+
return context.eventLoop.makeSucceededFuture(())
845+
}
846+
843847
self.state = .bodySent
844848
if let expectedBodyLength = self.expectedBodyLength, expectedBodyLength != self.actualBodyLength {
845849
let error = HTTPClientError.bodyLengthMismatch
846-
self.errorCaught(context: context, error: error)
847850
return context.eventLoop.makeFailedFuture(error)
848851
}
849852
return context.writeAndFlush(self.wrapOutboundOut(.end(nil)))
850853
}.map {
851854
context.eventLoop.assertInEventLoop()
855+
if case .endOrError = self.state {
856+
return
857+
}
858+
852859
self.state = .sent
853860
self.callOutToDelegateFireAndForget(self.delegate.didSendRequest)
854861
}.flatMapErrorThrowing { error in
@@ -926,7 +933,7 @@ extension TaskHandler: ChannelDuplexHandler {
926933
case .head(let head):
927934
switch self.state {
928935
case .endOrError:
929-
preconditionFailure("unexpected state on .head")
936+
break
930937
default:
931938
if !head.isKeepAlive {
932939
self.closing = true
@@ -945,10 +952,8 @@ extension TaskHandler: ChannelDuplexHandler {
945952
}
946953
case .body(let body):
947954
switch self.state {
948-
case .redirected:
955+
case .redirected, .endOrError:
949956
break
950-
case .endOrError:
951-
preconditionFailure("unexpected state on .body")
952957
default:
953958
self.state = .body
954959
self.mayRead = false
@@ -960,7 +965,7 @@ extension TaskHandler: ChannelDuplexHandler {
960965
case .end:
961966
switch self.state {
962967
case .endOrError:
963-
preconditionFailure("unexpected state on .end")
968+
break
964969
case .redirected(let head, let redirectURL):
965970
self.state = .endOrError
966971
self.task.releaseAssociatedConnection(delegateType: Delegate.self, closing: self.closing).whenSuccess {

Diff for: Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ extension HTTPClientInternalTests {
4747
("testInternalRequestURI", testInternalRequestURI),
4848
("testBodyPartStreamStateChangedBeforeNotification", testBodyPartStreamStateChangedBeforeNotification),
4949
("testHandlerDoubleError", testHandlerDoubleError),
50+
("testTaskHandlerStateChangeAfterError", testTaskHandlerStateChangeAfterError),
5051
]
5152
}
5253
}

Diff for: Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift

+44
Original file line numberDiff line numberDiff line change
@@ -1119,4 +1119,48 @@ class HTTPClientInternalTests: XCTestCase {
11191119

11201120
XCTAssertEqual(delegate.count, 1)
11211121
}
1122+
1123+
func testTaskHandlerStateChangeAfterError() throws {
1124+
let channel = EmbeddedChannel()
1125+
let task = Task<Void>(eventLoop: channel.eventLoop, logger: HTTPClient.loggingDisabled)
1126+
1127+
let handler = TaskHandler(task: task,
1128+
kind: .host,
1129+
delegate: TestHTTPDelegate(),
1130+
redirectHandler: nil,
1131+
ignoreUncleanSSLShutdown: false,
1132+
logger: HTTPClient.loggingDisabled)
1133+
1134+
try channel.pipeline.addHandler(handler).wait()
1135+
1136+
var request = try Request(url: "http://localhost:8080/get")
1137+
request.headers.add(name: "X-Test-Header", value: "X-Test-Value")
1138+
request.body = .stream(length: 4) { writer in
1139+
writer.write(.byteBuffer(channel.allocator.buffer(string: "1234"))).map {
1140+
handler.state = .endOrError
1141+
}
1142+
}
1143+
1144+
XCTAssertNoThrow(try channel.writeOutbound(request))
1145+
1146+
try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok)))
1147+
XCTAssertTrue(handler.state.isEndOrError)
1148+
1149+
try channel.writeInbound(HTTPClientResponsePart.body(channel.allocator.buffer(string: "1234")))
1150+
XCTAssertTrue(handler.state.isEndOrError)
1151+
1152+
try channel.writeInbound(HTTPClientResponsePart.end(nil))
1153+
XCTAssertTrue(handler.state.isEndOrError)
1154+
}
1155+
}
1156+
1157+
extension TaskHandler.State {
1158+
var isEndOrError: Bool {
1159+
switch self {
1160+
case .endOrError:
1161+
return true
1162+
default:
1163+
return false
1164+
}
1165+
}
11221166
}

0 commit comments

Comments
 (0)