diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 085721c..792ce69 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -36,4 +36,6 @@ jobs: run: cargo clippy - name: Test - run: cargo test + run: | + cargo test + cargo test --no-default-features diff --git a/README.md b/README.md index 337e63b..fae0a17 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,5 @@ This enables `alloc` on `embedded-tls` which in turn enables RSA signature algor `reqwless` requires a feature from `nightly` to compile `embedded-io` with async support: * `async_fn_in_trait` -* `impl_trait_projections` This feature is complete, but is not yet merged to `stable`. diff --git a/src/client.rs b/src/client.rs index ec23515..2e81c30 100644 --- a/src/client.rs +++ b/src/client.rs @@ -79,7 +79,10 @@ where } } - async fn connect<'m>(&'m mut self, url: &Url<'m>) -> Result>, Error> { + async fn connect<'conn>( + &'conn mut self, + url: &Url<'_>, + ) -> Result>, Error> { let host = url.host(); let port = url.port_or_default(); @@ -107,7 +110,7 @@ where if let TlsVerify::Psk { identity, psk } = tls.verify { config = config.with_psk(psk, &[identity]); } - let mut conn: embedded_tls::TlsConnection<'m, T::Connection<'m>, embedded_tls::Aes128GcmSha256> = + let mut conn: embedded_tls::TlsConnection<'conn, T::Connection<'conn>, embedded_tls::Aes128GcmSha256> = embedded_tls::TlsConnection::new(conn, tls.read_buffer, tls.write_buffer); conn.open::<_, embedded_tls::NoVerify>(TlsContext::new(&config, &mut rng)) .await?; @@ -121,7 +124,7 @@ where #[cfg(feature = "embedded-tls")] match self.tls.as_mut() { Some(tls) => Ok(HttpConnection::PlainBuffered(BufferedWrite::new( - buffered_io_adapter::ConnErrorAdapter(conn), + conn, tls.write_buffer, ))), None => Ok(HttpConnection::Plain(conn)), @@ -132,11 +135,11 @@ where } /// Create a single http request. - pub async fn request<'m>( - &'m mut self, + pub async fn request<'conn>( + &'conn mut self, method: Method, - url: &'m str, - ) -> Result>, ()>, Error> { + url: &'conn str, + ) -> Result, ()>, Error> { let url = Url::parse(url)?; let conn = self.connect(&url).await?; Ok(HttpRequestHandle { @@ -150,7 +153,7 @@ where pub async fn resource<'res>( &'res mut self, resource_url: &'res str, - ) -> Result>>, Error> { + ) -> Result>, Error> { let resource_url = Url::parse(resource_url)?; let conn = self.connect(&resource_url).await?; Ok(HttpResource { @@ -163,23 +166,64 @@ where /// Represents a HTTP connection that may be encrypted or unencrypted. #[allow(clippy::large_enum_variant)] -pub enum HttpConnection<'m, C> +pub enum HttpConnection<'conn, C> where C: Read + Write, { Plain(C), + PlainBuffered(BufferedWrite<'conn, C>), #[cfg(feature = "embedded-tls")] - PlainBuffered(BufferedWrite<'m, buffered_io_adapter::ConnErrorAdapter>), - #[cfg(feature = "embedded-tls")] - Tls(embedded_tls::TlsConnection<'m, C, embedded_tls::Aes128GcmSha256>), + Tls(embedded_tls::TlsConnection<'conn, C, embedded_tls::Aes128GcmSha256>), #[cfg(not(feature = "embedded-tls"))] - Tls((&'m mut (), core::convert::Infallible)), // Variant is impossible to create, but we need it to avoid "unused lifetime" warning + Tls((&'conn mut (), core::convert::Infallible)), // Variant is impossible to create, but we need it to avoid "unused lifetime" warning +} + +#[cfg(feature = "defmt")] +impl defmt::Format for HttpConnection<'_, C> +where + C: Read + Write, +{ + fn format(&self, fmt: defmt::Formatter) { + match self { + HttpConnection::Plain(_) => defmt::write!(fmt, "Plain"), + HttpConnection::PlainBuffered(_) => defmt::write!(fmt, "PlainBuffered"), + HttpConnection::Tls(_) => defmt::write!(fmt, "Tls"), + } + } +} + +impl core::fmt::Debug for HttpConnection<'_, C> +where + C: Read + Write, +{ + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match self { + HttpConnection::Plain(_) => f.debug_tuple("Plain").finish(), + HttpConnection::PlainBuffered(_) => f.debug_tuple("PlainBuffered").finish(), + HttpConnection::Tls(_) => f.debug_tuple("Tls").finish(), + } + } } impl<'conn, T> HttpConnection<'conn, T> where T: Read + Write, { + /// Turn the request into a buffered request. + /// + /// This is only relevant if no TLS is used, as `embedded-tls` buffers internally and we reuse + /// its buffer for non-TLS connections. + pub fn into_buffered<'buf>(self, tx_buf: &'buf mut [u8]) -> HttpConnection<'buf, T> + where + 'conn: 'buf, + { + match self { + HttpConnection::Plain(conn) => HttpConnection::PlainBuffered(BufferedWrite::new(conn, tx_buf)), + HttpConnection::PlainBuffered(conn) => HttpConnection::PlainBuffered(conn), + HttpConnection::Tls(tls) => HttpConnection::Tls(tls), + } + } + /// Send a request on an established connection. /// /// The request is sent in its raw form without any base path from the resource. @@ -187,10 +231,10 @@ where /// /// The response is returned. pub async fn send<'buf, B: RequestBody>( - &'conn mut self, + &'buf mut self, request: Request<'conn, B>, rx_buf: &'buf mut [u8], - ) -> Result>, Error> { + ) -> Result>, Error> { request.write(self).await?; Response::read(self, request.method, rx_buf).await } @@ -210,7 +254,6 @@ where async fn read(&mut self, buf: &mut [u8]) -> Result { match self { Self::Plain(conn) => conn.read(buf).await.map_err(|e| e.kind()), - #[cfg(feature = "embedded-tls")] Self::PlainBuffered(conn) => conn.read(buf).await.map_err(|e| e.kind()), #[cfg(feature = "embedded-tls")] Self::Tls(conn) => conn.read(buf).await.map_err(|e| e.kind()), @@ -227,7 +270,6 @@ where async fn write(&mut self, buf: &[u8]) -> Result { match self { Self::Plain(conn) => conn.write(buf).await.map_err(|e| e.kind()), - #[cfg(feature = "embedded-tls")] Self::PlainBuffered(conn) => conn.write(buf).await.map_err(|e| e.kind()), #[cfg(feature = "embedded-tls")] Self::Tls(conn) => conn.write(buf).await.map_err(|e| e.kind()), @@ -239,7 +281,6 @@ where async fn flush(&mut self) -> Result<(), Self::Error> { match self { Self::Plain(conn) => conn.flush().await.map_err(|e| e.kind()), - #[cfg(feature = "embedded-tls")] Self::PlainBuffered(conn) => conn.flush().await.map_err(|e| e.kind()), #[cfg(feature = "embedded-tls")] Self::Tls(conn) => conn.flush().await.map_err(|e| e.kind()), @@ -252,16 +293,16 @@ where /// A HTTP request handle /// /// The underlying connection is closed when drop'ed. -pub struct HttpRequestHandle<'m, C, B> +pub struct HttpRequestHandle<'conn, C, B> where C: Read + Write, B: RequestBody, { - pub conn: C, - request: Option>, + pub conn: HttpConnection<'conn, C>, + request: Option>, } -impl<'m, C, B> HttpRequestHandle<'m, C, B> +impl<'conn, C, B> HttpRequestHandle<'conn, C, B> where C: Read + Write, B: RequestBody, @@ -270,12 +311,12 @@ where /// /// This is only relevant if no TLS is used, as `embedded-tls` buffers internally and we reuse /// its buffer for non-TLS connections. - pub fn into_buffered<'buf>( - self, - tx_buf: &'buf mut [u8], - ) -> HttpRequestHandle<'m, BufferedWrite<'buf, buffered_io_adapter::ConnErrorAdapter>, B> { + pub fn into_buffered<'buf>(self, tx_buf: &'buf mut [u8]) -> HttpRequestHandle<'buf, C, B> + where + 'conn: 'buf, + { HttpRequestHandle { - conn: BufferedWrite::new(buffered_io_adapter::ConnErrorAdapter(self.conn), tx_buf), + conn: self.conn.into_buffered(tx_buf), request: self.request, } } @@ -285,7 +326,10 @@ where /// The response headers are stored in the provided rx_buf, which should be sized to contain at least the response headers. /// /// The response is returned. - pub async fn send<'buf, 'conn>(&'conn mut self, rx_buf: &'buf mut [u8]) -> Result, Error> { + pub async fn send<'buf>( + &'buf mut self, + rx_buf: &'buf mut [u8], + ) -> Result>, Error> { let request = self.request.take().ok_or(Error::AlreadySent)?.build(); request.write(&mut self.conn).await?; Response::read(&mut self.conn, request.method, rx_buf).await @@ -343,7 +387,7 @@ pub struct HttpResource<'res, C> where C: Read + Write, { - pub conn: C, + pub conn: HttpConnection<'res, C>, pub host: &'res str, pub base_path: &'res str, } @@ -356,25 +400,22 @@ where /// /// This is only relevant if no TLS is used, as `embedded-tls` buffers internally and we reuse /// its buffer for non-TLS connections. - pub fn into_buffered<'buf>( - self, - tx_buf: &'buf mut [u8], - ) -> HttpResource<'res, BufferedWrite<'buf, buffered_io_adapter::ConnErrorAdapter>> { + pub fn into_buffered<'buf>(self, tx_buf: &'buf mut [u8]) -> HttpResource<'buf, C> + where + 'res: 'buf, + { HttpResource { - conn: BufferedWrite::new(buffered_io_adapter::ConnErrorAdapter(self.conn), tx_buf), + conn: self.conn.into_buffered(tx_buf), host: self.host, base_path: self.base_path, } } - pub fn request<'conn, 'm>( - &'conn mut self, + pub fn request<'req>( + &'req mut self, method: Method, - path: &'m str, - ) -> HttpResourceRequestBuilder<'conn, 'res, 'm, C, ()> - where - 'res: 'm, - { + path: &'req str, + ) -> HttpResourceRequestBuilder<'req, 'res, C, ()> { HttpResourceRequestBuilder { conn: &mut self.conn, request: Request::new(method, path).host(self.host), @@ -383,42 +424,27 @@ where } /// Create a new scoped GET http request. - pub fn get<'conn, 'm>(&'conn mut self, path: &'m str) -> HttpResourceRequestBuilder<'conn, 'res, 'm, C, ()> - where - 'res: 'm, - { + pub fn get<'req>(&'req mut self, path: &'req str) -> HttpResourceRequestBuilder<'req, 'res, C, ()> { self.request(Method::GET, path) } /// Create a new scoped POST http request. - pub fn post<'conn, 'm>(&'conn mut self, path: &'m str) -> HttpResourceRequestBuilder<'conn, 'res, 'm, C, ()> - where - 'res: 'm, - { + pub fn post<'req>(&'req mut self, path: &'req str) -> HttpResourceRequestBuilder<'req, 'res, C, ()> { self.request(Method::POST, path) } /// Create a new scoped PUT http request. - pub fn put<'conn, 'm>(&'conn mut self, path: &'m str) -> HttpResourceRequestBuilder<'conn, 'res, 'm, C, ()> - where - 'res: 'm, - { + pub fn put<'req>(&'req mut self, path: &'req str) -> HttpResourceRequestBuilder<'req, 'res, C, ()> { self.request(Method::PUT, path) } /// Create a new scoped DELETE http request. - pub fn delete<'conn, 'm>(&'conn mut self, path: &'m str) -> HttpResourceRequestBuilder<'conn, 'res, 'm, C, ()> - where - 'res: 'm, - { + pub fn delete<'req>(&'req mut self, path: &'req str) -> HttpResourceRequestBuilder<'req, 'res, C, ()> { self.request(Method::DELETE, path) } /// Create a new scoped HEAD http request. - pub fn head<'conn, 'm>(&'conn mut self, path: &'m str) -> HttpResourceRequestBuilder<'conn, 'res, 'm, C, ()> - where - 'res: 'm, - { + pub fn head<'req>(&'req mut self, path: &'req str) -> HttpResourceRequestBuilder<'req, 'res, C, ()> { self.request(Method::HEAD, path) } @@ -428,28 +454,28 @@ where /// The response headers are stored in the provided rx_buf, which should be sized to contain at least the response headers. /// /// The response is returned. - pub async fn send<'buf, 'conn, B: RequestBody>( - &'conn mut self, - mut request: Request<'res, B>, - rx_buf: &'buf mut [u8], - ) -> Result, Error> { + pub async fn send<'req, B: RequestBody>( + &'req mut self, + mut request: Request<'req, B>, + rx_buf: &'req mut [u8], + ) -> Result>, Error> { request.base_path = Some(self.base_path); request.write(&mut self.conn).await?; Response::read(&mut self.conn, request.method, rx_buf).await } } -pub struct HttpResourceRequestBuilder<'conn, 'res, 'm, C, B> +pub struct HttpResourceRequestBuilder<'req, 'conn, C, B> where C: Read + Write, B: RequestBody, { - conn: &'conn mut C, - base_path: &'res str, - request: DefaultRequestBuilder<'m, B>, + conn: &'req mut HttpConnection<'conn, C>, + base_path: &'req str, + request: DefaultRequestBuilder<'req, B>, } -impl<'conn, 'res, 'm, C, B> HttpResourceRequestBuilder<'conn, 'res, 'm, C, B> +impl<'req, 'conn, C, B> HttpResourceRequestBuilder<'req, 'conn, C, B> where C: Read + Write, B: RequestBody, @@ -460,7 +486,11 @@ where /// The response headers are stored in the provided rx_buf, which should be sized to contain at least the response headers. /// /// The response is returned. - pub async fn send<'buf>(self, rx_buf: &'buf mut [u8]) -> Result, Error> { + pub async fn send<'buf>(self, rx_buf: &'buf mut [u8]) -> Result>, Error> + where + 'conn: 'req + 'buf, + 'req: 'buf, + { let conn = self.conn; let mut request = self.request.build(); request.base_path = Some(self.base_path); @@ -469,19 +499,19 @@ where } } -impl<'conn, 'res, 'm, C, B> RequestBuilder<'m, B> for HttpResourceRequestBuilder<'conn, 'res, 'm, C, B> +impl<'req, 'conn, C, B> RequestBuilder<'req, B> for HttpResourceRequestBuilder<'req, 'conn, C, B> where C: Read + Write, B: RequestBody, { - type WithBody = HttpResourceRequestBuilder<'conn, 'res, 'm, C, T>; + type WithBody = HttpResourceRequestBuilder<'req, 'conn, C, T>; - fn headers(mut self, headers: &'m [(&'m str, &'m str)]) -> Self { + fn headers(mut self, headers: &'req [(&'req str, &'req str)]) -> Self { self.request = self.request.headers(headers); self } - fn path(mut self, path: &'m str) -> Self { + fn path(mut self, path: &'req str) -> Self { self.request = self.request.path(path); self } @@ -494,7 +524,7 @@ where } } - fn host(mut self, host: &'m str) -> Self { + fn host(mut self, host: &'req str) -> Self { self.request = self.request.host(host); self } @@ -504,70 +534,12 @@ where self } - fn basic_auth(mut self, username: &'m str, password: &'m str) -> Self { + fn basic_auth(mut self, username: &'req str, password: &'req str) -> Self { self.request = self.request.basic_auth(username, password); self } - fn build(self) -> Request<'m, B> { + fn build(self) -> Request<'req, B> { self.request.build() } } - -mod buffered_io_adapter { - use embedded_io::{Error as _, ErrorType, ReadExactError}; - use embedded_io_async::{Read, Write}; - - pub struct Error(embedded_io::ErrorKind); - - impl core::fmt::Debug for Error { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.0.fmt(f) - } - } - - impl embedded_io_async::Error for Error { - fn kind(&self) -> embedded_io::ErrorKind { - self.0 - } - } - - pub struct ConnErrorAdapter(pub C); - - impl ErrorType for ConnErrorAdapter { - type Error = Error; - } - - impl Write for ConnErrorAdapter - where - C: Write, - { - async fn write(&mut self, buf: &[u8]) -> Result { - self.0.write(buf).await.map_err(|e| Error(e.kind())) - } - - async fn flush(&mut self) -> Result<(), Self::Error> { - self.0.flush().await.map_err(|e| Error(e.kind())) - } - - async fn write_all(&mut self, buf: &[u8]) -> Result<(), Self::Error> { - self.0.write_all(buf).await.map_err(|e| Error(e.kind())) - } - } - - impl Read for ConnErrorAdapter - where - C: Read, - { - async fn read(&mut self, buf: &mut [u8]) -> Result { - self.0.read(buf).await.map_err(|e| Error(e.kind())) - } - - async fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), ReadExactError> { - self.0.read_exact(buf).await.map_err(|e| match e { - ReadExactError::UnexpectedEof => ReadExactError::UnexpectedEof, - ReadExactError::Other(e) => ReadExactError::Other(Error(e.kind())), - }) - } - } -} diff --git a/src/concat.rs b/src/concat.rs deleted file mode 100644 index 0d0853f..0000000 --- a/src/concat.rs +++ /dev/null @@ -1,88 +0,0 @@ -use embedded_io::{ErrorKind, ErrorType}; -use embedded_io_async::Read; - -pub struct ConcatReader -where - A: Read, - B: Read, -{ - first: A, - last: B, - first_exhausted: bool, -} - -impl ConcatReader -where - A: Read, - B: Read, -{ - pub const fn new(first: A, last: B) -> Self { - Self { - first, - last, - first_exhausted: false, - } - } -} - -pub enum ConcatReaderError -where - A: Read, - B: Read, -{ - First(A::Error), - Last(B::Error), -} - -impl core::fmt::Debug for ConcatReaderError -where - A: Read, - B: Read, -{ - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - match self { - Self::First(arg0) => f.debug_tuple("First").field(arg0).finish(), - Self::Last(arg0) => f.debug_tuple("Last").field(arg0).finish(), - } - } -} - -impl embedded_io::Error for ConcatReaderError -where - A: Read, - B: Read, -{ - fn kind(&self) -> ErrorKind { - match self { - ConcatReaderError::First(a) => a.kind(), - ConcatReaderError::Last(b) => b.kind(), - } - } -} - -impl ErrorType for ConcatReader -where - A: Read, - B: Read, -{ - type Error = ConcatReaderError; -} - -impl Read for ConcatReader -where - A: Read, - B: Read, -{ - async fn read(&mut self, buf: &mut [u8]) -> Result { - if !self.first_exhausted { - let len = self.first.read(buf).await.map_err(ConcatReaderError::First)?; - if len > 0 { - return Ok(len); - } - - self.first_exhausted = true; - } - - self.last.read(buf).await.map_err(ConcatReaderError::Last) - } -} diff --git a/src/lib.rs b/src/lib.rs index cf47a6d..2a13cce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,8 +11,8 @@ use embedded_io_async::ReadExactError; mod fmt; pub mod client; -mod concat; pub mod headers; +mod reader; pub mod request; pub mod response; diff --git a/src/reader.rs b/src/reader.rs new file mode 100644 index 0000000..62cfb83 --- /dev/null +++ b/src/reader.rs @@ -0,0 +1,118 @@ +use embedded_io::{Error, ErrorKind, ErrorType}; +use embedded_io_async::{BufRead, Read, Write}; + +use crate::client::HttpConnection; + +struct ReadBuffer<'buf> { + buffer: &'buf mut [u8], + loaded: usize, +} + +impl<'buf> ReadBuffer<'buf> { + fn new(buffer: &'buf mut [u8], loaded: usize) -> Self { + Self { buffer, loaded } + } +} + +impl ReadBuffer<'_> { + fn is_empty(&self) -> bool { + self.loaded == 0 + } + + fn read(&mut self, buf: &mut [u8]) -> Result { + let amt = self.loaded.min(buf.len()); + buf[..amt].copy_from_slice(&self.buffer[0..amt]); + + self.consume(amt); + + Ok(amt) + } + + fn fill_buf(&mut self) -> Result<&[u8], ErrorKind> { + Ok(&self.buffer[..self.loaded]) + } + + fn consume(&mut self, amt: usize) -> usize { + let to_consume = amt.min(self.loaded); + + self.buffer.copy_within(to_consume..self.loaded, 0); + self.loaded -= to_consume; + + amt - to_consume + } +} + +pub struct BufferingReader<'buf, B> +where + B: Read, +{ + buffer: ReadBuffer<'buf>, + stream: &'buf mut B, +} + +impl<'buf, 'conn, B> BufferingReader<'buf, B> +where + B: Read, +{ + pub fn new(buffer: &'buf mut [u8], loaded: usize, stream: &'buf mut B) -> Self { + Self { + buffer: ReadBuffer::new(buffer, loaded), + stream, + } + } +} + +impl ErrorType for BufferingReader<'_, C> +where + C: Read, +{ + type Error = ErrorKind; +} + +impl Read for BufferingReader<'_, C> +where + C: Read, +{ + async fn read(&mut self, buf: &mut [u8]) -> Result { + if !self.buffer.is_empty() { + let amt = self.buffer.read(buf)?; + return Ok(amt); + } + + self.stream.read(buf).await.map_err(|e| e.kind()) + } +} + +impl BufRead for BufferingReader<'_, HttpConnection<'_, C>> +where + C: Read + Write, +{ + async fn fill_buf(&mut self) -> Result<&[u8], ErrorKind> { + // We need to consume the loaded bytes before we read mode. + if self.buffer.is_empty() { + // embedded-tls has its own internal buffer, let's prefer that if we can + #[cfg(feature = "embedded-tls")] + if let HttpConnection::Tls(ref mut tls) = self.stream { + return tls.fill_buf().await.map_err(|e| e.kind()); + } + + self.buffer.loaded = self.stream.read(&mut self.buffer.buffer).await?; + } + + self.buffer.fill_buf() + } + + fn consume(&mut self, amt: usize) { + // It's possible that the user requested more bytes to be consumed than loaded. Especially + // since it's also possible that nothing is loaded, after we consumed all and are using + // embedded-tls's buffering. + let unconsumed = self.buffer.consume(amt); + + if unconsumed > 0 { + #[cfg(feature = "embedded-tls")] + if let HttpConnection::Tls(tls) = &mut self.stream { + tls.consume(unconsumed); + } + } + } +} diff --git a/src/response.rs b/src/response.rs index 8b10017..3132865 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,20 +1,20 @@ use embedded_io::{Error as _, ErrorType}; -use embedded_io_async::Read; +use embedded_io_async::{BufRead, Read}; use heapless::Vec; -use crate::concat::ConcatReader; use crate::headers::{ContentType, KeepAlive, TransferEncoding}; +use crate::reader::BufferingReader; use crate::request::Method; use crate::Error; /// Type representing a parsed HTTP response. #[derive(Debug)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub struct Response<'buf, 'conn, C> +pub struct Response<'buf, C> where C: Read, { - conn: &'conn mut C, + conn: &'buf mut C, /// The method used to create the response. method: Method, /// The HTTP response status code. @@ -32,16 +32,12 @@ where raw_body_read: usize, } -impl<'buf, 'conn, C> Response<'buf, 'conn, C> +impl<'buf, C> Response<'buf, C> where C: Read, { // Read at least the headers from the connection. - pub async fn read( - conn: &'conn mut C, - method: Method, - header_buf: &'buf mut [u8], - ) -> Result, Error> { + pub async fn read(conn: &'buf mut C, method: Method, header_buf: &'buf mut [u8]) -> Result { let mut header_len = 0; let mut pos = 0; while pos < header_buf.len() { @@ -139,7 +135,7 @@ where } /// Get the response body - pub fn body(self) -> ResponseBody<'buf, 'conn, C> { + pub fn body(self) -> ResponseBody<'buf, C> { let reader_hint = if self.method == Method::HEAD { // Head requests does not have a body so we return an empty reader ReaderHint::Empty @@ -151,18 +147,14 @@ where ReaderHint::ToEnd }; - // Move the body part of the bytes in the header buffer to the beginning of the buffer - let header_buf = self.header_buf; - for i in 0..self.raw_body_read { - header_buf[i] = header_buf[self.header_len + i]; - } - // From now on, the header buffer is now the body buffer as all header bytes have been overwritten - let body_buf = header_buf; + // Move the body part of the bytes in the header buffer to the beginning of the buffer. + self.header_buf + .copy_within(self.header_len..self.header_len + self.raw_body_read, 0); ResponseBody { conn: self.conn, reader_hint, - body_buf, + body_buf: self.header_buf, raw_body_read: self.raw_body_read, } } @@ -187,11 +179,11 @@ impl<'a> Iterator for HeaderIterator<'a> { /// This type contains the original header buffer provided to `read_headers`, /// now renamed to `body_buf`, the number of read body bytes that are available /// in `body_buf`, and a reader to be used for reading the remaining body. -pub struct ResponseBody<'buf, 'conn, C> +pub struct ResponseBody<'buf, C> where C: Read, { - conn: &'conn mut C, + conn: &'buf mut C, reader_hint: ReaderHint, /// The number of raw bytes read from the body and available in the beginning of `body_buf`. raw_body_read: usize, @@ -206,12 +198,12 @@ enum ReaderHint { ToEnd, // https://www.rfc-editor.org/rfc/rfc7230#section-3.3.3 pt. 7: Until end of connection } -impl<'buf, 'conn, C> ResponseBody<'buf, 'conn, C> +impl<'buf, C> ResponseBody<'buf, C> where C: Read, { - pub fn reader(self) -> BodyReader> { - let raw_body = ConcatReader::new(&self.body_buf[..self.raw_body_read], self.conn); + pub fn reader(self) -> BodyReader> { + let raw_body = BufferingReader::new(self.body_buf, self.raw_body_read, self.conn); match self.reader_hint { ReaderHint::Empty => BodyReader::Empty, @@ -221,21 +213,23 @@ where }), ReaderHint::Chunked => BodyReader::Chunked(ChunkedBodyReader { raw_body, - chunk_remaining: 0, - empty_chunk_received: false, + chunk_remaining: ChunkState::NoChunk, }), ReaderHint::ToEnd => BodyReader::ToEnd(raw_body), } } } -impl<'buf, 'conn, C: Read> ResponseBody<'buf, 'conn, C> { +impl<'buf, C> ResponseBody<'buf, C> +where + C: Read, +{ /// Read the entire body into the buffer originally provided [`Response::read()`]. /// This requires that this original buffer is large enough to contain the entire body. /// /// This is not valid for chunked responses as it requires that the body bytes over-read /// while parsing the http response header would be available for the body reader. - /// For this case, of if the original buffer is not large enough, use + /// For this case, or if the original buffer is not large enough, use /// [`BodyReader::read_to_end()`] instead from the reader returned by [`ResponseBody::reader()`]. pub async fn read_to_end(self) -> Result<&'buf mut [u8], Error> { // We can only read responses with Content-Length header to end using the body_buf buffer, @@ -279,10 +273,7 @@ impl<'buf, 'conn, C: Read> ResponseBody<'buf, 'conn, C> { } /// A body reader -pub enum BodyReader -where - B: Read, -{ +pub enum BodyReader { Empty, FixedLength(FixedLengthBodyReader), Chunked(ChunkedBodyReader), @@ -306,8 +297,13 @@ where let is_done = match self { BodyReader::Empty => true, - BodyReader::FixedLength(reader) => reader.remaining == 0, - BodyReader::Chunked(reader) => reader.empty_chunk_received, + BodyReader::FixedLength(reader) => { + if reader.remaining > 0 { + warn!("FixedLength: {} bytes remained", reader.remaining); + } + reader.remaining == 0 + } + BodyReader::Chunked(reader) => reader.chunk_remaining == ChunkState::Empty, BodyReader::ToEnd(_) => true, }; @@ -320,23 +316,20 @@ where async fn discard(&mut self) -> Result { let mut body_len = 0; + let mut buf = [0; 128]; loop { - let mut trash = [0; 256]; - let len = self.read(&mut trash).await?; - if len == 0 { + let buf = self.read(&mut buf).await?; + if buf == 0 { break; } - body_len += len; + body_len += buf; } Ok(body_len) } } -impl ErrorType for BodyReader -where - B: Read, -{ +impl ErrorType for BodyReader { type Error = Error; } @@ -354,43 +347,170 @@ where } } +impl BufRead for BodyReader +where + B: BufRead + Read, +{ + async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> { + match self { + BodyReader::Empty => Ok(&[]), + BodyReader::FixedLength(reader) => reader.fill_buf().await, + BodyReader::Chunked(reader) => reader.fill_buf().await, + BodyReader::ToEnd(conn) => conn.fill_buf().await.map_err(|e| Error::Network(e.kind())), + } + } + + fn consume(&mut self, amt: usize) { + match self { + BodyReader::Empty => {} + BodyReader::FixedLength(reader) => reader.consume(amt), + BodyReader::Chunked(reader) => reader.consume(amt), + BodyReader::ToEnd(conn) => conn.consume(amt), + } + } +} + /// Fixed length response body reader -pub struct FixedLengthBodyReader { +pub struct FixedLengthBodyReader { raw_body: B, remaining: usize, } -impl ErrorType for FixedLengthBodyReader { +impl ErrorType for FixedLengthBodyReader { type Error = Error; } -impl Read for FixedLengthBodyReader { +impl Read for FixedLengthBodyReader +where + C: Read, +{ async fn read(&mut self, buf: &mut [u8]) -> Result { if self.remaining == 0 { return Ok(0); } - let to_read = usize::min(self.remaining, buf.len()); - let len = self.raw_body.read(&mut buf[..to_read]).await.map_err(|e| e.kind())?; - if len > 0 { - self.remaining -= len; - Ok(len) + + let read = self.raw_body.read(buf).await.map_err(|e| Error::Network(e.kind()))?; + self.remaining -= read; + + Ok(read) + } +} + +impl BufRead for FixedLengthBodyReader +where + C: BufRead + Read, +{ + async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> { + if self.remaining == 0 { + return Ok(&[]); + } + + let loaded = self + .raw_body + .fill_buf() + .await + .map_err(|e| Error::Network(e.kind())) + .map(|data| &data[..data.len().min(self.remaining)])?; + + if loaded.is_empty() { + return Err(Error::ConnectionClosed); + } + + Ok(loaded) + } + + fn consume(&mut self, amt: usize) { + let amt = amt.min(self.remaining); + self.remaining -= amt; + self.raw_body.consume(amt) + } +} + +#[derive(Clone, Copy, PartialEq, Eq)] +enum ChunkState { + NoChunk, + NotEmpty(u32), + Empty, +} + +impl ChunkState { + fn consume(&mut self, amt: usize) -> usize { + if let ChunkState::NotEmpty(remaining) = self { + let consumed = (amt as u32).min(*remaining); + *remaining -= consumed; + consumed as usize } else { - Err(Error::ConnectionClosed) + 0 + } + } + + fn len(self) -> usize { + if let ChunkState::NotEmpty(len) = self { + len as usize + } else { + 0 } } } /// Chunked response body reader -pub struct ChunkedBodyReader -where - B: Read, -{ +pub struct ChunkedBodyReader { raw_body: B, - chunk_remaining: u32, - empty_chunk_received: bool, + chunk_remaining: ChunkState, } -impl ChunkedBodyReader { +impl ChunkedBodyReader +where + C: Read, +{ + async fn read_next_chunk_length(&mut self) -> Result<(), Error> { + let mut header_buf = [0; 8 + 2]; // 32 bit hex + \r + \n + let mut total_read = 0; + + 'read_size: loop { + let mut byte = 0; + self.raw_body + .read_exact(core::slice::from_mut(&mut byte)) + .await + .map_err(|e| Error::from(e).kind())?; + + if byte != b'\n' { + header_buf[total_read] = byte; + total_read += 1; + + if total_read == header_buf.len() { + return Err(Error::Codec); + } + } else { + break 'read_size; + } + } + + if header_buf[total_read - 1] != b'\r' { + return Err(Error::Codec); + } + + let hex_digits = total_read - 1; + + // Prepend hex with zeros + let mut hex = [b'0'; 8]; + hex[8 - hex_digits..].copy_from_slice(&header_buf[..hex_digits]); + + let mut bytes = [0; 4]; + hex::decode_to_slice(hex, &mut bytes).map_err(|_| Error::Codec)?; + + let chunk_length = u32::from_be_bytes(bytes); + + debug!("Chunk length: {}", chunk_length); + + self.chunk_remaining = match chunk_length { + 0 => ChunkState::Empty, + other => ChunkState::NotEmpty(other), + }; + + Ok(()) + } + async fn read_chunk_end(&mut self) -> Result<(), Error> { // All chunks are terminated with a \r\n let mut newline_buf = [0; 2]; @@ -401,108 +521,68 @@ impl ChunkedBodyReader { } Ok(()) } + + /// Handles chunk boundary and returns the number of bytes in the current (or new) chunk. + async fn handle_chunk_boundary(&mut self) -> Result { + match self.chunk_remaining { + ChunkState::NoChunk => self.read_next_chunk_length().await?, + + ChunkState::NotEmpty(0) => { + // The current chunk is currently empty, advance into a new chunk... + self.read_chunk_end().await?; + self.read_next_chunk_length().await?; + } + + ChunkState::NotEmpty(_) => {} + + ChunkState::Empty => return Ok(0), + } + + Ok(self.chunk_remaining.len()) + } } -impl ErrorType for ChunkedBodyReader { +impl ErrorType for ChunkedBodyReader { type Error = Error; } -impl Read for ChunkedBodyReader { +impl Read for ChunkedBodyReader +where + C: Read, +{ async fn read(&mut self, buf: &mut [u8]) -> Result { - if buf.is_empty() || self.empty_chunk_received { - return Ok(0); - } + let remaining = self.handle_chunk_boundary().await?; + let max_len = buf.len().min(remaining); - if self.chunk_remaining == 0 { - // The current chunk is currently empty, advance into a new chunk... - - let mut header_buf = [0; 8 + 2]; // 32 bit hex + \r + \n - let mut total_read = 0; - - // For now, limit the number of bytes that we can read to avoid reading into a header after the current - let mut max_read = 3; // Single hex digit + \r + \n - loop { - let read = self - .raw_body - .read(&mut header_buf[total_read..max_read]) - .await - .map_err(|e| e.kind())?; - if read == 0 { - return Err(Error::ConnectionClosed); - } - total_read += read; - - // Decode the chunked header - let header_and_body = &header_buf[..total_read]; - if let Some(nl) = header_and_body.iter().position(|x| *x == b'\n') { - let header = &header_and_body[..nl + 1]; - if nl == 0 || header[nl - 1] != b'\r' { - return Err(Error::Codec); - } - let hex_digits = nl - 1; - // Prepend hex with zeros - let mut hex = [b'0'; 8]; - hex[8 - hex_digits..].copy_from_slice(&header[..hex_digits]); - let mut bytes = [0; 4]; - hex::decode_to_slice(hex, &mut bytes).map_err(|_| Error::Codec)?; - self.chunk_remaining = u32::from_be_bytes(bytes); - - if self.chunk_remaining == 0 { - self.empty_chunk_received = true; - } + let len = self + .raw_body + .read(&mut buf[..max_len]) + .await + .map_err(|e| Error::Network(e.kind()))?; - // Return the excess body bytes read during the header, if any - let excess_body_read = header_and_body.len() - header.len(); - if excess_body_read > 0 { - if excess_body_read > self.chunk_remaining as usize { - // We have read chunk bytes that exceed the size of the chunk - return Err(Error::Codec); - } - - buf[..excess_body_read].copy_from_slice(&header_and_body[header.len()..]); - self.chunk_remaining -= excess_body_read as u32; - return Ok(excess_body_read); - } + self.chunk_remaining.consume(len); - break; - } + Ok(len) + } +} - if total_read >= 3 { - // At least three bytes were read and a \n was not found - // This means that the chunk length is at least double-digit hex - // which in turn means that it is impossible for another header to - // be present within the 10 bytes header buffer. - // 10 is the length of the max header "ffffffff\r\n". - // For example, 10\r\nXXXXXXYYYYYYYYYY is more than 10 bytes - // - 10\r\n is the header - // - XXXXXX are the excess body 6 bytes that we may read - // - YYYYYYYYYY are the remaining unread chunk bytes. - // However, for reading these excess bytes into the actual chunk payload, - // the user buffer must be large enough to actually contain the excess read bytes. - // A \n was not found, and we can read that + buf.len(). - max_read = core::cmp::min(total_read + 1 + buf.len(), 10); - } - } - } +impl BufRead for ChunkedBodyReader +where + C: BufRead + Read, +{ + async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> { + let remaining = self.handle_chunk_boundary().await?; - if self.empty_chunk_received { - self.read_chunk_end().await?; - Ok(0) - } else { - let max_len = usize::min(self.chunk_remaining as usize, buf.len()); - let len = self.raw_body.read(&mut buf[..max_len]).await.map_err(|e| e.kind())?; - if len == 0 { - return Err(Error::ConnectionClosed); - } + let buf = self.raw_body.fill_buf().await.map_err(|e| Error::Network(e.kind()))?; - self.chunk_remaining -= len as u32; + let len = buf.len().min(remaining); - if self.chunk_remaining == 0 { - self.read_chunk_end().await?; - } + Ok(&buf[..len]) + } - Ok(len) - } + fn consume(&mut self, amt: usize) { + let consumed = self.chunk_remaining.consume(amt); + self.raw_body.consume(consumed); } } @@ -601,7 +681,13 @@ impl From for Status { #[cfg(test)] mod tests { - use super::*; + use embedded_io_async::Read; + + use crate::{ + reader::BufferingReader, + request::Method, + response::{ChunkState, ChunkedBodyReader, Response}, + }; #[tokio::test] async fn can_read_with_content_length_with_same_buffer() { @@ -708,11 +794,11 @@ mod tests { #[tokio::test] async fn chunked_body_reader_can_read_with_large_buffer() { - let raw_body = "1\r\nX\r\n10\r\nYYYYYYYYYYYYYYYY\r\n0\r\n\r\n".as_bytes(); + let mut raw_body = b"1\r\nX\r\n10\r\nYYYYYYYYYYYYYYYY\r\n0\r\n\r\n".as_slice(); + let mut read_buffer = [0; 128]; let mut reader = ChunkedBodyReader { - raw_body, - chunk_remaining: 0, - empty_chunk_received: false, + raw_body: BufferingReader::new(&mut read_buffer, 0, &mut raw_body), + chunk_remaining: ChunkState::NoChunk, }; let mut body = [0; 17]; @@ -725,14 +811,14 @@ mod tests { #[tokio::test] async fn chunked_body_reader_can_read_with_tiny_buffer() { - let raw_body = "1\r\nX\r\n10\r\nYYYYYYYYYYYYYYYY\r\n0\r\n\r\n".as_bytes(); + let mut raw_body = b"1\r\nX\r\n10\r\nYYYYYYYYYYYYYYYY\r\n0\r\n\r\n".as_slice(); + let mut read_buffer = [0; 128]; let mut reader = ChunkedBodyReader { - raw_body, - chunk_remaining: 0, - empty_chunk_received: false, + raw_body: BufferingReader::new(&mut read_buffer, 0, &mut raw_body), + chunk_remaining: ChunkState::NoChunk, }; - let mut body = Vec::::new(); + let mut body = heapless::Vec::::new(); for _ in 0..17 { let mut buf = [0; 1]; assert_eq!(1, reader.read(&mut buf).await.unwrap()); diff --git a/tests/client.rs b/tests/client.rs index d284899..77cf375 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -1,6 +1,7 @@ #![feature(async_fn_in_trait)] #![allow(incomplete_features)] use embedded_io_adapters::tokio_1::FromTokio; +use embedded_io_async::BufRead; use embedded_nal_async::{AddrType, IpAddr, Ipv4Addr}; use hyper::server::conn::Http; use hyper::service::{make_service_fn, service_fn}; @@ -52,15 +53,17 @@ async fn test_request_response_notls() { let url = format!("http://127.0.0.1:{}", addr.port()); let mut client = HttpClient::new(&TCP, &LOOPBACK_DNS); let mut rx_buf = [0; 4096]; - let mut request = client - .request(Method::POST, &url) - .await - .unwrap() - .body(b"PING".as_slice()) - .content_type(ContentType::TextPlain); - let response = request.send(&mut rx_buf).await.unwrap(); - let body = response.body().read_to_end().await; - assert_eq!(body.unwrap(), b"PING"); + for _ in 0..2 { + let mut request = client + .request(Method::POST, &url) + .await + .unwrap() + .body(b"PING".as_slice()) + .content_type(ContentType::TextPlain); + let response = request.send(&mut rx_buf).await.unwrap(); + let body = response.body().read_to_end().await; + assert_eq!(body.unwrap(), b"PING"); + } tx.send(()).unwrap(); t.await.unwrap(); @@ -104,6 +107,56 @@ async fn test_resource_notls() { t.await.unwrap(); } +#[tokio::test] +async fn test_resource_notls_bufread() { + setup(); + let addr = ([127, 0, 0, 1], 0).into(); + + let service = make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(echo)) }); + + let server = Server::bind(&addr).serve(service); + let addr = server.local_addr(); + + let (tx, rx) = oneshot::channel(); + let t = tokio::spawn(async move { + tokio::select! { + _ = server => {} + _ = rx => {} + } + }); + + let url = format!("http://127.0.0.1:{}", addr.port()); + let mut client = HttpClient::new(&TCP, &LOOPBACK_DNS); + let mut rx_buf = [0; 4096]; + let mut resource = client.resource(&url).await.unwrap(); + for _ in 0..2 { + let response = resource + .post("/") + .body(b"PING".as_slice()) + .content_type(ContentType::TextPlain) + .send(&mut rx_buf) + .await + .unwrap(); + let mut body_reader = response.body().reader(); + + let mut body = vec![]; + loop { + let buf = body_reader.fill_buf().await.unwrap(); + if buf.is_empty() { + break; + } + body.extend_from_slice(buf); + let buf_len = buf.len(); + body_reader.consume(buf_len); + } + + assert_eq!(body, b"PING"); + } + + tx.send(()).unwrap(); + t.await.unwrap(); +} + #[tokio::test] #[cfg(feature = "embedded-tls")] async fn test_resource_rustls() { @@ -207,6 +260,43 @@ async fn test_resource_drogue_cloud_sandbox() { } } +#[tokio::test] +async fn test_request_response_notls_buffered() { + setup(); + let addr = ([127, 0, 0, 1], 0).into(); + + let service = make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(echo)) }); + + let server = Server::bind(&addr).serve(service); + let addr = server.local_addr(); + + let (tx, rx) = oneshot::channel(); + let t = tokio::spawn(async move { + tokio::select! { + _ = server => {} + _ = rx => {} + } + }); + + let url = format!("http://127.0.0.1:{}", addr.port()); + let mut client = HttpClient::new(&TCP, &LOOPBACK_DNS); + let mut tx_buf = [0; 4096]; + let mut rx_buf = [0; 4096]; + let mut request = client + .request(Method::POST, &url) + .await + .unwrap() + .into_buffered(&mut tx_buf) + .body(b"PING".as_slice()) + .content_type(ContentType::TextPlain); + let response = request.send(&mut rx_buf).await.unwrap(); + let body = response.body().read_to_end().await; + assert_eq!(body.unwrap(), b"PING"); + + tx.send(()).unwrap(); + t.await.unwrap(); +} + fn load_certs(filename: &std::path::PathBuf) -> Vec { let certfile = std::fs::File::open(filename).expect("cannot open certificate file"); let mut reader = std::io::BufReader::new(certfile); diff --git a/tests/request.rs b/tests/request.rs index 5c43e21..d9a504e 100644 --- a/tests/request.rs +++ b/tests/request.rs @@ -1,6 +1,7 @@ use embedded_io_adapters::tokio_1::FromTokio; use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Server}; +use reqwless::client::HttpConnection; use reqwless::request::{Method, RequestBuilder}; use reqwless::{headers::ContentType, request::Request, response::Response}; use std::str::from_utf8; @@ -35,7 +36,7 @@ async fn test_request_response() { }); let stream = TcpStream::connect(addr).await.unwrap(); - let mut stream = FromTokio::new(stream); + let mut stream = HttpConnection::Plain(FromTokio::new(stream)); let request = Request::post("/") .body(b"PING".as_slice())