Skip to content

Commit

Permalink
Add tests and optimize allocated header to match the remaining size o…
Browse files Browse the repository at this point in the history
…f the buffer
  • Loading branch information
rmja committed May 21, 2024
1 parent befcc10 commit d8d416a
Showing 1 changed file with 131 additions and 36 deletions.
167 changes: 131 additions & 36 deletions src/body_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use core::mem::size_of;
use embedded_io::{Error as _, ErrorType};
use embedded_io_async::Write;

const NEWLINE: &[u8; 2] = b"\r\n";
const EMPTY_CHUNK: &[u8; 5] = b"0\r\n\r\n";

pub struct FixedBodyWriter<C: Write>(C, usize);

impl<C> FixedBodyWriter<C>
Expand Down Expand Up @@ -56,8 +59,9 @@ where
Self(conn)
}

/// Terminate the request body by writing an empty chunk
pub async fn terminate(&mut self) -> Result<(), C::Error> {
self.0.write_all(b"0\r\n\r\n").await
self.0.write_all(EMPTY_CHUNK).await
}
}

Expand Down Expand Up @@ -98,7 +102,7 @@ where
self.0.write_all(buf).await.map_err(|e| e.kind())?;

// Write newline footer
self.0.write_all(b"\r\n").await.map_err(|e| e.kind())?;
self.0.write_all(NEWLINE).await.map_err(|e| e.kind())?;
Ok(())
}

Expand All @@ -110,83 +114,93 @@ where
pub struct BufferingChunkedBodyWriter<'a, C: Write> {
conn: C,
buf: &'a mut [u8],
/// The position where the allocated chunk header starts
header_pos: usize,
/// The size of the allocated header (the final header may be smaller)
allocated_header: usize,
/// The position of the data in the chunk
pos: usize,
max_header_size: usize,
max_footer_size: usize,
}

impl<'a, C> BufferingChunkedBodyWriter<'a, C>
where
C: Write,
{
pub fn new_with_data(conn: C, buf: &'a mut [u8], written: usize) -> Self {
let max_hex_chars = hex_chars(buf.len());
let max_header_size = max_hex_chars as usize + 2;
let max_footer_size = 2;
assert!(buf.len() > max_header_size + max_footer_size); // There must be space for the chunk header and footer
assert!(written <= buf.len());
let allocated_header = get_max_chunk_header_size(buf.len() - written);
assert!(buf.len() > allocated_header + NEWLINE.len()); // There must be space for the chunk header and footer
Self {
conn,
buf,
header_pos: written,
pos: written + max_header_size,
max_header_size,
max_footer_size,
pos: written + allocated_header,
allocated_header,
}
}

/// Terminate the request body by writing an empty chunk
pub async fn terminate(&mut self) -> Result<(), C::Error> {
if self.pos > self.header_pos + self.max_header_size {
assert!(self.allocated_header > 0);

if self.pos > self.header_pos + self.allocated_header {
// There are bytes written in the current chunk
self.finish_current_chunk();
}
const EMPTY: &[u8; 5] = b"0\r\n\r\n";
if self.header_pos + EMPTY.len() > self.buf.len() {
self.emit_finished_chunk().await?;

if self.header_pos + EMPTY_CHUNK.len() > self.buf.len() {
// There is not enough space to fit the empty chunk in the buffer
self.emit_finished_chunk().await?;
}
}

self.buf[self.header_pos..self.header_pos + EMPTY.len()].copy_from_slice(EMPTY);
self.header_pos += EMPTY.len();
self.pos = self.header_pos + self.max_header_size;
self.buf[self.header_pos..self.header_pos + EMPTY_CHUNK.len()].copy_from_slice(EMPTY_CHUNK);
self.header_pos += EMPTY_CHUNK.len();
self.allocated_header = 0;
self.pos = self.header_pos + self.allocated_header;
self.emit_finished_chunk().await
}

/// Append to the buffer
fn append_current_chunk(&mut self, buf: &[u8]) -> usize {
let buffered = usize::min(buf.len(), self.buf.len() - self.max_footer_size - self.pos);
let buffered = usize::min(buf.len(), self.buf.len() - NEWLINE.len() - self.pos);
if buffered > 0 {
self.buf[self.pos..self.pos + buffered].copy_from_slice(&buf[..buffered]);
self.pos += buffered;
}
buffered
}

/// Finish the current chunk by writing the header
fn finish_current_chunk(&mut self) {
// Write the header in the allocated position position
let chunk_len = self.pos - self.header_pos - self.max_header_size;
let header_buf = &mut self.buf[self.header_pos..self.header_pos + self.max_header_size];
let chunk_len = self.pos - self.header_pos - self.allocated_header;
let header_buf = &mut self.buf[self.header_pos..self.header_pos + self.allocated_header];
let header_len = write_chunked_header(header_buf, chunk_len);

// Move the payload if the header length was not as large as it could possibly be
let spacing = self.max_header_size - header_len;
let spacing = self.allocated_header - header_len;
if spacing > 0 {
self.buf.copy_within(
self.header_pos + self.max_header_size..self.pos,
self.header_pos + self.allocated_header..self.pos,
self.header_pos + header_len,
);
self.pos -= spacing
}

// Write newline footer after chunk payload
self.buf[self.pos..self.pos + 2].copy_from_slice(b"\r\n");
self.buf[self.pos..self.pos + NEWLINE.len()].copy_from_slice(NEWLINE);
self.pos += 2;

self.header_pos = self.pos;
self.pos = self.header_pos + self.max_header_size;
self.allocated_header = get_max_chunk_header_size(self.buf.len() - self.header_pos);
self.pos = self.header_pos + self.allocated_header;
}

async fn emit_finished_chunk(&mut self) -> Result<(), C::Error> {
self.conn.write_all(&self.buf[..self.header_pos]).await?;
self.header_pos = 0;
self.pos = self.max_header_size;
self.allocated_header = get_max_chunk_header_size(self.buf.len());
self.pos = self.allocated_header;
Ok(())
}
}
Expand Down Expand Up @@ -220,21 +234,30 @@ where
}
}

