Skip to content

Commit

Permalink
Auto-buffer HTTP connections when TLS is set up
Browse files Browse the repository at this point in the history
  • Loading branch information
bugadani committed Oct 17, 2023
1 parent c24a8fc commit 15ad1ee
Showing 1 changed file with 36 additions and 8 deletions.
44 changes: 36 additions & 8 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,15 @@ where
#[cfg(not(feature = "embedded-tls"))]
Err(Error::InvalidUrl(nourl::Error::UnsupportedScheme))
} else {
#[cfg(feature = "embedded-tls")]
match self.tls.as_mut() {
Some(tls) => Ok(HttpConnection::PlainBuffered(BufferedWrite::new(
buffered_io_adapter::ConnErrorAdapter(conn),
tls.write_buffer,
))),
None => Ok(HttpConnection::Plain(conn)),
}
#[cfg(not(feature = "embedded-tls"))]
Ok(HttpConnection::Plain(conn))
}
}
Expand Down Expand Up @@ -155,15 +164,17 @@ where

/// Represents a HTTP connection that may be encrypted or unencrypted.
#[allow(clippy::large_enum_variant)]
pub enum HttpConnection<'m, T>
pub enum HttpConnection<'m, C>
where
T: Read + Write,
C: Read + Write,
{
Plain(T),
Plain(C),
#[cfg(feature = "embedded-tls")]
PlainBuffered(BufferedWrite<'m, buffered_io_adapter::ConnErrorAdapter<C>>),
#[cfg(feature = "embedded-tls")]
Tls(embedded_tls::TlsConnection<'m, T, embedded_tls::Aes128GcmSha256>),
Tls(embedded_tls::TlsConnection<'m, C, embedded_tls::Aes128GcmSha256>),
#[cfg(not(feature = "embedded-tls"))]
Tls(&'m mut T), // Variant is never actually created, but we need it to avoid "unused lifetime" warning
Tls((&'m mut (), core::convert::Infallible)), // Variant is impossible to create, but we need it to avoid "unused lifetime" warning
}

impl<'conn, T> HttpConnection<'conn, T>
Expand Down Expand Up @@ -200,7 +211,12 @@ where
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
match self {
Self::Plain(conn) => conn.read(buf).await.map_err(|e| e.kind()),
#[cfg(feature = "embedded-tls")]
Self::PlainBuffered(conn) => conn.read(buf).await.map_err(|e| e.kind()),
#[cfg(feature = "embedded-tls")]
Self::Tls(conn) => conn.read(buf).await.map_err(|e| e.kind()),
#[cfg(not(feature = "embedded-tls"))]
_ => unreachable!(),
}
}
}
Expand All @@ -212,14 +228,24 @@ where
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
match self {
Self::Plain(conn) => conn.write(buf).await.map_err(|e| e.kind()),
#[cfg(feature = "embedded-tls")]
Self::PlainBuffered(conn) => conn.write(buf).await.map_err(|e| e.kind()),
#[cfg(feature = "embedded-tls")]
Self::Tls(conn) => conn.write(buf).await.map_err(|e| e.kind()),
#[cfg(not(feature = "embedded-tls"))]
_ => unreachable!(),
}
}

async fn flush(&mut self) -> Result<(), Self::Error> {
match self {
Self::Plain(conn) => conn.flush().await.map_err(|e| e.kind()),
#[cfg(feature = "embedded-tls")]
Self::PlainBuffered(conn) => conn.flush().await.map_err(|e| e.kind()),
#[cfg(feature = "embedded-tls")]
Self::Tls(conn) => conn.flush().await.map_err(|e| e.kind()),
#[cfg(not(feature = "embedded-tls"))]
_ => unreachable!(),
}
}
}
Expand All @@ -241,9 +267,10 @@ where
C: Read + Write,
B: RequestBody,
{
/// Turn the request into a buffered request
/// Turn the request into a buffered request.
///
/// This is most likely only relevant for non-tls endpoints, as `embedded-tls` buffers internally.
/// This is only relevant if no TLS is used, as `embedded-tls` buffers internally and we reuse
/// its buffer for non-TLS connections.
pub fn into_buffered<'buf>(
self,
tx_buf: &'buf mut [u8],
Expand Down Expand Up @@ -328,7 +355,8 @@ where
{
/// Turn the resource into a buffered resource
///
/// This is most likely only relevant for non-tls endpoints, as `embedded-tls` buffers internally.
/// This is only relevant if no TLS is used, as `embedded-tls` buffers internally and we reuse
/// its buffer for non-TLS connections.
pub fn into_buffered<'buf>(
self,
tx_buf: &'buf mut [u8],
Expand Down

0 comments on commit 15ad1ee

Please sign in to comment.