Skip to content

Commit

Permalink
Introduce an AsyncConnectionWrapper type
Browse files Browse the repository at this point in the history
This type turns a `diesel_async::AsyncConnection` into a
`diesel::Conenction`. I see the following use-cases for this:

* Having a pure rust sync diesel connection implementation for postgres
and mysql can simplify the setup of new diesel projects
* Allowing projects depending on `diesel_async` to use
`diesel_migrations` without depending on `libpq`/`libmysqlclient`

This change requires restructuring the implementation of
`AsyncPgConnection` a bit so that we make the returned future `Send`
independently of whether or not the query parameters are `Send`. This is
possible by serialising the bind parameters before actually constructing
the future.

It also refactors the `TransactionManager` implementation to share more
code with diesel itself.
  • Loading branch information
weiznich committed Aug 31, 2023
1 parent 5ba4375 commit 4954cff
Show file tree
Hide file tree
Showing 17 changed files with 884 additions and 464 deletions.
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All user visible changes to this project will be documented in this file.
This project adheres to [Semantic Versioning](http://semver.org/), as described
for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/text/1105-api-evolution.md)

## Unreleased

* Add a `AsyncConnectionWrapper` type to turn a `diesel_async::AsyncConnection` into a `diesel::Connection`. This might be used to execute migrations via `diesel_migrations`.

## [0.3.2] - 2023-07-24

* Fix `TinyInt` serialization
Expand Down Expand Up @@ -52,5 +56,3 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/
[0.2.1]: https://github.com/weiznich/diesel_async/compare/v0.2.0...v0.2.1
[0.2.2]: https://github.com/weiznich/diesel_async/compare/v0.2.1...v0.2.2
[0.3.0]: https://github.com/weiznich/diesel_async/compare/v0.2.0...v0.3.0
[0.3.1]: https://github.com/weiznich/diesel_async/compare/v0.3.0...v0.3.1
[0.3.2]: https://github.com/weiznich/diesel_async/compare/v0.3.1...v0.3.2
11 changes: 7 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ description = "An async extension for Diesel the safe, extensible ORM and Query
rust-version = "1.65.0"

[dependencies]
diesel = { version = "~2.1.0", default-features = false, features = ["i-implement-a-third-party-backend-and-opt-into-breaking-changes"]}
diesel = { version = "~2.1.1", default-features = false, features = ["i-implement-a-third-party-backend-and-opt-into-breaking-changes"]}
async-trait = "0.1.66"
futures-channel = { version = "0.3.17", default-features = false, features = ["std", "sink"], optional = true }
futures-util = { version = "0.3.17", default-features = false, features = ["std", "sink"] }
tokio-postgres = { version = "0.7.2", optional = true}
tokio-postgres = { version = "0.7.10", optional = true}
tokio = { version = "1.26", optional = true}
mysql_async = { version = ">=0.30.0,<0.33", optional = true}
mysql_common = {version = ">=0.29.0,<0.31.0", optional = true}
Expand All @@ -31,12 +31,14 @@ scoped-futures = {version = "0.1", features = ["std"]}
tokio = {version = "1.12.0", features = ["rt", "macros", "rt-multi-thread"]}
cfg-if = "1"
chrono = "0.4"
diesel = { version = "2.0.0", default-features = false, features = ["chrono"]}
diesel = { version = "2.1.0", default-features = false, features = ["chrono"]}

[features]
default = []
mysql = ["diesel/mysql_backend", "mysql_async", "mysql_common", "futures-channel"]
mysql = ["diesel/mysql_backend", "mysql_async", "mysql_common", "futures-channel", "tokio"]
postgres = ["diesel/postgres_backend", "tokio-postgres", "tokio", "tokio/rt"]
async-connection-wrapper = []
r2d2 = ["diesel/r2d2"]

[[test]]
name = "integration_tests"
Expand All @@ -54,3 +56,4 @@ members = [
".",
"examples/postgres/pooled-with-rustls"
]

313 changes: 313 additions & 0 deletions src/async_connection_wrapper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
//! This module contains an wrapper type
//! that provides a [`diesel::Connection`]
//! implementation for types that implement
//! [`crate::AsyncConnection`]. Using this type
//! might be useful for the following usecases:
//!
//! * Executing migrations on application startup
//! * Using a pure rust diesel connection implementation
//! as replacement for the existing connection
//! implementations provided by diesel

use futures_util::Future;
use futures_util::Stream;
use futures_util::StreamExt;
use std::pin::Pin;

