Skip to content

Revert "[Fix] Query Hangs if Connection is Closed (#487)" #501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 11 additions & 28 deletions Sources/PostgresNIO/Connection/PostgresConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ public final class PostgresConnection: @unchecked Sendable {
promise: promise
)

self.write(.extendedQuery(context), cascadingFailureTo: promise)
self.channel.write(HandlerTask.extendedQuery(context), promise: nil)

return promise.futureResult
}
Expand All @@ -239,8 +239,7 @@ public final class PostgresConnection: @unchecked Sendable {
promise: promise
)

self.write(.extendedQuery(context), cascadingFailureTo: promise)

self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
return promise.futureResult.map { rowDescription in
PSQLPreparedStatement(name: name, query: query, connection: self, rowDescription: rowDescription)
}
Expand All @@ -256,17 +255,15 @@ public final class PostgresConnection: @unchecked Sendable {
logger: logger,
promise: promise)

self.write(.extendedQuery(context), cascadingFailureTo: promise)

self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
return promise.futureResult
}

func close(_ target: CloseTarget, logger: Logger) -> EventLoopFuture<Void> {
let promise = self.channel.eventLoop.makePromise(of: Void.self)
let context = CloseCommandContext(target: target, logger: logger, promise: promise)

self.write(.closeCommand(context), cascadingFailureTo: promise)

self.channel.write(HandlerTask.closeCommand(context), promise: nil)
return promise.futureResult
}

Expand Down Expand Up @@ -429,7 +426,7 @@ extension PostgresConnection {
promise: promise
)

self.write(.extendedQuery(context), cascadingFailureTo: promise)
self.channel.write(HandlerTask.extendedQuery(context), promise: nil)

do {
return try await promise.futureResult.map({ $0.asyncSequence() }).get()
Expand Down Expand Up @@ -458,11 +455,7 @@ extension PostgresConnection {

let task = HandlerTask.startListening(listener)

let writePromise = self.channel.eventLoop.makePromise(of: Void.self)
self.channel.write(task, promise: writePromise)
writePromise.futureResult.whenFailure { error in
listener.failed(error)
}
self.channel.write(task, promise: nil)
}
} onCancel: {
let task = HandlerTask.cancelListening(channel, id)
Expand All @@ -487,9 +480,7 @@ extension PostgresConnection {
logger: logger,
promise: promise
))

self.write(task, cascadingFailureTo: promise)

self.channel.write(task, promise: nil)
do {
return try await promise.futureResult
.map { $0.asyncSequence() }
Expand Down Expand Up @@ -524,9 +515,7 @@ extension PostgresConnection {
logger: logger,
promise: promise
))

self.write(task, cascadingFailureTo: promise)

self.channel.write(task, promise: nil)
do {
return try await promise.futureResult
.map { $0.commandTag }
Expand All @@ -541,12 +530,6 @@ extension PostgresConnection {
throw error // rethrow with more metadata
}
}

private func write<T>(_ task: HandlerTask, cascadingFailureTo promise: EventLoopPromise<T>) {
let writePromise = self.channel.eventLoop.makePromise(of: Void.self)
self.channel.write(task, promise: writePromise)
writePromise.futureResult.cascadeFailure(to: promise)
}
}

// MARK: EventLoopFuture interface
Expand Down Expand Up @@ -691,7 +674,7 @@ internal enum PostgresCommands: PostgresRequest {

/// Context for receiving NotificationResponse messages on a connection, used for PostgreSQL's `LISTEN`/`NOTIFY` support.
public final class PostgresListenContext: Sendable {
let promise: EventLoopPromise<Void>
private let promise: EventLoopPromise<Void>

var future: EventLoopFuture<Void> {
self.promise.futureResult
Expand Down Expand Up @@ -730,7 +713,8 @@ extension PostgresConnection {
closure: notificationHandler
)

self.write(.startListening(listener), cascadingFailureTo: listenContext.promise)
let task = HandlerTask.startListening(listener)
self.channel.write(task, promise: nil)

listenContext.future.whenComplete { _ in
let task = HandlerTask.cancelListening(channel, id)
Expand Down Expand Up @@ -777,4 +761,3 @@ extension PostgresConnection {
#endif
}
}

1 change: 1 addition & 0 deletions Tests/IntegrationTests/PSQLIntegrationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -359,4 +359,5 @@ final class IntegrationTests: XCTestCase {
XCTAssertEqual(obj?.bar, 2)
}
}

}
169 changes: 0 additions & 169 deletions Tests/PostgresNIOTests/New/PostgresConnectionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -224,63 +224,6 @@ class PostgresConnectionTests: XCTestCase {
}
}

