diff --git a/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift b/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift index 538424538..ad49332c0 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift @@ -356,6 +356,7 @@ extension Transaction { // response body stream. let body = TransactionBody.makeSequence( backPressureStrategy: .init(lowWatermark: 1, highWatermark: 1), + finishOnDeinit: true, delegate: AnyAsyncSequenceProducerDelegate(delegate) ) diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift index 63cb70b99..ba3a09cc9 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift @@ -42,9 +42,17 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { if let idleReadTimeout = newRequest.requestOptions.idleReadTimeout { self.idleReadTimeoutStateMachine = .init(timeAmount: idleReadTimeout) } + + if let idleWriteTimeout = newRequest.requestOptions.idleWriteTimeout { + self.idleWriteTimeoutStateMachine = .init( + timeAmount: idleWriteTimeout, + isWritabilityEnabled: self.channelContext?.channel.isWritable ?? false + ) + } } else { self.logger = self.backgroundLogger self.idleReadTimeoutStateMachine = nil + self.idleWriteTimeoutStateMachine = nil } } } @@ -57,6 +65,14 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { /// We check in the task if the timer ID has changed in the meantime and do not execute any action if has changed. private var currentIdleReadTimeoutTimerID: Int = 0 + private var idleWriteTimeoutStateMachine: IdleWriteStateMachine? + private var idleWriteTimeoutTimer: Scheduled? + + /// Cancelling a task in NIO does *not* guarantee that the task will not execute under certain race conditions. + /// We therefore give each timer an ID and increase the ID every time we reset or cancel it. + /// We check in the task if the timer ID has changed in the meantime and do not execute any action if has changed. + private var currentIdleWriteTimeoutTimerID: Int = 0 + private let backgroundLogger: Logger private var logger: Logger private let eventLoop: EventLoop @@ -106,6 +122,10 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { "ahc-channel-writable": "\(context.channel.isWritable)", ]) + if let timeoutAction = self.idleWriteTimeoutStateMachine?.channelWritabilityChanged(context: context) { + self.runTimeoutAction(timeoutAction, context: context) + } + let action = self.state.writabilityChanged(writable: context.channel.isWritable) self.run(action, context: context) context.fireChannelWritabilityChanged() @@ -150,6 +170,11 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.request = req self.logger.debug("Request was scheduled on connection") + + if let timeoutAction = self.idleWriteTimeoutStateMachine?.write() { + self.runTimeoutAction(timeoutAction, context: context) + } + req.willExecuteRequest(self) let action = self.state.runNewRequest( @@ -196,8 +221,12 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { request.resumeRequestBodyStream() } if startIdleTimer { - if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { - self.runTimeoutAction(timeoutAction, context: context) + if let readTimeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(readTimeoutAction, context: context) + } + + if let writeTimeoutAction = self.idleWriteTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(writeTimeoutAction, context: context) } } case .sendBodyPart(let part, let writePromise): @@ -206,8 +235,12 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { case .sendRequestEnd(let writePromise): context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: writePromise) - if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { - self.runTimeoutAction(timeoutAction, context: context) + if let readTimeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(readTimeoutAction, context: context) + } + + if let writeTimeoutAction = self.idleWriteTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(writeTimeoutAction, context: context) } case .pauseRequestBodyStream: @@ -380,6 +413,40 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { } } + private func runTimeoutAction(_ action: IdleWriteStateMachine.Action, context: ChannelHandlerContext) { + switch action { + case .startIdleWriteTimeoutTimer(let timeAmount): + assert(self.idleWriteTimeoutTimer == nil, "Expected there is no timeout timer so far.") + + let timerID = self.currentIdleWriteTimeoutTimerID + self.idleWriteTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + guard self.currentIdleWriteTimeoutTimerID == timerID else { return } + let action = self.state.idleWriteTimeoutTriggered() + self.run(action, context: context) + } + case .resetIdleWriteTimeoutTimer(let timeAmount): + if let oldTimer = self.idleWriteTimeoutTimer { + oldTimer.cancel() + } + + self.currentIdleWriteTimeoutTimerID &+= 1 + let timerID = self.currentIdleWriteTimeoutTimerID + self.idleWriteTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + guard self.currentIdleWriteTimeoutTimerID == timerID else { return } + let action = self.state.idleWriteTimeoutTriggered() + self.run(action, context: context) + } + case .clearIdleWriteTimeoutTimer: + if let oldTimer = self.idleWriteTimeoutTimer { + self.idleWriteTimeoutTimer = nil + self.currentIdleWriteTimeoutTimerID &+= 1 + oldTimer.cancel() + } + case .none: + break + } + } + // MARK: Private HTTPRequestExecutor private func writeRequestBodyPart0(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { @@ -393,6 +460,10 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { return } + if let timeoutAction = self.idleWriteTimeoutStateMachine?.write() { + self.runTimeoutAction(timeoutAction, context: context) + } + let action = self.state.requestStreamPartReceived(data, promise: promise) self.run(action, context: context) } @@ -428,6 +499,10 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.logger.trace("Request was cancelled") + if let timeoutAction = self.idleWriteTimeoutStateMachine?.cancelRequest() { + self.runTimeoutAction(timeoutAction, context: context) + } + let action = self.state.requestCancelled(closeConnection: true) self.run(action, context: context) } @@ -540,3 +615,87 @@ struct IdleReadStateMachine { } } } + +struct IdleWriteStateMachine { + enum Action { + case startIdleWriteTimeoutTimer(TimeAmount) + case resetIdleWriteTimeoutTimer(TimeAmount) + case clearIdleWriteTimeoutTimer + case none + } + + enum State { + case waitingForRequestEnd + case waitingForWritabilityEnabled + case requestEndSent + } + + private var state: State + private let timeAmount: TimeAmount + + init(timeAmount: TimeAmount, isWritabilityEnabled: Bool) { + self.timeAmount = timeAmount + if isWritabilityEnabled { + self.state = .waitingForRequestEnd + } else { + self.state = .waitingForWritabilityEnabled + } + } + + mutating func cancelRequest() -> Action { + switch self.state { + case .waitingForRequestEnd, .waitingForWritabilityEnabled: + self.state = .requestEndSent + return .clearIdleWriteTimeoutTimer + case .requestEndSent: + return .none + } + } + + mutating func write() -> Action { + switch self.state { + case .waitingForRequestEnd: + return .resetIdleWriteTimeoutTimer(self.timeAmount) + case .waitingForWritabilityEnabled: + return .none + case .requestEndSent: + preconditionFailure("If the request end has been sent, we can't write more data.") + } + } + + mutating func requestEndSent() -> Action { + switch self.state { + case .waitingForRequestEnd: + self.state = .requestEndSent + return .clearIdleWriteTimeoutTimer + case .waitingForWritabilityEnabled: + preconditionFailure("If the channel is not writable, we can't have sent the request end.") + case .requestEndSent: + return .none + } + } + + mutating func channelWritabilityChanged(context: ChannelHandlerContext) -> Action { + if context.channel.isWritable { + switch self.state { + case .waitingForRequestEnd: + preconditionFailure("If waiting for more data, the channel was already writable.") + case .waitingForWritabilityEnabled: + self.state = .waitingForRequestEnd + return .startIdleWriteTimeoutTimer(self.timeAmount) + case .requestEndSent: + return .none + } + } else { + switch self.state { + case .waitingForRequestEnd: + self.state = .waitingForWritabilityEnabled + return .clearIdleWriteTimeoutTimer + case .waitingForWritabilityEnabled: + preconditionFailure("If the channel was writable before, then we should have been waiting for more data.") + case .requestEndSent: + return .none + } + } + } +} diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift index eb4182593..ed4594183 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift @@ -355,6 +355,18 @@ struct HTTP1ConnectionStateMachine { } } + mutating func idleWriteTimeoutTriggered() -> Action { + guard case .inRequest(var requestStateMachine, let close) = self.state else { + preconditionFailure("Invalid state: \(self.state)") + } + + return self.avoidingStateMachineCoW { state -> Action in + let action = requestStateMachine.idleWriteTimeoutTriggered() + state = .inRequest(requestStateMachine, close: close) + return state.modify(with: action) + } + } + mutating func headSent() -> Action { guard case .inRequest(var requestStateMachine, let close) = self.state else { return .wait diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift index 0e8e819e8..4c69bc5dd 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift @@ -35,8 +35,16 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { private var request: HTTPExecutableRequest? { didSet { - if let newRequest = self.request, let idleReadTimeout = newRequest.requestOptions.idleReadTimeout { - self.idleReadTimeoutStateMachine = .init(timeAmount: idleReadTimeout) + if let newRequest = self.request { + if let idleReadTimeout = newRequest.requestOptions.idleReadTimeout { + self.idleReadTimeoutStateMachine = .init(timeAmount: idleReadTimeout) + } + if let idleWriteTimeout = newRequest.requestOptions.idleWriteTimeout { + self.idleWriteTimeoutStateMachine = .init( + timeAmount: idleWriteTimeout, + isWritabilityEnabled: self.channelContext?.channel.isWritable ?? false + ) + } } else { self.idleReadTimeoutStateMachine = nil } @@ -46,6 +54,9 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { private var idleReadTimeoutStateMachine: IdleReadStateMachine? private var idleReadTimeoutTimer: Scheduled? + private var idleWriteTimeoutStateMachine: IdleWriteStateMachine? + private var idleWriteTimeoutTimer: Scheduled? + init(eventLoop: EventLoop) { self.eventLoop = eventLoop } @@ -77,6 +88,10 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { } func channelWritabilityChanged(context: ChannelHandlerContext) { + if let timeoutAction = self.idleWriteTimeoutStateMachine?.channelWritabilityChanged(context: context) { + self.runTimeoutAction(timeoutAction, context: context) + } + let action = self.state.writabilityChanged(writable: context.channel.isWritable) self.run(action, context: context) } @@ -110,6 +125,10 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { // a single request. self.request = request + if let timeoutAction = self.idleWriteTimeoutStateMachine?.write() { + self.runTimeoutAction(timeoutAction, context: context) + } + request.willExecuteRequest(self) let action = self.state.startRequest( @@ -153,8 +172,12 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { request.resumeRequestBodyStream() } if startIdleTimer { - if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { - self.runTimeoutAction(timeoutAction, context: context) + if let readTimeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(readTimeoutAction, context: context) + } + + if let writeTimeoutAction = self.idleWriteTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(writeTimeoutAction, context: context) } } case .pauseRequestBodyStream: @@ -168,8 +191,12 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { case .sendRequestEnd(let writePromise): context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: writePromise) - if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { - self.runTimeoutAction(timeoutAction, context: context) + if let readTimeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(readTimeoutAction, context: context) + } + + if let writeTimeoutAction = self.idleWriteTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(writeTimeoutAction, context: context) } case .read: @@ -295,6 +322,36 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { } } + private func runTimeoutAction(_ action: IdleWriteStateMachine.Action, context: ChannelHandlerContext) { + switch action { + case .startIdleWriteTimeoutTimer(let timeAmount): + assert(self.idleWriteTimeoutTimer == nil, "Expected there is no timeout timer so far.") + + self.idleWriteTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + guard self.idleWriteTimeoutTimer != nil else { return } + let action = self.state.idleWriteTimeoutTriggered() + self.run(action, context: context) + } + case .resetIdleWriteTimeoutTimer(let timeAmount): + if let oldTimer = self.idleWriteTimeoutTimer { + oldTimer.cancel() + } + + self.idleWriteTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + guard self.idleWriteTimeoutTimer != nil else { return } + let action = self.state.idleWriteTimeoutTriggered() + self.run(action, context: context) + } + case .clearIdleWriteTimeoutTimer: + if let oldTimer = self.idleWriteTimeoutTimer { + self.idleWriteTimeoutTimer = nil + oldTimer.cancel() + } + case .none: + break + } + } + // MARK: Private HTTPRequestExecutor private func writeRequestBodyPart0(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { @@ -308,6 +365,10 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { return } + if let timeoutAction = self.idleWriteTimeoutStateMachine?.write() { + self.runTimeoutAction(timeoutAction, context: context) + } + let action = self.state.requestStreamPartReceived(data, promise: promise) self.run(action, context: context) } @@ -338,6 +399,10 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { return } + if let timeoutAction = self.idleWriteTimeoutStateMachine?.cancelRequest() { + self.runTimeoutAction(timeoutAction, context: context) + } + let action = self.state.requestCancelled() self.run(action, context: context) } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift index 4835feac3..b575ae094 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift @@ -704,6 +704,28 @@ struct HTTPRequestStateMachine { } } + mutating func idleWriteTimeoutTriggered() -> Action { + switch self.state { + case .initialized, + .waitForChannelToBecomeWritable: + preconditionFailure("We only schedule idle write timeouts while the request is being sent. Invalid state: \(self.state)") + + case .running(.streaming, _): + let error = HTTPClientError.writeTimeout + self.state = .failed(error) + return .failRequest(error, .close(nil)) + + case .running(.endSent, _): + preconditionFailure("Invalid state. This state should be: .finished") + + case .finished, .failed: + return .wait + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + private mutating func startSendingRequest(head: HTTPRequestHead, metadata: RequestFramingMetadata) -> Action { let length = metadata.body.expectedLength if length == 0 { diff --git a/Sources/AsyncHTTPClient/ConnectionPool/RequestOptions.swift b/Sources/AsyncHTTPClient/ConnectionPool/RequestOptions.swift index c46f1289c..903f962e5 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/RequestOptions.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/RequestOptions.swift @@ -17,11 +17,18 @@ import NIOCore struct RequestOptions { /// The maximal `TimeAmount` that is allowed to pass between `channelRead`s from the Channel. var idleReadTimeout: TimeAmount? - + /// The maximal `TimeAmount` that is allowed to pass between `write`s into the Channel. + var idleWriteTimeout: TimeAmount? + /// DNS overrides. var dnsOverride: [String: String] - init(idleReadTimeout: TimeAmount?, dnsOverride: [String: String]) { + init( + idleReadTimeout: TimeAmount?, + idleWriteTimeout: TimeAmount?, + dnsOverride: [String: String] + ) { self.idleReadTimeout = idleReadTimeout + self.idleWriteTimeout = idleWriteTimeout self.dnsOverride = dnsOverride } } @@ -30,6 +37,7 @@ extension RequestOptions { static func fromClientConfiguration(_ configuration: HTTPClient.Configuration) -> Self { RequestOptions( idleReadTimeout: configuration.timeout.read, + idleWriteTimeout: configuration.timeout.write, dnsOverride: configuration.dnsOverride ) } diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index db8ed3d97..6fc94de5c 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -79,7 +79,7 @@ public class HTTPClient { private var state: State private let stateLock = NIOLock() - internal static let loggingDisabled = Logger(label: "AHC-do-not-log", factory: { _ in SwiftLogNoOpLogHandler() }) + static let loggingDisabled = Logger(label: "AHC-do-not-log", factory: { _ in SwiftLogNoOpLogHandler() }) /// Create an ``HTTPClient`` with specified `EventLoopGroup` provider and configuration. /// @@ -184,7 +184,7 @@ public class HTTPClient { /// throw the appropriate error if needed. For instance, if its internal connection pool has any non-released connections, /// this indicate shutdown was called too early before tasks were completed or explicitly canceled. /// In general, setting this parameter to `true` should make it easier and faster to catch related programming errors. - internal func syncShutdown(requiresCleanClose: Bool) throws { + func syncShutdown(requiresCleanClose: Bool) throws { if let eventLoop = MultiThreadedEventLoopGroup.currentEventLoop { preconditionFailure(""" BUG DETECTED: syncShutdown() must not be called when on an EventLoop. @@ -927,8 +927,10 @@ extension HTTPClient.Configuration { public var connect: TimeAmount? /// Specifies read timeout. public var read: TimeAmount? + /// Specifies the maximum amount of time without bytes being written by the client before closing the connection. + public var write: TimeAmount? - /// internal connection creation timeout. Defaults the connect timeout to always contain a value. + /// Internal connection creation timeout. Defaults the connect timeout to always contain a value. var connectionCreationTimeout: TimeAmount { self.connect ?? .seconds(10) } @@ -938,7 +940,25 @@ extension HTTPClient.Configuration { /// - parameters: /// - connect: `connect` timeout. Will default to 10 seconds, if no value is provided. /// - read: `read` timeout. - public init(connect: TimeAmount? = nil, read: TimeAmount? = nil) { + public init( + connect: TimeAmount? = nil, + read: TimeAmount? = nil + ) { + self.connect = connect + self.read = read + } + + /// Create timeout. + /// + /// - parameters: + /// - connect: `connect` timeout. Will default to 10 seconds, if no value is provided. + /// - read: `read` timeout. + /// - write: `write` timeout. + public init( + connect: TimeAmount? = nil, + read: TimeAmount? = nil, + write: TimeAmount + ) { self.connect = connect self.read = read } @@ -1007,7 +1027,7 @@ extension HTTPClient.Configuration { } public struct HTTPVersion: Sendable, Hashable { - internal enum Configuration { + enum Configuration { case http1Only case automatic } @@ -1018,7 +1038,7 @@ extension HTTPClient.Configuration { /// HTTP/2 is used if we connect to a server with HTTPS and the server supports HTTP/2, otherwise we use HTTP/1 public static let automatic: Self = .init(configuration: .automatic) - internal var configuration: Configuration + var configuration: Configuration } } @@ -1032,6 +1052,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { case emptyScheme case unsupportedScheme(String) case readTimeout + case writeTimeout case remoteConnectionClosed case cancelled case identityCodingIncorrectlyPresent @@ -1090,6 +1111,8 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { return "Unsupported scheme" case .readTimeout: return "Read timeout" + case .writeTimeout: + return "Write timeout" case .remoteConnectionClosed: return "Remote connection closed" case .cancelled: @@ -1155,8 +1178,10 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { public static let emptyScheme = HTTPClientError(code: .emptyScheme) /// Provided URL scheme is not supported, supported schemes are: `http` and `https` public static func unsupportedScheme(_ scheme: String) -> HTTPClientError { return HTTPClientError(code: .unsupportedScheme(scheme)) } - /// Request timed out. + /// Request timed out while waiting for response. public static let readTimeout = HTTPClientError(code: .readTimeout) + /// Request timed out. + public static let writeTimeout = HTTPClientError(code: .writeTimeout) /// Remote connection was closed unexpectedly. public static let remoteConnectionClosed = HTTPClientError(code: .remoteConnectionClosed) /// Request was cancelled. diff --git a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift index 2aa010491..f6a2840d9 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift @@ -337,6 +337,135 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { embedded.embeddedEventLoop.advanceTime(by: .milliseconds(250)) } + func testIdleWriteTimeout() { + let embedded = EmbeddedChannel() + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: .stream(length: 10) { writer in + // Advance time by more than the idle write timeout (that's 1 millisecond) to trigger the timeout. + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) + return testWriter.start(writer: writer) + })) + + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + )) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + testUtils.connection.executeRequest(requestBag) + + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .writeTimeout) + } + } + + func testIdleWriteTimeoutWritabilityChanged() { + let embedded = EmbeddedChannel() + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: .stream(length: 10) { writer in + embedded.isWritable = false + embedded.pipeline.fireChannelWritabilityChanged() + // This should not trigger any errors or timeouts, because the timer isn't running + // as the channel is not writable. + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(20)) + + // Now that the channel will become writable, this should trigger a timeout. + embedded.isWritable = true + embedded.pipeline.fireChannelWritabilityChanged() + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) + + return testWriter.start(writer: writer) + })) + + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + )) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + testUtils.connection.executeRequest(requestBag) + + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .writeTimeout) + } + } + + func testIdleWriteTimeoutIsCancelledIfRequestIsCancelled() { + let embedded = EmbeddedChannel() + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 1) + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: .stream(length: 2) { writer in + return testWriter.start(writer: writer, expectedErrors: [HTTPClientError.cancelled]) + })) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + )) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + testUtils.connection.executeRequest(requestBag) + + // canceling the request + requestBag.fail(HTTPClientError.cancelled) + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .cancelled) + } + + // the idle write timeout should be cleared because we canceled the request + // therefore advancing the time should not trigger a crash + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(250)) + } + func testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand() { let embedded = EmbeddedChannel() var maybeTestUtils: HTTP1TestTools? @@ -576,7 +705,7 @@ class TestBackpressureWriter { self.finishPromise = eventLoop.makePromise(of: Void.self) } - func start(writer: HTTPClient.Body.StreamWriter) -> EventLoopFuture { + func start(writer: HTTPClient.Body.StreamWriter, expectedErrors: [HTTPClientError] = []) -> EventLoopFuture { func recursive() { XCTAssert(self.eventLoop.inEventLoop) XCTAssert(self.channelIsWritable) @@ -591,7 +720,15 @@ class TestBackpressureWriter { case .success: recursive() case .failure(let error): - XCTFail("Unexpected error: \(error)") + let isExpectedError = expectedErrors.contains { httpError in + if let castError = error as? HTTPClientError { + return castError == httpError + } + return false + } + if !isExpectedError { + XCTFail("Unexpected error: \(error)") + } } } } diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift index 2b68fceb3..545ba1e3c 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift @@ -286,6 +286,139 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { embedded.embeddedEventLoop.advanceTime(by: .milliseconds(250)) } + func testIdleWriteTimeout() { + let embedded = EmbeddedChannel() + let requestHandler = HTTP2ClientRequestHandler(eventLoop: embedded.eventLoop) + XCTAssertNoThrow(try embedded.pipeline.syncOperations.addHandlers([requestHandler])) + XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) + let logger = Logger(label: "test") + + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: .stream(length: 10) { writer in + // Advance time by more than the idle write timeout (that's 1 millisecond) to trigger the timeout. + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) + return testWriter.start(writer: writer) + })) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + )) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + embedded.write(requestBag, promise: nil) + + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .writeTimeout) + } + } + + func testIdleWriteTimeoutWritabilityChanged() { + let embedded = EmbeddedChannel() + let readEventHandler = ReadEventHitHandler() + let requestHandler = HTTP2ClientRequestHandler(eventLoop: embedded.eventLoop) + XCTAssertNoThrow(try embedded.pipeline.syncOperations.addHandlers([readEventHandler, requestHandler])) + XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) + let logger = Logger(label: "test") + + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: .stream(length: 10) { writer in + embedded.isWritable = false + embedded.pipeline.fireChannelWritabilityChanged() + // This should not trigger any errors or timeouts, because the timer isn't running + // as the channel is not writable. + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(20)) + + // Now that the channel will become writable, this should trigger a timeout. + embedded.isWritable = true + embedded.pipeline.fireChannelWritabilityChanged() + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) + + return testWriter.start(writer: writer) + })) + + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + )) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + embedded.write(requestBag, promise: nil) + + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .writeTimeout) + } + } + + func testIdleWriteTimeoutIsCanceledIfRequestIsCanceled() { + let embedded = EmbeddedChannel() + let readEventHandler = ReadEventHitHandler() + let requestHandler = HTTP2ClientRequestHandler(eventLoop: embedded.eventLoop) + XCTAssertNoThrow(try embedded.pipeline.syncOperations.addHandlers([readEventHandler, requestHandler])) + XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) + let logger = Logger(label: "test") + + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: .stream(length: 2) { writer in + return testWriter.start(writer: writer, expectedErrors: [HTTPClientError.cancelled]) + })) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + )) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + embedded.write(requestBag, promise: nil) + + // canceling the request + requestBag.fail(HTTPClientError.cancelled) + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .cancelled) + } + + // the idle read timeout should be cleared because we canceled the request + // therefore advancing the time should not trigger a crash + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(250)) + } + func testWriteHTTPHeadFails() { struct WriteError: Error, Equatable {} diff --git a/Tests/AsyncHTTPClientTests/RequestBagTests.swift b/Tests/AsyncHTTPClientTests/RequestBagTests.swift index e2a959589..610e429f5 100644 --- a/Tests/AsyncHTTPClientTests/RequestBagTests.swift +++ b/Tests/AsyncHTTPClientTests/RequestBagTests.swift @@ -979,10 +979,12 @@ final class MockTaskQueuer: HTTPRequestScheduler { extension RequestOptions { static func forTests( idleReadTimeout: TimeAmount? = nil, + idleWriteTimeout: TimeAmount? = nil, dnsOverride: [String: String] = [:] ) -> Self { RequestOptions( idleReadTimeout: idleReadTimeout, + idleWriteTimeout: idleWriteTimeout, dnsOverride: dnsOverride ) }