/// This is a helper trait that allows to customize the
/// async runtime used to execute futures as part of the
/// [`AsyncConnectionWrapper`] type. By default a
/// tokio runtime is used.
pub trait BlockOn {
/// This function should allow to execute a
/// given future to get the result
fn block_on<F>(&self, f: F) -> F::Output
where
F: Future;

/// This function should be used to construct
/// a new runtime instance
fn get_runtime() -> Self;
}

/// A helper type that wraps an [`crate::AsyncConnectionWrapper`] to
/// provide a sync [`diesel::Connection`] implementation.
///
/// Internally this wrapper type will use `block_on` to wait for
/// the execution of futures from the inner connection. This implies you
/// cannot use functions of this type in a scope with an already existing
/// tokio runtime. If you are in a situation where you want to use this
/// connection wrapper in the scope of an existing tokio runtime (for example
/// for running migrations via `diesel_migration`) you need to wrap
/// the relevant code block into a `tokio::task::spawn_blocking` task.
///
/// # Examples
///
/// ```rust
/// # include!("doctest_setup.rs");
/// use schema::users;
/// use diesel_async::async_connection_wrapper::AsyncConnectionWrapper;
/// #
/// # fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
/// use diesel::prelude::{RunQueryDsl, Connection};
/// # let database_url = database_url();
/// let mut conn = AsyncConnectionWrapper::<DbConnection>::establish(&database_url)?;
///
/// let all_users = users::table.load::<(i32, String)>(&mut conn)?;
/// # assert_eq!(all_users.len(), 0);
/// # Ok(())
/// # }
/// ```
///
/// If you are in the scope of an existing tokio runtime you need to use
/// `tokio::task::spawn_blocking` to encapsulate the blocking tasks
/// ```rust
/// # include!("doctest_setup.rs");
/// use schema::users;
/// use diesel_async::async_connection_wrapper::AsyncConnectionWrapper;
///
/// async fn some_async_fn() {
/// # let database_url = database_url();
/// // need to use `spawn_blocking` to execute
/// // a blocking task in the scope of an existing runtime
/// let res = tokio::task::spawn_blocking(move || {
/// use diesel::prelude::{RunQueryDsl, Connection};
/// let mut conn = AsyncConnectionWrapper::<DbConnection>::establish(&database_url)?;
///
/// let all_users = users::table.load::<(i32, String)>(&mut conn)?;
/// # assert_eq!(all_users.len(), 0);
/// Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
/// }).await;
///
/// # res.unwrap().unwrap();
/// }
///
/// # #[tokio::main]
/// # async fn main() {
/// # some_async_fn().await;
/// # }
/// ```
#[cfg(feature = "tokio")]
pub type AsyncConnectionWrapper<C, B = self::implementation::Tokio> =
self::implementation::AsyncConnectionWrapper<C, B>;

/// A helper type that wraps an [`crate::AsyncConnectionWrapper`] to
/// provide a sync [`diesel::Connection`] implementation.
///
/// Internally this wrapper type will use `block_on` to wait for
/// the execution of futures from the inner connection.
#[cfg(not(feature = "tokio"))]
pub use self::implementation::AsyncConnectionWrapper;

mod implementation {
use super::*;

pub struct AsyncConnectionWrapper<C, B> {
inner: C,
runtime: B,
}

impl<C, B> diesel::connection::SimpleConnection for AsyncConnectionWrapper<C, B>
where
C: crate::SimpleAsyncConnection,
B: BlockOn,
{
fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
let f = self.inner.batch_execute(query);
self.runtime.block_on(f)
}
}

impl<C, B> diesel::connection::ConnectionSealed for AsyncConnectionWrapper<C, B> {}

impl<C, B> diesel::connection::Connection for AsyncConnectionWrapper<C, B>
where
C: crate::AsyncConnection,
B: BlockOn + Send,
{
type Backend = C::Backend;

type TransactionManager = AsyncConnectionWrapperTransactionManagerWrapper;

fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
let runtime = B::get_runtime();
let f = C::establish(database_url);
let inner = runtime.block_on(f)?;
Ok(Self { inner, runtime })
}

fn execute_returning_count<T>(&mut self, source: &T) -> diesel::QueryResult<usize>
where
T: diesel::query_builder::QueryFragment<Self::Backend> + diesel::query_builder::QueryId,
{
let f = self.inner.execute_returning_count(source);
self.runtime.block_on(f)
}

fn transaction_state(
&mut self,
) -> &mut <Self::TransactionManager as diesel::connection::TransactionManager<Self>>::TransactionStateData{
self.inner.transaction_state()
}
}

