From 9cdc76d879e058b62ab0d1ccc2ef655bbf912691 Mon Sep 17 00:00:00 2001 From: Rasmus Melchior Jacobsen Date: Mon, 1 Jul 2024 08:41:05 +0200 Subject: [PATCH] Make an effort to ensure that write() does not return Ok(0) --- src/body_writer.rs | 92 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 85 insertions(+), 7 deletions(-) diff --git a/src/body_writer.rs b/src/body_writer.rs index c7aea69..5cc4316 100644 --- a/src/body_writer.rs +++ b/src/body_writer.rs @@ -149,7 +149,7 @@ where if self.header_pos + EMPTY_CHUNK.len() > self.buf.len() { // There is not enough space to fit the empty chunk in the buffer - self.emit_finished_chunks().await?; + self.emit_buffered().await?; } } @@ -157,12 +157,12 @@ where self.header_pos += EMPTY_CHUNK.len(); self.allocated_header = 0; self.pos = self.header_pos + self.allocated_header; - self.emit_finished_chunks().await + self.emit_buffered().await } /// Append to the buffer fn append_current_chunk(&mut self, buf: &[u8]) -> usize { - let buffered = usize::min(buf.len(), self.buf.len() - NEWLINE.len() - self.pos); + let buffered = usize::min(buf.len(), self.buf.len().saturating_sub(NEWLINE.len() + self.pos)); if buffered > 0 { self.buf[self.pos..self.pos + buffered].copy_from_slice(&buf[..buffered]); self.pos += buffered; @@ -196,7 +196,7 @@ where self.pos = self.header_pos + self.allocated_header; } - async fn emit_finished_chunks(&mut self) -> Result<(), C::Error> { + async fn emit_buffered(&mut self) -> Result<(), C::Error> { self.conn.write_all(&self.buf[..self.header_pos]).await?; self.header_pos = 0; self.allocated_header = get_max_chunk_header_size(self.buf.len()); @@ -217,10 +217,20 @@ where C: Write, { async fn write(&mut self, buf: &[u8]) -> Result { - let written = self.append_current_chunk(buf); + if buf.is_empty() { + return Ok(0); + } + + let mut written = self.append_current_chunk(buf); + if written == 0 { + // Unable to append any data to the buffer + // This can happen if the writer was pre-loaded with data + self.emit_buffered().await.map_err(|e| e.kind())?; + written = self.append_current_chunk(buf); + } if written < buf.len() { self.finish_current_chunk(); - self.emit_finished_chunks().await.map_err(|e| e.kind())?; + self.emit_buffered().await.map_err(|e| e.kind())?; } Ok(written) } @@ -229,7 +239,10 @@ where if self.pos > self.header_pos + self.allocated_header { // There are bytes written in the current chunk self.finish_current_chunk(); - self.emit_finished_chunks().await.map_err(|e| e.kind())?; + self.emit_buffered().await.map_err(|e| e.kind())?; + } else if self.header_pos > 0 { + // There are pre-written bytes in the buffer + self.emit_buffered().await.map_err(|e| e.kind())?; } self.conn.flush().await.map_err(|e| e.kind()) } @@ -337,6 +350,71 @@ mod tests { assert_eq!(b"HELLO4\r\nBODY\r\n0\r\n\r\n", conn.as_slice()); } + #[tokio::test] + async fn write_when_entire_buffer_is_prewritten() { + // Given + let mut conn = Vec::new(); + let mut buf = [0; 10]; + buf.copy_from_slice(b"HELLOHELLO"); + + // When + let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 10); + writer.write_all(b"BODY").await.unwrap(); // Cannot fit + writer.terminate().await.unwrap(); + + // Then + print!("{:?}", conn.as_slice()); + assert_eq!(b"HELLOHELLO4\r\nBODY\r\n0\r\n\r\n", conn.as_slice()); + } + + #[tokio::test] + async fn flush_when_entire_buffer_is_prewritten() { + // Given + let mut conn = Vec::new(); + let mut buf = [0; 10]; + buf.copy_from_slice(b"HELLOHELLO"); + + // When + let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 10); + writer.flush().await.unwrap(); + + // Then + print!("{:?}", conn.as_slice()); + assert_eq!(b"HELLOHELLO", conn.as_slice()); + } + + #[tokio::test] + async fn flush_when_entire_buffer_is_nearly_prewritten() { + // Given + let mut conn = Vec::new(); + let mut buf = [0; 11]; + buf[..10].copy_from_slice(b"HELLOHELLO"); + + // When + let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 10); + writer.flush().await.unwrap(); + + // Then + print!("{:?}", conn.as_slice()); + assert_eq!(b"HELLOHELLO", conn.as_slice()); + } + + #[tokio::test] + async fn flushes_already_written_bytes_if_first_cannot_fit() { + // Given + let mut conn = Vec::new(); + let mut buf = [0; 10]; + buf[..5].copy_from_slice(b"HELLO"); + + // When + let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 5); + writer.write_all(b"BODY").await.unwrap(); // Cannot fit + writer.terminate().await.unwrap(); // Can fit + + // Then + assert_eq!(b"HELLO4\r\nBODY\r\n0\r\n\r\n", conn.as_slice()); + } + #[tokio::test] async fn current_chunk_is_emitted_before_empty_chunk_is_emitted() { // Given