From 8c99c41b8d8c0a57d80803d24c9fe31d367fcfad Mon Sep 17 00:00:00 2001
From: Artem Redkin <aredkin@apple.com>
Date: Fri, 21 Aug 2020 14:56:37 +0100
Subject: [PATCH 1/2] fail if we get part when state is endOrError

---
 Sources/AsyncHTTPClient/HTTPHandler.swift | 35 ++++++++++++++---------
 1 file changed, 22 insertions(+), 13 deletions(-)

diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift
index 12e6a4fc4..084a66b0b 100644
--- a/Sources/AsyncHTTPClient/HTTPHandler.swift
+++ b/Sources/AsyncHTTPClient/HTTPHandler.swift
@@ -839,8 +839,8 @@ extension TaskHandler: ChannelDuplexHandler {
         }.flatMap {
             self.writeBody(request: request, context: context)
         }.flatMap {
-            self.state = .bodySent
             context.eventLoop.assertInEventLoop()
+            self.state = .bodySent
             if let expectedBodyLength = self.expectedBodyLength, expectedBodyLength != self.actualBodyLength {
                 let error = HTTPClientError.bodyLengthMismatch
                 self.errorCaught(context: context, error: error)
@@ -924,24 +924,31 @@ extension TaskHandler: ChannelDuplexHandler {
         let response = self.unwrapInboundIn(data)
         switch response {
         case .head(let head):
-            if !head.isKeepAlive {
-                self.closing = true
-            }
+            switch self.state {
+            case .endOrError:
+                preconditionFailure("unexpected state on .head")
+            default:
+                if !head.isKeepAlive {
+                    self.closing = true
+                }
 
-            if let redirectURL = self.redirectHandler?.redirectTarget(status: head.status, headers: head.headers) {
-                self.state = .redirected(head, redirectURL)
-            } else {
-                self.state = .head
-                self.mayRead = false
-                self.callOutToDelegate(value: head, channelEventLoop: context.eventLoop, self.delegate.didReceiveHead)
-                    .whenComplete { result in
-                        self.handleBackpressureResult(context: context, result: result)
-                    }
+                if let redirectURL = self.redirectHandler?.redirectTarget(status: head.status, headers: head.headers) {
+                    self.state = .redirected(head, redirectURL)
+                } else {
+                    self.state = .head
+                    self.mayRead = false
+                    self.callOutToDelegate(value: head, channelEventLoop: context.eventLoop, self.delegate.didReceiveHead)
+                        .whenComplete { result in
+                            self.handleBackpressureResult(context: context, result: result)
+                        }
+                }
             }
         case .body(let body):
             switch self.state {
             case .redirected:
                 break
+            case .endOrError:
+                preconditionFailure("unexpected state on .body")
             default:
                 self.state = .body
                 self.mayRead = false
@@ -952,6 +959,8 @@ extension TaskHandler: ChannelDuplexHandler {
             }
         case .end:
             switch self.state {
+            case .endOrError:
+                preconditionFailure("unexpected state on .end")
             case .redirected(let head, let redirectURL):
                 self.state = .endOrError
                 self.task.releaseAssociatedConnection(delegateType: Delegate.self, closing: self.closing).whenSuccess {

From c7b67e033d72467045a887ec1c9a14ef985c49ee Mon Sep 17 00:00:00 2001
From: Artem Redkin <aredkin@apple.com>
Date: Fri, 21 Aug 2020 15:39:02 +0100
Subject: [PATCH 2/2] Prevent TaskHandler state change after `.endOrError`

Motivation:
Right now if task handler encounters an error, it changes state to
`.endOrError`. We gate on that state to make sure that we do not
process errors in the pipeline twice. Unfortunately, that state
can be reset when we upload body or receive response parts.

Modifications:
Adds state validation before state is updated to a new value
Adds a test

Result:
Fixes #297
---
 Sources/AsyncHTTPClient/HTTPHandler.swift     | 48 ++++++++++---------
 .../HTTPClientInternalTests+XCTest.swift      |  1 +
 .../HTTPClientInternalTests.swift             | 44 +++++++++++++++++
 3 files changed, 71 insertions(+), 22 deletions(-)

diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift
index 084a66b0b..361f61159 100644
--- a/Sources/AsyncHTTPClient/HTTPHandler.swift
+++ b/Sources/AsyncHTTPClient/HTTPHandler.swift
@@ -840,15 +840,22 @@ extension TaskHandler: ChannelDuplexHandler {
             self.writeBody(request: request, context: context)
         }.flatMap {
             context.eventLoop.assertInEventLoop()
+            if case .endOrError = self.state {
+                return context.eventLoop.makeSucceededFuture(())
+            }
+
             self.state = .bodySent
             if let expectedBodyLength = self.expectedBodyLength, expectedBodyLength != self.actualBodyLength {
                 let error = HTTPClientError.bodyLengthMismatch
-                self.errorCaught(context: context, error: error)
                 return context.eventLoop.makeFailedFuture(error)
             }
             return context.writeAndFlush(self.wrapOutboundOut(.end(nil)))
         }.map {
             context.eventLoop.assertInEventLoop()
+            if case .endOrError = self.state {
+                return
+            }
+
             self.state = .sent
             self.callOutToDelegateFireAndForget(self.delegate.didSendRequest)
         }.flatMapErrorThrowing { error in
@@ -924,31 +931,28 @@ extension TaskHandler: ChannelDuplexHandler {
         let response = self.unwrapInboundIn(data)
         switch response {
         case .head(let head):
-            switch self.state {
-            case .endOrError:
-                preconditionFailure("unexpected state on .head")
-            default:
-                if !head.isKeepAlive {
-                    self.closing = true
-                }
+            if case .endOrError = self.state {
+                return
+            }
 
-                if let redirectURL = self.redirectHandler?.redirectTarget(status: head.status, headers: head.headers) {
-                    self.state = .redirected(head, redirectURL)
-                } else {
-                    self.state = .head
-                    self.mayRead = false
-                    self.callOutToDelegate(value: head, channelEventLoop: context.eventLoop, self.delegate.didReceiveHead)
-                        .whenComplete { result in
-                            self.handleBackpressureResult(context: context, result: result)
-                        }
-                }
+            if !head.isKeepAlive {
+                self.closing = true
+            }
+
+            if let redirectURL = self.redirectHandler?.redirectTarget(status: head.status, headers: head.headers) {
+                self.state = .redirected(head, redirectURL)
+            } else {
+                self.state = .head
+                self.mayRead = false
+                self.callOutToDelegate(value: head, channelEventLoop: context.eventLoop, self.delegate.didReceiveHead)
+                    .whenComplete { result in
+                        self.handleBackpressureResult(context: context, result: result)
+                    }
             }
         case .body(let body):
             switch self.state {
-            case .redirected:
+            case .redirected, .endOrError:
                 break
-            case .endOrError:
-                preconditionFailure("unexpected state on .body")
             default:
                 self.state = .body
                 self.mayRead = false
@@ -960,7 +964,7 @@ extension TaskHandler: ChannelDuplexHandler {
         case .end:
             switch self.state {
             case .endOrError:
-                preconditionFailure("unexpected state on .end")
+                break
             case .redirected(let head, let redirectURL):
                 self.state = .endOrError
                 self.task.releaseAssociatedConnection(delegateType: Delegate.self, closing: self.closing).whenSuccess {
diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift
index 648eb8078..839a68460 100644
--- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift
+++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift
@@ -47,6 +47,7 @@ extension HTTPClientInternalTests {
             ("testInternalRequestURI", testInternalRequestURI),
             ("testBodyPartStreamStateChangedBeforeNotification", testBodyPartStreamStateChangedBeforeNotification),
             ("testHandlerDoubleError", testHandlerDoubleError),
+            ("testTaskHandlerStateChangeAfterError", testTaskHandlerStateChangeAfterError),
         ]
     }
 }
diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift
index 706a3bbd7..803824a0c 100644
--- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift
+++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift
@@ -1119,4 +1119,48 @@ class HTTPClientInternalTests: XCTestCase {
 
         XCTAssertEqual(delegate.count, 1)
     }
+
+    func testTaskHandlerStateChangeAfterError() throws {
+        let channel = EmbeddedChannel()
+        let task = Task<Void>(eventLoop: channel.eventLoop, logger: HTTPClient.loggingDisabled)
+
+        let handler = TaskHandler(task: task,
+                                  kind: .host,
+                                  delegate: TestHTTPDelegate(),
+                                  redirectHandler: nil,
+                                  ignoreUncleanSSLShutdown: false,
+                                  logger: HTTPClient.loggingDisabled)
+
+        try channel.pipeline.addHandler(handler).wait()
+
+        var request = try Request(url: "http://localhost:8080/get")
+        request.headers.add(name: "X-Test-Header", value: "X-Test-Value")
+        request.body = .stream(length: 4) { writer in
+            writer.write(.byteBuffer(channel.allocator.buffer(string: "1234"))).map {
+                handler.state = .endOrError
+            }
+        }
+
+        XCTAssertNoThrow(try channel.writeOutbound(request))
+
+        try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok)))
+        XCTAssertTrue(handler.state.isEndOrError)
+
+        try channel.writeInbound(HTTPClientResponsePart.body(channel.allocator.buffer(string: "1234")))
+        XCTAssertTrue(handler.state.isEndOrError)
+
+        try channel.writeInbound(HTTPClientResponsePart.end(nil))
+        XCTAssertTrue(handler.state.isEndOrError)
+    }
+}
+
+extension TaskHandler.State {
+    var isEndOrError: Bool {
+        switch self {
+        case .endOrError:
+            return true
+        default:
+            return false
+        }
+    }
 }