Skip to content

NFC: Add AsyncStream-based API to AsyncProcess #7830

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 7, 2024
92 changes: 83 additions & 9 deletions Sources/Basics/AsyncProcess.swift
Original file line number Diff line number Diff line change
Expand Up @@ -173,20 +173,31 @@ package final class AsyncProcess {
case stdinUnavailable
}

package typealias OutputStream = AsyncStream<[UInt8]>
package typealias ReadableStream = AsyncStream<[UInt8]>

package enum OutputRedirection {
package enum OutputRedirection: Sendable {
/// Do not redirect the output
case none
/// Collect stdout and stderr output and provide it back via ProcessResult object. If redirectStderr is true,
/// stderr be redirected to stdout.

/// Collect stdout and stderr output and provide it back via ``AsyncProcessResult`` object. If
/// `redirectStderr` is `true`, `stderr` be redirected to `stdout`.
case collect(redirectStderr: Bool)
/// Stream stdout and stderr via the corresponding closures. If redirectStderr is true, stderr be redirected to
/// stdout.

/// Stream `stdout` and `stderr` via the corresponding closures. If `redirectStderr` is `true`, `stderr` will
/// be redirected to `stdout`.
case stream(stdout: OutputClosure, stderr: OutputClosure, redirectStderr: Bool)

/// Stream stdout and stderr as `AsyncSequence` provided as an argument to closures passed to
/// ``AsyncProcess/launch(stdoutStream:stderrStream:)``.
case asyncStream(
stdoutStream: ReadableStream,
stdoutContinuation: ReadableStream.Continuation,
stderrStream: ReadableStream,
stderrContinuation: ReadableStream.Continuation
)

/// Default collect OutputRedirection that defaults to not redirect stderr. Provided for API compatibility.
package static let collect: OutputRedirection = .collect(redirectStderr: false)
package static let collect: Self = .collect(redirectStderr: false)

/// Default stream OutputRedirection that defaults to not redirect stderr. Provided for API compatibility.
package static func stream(stdout: @escaping OutputClosure, stderr: @escaping OutputClosure) -> Self {
Expand All @@ -197,15 +208,19 @@ package final class AsyncProcess {
switch self {
case .none:
false
case .collect, .stream:
case .collect, .stream, .asyncStream:
true
}
}

package var outputClosures: (stdoutClosure: OutputClosure, stderrClosure: OutputClosure)? {
switch self {
case .stream(let stdoutClosure, let stderrClosure, _):
case let .stream(stdoutClosure, stderrClosure, _):
(stdoutClosure: stdoutClosure, stderrClosure: stderrClosure)

case let .asyncStream(stdoutStream, stdoutContinuation, stderrStream, stderrContinuation):
(stdoutClosure: { stdoutContinuation.yield($0) }, stderrClosure: { stderrContinuation.yield($0) })

case .collect, .none:
nil
}
Expand Down Expand Up @@ -946,6 +961,65 @@ extension AsyncProcess {
try await self.popen(arguments: args, environment: environment, loggingHandler: loggingHandler)
}

package typealias DuplexStreamHandler =
@Sendable (_ stdinStream: WritableByteStream, _ stdoutStream: ReadableStream) async throws -> ()
package typealias ReadableStreamHandler =
@Sendable (_ stderrStream: ReadableStream) async throws -> ()

/// Launches a new `AsyncProcess` instances, allowing the caller to consume `stdout` and `stderr` output
/// with handlers that support structured concurrency.
/// - Parameters:
/// - arguments: CLI command used to launch the process.
/// - environment: environment variables passed to the launched process.
/// - loggingHandler: handler used for logging,
/// - stdoutHandler: asynchronous bidirectional handler closure that receives `stdin` and `stdout` streams as
/// arguments.
/// - stderrHandler: asynchronous unidirectional handler closure that receives `stderr` stream as an argument.
/// - Returns: ``AsyncProcessResult`` value as received from the underlying ``AsyncProcess/waitUntilExit()`` call
/// made on ``AsyncProcess`` instance.
package static func popen(
arguments: [String],
environment: Environment = .current,
loggingHandler: LoggingHandler? = .none,
stdoutHandler: @escaping DuplexStreamHandler,
stderrHandler: ReadableStreamHandler? = nil
) async throws -> AsyncProcessResult {
let (stdoutStream, stdoutContinuation) = ReadableStream.makeStream()
let (stderrStream, stderrContinuation) = ReadableStream.makeStream()

let process = AsyncProcess(
arguments: arguments,
environment: environment,
outputRedirection: .stream {
stdoutContinuation.yield($0)
} stderr: {
stderrContinuation.yield($0)
},
loggingHandler: loggingHandler
)

return try await withThrowingTaskGroup(of: Void.self) { group in
let stdinStream = try process.launch()

group.addTask {
try await stdoutHandler(stdinStream, stdoutStream)
}

if let stderrHandler {
group.addTask {
try await stderrHandler(stderrStream)
}
}

defer {
stdoutContinuation.finish()
stderrContinuation.finish()
}

return try await process.waitUntilExit()
}
}

/// Execute a subprocess and get its (UTF-8) output if it has a non zero exit.
///
/// - Parameters:
Expand Down
60 changes: 49 additions & 11 deletions Tests/BasicsTests/AsyncProcessTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,8 @@ final class AsyncProcessTests: XCTestCase {
}

func testAsyncStream() async throws {
let (stdoutStream, stdoutContinuation) = AsyncProcess.OutputStream.makeStream()
let (stderrStream, stderrContinuation) = AsyncProcess.OutputStream.makeStream()
let (stdoutStream, stdoutContinuation) = AsyncProcess.ReadableStream.makeStream()
let (stderrStream, stderrContinuation) = AsyncProcess.ReadableStream.makeStream()

let process = AsyncProcess(
scriptName: "echo",
Expand All @@ -407,15 +407,15 @@ final class AsyncProcessTests: XCTestCase {
}
)

try await withThrowingTaskGroup(of: Void.self) { group in
let result = try await withThrowingTaskGroup(of: Void.self) { group in
let stdin = try process.launch()

group.addTask {
var counter = 0
stdin.write("Hello \(counter)\n")
stdin.flush()

for try await output in stdoutStream {
for await output in stdoutStream {
XCTAssertEqual(output, .init("Hello \(counter)\n".utf8))
counter += 1

Expand All @@ -430,9 +430,8 @@ final class AsyncProcessTests: XCTestCase {

group.addTask {
var counter = 0
for try await output in stderrStream {
for await output in stderrStream {
counter += 1
XCTAssertTrue(output.isEmpty)
}

XCTAssertEqual(counter, 0)
Expand All @@ -443,8 +442,43 @@ final class AsyncProcessTests: XCTestCase {
stderrContinuation.finish()
}

try await process.waitUntilExit()
return try await process.waitUntilExit()
}

XCTAssertEqual(result.exitStatus, .terminated(code: 0))
}

func testAsyncStreamHighLevelAPI() async throws {
let result = try await AsyncProcess.popen(
scriptName: "echo",
stdout: { stdin, stdout in
var counter = 0
stdin.write("Hello \(counter)\n")
stdin.flush()

for await output in stdout {
XCTAssertEqual(output, .init("Hello \(counter)\n".utf8))
counter += 1

stdin.write(.init("Hello \(counter)\n".utf8))
stdin.flush()
}

XCTAssertEqual(counter, 5)

try stdin.close()
},
stderr: { stderr in
var counter = 0
for await output in stderr {
counter += 1
}

XCTAssertEqual(counter, 0)
}
)

XCTAssertEqual(result.exitStatus, .terminated(code: 0))
}
}

Expand All @@ -465,9 +499,7 @@ extension AsyncProcess {
)
}

#if compiler(>=5.8)
@available(*, noasync)
#endif
fileprivate static func checkNonZeroExit(
scriptName: String,
environment: Environment = .current,
Expand All @@ -493,9 +525,7 @@ extension AsyncProcess {
)
}

#if compiler(>=5.8)
@available(*, noasync)
#endif
@discardableResult
fileprivate static func popen(
scriptName: String,
Expand All @@ -514,4 +544,12 @@ extension AsyncProcess {
) async throws -> AsyncProcessResult {
try await self.popen(arguments: [self.script(scriptName)], environment: .current, loggingHandler: loggingHandler)
}

fileprivate static func popen(
scriptName: String,
stdout: @escaping AsyncProcess.DuplexStreamHandler,
stderr: AsyncProcess.ReadableStreamHandler? = nil
) async throws -> AsyncProcessResult {
try await self.popen(arguments: [self.script(scriptName)], stdoutHandler: stdout, stderrHandler: stderr)
}
}