diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index a6efcfdf..eb9dc791 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -222,7 +222,7 @@ public final class PostgresConnection: @unchecked Sendable { promise: promise ) - self.write(.extendedQuery(context), cascadingFailureTo: promise) + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) return promise.futureResult } @@ -239,8 +239,7 @@ public final class PostgresConnection: @unchecked Sendable { promise: promise ) - self.write(.extendedQuery(context), cascadingFailureTo: promise) - + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) return promise.futureResult.map { rowDescription in PSQLPreparedStatement(name: name, query: query, connection: self, rowDescription: rowDescription) } @@ -256,8 +255,7 @@ public final class PostgresConnection: @unchecked Sendable { logger: logger, promise: promise) - self.write(.extendedQuery(context), cascadingFailureTo: promise) - + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) return promise.futureResult } @@ -265,8 +263,7 @@ public final class PostgresConnection: @unchecked Sendable { let promise = self.channel.eventLoop.makePromise(of: Void.self) let context = CloseCommandContext(target: target, logger: logger, promise: promise) - self.write(.closeCommand(context), cascadingFailureTo: promise) - + self.channel.write(HandlerTask.closeCommand(context), promise: nil) return promise.futureResult } @@ -429,7 +426,7 @@ extension PostgresConnection { promise: promise ) - self.write(.extendedQuery(context), cascadingFailureTo: promise) + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) do { return try await promise.futureResult.map({ $0.asyncSequence() }).get() @@ -458,11 +455,7 @@ extension PostgresConnection { let task = HandlerTask.startListening(listener) - let writePromise = self.channel.eventLoop.makePromise(of: Void.self) - self.channel.write(task, promise: writePromise) - writePromise.futureResult.whenFailure { error in - listener.failed(error) - } + self.channel.write(task, promise: nil) } } onCancel: { let task = HandlerTask.cancelListening(channel, id) @@ -487,9 +480,7 @@ extension PostgresConnection { logger: logger, promise: promise )) - - self.write(task, cascadingFailureTo: promise) - + self.channel.write(task, promise: nil) do { return try await promise.futureResult .map { $0.asyncSequence() } @@ -524,9 +515,7 @@ extension PostgresConnection { logger: logger, promise: promise )) - - self.write(task, cascadingFailureTo: promise) - + self.channel.write(task, promise: nil) do { return try await promise.futureResult .map { $0.commandTag } @@ -541,12 +530,6 @@ extension PostgresConnection { throw error // rethrow with more metadata } } - - private func write(_ task: HandlerTask, cascadingFailureTo promise: EventLoopPromise) { - let writePromise = self.channel.eventLoop.makePromise(of: Void.self) - self.channel.write(task, promise: writePromise) - writePromise.futureResult.cascadeFailure(to: promise) - } } // MARK: EventLoopFuture interface @@ -691,7 +674,7 @@ internal enum PostgresCommands: PostgresRequest { /// Context for receiving NotificationResponse messages on a connection, used for PostgreSQL's `LISTEN`/`NOTIFY` support. public final class PostgresListenContext: Sendable { - let promise: EventLoopPromise + private let promise: EventLoopPromise var future: EventLoopFuture { self.promise.futureResult @@ -730,7 +713,8 @@ extension PostgresConnection { closure: notificationHandler ) - self.write(.startListening(listener), cascadingFailureTo: listenContext.promise) + let task = HandlerTask.startListening(listener) + self.channel.write(task, promise: nil) listenContext.future.whenComplete { _ in let task = HandlerTask.cancelListening(channel, id) @@ -777,4 +761,3 @@ extension PostgresConnection { #endif } } - diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 913d91b2..57939c06 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -359,4 +359,5 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual(obj?.bar, 2) } } + } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 5c7d4c83..0bc61efd 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -224,63 +224,6 @@ class PostgresConnectionTests: XCTestCase { } } - func testSimpleListenFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await connection.closeGracefully() - - XCTAssertEqual(channel.isActive, false) - - do { - _ = try await connection.listen("test_channel") - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } - } - - func testSimpleListenFailsIfConnectionIsClosedWhileListening() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await withThrowingTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { - let events = try await connection.listen("foo") - var iterator = events.makeAsyncIterator() - let first = try await iterator.next() - XCTAssertEqual(first?.payload, "wooohooo") - do { - _ = try await iterator.next() - XCTFail("Did not expect to not throw") - } catch let error as PSQLError { - XCTAssertEqual(error.code, .clientClosedConnection) - } - } - - let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) - - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - - try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) - - try await connection.close() - - XCTAssertEqual(channel.isActive, false) - - switch await taskGroup.nextResult()! { - case .success: - break - case .failure(let failure): - XCTFail("Unexpected error: \(failure)") - } - } - } - func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in @@ -695,118 +638,6 @@ class PostgresConnectionTests: XCTestCase { } } - func testQueryFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await connection.closeGracefully() - - XCTAssertEqual(channel.isActive, false) - - do { - _ = try await connection.query("SELECT version;", logger: self.logger) - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } - } - - func testPrepareStatementFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await connection.closeGracefully() - - XCTAssertEqual(channel.isActive, false) - - do { - _ = try await connection.prepareStatement("SELECT version;", with: "test_query", logger: .psqlTest).get() - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } - } - - func testExecuteFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await connection.closeGracefully() - - XCTAssertEqual(channel.isActive, false) - - do { - let statement = PSQLExecuteStatement(name: "SELECT version;", binds: .init(), rowDescription: nil) - _ = try await connection.execute(statement, logger: .psqlTest).get() - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } - } - - func testExecutePreparedStatementFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await connection.closeGracefully() - - XCTAssertEqual(channel.isActive, false) - - struct TestPreparedStatement: PostgresPreparedStatement { - static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" - typealias Row = (Int, String) - - var state: String - - func makeBindings() -> PostgresBindings { - var bindings = PostgresBindings() - bindings.append(self.state) - return bindings - } - - func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { - try row.decode(Row.self) - } - } - - do { - let preparedStatement = TestPreparedStatement(state: "active") - _ = try await connection.execute(preparedStatement, logger: .psqlTest) - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } - } - - func testExecutePreparedStatementWithVoidRowFailsIfConnectionIsClosed() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await connection.closeGracefully() - - XCTAssertEqual(channel.isActive, false) - - struct TestPreparedStatement: PostgresPreparedStatement { - static let sql = "SELECT * FROM pg_stat_activity WHERE state = $1" - typealias Row = () - - var state: String - - func makeBindings() -> PostgresBindings { - var bindings = PostgresBindings() - bindings.append(self.state) - return bindings - } - - func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { - () - } - } - - do { - let preparedStatement = TestPreparedStatement(state: "active") - _ = try await connection.execute(preparedStatement, logger: .psqlTest) - XCTFail("Expected to fail") - } catch let error as ChannelError { - XCTAssertEqual(error, .ioOnClosedChannel) - } - } - func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { let eventLoop = NIOAsyncTestingEventLoop() let channel = await NIOAsyncTestingChannel(handlers: [