Skip to content

Commit

Permalink
Expose underlying connection errors to user of AsyncPgConnection
Browse files Browse the repository at this point in the history
  • Loading branch information
banool committed Oct 5, 2023
1 parent 1e18b37 commit 4438e23
Showing 1 changed file with 45 additions and 12 deletions.
57 changes: 45 additions & 12 deletions src/pg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ pub struct AsyncPgConnection {
stmt_cache: Arc<Mutex<StmtCache<diesel::pg::Pg, Statement>>>,
transaction_state: Arc<Mutex<AnsiTransactionManager>>,
metadata_cache: Arc<Mutex<PgMetadataCache>>,
error_receiver: tokio::sync::oneshot::Receiver<diesel::result::Error>,
}

#[async_trait::async_trait]
Expand All @@ -124,12 +125,19 @@ impl AsyncConnection for AsyncPgConnection {
let (client, connection) = tokio_postgres::connect(database_url, tokio_postgres::NoTls)
.await
.map_err(ErrorHelper)?;
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {e}");
// If there is a connection error, we capture it in this channel and make when
// the user next calls one of the functions on the connection in this trait, we
// return the error instead of the inner result.
let (sender, receiver) = tokio::sync::oneshot::channel();
tokio::spawn(async {
if let Err(connection_error) = connection.await {
let connection_error = diesel::result::Error::from(ErrorHelper(connection_error));
if let Err(send_error) = sender.send(connection_error) {
eprintln!("Failed to send connection error through channel, connection must have been dropped: {}", send_error);
}
}
});
Self::try_from(client).await
Self::try_from_with_error_receiver(client, receiver).await
}

fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
Expand All @@ -138,15 +146,19 @@ impl AsyncConnection for AsyncPgConnection {
T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
{
let query = source.as_query();
self.with_prepared_statement(query, |conn, stmt, binds| async move {
let f = self.with_prepared_statement(query, |conn, stmt, binds| async move {
let res = conn.query_raw(&stmt, binds).await.map_err(ErrorHelper)?;

Ok(res
.map_err(|e| diesel::result::Error::from(ErrorHelper(e)))
.map_ok(PgRow::new)
.boxed())
})
.boxed()
});

match self.error_receiver.try_recv() {
Ok(e) => Box::pin(async move { Err(e) }),
Err(_) => f,
}
}

fn execute_returning_count<'conn, 'query, T>(
Expand All @@ -156,7 +168,7 @@ impl AsyncConnection for AsyncPgConnection {
where
T: QueryFragment<Self::Backend> + QueryId + 'query,
{
self.with_prepared_statement(source, |conn, stmt, binds| async move {
let f = self.with_prepared_statement(source, |conn, stmt, binds| async move {
let binds = binds
.iter()
.map(|b| b as &(dyn ToSql + Sync))
Expand All @@ -166,8 +178,12 @@ impl AsyncConnection for AsyncPgConnection {
.await
.map_err(ErrorHelper)?;
Ok(res as usize)
})
.boxed()
});

match self.error_receiver.try_recv() {
Ok(e) => Box::pin(async move { Err(e) }),
Err(_) => f,
}
}

fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
Expand Down Expand Up @@ -270,11 +286,24 @@ impl AsyncPgConnection {

/// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`]
pub async fn try_from(conn: tokio_postgres::Client) -> ConnectionResult<Self> {
// We create a dummy receiver here. If the user is calling this, they have
// created their own client and connection and are handling any error in
// the latter themselves.
Self::try_from_with_error_receiver(conn, tokio::sync::oneshot::channel().1).await
}

/// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`]
/// and a [`tokio::sync::oneshot::Receiver`] for receiving an error from the connection.
async fn try_from_with_error_receiver(
conn: tokio_postgres::Client,
error_receiver: tokio::sync::oneshot::Receiver<diesel::result::Error>,
) -> ConnectionResult<Self> {
let mut conn = Self {
conn: Arc::new(conn),
stmt_cache: Arc::new(Mutex::new(StmtCache::new())),
transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())),
metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())),
error_receiver,
};
conn.set_config_options()
.await
Expand Down Expand Up @@ -340,7 +369,7 @@ impl AsyncPgConnection {
let metadata_cache = self.metadata_cache.clone();
let tm = self.transaction_state.clone();

async move {
let f = async move {
let sql = sql?;
let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
collect_bind_result?;
Expand Down Expand Up @@ -411,8 +440,12 @@ impl AsyncPgConnection {
let res = callback(raw_connection, stmt.clone(), binds).await;
let mut tm = tm.lock().await;
update_transaction_manager_status(res, &mut tm)
};

match self.error_receiver.try_recv() {
Ok(e) => Box::pin(async move { Err(e) }),
Err(_) => f.boxed(),
}
.boxed()
}
}

Expand Down

0 comments on commit 4438e23

Please sign in to comment.