Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Crash fix: Multiple bad messages could trigger reentrancy issue #379

Merged
merged 2 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,7 @@ struct ConnectionStateMachine {
.forwardStreamComplete,
.wait,
.read:
preconditionFailure("Expecting only failure actions if an error happened")
preconditionFailure("Invalid state: \(self.state)")
case .evaluateErrorAtConnectionLevel:
return .closeConnectionAndCleanup(cleanupContext)
case .failQuery(let queryContext, with: let error):
Expand All @@ -951,7 +951,7 @@ struct ConnectionStateMachine {
.succeedPreparedStatementCreation,
.read,
.wait:
preconditionFailure("Expecting only failure actions if an error happened")
preconditionFailure("Invalid state: \(self.state)")
case .failPreparedStatementCreation(let preparedStatementContext, with: let error):
return .failPreparedStatementCreation(preparedStatementContext, with: error, cleanupContext: cleanupContext)
}
Expand All @@ -970,22 +970,20 @@ struct ConnectionStateMachine {
.succeedClose,
.read,
.wait:
preconditionFailure("Expecting only failure actions if an error happened")
preconditionFailure("Invalid state: \(self.state)")
case .failClose(let closeCommandContext, with: let error):
return .failClose(closeCommandContext, with: error, cleanupContext: cleanupContext)
}
case .error:
// TBD: this is an interesting case. why would this case happen?
let cleanupContext = self.setErrorAndCreateCleanupContext(error)
return .closeConnectionAndCleanup(cleanupContext)

case .closing:
let cleanupContext = self.setErrorAndCreateCleanupContext(error)
return .closeConnectionAndCleanup(cleanupContext)
case .closed:
preconditionFailure("How can an error occur if the connection is already closed?")
case .error, .closing, .closed:
// We might run into this case because of reentrancy. For example: After we received an
// backend unexpected message, that we read of the wire, we bring this connection into
// the error state and will try to close the connection. However the server might have
// send further follow up messages. In those cases we will run into this method again
// and again. We should just ignore those events.
return .wait

case .modifying:
preconditionFailure("Invalid state")
preconditionFailure("Invalid state: \(self.state)")
}
}

Expand Down
107 changes: 61 additions & 46 deletions Sources/PostgresNIO/New/PostgresChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,17 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
}

