From 4954cff5a4040d979334665e97e53b2d3869afea Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Fri, 14 Jul 2023 14:02:57 +0200 Subject: [PATCH] Introduce an `AsyncConnectionWrapper` type This type turns a `diesel_async::AsyncConnection` into a `diesel::Conenction`. I see the following use-cases for this: * Having a pure rust sync diesel connection implementation for postgres and mysql can simplify the setup of new diesel projects * Allowing projects depending on `diesel_async` to use `diesel_migrations` without depending on `libpq`/`libmysqlclient` This change requires restructuring the implementation of `AsyncPgConnection` a bit so that we make the returned future `Send` independently of whether or not the query parameters are `Send`. This is possible by serialising the bind parameters before actually constructing the future. It also refactors the `TransactionManager` implementation to share more code with diesel itself. --- CHANGELOG.md | 6 +- Cargo.toml | 11 +- src/async_connection_wrapper.rs | 313 ++++++++++++++++++++++ src/doctest_setup.rs | 64 ++--- src/lib.rs | 23 +- src/mysql/mod.rs | 73 +++-- src/pg/mod.rs | 306 ++++++++++----------- src/pg/transaction_builder.rs | 5 +- src/pooled_connection/bb8.rs | 3 +- src/pooled_connection/deadpool.rs | 3 +- src/pooled_connection/mobc.rs | 3 +- src/pooled_connection/mod.rs | 22 +- src/run_query_dsl/mod.rs | 20 ++ src/stmt_cache.rs | 34 +-- src/transaction_manager.rs | 430 ++++++++++++++++-------------- tests/lib.rs | 6 +- tests/sync_wrapper.rs | 26 ++ 17 files changed, 884 insertions(+), 464 deletions(-) create mode 100644 src/async_connection_wrapper.rs create mode 100644 tests/sync_wrapper.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index c336bc3..84d4776 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ All user visible changes to this project will be documented in this file. This project adheres to [Semantic Versioning](http://semver.org/), as described for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/text/1105-api-evolution.md) +## Unreleased + +* Add a `AsyncConnectionWrapper` type to turn a `diesel_async::AsyncConnection` into a `diesel::Connection`. This might be used to execute migrations via `diesel_migrations`. + ## [0.3.2] - 2023-07-24 * Fix `TinyInt` serialization @@ -52,5 +56,3 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/ [0.2.1]: https://github.com/weiznich/diesel_async/compare/v0.2.0...v0.2.1 [0.2.2]: https://github.com/weiznich/diesel_async/compare/v0.2.1...v0.2.2 [0.3.0]: https://github.com/weiznich/diesel_async/compare/v0.2.0...v0.3.0 -[0.3.1]: https://github.com/weiznich/diesel_async/compare/v0.3.0...v0.3.1 -[0.3.2]: https://github.com/weiznich/diesel_async/compare/v0.3.1...v0.3.2 diff --git a/Cargo.toml b/Cargo.toml index 3732630..14dca0b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,11 +13,11 @@ description = "An async extension for Diesel the safe, extensible ORM and Query rust-version = "1.65.0" [dependencies] -diesel = { version = "~2.1.0", default-features = false, features = ["i-implement-a-third-party-backend-and-opt-into-breaking-changes"]} +diesel = { version = "~2.1.1", default-features = false, features = ["i-implement-a-third-party-backend-and-opt-into-breaking-changes"]} async-trait = "0.1.66" futures-channel = { version = "0.3.17", default-features = false, features = ["std", "sink"], optional = true } futures-util = { version = "0.3.17", default-features = false, features = ["std", "sink"] } -tokio-postgres = { version = "0.7.2", optional = true} +tokio-postgres = { version = "0.7.10", optional = true} tokio = { version = "1.26", optional = true} mysql_async = { version = ">=0.30.0,<0.33", optional = true} mysql_common = {version = ">=0.29.0,<0.31.0", optional = true} @@ -31,12 +31,14 @@ scoped-futures = {version = "0.1", features = ["std"]} tokio = {version = "1.12.0", features = ["rt", "macros", "rt-multi-thread"]} cfg-if = "1" chrono = "0.4" -diesel = { version = "2.0.0", default-features = false, features = ["chrono"]} +diesel = { version = "2.1.0", default-features = false, features = ["chrono"]} [features] default = [] -mysql = ["diesel/mysql_backend", "mysql_async", "mysql_common", "futures-channel"] +mysql = ["diesel/mysql_backend", "mysql_async", "mysql_common", "futures-channel", "tokio"] postgres = ["diesel/postgres_backend", "tokio-postgres", "tokio", "tokio/rt"] +async-connection-wrapper = [] +r2d2 = ["diesel/r2d2"] [[test]] name = "integration_tests" @@ -54,3 +56,4 @@ members = [ ".", "examples/postgres/pooled-with-rustls" ] + diff --git a/src/async_connection_wrapper.rs b/src/async_connection_wrapper.rs new file mode 100644 index 0000000..f93a77d --- /dev/null +++ b/src/async_connection_wrapper.rs @@ -0,0 +1,313 @@ +//! This module contains an wrapper type +//! that provides a [`diesel::Connection`] +//! implementation for types that implement +//! [`crate::AsyncConnection`]. Using this type +//! might be useful for the following usecases: +//! +//! * Executing migrations on application startup +//! * Using a pure rust diesel connection implementation +//! as replacement for the existing connection +//! implementations provided by diesel + +use futures_util::Future; +use futures_util::Stream; +use futures_util::StreamExt; +use std::pin::Pin; + +/// This is a helper trait that allows to customize the +/// async runtime used to execute futures as part of the +/// [`AsyncConnectionWrapper`] type. By default a +/// tokio runtime is used. +pub trait BlockOn { + /// This function should allow to execute a + /// given future to get the result + fn block_on(&self, f: F) -> F::Output + where + F: Future; + + /// This function should be used to construct + /// a new runtime instance + fn get_runtime() -> Self; +} + +/// A helper type that wraps an [`crate::AsyncConnectionWrapper`] to +/// provide a sync [`diesel::Connection`] implementation. +/// +/// Internally this wrapper type will use `block_on` to wait for +/// the execution of futures from the inner connection. This implies you +/// cannot use functions of this type in a scope with an already existing +/// tokio runtime. If you are in a situation where you want to use this +/// connection wrapper in the scope of an existing tokio runtime (for example +/// for running migrations via `diesel_migration`) you need to wrap +/// the relevant code block into a `tokio::task::spawn_blocking` task. +/// +/// # Examples +/// +/// ```rust +/// # include!("doctest_setup.rs"); +/// use schema::users; +/// use diesel_async::async_connection_wrapper::AsyncConnectionWrapper; +/// # +/// # fn main() -> Result<(), Box> { +/// use diesel::prelude::{RunQueryDsl, Connection}; +/// # let database_url = database_url(); +/// let mut conn = AsyncConnectionWrapper::::establish(&database_url)?; +/// +/// let all_users = users::table.load::<(i32, String)>(&mut conn)?; +/// # assert_eq!(all_users.len(), 0); +/// # Ok(()) +/// # } +/// ``` +/// +/// If you are in the scope of an existing tokio runtime you need to use +/// `tokio::task::spawn_blocking` to encapsulate the blocking tasks +/// ```rust +/// # include!("doctest_setup.rs"); +/// use schema::users; +/// use diesel_async::async_connection_wrapper::AsyncConnectionWrapper; +/// +/// async fn some_async_fn() { +/// # let database_url = database_url(); +/// // need to use `spawn_blocking` to execute +/// // a blocking task in the scope of an existing runtime +/// let res = tokio::task::spawn_blocking(move || { +/// use diesel::prelude::{RunQueryDsl, Connection}; +/// let mut conn = AsyncConnectionWrapper::::establish(&database_url)?; +/// +/// let all_users = users::table.load::<(i32, String)>(&mut conn)?; +/// # assert_eq!(all_users.len(), 0); +/// Ok::<_, Box>(()) +/// }).await; +/// +/// # res.unwrap().unwrap(); +/// } +/// +/// # #[tokio::main] +/// # async fn main() { +/// # some_async_fn().await; +/// # } +/// ``` +#[cfg(feature = "tokio")] +pub type AsyncConnectionWrapper = + self::implementation::AsyncConnectionWrapper; + +/// A helper type that wraps an [`crate::AsyncConnectionWrapper`] to +/// provide a sync [`diesel::Connection`] implementation. +/// +/// Internally this wrapper type will use `block_on` to wait for +/// the execution of futures from the inner connection. +#[cfg(not(feature = "tokio"))] +pub use self::implementation::AsyncConnectionWrapper; + +mod implementation { + use super::*; + + pub struct AsyncConnectionWrapper { + inner: C, + runtime: B, + } + + impl diesel::connection::SimpleConnection for AsyncConnectionWrapper + where + C: crate::SimpleAsyncConnection, + B: BlockOn, + { + fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> { + let f = self.inner.batch_execute(query); + self.runtime.block_on(f) + } + } + + impl diesel::connection::ConnectionSealed for AsyncConnectionWrapper {} + + impl diesel::connection::Connection for AsyncConnectionWrapper + where + C: crate::AsyncConnection, + B: BlockOn + Send, + { + type Backend = C::Backend; + + type TransactionManager = AsyncConnectionWrapperTransactionManagerWrapper; + + fn establish(database_url: &str) -> diesel::ConnectionResult { + let runtime = B::get_runtime(); + let f = C::establish(database_url); + let inner = runtime.block_on(f)?; + Ok(Self { inner, runtime }) + } + + fn execute_returning_count(&mut self, source: &T) -> diesel::QueryResult + where + T: diesel::query_builder::QueryFragment + diesel::query_builder::QueryId, + { + let f = self.inner.execute_returning_count(source); + self.runtime.block_on(f) + } + + fn transaction_state( + &mut self, + ) -> &mut >::TransactionStateData{ + self.inner.transaction_state() + } + } + + impl diesel::connection::LoadConnection for AsyncConnectionWrapper + where + C: crate::AsyncConnection, + B: BlockOn + Send, + { + type Cursor<'conn, 'query> = AsyncCursorWrapper<'conn, C::Stream<'conn, 'query>, B> + where + Self: 'conn; + + type Row<'conn, 'query> = C::Row<'conn, 'query> + where + Self: 'conn; + + fn load<'conn, 'query, T>( + &'conn mut self, + source: T, + ) -> diesel::QueryResult> + where + T: diesel::query_builder::Query + + diesel::query_builder::QueryFragment + + diesel::query_builder::QueryId + + 'query, + Self::Backend: diesel::expression::QueryMetadata, + { + let f = self.inner.load(source); + let stream = self.runtime.block_on(f)?; + + Ok(AsyncCursorWrapper { + stream: Box::pin(stream), + runtime: &self.runtime, + }) + } + } + + pub struct AsyncCursorWrapper<'a, S, B> { + stream: Pin>, + runtime: &'a B, + } + + impl<'a, S, B> Iterator for AsyncCursorWrapper<'a, S, B> + where + S: Stream, + B: BlockOn, + { + type Item = S::Item; + + fn next(&mut self) -> Option { + let f = self.stream.next(); + self.runtime.block_on(f) + } + } + + pub struct AsyncConnectionWrapperTransactionManagerWrapper; + + impl diesel::connection::TransactionManager> + for AsyncConnectionWrapperTransactionManagerWrapper + where + C: crate::AsyncConnection, + B: BlockOn + Send, + { + type TransactionStateData = + >::TransactionStateData; + + fn begin_transaction(conn: &mut AsyncConnectionWrapper) -> diesel::QueryResult<()> { + let f = >::begin_transaction( + &mut conn.inner, + ); + conn.runtime.block_on(f) + } + + fn rollback_transaction( + conn: &mut AsyncConnectionWrapper, + ) -> diesel::QueryResult<()> { + let f = >::rollback_transaction( + &mut conn.inner, + ); + conn.runtime.block_on(f) + } + + fn commit_transaction(conn: &mut AsyncConnectionWrapper) -> diesel::QueryResult<()> { + let f = >::commit_transaction( + &mut conn.inner, + ); + conn.runtime.block_on(f) + } + + fn transaction_manager_status_mut( + conn: &mut AsyncConnectionWrapper, + ) -> &mut diesel::connection::TransactionManagerStatus { + >::transaction_manager_status_mut( + &mut conn.inner, + ) + } + + fn is_broken_transaction_manager(conn: &mut AsyncConnectionWrapper) -> bool { + >::is_broken_transaction_manager( + &mut conn.inner, + ) + } + } + + #[cfg(feature = "r2d2")] + impl diesel::r2d2::R2D2Connection for AsyncConnectionWrapper + where + B: BlockOn, + Self: diesel::Connection, + C: crate::AsyncConnection::Backend> + + crate::pooled_connection::PoolableConnection, + { + fn ping(&mut self) -> diesel::QueryResult<()> { + diesel::Connection::execute_returning_count(self, &C::make_ping_query()).map(|_| ()) + } + + fn is_broken(&mut self) -> bool { + >::is_broken_transaction_manager( + &mut self.inner, + ) + } + } + + #[cfg(feature = "tokio")] + pub struct Tokio { + handle: Option, + runtime: Option, + } + + #[cfg(feature = "tokio")] + impl BlockOn for Tokio { + fn block_on(&self, f: F) -> F::Output + where + F: Future, + { + if let Some(handle) = &self.handle { + handle.block_on(f) + } else if let Some(runtime) = &self.runtime { + runtime.block_on(f) + } else { + unreachable!() + } + } + + fn get_runtime() -> Self { + if let Ok(handle) = tokio::runtime::Handle::try_current() { + Self { + handle: Some(handle), + runtime: None, + } + } else { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .build() + .unwrap(); + Self { + handle: None, + runtime: Some(runtime), + } + } + } + } +} diff --git a/src/doctest_setup.rs b/src/doctest_setup.rs index cc73b3d..b970a0b 100644 --- a/src/doctest_setup.rs +++ b/src/doctest_setup.rs @@ -1,33 +1,37 @@ -use diesel_async::*; -use diesel::prelude::*; +#[allow(unused_imports)] +use diesel::prelude::{ + AsChangeset, ExpressionMethods, Identifiable, IntoSql, QueryDsl, QueryResult, Queryable, + QueryableByName, +}; cfg_if::cfg_if! { if #[cfg(feature = "postgres")] { + use diesel_async::AsyncPgConnection; #[allow(dead_code)] type DB = diesel::pg::Pg; + #[allow(dead_code)] + type DbConnection = AsyncPgConnection; - async fn connection_no_transaction() -> AsyncPgConnection { - let connection_url = database_url_from_env("PG_DATABASE_URL"); - AsyncPgConnection::establish(&connection_url).await.unwrap() + fn database_url() -> String { + database_url_from_env("PG_DATABASE_URL") } - async fn clear_tables(connection: &mut AsyncPgConnection) { - diesel::sql_query("DROP TABLE IF EXISTS users CASCADE").execute(connection).await.unwrap(); - diesel::sql_query("DROP TABLE IF EXISTS animals CASCADE").execute(connection).await.unwrap(); - diesel::sql_query("DROP TABLE IF EXISTS posts CASCADE").execute(connection).await.unwrap(); - diesel::sql_query("DROP TABLE IF EXISTS comments CASCADE").execute(connection).await.unwrap(); - diesel::sql_query("DROP TABLE IF EXISTS brands CASCADE").execute(connection).await.unwrap(); + async fn connection_no_transaction() -> AsyncPgConnection { + use diesel_async::AsyncConnection; + let connection_url = database_url(); + AsyncPgConnection::establish(&connection_url).await.unwrap() } async fn connection_no_data() -> AsyncPgConnection { + use diesel_async::AsyncConnection; let mut connection = connection_no_transaction().await; connection.begin_test_transaction().await.unwrap(); - clear_tables(&mut connection).await; connection } async fn create_tables(connection: &mut AsyncPgConnection) { - diesel::sql_query("CREATE TABLE IF NOT EXISTS users ( + use diesel_async::RunQueryDsl; + diesel::sql_query("CREATE TEMPORARY TABLE users ( id SERIAL PRIMARY KEY, name VARCHAR NOT NULL )").execute(connection).await.unwrap(); @@ -36,7 +40,7 @@ cfg_if::cfg_if! { ).execute(connection).await.unwrap(); diesel::sql_query( - "CREATE TABLE IF NOT EXISTS animals ( + "CREATE TEMPORARY TABLE animals ( id SERIAL PRIMARY KEY, species VARCHAR NOT NULL, legs INTEGER NOT NULL, @@ -50,7 +54,7 @@ cfg_if::cfg_if! { .await.unwrap(); diesel::sql_query( - "CREATE TABLE IF NOT EXISTS posts ( + "CREATE TEMPORARY TABLE posts ( id SERIAL PRIMARY KEY, user_id INTEGER NOT NULL, title VARCHAR NOT NULL @@ -61,7 +65,7 @@ cfg_if::cfg_if! { (1, 'About Rust'), (2, 'My first post too')").execute(connection).await.unwrap(); - diesel::sql_query("CREATE TABLE IF NOT EXISTS comments ( + diesel::sql_query("CREATE TEMPORARY TABLE comments ( id SERIAL PRIMARY KEY, post_id INTEGER NOT NULL, body VARCHAR NOT NULL @@ -71,7 +75,7 @@ cfg_if::cfg_if! { (2, 'Yay! I am learning Rust'), (3, 'I enjoyed your post')").execute(connection).await.unwrap(); - diesel::sql_query("CREATE TABLE IF NOT EXISTS brands ( + diesel::sql_query("CREATE TEMPORARY TABLE brands ( id SERIAL PRIMARY KEY, color VARCHAR NOT NULL DEFAULT 'Green', accent VARCHAR DEFAULT 'Blue' @@ -85,28 +89,26 @@ cfg_if::cfg_if! { connection } } else if #[cfg(feature = "mysql")] { + use diesel_async::AsyncMysqlConnection; #[allow(dead_code)] type DB = diesel::mysql::Mysql; + #[allow(dead_code)] + type DbConnection = AsyncMysqlConnection; - async fn clear_tables(connection: &mut AsyncMysqlConnection) { - diesel::sql_query("SET FOREIGN_KEY_CHECKS=0;").execute(connection).await.unwrap(); - diesel::sql_query("DROP TABLE IF EXISTS users CASCADE").execute(connection).await.unwrap(); - diesel::sql_query("DROP TABLE IF EXISTS animals CASCADE").execute(connection).await.unwrap(); - diesel::sql_query("DROP TABLE IF EXISTS posts CASCADE").execute(connection).await.unwrap(); - diesel::sql_query("DROP TABLE IF EXISTS comments CASCADE").execute(connection).await.unwrap(); - diesel::sql_query("DROP TABLE IF EXISTS brands CASCADE").execute(connection).await.unwrap(); - diesel::sql_query("SET FOREIGN_KEY_CHECKS=1;").execute(connection).await.unwrap(); + fn database_url() -> String { + database_url_from_env("MYSQL_UNIT_TEST_DATABASE_URL") } async fn connection_no_data() -> AsyncMysqlConnection { - let connection_url = database_url_from_env("MYSQL_UNIT_TEST_DATABASE_URL"); - let mut connection = AsyncMysqlConnection::establish(&connection_url).await.unwrap(); - clear_tables(&mut connection).await; - connection + use diesel_async::AsyncConnection; + let connection_url = database_url(); + AsyncMysqlConnection::establish(&connection_url).await.unwrap() } async fn create_tables(connection: &mut AsyncMysqlConnection) { - diesel::sql_query("CREATE TEMPORARY TABLE IF NOT EXISTS users ( + use diesel_async::RunQueryDsl; + use diesel_async::AsyncConnection; + diesel::sql_query("CREATE TEMPORARY TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTO_INCREMENT, name TEXT NOT NULL ) CHARACTER SET utf8mb4").execute(connection).await.unwrap(); @@ -173,8 +175,6 @@ cfg_if::cfg_if! { fn database_url_from_env(backend_specific_env_var: &str) -> String { use std::env; - //dotenv().ok(); - env::var(backend_specific_env_var) .or_else(|_| env::var("DATABASE_URL")) .expect("DATABASE_URL must be set in order to run tests") diff --git a/src/lib.rs b/src/lib.rs index a2124c2..b86393b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -78,11 +78,18 @@ use std::fmt::Debug; pub use scoped_futures; use scoped_futures::{ScopedBoxFuture, ScopedFutureExt}; +#[cfg(feature = "async-connection-wrapper")] +pub mod async_connection_wrapper; #[cfg(feature = "mysql")] mod mysql; #[cfg(feature = "postgres")] pub mod pg; -#[cfg(any(feature = "deadpool", feature = "bb8", feature = "mobc"))] +#[cfg(any( + feature = "deadpool", + feature = "bb8", + feature = "mobc", + feature = "r2d2" +))] pub mod pooled_connection; mod run_query_dsl; mod stmt_cache; @@ -98,9 +105,7 @@ pub use self::pg::AsyncPgConnection; pub use self::run_query_dsl::*; #[doc(inline)] -pub use self::transaction_manager::{ - AnsiTransactionManager, TransactionManager, TransactionManagerStatus, -}; +pub use self::transaction_manager::{AnsiTransactionManager, TransactionManager}; /// Perform simple operations on a backend. /// @@ -187,6 +192,7 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send { /// # include!("doctest_setup.rs"); /// use diesel::result::Error; /// use scoped_futures::ScopedFutureExt; + /// use diesel_async::{RunQueryDsl, AsyncConnection}; /// /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() { @@ -240,7 +246,7 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send { /// tests. Panics if called while inside of a transaction or /// if called with a connection containing a broken transaction async fn begin_test_transaction(&mut self) -> QueryResult<()> { - use crate::transaction_manager::TransactionManagerStatus; + use diesel::connection::TransactionManagerStatus; match Self::TransactionManager::transaction_manager_status_mut(self) { TransactionManagerStatus::Valid(valid_status) => { @@ -266,6 +272,7 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send { /// # include!("doctest_setup.rs"); /// use diesel::result::Error; /// use scoped_futures::ScopedFutureExt; + /// use diesel_async::{RunQueryDsl, AsyncConnection}; /// /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() { @@ -319,8 +326,8 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send { #[doc(hidden)] fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> where - T: AsQuery + Send + 'query, - T::Query: QueryFragment + QueryId + Send + 'query; + T: AsQuery + 'query, + T::Query: QueryFragment + QueryId + 'query; #[doc(hidden)] fn execute_returning_count<'conn, 'query, T>( @@ -328,7 +335,7 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send { source: T, ) -> Self::ExecuteFuture<'conn, 'query> where - T: QueryFragment + QueryId + Send + 'query; + T: QueryFragment + QueryId + 'query; #[doc(hidden)] fn transaction_state( diff --git a/src/mysql/mod.rs b/src/mysql/mod.rs index 14d2279..f460c8d 100644 --- a/src/mysql/mod.rs +++ b/src/mysql/mod.rs @@ -1,13 +1,14 @@ use crate::stmt_cache::{PrepareCallback, StmtCache}; use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection}; -use diesel::connection::statement_cache::MaybeCached; -use diesel::mysql::{Mysql, MysqlType}; +use diesel::connection::statement_cache::{MaybeCached, StatementCacheKey}; +use diesel::mysql::{Mysql, MysqlQueryBuilder, MysqlType}; +use diesel::query_builder::QueryBuilder; use diesel::query_builder::{bind_collector::RawBytesBindCollector, QueryFragment, QueryId}; use diesel::result::{ConnectionError, ConnectionResult}; use diesel::QueryResult; -use futures_util::future::{self, BoxFuture}; +use futures_util::future::BoxFuture; use futures_util::stream::{self, BoxStream}; -use futures_util::{Future, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; +use futures_util::{Future, FutureExt, StreamExt, TryStreamExt}; use mysql_async::prelude::Queryable; use mysql_async::{Opts, OptsBuilder, Statement}; @@ -69,10 +70,9 @@ impl AsyncConnection for AsyncMysqlConnection { fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> where - T: diesel::query_builder::AsQuery + Send, + T: diesel::query_builder::AsQuery, T::Query: diesel::query_builder::QueryFragment + diesel::query_builder::QueryId - + Send + 'query, { self.with_prepared_statement(source.as_query(), |conn, stmt, binds| async move { @@ -126,7 +126,6 @@ impl AsyncConnection for AsyncMysqlConnection { where T: diesel::query_builder::QueryFragment + diesel::query_builder::QueryId - + Send + 'query, { self.with_prepared_statement(source, |conn, stmt, binds| async move { @@ -166,7 +165,7 @@ fn update_transaction_manager_status( { transaction_manager .status - .set_top_level_transaction_requires_rollback() + .set_requires_rollback_maybe_up_to_top_level(true) } query_result } @@ -216,16 +215,13 @@ impl AsyncMysqlConnection { ) -> BoxFuture<'conn, QueryResult> where R: Send + 'conn, - T: QueryFragment + QueryId + Send, + T: QueryFragment + QueryId, F: Future> + Send, { let mut bind_collector = RawBytesBindCollector::::new(); - if let Err(e) = query.collect_binds(&mut bind_collector, &mut (), &Mysql) { - return future::ready(Err(e)).boxed(); - } - - let binds = bind_collector.binds; - let metadata = bind_collector.metadata; + let bind_collector = query + .collect_binds(&mut bind_collector, &mut (), &Mysql) + .map(|()| bind_collector); let AsyncMysqlConnection { ref mut conn, @@ -234,14 +230,40 @@ impl AsyncMysqlConnection { .. } = self; - let stmt = stmt_cache.cached_prepared_statement(query, &metadata, conn, &Mysql); + let is_safe_to_cache_prepared = query.is_safe_to_cache_prepared(&Mysql); + let mut qb = MysqlQueryBuilder::new(); + let sql = query.to_sql(&mut qb, &Mysql).map(|()| qb.finish()); + let query_id = T::query_id(); + + async move { + let RawBytesBindCollector { + metadata, binds, .. + } = bind_collector?; + let is_safe_to_cache_prepared = is_safe_to_cache_prepared?; + let sql = sql?; + let cache_key = if let Some(query_id) = query_id { + StatementCacheKey::Type(query_id) + } else { + StatementCacheKey::Sql { + sql: sql.clone(), + bind_types: metadata.clone(), + } + }; - stmt.and_then(|(stmt, conn)| async move { + let (stmt, conn) = stmt_cache + .cached_prepared_statement( + cache_key, + sql, + is_safe_to_cache_prepared, + &metadata, + conn, + ) + .await?; update_transaction_manager_status( callback(conn, stmt, ToSqlHelper { metadata, binds }).await, transaction_manager, ) - }) + } .boxed() } @@ -279,8 +301,19 @@ impl AsyncMysqlConnection { } } -#[cfg(any(feature = "deadpool", feature = "bb8", feature = "mobc"))] -impl crate::pooled_connection::PoolableConnection for AsyncMysqlConnection {} +#[cfg(any( + feature = "deadpool", + feature = "bb8", + feature = "mobc", + feature = "r2d2" +))] +impl crate::pooled_connection::PoolableConnection for AsyncMysqlConnection { + type PingQuery = crate::pooled_connection::CheckConnectionQuery; + + fn make_ping_query() -> Self::PingQuery { + crate::pooled_connection::CheckConnectionQuery + } +} #[cfg(test)] mod tests { diff --git a/src/pg/mod.rs b/src/pg/mod.rs index 6a4832b..0de6fa4 100644 --- a/src/pg/mod.rs +++ b/src/pg/mod.rs @@ -9,19 +9,20 @@ use self::row::PgRow; use self::serialize::ToSqlHelper; use crate::stmt_cache::{PrepareCallback, StmtCache}; use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection}; -use diesel::connection::statement_cache::PrepareForCache; +use diesel::connection::statement_cache::{PrepareForCache, StatementCacheKey}; use diesel::pg::{ - FailedToLookupTypeError, PgMetadataCache, PgMetadataCacheKey, PgMetadataLookup, PgTypeMetadata, + FailedToLookupTypeError, Pg, PgMetadataCache, PgMetadataCacheKey, PgMetadataLookup, + PgQueryBuilder, PgTypeMetadata, }; use diesel::query_builder::bind_collector::RawBytesBindCollector; -use diesel::query_builder::{AsQuery, QueryFragment, QueryId}; +use diesel::query_builder::{AsQuery, QueryBuilder, QueryFragment, QueryId}; use diesel::{ConnectionError, ConnectionResult, QueryResult}; use futures_util::future::BoxFuture; -use futures_util::lock::Mutex; use futures_util::stream::{BoxStream, TryStreamExt}; use futures_util::{Future, FutureExt, StreamExt}; use std::borrow::Cow; use std::sync::Arc; +use tokio::sync::Mutex; use tokio_postgres::types::ToSql; use tokio_postgres::types::Type; use tokio_postgres::Statement; @@ -71,6 +72,8 @@ mod transaction_builder; /// /// ```rust /// # include!("../doctest_setup.rs"); +/// use diesel_async::RunQueryDsl; +/// /// # /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() { @@ -98,7 +101,7 @@ pub struct AsyncPgConnection { conn: Arc, stmt_cache: Arc>>, transaction_state: Arc>, - metadata_cache: Arc>>, + metadata_cache: Arc>, } #[async_trait::async_trait] @@ -131,29 +134,18 @@ impl AsyncConnection for AsyncPgConnection { fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> where - T: AsQuery + Send + 'query, - T::Query: QueryFragment + QueryId + Send + 'query, + T: AsQuery + 'query, + T::Query: QueryFragment + QueryId + 'query, { - let conn = self.conn.clone(); - let stmt_cache = self.stmt_cache.clone(); - let metadata_cache = self.metadata_cache.clone(); - let tm = self.transaction_state.clone(); let query = source.as_query(); - Self::with_prepared_statement( - conn, - stmt_cache, - metadata_cache, - tm, - query, - |conn, stmt, binds| async move { - let res = conn.query_raw(&stmt, binds).await.map_err(ErrorHelper)?; - - Ok(res - .map_err(|e| diesel::result::Error::from(ErrorHelper(e))) - .map_ok(PgRow::new) - .boxed()) - }, - ) + self.with_prepared_statement(query, |conn, stmt, binds| async move { + let res = conn.query_raw(&stmt, binds).await.map_err(ErrorHelper)?; + + Ok(res + .map_err(|e| diesel::result::Error::from(ErrorHelper(e))) + .map_ok(PgRow::new) + .boxed()) + }) .boxed() } @@ -162,26 +154,19 @@ impl AsyncConnection for AsyncPgConnection { source: T, ) -> Self::ExecuteFuture<'conn, 'query> where - T: QueryFragment + QueryId + Send + 'query, + T: QueryFragment + QueryId + 'query, { - Self::with_prepared_statement( - self.conn.clone(), - self.stmt_cache.clone(), - self.metadata_cache.clone(), - self.transaction_state.clone(), - source, - |conn, stmt, binds| async move { - let binds = binds - .iter() - .map(|b| b as &(dyn ToSql + Sync)) - .collect::>(); - - let res = tokio_postgres::Client::execute(&conn, &stmt, &binds as &[_]) - .await - .map_err(ErrorHelper)?; - Ok(res as usize) - }, - ) + self.with_prepared_statement(source, |conn, stmt, binds| async move { + let binds = binds + .iter() + .map(|b| b as &(dyn ToSql + Sync)) + .collect::>(); + + let res = tokio_postgres::Client::execute(&conn, &stmt, &binds as &[_]) + .await + .map_err(ErrorHelper)?; + Ok(res as usize) + }) .boxed() } @@ -209,7 +194,7 @@ fn update_transaction_manager_status( { transaction_manager .status - .set_top_level_transaction_requires_rollback() + .set_requires_rollback_maybe_up_to_top_level(true) } query_result } @@ -226,6 +211,7 @@ impl PrepareCallback for Arc .iter() .map(type_from_oid) .collect::>>()?; + let stmt = self .prepare_typed(sql, &bind_types) .await @@ -288,7 +274,7 @@ impl AsyncPgConnection { conn: Arc::new(conn), stmt_cache: Arc::new(Mutex::new(StmtCache::new())), transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())), - metadata_cache: Arc::new(Mutex::new(Some(PgMetadataCache::new()))), + metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())), }; conn.set_config_options() .await @@ -313,116 +299,131 @@ impl AsyncPgConnection { Ok(()) } - async fn with_prepared_statement<'a, T, F, R>( - raw_connection: Arc, - stmt_cache: Arc>>, - metadata_cache: Arc>>, - tm: Arc>, + fn with_prepared_statement<'a, T, F, R>( + &mut self, query: T, - callback: impl FnOnce(Arc, Statement, Vec) -> F, - ) -> QueryResult + callback: impl FnOnce(Arc, Statement, Vec) -> F + Send + 'a, + ) -> BoxFuture<'a, QueryResult> where - T: QueryFragment + QueryId + Send, - F: Future>, + T: QueryFragment + QueryId, + F: Future> + Send, + R: Send, { - let mut bind_collector; - { - loop { - // we need a new bind collector per iteration here - bind_collector = RawBytesBindCollector::::new(); - - let (res, unresolved_types) = { - let mut metadata_cache_lock = metadata_cache.lock().await; - let mut metadata_lookup = - PgAsyncMetadataLookup::new(metadata_cache_lock.take().unwrap_or_default()); - - let res = query.collect_binds( - &mut bind_collector, - &mut metadata_lookup, - &diesel::pg::Pg, - ); - - let PgAsyncMetadataLookup { - unresolved_types, - metadata_cache, - } = metadata_lookup; - *metadata_cache_lock = Some(metadata_cache); - (res, unresolved_types) - }; - - if !unresolved_types.is_empty() { - for (schema, lookup_type_name) in unresolved_types { - // as this is an async call and we don't want to infect the whole diesel serialization - // api with async we just error out in the `PgMetadataLookup` implementation below if we encounter - // a type that is not cached yet - // If that's the case we will do the lookup here and try again as the - // type is now cached. - let type_metadata = - lookup_type(schema.clone(), lookup_type_name.clone(), &raw_connection) - .await?; - let mut metadata_cache_lock = metadata_cache.lock().await; - let metadata_cache = - if let Some(ref mut metadata_cache) = *metadata_cache_lock { - metadata_cache + // we explicilty descruct the query here before going into the async block + // + // That's required to remove the send bound from `T` as we have translated + // the query type to just a string (for the SQL) and a bunch of bytes (for the binds) + // which both are `Send`. + // We also collect the query id (essentially an integer) and the safe_to_cache flag here + // so there is no need to even access the query in the async block below + let is_safe_to_cache_prepared = query.is_safe_to_cache_prepared(&diesel::pg::Pg); + let mut query_builder = PgQueryBuilder::default(); + let sql = query + .to_sql(&mut query_builder, &Pg) + .map(|_| query_builder.finish()); + + let mut bind_collector = RawBytesBindCollector::::new(); + let query_id = T::query_id(); + + // we don't resolve custom types here yet, we do that later + // in the async block below as we might need to perform lookup + // queries for that. + // + // We apply this workaround to prevent requiring all the diesel + // serialization code to beeing async + let mut metadata_lookup = PgAsyncMetadataLookup::new(); + let collect_bind_result = + query.collect_binds(&mut bind_collector, &mut metadata_lookup, &Pg); + + let raw_connection = self.conn.clone(); + let stmt_cache = self.stmt_cache.clone(); + let metadata_cache = self.metadata_cache.clone(); + let tm = self.transaction_state.clone(); + + async move { + let sql = sql?; + let is_safe_to_cache_prepared = is_safe_to_cache_prepared?; + collect_bind_result?; + // Check whether we need to resolve some types at all + // + // If the user doesn't use custom types there is no need + // to borther with that at all + if !metadata_lookup.unresolved_types.is_empty() { + let metadata_cache = &mut *metadata_cache.lock().await; + let mut next_unresolved = metadata_lookup.unresolved_types.into_iter(); + for m in &mut bind_collector.metadata { + // for each unresolved item + // we check whether it's arleady in the cache + // or perform a lookup and insert it into the cache + if m.oid().is_err() { + if let Some((ref schema, ref lookup_type_name)) = next_unresolved.next() { + let cache_key = PgMetadataCacheKey::new( + schema.as_ref().map(Into::into), + lookup_type_name.into(), + ); + if let Some(entry) = metadata_cache.lookup_type(&cache_key) { + *m = entry; } else { - *metadata_cache_lock = Some(Default::default()); - metadata_cache_lock.as_mut().expect("We set it above") - }; - - metadata_cache.store_type( - PgMetadataCacheKey::new( - schema.map(Cow::Owned), - Cow::Owned(lookup_type_name), - ), - type_metadata, - ); - // just try again to get the binds, now that we've inserted the - // type into the lookup list + let type_metadata = lookup_type( + schema.clone(), + lookup_type_name.clone(), + &raw_connection, + ) + .await?; + *m = PgTypeMetadata::from_result(Ok(type_metadata)); + + metadata_cache.store_type(cache_key, type_metadata); + } + } else { + break; + } } - } else { - // bubble up any error as soon as we have done all lookups - res?; - break; } } + let key = match query_id { + Some(id) => StatementCacheKey::Type(id), + None => StatementCacheKey::Sql { + sql: sql.clone(), + bind_types: bind_collector.metadata.clone(), + }, + }; + let stmt = { + let mut stmt_cache = stmt_cache.lock().await; + stmt_cache + .cached_prepared_statement( + key, + sql, + is_safe_to_cache_prepared, + &bind_collector.metadata, + raw_connection.clone(), + ) + .await? + .0 + .clone() + }; + + let binds = bind_collector + .metadata + .into_iter() + .zip(bind_collector.binds) + .map(|(meta, bind)| ToSqlHelper(meta, bind)) + .collect::>(); + let res = callback(raw_connection, stmt.clone(), binds).await; + let mut tm = tm.lock().await; + update_transaction_manager_status(res, &mut tm) } - - let stmt = { - let mut stmt_cache = stmt_cache.lock().await; - stmt_cache - .cached_prepared_statement( - query, - &bind_collector.metadata, - raw_connection.clone(), - &diesel::pg::Pg, - ) - .await? - .0 - .clone() - }; - - let binds = bind_collector - .metadata - .into_iter() - .zip(bind_collector.binds) - .map(|(meta, bind)| ToSqlHelper(meta, bind)) - .collect::>(); - let res = callback(raw_connection, stmt.clone(), binds).await; - let mut tm = tm.lock().await; - update_transaction_manager_status(res, &mut tm) + .boxed() } } struct PgAsyncMetadataLookup { unresolved_types: Vec<(Option, String)>, - metadata_cache: PgMetadataCache, } impl PgAsyncMetadataLookup { - fn new(metadata_cache: PgMetadataCache) -> Self { + fn new() -> Self { Self { unresolved_types: Vec::new(), - metadata_cache, } } } @@ -432,14 +433,10 @@ impl PgMetadataLookup for PgAsyncMetadataLookup { let cache_key = PgMetadataCacheKey::new(schema.map(Cow::Borrowed), Cow::Borrowed(type_name)); - if let Some(metadata) = self.metadata_cache.lookup_type(&cache_key) { - metadata - } else { - let cache_key = cache_key.into_owned(); - self.unresolved_types - .push((schema.map(ToOwned::to_owned), type_name.to_owned())); - PgTypeMetadata::from_result(Err(FailedToLookupTypeError::new(cache_key))) - } + let cache_key = cache_key.into_owned(); + self.unresolved_types + .push((schema.map(ToOwned::to_owned), type_name.to_owned())); + PgTypeMetadata::from_result(Err(FailedToLookupTypeError::new(cache_key))) } } @@ -473,8 +470,19 @@ async fn lookup_type( Ok((r.get(0), r.get(1))) } -#[cfg(any(feature = "deadpool", feature = "bb8", feature = "mobc"))] -impl crate::pooled_connection::PoolableConnection for AsyncPgConnection {} +#[cfg(any( + feature = "deadpool", + feature = "bb8", + feature = "mobc", + feature = "r2d2" +))] +impl crate::pooled_connection::PoolableConnection for AsyncPgConnection { + type PingQuery = crate::pooled_connection::CheckConnectionQuery; + + fn make_ping_query() -> Self::PingQuery { + crate::pooled_connection::CheckConnectionQuery + } +} #[cfg(test)] pub mod tests { diff --git a/src/pg/transaction_builder.rs b/src/pg/transaction_builder.rs index fa52dfa..1096433 100644 --- a/src/pg/transaction_builder.rs +++ b/src/pg/transaction_builder.rs @@ -43,13 +43,14 @@ where /// ```rust /// # include!("../doctest_setup.rs"); /// # use diesel::sql_query; + /// use diesel_async::RunQueryDsl; /// # /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() { /// # run_test().await.unwrap(); /// # } /// # - /// # table! { + /// # diesel::table! { /// # users_for_read_only { /// # id -> Integer, /// # name -> Text, @@ -98,6 +99,8 @@ where /// # include!("../doctest_setup.rs"); /// # use diesel::result::Error::RollbackTransaction; /// # use diesel::sql_query; + /// use diesel_async::RunQueryDsl; + /// /// # /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() { diff --git a/src/pooled_connection/bb8.rs b/src/pooled_connection/bb8.rs index efd87f6..c456b58 100644 --- a/src/pooled_connection/bb8.rs +++ b/src/pooled_connection/bb8.rs @@ -6,7 +6,7 @@ //! use futures_util::FutureExt; //! use diesel_async::pooled_connection::AsyncDieselConnectionManager; //! use diesel_async::pooled_connection::bb8::Pool; -//! use diesel_async::RunQueryDsl; +//! use diesel_async::{RunQueryDsl, AsyncConnection}; //! //! # #[tokio::main(flavor = "current_thread")] //! # async fn main() { @@ -33,7 +33,6 @@ //! let pool = Pool::builder().build(config).await?; //! let mut conn = pool.get().await?; //! # conn.begin_test_transaction(); -//! # clear_tables(&mut conn).await; //! # create_tables(&mut conn).await; //! # #[cfg(feature = "mysql")] //! # conn.begin_test_transaction(); diff --git a/src/pooled_connection/deadpool.rs b/src/pooled_connection/deadpool.rs index 296fb56..8914ec7 100644 --- a/src/pooled_connection/deadpool.rs +++ b/src/pooled_connection/deadpool.rs @@ -6,7 +6,7 @@ //! use futures_util::FutureExt; //! use diesel_async::pooled_connection::AsyncDieselConnectionManager; //! use diesel_async::pooled_connection::deadpool::Pool; -//! use diesel_async::RunQueryDsl; +//! use diesel_async::{RunQueryDsl, AsyncConnection}; //! //! # #[tokio::main(flavor = "current_thread")] //! # async fn main() { @@ -33,7 +33,6 @@ //! let pool = Pool::builder(config).build()?; //! let mut conn = pool.get().await?; //! # conn.begin_test_transaction(); -//! # clear_tables(&mut conn).await; //! # create_tables(&mut conn).await; //! # conn.begin_test_transaction(); //! let res = users.load::<(i32, String)>(&mut conn).await?; diff --git a/src/pooled_connection/mobc.rs b/src/pooled_connection/mobc.rs index dbe2270..bde77f2 100644 --- a/src/pooled_connection/mobc.rs +++ b/src/pooled_connection/mobc.rs @@ -6,7 +6,7 @@ //! use futures_util::FutureExt; //! use diesel_async::pooled_connection::AsyncDieselConnectionManager; //! use diesel_async::pooled_connection::mobc::Pool; -//! use diesel_async::RunQueryDsl; +//! use diesel_async::{RunQueryDsl, AsyncConnection}; //! //! # #[tokio::main(flavor = "current_thread")] //! # async fn main() { @@ -33,7 +33,6 @@ //! let pool = Pool::new(config); //! let mut conn = pool.get().await?; //! # conn.begin_test_transaction(); -//! # clear_tables(&mut conn).await; //! # create_tables(&mut conn).await; //! # conn.begin_test_transaction(); //! let res = users.load::<(i32, String)>(&mut conn).await?; diff --git a/src/pooled_connection/mod.rs b/src/pooled_connection/mod.rs index 1824702..1ab0ebe 100644 --- a/src/pooled_connection/mod.rs +++ b/src/pooled_connection/mod.rs @@ -8,6 +8,7 @@ use crate::{AsyncConnection, SimpleAsyncConnection}; use crate::{TransactionManager, UpdateAndFetchResults}; use diesel::associations::HasTable; +use diesel::query_builder::{QueryFragment, QueryId}; use diesel::QueryResult; use futures_util::{future, FutureExt}; use std::fmt; @@ -132,10 +133,9 @@ where fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> where - T: diesel::query_builder::AsQuery + Send + 'query, + T: diesel::query_builder::AsQuery + 'query, T::Query: diesel::query_builder::QueryFragment + diesel::query_builder::QueryId - + Send + 'query, { let conn = self.deref_mut(); @@ -149,7 +149,6 @@ where where T: diesel::query_builder::QueryFragment + diesel::query_builder::QueryId - + Send + 'query, { let conn = self.deref_mut(); @@ -195,13 +194,17 @@ where fn transaction_manager_status_mut( conn: &mut C, - ) -> &mut crate::transaction_manager::TransactionManagerStatus { + ) -> &mut diesel::connection::TransactionManagerStatus { TM::transaction_manager_status_mut(&mut **conn) } + + fn is_broken_transaction_manager(conn: &mut C) -> bool { + TM::is_broken_transaction_manager(&mut **conn) + } } #[async_trait::async_trait] -impl<'b, Changes, Output, Conn> UpdateAndFetchResults for Conn +impl UpdateAndFetchResults for Conn where Conn: DerefMut + Send, Changes: diesel::prelude::Identifiable + HasTable + Send, @@ -215,8 +218,9 @@ where } } +#[doc(hidden)] #[derive(diesel::query_builder::QueryId)] -struct CheckConnectionQuery; +pub struct CheckConnectionQuery; impl diesel::query_builder::QueryFragment for CheckConnectionQuery where @@ -240,6 +244,10 @@ impl diesel::query_dsl::RunQueryDsl for CheckConnectionQuery {} #[doc(hidden)] #[async_trait::async_trait] pub trait PoolableConnection: AsyncConnection { + type PingQuery: QueryFragment + QueryId + Send; + + fn make_ping_query() -> Self::PingQuery; + /// Check if a connection is still valid /// /// The default implementation performs a `SELECT 1` query @@ -248,7 +256,7 @@ pub trait PoolableConnection: AsyncConnection { for<'a> Self: 'a, { use crate::RunQueryDsl; - CheckConnectionQuery.execute(self).await.map(|_| ()) + Self::make_ping_query().execute(self).await.map(|_| ()) } /// Checks if the connection is broken and should not be reused diff --git a/src/run_query_dsl/mod.rs b/src/run_query_dsl/mod.rs index 580bf01..6e12f02 100644 --- a/src/run_query_dsl/mod.rs +++ b/src/run_query_dsl/mod.rs @@ -191,6 +191,8 @@ pub trait RunQueryDsl: Sized { /// ```rust /// # include!("../doctest_setup.rs"); /// # + /// use diesel_async::RunQueryDsl; + /// /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() { /// # run_test().await; @@ -245,6 +247,9 @@ pub trait RunQueryDsl: Sized { /// ```rust /// # include!("../doctest_setup.rs"); /// # + /// use diesel_async::{RunQueryDsl, AsyncConnection}; + /// + /// # /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() { /// # run_test().await; @@ -266,6 +271,8 @@ pub trait RunQueryDsl: Sized { /// /// ```rust /// # include!("../doctest_setup.rs"); + /// use diesel_async::RunQueryDsl; + /// /// # /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() { @@ -292,6 +299,8 @@ pub trait RunQueryDsl: Sized { /// /// ```rust /// # include!("../doctest_setup.rs"); + /// use diesel_async::RunQueryDsl; + /// /// # /// #[derive(Queryable, PartialEq, Debug)] /// struct User { @@ -364,6 +373,8 @@ pub trait RunQueryDsl: Sized { /// ```rust /// # include!("../doctest_setup.rs"); /// # + /// use diesel_async::RunQueryDsl; + /// /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() { /// # run_test().await; @@ -391,6 +402,7 @@ pub trait RunQueryDsl: Sized { /// /// ```rust /// # include!("../doctest_setup.rs"); + /// use diesel_async::RunQueryDsl; /// # /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() { @@ -424,6 +436,8 @@ pub trait RunQueryDsl: Sized { /// ```rust /// # include!("../doctest_setup.rs"); /// # + /// use diesel_async::RunQueryDsl; + /// /// #[derive(Queryable, PartialEq, Debug)] /// struct User { /// id: i32, @@ -482,6 +496,8 @@ pub trait RunQueryDsl: Sized { /// /// ```rust /// # include!("../doctest_setup.rs"); + /// use diesel_async::RunQueryDsl; + /// /// # /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() { @@ -577,6 +593,8 @@ pub trait RunQueryDsl: Sized { /// /// ```rust /// # include!("../doctest_setup.rs"); + /// use diesel_async::RunQueryDsl; + /// /// # /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() { @@ -634,6 +652,8 @@ impl RunQueryDsl for T {} /// # include!("../doctest_setup.rs"); /// # use schema::animals; /// # +/// use diesel_async::{SaveChangesDsl, AsyncConnection}; +/// /// #[derive(Queryable, Debug, PartialEq)] /// struct Animal { /// id: i32, diff --git a/src/stmt_cache.rs b/src/stmt_cache.rs index 9f0040e..53a7bac 100644 --- a/src/stmt_cache.rs +++ b/src/stmt_cache.rs @@ -3,7 +3,6 @@ use std::hash::Hash; use diesel::backend::Backend; use diesel::connection::statement_cache::{MaybeCached, PrepareForCache, StatementCacheKey}; -use diesel::query_builder::{QueryFragment, QueryId}; use diesel::QueryResult; use futures_util::{future, FutureExt}; @@ -18,15 +17,13 @@ type PrepareFuture<'a, F, S> = future::Either< >; #[async_trait::async_trait] -pub trait PrepareCallback { +pub trait PrepareCallback: Sized { async fn prepare( self, sql: &str, metadata: &[M], is_for_cache: PrepareForCache, - ) -> QueryResult<(S, Self)> - where - Self: Sized; + ) -> QueryResult<(S, Self)>; } impl StmtCache { @@ -36,39 +33,24 @@ impl StmtCache { } } - pub fn cached_prepared_statement<'a, T, F>( + pub fn cached_prepared_statement<'a, F>( &'a mut self, - query: T, + cache_key: StatementCacheKey, + sql: String, + is_query_safe_to_cache: bool, metadata: &[DB::TypeMetadata], prepare_fn: F, - backend: &DB, ) -> PrepareFuture<'a, F, S> where S: Send, DB::QueryBuilder: Default, DB::TypeMetadata: Clone + Send + Sync, - T: QueryFragment + QueryId + Send, F: PrepareCallback + Send + 'a, StatementCacheKey: Hash + Eq, { use std::collections::hash_map::Entry::{Occupied, Vacant}; - let cache_key = match StatementCacheKey::for_source(&query, metadata, backend) { - Ok(key) => key, - Err(e) => return future::Either::Left(future::ready(Err(e))), - }; - - let is_query_safe_to_cache = match query.is_safe_to_cache_prepared(backend) { - Ok(is_safe_to_cache) => is_safe_to_cache, - Err(e) => return future::Either::Left(future::ready(Err(e))), - }; - if !is_query_safe_to_cache { - let sql = match cache_key.sql(&query, backend) { - Ok(sql) => sql.into_owned(), - Err(e) => return future::Either::Left(future::ready(Err(e))), - }; - let metadata = metadata.to_vec(); let f = async move { let stmt = prepare_fn @@ -86,10 +68,6 @@ impl StmtCache { prepare_fn, )))), Vacant(entry) => { - let sql = match entry.key().sql(&query, backend) { - Ok(sql) => sql.into_owned(), - Err(e) => return future::Either::Left(future::ready(Err(e))), - }; let metadata = metadata.to_vec(); let f = async move { let statement = prepare_fn diff --git a/src/transaction_manager.rs b/src/transaction_manager.rs index c789261..dbb5d5a 100644 --- a/src/transaction_manager.rs +++ b/src/transaction_manager.rs @@ -1,3 +1,7 @@ +use diesel::connection::TransactionManagerStatus; +use diesel::connection::{ + InTransactionStatus, TransactionDepthChange, ValidTransactionManagerStatus, +}; use diesel::result::Error; use diesel::QueryResult; use scoped_futures::ScopedBoxFuture; @@ -88,6 +92,7 @@ pub trait TransactionManager: Send { // so we don't consider this connection broken Ok(ValidTransactionManagerStatus { in_transaction: None, + .. }) => false, // The transaction manager is in an error state // Therefore we consider this connection broken @@ -97,6 +102,7 @@ pub trait TransactionManager: Send { // if that transaction was not opened by `begin_test_transaction` Ok(ValidTransactionManagerStatus { in_transaction: Some(s), + .. }) => !s.test_transaction, } } @@ -109,144 +115,144 @@ pub struct AnsiTransactionManager { pub(crate) status: TransactionManagerStatus, } -/// Status of the transaction manager -#[derive(Debug)] -pub enum TransactionManagerStatus { - /// Valid status, the manager can run operations - Valid(ValidTransactionManagerStatus), - /// Error status, probably following a broken connection. The manager will no longer run operations - InError, -} - -impl Default for TransactionManagerStatus { - fn default() -> Self { - TransactionManagerStatus::Valid(ValidTransactionManagerStatus::default()) - } -} - -impl TransactionManagerStatus { - /// Returns the transaction depth if the transaction manager's status is valid, or returns - /// [`Error::BrokenTransactionManager`] if the transaction manager is in error. - pub fn transaction_depth(&self) -> QueryResult> { - match self { - TransactionManagerStatus::Valid(valid_status) => Ok(valid_status.transaction_depth()), - TransactionManagerStatus::InError => Err(Error::BrokenTransactionManager), - } - } - - /// If in transaction and transaction manager is not broken, registers that the - /// connection can not be used anymore until top-level transaction is rolled back - pub(crate) fn set_top_level_transaction_requires_rollback(&mut self) { - if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus { - in_transaction: - Some(InTransactionStatus { - top_level_transaction_requires_rollback, - .. - }), - }) = self - { - *top_level_transaction_requires_rollback = true; - } - } - - /// Sets the transaction manager status to InError - /// - /// Subsequent attempts to use transaction-related features will result in a - /// [`Error::BrokenTransactionManager`] error - pub fn set_in_error(&mut self) { - *self = TransactionManagerStatus::InError - } - - fn transaction_state(&mut self) -> QueryResult<&mut ValidTransactionManagerStatus> { - match self { - TransactionManagerStatus::Valid(valid_status) => Ok(valid_status), - TransactionManagerStatus::InError => Err(Error::BrokenTransactionManager), - } - } - - pub(crate) fn set_test_transaction_flag(&mut self) { - if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus { - in_transaction: Some(s), - }) = self - { - s.test_transaction = true; - } - } -} - -/// Valid transaction status for the manager. Can return the current transaction depth -#[allow(missing_copy_implementations)] -#[derive(Debug, Default)] -pub struct ValidTransactionManagerStatus { - in_transaction: Option, -} - -#[allow(missing_copy_implementations)] -#[derive(Debug)] -struct InTransactionStatus { - transaction_depth: NonZeroU32, - top_level_transaction_requires_rollback: bool, - test_transaction: bool, -} - -impl ValidTransactionManagerStatus { - /// Return the current transaction depth - /// - /// This value is `None` if no current transaction is running - /// otherwise the number of nested transactions is returned. - pub fn transaction_depth(&self) -> Option { - self.in_transaction.as_ref().map(|it| it.transaction_depth) - } - - /// Update the transaction depth by adding the value of the `transaction_depth_change` parameter if the `query` is - /// `Ok(())` - pub fn change_transaction_depth( - &mut self, - transaction_depth_change: TransactionDepthChange, - ) -> QueryResult<()> { - match (&mut self.in_transaction, transaction_depth_change) { - (Some(in_transaction), TransactionDepthChange::IncreaseDepth) => { - // Can be replaced with saturating_add directly on NonZeroU32 once - // is stable - in_transaction.transaction_depth = - NonZeroU32::new(in_transaction.transaction_depth.get().saturating_add(1)) - .expect("nz + nz is always non-zero"); - Ok(()) - } - (Some(in_transaction), TransactionDepthChange::DecreaseDepth) => { - // This sets `transaction_depth` to `None` as soon as we reach zero - match NonZeroU32::new(in_transaction.transaction_depth.get() - 1) { - Some(depth) => in_transaction.transaction_depth = depth, - None => self.in_transaction = None, - } - Ok(()) - } - (None, TransactionDepthChange::IncreaseDepth) => { - self.in_transaction = Some(InTransactionStatus { - transaction_depth: NonZeroU32::new(1).expect("1 is non-zero"), - top_level_transaction_requires_rollback: false, - test_transaction: false, - }); - Ok(()) - } - (None, TransactionDepthChange::DecreaseDepth) => { - // We screwed up something somewhere - // we cannot decrease the transaction count if - // we are not inside a transaction - Err(Error::NotInTransaction) - } - } - } -} - -/// Represents a change to apply to the depth of a transaction -#[derive(Debug, Clone, Copy)] -pub enum TransactionDepthChange { - /// Increase the depth of the transaction (corresponds to `BEGIN` or `SAVEPOINT`) - IncreaseDepth, - /// Decreases the depth of the transaction (corresponds to `COMMIT`/`RELEASE SAVEPOINT` or `ROLLBACK`) - DecreaseDepth, -} +// /// Status of the transaction manager +// #[derive(Debug)] +// pub enum TransactionManagerStatus { +// /// Valid status, the manager can run operations +// Valid(ValidTransactionManagerStatus), +// /// Error status, probably following a broken connection. The manager will no longer run operations +// InError, +// } + +// impl Default for TransactionManagerStatus { +// fn default() -> Self { +// TransactionManagerStatus::Valid(ValidTransactionManagerStatus::default()) +// } +// } + +// impl TransactionManagerStatus { +// /// Returns the transaction depth if the transaction manager's status is valid, or returns +// /// [`Error::BrokenTransactionManager`] if the transaction manager is in error. +// pub fn transaction_depth(&self) -> QueryResult> { +// match self { +// TransactionManagerStatus::Valid(valid_status) => Ok(valid_status.transaction_depth()), +// TransactionManagerStatus::InError => Err(Error::BrokenTransactionManager), +// } +// } + +// /// If in transaction and transaction manager is not broken, registers that the +// /// connection can not be used anymore until top-level transaction is rolled back +// pub(crate) fn set_top_level_transaction_requires_rollback(&mut self) { +// if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus { +// in_transaction: +// Some(InTransactionStatus { +// top_level_transaction_requires_rollback, +// .. +// }), +// }) = self +// { +// *top_level_transaction_requires_rollback = true; +// } +// } + +// /// Sets the transaction manager status to InError +// /// +// /// Subsequent attempts to use transaction-related features will result in a +// /// [`Error::BrokenTransactionManager`] error +// pub fn set_in_error(&mut self) { +// *self = TransactionManagerStatus::InError +// } + +// fn transaction_state(&mut self) -> QueryResult<&mut ValidTransactionManagerStatus> { +// match self { +// TransactionManagerStatus::Valid(valid_status) => Ok(valid_status), +// TransactionManagerStatus::InError => Err(Error::BrokenTransactionManager), +// } +// } + +// pub(crate) fn set_test_transaction_flag(&mut self) { +// if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus { +// in_transaction: Some(s), +// }) = self +// { +// s.test_transaction = true; +// } +// } +// } + +// /// Valid transaction status for the manager. Can return the current transaction depth +// #[allow(missing_copy_implementations)] +// #[derive(Debug, Default)] +// pub struct ValidTransactionManagerStatus { +// in_transaction: Option, +// } + +// #[allow(missing_copy_implementations)] +// #[derive(Debug)] +// struct InTransactionStatus { +// transaction_depth: NonZeroU32, +// top_level_transaction_requires_rollback: bool, +// test_transaction: bool, +// } + +// impl ValidTransactionManagerStatus { +// /// Return the current transaction depth +// /// +// /// This value is `None` if no current transaction is running +// /// otherwise the number of nested transactions is returned. +// pub fn transaction_depth(&self) -> Option { +// self.in_transaction.as_ref().map(|it| it.transaction_depth) +// } + +// /// Update the transaction depth by adding the value of the `transaction_depth_change` parameter if the `query` is +// /// `Ok(())` +// pub fn change_transaction_depth( +// &mut self, +// transaction_depth_change: TransactionDepthChange, +// ) -> QueryResult<()> { +// match (&mut self.in_transaction, transaction_depth_change) { +// (Some(in_transaction), TransactionDepthChange::IncreaseDepth) => { +// // Can be replaced with saturating_add directly on NonZeroU32 once +// // is stable +// in_transaction.transaction_depth = +// NonZeroU32::new(in_transaction.transaction_depth.get().saturating_add(1)) +// .expect("nz + nz is always non-zero"); +// Ok(()) +// } +// (Some(in_transaction), TransactionDepthChange::DecreaseDepth) => { +// // This sets `transaction_depth` to `None` as soon as we reach zero +// match NonZeroU32::new(in_transaction.transaction_depth.get() - 1) { +// Some(depth) => in_transaction.transaction_depth = depth, +// None => self.in_transaction = None, +// } +// Ok(()) +// } +// (None, TransactionDepthChange::IncreaseDepth) => { +// self.in_transaction = Some(InTransactionStatus { +// transaction_depth: NonZeroU32::new(1).expect("1 is non-zero"), +// top_level_transaction_requires_rollback: false, +// test_transaction: false, +// }); +// Ok(()) +// } +// (None, TransactionDepthChange::DecreaseDepth) => { +// // We screwed up something somewhere +// // we cannot decrease the transaction count if +// // we are not inside a transaction +// Err(Error::NotInTransaction) +// } +// } +// } +// } + +// /// Represents a change to apply to the depth of a transaction +// #[derive(Debug, Clone, Copy)] +// pub enum TransactionDepthChange { +// /// Increase the depth of the transaction (corresponds to `BEGIN` or `SAVEPOINT`) +// IncreaseDepth, +// /// Decreases the depth of the transaction (corresponds to `COMMIT`/`RELEASE SAVEPOINT` or `ROLLBACK`) +// DecreaseDepth, +// } impl AnsiTransactionManager { fn get_transaction_state( @@ -305,40 +311,38 @@ where async fn rollback_transaction(conn: &mut Conn) -> QueryResult<()> { let transaction_state = Self::get_transaction_state(conn)?; - let rollback_sql = match transaction_state.in_transaction { - Some(ref mut in_transaction) => { + let ( + (rollback_sql, rolling_back_top_level), + requires_rollback_maybe_up_to_top_level_before_execute, + ) = match transaction_state.in_transaction { + Some(ref in_transaction) => ( match in_transaction.transaction_depth.get() { - 1 => Cow::Borrowed("ROLLBACK"), - depth_gt1 => { - if in_transaction.top_level_transaction_requires_rollback { - // There's no point in *actually* rolling back this one - // because we won't be able to do anything until top-level - // is rolled back. - - // To make it easier on the user (that they don't have to really look - // at actual transaction depth and can just rely on the number of - // times they have called begin/commit/rollback) we don't mark the - // transaction manager as out of the savepoints as soon as we - // realize there is that issue, but instead we still decrement here: - in_transaction.transaction_depth = NonZeroU32::new(depth_gt1 - 1) - .expect("Depth was checked to be > 1"); - return Ok(()); - } else { - Cow::Owned(format!( - "ROLLBACK TO SAVEPOINT diesel_savepoint_{}", - depth_gt1 - 1 - )) - } - } - } - } + 1 => (Cow::Borrowed("ROLLBACK"), true), + depth_gt1 => ( + Cow::Owned(format!( + "ROLLBACK TO SAVEPOINT diesel_savepoint_{}", + depth_gt1 - 1 + )), + false, + ), + }, + in_transaction.requires_rollback_maybe_up_to_top_level, + ), None => return Err(Error::NotInTransaction), }; match conn.batch_execute(&rollback_sql).await { Ok(()) => { - Self::get_transaction_state(conn)? - .change_transaction_depth(TransactionDepthChange::DecreaseDepth)?; + match Self::get_transaction_state(conn)? + .change_transaction_depth(TransactionDepthChange::DecreaseDepth) + { + Ok(()) => {} + Err(Error::NotInTransaction) if rolling_back_top_level => { + // Transaction exit may have already been detected by connection + // implementation. It's fine. + } + Err(e) => return Err(e), + } Ok(()) } Err(rollback_error) => { @@ -348,17 +352,35 @@ where in_transaction: Some(InTransactionStatus { transaction_depth, - top_level_transaction_requires_rollback, + requires_rollback_maybe_up_to_top_level, .. }), - }) if transaction_depth.get() > 1 - && !*top_level_transaction_requires_rollback => - { + .. + }) if transaction_depth.get() > 1 => { // A savepoint failed to rollback - we may still attempt to repair - // the connection by rolling back top-level transaction. + // the connection by rolling back higher levels. + + // To make it easier on the user (that they don't have to really + // look at actual transaction depth and can just rely on the number + // of times they have called begin/commit/rollback) we still + // decrement here: *transaction_depth = NonZeroU32::new(transaction_depth.get() - 1) .expect("Depth was checked to be > 1"); - *top_level_transaction_requires_rollback = true; + *requires_rollback_maybe_up_to_top_level = true; + if requires_rollback_maybe_up_to_top_level_before_execute { + // In that case, we tolerate that savepoint releases fail + // -> we should ignore errors + return Ok(()); + } + } + TransactionManagerStatus::Valid(ValidTransactionManagerStatus { + in_transaction: None, + .. + }) => { + // we would have returned `NotInTransaction` if that was already the state + // before we made our call + // => Transaction manager status has been fixed by the underlying connection + // so we don't need to set_in_error } _ => tm_status.set_in_error(), } @@ -375,53 +397,51 @@ where async fn commit_transaction(conn: &mut Conn) -> QueryResult<()> { let transaction_state = Self::get_transaction_state(conn)?; let transaction_depth = transaction_state.transaction_depth(); - let commit_sql = match transaction_depth { + let (commit_sql, committing_top_level) = match transaction_depth { None => return Err(Error::NotInTransaction), - Some(transaction_depth) if transaction_depth.get() == 1 => Cow::Borrowed("COMMIT"), - Some(transaction_depth) => Cow::Owned(format!( - "RELEASE SAVEPOINT diesel_savepoint_{}", - transaction_depth.get() - 1 - )), + Some(transaction_depth) if transaction_depth.get() == 1 => { + (Cow::Borrowed("COMMIT"), true) + } + Some(transaction_depth) => ( + Cow::Owned(format!( + "RELEASE SAVEPOINT diesel_savepoint_{}", + transaction_depth.get() - 1 + )), + false, + ), }; match conn.batch_execute(&commit_sql).await { Ok(()) => { - Self::get_transaction_state(conn)? - .change_transaction_depth(TransactionDepthChange::DecreaseDepth)?; + match Self::get_transaction_state(conn)? + .change_transaction_depth(TransactionDepthChange::DecreaseDepth) + { + Ok(()) => {} + Err(Error::NotInTransaction) if committing_top_level => { + // Transaction exit may have already been detected by connection. + // It's fine + } + Err(e) => return Err(e), + } Ok(()) } Err(commit_error) => { if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus { in_transaction: Some(InTransactionStatus { - ref mut transaction_depth, - top_level_transaction_requires_rollback: true, + requires_rollback_maybe_up_to_top_level: true, .. }), + .. }) = conn.transaction_state().status { - match transaction_depth.get() { - 1 => match Self::rollback_transaction(conn).await { - Ok(()) => {} - Err(rollback_error) => { - conn.transaction_state().status.set_in_error(); - return Err(Error::RollbackErrorOnCommit { - rollback_error: Box::new(rollback_error), - commit_error: Box::new(commit_error), - }); - } - }, - depth_gt1 => { - // There's no point in *actually* rolling back this one - // because we won't be able to do anything until top-level - // is rolled back. - - // To make it easier on the user (that they don't have to really look - // at actual transaction depth and can just rely on the number of - // times they have called begin/commit/rollback) we don't mark the - // transaction manager as out of the savepoints as soon as we - // realize there is that issue, but instead we still decrement here: - *transaction_depth = NonZeroU32::new(depth_gt1 - 1) - .expect("Depth was checked to be > 1"); + match Self::rollback_transaction(conn).await { + Ok(()) => {} + Err(rollback_error) => { + conn.transaction_state().status.set_in_error(); + return Err(Error::RollbackErrorOnCommit { + rollback_error: Box::new(rollback_error), + commit_error: Box::new(commit_error), + }); } } } diff --git a/tests/lib.rs b/tests/lib.rs index 7c9bce8..27dfde1 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -1,5 +1,5 @@ use diesel::prelude::{ExpressionMethods, OptionalExtension, QueryDsl}; -use diesel::{sql_function, QueryResult}; +use diesel::QueryResult; use diesel_async::*; use scoped_futures::ScopedFutureExt; use std::fmt::Debug; @@ -9,6 +9,8 @@ use std::pin::Pin; mod custom_types; #[cfg(any(feature = "bb8", feature = "deadpool", feature = "mobc"))] mod pooling; +#[cfg(feature = "async-connection-wrapper")] +mod sync_wrapper; mod type_check; async fn transaction_test(conn: &mut TestConnection) -> QueryResult<()> { @@ -121,7 +123,7 @@ async fn setup(connection: &mut TestConnection) { } #[cfg(feature = "postgres")] -sql_function!(fn pg_sleep(interval: diesel::sql_types::Double)); +diesel::sql_function!(fn pg_sleep(interval: diesel::sql_types::Double)); #[cfg(feature = "postgres")] #[tokio::test] diff --git a/tests/sync_wrapper.rs b/tests/sync_wrapper.rs new file mode 100644 index 0000000..024afe9 --- /dev/null +++ b/tests/sync_wrapper.rs @@ -0,0 +1,26 @@ +use diesel::prelude::*; +use diesel_async::async_connection_wrapper::AsyncConnectionWrapper; + +#[test] +fn test_sync_wrapper() { + let db_url = std::env::var("DATABASE_URL").unwrap(); + let mut conn = AsyncConnectionWrapper::::establish(&db_url).unwrap(); + + let res = + diesel::select(1.into_sql::()).get_result::(&mut conn); + assert_eq!(Ok(1), res); +} + +#[tokio::test] +async fn test_sync_wrapper_under_runtime() { + let db_url = std::env::var("DATABASE_URL").unwrap(); + tokio::task::spawn_blocking(move || { + let mut conn = AsyncConnectionWrapper::::establish(&db_url).unwrap(); + + let res = + diesel::select(1.into_sql::()).get_result::(&mut conn); + assert_eq!(Ok(1), res); + }) + .await + .unwrap(); +}