Skip to content

Commit

Permalink
Add test case and fix build
Browse files Browse the repository at this point in the history
  • Loading branch information
bugadani committed Oct 20, 2023
1 parent 28637ef commit 0c9a66f
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 25 deletions.
76 changes: 51 additions & 25 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ where
&'m mut self,
method: Method,
url: &'m str,
) -> Result<HttpRequestHandle<'m, HttpConnection<'m, T::Connection<'m>>, ()>, Error> {
) -> Result<HttpRequestHandle<'m, T::Connection<'m>, ()>, Error> {
let url = Url::parse(url)?;
let conn = self.connect(&url).await?;
Ok(HttpRequestHandle {
Expand All @@ -150,7 +150,7 @@ where
pub async fn resource<'res>(
&'res mut self,
resource_url: &'res str,
) -> Result<HttpResource<'res, HttpConnection<'res, T::Connection<'res>>>, Error> {
) -> Result<HttpResource<'res, T::Connection<'res>>, Error> {
let resource_url = Url::parse(resource_url)?;
let conn = self.connect(&resource_url).await?;
Ok(HttpResource {
Expand Down Expand Up @@ -180,6 +180,23 @@ impl<'conn, T> HttpConnection<'conn, T>
where
T: Read + Write,
{
/// Turn the request into a buffered request.
///
/// This is only relevant if no TLS is used, as `embedded-tls` buffers internally and we reuse
/// its buffer for non-TLS connections.
pub fn into_buffered<'buf>(self, tx_buf: &'buf mut [u8]) -> HttpConnection<'buf, T>
where
'conn: 'buf,
{
match self {
HttpConnection::Plain(conn) => {
HttpConnection::PlainBuffered(BufferedWrite::new(buffered_io_adapter::ConnErrorAdapter(conn), tx_buf))
}
HttpConnection::PlainBuffered(conn) => HttpConnection::PlainBuffered(conn),
HttpConnection::Tls(tls) => HttpConnection::Tls(tls),
}
}

/// Send a request on an established connection.
///
/// The request is sent in its raw form without any base path from the resource.
Expand Down Expand Up @@ -257,7 +274,7 @@ where
C: Read + Write,
B: RequestBody,
{
pub conn: C,
pub conn: HttpConnection<'m, C>,
request: Option<DefaultRequestBuilder<'m, B>>,
}

Expand All @@ -270,12 +287,12 @@ where
///
/// This is only relevant if no TLS is used, as `embedded-tls` buffers internally and we reuse
/// its buffer for non-TLS connections.
pub fn into_buffered<'buf>(
self,
tx_buf: &'buf mut [u8],
) -> HttpRequestHandle<'m, BufferedWrite<'buf, buffered_io_adapter::ConnErrorAdapter<C>>, B> {
pub fn into_buffered<'buf>(self, tx_buf: &'buf mut [u8]) -> HttpRequestHandle<'buf, C, B>
where
'm: 'buf,
{
HttpRequestHandle {
conn: BufferedWrite::new(buffered_io_adapter::ConnErrorAdapter(self.conn), tx_buf),
conn: self.conn.into_buffered(tx_buf),
request: self.request,
}
}
Expand All @@ -285,7 +302,13 @@ where
/// The response headers are stored in the provided rx_buf, which should be sized to contain at least the response headers.
///
/// The response is returned.
pub async fn send<'buf, 'conn>(&'conn mut self, rx_buf: &'buf mut [u8]) -> Result<Response<'buf, 'conn, C>, Error> {
pub async fn send<'buf, 'conn>(
&'conn mut self,
rx_buf: &'buf mut [u8],
) -> Result<Response<'buf, 'conn, HttpConnection<'conn, C>>, Error>
where
'conn: 'm,
{
let request = self.request.take().ok_or(Error::AlreadySent)?.build();
request.write(&mut self.conn).await?;
Response::read(&mut self.conn, request.method, rx_buf).await
Expand Down Expand Up @@ -343,7 +366,7 @@ pub struct HttpResource<'res, C>
where
C: Read + Write,
{
pub conn: C,
pub conn: HttpConnection<'res, C>,
pub host: &'res str,
pub base_path: &'res str,
}
Expand All @@ -356,12 +379,12 @@ where
///
/// This is only relevant if no TLS is used, as `embedded-tls` buffers internally and we reuse
/// its buffer for non-TLS connections.
pub fn into_buffered<'buf>(
self,
tx_buf: &'buf mut [u8],
) -> HttpResource<'res, BufferedWrite<'buf, buffered_io_adapter::ConnErrorAdapter<C>>> {
pub fn into_buffered<'buf>(self, tx_buf: &'buf mut [u8]) -> HttpResource<'buf, C>
where
'res: 'buf,
{
HttpResource {
conn: BufferedWrite::new(buffered_io_adapter::ConnErrorAdapter(self.conn), tx_buf),
conn: self.conn.into_buffered(tx_buf),
host: self.host,
base_path: self.base_path,
}
Expand Down Expand Up @@ -432,7 +455,7 @@ where
&'conn mut self,
mut request: Request<'res, B>,
rx_buf: &'buf mut [u8],
) -> Result<Response<'buf, 'conn, C>, Error> {
) -> Result<Response<'buf, 'conn, HttpConnection<'res, C>>, Error> {
request.base_path = Some(self.base_path);
request.write(&mut self.conn).await?;
Response::read(&mut self.conn, request.method, rx_buf).await
Expand All @@ -444,7 +467,7 @@ where
C: Read + Write,
B: RequestBody,
{
conn: &'conn mut C,
conn: &'conn mut HttpConnection<'res, C>,
base_path: &'res str,
request: DefaultRequestBuilder<'m, B>,
}
Expand All @@ -460,7 +483,10 @@ where
/// The response headers are stored in the provided rx_buf, which should be sized to contain at least the response headers.
///
/// The response is returned.
pub async fn send<'buf>(self, rx_buf: &'buf mut [u8]) -> Result<Response<'buf, 'conn, C>, Error> {
pub async fn send<'buf>(
self,
rx_buf: &'buf mut [u8],
) -> Result<Response<'buf, 'conn, HttpConnection<'res, C>>, Error> {
let conn = self.conn;
let mut request = self.request.build();
request.base_path = Some(self.base_path);
Expand Down Expand Up @@ -515,7 +541,7 @@ where
}

mod buffered_io_adapter {
use embedded_io::{Error as _, ErrorType, ReadExactError};
use embedded_io::{Error as _, ErrorKind, ErrorType, ReadExactError};
use embedded_io_async::{Read, Write};

pub struct Error(embedded_io::ErrorKind);
Expand All @@ -535,23 +561,23 @@ mod buffered_io_adapter {
pub struct ConnErrorAdapter<C>(pub C);

impl<C> ErrorType for ConnErrorAdapter<C> {
type Error = Error;
type Error = ErrorKind;
}

impl<C> Write for ConnErrorAdapter<C>
where
C: Write,
{
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
self.0.write(buf).await.map_err(|e| Error(e.kind()))
self.0.write(buf).await.map_err(|e| e.kind())
}

async fn flush(&mut self) -> Result<(), Self::Error> {
self.0.flush().await.map_err(|e| Error(e.kind()))
self.0.flush().await.map_err(|e| e.kind())
}

async fn write_all(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
self.0.write_all(buf).await.map_err(|e| Error(e.kind()))
self.0.write_all(buf).await.map_err(|e| e.kind())
}
}

Expand All @@ -560,13 +586,13 @@ mod buffered_io_adapter {
C: Read,
{
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
self.0.read(buf).await.map_err(|e| Error(e.kind()))
self.0.read(buf).await.map_err(|e| e.kind())
}

async fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), ReadExactError<Self::Error>> {
self.0.read_exact(buf).await.map_err(|e| match e {
ReadExactError::UnexpectedEof => ReadExactError::UnexpectedEof,
ReadExactError::Other(e) => ReadExactError::Other(Error(e.kind())),
ReadExactError::Other(e) => ReadExactError::Other(e.kind()),
})
}
}
Expand Down
37 changes: 37 additions & 0 deletions tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,43 @@ async fn test_resource_drogue_cloud_sandbox() {
}
}

#[tokio::test]
async fn test_request_response_notls_buffered() {
setup();
let addr = ([127, 0, 0, 1], 0).into();

let service = make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(echo)) });

let server = Server::bind(&addr).serve(service);
let addr = server.local_addr();

let (tx, rx) = oneshot::channel();
let t = tokio::spawn(async move {
tokio::select! {
_ = server => {}
_ = rx => {}
}
});

let url = format!("http://127.0.0.1:{}", addr.port());
let mut client = HttpClient::new(&TCP, &LOOPBACK_DNS);
let mut tx_buf = [0; 4096];
let mut rx_buf = [0; 4096];
let mut request = client
.request(Method::POST, &url)
.await
.unwrap()
.into_buffered(&mut tx_buf)
.body(b"PING".as_slice())
.content_type(ContentType::TextPlain);
let response = request.send(&mut rx_buf).await.unwrap();
let body = response.body().read_to_end().await;
assert_eq!(body.unwrap(), b"PING");

tx.send(()).unwrap();
t.await.unwrap();
}

fn load_certs(filename: &std::path::PathBuf) -> Vec<rustls::Certificate> {
let certfile = std::fs::File::open(filename).expect("cannot open certificate file");
let mut reader = std::io::BufReader::new(certfile);
Expand Down

0 comments on commit 0c9a66f

Please sign in to comment.