diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index 719083d11f4..02131a2daf0 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -40,7 +40,7 @@ pub use self::connect_info::ConnectInfo; #[doc(no_inline)] #[cfg(feature = "json")] -pub use crate::Json; +pub use crate::{json::JsonDeserializer, Json}; #[doc(no_inline)] pub use crate::Extension; diff --git a/axum/src/json.rs b/axum/src/json.rs index ebff242dd42..8107771615b 100644 --- a/axum/src/json.rs +++ b/axum/src/json.rs @@ -7,17 +7,18 @@ use http::{ header::{self, HeaderMap, HeaderValue}, StatusCode, }; -use serde::{de::DeserializeOwned, Serialize}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use std::marker::PhantomData; /// JSON Extractor / Response. /// /// When used as an extractor, it can deserialize request bodies into some type that -/// implements [`serde::Deserialize`]. The request will be rejected (and a [`JsonRejection`] will +/// implements [`serde::de::DeserializeOwned`]. The request will be rejected (and a [`JsonRejection`] will /// be returned) if: /// /// - The request doesn't have a `Content-Type: application/json` (or similar) header. /// - The body doesn't contain syntactically valid JSON. -/// - The body contains syntactically valid JSON but it couldn't be deserialized into the target +/// - The body contains syntactically valid JSON, but it couldn't be deserialized into the target /// type. /// - Buffering the request body fails. /// @@ -135,6 +136,32 @@ fn json_content_type(headers: &HeaderMap) -> bool { is_json_content_type } +fn json_from_bytes<'a, T: Deserialize<'a>>(bytes: &'a [u8]) -> Result { + let deserializer = &mut serde_json::Deserializer::from_slice(bytes); + + match serde_path_to_error::deserialize(deserializer) { + Ok(value) => Ok(value), + Err(err) => { + let rejection = match err.inner().classify() { + serde_json::error::Category::Data => JsonDataError::from_err(err).into(), + serde_json::error::Category::Syntax | serde_json::error::Category::Eof => { + JsonSyntaxError::from_err(err).into() + } + serde_json::error::Category::Io => { + if cfg!(debug_assertions) { + // we don't use `serde_json::from_reader` and instead always buffer + // bodies first, so we shouldn't encounter any IO errors + unreachable!() + } else { + JsonSyntaxError::from_err(err).into() + } + } + }; + Err(rejection) + } + } +} + axum_core::__impl_deref!(Json); impl From for Json { @@ -151,30 +178,7 @@ where /// but special cases may require first extracting a `Request` into `Bytes` then optionally /// constructing a `Json`. pub fn from_bytes(bytes: &[u8]) -> Result { - let deserializer = &mut serde_json::Deserializer::from_slice(bytes); - - let value = match serde_path_to_error::deserialize(deserializer) { - Ok(value) => value, - Err(err) => { - let rejection = match err.inner().classify() { - serde_json::error::Category::Data => JsonDataError::from_err(err).into(), - serde_json::error::Category::Syntax | serde_json::error::Category::Eof => { - JsonSyntaxError::from_err(err).into() - } - serde_json::error::Category::Io => { - if cfg!(debug_assertions) { - // we don't use `serde_json::from_reader` and instead always buffer - // bodies first, so we shouldn't encounter any IO errors - unreachable!() - } else { - JsonSyntaxError::from_err(err).into() - } - } - }; - return Err(rejection); - } - }; - + let value = json_from_bytes(bytes)?; Ok(Json(value)) } } @@ -209,12 +213,119 @@ where } } +/// JSON Extractor for zero-copy deserialization. +/// +/// Deserialize request bodies into some type that implements [`serde::Deserialize<'de>`]. +/// Parsing JSON is delayed until [`deserialize`](JsonDeserializer::deserialize) is called. +/// If the type implements [`serde::de::DeserializeOwned`], the [`Json`] extractor should +/// be preferred. +/// +/// The request will be rejected (and a [`JsonRejection`] will be returned) if: +/// +/// - The request doesn't have a `Content-Type: application/json` (or similar) header. +/// - Buffering the request body fails. +/// +/// Additionally, a `JsonRejection` error will be returned, when calling `deserialize` if: +/// +/// - The body doesn't contain syntactically valid JSON. +/// - The body contains syntactically valid JSON, but it couldn't be deserialized into the target +/// type. +/// - Attempting to deserialize escaped JSON into a type that must be borrowed (e.g. `&'a str`). +/// +/// ⚠️ `serde` will implicitly try to borrow for `&str` and `&[u8]` types, but will error if the +/// input contains escaped characters. Use `Cow<'a, str>` or `Cow<'a, [u8]>`, with the +/// `#[serde(borrow)]` attribute, to allow serde to fall back to an owned type when encountering +/// escaped characters. +/// +/// ⚠️ Since parsing JSON requires consuming the request body, the `Json` extractor must be +/// *last* if there are multiple extractors in a handler. +/// See ["the order of extractors"][order-of-extractors] +/// +/// # Example +/// +/// ```rust,no_run +/// use axum::{ +/// extract, +/// routing::post, +/// Router, +/// response::{IntoResponse, Response} +/// }; +/// use serde::Deserialize; +/// use std::borrow::Cow; +/// use http::StatusCode; +/// +/// #[derive(Deserialize)] +/// struct Data<'a> { +/// #[serde(borrow)] +/// borrow_text: Cow<'a, str>, +/// #[serde(borrow)] +/// borrow_bytes: Cow<'a, [u8]>, +/// borrow_dangerous: &'a str, +/// not_borrowed: String, +/// } +/// +/// async fn upload(deserializer: extract::JsonDeserializer>) -> Response { +/// let data = match deserializer.deserialize() { +/// Ok(data) => data, +/// Err(e) => return e.into_response(), +/// }; +/// +/// // payload is a `Data` with borrowed data from `deserializer`, +/// // which owns the request body (`Bytes`). +/// +/// StatusCode::OK.into_response() +/// } +/// +/// let app = Router::new().route("/upload", post(upload)); +/// # let _: Router = app; +/// ``` +#[derive(Debug, Clone, Default)] +#[cfg_attr(docsrs, doc(cfg(feature = "json")))] +pub struct JsonDeserializer { + bytes: Bytes, + _marker: PhantomData, +} + +#[async_trait] +impl FromRequest for JsonDeserializer +where + T: Deserialize<'static>, + S: Send + Sync, +{ + type Rejection = JsonRejection; + + async fn from_request(req: Request, state: &S) -> Result { + if json_content_type(req.headers()) { + let bytes = Bytes::from_request(req, state).await?; + Ok(Self { + bytes, + _marker: PhantomData, + }) + } else { + Err(MissingJsonContentType.into()) + } + } +} + +impl<'de, 'a: 'de, T> JsonDeserializer +where + T: Deserialize<'de>, +{ + /// Deserialize the request body into the target type. + /// See [`JsonDeserializer`] for more details. + pub fn deserialize(&'a self) -> Result { + let value = json_from_bytes(&self.bytes)?; + Ok(value) + } +} + #[cfg(test)] mod tests { use super::*; use crate::{routing::post, test_helpers::*, Router}; use serde::Deserialize; use serde_json::{json, Value}; + use std::borrow::Cow; #[crate::test] async fn deserialize_body() { @@ -232,6 +343,111 @@ mod tests { assert_eq!(body, "bar"); } + #[crate::test] + async fn deserializer_deserialize_body() { + #[derive(Debug, Deserialize)] + struct Input<'a> { + #[serde(borrow)] + foo: Cow<'a, str>, + } + + async fn handler(deserializer: JsonDeserializer>) -> Response { + match deserializer.deserialize() { + Ok(input) => { + assert!(matches!(input.foo, Cow::Borrowed(_))); + input.foo.into_owned().into_response() + } + Err(e) => e.into_response(), + } + } + + let app = Router::new().route("/", post(handler)); + + let client = TestClient::new(app); + let res = client.post("/").json(&json!({ "foo": "bar" })).send().await; + let body = res.text().await; + + assert_eq!(body, "bar"); + } + + #[crate::test] + async fn deserializer_deserialize_body_escaped_to_cow() { + #[derive(Debug, Deserialize)] + struct Input<'a> { + #[serde(borrow)] + foo: Cow<'a, str>, + } + + async fn handler(deserializer: JsonDeserializer>) -> Response { + match deserializer.deserialize() { + Ok(Input { foo }) => { + let Cow::Owned(foo) = foo else { + panic!("Deserializer is expected to fallback to Cow::Owned when encountering escaped characters") + }; + + foo.into_response() + } + Err(e) => e.into_response(), + } + } + + let app = Router::new().route("/", post(handler)); + + let client = TestClient::new(app); + + // The escaped characters prevent serde_json from borrowing. + let res = client + .post("/") + .json(&json!({ "foo": "\"bar\"" })) + .send() + .await; + + let body = res.text().await; + + assert_eq!(body, r#""bar""#); + } + + #[crate::test] + async fn deserializer_deserialize_body_escaped_to_str() { + #[derive(Debug, Deserialize)] + struct Input<'a> { + // Explicit `#[serde(borrow)]` attribute is not required for `&str` or &[u8]. + // See: https://serde.rs/lifetimes.html#borrowing-data-in-a-derived-impl + foo: &'a str, + } + + async fn route_fn(deserializer: JsonDeserializer>) -> Response { + match deserializer.deserialize() { + Ok(Input { foo }) => foo.to_owned().into_response(), + Err(e) => e.into_response(), + } + } + + let app = Router::new().route("/", post(route_fn)); + + let client = TestClient::new(app); + + let res = client + .post("/") + .json(&json!({ "foo": "good" })) + .send() + .await; + let body = res.text().await; + assert_eq!(body, "good"); + + let res = client + .post("/") + .json(&json!({ "foo": "\"bad\"" })) + .send() + .await; + assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY); + let body_text = res.text().await; + assert_eq!( + body_text, + "Failed to deserialize the JSON body into the target type: foo: invalid type: string \"\\\"bad\\\"\", expected a borrowed string at line 1 column 16" + ); + } + #[crate::test] async fn consume_body_to_json_requires_json_content_type() { #[derive(Debug, Deserialize)]