Skip to content

Commit

Permalink
Use BufferedRead when reading
Browse files Browse the repository at this point in the history
  • Loading branch information
rmja committed Oct 31, 2023
1 parent ddabc48 commit 0e9d9ac
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 68 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ keywords = ["embedded", "async", "http", "no_std"]
exclude = [".github"]

[dependencies]
buffered-io = { version = "0.4.0", features = ["async"] }
buffered-io = { version = "0.4.2", features = ["async"] }
embedded-io = { version = "0.6" }
embedded-io-async = { version = "0.6" }
embedded-nal-async = "0.6.0"
Expand Down
87 changes: 20 additions & 67 deletions src/reader.rs
Original file line number Diff line number Diff line change
@@ -1,53 +1,14 @@
use buffered_io::asynch::BufferedRead;
use embedded_io::{Error, ErrorKind, ErrorType};
use embedded_io_async::{BufRead, Read, Write};

use crate::client::HttpConnection;

struct ReadBuffer<'buf> {
buffer: &'buf mut [u8],
loaded: usize,
}

impl<'buf> ReadBuffer<'buf> {
fn new(buffer: &'buf mut [u8], loaded: usize) -> Self {
Self { buffer, loaded }
}
}

impl ReadBuffer<'_> {
fn is_empty(&self) -> bool {
self.loaded == 0
}

fn read(&mut self, buf: &mut [u8]) -> Result<usize, ErrorKind> {
let amt = self.loaded.min(buf.len());
buf[..amt].copy_from_slice(&self.buffer[0..amt]);

self.consume(amt);

Ok(amt)
}

fn fill_buf(&mut self) -> Result<&[u8], ErrorKind> {
Ok(&self.buffer[..self.loaded])
}

fn consume(&mut self, amt: usize) -> usize {
let to_consume = amt.min(self.loaded);

self.buffer.copy_within(to_consume..self.loaded, 0);
self.loaded -= to_consume;

amt - to_consume
}
}

pub struct BufferingReader<'resp, 'buf, B>
where
B: Read,
{
buffer: ReadBuffer<'buf>,
stream: &'resp mut B,
buffered: BufferedRead<'buf, &'resp mut B>,
}

impl<'resp, 'buf, B> BufferingReader<'resp, 'buf, B>
Expand All @@ -56,8 +17,7 @@ where
{
pub fn new(buffer: &'buf mut [u8], loaded: usize, stream: &'resp mut B) -> Self {
Self {
buffer: ReadBuffer::new(buffer, loaded),
stream,
buffered: BufferedRead::new_with_data(stream, buffer, 0, loaded),
}
}
}
Expand All @@ -74,12 +34,7 @@ where
C: Read,
{
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
if !self.buffer.is_empty() {
let amt = self.buffer.read(buf)?;
return Ok(amt);
}

self.stream.read(buf).await.map_err(|e| e.kind())
self.buffered.read(buf).await.map_err(|e| e.kind())
}
}

Expand All @@ -88,31 +43,29 @@ where
C: Read + Write,
{
async fn fill_buf(&mut self) -> Result<&[u8], ErrorKind> {
// We need to consume the loaded bytes before we read mode.
if self.buffer.is_empty() {
// embedded-tls has its own internal buffer, let's prefer that if we can
#[cfg(feature = "embedded-tls")]
if let HttpConnection::Tls(ref mut tls) = self.stream {
// The call to buffered.bypass() will only return Ok(...) if the buffer is empty.
// This ensures that we completely drain the possibly pre-filled buffer before we try
// to use the embedded-tls buffer directly.
// The matches/if let dance is to fix lifetime of the borrowed inner connection.
#[cfg(feature = "embedded-tls")]
if matches!(self.buffered.bypass(), Ok(HttpConnection::Tls(_))) {
if let HttpConnection::Tls(ref mut tls) = self.buffered.bypass().unwrap() {
return tls.fill_buf().await.map_err(|e| e.kind());
}

self.buffer.loaded = self.stream.read(&mut self.buffer.buffer).await?;
unreachable!();
}

self.buffer.fill_buf()
self.buffered.fill_buf().await
}

fn consume(&mut self, amt: usize) {
// It's possible that the user requested more bytes to be consumed than loaded. Especially
// since it's also possible that nothing is loaded, after we consumed all and are using
// embedded-tls's buffering.
let unconsumed = self.buffer.consume(amt);

if unconsumed > 0 {
#[cfg(feature = "embedded-tls")]
if let HttpConnection::Tls(tls) = &mut self.stream {
tls.consume(unconsumed);
}
// The call to buffered.bypass() will only return Ok(...) if the buffer is empty.
#[cfg(feature = "embedded-tls")]
if let Ok(HttpConnection::Tls(tls)) = self.buffered.bypass() {
tls.consume(amt);
return;
}

self.buffered.consume(amt);
}
}

0 comments on commit 0e9d9ac

Please sign in to comment.