Skip to content

Commit

Permalink
Add BufferedChunkedBodyWriter
Browse files Browse the repository at this point in the history
  • Loading branch information
rmja committed May 17, 2024
1 parent bb13c91 commit 20b0514
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 40 deletions.
6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ keywords = ["embedded", "async", "http", "no_std"]
exclude = [".github"]

[dependencies]
buffered-io = { version = "0.5.1" }
buffered-io = { version = "0.5.2" }
embedded-io = { version = "0.6" }
embedded-io-async = { version = "0.6" }
embedded-nal-async = "0.7.0"
Expand All @@ -27,7 +27,9 @@ defmt = { version = "0.3", optional = true }
embedded-tls = { version = "0.17", default-features = false, optional = true }
rand_chacha = { version = "0.3", default-features = false }
nourl = "0.1.1"
esp-mbedtls = { git = "https://github.com/esp-rs/esp-mbedtls.git", features = ["async"], optional = true }
esp-mbedtls = { git = "https://github.com/esp-rs/esp-mbedtls.git", features = [
"async",
], optional = true }

[dev-dependencies]
hyper = { version = "0.14.23", features = ["full"] }
Expand Down
24 changes: 6 additions & 18 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,31 +311,19 @@ where
HttpConnection::Plain(c) => {
let mut writer = ChunkedBodyWriter::new(c);
body.write(&mut writer).await?;
writer.write_empty_chunk().await.map_err(|e| e.kind())?;
writer.terminate().await.map_err(|e| e.kind())?;
}
HttpConnection::PlainBuffered(buffered_conn) => {
// Flush the buffered connection so that we can bypass it and rent its buffer
buffered_conn.flush().await.map_err(|e| e.kind())?;
let (conn, buf) = buffered_conn.bypass_with_buf().unwrap();

// Construct a new buffered writer that buffers _before_ the chunked body writer
let mut writer = BufferedWrite::new(ChunkedBodyWriter::new(conn), buf);
HttpConnection::PlainBuffered(buffered) => {
let (conn, buf, unwritten) = buffered.split();
let mut writer = BufferedChunkedBodyWriter::new_with_data(conn, buf, unwritten);
body.write(&mut writer).await?;

// Flush the buffered writer and write the empty chunk to the chunked body writer
writer.flush().await.map_err(|e| e.kind())?;
writer
.bypass()
.unwrap()
.write_empty_chunk()
.await
.map_err(|e| e.kind())?;
writer.terminate().await.map_err(|e| e.kind())?;
}
#[cfg(any(feature = "embedded-tls", feature = "esp-mbedtls"))]
HttpConnection::Tls(c) => {
let mut writer = ChunkedBodyWriter::new(c);
body.write(&mut writer).await?;
writer.write_empty_chunk().await.map_err(|e| e.kind())?;
writer.terminate().await.map_err(|e| e.kind())?;
}
#[cfg(all(not(feature = "embedded-tls"), not(feature = "esp-mbedtls")))]
HttpConnection::Tls(_) => unreachable!(),
Expand Down
176 changes: 156 additions & 20 deletions src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ impl Method {
}

async fn write_str<C: Write>(c: &mut C, data: &str) -> Result<(), Error> {
c.write_all(data.as_bytes()).await.map_err(to_errorkind)?;
c.write_all(data.as_bytes()).await.map_err(|e| e.kind())?;
Ok(())
}

Expand Down Expand Up @@ -355,17 +355,35 @@ where
}
}

pub struct ChunkedBodyWriter<C: Write>(C, usize);
const fn hex_chars(number: usize) -> u32 {
if number == 0 {
1
} else {
(usize::BITS - number.leading_zeros()).div_ceil(4)
}
}

fn write_chunked_header(buf: &mut [u8], chunk_len: usize) -> usize {
let mut hex = [0; 2 * size_of::<usize>()];
hex::encode_to_slice(chunk_len.to_be_bytes(), &mut hex).unwrap();
let leading_zeros = hex.iter().position(|x| *x != b'0').unwrap_or_default();
let hex_chars = hex.len() - leading_zeros;
buf[..hex_chars].copy_from_slice(&hex[leading_zeros..]);
buf[hex_chars..hex_chars + 2].copy_from_slice(b"\r\n");
hex_chars + 2
}

pub struct ChunkedBodyWriter<C: Write>(C);

impl<C> ChunkedBodyWriter<C>
where
C: Write,
{
pub fn new(conn: C) -> Self {
Self(conn, 0)
Self(conn)
}

pub async fn write_empty_chunk(&mut self) -> Result<(), C::Error> {
pub async fn terminate(&mut self) -> Result<(), C::Error> {
self.0.write_all(b"0\r\n\r\n").await
}
}
Expand All @@ -377,21 +395,16 @@ where
type Error = embedded_io::ErrorKind;
}

fn to_errorkind<E: embedded_io::Error>(e: E) -> embedded_io::ErrorKind {
e.kind()
}

impl<C> Write for ChunkedBodyWriter<C>
where
C: Write,
{
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
self.write_all(buf).await.map_err(to_errorkind)?;
self.write_all(buf).await.map_err(|e| e.kind())?;
Ok(buf.len())
}

