Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make an effort to ensure that write() does not return Ok(0) #81

Merged
merged 2 commits into from
Jul 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 136 additions & 12 deletions src/body_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,37 @@ where
}
}

/// A body writer that buffers internally and emits chunks as expected by the
/// `Transfer-Encoding: chunked` header specification.
///
/// Each emittted chunk has a header that specifies the size of the chunk,
/// and the last chunk has size equal to zero, indicating the end of the request.
///
/// The writer can be seeded with a buffer that is already pre-written. This is
/// typical if for example the request header is already written to the buffer.
/// The writer will in this case start appending a chunk to the end of the pre-written
/// buffer data leaving the pre-written data as-is.
///
/// To minimize the number of write calls to the underlying connection the writer
/// works by pre-allocating the chunk header in the buffer. The written body data is
/// then appended after this pre-allocated header. Depending on the number of bytes
/// actually written to the current chunk before the writer is terminated (indicating
/// the end of the request body), the pre-allocated header may be too large. If this
/// is the case, then the chunk payload is moved into the pre-allocated header region
/// such that the header and payload can be written to the underlying connection in
/// a single write.
///
pub struct BufferingChunkedBodyWriter<'a, C: Write> {
conn: C,
buf: &'a mut [u8],
/// The position where the allocated chunk header starts
header_pos: usize,
/// The size of the allocated header (the final header may be smaller)
/// This may be 0 if the pre-written bytes in `buf` is too large to fit a header.
allocated_header: usize,
/// The position of the data in the chunk
pos: usize,
terminated: bool,
}

impl<'a, C> BufferingChunkedBodyWriter<'a, C>
Expand All @@ -136,33 +158,37 @@ where
header_pos: written,
pos: written + allocated_header,
allocated_header,
terminated: false,
}
}

/// Terminate the request body by writing an empty chunk
pub async fn terminate(&mut self) -> Result<(), C::Error> {
assert!(self.allocated_header > 0);
assert!(!self.terminated);

if self.pos > self.header_pos + self.allocated_header {
// There are bytes written in the current chunk
self.finish_current_chunk();
}

if self.header_pos + EMPTY_CHUNK.len() > self.buf.len() {
// There is not enough space to fit the empty chunk in the buffer
self.emit_finished_chunks().await?;
}
if self.header_pos + EMPTY_CHUNK.len() > self.buf.len() {
// There is not enough space to fit the empty chunk in the buffer
self.emit_buffered().await?;
}

self.buf[self.header_pos..self.header_pos + EMPTY_CHUNK.len()].copy_from_slice(EMPTY_CHUNK);
self.header_pos += EMPTY_CHUNK.len();
self.allocated_header = 0;
self.pos = self.header_pos + self.allocated_header;
self.emit_finished_chunks().await
self.emit_buffered().await?;
self.terminated = true;
Ok(())
}

