diff --git a/socketioxide/src/handler/extract.rs b/socketioxide/src/handler/extract.rs index 568cf8b3..c0ba36bb 100644 --- a/socketioxide/src/handler/extract.rs +++ b/socketioxide/src/handler/extract.rs @@ -144,11 +144,11 @@ where fn from_message_parts( _: &Arc>, v: &mut serde_json::Value, - _: &mut Vec, + b: &mut Vec, _: &Option, ) -> Result { upwrap_array(v); - serde_json::from_value(v.clone()).map(Data) + crate::from_value::(v.clone(), b).map(Data) } } @@ -178,11 +178,11 @@ where fn from_message_parts( _: &Arc>, v: &mut serde_json::Value, - _: &mut Vec, + b: &mut Vec, _: &Option, ) -> Result { upwrap_array(v); - Ok(TryData(serde_json::from_value(v.clone()))) + Ok(TryData(crate::from_value::(v.clone(), b))) } } /// An Extractor that returns a reference to a [`Socket`]. diff --git a/socketioxide/src/lib.rs b/socketioxide/src/lib.rs index cc9ae9f6..7b12b49f 100644 --- a/socketioxide/src/lib.rs +++ b/socketioxide/src/lib.rs @@ -292,11 +292,13 @@ pub use engineioxide::TransportType; pub use errors::{AckError, AdapterError, BroadcastError, DisconnectError, SendError, SocketError}; pub use handler::extract; pub use io::{SocketIo, SocketIoBuilder, SocketIoConfig}; +pub use value::{de::from_value, ser::to_value}; mod client; mod errors; mod io; mod ns; +mod value; /// Socket.IO protocol version. /// It is accessible with the [`Socket::protocol`](socket::Socket) method or as an extractor diff --git a/socketioxide/src/packet.rs b/socketioxide/src/packet.rs index 890e7300..38d127fe 100644 --- a/socketioxide/src/packet.rs +++ b/socketioxide/src/packet.rs @@ -6,7 +6,7 @@ use std::borrow::Cow; use crate::ProtocolVersion; use bytes::Bytes; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_json::{json, Value}; +use serde_json::{json, Map, Value}; use crate::errors::Error; use engineioxide::sid::Sid; @@ -243,27 +243,13 @@ impl<'a> PacketData<'a> { } impl BinaryPacket { - /// Create a binary packet from incoming data and remove all placeholders and get the payload count - pub fn incoming(mut data: Value) -> Self { - let payload_count = match &mut data { - Value::Array(ref mut v) => { - let count = v.len(); - v.retain(|v| v.as_object().and_then(|o| o.get("_placeholder")).is_none()); - count - v.len() - } - val => { - if val - .as_object() - .and_then(|o| o.get("_placeholder")) - .is_some() - { - data = Value::Array(vec![]); - 1 - } else { - 0 - } - } + /// Create a binary packet from incoming data and gets the payload count + pub fn incoming(data: Value) -> Self { + let data = match data { + v @ Value::Array(_) => v, + v => Value::Array(vec![v]), }; + let payload_count = count_binary_payloads(&data); Self { data, @@ -273,18 +259,38 @@ impl BinaryPacket { } /// Create a binary packet from outgoing data and a payload + /// + /// The outgoing data should include numbered placeholder objects for each binary, like so: + /// ```json + /// { + /// "_placeholder": true, + /// "num": 0 + /// } + /// ``` + ///The value of the "num" field should correspond to its index in the `bin` argument. pub fn outgoing(data: Value, bin: Vec) -> Self { let mut data = match data { Value::Array(v) => Value::Array(v), d => Value::Array(vec![d]), }; - let payload_count = bin.len(); - (0..payload_count).for_each(|i| { - data.as_array_mut().unwrap().push(json!({ - "_placeholder": true, - "num": i - })) - }); + + let payload_count = count_binary_payloads(&data); + let bin_count = bin.len(); + + // TODO: if payload_count > bin_count, maybe should return an error here? data has more + // placeholders than bin has payloads. but maybe some payloads are reused and it's ok? at + // any rate, serialization will fail later if there are placeholders that reference + // payloads that don't exist. + + if bin_count > payload_count { + (payload_count..bin_count).for_each(|i| { + data.as_array_mut().unwrap().push(json!({ + "_placeholder": true, + "num": i + })) + }); + } + Self { data, bin, @@ -303,6 +309,23 @@ impl BinaryPacket { } } +fn is_placeholder(o: &Map) -> bool { + o.len() == 2 + && o.get("_placeholder") + .and_then(|v| v.as_bool()) + .unwrap_or(false) + && o.get("num").and_then(|v| v.as_u64()).is_some() +} + +fn count_binary_payloads(data: &Value) -> usize { + match data { + Value::Array(a) => a.iter().map(count_binary_payloads).sum(), + Value::Object(o) if is_placeholder(o) => 1, + Value::Object(o) => o.values().map(count_binary_payloads).sum(), + _ => 0, + } +} + impl<'a> From> for String { fn from(mut packet: Packet<'a>) -> String { use PacketData::*; @@ -715,7 +738,7 @@ mod test { let packet: String = Packet::bin_event( "/", "event", - json!({ "data": "value™" }), + json!([{ "data": "value™" }, { "_placeholder": true, "num": 0 }]), vec![Bytes::from_static(&[1])], ) .try_into() @@ -728,7 +751,7 @@ mod test { let mut packet = Packet::bin_event( "/", "event", - json!({ "data": "value™" }), + json!([{ "data": "value™" }, { "_placeholder": true, "num": 0 }]), vec![Bytes::from_static(&[1])], ); packet.inner.set_ack_id(254); @@ -741,7 +764,7 @@ mod test { let packet: String = Packet::bin_event( "/admin™", "event", - json!([{"data": "value™"}]), + json!([{"data": "value™"}, { "_placeholder": true, "num": 0 }]), vec![Bytes::from_static(&[1])], ) .try_into() @@ -754,7 +777,7 @@ mod test { let mut packet = Packet::bin_event( "/admin™", "event", - json!([{"data": "value™"}]), + json!([{"data": "value™"}, { "_placeholder": true, "num": 0 }]), vec![Bytes::from_static(&[1])], ); packet.inner.set_ack_id(254); @@ -770,7 +793,7 @@ mod test { "event".into(), BinaryPacket { bin: vec![Bytes::from_static(&[1])], - data: json!([{"data": "value™"}]), + data: json!([{"data": "value™"}, {"_placeholder": true, "num": 0}]), payload_count: 1, }, ack, @@ -825,7 +848,7 @@ mod test { let payload = format!("61-54{}", json); let packet: String = Packet::bin_ack( "/", - json!({ "data": "value™" }), + json!([{ "data": "value™" }, { "_placeholder": true, "num": 0 }]), vec![Bytes::from_static(&[1])], 54, ) @@ -838,7 +861,7 @@ mod test { let payload = format!("61-/admin™,54{}", json); let packet: String = Packet::bin_ack( "/admin™", - json!({ "data": "value™" }), + json!([{ "data": "value™" }, { "_placeholder": true, "num": 0 }]), vec![Bytes::from_static(&[1])], 54, ) @@ -855,7 +878,7 @@ mod test { inner: PacketData::BinaryAck( BinaryPacket { bin: vec![Bytes::from_static(&[1])], - data: json!([{"data": "value™"}]), + data: json!([{"data": "value™"}, {"_placeholder": true, "num": 0}]), payload_count: 1, }, ack, diff --git a/socketioxide/src/value/de.rs b/socketioxide/src/value/de.rs new file mode 100644 index 00000000..de65611d --- /dev/null +++ b/socketioxide/src/value/de.rs @@ -0,0 +1,841 @@ +use std::{borrow::Cow, marker::PhantomData}; + +use bytes::Bytes; +use serde::{ + de::{ + self, DeserializeSeed, EnumAccess, Error, IntoDeserializer, MapAccess, SeqAccess, + Unexpected, VariantAccess, + }, + forward_to_deserialize_any, +}; +use serde_json::{Map, Value}; + +// For Deserializer impl for Value Deserializer wrapper +macro_rules! forward_deser { + ($method:ident) => { + #[inline] + fn $method(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + self.value.$method(visitor) + } + }; +} + +// For Deserializer impl for MapKeyDeserializer +macro_rules! impl_deser_numeric_key { + ($method:ident) => { + fn $method(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + // This is less efficient than it could be, but `Deserializer::from_str()` (used + // internally by `serde_json`) wants to hold a reference to `self.key` longer than is + // permitted. The methods on `Deserializer` for deserializing into a `Number` without + // that constraint are not exposed publicly. + let reader = VecRead::from(self.key); + let mut deser = serde_json::Deserializer::from_reader(reader); + let number = deser.$method(visitor)?; + let _ = deser + .end() + .map_err(|_| serde_json::Error::custom("expected numeric map key"))?; + Ok(number) + } + }; +} + +struct Deserializer<'a, T> { + value: serde_json::Value, + binary_payloads: &'a [Bytes], + _phantom: PhantomData, +} + +impl<'a, 'de, T: serde::Deserialize<'de>> serde::de::Deserializer<'de> for Deserializer<'a, T> { + type Error = serde_json::Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.value { + Value::Array(a) => visit_value_array::<'a, 'de, V, T>(a, self.binary_payloads, visitor), + Value::Object(o) => { + visit_value_object::<'a, 'de, V, T>(o, self.binary_payloads, visitor) + } + other => other.deserialize_any(visitor), + } + } + + forward_deser!(deserialize_unit); + forward_deser!(deserialize_bool); + forward_deser!(deserialize_u8); + forward_deser!(deserialize_i8); + forward_deser!(deserialize_u16); + forward_deser!(deserialize_i16); + forward_deser!(deserialize_u32); + forward_deser!(deserialize_i32); + forward_deser!(deserialize_u64); + forward_deser!(deserialize_i64); + forward_deser!(deserialize_u128); + forward_deser!(deserialize_i128); + forward_deser!(deserialize_f32); + forward_deser!(deserialize_f64); + forward_deser!(deserialize_char); + forward_deser!(deserialize_str); + forward_deser!(deserialize_string); + forward_deser!(deserialize_identifier); + + #[inline] + fn deserialize_unit_struct( + self, + name: &'static str, + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + self.value.deserialize_unit_struct(name, visitor) + } + + #[inline] + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_bytes(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.value { + Value::String(s) => visitor.visit_bytes(s.as_bytes()), + Value::Object(o) => visit_value_object_for_bytes(o, self.binary_payloads, visitor), + Value::Array(a) => visit_value_array_for_bytes(a, visitor), + _ => Err(serde_json::Error::invalid_type( + unexpected_value(&self.value), + &"byte array or binary payload", + )), + } + } + + fn deserialize_byte_buf(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + self.deserialize_bytes(visitor) + } + + #[inline] + fn deserialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + self.deserialize_seq(visitor) + } + + #[inline] + fn deserialize_tuple(self, _len: usize, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + self.deserialize_seq(visitor) + } + + #[inline] + fn deserialize_option(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.value { + Value::Null => visitor.visit_none(), + Value::Array(a) => visit_value_array::<'a, 'de, V, T>(a, self.binary_payloads, visitor), + Value::Object(o) => { + visit_value_object::<'a, 'de, V, T>(o, self.binary_payloads, visitor) + } + other => other.deserialize_option(visitor), + } + } + + fn deserialize_struct( + self, + _name: &'static str, + _fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + match self.value { + Value::Array(a) => visit_value_array::<'a, 'de, V, T>(a, self.binary_payloads, visitor), + Value::Object(o) => { + visit_value_object::<'a, 'de, V, T>(o, self.binary_payloads, visitor) + } + _ => Err(serde_json::Error::invalid_type( + unexpected_value(&self.value), + &visitor, + )), + } + } + + fn deserialize_map(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.value { + Value::Object(o) => { + visit_value_object::<'a, 'de, V, T>(o, self.binary_payloads, visitor) + } + _ => Err(serde_json::Error::invalid_type( + unexpected_value(&self.value), + &visitor, + )), + } + } + + fn deserialize_seq(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.value { + Value::Array(a) => visit_value_array::<'a, 'de, V, T>(a, self.binary_payloads, visitor), + _ => Err(serde_json::Error::invalid_type( + unexpected_value(&self.value), + &visitor, + )), + } + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + match self.value { + Value::Object(o) => { + // From serde_json: enums are encoded as a map with _only_ a single key-value pair + if o.len() == 1 { + let (variant, value) = o.into_iter().next().unwrap(); + visitor.visit_enum(EnumDeserializer { + variant, + value: Some(value), + binary_payloads: self.binary_payloads, + _phantom: PhantomData::, + }) + } else { + Err(serde_json::Error::invalid_value( + Unexpected::Map, + &"a map with a single key-value pair", + )) + } + } + Value::String(s) => visitor.visit_enum(EnumDeserializer { + variant: s, + value: None, + binary_payloads: self.binary_payloads, + _phantom: PhantomData::, + }), + _ => Err(serde_json::Error::invalid_type( + unexpected_value(&self.value), + &"map or string", + )), + } + } + + #[inline] + fn deserialize_ignored_any(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + drop(self); + visitor.visit_unit() + } +} + +fn visit_value_object_for_bytes<'a: 'a, 'de, V>( + o: Map, + binary_payloads: &'a [Bytes], + visitor: V, +) -> Result +where + V: de::Visitor<'de>, +{ + if !o.len() == 2 + || !o + .get("_placeholder") + .and_then(|v| v.as_bool()) + .unwrap_or(false) + { + Err(serde_json::Error::invalid_type( + Unexpected::Map, + &"binary payload placeholder object", + )) + } else if let Some(num) = o.get("num").and_then(|v| v.as_u64()) { + if let Some(payload) = binary_payloads.get(num as usize) { + visitor.visit_bytes(payload) + } else { + Err(serde_json::Error::invalid_value( + Unexpected::Unsigned(num), + &"a payload number in range", + )) + } + } else { + Err(serde_json::Error::invalid_value( + Unexpected::Map, + &"binary payload placeholder without valid num", + )) + } +} + +fn visit_value_array_for_bytes<'de, V>( + a: Vec, + visitor: V, +) -> Result +where + V: de::Visitor<'de>, +{ + let bytes = a + .into_iter() + .map(|v| match v { + Value::Number(n) => n + .as_u64() + .and_then(|n| u8::try_from(n).ok()) + .ok_or_else(|| { + serde_json::Error::invalid_value( + Unexpected::Other("non-u8 number"), + &"number that fits in a u8", + ) + }), + _ => Err(serde_json::Error::invalid_value( + unexpected_value(&v), + &"number that fits in a u8", + )), + }) + .collect::, _>>()?; + visitor.visit_bytes(&bytes) +} + +fn visit_value_object<'a: 'a, 'de, V, T>( + o: Map, + binary_payloads: &'a [Bytes], + visitor: V, +) -> Result +where + V: de::Visitor<'de>, + T: serde::Deserialize<'de>, +{ + let len = o.len(); + + let mut deser = MapDeserializer { + iter: o.into_iter(), + binary_payloads, + value: None, + _phantom: PhantomData::, + }; + let map = visitor.visit_map(&mut deser)?; + if deser.iter.len() == 0 { + Ok(map) + } else { + Err(serde_json::Error::invalid_length( + len, + &"fewer elements in map", + )) + } +} + +fn visit_value_array<'a: 'a, 'de, V, T>( + a: Vec, + binary_payloads: &'a [Bytes], + visitor: V, +) -> Result +where + V: de::Visitor<'de>, + T: serde::Deserialize<'de>, +{ + let len = a.len(); + let mut deser = SeqDeserializer { + iter: a.into_iter(), + binary_payloads, + _phantom: PhantomData::, + }; + let seq = visitor.visit_seq(&mut deser)?; + if deser.iter.len() == 0 { + Ok(seq) + } else { + Err(serde_json::Error::invalid_length( + len, + &"fewer elements in seq", + )) + } +} + +struct MapDeserializer<'a, T> { + iter: as IntoIterator>::IntoIter, + binary_payloads: &'a [Bytes], + value: Option, + _phantom: PhantomData, +} + +impl<'a, 'de, T: serde::Deserialize<'de>> MapAccess<'de> for MapDeserializer<'a, T> { + type Error = serde_json::Error; + + fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> + where + K: DeserializeSeed<'de>, + { + match self.iter.next() { + Some((key, value)) => { + self.value = Some(value); + let key_deser = MapKeyDeserializer::from(key); + seed.deserialize(key_deser).map(Some) + } + None => Ok(None), + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'de>, + { + match self.value.take() { + Some(value) => { + let payload = Deserializer { + value, + binary_payloads: self.binary_payloads, + _phantom: PhantomData::, + }; + seed.deserialize(payload) + } + None => Err(serde_json::Error::custom("value is missing")), + } + } + + fn size_hint(&self) -> Option { + match self.iter.size_hint() { + (lower, Some(upper)) if lower == upper => Some(upper), + _ => None, + } + } +} + +struct SeqDeserializer<'a, T> { + iter: as IntoIterator>::IntoIter, + binary_payloads: &'a [Bytes], + _phantom: PhantomData, +} + +impl<'a, 'de, T: serde::Deserialize<'de>> SeqAccess<'de> for SeqDeserializer<'a, T> { + type Error = serde_json::Error; + + fn next_element_seed(&mut self, seed: S) -> Result, Self::Error> + where + S: DeserializeSeed<'de>, + { + match self.iter.next() { + Some(value) => { + let payload = Deserializer { + value, + binary_payloads: self.binary_payloads, + _phantom: PhantomData::, + }; + seed.deserialize(payload).map(Some) + } + None => Ok(None), + } + } + + fn size_hint(&self) -> Option { + match self.iter.size_hint() { + (lower, Some(upper)) if lower == upper => Some(upper), + _ => None, + } + } +} + +struct EnumDeserializer<'a, T> { + variant: String, + value: Option, + binary_payloads: &'a [Bytes], + _phantom: PhantomData, +} + +impl<'a, 'de, T: serde::Deserialize<'de>> EnumAccess<'de> for EnumDeserializer<'a, T> { + type Variant = VariantDeserializer<'a, T>; + type Error = serde_json::Error; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> + where + V: DeserializeSeed<'de>, + { + let variant = self.variant.into_deserializer(); + let visitor = VariantDeserializer { + value: self.value, + binary_payloads: self.binary_payloads, + _phantom: PhantomData::, + }; + seed.deserialize(variant).map(|v| (v, visitor)) + } +} + +struct VariantDeserializer<'a, T> { + value: Option, + binary_payloads: &'a [Bytes], + _phantom: PhantomData, +} + +impl<'a, 'de, T: serde::Deserialize<'de>> VariantAccess<'de> for VariantDeserializer<'a, T> { + type Error = serde_json::Error; + + fn unit_variant(self) -> Result<(), Self::Error> { + match self.value { + Some(value) => { + let deser = Deserializer { + value, + binary_payloads: self.binary_payloads, + _phantom: PhantomData::, + }; + serde::Deserialize::deserialize(deser) + } + None => Ok(()), + } + } + + fn newtype_variant_seed(self, seed: S) -> Result + where + S: DeserializeSeed<'de>, + { + match self.value { + Some(value) => { + let deser = Deserializer { + value, + binary_payloads: self.binary_payloads, + _phantom: PhantomData::, + }; + seed.deserialize(deser) + } + None => Err(serde_json::Error::invalid_type( + Unexpected::Unit, + &"newtype variant", + )), + } + } + + fn tuple_variant(self, _len: usize, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.value { + Some(Value::Array(a)) => { + if a.is_empty() { + visitor.visit_unit() + } else { + visit_value_array::<'a, 'de, V, T>(a, self.binary_payloads, visitor) + } + } + Some(other) => Err(serde_json::Error::invalid_type( + unexpected_value(&other), + &"tuple variant", + )), + None => Err(serde_json::Error::invalid_type( + Unexpected::UnitVariant, + &"tuple variant", + )), + } + } + + fn struct_variant( + self, + _fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + match self.value { + Some(Value::Object(o)) => { + visit_value_object::<'a, 'de, V, T>(o, self.binary_payloads, visitor) + } + Some(other) => Err(serde_json::Error::invalid_type( + unexpected_value(&other), + &"struct variant", + )), + None => Err(serde_json::Error::invalid_type( + Unexpected::UnitVariant, + &"struct variant", + )), + } + } +} + +/// Helper struct that implements `std::io::Read`, which allows us to use +/// `serde_json::Deserializer` to deserialize a string into a `serde_json::Number`. +struct VecRead { + vec: Vec, + pos: usize, +} + +impl<'any> From> for VecRead { + fn from(value: Cow<'any, str>) -> Self { + Self { + vec: value.to_string().into_bytes(), + pos: 0, + } + } +} + +impl std::io::Read for VecRead { + fn read(&mut self, mut buf: &mut [u8]) -> std::io::Result { + use std::io::Write; + + let to_write = std::cmp::min(buf.len(), self.vec.len() - self.pos); + if to_write > 0 { + let written = buf.write(&self.vec[self.pos..to_write])?; + self.pos += to_write; + Ok(written) + } else { + Ok(0) + } + } +} + +struct MapKeyDeserializer<'de> { + key: Cow<'de, str>, +} + +impl<'de> From for MapKeyDeserializer<'de> { + fn from(value: String) -> Self { + MapKeyDeserializer { + key: Cow::Owned(value), + } + } +} + +impl<'de> serde::Deserializer<'de> for MapKeyDeserializer<'de> { + type Error = serde_json::Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + MapKeyStrDeserializer::from(self.key).deserialize_any(visitor) + } + + fn deserialize_bool(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.key.as_ref() { + "true" => visitor.visit_bool(true), + "false" => visitor.visit_bool(false), + _ => Err(serde_json::Error::invalid_type( + Unexpected::Str(&self.key), + &visitor, + )), + } + } + + impl_deser_numeric_key!(deserialize_i8); + impl_deser_numeric_key!(deserialize_i16); + impl_deser_numeric_key!(deserialize_i32); + impl_deser_numeric_key!(deserialize_i64); + impl_deser_numeric_key!(deserialize_u8); + impl_deser_numeric_key!(deserialize_u16); + impl_deser_numeric_key!(deserialize_u32); + impl_deser_numeric_key!(deserialize_u64); + impl_deser_numeric_key!(deserialize_f64); + impl_deser_numeric_key!(deserialize_f32); + impl_deser_numeric_key!(deserialize_i128); + impl_deser_numeric_key!(deserialize_u128); + + fn deserialize_option(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_some(self) + } + + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_enum( + self, + name: &'static str, + variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + self.key + .into_deserializer() + .deserialize_enum(name, variants, visitor) + } + + forward_to_deserialize_any! { + char str string bytes byte_buf unit unit_struct seq tuple tuple_struct + map struct identifier ignored_any + } +} + +struct MapKeyStrDeserializer<'de> { + key: Cow<'de, str>, +} + +impl<'de> From> for MapKeyStrDeserializer<'de> { + fn from(value: Cow<'de, str>) -> Self { + Self { key: value } + } +} + +impl<'de> serde::Deserializer<'de> for MapKeyStrDeserializer<'de> { + type Error = serde_json::Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.key { + Cow::Owned(s) => visitor.visit_string(s), + Cow::Borrowed(s) => visitor.visit_borrowed_str(s), + } + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_enum(self) + } + + forward_to_deserialize_any! { + bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string + bytes byte_buf option unit unit_struct newtype_struct seq tuple + tuple_struct map struct identifier ignored_any + } +} + +impl<'de> EnumAccess<'de> for MapKeyStrDeserializer<'de> { + type Variant = UnitEnum; + type Error = serde_json::Error; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> + where + V: de::DeserializeSeed<'de>, + { + Ok((seed.deserialize(self)?, UnitEnum)) + } +} + +struct UnitEnum; + +impl<'de> VariantAccess<'de> for UnitEnum { + type Error = serde_json::Error; + + fn unit_variant(self) -> Result<(), Self::Error> { + Ok(()) + } + + fn tuple_variant(self, _len: usize, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + Err(serde_json::Error::invalid_type( + Unexpected::UnitVariant, + &"tuple variant", + )) + } + + fn struct_variant( + self, + _fields: &'static [&'static str], + _visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + Err(serde_json::Error::invalid_type( + Unexpected::UnitVariant, + &"struct variant", + )) + } + + fn newtype_variant_seed(self, _seed: T) -> Result + where + T: de::DeserializeSeed<'de>, + { + Err(serde_json::Error::invalid_type( + Unexpected::UnitVariant, + &"newtype variant", + )) + } +} + +fn unexpected_value(v: &Value) -> Unexpected<'_> { + match v { + Value::Null => Unexpected::Unit, + Value::Bool(b) => Unexpected::Bool(*b), + Value::Number(n) => { + if let Some(n) = n.as_u64() { + Unexpected::Unsigned(n) + } else if let Some(n) = n.as_i64() { + Unexpected::Signed(n) + } else if let Some(n) = n.as_f64() { + Unexpected::Float(n) + } else { + Unexpected::Other("non-unsigned, non-signed, non-float number") + } + } + Value::String(s) => Unexpected::Str(s), + Value::Array(_) => Unexpected::Seq, + Value::Object(_) => Unexpected::Map, + } +} + +/// Converts a [`serde_json::Value`], with optional binary payloads into an arbitrary data type. +/// +/// # Arguments +/// +/// - `value` - a [`serde_json::Value`], with any binary blobs replaced with socket.io placeholder +/// objects +/// - `binary_payloads` - a [`Vec`] of binary payloads, in the order specified by the `num` fields +/// of the placeholder objecst in `value` +pub fn from_value<'de, 'a, T: serde::Deserialize<'de> + 'a>( + value: Value, + binary_payloads: &'a [Bytes], +) -> Result { + let payload = Deserializer { + value, + binary_payloads, + _phantom: PhantomData::, + }; + + T::deserialize(payload) +} diff --git a/socketioxide/src/value/mod.rs b/socketioxide/src/value/mod.rs new file mode 100644 index 00000000..5eec7e6d --- /dev/null +++ b/socketioxide/src/value/mod.rs @@ -0,0 +1,179 @@ +pub(crate) mod de; +pub(crate) mod ser; + +#[cfg(test)] +mod test { + use bytes::Bytes; + use serde::{Deserialize, Serialize}; + use serde_json::{json, Value}; + + use super::de::from_value; + use super::ser::to_value; + + #[derive(Debug, Serialize, Deserialize, PartialEq)] + struct TestSubPayload { + more_binary: Bytes, + opt_int: Option, + opt_float: Option, + opt_string: Option, + opt_boolean: Option, + } + + #[derive(Debug, Serialize, Deserialize, PartialEq)] + struct TestPayload { + uint: u32, + float: f64, + binary: Bytes, + string: String, + boolean: bool, + array: Vec, + sub_payload: TestSubPayload, + } + + const BINARY_PAYLOAD: &[u8] = &[1, 2, 3, 4, 5, 6, 7]; + const MORE_BINARY_PAYLOAD: &[u8] = &[10, 9, 8, 7, 6, 5]; + + fn build_bins() -> Vec { + vec![ + Bytes::from_static(BINARY_PAYLOAD), + Bytes::from_static(MORE_BINARY_PAYLOAD), + ] + } + + fn build_test_payload(fill_options: bool) -> TestPayload { + let fill_options = fill_options.then_some(true); + TestPayload { + uint: 42, + float: 1.75, + binary: Bytes::from_static(BINARY_PAYLOAD), + string: "test string".to_string(), + boolean: true, + array: ["one", "two", "three"] + .into_iter() + .map(ToString::to_string) + .collect(), + sub_payload: TestSubPayload { + more_binary: Bytes::from_static(MORE_BINARY_PAYLOAD), + opt_int: fill_options.map(|_| 99), + opt_float: fill_options.map(|_| 2.5), + opt_string: fill_options.map(|_| "another test string".to_string()), + opt_boolean: fill_options, + }, + } + } + + fn build_test_value(fill_options: bool) -> Value { + let sub_payload = if fill_options { + json!({ + "more_binary": { + "_placeholder": true, + "num": 1 + }, + "opt_int": 99, + "opt_float": 2.5, + "opt_string": "another test string", + "opt_boolean": true + }) + } else { + json!({ + "more_binary": { + "_placeholder": true, + "num": 1 + }, + "opt_int": null, + "opt_float": null, + "opt_string": null, + "opt_boolean": null + }) + }; + + let mut main_payload = json!({ + "uint": 42, + "float": 1.75, + "binary": { + "_placeholder": true, + "num": 0 + }, + "string": "test string", + "boolean": true, + "array": [ + "one", + "two", + "three" + ] + }); + + if let Value::Object(ref mut o) = &mut main_payload { + o.insert("sub_payload".to_string(), sub_payload); + } else { + panic!("test bug: not an object"); + } + + main_payload + } + + #[test] + pub fn test_value_from_data() { + let (value, bins) = to_value(build_test_payload(true)).unwrap(); + assert_eq!(build_test_value(true), value); + assert_eq!(bins.len(), 2); + assert_eq!(bins[0], Bytes::from_static(BINARY_PAYLOAD)); + assert_eq!(bins[1], Bytes::from_static(MORE_BINARY_PAYLOAD)); + } + + #[test] + pub fn test_value_into_data() { + let data: TestPayload = + from_value(build_test_value(true), build_bins().as_slice()).unwrap(); + assert_eq!(data, build_test_payload(true)); + } + + /* + + #[test] + pub fn test_payload_value_to_json_value() { + let payload_value = build_test_payload_value(true, false); + let json = payload_value.to_value().unwrap(); + assert_eq!(json, build_test_payload_json_value(true)); + + let payload_value = build_test_payload_value(false, false); + let json = payload_value.to_value().unwrap(); + assert_eq!(json, build_test_payload_json_value(false)); + } + + #[test] + pub fn test_payload_value_from_json_value() { + let json = build_test_payload_json_value(true); + let payload_value: PayloadValue = serde_json::from_value(json).unwrap(); + assert_eq!(payload_value, build_test_payload_value(true, false)); + + let json = build_test_payload_json_value(false); + let payload_value: PayloadValue = serde_json::from_value(json).unwrap(); + assert_eq!(payload_value, build_test_payload_value(false, false)); + } + + #[test] + pub fn test_count_payloads() { + let payload_value = build_test_payload_value(true, false); + assert_eq!(payload_value.count_payloads(), 2); + } + + #[test] + pub fn test_extract_binary_payloads() { + let test_payload = build_test_payload(true); + let payload_value = build_test_payload_value(true, true); + let bins = payload_value.get_binary_payloads(); + + assert_eq!(bins.len(), 2); + assert_eq!(bins[0], *test_payload.binary); + assert_eq!(bins[1], *test_payload.sub_payload.more_binary); + } + + #[test] + pub fn test_payload_value_redeser() { + let payload_value_again: PayloadValue = + build_test_payload_value(true, true).into_data().unwrap(); + assert_eq!(build_test_payload_value(true, true), payload_value_again); + } + */ +} diff --git a/socketioxide/src/value/ser.rs b/socketioxide/src/value/ser.rs new file mode 100644 index 00000000..fac77a30 --- /dev/null +++ b/socketioxide/src/value/ser.rs @@ -0,0 +1,547 @@ +use bytes::Bytes; +use serde::ser::{Error, Impossible, SerializeSeq}; +use serde_json::{Map, Value}; + +const KEY_STRING_ERROR: &str = "key must be a string"; + +macro_rules! forward_ser_impl { + ($method:ident, $ty:ty) => { + fn $method(self, v: $ty) -> Result { + serde_json::value::Serializer.$method(v) + } + }; +} + +#[derive(Default)] +struct Serializer { + binary_payloads: Vec<(usize, Bytes)>, + next_binary_payload_num: usize, +} + +impl<'a> serde::Serializer for &'a mut Serializer { + type Ok = Value; + type Error = serde_json::Error; + + type SerializeSeq = SerializeVec<'a>; + type SerializeTuple = SerializeVec<'a>; + type SerializeTupleVariant = SerializeTupleVariant<'a>; + type SerializeMap = SerializeMap<'a>; + type SerializeStruct = SerializeMap<'a>; + type SerializeStructVariant = SerializeStructVariant<'a>; + type SerializeTupleStruct = SerializeVec<'a>; + + fn serialize_unit(self) -> Result { + serde_json::value::Serializer.serialize_unit() + } + + forward_ser_impl!(serialize_bool, bool); + forward_ser_impl!(serialize_i8, i8); + forward_ser_impl!(serialize_u8, u8); + forward_ser_impl!(serialize_i16, i16); + forward_ser_impl!(serialize_u16, u16); + forward_ser_impl!(serialize_i32, i32); + forward_ser_impl!(serialize_u32, u32); + forward_ser_impl!(serialize_i64, i64); + forward_ser_impl!(serialize_u64, u64); + forward_ser_impl!(serialize_i128, i128); + forward_ser_impl!(serialize_u128, u128); + forward_ser_impl!(serialize_f32, f32); + forward_ser_impl!(serialize_f64, f64); + forward_ser_impl!(serialize_char, char); + forward_ser_impl!(serialize_str, &str); + + fn serialize_bytes(self, v: &[u8]) -> Result { + let num = self.next_binary_payload_num; + self.next_binary_payload_num += 1; + self.binary_payloads.push((num, Bytes::copy_from_slice(v))); + Ok(Value::Object( + [ + ("_placeholder".to_string(), Value::Bool(true)), + ("num".to_string(), Value::Number(num.into())), + ] + .into_iter() + .collect(), + )) + } + + fn serialize_unit_struct(self, name: &'static str) -> Result { + serde_json::value::Serializer.serialize_unit_struct(name) + } + + fn serialize_unit_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + ) -> Result { + serde_json::value::Serializer.serialize_unit_variant(name, variant_index, variant) + } + + fn serialize_newtype_struct( + self, + _name: &'static str, + value: &T, + ) -> Result { + value.serialize(self) + } + + fn serialize_none(self) -> Result { + serde_json::value::Serializer.serialize_none() + } + + fn serialize_some( + self, + value: &T, + ) -> Result { + value.serialize(self) + } + + fn serialize_seq(self, len: Option) -> Result { + Ok(SerializeVec { + root_ser: self, + elements: Vec::with_capacity(len.unwrap_or(0)), + }) + } + + fn serialize_tuple(self, len: usize) -> Result { + self.serialize_seq(Some(len)) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result { + Ok(SerializeTupleVariant { + root_ser: self, + name: String::from(variant), + elements: Vec::with_capacity(len), + }) + } + + fn serialize_map(self, _len: Option) -> Result { + Ok(SerializeMap { + root_ser: self, + map: Map::new(), + next_key: None, + }) + } + + fn serialize_struct( + self, + _name: &'static str, + len: usize, + ) -> Result { + self.serialize_map(Some(len)) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + _len: usize, + ) -> Result { + Ok(SerializeStructVariant { + root_ser: self, + name: String::from(variant), + map: Map::new(), + }) + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + len: usize, + ) -> Result { + self.serialize_seq(Some(len)) + } + + fn serialize_newtype_variant( + self, + name: &'static str, + _variant_index: u32, + _variant: &'static str, + value: &T, + ) -> Result { + let mut obj = Map::new(); + obj.insert(name.to_string(), value.serialize(self)?); + Ok(Value::Object(obj)) + } + + fn collect_str( + self, + value: &T, + ) -> Result { + Ok(Value::String(value.to_string())) + } +} + +struct SerializeVec<'a> { + root_ser: &'a mut Serializer, + elements: Vec, +} + +impl<'a> SerializeSeq for SerializeVec<'a> { + type Ok = Value; + type Error = serde_json::Error; + + fn serialize_element( + &mut self, + value: &T, + ) -> Result<(), Self::Error> { + let element = value.serialize(&mut *self.root_ser)?; + self.elements.push(element); + Ok(()) + } + + fn end(self) -> Result { + Ok(Value::Array(self.elements)) + } +} + +impl<'a> serde::ser::SerializeTuple for SerializeVec<'a> { + type Ok = Value; + type Error = serde_json::Error; + + fn serialize_element( + &mut self, + value: &T, + ) -> Result<(), Self::Error> { + serde::ser::SerializeSeq::serialize_element(self, value) + } + + fn end(self) -> Result { + serde::ser::SerializeSeq::end(self) + } +} + +impl<'a> serde::ser::SerializeTupleStruct for SerializeVec<'a> { + type Ok = Value; + type Error = serde_json::Error; + + fn serialize_field( + &mut self, + value: &T, + ) -> Result<(), Self::Error> { + serde::ser::SerializeSeq::serialize_element(self, value) + } + + fn end(self) -> Result { + serde::ser::SerializeSeq::end(self) + } +} + +struct SerializeTupleVariant<'a> { + root_ser: &'a mut Serializer, + name: String, + elements: Vec, +} + +impl<'a> serde::ser::SerializeTupleVariant for SerializeTupleVariant<'a> { + type Ok = Value; + type Error = serde_json::Error; + + fn serialize_field( + &mut self, + value: &T, + ) -> Result<(), Self::Error> { + let element = value.serialize(&mut *self.root_ser)?; + self.elements.push(element); + Ok(()) + } + + fn end(self) -> Result { + let mut obj = Map::new(); + obj.insert(self.name, Value::Array(self.elements)); + Ok(Value::Object(obj)) + } +} + +struct SerializeMap<'a> { + root_ser: &'a mut Serializer, + map: Map, + next_key: Option, +} + +impl<'a> serde::ser::SerializeMap for SerializeMap<'a> { + type Ok = Value; + type Error = serde_json::Error; + + fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> { + self.next_key = Some(key.serialize(MapKeySerializer)?); + Ok(()) + } + + fn serialize_value( + &mut self, + value: &T, + ) -> Result<(), Self::Error> { + if let Some(key) = self.next_key.take() { + self.map.insert(key, value.serialize(&mut *self.root_ser)?); + Ok(()) + } else { + panic!("serialize_value() called before serialize_key()"); + } + } + + fn serialize_entry( + &mut self, + key: &K, + value: &V, + ) -> Result<(), Self::Error> { + self.map.insert( + key.serialize(MapKeySerializer)?, + value.serialize(&mut *self.root_ser)?, + ); + Ok(()) + } + + fn end(self) -> Result { + Ok(Value::Object(self.map)) + } +} + +impl<'a> serde::ser::SerializeStruct for SerializeMap<'a> { + type Ok = Value; + type Error = serde_json::Error; + + fn serialize_field( + &mut self, + key: &'static str, + value: &T, + ) -> Result<(), Self::Error> { + serde::ser::SerializeMap::serialize_entry(self, key, value) + } + + fn end(self) -> Result { + serde::ser::SerializeMap::end(self) + } +} + +struct SerializeStructVariant<'a> { + root_ser: &'a mut Serializer, + name: String, + map: Map, +} + +impl<'a> serde::ser::SerializeStructVariant for SerializeStructVariant<'a> { + type Ok = Value; + type Error = serde_json::Error; + + fn serialize_field( + &mut self, + key: &'static str, + value: &T, + ) -> Result<(), Self::Error> { + self.map + .insert(key.to_string(), value.serialize(&mut *self.root_ser)?); + Ok(()) + } + + fn end(self) -> Result { + let mut obj = Map::new(); + obj.insert(self.name, Value::Object(self.map)); + Ok(Value::Object(obj)) + } +} + +struct MapKeySerializer; + +impl serde::Serializer for MapKeySerializer { + type Ok = String; + type Error = serde_json::Error; + + type SerializeSeq = Impossible; + type SerializeMap = Impossible; + type SerializeTuple = Impossible; + type SerializeStruct = Impossible; + type SerializeTupleStruct = Impossible; + type SerializeTupleVariant = Impossible; + type SerializeStructVariant = Impossible; + + fn serialize_bool(self, v: bool) -> Result { + Ok(v.to_string()) + } + + fn serialize_i8(self, v: i8) -> Result { + Ok(v.to_string()) + } + + fn serialize_i16(self, v: i16) -> Result { + Ok(v.to_string()) + } + + fn serialize_i32(self, v: i32) -> Result { + Ok(v.to_string()) + } + + fn serialize_i64(self, v: i64) -> Result { + Ok(v.to_string()) + } + + fn serialize_i128(self, v: i128) -> Result { + Ok(v.to_string()) + } + + fn serialize_u8(self, v: u8) -> Result { + Ok(v.to_string()) + } + + fn serialize_u16(self, v: u16) -> Result { + Ok(v.to_string()) + } + + fn serialize_u32(self, v: u32) -> Result { + Ok(v.to_string()) + } + + fn serialize_u64(self, v: u64) -> Result { + Ok(v.to_string()) + } + + fn serialize_u128(self, v: u128) -> Result { + Ok(v.to_string()) + } + + fn serialize_f32(self, v: f32) -> Result { + Ok(v.to_string()) + } + + fn serialize_f64(self, v: f64) -> Result { + Ok(v.to_string()) + } + + fn serialize_char(self, v: char) -> Result { + Ok(v.to_string()) + } + + fn serialize_str(self, v: &str) -> Result { + Ok(v.to_string()) + } + + fn serialize_bytes(self, _v: &[u8]) -> Result { + Err(Self::Error::custom(KEY_STRING_ERROR)) + } + + fn serialize_seq(self, _len: Option) -> Result { + Err(Self::Error::custom(KEY_STRING_ERROR)) + } + + fn serialize_map(self, _len: Option) -> Result { + Err(Self::Error::custom(KEY_STRING_ERROR)) + } + + fn serialize_none(self) -> Result { + Err(Self::Error::custom(KEY_STRING_ERROR)) + } + + fn serialize_some( + self, + _value: &T, + ) -> Result { + Err(Self::Error::custom(KEY_STRING_ERROR)) + } + + fn serialize_unit(self) -> Result { + Err(Self::Error::custom(KEY_STRING_ERROR)) + } + + fn serialize_tuple(self, _len: usize) -> Result { + Err(Self::Error::custom(KEY_STRING_ERROR)) + } + + fn serialize_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Err(Self::Error::custom(KEY_STRING_ERROR)) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + Err(Self::Error::custom(KEY_STRING_ERROR)) + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + ) -> Result { + Err(Self::Error::custom(KEY_STRING_ERROR)) + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Err(Self::Error::custom(KEY_STRING_ERROR)) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(Self::Error::custom(KEY_STRING_ERROR)) + } + + fn serialize_newtype_struct( + self, + _name: &'static str, + _value: &T, + ) -> Result { + Err(Self::Error::custom(KEY_STRING_ERROR)) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(Self::Error::custom(KEY_STRING_ERROR)) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result { + Err(Self::Error::custom(KEY_STRING_ERROR)) + } +} + +/// Converts an arbitrary data type to a [`serde_json::Value`], extracting any binary payloads. +/// +/// Binary data should be represented using the type [`bytes::Bytes`], or any other type that +/// implements [`serde::Serialize`] by serializing to bytes. You can also use +/// `#[serde(serialize_with = "...")]` on a custom or existing type. +/// +/// # Arguments +/// +/// - `data` - a data type that implements [`serde::Serialize`] +pub fn to_value(data: T) -> Result<(Value, Vec), serde_json::Error> { + let mut ser = Serializer::default(); + let value = data.serialize(&mut ser)?; + + ser.binary_payloads + .sort_by(|(num_a, _), (num_b, _)| num_a.cmp(num_b)); + + Ok(( + value, + ser.binary_payloads + .into_iter() + .map(|(_, bin)| bin) + .collect(), + )) +}