Skip to content

Make RequestBag conform to Sendable #837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 30, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 69 additions & 64 deletions Sources/AsyncHTTPClient/RequestBag.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ import NIOCore
import NIOHTTP1
import NIOSSL

final class RequestBag<Delegate: HTTPClientResponseDelegate> {
@preconcurrency
final class RequestBag<Delegate: HTTPClientResponseDelegate & Sendable>: Sendable {
/// Defends against the call stack getting too large when consuming body parts.
///
/// If the response body comes in lots of tiny chunks, we'll deliver those tiny chunks to users
Expand All @@ -35,16 +36,23 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
}

private let delegate: Delegate
private var request: HTTPClient.Request

// the request state is synchronized on the task eventLoop
private var state: StateMachine

// the consume body part stack depth is synchronized on the task event loop.
private var consumeBodyPartStackDepth: Int
struct LoopBoundState: @unchecked Sendable {
// The 'StateMachine' *isn't* Sendable (it holds various objects which aren't). This type
// needs to be sendable so that we can construct a loop bound box off of the event loop
// to hold this state and then subsequently only access it from the event loop. This needs
// to happen so that the request bag can be constructed off of the event loop. If it's
// constructed on the event loop then there's a timing window between users issuing
// a request and calling shutdown where the underlying pool doesn't know about the request
// so the shutdown call may cancel it.
var request: HTTPClient.Request
var state: StateMachine
var consumeBodyPartStackDepth: Int
// if a redirect occurs, we store the task for it so we can propagate cancellation
var redirectTask: HTTPClient.Task<Delegate.Response>? = nil
}

// if a redirect occurs, we store the task for it so we can propagate cancellation
private var redirectTask: HTTPClient.Task<Delegate.Response>? = nil
private let loopBoundState: NIOLoopBoundBox<LoopBoundState>

// MARK: HTTPClientTask properties

Expand All @@ -61,6 +69,8 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {

let eventLoopPreference: HTTPClient.EventLoopPreference

let tlsConfiguration: TLSConfiguration?

init(
request: HTTPClient.Request,
eventLoopPreference: HTTPClient.EventLoopPreference,
Expand All @@ -73,9 +83,13 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
self.poolKey = .init(request, dnsOverride: requestOptions.dnsOverride)
self.eventLoopPreference = eventLoopPreference
self.task = task
self.state = .init(redirectHandler: redirectHandler)
self.consumeBodyPartStackDepth = 0
self.request = request

let loopBoundState = LoopBoundState(
request: request,
state: StateMachine(redirectHandler: redirectHandler),
consumeBodyPartStackDepth: 0
)
self.loopBoundState = NIOLoopBoundBox.makeBoxSendingValue(loopBoundState, eventLoop: task.eventLoop)
self.connectionDeadline = connectionDeadline
self.requestOptions = requestOptions
self.delegate = delegate
Expand All @@ -84,6 +98,8 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
self.requestHead = head
self.requestFramingMetadata = metadata

self.tlsConfiguration = request.tlsConfiguration

self.task.taskDelegate = self
self.task.futureResult.whenComplete { _ in
self.task.taskDelegate = nil
Expand All @@ -92,16 +108,13 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {

private func requestWasQueued0(_ scheduler: HTTPRequestScheduler) {
self.logger.debug("Request was queued (waiting for a connection to become available)")

self.task.eventLoop.assertInEventLoop()
self.state.requestWasQueued(scheduler)
self.loopBoundState.value.state.requestWasQueued(scheduler)
}

// MARK: - Request -

private func willExecuteRequest0(_ executor: HTTPRequestExecutor) {
self.task.eventLoop.assertInEventLoop()
let action = self.state.willExecuteRequest(executor)
let action = self.loopBoundState.value.state.willExecuteRequest(executor)
switch action {
case .cancelExecuter(let executor):
executor.cancelRequest(self)
Expand All @@ -115,26 +128,22 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
}

private func requestHeadSent0() {
self.task.eventLoop.assertInEventLoop()

self.delegate.didSendRequestHead(task: self.task, self.requestHead)

if self.request.body == nil {
if self.loopBoundState.value.request.body == nil {
self.delegate.didSendRequest(task: self.task)
}
}

private func resumeRequestBodyStream0() {
self.task.eventLoop.assertInEventLoop()

let produceAction = self.state.resumeRequestBodyStream()
let produceAction = self.loopBoundState.value.state.resumeRequestBodyStream()

switch produceAction {
case .startWriter:
guard let body = self.request.body else {
guard let body = self.loopBoundState.value.request.body else {
preconditionFailure("Expected to have a body, if the `HTTPRequestStateMachine` resume a request stream")
}
self.request.body = nil
self.loopBoundState.value.request.body = nil

let writer = HTTPClient.Body.StreamWriter {
self.writeNextRequestPart($0)
Expand All @@ -153,9 +162,7 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
}

private func pauseRequestBodyStream0() {
self.task.eventLoop.assertInEventLoop()

self.state.pauseRequestBodyStream()
self.loopBoundState.value.state.pauseRequestBodyStream()
}

private func writeNextRequestPart(_ part: IOData) -> EventLoopFuture<Void> {
Expand All @@ -169,9 +176,7 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
}

private func writeNextRequestPart0(_ part: IOData) -> EventLoopFuture<Void> {
self.eventLoop.assertInEventLoop()

let action = self.state.writeNextRequestPart(part, taskEventLoop: self.task.eventLoop)
let action = self.loopBoundState.value.state.writeNextRequestPart(part, taskEventLoop: self.task.eventLoop)

switch action {
case .failTask(let error):
Expand All @@ -193,9 +198,7 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
}

private func finishRequestBodyStream(_ result: Result<Void, Error>) {
self.task.eventLoop.assertInEventLoop()

let action = self.state.finishRequestBodyStream(result)
let action = self.loopBoundState.value.state.finishRequestBodyStream(result)

switch action {
case .none:
Expand Down Expand Up @@ -226,20 +229,22 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
// MARK: - Response -

private func receiveResponseHead0(_ head: HTTPResponseHead) {
self.task.eventLoop.assertInEventLoop()

self.delegate.didVisitURL(task: self.task, self.request, head)
self.delegate.didVisitURL(task: self.task, self.loopBoundState.value.request, head)

// runs most likely on channel eventLoop
switch self.state.receiveResponseHead(head) {
switch self.loopBoundState.value.state.receiveResponseHead(head) {
case .none:
break

case .signalBodyDemand(let executor):
executor.demandResponseBodyStream(self)

case .redirect(let executor, let handler, let head, let newURL):
self.redirectTask = handler.redirect(status: head.status, to: newURL, promise: self.task.promise)
self.loopBoundState.value.redirectTask = handler.redirect(
status: head.status,
to: newURL,
promise: self.task.promise
)
executor.cancelRequest(self)

case .forwardResponseHead(let head):
Expand All @@ -253,17 +258,19 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
}

private func receiveResponseBodyParts0(_ buffer: CircularBuffer<ByteBuffer>) {
self.task.eventLoop.assertInEventLoop()

switch self.state.receiveResponseBodyParts(buffer) {
switch self.loopBoundState.value.state.receiveResponseBodyParts(buffer) {
case .none:
break

case .signalBodyDemand(let executor):
executor.demandResponseBodyStream(self)

case .redirect(let executor, let handler, let head, let newURL):
self.redirectTask = handler.redirect(status: head.status, to: newURL, promise: self.task.promise)
self.loopBoundState.value.redirectTask = handler.redirect(
status: head.status,
to: newURL,
promise: self.task.promise
)
executor.cancelRequest(self)

case .forwardResponsePart(let part):
Expand All @@ -277,8 +284,7 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
}

private func succeedRequest0(_ buffer: CircularBuffer<ByteBuffer>?) {
self.task.eventLoop.assertInEventLoop()
let action = self.state.succeedRequest(buffer)
let action = self.loopBoundState.value.state.succeedRequest(buffer)

switch action {
case .none:
Expand All @@ -299,13 +305,15 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
}

case .redirect(let handler, let head, let newURL):
self.redirectTask = handler.redirect(status: head.status, to: newURL, promise: self.task.promise)
self.loopBoundState.value.redirectTask = handler.redirect(
status: head.status,
to: newURL,
promise: self.task.promise
)
}
}

private func consumeMoreBodyData0(resultOfPreviousConsume result: Result<Void, Error>) {
self.task.eventLoop.assertInEventLoop()

// We get defensive here about the maximum stack depth. It's possible for the `didReceiveBodyPart`
// future to be returned to us completed. If it is, we will recurse back into this method. To
// break that recursion we have a max stack depth which we increment and decrement in this method:
Expand All @@ -316,24 +324,27 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
// that risk ending up in this loop. That's because we don't need an accurate count: our limit is
// a best-effort target anyway, one stack frame here or there does not put us at risk. We're just
// trying to prevent ourselves looping out of control.
self.consumeBodyPartStackDepth += 1
self.loopBoundState.value.consumeBodyPartStackDepth += 1
defer {
self.consumeBodyPartStackDepth -= 1
assert(self.consumeBodyPartStackDepth >= 0)
self.loopBoundState.value.consumeBodyPartStackDepth -= 1
assert(self.loopBoundState.value.consumeBodyPartStackDepth >= 0)
}

let consumptionAction = self.state.consumeMoreBodyData(resultOfPreviousConsume: result)
let consumptionAction = self.loopBoundState.value.state.consumeMoreBodyData(
resultOfPreviousConsume: result
)

switch consumptionAction {
case .consume(let byteBuffer):
self.delegate.didReceiveBodyPart(task: self.task, byteBuffer)
.hop(to: self.task.eventLoop)
.assumeIsolated()
.whenComplete { result in
if self.consumeBodyPartStackDepth < Self.maxConsumeBodyPartStackDepth {
if self.loopBoundState.value.consumeBodyPartStackDepth < Self.maxConsumeBodyPartStackDepth {
self.consumeMoreBodyData0(resultOfPreviousConsume: result)
} else {
// We need to unwind the stack, let's take a break.
self.task.eventLoop.execute {
self.task.eventLoop.assumeIsolated().execute {
self.consumeMoreBodyData0(resultOfPreviousConsume: result)
}
}
Expand All @@ -344,7 +355,7 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
case .finishStream:
do {
let response = try self.delegate.didFinishRequest(task: self.task)
self.task.promise.succeed(response)
self.task.promise.assumeIsolated().succeed(response)
} catch {
self.task.promise.fail(error)
}
Expand All @@ -358,13 +369,11 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
}

private func fail0(_ error: Error) {
self.task.eventLoop.assertInEventLoop()

let action = self.state.fail(error)
let action = self.loopBoundState.value.state.fail(error)

self.executeFailAction0(action)

self.redirectTask?.fail(reason: error)
self.loopBoundState.value.redirectTask?.fail(reason: error)
}

private func executeFailAction0(_ action: RequestBag<Delegate>.StateMachine.FailAction) {
Expand All @@ -381,8 +390,7 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
}

func deadlineExceeded0() {
self.task.eventLoop.assertInEventLoop()
let action = self.state.deadlineExceeded()
let action = self.loopBoundState.value.state.deadlineExceeded()

switch action {
case .cancelScheduler(let scheduler):
Expand All @@ -404,9 +412,6 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
}

extension RequestBag: HTTPSchedulableRequest, HTTPClientTaskDelegate {
var tlsConfiguration: TLSConfiguration? {
self.request.tlsConfiguration
}

func requestWasQueued(_ scheduler: HTTPRequestScheduler) {
if self.task.eventLoop.inEventLoop {
Expand Down
Loading