From d0f96464a08aca10bd065f7d9b13e00379a9ff86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=A1niel=20Buga?= Date: Sat, 11 Nov 2023 13:19:36 +0100 Subject: [PATCH 01/11] Improve tests --- src/response.rs | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/src/response.rs b/src/response.rs index 5fbab03..9556377 100644 --- a/src/response.rs +++ b/src/response.rs @@ -749,7 +749,7 @@ mod tests { #[tokio::test] async fn can_read_with_chunked_encoding() { let mut conn = FakeSingleReadConnection::new( - b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\nB\r\nHELLO WORLD\r\n0\r\n\r\n", + b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nHELLO\r\n6\r\n WORLD\r\n0\r\n\r\n", ); let mut header_buf = [0; 200]; let response = Response::read(&mut conn, Method::GET, &mut header_buf).await.unwrap(); @@ -761,6 +761,22 @@ mod tests { assert!(conn.is_exhausted()); } + #[tokio::test] + async fn can_read_chunked_with_preloaded() { + let mut conn = FakeSingleReadConnection::new( + b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nHELLO\r\n6\r\n WORLD\r\n0\r\n\r\n", + ); + conn.read_length = 100; + let mut header_buf = [0; 200]; + let response = Response::read(&mut conn, Method::GET, &mut header_buf).await.unwrap(); + + let mut body_buf = [0; 200]; + let len = response.body().reader().read_to_end(&mut body_buf).await.unwrap(); + + assert_eq!(b"HELLO WORLD", &body_buf[..len]); + assert!(conn.is_exhausted()); + } + #[tokio::test] async fn can_read_with_chunked_encoding_empty_body() { let mut conn = FakeSingleReadConnection::new(b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"); @@ -863,11 +879,17 @@ mod tests { struct FakeSingleReadConnection { response: &'static [u8], offset: usize, + /// The fake connection will provide at most this many bytes per read + read_length: usize, } impl FakeSingleReadConnection { pub fn new(response: &'static [u8]) -> Self { - Self { response, offset: 0 } + Self { + response, + offset: 0, + read_length: 1, + } } pub fn is_exhausted(&self) -> bool { @@ -885,9 +907,12 @@ mod tests { return Ok(0); } - buf[0] = self.response[self.offset]; - self.offset += 1; - return Ok(1); + let loaded = &self.response[self.offset..]; + let len = self.read_length.min(buf.len()).min(loaded.len()); + buf[..len].copy_from_slice(&loaded[..len]); + self.offset += len; + + Ok(len) } } } From 72937f1b9d090cad04a191c0b779cbabc0beed77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=A1niel=20Buga?= Date: Sat, 11 Nov 2023 10:17:33 +0100 Subject: [PATCH 02/11] Reorganize response module --- src/response/chunked.rs | 178 +++++++++++++++++++ src/response/fixed_length.rs | 59 +++++++ src/{response.rs => response/mod.rs} | 244 ++------------------------- 3 files changed, 247 insertions(+), 234 deletions(-) create mode 100644 src/response/chunked.rs create mode 100644 src/response/fixed_length.rs rename src/{response.rs => response/mod.rs} (78%) diff --git a/src/response/chunked.rs b/src/response/chunked.rs new file mode 100644 index 0000000..efbe6cc --- /dev/null +++ b/src/response/chunked.rs @@ -0,0 +1,178 @@ +use embedded_io_async::{BufRead, Error as _, ErrorType, Read}; + +use crate::Error; + +#[derive(Clone, Copy, PartialEq, Eq)] +enum ChunkState { + NoChunk, + NotEmpty(u32), + Empty, +} + +impl ChunkState { + fn consume(&mut self, amt: usize) -> usize { + if let ChunkState::NotEmpty(remaining) = self { + let consumed = (amt as u32).min(*remaining); + *remaining -= consumed; + consumed as usize + } else { + 0 + } + } + + fn len(self) -> usize { + if let ChunkState::NotEmpty(len) = self { + len as usize + } else { + 0 + } + } +} + +/// Chunked response body reader +pub struct ChunkedBodyReader { + raw_body: B, + chunk_remaining: ChunkState, +} + +impl ChunkedBodyReader +where + C: Read, +{ + pub fn new(raw_body: C) -> Self { + Self { + raw_body, + chunk_remaining: ChunkState::NoChunk, + } + } + + pub fn is_done(&self) -> bool { + self.chunk_remaining == ChunkState::Empty + } + + async fn read_next_chunk_length(&mut self) -> Result<(), Error> { + let mut header_buf = [0; 8 + 2]; // 32 bit hex + \r + \n + let mut total_read = 0; + + 'read_size: loop { + let mut byte = 0; + self.raw_body + .read_exact(core::slice::from_mut(&mut byte)) + .await + .map_err(|e| Error::from(e).kind())?; + + if byte != b'\n' { + header_buf[total_read] = byte; + total_read += 1; + + if total_read == header_buf.len() { + return Err(Error::Codec); + } + } else { + if total_read == 0 || header_buf[total_read - 1] != b'\r' { + return Err(Error::Codec); + } + break 'read_size; + } + } + + let hex_digits = total_read - 1; + + // Prepend hex with zeros + let mut hex = [b'0'; 8]; + hex[8 - hex_digits..].copy_from_slice(&header_buf[..hex_digits]); + + let mut bytes = [0; 4]; + hex::decode_to_slice(hex, &mut bytes).map_err(|_| Error::Codec)?; + + let chunk_length = u32::from_be_bytes(bytes); + + debug!("Chunk length: {}", chunk_length); + + self.chunk_remaining = match chunk_length { + 0 => ChunkState::Empty, + other => ChunkState::NotEmpty(other), + }; + + Ok(()) + } + + async fn read_chunk_end(&mut self) -> Result<(), Error> { + // All chunks are terminated with a \r\n + let mut newline_buf = [0; 2]; + self.raw_body.read_exact(&mut newline_buf).await?; + + if newline_buf != [b'\r', b'\n'] { + return Err(Error::Codec); + } + Ok(()) + } + + /// Handles chunk boundary and returns the number of bytes in the current (or new) chunk. + async fn handle_chunk_boundary(&mut self) -> Result { + match self.chunk_remaining { + ChunkState::NoChunk => self.read_next_chunk_length().await?, + + ChunkState::NotEmpty(0) => { + // The current chunk is currently empty, advance into a new chunk... + self.read_chunk_end().await?; + self.read_next_chunk_length().await?; + } + + ChunkState::NotEmpty(_) => {} + + ChunkState::Empty => return Ok(0), + } + + if self.chunk_remaining == ChunkState::Empty { + // Read final chunk termination + self.read_chunk_end().await?; + } + + Ok(self.chunk_remaining.len()) + } +} + +impl ErrorType for ChunkedBodyReader { + type Error = Error; +} + +impl Read for ChunkedBodyReader +where + C: Read, +{ + async fn read(&mut self, buf: &mut [u8]) -> Result { + let remaining = self.handle_chunk_boundary().await?; + let max_len = buf.len().min(remaining); + + let len = self + .raw_body + .read(&mut buf[..max_len]) + .await + .map_err(|e| Error::Network(e.kind()))?; + + self.chunk_remaining.consume(len); + + Ok(len) + } +} + +impl BufRead for ChunkedBodyReader +where + C: BufRead + Read, +{ + async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> { + let remaining = self.handle_chunk_boundary().await?; + + let buf = self.raw_body.fill_buf().await.map_err(|e| Error::Network(e.kind()))?; + + let len = buf.len().min(remaining); + + Ok(&buf[..len]) + } + + fn consume(&mut self, amt: usize) { + let consumed = self.chunk_remaining.consume(amt); + self.raw_body.consume(consumed); + } +} diff --git a/src/response/fixed_length.rs b/src/response/fixed_length.rs new file mode 100644 index 0000000..b653a9d --- /dev/null +++ b/src/response/fixed_length.rs @@ -0,0 +1,59 @@ +use embedded_io_async::{BufRead, Error as _, ErrorType, Read}; + +use crate::Error; + +/// Fixed length response body reader +pub struct FixedLengthBodyReader { + pub raw_body: B, + pub remaining: usize, +} + +impl ErrorType for FixedLengthBodyReader { + type Error = Error; +} + +impl Read for FixedLengthBodyReader +where + C: Read, +{ + async fn read(&mut self, buf: &mut [u8]) -> Result { + if self.remaining == 0 { + return Ok(0); + } + + let read = self.raw_body.read(buf).await.map_err(|e| Error::Network(e.kind()))?; + self.remaining -= read; + + Ok(read) + } +} + +impl BufRead for FixedLengthBodyReader +where + C: BufRead + Read, +{ + async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> { + if self.remaining == 0 { + return Ok(&[]); + } + + let loaded = self + .raw_body + .fill_buf() + .await + .map_err(|e| Error::Network(e.kind())) + .map(|data| &data[..data.len().min(self.remaining)])?; + + if loaded.is_empty() { + return Err(Error::ConnectionAborted); + } + + Ok(loaded) + } + + fn consume(&mut self, amt: usize) { + let amt = amt.min(self.remaining); + self.remaining -= amt; + self.raw_body.consume(amt) + } +} diff --git a/src/response.rs b/src/response/mod.rs similarity index 78% rename from src/response.rs rename to src/response/mod.rs index 9556377..e95b393 100644 --- a/src/response.rs +++ b/src/response/mod.rs @@ -5,8 +5,13 @@ use heapless::Vec; use crate::headers::{ContentType, KeepAlive, TransferEncoding}; use crate::reader::BufferingReader; use crate::request::Method; +use crate::response::chunked::ChunkedBodyReader; +use crate::response::fixed_length::FixedLengthBodyReader; use crate::Error; +mod chunked; +mod fixed_length; + /// Type representing a parsed HTTP response. #[derive(Debug)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] @@ -211,10 +216,7 @@ where raw_body, remaining: content_length, }), - ReaderHint::Chunked => BodyReader::Chunked(ChunkedBodyReader { - raw_body, - chunk_remaining: ChunkState::NoChunk, - }), + ReaderHint::Chunked => BodyReader::Chunked(ChunkedBodyReader::new(raw_body)), ReaderHint::ToEnd => BodyReader::ToEnd(raw_body), } } @@ -303,7 +305,7 @@ where } reader.remaining == 0 } - BodyReader::Chunked(reader) => reader.chunk_remaining == ChunkState::Empty, + BodyReader::Chunked(reader) => reader.is_done(), BodyReader::ToEnd(_) => true, }; @@ -370,226 +372,6 @@ where } } -/// Fixed length response body reader -pub struct FixedLengthBodyReader { - raw_body: B, - remaining: usize, -} - -impl ErrorType for FixedLengthBodyReader { - type Error = Error; -} - -impl Read for FixedLengthBodyReader -where - C: Read, -{ - async fn read(&mut self, buf: &mut [u8]) -> Result { - if self.remaining == 0 { - return Ok(0); - } - - let read = self.raw_body.read(buf).await.map_err(|e| Error::Network(e.kind()))?; - self.remaining -= read; - - Ok(read) - } -} - -impl BufRead for FixedLengthBodyReader -where - C: BufRead + Read, -{ - async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> { - if self.remaining == 0 { - return Ok(&[]); - } - - let loaded = self - .raw_body - .fill_buf() - .await - .map_err(|e| Error::Network(e.kind())) - .map(|data| &data[..data.len().min(self.remaining)])?; - - if loaded.is_empty() { - return Err(Error::ConnectionAborted); - } - - Ok(loaded) - } - - fn consume(&mut self, amt: usize) { - let amt = amt.min(self.remaining); - self.remaining -= amt; - self.raw_body.consume(amt) - } -} - -#[derive(Clone, Copy, PartialEq, Eq)] -enum ChunkState { - NoChunk, - NotEmpty(u32), - Empty, -} - -impl ChunkState { - fn consume(&mut self, amt: usize) -> usize { - if let ChunkState::NotEmpty(remaining) = self { - let consumed = (amt as u32).min(*remaining); - *remaining -= consumed; - consumed as usize - } else { - 0 - } - } - - fn len(self) -> usize { - if let ChunkState::NotEmpty(len) = self { - len as usize - } else { - 0 - } - } -} - -/// Chunked response body reader -pub struct ChunkedBodyReader { - raw_body: B, - chunk_remaining: ChunkState, -} - -impl ChunkedBodyReader -where - C: Read, -{ - async fn read_next_chunk_length(&mut self) -> Result<(), Error> { - let mut header_buf = [0; 8 + 2]; // 32 bit hex + \r + \n - let mut total_read = 0; - - 'read_size: loop { - let mut byte = 0; - self.raw_body - .read_exact(core::slice::from_mut(&mut byte)) - .await - .map_err(|e| Error::from(e).kind())?; - - if byte != b'\n' { - header_buf[total_read] = byte; - total_read += 1; - - if total_read == header_buf.len() { - return Err(Error::Codec); - } - } else { - if total_read == 0 || header_buf[total_read - 1] != b'\r' { - return Err(Error::Codec); - } - break 'read_size; - } - } - - let hex_digits = total_read - 1; - - // Prepend hex with zeros - let mut hex = [b'0'; 8]; - hex[8 - hex_digits..].copy_from_slice(&header_buf[..hex_digits]); - - let mut bytes = [0; 4]; - hex::decode_to_slice(hex, &mut bytes).map_err(|_| Error::Codec)?; - - let chunk_length = u32::from_be_bytes(bytes); - - debug!("Chunk length: {}", chunk_length); - - self.chunk_remaining = match chunk_length { - 0 => ChunkState::Empty, - other => ChunkState::NotEmpty(other), - }; - - Ok(()) - } - - async fn read_chunk_end(&mut self) -> Result<(), Error> { - // All chunks are terminated with a \r\n - let mut newline_buf = [0; 2]; - self.raw_body.read_exact(&mut newline_buf).await?; - - if newline_buf != [b'\r', b'\n'] { - return Err(Error::Codec); - } - Ok(()) - } - - /// Handles chunk boundary and returns the number of bytes in the current (or new) chunk. - async fn handle_chunk_boundary(&mut self) -> Result { - match self.chunk_remaining { - ChunkState::NoChunk => self.read_next_chunk_length().await?, - - ChunkState::NotEmpty(0) => { - // The current chunk is currently empty, advance into a new chunk... - self.read_chunk_end().await?; - self.read_next_chunk_length().await?; - } - - ChunkState::NotEmpty(_) => {} - - ChunkState::Empty => return Ok(0), - } - - if self.chunk_remaining == ChunkState::Empty { - // Read final chunk termination - self.read_chunk_end().await?; - } - - Ok(self.chunk_remaining.len()) - } -} - -impl ErrorType for ChunkedBodyReader { - type Error = Error; -} - -impl Read for ChunkedBodyReader -where - C: Read, -{ - async fn read(&mut self, buf: &mut [u8]) -> Result { - let remaining = self.handle_chunk_boundary().await?; - let max_len = buf.len().min(remaining); - - let len = self - .raw_body - .read(&mut buf[..max_len]) - .await - .map_err(|e| Error::Network(e.kind()))?; - - self.chunk_remaining.consume(len); - - Ok(len) - } -} - -impl BufRead for ChunkedBodyReader -where - C: BufRead + Read, -{ - async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> { - let remaining = self.handle_chunk_boundary().await?; - - let buf = self.raw_body.fill_buf().await.map_err(|e| Error::Network(e.kind()))?; - - let len = buf.len().min(remaining); - - Ok(&buf[..len]) - } - - fn consume(&mut self, amt: usize) { - let consumed = self.chunk_remaining.consume(amt); - self.raw_body.consume(consumed); - } -} - /// HTTP status types #[derive(Clone, Copy, Debug, PartialEq, PartialOrd)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] @@ -693,7 +475,7 @@ mod tests { use crate::{ reader::BufferingReader, request::Method, - response::{ChunkState, ChunkedBodyReader, Response}, + response::{chunked::ChunkedBodyReader, Response}, Error, }; @@ -841,10 +623,7 @@ mod tests { async fn chunked_body_reader_can_read_with_large_buffer() { let mut raw_body = b"1\r\nX\r\n10\r\nYYYYYYYYYYYYYYYY\r\n0\r\n\r\n".as_slice(); let mut read_buffer = [0; 128]; - let mut reader = ChunkedBodyReader { - raw_body: BufferingReader::new(&mut read_buffer, 0, &mut raw_body), - chunk_remaining: ChunkState::NoChunk, - }; + let mut reader = ChunkedBodyReader::new(BufferingReader::new(&mut read_buffer, 0, &mut raw_body)); let mut body = [0; 17]; reader.read_exact(&mut body).await.unwrap(); @@ -858,10 +637,7 @@ mod tests { async fn chunked_body_reader_can_read_with_tiny_buffer() { let mut raw_body = b"1\r\nX\r\n10\r\nYYYYYYYYYYYYYYYY\r\n0\r\n\r\n".as_slice(); let mut read_buffer = [0; 128]; - let mut reader = ChunkedBodyReader { - raw_body: BufferingReader::new(&mut read_buffer, 0, &mut raw_body), - chunk_remaining: ChunkState::NoChunk, - }; + let mut reader = ChunkedBodyReader::new(BufferingReader::new(&mut read_buffer, 0, &mut raw_body)); let mut body = heapless::Vec::::new(); for _ in 0..17 { From d96c0f54a2752bed42cc8d96afe1c3cfc20059e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=A1niel=20Buga?= Date: Sat, 11 Nov 2023 11:49:59 +0100 Subject: [PATCH 03/11] Introduce TryBufRead, clean up --- src/lib.rs | 33 +++++++++++++++ src/reader.rs | 25 ++++++----- src/response/chunked.rs | 2 +- src/response/mod.rs | 55 ++++++++++++++---------- tests/client.rs | 76 +++------------------------------ tests/connection.rs | 93 +++++++++++++++++++++++++++++++++++++++++ tests/request.rs | 6 ++- 7 files changed, 181 insertions(+), 109 deletions(-) create mode 100644 tests/connection.rs diff --git a/src/lib.rs b/src/lib.rs index 0923b60..b2c5ecd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -90,3 +90,36 @@ impl From for Error { Error::InvalidUrl(e) } } + +/// Trait for types that may optionally implement [`embedded_io_async::BufRead`] +pub trait TryBufRead: embedded_io_async::Read { + async fn try_fill_buf(&mut self) -> Option> { + None + } + + fn try_consume(&mut self, _amt: usize) {} +} + +impl TryBufRead for crate::client::HttpConnection<'_, C> +where + C: embedded_io_async::Read + embedded_io_async::Write, +{ + async fn try_fill_buf(&mut self) -> Option> { + // embedded-tls has its own internal buffer, let's prefer that if we can + #[cfg(feature = "embedded-tls")] + if let Self::Tls(ref mut tls) = self { + use embedded_io_async::{BufRead, Error}; + return Some(tls.fill_buf().await.map_err(|e| e.kind())); + } + + None + } + + fn try_consume(&mut self, amt: usize) { + #[cfg(feature = "embedded-tls")] + if let Self::Tls(tls) = self { + use embedded_io_async::BufRead; + tls.consume(amt); + } + } +} diff --git a/src/reader.rs b/src/reader.rs index be19719..098c784 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -1,7 +1,7 @@ use embedded_io::{Error, ErrorKind, ErrorType}; -use embedded_io_async::{BufRead, Read, Write}; +use embedded_io_async::{BufRead, Read}; -use crate::client::HttpConnection; +use crate::TryBufRead; struct ReadBuffer<'buf> { buffer: &'buf mut [u8], @@ -83,20 +83,22 @@ where } } -impl BufRead for BufferingReader<'_, '_, HttpConnection<'_, C>> +impl BufRead for BufferingReader<'_, '_, C> where - C: Read + Write, + C: TryBufRead, { async fn fill_buf(&mut self) -> Result<&[u8], ErrorKind> { // We need to consume the loaded bytes before we read mode. if self.buffer.is_empty() { - // embedded-tls has its own internal buffer, let's prefer that if we can - #[cfg(feature = "embedded-tls")] - if let HttpConnection::Tls(ref mut tls) = self.stream { - return tls.fill_buf().await.map_err(|e| e.kind()); + // The matches/if let dance is to fix lifetime of the borrowed inner connection. + if self.stream.try_fill_buf().await.is_some() { + if let Some(result) = self.stream.try_fill_buf().await { + return result.map_err(|e| e.kind()); + } + unreachable!() } - self.buffer.loaded = self.stream.read(&mut self.buffer.buffer).await?; + self.buffer.loaded = self.stream.read(&mut self.buffer.buffer).await.map_err(|e| e.kind())?; } self.buffer.fill_buf() @@ -109,10 +111,7 @@ where let unconsumed = self.buffer.consume(amt); if unconsumed > 0 { - #[cfg(feature = "embedded-tls")] - if let HttpConnection::Tls(tls) = &mut self.stream { - tls.consume(unconsumed); - } + self.stream.try_consume(unconsumed); } } } diff --git a/src/response/chunked.rs b/src/response/chunked.rs index efbe6cc..085cfec 100644 --- a/src/response/chunked.rs +++ b/src/response/chunked.rs @@ -31,7 +31,7 @@ impl ChunkState { /// Chunked response body reader pub struct ChunkedBodyReader { - raw_body: B, + pub raw_body: B, chunk_remaining: ChunkState, } diff --git a/src/response/mod.rs b/src/response/mod.rs index e95b393..6008d43 100644 --- a/src/response/mod.rs +++ b/src/response/mod.rs @@ -7,7 +7,7 @@ use crate::reader::BufferingReader; use crate::request::Method; use crate::response::chunked::ChunkedBodyReader; use crate::response::fixed_length::FixedLengthBodyReader; -use crate::Error; +use crate::{Error, TryBufRead}; mod chunked; mod fixed_length; @@ -196,6 +196,7 @@ where pub body_buf: &'buf mut [u8], } +#[derive(Clone, Copy)] enum ReaderHint { Empty, FixedLength(usize), @@ -203,14 +204,9 @@ enum ReaderHint { ToEnd, // https://www.rfc-editor.org/rfc/rfc7230#section-3.3.3 pt. 7: Until end of connection } -impl<'resp, 'buf, C> ResponseBody<'resp, 'buf, C> -where - C: Read, -{ - pub fn reader(self) -> BodyReader> { - let raw_body = BufferingReader::new(self.body_buf, self.raw_body_read, self.conn); - - match self.reader_hint { +impl ReaderHint { + fn reader(self, raw_body: R) -> BodyReader { + match self { ReaderHint::Empty => BodyReader::Empty, ReaderHint::FixedLength(content_length) => BodyReader::FixedLength(FixedLengthBodyReader { raw_body, @@ -225,6 +221,17 @@ where impl<'resp, 'buf, C> ResponseBody<'resp, 'buf, C> where C: Read, +{ + pub fn reader(self) -> BodyReader> { + let raw_body = BufferingReader::new(self.body_buf, self.raw_body_read, self.conn); + + self.reader_hint.reader(raw_body) + } +} + +impl<'resp, 'buf, C> ResponseBody<'resp, 'buf, C> +where + C: Read + TryBufRead, { /// Read the entire body into the buffer originally provided [`Response::read()`]. /// This requires that this original buffer is large enough to contain the entire body. @@ -286,6 +293,15 @@ impl BodyReader where B: Read, { + fn is_done(&self) -> bool { + match self { + BodyReader::Empty => true, + BodyReader::FixedLength(reader) => reader.remaining == 0, + BodyReader::Chunked(reader) => reader.is_done(), + BodyReader::ToEnd(_) => true, + } + } + /// Read the entire body pub async fn read_to_end(&mut self, buf: &mut [u8]) -> Result { let mut len = 0; @@ -297,21 +313,12 @@ where } } - let is_done = match self { - BodyReader::Empty => true, - BodyReader::FixedLength(reader) => { - if reader.remaining > 0 { - warn!("FixedLength: {} bytes remained", reader.remaining); - } - reader.remaining == 0 - } - BodyReader::Chunked(reader) => reader.is_done(), - BodyReader::ToEnd(_) => true, - }; - - if is_done { + if self.is_done() { Ok(len) } else { + if let BodyReader::FixedLength(reader) = self { + warn!("FixedLength: {} bytes remained", reader.remaining); + } Err(Error::BufferTooSmall) } } @@ -476,7 +483,7 @@ mod tests { reader::BufferingReader, request::Method, response::{chunked::ChunkedBodyReader, Response}, - Error, + Error, TryBufRead, }; #[tokio::test] @@ -691,4 +698,6 @@ mod tests { Ok(len) } } + + impl TryBufRead for FakeSingleReadConnection {} } diff --git a/tests/client.rs b/tests/client.rs index 40baa0c..3dca76e 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -1,8 +1,6 @@ #![feature(async_fn_in_trait)] #![allow(incomplete_features)] -use embedded_io_adapters::tokio_1::FromTokio; use embedded_io_async::BufRead; -use embedded_nal_async::{AddrType, IpAddr, Ipv4Addr}; use hyper::server::conn::Http; use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Server}; @@ -12,14 +10,17 @@ use reqwless::client::HttpClient; use reqwless::headers::ContentType; use reqwless::request::{Method, RequestBuilder}; use reqwless::response::Status; -use std::net::{SocketAddr, ToSocketAddrs}; +use std::net::SocketAddr; use std::sync::Once; use tokio::net::TcpListener; -use tokio::net::TcpStream; use tokio::sync::oneshot; use tokio_rustls::rustls; use tokio_rustls::TlsAcceptor; +mod connection; + +use connection::*; + static INIT: Once = Once::new(); fn setup() { @@ -323,73 +324,6 @@ fn load_private_key(filename: &std::path::PathBuf) -> rustls::PrivateKey { panic!("no keys found in {:?} (encrypted keys not supported)", filename); } -struct LoopbackDns; -impl embedded_nal_async::Dns for LoopbackDns { - type Error = TestError; - - async fn get_host_by_name(&self, _: &str, _: AddrType) -> Result { - Ok(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))) - } - - async fn get_host_by_address(&self, _: IpAddr, _: &mut [u8]) -> Result { - Err(TestError) - } -} - -struct StdDns; - -impl embedded_nal_async::Dns for StdDns { - type Error = std::io::Error; - - async fn get_host_by_name(&self, host: &str, addr_type: AddrType) -> Result { - for address in (host, 0).to_socket_addrs()? { - match address { - SocketAddr::V4(a) if addr_type == AddrType::IPv4 || addr_type == AddrType::Either => { - return Ok(IpAddr::V4(a.ip().octets().into())) - } - SocketAddr::V6(a) if addr_type == AddrType::IPv6 || addr_type == AddrType::Either => { - return Ok(IpAddr::V6(a.ip().octets().into())) - } - _ => {} - } - } - Err(std::io::ErrorKind::AddrNotAvailable.into()) - } - - async fn get_host_by_address(&self, _: IpAddr, _: &mut [u8]) -> Result { - todo!() - } -} - -struct TokioTcp; -#[derive(Debug)] -struct TestError; - -impl embedded_io::Error for TestError { - fn kind(&self) -> embedded_io::ErrorKind { - embedded_io::ErrorKind::Other - } -} - -impl embedded_nal_async::TcpConnect for TokioTcp { - type Error = std::io::Error; - type Connection<'m> = FromTokio; - - async fn connect<'m>( - &'m self, - remote: embedded_nal_async::SocketAddr, - ) -> Result, Self::Error> { - let ip = match remote { - embedded_nal_async::SocketAddr::V4(a) => a.ip().octets().into(), - embedded_nal_async::SocketAddr::V6(a) => a.ip().octets().into(), - }; - let remote = SocketAddr::new(ip, remote.port()); - let stream = TcpStream::connect(remote).await?; - let stream = FromTokio::new(stream); - Ok(stream) - } -} - async fn echo(req: hyper::Request) -> Result, hyper::Error> { match (req.method(), req.uri().path()) { _ => Ok(hyper::Response::new(req.into_body())), diff --git a/tests/connection.rs b/tests/connection.rs new file mode 100644 index 0000000..07f8e00 --- /dev/null +++ b/tests/connection.rs @@ -0,0 +1,93 @@ +use embedded_io_adapters::tokio_1::FromTokio; +use embedded_io_async::{ErrorType, Read, Write}; +use embedded_nal_async::{AddrType, IpAddr, Ipv4Addr}; +use reqwless::TryBufRead; +use std::net::{SocketAddr, ToSocketAddrs}; +use tokio::net::TcpStream; + +#[derive(Debug)] +pub struct TestError; + +impl embedded_io::Error for TestError { + fn kind(&self) -> embedded_io::ErrorKind { + embedded_io::ErrorKind::Other + } +} + +pub struct LoopbackDns; +impl embedded_nal_async::Dns for LoopbackDns { + type Error = TestError; + + async fn get_host_by_name(&self, _: &str, _: AddrType) -> Result { + Ok(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))) + } + + async fn get_host_by_address(&self, _: IpAddr, _: &mut [u8]) -> Result { + Err(TestError) + } +} + +pub struct StdDns; + +impl embedded_nal_async::Dns for StdDns { + type Error = std::io::Error; + + async fn get_host_by_name(&self, host: &str, addr_type: AddrType) -> Result { + for address in (host, 0).to_socket_addrs()? { + match address { + SocketAddr::V4(a) if addr_type == AddrType::IPv4 || addr_type == AddrType::Either => { + return Ok(IpAddr::V4(a.ip().octets().into())) + } + SocketAddr::V6(a) if addr_type == AddrType::IPv6 || addr_type == AddrType::Either => { + return Ok(IpAddr::V6(a.ip().octets().into())) + } + _ => {} + } + } + Err(std::io::ErrorKind::AddrNotAvailable.into()) + } + + async fn get_host_by_address(&self, _: IpAddr, _: &mut [u8]) -> Result { + todo!() + } +} + +pub struct TokioTcp; +pub struct TokioStream(pub(crate) FromTokio); + +impl TryBufRead for TokioStream {} + +impl ErrorType for TokioStream { + type Error = as ErrorType>::Error; +} + +impl Read for TokioStream { + async fn read(&mut self, buf: &mut [u8]) -> Result { + self.0.read(buf).await + } +} + +impl Write for TokioStream { + async fn write(&mut self, buf: &[u8]) -> Result { + self.0.write(buf).await + } +} + +impl embedded_nal_async::TcpConnect for TokioTcp { + type Error = std::io::Error; + type Connection<'m> = TokioStream; + + async fn connect<'m>( + &'m self, + remote: embedded_nal_async::SocketAddr, + ) -> Result, Self::Error> { + let ip = match remote { + embedded_nal_async::SocketAddr::V4(a) => a.ip().octets().into(), + embedded_nal_async::SocketAddr::V6(a) => a.ip().octets().into(), + }; + let remote = SocketAddr::new(ip, remote.port()); + let stream = TcpStream::connect(remote).await?; + let stream = FromTokio::new(stream); + Ok(TokioStream(stream)) + } +} diff --git a/tests/request.rs b/tests/request.rs index d9a504e..7f7f081 100644 --- a/tests/request.rs +++ b/tests/request.rs @@ -9,6 +9,10 @@ use std::sync::Once; use tokio::net::TcpStream; use tokio::sync::oneshot; +mod connection; + +use connection::*; + static INIT: Once = Once::new(); fn setup() { @@ -36,7 +40,7 @@ async fn test_request_response() { }); let stream = TcpStream::connect(addr).await.unwrap(); - let mut stream = HttpConnection::Plain(FromTokio::new(stream)); + let mut stream = HttpConnection::Plain(TokioStream(FromTokio::new(stream))); let request = Request::post("/") .body(b"PING".as_slice()) From 45d9a4728e37f465f775fc4f042e276c5f636c2e Mon Sep 17 00:00:00 2001 From: Ulf Lilleengen Date: Thu, 9 Nov 2023 21:52:04 +0100 Subject: [PATCH 04/11] unit test reproducing google panic Add failing test case Introduce TryBufRead, clean up --- src/response/mod.rs | 14 ++++++++++++++ tests/request.rs | 21 +++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/src/response/mod.rs b/src/response/mod.rs index 6008d43..80c12be 100644 --- a/src/response/mod.rs +++ b/src/response/mod.rs @@ -591,6 +591,20 @@ mod tests { assert!(conn.is_exhausted()); } + #[tokio::test] + async fn can_read_to_end_with_chunked_encoding() { + let mut conn = FakeSingleReadConnection::new( + b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nHELLO\r\n6\r\n WORLD\r\n0\r\n\r\n", + ); + let mut header_buf = [0; 200]; + let response = Response::read(&mut conn, Method::GET, &mut header_buf).await.unwrap(); + + let body = response.body().read_to_end().await.unwrap(); + + assert_eq!(b"HELLO WORLD", body); + assert!(conn.is_exhausted()); + } + #[tokio::test] async fn can_read_to_end_of_connection_with_same_buffer() { let mut conn = FakeSingleReadConnection::new(b"HTTP/1.1 200 OK\r\n\r\nHELLO WORLD"); diff --git a/tests/request.rs b/tests/request.rs index 7f7f081..8eb102b 100644 --- a/tests/request.rs +++ b/tests/request.rs @@ -73,3 +73,24 @@ async fn write_without_base_path() { assert!(from_utf8(&buf).unwrap().starts_with("GET /hello HTTP/1.1")); } + +#[tokio::test] +async fn google_panic() { + use std::net::SocketAddr; + let google_ip = [142, 250, 74, 110]; + let addr = SocketAddr::from((google_ip, 80)); + + let conn = tokio::net::TcpStream::connect(addr).await.unwrap(); + let mut conn = TokioStream(FromTokio::new(conn)); + + let request = Request::get("/") + .host("www.google.com") + .content_type(ContentType::TextPlain) + .build(); + request.write(&mut conn).await.unwrap(); + + let mut rx_buf = [0; 8 * 1024]; + let resp = Response::read(&mut conn, Method::GET, &mut rx_buf).await.unwrap(); + let body = resp.body().read_to_end().await.unwrap(); + println!("{} -> {}", body.len(), core::str::from_utf8(&body).unwrap()); +} From 14fa5f2ce53f578f48b4e23ecaa5c92c3a74d7ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=A1niel=20Buga?= Date: Sat, 11 Nov 2023 13:14:52 +0100 Subject: [PATCH 05/11] Fix read_to_end with chunked encoding --- src/reader.rs | 10 ++++---- src/response/chunked.rs | 57 +++++++++++++++++++++++++++++++++++++++-- src/response/mod.rs | 6 ++++- 3 files changed, 65 insertions(+), 8 deletions(-) diff --git a/src/reader.rs b/src/reader.rs index 098c784..f3e29c3 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -3,9 +3,9 @@ use embedded_io_async::{BufRead, Read}; use crate::TryBufRead; -struct ReadBuffer<'buf> { - buffer: &'buf mut [u8], - loaded: usize, +pub struct ReadBuffer<'buf> { + pub buffer: &'buf mut [u8], + pub loaded: usize, } impl<'buf> ReadBuffer<'buf> { @@ -46,8 +46,8 @@ pub struct BufferingReader<'resp, 'buf, B> where B: Read, { - buffer: ReadBuffer<'buf>, - stream: &'resp mut B, + pub buffer: ReadBuffer<'buf>, + pub stream: &'resp mut B, } impl<'resp, 'buf, B> BufferingReader<'resp, 'buf, B> diff --git a/src/response/chunked.rs b/src/response/chunked.rs index 085cfec..47f00b7 100644 --- a/src/response/chunked.rs +++ b/src/response/chunked.rs @@ -1,8 +1,11 @@ use embedded_io_async::{BufRead, Error as _, ErrorType, Read}; -use crate::Error; +use crate::{ + reader::{BufferingReader, ReadBuffer}, + Error, TryBufRead, +}; -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Clone, Copy, PartialEq, Eq, Debug)] enum ChunkState { NoChunk, NotEmpty(u32), @@ -133,6 +136,56 @@ where } } +impl<'conn, 'buf, C> ChunkedBodyReader> +where + C: Read + TryBufRead, +{ + pub(crate) async fn read_to_end(self) -> Result<&'buf mut [u8], Error> { + let buffer = self.raw_body.buffer.buffer; + + // We reconstruct the reader to change the 'buf lifetime. + let mut reader = ChunkedBodyReader { + raw_body: BufferingReader { + buffer: ReadBuffer { + buffer: &mut buffer[..], + loaded: self.raw_body.buffer.loaded, + }, + stream: self.raw_body.stream, + }, + chunk_remaining: self.chunk_remaining, + }; + + let mut len = 0; + while len < reader.raw_body.buffer.buffer.len() { + // Read some + let read = reader.fill_buf().await?.len(); + len += read; + + // Make sure we don't erase the newly read data + let was_loaded = reader.raw_body.buffer.loaded; + let fake_loaded = read.min(was_loaded); + reader.raw_body.buffer.loaded = fake_loaded; + + // Consume the returned buffer + reader.consume(read); + + if reader.is_done() { + // If we're done, we don't care about the rest of the housekeeping. + break; + } + + // How many bytes were actually consumed from the preloaded buffer? + let consumed_from_buffer = fake_loaded - reader.raw_body.buffer.loaded; + + // ... move the buffer by that many bytes to avoid overwriting in the next iteration. + reader.raw_body.buffer.loaded = was_loaded - consumed_from_buffer; + reader.raw_body.buffer.buffer = &mut reader.raw_body.buffer.buffer[consumed_from_buffer..]; + } + + Ok(&mut buffer[..len]) + } +} + impl ErrorType for ChunkedBodyReader { type Error = Error; } diff --git a/src/response/mod.rs b/src/response/mod.rs index 80c12be..52c9d26 100644 --- a/src/response/mod.rs +++ b/src/response/mod.rs @@ -253,7 +253,10 @@ where Ok(&mut self.body_buf[..content_length]) } - ReaderHint::Chunked => Err(Error::Codec), + ReaderHint::Chunked => { + let raw_body = BufferingReader::new(self.body_buf, self.raw_body_read, self.conn); + ChunkedBodyReader::new(raw_body).read_to_end().await + } ReaderHint::ToEnd => { let mut body_len = self.raw_body_read; loop { @@ -596,6 +599,7 @@ mod tests { let mut conn = FakeSingleReadConnection::new( b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nHELLO\r\n6\r\n WORLD\r\n0\r\n\r\n", ); + conn.read_length = 10; let mut header_buf = [0; 200]; let response = Response::read(&mut conn, Method::GET, &mut header_buf).await.unwrap(); From 841af88edfa0c60cd0eed6b27588e593ca6e03a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=A1niel=20Buga?= Date: Mon, 20 Nov 2023 16:03:03 +0100 Subject: [PATCH 06/11] Remove outdated documentation --- src/response/mod.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/response/mod.rs b/src/response/mod.rs index 52c9d26..573d181 100644 --- a/src/response/mod.rs +++ b/src/response/mod.rs @@ -235,11 +235,6 @@ where { /// Read the entire body into the buffer originally provided [`Response::read()`]. /// This requires that this original buffer is large enough to contain the entire body. - /// - /// This is not valid for chunked responses as it requires that the body bytes over-read - /// while parsing the http response header would be available for the body reader. - /// For this case, or if the original buffer is not large enough, use - /// [`BodyReader::read_to_end()`] instead from the reader returned by [`ResponseBody::reader()`]. pub async fn read_to_end(self) -> Result<&'buf mut [u8], Error> { // We can only read responses with Content-Length header to end using the body_buf buffer, // as any other response would require the body reader to know the entire body. From 3f68168f385c0be80ddc30626fbb5f5e8581984e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=A1niel=20Buga?= Date: Mon, 20 Nov 2023 16:04:59 +0100 Subject: [PATCH 07/11] Hide ReadBuffer --- src/reader.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/reader.rs b/src/reader.rs index f3e29c3..bb22d13 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -3,7 +3,7 @@ use embedded_io_async::{BufRead, Read}; use crate::TryBufRead; -pub struct ReadBuffer<'buf> { +pub(crate) struct ReadBuffer<'buf> { pub buffer: &'buf mut [u8], pub loaded: usize, } @@ -46,8 +46,8 @@ pub struct BufferingReader<'resp, 'buf, B> where B: Read, { - pub buffer: ReadBuffer<'buf>, - pub stream: &'resp mut B, + pub(crate) buffer: ReadBuffer<'buf>, + pub(crate) stream: &'resp mut B, } impl<'resp, 'buf, B> BufferingReader<'resp, 'buf, B> From b982905a8a22c1ee1afa8ebe36f5ddb3b487216e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=A1niel=20Buga?= Date: Tue, 21 Nov 2023 09:29:06 +0100 Subject: [PATCH 08/11] Fix using the entire available buffer --- src/response/chunked.rs | 2 +- src/response/mod.rs | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/response/chunked.rs b/src/response/chunked.rs index 47f00b7..e786319 100644 --- a/src/response/chunked.rs +++ b/src/response/chunked.rs @@ -156,7 +156,7 @@ where }; let mut len = 0; - while len < reader.raw_body.buffer.buffer.len() { + while !reader.raw_body.buffer.buffer.is_empty() { // Read some let read = reader.fill_buf().await?.len(); len += read; diff --git a/src/response/mod.rs b/src/response/mod.rs index 573d181..90f24db 100644 --- a/src/response/mod.rs +++ b/src/response/mod.rs @@ -604,6 +604,21 @@ mod tests { assert!(conn.is_exhausted()); } + #[tokio::test] + async fn can_read_to_end_into_a_small_buffer() { + let mut conn = FakeSingleReadConnection::new( + b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nHELLO\r\n6\r\n WORLD\r\n1\r\n \r\n5\r\nHELLO\r\n6\r\n WORLD\r\n1\r\n \r\n5\r\nHELLO\r\n6\r\n WORLD\r\n0\r\n\r\n", + ); + conn.read_length = 10; + let mut header_buf = [0; 50]; // buffer is long enough to hold the complete response + let response = Response::read(&mut conn, Method::GET, &mut header_buf).await.unwrap(); + + let body = response.body().read_to_end().await.unwrap(); + + assert_eq!(b"HELLO WORLD HELLO WORLD HELLO WORLD", body); + assert!(conn.is_exhausted()); + } + #[tokio::test] async fn can_read_to_end_of_connection_with_same_buffer() { let mut conn = FakeSingleReadConnection::new(b"HTTP/1.1 200 OK\r\n\r\nHELLO WORLD"); From 4f9853c50cdd30e1d60982bbfceb7b06f4611d4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=A1niel=20Buga?= Date: Wed, 22 Nov 2023 08:33:11 +0100 Subject: [PATCH 09/11] Return BufferTooSmall when reading to end --- src/response/chunked.rs | 4 ++++ src/response/mod.rs | 24 +++++++++++++++++++++++- tests/request.rs | 12 ++++++++++-- 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/src/response/chunked.rs b/src/response/chunked.rs index e786319..29e3594 100644 --- a/src/response/chunked.rs +++ b/src/response/chunked.rs @@ -182,6 +182,10 @@ where reader.raw_body.buffer.buffer = &mut reader.raw_body.buffer.buffer[consumed_from_buffer..]; } + if !reader.is_done() { + return Err(Error::BufferTooSmall); + } + Ok(&mut buffer[..len]) } } diff --git a/src/response/mod.rs b/src/response/mod.rs index 90f24db..cfa1535 100644 --- a/src/response/mod.rs +++ b/src/response/mod.rs @@ -242,10 +242,16 @@ where ReaderHint::Empty => Ok(&mut []), ReaderHint::FixedLength(content_length) => { // Read into the buffer after the portion that was already received when parsing the header + let to_read = self.body_buf.len().min(content_length); self.conn - .read_exact(&mut self.body_buf[self.raw_body_read..content_length]) + .read_exact(&mut self.body_buf[self.raw_body_read..to_read]) .await?; + if content_length > self.body_buf.len() { + warn!("FixedLength: {} bytes remained", content_length - self.body_buf.len()); + return Err(Error::BufferTooSmall); + } + Ok(&mut self.body_buf[..content_length]) } ReaderHint::Chunked => { @@ -509,6 +515,22 @@ mod tests { assert!(conn.is_exhausted()); } + #[tokio::test] + async fn read_to_end_with_content_length_with_small_buffer() { + let mut conn = FakeSingleReadConnection::new( + b"HTTP/1.1 200 OK\r\nContent-Length: 52\r\n\r\nHELLO WORLD this is some longer response for testing", + ); + let mut header_buf = [0; 40]; + let response = Response::read(&mut conn, Method::GET, &mut header_buf).await.unwrap(); + + let body = response.body().read_to_end().await.expect_err("Failure expected"); + + match body { + Error::BufferTooSmall => {} + e => panic!("Unexpected error: {e:?}"), + } + } + #[tokio::test] async fn can_discard_with_content_length() { let mut conn = FakeSingleReadConnection::new(b"HTTP/1.1 200 OK\r\nContent-Length: 11\r\n\r\nHELLO WORLD"); diff --git a/tests/request.rs b/tests/request.rs index 8eb102b..111ccc5 100644 --- a/tests/request.rs +++ b/tests/request.rs @@ -3,6 +3,7 @@ use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Server}; use reqwless::client::HttpConnection; use reqwless::request::{Method, RequestBuilder}; +use reqwless::Error; use reqwless::{headers::ContentType, request::Request, response::Response}; use std::str::from_utf8; use std::sync::Once; @@ -91,6 +92,13 @@ async fn google_panic() { let mut rx_buf = [0; 8 * 1024]; let resp = Response::read(&mut conn, Method::GET, &mut rx_buf).await.unwrap(); - let body = resp.body().read_to_end().await.unwrap(); - println!("{} -> {}", body.len(), core::str::from_utf8(&body).unwrap()); + let result = resp.body().read_to_end().await; + + match result { + Ok(body) => { + println!("{} -> {}", body.len(), core::str::from_utf8(&body).unwrap()); + } + Err(Error::BufferTooSmall) => println!("Buffer too small"), + Err(e) => panic!("Unexpected error: {e:?}"), + } } From bf5f221b992c669c216a33499362d8b22b5ac09e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=A1niel=20Buga?= Date: Wed, 22 Nov 2023 08:40:06 +0100 Subject: [PATCH 10/11] Reuse body readers --- src/response/mod.rs | 41 +++++++++++++---------------------------- 1 file changed, 13 insertions(+), 28 deletions(-) diff --git a/src/response/mod.rs b/src/response/mod.rs index cfa1535..d3a4ed3 100644 --- a/src/response/mod.rs +++ b/src/response/mod.rs @@ -236,43 +236,28 @@ where /// Read the entire body into the buffer originally provided [`Response::read()`]. /// This requires that this original buffer is large enough to contain the entire body. pub async fn read_to_end(self) -> Result<&'buf mut [u8], Error> { - // We can only read responses with Content-Length header to end using the body_buf buffer, - // as any other response would require the body reader to know the entire body. match self.reader_hint { ReaderHint::Empty => Ok(&mut []), ReaderHint::FixedLength(content_length) => { - // Read into the buffer after the portion that was already received when parsing the header - let to_read = self.body_buf.len().min(content_length); - self.conn - .read_exact(&mut self.body_buf[self.raw_body_read..to_read]) - .await?; - - if content_length > self.body_buf.len() { - warn!("FixedLength: {} bytes remained", content_length - self.body_buf.len()); - return Err(Error::BufferTooSmall); - } - - Ok(&mut self.body_buf[..content_length]) + let read = BodyReader::FixedLength(FixedLengthBodyReader { + raw_body: self.conn, + remaining: content_length - self.raw_body_read, + }) + .read_to_end(&mut self.body_buf[self.raw_body_read..]) + .await?; + + Ok(&mut self.body_buf[..read + self.raw_body_read]) } ReaderHint::Chunked => { let raw_body = BufferingReader::new(self.body_buf, self.raw_body_read, self.conn); ChunkedBodyReader::new(raw_body).read_to_end().await } ReaderHint::ToEnd => { - let mut body_len = self.raw_body_read; - loop { - let len = self - .conn - .read(&mut self.body_buf[body_len..]) - .await - .map_err(|e| e.kind())?; - if len == 0 { - break; - } - body_len += len; - } - - Ok(&mut self.body_buf[..body_len]) + let read = BodyReader::ToEnd(self.conn) + .read_to_end(&mut self.body_buf[self.raw_body_read..]) + .await?; + + Ok(&mut self.body_buf[..read + self.raw_body_read]) } } } From 9c2e257c85ba8e9b2ab37a3979529b6eb4a48b5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=A1niel=20Buga?= Date: Wed, 22 Nov 2023 10:19:55 +0100 Subject: [PATCH 11/11] Detect BufferTooSmall error for ToEnd bodies --- src/response/mod.rs | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/src/response/mod.rs b/src/response/mod.rs index d3a4ed3..2661add 100644 --- a/src/response/mod.rs +++ b/src/response/mod.rs @@ -235,7 +235,7 @@ where { /// Read the entire body into the buffer originally provided [`Response::read()`]. /// This requires that this original buffer is large enough to contain the entire body. - pub async fn read_to_end(self) -> Result<&'buf mut [u8], Error> { + pub async fn read_to_end(mut self) -> Result<&'buf mut [u8], Error> { match self.reader_hint { ReaderHint::Empty => Ok(&mut []), ReaderHint::FixedLength(content_length) => { @@ -253,7 +253,7 @@ where ChunkedBodyReader::new(raw_body).read_to_end().await } ReaderHint::ToEnd => { - let read = BodyReader::ToEnd(self.conn) + let read = BodyReader::ToEnd(&mut self.conn) .read_to_end(&mut self.body_buf[self.raw_body_read..]) .await?; @@ -287,7 +287,7 @@ where BodyReader::Empty => true, BodyReader::FixedLength(reader) => reader.remaining == 0, BodyReader::Chunked(reader) => reader.is_done(), - BodyReader::ToEnd(_) => true, + BodyReader::ToEnd(_) => false, } } @@ -302,14 +302,29 @@ where } } - if self.is_done() { - Ok(len) - } else { - if let BodyReader::FixedLength(reader) = self { - warn!("FixedLength: {} bytes remained", reader.remaining); + if !self.is_done() { + let more = match self { + BodyReader::FixedLength(reader) => { + warn!("FixedLength: {} bytes remained", reader.remaining); + true + } + BodyReader::ToEnd(reader) if len == buf.len() => { + warn!("ToEnd: Buffer full, waiting to see if there is unread data."); + + let mut b = [0]; + matches!(reader.read(&mut b).await, Ok(1)) + } + + BodyReader::ToEnd(_) => false, + _ => true, + }; + + if more { + return Err(Error::BufferTooSmall); } - Err(Error::BufferTooSmall) } + + Ok(len) } async fn discard(&mut self) -> Result {