Skip to content

Commit

Permalink
feat: add support for direct ssl (#189)
Browse files Browse the repository at this point in the history
* feat: add support for direct ssl

* feat: require ALPN check for direct ssl

* test: add tests for direct ssl
  • Loading branch information
sunng87 committed Jul 24, 2024
1 parent 18a6adb commit f232a77
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 69 deletions.
4 changes: 3 additions & 1 deletion examples/secure_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,13 @@ fn setup_tls() -> Result<TlsAcceptor, IOError> {
.collect::<Result<Vec<PrivateKeyDer>, IOError>>()?
.remove(0);

let config = ServerConfig::builder()
let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert, key)
.map_err(|err| IOError::new(ErrorKind::InvalidInput, err))?;

config.alpn_protocols = vec![b"postgresql".to_vec()];

Ok(TlsAcceptor::from(Arc::new(config)))
}

Expand Down
180 changes: 117 additions & 63 deletions src/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use std::io::Error as IOError;
use std::sync::Arc;

use bytes::BytesMut;
use futures::future::poll_fn;
use futures::{SinkExt, StreamExt};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_rustls::server::TlsStream;
use tokio_rustls::TlsAcceptor;
use tokio_util::codec::{Decoder, Encoder, Framed};

Expand Down Expand Up @@ -249,41 +249,111 @@ where
Ok(())
}

