Skip to content

Commit

Permalink
Make forward progress when Query is cancelled (#261)
Browse files Browse the repository at this point in the history
Co-authored-by: Gwynne Raskind <gwynne@darkrainfall.org>
  • Loading branch information
fabianfett and gwynne committed Mar 18, 2022
1 parent ab624e4 commit c1683ba
Show file tree
Hide file tree
Showing 9 changed files with 262 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,15 @@ struct ConnectionStateMachine {
// MARK: Consumer

mutating func cancelQueryStream() -> ConnectionAction {
preconditionFailure("Unimplemented")
guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else {
preconditionFailure("Tried to cancel stream without active query")
}

return self.avoidingStateMachineCoW { machine -> ConnectionAction in
let action = queryState.cancel()
machine.state = .extendedQuery(queryState, connectionContext)
return machine.modify(with: action)
}
}

mutating func requestQueryRows() -> ConnectionAction {
Expand Down Expand Up @@ -1074,6 +1082,8 @@ extension ConnectionStateMachine {
return true
case .failedToAddSSLHandler:
return true
case .queryCancelled:
return false
case .server(let message):
guard let sqlState = message.fields[.sqlState] else {
// any error message that doesn't have a sql state field, is unexpected by default.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import NIOCore

struct ExtendedQueryStateMachine {

enum State {
private enum State {
case initialized(ExtendedQueryContext)
case parseDescribeBindExecuteSyncSent(ExtendedQueryContext)

Expand All @@ -15,6 +15,8 @@ struct ExtendedQueryStateMachine {
/// used after receiving a `bindComplete` message
case bindCompleteReceived(ExtendedQueryContext)
case streaming([RowDescription.Column], RowStreamStateMachine)
/// Indicates that the current query was cancelled and we want to drain rows from the connection ASAP
case drain([RowDescription.Column])

case commandComplete(commandTag: String)
case error(PSQLError)
Expand All @@ -41,9 +43,11 @@ struct ExtendedQueryStateMachine {
case wait
}

var state: State
private var state: State
private var isCancelled: Bool

init(queryContext: ExtendedQueryContext) {
self.isCancelled = false
self.state = .initialized(queryContext)
}

Expand Down Expand Up @@ -71,6 +75,44 @@ struct ExtendedQueryStateMachine {
}
}
}

mutating func cancel() -> Action {
switch self.state {
case .initialized:
preconditionFailure("Start must be called immediatly after the query was created")

case .parseDescribeBindExecuteSyncSent(let queryContext),
.parseCompleteReceived(let queryContext),
.parameterDescriptionReceived(let queryContext),
.rowDescriptionReceived(let queryContext, _),
.noDataMessageReceived(let queryContext),
.bindCompleteReceived(let queryContext):
guard !self.isCancelled else {
return .wait
}

self.isCancelled = true
return .failQuery(queryContext, with: .queryCancelled)

case .streaming(let columns, var streamStateMachine):
precondition(!self.isCancelled)
self.isCancelled = true
self.state = .drain(columns)
switch streamStateMachine.fail() {
case .wait:
return .forwardStreamError(.queryCancelled, read: false)
case .read:
return .forwardStreamError(.queryCancelled, read: true)
}

case .commandComplete, .error, .drain:
// the stream has already finished.
return .wait

case .modifying:
preconditionFailure("Invalid state: \(self.state)")
}
}

mutating func parseCompletedReceived() -> Action {
guard case .parseDescribeBindExecuteSyncSent(let queryContext) = self.state else {
Expand Down Expand Up @@ -147,9 +189,11 @@ struct ExtendedQueryStateMachine {
.parameterDescriptionReceived,
.bindCompleteReceived,
.streaming,
.drain,
.commandComplete,
.error:
return self.setAndFireError(.unexpectedBackendMessage(.bindComplete))

case .modifying:
preconditionFailure("Invalid state")
}
Expand All @@ -169,6 +213,13 @@ struct ExtendedQueryStateMachine {
state = .streaming(columns, demandStateMachine)
return .wait
}

case .drain(let columns):
guard dataRow.columnCount == columns.count else {
return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow)))
}
// we ignore all rows and wait for readyForQuery
return .wait

case .initialized,
.parseDescribeBindExecuteSyncSent,
Expand Down Expand Up @@ -198,6 +249,11 @@ struct ExtendedQueryStateMachine {
state = .commandComplete(commandTag: commandTag)
return .forwardStreamComplete(demandStateMachine.end(), commandTag: commandTag)
}

case .drain:
precondition(self.isCancelled)
self.state = .commandComplete(commandTag: commandTag)
return .wait

case .initialized,
.parseDescribeBindExecuteSyncSent,
Expand Down Expand Up @@ -229,7 +285,7 @@ struct ExtendedQueryStateMachine {
return self.setAndFireError(error)
case .rowDescriptionReceived, .noDataMessageReceived:
return self.setAndFireError(error)
case .streaming:
case .streaming, .drain:
return self.setAndFireError(error)
case .commandComplete:
return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage)))
Expand Down Expand Up @@ -269,6 +325,9 @@ struct ExtendedQueryStateMachine {
}
}

case .drain:
return .wait

case .initialized,
.parseDescribeBindExecuteSyncSent,
.parseCompleteReceived,
Expand All @@ -291,6 +350,7 @@ struct ExtendedQueryStateMachine {
switch self.state {
case .initialized,
.commandComplete,
.drain,
.error,
.parseDescribeBindExecuteSyncSent,
.parseCompleteReceived,
Expand Down Expand Up @@ -327,6 +387,7 @@ struct ExtendedQueryStateMachine {
.bindCompleteReceived:
return .read
case .streaming(let columns, var demandStateMachine):
precondition(!self.isCancelled)
return self.avoidingStateMachineCoW { state -> Action in
let action = demandStateMachine.read()
state = .streaming(columns, demandStateMachine)
Expand All @@ -339,6 +400,7 @@ struct ExtendedQueryStateMachine {
}
case .initialized,
.commandComplete,
.drain,
.error:
// we already have the complete stream received, now we are waiting for a
// `readyForQuery` package. To receive this we need to read!
Expand All @@ -361,11 +423,20 @@ struct ExtendedQueryStateMachine {
.bindCompleteReceived(let context):
self.state = .error(error)
return .failQuery(context, with: error)
case .streaming:

case .drain:
self.state = .error(error)
return .forwardStreamError(error, read: false)

case .streaming(_, var streamStateMachine):
self.state = .error(error)
switch streamStateMachine.fail() {
case .wait:
return .forwardStreamError(error, read: false)
case .read:
return .forwardStreamError(error, read: true)
}

case .commandComplete, .error:
preconditionFailure("""
This state must not be reached. If the query `.isComplete`, the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ struct RowStreamStateMachine {
/// preserved for performance reasons.
case waitingForDemand([DataRow])

case failed

case modifying
}

Expand Down Expand Up @@ -63,6 +65,11 @@ struct RowStreamStateMachine {
buffer.append(newRow)
self.state = .waitingForReadOrDemand(buffer)

case .failed:
// Once the row stream state machine is marked as failed, no further events must be
// forwarded to it.
preconditionFailure("Invalid state: \(self.state)")

case .modifying:
preconditionFailure("Invalid state: \(self.state)")
}
Expand All @@ -86,6 +93,11 @@ struct RowStreamStateMachine {
.waitingForReadOrDemand:
preconditionFailure("How can we receive a body part, after a channelReadComplete, but no read has been forwarded yet. Invalid state: \(self.state)")

case .failed:
// Once the row stream state machine is marked as failed, no further events must be
// forwarded to it.
preconditionFailure("Invalid state: \(self.state)")

case .modifying:
preconditionFailure("Invalid state: \(self.state)")
}
Expand All @@ -111,6 +123,11 @@ struct RowStreamStateMachine {
// the next `channelReadComplete` we will forward all buffered data
return .wait

case .failed:
// Once the row stream state machine is marked as failed, no further events must be
// forwarded to it.
preconditionFailure("Invalid state: \(self.state)")

case .modifying:
preconditionFailure("Invalid state: \(self.state)")
}
Expand All @@ -136,6 +153,11 @@ struct RowStreamStateMachine {
// from the consumer
return .wait

case .failed:
// Once the row stream state machine is marked as failed, no further events must be
// forwarded to it.
preconditionFailure("Invalid state: \(self.state)")

case .modifying:
preconditionFailure("Invalid state: \(self.state)")
}
Expand All @@ -158,6 +180,33 @@ struct RowStreamStateMachine {
// receive a call to `end()`, when we don't expect it here.
return buffer

case .failed:
// Once the row stream state machine is marked as failed, no further events must be
// forwarded to it.
preconditionFailure("Invalid state: \(self.state)")

case .modifying:
preconditionFailure("Invalid state: \(self.state)")
}
}

mutating func fail() -> Action {
switch self.state {
case .waitingForRows,
.waitingForReadOrDemand,
.waitingForRead:
self.state = .failed
return .wait

case .waitingForDemand:
self.state = .failed
return .read

case .failed:
// Once the row stream state machine is marked as failed, no further events must be
// forwarded to it.
preconditionFailure("Invalid state: \(self.state)")

case .modifying:
preconditionFailure("Invalid state: \(self.state)")
}
Expand Down
7 changes: 6 additions & 1 deletion Sources/PostgresNIO/New/PSQLError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ struct PSQLError: Error {
case unsupportedAuthMechanism(PSQLAuthScheme)
case authMechanismRequiresPassword
case saslError(underlyingError: Error)


case queryCancelled
case tooManyParameters
case connectionQuiescing
case connectionClosed
Expand Down Expand Up @@ -58,6 +59,10 @@ struct PSQLError: Error {
static func sasl(underlying: Error) -> PSQLError {
Self.init(.saslError(underlyingError: underlying))
}

static var queryCancelled: PSQLError {
Self.init(.queryCancelled)
}

static var tooManyParameters: PSQLError {
Self.init(.tooManyParameters)
Expand Down
19 changes: 12 additions & 7 deletions Sources/PostgresNIO/New/PostgresChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
/// A `ChannelHandlerContext` to be used for non channel related events. (for example: More rows needed).
///
/// The context is captured in `handlerAdded` and released` in `handlerRemoved`
private var handlerContext: ChannelHandlerContext!
private var handlerContext: ChannelHandlerContext?
private var rowStream: PSQLRowStream?
private var decoder: NIOSingleStepByteToMessageProcessor<PostgresBackendMessageDecoder>
private var encoder: BufferedMessageEncoder!
Expand Down Expand Up @@ -262,7 +262,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler {

case .forwardStreamComplete(let buffer, let commandTag):
guard let rowStream = self.rowStream else {
preconditionFailure("Expected to have a row stream here.")
// if the stream was cancelled we don't have it here anymore.
return
}
self.rowStream = nil
if buffer.count > 0 {
Expand Down Expand Up @@ -499,18 +500,20 @@ final class PostgresChannelHandler: ChannelDuplexHandler {

extension PostgresChannelHandler: PSQLRowsDataSource {
func request(for stream: PSQLRowStream) {
guard self.rowStream === stream else {
guard self.rowStream === stream, let handlerContext = self.handlerContext else {
return
}
let action = self.state.requestQueryRows()
self.run(action, with: self.handlerContext!)
self.run(action, with: handlerContext)
}

func cancel(for stream: PSQLRowStream) {
guard self.rowStream === stream else {
guard self.rowStream === stream, let handlerContext = self.handlerContext else {
return
}
// we ignore this right now :)
let action = self.state.cancelQueryStream()
self.run(action, with: handlerContext)
}
}

Expand All @@ -519,7 +522,8 @@ extension PostgresConnection.Configuration.Authentication {
AuthContext(
username: self.username,
password: self.password,
database: self.database)
database: self.database
)
}
}

Expand All @@ -529,7 +533,8 @@ extension AuthContext {
user: self.username,
database: self.database,
options: nil,
replication: .false)
replication: .false
)
}
}

Expand Down
2 changes: 2 additions & 0 deletions Sources/PostgresNIO/Postgres+PSQLCompat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import NIOCore
extension PSQLError {
func toPostgresError() -> Error {
switch self.base {
case .queryCancelled:
return self
case .server(let errorMessage):
var fields = [PostgresMessage.Error.Field: String]()
fields.reserveCapacity(errorMessage.fields.count)
Expand Down
Loading

0 comments on commit c1683ba

Please sign in to comment.