From 0e9d9ac8b056f03ccc8534671e38286022cd2ec2 Mon Sep 17 00:00:00 2001 From: Rasmus Melchior Jacobsen Date: Tue, 31 Oct 2023 10:39:45 +0100 Subject: [PATCH 1/3] Use BufferedRead when reading --- Cargo.toml | 2 +- src/reader.rs | 87 ++++++++++++--------------------------------------- 2 files changed, 21 insertions(+), 68 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 25c4d4e..9990714 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ keywords = ["embedded", "async", "http", "no_std"] exclude = [".github"] [dependencies] -buffered-io = { version = "0.4.0", features = ["async"] } +buffered-io = { version = "0.4.2", features = ["async"] } embedded-io = { version = "0.6" } embedded-io-async = { version = "0.6" } embedded-nal-async = "0.6.0" diff --git a/src/reader.rs b/src/reader.rs index be19719..f3a1b82 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -1,53 +1,14 @@ +use buffered_io::asynch::BufferedRead; use embedded_io::{Error, ErrorKind, ErrorType}; use embedded_io_async::{BufRead, Read, Write}; use crate::client::HttpConnection; -struct ReadBuffer<'buf> { - buffer: &'buf mut [u8], - loaded: usize, -} - -impl<'buf> ReadBuffer<'buf> { - fn new(buffer: &'buf mut [u8], loaded: usize) -> Self { - Self { buffer, loaded } - } -} - -impl ReadBuffer<'_> { - fn is_empty(&self) -> bool { - self.loaded == 0 - } - - fn read(&mut self, buf: &mut [u8]) -> Result { - let amt = self.loaded.min(buf.len()); - buf[..amt].copy_from_slice(&self.buffer[0..amt]); - - self.consume(amt); - - Ok(amt) - } - - fn fill_buf(&mut self) -> Result<&[u8], ErrorKind> { - Ok(&self.buffer[..self.loaded]) - } - - fn consume(&mut self, amt: usize) -> usize { - let to_consume = amt.min(self.loaded); - - self.buffer.copy_within(to_consume..self.loaded, 0); - self.loaded -= to_consume; - - amt - to_consume - } -} - pub struct BufferingReader<'resp, 'buf, B> where B: Read, { - buffer: ReadBuffer<'buf>, - stream: &'resp mut B, + buffered: BufferedRead<'buf, &'resp mut B>, } impl<'resp, 'buf, B> BufferingReader<'resp, 'buf, B> @@ -56,8 +17,7 @@ where { pub fn new(buffer: &'buf mut [u8], loaded: usize, stream: &'resp mut B) -> Self { Self { - buffer: ReadBuffer::new(buffer, loaded), - stream, + buffered: BufferedRead::new_with_data(stream, buffer, 0, loaded), } } } @@ -74,12 +34,7 @@ where C: Read, { async fn read(&mut self, buf: &mut [u8]) -> Result { - if !self.buffer.is_empty() { - let amt = self.buffer.read(buf)?; - return Ok(amt); - } - - self.stream.read(buf).await.map_err(|e| e.kind()) + self.buffered.read(buf).await.map_err(|e| e.kind()) } } @@ -88,31 +43,29 @@ where C: Read + Write, { async fn fill_buf(&mut self) -> Result<&[u8], ErrorKind> { - // We need to consume the loaded bytes before we read mode. - if self.buffer.is_empty() { - // embedded-tls has its own internal buffer, let's prefer that if we can - #[cfg(feature = "embedded-tls")] - if let HttpConnection::Tls(ref mut tls) = self.stream { + // The call to buffered.bypass() will only return Ok(...) if the buffer is empty. + // This ensures that we completely drain the possibly pre-filled buffer before we try + // to use the embedded-tls buffer directly. + // The matches/if let dance is to fix lifetime of the borrowed inner connection. + #[cfg(feature = "embedded-tls")] + if matches!(self.buffered.bypass(), Ok(HttpConnection::Tls(_))) { + if let HttpConnection::Tls(ref mut tls) = self.buffered.bypass().unwrap() { return tls.fill_buf().await.map_err(|e| e.kind()); } - - self.buffer.loaded = self.stream.read(&mut self.buffer.buffer).await?; + unreachable!(); } - self.buffer.fill_buf() + self.buffered.fill_buf().await } fn consume(&mut self, amt: usize) { - // It's possible that the user requested more bytes to be consumed than loaded. Especially - // since it's also possible that nothing is loaded, after we consumed all and are using - // embedded-tls's buffering. - let unconsumed = self.buffer.consume(amt); - - if unconsumed > 0 { - #[cfg(feature = "embedded-tls")] - if let HttpConnection::Tls(tls) = &mut self.stream { - tls.consume(unconsumed); - } + // The call to buffered.bypass() will only return Ok(...) if the buffer is empty. + #[cfg(feature = "embedded-tls")] + if let Ok(HttpConnection::Tls(tls)) = self.buffered.bypass() { + tls.consume(amt); + return; } + + self.buffered.consume(amt); } } From e1b6eae849b1e5f2ba41ff5f8d6605d7685494c9 Mon Sep 17 00:00:00 2001 From: Rasmus Melchior Jacobsen Date: Tue, 31 Oct 2023 11:24:40 +0100 Subject: [PATCH 2/3] Do not unwrap call to bypass() --- src/reader.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/reader.rs b/src/reader.rs index f3a1b82..6152ce7 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -49,7 +49,7 @@ where // The matches/if let dance is to fix lifetime of the borrowed inner connection. #[cfg(feature = "embedded-tls")] if matches!(self.buffered.bypass(), Ok(HttpConnection::Tls(_))) { - if let HttpConnection::Tls(ref mut tls) = self.buffered.bypass().unwrap() { + if let Ok(HttpConnection::Tls(ref mut tls)) = self.buffered.bypass() { return tls.fill_buf().await.map_err(|e| e.kind()); } unreachable!(); From 196a14075533371aa17173bfc9241b94c153b080 Mon Sep 17 00:00:00 2001 From: Rasmus Melchior Jacobsen Date: Tue, 31 Oct 2023 12:36:36 +0100 Subject: [PATCH 3/3] Add preload test --- src/response.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/response.rs b/src/response.rs index f56b3f8..c682950 100644 --- a/src/response.rs +++ b/src/response.rs @@ -807,6 +807,25 @@ mod tests { assert_eq!(11, response.body().discard().await.unwrap()); } + #[tokio::test] + async fn chunked_body_reader_can_read_preloaded() { + let mut read_buffer: Vec = Vec::new(); + read_buffer.extend_from_slice(b"1\r\nX\r\n10\r\nYYYYYYYYYYYYYYYY\r\n0\r\n\r\n"); + let preloaded = read_buffer.len(); + let mut empty_body = [0; 0].as_slice(); + let mut reader = ChunkedBodyReader { + raw_body: BufferingReader::new(&mut read_buffer, preloaded, &mut empty_body), + chunk_remaining: ChunkState::NoChunk, + }; + + let mut body = [0; 17]; + reader.read_exact(&mut body).await.unwrap(); + + assert_eq!(0, reader.read(&mut body).await.unwrap()); + assert_eq!(0, reader.read(&mut body).await.unwrap()); + assert_eq!(b"XYYYYYYYYYYYYYYYY", &body); + } + #[tokio::test] async fn chunked_body_reader_can_read_with_large_buffer() { let mut raw_body = b"1\r\nX\r\n10\r\nYYYYYYYYYYYYYYYY\r\n0\r\n\r\n".as_slice();