async fn is_sslrequest_pending(tcp_socket: &TcpStream) -> Result<bool, IOError> {
#[derive(Debug, PartialEq, Eq)]
enum SslNegotiationType {
Postgres,
Direct,
None,
}

async fn check_ssl_negotiation(tcp_socket: &TcpStream) -> Result<SslNegotiationType, IOError> {
let mut buf = [0u8; SslRequest::BODY_SIZE];
let mut buf = ReadBuf::new(&mut buf);
while buf.filled().len() < SslRequest::BODY_SIZE {
if poll_fn(|cx| tcp_socket.poll_peek(cx, &mut buf)).await? == 0 {
// the tcp_stream has ended
return Ok(false);
loop {
let n = tcp_socket.peek(&mut buf).await?;
if n >= SslRequest::BODY_SIZE {
break;
}
}
if buf[0] == 0x16 {
return Ok(SslNegotiationType::Direct);
}

let mut buf = BytesMut::from(buf.filled());
let mut buf = BytesMut::from(buf.as_slice());
if let Ok(Some(_)) = SslRequest::decode(&mut buf) {
return Ok(true);
return Ok(SslNegotiationType::Postgres);
}
Ok(false)
Ok(SslNegotiationType::None)
}

async fn peek_for_sslrequest<ST>(
socket: &mut Framed<TcpStream, PgWireMessageServerCodec<ST>>,
ssl_supported: bool,
) -> Result<bool, IOError> {
let mut ssl = false;
if is_sslrequest_pending(socket.get_ref()).await? {
// consume request
socket.next().await;

let response = if ssl_supported {
ssl = true;
PgWireBackendMessage::SslResponse(SslResponse::Accept)
} else {
PgWireBackendMessage::SslResponse(SslResponse::Refuse)
) -> Result<SslNegotiationType, IOError> {
let mut negotiation_type = check_ssl_negotiation(socket.get_ref()).await?;
match negotiation_type {
SslNegotiationType::Postgres => {
// consume request
socket.next().await;

let response = if ssl_supported {
PgWireBackendMessage::SslResponse(SslResponse::Accept)
} else {
negotiation_type = SslNegotiationType::None;
PgWireBackendMessage::SslResponse(SslResponse::Refuse)
};
socket.send(response).await?;
}
SslNegotiationType::Direct => {}
SslNegotiationType::None => {}
}

Ok(negotiation_type)
}

async fn do_process_socket<S, A, Q, EQ, C>(
socket: &mut Framed<S, PgWireMessageServerCodec<EQ::Statement>>,
startup_handler: Arc<A>,
simple_query_handler: Arc<Q>,
extended_query_handler: Arc<EQ>,
copy_handler: Arc<C>,
) -> Result<(), IOError>
where
S: AsyncRead + AsyncWrite + Unpin + Send + Sync,
A: StartupHandler,
Q: SimpleQueryHandler,
EQ: ExtendedQueryHandler,
C: CopyHandler,
{
while let Some(Ok(msg)) = socket.next().await {
let is_extended_query = match socket.state() {
PgWireConnectionState::CopyInProgress(is_extended_query) => is_extended_query,
_ => msg.is_extended_query(),
};
socket.send(response).await?;
if let Err(e) = process_message(
msg,
socket,
startup_handler.clone(),
simple_query_handler.clone(),
extended_query_handler.clone(),
copy_handler.clone(),
)
.await
{
process_error(socket, e, is_extended_query).await?;
}
}

Ok(())
}

fn check_alpn_for_direct_ssl<IO>(tls_socket: &TlsStream<IO>) -> Result<(), IOError> {
let (_, the_conn) = tls_socket.get_ref();
let mut accept = false;

if let Some(alpn) = the_conn.alpn_protocol() {
if alpn == b"postgresql" {
accept = true;
}
}

if !accept {
Err(IOError::new(
std::io::ErrorKind::InvalidData,
"received direct SSL connection request without ALPN protocol negotiation extension",
))
} else {
Ok(())
}
Ok(ssl)
}

pub async fn process_socket<H>(
Expand All @@ -306,28 +376,18 @@ where
let extended_query_handler = handlers.extended_query_handler();
let copy_handler = handlers.copy_handler();

if !ssl {
if ssl == SslNegotiationType::None {
// use an already configured socket.
let mut socket = tcp_socket;

while let Some(Ok(msg)) = socket.next().await {
let is_extended_query = match socket.state() {
PgWireConnectionState::CopyInProgress(is_extended_query) => is_extended_query,
_ => msg.is_extended_query(),
};
if let Err(e) = process_message(
msg,
&mut socket,
startup_handler.clone(),
simple_query_handler.clone(),
extended_query_handler.clone(),
copy_handler.clone(),
)
.await
{
process_error(&mut socket, e, is_extended_query).await?;
}
}
do_process_socket(
&mut socket,
startup_handler,
simple_query_handler,
extended_query_handler,
copy_handler,
)
.await
} else {
// mention the use of ssl
let client_info = DefaultClient::new(addr, true);
Expand All @@ -336,27 +396,21 @@ where
.unwrap()
.accept(tcp_socket.into_inner())
.await?;
let mut socket = Framed::new(ssl_socket, PgWireMessageServerCodec::new(client_info));

while let Some(Ok(msg)) = socket.next().await {
let is_extended_query = match socket.state() {
PgWireConnectionState::CopyInProgress(is_extended_query) => is_extended_query,
_ => msg.is_extended_query(),
};
if let Err(e) = process_message(
msg,
&mut socket,
startup_handler.clone(),
simple_query_handler.clone(),
extended_query_handler.clone(),
copy_handler.clone(),
)
.await
{
process_error(&mut socket, e, is_extended_query).await?;
}
// check alpn for direct ssl connection
if ssl == SslNegotiationType::Direct {
check_alpn_for_direct_ssl(&ssl_socket)?;
}
}

Ok(())
let mut socket = Framed::new(ssl_socket, PgWireMessageServerCodec::new(client_info));

do_process_socket(
&mut socket,
startup_handler,
simple_query_handler,
extended_query_handler,
copy_handler,
)
.await
}
}
4 changes: 3 additions & 1 deletion tests-integration/rust-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ publish = false

[dependencies]
tokio = { version = "1", features = ["full"] }
postgres = { version = "0.19" }
openssl = "0.10"
postgres = { git = "https://github.com/sunng87/rust-postgres.git", rev = "629991beed1e689bd8ec79e1fd83aed03d049eef" }
postgres-openssl = { git = "https://github.com/sunng87/rust-postgres.git", rev = "629991beed1e689bd8ec79e1fd83aed03d049eef" }
12 changes: 9 additions & 3 deletions tests-integration/rust-client/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
use std::time::SystemTime;

use postgres::{Client, NoTls, SimpleQueryMessage};
use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
use postgres::{Client, SimpleQueryMessage};
use postgres_openssl::MakeTlsConnector;

fn main() {
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_verify(SslVerifyMode::NONE);
postgres_openssl::set_postgresql_alpn(&mut builder).unwrap();
let connector = MakeTlsConnector::new(builder.build());
let mut client = Client::connect(
"host=localhost port=5432 user=postgres password=pencil dbname=localdb keepalives=0",
NoTls,
"host=localhost port=5432 user=postgres password=pencil dbname=localdb keepalives=0 sslmode=require sslnegotiation=direct",
connector,
)
.unwrap();

Expand Down
3 changes: 3 additions & 0 deletions tests-integration/test-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ pgwire = { path = "../../", features = ["scram"] }
async-trait = "0.1"
futures = "0.3"
tokio = { version = "1", features = ["full"] }
tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "tls12"]}
rustls-pemfile = "2.0"
rustls-pki-types = "1.0"
35 changes: 34 additions & 1 deletion tests-integration/test-server/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
use std::fs::File;
use std::io::{BufReader, Error as IOError, ErrorKind};
use std::sync::Arc;
use std::time::{Duration, SystemTime};

use async_trait::async_trait;
use futures::stream;
use futures::StreamExt;
use rustls_pemfile::{certs, pkcs8_private_keys};
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
use tokio_rustls::rustls::ServerConfig;
use tokio_rustls::TlsAcceptor;

use pgwire::api::auth::scram::{gen_salted_password, SASLScramAuthStartupHandler};
use pgwire::api::auth::{AuthSource, DefaultServerParameterProvider, LoginInfo, Password};
Expand Down Expand Up @@ -229,17 +235,44 @@ impl PgWireHandlerFactory for DummyDatabaseFactory {
}
}

fn setup_tls() -> Result<TlsAcceptor, IOError> {
let cert = certs(&mut BufReader::new(File::open(
"../../examples/ssl/server.crt",
)?))
.collect::<Result<Vec<CertificateDer>, IOError>>()?;

let key = pkcs8_private_keys(&mut BufReader::new(File::open(
"../../examples/ssl/server.key",
)?))
.map(|key| key.map(PrivateKeyDer::from))
.collect::<Result<Vec<PrivateKeyDer>, IOError>>()?
.remove(0);

let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert, key)
.map_err(|err| IOError::new(ErrorKind::InvalidInput, err))?;

config.alpn_protocols = vec![b"postgresql".to_vec()];

Ok(TlsAcceptor::from(Arc::new(config)))
}

#[tokio::main]
pub async fn main() {
let factory = Arc::new(DummyDatabaseFactory(Arc::new(DummyDatabase::default())));

let server_addr = "127.0.0.1:5432";
let tls_acceptor = Arc::new(setup_tls().unwrap());
let listener = TcpListener::bind(server_addr).await.unwrap();
println!("Listening to {}", server_addr);
loop {
let incoming_socket = listener.accept().await.unwrap();
let tls_acceptor_ref = tls_acceptor.clone();
let factory_ref = factory.clone();

tokio::spawn(async move { process_socket(incoming_socket.0, None, factory_ref).await });
tokio::spawn(async move {
process_socket(incoming_socket.0, Some(tls_acceptor_ref), factory_ref).await
});
}
}

0 comments on commit f232a77

Please sign in to comment.