diff --git a/src/pg/mod.rs b/src/pg/mod.rs index 654874d..a2c8fbb 100644 --- a/src/pg/mod.rs +++ b/src/pg/mod.rs @@ -102,6 +102,7 @@ pub struct AsyncPgConnection { stmt_cache: Arc>>, transaction_state: Arc>, metadata_cache: Arc>, + error_receiver: tokio::sync::oneshot::Receiver, } #[async_trait::async_trait] @@ -124,12 +125,19 @@ impl AsyncConnection for AsyncPgConnection { let (client, connection) = tokio_postgres::connect(database_url, tokio_postgres::NoTls) .await .map_err(ErrorHelper)?; - tokio::spawn(async move { - if let Err(e) = connection.await { - eprintln!("connection error: {e}"); + // If there is a connection error, we capture it in this channel and make when + // the user next calls one of the functions on the connection in this trait, we + // return the error instead of the inner result. + let (sender, receiver) = tokio::sync::oneshot::channel(); + tokio::spawn(async { + if let Err(connection_error) = connection.await { + let connection_error = diesel::result::Error::from(ErrorHelper(connection_error)); + if let Err(send_error) = sender.send(connection_error) { + eprintln!("Failed to send connection error through channel, connection must have been dropped: {}", send_error); + } } }); - Self::try_from(client).await + Self::try_from_with_error_receiver(client, receiver).await } fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> @@ -138,15 +146,19 @@ impl AsyncConnection for AsyncPgConnection { T::Query: QueryFragment + QueryId + 'query, { let query = source.as_query(); - self.with_prepared_statement(query, |conn, stmt, binds| async move { + let f = self.with_prepared_statement(query, |conn, stmt, binds| async move { let res = conn.query_raw(&stmt, binds).await.map_err(ErrorHelper)?; Ok(res .map_err(|e| diesel::result::Error::from(ErrorHelper(e))) .map_ok(PgRow::new) .boxed()) - }) - .boxed() + }); + + match self.error_receiver.try_recv() { + Ok(e) => Box::pin(async move { Err(e) }), + Err(_) => f, + } } fn execute_returning_count<'conn, 'query, T>( @@ -156,7 +168,7 @@ impl AsyncConnection for AsyncPgConnection { where T: QueryFragment + QueryId + 'query, { - self.with_prepared_statement(source, |conn, stmt, binds| async move { + let f = self.with_prepared_statement(source, |conn, stmt, binds| async move { let binds = binds .iter() .map(|b| b as &(dyn ToSql + Sync)) @@ -166,8 +178,12 @@ impl AsyncConnection for AsyncPgConnection { .await .map_err(ErrorHelper)?; Ok(res as usize) - }) - .boxed() + }); + + match self.error_receiver.try_recv() { + Ok(e) => Box::pin(async move { Err(e) }), + Err(_) => f, + } } fn transaction_state(&mut self) -> &mut AnsiTransactionManager { @@ -270,11 +286,24 @@ impl AsyncPgConnection { /// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`] pub async fn try_from(conn: tokio_postgres::Client) -> ConnectionResult { + // We create a dummy receiver here. If the user is calling this, they have + // created their own client and connection and are handling any error in + // the latter themselves. + Self::try_from_with_error_receiver(conn, tokio::sync::oneshot::channel().1).await + } + + /// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`] + /// and a [`tokio::sync::oneshot::Receiver`] for receiving an error from the connection. + async fn try_from_with_error_receiver( + conn: tokio_postgres::Client, + error_receiver: tokio::sync::oneshot::Receiver, + ) -> ConnectionResult { let mut conn = Self { conn: Arc::new(conn), stmt_cache: Arc::new(Mutex::new(StmtCache::new())), transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())), metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())), + error_receiver, }; conn.set_config_options() .await @@ -340,7 +369,7 @@ impl AsyncPgConnection { let metadata_cache = self.metadata_cache.clone(); let tm = self.transaction_state.clone(); - async move { + let f = async move { let sql = sql?; let is_safe_to_cache_prepared = is_safe_to_cache_prepared?; collect_bind_result?; @@ -411,8 +440,12 @@ impl AsyncPgConnection { let res = callback(raw_connection, stmt.clone(), binds).await; let mut tm = tm.lock().await; update_transaction_manager_status(res, &mut tm) + }; + + match self.error_receiver.try_recv() { + Ok(e) => Box::pin(async move { Err(e) }), + Err(_) => f.boxed(), } - .boxed() } }