diff --git a/Cargo.toml b/Cargo.toml index 8e903ec..ad2bfd3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ defmt = { version = "0.3", optional = true } embedded-tls = { version = "0.17", default-features = false, optional = true } rand_chacha = { version = "0.3", default-features = false } nourl = "0.1.1" +esp-mbedtls = { git = "https://github.com/esp-rs/esp-mbedtls.git", features = ["async"], optional = true } [dev-dependencies] hyper = { version = "0.14.23", features = ["full"] } @@ -50,3 +51,10 @@ defmt = [ "embedded-tls?/defmt", "nourl/defmt", ] + +# For esp32, re-export those features as required by Cargo. +# This will also automatically enable `esp-mbedtls`. +esp32 = ["esp-mbedtls/esp32"] +esp32c3 = ["esp-mbedtls/esp32c3"] +esp32s2 = ["esp-mbedtls/esp32s2"] +esp32s3 = ["esp-mbedtls/esp32s3"] diff --git a/src/client.rs b/src/client.rs index b479d15..86db2ce 100644 --- a/src/client.rs +++ b/src/client.rs @@ -20,10 +20,23 @@ where { client: &'a T, dns: &'a D, - #[cfg(feature = "embedded-tls")] + #[cfg(any(feature = "embedded-tls", feature = "esp-mbedtls"))] tls: Option>, } +/// Type for TLS configuration of HTTP client. +#[cfg(feature = "esp-mbedtls")] +pub struct TlsConfig<'a> { + /// Minimum TLS version for the connection + version: crate::TlsVersion, + + /// Client certificates. See [esp_mbedtls::Certificates] + certificates: crate::Certificates<'a>, + + /// Will use hardware acceleration on the ESP32 if it contains the RSA peripheral. + rsa: Option<&'a mut esp_mbedtls::Rsa<'a>>, +} + /// Type for TLS configuration of HTTP client. #[cfg(feature = "embedded-tls")] pub struct TlsConfig<'a> { @@ -54,6 +67,21 @@ impl<'a> TlsConfig<'a> { } } +#[cfg(feature = "esp-mbedtls")] +impl<'a> TlsConfig<'a> { + pub fn new( + version: crate::TlsVersion, + certificates: crate::Certificates<'a>, + rsa: Option<&'a mut esp_mbedtls::Rsa<'a>>, + ) -> Self { + Self { + version, + certificates, + rsa, + } + } +} + impl<'a, T, D> HttpClient<'a, T, D> where T: TcpConnect + 'a, @@ -64,13 +92,13 @@ where Self { client, dns, - #[cfg(feature = "embedded-tls")] + #[cfg(any(feature = "embedded-tls", feature = "esp-mbedtls"))] tls: None, } } /// Create a new HTTP client for a given connection handle and a target host. - #[cfg(feature = "embedded-tls")] + #[cfg(any(feature = "embedded-tls", feature = "esp-mbedtls"))] pub fn new_with_tls(client: &'a T, dns: &'a D, tls: TlsConfig<'a>) -> Self { Self { client, @@ -99,6 +127,24 @@ where .map_err(|e| e.kind())?; if url.scheme() == UrlScheme::HTTPS { + #[cfg(feature = "esp-mbedtls")] + if let Some(tls) = self.tls.as_mut() { + let session = esp_mbedtls::asynch::Session::new( + conn, + host, + esp_mbedtls::Mode::Client, + tls.version, + tls.certificates, + // Create a inner Some(&mut Rsa) because Rsa doesn't implement Copy + tls.rsa.as_mut().map(|inner| inner as &mut esp_mbedtls::Rsa), + )? + .connect() + .await?; + Ok(HttpConnection::Tls(session)) + } else { + Ok(HttpConnection::Plain(conn)) + } + #[cfg(feature = "embedded-tls")] if let Some(tls) = self.tls.as_mut() { use embedded_tls::{TlsConfig, TlsContext}; @@ -118,7 +164,7 @@ where } else { Ok(HttpConnection::Plain(conn)) } - #[cfg(not(feature = "embedded-tls"))] + #[cfg(all(not(feature = "embedded-tls"), not(feature = "esp-mbedtls")))] Err(Error::InvalidUrl(nourl::Error::UnsupportedScheme)) } else { #[cfg(feature = "embedded-tls")] @@ -172,9 +218,11 @@ where { Plain(C), PlainBuffered(BufferedWrite<'conn, C>), + #[cfg(feature = "esp-mbedtls")] + Tls(esp_mbedtls::asynch::AsyncConnectedSession), #[cfg(feature = "embedded-tls")] Tls(embedded_tls::TlsConnection<'conn, C, embedded_tls::Aes128GcmSha256>), - #[cfg(not(feature = "embedded-tls"))] + #[cfg(all(not(feature = "embedded-tls"), not(feature = "esp-mbedtls")))] Tls((&'conn mut (), core::convert::Infallible)), // Variant is impossible to create, but we need it to avoid "unused lifetime" warning } @@ -255,9 +303,9 @@ where match self { Self::Plain(conn) => conn.read(buf).await.map_err(|e| e.kind()), Self::PlainBuffered(conn) => conn.read(buf).await.map_err(|e| e.kind()), - #[cfg(feature = "embedded-tls")] + #[cfg(any(feature = "embedded-tls", feature = "esp-mbedtls"))] Self::Tls(conn) => conn.read(buf).await.map_err(|e| e.kind()), - #[cfg(not(feature = "embedded-tls"))] + #[cfg(not(any(feature = "embedded-tls", feature = "esp-mbedtls")))] _ => unreachable!(), } } @@ -271,9 +319,9 @@ where match self { Self::Plain(conn) => conn.write(buf).await.map_err(|e| e.kind()), Self::PlainBuffered(conn) => conn.write(buf).await.map_err(|e| e.kind()), - #[cfg(feature = "embedded-tls")] + #[cfg(any(feature = "embedded-tls", feature = "esp-mbedtls"))] Self::Tls(conn) => conn.write(buf).await.map_err(|e| e.kind()), - #[cfg(not(feature = "embedded-tls"))] + #[cfg(not(any(feature = "embedded-tls", feature = "esp-mbedtls")))] _ => unreachable!(), } } @@ -282,9 +330,9 @@ where match self { Self::Plain(conn) => conn.flush().await.map_err(|e| e.kind()), Self::PlainBuffered(conn) => conn.flush().await.map_err(|e| e.kind()), - #[cfg(feature = "embedded-tls")] + #[cfg(any(feature = "embedded-tls", feature = "esp-mbedtls"))] Self::Tls(conn) => conn.flush().await.map_err(|e| e.kind()), - #[cfg(not(feature = "embedded-tls"))] + #[cfg(not(any(feature = "embedded-tls", feature = "esp-mbedtls")))] _ => unreachable!(), } } diff --git a/src/lib.rs b/src/lib.rs index 74bdba9..e1f7740 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,6 +28,9 @@ pub enum Error { /// Tls Error #[cfg(feature = "embedded-tls")] Tls(embedded_tls::TlsError), + /// Tls Error + #[cfg(feature = "esp-mbedtls")] + Tls(esp_mbedtls::TlsError), /// The provided buffer is too small BufferTooSmall, /// The request is already sent @@ -70,6 +73,17 @@ impl From for Error { } } +/// Re-export those members since they're used for [client::TlsConfig]. +#[cfg(feature = "esp-mbedtls")] +pub use esp_mbedtls::{Certificates, Rsa, TlsVersion, X509}; + +#[cfg(feature = "esp-mbedtls")] +impl From for Error { + fn from(e: esp_mbedtls::TlsError) -> Error { + Error::Tls(e) + } +} + impl From for Error { fn from(_: ParseIntError) -> Error { Error::Codec