Skip to content

Commit

Permalink
Add JsonDeserializer extractor
Browse files Browse the repository at this point in the history
  • Loading branch information
future-highway committed Dec 16, 2023
1 parent 2b486ea commit 6eaa657
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 28 deletions.
2 changes: 1 addition & 1 deletion axum/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub use self::connect_info::ConnectInfo;

#[doc(no_inline)]
#[cfg(feature = "json")]
pub use crate::Json;
pub use crate::{json::JsonDeserializer, Json};

#[doc(no_inline)]
pub use crate::Extension;
Expand Down
270 changes: 243 additions & 27 deletions axum/src/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@ use http::{
header::{self, HeaderMap, HeaderValue},
StatusCode,
};
use serde::{de::DeserializeOwned, Serialize};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::marker::PhantomData;

/// JSON Extractor / Response.
///
/// When used as an extractor, it can deserialize request bodies into some type that
/// implements [`serde::Deserialize`]. The request will be rejected (and a [`JsonRejection`] will
/// implements [`serde::de::DeserializeOwned`]. The request will be rejected (and a [`JsonRejection`] will
/// be returned) if:
///
/// - The request doesn't have a `Content-Type: application/json` (or similar) header.
/// - The body doesn't contain syntactically valid JSON.
/// - The body contains syntactically valid JSON but it couldn't be deserialized into the target
/// - The body contains syntactically valid JSON, but it couldn't be deserialized into the target
/// type.
/// - Buffering the request body fails.
///
Expand Down Expand Up @@ -135,6 +136,32 @@ fn json_content_type(headers: &HeaderMap) -> bool {
is_json_content_type
}

fn json_from_bytes<'a, T: Deserialize<'a>>(bytes: &'a [u8]) -> Result<T, JsonRejection> {
let deserializer = &mut serde_json::Deserializer::from_slice(bytes);

match serde_path_to_error::deserialize(deserializer) {
Ok(value) => Ok(value),
Err(err) => {
let rejection = match err.inner().classify() {
serde_json::error::Category::Data => JsonDataError::from_err(err).into(),
serde_json::error::Category::Syntax | serde_json::error::Category::Eof => {
JsonSyntaxError::from_err(err).into()
}
serde_json::error::Category::Io => {
if cfg!(debug_assertions) {
// we don't use `serde_json::from_reader` and instead always buffer
// bodies first, so we shouldn't encounter any IO errors
unreachable!()
} else {
JsonSyntaxError::from_err(err).into()
}
}
};
Err(rejection)
}
}
}

axum_core::__impl_deref!(Json);

impl<T> From<T> for Json<T> {
Expand All @@ -151,30 +178,7 @@ where
/// but special cases may require first extracting a `Request` into `Bytes` then optionally
/// constructing a `Json<T>`.
pub fn from_bytes(bytes: &[u8]) -> Result<Self, JsonRejection> {
let deserializer = &mut serde_json::Deserializer::from_slice(bytes);

let value = match serde_path_to_error::deserialize(deserializer) {
Ok(value) => value,
Err(err) => {
let rejection = match err.inner().classify() {
serde_json::error::Category::Data => JsonDataError::from_err(err).into(),
serde_json::error::Category::Syntax | serde_json::error::Category::Eof => {
JsonSyntaxError::from_err(err).into()
}
serde_json::error::Category::Io => {
if cfg!(debug_assertions) {
// we don't use `serde_json::from_reader` and instead always buffer
// bodies first, so we shouldn't encounter any IO errors
unreachable!()
} else {
JsonSyntaxError::from_err(err).into()
}
}
};
return Err(rejection);
}
};

let value = json_from_bytes(bytes)?;
Ok(Json(value))
}
}
Expand Down Expand Up @@ -209,12 +213,119 @@ where
}
}

/// JSON Extractor for zero-copy deserialization.
///
/// Deserialize request bodies into some type that implements [`serde::Deserialize<'de>`].
/// Parsing JSON is delayed until [`deserialize`](JsonDeserializer::deserialize) is called.
/// If the type implements [`serde::de::DeserializeOwned`], the [`Json`] extractor should
/// be preferred.
///
/// The request will be rejected (and a [`JsonRejection`] will be returned) if:
///
/// - The request doesn't have a `Content-Type: application/json` (or similar) header.
/// - Buffering the request body fails.
///
/// Additionally, a `JsonRejection` error will be returned, when calling `deserialize` if:
///
/// - The body doesn't contain syntactically valid JSON.
/// - The body contains syntactically valid JSON, but it couldn't be deserialized into the target
/// type.
/// - Attempting to deserialize escaped JSON into a type that must be borrowed (e.g. `&'a str`).
///
/// ⚠️ `serde` will implicitly try to borrow for `&str` and `&[u8]` types, but will error if the
/// input contains escaped characters. Use `Cow<'a, str>` or `Cow<'a, [u8]>`, with the
/// `#[serde(borrow)]` attribute, to allow serde to fall back to an owned type when encountering
/// escaped characters.
///
/// ⚠️ Since parsing JSON requires consuming the request body, the `Json` extractor must be
/// *last* if there are multiple extractors in a handler.
/// See ["the order of extractors"][order-of-extractors]
///
/// # Example
///
/// ```rust,no_run
/// use axum::{
/// extract,
/// routing::post,
/// Router,
/// response::{IntoResponse, Response}
/// };
/// use serde::Deserialize;
/// use std::borrow::Cow;
/// use http::StatusCode;
///
/// #[derive(Deserialize)]
/// struct Data<'a> {
/// #[serde(borrow)]
/// borrow_text: Cow<'a, str>,
/// #[serde(borrow)]
/// borrow_bytes: Cow<'a, [u8]>,
/// borrow_dangerous: &'a str,
/// not_borrowed: String,
/// }
///
/// async fn upload(deserializer: extract::JsonDeserializer<Data<'_>>) -> Response {
/// let data = match deserializer.deserialize() {
/// Ok(data) => data,
/// Err(e) => return e.into_response(),
/// };
///
/// // payload is a `Data` with borrowed data from `deserializer`,
/// // which owns the request body (`Bytes`).
///
/// StatusCode::OK.into_response()
/// }
///
/// let app = Router::new().route("/upload", post(upload));
/// # let _: Router = app;
/// ```
#[derive(Debug, Clone, Default)]
#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
pub struct JsonDeserializer<T> {
bytes: Bytes,
_marker: PhantomData<T>,
}

