Skip to content

Commit

Permalink
Implement SimpleQuery + Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MahdiBM committed Aug 25, 2024
1 parent 9f84290 commit 94eeb8a
Show file tree
Hide file tree
Showing 9 changed files with 314 additions and 31 deletions.
39 changes: 39 additions & 0 deletions Sources/PostgresNIO/Connection/PostgresConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,45 @@ extension PostgresConnection {
}
}

/// Run a query on the Postgres server the connection is connected to.
///
/// - Parameters:
/// - query: The simple query to run
/// - logger: The `Logger` to log into for the query
/// - file: The file, the query was started in. Used for better error reporting.
/// - line: The line, the query was started in. Used for better error reporting.
/// - Returns: A ``PostgresRowSequence`` containing the rows the server sent as the query result.
/// The sequence be discarded.
@discardableResult
public func simpleQuery(
_ query: String,
logger: Logger,
file: String = #fileID,
line: Int = #line
) async throws -> PostgresRowSequence {
var logger = logger
logger[postgresMetadataKey: .connectionID] = "\(self.id)"

let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
let context = ExtendedQueryContext(
simpleQuery: query,
logger: logger,
promise: promise
)

self.channel.write(HandlerTask.extendedQuery(context), promise: nil)

do {
return try await promise.futureResult.map({ $0.asyncSequence() }).get()
} catch var error as PSQLError {
error.file = file
error.line = line
// FIXME: just pass the string as a simple query, instead of acting like this is a PostgresQuery.
error.query = PostgresQuery(unsafeSQL: query)
throw error // rethrow with more metadata
}
}

/// Start listening for a channel
public func listen(_ channel: String) async throws -> PostgresNotificationSequence {
let id = self.internalListenID.loadThenWrappingIncrement(ordering: .relaxed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ struct ConnectionStateMachine {
// --- general actions
case sendParseDescribeBindExecuteSync(PostgresQuery)
case sendBindExecuteSync(PSQLExecuteStatement)
case sendQuery(String)
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError, cleanupContext: CleanUpContext?)
case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)

Expand Down Expand Up @@ -537,7 +538,7 @@ struct ConnectionStateMachine {

self.state = .readyForQuery(connectionContext)
return self.executeNextQueryFromQueue()

default:
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.readyForQuery(transactionState)))
}
Expand Down Expand Up @@ -585,7 +586,7 @@ struct ConnectionStateMachine {
switch task {
case .extendedQuery(let queryContext):
switch queryContext.query {
case .executeStatement(_, let promise), .unnamed(_, let promise):
case .executeStatement(_, let promise), .unnamed(_, let promise), .simpleQuery(_, let promise):
return .failQuery(promise, with: psqlErrror, cleanupContext: nil)
case .prepareStatement(_, _, _, let promise):
return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil)
Expand Down Expand Up @@ -745,7 +746,7 @@ struct ConnectionStateMachine {
guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else {
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.commandComplete(commandTag)))
}

self.state = .modifying // avoid CoW
let action = queryState.commandCompletedReceived(commandTag)
self.state = .extendedQuery(queryState, connectionContext)
Expand Down Expand Up @@ -855,6 +856,7 @@ struct ConnectionStateMachine {
case .sendParseDescribeBindExecuteSync,
.sendParseDescribeSync,
.sendBindExecuteSync,
.sendQuery,
.succeedQuery,
.succeedPreparedStatementCreation,
.forwardRows,
Expand Down Expand Up @@ -1035,6 +1037,8 @@ extension ConnectionStateMachine {
return .sendParseDescribeBindExecuteSync(query)
case .sendBindExecuteSync(let executeStatement):
return .sendBindExecuteSync(executeStatement)
case .sendQuery(let query):
return .sendQuery(query)
case .failQuery(let requestContext, with: let error):
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
return .failQuery(requestContext, with: error, cleanupContext: cleanupContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ struct ExtendedQueryStateMachine {
case sendParseDescribeBindExecuteSync(PostgresQuery)
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType])
case sendBindExecuteSync(PSQLExecuteStatement)

case sendQuery(String)

// --- general actions
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError)
case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)
Expand Down Expand Up @@ -85,6 +86,12 @@ struct ExtendedQueryStateMachine {
state = .messagesSent(queryContext)
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes)
}

case .simpleQuery(let query, _):
return self.avoidingStateMachineCoW { state -> Action in
state = .messagesSent(queryContext)
return .sendQuery(query)
}
}
}

