diff --git a/src/pg/mod.rs b/src/pg/mod.rs index 654874d..af1670d 100644 --- a/src/pg/mod.rs +++ b/src/pg/mod.rs @@ -21,6 +21,7 @@ use futures_util::future::BoxFuture; use futures_util::stream::{BoxStream, TryStreamExt}; use futures_util::{Future, FutureExt, StreamExt}; use std::borrow::Cow; +use std::ops::DerefMut; use std::sync::Arc; use tokio::sync::Mutex; use tokio_postgres::types::ToSql; @@ -102,6 +103,7 @@ pub struct AsyncPgConnection { stmt_cache: Arc>>, transaction_state: Arc>, metadata_cache: Arc>, + error_receiver: Arc>>, } #[async_trait::async_trait] @@ -124,12 +126,19 @@ impl AsyncConnection for AsyncPgConnection { let (client, connection) = tokio_postgres::connect(database_url, tokio_postgres::NoTls) .await .map_err(ErrorHelper)?; - tokio::spawn(async move { - if let Err(e) = connection.await { - eprintln!("connection error: {e}"); + // If there is a connection error, we capture it in this channel and make when + // the user next calls one of the functions on the connection in this trait, we + // return the error instead of the inner result. + let (sender, receiver) = tokio::sync::oneshot::channel(); + tokio::spawn(async { + if let Err(connection_error) = connection.await { + let connection_error = diesel::result::Error::from(ErrorHelper(connection_error)); + if let Err(send_error) = sender.send(connection_error) { + eprintln!("Failed to send connection error through channel, connection must have been dropped: {}", send_error); + } } }); - Self::try_from(client).await + Self::try_from_with_error_receiver(client, receiver).await } fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> @@ -270,11 +279,24 @@ impl AsyncPgConnection { /// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`] pub async fn try_from(conn: tokio_postgres::Client) -> ConnectionResult { + // We create a dummy receiver here. If the user is calling this, they have + // created their own client and connection and are handling any error in + // the latter themselves. + Self::try_from_with_error_receiver(conn, tokio::sync::oneshot::channel().1).await + } + + /// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`] + /// and a [`tokio::sync::oneshot::Receiver`] for receiving an error from the connection. + async fn try_from_with_error_receiver( + conn: tokio_postgres::Client, + error_receiver: tokio::sync::oneshot::Receiver, + ) -> ConnectionResult { let mut conn = Self { conn: Arc::new(conn), stmt_cache: Arc::new(Mutex::new(StmtCache::new())), transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())), metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())), + error_receiver: Arc::new(Mutex::new(error_receiver)), }; conn.set_config_options() .await @@ -340,7 +362,7 @@ impl AsyncPgConnection { let metadata_cache = self.metadata_cache.clone(); let tm = self.transaction_state.clone(); - async move { + let f = async move { let sql = sql?; let is_safe_to_cache_prepared = is_safe_to_cache_prepared?; collect_bind_result?; @@ -411,6 +433,18 @@ impl AsyncPgConnection { let res = callback(raw_connection, stmt.clone(), binds).await; let mut tm = tm.lock().await; update_transaction_manager_status(res, &mut tm) + }; + + let er = self.error_receiver.clone(); + async move { + let mut error_receiver = er.lock().await; + // While the future (f) is running, at any await point we will check if + // there is an error in the channel from the connection. If there is, we + // will return that instead and f will get aborted. + tokio::select! { + error = error_receiver.deref_mut() => Err(error.unwrap()), + res = f => res, + } } .boxed() }