From 738d0b7c90f113858fe22571a294ddfc8996a625 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 21 Feb 2024 13:04:37 +0100 Subject: [PATCH] Fix prepared statements --- .../Connection/PostgresConnection.swift | 9 ++- .../ConnectionStateMachine.swift | 8 +- .../ExtendedQueryStateMachine.swift | 14 ++-- Sources/PostgresNIO/New/PSQLTask.swift | 14 +++- .../New/PostgresChannelHandler.swift | 10 ++- .../PostgresNIO/New/PreparedStatement.swift | 23 +++++- Tests/IntegrationTests/AsyncTests.swift | 81 +++++++++++++++++++ .../PrepareStatementStateMachineTests.swift | 12 +-- .../PreparedStatementStateMachineTests.swift | 1 + .../ConnectionAction+TestUtils.swift | 4 +- .../New/PostgresConnectionTests.swift | 8 +- 11 files changed, 150 insertions(+), 34 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index f79a5555..eb9dc791 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -234,6 +234,7 @@ public final class PostgresConnection: @unchecked Sendable { let context = ExtendedQueryContext( name: name, query: query, + bindingDataTypes: [], logger: logger, promise: promise ) @@ -472,9 +473,10 @@ extension PostgresConnection { let bindings = try preparedStatement.makeBindings() let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) let task = HandlerTask.executePreparedStatement(.init( - name: String(reflecting: Statement.self), + name: Statement.name, sql: Statement.sql, bindings: bindings, + bindingDataTypes: Statement.bindingDataTypes, logger: logger, promise: promise )) @@ -493,10 +495,10 @@ extension PostgresConnection { ) throw error // rethrow with more metadata } - } /// Execute a prepared statement, taking care of the preparation when necessary + @_disfavoredOverload public func execute( _ preparedStatement: Statement, logger: Logger, @@ -506,9 +508,10 @@ extension PostgresConnection { let bindings = try preparedStatement.makeBindings() let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) let task = HandlerTask.executePreparedStatement(.init( - name: String(reflecting: Statement.self), + name: Statement.name, sql: Statement.sql, bindings: bindings, + bindingDataTypes: Statement.bindingDataTypes, logger: logger, promise: promise )) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 8c3252de..9d264bcc 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -97,7 +97,7 @@ struct ConnectionStateMachine { case forwardStreamError(PSQLError, read: Bool, cleanupContext: CleanUpContext?) // Prepare statement actions - case sendParseDescribeSync(name: String, query: String) + case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType]) case succeedPreparedStatementCreation(EventLoopPromise, with: RowDescription?) case failPreparedStatementCreation(EventLoopPromise, with: PSQLError, cleanupContext: CleanUpContext?) @@ -587,7 +587,7 @@ struct ConnectionStateMachine { switch queryContext.query { case .executeStatement(_, let promise), .unnamed(_, let promise): return .failQuery(promise, with: psqlErrror, cleanupContext: nil) - case .prepareStatement(_, _, let promise): + case .prepareStatement(_, _, _, let promise): return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil) } case .closeCommand(let closeContext): @@ -1057,8 +1057,8 @@ extension ConnectionStateMachine { return .read case .wait: return .wait - case .sendParseDescribeSync(name: let name, query: let query): - return .sendParseDescribeSync(name: name, query: query) + case .sendParseDescribeSync(name: let name, query: let query, bindingDataTypes: let bindingDataTypes): + return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes) case .succeedPreparedStatementCreation(let promise, with: let rowDescription): return .succeedPreparedStatementCreation(promise, with: rowDescription) case .failPreparedStatementCreation(let promise, with: let error): diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 3a84031b..78f0d202 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -26,7 +26,7 @@ struct ExtendedQueryStateMachine { enum Action { case sendParseDescribeBindExecuteSync(PostgresQuery) - case sendParseDescribeSync(name: String, query: String) + case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType]) case sendBindExecuteSync(PSQLExecuteStatement) // --- general actions @@ -79,10 +79,10 @@ struct ExtendedQueryStateMachine { return .sendBindExecuteSync(prepared) } - case .prepareStatement(let name, let query, _): + case .prepareStatement(let name, let query, let bindingDataTypes, _): return self.avoidingStateMachineCoW { state -> Action in state = .messagesSent(queryContext) - return .sendParseDescribeSync(name: name, query: query) + return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes) } } } @@ -107,7 +107,7 @@ struct ExtendedQueryStateMachine { case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): return .failQuery(eventLoopPromise, with: .queryCancelled) - case .prepareStatement(_, _, let eventLoopPromise): + case .prepareStatement(_, _, _, let eventLoopPromise): return .failPreparedStatementCreation(eventLoopPromise, with: .queryCancelled) } @@ -165,7 +165,7 @@ struct ExtendedQueryStateMachine { return .wait } - case .prepareStatement(_, _, let promise): + case .prepareStatement(_, _, _, let promise): return self.avoidingStateMachineCoW { state -> Action in state = .noDataMessageReceived(queryContext) return .succeedPreparedStatementCreation(promise, with: nil) @@ -200,7 +200,7 @@ struct ExtendedQueryStateMachine { case .unnamed, .executeStatement: return .wait - case .prepareStatement(_, _, let eventLoopPromise): + case .prepareStatement(_, _, _, let eventLoopPromise): return .succeedPreparedStatementCreation(eventLoopPromise, with: rowDescription) } } @@ -477,7 +477,7 @@ struct ExtendedQueryStateMachine { switch context.query { case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): return .failQuery(eventLoopPromise, with: error) - case .prepareStatement(_, _, let eventLoopPromise): + case .prepareStatement(_, _, _, let eventLoopPromise): return .failPreparedStatementCreation(eventLoopPromise, with: error) } } diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index 6308a5b3..363f9394 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -21,7 +21,7 @@ enum PSQLTask { eventLoopPromise.fail(error) case .executeStatement(_, let eventLoopPromise): eventLoopPromise.fail(error) - case .prepareStatement(_, _, let eventLoopPromise): + case .prepareStatement(_, _, _, let eventLoopPromise): eventLoopPromise.fail(error) } @@ -35,7 +35,7 @@ final class ExtendedQueryContext { enum Query { case unnamed(PostgresQuery, EventLoopPromise) case executeStatement(PSQLExecuteStatement, EventLoopPromise) - case prepareStatement(name: String, query: String, EventLoopPromise) + case prepareStatement(name: String, query: String, bindingDataTypes: [PostgresDataType], EventLoopPromise) } let query: Query @@ -62,10 +62,11 @@ final class ExtendedQueryContext { init( name: String, query: String, + bindingDataTypes: [PostgresDataType], logger: Logger, promise: EventLoopPromise ) { - self.query = .prepareStatement(name: name, query: query, promise) + self.query = .prepareStatement(name: name, query: query, bindingDataTypes: bindingDataTypes, promise) self.logger = logger } } @@ -73,6 +74,7 @@ final class ExtendedQueryContext { final class PreparedStatementContext: Sendable { let name: String let sql: String + let bindingDataTypes: [PostgresDataType] let bindings: PostgresBindings let logger: Logger let promise: EventLoopPromise @@ -81,12 +83,18 @@ final class PreparedStatementContext: Sendable { name: String, sql: String, bindings: PostgresBindings, + bindingDataTypes: [PostgresDataType], logger: Logger, promise: EventLoopPromise ) { self.name = name self.sql = sql self.bindings = bindings + if bindingDataTypes.isEmpty { + self.bindingDataTypes = bindings.metadata.map(\.dataType) + } else { + self.bindingDataTypes = bindingDataTypes + } self.logger = logger self.promise = promise } diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 54ae0fc9..32dea4a5 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -345,8 +345,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { self.closeConnectionAndCleanup(cleanupContext, context: context) case .fireChannelInactive: context.fireChannelInactive() - case .sendParseDescribeSync(let name, let query): - self.sendParseDecribeAndSyncMessage(statementName: name, query: query, context: context) + case .sendParseDescribeSync(let name, let query, let bindingDataTypes): + self.sendParseDescribeAndSyncMessage(statementName: name, query: query, bindingDataTypes: bindingDataTypes, context: context) case .sendBindExecuteSync(let executeStatement): self.sendBindExecuteAndSyncMessage(executeStatement: executeStatement, context: context) case .sendParseDescribeBindExecuteSync(let query): @@ -489,13 +489,14 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } } - private func sendParseDecribeAndSyncMessage( + private func sendParseDescribeAndSyncMessage( statementName: String, query: String, + bindingDataTypes: [PostgresDataType], context: ChannelHandlerContext ) { precondition(self.rowStream == nil, "Expected to not have an open stream at this point") - self.encoder.parse(preparedStatementName: statementName, query: query, parameters: []) + self.encoder.parse(preparedStatementName: statementName, query: query, parameters: bindingDataTypes) self.encoder.describePreparedStatement(statementName) self.encoder.sync() context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) @@ -724,6 +725,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { return .extendedQuery(.init( name: preparedStatement.name, query: preparedStatement.sql, + bindingDataTypes: preparedStatement.bindingDataTypes, logger: preparedStatement.logger, promise: promise )) diff --git a/Sources/PostgresNIO/New/PreparedStatement.swift b/Sources/PostgresNIO/New/PreparedStatement.swift index 1e0b5d5a..21165388 100644 --- a/Sources/PostgresNIO/New/PreparedStatement.swift +++ b/Sources/PostgresNIO/New/PreparedStatement.swift @@ -26,15 +26,36 @@ /// Structs conforming to this protocol can then be used with `PostgresConnection.execute(_ preparedStatement:, logger:)`, /// which will take care of preparing the statement on the server side and executing it. public protocol PostgresPreparedStatement: Sendable { + /// The prepared statements name. + /// + /// > Note: There is a default implementation that returns the implementor's name. + static var name: String { get } + /// The type rows returned by the statement will be decoded into associatedtype Row /// The SQL statement to prepare on the database server. static var sql: String { get } - /// Make the bindings to provided concrete values to use when executing the prepared SQL statement + /// The postgres data types of the values that are bind when this statement is executed. + /// + /// If an empty array is returned the datatypes are inferred from the ``PostgresBindings`` returned + /// from ``PostgresPreparedStatement/makeBindings()``. + /// + /// > Note: There is a default implementation that returns an empty array, which will lead to + /// automatic inference. + static var bindingDataTypes: [PostgresDataType] { get } + + /// Make the bindings to provided concrete values to use when executing the prepared SQL statement. + /// The order must match ``PostgresPreparedStatement/bindingDataTypes-4b6tx``. func makeBindings() throws -> PostgresBindings /// Decode a row returned by the database into an instance of `Row` func decodeRow(_ row: PostgresRow) throws -> Row } + +extension PostgresPreparedStatement { + public static var name: String { String(reflecting: self) } + + public static var bindingDataTypes: [PostgresDataType] { [] } +} diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 91b5656c..75e5b6ba 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -358,6 +358,87 @@ final class AsyncPostgresConnectionTests: XCTestCase { } } } + + static let preparedStatementTestTable = "AsyncTestPreparedStatementTestTable" + func testPreparedStatementWithIntegerBinding() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + struct InsertPreparedStatement: PostgresPreparedStatement { + static let name = "INSERT-AsyncTestPreparedStatementTestTable" + + static let sql = #"INSERT INTO "\#(AsyncPostgresConnectionTests.preparedStatementTestTable)" (uuid) VALUES ($1);"# + typealias Row = () + + var uuid: UUID + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.uuid) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + () + } + } + + struct SelectPreparedStatement: PostgresPreparedStatement { + static let name = "SELECT-AsyncTestPreparedStatementTestTable" + + static let sql = #"SELECT id, uuid FROM "\#(AsyncPostgresConnectionTests.preparedStatementTestTable)" WHERE id <= $1;"# + typealias Row = (Int, UUID) + + var id: Int + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.id) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode((Int, UUID).self) + } + } + + do { + try await withTestConnection(on: eventLoop) { connection in + try await connection.query(""" + CREATE TABLE IF NOT EXISTS "\(unescaped: Self.preparedStatementTestTable)" ( + id SERIAL PRIMARY KEY, + uuid UUID NOT NULL + ) + """, + logger: .psqlTest + ) + + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + + let rows = try await connection.execute(SelectPreparedStatement(id: 3), logger: .psqlTest) + var counter = 0 + for try await (id, uuid) in rows { + Logger.psqlTest.info("Received row", metadata: [ + "id": "\(id)", "uuid": "\(uuid)" + ]) + counter += 1 + } + + try await connection.query(""" + DROP TABLE "\(unescaped: Self.preparedStatementTestTable)"; + """, + logger: .psqlTest + ) + } + } catch { + XCTFail("Unexpected error: \(String(describing: error))") + } + } } extension XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift index 6a08afeb..547f5cdf 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift @@ -12,11 +12,11 @@ class PrepareStatementStateMachineTests: XCTestCase { let name = "haha" let query = #"SELECT id FROM users WHERE id = $1 "# let prepareStatementContext = ExtendedQueryContext( - name: name, query: query, logger: .psqlTest, promise: promise + name: name, query: query, bindingDataTypes: [], logger: .psqlTest, promise: promise ) XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)), - .sendParseDescribeSync(name: name, query: query)) + .sendParseDescribeSync(name: name, query: query, bindingDataTypes: [])) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) @@ -38,11 +38,11 @@ class PrepareStatementStateMachineTests: XCTestCase { let name = "haha" let query = #"DELETE FROM users WHERE id = $1 "# let prepareStatementContext = ExtendedQueryContext( - name: name, query: query, logger: .psqlTest, promise: promise + name: name, query: query, bindingDataTypes: [], logger: .psqlTest, promise: promise ) XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)), - .sendParseDescribeSync(name: name, query: query)) + .sendParseDescribeSync(name: name, query: query, bindingDataTypes: [])) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) @@ -60,11 +60,11 @@ class PrepareStatementStateMachineTests: XCTestCase { let name = "haha" let query = #"DELETE FROM users WHERE id = $1 "# let prepareStatementContext = ExtendedQueryContext( - name: name, query: query, logger: .psqlTest, promise: promise + name: name, query: query, bindingDataTypes: [], logger: .psqlTest, promise: promise ) XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)), - .sendParseDescribeSync(name: name, query: query)) + .sendParseDescribeSync(name: name, query: query, bindingDataTypes: [])) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift index ab77a57c..f6c1ddf7 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift @@ -152,6 +152,7 @@ class PreparedStatementStateMachineTests: XCTestCase { name: "test", sql: "INSERT INTO test_table (column1) VALUES (1)", bindings: PostgresBindings(), + bindingDataTypes: [], logger: .psqlTest, promise: promise ) diff --git a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift index d20032a8..9a1224d8 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift @@ -36,8 +36,8 @@ extension PostgresNIO.ConnectionStateMachine.ConnectionAction: Swift.Equatable { return lhsBuffer == rhsBuffer && lhsCommandTag == rhsCommandTag case (.forwardStreamError(let lhsError, let lhsRead, let lhsCleanupContext), .forwardStreamError(let rhsError , let rhsRead, let rhsCleanupContext)): return lhsError == rhsError && lhsRead == rhsRead && lhsCleanupContext == rhsCleanupContext - case (.sendParseDescribeSync(let lhsName, let lhsQuery), .sendParseDescribeSync(let rhsName, let rhsQuery)): - return lhsName == rhsName && lhsQuery == rhsQuery + case (.sendParseDescribeSync(let lhsName, let lhsQuery, let lhsDataTypes), .sendParseDescribeSync(let rhsName, let rhsQuery, let rhsDataTypes)): + return lhsName == rhsName && lhsQuery == rhsQuery && lhsDataTypes == rhsDataTypes case (.succeedPreparedStatementCreation(let lhsPromise, let lhsRowDescription), .succeedPreparedStatementCreation(let rhsPromise, let rhsRowDescription)): return lhsPromise.futureResult === rhsPromise.futureResult && lhsRowDescription == rhsRowDescription case (.fireChannelInactive, .fireChannelInactive): diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 82baf914..a773cf2c 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -337,7 +337,7 @@ class PostgresConnectionTests: XCTestCase { let prepareRequest = try await channel.waitForPrepareRequest() XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + XCTAssertEqual(prepareRequest.parse.parameters.first, .text) guard case .preparedStatement(let name) = prepareRequest.describe else { fatalError("Describe should contain a prepared statement") } @@ -393,7 +393,7 @@ class PostgresConnectionTests: XCTestCase { let prepareRequest = try await channel.waitForPrepareRequest() XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + XCTAssertEqual(prepareRequest.parse.parameters.first, .text) guard case .preparedStatement(let name) = prepareRequest.describe else { fatalError("Describe should contain a prepared statement") } @@ -487,7 +487,7 @@ class PostgresConnectionTests: XCTestCase { // The channel deduplicates prepare requests, we're going to see only one of them let prepareRequest = try await channel.waitForPrepareRequest() XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + XCTAssertEqual(prepareRequest.parse.parameters.first, .text) guard case .preparedStatement(let name) = prepareRequest.describe else { fatalError("Describe should contain a prepared statement") } @@ -555,7 +555,7 @@ class PostgresConnectionTests: XCTestCase { let prepareRequest = try await channel.waitForPrepareRequest() XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + XCTAssertEqual(prepareRequest.parse.parameters.first, .text) guard case .preparedStatement(let name) = prepareRequest.describe else { fatalError("Describe should contain a prepared statement") }