Expand All @@ -105,7 +112,7 @@ struct ExtendedQueryStateMachine {

self.isCancelled = true
switch queryContext.query {
case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise):
case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise), .simpleQuery(_, let eventLoopPromise):
return .failQuery(eventLoopPromise, with: .queryCancelled)

case .prepareStatement(_, _, _, let eventLoopPromise):
Expand Down Expand Up @@ -171,11 +178,19 @@ struct ExtendedQueryStateMachine {
state = .noDataMessageReceived(queryContext)
return .succeedPreparedStatementCreation(promise, with: nil)
}

case .simpleQuery:
return self.setAndFireError(.unexpectedBackendMessage(.noData))
}
}

mutating func rowDescriptionReceived(_ rowDescription: RowDescription) -> Action {
guard case .parameterDescriptionReceived(let queryContext) = self.state else {
let queryContext: ExtendedQueryContext
switch self.state {
case .messagesSent(let extendedQueryContext),
.parameterDescriptionReceived(let extendedQueryContext):
queryContext = extendedQueryContext
default:
return self.setAndFireError(.unexpectedBackendMessage(.rowDescription(rowDescription)))
}

Expand All @@ -198,7 +213,7 @@ struct ExtendedQueryStateMachine {
}

switch queryContext.query {
case .unnamed, .executeStatement:
case .unnamed, .executeStatement, .simpleQuery:
return .wait

case .prepareStatement(_, _, _, let eventLoopPromise):
Expand All @@ -219,6 +234,9 @@ struct ExtendedQueryStateMachine {

case .prepareStatement:
return .evaluateErrorAtConnectionLevel(.unexpectedBackendMessage(.bindComplete))

case .simpleQuery:
return self.setAndFireError(.unexpectedBackendMessage(.bindComplete))
}

case .noDataMessageReceived(let queryContext):
Expand Down Expand Up @@ -258,20 +276,40 @@ struct ExtendedQueryStateMachine {
return .wait
}

case .rowDescriptionReceived(let queryContext, let columns):
switch queryContext.query {
case .simpleQuery(_, let eventLoopPromise):
// When receiving a data row, we must ensure that the data row column count
// matches the previously received row description column count.
guard dataRow.columnCount == columns.count else {
return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow)))
}

return self.avoidingStateMachineCoW { state -> Action in
var demandStateMachine = RowStreamStateMachine()
demandStateMachine.receivedRow(dataRow)
state = .streaming(columns, demandStateMachine)
let result = QueryResult(value: .rowDescription(columns), logger: queryContext.logger)
return .succeedQuery(eventLoopPromise, with: result)
}

case .unnamed, .executeStatement, .prepareStatement:
return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow)))
}

case .drain(let columns):
guard dataRow.columnCount == columns.count else {
return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow)))
}
// we ignore all rows and wait for readyForQuery
return .wait

case .initialized,
.messagesSent,
.parseCompleteReceived,
.parameterDescriptionReceived,
.noDataMessageReceived,
.emptyQueryResponseReceived,
.rowDescriptionReceived,
.bindCompleteReceived,
.commandComplete,
.error:
Expand All @@ -292,10 +330,36 @@ struct ExtendedQueryStateMachine {
return .succeedQuery(eventLoopPromise, with: result)
}

case .prepareStatement:
case .prepareStatement, .simpleQuery:
preconditionFailure("Invalid state: \(self.state)")
}


case .messagesSent(let context):
switch context.query {
case .simpleQuery(_, let eventLoopGroup):
return self.avoidingStateMachineCoW { state -> Action in
state = .commandComplete(commandTag: commandTag)
let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger)
return .succeedQuery(eventLoopGroup, with: result)
}

case .unnamed, .executeStatement, .prepareStatement:
return self.setAndFireError(.unexpectedBackendMessage(.commandComplete(commandTag)))
}

case .rowDescriptionReceived(let context, _):
switch context.query {
case .simpleQuery(_, let eventLoopPromise):
return self.avoidingStateMachineCoW { state -> Action in
state = .commandComplete(commandTag: commandTag)
let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger)
return .succeedQuery(eventLoopPromise, with: result)
}

case .unnamed, .executeStatement, .prepareStatement:
return self.setAndFireError(.unexpectedBackendMessage(.commandComplete(commandTag)))
}

