Skip to content

Commit e226032

Browse files
committed
Always access Task’s connection and cancelled property through a lock
1 parent 8e4d519 commit e226032

File tree

4 files changed

+96
-46
lines changed

4 files changed

+96
-46
lines changed

Diff for: Sources/AsyncHTTPClient/HTTPClient.swift

+1-5
Original file line numberDiff line numberDiff line change
@@ -579,11 +579,7 @@ public class HTTPClient {
579579

580580
task.setConnection(connection)
581581

582-
let isCancelled = task.lock.withLock {
583-
task.cancelled
584-
}
585-
586-
if !isCancelled {
582+
if !task.isCancelled {
587583
return channel.writeAndFlush(request).flatMapError { _ in
588584
// At this point the `TaskHandler` will already be present
589585
// to handle the failure and pass it to the `promise`

Diff for: Sources/AsyncHTTPClient/HTTPHandler.swift

+35-16
Original file line numberDiff line numberDiff line change
@@ -633,17 +633,25 @@ extension HTTPClient {
633633

634634
let promise: EventLoopPromise<Response>
635635
var completion: EventLoopFuture<Void>
636-
var connection: Connection?
637-
var cancelled: Bool
638-
let lock: Lock
636+
private let lock = Lock()
637+
// protected by lock
638+
private var _connection: Connection?
639+
// protected by lock
640+
private var _cancelled: Bool = false
639641
let logger: Logger // We are okay to store the logger here because a Task is for only one request.
640642

643+
var isCancelled: Bool {
644+
self.lock.withLock { self._cancelled }
645+
}
646+
647+
var connection: Connection? {
648+
self.lock.withLock { self._connection }
649+
}
650+
641651
init(eventLoop: EventLoop, logger: Logger) {
642652
self.eventLoop = eventLoop
643653
self.promise = eventLoop.makePromise()
644654
self.completion = self.promise.futureResult.map { _ in }
645-
self.cancelled = false
646-
self.lock = Lock()
647655
self.logger = logger
648656
}
649657

@@ -669,9 +677,9 @@ extension HTTPClient {
669677
/// Cancels the request execution.
670678
public func cancel() {
671679
let channel: Channel? = self.lock.withLock {
672-
if !self.cancelled {
673-
self.cancelled = true
674-
return self.connection?.channel
680+
if !self._cancelled {
681+
self._cancelled = true
682+
return self._connection?.channel
675683
} else {
676684
return nil
677685
}
@@ -681,13 +689,16 @@ extension HTTPClient {
681689

682690
@discardableResult
683691
func setConnection(_ connection: Connection) -> Connection {
684-
return self.lock.withLock {
685-
self.connection = connection
686-
if self.cancelled {
687-
connection.channel.triggerUserOutboundEvent(TaskCancelEvent(), promise: nil)
688-
}
689-
return connection
692+
let cancelled = self.lock.withLock { () -> Bool in
693+
self._connection = connection
694+
return self._cancelled
690695
}
696+
697+
if cancelled {
698+
connection.channel.triggerUserOutboundEvent(TaskCancelEvent(), promise: nil)
699+
}
700+
701+
return connection
691702
}
692703

693704
func succeed<Delegate: HTTPClientResponseDelegate>(promise: EventLoopPromise<Response>?,
@@ -702,7 +713,9 @@ extension HTTPClient {
702713

703714
func fail<Delegate: HTTPClientResponseDelegate>(with error: Error,
704715
delegateType: Delegate.Type) {
705-
if let connection = self.connection {
716+
let maybeConnection = self.lock.withLock { self._connection }
717+
718+
if let connection = maybeConnection {
706719
self.releaseAssociatedConnection(delegateType: delegateType, closing: true)
707720
.whenSuccess {
708721
self.promise.fail(error)
@@ -716,7 +729,13 @@ extension HTTPClient {
716729

717730
func releaseAssociatedConnection<Delegate: HTTPClientResponseDelegate>(delegateType: Delegate.Type,
718731
closing: Bool) -> EventLoopFuture<Void> {
719-
if let connection = self.connection {
732+
let maybeConnection = self.lock.withLock { () -> Connection? in
733+
let connection = self._connection
734+
self._connection = nil
735+
return connection
736+
}
737+
738+
if let connection = maybeConnection {
720739
// remove read timeout handler
721740
return connection.removeHandler(IdleStateHandler.self).flatMap {
722741
connection.removeHandler(TaskHandler<Delegate>.self)

Diff for: Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ extension HTTPClientInternalTests {
3434
("testRequestFinishesAfterRedirectIfServerRespondsBeforeClientFinishes", testRequestFinishesAfterRedirectIfServerRespondsBeforeClientFinishes),
3535
("testProxyStreaming", testProxyStreaming),
3636
("testProxyStreamingFailure", testProxyStreamingFailure),
37-
("testUploadStreamingBackpressure", testUploadStreamingBackpressure),
37+
("testDownloadStreamingBackpressure", testDownloadStreamingBackpressure),
3838
("testRequestURITrailingSlash", testRequestURITrailingSlash),
3939
("testChannelAndDelegateOnDifferentEventLoops", testChannelAndDelegateOnDifferentEventLoops),
4040
("testResponseConnectionCloseGet", testResponseConnectionCloseGet),

Diff for: Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift

+59-24
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import NIO
1717
import NIOConcurrencyHelpers
1818
import NIOHTTP1
1919
import NIOTestUtils
20+
import Logging
2021
import XCTest
2122

2223
class HTTPClientInternalTests: XCTestCase {
@@ -323,21 +324,25 @@ class HTTPClientInternalTests: XCTestCase {
323324
// of 4 bytes. This will guarantee that if we see first byte of the message, other
324325
// bytes a ready to be read as well. This will allow us to test if subsequent reads
325326
// are waiting for backpressure promise.
326-
func testUploadStreamingBackpressure() throws {
327+
func testDownloadStreamingBackpressure() throws {
327328
class BackpressureTestDelegate: HTTPClientResponseDelegate {
328329
typealias Response = Void
329330

330331
var _reads = 0
331332
let lock: Lock
332-
let backpressurePromise: EventLoopPromise<Void>
333+
let channel: Channel
334+
333335
let optionsApplied: EventLoopPromise<Void>
334-
let messageReceived: EventLoopPromise<Void>
336+
let backpressureFuture: EventLoopFuture<Void>
337+
let firstBodyPartReceived: EventLoopPromise<Void>
335338

336-
init(eventLoop: EventLoop) {
339+
init(channel: Channel, writeBodyPromise: EventLoopPromise<Void>, writeEndFuture: EventLoopFuture<Void>) {
337340
self.lock = Lock()
338-
self.backpressurePromise = eventLoop.makePromise()
339-
self.optionsApplied = eventLoop.makePromise()
340-
self.messageReceived = eventLoop.makePromise()
341+
342+
self.channel = channel
343+
self.optionsApplied = writeBodyPromise
344+
self.backpressureFuture = writeEndFuture
345+
self.firstBodyPartReceived = channel.eventLoop.makePromise()
341346
}
342347

343348
var reads: Int {
@@ -348,8 +353,8 @@ class HTTPClientInternalTests: XCTestCase {
348353

349354
func didReceiveHead(task: HTTPClient.Task<Void>, _ head: HTTPResponseHead) -> EventLoopFuture<Void> {
350355
// This is to force NIO to send only 1 byte at a time.
351-
let future = task.connection!.channel.setOption(ChannelOptions.maxMessagesPerRead, value: 1).flatMap {
352-
task.connection!.channel.setOption(ChannelOptions.recvAllocator, value: FixedSizeRecvByteBufferAllocator(capacity: 1))
356+
let future = self.channel.setOption(ChannelOptions.maxMessagesPerRead, value: 1).flatMap {
357+
self.channel.setOption(ChannelOptions.recvAllocator, value: FixedSizeRecvByteBufferAllocator(capacity: 1))
353358
}
354359
future.cascade(to: self.optionsApplied)
355360
return future
@@ -361,8 +366,8 @@ class HTTPClientInternalTests: XCTestCase {
361366
self._reads += 1
362367
}
363368
// We need to notify the test when first byte of the message is arrived.
364-
self.messageReceived.succeed(())
365-
return self.backpressurePromise.futureResult
369+
self.firstBodyPartReceived.succeed(())
370+
return self.self.backpressureFuture
366371
}
367372

368373
func didFinishRequest(task: HTTPClient.Task<Response>) throws {}
@@ -403,37 +408,67 @@ class HTTPClientInternalTests: XCTestCase {
403408

404409
// cannot test with NIOTS as `maxMessagesPerRead` is not supported
405410
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
406-
let httpClient = HTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup))
407-
let delegate = BackpressureTestDelegate(eventLoop: httpClient.eventLoopGroup.next())
411+
let eventLoop = eventLoopGroup.next()
412+
let writeBodyPromise = eventLoop.makePromise(of: Void.self)
413+
let writeEndPromise = eventLoop.makePromise(of: Void.self)
408414
let httpBin = HTTPBin { _ in
409415
WriteAfterFutureSucceedsHandler(
410-
bodyFuture: delegate.optionsApplied.futureResult,
411-
endFuture: delegate.backpressurePromise.futureResult
416+
bodyFuture: writeBodyPromise.futureResult,
417+
endFuture: writeEndPromise.futureResult
412418
)
413419
}
414-
415420
defer {
416-
XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true))
417421
XCTAssertNoThrow(try httpBin.shutdown())
418422
XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully())
419423
}
420-
424+
425+
421426
let request = try Request(url: "http://localhost:\(httpBin.port)/custom")
422-
423-
let requestFuture = httpClient.execute(request: request, delegate: delegate).futureResult
427+
let logger = Logger(label: "test-connection")
428+
429+
let clientFactory = HTTPConnectionPool.ConnectionFactory(
430+
key: .init(request),
431+
tlsConfiguration: nil,
432+
clientConfiguration: .init(),
433+
sslContextCache: .init())
434+
var maybeChannel: Channel?
435+
XCTAssertNoThrow(maybeChannel = try clientFactory.makeHTTP1Channel(
436+
connectionID: 1,
437+
deadline: .now() + .seconds(10),
438+
eventLoop: eventLoopGroup.next(),
439+
logger: logger).wait())
440+
441+
guard let channel = maybeChannel else { return XCTFail("Expected to have a channel at this point") }
442+
443+
let delegate = BackpressureTestDelegate(
444+
channel: channel,
445+
writeBodyPromise: writeBodyPromise,
446+
writeEndFuture: writeEndPromise.futureResult)
447+
let task = HTTPClient.Task<BackpressureTestDelegate.Response>(eventLoop: eventLoop, logger: logger)
448+
449+
450+
let taskHandler = TaskHandler(task: task,
451+
kind: request.kind,
452+
delegate: delegate,
453+
redirectHandler: nil,
454+
ignoreUncleanSSLShutdown: true,
455+
logger: logger)
456+
457+
XCTAssertNoThrow(try channel.pipeline.addHandler(taskHandler).wait())
458+
XCTAssertNoThrow(try channel.writeAndFlush(request).wait())
424459

425460
// We need to wait for channel options that limit NIO to sending only one byte at a time.
426-
try delegate.optionsApplied.futureResult.wait()
461+
XCTAssertNoThrow(try delegate.optionsApplied.futureResult.wait())
427462

428463
// Send 4 bytes, but only one should be received until the backpressure promise is succeeded.
429464

430465
// Now we wait until message is delivered to client channel pipeline
431-
try delegate.messageReceived.futureResult.wait()
466+
XCTAssertNoThrow(try delegate.firstBodyPartReceived.futureResult.wait())
432467
XCTAssertEqual(delegate.reads, 1)
433468

434469
// Succeed the backpressure promise.
435-
delegate.backpressurePromise.succeed(())
436-
try requestFuture.wait()
470+
writeEndPromise.succeed(())
471+
XCTAssertNoThrow(try task.futureResult.wait())
437472

438473
// At this point all other bytes should be delivered.
439474
XCTAssertEqual(delegate.reads, 4)

0 commit comments

Comments
 (0)