Skip to content

Commit

Permalink
refactor: remove unneeded arc for startup handler
Browse files Browse the repository at this point in the history
  • Loading branch information
sunng87 committed Jul 4, 2023
1 parent 8c353e7 commit 46fb90a
Show file tree
Hide file tree
Showing 13 changed files with 54 additions and 56 deletions.
4 changes: 1 addition & 3 deletions examples/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,19 @@ pub async fn main() {
let placeholder = Arc::new(StatelessMakeHandler::new(Arc::new(
PlaceholderExtendedQueryHandler,
)));
let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler)));

let server_addr = "127.0.0.1:5433";
let listener = TcpListener::bind(server_addr).await.unwrap();
println!("Listening to {}", server_addr);
loop {
let incoming_socket = listener.accept().await.unwrap();
let authenticator_ref = authenticator.make();
let processor_ref = processor.make();
let placeholder_ref = placeholder.make();
tokio::spawn(async move {
process_socket(
incoming_socket.0,
None,
authenticator_ref,
NoopStartupHandler,
processor_ref,
placeholder_ref,
)
Expand Down
4 changes: 1 addition & 3 deletions examples/gluesql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,21 +182,19 @@ pub async fn main() {
let placeholder = Arc::new(StatelessMakeHandler::new(Arc::new(
PlaceholderExtendedQueryHandler,
)));
let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler)));

let server_addr = "127.0.0.1:5432";
let listener = TcpListener::bind(server_addr).await.unwrap();
println!("Listening to {}", server_addr);
loop {
let incoming_socket = listener.accept().await.unwrap();
let authenticator_ref = authenticator.make();
let processor_ref = processor.make();
let placeholder_ref = placeholder.make();
tokio::spawn(async move {
process_socket(
incoming_socket.0,
None,
authenticator_ref,
NoopStartupHandler,
processor_ref,
placeholder_ref,
)
Expand Down
13 changes: 7 additions & 6 deletions examples/scram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,16 @@ pub async fn main() {
let placeholder = Arc::new(StatelessMakeHandler::new(Arc::new(
PlaceholderExtendedQueryHandler,
)));
let mut authenticator = MakeSASLScramAuthStartupHandler::new(
let mut authenticator_maker = MakeSASLScramAuthStartupHandler::new(
Arc::new(DummyAuthDB),
Arc::new(DefaultServerParameterProvider),
);
authenticator.set_iterations(ITERATIONS);
authenticator_maker.set_iterations(ITERATIONS);

let cert = fs::read("examples/ssl/server.crt").unwrap();
authenticator.configure_certificate(cert.as_ref()).unwrap();
let authenticator = Arc::new(authenticator);
authenticator_maker
.configure_certificate(cert.as_ref())
.unwrap();

let server_addr = "127.0.0.1:5432";
let tls_acceptor = Arc::new(setup_tls().unwrap());
Expand All @@ -98,14 +99,14 @@ pub async fn main() {
loop {
let incoming_socket = listener.accept().await.unwrap();
let tls_acceptor_ref = tls_acceptor.clone();
let authenticator_ref = authenticator.make();
let authenticator = authenticator_maker.make();
let processor_ref = processor.make();
let placeholder_ref = placeholder.make();
tokio::spawn(async move {
process_socket(
incoming_socket.0,
Some(tls_acceptor_ref),
authenticator_ref,
authenticator,
processor_ref,
placeholder_ref,
)
Expand Down
4 changes: 1 addition & 3 deletions examples/secure_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ pub async fn main() {
let placeholder = Arc::new(StatelessMakeHandler::new(Arc::new(
PlaceholderExtendedQueryHandler,
)));
let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler)));

let server_addr = "127.0.0.1:5433";
let tls_acceptor = Arc::new(setup_tls().unwrap());
Expand All @@ -91,14 +90,13 @@ pub async fn main() {
loop {
let incoming_socket = listener.accept().await.unwrap();
let tls_acceptor_ref = tls_acceptor.clone();
let authenticator_ref = authenticator.make();
let processor_ref = processor.make();
let placeholder_ref = placeholder.make();
tokio::spawn(async move {
process_socket(
incoming_socket.0,
Some(tls_acceptor_ref),
authenticator_ref,
NoopStartupHandler,
processor_ref,
placeholder_ref,
)
Expand Down
4 changes: 1 addition & 3 deletions examples/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,19 @@ pub async fn main() {
let placeholder = Arc::new(StatelessMakeHandler::new(Arc::new(
PlaceholderExtendedQueryHandler,
)));
let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler)));

let server_addr = "127.0.0.1:5432";
let listener = TcpListener::bind(server_addr).await.unwrap();
println!("Listening to {}", server_addr);
loop {
let incoming_socket = listener.accept().await.unwrap();
let authenticator_ref = authenticator.make();
let processor_ref = processor.make();
let placeholder_ref = placeholder.make();
tokio::spawn(async move {
process_socket(
incoming_socket.0,
None,
authenticator_ref,
NoopStartupHandler,
processor_ref,
placeholder_ref,
)
Expand Down
6 changes: 3 additions & 3 deletions examples/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ impl MakeHandler for MakeSqliteBackend {

#[tokio::main]
pub async fn main() {
let authenticator = Arc::new(MakeMd5PasswordAuthStartupHandler::new(
let authenticator_maker = Arc::new(MakeMd5PasswordAuthStartupHandler::new(
Arc::new(DummyAuthSource),
Arc::new(SqliteParameters::new()),
));
Expand All @@ -334,13 +334,13 @@ pub async fn main() {
println!("Listening to {}", server_addr);
loop {
let incoming_socket = listener.accept().await.unwrap();
let authenticator_ref = authenticator.make();
let mut authenticator = authenticator_maker.make();
let processor_ref = processor.make();
tokio::spawn(async move {
process_socket(
incoming_socket.0,
None,
authenticator_ref,
authenticator,
processor_ref.clone(),
processor_ref,
)
Expand Down
2 changes: 1 addition & 1 deletion src/api/auth/cleartext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl<V: AuthSource, P: ServerParameterProvider> StartupHandler
for CleartextPasswordAuthStartupHandler<V, P>
{
async fn on_startup<C>(
&self,
&mut self,
client: &mut C,
message: PgWireFrontendMessage,
) -> PgWireResult<()>
Expand Down
29 changes: 17 additions & 12 deletions src/api/auth/md5pass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::sync::Arc;

use async_trait::async_trait;
use futures::sink::{Sink, SinkExt};
use tokio::sync::Mutex;

use super::{
AuthSource, ClientInfo, LoginInfo, PgWireConnectionState, ServerParameterProvider,
Expand All @@ -18,15 +17,15 @@ use crate::messages::{PgWireBackendMessage, PgWireFrontendMessage};
pub struct Md5PasswordAuthStartupHandler<A, P> {
auth_source: Arc<A>,
parameter_provider: Arc<P>,
cached_password: Mutex<Vec<u8>>,
cached_password: Vec<u8>,
}

#[async_trait]
impl<A: AuthSource, P: ServerParameterProvider> StartupHandler
for Md5PasswordAuthStartupHandler<A, P>
{
async fn on_startup<C>(
&self,
&mut self,
client: &mut C,
message: PgWireFrontendMessage,
) -> PgWireResult<()>
Expand All @@ -48,7 +47,7 @@ impl<A: AuthSource, P: ServerParameterProvider> StartupHandler
.as_ref()
.expect("Salt is required for Md5Password authentication");

*self.cached_password.lock().await = salt_and_pass.password().clone();
self.cached_password = salt_and_pass.password().clone();

client
.send(PgWireBackendMessage::Authentication(
Expand All @@ -57,10 +56,12 @@ impl<A: AuthSource, P: ServerParameterProvider> StartupHandler
.await?;
}
PgWireFrontendMessage::PasswordMessageFamily(pwd) => {
let pwd = pwd.into_password()?;
let cached_pass = self.cached_password.lock().await;
let pass_match = {
let pwd = pwd.into_password()?;
pwd.password().as_bytes() == self.cached_password
};

if pwd.password().as_bytes() == *cached_pass {
if pass_match {
super::finish_authentication(client, self.parameter_provider.as_ref()).await
} else {
let error_info = ErrorInfo::new(
Expand Down Expand Up @@ -104,15 +105,19 @@ pub struct MakeMd5PasswordAuthStartupHandler<A, P> {
parameter_provider: Arc<P>,
}

impl<V, P> MakeHandler for MakeMd5PasswordAuthStartupHandler<V, P> {
type Handler = Arc<Md5PasswordAuthStartupHandler<V, P>>;
impl<V, P> MakeHandler for MakeMd5PasswordAuthStartupHandler<V, P>
where
V: AuthSource,
P: ServerParameterProvider,
{
type Handler = Md5PasswordAuthStartupHandler<V, P>;

fn make(&self) -> Self::Handler {
Arc::new(Md5PasswordAuthStartupHandler {
Md5PasswordAuthStartupHandler {
auth_source: self.auth_source.clone(),
parameter_provider: self.parameter_provider.clone(),
cached_password: Mutex::new(vec![]),
})
cached_password: vec![],
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/api/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::messages::{PgWireBackendMessage, PgWireFrontendMessage};
pub trait StartupHandler: Send + Sync {
/// A generic frontend message callback during startup phase.
async fn on_startup<C>(
&self,
&mut self,
client: &mut C,
message: PgWireFrontendMessage,
) -> PgWireResult<()>
Expand Down
2 changes: 1 addition & 1 deletion src/api/auth/noop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub struct NoopStartupHandler;
#[async_trait]
impl StartupHandler for NoopStartupHandler {
async fn on_startup<C>(
&self,
&mut self,
client: &mut C,
message: PgWireFrontendMessage,
) -> PgWireResult<()>
Expand Down
28 changes: 14 additions & 14 deletions src/api/auth/scram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use futures::{Sink, SinkExt};
use ring::digest;
use ring::hmac;
use ring::pbkdf2;
use tokio::sync::Mutex;
use x509_certificate::certificate::CapturedX509Certificate;
use x509_certificate::SignatureAlgorithm;

Expand All @@ -36,7 +35,7 @@ pub struct SASLScramAuthStartupHandler<A, P> {
auth_db: Arc<A>,
parameter_provider: Arc<P>,
/// state of the client-server communication
state: Mutex<ScramState>,
state: ScramState,
/// base64 encoded certificate signature for tls-server-end-point channel binding
server_cert_sig: Option<Arc<String>>,
/// iterations
Expand Down Expand Up @@ -91,7 +90,7 @@ impl<A: AuthSource, P: ServerParameterProvider> StartupHandler
for SASLScramAuthStartupHandler<A, P>
{
async fn on_startup<C>(
&self,
&mut self,
client: &mut C,
message: PgWireFrontendMessage,
) -> PgWireResult<()>
Expand All @@ -117,8 +116,7 @@ impl<A: AuthSource, P: ServerParameterProvider> StartupHandler
}
PgWireFrontendMessage::PasswordMessageFamily(msg) => {
let salt_and_salted_pass = {
let state = self.state.lock().await;
match *state {
match self.state {
ScramState::Initial => {
let login_info = LoginInfo::from_client_info(client);
self.auth_db.get_password(&login_info).await?
Expand All @@ -129,9 +127,7 @@ impl<A: AuthSource, P: ServerParameterProvider> StartupHandler

let mut success = false;
let resp = {
// this should never block
let mut state = self.state.lock().await;
match *state {
match self.state {
ScramState::Initial => {
// initial response, client_first
let resp = msg.into_sasl_initial_response()?;
Expand Down Expand Up @@ -165,7 +161,7 @@ impl<A: AuthSource, P: ServerParameterProvider> StartupHandler
);
let server_first_message = server_first.message();

*state = ScramState::ServerFirstSent(
self.state = ScramState::ServerFirstSent(
salt_and_salted_pass,
client_first.channel_binding(),
format!("{},{}", client_first.bare(), &server_first_message),
Expand Down Expand Up @@ -266,17 +262,21 @@ impl<A, P> MakeSASLScramAuthStartupHandler<A, P> {
}
}

impl<A, P> MakeHandler for MakeSASLScramAuthStartupHandler<A, P> {
type Handler = Arc<SASLScramAuthStartupHandler<A, P>>;
impl<A, P> MakeHandler for MakeSASLScramAuthStartupHandler<A, P>
where
A: AuthSource,
P: ServerParameterProvider,
{
type Handler = SASLScramAuthStartupHandler<A, P>;

fn make(&self) -> Self::Handler {
Arc::new(SASLScramAuthStartupHandler {
SASLScramAuthStartupHandler {
auth_db: self.auth_db.clone(),
parameter_provider: self.parameter_provider.clone(),
state: Mutex::new(ScramState::Initial),
state: ScramState::Initial,
server_cert_sig: self.server_cert_sig.clone(),
iterations: self.iterations,
})
}
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl<T> ClientInfo for Framed<T, PgWireMessageServerCodec> {
async fn process_message<S, A, Q, EQ>(
message: PgWireFrontendMessage,
socket: &mut Framed<S, PgWireMessageServerCodec>,
authenticator: Arc<A>,
authenticator: &mut A,
query_handler: Arc<Q>,
extended_query_handler: Arc<EQ>,
) -> PgWireResult<()>
Expand Down Expand Up @@ -216,7 +216,7 @@ async fn peek_for_sslrequest(
pub async fn process_socket<A, Q, EQ>(
mut tcp_socket: TcpStream,
tls_acceptor: Option<Arc<TlsAcceptor>>,
startup_handler: Arc<A>,
mut startup_handler: A,
query_handler: Arc<Q>,
extended_query_handler: Arc<EQ>,
) -> Result<(), IOError>
Expand All @@ -239,7 +239,7 @@ where
if let Err(e) = process_message(
msg,
&mut socket,
startup_handler.clone(),
&mut startup_handler,
query_handler.clone(),
extended_query_handler.clone(),
)
Expand All @@ -255,7 +255,7 @@ where
if let Err(e) = process_message(
msg,
&mut socket,
startup_handler.clone(),
&mut startup_handler,
query_handler.clone(),
extended_query_handler.clone(),
)
Expand Down
4 changes: 2 additions & 2 deletions tests-integration/test-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,13 @@ pub async fn main() {
println!("Listening to {}", server_addr);
loop {
let incoming_socket = listener.accept().await.unwrap();
let authenticator_ref = authenticator.make();
let authenticator = authenticator.make();
let processor_ref = processor.make();
tokio::spawn(async move {
process_socket(
incoming_socket.0,
None,
authenticator_ref,
authenticator,
processor_ref.clone(),
processor_ref,
)
Expand Down

0 comments on commit 46fb90a

Please sign in to comment.