#[async_trait]
impl<T, S> FromRequest<S> for JsonDeserializer<T>
where
T: Deserialize<'static>,
S: Send + Sync,
{
type Rejection = JsonRejection;

async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
if json_content_type(req.headers()) {
let bytes = Bytes::from_request(req, state).await?;
Ok(Self {
bytes,
_marker: PhantomData,
})
} else {
Err(MissingJsonContentType.into())
}
}
}

impl<'de, 'a: 'de, T> JsonDeserializer<T>
where
T: Deserialize<'de>,
{
/// Deserialize the request body into the target type.
/// See [`JsonDeserializer`] for more details.
pub fn deserialize(&'a self) -> Result<T, JsonRejection> {
let value = json_from_bytes(&self.bytes)?;
Ok(value)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{routing::post, test_helpers::*, Router};
use serde::Deserialize;
use serde_json::{json, Value};
use std::borrow::Cow;

#[crate::test]
async fn deserialize_body() {
Expand All @@ -232,6 +343,111 @@ mod tests {
assert_eq!(body, "bar");
}

#[crate::test]
async fn deserializer_deserialize_body() {
#[derive(Debug, Deserialize)]
struct Input<'a> {
#[serde(borrow)]
foo: Cow<'a, str>,
}

async fn handler(deserializer: JsonDeserializer<Input<'_>>) -> Response {
match deserializer.deserialize() {
Ok(input) => {
assert!(matches!(input.foo, Cow::Borrowed(_)));
input.foo.into_owned().into_response()
}
Err(e) => e.into_response(),
}
}

let app = Router::new().route("/", post(handler));

let client = TestClient::new(app);
let res = client.post("/").json(&json!({ "foo": "bar" })).send().await;
let body = res.text().await;

assert_eq!(body, "bar");
}

#[crate::test]
async fn deserializer_deserialize_body_escaped_to_cow() {
#[derive(Debug, Deserialize)]
struct Input<'a> {
#[serde(borrow)]
foo: Cow<'a, str>,
}

async fn handler(deserializer: JsonDeserializer<Input<'_>>) -> Response {
match deserializer.deserialize() {
Ok(Input { foo }) => {
let Cow::Owned(foo) = foo else {
panic!("Deserializer is expected to fallback to Cow::Owned when encountering escaped characters")
};

foo.into_response()
}
Err(e) => e.into_response(),
}
}

let app = Router::new().route("/", post(handler));

let client = TestClient::new(app);

// The escaped characters prevent serde_json from borrowing.
let res = client
.post("/")
.json(&json!({ "foo": "\"bar\"" }))
.send()
.await;

let body = res.text().await;

assert_eq!(body, r#""bar""#);
}

#[crate::test]
async fn deserializer_deserialize_body_escaped_to_str() {
#[derive(Debug, Deserialize)]
struct Input<'a> {
// Explicit `#[serde(borrow)]` attribute is not required for `&str` or &[u8].
// See: https://serde.rs/lifetimes.html#borrowing-data-in-a-derived-impl
foo: &'a str,
}

async fn route_fn(deserializer: JsonDeserializer<Input<'_>>) -> Response {
match deserializer.deserialize() {
Ok(Input { foo }) => foo.to_owned().into_response(),
Err(e) => e.into_response(),
}
}

let app = Router::new().route("/", post(route_fn));

let client = TestClient::new(app);

let res = client
.post("/")
.json(&json!({ "foo": "good" }))
.send()
.await;
let body = res.text().await;
assert_eq!(body, "good");

let res = client
.post("/")
.json(&json!({ "foo": "\"bad\"" }))
.send()
.await;
assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
let body_text = res.text().await;
assert_eq!(
body_text,
"Failed to deserialize the JSON body into the target type: foo: invalid type: string \"\\\"bad\\\"\", expected a borrowed string at line 1 column 16"
);
}

#[crate::test]
async fn consume_body_to_json_requires_json_content_type() {
#[derive(Debug, Deserialize)]
Expand Down

0 comments on commit 6eaa657

Please sign in to comment.