Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove getset #144

Merged
merged 7 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ rust-version = "1.67"

[dependencies]
log = "0.4"
getset = "0.1.2"
derive-new = "0.6"
bytes = "1.1.0"
time = "0.3"
Expand Down
16 changes: 8 additions & 8 deletions examples/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ fn encode_row_data(
fn get_params(portal: &Portal<String>) -> Vec<Box<dyn ToSql>> {
let mut results = Vec::with_capacity(portal.parameter_len());
for i in 0..portal.parameter_len() {
let param_type = portal.statement().parameter_types().get(i).unwrap();
let param_type = portal.statement.parameter_types.get(i).unwrap();
// we only support a small amount of types for demo
match param_type {
&Type::BOOL => {
Expand Down Expand Up @@ -204,7 +204,7 @@ impl ExtendedQueryHandler for SqliteBackend {
C: ClientInfo + Unpin + Send + Sync,
{
let conn = self.conn.lock().unwrap();
let query = portal.statement().statement();
let query = &portal.statement.statement;
let mut stmt = conn
.prepare_cached(query)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
Expand All @@ -215,7 +215,7 @@ impl ExtendedQueryHandler for SqliteBackend {
.collect::<Vec<&dyn rusqlite::ToSql>>();

if query.to_uppercase().starts_with("SELECT") {
let header = Arc::new(row_desc_from_stmt(&stmt, portal.result_column_format())?);
let header = Arc::new(row_desc_from_stmt(&stmt, &portal.result_column_format)?);
stmt.query::<&[&dyn rusqlite::ToSql]>(params_ref.as_ref())
.map(|rows| {
let s = encode_row_data(rows, header.clone());
Expand All @@ -242,18 +242,18 @@ impl ExtendedQueryHandler for SqliteBackend {
let conn = self.conn.lock().unwrap();
match target {
StatementOrPortal::Statement(stmt) => {
let param_types = Some(stmt.parameter_types().clone());
let param_types = Some(stmt.parameter_types.clone());
let stmt = conn
.prepare_cached(stmt.statement())
.prepare_cached(&stmt.statement)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
row_desc_from_stmt(&stmt, &Format::UnifiedBinary)
.map(|fields| DescribeResponse::new(param_types, fields))
}
StatementOrPortal::Portal(portal) => {
let stmt = conn
.prepare_cached(portal.statement().statement())
.prepare_cached(&portal.statement.statement)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
row_desc_from_stmt(&stmt, portal.result_column_format())
row_desc_from_stmt(&stmt, &portal.result_column_format)
.map(|fields| DescribeResponse::new(None, fields))
}
}
Expand Down Expand Up @@ -290,7 +290,7 @@ impl MakeHandler for MakeSqliteBackend {
#[tokio::main]
pub async fn main() {
let mut parameters = DefaultServerParameterProvider::default();
parameters.set_server_version(rusqlite::version().to_owned());
parameters.server_version = rusqlite::version().to_owned();

let authenticator = Arc::new(MakeMd5PasswordAuthStartupHandler::new(
Arc::new(DummyAuthSource),
Expand Down
2 changes: 1 addition & 1 deletion src/api/auth/cleartext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl<V: AuthSource, P: ServerParameterProvider> StartupHandler
let pwd = pwd.into_password()?;
let login_info = LoginInfo::from_client_info(client);
let pass = self.auth_source.get_password(&login_info).await?;
if pass.password() == pwd.password().as_bytes() {
if pass.password == pwd.password.as_bytes() {
super::finish_authentication(client, &self.parameter_provider).await
} else {
let error_info = ErrorInfo::new(
Expand Down
6 changes: 3 additions & 3 deletions src/api/auth/md5pass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ impl<A: AuthSource, P: ServerParameterProvider> StartupHandler
let salt_and_pass = self.auth_source.get_password(&login_info).await?;

let salt = salt_and_pass
.salt()
.salt
.as_ref()
.expect("Salt is required for Md5Password authentication");

*self.cached_password.lock().await = salt_and_pass.password().clone();
*self.cached_password.lock().await = salt_and_pass.password.clone();

client
.send(PgWireBackendMessage::Authentication(
Expand All @@ -60,7 +60,7 @@ impl<A: AuthSource, P: ServerParameterProvider> StartupHandler
let pwd = pwd.into_password()?;
let cached_pass = self.cached_password.lock().await;

if pwd.password().as_bytes() == *cached_pass {
if pwd.password.as_bytes() == *cached_pass {
super::finish_authentication(client, self.parameter_provider.as_ref()).await
} else {
let error_info = ErrorInfo::new(
Expand Down
52 changes: 36 additions & 16 deletions src/api/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ pub trait ServerParameterProvider: Send + Sync {
/// - `client_encoding: UTF8`
/// - `integer_datetimes: on`:
///
#[derive(Debug, Getters, Setters)]
#[getset(get = "pub", set = "pub")]
#[non_exhaustive]
#[derive(Debug)]
pub struct DefaultServerParameterProvider {
server_version: String,
server_encoding: String,
client_encoding: String,
date_style: String,
integer_datetimes: String,
pub server_version: String,
pub server_encoding: String,
pub client_encoding: String,
pub date_style: String,
pub integer_datetimes: String,
}

impl Default for DefaultServerParameterProvider {
Expand Down Expand Up @@ -85,29 +85,49 @@ impl ServerParameterProvider for DefaultServerParameterProvider {
}
}

#[derive(Debug, new, Getters, Clone)]
#[getset(get = "pub")]
#[derive(Debug, new, Clone)]
pub struct Password {
salt: Option<Vec<u8>>,
password: Vec<u8>,
}

#[derive(Debug, new, Getters)]
#[getset(get = "pub")]
impl Password {
pub fn salt(&self) -> Option<&[u8]> {
self.salt.as_deref()
}

pub fn password(&self) -> &[u8] {
&self.password
}
}

#[derive(Debug, new)]
pub struct LoginInfo<'a> {
user: Option<&'a String>,
database: Option<&'a String>,
user: Option<&'a str>,
database: Option<&'a str>,
host: String,
}

impl<'a> LoginInfo<'a> {
pub fn user(&self) -> Option<&str> {
self.user
}

pub fn database(&self) -> Option<&str> {
self.database
}

pub fn host(&self) -> &str {
&self.host
}

pub fn from_client_info<C>(client: &'a C) -> LoginInfo
where
C: ClientInfo,
{
LoginInfo {
user: client.metadata().get(METADATA_USER),
database: client.metadata().get(METADATA_DATABASE),
user: client.metadata().get(METADATA_USER).map(|s| s.as_str()),
database: client.metadata().get(METADATA_DATABASE).map(|s| s.as_str()),
host: client.socket_addr().ip().to_string(),
}
}
Expand Down Expand Up @@ -135,7 +155,7 @@ where
{
client.metadata_mut().extend(
startup_message
.parameters()
.parameters
.iter()
.map(|(k, v)| (k.to_owned(), v.to_owned())),
);
Expand Down
11 changes: 5 additions & 6 deletions src/api/auth/scram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ impl<A: AuthSource, P: ServerParameterProvider> StartupHandler
let resp = msg.into_sasl_initial_response()?;
// parse into client_first
let client_first = resp
.data()
.data
.as_ref()
.ok_or_else(|| {
PgWireError::InvalidScramMessage(
Expand All @@ -157,7 +157,7 @@ impl<A: AuthSource, P: ServerParameterProvider> StartupHandler
new_nonce,
STANDARD.encode(
salt_and_salted_pass
.salt()
.salt
.as_ref()
.expect("Salt required for SCRAM auth source"),
),
Expand All @@ -179,16 +179,15 @@ impl<A: AuthSource, P: ServerParameterProvider> StartupHandler
) => {
// second response, client_final
let resp = msg.into_sasl_response()?;
let client_final = ClientFinal::try_new(
String::from_utf8_lossy(resp.data().as_ref()).as_ref(),
)?;
let client_final =
ClientFinal::try_new(String::from_utf8_lossy(&resp.data).as_ref())?;
// dbg!(&client_final);

let channel_binding =
self.compute_channel_binding(channel_binding_prefix);
client_final.validate_channel_binding(&channel_binding)?;

let salted_password = salt_and_salted_pass.password();
let salted_password = salt_and_salted_pass.password;
let client_key = hmac(salted_password.as_ref(), b"Client Key");
let stored_key = h(client_key.as_ref());
let auth_msg =
Expand Down
46 changes: 36 additions & 10 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub mod store;

pub const DEFAULT_NAME: &str = "POSTGRESQL_DEFAULT_NAME";

#[derive(Debug, Clone, Default)]
#[derive(Debug, Clone, Copy, Default)]
pub enum PgWireConnectionState {
#[default]
AwaitingStartup,
Expand All @@ -26,11 +26,11 @@ pub enum PgWireConnectionState {

/// Describe a client information holder
pub trait ClientInfo {
fn socket_addr(&self) -> &SocketAddr;
fn socket_addr(&self) -> SocketAddr;

fn is_secure(&self) -> bool;

fn state(&self) -> &PgWireConnectionState;
fn state(&self) -> PgWireConnectionState;

fn set_state(&mut self, new_state: PgWireConnectionState);

Expand All @@ -49,14 +49,40 @@ pub trait ClientPortalStore {
pub const METADATA_USER: &str = "user";
pub const METADATA_DATABASE: &str = "database";

#[derive(Debug, Getters, Setters, MutGetters)]
#[getset(get = "pub", set = "pub", get_mut = "pub")]
#[non_exhaustive]
#[derive(Debug)]
pub struct DefaultClient<S> {
socket_addr: SocketAddr,
is_secure: bool,
state: PgWireConnectionState,
metadata: HashMap<String, String>,
portal_store: store::MemPortalStore<S>,
pub socket_addr: SocketAddr,
pub is_secure: bool,
pub state: PgWireConnectionState,
pub metadata: HashMap<String, String>,
pub portal_store: store::MemPortalStore<S>,
}

impl<S> ClientInfo for DefaultClient<S> {
fn socket_addr(&self) -> SocketAddr {
self.socket_addr
}

fn is_secure(&self) -> bool {
self.is_secure
}

fn state(&self) -> PgWireConnectionState {
self.state
}

fn set_state(&mut self, new_state: PgWireConnectionState) {
self.state = new_state;
}

fn metadata(&self) -> &HashMap<String, String> {
&self.metadata
}

fn metadata_mut(&mut self) -> &mut HashMap<String, String> {
&mut self.metadata
}
}

impl<S> DefaultClient<S> {
Expand Down
25 changes: 12 additions & 13 deletions src/api/portal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@ use super::{results::FieldFormat, stmt::StoredStatement, DEFAULT_NAME};

/// Represent a prepared sql statement and its parameters bound by a `Bind`
/// request.
#[derive(Debug, CopyGetters, Default, Getters, Setters, Clone)]
#[getset(get = "pub", set = "pub", get_mut = "pub")]
#[derive(Debug, Default, Clone)]
pub struct Portal<S> {
name: String,
statement: Arc<StoredStatement<S>>,
parameter_format: Format,
parameters: Vec<Option<Bytes>>,
result_column_format: Format,
pub name: String,
pub statement: Arc<StoredStatement<S>>,
pub parameter_format: Format,
pub parameters: Vec<Option<Bytes>>,
pub result_column_format: Format,
}

#[derive(Debug, Clone, Default)]
Expand Down Expand Up @@ -76,21 +75,21 @@ impl<S: Clone> Portal<S> {
/// Try to create portal from bind command and current client state
pub fn try_new(bind: &Bind, statement: Arc<StoredStatement<S>>) -> PgWireResult<Self> {
let portal_name = bind
.portal_name()
.portal_name
.clone()
.unwrap_or_else(|| DEFAULT_NAME.to_owned());

// param format
let param_format = Format::from_codes(bind.parameter_format_codes());
let param_format = Format::from_codes(&bind.parameter_format_codes);

// format
let result_format = Format::from_codes(bind.result_column_format_codes());
let result_format = Format::from_codes(&bind.result_column_format_codes);

Ok(Portal {
name: portal_name,
statement,
parameter_format: param_format,
parameters: bind.parameters().clone(),
parameters: bind.parameters.clone(),
result_column_format: result_format,
})
}
Expand All @@ -113,11 +112,11 @@ impl<S: Clone> Portal<S> {
}

let param = self
.parameters()
.parameters
.get(idx)
.ok_or_else(|| PgWireError::ParameterIndexOutOfBound(idx))?;

let _format = self.parameter_format().format_for(idx);
let _format = self.parameter_format.format_for(idx);

if let Some(ref param) = param {
// TODO: from_sql only works with binary format
Expand Down
Loading
Loading