Skip to content

Commit

Permalink
Buffer writes before chunks are written to connection
Browse files Browse the repository at this point in the history
This commit:
* Moves the responsibility for writing the request body from `Request` to `HttpConnection`.
* Uses the buffer provided when calling `into_buffered()` to buffer writes before they are passed on to the `ChunkedBufferWriter`

This fixes drogue-iot#71
  • Loading branch information
rmja committed May 17, 2024
1 parent 12c5ab7 commit 9790bc8
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 59 deletions.
164 changes: 159 additions & 5 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,68 @@ where
/// The response is returned.
pub async fn send<'req, 'buf, B: RequestBody>(
&'req mut self,
request: Request<'conn, B>,
request: Request<'req, B>,
rx_buf: &'buf mut [u8],
) -> Result<Response<'req, 'buf, HttpConnection<'conn, T>>, Error> {
request.write(self).await?;
self.write_request(&request).await?;
self.flush().await?;
Response::read(self, request.method, rx_buf).await
}

async fn write_request<'req, B: RequestBody>(&mut self, request: &Request<'req, B>) -> Result<(), Error> {
request.write_header(self).await?;

if let Some(body) = request.body.as_ref() {
match body.len() {
Some(0) => {
// Empty body
}
Some(len) => {
trace!("Writing not-chunked body");
let mut writer = FixedBodyWriter::new(self);
body.write(&mut writer).await.map_err(|e| e.kind())?;

if writer.written() != len {
return Err(Error::IncorrectBodyWritten);
}
}
None => {
trace!("Writing chunked body");
match self {
HttpConnection::Plain(c) => {
let mut writer = ChunkedBodyWriter::new(c);
body.write(&mut writer).await?;
writer.write_empty_chunk().await.map_err(|e| e.kind())?;
}
HttpConnection::PlainBuffered(buffered_conn) => {
// Flush the buffered connection so that we can bypass it and rent its buffer
buffered_conn.flush().await.map_err(|e| e.kind())?;
let (conn, buf) = buffered_conn.bypass_with_buf().unwrap();

// Construct a new buffered writer that buffers _before_ the chunked body writer
let mut writer = BufferedWrite::new(ChunkedBodyWriter::new(conn), buf);
body.write(&mut writer).await?;

// Flush the buffered writer and write the empty chunk to the chunked body writer
writer.flush().await.map_err(|e| e.kind())?;
writer
.bypass()
.unwrap()
.write_empty_chunk()
.await
.map_err(|e| e.kind())?;
}
HttpConnection::Tls(c) => {
let mut writer = ChunkedBodyWriter::new(c);
body.write(&mut writer).await?;
writer.write_empty_chunk().await.map_err(|e| e.kind())?;
}
};
}
}
}
Ok(())
}
}

impl<T> ErrorType for HttpConnection<'_, T>
Expand Down Expand Up @@ -379,7 +435,8 @@ where
rx_buf: &'buf mut [u8],
) -> Result<Response<'req, 'buf, HttpConnection<'conn, C>>, Error> {
let request = self.request.take().ok_or(Error::AlreadySent)?.build();
request.write(&mut self.conn).await?;
self.conn.write_request(&request).await?;
self.conn.flush().await?;
Response::read(&mut self.conn, request.method, rx_buf).await
}
}
Expand Down Expand Up @@ -508,7 +565,8 @@ where
rx_buf: &'buf mut [u8],
) -> Result<Response<'req, 'buf, HttpConnection<'res, C>>, Error> {
request.base_path = Some(self.base_path);
request.write(&mut self.conn).await?;
self.conn.write_request(&request).await?;
self.conn.flush().await?;
Response::read(&mut self.conn, request.method, rx_buf).await
}
}
Expand Down Expand Up @@ -541,7 +599,8 @@ where
let conn = self.conn;
let mut request = self.request.build();
request.base_path = Some(self.base_path);
request.write(conn).await?;
conn.write_request(&request).await?;
conn.flush().await?;
Response::read(conn, request.method, rx_buf).await
}
}
Expand Down Expand Up @@ -590,3 +649,98 @@ where
self.request.build()
}
}

