diff --git a/Cargo.toml b/Cargo.toml index c127c9d..fa666c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ keywords = ["embedded", "async", "http", "no_std"] exclude = [".github"] [dependencies] -buffered-io = { version = "0.5" } +buffered-io = { version = "0.5.3" } embedded-io = { version = "0.6" } embedded-io-async = { version = "0.6" } embedded-nal-async = "0.7.0" @@ -27,7 +27,9 @@ defmt = { version = "0.3", optional = true } embedded-tls = { version = "0.17", default-features = false, optional = true } rand_chacha = { version = "0.3", default-features = false } nourl = "0.1.1" -esp-mbedtls = { git = "https://github.com/esp-rs/esp-mbedtls.git", features = ["async"], optional = true } +esp-mbedtls = { git = "https://github.com/esp-rs/esp-mbedtls.git", features = [ + "async", +], optional = true } [dev-dependencies] hyper = { version = "0.14.23", features = ["full"] } diff --git a/src/body_writer.rs b/src/body_writer.rs new file mode 100644 index 0000000..c7aea69 --- /dev/null +++ b/src/body_writer.rs @@ -0,0 +1,371 @@ +use core::mem::size_of; + +use embedded_io::{Error as _, ErrorType}; +use embedded_io_async::Write; + +const NEWLINE: &[u8; 2] = b"\r\n"; +const EMPTY_CHUNK: &[u8; 5] = b"0\r\n\r\n"; + +pub struct FixedBodyWriter(C, usize); + +impl FixedBodyWriter +where + C: Write, +{ + pub fn new(conn: C) -> Self { + Self(conn, 0) + } + + pub fn written(&self) -> usize { + self.1 + } +} + +impl ErrorType for FixedBodyWriter +where + C: Write, +{ + type Error = C::Error; +} + +impl Write for FixedBodyWriter +where + C: Write, +{ + async fn write(&mut self, buf: &[u8]) -> Result { + let written = self.0.write(buf).await?; + self.1 += written; + Ok(written) + } + + async fn write_all(&mut self, buf: &[u8]) -> Result<(), Self::Error> { + self.0.write_all(buf).await?; + self.1 += buf.len(); + Ok(()) + } + + async fn flush(&mut self) -> Result<(), Self::Error> { + self.0.flush().await + } +} + +pub struct ChunkedBodyWriter(C); + +impl ChunkedBodyWriter +where + C: Write, +{ + pub fn new(conn: C) -> Self { + Self(conn) + } + + /// Terminate the request body by writing an empty chunk + pub async fn terminate(&mut self) -> Result<(), C::Error> { + self.0.write_all(EMPTY_CHUNK).await + } +} + +impl ErrorType for ChunkedBodyWriter +where + C: Write, +{ + type Error = embedded_io::ErrorKind; +} + +impl Write for ChunkedBodyWriter +where + C: Write, +{ + async fn write(&mut self, buf: &[u8]) -> Result { + self.write_all(buf).await.map_err(|e| e.kind())?; + Ok(buf.len()) + } + + async fn write_all(&mut self, buf: &[u8]) -> Result<(), Self::Error> { + let len = buf.len(); + + // Do not write an empty chunk as that will terminate the body + // Use `ChunkedBodyWriter.write_empty_chunk` instead if this is intended + if len == 0 { + return Ok(()); + } + + // Write chunk header + let mut header_buf = [0; 2 * size_of::() + 2]; + let header_len = write_chunked_header(&mut header_buf, len); + self.0 + .write_all(&header_buf[..header_len]) + .await + .map_err(|e| e.kind())?; + + // Write chunk + self.0.write_all(buf).await.map_err(|e| e.kind())?; + + // Write newline footer + self.0.write_all(NEWLINE).await.map_err(|e| e.kind())?; + Ok(()) + } + + async fn flush(&mut self) -> Result<(), Self::Error> { + self.0.flush().await.map_err(|e| e.kind()) + } +} + +pub struct BufferingChunkedBodyWriter<'a, C: Write> { + conn: C, + buf: &'a mut [u8], + /// The position where the allocated chunk header starts + header_pos: usize, + /// The size of the allocated header (the final header may be smaller) + allocated_header: usize, + /// The position of the data in the chunk + pos: usize, +} + +impl<'a, C> BufferingChunkedBodyWriter<'a, C> +where + C: Write, +{ + pub fn new_with_data(conn: C, buf: &'a mut [u8], written: usize) -> Self { + assert!(written <= buf.len()); + let allocated_header = get_max_chunk_header_size(buf.len() - written); + assert!(buf.len() > allocated_header + NEWLINE.len()); // There must be space for the chunk header and footer + Self { + conn, + buf, + header_pos: written, + pos: written + allocated_header, + allocated_header, + } + } + + /// Terminate the request body by writing an empty chunk + pub async fn terminate(&mut self) -> Result<(), C::Error> { + assert!(self.allocated_header > 0); + + if self.pos > self.header_pos + self.allocated_header { + // There are bytes written in the current chunk + self.finish_current_chunk(); + + if self.header_pos + EMPTY_CHUNK.len() > self.buf.len() { + // There is not enough space to fit the empty chunk in the buffer + self.emit_finished_chunks().await?; + } + } + + self.buf[self.header_pos..self.header_pos + EMPTY_CHUNK.len()].copy_from_slice(EMPTY_CHUNK); + self.header_pos += EMPTY_CHUNK.len(); + self.allocated_header = 0; + self.pos = self.header_pos + self.allocated_header; + self.emit_finished_chunks().await + } + + /// Append to the buffer + fn append_current_chunk(&mut self, buf: &[u8]) -> usize { + let buffered = usize::min(buf.len(), self.buf.len() - NEWLINE.len() - self.pos); + if buffered > 0 { + self.buf[self.pos..self.pos + buffered].copy_from_slice(&buf[..buffered]); + self.pos += buffered; + } + buffered + } + + /// Finish the current chunk by writing the header + fn finish_current_chunk(&mut self) { + // Write the header in the allocated position position + let chunk_len = self.pos - self.header_pos - self.allocated_header; + let header_buf = &mut self.buf[self.header_pos..self.header_pos + self.allocated_header]; + let header_len = write_chunked_header(header_buf, chunk_len); + + // Move the payload if the header length was not as large as it could possibly be + let spacing = self.allocated_header - header_len; + if spacing > 0 { + self.buf.copy_within( + self.header_pos + self.allocated_header..self.pos, + self.header_pos + header_len, + ); + self.pos -= spacing + } + + // Write newline footer after chunk payload + self.buf[self.pos..self.pos + NEWLINE.len()].copy_from_slice(NEWLINE); + self.pos += 2; + + self.header_pos = self.pos; + self.allocated_header = get_max_chunk_header_size(self.buf.len() - self.header_pos); + self.pos = self.header_pos + self.allocated_header; + } + + async fn emit_finished_chunks(&mut self) -> Result<(), C::Error> { + self.conn.write_all(&self.buf[..self.header_pos]).await?; + self.header_pos = 0; + self.allocated_header = get_max_chunk_header_size(self.buf.len()); + self.pos = self.allocated_header; + Ok(()) + } +} + +impl ErrorType for BufferingChunkedBodyWriter<'_, C> +where + C: Write, +{ + type Error = embedded_io::ErrorKind; +} + +impl Write for BufferingChunkedBodyWriter<'_, C> +where + C: Write, +{ + async fn write(&mut self, buf: &[u8]) -> Result { + let written = self.append_current_chunk(buf); + if written < buf.len() { + self.finish_current_chunk(); + self.emit_finished_chunks().await.map_err(|e| e.kind())?; + } + Ok(written) + } + + async fn flush(&mut self) -> Result<(), Self::Error> { + if self.pos > self.header_pos + self.allocated_header { + // There are bytes written in the current chunk + self.finish_current_chunk(); + self.emit_finished_chunks().await.map_err(|e| e.kind())?; + } + self.conn.flush().await.map_err(|e| e.kind()) + } +} + +/// Get the number of hex characters for a number. +/// E.g. 0x0 => 1, 0x0F => 1, 0x10 => 2, 0x1234 => 4. +const fn get_num_hex_chars(number: usize) -> usize { + if number == 0 { + 1 + } else { + (usize::BITS - number.leading_zeros()).div_ceil(4) as usize + } +} + +const fn get_max_chunk_header_size(buffer_size: usize) -> usize { + if let Some(hex_chars_and_payload_size) = buffer_size.checked_sub(2 * NEWLINE.len()) { + get_num_hex_chars(hex_chars_and_payload_size) + NEWLINE.len() + } else { + // Not enough space in buffer to fit a header + footer + 0 + } +} + +fn write_chunked_header(buf: &mut [u8], chunk_len: usize) -> usize { + let mut hex = [0; 2 * size_of::()]; + hex::encode_to_slice(chunk_len.to_be_bytes(), &mut hex).unwrap(); + let leading_zeros = hex.iter().position(|x| *x != b'0').unwrap_or(hex.len() - 1); + let hex_chars = hex.len() - leading_zeros; + buf[..hex_chars].copy_from_slice(&hex[leading_zeros..]); + buf[hex_chars..hex_chars + NEWLINE.len()].copy_from_slice(NEWLINE); + hex_chars + 2 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn can_get_hex_chars() { + assert_eq!(1, get_num_hex_chars(0)); + assert_eq!(1, get_num_hex_chars(1)); + assert_eq!(1, get_num_hex_chars(0xF)); + assert_eq!(2, get_num_hex_chars(0x10)); + assert_eq!(2, get_num_hex_chars(0xFF)); + assert_eq!(3, get_num_hex_chars(0x100)); + } + + #[test] + fn can_get_max_chunk_header_size() { + assert_eq!(0, get_max_chunk_header_size(3)); + assert_eq!(3, get_max_chunk_header_size(0x00 + 2 + 2)); + assert_eq!(3, get_max_chunk_header_size(0x01 + 2 + 2)); + assert_eq!(3, get_max_chunk_header_size(0x0F + 2 + 2)); + assert_eq!(4, get_max_chunk_header_size(0x10 + 2 + 2)); + assert_eq!(4, get_max_chunk_header_size(0x11 + 2 + 2)); + assert_eq!(4, get_max_chunk_header_size(0x12 + 2 + 2)); + } + + #[test] + fn can_write_chunked_header() { + let mut buf = [0; 4]; + + let len = write_chunked_header(&mut buf, 0x00); + assert_eq!(b"0\r\n", &buf[..len]); + + let len = write_chunked_header(&mut buf, 0x01); + assert_eq!(b"1\r\n", &buf[..len]); + + let len = write_chunked_header(&mut buf, 0x0F); + assert_eq!(b"f\r\n", &buf[..len]); + + let len = write_chunked_header(&mut buf, 0x10); + assert_eq!(b"10\r\n", &buf[..len]); + } + + #[tokio::test] + async fn preserves_already_written_bytes_in_the_buffer_without_any_chunks() { + // Given + let mut conn = Vec::new(); + let mut buf = [0; 1024]; + buf[..5].copy_from_slice(b"HELLO"); + + // When + let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 5); + writer.terminate().await.unwrap(); + + // Then + assert_eq!(b"HELLO0\r\n\r\n", conn.as_slice()); + } + + #[tokio::test] + async fn preserves_already_written_bytes_in_the_buffer_with_chunks() { + // Given + let mut conn = Vec::new(); + let mut buf = [0; 1024]; + buf[..5].copy_from_slice(b"HELLO"); + + // When + let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 5); + writer.write_all(b"BODY").await.unwrap(); + writer.terminate().await.unwrap(); + + // Then + assert_eq!(b"HELLO4\r\nBODY\r\n0\r\n\r\n", conn.as_slice()); + } + + #[tokio::test] + async fn current_chunk_is_emitted_before_empty_chunk_is_emitted() { + // Given + let mut conn = Vec::new(); + let mut buf = [0; 14]; + buf[..5].copy_from_slice(b"HELLO"); + + // When + let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 5); + writer.write_all(b"BODY").await.unwrap(); // Can fit + writer.terminate().await.unwrap(); // Cannot fit + + // Then + assert_eq!(b"HELLO4\r\nBODY\r\n0\r\n\r\n", conn.as_slice()); + } + + #[tokio::test] + async fn write_emits_chunks() { + // Given + let mut conn = Vec::new(); + let mut buf = [0; 12]; + buf[..5].copy_from_slice(b"HELLO"); + + // When + let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 5); + writer.write_all(b"BODY").await.unwrap(); // Only "BO" can fit first, then "DY" is written in a different chunk + writer.terminate().await.unwrap(); + + // Then + assert_eq!(b"HELLO2\r\nBO\r\n2\r\nDY\r\n0\r\n\r\n", conn.as_slice()); + } +} diff --git a/src/client.rs b/src/client.rs index 86db2ce..163ab50 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,3 +1,6 @@ +use crate::body_writer::BufferingChunkedBodyWriter; +use crate::body_writer::ChunkedBodyWriter; +use crate::body_writer::FixedBodyWriter; /// Client using embedded-nal-async traits to establish connections and perform HTTP requests. /// use crate::headers::ContentType; @@ -280,12 +283,60 @@ where /// The response is returned. pub async fn send<'req, 'buf, B: RequestBody>( &'req mut self, - request: Request<'conn, B>, + request: Request<'req, B>, rx_buf: &'buf mut [u8], ) -> Result>, Error> { - request.write(self).await?; + self.write_request(&request).await?; + self.flush().await?; Response::read(self, request.method, rx_buf).await } + + async fn write_request<'req, B: RequestBody>(&mut self, request: &Request<'req, B>) -> Result<(), Error> { + request.write_header(self).await?; + + if let Some(body) = request.body.as_ref() { + match body.len() { + Some(0) => { + // Empty body + } + Some(len) => { + trace!("Writing not-chunked body"); + let mut writer = FixedBodyWriter::new(self); + body.write(&mut writer).await.map_err(|e| e.kind())?; + + if writer.written() != len { + return Err(Error::IncorrectBodyWritten); + } + } + None => { + trace!("Writing chunked body"); + match self { + HttpConnection::Plain(c) => { + let mut writer = ChunkedBodyWriter::new(c); + body.write(&mut writer).await?; + writer.terminate().await.map_err(|e| e.kind())?; + } + HttpConnection::PlainBuffered(buffered) => { + let (conn, buf, unwritten) = buffered.split(); + let mut writer = BufferingChunkedBodyWriter::new_with_data(conn, buf, unwritten); + body.write(&mut writer).await?; + writer.terminate().await.map_err(|e| e.kind())?; + buffered.clear(); + } + #[cfg(any(feature = "embedded-tls", feature = "esp-mbedtls"))] + HttpConnection::Tls(c) => { + let mut writer = ChunkedBodyWriter::new(c); + body.write(&mut writer).await?; + writer.terminate().await.map_err(|e| e.kind())?; + } + #[cfg(all(not(feature = "embedded-tls"), not(feature = "esp-mbedtls")))] + HttpConnection::Tls(_) => unreachable!(), + }; + } + } + } + Ok(()) + } } impl ErrorType for HttpConnection<'_, T> @@ -379,7 +430,8 @@ where rx_buf: &'buf mut [u8], ) -> Result>, Error> { let request = self.request.take().ok_or(Error::AlreadySent)?.build(); - request.write(&mut self.conn).await?; + self.conn.write_request(&request).await?; + self.conn.flush().await?; Response::read(&mut self.conn, request.method, rx_buf).await } } @@ -508,7 +560,8 @@ where rx_buf: &'buf mut [u8], ) -> Result>, Error> { request.base_path = Some(self.base_path); - request.write(&mut self.conn).await?; + self.conn.write_request(&request).await?; + self.conn.flush().await?; Response::read(&mut self.conn, request.method, rx_buf).await } } @@ -541,7 +594,8 @@ where let conn = self.conn; let mut request = self.request.build(); request.base_path = Some(self.base_path); - request.write(conn).await?; + conn.write_request(&request).await?; + conn.flush().await?; Response::read(conn, request.method, rx_buf).await } } @@ -590,3 +644,98 @@ where self.request.build() } } + +#[cfg(test)] +mod tests { + use core::convert::Infallible; + + use super::*; + + #[derive(Default)] + struct VecBuffer(Vec); + + impl ErrorType for VecBuffer { + type Error = Infallible; + } + + impl Read for VecBuffer { + async fn read(&mut self, _buf: &mut [u8]) -> Result { + unreachable!() + } + } + + impl Write for VecBuffer { + async fn write(&mut self, buf: &[u8]) -> Result { + self.0.extend_from_slice(buf); + Ok(buf.len()) + } + } + + #[tokio::test] + async fn with_empty_body() { + let mut buffer = VecBuffer::default(); + let mut conn = HttpConnection::Plain(&mut buffer); + + let request = Request::new(Method::POST, "/").body([].as_slice()).build(); + conn.write_request(&request).await.unwrap(); + + assert_eq!(b"POST / HTTP/1.1\r\nContent-Length: 0\r\n\r\n", buffer.0.as_slice()); + } + + #[tokio::test] + async fn with_known_body() { + let mut buffer = VecBuffer::default(); + let mut conn = HttpConnection::Plain(&mut buffer); + + let request = Request::new(Method::POST, "/").body(b"BODY".as_slice()).build(); + conn.write_request(&request).await.unwrap(); + + assert_eq!(b"POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nBODY", buffer.0.as_slice()); + } + + struct ChunkedBody(&'static [&'static [u8]]); + + impl RequestBody for ChunkedBody { + fn len(&self) -> Option { + None // Unknown length: triggers chunked body + } + + async fn write(&self, writer: &mut W) -> Result<(), W::Error> { + for chunk in self.0 { + writer.write_all(chunk).await?; + } + Ok(()) + } + } + + #[tokio::test] + async fn with_unknown_body_unbuffered() { + let mut buffer = VecBuffer::default(); + let mut conn = HttpConnection::Plain(&mut buffer); + + static CHUNKS: [&'static [u8]; 2] = [b"PART1", b"PART2"]; + let request = Request::new(Method::POST, "/").body(ChunkedBody(&CHUNKS)).build(); + conn.write_request(&request).await.unwrap(); + + assert_eq!( + b"POST / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nPART1\r\n5\r\nPART2\r\n0\r\n\r\n", + buffer.0.as_slice() + ); + } + + #[tokio::test] + async fn with_unknown_body_buffered() { + let mut buffer = VecBuffer::default(); + let mut tx_buf = [0; 1024]; + let mut conn = HttpConnection::Plain(&mut buffer).into_buffered(&mut tx_buf); + + static CHUNKS: [&'static [u8]; 2] = [b"PART1", b"PART2"]; + let request = Request::new(Method::POST, "/").body(ChunkedBody(&CHUNKS)).build(); + conn.write_request(&request).await.unwrap(); + + assert_eq!( + b"POST / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\na\r\nPART1PART2\r\n0\r\n\r\n", + buffer.0.as_slice() + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index e1f7740..f699415 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ use embedded_io_async::ReadExactError; mod fmt; +mod body_writer; pub mod client; pub mod headers; mod reader; diff --git a/src/reader.rs b/src/reader.rs index bb22d13..6d1b8c6 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -98,7 +98,7 @@ where unreachable!() } - self.buffer.loaded = self.stream.read(&mut self.buffer.buffer).await.map_err(|e| e.kind())?; + self.buffer.loaded = self.stream.read(self.buffer.buffer).await.map_err(|e| e.kind())?; } self.buffer.fill_buf() diff --git a/src/request.rs b/src/request.rs index 416211e..a9238e5 100644 --- a/src/request.rs +++ b/src/request.rs @@ -1,9 +1,8 @@ -use crate::headers::ContentType; /// Low level API for encoding requests and decoding responses. +use crate::headers::ContentType; use crate::Error; use core::fmt::Write as _; -use core::mem::size_of; -use embedded_io::{Error as _, ErrorType}; +use embedded_io::Error as _; use embedded_io_async::Write; use heapless::String; @@ -106,8 +105,8 @@ impl<'req, B> Request<'req, B> where B: RequestBody, { - /// Write request to the I/O stream - pub async fn write(&self, c: &mut C) -> Result<(), Error> + /// Write request header to the I/O stream + pub async fn write_header(&self, c: &mut C) -> Result<(), Error> where C: Write, { @@ -161,31 +160,6 @@ where } write_str(c, "\r\n").await?; trace!("Header written"); - if let Some(body) = self.body.as_ref() { - match body.len() { - Some(0) => { - // Empty body - } - Some(len) => { - trace!("Writing not-chunked body"); - let mut writer = FixedBodyWriter(c, 0); - body.write(&mut writer).await.map_err(to_errorkind)?; - - if writer.1 != len { - return Err(Error::IncorrectBodyWritten); - } - } - None => { - trace!("Writing chunked body"); - let mut writer = ChunkedBodyWriter(c, 0); - body.write(&mut writer).await?; - - write_str(c, "0\r\n\r\n").await?; - } - } - } - - c.flush().await.map_err(|e| e.kind())?; Ok(()) } } @@ -273,7 +247,7 @@ impl Method { } async fn write_str(c: &mut C, data: &str) -> Result<(), Error> { - c.write_all(data.as_bytes()).await.map_err(to_errorkind)?; + c.write_all(data.as_bytes()).await.map_err(|e| e.kind())?; Ok(()) } @@ -337,82 +311,6 @@ where } } -pub struct FixedBodyWriter<'a, C: Write>(&'a mut C, usize); - -impl ErrorType for FixedBodyWriter<'_, C> -where - C: Write, -{ - type Error = C::Error; -} - -impl Write for FixedBodyWriter<'_, C> -where - C: Write, -{ - async fn write(&mut self, buf: &[u8]) -> Result { - let written = self.0.write(buf).await?; - self.1 += written; - Ok(written) - } - - async fn write_all(&mut self, buf: &[u8]) -> Result<(), Self::Error> { - self.0.write_all(buf).await?; - self.1 += buf.len(); - Ok(()) - } - - async fn flush(&mut self) -> Result<(), Self::Error> { - self.0.flush().await - } -} - -pub struct ChunkedBodyWriter<'a, C: Write>(&'a mut C, usize); - -impl ErrorType for ChunkedBodyWriter<'_, C> -where - C: Write, -{ - type Error = embedded_io::ErrorKind; -} - -fn to_errorkind(e: E) -> embedded_io::ErrorKind { - e.kind() -} - -impl Write for ChunkedBodyWriter<'_, C> -where - C: Write, -{ - async fn write(&mut self, buf: &[u8]) -> Result { - self.write_all(buf).await.map_err(to_errorkind)?; - Ok(buf.len()) - } - - async fn write_all(&mut self, buf: &[u8]) -> Result<(), Self::Error> { - // Write chunk header - let len = buf.len(); - let mut hex = [0; 2 * size_of::()]; - hex::encode_to_slice(len.to_be_bytes(), &mut hex).unwrap(); - let leading_zeros = hex.iter().position(|x| *x != b'0').unwrap_or_default(); - let (_, hex) = hex.split_at(leading_zeros); - self.0.write_all(hex).await.map_err(to_errorkind)?; - self.0.write_all(b"\r\n").await.map_err(to_errorkind)?; - - // Write chunk - self.0.write_all(buf).await.map_err(to_errorkind)?; - self.1 += len; - - // Write newline - self.0.write_all(b"\r\n").await.map_err(to_errorkind)?; - Ok(()) - } - - async fn flush(&mut self) -> Result<(), Self::Error> { - self.0.flush().await.map_err(|e| e.kind()) - } -} - #[cfg(test)] mod tests { use super::*; @@ -423,7 +321,7 @@ mod tests { Request::new(Method::GET, "/") .basic_auth("username", "password") .build() - .write(&mut buffer) + .write_header(&mut buffer) .await .unwrap(); @@ -439,7 +337,7 @@ mod tests { Request::new(Method::POST, "/") .body([].as_slice()) .build() - .write(&mut buffer) + .write_header(&mut buffer) .await .unwrap(); @@ -447,43 +345,43 @@ mod tests { } #[tokio::test] - async fn with_known_body() { + async fn with_known_body_adds_content_length_header() { let mut buffer = Vec::new(); Request::new(Method::POST, "/") .body(b"BODY".as_slice()) .build() - .write(&mut buffer) + .write_header(&mut buffer) .await .unwrap(); - assert_eq!(b"POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nBODY", buffer.as_slice()); + assert_eq!(b"POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\n", buffer.as_slice()); } - struct ChunkedBody<'a>(&'a [u8]); + struct ChunkedBody; - impl RequestBody for ChunkedBody<'_> { + impl RequestBody for ChunkedBody { fn len(&self) -> Option { None // Unknown length: triggers chunked body } - async fn write(&self, writer: &mut W) -> Result<(), W::Error> { - writer.write_all(self.0).await + async fn write(&self, _writer: &mut W) -> Result<(), W::Error> { + unreachable!() } } #[tokio::test] - async fn with_unknown_body() { + async fn with_unknown_body_adds_transfer_encoding_header() { let mut buffer = Vec::new(); Request::new(Method::POST, "/") - .body(ChunkedBody(b"BODY".as_slice())) + .body(ChunkedBody) .build() - .write(&mut buffer) + .write_header(&mut buffer) .await .unwrap(); assert_eq!( - b"POST / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nBODY\r\n0\r\n\r\n", + b"POST / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n", buffer.as_slice() ); } diff --git a/tests/request.rs b/tests/request.rs index 111ccc5..fa5fea9 100644 --- a/tests/request.rs +++ b/tests/request.rs @@ -2,9 +2,9 @@ use embedded_io_adapters::tokio_1::FromTokio; use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Server}; use reqwless::client::HttpConnection; -use reqwless::request::{Method, RequestBuilder}; +use reqwless::request::RequestBuilder; use reqwless::Error; -use reqwless::{headers::ContentType, request::Request, response::Response}; +use reqwless::{headers::ContentType, request::Request}; use std::str::from_utf8; use std::sync::Once; use tokio::net::TcpStream; @@ -48,9 +48,8 @@ async fn test_request_response() { .content_type(ContentType::TextPlain) .build(); - request.write(&mut stream).await.unwrap(); let mut rx_buf = [0; 4096]; - let response = Response::read(&mut stream, Method::POST, &mut rx_buf).await.unwrap(); + let response = stream.send(request, &mut rx_buf).await.unwrap(); let body = response.body().read_to_end().await; assert_eq!(body.unwrap(), b"PING"); @@ -70,7 +69,7 @@ async fn write_without_base_path() { let request = Request::get("/hello").build(); let mut buf = Vec::new(); - request.write(&mut buf).await.unwrap(); + request.write_header(&mut buf).await.unwrap(); assert!(from_utf8(&buf).unwrap().starts_with("GET /hello HTTP/1.1")); } @@ -82,16 +81,15 @@ async fn google_panic() { let addr = SocketAddr::from((google_ip, 80)); let conn = tokio::net::TcpStream::connect(addr).await.unwrap(); - let mut conn = TokioStream(FromTokio::new(conn)); + let mut conn = HttpConnection::Plain(TokioStream(FromTokio::new(conn))); let request = Request::get("/") .host("www.google.com") .content_type(ContentType::TextPlain) .build(); - request.write(&mut conn).await.unwrap(); let mut rx_buf = [0; 8 * 1024]; - let resp = Response::read(&mut conn, Method::GET, &mut rx_buf).await.unwrap(); + let resp = conn.send(request, &mut rx_buf).await.unwrap(); let result = resp.body().read_to_end().await; match result {