case .streaming(_, var demandStateMachine):
return self.avoidingStateMachineCoW { state -> Action in
state = .commandComplete(commandTag: commandTag)
Expand All @@ -306,14 +370,12 @@ struct ExtendedQueryStateMachine {
precondition(self.isCancelled)
self.state = .commandComplete(commandTag: commandTag)
return .wait

case .initialized,
.messagesSent,
.parseCompleteReceived,
.parameterDescriptionReceived,
.noDataMessageReceived,
.emptyQueryResponseReceived,
.rowDescriptionReceived,
.commandComplete,
.error:
return self.setAndFireError(.unexpectedBackendMessage(.commandComplete(commandTag)))
Expand All @@ -323,20 +385,32 @@ struct ExtendedQueryStateMachine {
}

mutating func emptyQueryResponseReceived() -> Action {
guard case .bindCompleteReceived(let queryContext) = self.state else {
return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse))
}
switch self.state {
case .bindCompleteReceived(let queryContext):
switch queryContext.query {
case .unnamed(_, let eventLoopPromise),
.executeStatement(_, let eventLoopPromise):
return self.avoidingStateMachineCoW { state -> Action in
state = .emptyQueryResponseReceived
let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger)
return .succeedQuery(eventLoopPromise, with: result)
}

switch queryContext.query {
case .unnamed(_, let eventLoopPromise),
.executeStatement(_, let eventLoopPromise):
return self.avoidingStateMachineCoW { state -> Action in
state = .emptyQueryResponseReceived
let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger)
return .succeedQuery(eventLoopPromise, with: result)
case .prepareStatement, .simpleQuery:
return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse))
}

case .prepareStatement(_, _, _, _):
case .messagesSent(let queryContext):
switch queryContext.query {
case .simpleQuery(_, let eventLoopPromise):
return self.avoidingStateMachineCoW { state -> Action in
state = .emptyQueryResponseReceived
let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger)
return .succeedQuery(eventLoopPromise, with: result)
}
case .unnamed, .executeStatement, .prepareStatement:
return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse))
}
default:
return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse))
}
}
Expand Down Expand Up @@ -497,7 +571,7 @@ struct ExtendedQueryStateMachine {
return .evaluateErrorAtConnectionLevel(error)
} else {
switch context.query {
case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise):
case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise), .simpleQuery(_, let eventLoopPromise):
return .failQuery(eventLoopPromise, with: error)
case .prepareStatement(_, _, _, let eventLoopPromise):
return .failPreparedStatementCreation(eventLoopPromise, with: error)
Expand Down Expand Up @@ -536,7 +610,7 @@ struct ExtendedQueryStateMachine {
switch context.query {
case .prepareStatement:
return true
case .unnamed, .executeStatement:
case .unnamed, .executeStatement, .simpleQuery:
return false
}

Expand Down
15 changes: 14 additions & 1 deletion Sources/PostgresNIO/New/PSQLTask.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ enum PSQLTask {
eventLoopPromise.fail(error)
case .prepareStatement(_, _, _, let eventLoopPromise):
eventLoopPromise.fail(error)
case .simpleQuery(_, let eventLoopPromise):
eventLoopPromise.fail(error)
}

case .closeCommand(let closeCommandContext):
Expand All @@ -31,16 +33,18 @@ enum PSQLTask {
}
}

// FIXME: Either rename all these `ExtendedQuery`s to just like `Query` or pull out `simpleQuery`
final class ExtendedQueryContext {
enum Query {
case unnamed(PostgresQuery, EventLoopPromise<PSQLRowStream>)
case executeStatement(PSQLExecuteStatement, EventLoopPromise<PSQLRowStream>)
case prepareStatement(name: String, query: String, bindingDataTypes: [PostgresDataType], EventLoopPromise<RowDescription?>)
case simpleQuery(String, EventLoopPromise<PSQLRowStream>)
}

let query: Query
let logger: Logger

init(
query: PostgresQuery,
logger: Logger,
Expand Down Expand Up @@ -69,6 +73,15 @@ final class ExtendedQueryContext {
self.query = .prepareStatement(name: name, query: query, bindingDataTypes: bindingDataTypes, promise)
self.logger = logger
}

init(
simpleQuery: String,
logger: Logger,
promise: EventLoopPromise<PSQLRowStream>
) {
self.query = .simpleQuery(simpleQuery, promise)
self.logger = logger
}
}

final class PreparedStatementContext: Sendable {
Expand Down
Loading

0 comments on commit 94eeb8a

Please sign in to comment.