const fn hex_chars(number: usize) -> u32 {
const fn get_hex_chars(number: usize) -> u32 {
if number == 0 {
1
} else {
(usize::BITS - number.leading_zeros()).div_ceil(4)
}
}

const fn get_max_chunk_header_size(buffer_size: usize) -> usize {
if buffer_size >= NEWLINE.len() + NEWLINE.len() {
get_hex_chars(buffer_size - NEWLINE.len() - NEWLINE.len()) as usize + NEWLINE.len()
} else {
// Not enough space in buffer to fit a header + footer
0
}
}

fn write_chunked_header(buf: &mut [u8], chunk_len: usize) -> usize {
let mut hex = [0; 2 * size_of::<usize>()];
hex::encode_to_slice(chunk_len.to_be_bytes(), &mut hex).unwrap();
let leading_zeros = hex.iter().position(|x| *x != b'0').unwrap_or_default();
let hex_chars = hex.len() - leading_zeros;
buf[..hex_chars].copy_from_slice(&hex[leading_zeros..]);
buf[hex_chars..hex_chars + 2].copy_from_slice(b"\r\n");
buf[hex_chars..hex_chars + NEWLINE.len()].copy_from_slice(NEWLINE);
hex_chars + 2
}

Expand All @@ -243,12 +266,84 @@ mod tests {
use super::*;

#[test]
fn hex_chars_values() {
assert_eq!(1, hex_chars(0));
assert_eq!(1, hex_chars(1));
assert_eq!(1, hex_chars(0xF));
assert_eq!(2, hex_chars(0x10));
assert_eq!(2, hex_chars(0xFF));
assert_eq!(3, hex_chars(0x100));
fn can_get_hex_chars() {
assert_eq!(1, get_hex_chars(0));
assert_eq!(1, get_hex_chars(1));
assert_eq!(1, get_hex_chars(0xF));
assert_eq!(2, get_hex_chars(0x10));
assert_eq!(2, get_hex_chars(0xFF));
assert_eq!(3, get_hex_chars(0x100));
}

#[test]
fn can_get_max_chunk_header_size() {
assert_eq!(0, get_max_chunk_header_size(3));
assert_eq!(3, get_max_chunk_header_size(0x00 + 2 + 2));
assert_eq!(3, get_max_chunk_header_size(0x01 + 2 + 2));
assert_eq!(3, get_max_chunk_header_size(0x0F + 2 + 2));
assert_eq!(4, get_max_chunk_header_size(0x10 + 2 + 2));
}

#[tokio::test]
async fn preserves_already_written_bytes_in_the_buffer_without_any_chunks() {
// Given
let mut conn = Vec::new();
let mut buf = [0; 1024];
buf[..5].copy_from_slice(b"HELLO");

// When
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 5);
writer.terminate().await.unwrap();

// Then
assert_eq!(b"HELLO0\r\n\r\n", conn.as_slice());
}

#[tokio::test]
async fn preserves_already_written_bytes_in_the_buffer_with_chunks() {
// Given
let mut conn = Vec::new();
let mut buf = [0; 1024];
buf[..5].copy_from_slice(b"HELLO");

// When
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 5);
writer.write_all(b"BODY").await.unwrap();
writer.terminate().await.unwrap();

// Then
assert_eq!(b"HELLO4\r\nBODY\r\n0\r\n\r\n", conn.as_slice());
}

#[tokio::test]
async fn current_chunk_is_emitted_before_empty_chunk_is_emitted() {
// Given
let mut conn = Vec::new();
let mut buf = [0; 14];
buf[..5].copy_from_slice(b"HELLO");

// When
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 5);
writer.write_all(b"BODY").await.unwrap(); // Can fit
writer.terminate().await.unwrap(); // Cannot fit

// Then
assert_eq!(b"HELLO4\r\nBODY\r\n0\r\n\r\n", conn.as_slice());
}

#[tokio::test]
async fn write_emits_chunks() {
// Given
let mut conn = Vec::new();
let mut buf = [0; 12];
buf[..5].copy_from_slice(b"HELLO");

// When
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 5);
writer.write_all(b"BODY").await.unwrap(); // Only "BO" can fit first, then "DY" is written in a different chunk
writer.terminate().await.unwrap();

// Then
assert_eq!(b"HELLO2\r\nBO\r\n2\r\nDY\r\n0\r\n\r\n", conn.as_slice());
}
}

0 comments on commit d8d416a

Please sign in to comment.