Skip to content

Commit

Permalink
Revert "[Fix] Query Hangs if Connection is Closed (#487)" (#501)
Browse files Browse the repository at this point in the history
This reverts commit f55caa7.
  • Loading branch information
fabianfett committed Aug 20, 2024
1 parent d18b137 commit cd5318a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 197 deletions.
39 changes: 11 additions & 28 deletions Sources/PostgresNIO/Connection/PostgresConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ public final class PostgresConnection: @unchecked Sendable {
promise: promise
)

self.write(.extendedQuery(context), cascadingFailureTo: promise)
self.channel.write(HandlerTask.extendedQuery(context), promise: nil)

return promise.futureResult
}
Expand All @@ -239,8 +239,7 @@ public final class PostgresConnection: @unchecked Sendable {
promise: promise
)

self.write(.extendedQuery(context), cascadingFailureTo: promise)

self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
return promise.futureResult.map { rowDescription in
PSQLPreparedStatement(name: name, query: query, connection: self, rowDescription: rowDescription)
}
Expand All @@ -256,17 +255,15 @@ public final class PostgresConnection: @unchecked Sendable {
logger: logger,
promise: promise)

self.write(.extendedQuery(context), cascadingFailureTo: promise)

self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
return promise.futureResult
}

func close(_ target: CloseTarget, logger: Logger) -> EventLoopFuture<Void> {
let promise = self.channel.eventLoop.makePromise(of: Void.self)
let context = CloseCommandContext(target: target, logger: logger, promise: promise)

self.write(.closeCommand(context), cascadingFailureTo: promise)

self.channel.write(HandlerTask.closeCommand(context), promise: nil)
return promise.futureResult
}

Expand Down Expand Up @@ -429,7 +426,7 @@ extension PostgresConnection {
promise: promise
)

self.write(.extendedQuery(context), cascadingFailureTo: promise)
self.channel.write(HandlerTask.extendedQuery(context), promise: nil)

do {
return try await promise.futureResult.map({ $0.asyncSequence() }).get()
Expand Down Expand Up @@ -458,11 +455,7 @@ extension PostgresConnection {

let task = HandlerTask.startListening(listener)

let writePromise = self.channel.eventLoop.makePromise(of: Void.self)
self.channel.write(task, promise: writePromise)
writePromise.futureResult.whenFailure { error in
listener.failed(error)
}
self.channel.write(task, promise: nil)
}
} onCancel: {
let task = HandlerTask.cancelListening(channel, id)
Expand All @@ -487,9 +480,7 @@ extension PostgresConnection {
logger: logger,
promise: promise
))

self.write(task, cascadingFailureTo: promise)

self.channel.write(task, promise: nil)
do {
return try await promise.futureResult
.map { $0.asyncSequence() }
Expand Down Expand Up @@ -524,9 +515,7 @@ extension PostgresConnection {
logger: logger,
promise: promise
))

self.write(task, cascadingFailureTo: promise)

self.channel.write(task, promise: nil)
do {
return try await promise.futureResult
.map { $0.commandTag }
Expand All @@ -541,12 +530,6 @@ extension PostgresConnection {
throw error // rethrow with more metadata
}
}

private func write<T>(_ task: HandlerTask, cascadingFailureTo promise: EventLoopPromise<T>) {
let writePromise = self.channel.eventLoop.makePromise(of: Void.self)
self.channel.write(task, promise: writePromise)
writePromise.futureResult.cascadeFailure(to: promise)
}
}

// MARK: EventLoopFuture interface
Expand Down Expand Up @@ -691,7 +674,7 @@ internal enum PostgresCommands: PostgresRequest {

/// Context for receiving NotificationResponse messages on a connection, used for PostgreSQL's `LISTEN`/`NOTIFY` support.
public final class PostgresListenContext: Sendable {
let promise: EventLoopPromise<Void>
private let promise: EventLoopPromise<Void>

var future: EventLoopFuture<Void> {
self.promise.futureResult
Expand Down Expand Up @@ -730,7 +713,8 @@ extension PostgresConnection {
closure: notificationHandler
)

self.write(.startListening(listener), cascadingFailureTo: listenContext.promise)
let task = HandlerTask.startListening(listener)
self.channel.write(task, promise: nil)

listenContext.future.whenComplete { _ in
let task = HandlerTask.cancelListening(channel, id)
Expand Down Expand Up @@ -777,4 +761,3 @@ extension PostgresConnection {
#endif
}
}

1 change: 1 addition & 0 deletions Tests/IntegrationTests/PSQLIntegrationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -359,4 +359,5 @@ final class IntegrationTests: XCTestCase {
XCTAssertEqual(obj?.bar, 2)
}
}

}
169 changes: 0 additions & 169 deletions Tests/PostgresNIOTests/New/PostgresConnectionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -224,63 +224,6 @@ class PostgresConnectionTests: XCTestCase {
}
}

