Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run queries directly on PostgresClient #456

Merged
merged 2 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions Sources/PostgresNIO/New/PSQLRowStream.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ final class PSQLRowStream: @unchecked Sendable {
case iteratingRows(onRow: (PostgresRow) throws -> (), EventLoopPromise<Void>, PSQLRowsDataSource)
case waitingForAll([PostgresRow], EventLoopPromise<[PostgresRow]>, PSQLRowsDataSource)
case consumed(Result<String, Error>)
case asyncSequence(AsyncSequenceSource, PSQLRowsDataSource)
case asyncSequence(AsyncSequenceSource, PSQLRowsDataSource, onFinish: @Sendable () -> ())
}

internal let rowDescription: [RowDescription.Column]
Expand Down Expand Up @@ -75,7 +75,7 @@ final class PSQLRowStream: @unchecked Sendable {

// MARK: Async Sequence

func asyncSequence() -> PostgresRowSequence {
func asyncSequence(onFinish: @escaping @Sendable () -> () = {}) -> PostgresRowSequence {
self.eventLoop.preconditionInEventLoop()

guard case .waitingForConsumer(let bufferState) = self.downstreamState else {
Expand All @@ -95,13 +95,13 @@ final class PSQLRowStream: @unchecked Sendable {
switch bufferState {
case .streaming(let bufferedRows, let dataSource):
let yieldResult = source.yield(contentsOf: bufferedRows)
self.downstreamState = .asyncSequence(source, dataSource)

self.downstreamState = .asyncSequence(source, dataSource, onFinish: onFinish)
self.executeActionBasedOnYieldResult(yieldResult, source: dataSource)

case .finished(let buffer, let commandTag):
_ = source.yield(contentsOf: buffer)
source.finish()
onFinish()
self.downstreamState = .consumed(.success(commandTag))

case .failure(let error):
Expand Down Expand Up @@ -130,7 +130,7 @@ final class PSQLRowStream: @unchecked Sendable {
case .consumed:
break

case .asyncSequence(_, let dataSource):
case .asyncSequence(_, let dataSource, _):
dataSource.request(for: self)
}
}
Expand All @@ -147,9 +147,10 @@ final class PSQLRowStream: @unchecked Sendable {

private func cancel0() {
switch self.downstreamState {
case .asyncSequence(_, let dataSource):
case .asyncSequence(_, let dataSource, let onFinish):
self.downstreamState = .consumed(.failure(CancellationError()))
dataSource.cancel(for: self)
onFinish()

case .consumed:
return
Expand Down Expand Up @@ -320,7 +321,7 @@ final class PSQLRowStream: @unchecked Sendable {
// immediately request more
dataSource.request(for: self)

case .asyncSequence(let consumer, let source):
case .asyncSequence(let consumer, let source, _):
let yieldResult = consumer.yield(contentsOf: newRows)
self.executeActionBasedOnYieldResult(yieldResult, source: source)

Expand Down Expand Up @@ -359,10 +360,11 @@ final class PSQLRowStream: @unchecked Sendable {
self.downstreamState = .consumed(.success(commandTag))
promise.succeed(rows)

case .asyncSequence(let source, _):
source.finish()
case .asyncSequence(let source, _, let onFinish):
self.downstreamState = .consumed(.success(commandTag))

source.finish()
onFinish()

case .consumed:
break
}
Expand All @@ -384,9 +386,10 @@ final class PSQLRowStream: @unchecked Sendable {
self.downstreamState = .consumed(.failure(error))
promise.fail(error)

case .asyncSequence(let consumer, _):
consumer.finish(error)
case .asyncSequence(let consumer, _, let onFinish):
self.downstreamState = .consumed(.failure(error))
consumer.finish(error)
onFinish()

case .consumed:
break
Expand Down
52 changes: 52 additions & 0 deletions Sources/PostgresNIO/Pool/PostgresClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,58 @@ public final class PostgresClient: Sendable {
return try await closure(connection)
}

/// Run a query on the Postgres server the client is connected to.
///
/// - Parameters:
/// - query: The ``PostgresQuery`` to run
/// - logger: The `Logger` to log into for the query
/// - file: The file, the query was started in. Used for better error reporting.
/// - line: The line, the query was started in. Used for better error reporting.
/// - Returns: A ``PostgresRowSequence`` containing the rows the server sent as the query result.
/// The sequence be discarded.
@discardableResult
public func query(
_ query: PostgresQuery,
logger: Logger,
file: String = #fileID,
line: Int = #line
) async throws -> PostgresRowSequence {
do {
guard query.binds.count <= Int(UInt16.max) else {
throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line)
}

let connection = try await self.leaseConnection()

var logger = logger
logger[postgresMetadataKey: .connectionID] = "\(connection.id)"

let promise = connection.channel.eventLoop.makePromise(of: PSQLRowStream.self)
let context = ExtendedQueryContext(
query: query,
logger: logger,
promise: promise
)

connection.channel.write(HandlerTask.extendedQuery(context), promise: nil)

promise.futureResult.whenFailure { _ in
self.pool.releaseConnection(connection)
}

return try await promise.futureResult.map {
$0.asyncSequence(onFinish: {
self.pool.releaseConnection(connection)
})
}.get()
} catch var error as PSQLError {
error.file = file
error.line = line
error.query = query
throw error // rethrow with more metadata
}
}

/// The client's run method. Users must call this function in order to start the client's background task processing
/// like creating and destroying connections and running timers.
///
Expand Down
37 changes: 37 additions & 0 deletions Tests/IntegrationTests/PostgresClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,43 @@ final class PostgresClientTests: XCTestCase {
taskGroup.cancelAll()
}
}

func testQueryDirectly() async throws {
var mlogger = Logger(label: "test")
mlogger.logLevel = .debug
let logger = mlogger
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 8)
self.addTeardownBlock {
try await eventLoopGroup.shutdownGracefully()
}

let clientConfig = PostgresClient.Configuration.makeTestConfiguration()
let client = PostgresClient(configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger)

await withThrowingTaskGroup(of: Void.self) { taskGroup in
taskGroup.addTask {
await client.run()
}

for i in 0..<10000 {
taskGroup.addTask {
do {
try await client.query("SELECT 1", logger: logger)
logger.info("Success", metadata: ["run": "\(i)"])
} catch {
XCTFail("Unexpected error: \(error)")
}
}
}

for _ in 0..<10000 {
_ = await taskGroup.nextResult()!
}

taskGroup.cancelAll()
}
}

}

@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *)
Expand Down
Loading