Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow read_to_end with ChunkedEncoding #61

Merged
merged 11 commits into from
Nov 28, 2023
33 changes: 33 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,36 @@ impl From<nourl::Error> for Error {
Error::InvalidUrl(e)
}
}

/// Trait for types that may optionally implement [`embedded_io_async::BufRead`]
pub trait TryBufRead: embedded_io_async::Read {
async fn try_fill_buf(&mut self) -> Option<Result<&[u8], Self::Error>> {
None
}

fn try_consume(&mut self, _amt: usize) {}
}

impl<C> TryBufRead for crate::client::HttpConnection<'_, C>
where
C: embedded_io_async::Read + embedded_io_async::Write,
{
async fn try_fill_buf(&mut self) -> Option<Result<&[u8], Self::Error>> {
// embedded-tls has its own internal buffer, let's prefer that if we can
#[cfg(feature = "embedded-tls")]
if let Self::Tls(ref mut tls) = self {
use embedded_io_async::{BufRead, Error};
return Some(tls.fill_buf().await.map_err(|e| e.kind()));
}

None
}

fn try_consume(&mut self, amt: usize) {
#[cfg(feature = "embedded-tls")]
if let Self::Tls(tls) = self {
use embedded_io_async::BufRead;
tls.consume(amt);
}
}
}
35 changes: 17 additions & 18 deletions src/reader.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use embedded_io::{Error, ErrorKind, ErrorType};
use embedded_io_async::{BufRead, Read, Write};
use embedded_io_async::{BufRead, Read};

use crate::client::HttpConnection;
use crate::TryBufRead;

struct ReadBuffer<'buf> {
buffer: &'buf mut [u8],
loaded: usize,
pub(crate) struct ReadBuffer<'buf> {
pub buffer: &'buf mut [u8],
pub loaded: usize,
}

impl<'buf> ReadBuffer<'buf> {
Expand Down Expand Up @@ -46,8 +46,8 @@ pub struct BufferingReader<'resp, 'buf, B>
where
B: Read,
{
buffer: ReadBuffer<'buf>,
stream: &'resp mut B,
pub(crate) buffer: ReadBuffer<'buf>,
pub(crate) stream: &'resp mut B,
}

impl<'resp, 'buf, B> BufferingReader<'resp, 'buf, B>
Expand Down Expand Up @@ -83,20 +83,22 @@ where
}
}

impl<C> BufRead for BufferingReader<'_, '_, HttpConnection<'_, C>>
impl<C> BufRead for BufferingReader<'_, '_, C>
where
C: Read + Write,
C: TryBufRead,
{
async fn fill_buf(&mut self) -> Result<&[u8], ErrorKind> {
// We need to consume the loaded bytes before we read mode.
if self.buffer.is_empty() {
// embedded-tls has its own internal buffer, let's prefer that if we can
#[cfg(feature = "embedded-tls")]
if let HttpConnection::Tls(ref mut tls) = self.stream {
return tls.fill_buf().await.map_err(|e| e.kind());
// The matches/if let dance is to fix lifetime of the borrowed inner connection.
if self.stream.try_fill_buf().await.is_some() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rmja This got handy here, too :)

if let Some(result) = self.stream.try_fill_buf().await {
return result.map_err(|e| e.kind());
}
unreachable!()
}

self.buffer.loaded = self.stream.read(&mut self.buffer.buffer).await?;
self.buffer.loaded = self.stream.read(&mut self.buffer.buffer).await.map_err(|e| e.kind())?;
}

self.buffer.fill_buf()
Expand All @@ -109,10 +111,7 @@ where
let unconsumed = self.buffer.consume(amt);

if unconsumed > 0 {
#[cfg(feature = "embedded-tls")]
if let HttpConnection::Tls(tls) = &mut self.stream {
tls.consume(unconsumed);
}
self.stream.try_consume(unconsumed);
}
}
}
235 changes: 235 additions & 0 deletions src/response/chunked.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
use embedded_io_async::{BufRead, Error as _, ErrorType, Read};

use crate::{
reader::{BufferingReader, ReadBuffer},
Error, TryBufRead,
};

#[derive(Clone, Copy, PartialEq, Eq, Debug)]
enum ChunkState {
NoChunk,
NotEmpty(u32),
Empty,
}

impl ChunkState {
fn consume(&mut self, amt: usize) -> usize {
if let ChunkState::NotEmpty(remaining) = self {
let consumed = (amt as u32).min(*remaining);
*remaining -= consumed;
consumed as usize
} else {
0
}
}

fn len(self) -> usize {
if let ChunkState::NotEmpty(len) = self {
len as usize
} else {
0
}
}
}

/// Chunked response body reader
pub struct ChunkedBodyReader<B> {
pub raw_body: B,
chunk_remaining: ChunkState,
}