func testSimpleListenFailsIfConnectionIsClosed() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await connection.closeGracefully()

XCTAssertEqual(channel.isActive, false)

do {
_ = try await connection.listen("test_channel")
XCTFail("Expected to fail")
} catch let error as ChannelError {
XCTAssertEqual(error, .ioOnClosedChannel)
}
}

func testSimpleListenFailsIfConnectionIsClosedWhileListening() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await withThrowingTaskGroup(of: Void.self) { taskGroup in
taskGroup.addTask {
let events = try await connection.listen("foo")
var iterator = events.makeAsyncIterator()
let first = try await iterator.next()
XCTAssertEqual(first?.payload, "wooohooo")
do {
_ = try await iterator.next()
XCTFail("Did not expect to not throw")
} catch let error as PSQLError {
XCTAssertEqual(error.code, .clientClosedConnection)
}
}

let listenMessage = try await channel.waitForUnpreparedRequest()
XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#)

try await channel.writeInbound(PostgresBackendMessage.parseComplete)
try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: [])))
try await channel.writeInbound(PostgresBackendMessage.noData)
try await channel.writeInbound(PostgresBackendMessage.bindComplete)
try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN"))
try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle))

try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo")))

try await connection.close()

XCTAssertEqual(channel.isActive, false)

switch await taskGroup.nextResult()! {
case .success:
break
case .failure(let failure):
XCTFail("Unexpected error: \(failure)")
}
}
}

func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in
Expand Down Expand Up @@ -695,118 +638,6 @@ class PostgresConnectionTests: XCTestCase {
}
}

func testQueryFailsIfConnectionIsClosed() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await connection.closeGracefully()

XCTAssertEqual(channel.isActive, false)

do {
_ = try await connection.query("SELECT version;", logger: self.logger)
XCTFail("Expected to fail")
} catch let error as ChannelError {
XCTAssertEqual(error, .ioOnClosedChannel)
}
}

func testPrepareStatementFailsIfConnectionIsClosed() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await connection.closeGracefully()

XCTAssertEqual(channel.isActive, false)

do {
_ = try await connection.prepareStatement("SELECT version;", with: "test_query", logger: .psqlTest).get()
XCTFail("Expected to fail")
} catch let error as ChannelError {
XCTAssertEqual(error, .ioOnClosedChannel)
}
}

func testExecuteFailsIfConnectionIsClosed() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await connection.closeGracefully()

XCTAssertEqual(channel.isActive, false)

do {
let statement = PSQLExecuteStatement(name: "SELECT version;", binds: .init(), rowDescription: nil)
_ = try await connection.execute(statement, logger: .psqlTest).get()
XCTFail("Expected to fail")
} catch let error as ChannelError {
XCTAssertEqual(error, .ioOnClosedChannel)
}
}

func testExecutePreparedStatementFailsIfConnectionIsClosed() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await connection.closeGracefully()

XCTAssertEqual(channel.isActive, false)

struct TestPreparedStatement: PostgresPreparedStatement {
static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1"
typealias Row = (Int, String)

var state: String

func makeBindings() -> PostgresBindings {
var bindings = PostgresBindings()
bindings.append(self.state)
return bindings
}

func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
try row.decode(Row.self)
}
}

do {
let preparedStatement = TestPreparedStatement(state: "active")
_ = try await connection.execute(preparedStatement, logger: .psqlTest)
XCTFail("Expected to fail")
} catch let error as ChannelError {
XCTAssertEqual(error, .ioOnClosedChannel)
}
}

func testExecutePreparedStatementWithVoidRowFailsIfConnectionIsClosed() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await connection.closeGracefully()

XCTAssertEqual(channel.isActive, false)

struct TestPreparedStatement: PostgresPreparedStatement {
static let sql = "SELECT * FROM pg_stat_activity WHERE state = $1"
typealias Row = ()

var state: String

func makeBindings() -> PostgresBindings {
var bindings = PostgresBindings()
bindings.append(self.state)
return bindings
}

func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
()
}
}

do {
let preparedStatement = TestPreparedStatement(state: "active")
_ = try await connection.execute(preparedStatement, logger: .psqlTest)
XCTFail("Expected to fail")
} catch let error as ChannelError {
XCTAssertEqual(error, .ioOnClosedChannel)
}
}

func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) {
let eventLoop = NIOAsyncTestingEventLoop()
let channel = await NIOAsyncTestingChannel(handlers: [
Expand Down
Loading