#[cfg(test)]
mod tests {
use core::convert::Infallible;

use super::*;

#[derive(Default)]
struct VecBuffer(Vec<u8>);

impl ErrorType for VecBuffer {
type Error = Infallible;
}

impl Read for VecBuffer {
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
unreachable!()
}
}

impl Write for VecBuffer {
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
self.0.extend_from_slice(buf);
Ok(buf.len())
}
}

#[tokio::test]
async fn with_empty_body() {
let mut buffer = VecBuffer::default();
let mut conn = HttpConnection::Plain(&mut buffer);

let request = Request::new(Method::POST, "/").body([].as_slice()).build();
conn.write_request(&request).await.unwrap();

assert_eq!(b"POST / HTTP/1.1\r\nContent-Length: 0\r\n\r\n", buffer.0.as_slice());
}

#[tokio::test]
async fn with_known_body() {
let mut buffer = VecBuffer::default();
let mut conn = HttpConnection::Plain(&mut buffer);

let request = Request::new(Method::POST, "/").body(b"BODY".as_slice()).build();
conn.write_request(&request).await.unwrap();

assert_eq!(b"POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nBODY", buffer.0.as_slice());
}

struct ChunkedBody(&'static [&'static [u8]]);

impl RequestBody for ChunkedBody {
fn len(&self) -> Option<usize> {
None // Unknown length: triggers chunked body
}

async fn write<W: Write>(&self, writer: &mut W) -> Result<(), W::Error> {
for chunk in self.0 {
writer.write_all(chunk).await?;
}
Ok(())
}
}

#[tokio::test]
async fn with_unknown_body_unbuffered() {
let mut buffer = VecBuffer::default();
let mut conn = HttpConnection::Plain(&mut buffer);

static CHUNKS: [&'static [u8]; 2] = [b"PART1", b"PART2"];
let request = Request::new(Method::POST, "/").body(ChunkedBody(&CHUNKS)).build();
conn.write_request(&request).await.unwrap();

assert_eq!(
b"POST / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nPART1\r\n5\r\nPART2\r\n0\r\n\r\n",
buffer.0.as_slice()
);
}

#[tokio::test]
async fn with_unknown_body_buffered() {
let mut buffer = VecBuffer::default();
let mut tx_buf = [0; 1024];
let mut conn = HttpConnection::Plain(&mut buffer).into_buffered(&mut tx_buf);

static CHUNKS: [&'static [u8]; 2] = [b"PART1", b"PART2"];
let request = Request::new(Method::POST, "/").body(ChunkedBody(&CHUNKS)).build();
conn.write_request(&request).await.unwrap();

assert_eq!(
b"POST / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\na\r\nPART1PART2\r\n0\r\n\r\n",
buffer.0.as_slice()
);
}
}
100 changes: 54 additions & 46 deletions src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ impl<'req, B> Request<'req, B>
where
B: RequestBody,
{
/// Write request to the I/O stream
pub async fn write<C>(&self, c: &mut C) -> Result<(), Error>
/// Write request header to the I/O stream
pub async fn write_header<C>(&self, c: &mut C) -> Result<(), Error>
where
C: Write,
{
Expand Down Expand Up @@ -161,31 +161,6 @@ where
}
write_str(c, "\r\n").await?;
trace!("Header written");
if let Some(body) = self.body.as_ref() {
match body.len() {
Some(0) => {
// Empty body
}
Some(len) => {
trace!("Writing not-chunked body");
let mut writer = FixedBodyWriter(c, 0);
body.write(&mut writer).await.map_err(to_errorkind)?;

if writer.1 != len {
return Err(Error::IncorrectBodyWritten);
}
}
None => {
trace!("Writing chunked body");
let mut writer = ChunkedBodyWriter(c, 0);
body.write(&mut writer).await?;

write_str(c, "0\r\n\r\n").await?;
}
}
}

c.flush().await.map_err(|e| e.kind())?;
Ok(())
}
}
Expand Down Expand Up @@ -337,16 +312,29 @@ where
}
}

pub struct FixedBodyWriter<'a, C: Write>(&'a mut C, usize);
pub struct FixedBodyWriter<C: Write>(C, usize);