/// Append to the buffer
/// Append data to the current chunk and return the number of bytes appended.
/// This returns 0 if there is no current chunk to append to.
fn append_current_chunk(&mut self, buf: &[u8]) -> usize {
let buffered = usize::min(buf.len(), self.buf.len() - NEWLINE.len() - self.pos);
let buffered = usize::min(buf.len(), self.buf.len().saturating_sub(NEWLINE.len() + self.pos));
if buffered > 0 {
self.buf[self.pos..self.pos + buffered].copy_from_slice(&buf[..buffered]);
self.pos += buffered;
Expand Down Expand Up @@ -196,7 +222,7 @@ where
self.pos = self.header_pos + self.allocated_header;
}

async fn emit_finished_chunks(&mut self) -> Result<(), C::Error> {
async fn emit_buffered(&mut self) -> Result<(), C::Error> {
self.conn.write_all(&self.buf[..self.header_pos]).await?;
self.header_pos = 0;
self.allocated_header = get_max_chunk_header_size(self.buf.len());
Expand All @@ -217,10 +243,20 @@ where
C: Write,
{
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
let written = self.append_current_chunk(buf);
if buf.is_empty() {
return Ok(0);
}

let mut written = self.append_current_chunk(buf);
if written == 0 {
// Unable to append any data to the buffer
// This can happen if the writer was pre-loaded with data
self.emit_buffered().await.map_err(|e| e.kind())?;
written = self.append_current_chunk(buf);
}
if written < buf.len() {
self.finish_current_chunk();
self.emit_finished_chunks().await.map_err(|e| e.kind())?;
self.emit_buffered().await.map_err(|e| e.kind())?;
}
Ok(written)
}
Expand All @@ -229,7 +265,11 @@ where
if self.pos > self.header_pos + self.allocated_header {
// There are bytes written in the current chunk
self.finish_current_chunk();
self.emit_finished_chunks().await.map_err(|e| e.kind())?;
self.emit_buffered().await.map_err(|e| e.kind())?;
} else if self.header_pos > 0 {
// There are pre-written bytes in the buffer but no current chunk
// (the number of pre-written was so large that the space for a header could not be allocated)
self.emit_buffered().await.map_err(|e| e.kind())?;
}
self.conn.flush().await.map_err(|e| e.kind())
}
Expand Down Expand Up @@ -280,6 +320,9 @@ mod tests {

#[test]
fn can_get_max_chunk_header_size() {
assert_eq!(0, get_max_chunk_header_size(0));
assert_eq!(0, get_max_chunk_header_size(1));
assert_eq!(0, get_max_chunk_header_size(2));
assert_eq!(0, get_max_chunk_header_size(3));
assert_eq!(3, get_max_chunk_header_size(0x00 + 2 + 2));
assert_eq!(3, get_max_chunk_header_size(0x01 + 2 + 2));
Expand Down Expand Up @@ -337,6 +380,87 @@ mod tests {
assert_eq!(b"HELLO4\r\nBODY\r\n0\r\n\r\n", conn.as_slice());
}

#[tokio::test]
async fn write_when_entire_buffer_is_prewritten() {
// Given
let mut conn = Vec::new();
let mut buf = [0; 10];
buf.copy_from_slice(b"HELLOHELLO");

// When
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 10);
writer.write_all(b"BODY").await.unwrap(); // Cannot fit
writer.terminate().await.unwrap();

// Then
print!("{:?}", conn.as_slice());
assert_eq!(b"HELLOHELLO4\r\nBODY\r\n0\r\n\r\n", conn.as_slice());
}

#[tokio::test]
async fn flush_empty_body_when_entire_buffer_is_prewritten() {
// Given
let mut conn = Vec::new();
let mut buf = [0; 10];
buf.copy_from_slice(b"HELLOHELLO");

// When
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 10);
writer.flush().await.unwrap();

// Then
print!("{:?}", conn.as_slice());
assert_eq!(b"HELLOHELLO", conn.as_slice());
}

#[tokio::test]
async fn terminate_empty_body_when_entire_buffer_is_prewritten() {
// Given
let mut conn = Vec::new();
let mut buf = [0; 10];
buf.copy_from_slice(b"HELLOHELLO");

// When
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 10);
writer.terminate().await.unwrap();

// Then
print!("{:?}", conn.as_slice());
assert_eq!(b"HELLOHELLO0\r\n\r\n", conn.as_slice());
}

#[tokio::test]
async fn flush_when_entire_buffer_is_nearly_prewritten() {
// Given
let mut conn = Vec::new();
let mut buf = [0; 11];
buf[..10].copy_from_slice(b"HELLOHELLO");

// When
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 10);
writer.flush().await.unwrap();

// Then
print!("{:?}", conn.as_slice());
assert_eq!(b"HELLOHELLO", conn.as_slice());
}

#[tokio::test]
async fn flushes_already_written_bytes_if_first_cannot_fit() {
// Given
let mut conn = Vec::new();
let mut buf = [0; 10];
buf[..5].copy_from_slice(b"HELLO");

// When
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 5);
writer.write_all(b"BODY").await.unwrap(); // Cannot fit
writer.terminate().await.unwrap(); // Can fit

// Then
assert_eq!(b"HELLO4\r\nBODY\r\n0\r\n\r\n", conn.as_slice());
}

#[tokio::test]
async fn current_chunk_is_emitted_before_empty_chunk_is_emitted() {
// Given
Expand Down
Loading