From 8c99c41b8d8c0a57d80803d24c9fe31d367fcfad Mon Sep 17 00:00:00 2001 From: Artem Redkin <aredkin@apple.com> Date: Fri, 21 Aug 2020 14:56:37 +0100 Subject: [PATCH 1/2] fail if we get part when state is endOrError --- Sources/AsyncHTTPClient/HTTPHandler.swift | 35 ++++++++++++++--------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index 12e6a4fc4..084a66b0b 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -839,8 +839,8 @@ extension TaskHandler: ChannelDuplexHandler { }.flatMap { self.writeBody(request: request, context: context) }.flatMap { - self.state = .bodySent context.eventLoop.assertInEventLoop() + self.state = .bodySent if let expectedBodyLength = self.expectedBodyLength, expectedBodyLength != self.actualBodyLength { let error = HTTPClientError.bodyLengthMismatch self.errorCaught(context: context, error: error) @@ -924,24 +924,31 @@ extension TaskHandler: ChannelDuplexHandler { let response = self.unwrapInboundIn(data) switch response { case .head(let head): - if !head.isKeepAlive { - self.closing = true - } + switch self.state { + case .endOrError: + preconditionFailure("unexpected state on .head") + default: + if !head.isKeepAlive { + self.closing = true + } - if let redirectURL = self.redirectHandler?.redirectTarget(status: head.status, headers: head.headers) { - self.state = .redirected(head, redirectURL) - } else { - self.state = .head - self.mayRead = false - self.callOutToDelegate(value: head, channelEventLoop: context.eventLoop, self.delegate.didReceiveHead) - .whenComplete { result in - self.handleBackpressureResult(context: context, result: result) - } + if let redirectURL = self.redirectHandler?.redirectTarget(status: head.status, headers: head.headers) { + self.state = .redirected(head, redirectURL) + } else { + self.state = .head + self.mayRead = false + self.callOutToDelegate(value: head, channelEventLoop: context.eventLoop, self.delegate.didReceiveHead) + .whenComplete { result in + self.handleBackpressureResult(context: context, result: result) + } + } } case .body(let body): switch self.state { case .redirected: break + case .endOrError: + preconditionFailure("unexpected state on .body") default: self.state = .body self.mayRead = false @@ -952,6 +959,8 @@ extension TaskHandler: ChannelDuplexHandler { } case .end: switch self.state { + case .endOrError: + preconditionFailure("unexpected state on .end") case .redirected(let head, let redirectURL): self.state = .endOrError self.task.releaseAssociatedConnection(delegateType: Delegate.self, closing: self.closing).whenSuccess { From c7b67e033d72467045a887ec1c9a14ef985c49ee Mon Sep 17 00:00:00 2001 From: Artem Redkin <aredkin@apple.com> Date: Fri, 21 Aug 2020 15:39:02 +0100 Subject: [PATCH 2/2] Prevent TaskHandler state change after `.endOrError` Motivation: Right now if task handler encounters an error, it changes state to `.endOrError`. We gate on that state to make sure that we do not process errors in the pipeline twice. Unfortunately, that state can be reset when we upload body or receive response parts. Modifications: Adds state validation before state is updated to a new value Adds a test Result: Fixes #297 --- Sources/AsyncHTTPClient/HTTPHandler.swift | 48 ++++++++++--------- .../HTTPClientInternalTests+XCTest.swift | 1 + .../HTTPClientInternalTests.swift | 44 +++++++++++++++++ 3 files changed, 71 insertions(+), 22 deletions(-) diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index 084a66b0b..361f61159 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -840,15 +840,22 @@ extension TaskHandler: ChannelDuplexHandler { self.writeBody(request: request, context: context) }.flatMap { context.eventLoop.assertInEventLoop() + if case .endOrError = self.state { + return context.eventLoop.makeSucceededFuture(()) + } + self.state = .bodySent if let expectedBodyLength = self.expectedBodyLength, expectedBodyLength != self.actualBodyLength { let error = HTTPClientError.bodyLengthMismatch - self.errorCaught(context: context, error: error) return context.eventLoop.makeFailedFuture(error) } return context.writeAndFlush(self.wrapOutboundOut(.end(nil))) }.map { context.eventLoop.assertInEventLoop() + if case .endOrError = self.state { + return + } + self.state = .sent self.callOutToDelegateFireAndForget(self.delegate.didSendRequest) }.flatMapErrorThrowing { error in @@ -924,31 +931,28 @@ extension TaskHandler: ChannelDuplexHandler { let response = self.unwrapInboundIn(data) switch response { case .head(let head): - switch self.state { - case .endOrError: - preconditionFailure("unexpected state on .head") - default: - if !head.isKeepAlive { - self.closing = true - } + if case .endOrError = self.state { + return + } - if let redirectURL = self.redirectHandler?.redirectTarget(status: head.status, headers: head.headers) { - self.state = .redirected(head, redirectURL) - } else { - self.state = .head - self.mayRead = false - self.callOutToDelegate(value: head, channelEventLoop: context.eventLoop, self.delegate.didReceiveHead) - .whenComplete { result in - self.handleBackpressureResult(context: context, result: result) - } - } + if !head.isKeepAlive { + self.closing = true + } + + if let redirectURL = self.redirectHandler?.redirectTarget(status: head.status, headers: head.headers) { + self.state = .redirected(head, redirectURL) + } else { + self.state = .head + self.mayRead = false + self.callOutToDelegate(value: head, channelEventLoop: context.eventLoop, self.delegate.didReceiveHead) + .whenComplete { result in + self.handleBackpressureResult(context: context, result: result) + } } case .body(let body): switch self.state { - case .redirected: + case .redirected, .endOrError: break - case .endOrError: - preconditionFailure("unexpected state on .body") default: self.state = .body self.mayRead = false @@ -960,7 +964,7 @@ extension TaskHandler: ChannelDuplexHandler { case .end: switch self.state { case .endOrError: - preconditionFailure("unexpected state on .end") + break case .redirected(let head, let redirectURL): self.state = .endOrError self.task.releaseAssociatedConnection(delegateType: Delegate.self, closing: self.closing).whenSuccess { diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift index 648eb8078..839a68460 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift @@ -47,6 +47,7 @@ extension HTTPClientInternalTests { ("testInternalRequestURI", testInternalRequestURI), ("testBodyPartStreamStateChangedBeforeNotification", testBodyPartStreamStateChangedBeforeNotification), ("testHandlerDoubleError", testHandlerDoubleError), + ("testTaskHandlerStateChangeAfterError", testTaskHandlerStateChangeAfterError), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift index 706a3bbd7..803824a0c 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift @@ -1119,4 +1119,48 @@ class HTTPClientInternalTests: XCTestCase { XCTAssertEqual(delegate.count, 1) } + + func testTaskHandlerStateChangeAfterError() throws { + let channel = EmbeddedChannel() + let task = Task<Void>(eventLoop: channel.eventLoop, logger: HTTPClient.loggingDisabled) + + let handler = TaskHandler(task: task, + kind: .host, + delegate: TestHTTPDelegate(), + redirectHandler: nil, + ignoreUncleanSSLShutdown: false, + logger: HTTPClient.loggingDisabled) + + try channel.pipeline.addHandler(handler).wait() + + var request = try Request(url: "http://localhost:8080/get") + request.headers.add(name: "X-Test-Header", value: "X-Test-Value") + request.body = .stream(length: 4) { writer in + writer.write(.byteBuffer(channel.allocator.buffer(string: "1234"))).map { + handler.state = .endOrError + } + } + + XCTAssertNoThrow(try channel.writeOutbound(request)) + + try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok))) + XCTAssertTrue(handler.state.isEndOrError) + + try channel.writeInbound(HTTPClientResponsePart.body(channel.allocator.buffer(string: "1234"))) + XCTAssertTrue(handler.state.isEndOrError) + + try channel.writeInbound(HTTPClientResponsePart.end(nil)) + XCTAssertTrue(handler.state.isEndOrError) + } +} + +extension TaskHandler.State { + var isEndOrError: Bool { + switch self { + case .endOrError: + return true + default: + return false + } + } }