func channelInactive(context: ChannelHandlerContext) {
do {
try self.decoder.finishProcessing(seenEOF: true) { message in
self.handleMessage(message, context: context)
}
} catch let error as PostgresMessageDecodingError {
let action = self.state.errorHappened(.messageDecodingFailure(error))
self.run(action, with: context)
} catch {
preconditionFailure("Expected to only get PSQLDecodingErrors from the PSQLBackendMessageDecoder.")
}

self.logger.trace("Channel inactive.")
let action = self.state.closed()
self.run(action, with: context)
Expand All @@ -100,51 +111,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {

do {
try self.decoder.process(buffer: buffer) { message in
self.logger.trace("Backend message received", metadata: [.message: "\(message)"])
let action: ConnectionStateMachine.ConnectionAction

switch message {
case .authentication(let authentication):
action = self.state.authenticationMessageReceived(authentication)
case .backendKeyData(let keyData):
action = self.state.backendKeyDataReceived(keyData)
case .bindComplete:
action = self.state.bindCompleteReceived()
case .closeComplete:
action = self.state.closeCompletedReceived()
case .commandComplete(let commandTag):
action = self.state.commandCompletedReceived(commandTag)
case .dataRow(let dataRow):
action = self.state.dataRowReceived(dataRow)
case .emptyQueryResponse:
action = self.state.emptyQueryResponseReceived()
case .error(let errorResponse):
action = self.state.errorReceived(errorResponse)
case .noData:
action = self.state.noDataReceived()
case .notice(let noticeResponse):
action = self.state.noticeReceived(noticeResponse)
case .notification(let notification):
action = self.state.notificationReceived(notification)
case .parameterDescription(let parameterDescription):
action = self.state.parameterDescriptionReceived(parameterDescription)
case .parameterStatus(let parameterStatus):
action = self.state.parameterStatusReceived(parameterStatus)
case .parseComplete:
action = self.state.parseCompleteReceived()
case .portalSuspended:
action = self.state.portalSuspendedReceived()
case .readyForQuery(let transactionState):
action = self.state.readyForQueryReceived(transactionState)
case .rowDescription(let rowDescription):
action = self.state.rowDescriptionReceived(rowDescription)
case .sslSupported:
action = self.state.sslSupportedReceived(unprocessedBytes: self.decoder.unprocessedBytes)
case .sslUnsupported:
action = self.state.sslUnsupportedReceived()
}

self.run(action, with: context)
self.handleMessage(message, context: context)
}
} catch let error as PostgresMessageDecodingError {
let action = self.state.errorHappened(.messageDecodingFailure(error))
Expand All @@ -153,7 +120,55 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
preconditionFailure("Expected to only get PSQLDecodingErrors from the PSQLBackendMessageDecoder.")
}
}


private func handleMessage(_ message: PostgresBackendMessage, context: ChannelHandlerContext) {
self.logger.trace("Backend message received", metadata: [.message: "\(message)"])
let action: ConnectionStateMachine.ConnectionAction

switch message {
case .authentication(let authentication):
action = self.state.authenticationMessageReceived(authentication)
case .backendKeyData(let keyData):
action = self.state.backendKeyDataReceived(keyData)
case .bindComplete:
action = self.state.bindCompleteReceived()
case .closeComplete:
action = self.state.closeCompletedReceived()
case .commandComplete(let commandTag):
action = self.state.commandCompletedReceived(commandTag)
case .dataRow(let dataRow):
action = self.state.dataRowReceived(dataRow)
case .emptyQueryResponse:
action = self.state.emptyQueryResponseReceived()
case .error(let errorResponse):
action = self.state.errorReceived(errorResponse)
case .noData:
action = self.state.noDataReceived()
case .notice(let noticeResponse):
action = self.state.noticeReceived(noticeResponse)
case .notification(let notification):
action = self.state.notificationReceived(notification)
case .parameterDescription(let parameterDescription):
action = self.state.parameterDescriptionReceived(parameterDescription)
case .parameterStatus(let parameterStatus):
action = self.state.parameterStatusReceived(parameterStatus)
case .parseComplete:
action = self.state.parseCompleteReceived()
case .portalSuspended:
action = self.state.portalSuspendedReceived()
case .readyForQuery(let transactionState):
action = self.state.readyForQueryReceived(transactionState)
case .rowDescription(let rowDescription):
action = self.state.rowDescriptionReceived(rowDescription)
case .sslSupported:
action = self.state.sslSupportedReceived(unprocessedBytes: self.decoder.unprocessedBytes)
case .sslUnsupported:
action = self.state.sslUnsupportedReceived()
}

self.run(action, with: context)
}

func channelReadComplete(context: ChannelHandlerContext) {
let action = self.state.channelReadComplete()
self.run(action, with: context)
Expand Down
39 changes: 38 additions & 1 deletion Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,44 @@ class PostgresChannelHandlerTests: XCTestCase {

XCTAssertEqual(message, .password(.init(value: password)))
}


func testHandlerThatSendsMultipleWrongMessages() {
let config = self.testConnectionConfiguration()
let handler = PostgresChannelHandler(configuration: config, configureSSLCallback: nil)
let embedded = EmbeddedChannel(handlers: [
ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()),
handler
])

var maybeMessage: PostgresFrontendMessage?
XCTAssertNoThrow(embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil))
XCTAssertNoThrow(maybeMessage = try embedded.readOutbound(as: PostgresFrontendMessage.self))
guard case .startup(let startup) = maybeMessage else {
return XCTFail("Unexpected message")
}

XCTAssertEqual(startup.parameters.user, config.username)
XCTAssertEqual(startup.parameters.database, config.database)
XCTAssertEqual(startup.parameters.options, nil)
XCTAssertEqual(startup.parameters.replication, .false)

var buffer = ByteBuffer()
buffer.writeMultipleIntegers(UInt8(ascii: "R"), UInt32(8), Int32(0))
buffer.writeMultipleIntegers(UInt8(ascii: "K"), UInt32(12), Int32(1234), Int32(5678))
buffer.writeMultipleIntegers(UInt8(ascii: "Z"), UInt32(5), UInt8(ascii: "I"))
XCTAssertNoThrow(try embedded.writeInbound(buffer))
XCTAssertTrue(embedded.isActive)

buffer.clear()
buffer.writeMultipleIntegers(UInt8(ascii: "Z"), UInt32(5), UInt8(ascii: "I"))
buffer.writeMultipleIntegers(UInt8(ascii: "Z"), UInt32(5), UInt8(ascii: "I"))
buffer.writeMultipleIntegers(UInt8(ascii: "Z"), UInt32(5), UInt8(ascii: "I"))
buffer.writeMultipleIntegers(UInt8(ascii: "Z"), UInt32(5), UInt8(ascii: "I"))

XCTAssertThrowsError(try embedded.writeInbound(buffer))
XCTAssertFalse(embedded.isActive)
}

// MARK: Helpers

func testConnectionConfiguration(
Expand Down