impl<C> ChunkedBodyReader<C>
where
C: Read,
{
pub fn new(raw_body: C) -> Self {
Self {
raw_body,
chunk_remaining: ChunkState::NoChunk,
}
}

pub fn is_done(&self) -> bool {
self.chunk_remaining == ChunkState::Empty
}

async fn read_next_chunk_length(&mut self) -> Result<(), Error> {
let mut header_buf = [0; 8 + 2]; // 32 bit hex + \r + \n
let mut total_read = 0;

'read_size: loop {
let mut byte = 0;
self.raw_body
.read_exact(core::slice::from_mut(&mut byte))
.await
.map_err(|e| Error::from(e).kind())?;

if byte != b'\n' {
header_buf[total_read] = byte;
total_read += 1;

if total_read == header_buf.len() {
return Err(Error::Codec);
}
} else {
if total_read == 0 || header_buf[total_read - 1] != b'\r' {
return Err(Error::Codec);
}
break 'read_size;
}
}

let hex_digits = total_read - 1;

// Prepend hex with zeros
let mut hex = [b'0'; 8];
hex[8 - hex_digits..].copy_from_slice(&header_buf[..hex_digits]);

let mut bytes = [0; 4];
hex::decode_to_slice(hex, &mut bytes).map_err(|_| Error::Codec)?;

let chunk_length = u32::from_be_bytes(bytes);

debug!("Chunk length: {}", chunk_length);

self.chunk_remaining = match chunk_length {
0 => ChunkState::Empty,
other => ChunkState::NotEmpty(other),
};

Ok(())
}

async fn read_chunk_end(&mut self) -> Result<(), Error> {
// All chunks are terminated with a \r\n
let mut newline_buf = [0; 2];
self.raw_body.read_exact(&mut newline_buf).await?;

if newline_buf != [b'\r', b'\n'] {
return Err(Error::Codec);
}
Ok(())
}

/// Handles chunk boundary and returns the number of bytes in the current (or new) chunk.
async fn handle_chunk_boundary(&mut self) -> Result<usize, Error> {
match self.chunk_remaining {
ChunkState::NoChunk => self.read_next_chunk_length().await?,

ChunkState::NotEmpty(0) => {
// The current chunk is currently empty, advance into a new chunk...
self.read_chunk_end().await?;
self.read_next_chunk_length().await?;
}

ChunkState::NotEmpty(_) => {}

ChunkState::Empty => return Ok(0),
}

if self.chunk_remaining == ChunkState::Empty {
// Read final chunk termination
self.read_chunk_end().await?;
}

Ok(self.chunk_remaining.len())
}
}

impl<'conn, 'buf, C> ChunkedBodyReader<BufferingReader<'conn, 'buf, C>>
where
C: Read + TryBufRead,
{
pub(crate) async fn read_to_end(self) -> Result<&'buf mut [u8], Error> {
let buffer = self.raw_body.buffer.buffer;

// We reconstruct the reader to change the 'buf lifetime.
let mut reader = ChunkedBodyReader {
raw_body: BufferingReader {
buffer: ReadBuffer {
buffer: &mut buffer[..],
loaded: self.raw_body.buffer.loaded,
},
stream: self.raw_body.stream,
},
chunk_remaining: self.chunk_remaining,
};

let mut len = 0;
while !reader.raw_body.buffer.buffer.is_empty() {
// Read some
let read = reader.fill_buf().await?.len();
len += read;

// Make sure we don't erase the newly read data
let was_loaded = reader.raw_body.buffer.loaded;
let fake_loaded = read.min(was_loaded);
reader.raw_body.buffer.loaded = fake_loaded;

// Consume the returned buffer
reader.consume(read);

if reader.is_done() {
// If we're done, we don't care about the rest of the housekeeping.
break;
}

// How many bytes were actually consumed from the preloaded buffer?
let consumed_from_buffer = fake_loaded - reader.raw_body.buffer.loaded;

// ... move the buffer by that many bytes to avoid overwriting in the next iteration.
reader.raw_body.buffer.loaded = was_loaded - consumed_from_buffer;
reader.raw_body.buffer.buffer = &mut reader.raw_body.buffer.buffer[consumed_from_buffer..];
}

if !reader.is_done() {
return Err(Error::BufferTooSmall);
}

Ok(&mut buffer[..len])
}
}

impl<C> ErrorType for ChunkedBodyReader<C> {
type Error = Error;
}

impl<C> Read for ChunkedBodyReader<C>
where
C: Read,
{
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
let remaining = self.handle_chunk_boundary().await?;
let max_len = buf.len().min(remaining);

let len = self
.raw_body
.read(&mut buf[..max_len])
.await
.map_err(|e| Error::Network(e.kind()))?;

self.chunk_remaining.consume(len);

Ok(len)
}
}

impl<C> BufRead for ChunkedBodyReader<C>
where
C: BufRead + Read,
{
async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
let remaining = self.handle_chunk_boundary().await?;

let buf = self.raw_body.fill_buf().await.map_err(|e| Error::Network(e.kind()))?;

let len = buf.len().min(remaining);

Ok(&buf[..len])
}

fn consume(&mut self, amt: usize) {
let consumed = self.chunk_remaining.consume(amt);
self.raw_body.consume(consumed);
}
}
Loading