diff --git a/Cargo.toml b/Cargo.toml index 31975090..75d7ddfe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ readme = "README.md" homepage = "https://github.com/snapview/tungstenite-rs" documentation = "https://docs.rs/tungstenite/0.20.0" repository = "https://github.com/snapview/tungstenite-rs" -version = "0.20.0" +version = "0.20.1" edition = "2018" rust-version = "1.51" include = ["benches/**/*", "src/**/*", "examples/**/*", "LICENSE-*", "README.md", "CHANGELOG.md"] diff --git a/src/error.rs b/src/error.rs index a7b33545..faea80bf 100644 --- a/src/error.rs +++ b/src/error.rs @@ -59,6 +59,9 @@ pub enum Error { /// UTF coding error. #[error("UTF-8 encoding error")] Utf8, + /// Attack attempt detected. + #[error("Attack attempt detected")] + AttackAttempt, /// Invalid URL. #[error("URL error: {0}")] Url(#[from] UrlError), diff --git a/src/handshake/machine.rs b/src/handshake/machine.rs index eacb4bfe..ee602a9c 100644 --- a/src/handshake/machine.rs +++ b/src/handshake/machine.rs @@ -20,7 +20,7 @@ pub struct HandshakeMachine { impl HandshakeMachine { /// Start reading data from the peer. pub fn start_read(stream: Stream) -> Self { - HandshakeMachine { stream, state: HandshakeState::Reading(ReadBuffer::new()) } + Self { stream, state: HandshakeState::Reading(ReadBuffer::new(), AttackCheck::new()) } } /// Start writing data to the peer. pub fn start_write>>(stream: Stream, data: D) -> Self { @@ -41,25 +41,31 @@ impl HandshakeMachine { pub fn single_round(mut self) -> Result> { trace!("Doing handshake round."); match self.state { - HandshakeState::Reading(mut buf) => { + HandshakeState::Reading(mut buf, mut attack_check) => { let read = buf.read_from(&mut self.stream).no_block()?; match read { Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)), - Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? { - buf.advance(size); - RoundResult::StageFinished(StageResult::DoneReading { - result: obj, - stream: self.stream, - tail: buf.into_vec(), + Some(count) => { + attack_check.check_incoming_packet_size(count)?; + // TODO: this is slow for big headers with too many small packets. + // The parser has to be reworked in order to work on streams instead + // of buffers. + Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? { + buf.advance(size); + RoundResult::StageFinished(StageResult::DoneReading { + result: obj, + stream: self.stream, + tail: buf.into_vec(), + }) + } else { + RoundResult::Incomplete(HandshakeMachine { + state: HandshakeState::Reading(buf, attack_check), + ..self + }) }) - } else { - RoundResult::Incomplete(HandshakeMachine { - state: HandshakeState::Reading(buf), - ..self - }) - }), + } None => Ok(RoundResult::WouldBlock(HandshakeMachine { - state: HandshakeState::Reading(buf), + state: HandshakeState::Reading(buf, attack_check), ..self })), } @@ -119,7 +125,54 @@ pub trait TryParse: Sized { #[derive(Debug)] enum HandshakeState { /// Reading data from the peer. - Reading(ReadBuffer), + Reading(ReadBuffer, AttackCheck), /// Sending data to the peer. Writing(Cursor>), } + +/// Attack mitigation. Contains counters needed to prevent DoS attacks +/// and reject valid but useless headers. +#[derive(Debug)] +pub(crate) struct AttackCheck { + /// Number of HTTP header successful reads (TCP packets). + number_of_packets: usize, + /// Total number of bytes in HTTP header. + number_of_bytes: usize, +} + +impl AttackCheck { + /// Initialize attack checking for incoming buffer. + fn new() -> Self { + Self { number_of_packets: 0, number_of_bytes: 0 } + } + + /// Check the size of an incoming packet. To be called immediately after `read()` + /// passing its returned bytes count as `size`. + fn check_incoming_packet_size(&mut self, size: usize) -> Result<()> { + self.number_of_packets += 1; + self.number_of_bytes += size; + + // TODO: these values are hardcoded. Instead of making them configurable, + // rework the way HTTP header is parsed to remove this check at all. + const MAX_BYTES: usize = 65536; + const MAX_PACKETS: usize = 512; + const MIN_PACKET_SIZE: usize = 128; + const MIN_PACKET_CHECK_THRESHOLD: usize = 64; + + if self.number_of_bytes > MAX_BYTES { + return Err(Error::AttackAttempt); + } + + if self.number_of_packets > MAX_PACKETS { + return Err(Error::AttackAttempt); + } + + if self.number_of_packets > MIN_PACKET_CHECK_THRESHOLD { + if self.number_of_packets * MIN_PACKET_SIZE > self.number_of_bytes { + return Err(Error::AttackAttempt); + } + } + + Ok(()) + } +}