func testSimpleListenFailsIfConnectionIsClosed() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await connection.closeGracefully()

XCTAssertEqual(channel.isActive, false)

do {
_ = try await connection.listen("test_channel")
XCTFail("Expected to fail")
} catch let error as ChannelError {
XCTAssertEqual(error, .ioOnClosedChannel)
}
}

func testSimpleListenFailsIfConnectionIsClosedWhileListening() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await withThrowingTaskGroup(of: Void.self) { taskGroup in
taskGroup.addTask {
let events = try await connection.listen("foo")
var iterator = events.makeAsyncIterator()
let first = try await iterator.next()
XCTAssertEqual(first?.payload, "wooohooo")
do {
_ = try await iterator.next()
XCTFail("Did not expect to not throw")
} catch let error as PSQLError {
XCTAssertEqual(error.code, .clientClosedConnection)
}
}

let listenMessage = try await channel.waitForUnpreparedRequest()
XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#)

try await channel.writeInbound(PostgresBackendMessage.parseComplete)
try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: [])))
try await channel.writeInbound(PostgresBackendMessage.noData)
try await channel.writeInbound(PostgresBackendMessage.bindComplete)
try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN"))
try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle))

try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo")))

try await connection.close()

XCTAssertEqual(channel.isActive, false)

switch await taskGroup.nextResult()! {
case .success:
break
case .failure(let failure):
XCTFail("Unexpected error: \(failure)")
}
}
}

func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in
Expand Down Expand Up @@ -695,118 +638,6 @@ class PostgresConnectionTests: XCTestCase {
}
}

func testQueryFailsIfConnectionIsClosed() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await connection.closeGracefully()

XCTAssertEqual(channel.isActive, false)

do {
_ = try await connection.query("SELECT version;", logger: self.logger)
XCTFail("Expected to fail")
} catch let error as ChannelError {
XCTAssertEqual(error, .ioOnClosedChannel)
}
}

func testPrepareStatementFailsIfConnectionIsClosed() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await connection.closeGracefully()

XCTAssertEqual(channel.isActive, false)

do {
_ = try await connection.prepareStatement("SELECT version;", with: "test_query", logger: .psqlTest).get()
XCTFail("Expected to fail")
} catch let error as ChannelError {
XCTAssertEqual(error, .ioOnClosedChannel)
}
}

func testExecuteFailsIfConnectionIsClosed() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await connection.closeGracefully()

XCTAssertEqual(channel.isActive, false)

do {
let statement = PSQLExecuteStatement(name: "SELECT version;", binds: .init(), rowDescription: nil)
_ = try await connection.execute(statement, logger: .psqlTest).get()
XCTFail("Expected to fail")
} catch let error as ChannelError {
XCTAssertEqual(error, .ioOnClosedChannel)
}
}

func testExecutePreparedStatementFailsIfConnectionIsClosed() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await connection.closeGracefully()

XCTAssertEqual(channel.isActive, false)

struct TestPreparedStatement: PostgresPreparedStatement {
static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1"
typealias Row = (Int, String)

var state: String

func makeBindings() -> PostgresBindings {
var bindings = PostgresBindings()
bindings.append(self.state)
return bindings
}

func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
try row.decode(Row.self)
}
}

do {
let preparedStatement = TestPreparedStatement(state: "active")
_ = try await connection.execute(preparedStatement, logger: .psqlTest)
XCTFail("Expected to fail")
} catch let error as ChannelError {
XCTAssertEqual(error, .ioOnClosedChannel)
}
}

func testExecutePreparedStatementWithVoidRowFailsIfConnectionIsClosed() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await connection.closeGracefully()

XCTAssertEqual(channel.isActive, false)

struct TestPreparedStatement: PostgresPreparedStatement {
static let sql = "SELECT * FROM pg_stat_activity WHERE state = $1"
typealias Row = ()

var state: String

func makeBindings() -> PostgresBindings {
var bindings = PostgresBindings()
bindings.append(self.state)
return bindings
}

func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
()
}
}

do {
let preparedStatement = TestPreparedStatement(state: "active")
_ = try await connection.execute(preparedStatement, logger: .psqlTest)
XCTFail("Expected to fail")
} catch let error as ChannelError {
XCTAssertEqual(error, .ioOnClosedChannel)
}
}

func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) {
let eventLoop = NIOAsyncTestingEventLoop()
let channel = await NIOAsyncTestingChannel(handlers: [
Expand Down

0 comments on commit cd5318a

Please sign in to comment.