diff --git a/src/error.rs b/src/error.rs index 00cb16a..f840aa9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -13,6 +13,8 @@ pub enum PgWireError { InvalidMessageType(u8), #[error("Invalid target type, received {0}")] InvalidTargetType(u8), + #[error("Invalid startup message")] + InvalidStartupMessage, #[error(transparent)] IoError(#[from] std::io::Error), #[error("Portal not found for name: {0:?}")] diff --git a/src/messages/startup.rs b/src/messages/startup.rs index 0fa4bbd..620d78f 100644 --- a/src/messages/startup.rs +++ b/src/messages/startup.rs @@ -29,6 +29,8 @@ impl Default for Startup { } impl Startup { + const MINIMUM_STARTUP_MESSAGE_LEN: usize = 8; + fn is_protocol_version_supported(version: i32) -> bool { version == 196608 } @@ -64,7 +66,7 @@ impl Message for Startup { fn decode(buf: &mut BytesMut) -> PgWireResult> { // packet len + protocol version // check if packet is valid - if buf.remaining() >= 8 { + if buf.remaining() >= Self::MINIMUM_STARTUP_MESSAGE_LEN { let packet_version = (&buf[4..8]).get_i32(); if !Self::is_protocol_version_supported(packet_version) { return Err(PgWireError::InvalidProtocolVersion(packet_version)); @@ -74,8 +76,17 @@ impl Message for Startup { codec::decode_packet(buf, 0, Self::decode_body) } - fn decode_body(buf: &mut BytesMut, _: usize) -> PgWireResult { + fn decode_body(buf: &mut BytesMut, msg_len: usize) -> PgWireResult { + // double check to ensure that the packet has more than 8 bytes + // `codec::decode_packet` has its validation to ensure buf remaining is + // larger than `msg_len`. So with both checks, we should not have issue + // with reading protocol numbers. + if msg_len <= Self::MINIMUM_STARTUP_MESSAGE_LEN { + return Err(PgWireError::InvalidStartupMessage); + } + let mut msg = Startup::default(); + // parse msg.set_protocol_number_major(buf.get_u16()); msg.set_protocol_number_minor(buf.get_u16());