Skip to content

Commit

Permalink
Expose underlying connection errors to user of AsyncPgConnection
Browse files Browse the repository at this point in the history
  • Loading branch information
banool committed Oct 6, 2023
1 parent 1e18b37 commit 97520ec
Showing 1 changed file with 39 additions and 5 deletions.
44 changes: 39 additions & 5 deletions src/pg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use futures_util::future::BoxFuture;
use futures_util::stream::{BoxStream, TryStreamExt};
use futures_util::{Future, FutureExt, StreamExt};
use std::borrow::Cow;
use std::ops::DerefMut;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio_postgres::types::ToSql;
Expand Down Expand Up @@ -102,6 +103,7 @@ pub struct AsyncPgConnection {
stmt_cache: Arc<Mutex<StmtCache<diesel::pg::Pg, Statement>>>,
transaction_state: Arc<Mutex<AnsiTransactionManager>>,
metadata_cache: Arc<Mutex<PgMetadataCache>>,
error_receiver: Arc<Mutex<tokio::sync::oneshot::Receiver<diesel::result::Error>>>,
}

#[async_trait::async_trait]
Expand All @@ -124,12 +126,19 @@ impl AsyncConnection for AsyncPgConnection {
let (client, connection) = tokio_postgres::connect(database_url, tokio_postgres::NoTls)
.await
.map_err(ErrorHelper)?;
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {e}");
// If there is a connection error, we capture it in this channel and make when
// the user next calls one of the functions on the connection in this trait, we
// return the error instead of the inner result.
let (sender, receiver) = tokio::sync::oneshot::channel();
tokio::spawn(async {
if let Err(connection_error) = connection.await {
let connection_error = diesel::result::Error::from(ErrorHelper(connection_error));
if let Err(send_error) = sender.send(connection_error) {
eprintln!("Failed to send connection error through channel, connection must have been dropped: {}", send_error);
}
}
});
Self::try_from(client).await
Self::try_from_with_error_receiver(client, receiver).await
}

fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
Expand Down Expand Up @@ -270,11 +279,24 @@ impl AsyncPgConnection {

/// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`]
pub async fn try_from(conn: tokio_postgres::Client) -> ConnectionResult<Self> {
// We create a dummy receiver here. If the user is calling this, they have
// created their own client and connection and are handling any error in
// the latter themselves.
Self::try_from_with_error_receiver(conn, tokio::sync::oneshot::channel().1).await
}

/// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`]
/// and a [`tokio::sync::oneshot::Receiver`] for receiving an error from the connection.
async fn try_from_with_error_receiver(
conn: tokio_postgres::Client,
error_receiver: tokio::sync::oneshot::Receiver<diesel::result::Error>,
) -> ConnectionResult<Self> {
let mut conn = Self {
conn: Arc::new(conn),
stmt_cache: Arc::new(Mutex::new(StmtCache::new())),
transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())),
metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())),
error_receiver: Arc::new(Mutex::new(error_receiver)),
};
conn.set_config_options()
.await
Expand Down Expand Up @@ -340,7 +362,7 @@ impl AsyncPgConnection {
let metadata_cache = self.metadata_cache.clone();
let tm = self.transaction_state.clone();

async move {
let f = async move {
let sql = sql?;
let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
collect_bind_result?;
Expand Down Expand Up @@ -411,6 +433,18 @@ impl AsyncPgConnection {
let res = callback(raw_connection, stmt.clone(), binds).await;
let mut tm = tm.lock().await;
update_transaction_manager_status(res, &mut tm)
};

let er = self.error_receiver.clone();
async move {
let mut error_receiver = er.lock().await;
// While the future (f) is running, at any await point we will check if
// there is an error in the channel from the connection. If there is, we
// will return that instead and f will get aborted.
tokio::select! {
error = error_receiver.deref_mut() => Err(error.unwrap()),
res = f => res,
}
}
.boxed()
}
Expand Down

0 comments on commit 97520ec

Please sign in to comment.