Skip to content

Commit 61da70d

Browse files
committed
implement PostgresConnection.query and .execute with metadata
1 parent 5d817be commit 61da70d

File tree

6 files changed

+239
-9
lines changed

6 files changed

+239
-9
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
.DS_Store
22
/.build
3+
/.index-build
34
/Packages
45
/*.xcodeproj
56
DerivedData

Diff for: Sources/PostgresNIO/Connection/PostgresConnection.swift

+81
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,48 @@ extension PostgresConnection {
438438
}
439439
}
440440

441+
// use this for queries where you want to consume the rows.
442+
// we can use the `consume` scope to better ensure structured concurrency when consuming the rows.
443+
public func query<Result>(
444+
_ query: PostgresQuery,
445+
logger: Logger,
446+
file: String = #fileID,
447+
line: Int = #line,
448+
_ consume: (PostgresRowSequence) async throws -> Result
449+
) async throws -> (Result, PostgresQueryMetadata) {
450+
var logger = logger
451+
logger[postgresMetadataKey: .connectionID] = "\(self.id)"
452+
453+
guard query.binds.count <= Int(UInt16.max) else {
454+
throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line)
455+
}
456+
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
457+
let context = ExtendedQueryContext(
458+
query: query,
459+
logger: logger,
460+
promise: promise
461+
)
462+
463+
self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
464+
465+
do {
466+
let (rowStream, rowSequence) = try await promise.futureResult.map { rowStream in
467+
(rowStream, rowStream.asyncSequence())
468+
}.get()
469+
let result = try await consume(rowSequence)
470+
try await rowStream.drain().get()
471+
guard let metadata = PostgresQueryMetadata(string: rowStream.commandTag) else {
472+
throw PSQLError.invalidCommandTag(rowStream.commandTag)
473+
}
474+
return (result, metadata)
475+
} catch var error as PSQLError {
476+
error.file = file
477+
error.line = line
478+
error.query = query
479+
throw error // rethrow with more metadata
480+
}
481+
}
482+
441483
/// Start listening for a channel
442484
public func listen(_ channel: String) async throws -> PostgresNotificationSequence {
443485
let id = self.internalListenID.loadThenWrappingIncrement(ordering: .relaxed)
@@ -531,6 +573,45 @@ extension PostgresConnection {
531573
}
532574
}
533575

576+
// use this for queries where you want to consume the rows.
577+
// we can use the `consume` scope to better ensure structured concurrency when consuming the rows.
578+
@discardableResult
579+
public func execute(
580+
_ query: PostgresQuery,
581+
logger: Logger,
582+
file: String = #fileID,
583+
line: Int = #line
584+
) async throws -> PostgresQueryMetadata {
585+
var logger = logger
586+
logger[postgresMetadataKey: .connectionID] = "\(self.id)"
587+
588+
guard query.binds.count <= Int(UInt16.max) else {
589+
throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line)
590+
}
591+
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
592+
let context = ExtendedQueryContext(
593+
query: query,
594+
logger: logger,
595+
promise: promise
596+
)
597+
598+
self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
599+
600+
do {
601+
let rowStream = try await promise.futureResult.get()
602+
try await rowStream.drain().get()
603+
guard let metadata = PostgresQueryMetadata(string: rowStream.commandTag) else {
604+
throw PSQLError.invalidCommandTag(rowStream.commandTag)
605+
}
606+
return metadata
607+
} catch var error as PSQLError {
608+
error.file = file
609+
error.line = line
610+
error.query = query
611+
throw error // rethrow with more metadata
612+
}
613+
}
614+
534615
#if compiler(>=6.0)
535616
/// Puts the connection into an open transaction state, for the provided `closure`'s lifetime.
536617
///

Diff for: Sources/PostgresNIO/New/PSQLRowStream.swift

+57-1
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,63 @@ final class PSQLRowStream: @unchecked Sendable {
276276
return self.eventLoop.makeFailedFuture(error)
277277
}
278278
}
279-
279+
280+
// MARK: Drain on EventLoop
281+
282+
func drain() -> EventLoopFuture<Void> {
283+
if self.eventLoop.inEventLoop {
284+
return self.drain0()
285+
} else {
286+
return self.eventLoop.flatSubmit {
287+
self.drain0()
288+
}
289+
}
290+
}
291+
292+
private func drain0() -> EventLoopFuture<Void> {
293+
self.eventLoop.preconditionInEventLoop()
294+
295+
switch self.downstreamState {
296+
case .waitingForConsumer(let bufferState):
297+
switch bufferState {
298+
case .streaming(var buffer, let dataSource):
299+
let promise = self.eventLoop.makePromise(of: Void.self)
300+
301+
buffer.removeAll()
302+
self.downstreamState = .iteratingRows(onRow: { _ in }, promise, dataSource)
303+
// immediately request more
304+
dataSource.request(for: self)
305+
306+
return promise.futureResult
307+
308+
case .finished(_, let summary):
309+
self.downstreamState = .consumed(.success(summary))
310+
return self.eventLoop.makeSucceededVoidFuture()
311+
312+
case .failure(let error):
313+
self.downstreamState = .consumed(.failure(error))
314+
return self.eventLoop.makeFailedFuture(error)
315+
}
316+
case .asyncSequence(let consumer, let dataSource, _):
317+
consumer.finish()
318+
319+
let promise = self.eventLoop.makePromise(of: Void.self)
320+
321+
self.downstreamState = .iteratingRows(onRow: { _ in }, promise, dataSource)
322+
// immediately request more
323+
dataSource.request(for: self)
324+
325+
return promise.futureResult
326+
case .consumed(.success):
327+
// already drained
328+
return self.eventLoop.makeSucceededVoidFuture()
329+
case .consumed(let .failure(error)):
330+
return self.eventLoop.makeFailedFuture(error)
331+
default:
332+
preconditionFailure("Invalid state: \(self.downstreamState)")
333+
}
334+
}
335+
280336
internal func noticeReceived(_ notice: PostgresBackendMessage.NoticeResponse) {
281337
self.logger.debug("Notice Received", metadata: [
282338
.notice: "\(notice)"

Diff for: Sources/PostgresNIO/New/PostgresRowSequence.swift

+2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ extension PostgresRowSequence {
6060
extension PostgresRowSequence.AsyncIterator: Sendable {}
6161

6262
extension PostgresRowSequence {
63+
/// Collects all rows into an array.
64+
/// - Returns: The rows.
6365
public func collect() async throws -> [PostgresRow] {
6466
var result = [PostgresRow]()
6567
for try await row in self {

Diff for: Tests/IntegrationTests/AsyncTests.swift

+95-8
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,98 @@ final class AsyncPostgresConnectionTests: XCTestCase {
4646
}
4747
}
4848

49+
func testSelect10kRowsWithMetadata() async throws {
50+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
51+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
52+
let eventLoop = eventLoopGroup.next()
53+
54+
let start = 1
55+
let end = 10000
56+
57+
try await withTestConnection(on: eventLoop) { connection in
58+
let (result, metadata) = try await connection.query(
59+
"SELECT generate_series(\(start), \(end));",
60+
logger: .psqlTest
61+
) { rows in
62+
var counter = 0
63+
for try await row in rows {
64+
let element = try row.decode(Int.self)
65+
XCTAssertEqual(element, counter + 1)
66+
counter += 1
67+
}
68+
return counter
69+
}
70+
71+
XCTAssertEqual(metadata.command, "SELECT")
72+
XCTAssertEqual(metadata.oid, nil)
73+
XCTAssertEqual(metadata.rows, end)
74+
75+
XCTAssertEqual(result, end)
76+
}
77+
}
78+
79+
func testSelectRowsWithMetadataNotConsumedAtAll() async throws {
80+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
81+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
82+
let eventLoop = eventLoopGroup.next()
83+
84+
let start = 1
85+
let end = 10000
86+
87+
try await withTestConnection(on: eventLoop) { connection in
88+
let (_, metadata) = try await connection.query(
89+
"SELECT generate_series(\(start), \(end));",
90+
logger: .psqlTest
91+
) { _ in }
92+
93+
XCTAssertEqual(metadata.command, "SELECT")
94+
XCTAssertEqual(metadata.oid, nil)
95+
XCTAssertEqual(metadata.rows, end)
96+
}
97+
}
98+
99+
func testSelectRowsWithMetadataNotFullyConsumed() async throws {
100+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
101+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
102+
let eventLoop = eventLoopGroup.next()
103+
104+
try await withTestConnection(on: eventLoop) { connection in
105+
do {
106+
_ = try await connection.query(
107+
"SELECT generate_series(1, 10000);",
108+
logger: .psqlTest
109+
) { rows in
110+
for try await _ in rows { break }
111+
}
112+
// This path is also fine
113+
} catch is CancellationError {
114+
// Expected
115+
} catch {
116+
XCTFail("Expected 'CancellationError', got: \(String(reflecting: error))")
117+
}
118+
}
119+
}
120+
121+
func testExecuteRowsWithMetadata() async throws {
122+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
123+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
124+
let eventLoop = eventLoopGroup.next()
125+
126+
let start = 1
127+
let end = 10000
128+
129+
try await withTestConnection(on: eventLoop) { connection in
130+
let metadata = try await connection.execute(
131+
"SELECT generate_series(\(start), \(end));",
132+
logger: .psqlTest
133+
)
134+
135+
XCTAssertEqual(metadata.command, "SELECT")
136+
XCTAssertEqual(metadata.oid, nil)
137+
XCTAssertEqual(metadata.rows, end)
138+
}
139+
}
140+
49141
func testSelectActiveConnection() async throws {
50142
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
51143
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
@@ -207,7 +299,7 @@ final class AsyncPostgresConnectionTests: XCTestCase {
207299

208300
try await withTestConnection(on: eventLoop) { connection in
209301
// Max binds limit is UInt16.max which is 65535 which is 3 * 5 * 17 * 257
210-
// Max columns limit is 1664, so we will only make 5 * 257 columns which is less
302+
// Max columns limit appears to be ~1600, so we will only make 5 * 257 columns which is less
211303
// Then we will insert 3 * 17 rows
212304
// In the insertion, there will be a total of 3 * 17 * 5 * 257 == UInt16.max bindings
213305
// If the test is successful, it means Postgres supports UInt16.max bindings
@@ -241,13 +333,8 @@ final class AsyncPostgresConnectionTests: XCTestCase {
241333
unsafeSQL: "INSERT INTO table1 VALUES \(insertionValues)",
242334
binds: binds
243335
)
244-
try await connection.query(insertionQuery, logger: .psqlTest)
245-
246-
let countQuery = PostgresQuery(unsafeSQL: "SELECT COUNT(*) FROM table1")
247-
let countRows = try await connection.query(countQuery, logger: .psqlTest)
248-
var countIterator = countRows.makeAsyncIterator()
249-
let insertedRowsCount = try await countIterator.next()?.decode(Int.self, context: .default)
250-
XCTAssertEqual(rowsCount, insertedRowsCount)
336+
let metadata = try await connection.execute(insertionQuery, logger: .psqlTest)
337+
XCTAssertEqual(metadata.rows, rowsCount)
251338

252339
let dropQuery = PostgresQuery(unsafeSQL: "DROP TABLE table1")
253340
try await connection.query(dropQuery, logger: .psqlTest)

Diff for: docker-compose.yml

+3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ x-shared-config: &shared_config
1010
- 5432:5432
1111

1212
services:
13+
psql-17:
14+
image: postgres:17
15+
<<: *shared_config
1316
psql-16:
1417
image: postgres:16
1518
<<: *shared_config

0 commit comments

Comments
 (0)