impl<C> ErrorType for FixedBodyWriter<'_, C>
impl<C> FixedBodyWriter<C>
where
C: Write,
{
pub fn new(conn: C) -> Self {
Self(conn, 0)
}

pub fn written(&self) -> usize {
self.1
}
}

impl<C> ErrorType for FixedBodyWriter<C>
where
C: Write,
{
type Error = C::Error;
}

impl<C> Write for FixedBodyWriter<'_, C>
impl<C> Write for FixedBodyWriter<C>
where
C: Write,
{
Expand All @@ -367,9 +355,22 @@ where
}
}

pub struct ChunkedBodyWriter<'a, C: Write>(&'a mut C, usize);
pub struct ChunkedBodyWriter<C: Write>(C, usize);

impl<C> ErrorType for ChunkedBodyWriter<'_, C>
impl<C> ChunkedBodyWriter<C>
where
C: Write,
{
pub fn new(conn: C) -> Self {
Self(conn, 0)
}

pub async fn write_empty_chunk(&mut self) -> Result<(), C::Error> {
self.0.write_all(b"0\r\n\r\n").await
}
}

impl<C> ErrorType for ChunkedBodyWriter<C>
where
C: Write,
{
Expand All @@ -380,7 +381,7 @@ fn to_errorkind<E: embedded_io::Error>(e: E) -> embedded_io::ErrorKind {
e.kind()
}

impl<C> Write for ChunkedBodyWriter<'_, C>
impl<C> Write for ChunkedBodyWriter<C>
where
C: Write,
{
Expand All @@ -392,6 +393,13 @@ where
async fn write_all(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
// Write chunk header
let len = buf.len();

// Do not write an empty chunk as that will terminate the body
// Use `ChunkedBodyWriter.write_empty_chunk` instead if this is intended
if len == 0 {
return Ok(());
}

let mut hex = [0; 2 * size_of::<usize>()];
hex::encode_to_slice(len.to_be_bytes(), &mut hex).unwrap();
let leading_zeros = hex.iter().position(|x| *x != b'0').unwrap_or_default();
Expand Down Expand Up @@ -423,7 +431,7 @@ mod tests {
Request::new(Method::GET, "/")
.basic_auth("username", "password")
.build()
.write(&mut buffer)
.write_header(&mut buffer)
.await
.unwrap();

Expand All @@ -439,51 +447,51 @@ mod tests {
Request::new(Method::POST, "/")
.body([].as_slice())
.build()
.write(&mut buffer)
.write_header(&mut buffer)
.await
.unwrap();

assert_eq!(b"POST / HTTP/1.1\r\nContent-Length: 0\r\n\r\n", buffer.as_slice());
}

#[tokio::test]
async fn with_known_body() {
async fn with_known_body_adds_content_length_header() {
let mut buffer = Vec::new();
Request::new(Method::POST, "/")
.body(b"BODY".as_slice())
.build()
.write(&mut buffer)
.write_header(&mut buffer)
.await
.unwrap();

assert_eq!(b"POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nBODY", buffer.as_slice());
assert_eq!(b"POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\n", buffer.as_slice());
}

struct ChunkedBody<'a>(&'a [u8]);
struct ChunkedBody;

impl RequestBody for ChunkedBody<'_> {
impl RequestBody for ChunkedBody {
fn len(&self) -> Option<usize> {
None // Unknown length: triggers chunked body
}

async fn write<W: Write>(&self, writer: &mut W) -> Result<(), W::Error> {
writer.write_all(self.0).await
async fn write<W: Write>(&self, _writer: &mut W) -> Result<(), W::Error> {
unreachable!()
}
}

#[tokio::test]
async fn with_unknown_body() {
async fn with_unknown_body_adds_transfer_encoding_header() {
let mut buffer = Vec::new();

Request::new(Method::POST, "/")
.body(ChunkedBody(b"BODY".as_slice()))
.body(ChunkedBody)
.build()
.write(&mut buffer)
.write_header(&mut buffer)
.await
.unwrap();

assert_eq!(
b"POST / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nBODY\r\n0\r\n\r\n",
b"POST / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n",
buffer.as_slice()
);
}
Expand Down
Loading

0 comments on commit 9790bc8

Please sign in to comment.