From 1e8ad1038322d37f2eee90f879d4e2f61b05077a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Tue, 26 Dec 2023 18:04:49 +0000 Subject: [PATCH 1/6] Remove getset getset & its proc-macro-error dependencies seem unmaintained In particular, I'm not confident I can get upstream to update to syn 2 I also don't think getter/setter patterns make particular sense in Rust Use `_hidden: ()` in places to maintain capability to add fields without breaking semver --- Cargo.toml | 1 - src/api/auth/cleartext.rs | 2 +- src/api/auth/md5pass.rs | 6 +-- src/api/auth/mod.rs | 45 ++++++++++++++----- src/api/auth/scram.rs | 8 ++-- src/api/mod.rs | 48 +++++++++++++++----- src/api/portal.rs | 25 +++++------ src/api/query.rs | 24 +++++----- src/api/results.rs | 64 +++++++++++++++++++------- src/api/stmt.rs | 18 ++++---- src/api/store.rs | 4 +- src/error.rs | 37 +++++++-------- src/lib.rs | 2 - src/messages/copy.rs | 20 +++------ src/messages/data.rs | 47 ++++++++++---------- src/messages/extendedquery.rs | 84 +++++++++++++++++------------------ src/messages/mod.rs | 49 ++++++++++---------- src/messages/response.rs | 32 ++++++------- src/messages/simplequery.rs | 7 +-- src/messages/startup.rs | 58 ++++++++++++------------ src/tokio.rs | 27 +++++------ 21 files changed, 339 insertions(+), 269 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2373495..a39c24c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,6 @@ rust-version = "1.67" [dependencies] log = "0.4" -getset = "0.1.2" derive-new = "0.6" bytes = "1.1.0" time = "0.3" diff --git a/src/api/auth/cleartext.rs b/src/api/auth/cleartext.rs index 593f944..1b5dc00 100644 --- a/src/api/auth/cleartext.rs +++ b/src/api/auth/cleartext.rs @@ -46,7 +46,7 @@ impl StartupHandler let pwd = pwd.into_password()?; let login_info = LoginInfo::from_client_info(client); let pass = self.auth_source.get_password(&login_info).await?; - if pass.password() == pwd.password().as_bytes() { + if pass.password == pwd.password.as_bytes() { super::finish_authentication(client, &self.parameter_provider).await } else { let error_info = ErrorInfo::new( diff --git a/src/api/auth/md5pass.rs b/src/api/auth/md5pass.rs index 57c51ec..a4742d1 100644 --- a/src/api/auth/md5pass.rs +++ b/src/api/auth/md5pass.rs @@ -44,11 +44,11 @@ impl StartupHandler let salt_and_pass = self.auth_source.get_password(&login_info).await?; let salt = salt_and_pass - .salt() + .salt .as_ref() .expect("Salt is required for Md5Password authentication"); - *self.cached_password.lock().await = salt_and_pass.password().clone(); + *self.cached_password.lock().await = salt_and_pass.password.clone(); client .send(PgWireBackendMessage::Authentication( @@ -60,7 +60,7 @@ impl StartupHandler let pwd = pwd.into_password()?; let cached_pass = self.cached_password.lock().await; - if pwd.password().as_bytes() == *cached_pass { + if pwd.password.as_bytes() == *cached_pass { super::finish_authentication(client, self.parameter_provider.as_ref()).await } else { let error_info = ErrorInfo::new( diff --git a/src/api/auth/mod.rs b/src/api/auth/mod.rs index 631c6b2..1b3c93c 100644 --- a/src/api/auth/mod.rs +++ b/src/api/auth/mod.rs @@ -44,14 +44,14 @@ pub trait ServerParameterProvider: Send + Sync { /// - `client_encoding: UTF8` /// - `integer_datetimes: on`: /// -#[derive(Debug, Getters, Setters)] -#[getset(get = "pub", set = "pub")] +#[derive(Debug)] pub struct DefaultServerParameterProvider { - server_version: String, - server_encoding: String, - client_encoding: String, - date_style: String, - integer_datetimes: String, + pub server_version: String, + pub server_encoding: String, + pub client_encoding: String, + pub date_style: String, + pub integer_datetimes: String, + _hidden: (), } impl Default for DefaultServerParameterProvider { @@ -62,6 +62,7 @@ impl Default for DefaultServerParameterProvider { client_encoding: "UTF8".to_owned(), date_style: "ISO YMD".to_owned(), integer_datetimes: "on".to_owned(), + _hidden: (), } } } @@ -85,15 +86,23 @@ impl ServerParameterProvider for DefaultServerParameterProvider { } } -#[derive(Debug, new, Getters, Clone)] -#[getset(get = "pub")] +#[derive(Debug, new, Clone)] pub struct Password { salt: Option>, password: Vec, } -#[derive(Debug, new, Getters)] -#[getset(get = "pub")] +impl Password { + pub fn salt(&self) -> Option<&[u8]> { + return self.salt.as_deref(); + } + + pub fn password(&self) -> &[u8] { + return &self.password; + } +} + +#[derive(Debug, new)] pub struct LoginInfo<'a> { user: Option<&'a String>, database: Option<&'a String>, @@ -101,6 +110,18 @@ pub struct LoginInfo<'a> { } impl<'a> LoginInfo<'a> { + pub fn user(&self) -> Option<&str> { + return self.user.map(|u| u.as_str()); + } + + pub fn database(&self) -> Option<&str> { + return self.database.map(|db| db.as_str()); + } + + pub fn host(&self) -> &str { + return &self.host; + } + pub fn from_client_info(client: &'a C) -> LoginInfo where C: ClientInfo, @@ -135,7 +156,7 @@ where { client.metadata_mut().extend( startup_message - .parameters() + .parameters .iter() .map(|(k, v)| (k.to_owned(), v.to_owned())), ); diff --git a/src/api/auth/scram.rs b/src/api/auth/scram.rs index 9ee7d04..fde52b8 100644 --- a/src/api/auth/scram.rs +++ b/src/api/auth/scram.rs @@ -137,7 +137,7 @@ impl StartupHandler let resp = msg.into_sasl_initial_response()?; // parse into client_first let client_first = resp - .data() + .data .as_ref() .ok_or_else(|| { PgWireError::InvalidScramMessage( @@ -157,7 +157,7 @@ impl StartupHandler new_nonce, STANDARD.encode( salt_and_salted_pass - .salt() + .salt .as_ref() .expect("Salt required for SCRAM auth source"), ), @@ -180,7 +180,7 @@ impl StartupHandler // second response, client_final let resp = msg.into_sasl_response()?; let client_final = ClientFinal::try_new( - String::from_utf8_lossy(resp.data().as_ref()).as_ref(), + String::from_utf8_lossy(&resp.data).as_ref(), )?; // dbg!(&client_final); @@ -188,7 +188,7 @@ impl StartupHandler self.compute_channel_binding(channel_binding_prefix); client_final.validate_channel_binding(&channel_binding)?; - let salted_password = salt_and_salted_pass.password(); + let salted_password = salt_and_salted_pass.password; let client_key = hmac(salted_password.as_ref(), b"Client Key"); let stored_key = h(client_key.as_ref()); let auth_msg = diff --git a/src/api/mod.rs b/src/api/mod.rs index b51ff30..7ae39b6 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -15,7 +15,7 @@ pub mod store; pub const DEFAULT_NAME: &str = "POSTGRESQL_DEFAULT_NAME"; -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Copy, Default)] pub enum PgWireConnectionState { #[default] AwaitingStartup, @@ -26,11 +26,11 @@ pub enum PgWireConnectionState { /// Describe a client information holder pub trait ClientInfo { - fn socket_addr(&self) -> &SocketAddr; + fn socket_addr(&self) -> SocketAddr; fn is_secure(&self) -> bool; - fn state(&self) -> &PgWireConnectionState; + fn state(&self) -> PgWireConnectionState; fn set_state(&mut self, new_state: PgWireConnectionState); @@ -49,14 +49,41 @@ pub trait ClientPortalStore { pub const METADATA_USER: &str = "user"; pub const METADATA_DATABASE: &str = "database"; -#[derive(Debug, Getters, Setters, MutGetters)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(Debug)] pub struct DefaultClient { - socket_addr: SocketAddr, - is_secure: bool, - state: PgWireConnectionState, - metadata: HashMap, - portal_store: store::MemPortalStore, + pub socket_addr: SocketAddr, + pub is_secure: bool, + pub state: PgWireConnectionState, + pub metadata: HashMap, + pub portal_store: store::MemPortalStore, + _hidden: (), +} + +impl ClientInfo for DefaultClient { + fn socket_addr(&self) -> SocketAddr { + return self.socket_addr; + } + + fn is_secure(&self) -> bool { + return self.is_secure; + } + + fn state(&self) -> PgWireConnectionState { + return self.state; + } + + fn set_state(&mut self, new_state: PgWireConnectionState) { + self.state = new_state; + } + + fn metadata(&self) -> &HashMap { + return &self.metadata; + } + + fn metadata_mut(&mut self) -> &mut HashMap { + return &mut self.metadata; + } + } impl DefaultClient { @@ -67,6 +94,7 @@ impl DefaultClient { state: PgWireConnectionState::default(), metadata: HashMap::new(), portal_store: store::MemPortalStore::new(), + _hidden: (), } } } diff --git a/src/api/portal.rs b/src/api/portal.rs index 1a963ac..aea9ca3 100644 --- a/src/api/portal.rs +++ b/src/api/portal.rs @@ -13,14 +13,13 @@ use super::{results::FieldFormat, stmt::StoredStatement, DEFAULT_NAME}; /// Represent a prepared sql statement and its parameters bound by a `Bind` /// request. -#[derive(Debug, CopyGetters, Default, Getters, Setters, Clone)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(Debug, Default, Clone)] pub struct Portal { - name: String, - statement: Arc>, - parameter_format: Format, - parameters: Vec>, - result_column_format: Format, + pub name: String, + pub statement: Arc>, + pub parameter_format: Format, + pub parameters: Vec>, + pub result_column_format: Format, } #[derive(Debug, Clone, Default)] @@ -76,21 +75,21 @@ impl Portal { /// Try to create portal from bind command and current client state pub fn try_new(bind: &Bind, statement: Arc>) -> PgWireResult { let portal_name = bind - .portal_name() + .portal_name .clone() .unwrap_or_else(|| DEFAULT_NAME.to_owned()); // param format - let param_format = Format::from_codes(bind.parameter_format_codes()); + let param_format = Format::from_codes(&bind.parameter_format_codes); // format - let result_format = Format::from_codes(bind.result_column_format_codes()); + let result_format = Format::from_codes(&bind.result_column_format_codes); Ok(Portal { name: portal_name, statement, parameter_format: param_format, - parameters: bind.parameters().clone(), + parameters: bind.parameters.clone(), result_column_format: result_format, }) } @@ -113,11 +112,11 @@ impl Portal { } let param = self - .parameters() + .parameters .get(idx) .ok_or_else(|| PgWireError::ParameterIndexOutOfBound(idx))?; - let _format = self.parameter_format().format_for(idx); + let _format = self.parameter_format.format_for(idx); if let Some(ref param) = param { // TODO: from_sql only works with binary format diff --git a/src/api/query.rs b/src/api/query.rs index f41ef74..2ac2f97 100644 --- a/src/api/query.rs +++ b/src/api/query.rs @@ -42,13 +42,13 @@ pub trait SimpleQueryHandler: Send + Sync { PgWireError: From<>::Error>, { client.set_state(super::PgWireConnectionState::QueryInProgress); - let query_string = query.query(); - if is_empty_query(query_string) { + let query_string = query.query; + if is_empty_query(&query_string) { client .feed(PgWireBackendMessage::EmptyQueryResponse(EmptyQueryResponse)) .await?; } else { - let resp = self.do_query(client, query.query()).await?; + let resp = self.do_query(client, &query_string).await?; for r in resp { match r { Response::EmptyQuery => { @@ -133,7 +133,7 @@ pub trait ExtendedQueryHandler: Send + Sync { C::Error: Debug, PgWireError: From<>::Error>, { - let statement_name = message.statement_name().as_deref().unwrap_or(DEFAULT_NAME); + let statement_name = message.statement_name.as_deref().unwrap_or(DEFAULT_NAME); if let Some(statement) = client.portal_store().get_statement(statement_name) { let portal = Portal::try_new(&message, statement)?; @@ -162,10 +162,10 @@ pub trait ExtendedQueryHandler: Send + Sync { C::Error: Debug, PgWireError: From<>::Error>, { - let portal_name = message.name().as_deref().unwrap_or(DEFAULT_NAME); + let portal_name = message.name.as_deref().unwrap_or(DEFAULT_NAME); if let Some(portal) = client.portal_store().get_portal(portal_name) { match self - .do_query(client, portal.as_ref(), *message.max_rows() as usize) + .do_query(client, portal.as_ref(), message.max_rows as usize) .await? { Response::EmptyQuery => { @@ -202,8 +202,8 @@ pub trait ExtendedQueryHandler: Send + Sync { C::Error: Debug, PgWireError: From<>::Error>, { - let name = message.name().as_deref().unwrap_or(DEFAULT_NAME); - match message.target_type() { + let name = message.name.as_deref().unwrap_or(DEFAULT_NAME); + match message.target_type { TARGET_TYPE_BYTE_STATEMENT => { if let Some(stmt) = client.portal_store().get_statement(name) { let describe_response = self @@ -224,7 +224,7 @@ pub trait ExtendedQueryHandler: Send + Sync { return Err(PgWireError::PortalNotFound(name.to_owned())); } } - _ => return Err(PgWireError::InvalidTargetType(message.target_type())), + _ => return Err(PgWireError::InvalidTargetType(message.target_type)), } Ok(()) @@ -259,8 +259,8 @@ pub trait ExtendedQueryHandler: Send + Sync { C::Error: Debug, PgWireError: From<>::Error>, { - let name = message.name().as_deref().unwrap_or(DEFAULT_NAME); - match message.target_type() { + let name = message.name.as_deref().unwrap_or(DEFAULT_NAME); + match message.target_type { TARGET_TYPE_BYTE_STATEMENT => { client.portal_store().rm_statement(name); } @@ -377,7 +377,7 @@ where client.send(PgWireBackendMessage::NoData(NoData)).await?; } else { if include_parameters { - if let Some(parameter_types) = describe_response.parameters() { + if let Some(parameter_types) = describe_response.parameters.as_deref() { // parameter type inference client .send(PgWireBackendMessage::ParameterDescription( diff --git a/src/api/results.rs b/src/api/results.rs index 08fdab1..2e667cd 100644 --- a/src/api/results.rs +++ b/src/api/results.rs @@ -78,8 +78,7 @@ impl FieldFormat { } } -#[derive(Debug, new, Eq, PartialEq, Clone, Getters)] -#[getset(get = "pub")] +#[derive(Debug, new, Eq, PartialEq, Clone)] pub struct FieldInfo { name: String, table_id: Option, @@ -88,6 +87,28 @@ pub struct FieldInfo { format: FieldFormat, } +impl FieldInfo { + pub fn name(&self) -> &str { + return &self.name + } + + pub fn table_id(&self) -> Option { + return self.table_id; + } + + pub fn column_id(&self) -> Option { + return self.column_id; + } + + pub fn datatype(&self) -> &Type { + return &self.datatype; + } + + pub fn format(&self) -> FieldFormat { + return self.format; + } +} + impl From<&FieldInfo> for FieldDescription { fn from(fi: &FieldInfo) -> Self { FieldDescription::new( @@ -175,9 +196,9 @@ impl DataRowEncoder { if let IsNull::No = is_null { let buf = self.field_buffer.split().freeze(); - self.buffer.fields_mut().push(Some(buf)); + self.buffer.fields.push(Some(buf)); } else { - self.buffer.fields_mut().push(None); + self.buffer.fields.push(None); } self.col_index += 1; @@ -195,7 +216,7 @@ impl DataRowEncoder { let data_type = self.schema[self.col_index].datatype(); let format = self.schema[self.col_index].format(); - let is_null = if *format == FieldFormat::Text { + let is_null = if format == FieldFormat::Text { value.to_sql_text(data_type, &mut self.field_buffer)? } else { value.to_sql(data_type, &mut self.field_buffer)? @@ -203,9 +224,9 @@ impl DataRowEncoder { if let IsNull::No = is_null { let buf = self.field_buffer.split().freeze(); - self.buffer.fields_mut().push(Some(buf)); + self.buffer.fields.push(Some(buf)); } else { - self.buffer.fields_mut().push(None); + self.buffer.fields.push(None); } self.col_index += 1; @@ -225,20 +246,31 @@ impl DataRowEncoder { /// statement, frontend expects parameter types inferenced by server. And both /// describe messages will require column definitions for resultset being /// returned. -#[derive(Debug, Getters, new)] -#[getset(get = "pub")] +#[derive(Debug, new)] pub struct DescribeResponse { - parameters: Option>, - fields: Vec, + pub parameters: Option>, + pub fields: Vec, + #[new(default)] + _hidden: (), } impl DescribeResponse { + pub fn parameters(&self) -> Option<&[Type]> { + return self.parameters.as_deref() + + } + + pub fn fields(&self) -> &[FieldInfo] { + return &self.fields; + } + /// Create an no_data instance of `DescribeResponse`. This is typically used /// when client tries to describe an empty query. pub fn no_data() -> Self { DescribeResponse { parameters: None, fields: vec![], + _hidden: (), } } @@ -271,7 +303,7 @@ mod test { let tag = Tag::new_for_execution("INSERT", Some(100)); let cc = CommandComplete::from(tag); - assert_eq!(cc.tag(), "INSERT 100"); + assert_eq!(cc.tag, "INSERT 100"); } #[test] @@ -288,9 +320,9 @@ mod test { let row = encoder.finish().unwrap(); - assert_eq!(row.fields().len(), 3); - assert_eq!(row.fields()[0].as_ref().unwrap().len(), 4); - assert_eq!(row.fields()[1].as_ref().unwrap().len(), 4); - assert_eq!(row.fields()[2].as_ref().unwrap().len(), 26); + assert_eq!(row.fields.len(), 3); + assert_eq!(row.fields[0].as_ref().unwrap().len(), 4); + assert_eq!(row.fields[1].as_ref().unwrap().len(), 4); + assert_eq!(row.fields[2].as_ref().unwrap().len(), 26); } } diff --git a/src/api/stmt.rs b/src/api/stmt.rs index 1482f7d..379b427 100644 --- a/src/api/stmt.rs +++ b/src/api/stmt.rs @@ -8,16 +8,17 @@ use crate::messages::extendedquery::Parse; use super::DEFAULT_NAME; -#[derive(Debug, Default, new, Getters, Setters)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(Debug, Default, new)] pub struct StoredStatement { /// name of the statement - id: String, + pub id: String, /// parsed query statement - statement: S, + pub statement: S, /// type ids of query parameters, can be empty if frontend asks backend for /// type inference - parameter_types: Vec, + pub parameter_types: Vec, + #[new(default)] + _hidden: (), } impl StoredStatement { @@ -26,18 +27,19 @@ impl StoredStatement { Q: QueryParser, { let types = parse - .type_oids() + .type_oids .iter() .map(|oid| Type::from_oid(*oid).ok_or_else(|| PgWireError::UnknownTypeId(*oid))) .collect::>>()?; - let statement = parser.parse_sql(parse.query(), &types).await?; + let statement = parser.parse_sql(&parse.query, &types).await?; Ok(StoredStatement { id: parse - .name() + .name .clone() .unwrap_or_else(|| DEFAULT_NAME.to_owned()), statement, parameter_types: types, + _hidden: (), }) } } diff --git a/src/api/store.rs b/src/api/store.rs index e396186..fd7dd4f 100644 --- a/src/api/store.rs +++ b/src/api/store.rs @@ -33,7 +33,7 @@ impl PortalStore for MemPortalStore { fn put_statement(&self, statement: Arc>) { let mut guard = self.statements.write().unwrap(); - guard.insert(statement.id().to_owned(), statement); + guard.insert(statement.id.to_owned(), statement); } fn rm_statement(&self, name: &str) { @@ -48,7 +48,7 @@ impl PortalStore for MemPortalStore { fn put_portal(&self, portal: Arc>) { let mut guard = self.portals.write().unwrap(); - guard.insert(portal.name().to_owned(), portal); + guard.insert(portal.name.to_owned(), portal); } fn rm_portal(&self, name: &str) { diff --git a/src/error.rs b/src/error.rs index f840aa9..46316e7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -54,50 +54,51 @@ pub type PgWireResult = Result; // Postgres error and notice message fields // This part of protocol is defined in // https://www.postgresql.org/docs/8.2/protocol-error-fields.html -#[derive(new, Setters, Getters, Debug)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(new, Debug)] pub struct ErrorInfo { // severity can be one of `ERROR`, `FATAL`, or `PANIC` (in an error // message), or `WARNING`, `NOTICE`, `DEBUG`, `INFO`, or `LOG` (in a notice // message), or a localized translation of one of these. - severity: String, + pub severity: String, // error code defined in // https://www.postgresql.org/docs/current/errcodes-appendix.html - code: String, + pub code: String, // readable message - message: String, + pub message: String, // optional secondary message #[new(default)] - detail: Option, + pub detail: Option, // optional suggestion for fixing the issue #[new(default)] - hint: Option, + pub hint: Option, // Position: the field value is a decimal ASCII integer, indicating an error // cursor position as an index into the original query string. #[new(default)] - position: Option, + pub position: Option, // Internal position: this is defined the same as the P field, but it is // used when the cursor position refers to an internally generated command // rather than the one submitted by the client #[new(default)] - internal_position: Option, + pub internal_position: Option, // Internal query: the text of a failed internally-generated command. #[new(default)] - internal_query: Option, + pub internal_query: Option, // Where: an indication of the context in which the error occurred. #[new(default)] - where_context: Option, + pub where_context: Option, // File: the file name of the source-code location where the error was // reported. #[new(default)] - file_name: Option, + pub file_name: Option, // Line: the line number of the source-code location where the error was // reported. #[new(default)] - line: Option, + pub line: Option, // Routine: the name of the source-code routine reporting the error. #[new(default)] - routine: Option, + pub routine: Option, + #[new(default)] + _hidden: (), } impl ErrorInfo { @@ -162,9 +163,9 @@ mod test { "28P01".to_owned(), "Password authentication failed".to_owned(), ); - assert_eq!("FATAL", error_info.severity()); - assert_eq!("28P01", error_info.code()); - assert_eq!("Password authentication failed", error_info.message()); - assert!(error_info.file_name().is_none()); + assert_eq!("FATAL", error_info.severity); + assert_eq!("28P01", error_info.code); + assert_eq!("Password authentication failed", error_info.message); + assert!(error_info.file_name.is_none()); } } diff --git a/src/lib.rs b/src/lib.rs index 3048c5f..77697cb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -53,8 +53,6 @@ //! usages. //! -#[macro_use] -extern crate getset; #[macro_use] extern crate derive_new; diff --git a/src/messages/copy.rs b/src/messages/copy.rs index 77d4710..2bd2c26 100644 --- a/src/messages/copy.rs +++ b/src/messages/copy.rs @@ -6,10 +6,9 @@ use crate::error::PgWireResult; pub const MESSAGE_TYPE_BYTE_COPY_DATA: u8 = b'd'; -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, Default, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, Default, new)] pub struct CopyData { - data: Bytes, + pub data: Bytes, } impl Message for CopyData { @@ -35,8 +34,7 @@ impl Message for CopyData { pub const MESSAGE_TYPE_BYTE_COPY_DONE: u8 = b'c'; -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, Default, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, Default, new)] pub struct CopyDone; impl Message for CopyDone { @@ -60,8 +58,7 @@ impl Message for CopyDone { pub const MESSAGE_TYPE_BYTE_COPY_FAIL: u8 = b'f'; -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, Default, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, Default, new)] pub struct CopyFail { message: String, } @@ -89,8 +86,7 @@ impl Message for CopyFail { pub const MESSAGE_TYPE_BYTE_COPY_IN_RESPONSE: u8 = b'G'; -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, Default, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, Default, new)] pub struct CopyInResponse { format: i8, columns: i16, @@ -130,8 +126,7 @@ impl Message for CopyInResponse { pub const MESSAGE_TYPE_BYTE_COPY_OUT_RESPONSE: u8 = b'H'; -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, Default, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, Default, new)] pub struct CopyOutResponse { format: i8, columns: i16, @@ -171,8 +166,7 @@ impl Message for CopyOutResponse { pub const MESSAGE_TYPE_BYTE_COPY_BOTH_RESPONSE: u8 = b'W'; -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, Default, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, Default, new)] pub struct CopyBothResponse { format: i8, columns: i16, diff --git a/src/messages/data.rs b/src/messages/data.rs index 4814435..0960e07 100644 --- a/src/messages/data.rs +++ b/src/messages/data.rs @@ -8,29 +8,29 @@ use crate::error::PgWireResult; pub(crate) const FORMAT_CODE_TEXT: i16 = 0; pub(crate) const FORMAT_CODE_BINARY: i16 = 1; -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, Default, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, Default, new)] pub struct FieldDescription { // the field name - name: String, + pub name: String, // the object ID of table, default to 0 if not a table - table_id: i32, + pub table_id: i32, // the attribute number of the column, default to 0 if not a column from table - column_id: i16, + pub column_id: i16, // the object ID of the data type - type_id: Oid, + pub type_id: Oid, // the size of data type, negative values denote variable-width types - type_size: i16, + pub type_size: i16, // the type modifier - type_modifier: i32, + pub type_modifier: i32, // the format code being used for the filed, will be 0 or 1 for now - format_code: i16, + pub format_code: i16, } -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, Default, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, Default, new)] pub struct RowDescription { - fields: Vec, + pub fields: Vec, + #[new(default)] + _hidden: (), } pub const MESSAGE_TYPE_BYTE_ROW_DESCRITION: u8 = b'T'; @@ -83,16 +83,17 @@ impl Message for RowDescription { fields.push(field); } - Ok(RowDescription { fields }) + Ok(RowDescription { fields, _hidden: () }) } } /// Data structure returned when frontend describes a statement -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, Default, new, Clone)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, Default, new, Clone)] pub struct ParameterDescription { /// parameter types - types: Vec, + pub types: Vec, + #[new(default)] + _hidden: (), } pub const MESSAGE_TYPE_BYTE_PARAMETER_DESCRITION: u8 = b't'; @@ -124,7 +125,7 @@ impl Message for ParameterDescription { types.push(buf.get_i32() as Oid); } - Ok(ParameterDescription { types }) + Ok(ParameterDescription { types, _hidden: () }) } } @@ -132,10 +133,11 @@ impl Message for ParameterDescription { /// /// Data can be represented as text or binary format as specified by format /// codes from previous `RowDescription` message. -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, Default, new, Clone)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, Default, new, Clone)] pub struct DataRow { - fields: Vec>, + pub fields: Vec>, + #[new(default)] + _hidden: (), } impl DataRow {} @@ -184,14 +186,13 @@ impl Message for DataRow { } } - Ok(DataRow { fields }) + Ok(DataRow { fields, _hidden: () }) } } /// postgres response when query returns no data, sent from backend to frontend /// in extended query -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, Default, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, Default, new)] pub struct NoData; pub const MESSAGE_TYPE_BYTE_NO_DATA: u8 = b'n'; diff --git a/src/messages/extendedquery.rs b/src/messages/extendedquery.rs index be33e28..898945e 100644 --- a/src/messages/extendedquery.rs +++ b/src/messages/extendedquery.rs @@ -5,12 +5,13 @@ use super::{codec, Message}; use crate::error::PgWireResult; /// Request from frontend to parse a prepared query string -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, new)] pub struct Parse { - name: Option, - query: String, - type_oids: Vec, + pub name: Option, + pub query: String, + pub type_oids: Vec, + #[new(default)] + _hidden: (), } pub const MESSAGE_TYPE_BYTE_PARSE: u8 = b'P'; @@ -53,13 +54,13 @@ impl Message for Parse { name, query, type_oids, + _hidden: (), }) } } /// Response for Parse command, sent from backend to frontend -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, new)] pub struct ParseComplete; pub const MESSAGE_TYPE_BYTE_PARSE_COMPLETE: u8 = b'1'; @@ -87,13 +88,12 @@ impl Message for ParseComplete { } /// Closing the prepared statement or portal -#[derive(Getters, CopyGetters, Setters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, new)] pub struct Close { - #[getset(skip)] - #[getset(get_copy = "pub", set = "pub")] - target_type: u8, - name: Option, + pub target_type: u8, + pub name: Option, + #[new(default)] + _hidden: (), } pub const TARGET_TYPE_BYTE_STATEMENT: u8 = b'S'; @@ -121,13 +121,12 @@ impl Message for Close { let target_type = buf.get_u8(); let name = codec::get_cstring(buf); - Ok(Close { target_type, name }) + Ok(Close { target_type, name, _hidden: () }) } } /// Response for Close command, sent from backend to frontend -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, new)] pub struct CloseComplete; pub const MESSAGE_TYPE_BYTE_CLOSE_COMPLETE: u8 = b'3'; @@ -155,17 +154,18 @@ impl Message for CloseComplete { } /// Bind command, for executing prepared statement -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, new)] pub struct Bind { - portal_name: Option, - statement_name: Option, - parameter_format_codes: Vec, + pub portal_name: Option, + pub statement_name: Option, + pub parameter_format_codes: Vec, // None for Null data, TODO: consider wrapping this together with DataRow in // data.rs - parameters: Vec>, + pub parameters: Vec>, - result_column_format_codes: Vec, + pub result_column_format_codes: Vec, + #[new(default)] + _hidden: (), } pub const MESSAGE_TYPE_BYTE_BIND: u8 = b'B'; @@ -251,13 +251,14 @@ impl Message for Bind { parameters, result_column_format_codes, + + _hidden: (), }) } } /// Success response for `Bind` -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, new)] pub struct BindComplete; pub const MESSAGE_TYPE_BYTE_BIND_COMPLETE: u8 = b'2'; @@ -286,13 +287,12 @@ impl Message for BindComplete { /// Describe command fron frontend to backend. For getting information of /// particular portal or statement -#[derive(Getters, Setters, CopyGetters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, new)] pub struct Describe { - #[getset(skip)] - #[getset(get_copy = "pub", set = "pub")] - target_type: u8, - name: Option, + pub target_type: u8, + pub name: Option, + #[new(default)] + _hidden: (), } pub const MESSAGE_TYPE_BYTE_DESCRIBE: u8 = b'D'; @@ -317,16 +317,17 @@ impl Message for Describe { let target_type = buf.get_u8(); let name = codec::get_cstring(buf); - Ok(Describe { target_type, name }) + Ok(Describe { target_type, name, _hidden: () }) } } /// Execute portal by its name -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, new)] pub struct Execute { - name: Option, - max_rows: i32, + pub name: Option, + pub max_rows: i32, + #[new(default)] + _hidden: (), } pub const MESSAGE_TYPE_BYTE_EXECUTE: u8 = b'E'; @@ -351,12 +352,11 @@ impl Message for Execute { let name = codec::get_cstring(buf); let max_rows = buf.get_i32(); - Ok(Execute { name, max_rows }) + Ok(Execute { name, max_rows, _hidden: () }) } } -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, new)] pub struct Flush; pub const MESSAGE_TYPE_BYTE_FLUSH: u8 = b'H'; @@ -382,8 +382,7 @@ impl Message for Flush { } /// Execute portal by its name -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, new)] pub struct Sync; pub const MESSAGE_TYPE_BYTE_SYNC: u8 = b'S'; @@ -408,8 +407,7 @@ impl Message for Sync { } } -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, new)] pub struct PortalSuspended; pub const MESSAGE_TYPE_BYTE_PORTAL_SUSPENDED: u8 = b's'; diff --git a/src/messages/mod.rs b/src/messages/mod.rs index 69e0567..f1f94ff 100644 --- a/src/messages/mod.rs +++ b/src/messages/mod.rs @@ -369,8 +369,7 @@ mod test { #[test] fn test_startup() { let mut s = Startup::default(); - s.parameters_mut() - .insert("user".to_owned(), "tomcat".to_owned()); + s.parameters.insert("user".to_owned(), "tomcat".to_owned()); roundtrip!(s, Startup); } @@ -423,8 +422,8 @@ mod test { #[test] fn test_error_response() { let mut error = ErrorResponse::default(); - error.fields_mut().push((b'R', "ERROR".to_owned())); - error.fields_mut().push((b'K', "cli".to_owned())); + error.fields.push((b'R', "ERROR".to_owned())); + error.fields.push((b'K', "cli".to_owned())); roundtrip!(error, ErrorResponse); } @@ -432,8 +431,8 @@ mod test { #[test] fn test_notice_response() { let mut error = NoticeResponse::default(); - error.fields_mut().push((b'R', "NOTICE".to_owned())); - error.fields_mut().push((b'K', "cli".to_owned())); + error.fields.push((b'R', "NOTICE".to_owned())); + error.fields.push((b'K', "cli".to_owned())); roundtrip!(error, NoticeResponse); } @@ -443,24 +442,24 @@ mod test { let mut row_description = RowDescription::default(); let mut f1 = FieldDescription::default(); - f1.set_name("id".into()); - f1.set_table_id(1001); - f1.set_column_id(10001); - f1.set_type_id(1083); - f1.set_type_size(4); - f1.set_type_modifier(-1); - f1.set_format_code(FORMAT_CODE_TEXT); - row_description.fields_mut().push(f1); + f1.name = "id".into(); + f1.table_id = 1001; + f1.column_id = 10001; + f1.type_id = 1083; + f1.type_size = 4; + f1.type_modifier = -1; + f1.format_code = FORMAT_CODE_TEXT; + row_description.fields.push(f1); let mut f2 = FieldDescription::default(); - f2.set_name("name".into()); - f2.set_table_id(1001); - f2.set_column_id(10001); - f2.set_type_id(1099); - f2.set_type_size(-1); - f2.set_type_modifier(-1); - f2.set_format_code(FORMAT_CODE_TEXT); - row_description.fields_mut().push(f2); + f2.name = "name".into(); + f2.table_id = 1001; + f2.column_id = 10001; + f2.type_id = 1099; + f2.type_size = -1; + f2.type_modifier = -1; + f2.format_code = FORMAT_CODE_TEXT; + row_description.fields.push(f2); roundtrip!(row_description, RowDescription); } @@ -468,9 +467,9 @@ mod test { #[test] fn test_data_row() { let mut row0 = DataRow::default(); - row0.fields_mut().push(Some(Bytes::from_static(b"1"))); - row0.fields_mut().push(Some(Bytes::from_static(b"abc"))); - row0.fields_mut().push(None); + row0.fields.push(Some(Bytes::from_static(b"1"))); + row0.fields.push(Some(Bytes::from_static(b"abc"))); + row0.fields.push(None); roundtrip!(row0, DataRow); } diff --git a/src/messages/response.rs b/src/messages/response.rs index 1dcec53..1f50ec5 100644 --- a/src/messages/response.rs +++ b/src/messages/response.rs @@ -4,10 +4,9 @@ use super::codec; use super::Message; use crate::error::PgWireResult; -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, new)] pub struct CommandComplete { - tag: String, + pub tag: String, } pub const MESSAGE_TYPE_BYTE_COMMAND_COMPLETE: u8 = b'C'; @@ -35,8 +34,7 @@ impl Message for CommandComplete { } } -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, new)] pub struct EmptyQueryResponse; pub const MESSAGE_TYPE_BYTE_EMPTY_QUERY_RESPONSE: u8 = b'I'; @@ -59,10 +57,9 @@ impl Message for EmptyQueryResponse { } } -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, new)] pub struct ReadyForQuery { - status: u8, + pub status: u8, } pub const READY_STATUS_IDLE: u8 = b'I'; @@ -95,10 +92,9 @@ impl Message for ReadyForQuery { } /// postgres error response, sent from backend to frontend -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, Default, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, Default, new)] pub struct ErrorResponse { - fields: Vec<(u8, String)>, + pub fields: Vec<(u8, String)>, } pub const MESSAGE_TYPE_BYTE_ERROR_RESPONSE: u8 = b'E'; @@ -145,10 +141,9 @@ impl Message for ErrorResponse { } /// postgres error response, sent from backend to frontend -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, Default, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, Default, new)] pub struct NoticeResponse { - fields: Vec<(u8, String)>, + pub fields: Vec<(u8, String)>, } pub const MESSAGE_TYPE_BYTE_NOTICE_RESPONSE: u8 = b'N'; @@ -252,12 +247,11 @@ impl Message for SslResponse { } /// NotificationResponse -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, Default, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, Default, new)] pub struct NotificationResponse { - pid: i32, - channel: String, - payload: String, + pub pid: i32, + pub channel: String, + pub payload: String, } pub const MESSAGE_TYPE_BYTE_NOTIFICATION_RESPONSE: u8 = b'A'; diff --git a/src/messages/simplequery.rs b/src/messages/simplequery.rs index 5896541..62b2a42 100644 --- a/src/messages/simplequery.rs +++ b/src/messages/simplequery.rs @@ -5,10 +5,11 @@ use super::Message; use crate::error::PgWireResult; /// A sql query sent from frontend to backend. -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, new)] pub struct Query { - query: String, + pub query: String, + #[new(default)] + _hidden: (), } pub const MESSAGE_TYPE_BYTE_QUERY: u8 = b'Q'; diff --git a/src/messages/startup.rs b/src/messages/startup.rs index 620d78f..3286167 100644 --- a/src/messages/startup.rs +++ b/src/messages/startup.rs @@ -11,15 +11,16 @@ use crate::error::{PgWireError, PgWireResult}; /// terminated by a zero byte. /// the key-value parameter pairs are terminated by a zero byte, too. /// -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, new)] pub struct Startup { #[new(value = "3")] - protocol_number_major: u16, + pub protocol_number_major: u16, #[new(value = "0")] - protocol_number_minor: u16, + pub protocol_number_minor: u16, #[new(default)] - parameters: BTreeMap, + pub parameters: BTreeMap, + #[new(default)] + _hidden: (), } impl Default for Startup { @@ -88,13 +89,13 @@ impl Message for Startup { let mut msg = Startup::default(); // parse - msg.set_protocol_number_major(buf.get_u16()); - msg.set_protocol_number_minor(buf.get_u16()); + msg.protocol_number_major = buf.get_u16(); + msg.protocol_number_minor = buf.get_u16(); // end by reading the last \0 while let Some(key) = codec::get_cstring(buf) { let value = codec::get_cstring(buf).unwrap_or_else(|| "".to_owned()); - msg.parameters_mut().insert(key, value); + msg.parameters.insert(key, value); } Ok(msg) @@ -319,10 +320,9 @@ impl PasswordMessageFamily { } /// password packet sent from frontend -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, new)] pub struct Password { - password: String, + pub password: String, } impl Message for Password { @@ -349,8 +349,7 @@ impl Message for Password { } /// parameter ack sent from backend after authentication success -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, new)] pub struct ParameterStatus { name: String, value: String, @@ -385,8 +384,7 @@ impl Message for ParameterStatus { /// `BackendKeyData` message, sent from backend to frontend for issuing /// `CancelRequestMessage` -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, new)] pub struct BackendKeyData { pid: i32, secret_key: i32, @@ -427,9 +425,11 @@ impl Message for BackendKeyData { /// The backend sents a single byte 'S' or 'N' to indicate its support. Upon 'S' /// the frontend should close the connection and reinitialize a new TLS /// connection. -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] -pub struct SslRequest {} +#[derive(PartialEq, Eq, Debug, new)] +pub struct SslRequest { + #[new(default)] + _hidden: () +} impl SslRequest { pub const BODY_MAGIC_NUMBER: i32 = 80877103; @@ -460,18 +460,19 @@ impl Message for SslRequest { fn decode(buf: &mut BytesMut) -> PgWireResult> { if buf.remaining() >= 8 && (&buf[4..8]).get_i32() == Self::BODY_MAGIC_NUMBER { buf.advance(8); - Ok(Some(SslRequest {})) + Ok(Some(SslRequest { _hidden: () })) } else { Ok(None) } } } -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, new)] pub struct SASLInitialResponse { - auth_method: String, - data: Option, + pub auth_method: String, + pub data: Option, + #[new(default)] + _hidden: (), } impl Message for SASLInitialResponse { @@ -508,14 +509,15 @@ impl Message for SASLInitialResponse { Some(buf.split_to(data_len as usize).freeze()) }; - Ok(SASLInitialResponse { auth_method, data }) + Ok(SASLInitialResponse { auth_method, data, _hidden: () }) } } -#[derive(Getters, Setters, MutGetters, PartialEq, Eq, Debug, new)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(PartialEq, Eq, Debug, new)] pub struct SASLResponse { - data: Bytes, + pub data: Bytes, + #[new(default)] + _hidden: (), } impl Message for SASLResponse { @@ -536,6 +538,6 @@ impl Message for SASLResponse { fn decode_body(buf: &mut BytesMut, full_len: usize) -> PgWireResult { let data = buf.split_to(full_len - 4).freeze(); - Ok(SASLResponse { data }) + Ok(SASLResponse { data, _hidden: () }) } } diff --git a/src/tokio.rs b/src/tokio.rs index fca04ac..45a5476 100644 --- a/src/tokio.rs +++ b/src/tokio.rs @@ -19,10 +19,11 @@ use crate::messages::response::{SslResponse, READY_STATUS_IDLE}; use crate::messages::startup::{SslRequest, Startup}; use crate::messages::{Message, PgWireBackendMessage, PgWireFrontendMessage}; -#[derive(Debug, new, Getters, Setters, MutGetters)] -#[getset(get = "pub", set = "pub", get_mut = "pub")] +#[derive(Debug, new)] pub struct PgWireMessageServerCodec { - client_info: DefaultClient, + pub client_info: DefaultClient, + #[new(default)] + _hidden: (), } impl Decoder for PgWireMessageServerCodec { @@ -60,28 +61,28 @@ impl Encoder for PgWireMessageServerCodec { } impl ClientInfo for Framed> { - fn socket_addr(&self) -> &std::net::SocketAddr { - self.codec().client_info().socket_addr() + fn socket_addr(&self) -> std::net::SocketAddr { + self.codec().client_info.socket_addr } fn is_secure(&self) -> bool { - *self.codec().client_info().is_secure() + self.codec().client_info.is_secure } - fn state(&self) -> &PgWireConnectionState { - self.codec().client_info().state() + fn state(&self) -> PgWireConnectionState { + self.codec().client_info.state } fn set_state(&mut self, new_state: PgWireConnectionState) { - self.codec_mut().client_info_mut().set_state(new_state); + self.codec_mut().client_info.set_state(new_state); } fn metadata(&self) -> &std::collections::HashMap { - self.codec().client_info().metadata() + self.codec().client_info.metadata() } fn metadata_mut(&mut self) -> &mut std::collections::HashMap { - self.codec_mut().client_info_mut().metadata_mut() + self.codec_mut().client_info.metadata_mut() } } @@ -89,7 +90,7 @@ impl ClientPortalStore for Framed> { type PortalStore = as ClientPortalStore>::PortalStore; fn portal_store(&self) -> &Self::PortalStore { - self.codec().client_info().portal_store() + self.codec().client_info.portal_store() } } @@ -106,7 +107,7 @@ where Q: SimpleQueryHandler + 'static, EQ: ExtendedQueryHandler + 'static, { - match socket.codec().client_info().state() { + match socket.codec().client_info.state() { PgWireConnectionState::AwaitingStartup | PgWireConnectionState::AuthenticationInProgress => { authenticator.on_startup(socket, message).await?; From 4a00ddc8e4ae23b20dee8876a687ad2a13582987 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Tue, 26 Dec 2023 18:10:54 +0000 Subject: [PATCH 2/6] cargo fmt --- src/api/auth/scram.rs | 5 ++--- src/api/mod.rs | 1 - src/api/results.rs | 5 ++--- src/messages/data.rs | 10 ++++++++-- src/messages/extendedquery.rs | 18 +++++++++++++++--- src/messages/startup.rs | 8 ++++++-- 6 files changed, 33 insertions(+), 14 deletions(-) diff --git a/src/api/auth/scram.rs b/src/api/auth/scram.rs index fde52b8..7f1538c 100644 --- a/src/api/auth/scram.rs +++ b/src/api/auth/scram.rs @@ -179,9 +179,8 @@ impl StartupHandler ) => { // second response, client_final let resp = msg.into_sasl_response()?; - let client_final = ClientFinal::try_new( - String::from_utf8_lossy(&resp.data).as_ref(), - )?; + let client_final = + ClientFinal::try_new(String::from_utf8_lossy(&resp.data).as_ref())?; // dbg!(&client_final); let channel_binding = diff --git a/src/api/mod.rs b/src/api/mod.rs index 7ae39b6..f106b18 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -83,7 +83,6 @@ impl ClientInfo for DefaultClient { fn metadata_mut(&mut self) -> &mut HashMap { return &mut self.metadata; } - } impl DefaultClient { diff --git a/src/api/results.rs b/src/api/results.rs index 2e667cd..6b70a06 100644 --- a/src/api/results.rs +++ b/src/api/results.rs @@ -89,7 +89,7 @@ pub struct FieldInfo { impl FieldInfo { pub fn name(&self) -> &str { - return &self.name + return &self.name; } pub fn table_id(&self) -> Option { @@ -256,8 +256,7 @@ pub struct DescribeResponse { impl DescribeResponse { pub fn parameters(&self) -> Option<&[Type]> { - return self.parameters.as_deref() - + return self.parameters.as_deref(); } pub fn fields(&self) -> &[FieldInfo] { diff --git a/src/messages/data.rs b/src/messages/data.rs index 0960e07..74569f9 100644 --- a/src/messages/data.rs +++ b/src/messages/data.rs @@ -83,7 +83,10 @@ impl Message for RowDescription { fields.push(field); } - Ok(RowDescription { fields, _hidden: () }) + Ok(RowDescription { + fields, + _hidden: (), + }) } } @@ -186,7 +189,10 @@ impl Message for DataRow { } } - Ok(DataRow { fields, _hidden: () }) + Ok(DataRow { + fields, + _hidden: (), + }) } } diff --git a/src/messages/extendedquery.rs b/src/messages/extendedquery.rs index 898945e..77b0472 100644 --- a/src/messages/extendedquery.rs +++ b/src/messages/extendedquery.rs @@ -121,7 +121,11 @@ impl Message for Close { let target_type = buf.get_u8(); let name = codec::get_cstring(buf); - Ok(Close { target_type, name, _hidden: () }) + Ok(Close { + target_type, + name, + _hidden: (), + }) } } @@ -317,7 +321,11 @@ impl Message for Describe { let target_type = buf.get_u8(); let name = codec::get_cstring(buf); - Ok(Describe { target_type, name, _hidden: () }) + Ok(Describe { + target_type, + name, + _hidden: (), + }) } } @@ -352,7 +360,11 @@ impl Message for Execute { let name = codec::get_cstring(buf); let max_rows = buf.get_i32(); - Ok(Execute { name, max_rows, _hidden: () }) + Ok(Execute { + name, + max_rows, + _hidden: (), + }) } } diff --git a/src/messages/startup.rs b/src/messages/startup.rs index 3286167..d57c79f 100644 --- a/src/messages/startup.rs +++ b/src/messages/startup.rs @@ -428,7 +428,7 @@ impl Message for BackendKeyData { #[derive(PartialEq, Eq, Debug, new)] pub struct SslRequest { #[new(default)] - _hidden: () + _hidden: (), } impl SslRequest { @@ -509,7 +509,11 @@ impl Message for SASLInitialResponse { Some(buf.split_to(data_len as usize).freeze()) }; - Ok(SASLInitialResponse { auth_method, data, _hidden: () }) + Ok(SASLInitialResponse { + auth_method, + data, + _hidden: (), + }) } } From c701a17713dfcddbbfb281a3b602ac73f9667c9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Tue, 26 Dec 2023 18:28:19 +0000 Subject: [PATCH 3/6] replace _hidden with non_exhaustive --- src/api/auth/mod.rs | 3 +-- src/api/mod.rs | 3 +-- src/api/results.rs | 4 +--- src/api/stmt.rs | 4 +--- src/error.rs | 3 +-- src/messages/data.rs | 21 ++++++-------------- src/messages/extendedquery.rs | 36 ++++++++--------------------------- src/messages/simplequery.rs | 3 +-- src/messages/startup.rs | 25 ++++++++---------------- src/tokio.rs | 3 +-- 10 files changed, 29 insertions(+), 76 deletions(-) diff --git a/src/api/auth/mod.rs b/src/api/auth/mod.rs index 1b3c93c..cab1e3b 100644 --- a/src/api/auth/mod.rs +++ b/src/api/auth/mod.rs @@ -44,6 +44,7 @@ pub trait ServerParameterProvider: Send + Sync { /// - `client_encoding: UTF8` /// - `integer_datetimes: on`: /// +#[non_exhaustive] #[derive(Debug)] pub struct DefaultServerParameterProvider { pub server_version: String, @@ -51,7 +52,6 @@ pub struct DefaultServerParameterProvider { pub client_encoding: String, pub date_style: String, pub integer_datetimes: String, - _hidden: (), } impl Default for DefaultServerParameterProvider { @@ -62,7 +62,6 @@ impl Default for DefaultServerParameterProvider { client_encoding: "UTF8".to_owned(), date_style: "ISO YMD".to_owned(), integer_datetimes: "on".to_owned(), - _hidden: (), } } } diff --git a/src/api/mod.rs b/src/api/mod.rs index f106b18..f9df42f 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -49,6 +49,7 @@ pub trait ClientPortalStore { pub const METADATA_USER: &str = "user"; pub const METADATA_DATABASE: &str = "database"; +#[non_exhaustive] #[derive(Debug)] pub struct DefaultClient { pub socket_addr: SocketAddr, @@ -56,7 +57,6 @@ pub struct DefaultClient { pub state: PgWireConnectionState, pub metadata: HashMap, pub portal_store: store::MemPortalStore, - _hidden: (), } impl ClientInfo for DefaultClient { @@ -93,7 +93,6 @@ impl DefaultClient { state: PgWireConnectionState::default(), metadata: HashMap::new(), portal_store: store::MemPortalStore::new(), - _hidden: (), } } } diff --git a/src/api/results.rs b/src/api/results.rs index 6b70a06..182097c 100644 --- a/src/api/results.rs +++ b/src/api/results.rs @@ -246,12 +246,11 @@ impl DataRowEncoder { /// statement, frontend expects parameter types inferenced by server. And both /// describe messages will require column definitions for resultset being /// returned. +#[non_exhaustive] #[derive(Debug, new)] pub struct DescribeResponse { pub parameters: Option>, pub fields: Vec, - #[new(default)] - _hidden: (), } impl DescribeResponse { @@ -269,7 +268,6 @@ impl DescribeResponse { DescribeResponse { parameters: None, fields: vec![], - _hidden: (), } } diff --git a/src/api/stmt.rs b/src/api/stmt.rs index 379b427..ce0672f 100644 --- a/src/api/stmt.rs +++ b/src/api/stmt.rs @@ -8,6 +8,7 @@ use crate::messages::extendedquery::Parse; use super::DEFAULT_NAME; +#[non_exhaustive] #[derive(Debug, Default, new)] pub struct StoredStatement { /// name of the statement @@ -17,8 +18,6 @@ pub struct StoredStatement { /// type ids of query parameters, can be empty if frontend asks backend for /// type inference pub parameter_types: Vec, - #[new(default)] - _hidden: (), } impl StoredStatement { @@ -39,7 +38,6 @@ impl StoredStatement { .unwrap_or_else(|| DEFAULT_NAME.to_owned()), statement, parameter_types: types, - _hidden: (), }) } } diff --git a/src/error.rs b/src/error.rs index 46316e7..1e7d991 100644 --- a/src/error.rs +++ b/src/error.rs @@ -54,6 +54,7 @@ pub type PgWireResult = Result; // Postgres error and notice message fields // This part of protocol is defined in // https://www.postgresql.org/docs/8.2/protocol-error-fields.html +#[non_exhaustive] #[derive(new, Debug)] pub struct ErrorInfo { // severity can be one of `ERROR`, `FATAL`, or `PANIC` (in an error @@ -97,8 +98,6 @@ pub struct ErrorInfo { // Routine: the name of the source-code routine reporting the error. #[new(default)] pub routine: Option, - #[new(default)] - _hidden: (), } impl ErrorInfo { diff --git a/src/messages/data.rs b/src/messages/data.rs index 74569f9..b53cf4c 100644 --- a/src/messages/data.rs +++ b/src/messages/data.rs @@ -26,11 +26,10 @@ pub struct FieldDescription { pub format_code: i16, } +#[non_exhaustive] #[derive(PartialEq, Eq, Debug, Default, new)] pub struct RowDescription { pub fields: Vec, - #[new(default)] - _hidden: (), } pub const MESSAGE_TYPE_BYTE_ROW_DESCRITION: u8 = b'T'; @@ -83,20 +82,16 @@ impl Message for RowDescription { fields.push(field); } - Ok(RowDescription { - fields, - _hidden: (), - }) + Ok(RowDescription { fields }) } } /// Data structure returned when frontend describes a statement +#[non_exhaustive] #[derive(PartialEq, Eq, Debug, Default, new, Clone)] pub struct ParameterDescription { /// parameter types pub types: Vec, - #[new(default)] - _hidden: (), } pub const MESSAGE_TYPE_BYTE_PARAMETER_DESCRITION: u8 = b't'; @@ -128,7 +123,7 @@ impl Message for ParameterDescription { types.push(buf.get_i32() as Oid); } - Ok(ParameterDescription { types, _hidden: () }) + Ok(ParameterDescription { types }) } } @@ -136,11 +131,10 @@ impl Message for ParameterDescription { /// /// Data can be represented as text or binary format as specified by format /// codes from previous `RowDescription` message. +#[non_exhaustive] #[derive(PartialEq, Eq, Debug, Default, new, Clone)] pub struct DataRow { pub fields: Vec>, - #[new(default)] - _hidden: (), } impl DataRow {} @@ -189,10 +183,7 @@ impl Message for DataRow { } } - Ok(DataRow { - fields, - _hidden: (), - }) + Ok(DataRow { fields }) } } diff --git a/src/messages/extendedquery.rs b/src/messages/extendedquery.rs index 77b0472..736d6dc 100644 --- a/src/messages/extendedquery.rs +++ b/src/messages/extendedquery.rs @@ -5,13 +5,12 @@ use super::{codec, Message}; use crate::error::PgWireResult; /// Request from frontend to parse a prepared query string +#[non_exhaustive] #[derive(PartialEq, Eq, Debug, new)] pub struct Parse { pub name: Option, pub query: String, pub type_oids: Vec, - #[new(default)] - _hidden: (), } pub const MESSAGE_TYPE_BYTE_PARSE: u8 = b'P'; @@ -54,7 +53,6 @@ impl Message for Parse { name, query, type_oids, - _hidden: (), }) } } @@ -88,12 +86,11 @@ impl Message for ParseComplete { } /// Closing the prepared statement or portal +#[non_exhaustive] #[derive(PartialEq, Eq, Debug, new)] pub struct Close { pub target_type: u8, pub name: Option, - #[new(default)] - _hidden: (), } pub const TARGET_TYPE_BYTE_STATEMENT: u8 = b'S'; @@ -121,11 +118,7 @@ impl Message for Close { let target_type = buf.get_u8(); let name = codec::get_cstring(buf); - Ok(Close { - target_type, - name, - _hidden: (), - }) + Ok(Close { target_type, name }) } } @@ -158,6 +151,7 @@ impl Message for CloseComplete { } /// Bind command, for executing prepared statement +#[non_exhaustive] #[derive(PartialEq, Eq, Debug, new)] pub struct Bind { pub portal_name: Option, @@ -168,8 +162,6 @@ pub struct Bind { pub parameters: Vec>, pub result_column_format_codes: Vec, - #[new(default)] - _hidden: (), } pub const MESSAGE_TYPE_BYTE_BIND: u8 = b'B'; @@ -255,8 +247,6 @@ impl Message for Bind { parameters, result_column_format_codes, - - _hidden: (), }) } } @@ -291,12 +281,11 @@ impl Message for BindComplete { /// Describe command fron frontend to backend. For getting information of /// particular portal or statement +#[non_exhaustive] #[derive(PartialEq, Eq, Debug, new)] pub struct Describe { pub target_type: u8, pub name: Option, - #[new(default)] - _hidden: (), } pub const MESSAGE_TYPE_BYTE_DESCRIBE: u8 = b'D'; @@ -321,21 +310,16 @@ impl Message for Describe { let target_type = buf.get_u8(); let name = codec::get_cstring(buf); - Ok(Describe { - target_type, - name, - _hidden: (), - }) + Ok(Describe { target_type, name }) } } /// Execute portal by its name +#[non_exhaustive] #[derive(PartialEq, Eq, Debug, new)] pub struct Execute { pub name: Option, pub max_rows: i32, - #[new(default)] - _hidden: (), } pub const MESSAGE_TYPE_BYTE_EXECUTE: u8 = b'E'; @@ -360,11 +344,7 @@ impl Message for Execute { let name = codec::get_cstring(buf); let max_rows = buf.get_i32(); - Ok(Execute { - name, - max_rows, - _hidden: (), - }) + Ok(Execute { name, max_rows }) } } diff --git a/src/messages/simplequery.rs b/src/messages/simplequery.rs index 62b2a42..943fff9 100644 --- a/src/messages/simplequery.rs +++ b/src/messages/simplequery.rs @@ -5,11 +5,10 @@ use super::Message; use crate::error::PgWireResult; /// A sql query sent from frontend to backend. +#[non_exhaustive] #[derive(PartialEq, Eq, Debug, new)] pub struct Query { pub query: String, - #[new(default)] - _hidden: (), } pub const MESSAGE_TYPE_BYTE_QUERY: u8 = b'Q'; diff --git a/src/messages/startup.rs b/src/messages/startup.rs index d57c79f..81f8b2d 100644 --- a/src/messages/startup.rs +++ b/src/messages/startup.rs @@ -11,6 +11,7 @@ use crate::error::{PgWireError, PgWireResult}; /// terminated by a zero byte. /// the key-value parameter pairs are terminated by a zero byte, too. /// +#[non_exhaustive] #[derive(PartialEq, Eq, Debug, new)] pub struct Startup { #[new(value = "3")] @@ -19,8 +20,6 @@ pub struct Startup { pub protocol_number_minor: u16, #[new(default)] pub parameters: BTreeMap, - #[new(default)] - _hidden: (), } impl Default for Startup { @@ -425,11 +424,9 @@ impl Message for BackendKeyData { /// The backend sents a single byte 'S' or 'N' to indicate its support. Upon 'S' /// the frontend should close the connection and reinitialize a new TLS /// connection. +#[non_exhaustive] #[derive(PartialEq, Eq, Debug, new)] -pub struct SslRequest { - #[new(default)] - _hidden: (), -} +pub struct SslRequest; impl SslRequest { pub const BODY_MAGIC_NUMBER: i32 = 80877103; @@ -460,19 +457,18 @@ impl Message for SslRequest { fn decode(buf: &mut BytesMut) -> PgWireResult> { if buf.remaining() >= 8 && (&buf[4..8]).get_i32() == Self::BODY_MAGIC_NUMBER { buf.advance(8); - Ok(Some(SslRequest { _hidden: () })) + Ok(Some(SslRequest)) } else { Ok(None) } } } +#[non_exhaustive] #[derive(PartialEq, Eq, Debug, new)] pub struct SASLInitialResponse { pub auth_method: String, pub data: Option, - #[new(default)] - _hidden: (), } impl Message for SASLInitialResponse { @@ -509,19 +505,14 @@ impl Message for SASLInitialResponse { Some(buf.split_to(data_len as usize).freeze()) }; - Ok(SASLInitialResponse { - auth_method, - data, - _hidden: (), - }) + Ok(SASLInitialResponse { auth_method, data }) } } +#[non_exhaustive] #[derive(PartialEq, Eq, Debug, new)] pub struct SASLResponse { pub data: Bytes, - #[new(default)] - _hidden: (), } impl Message for SASLResponse { @@ -542,6 +533,6 @@ impl Message for SASLResponse { fn decode_body(buf: &mut BytesMut, full_len: usize) -> PgWireResult { let data = buf.split_to(full_len - 4).freeze(); - Ok(SASLResponse { data, _hidden: () }) + Ok(SASLResponse { data }) } } diff --git a/src/tokio.rs b/src/tokio.rs index 45a5476..597a803 100644 --- a/src/tokio.rs +++ b/src/tokio.rs @@ -19,11 +19,10 @@ use crate::messages::response::{SslResponse, READY_STATUS_IDLE}; use crate::messages::startup::{SslRequest, Startup}; use crate::messages::{Message, PgWireBackendMessage, PgWireFrontendMessage}; +#[non_exhaustive] #[derive(Debug, new)] pub struct PgWireMessageServerCodec { pub client_info: DefaultClient, - #[new(default)] - _hidden: (), } impl Decoder for PgWireMessageServerCodec { From 077b6f982355f6013ff6e1aac329a3e625f8d58e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Tue, 26 Dec 2023 18:35:52 +0000 Subject: [PATCH 4/6] appease clippy --- src/api/auth/mod.rs | 18 +++++++++--------- src/api/mod.rs | 10 +++++----- src/api/results.rs | 14 +++++++------- src/messages/startup.rs | 15 +++++++++------ 4 files changed, 30 insertions(+), 27 deletions(-) diff --git a/src/api/auth/mod.rs b/src/api/auth/mod.rs index cab1e3b..82cf194 100644 --- a/src/api/auth/mod.rs +++ b/src/api/auth/mod.rs @@ -93,32 +93,32 @@ pub struct Password { impl Password { pub fn salt(&self) -> Option<&[u8]> { - return self.salt.as_deref(); + self.salt.as_deref() } pub fn password(&self) -> &[u8] { - return &self.password; + &self.password } } #[derive(Debug, new)] pub struct LoginInfo<'a> { - user: Option<&'a String>, - database: Option<&'a String>, + user: Option<&'a str>, + database: Option<&'a str>, host: String, } impl<'a> LoginInfo<'a> { pub fn user(&self) -> Option<&str> { - return self.user.map(|u| u.as_str()); + self.user } pub fn database(&self) -> Option<&str> { - return self.database.map(|db| db.as_str()); + self.database } pub fn host(&self) -> &str { - return &self.host; + &self.host } pub fn from_client_info(client: &'a C) -> LoginInfo @@ -126,8 +126,8 @@ impl<'a> LoginInfo<'a> { C: ClientInfo, { LoginInfo { - user: client.metadata().get(METADATA_USER), - database: client.metadata().get(METADATA_DATABASE), + user: client.metadata().get(METADATA_USER).map(|s| s.as_str()), + database: client.metadata().get(METADATA_DATABASE).map(|s| s.as_str()), host: client.socket_addr().ip().to_string(), } } diff --git a/src/api/mod.rs b/src/api/mod.rs index f9df42f..8b4c40a 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -61,15 +61,15 @@ pub struct DefaultClient { impl ClientInfo for DefaultClient { fn socket_addr(&self) -> SocketAddr { - return self.socket_addr; + self.socket_addr } fn is_secure(&self) -> bool { - return self.is_secure; + self.is_secure } fn state(&self) -> PgWireConnectionState { - return self.state; + self.state } fn set_state(&mut self, new_state: PgWireConnectionState) { @@ -77,11 +77,11 @@ impl ClientInfo for DefaultClient { } fn metadata(&self) -> &HashMap { - return &self.metadata; + &self.metadata } fn metadata_mut(&mut self) -> &mut HashMap { - return &mut self.metadata; + &mut self.metadata } } diff --git a/src/api/results.rs b/src/api/results.rs index 182097c..2f26339 100644 --- a/src/api/results.rs +++ b/src/api/results.rs @@ -89,23 +89,23 @@ pub struct FieldInfo { impl FieldInfo { pub fn name(&self) -> &str { - return &self.name; + &self.name } pub fn table_id(&self) -> Option { - return self.table_id; + self.table_id } pub fn column_id(&self) -> Option { - return self.column_id; + self.column_id } pub fn datatype(&self) -> &Type { - return &self.datatype; + &self.datatype } pub fn format(&self) -> FieldFormat { - return self.format; + self.format } } @@ -255,11 +255,11 @@ pub struct DescribeResponse { impl DescribeResponse { pub fn parameters(&self) -> Option<&[Type]> { - return self.parameters.as_deref(); + self.parameters.as_deref() } pub fn fields(&self) -> &[FieldInfo] { - return &self.fields; + &self.fields } /// Create an no_data instance of `DescribeResponse`. This is typically used diff --git a/src/messages/startup.rs b/src/messages/startup.rs index 81f8b2d..4a9138e 100644 --- a/src/messages/startup.rs +++ b/src/messages/startup.rs @@ -85,19 +85,22 @@ impl Message for Startup { return Err(PgWireError::InvalidStartupMessage); } - let mut msg = Startup::default(); - // parse - msg.protocol_number_major = buf.get_u16(); - msg.protocol_number_minor = buf.get_u16(); + let protocol_number_major = buf.get_u16(); + let protocol_number_minor = buf.get_u16(); // end by reading the last \0 + let mut parameters = BTreeMap::new(); while let Some(key) = codec::get_cstring(buf) { let value = codec::get_cstring(buf).unwrap_or_else(|| "".to_owned()); - msg.parameters.insert(key, value); + parameters.insert(key, value); } - Ok(msg) + Ok(Startup { + protocol_number_major, + protocol_number_minor, + parameters, + }) } } From 223139507d7f52c1a84dcc03e06b791ca37006b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Tue, 26 Dec 2023 18:45:25 +0000 Subject: [PATCH 5/6] fix test --- examples/sqlite.rs | 16 ++++++++-------- tests-integration/test-server/src/main.rs | 6 +++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/sqlite.rs b/examples/sqlite.rs index 5c98a31..5d21b20 100644 --- a/examples/sqlite.rs +++ b/examples/sqlite.rs @@ -145,7 +145,7 @@ fn encode_row_data( fn get_params(portal: &Portal) -> Vec> { let mut results = Vec::with_capacity(portal.parameter_len()); for i in 0..portal.parameter_len() { - let param_type = portal.statement().parameter_types().get(i).unwrap(); + let param_type = portal.statement.parameter_types.get(i).unwrap(); // we only support a small amount of types for demo match param_type { &Type::BOOL => { @@ -204,7 +204,7 @@ impl ExtendedQueryHandler for SqliteBackend { C: ClientInfo + Unpin + Send + Sync, { let conn = self.conn.lock().unwrap(); - let query = portal.statement().statement(); + let query = &portal.statement.statement; let mut stmt = conn .prepare_cached(query) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; @@ -215,7 +215,7 @@ impl ExtendedQueryHandler for SqliteBackend { .collect::>(); if query.to_uppercase().starts_with("SELECT") { - let header = Arc::new(row_desc_from_stmt(&stmt, portal.result_column_format())?); + let header = Arc::new(row_desc_from_stmt(&stmt, &portal.result_column_format)?); stmt.query::<&[&dyn rusqlite::ToSql]>(params_ref.as_ref()) .map(|rows| { let s = encode_row_data(rows, header.clone()); @@ -242,18 +242,18 @@ impl ExtendedQueryHandler for SqliteBackend { let conn = self.conn.lock().unwrap(); match target { StatementOrPortal::Statement(stmt) => { - let param_types = Some(stmt.parameter_types().clone()); + let param_types = Some(stmt.parameter_types.clone()); let stmt = conn - .prepare_cached(stmt.statement()) + .prepare_cached(&stmt.statement) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; row_desc_from_stmt(&stmt, &Format::UnifiedBinary) .map(|fields| DescribeResponse::new(param_types, fields)) } StatementOrPortal::Portal(portal) => { let stmt = conn - .prepare_cached(portal.statement().statement()) + .prepare_cached(&portal.statement.statement) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - row_desc_from_stmt(&stmt, portal.result_column_format()) + row_desc_from_stmt(&stmt, &portal.result_column_format) .map(|fields| DescribeResponse::new(None, fields)) } } @@ -290,7 +290,7 @@ impl MakeHandler for MakeSqliteBackend { #[tokio::main] pub async fn main() { let mut parameters = DefaultServerParameterProvider::default(); - parameters.set_server_version(rusqlite::version().to_owned()); + parameters.server_version = rusqlite::version().to_owned(); let authenticator = Arc::new(MakeMd5PasswordAuthStartupHandler::new( Arc::new(DummyAuthSource), diff --git a/tests-integration/test-server/src/main.rs b/tests-integration/test-server/src/main.rs index f90fe1f..77b3771 100644 --- a/tests-integration/test-server/src/main.rs +++ b/tests-integration/test-server/src/main.rs @@ -118,7 +118,7 @@ impl ExtendedQueryHandler for DummyDatabase { where C: ClientInfo + Unpin + Send + Sync, { - let query = portal.statement().statement(); + let query = &portal.statement.statement; println!("extended query: {:?}", query); if query.starts_with("SELECT") { let data = vec![ @@ -130,7 +130,7 @@ impl ExtendedQueryHandler for DummyDatabase { ), (Some(2), None, None), ]; - let schema = Arc::new(self.schema(portal.result_column_format())); + let schema = Arc::new(self.schema(&portal.result_column_format)); let schema_ref = schema.clone(); let data_row_stream = stream::iter(data.into_iter()).map(move |r| { let mut encoder = DataRowEncoder::new(schema_ref.clone()); @@ -164,7 +164,7 @@ impl ExtendedQueryHandler for DummyDatabase { Ok(DescribeResponse::new(param_types, schema)) } StatementOrPortal::Portal(portal) => { - let schema = self.schema(portal.result_column_format()); + let schema = self.schema(&portal.result_column_format); Ok(DescribeResponse::new(None, schema)) } } From ff3768bf0814eed748572362058ecdeb9cfc1cb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Wed, 27 Dec 2023 04:51:34 +0000 Subject: [PATCH 6/6] Missed some pubs earlier --- src/messages/copy.rs | 8 ++++---- src/messages/startup.rs | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/messages/copy.rs b/src/messages/copy.rs index 2bd2c26..adeba2b 100644 --- a/src/messages/copy.rs +++ b/src/messages/copy.rs @@ -60,7 +60,7 @@ pub const MESSAGE_TYPE_BYTE_COPY_FAIL: u8 = b'f'; #[derive(PartialEq, Eq, Debug, Default, new)] pub struct CopyFail { - message: String, + pub message: String, } impl Message for CopyFail { @@ -88,9 +88,9 @@ pub const MESSAGE_TYPE_BYTE_COPY_IN_RESPONSE: u8 = b'G'; #[derive(PartialEq, Eq, Debug, Default, new)] pub struct CopyInResponse { - format: i8, - columns: i16, - column_formats: Vec, + pub format: i8, + pub columns: i16, + pub column_formats: Vec, } impl Message for CopyInResponse { diff --git a/src/messages/startup.rs b/src/messages/startup.rs index 4a9138e..2d8b175 100644 --- a/src/messages/startup.rs +++ b/src/messages/startup.rs @@ -353,8 +353,8 @@ impl Message for Password { /// parameter ack sent from backend after authentication success #[derive(PartialEq, Eq, Debug, new)] pub struct ParameterStatus { - name: String, - value: String, + pub name: String, + pub value: String, } pub const MESSAGE_TYPE_BYTE_PARAMETER_STATUS: u8 = b'S'; @@ -388,8 +388,8 @@ impl Message for ParameterStatus { /// `CancelRequestMessage` #[derive(PartialEq, Eq, Debug, new)] pub struct BackendKeyData { - pid: i32, - secret_key: i32, + pub pid: i32, + pub secret_key: i32, } pub const MESSAGE_TYPE_BYTE_BACKEND_KEY_DATA: u8 = b'K'; @@ -424,7 +424,7 @@ impl Message for BackendKeyData { /// backend supports secure connection. The packet has no message type and /// contains only a length(4) and an i32 value. /// -/// The backend sents a single byte 'S' or 'N' to indicate its support. Upon 'S' +/// The backend sends a single byte 'S' or 'N' to indicate its support. Upon 'S' /// the frontend should close the connection and reinitialize a new TLS /// connection. #[non_exhaustive]