impl<C, B> diesel::connection::LoadConnection for AsyncConnectionWrapper<C, B>
where
C: crate::AsyncConnection,
B: BlockOn + Send,
{
type Cursor<'conn, 'query> = AsyncCursorWrapper<'conn, C::Stream<'conn, 'query>, B>
where
Self: 'conn;

type Row<'conn, 'query> = C::Row<'conn, 'query>
where
Self: 'conn;

fn load<'conn, 'query, T>(
&'conn mut self,
source: T,
) -> diesel::QueryResult<Self::Cursor<'conn, 'query>>
where
T: diesel::query_builder::Query
+ diesel::query_builder::QueryFragment<Self::Backend>
+ diesel::query_builder::QueryId
+ 'query,
Self::Backend: diesel::expression::QueryMetadata<T::SqlType>,
{
let f = self.inner.load(source);
let stream = self.runtime.block_on(f)?;

Ok(AsyncCursorWrapper {
stream: Box::pin(stream),
runtime: &self.runtime,
})
}
}

pub struct AsyncCursorWrapper<'a, S, B> {
stream: Pin<Box<S>>,
runtime: &'a B,
}

impl<'a, S, B> Iterator for AsyncCursorWrapper<'a, S, B>
where
S: Stream,
B: BlockOn,
{
type Item = S::Item;

fn next(&mut self) -> Option<Self::Item> {
let f = self.stream.next();
self.runtime.block_on(f)
}
}

pub struct AsyncConnectionWrapperTransactionManagerWrapper;

impl<C, B> diesel::connection::TransactionManager<AsyncConnectionWrapper<C, B>>
for AsyncConnectionWrapperTransactionManagerWrapper
where
C: crate::AsyncConnection,
B: BlockOn + Send,
{
type TransactionStateData =
<C::TransactionManager as crate::TransactionManager<C>>::TransactionStateData;

fn begin_transaction(conn: &mut AsyncConnectionWrapper<C, B>) -> diesel::QueryResult<()> {
let f = <C::TransactionManager as crate::TransactionManager<_>>::begin_transaction(
&mut conn.inner,
);
conn.runtime.block_on(f)
}

fn rollback_transaction(
conn: &mut AsyncConnectionWrapper<C, B>,
) -> diesel::QueryResult<()> {
let f = <C::TransactionManager as crate::TransactionManager<_>>::rollback_transaction(
&mut conn.inner,
);
conn.runtime.block_on(f)
}

fn commit_transaction(conn: &mut AsyncConnectionWrapper<C, B>) -> diesel::QueryResult<()> {
let f = <C::TransactionManager as crate::TransactionManager<_>>::commit_transaction(
&mut conn.inner,
);
conn.runtime.block_on(f)
}

fn transaction_manager_status_mut(
conn: &mut AsyncConnectionWrapper<C, B>,
) -> &mut diesel::connection::TransactionManagerStatus {
<C::TransactionManager as crate::TransactionManager<_>>::transaction_manager_status_mut(
&mut conn.inner,
)
}

fn is_broken_transaction_manager(conn: &mut AsyncConnectionWrapper<C, B>) -> bool {
<C::TransactionManager as crate::TransactionManager<_>>::is_broken_transaction_manager(
&mut conn.inner,
)
}
}

#[cfg(feature = "r2d2")]
impl<C, B> diesel::r2d2::R2D2Connection for AsyncConnectionWrapper<C, B>
where
B: BlockOn,
Self: diesel::Connection,
C: crate::AsyncConnection<Backend = <Self as diesel::Connection>::Backend>
+ crate::pooled_connection::PoolableConnection,
{
fn ping(&mut self) -> diesel::QueryResult<()> {
diesel::Connection::execute_returning_count(self, &C::make_ping_query()).map(|_| ())
}

fn is_broken(&mut self) -> bool {
<C::TransactionManager as crate::TransactionManager<_>>::is_broken_transaction_manager(
&mut self.inner,
)
}
}

#[cfg(feature = "tokio")]
pub struct Tokio {
handle: Option<tokio::runtime::Handle>,
runtime: Option<tokio::runtime::Runtime>,
}

#[cfg(feature = "tokio")]
impl BlockOn for Tokio {
fn block_on<F>(&self, f: F) -> F::Output
where
F: Future,
{
if let Some(handle) = &self.handle {
handle.block_on(f)
} else if let Some(runtime) = &self.runtime {
runtime.block_on(f)
} else {
unreachable!()
}
}

fn get_runtime() -> Self {
if let Ok(handle) = tokio::runtime::Handle::try_current() {
Self {
handle: Some(handle),
runtime: None,
}
} else {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.build()
.unwrap();
Self {
handle: None,
runtime: Some(runtime),
}
}
}
}
}
Loading

0 comments on commit 4954cff

Please sign in to comment.