diff --git a/examples/bench.rs b/examples/bench.rs index 2765d9e..3f3d25a 100644 --- a/examples/bench.rs +++ b/examples/bench.rs @@ -42,21 +42,19 @@ pub async fn main() { let placeholder = Arc::new(StatelessMakeHandler::new(Arc::new( PlaceholderExtendedQueryHandler, ))); - let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler))); let server_addr = "127.0.0.1:5433"; let listener = TcpListener::bind(server_addr).await.unwrap(); println!("Listening to {}", server_addr); loop { let incoming_socket = listener.accept().await.unwrap(); - let authenticator_ref = authenticator.make(); let processor_ref = processor.make(); let placeholder_ref = placeholder.make(); tokio::spawn(async move { process_socket( incoming_socket.0, None, - authenticator_ref, + NoopStartupHandler, processor_ref, placeholder_ref, ) diff --git a/examples/gluesql.rs b/examples/gluesql.rs index 0bbdcf2..f43dedf 100644 --- a/examples/gluesql.rs +++ b/examples/gluesql.rs @@ -182,21 +182,19 @@ pub async fn main() { let placeholder = Arc::new(StatelessMakeHandler::new(Arc::new( PlaceholderExtendedQueryHandler, ))); - let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler))); let server_addr = "127.0.0.1:5432"; let listener = TcpListener::bind(server_addr).await.unwrap(); println!("Listening to {}", server_addr); loop { let incoming_socket = listener.accept().await.unwrap(); - let authenticator_ref = authenticator.make(); let processor_ref = processor.make(); let placeholder_ref = placeholder.make(); tokio::spawn(async move { process_socket( incoming_socket.0, None, - authenticator_ref, + NoopStartupHandler, processor_ref, placeholder_ref, ) diff --git a/examples/scram.rs b/examples/scram.rs index 01438b8..fcacd79 100644 --- a/examples/scram.rs +++ b/examples/scram.rs @@ -81,15 +81,16 @@ pub async fn main() { let placeholder = Arc::new(StatelessMakeHandler::new(Arc::new( PlaceholderExtendedQueryHandler, ))); - let mut authenticator = MakeSASLScramAuthStartupHandler::new( + let mut authenticator_maker = MakeSASLScramAuthStartupHandler::new( Arc::new(DummyAuthDB), Arc::new(DefaultServerParameterProvider), ); - authenticator.set_iterations(ITERATIONS); + authenticator_maker.set_iterations(ITERATIONS); let cert = fs::read("examples/ssl/server.crt").unwrap(); - authenticator.configure_certificate(cert.as_ref()).unwrap(); - let authenticator = Arc::new(authenticator); + authenticator_maker + .configure_certificate(cert.as_ref()) + .unwrap(); let server_addr = "127.0.0.1:5432"; let tls_acceptor = Arc::new(setup_tls().unwrap()); @@ -98,14 +99,14 @@ pub async fn main() { loop { let incoming_socket = listener.accept().await.unwrap(); let tls_acceptor_ref = tls_acceptor.clone(); - let authenticator_ref = authenticator.make(); + let authenticator = authenticator_maker.make(); let processor_ref = processor.make(); let placeholder_ref = placeholder.make(); tokio::spawn(async move { process_socket( incoming_socket.0, Some(tls_acceptor_ref), - authenticator_ref, + authenticator, processor_ref, placeholder_ref, ) diff --git a/examples/secure_server.rs b/examples/secure_server.rs index 22bd60b..da66714 100644 --- a/examples/secure_server.rs +++ b/examples/secure_server.rs @@ -81,7 +81,6 @@ pub async fn main() { let placeholder = Arc::new(StatelessMakeHandler::new(Arc::new( PlaceholderExtendedQueryHandler, ))); - let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler))); let server_addr = "127.0.0.1:5433"; let tls_acceptor = Arc::new(setup_tls().unwrap()); @@ -91,14 +90,13 @@ pub async fn main() { loop { let incoming_socket = listener.accept().await.unwrap(); let tls_acceptor_ref = tls_acceptor.clone(); - let authenticator_ref = authenticator.make(); let processor_ref = processor.make(); let placeholder_ref = placeholder.make(); tokio::spawn(async move { process_socket( incoming_socket.0, Some(tls_acceptor_ref), - authenticator_ref, + NoopStartupHandler, processor_ref, placeholder_ref, ) diff --git a/examples/server.rs b/examples/server.rs index 07940a5..0acf184 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -59,21 +59,19 @@ pub async fn main() { let placeholder = Arc::new(StatelessMakeHandler::new(Arc::new( PlaceholderExtendedQueryHandler, ))); - let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler))); let server_addr = "127.0.0.1:5432"; let listener = TcpListener::bind(server_addr).await.unwrap(); println!("Listening to {}", server_addr); loop { let incoming_socket = listener.accept().await.unwrap(); - let authenticator_ref = authenticator.make(); let processor_ref = processor.make(); let placeholder_ref = placeholder.make(); tokio::spawn(async move { process_socket( incoming_socket.0, None, - authenticator_ref, + NoopStartupHandler, processor_ref, placeholder_ref, ) diff --git a/examples/sqlite.rs b/examples/sqlite.rs index 5d3daf4..68f7e6d 100644 --- a/examples/sqlite.rs +++ b/examples/sqlite.rs @@ -323,7 +323,7 @@ impl MakeHandler for MakeSqliteBackend { #[tokio::main] pub async fn main() { - let authenticator = Arc::new(MakeMd5PasswordAuthStartupHandler::new( + let authenticator_maker = Arc::new(MakeMd5PasswordAuthStartupHandler::new( Arc::new(DummyAuthSource), Arc::new(SqliteParameters::new()), )); @@ -334,13 +334,13 @@ pub async fn main() { println!("Listening to {}", server_addr); loop { let incoming_socket = listener.accept().await.unwrap(); - let authenticator_ref = authenticator.make(); + let mut authenticator = authenticator_maker.make(); let processor_ref = processor.make(); tokio::spawn(async move { process_socket( incoming_socket.0, None, - authenticator_ref, + authenticator, processor_ref.clone(), processor_ref, ) diff --git a/src/api/auth/cleartext.rs b/src/api/auth/cleartext.rs index 593f944..ce1b84b 100644 --- a/src/api/auth/cleartext.rs +++ b/src/api/auth/cleartext.rs @@ -23,7 +23,7 @@ impl StartupHandler for CleartextPasswordAuthStartupHandler { async fn on_startup( - &self, + &mut self, client: &mut C, message: PgWireFrontendMessage, ) -> PgWireResult<()> diff --git a/src/api/auth/md5pass.rs b/src/api/auth/md5pass.rs index 712b602..404ebd8 100644 --- a/src/api/auth/md5pass.rs +++ b/src/api/auth/md5pass.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use async_trait::async_trait; use futures::sink::{Sink, SinkExt}; -use tokio::sync::Mutex; use super::{ AuthSource, ClientInfo, LoginInfo, PgWireConnectionState, ServerParameterProvider, @@ -18,7 +17,7 @@ use crate::messages::{PgWireBackendMessage, PgWireFrontendMessage}; pub struct Md5PasswordAuthStartupHandler { auth_source: Arc, parameter_provider: Arc

, - cached_password: Mutex>, + cached_password: Vec, } #[async_trait] @@ -26,7 +25,7 @@ impl StartupHandler for Md5PasswordAuthStartupHandler { async fn on_startup( - &self, + &mut self, client: &mut C, message: PgWireFrontendMessage, ) -> PgWireResult<()> @@ -48,7 +47,7 @@ impl StartupHandler .as_ref() .expect("Salt is required for Md5Password authentication"); - *self.cached_password.lock().await = salt_and_pass.password().clone(); + self.cached_password = salt_and_pass.password().clone(); client .send(PgWireBackendMessage::Authentication( @@ -57,10 +56,12 @@ impl StartupHandler .await?; } PgWireFrontendMessage::PasswordMessageFamily(pwd) => { - let pwd = pwd.into_password()?; - let cached_pass = self.cached_password.lock().await; + let pass_match = { + let pwd = pwd.into_password()?; + pwd.password().as_bytes() == self.cached_password + }; - if pwd.password().as_bytes() == *cached_pass { + if pass_match { super::finish_authentication(client, self.parameter_provider.as_ref()).await } else { let error_info = ErrorInfo::new( @@ -104,15 +105,19 @@ pub struct MakeMd5PasswordAuthStartupHandler { parameter_provider: Arc

, } -impl MakeHandler for MakeMd5PasswordAuthStartupHandler { - type Handler = Arc>; +impl MakeHandler for MakeMd5PasswordAuthStartupHandler +where + V: AuthSource, + P: ServerParameterProvider, +{ + type Handler = Md5PasswordAuthStartupHandler; fn make(&self) -> Self::Handler { - Arc::new(Md5PasswordAuthStartupHandler { + Md5PasswordAuthStartupHandler { auth_source: self.auth_source.clone(), parameter_provider: self.parameter_provider.clone(), - cached_password: Mutex::new(vec![]), - }) + cached_password: vec![], + } } } diff --git a/src/api/auth/mod.rs b/src/api/auth/mod.rs index c422535..d4faacf 100644 --- a/src/api/auth/mod.rs +++ b/src/api/auth/mod.rs @@ -17,7 +17,7 @@ use crate::messages::{PgWireBackendMessage, PgWireFrontendMessage}; pub trait StartupHandler: Send + Sync { /// A generic frontend message callback during startup phase. async fn on_startup( - &self, + &mut self, client: &mut C, message: PgWireFrontendMessage, ) -> PgWireResult<()> diff --git a/src/api/auth/noop.rs b/src/api/auth/noop.rs index 5a2aa5a..6c924b3 100644 --- a/src/api/auth/noop.rs +++ b/src/api/auth/noop.rs @@ -12,7 +12,7 @@ pub struct NoopStartupHandler; #[async_trait] impl StartupHandler for NoopStartupHandler { async fn on_startup( - &self, + &mut self, client: &mut C, message: PgWireFrontendMessage, ) -> PgWireResult<()> diff --git a/src/api/auth/scram.rs b/src/api/auth/scram.rs index 93e5e90..4c3c5f2 100644 --- a/src/api/auth/scram.rs +++ b/src/api/auth/scram.rs @@ -12,7 +12,6 @@ use futures::{Sink, SinkExt}; use ring::digest; use ring::hmac; use ring::pbkdf2; -use tokio::sync::Mutex; use x509_certificate::certificate::CapturedX509Certificate; use x509_certificate::SignatureAlgorithm; @@ -36,7 +35,7 @@ pub struct SASLScramAuthStartupHandler { auth_db: Arc, parameter_provider: Arc

, /// state of the client-server communication - state: Mutex, + state: ScramState, /// base64 encoded certificate signature for tls-server-end-point channel binding server_cert_sig: Option>, /// iterations @@ -91,7 +90,7 @@ impl StartupHandler for SASLScramAuthStartupHandler { async fn on_startup( - &self, + &mut self, client: &mut C, message: PgWireFrontendMessage, ) -> PgWireResult<()> @@ -117,8 +116,7 @@ impl StartupHandler } PgWireFrontendMessage::PasswordMessageFamily(msg) => { let salt_and_salted_pass = { - let state = self.state.lock().await; - match *state { + match self.state { ScramState::Initial => { let login_info = LoginInfo::from_client_info(client); self.auth_db.get_password(&login_info).await? @@ -129,9 +127,7 @@ impl StartupHandler let mut success = false; let resp = { - // this should never block - let mut state = self.state.lock().await; - match *state { + match self.state { ScramState::Initial => { // initial response, client_first let resp = msg.into_sasl_initial_response()?; @@ -165,7 +161,7 @@ impl StartupHandler ); let server_first_message = server_first.message(); - *state = ScramState::ServerFirstSent( + self.state = ScramState::ServerFirstSent( salt_and_salted_pass, client_first.channel_binding(), format!("{},{}", client_first.bare(), &server_first_message), @@ -266,17 +262,21 @@ impl MakeSASLScramAuthStartupHandler { } } -impl MakeHandler for MakeSASLScramAuthStartupHandler { - type Handler = Arc>; +impl MakeHandler for MakeSASLScramAuthStartupHandler +where + A: AuthSource, + P: ServerParameterProvider, +{ + type Handler = SASLScramAuthStartupHandler; fn make(&self) -> Self::Handler { - Arc::new(SASLScramAuthStartupHandler { + SASLScramAuthStartupHandler { auth_db: self.auth_db.clone(), parameter_provider: self.parameter_provider.clone(), - state: Mutex::new(ScramState::Initial), + state: ScramState::Initial, server_cert_sig: self.server_cert_sig.clone(), iterations: self.iterations, - }) + } } } diff --git a/src/tokio.rs b/src/tokio.rs index 1b7b9f6..c09f791 100644 --- a/src/tokio.rs +++ b/src/tokio.rs @@ -84,7 +84,7 @@ impl ClientInfo for Framed { async fn process_message( message: PgWireFrontendMessage, socket: &mut Framed, - authenticator: Arc, + authenticator: &mut A, query_handler: Arc, extended_query_handler: Arc, ) -> PgWireResult<()> @@ -216,7 +216,7 @@ async fn peek_for_sslrequest( pub async fn process_socket( mut tcp_socket: TcpStream, tls_acceptor: Option>, - startup_handler: Arc, + mut startup_handler: A, query_handler: Arc, extended_query_handler: Arc, ) -> Result<(), IOError> @@ -239,7 +239,7 @@ where if let Err(e) = process_message( msg, &mut socket, - startup_handler.clone(), + &mut startup_handler, query_handler.clone(), extended_query_handler.clone(), ) @@ -255,7 +255,7 @@ where if let Err(e) = process_message( msg, &mut socket, - startup_handler.clone(), + &mut startup_handler, query_handler.clone(), extended_query_handler.clone(), ) diff --git a/tests-integration/test-server/src/main.rs b/tests-integration/test-server/src/main.rs index b402c5e..960ff62 100644 --- a/tests-integration/test-server/src/main.rs +++ b/tests-integration/test-server/src/main.rs @@ -219,13 +219,13 @@ pub async fn main() { println!("Listening to {}", server_addr); loop { let incoming_socket = listener.accept().await.unwrap(); - let authenticator_ref = authenticator.make(); + let authenticator = authenticator.make(); let processor_ref = processor.make(); tokio::spawn(async move { process_socket( incoming_socket.0, None, - authenticator_ref, + authenticator, processor_ref.clone(), processor_ref, )