async fn write_all(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
// Write chunk header
let len = buf.len();

// Do not write an empty chunk as that will terminate the body
Expand All @@ -400,19 +413,19 @@ where
return Ok(());
}

let mut hex = [0; 2 * size_of::<usize>()];
hex::encode_to_slice(len.to_be_bytes(), &mut hex).unwrap();
let leading_zeros = hex.iter().position(|x| *x != b'0').unwrap_or_default();
let (_, hex) = hex.split_at(leading_zeros);
self.0.write_all(hex).await.map_err(to_errorkind)?;
self.0.write_all(b"\r\n").await.map_err(to_errorkind)?;
// Write chunk header
let mut header_buf = [0; 2 * size_of::<usize>() + 2];
let header_len = write_chunked_header(&mut header_buf, len);
self.0
.write_all(&header_buf[..header_len])
.await
.map_err(|e| e.kind())?;

// Write chunk
self.0.write_all(buf).await.map_err(to_errorkind)?;
self.1 += len;
self.0.write_all(buf).await.map_err(|e| e.kind())?;

// Write newline
self.0.write_all(b"\r\n").await.map_err(to_errorkind)?;
// Write newline footer
self.0.write_all(b"\r\n").await.map_err(|e| e.kind())?;
Ok(())
}

Expand All @@ -421,10 +434,133 @@ where
}
}

pub struct BufferedChunkedBodyWriter<'a, C: Write> {
conn: C,
buf: &'a mut [u8],
header_pos: usize,
pos: usize,
max_header_size: usize,
max_footer_size: usize,
}

impl<'a, C> BufferedChunkedBodyWriter<'a, C>
where
C: Write,
{
pub fn new_with_data(conn: C, buf: &'a mut [u8], written: usize) -> Self {
let max_hex_chars = hex_chars(buf.len());
let max_header_size = max_hex_chars as usize + 2;
let max_footer_size = 2;
assert!(buf.len() > max_header_size + max_footer_size); // There must be space for the chunk header and footer
Self {
conn,
buf,
header_pos: written,
pos: written + max_header_size,
max_header_size,
max_footer_size,
}
}

pub async fn terminate(&mut self) -> Result<(), C::Error> {
if self.pos > self.header_pos + self.max_header_size {
self.finish_current_chunk();
}
const EMPTY: &[u8; 5] = b"0\r\n\r\n";
if self.header_pos + EMPTY.len() > self.buf.len() {
self.emit_finished_chunk().await?;
}

self.buf[self.header_pos..self.header_pos + EMPTY.len()].copy_from_slice(EMPTY);
self.header_pos += EMPTY.len();
self.pos = self.header_pos + self.max_header_size;
self.emit_finished_chunk().await
}

fn append_current_chunk(&mut self, buf: &[u8]) -> usize {
let buffered = usize::min(buf.len(), self.buf.len() - self.max_footer_size - self.pos);
if buffered > 0 {
self.buf[self.pos..self.pos + buffered].copy_from_slice(&buf[..buffered]);
self.pos += buffered;
}
buffered
}

fn finish_current_chunk(&mut self) {
// Write the header in the allocated position position
let chunk_len = self.pos - self.header_pos - self.max_header_size;
let header_buf = &mut self.buf[self.header_pos..self.header_pos + self.max_header_size];
let header_len = write_chunked_header(header_buf, chunk_len);

// Move the payload if the header length was not as large as it could possibly be
let spacing = self.max_header_size - header_len;
if spacing > 0 {
self.buf.copy_within(
self.header_pos + self.max_header_size..self.pos,
self.header_pos + header_len,
);
self.pos -= spacing
}

// Write newline footer after chunk payload
self.buf[self.pos..self.pos + 2].copy_from_slice(b"\r\n");
self.pos += 2;

self.header_pos = self.pos;
self.pos = self.header_pos + self.max_header_size;
}

async fn emit_finished_chunk(&mut self) -> Result<(), C::Error> {
self.conn.write_all(&self.buf[..self.header_pos]).await?;
self.header_pos = 0;
self.pos = self.max_header_size;
Ok(())
}
}

impl<C> ErrorType for BufferedChunkedBodyWriter<'_, C>
where
C: Write,
{
type Error = embedded_io::ErrorKind;
}

impl<C> Write for BufferedChunkedBodyWriter<'_, C>
where
C: Write,
{
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
let written = self.append_current_chunk(buf);
if written < buf.len() {
self.finish_current_chunk();
self.emit_finished_chunk().await.map_err(|e| e.kind())?;
}
Ok(written)
}

async fn flush(&mut self) -> Result<(), Self::Error> {
if self.header_pos > 0 {
self.finish_current_chunk();
self.emit_finished_chunk().await.map_err(|e| e.kind())?;
}
self.conn.flush().await.map_err(|e| e.kind())
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn hex_chars_values() {
assert_eq!(1, hex_chars(0));
assert_eq!(1, hex_chars(1));
assert_eq!(1, hex_chars(0xF));
assert_eq!(2, hex_chars(0x10));
assert_eq!(2, hex_chars(0xFF));
assert_eq!(3, hex_chars(0x100));
}

#[tokio::test]
async fn basic_auth() {
let mut buffer: Vec<u8> = Vec::new();
Expand Down

0 comments on commit 20b0514

Please sign in to comment.