From f808615ba94089c77f4400a1a7b52b19085db5e3 Mon Sep 17 00:00:00 2001 From: "Brian J. Tarricone" Date: Thu, 7 Mar 2024 00:55:48 -0800 Subject: [PATCH] wip: integrate binary data into payloads Previously, the non-binary part of a message and the binary payloads in a message were represented separately: the non-binary portion was represented by a serde_json::Value, and could be converted to an arbitrary data structure. That data structure would not include the binary data or any indication that there is any binary data at all. The binary data would be provided in a Vec>. There were a few problems with this: 1. The original code only supported cases where the payload was a flat array with some binary payloads in the root of the array, or a flat object where the root of the object was a binary payload. Objects with more complicated structure and binary data embedded in various places in the structure were not supported. 2. Adding support for the above turned out to not be possible in a useful way, because the ordering of the Vec> matters, and it could never be clear where exactly in the possibly-complex structure each binary payload belonged. 3. One of the useful features of the socket.io protocol is that it lets users efficiently transmit binary data in addition to textual/numeric data, and have that handled transparently by the protocol, with either end of the connection believing that they just sent or received a single mixed textual/numeric/binary payload. Separating the non-binary from the binary negates that benefit. This introduces a new type, PayloadValue, that behaves similarly to serde_json::Value. The main difference is that it has a Binary variant, which holds a numeric index and a Vec. This allows us to include the binary data where the sender of that data intended it to be. There is currently one wrinkle: serde_json does not appear to consistently handle binary data; when serializing a struct with Vec, I believe it will serialize it as an array of numbers, rather than recognize that it's binary data. For now, I've included a Binary struct that wraps a Vec, which can be included as the type of a binary member, instead of using a Vec directly. Hopefully I'll be able to figure out a better way to do this. Unfinished tasks: * Testing: I have no idea if this even works yet. All I've done is get it to compile. * Benchmarking: I've tried to ensure that I don't copy data any more than the existing library does, but it's possible I've introduced some performance regressions, so I'll have to look into that. * Documentation: the documentation still references the old way of doing things and needs to be updated. Closes #276. --- e2e/socketioxide/socketioxide.rs | 26 +- socketioxide/src/ack.rs | 59 ++-- socketioxide/src/client.rs | 5 +- socketioxide/src/errors.rs | 10 +- socketioxide/src/handler/extract.rs | 64 ++--- socketioxide/src/handler/message.rs | 39 ++- socketioxide/src/io.rs | 28 -- socketioxide/src/lib.rs | 2 +- socketioxide/src/operators.rs | 66 +---- socketioxide/src/packet.rs | 401 +++++++++++++++++++--------- socketioxide/src/socket.rs | 92 +++---- 11 files changed, 425 insertions(+), 367 deletions(-) diff --git a/e2e/socketioxide/socketioxide.rs b/e2e/socketioxide/socketioxide.rs index 890adc38..80c06636 100644 --- a/e2e/socketioxide/socketioxide.rs +++ b/e2e/socketioxide/socketioxide.rs @@ -4,46 +4,44 @@ use std::time::Duration; use hyper::server::conn::http1; use hyper_util::rt::TokioIo; -use serde_json::Value; use socketioxide::{ - extract::{AckSender, Bin, Data, SocketRef}, - SocketIo, + extract::{AckSender, Data, SocketRef}, + PayloadValue, SocketIo, }; use tokio::net::TcpListener; use tracing::{info, Level}; use tracing_subscriber::FmtSubscriber; -fn on_connect(socket: SocketRef, Data(data): Data) { +fn on_connect(socket: SocketRef, Data(data): Data) { info!("Socket.IO connected: {:?} {:?}", socket.ns(), socket.id); socket.emit("auth", data).ok(); socket.on( "message", - |socket: SocketRef, Data::(data), Bin(bin)| { - info!("Received event: {:?} {:?}", data, bin); - socket.bin(bin).emit("message-back", data).ok(); + |socket: SocketRef, Data::(data)| { + info!("Received event: {:?}", data); + socket.emit("message-back", data).ok(); }, ); // keep this handler async to test async message handlers socket.on( "message-with-ack", - |Data::(data), ack: AckSender, Bin(bin)| async move { - info!("Received event: {:?} {:?}", data, bin); - ack.bin(bin).send(data).ok(); + |Data::(data), ack: AckSender| async move { + info!("Received event: {:?}", data); + ack.send(data).ok(); }, ); socket.on( "emit-with-ack", - |s: SocketRef, Data::(data), Bin(bin)| async move { + |s: SocketRef, Data::(data)| async move { let ack = s - .bin(bin) - .emit_with_ack::<_, Value>("emit-with-ack", data) + .emit_with_ack::<_, PayloadValue>("emit-with-ack", data) .unwrap() .await .unwrap(); - s.bin(ack.binary).emit("emit-with-ack", ack.data).unwrap(); + s.emit("emit-with-ack", ack.data).unwrap(); }, ); } diff --git a/socketioxide/src/ack.rs b/socketioxide/src/ack.rs index 02f1cc98..52757013 100644 --- a/socketioxide/src/ack.rs +++ b/socketioxide/src/ack.rs @@ -18,10 +18,12 @@ use futures::{ Future, Stream, }; use serde::de::DeserializeOwned; -use serde_json::Value; use tokio::{sync::oneshot::Receiver, time::Timeout}; -use crate::{adapter::Adapter, errors::AckError, extract::SocketRef, packet::Packet, SocketError}; +use crate::{ + adapter::Adapter, errors::AckError, extract::SocketRef, packet::Packet, + payload_value::PayloadValue, SocketError, +}; /// An acknowledgement sent by the client. /// It contains the data sent by the client and the binary payloads if there are any. @@ -29,12 +31,9 @@ use crate::{adapter::Adapter, errors::AckError, extract::SocketRef, packet::Pack pub struct AckResponse { /// The data returned by the client pub data: T, - /// Optional binary payloads. - /// If there is no binary payload, the `Vec` will be empty - pub binary: Vec>, } -pub(crate) type AckResult = Result, AckError<()>>; +pub(crate) type AckResult = Result, AckError<()>>; pin_project_lite::pin_project! { /// A [`Future`] of [`AckResponse`] received from the client with its corresponding [`Sid`]. @@ -127,12 +126,12 @@ pin_project_lite::pin_project! { pub enum AckInnerStream { Stream { #[pin] - rxs: FuturesUnordered>, + rxs: FuturesUnordered>, }, Fut { #[pin] - rx: AckResultWithId, + rx: AckResultWithId, polled: bool, }, } @@ -171,7 +170,7 @@ impl AckInnerStream { /// Creates a new [`AckInnerStream`] from a [`oneshot::Receiver`](tokio) corresponding to the acknowledgement /// of a single socket. - pub fn send(rx: Receiver>, duration: Duration, id: Sid) -> Self { + pub fn send(rx: Receiver>, duration: Duration, id: Sid) -> Self { AckInnerStream::Fut { polled: false, rx: AckResultWithId { @@ -183,7 +182,7 @@ impl AckInnerStream { } impl Stream for AckInnerStream { - type Item = (Sid, AckResult); + type Item = (Sid, AckResult); fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { use InnerProj::*; @@ -221,7 +220,7 @@ impl FusedStream for AckInnerStream { } impl Future for AckInnerStream { - type Output = AckResult; + type Output = AckResult; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.as_mut().poll_next(cx) { @@ -295,13 +294,11 @@ impl From for AckStream { } } -fn map_ack_response(ack: AckResult) -> AckResult { +fn map_ack_response(ack: AckResult) -> AckResult { ack.and_then(|v| { - serde_json::from_value(v.data) - .map(|data| AckResponse { - data, - binary: v.binary, - }) + v.data + .into_data::() + .map(|data| AckResponse { data }) .map_err(|e| e.into()) }) } @@ -328,12 +325,12 @@ mod test { async fn broadcast_ack() { let socket = create_socket(); let socket2 = create_socket(); - let mut packet = Packet::event("/", "test", "test".into()); + let mut packet = Packet::event("/", "test", PayloadValue::from_data("test").unwrap()); packet.inner.set_ack_id(1); let socks = vec![socket.clone().into(), socket2.clone().into()]; let stream: AckStream = AckInnerStream::broadcast(packet, socks, None).into(); - let res_packet = Packet::ack("test", "test".into(), 1); + let res_packet = Packet::ack("test", PayloadValue::from_data("test").unwrap(), 1); socket.recv(res_packet.inner.clone()).unwrap(); socket2.recv(res_packet.inner).unwrap(); @@ -351,8 +348,7 @@ mod test { let stream: AckStream = AckInnerStream::send(rx, Duration::from_secs(1), sid).into(); tx.send(Ok(AckResponse { - data: Value::String("test".into()), - binary: vec![], + data: PayloadValue::String("test".into()), })) .unwrap(); @@ -372,8 +368,7 @@ mod test { let stream: AckStream = AckInnerStream::send(rx, Duration::from_secs(1), sid).into(); tx.send(Ok(AckResponse { - data: Value::String("test".into()), - binary: vec![], + data: PayloadValue::String("test".into()), })) .unwrap(); @@ -384,12 +379,12 @@ mod test { async fn broadcast_ack_with_deserialize_error() { let socket = create_socket(); let socket2 = create_socket(); - let mut packet = Packet::event("/", "test", "test".into()); + let mut packet = Packet::event("/", "test", PayloadValue::from_data("test").unwrap()); packet.inner.set_ack_id(1); let socks = vec![socket.clone().into(), socket2.clone().into()]; let stream: AckStream = AckInnerStream::broadcast(packet, socks, None).into(); - let res_packet = Packet::ack("test", 132.into(), 1); + let res_packet = Packet::ack("test", PayloadValue::from_data(132).unwrap(), 1); socket.recv(res_packet.inner.clone()).unwrap(); socket2.recv(res_packet.inner).unwrap(); @@ -413,8 +408,7 @@ mod test { let stream: AckStream = AckInnerStream::send(rx, Duration::from_secs(1), sid).into(); tx.send(Ok(AckResponse { - data: Value::Bool(true), - binary: vec![], + data: PayloadValue::Bool(true), })) .unwrap(); assert_eq!(stream.size_hint().0, 1); @@ -436,8 +430,7 @@ mod test { let stream: AckStream = AckInnerStream::send(rx, Duration::from_secs(1), sid).into(); tx.send(Ok(AckResponse { - data: Value::Bool(true), - binary: vec![], + data: PayloadValue::Bool(true), })) .unwrap(); @@ -448,12 +441,12 @@ mod test { async fn broadcast_ack_with_closed_socket() { let socket = create_socket(); let socket2 = create_socket(); - let mut packet = Packet::event("/", "test", "test".into()); + let mut packet = Packet::event("/", "test", PayloadValue::from_data("test").unwrap()); packet.inner.set_ack_id(1); let socks = vec![socket.clone().into(), socket2.clone().into()]; let stream: AckStream = AckInnerStream::broadcast(packet, socks, None).into(); - let res_packet = Packet::ack("test", "test".into(), 1); + let res_packet = Packet::ack("test", PayloadValue::from_data("test").unwrap(), 1); socket.clone().recv(res_packet.inner.clone()).unwrap(); futures::pin_mut!(stream); @@ -503,14 +496,14 @@ mod test { async fn broadcast_ack_with_timeout() { let socket = create_socket(); let socket2 = create_socket(); - let mut packet = Packet::event("/", "test", "test".into()); + let mut packet = Packet::event("/", "test", PayloadValue::from_data("test").unwrap()); packet.inner.set_ack_id(1); let socks = vec![socket.clone().into(), socket2.clone().into()]; let stream: AckStream = AckInnerStream::broadcast(packet, socks, Some(Duration::from_millis(10))).into(); socket - .recv(Packet::ack("test", "test".into(), 1).inner) + .recv(Packet::ack("test", PayloadValue::from_data("test").unwrap(), 1).inner) .unwrap(); futures::pin_mut!(stream); diff --git a/socketioxide/src/client.rs b/socketioxide/src/client.rs index ffb9fbb1..51281ba3 100644 --- a/socketioxide/src/client.rs +++ b/socketioxide/src/client.rs @@ -2,6 +2,7 @@ use std::borrow::Cow; use std::collections::HashMap; use std::sync::{Arc, Mutex, RwLock}; +use bytes::Bytes; use engineioxide::handler::EngineIoHandler; use engineioxide::socket::{DisconnectReason as EIoDisconnectReason, Socket as EIoSocket}; use futures::TryFutureExt; @@ -242,7 +243,7 @@ impl EngineIoHandler for Client { /// /// If the packet is complete, it is propagated to the namespace fn on_binary(&self, data: Vec, socket: Arc>) { - if apply_payload_on_packet(data, &socket) { + if apply_payload_on_packet(data.into(), &socket) { if let Some(packet) = socket.data.partial_bin_packet.lock().unwrap().take() { if let Err(ref err) = self.sock_propagate_packet(packet, socket.id) { #[cfg(feature = "tracing")] @@ -264,7 +265,7 @@ impl EngineIoHandler for Client { /// waiting to be filled with all the payloads /// /// Returns true if the packet is complete and should be processed -fn apply_payload_on_packet(data: Vec, socket: &EIoSocket) -> bool { +fn apply_payload_on_packet(data: Bytes, socket: &EIoSocket) -> bool { #[cfg(feature = "tracing")] tracing::debug!("[sid={}] applying payload on packet", socket.id); if let Some(ref mut packet) = *socket.data.partial_bin_packet.lock().unwrap() { diff --git a/socketioxide/src/errors.rs b/socketioxide/src/errors.rs index f7294789..2466e096 100644 --- a/socketioxide/src/errors.rs +++ b/socketioxide/src/errors.rs @@ -11,6 +11,9 @@ pub enum Error { #[error("invalid packet type")] InvalidPacketType, + #[error("invalid binary payload count")] + InvalidPayloadCount, + #[error("invalid event name")] InvalidEventName, @@ -165,9 +168,10 @@ impl From<&Error> for Option { use EIoDisconnectReason::*; match value { Error::SocketGone(_) => Some(TransportClose), - Error::Serialize(_) | Error::InvalidPacketType | Error::InvalidEventName => { - Some(PacketParsingError) - } + Error::Serialize(_) + | Error::InvalidPacketType + | Error::InvalidEventName + | Error::InvalidPayloadCount => Some(PacketParsingError), Error::Adapter(_) | Error::InvalidNamespace => None, } } diff --git a/socketioxide/src/handler/extract.rs b/socketioxide/src/handler/extract.rs index 7ee9c16e..594b4bc6 100644 --- a/socketioxide/src/handler/extract.rs +++ b/socketioxide/src/handler/extract.rs @@ -25,6 +25,7 @@ //! # use socketioxide::handler::{FromConnectParts, FromMessageParts}; //! # use socketioxide::adapter::Adapter; //! # use socketioxide::socket::Socket; +//! # use socketioxide::PayloadValue; //! # use std::sync::Arc; //! # use std::convert::Infallible; //! # use socketioxide::SocketIo; @@ -61,8 +62,7 @@ //! //! fn from_message_parts( //! s: &Arc>, -//! _: &mut serde_json::Value, -//! _: &mut Vec>, +//! _: &mut PayloadValue, //! _: &Option, //! ) -> Result { //! // In a real app it would be better to parse the query params with a crate like `url` @@ -89,6 +89,7 @@ use super::message::FromMessageParts; use super::FromDisconnectParts; use super::{connect::FromConnectParts, message::FromMessage}; use crate::errors::{DisconnectError, SendError}; +use crate::payload_value::PayloadValue; use crate::socket::DisconnectReason; use crate::{ adapter::{Adapter, LocalAdapter}, @@ -96,16 +97,15 @@ use crate::{ socket::Socket, }; use serde::{de::DeserializeOwned, Serialize}; -use serde_json::Value; #[cfg(feature = "state")] #[cfg_attr(docsrs, doc(cfg(feature = "state")))] pub use state_extract::*; /// Utility function to unwrap an array with a single element -fn upwrap_array(v: &mut Value) { +fn upwrap_array(v: &mut PayloadValue) { match v { - Value::Array(vec) if vec.len() == 1 => { + PayloadValue::Array(vec) if vec.len() == 1 => { *v = vec.pop().unwrap(); } _ => (), @@ -137,12 +137,11 @@ where type Error = serde_json::Error; fn from_message_parts( _: &Arc>, - v: &mut serde_json::Value, - _: &mut Vec>, + v: &mut PayloadValue, _: &Option, ) -> Result { upwrap_array(v); - serde_json::from_value(v.clone()).map(Data) + v.clone().into_data::().map(Data) } } @@ -171,12 +170,11 @@ where type Error = Infallible; fn from_message_parts( _: &Arc>, - v: &mut serde_json::Value, - _: &mut Vec>, + v: &mut PayloadValue, _: &Option, ) -> Result { upwrap_array(v); - Ok(TryData(serde_json::from_value(v.clone()))) + Ok(TryData(v.clone().into_data::())) } } /// An Extractor that returns a reference to a [`Socket`]. @@ -193,8 +191,7 @@ impl FromMessageParts for SocketRef { type Error = Infallible; fn from_message_parts( s: &Arc>, - _: &mut serde_json::Value, - _: &mut Vec>, + _: &mut PayloadValue, _: &Option, ) -> Result { Ok(SocketRef(s.clone())) @@ -244,11 +241,10 @@ impl FromMessage for Bin { type Error = Infallible; fn from_message( _: Arc>, - _: serde_json::Value, - bin: Vec>, + mut v: PayloadValue, _: Option, ) -> Result { - Ok(Bin(bin)) + Ok(Bin(v.extract_binary_payloads())) } } @@ -256,7 +252,6 @@ impl FromMessage for Bin { /// If the client sent a normal message without expecting an ack, the ack callback will do nothing. #[derive(Debug)] pub struct AckSender { - binary: Vec>, socket: Arc>, ack_id: Option, } @@ -264,8 +259,7 @@ impl FromMessageParts for AckSender { type Error = Infallible; fn from_message_parts( s: &Arc>, - _: &mut serde_json::Value, - _: &mut Vec>, + _: &mut PayloadValue, ack_id: &Option, ) -> Result { Ok(Self::new(s.clone(), *ack_id)) @@ -273,24 +267,16 @@ impl FromMessageParts for AckSender { } impl AckSender { pub(crate) fn new(socket: Arc>, ack_id: Option) -> Self { - Self { - binary: vec![], - socket, - ack_id, - } - } - - /// Add binary data to the ack response. - pub fn bin(mut self, bin: Vec>) -> Self { - self.binary = bin; - self + Self { socket, ack_id } } /// Send the ack response to the client. - pub fn send(self, data: T) -> Result<(), SendError> { + pub fn send(self, data: T) -> Result<(), SendError> { use crate::socket::PermitIteratorExt; if let Some(ack_id) = self.ack_id { - let permits = match self.socket.reserve(1 + self.binary.len()) { + let data = PayloadValue::from_data(data)?; + let payload_count = data.count_payloads(); + let permits = match self.socket.reserve(1 + payload_count) { Ok(permits) => permits, Err(e) => { #[cfg(feature = "tracing")] @@ -299,11 +285,10 @@ impl AckSender { } }; let ns = self.socket.ns(); - let data = serde_json::to_value(data)?; - let packet = if self.binary.is_empty() { + let packet = if payload_count == 0 { Packet::ack(ns, data, ack_id) } else { - Packet::bin_ack(ns, data, self.binary, ack_id) + Packet::bin_ack(ns, data, ack_id) }; permits.emit(packet); Ok(()) @@ -323,8 +308,7 @@ impl FromMessageParts for crate::ProtocolVersion { type Error = Infallible; fn from_message_parts( s: &Arc>, - _: &mut serde_json::Value, - _: &mut Vec>, + _: &mut PayloadValue, _: &Option, ) -> Result { Ok(s.protocol()) @@ -347,8 +331,7 @@ impl FromMessageParts for crate::TransportType { type Error = Infallible; fn from_message_parts( s: &Arc>, - _: &mut serde_json::Value, - _: &mut Vec>, + _: &mut PayloadValue, _: &Option, ) -> Result { Ok(s.transport_type()) @@ -442,8 +425,7 @@ mod state_extract { type Error = StateNotFound; fn from_message_parts( _: &Arc>, - _: &mut serde_json::Value, - _: &mut Vec>, + _: &mut PayloadValue, _: &Option, ) -> Result { get_state::().map(State).ok_or(StateNotFound) diff --git a/socketioxide/src/handler/message.rs b/socketioxide/src/handler/message.rs index a4501e14..8f37092b 100644 --- a/socketioxide/src/handler/message.rs +++ b/socketioxide/src/handler/message.rs @@ -60,7 +60,7 @@ //! # use serde_json::Error; //! # use socketioxide::extract::*; //! // async named event handler -//! async fn on_event(s: SocketRef, Data(data): Data, ack: AckSender) { +//! async fn on_event(s: SocketRef, Data(data): Data, ack: AckSender) { //! tokio::time::sleep(std::time::Duration::from_secs(1)).await; //! ack.send("Here is my acknowledgment!").ok(); //! } @@ -74,9 +74,9 @@ use std::sync::Arc; use futures::Future; -use serde_json::Value; use crate::adapter::Adapter; +use crate::payload_value::PayloadValue; use crate::socket::Socket; use super::MakeErasedHandler; @@ -85,7 +85,7 @@ use super::MakeErasedHandler; pub(crate) type BoxedMessageHandler = Box>; pub(crate) trait ErasedMessageHandler: Send + Sync + 'static { - fn call(&self, s: Arc>, v: Value, p: Vec>, ack_id: Option); + fn call(&self, s: Arc>, v: PayloadValue, ack_id: Option); } /// Define a handler for the connect event. @@ -101,7 +101,7 @@ pub(crate) trait ErasedMessageHandler: Send + Sync + 'static { )] pub trait MessageHandler: Send + Sync + 'static { /// Call the handler with the given arguments - fn call(&self, s: Arc>, v: Value, p: Vec>, ack_id: Option); + fn call(&self, s: Arc>, v: PayloadValue, ack_id: Option); #[doc(hidden)] fn phantom(&self) -> std::marker::PhantomData { @@ -127,8 +127,8 @@ where A: Adapter, { #[inline(always)] - fn call(&self, s: Arc>, v: Value, p: Vec>, ack_id: Option) { - self.handler.call(s, v, p, ack_id); + fn call(&self, s: Arc>, v: PayloadValue, ack_id: Option) { + self.handler.call(s, v, ack_id); } } @@ -164,8 +164,7 @@ pub trait FromMessageParts: Sized { /// If it fails, the handler is not called. fn from_message_parts( s: &Arc>, - v: &mut Value, - p: &mut Vec>, + v: &mut PayloadValue, ack_id: &Option, ) -> Result; } @@ -189,8 +188,7 @@ pub trait FromMessage: Sized { /// If it fails, the handler is not called fn from_message( s: Arc>, - v: Value, - p: Vec>, + v: PayloadValue, ack_id: Option, ) -> Result; } @@ -204,11 +202,10 @@ where type Error = T::Error; fn from_message( s: Arc>, - mut v: Value, - mut p: Vec>, + mut v: PayloadValue, ack_id: Option, ) -> Result { - Self::from_message_parts(&s, &mut v, &mut p, &ack_id) + Self::from_message_parts(&s, &mut v, &ack_id) } } @@ -219,7 +216,7 @@ where Fut: Future + Send + 'static, A: Adapter, { - fn call(&self, _: Arc>, _: Value, _: Vec>, _: Option) { + fn call(&self, _: Arc>, _: PayloadValue, _: Option) { let fut = (self.clone())(); tokio::spawn(fut); } @@ -231,7 +228,7 @@ where F: FnOnce() + Send + Sync + Clone + 'static, A: Adapter, { - fn call(&self, _: Arc>, _: Value, _: Vec>, _: Option) { + fn call(&self, _: Arc>, _: PayloadValue, _: Option) { (self.clone())(); } } @@ -249,9 +246,9 @@ macro_rules! impl_async_handler { $( $ty: FromMessageParts + Send, )* $last: FromMessage + Send, { - fn call(&self, s: Arc>, mut v: Value, mut p: Vec>, ack_id: Option) { + fn call(&self, s: Arc>, mut v: PayloadValue, ack_id: Option) { $( - let $ty = match $ty::from_message_parts(&s, &mut v, &mut p, &ack_id) { + let $ty = match $ty::from_message_parts(&s, &mut v, &ack_id) { Ok(v) => v, Err(_e) => { #[cfg(feature = "tracing")] @@ -260,7 +257,7 @@ macro_rules! impl_async_handler { }, }; )* - let last = match $last::from_message(s, v, p, ack_id) { + let last = match $last::from_message(s, v, ack_id) { Ok(v) => v, Err(_e) => { #[cfg(feature = "tracing")] @@ -287,14 +284,14 @@ macro_rules! impl_handler { $( $ty: FromMessageParts + Send, )* $last: FromMessage + Send, { - fn call(&self, s: Arc>, mut v: Value, mut p: Vec>, ack_id: Option) { + fn call(&self, s: Arc>, mut v: PayloadValue, ack_id: Option) { $( - let $ty = match $ty::from_message_parts(&s, &mut v, &mut p, &ack_id) { + let $ty = match $ty::from_message_parts(&s, &mut v, &ack_id) { Ok(v) => v, Err(_) => return, }; )* - let last = match $last::from_message(s, v, p, ack_id) { + let last = match $last::from_message(s, v, ack_id) { Ok(v) => v, Err(_) => return, }; diff --git a/socketioxide/src/io.rs b/socketioxide/src/io.rs index 21dfe6e7..08712c81 100644 --- a/socketioxide/src/io.rs +++ b/socketioxide/src/io.rs @@ -535,34 +535,6 @@ impl SocketIo { self.get_default_op().timeout(timeout) } - /// Adds a binary payload to the message. - /// - /// Alias for `io.of("/").unwrap().bin(binary_payload)` - /// - /// ## Panics - /// If the **default namespace "/" is not found** this fn will panic! - /// - /// ## Example - /// ``` - /// # use socketioxide::{SocketIo, extract::SocketRef}; - /// # use serde_json::Value; - /// let (_, io) = SocketIo::new_svc(); - /// io.ns("/", |socket: SocketRef| { - /// println!("Socket connected on / namespace with id: {}", socket.id); - /// }); - /// - /// // Later in your code you can emit a test message on the root namespace in the room1 and room3 rooms, - /// // except for the room2 with a binary payload - /// io.to("room1") - /// .to("room3") - /// .except("room2") - /// .bin(vec![vec![1, 2, 3, 4]]) - /// .emit("test", ()); - #[inline] - pub fn bin(&self, binary: Vec>) -> BroadcastOperators { - self.get_default_op().bin(binary) - } - /// Emits a message to all sockets selected with the previous operators. /// /// Alias for `io.of("/").unwrap().emit(event, data)` diff --git a/socketioxide/src/lib.rs b/socketioxide/src/lib.rs index 6368ae6c..55b5cf8b 100644 --- a/socketioxide/src/lib.rs +++ b/socketioxide/src/lib.rs @@ -199,7 +199,6 @@ //! * rooms: emit, join, leave to specific rooms //! * namespace: emit to a specific namespace (only from the [`SocketIo`] handle) //! * timeout: set a custom timeout when waiting for an ack -//! * binary: emit a binary payload with the message //! * local: broadcast only to the current node (in case of a cluster) //! //! Check the [`operators`] module doc for more details on operators. @@ -271,6 +270,7 @@ pub mod packet; pub mod service; pub mod socket; +pub use bytes::Bytes; pub use engineioxide::TransportType; pub use errors::{AckError, AdapterError, BroadcastError, DisconnectError, SendError, SocketError}; pub use handler::extract; diff --git a/socketioxide/src/operators.rs b/socketioxide/src/operators.rs index b3c35ec7..a24a37b5 100644 --- a/socketioxide/src/operators.rs +++ b/socketioxide/src/operators.rs @@ -15,6 +15,7 @@ use crate::ack::{AckInnerStream, AckStream}; use crate::adapter::LocalAdapter; use crate::errors::{BroadcastError, DisconnectError}; use crate::extract::SocketRef; +use crate::payload_value::PayloadValue; use crate::socket::Socket; use crate::SendError; use crate::{ @@ -103,13 +104,11 @@ impl RoomParam for Sid { /// Chainable operators to configure the message to be sent. pub struct ConfOperators<'a, A: Adapter = LocalAdapter> { - binary: Vec>, timeout: Option, socket: &'a Socket, } /// Chainable operators to select sockets to send a message to and to configure the message to be sent. pub struct BroadcastOperators { - binary: Vec>, timeout: Option, ns: Arc>, opts: BroadcastOptions, @@ -122,7 +121,6 @@ impl From> for BroadcastOperators { ..Default::default() }; Self { - binary: conf.binary, timeout: conf.timeout, ns: conf.socket.ns.clone(), opts, @@ -134,7 +132,6 @@ impl From> for BroadcastOperators { impl<'a, A: Adapter> ConfOperators<'a, A> { pub(crate) fn new(sender: &'a Socket) -> Self { Self { - binary: vec![], timeout: None, socket: sender, } @@ -283,23 +280,6 @@ impl<'a, A: Adapter> ConfOperators<'a, A> { self.timeout = Some(timeout); self } - - /// Adds a binary payload to the message. - /// #### Example - /// ``` - /// # use socketioxide::{SocketIo, extract::*}; - /// # use serde_json::Value; - /// let (_, io) = SocketIo::new_svc(); - /// io.ns("/", |socket: SocketRef| { - /// socket.on("test", |socket: SocketRef, Data::(data), Bin(bin)| async move { - /// // This will send the binary payload received to all sockets in this namespace with the test message - /// socket.bin(bin).emit("test", data); - /// }); - /// }); - pub fn bin(mut self, binary: Vec>) -> Self { - self.binary = binary; - self - } } // ==== impl ConfOperators consume fns ==== @@ -344,9 +324,10 @@ impl ConfOperators<'_, A> { mut self, event: impl Into>, data: T, - ) -> Result<(), SendError> { + ) -> Result<(), SendError> { use crate::socket::PermitIteratorExt; - let permits = match self.socket.reserve(1 + self.binary.len()) { + let data = PayloadValue::from_data(data)?; + let permits = match self.socket.reserve(1 + data.count_payloads()) { Ok(permits) => permits, Err(e) => { #[cfg(feature = "tracing")] @@ -415,8 +396,10 @@ impl ConfOperators<'_, A> { mut self, event: impl Into>, data: T, - ) -> Result, SendError> { - let permits = match self.socket.reserve(1 + self.binary.len()) { + ) -> Result, SendError> { + let data = PayloadValue::from_data(data)?; + let payload_count = data.count_payloads(); + let permits = match self.socket.reserve(1 + payload_count) { Ok(permits) => permits, Err(e) => { #[cfg(feature = "tracing")] @@ -475,12 +458,11 @@ impl ConfOperators<'_, A> { data: impl serde::Serialize, ) -> Result, serde_json::Error> { let ns = self.socket.ns.path.clone(); - let data = serde_json::to_value(data)?; - let packet = if self.binary.is_empty() { + let data = PayloadValue::from_data(data)?; + let packet = if !data.has_binary() { Packet::event(ns, event.into(), data) } else { - let binary = std::mem::take(&mut self.binary); - Packet::bin_event(ns, event.into(), data, binary) + Packet::bin_event(ns, event.into(), data) }; Ok(packet) } @@ -489,7 +471,6 @@ impl ConfOperators<'_, A> { impl BroadcastOperators { pub(crate) fn new(ns: Arc>) -> Self { Self { - binary: vec![], timeout: None, ns, opts: BroadcastOptions::default(), @@ -497,7 +478,6 @@ impl BroadcastOperators { } pub(crate) fn from_sock(ns: Arc>, sid: Sid) -> Self { Self { - binary: vec![], timeout: None, ns, opts: BroadcastOptions { @@ -655,23 +635,6 @@ impl BroadcastOperators { self.timeout = Some(timeout); self } - - /// Adds a binary payload to the message. - /// #### Example - /// ``` - /// # use socketioxide::{SocketIo, extract::*}; - /// # use serde_json::Value; - /// let (_, io) = SocketIo::new_svc(); - /// io.ns("/", |socket: SocketRef| { - /// socket.on("test", |socket: SocketRef, Data::(data), Bin(bin)| async move { - /// // This will send the binary payload received to all sockets in this namespace with the test message - /// socket.bin(bin).emit("test", data); - /// }); - /// }); - pub fn bin(mut self, binary: Vec>) -> Self { - self.binary = binary; - self - } } // ==== impl BroadcastOperators consume fns ==== @@ -886,12 +849,11 @@ impl BroadcastOperators { data: impl serde::Serialize, ) -> Result, serde_json::Error> { let ns = self.ns.path.clone(); - let data = serde_json::to_value(data)?; - let packet = if self.binary.is_empty() { + let data = PayloadValue::from_data(data)?; + let packet = if !data.has_binary() { Packet::event(ns, event.into(), data) } else { - let binary = std::mem::take(&mut self.binary); - Packet::bin_event(ns, event.into(), data, binary) + Packet::bin_event(ns, event.into(), data) }; Ok(packet) } diff --git a/socketioxide/src/packet.rs b/socketioxide/src/packet.rs index 8c8b45c3..79748222 100644 --- a/socketioxide/src/packet.rs +++ b/socketioxide/src/packet.rs @@ -3,9 +3,9 @@ //! It should not be used directly except when implementing the [`Adapter`](crate::adapter::Adapter) trait. use std::borrow::Cow; -use crate::ProtocolVersion; +use crate::{payload_value::PayloadValue, ProtocolVersion}; +use bytes::Bytes; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_json::{json, Value}; use crate::errors::Error; use engineioxide::sid::Sid; @@ -78,7 +78,11 @@ impl<'a> Packet<'a> { } /// Create an event packet for the given namespace - pub fn event(ns: impl Into>, e: impl Into>, data: Value) -> Self { + pub fn event( + ns: impl Into>, + e: impl Into>, + data: PayloadValue, + ) -> Self { Self { inner: PacketData::Event(e.into(), data, None), ns: ns.into(), @@ -89,12 +93,9 @@ impl<'a> Packet<'a> { pub fn bin_event( ns: impl Into>, e: impl Into>, - data: Value, - bin: Vec>, + data: PayloadValue, ) -> Self { - debug_assert!(!bin.is_empty()); - - let packet = BinaryPacket::outgoing(data, bin); + let packet = BinaryPacket::outgoing(data); Self { inner: PacketData::BinaryEvent(e.into(), packet, None), ns: ns.into(), @@ -102,7 +103,7 @@ impl<'a> Packet<'a> { } /// Create an ack packet for the given namespace - pub fn ack(ns: &'a str, data: Value, ack: i64) -> Self { + pub fn ack(ns: &'a str, data: PayloadValue, ack: i64) -> Self { Self { inner: PacketData::EventAck(data, ack), ns: Cow::Borrowed(ns), @@ -110,9 +111,8 @@ impl<'a> Packet<'a> { } /// Create a binary ack packet for the given namespace - pub fn bin_ack(ns: &'a str, data: Value, bin: Vec>, ack: i64) -> Self { - debug_assert!(!bin.is_empty()); - let packet = BinaryPacket::outgoing(data, bin); + pub fn bin_ack(ns: &'a str, data: PayloadValue, ack: i64) -> Self { + let packet = BinaryPacket::outgoing(data); Self { inner: PacketData::BinaryAck(packet, ack), ns: Cow::Borrowed(ns), @@ -184,9 +184,9 @@ pub enum PacketData<'a> { /// Disconnect packet, used to disconnect from a namespace Disconnect, /// Event packet with optional ack id, to request an ack from the other side - Event(Cow<'a, str>, Value, Option), + Event(Cow<'a, str>, PayloadValue, Option), /// Event ack packet, to acknowledge an event - EventAck(Value, i64), + EventAck(PayloadValue, i64), /// Connect error packet, sent when the namespace is invalid ConnectError, /// Binary event packet with optional ack id, to request an ack from the other side @@ -199,11 +199,11 @@ pub enum PacketData<'a> { #[derive(Debug, Clone, PartialEq, Eq)] pub struct BinaryPacket { /// Data related to the packet - pub data: Value, - /// Binary payload - pub bin: Vec>, + pub(crate) data: PayloadValue, /// The number of expected payloads (used when receiving data) payload_count: usize, + /// A place to receive binary payloads until the packet is complete + payloads: Vec, } impl<'a> PacketData<'a> { @@ -248,63 +248,88 @@ impl<'a> PacketData<'a> { } impl BinaryPacket { - /// Create a binary packet from incoming data and remove all placeholders and get the payload count - pub fn incoming(mut data: Value) -> Self { - let payload_count = match &mut data { - Value::Array(ref mut v) => { - let count = v.len(); - v.retain(|v| v.as_object().and_then(|o| o.get("_placeholder")).is_none()); - count - v.len() - } - val => { - if val - .as_object() - .and_then(|o| o.get("_placeholder")) - .is_some() - { - data = Value::Array(vec![]); - 1 - } else { - 0 - } - } - }; + /// Create a binary packet from incoming data + fn incoming(data: PayloadValue, payload_count: usize) -> Self { + let actual_payload_count = data.count_payloads(); + if payload_count != actual_payload_count { + #[cfg(feature = "tracing")] + tracing::warn!( + "Binary packet header claimed {} payloads, but found {} placeholders", + payload_count, + actual_payload_count + ); + } Self { data, - bin: Vec::new(), - payload_count, + payload_count: actual_payload_count, + payloads: vec![], } } - /// Create a binary packet from outgoing data and a payload - pub fn outgoing(data: Value, bin: Vec>) -> Self { - let mut data = match data { - Value::Array(v) => Value::Array(v), - d => Value::Array(vec![d]), + /// Create a binary packet from outgoing data + fn outgoing(data: PayloadValue) -> Self { + let data = match data { + arr @ PayloadValue::Array(_) => arr, + d => PayloadValue::Array(vec![d]), }; - let payload_count = bin.len(); - (0..payload_count).for_each(|i| { - data.as_array_mut().unwrap().push(json!({ - "_placeholder": true, - "num": i - })) - }); + let payload_count = data.count_payloads(); + Self { data, - bin, payload_count, + payloads: vec![], } } /// Add a payload to the binary packet, when all payloads are added, /// the packet is complete and can be further processed - pub fn add_payload(&mut self, payload: Vec) { - self.bin.push(payload); + pub fn add_payload(&mut self, payload: Bytes) { + if self.is_complete() { + #[cfg(feature = "tracing")] + tracing::warn!("Attempt to add payload to already-complete binary packet"); + } else { + self.payloads.push(payload); + + if self.is_complete() { + integrate_payloads(&mut self.data, &mut self.payloads); + } + } } + /// Check if the binary packet is complete, it means that all payloads have been received pub fn is_complete(&self) -> bool { - self.payload_count == self.bin.len() + self.payload_count == self.payloads.len() + } + + pub(crate) fn extract_payloads(&mut self) -> Vec { + self.data.get_binary_payloads() + } +} + +fn integrate_payloads(data: &mut PayloadValue, payloads: &mut Vec) { + match data { + PayloadValue::Binary(n, ref mut data) => { + if let Some(payload) = payloads.get_mut(*n) { + std::mem::swap(payload, data); + } else { + #[cfg(feature = "tracing")] + tracing::warn!( + "Binary packet structure included placeholder with num out of range" + ); + } + } + PayloadValue::Object(o) => { + for value in o.values_mut() { + integrate_payloads(value, payloads); + } + } + PayloadValue::Array(a) => { + for value in a.iter_mut() { + integrate_payloads(value, payloads); + } + } + _ => (), } } @@ -318,12 +343,12 @@ impl<'a> From> for String { Event(e, data, _) | BinaryEvent(e, BinaryPacket { data, .. }, _) => { // Expand the packet if it is an array with data -> ["event", ...data] let packet = match data { - Value::Array(ref mut v) if !v.is_empty() => { - v.insert(0, Value::String((*e).to_string())); - serde_json::to_string(&v) + PayloadValue::Array(ref mut v) if !v.is_empty() => { + v.insert(0, PayloadValue::String((*e).to_string())); + data.to_json_string() } - Value::Array(_) => serde_json::to_string::<(_, [(); 0])>(&(e, [])), - _ => serde_json::to_string(&(e, data)), + PayloadValue::Array(_) => serde_json::to_string::<(_, [(); 0])>(&(e, [])), + _ => serde_json::to_string(&(e, data.as_json())), } .unwrap(); Some(packet) @@ -331,9 +356,9 @@ impl<'a> From> for String { EventAck(data, _) | BinaryAck(BinaryPacket { data, .. }, _) => { // Enforce that the packet is an array -> [data] let packet = match data { - Value::Array(_) => serde_json::to_string(&data), - Value::Null => Ok("[]".to_string()), - _ => serde_json::to_string(&[data]), + PayloadValue::Array(_) => data.to_json_string(), + PayloadValue::Null => Ok("[]".to_string()), + _ => serde_json::to_string(&[data.as_json()]), } .unwrap(); Some(packet) @@ -409,11 +434,11 @@ impl<'a> From> for String { /// ```text /// ["", ...] /// ``` -fn deserialize_event_packet(data: &str) -> Result<(String, Value), Error> { +fn deserialize_event_packet(data: &str) -> Result<(String, PayloadValue), Error> { #[cfg(feature = "tracing")] tracing::debug!("Deserializing event packet: {:?}", data); - let packet = match serde_json::from_str::(data)? { - Value::Array(packet) => packet, + let packet = match serde_json::from_str::(data)? { + PayloadValue::Array(packet) => packet, _ => return Err(Error::InvalidEventName), }; @@ -423,7 +448,7 @@ fn deserialize_event_packet(data: &str) -> Result<(String, Value), Error> { .as_str() .ok_or(Error::InvalidEventName)? .to_string(); - let payload = Value::from_iter(packet.into_iter().skip(1)); + let payload = PayloadValue::from_iter(packet.into_iter().skip(1)); Ok((event, payload)) } @@ -457,12 +482,18 @@ impl<'a> TryFrom for Packet<'a> { .ok_or(Error::InvalidPacketType)?; // Move the cursor to skip the payload count if it is a binary packet - if index == b'5' || index == b'6' { + let payload_count = if index == b'5' || index == b'6' { while chars.get(i) != Some(&b'-') { i += 1; } i += 1; - } + + std::str::from_utf8(&chars[1..(i - 1)]) + .map_err(|_| Error::InvalidPayloadCount) + .and_then(|s| s.parse::().map_err(|_| Error::InvalidPayloadCount))? + } else { + 0 + }; let start_index = i; // Custom nsps will start with a slash @@ -509,12 +540,16 @@ impl<'a> TryFrom for Packet<'a> { } b'5' => { let (event, payload) = deserialize_event_packet(data)?; - PacketData::BinaryEvent(event.into(), BinaryPacket::incoming(payload), ack) + PacketData::BinaryEvent( + event.into(), + BinaryPacket::incoming(payload, payload_count), + ack, + ) } b'6' => { let packet = deserialize_packet(data)?.ok_or(Error::InvalidPacketType)?; PacketData::BinaryAck( - BinaryPacket::incoming(packet), + BinaryPacket::incoming(packet, payload_count), ack.ok_or(Error::InvalidPacketType)?, ) } @@ -537,6 +572,21 @@ mod test { use super::*; + fn packet_data_map(key: &'static str, value: &'static str) -> PayloadValue { + PayloadValue::Object( + [(key.to_string(), PayloadValue::String(value.to_string()))] + .into_iter() + .collect(), + ) + } + + fn wrapped_packet_data_map(key: &'static str, value: &'static str) -> PayloadValue { + PayloadValue::Array(vec![ + packet_data_map(key, value), + PayloadValue::Binary(0, Bytes::from_static(&[1])), + ]) + } + #[test] fn packet_decode_connect() { let sid = Sid::new(); @@ -598,7 +648,11 @@ mod test { let packet = Packet::try_from(payload).unwrap(); assert_eq!( - Packet::event("/", "event", json!([{"data": "value"}])), + Packet::event( + "/", + "event", + PayloadValue::from_data(json!([{"data": "value"}])).unwrap() + ), packet ); @@ -606,7 +660,11 @@ mod test { let payload = format!("21{}", json!(["event", { "data": "value" }])); let packet = Packet::try_from(payload).unwrap(); - let mut comparison_packet = Packet::event("/", "event", json!([{"data": "value"}])); + let mut comparison_packet = Packet::event( + "/", + "event", + PayloadValue::from_data(json!([{"data": "value"}])).unwrap(), + ); comparison_packet.inner.set_ack_id(1); assert_eq!(packet, comparison_packet); @@ -615,7 +673,11 @@ mod test { let packet = Packet::try_from(payload).unwrap(); assert_eq!( - Packet::event("/admin™", "event", json!([{"data": "value™"}])), + Packet::event( + "/admin™", + "event", + PayloadValue::from_data(json!([{"data": "value™"}])).unwrap() + ), packet ); @@ -624,7 +686,11 @@ mod test { let mut packet = Packet::try_from(payload).unwrap(); packet.inner.set_ack_id(1); - let mut comparison_packet = Packet::event("/admin™", "event", json!([{"data": "value™"}])); + let mut comparison_packet = Packet::event( + "/admin™", + "event", + PayloadValue::from_data(json!([{"data": "value™"}])).unwrap(), + ); comparison_packet.inner.set_ack_id(1); assert_eq!(packet, comparison_packet); @@ -633,21 +699,32 @@ mod test { #[test] fn packet_encode_event() { let payload = format!("2{}", json!(["event", { "data": "value™" }])); - let packet: String = Packet::event("/", "event", json!({ "data": "value™" })) - .try_into() - .unwrap(); + let packet: String = Packet::event( + "/", + "event", + PayloadValue::from_data(json!({ "data": "value™" })).unwrap(), + ) + .try_into() + .unwrap(); assert_eq!(packet, payload); // Encode empty data let payload = format!("2{}", json!(["event", []])); - let packet: String = Packet::event("/", "event", json!([])).try_into().unwrap(); + let packet: String = + Packet::event("/", "event", PayloadValue::from_data(json!([])).unwrap()) + .try_into() + .unwrap(); assert_eq!(packet, payload); // Encode with ack ID let payload = format!("21{}", json!(["event", { "data": "value™" }])); - let mut packet = Packet::event("/", "event", json!({ "data": "value™" })); + let mut packet = Packet::event( + "/", + "event", + PayloadValue::from_data(json!({ "data": "value™" })).unwrap(), + ); packet.inner.set_ack_id(1); let packet: String = packet.try_into().unwrap(); @@ -655,15 +732,23 @@ mod test { // Encode with NS let payload = format!("2/admin™,{}", json!(["event", { "data": "value™" }])); - let packet: String = Packet::event("/admin™", "event", json!({"data": "value™"})) - .try_into() - .unwrap(); + let packet: String = Packet::event( + "/admin™", + "event", + PayloadValue::from_data(json!({"data": "value™"})).unwrap(), + ) + .try_into() + .unwrap(); assert_eq!(packet, payload); // Encode with NS and ack ID let payload = format!("2/admin™,1{}", json!(["event", { "data": "value™" }])); - let mut packet = Packet::event("/admin™", "event", json!([{"data": "value™"}])); + let mut packet = Packet::event( + "/admin™", + "event", + PayloadValue::from_data(json!([{"data": "value™"}])).unwrap(), + ); packet.inner.set_ack_id(1); let packet: String = packet.try_into().unwrap(); assert_eq!(packet, payload); @@ -675,24 +760,40 @@ mod test { let payload = "354[\"data\"]".to_string(); let packet = Packet::try_from(payload).unwrap(); - assert_eq!(Packet::ack("/", json!(["data"]), 54), packet); + assert_eq!( + Packet::ack("/", PayloadValue::from_data(json!(["data"])).unwrap(), 54), + packet + ); let payload = "3/admin™,54[\"data\"]".to_string(); let packet = Packet::try_from(payload).unwrap(); - assert_eq!(Packet::ack("/admin™", json!(["data"]), 54), packet); + assert_eq!( + Packet::ack( + "/admin™", + PayloadValue::from_data(json!(["data"])).unwrap(), + 54 + ), + packet + ); } #[test] fn packet_encode_event_ack() { let payload = "354[\"data\"]".to_string(); - let packet: String = Packet::ack("/", json!("data"), 54).try_into().unwrap(); + let packet: String = Packet::ack("/", PayloadValue::from_data(json!("data")).unwrap(), 54) + .try_into() + .unwrap(); assert_eq!(packet, payload); let payload = "3/admin™,54[\"data\"]".to_string(); - let packet: String = Packet::ack("/admin™", json!("data"), 54) - .try_into() - .unwrap(); + let packet: String = Packet::ack( + "/admin™", + PayloadValue::from_data(json!("data")).unwrap(), + 54, + ) + .try_into() + .unwrap(); assert_eq!(packet, payload); } @@ -713,17 +814,29 @@ mod test { let json = json!(["event", { "data": "value™" }, { "_placeholder": true, "num": 0}]); let payload = format!("51-{}", json); - let packet: String = - Packet::bin_event("/", "event", json!({ "data": "value™" }), vec![vec![1]]) - .try_into() - .unwrap(); + let packet: String = Packet::bin_event( + "/", + "event", + PayloadValue::from_data( + json!([{ "data": "value™" }, { "_placeholder": true, "num": 0 }]), + ) + .unwrap(), + ) + .try_into() + .unwrap(); assert_eq!(packet, payload); // Encode with ack ID let payload = format!("51-254{}", json); - let mut packet = - Packet::bin_event("/", "event", json!({ "data": "value™" }), vec![vec![1]]); + let mut packet = Packet::bin_event( + "/", + "event", + PayloadValue::from_data( + json!([{ "data": "value™" }, { "_placeholder": true, "num": 0 }]), + ) + .unwrap(), + ); packet.inner.set_ack_id(254); let packet: String = packet.try_into().unwrap(); @@ -734,8 +847,10 @@ mod test { let packet: String = Packet::bin_event( "/admin™", "event", - json!([{"data": "value™"}]), - vec![vec![1]], + PayloadValue::from_data( + json!([{"data": "value™"}, { "_placeholder": true, "num": 0 }]), + ) + .unwrap(), ) .try_into() .unwrap(); @@ -747,8 +862,10 @@ mod test { let mut packet = Packet::bin_event( "/admin™", "event", - json!([{"data": "value™"}]), - vec![vec![1]], + PayloadValue::from_data( + json!([{"data": "value™"}, { "_placeholder": true, "num": 0 }]), + ) + .unwrap(), ); packet.inner.set_ack_id(254); let packet: String = packet.try_into().unwrap(); @@ -757,14 +874,16 @@ mod test { #[test] fn packet_decode_binary_event() { + let binary_payload = Bytes::from_static(&[1]); + let json = json!(["event", { "data": "value™" }, { "_placeholder": true, "num": 0}]); let comparison_packet = |ack, ns: &'static str| Packet { inner: PacketData::BinaryEvent( "event".into(), BinaryPacket { - bin: vec![vec![1]], - data: json!([{"data": "value™"}]), + data: wrapped_packet_data_map("data", "value™"), payload_count: 1, + payloads: vec![Bytes::new()], }, ack, ), @@ -774,7 +893,7 @@ mod test { let payload = format!("51-{}", json); let mut packet = Packet::try_from(payload).unwrap(); match packet.inner { - PacketData::BinaryEvent(_, ref mut bin, _) => bin.add_payload(vec![1]), + PacketData::BinaryEvent(_, ref mut bin, _) => bin.add_payload(binary_payload.clone()), _ => (), } @@ -784,7 +903,7 @@ mod test { let payload = format!("51-254{}", json); let mut packet = Packet::try_from(payload).unwrap(); match packet.inner { - PacketData::BinaryEvent(_, ref mut bin, _) => bin.add_payload(vec![1]), + PacketData::BinaryEvent(_, ref mut bin, _) => bin.add_payload(binary_payload.clone()), _ => (), } @@ -794,7 +913,7 @@ mod test { let payload = format!("51-/admin™,{}", json); let mut packet = Packet::try_from(payload).unwrap(); match packet.inner { - PacketData::BinaryEvent(_, ref mut bin, _) => bin.add_payload(vec![1]), + PacketData::BinaryEvent(_, ref mut bin, _) => bin.add_payload(binary_payload.clone()), _ => (), } @@ -804,7 +923,7 @@ mod test { let payload = format!("51-/admin™,254{}", json); let mut packet = Packet::try_from(payload).unwrap(); match packet.inner { - PacketData::BinaryEvent(_, ref mut bin, _) => bin.add_payload(vec![1]), + PacketData::BinaryEvent(_, ref mut bin, _) => bin.add_payload(binary_payload.clone()), _ => (), } assert_eq!(packet, comparison_packet(Some(254), "/admin™")); @@ -816,31 +935,46 @@ mod test { let json = json!([{ "data": "value™" }, { "_placeholder": true, "num": 0}]); let payload = format!("61-54{}", json); - let packet: String = Packet::bin_ack("/", json!({ "data": "value™" }), vec![vec![1]], 54) - .try_into() - .unwrap(); + let packet: String = Packet::bin_ack( + "/", + PayloadValue::from_data( + json!([{ "data": "value™" }, { "_placeholder": true, "num": 0 }]), + ) + .unwrap(), + 54, + ) + .try_into() + .unwrap(); assert_eq!(packet, payload); // Encode with NS let payload = format!("61-/admin™,54{}", json); - let packet: String = - Packet::bin_ack("/admin™", json!({ "data": "value™" }), vec![vec![1]], 54) - .try_into() - .unwrap(); + let packet: String = Packet::bin_ack( + "/admin™", + PayloadValue::from_data( + json!([{ "data": "value™" }, { "_placeholder": true, "num": 0 }]), + ) + .unwrap(), + 54, + ) + .try_into() + .unwrap(); assert_eq!(packet, payload); } #[test] fn packet_decode_binary_ack() { + let binary_payload = Bytes::from_static(&[1]); + let json = json!([{ "data": "value™" }, { "_placeholder": true, "num": 0}]); let comparison_packet = |ack, ns: &'static str| Packet { inner: PacketData::BinaryAck( BinaryPacket { - bin: vec![vec![1]], - data: json!([{"data": "value™"}]), + data: wrapped_packet_data_map("data", "value™"), payload_count: 1, + payloads: vec![Bytes::new()], }, ack, ), @@ -850,7 +984,7 @@ mod test { let payload = format!("61-54{}", json); let mut packet = Packet::try_from(payload).unwrap(); match packet.inner { - PacketData::BinaryAck(ref mut bin, _) => bin.add_payload(vec![1]), + PacketData::BinaryAck(ref mut bin, _) => bin.add_payload(binary_payload.clone()), _ => (), } @@ -860,7 +994,7 @@ mod test { let payload = format!("61-/admin™,54{}", json); let mut packet = Packet::try_from(payload).unwrap(); match packet.inner { - PacketData::BinaryAck(ref mut bin, _) => bin.add_payload(vec![1]), + PacketData::BinaryAck(ref mut bin, _) => bin.add_payload(binary_payload.clone()), _ => (), } @@ -886,30 +1020,45 @@ mod test { let packet = Packet::disconnect("/admin"); assert_eq!(packet.get_size_hint(), 8); - let packet = Packet::event("/", "event", json!({ "data": "value™" })); + let packet = Packet::event( + "/", + "event", + PayloadValue::from_data(json!({ "data": "value™" })).unwrap(), + ); assert_eq!(packet.get_size_hint(), 1); - let packet = Packet::event("/admin", "event", json!({ "data": "value™" })); + let packet = Packet::event( + "/admin", + "event", + PayloadValue::from_data(json!({ "data": "value™" })).unwrap(), + ); assert_eq!(packet.get_size_hint(), 8); - let packet = Packet::ack("/", json!("data"), 54); + let packet = Packet::ack("/", PayloadValue::from_data(json!("data")).unwrap(), 54); assert_eq!(packet.get_size_hint(), 3); - let packet = Packet::ack("/admin", json!("data"), 54); + let packet = Packet::ack( + "/admin", + PayloadValue::from_data(json!("data")).unwrap(), + 54, + ); assert_eq!(packet.get_size_hint(), 10); - let packet = Packet::bin_event("/", "event", json!({ "data": "value™" }), vec![vec![1]]); + let packet = Packet::bin_event( + "/", + "event", + PayloadValue::from_data(json!({ "data": "value™" })).unwrap(), + ); assert_eq!(packet.get_size_hint(), 3); let packet = Packet::bin_event( "/admin", "event", - json!({ "data": "value™" }), - vec![vec![1]], + PayloadValue::from_data(json!({ "data": "value™" })).unwrap(), ); assert_eq!(packet.get_size_hint(), 10); - let packet = Packet::bin_ack("/", json!("data"), vec![vec![1]], 54); + let packet = Packet::bin_ack("/", PayloadValue::from_data(json!("data")).unwrap(), 54); assert_eq!(packet.get_size_hint(), 5); } } diff --git a/socketioxide/src/socket.rs b/socketioxide/src/socket.rs index eac68f6d..6e5969a1 100644 --- a/socketioxide/src/socket.rs +++ b/socketioxide/src/socket.rs @@ -14,7 +14,6 @@ use std::{ use engineioxide::socket::{DisconnectReason as EIoDisconnectReason, Permit, PermitIterator}; use serde::{de::DeserializeOwned, Serialize}; -use serde_json::Value; use tokio::sync::oneshot::{self, Receiver}; #[cfg(feature = "extensions")] @@ -31,6 +30,7 @@ use crate::{ ns::Namespace, operators::{BroadcastOperators, ConfOperators, RoomParam}, packet::{BinaryPacket, Packet, PacketData}, + payload_value::PayloadValue, AckError, SocketIoConfig, }; use crate::{ @@ -107,9 +107,9 @@ pub(crate) trait PermitIteratorExt<'a>: fn emit(mut self, mut packet: Packet<'_>) { debug_assert!(self.len() > 0, "No permits available to send the message"); - let bin_payloads = match packet.inner { + let bin_payloads = match &mut packet.inner { PacketData::BinaryEvent(_, ref mut bin, _) | PacketData::BinaryAck(ref mut bin, _) => { - Some(std::mem::take(&mut bin.bin)) + Some(bin.extract_payloads()) } _ => None, }; @@ -123,7 +123,7 @@ pub(crate) trait PermitIteratorExt<'a>: "Not enough permits available to send the message with the binary payload" ); for bin in bin_payloads { - self.next().unwrap().emit_binary(bin); + self.next().unwrap().emit_binary(bin.into()); } } } @@ -138,7 +138,7 @@ pub struct Socket { pub(crate) ns: Arc>, message_handlers: RwLock, BoxedMessageHandler>>, disconnect_handler: Mutex>>, - ack_message: Mutex>>>, + ack_message: Mutex>>>, ack_counter: AtomicI64, /// The socket id pub id: Sid, @@ -312,8 +312,11 @@ impl Socket { &self, event: impl Into>, data: T, - ) -> Result<(), SendError> { - let permits = match self.reserve(1) { + ) -> Result<(), SendError> { + let data = PayloadValue::from_data(data)?; + let payload_count = data.count_payloads(); + + let permits = match self.reserve(1 + payload_count) { Ok(permits) => permits, Err(e) => { #[cfg(feature = "tracing")] @@ -323,8 +326,12 @@ impl Socket { }; let ns = self.ns(); - let data = serde_json::to_value(data)?; - permits.emit(Packet::event(ns, event.into(), data)); + let packet = if payload_count > 0 { + Packet::bin_event(ns, event.into(), data) + } else { + Packet::event(ns, event.into(), data) + }; + permits.emit(packet); Ok(()) } @@ -383,8 +390,11 @@ impl Socket { &self, event: impl Into>, data: T, - ) -> Result, SendError> { - let permits = match self.reserve(1) { + ) -> Result, SendError> { + let data = PayloadValue::from_data(data)?; + let payload_count = data.count_payloads(); + + let permits = match self.reserve(1 + payload_count) { Ok(permits) => permits, Err(e) => { #[cfg(feature = "tracing")] @@ -392,8 +402,13 @@ impl Socket { return Err(e.with_value(data).into()); } }; - let data = serde_json::to_value(data)?; - let packet = Packet::event(self.ns(), event.into(), data); + + let ns = self.ns(); + let packet = if payload_count > 0 { + Packet::bin_event(ns, event.into(), data) + } else { + Packet::event(ns, event.into(), data) + }; let rx = self.send_with_ack_permit(packet, permits); let stream = AckInnerStream::send(rx, self.config.ack_timeout, self.id); Ok(AckStream::::from(stream)) @@ -569,23 +584,6 @@ impl Socket { ConfOperators::new(self).timeout(timeout) } - /// Adds a binary payload to the message. - /// # Example - /// ``` - /// # use socketioxide::{SocketIo, extract::*}; - /// # use serde_json::Value; - /// # use std::sync::Arc; - /// let (_, io) = SocketIo::new_svc(); - /// io.ns("/", |socket: SocketRef| { - /// socket.on("test", |socket: SocketRef, Data::(data), Bin(bin)| async move { - /// // This will send the binary payload received to all clients in this namespace with the test message - /// socket.bin(bin).emit("test", data); - /// }); - /// }); - pub fn bin(&self, binary: Vec>) -> ConfOperators<'_, A> { - ConfOperators::new(self).bin(binary) - } - /// Broadcasts to all clients without any filtering (except the current socket). /// # Example /// ``` @@ -647,7 +645,7 @@ impl Socket { &self, mut packet: Packet<'_>, permits: PermitIterator<'_>, - ) -> Receiver> { + ) -> Receiver> { let (tx, rx) = oneshot::channel(); let ack = self.ack_counter.fetch_add(1, Ordering::SeqCst) + 1; @@ -657,7 +655,10 @@ impl Socket { rx } - pub(crate) fn send_with_ack(&self, mut packet: Packet<'_>) -> Receiver> { + pub(crate) fn send_with_ack( + &self, + mut packet: Packet<'_>, + ) -> Receiver> { let (tx, rx) = oneshot::channel(); let ack = self.ack_counter.fetch_add(1, Ordering::SeqCst) + 1; @@ -735,9 +736,14 @@ impl Socket { self.esocket.protocol.into() } - fn recv_event(self: Arc, e: &str, data: Value, ack: Option) -> Result<(), Error> { + fn recv_event( + self: Arc, + e: &str, + data: PayloadValue, + ack: Option, + ) -> Result<(), Error> { if let Some(handler) = self.message_handlers.read().unwrap().get(e) { - handler.call(self.clone(), data, vec![], ack); + handler.call(self.clone(), data, ack); } Ok(()) } @@ -749,17 +755,14 @@ impl Socket { ack: Option, ) -> Result<(), Error> { if let Some(handler) = self.message_handlers.read().unwrap().get(e) { - handler.call(self.clone(), packet.data, packet.bin, ack); + handler.call(self.clone(), packet.data, ack); } Ok(()) } - fn recv_ack(self: Arc, data: Value, ack: i64) -> Result<(), Error> { + fn recv_ack(self: Arc, data: PayloadValue, ack: i64) -> Result<(), Error> { if let Some(tx) = self.ack_message.lock().unwrap().remove(&ack) { - let res = AckResponse { - data, - binary: vec![], - }; + let res = AckResponse { data }; tx.send(Ok(res)).ok(); } Ok(()) @@ -767,10 +770,7 @@ impl Socket { fn recv_bin_ack(self: Arc, packet: BinaryPacket, ack: i64) -> Result<(), Error> { if let Some(tx) = self.ack_message.lock().unwrap().remove(&ack) { - let res = AckResponse { - data: packet.data, - binary: packet.bin, - }; + let res = AckResponse { data: packet.data }; tx.send(Ok(res)).ok(); } Ok(()) @@ -818,11 +818,11 @@ mod test { // Saturate the channel for _ in 0..200 { socket - .send(Packet::event("test", "test", Value::Null)) + .send(Packet::event("test", "test", PayloadValue::Null)) .unwrap(); } - let ack = socket.emit_with_ack::<_, Value>("test", Value::Null); + let ack = socket.emit_with_ack::<_, PayloadValue>("test", PayloadValue::Null); assert!(matches!( ack, Err(SendError::Socket(SocketError::InternalChannelFull(_)))