From c15e6517b1e9c60689ce3287eb8b74d4bdd93c72 Mon Sep 17 00:00:00 2001 From: Victor Ermolaev <16148931+vnermolaev@users.noreply.github.com> Date: Mon, 30 Jan 2023 23:09:51 +0100 Subject: [PATCH] refactor(tcp): use SelectAll for driving listener streams (#3361) The PR optimizes polling of the listeners in the TCP transport by using `futures::SelectAll` instead of storing them in a queue and polling manually. Resolves #2781. --- transports/tcp/src/lib.rs | 258 +++++++++++++++++--------------------- 1 file changed, 118 insertions(+), 140 deletions(-) diff --git a/transports/tcp/src/lib.rs b/transports/tcp/src/lib.rs index 30a70e69b79..52d23d0a52d 100644 --- a/transports/tcp/src/lib.rs +++ b/transports/tcp/src/lib.rs @@ -39,6 +39,7 @@ pub use provider::tokio; use futures::{ future::{self, Ready}, prelude::*, + stream::SelectAll, }; use futures_timer::Delay; use if_watch::IfEvent; @@ -55,7 +56,7 @@ use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpListener}, pin::Pin, sync::{Arc, RwLock}, - task::{Context, Poll}, + task::{Context, Poll, Waker}, time::Duration, }; @@ -312,7 +313,7 @@ where /// All the active listeners. /// The [`ListenStream`] struct contains a stream that we want to be pinned. Since the `VecDeque` /// can be resized, the only way is to use a `Pin>`. - listeners: VecDeque>>>, + listeners: SelectAll>, /// Pending transport events to return from [`libp2p_core::Transport::poll`]. pending_events: VecDeque::ListenerUpgrade, io::Error>>, @@ -419,7 +420,7 @@ where Transport { port_reuse, config, - listeners: VecDeque::new(), + listeners: SelectAll::new(), pending_events: VecDeque::new(), } } @@ -447,18 +448,13 @@ where let listener = self .do_listen(id, socket_addr) .map_err(TransportError::Other)?; - self.listeners.push_back(Box::pin(listener)); + self.listeners.push(listener); Ok(id) } fn remove_listener(&mut self, id: ListenerId) -> bool { - if let Some(index) = self.listeners.iter().position(|l| l.listener_id == id) { - self.listeners.remove(index); - self.pending_events - .push_back(TransportEvent::ListenerClosed { - listener_id: id, - reason: Ok(()), - }); + if let Some(listener) = self.listeners.iter_mut().find(|l| l.listener_id == id) { + listener.close(Ok(())); true } else { false @@ -548,96 +544,14 @@ where if let Some(event) = self.pending_events.pop_front() { return Poll::Ready(event); } - // We remove each element from `listeners` one by one and add them back. - let mut remaining = self.listeners.len(); - while let Some(mut listener) = self.listeners.pop_back() { - match TryStream::try_poll_next(listener.as_mut(), cx) { - Poll::Pending => { - self.listeners.push_front(listener); - remaining -= 1; - if remaining == 0 { - break; - } - } - Poll::Ready(Some(Ok(TcpListenerEvent::Upgrade { - upgrade, - local_addr, - remote_addr, - }))) => { - let id = listener.listener_id; - self.listeners.push_front(listener); - return Poll::Ready(TransportEvent::Incoming { - listener_id: id, - upgrade, - local_addr, - send_back_addr: remote_addr, - }); - } - Poll::Ready(Some(Ok(TcpListenerEvent::NewAddress(a)))) => { - let id = listener.listener_id; - self.listeners.push_front(listener); - return Poll::Ready(TransportEvent::NewAddress { - listener_id: id, - listen_addr: a, - }); - } - Poll::Ready(Some(Ok(TcpListenerEvent::AddressExpired(a)))) => { - let id = listener.listener_id; - self.listeners.push_front(listener); - return Poll::Ready(TransportEvent::AddressExpired { - listener_id: id, - listen_addr: a, - }); - } - Poll::Ready(Some(Ok(TcpListenerEvent::Error(error)))) => { - let id = listener.listener_id; - self.listeners.push_front(listener); - return Poll::Ready(TransportEvent::ListenerError { - listener_id: id, - error, - }); - } - Poll::Ready(None) => { - return Poll::Ready(TransportEvent::ListenerClosed { - listener_id: listener.listener_id, - reason: Ok(()), - }); - } - Poll::Ready(Some(Err(err))) => { - return Poll::Ready(TransportEvent::ListenerClosed { - listener_id: listener.listener_id, - reason: Err(err), - }); - } - } + + match self.listeners.poll_next_unpin(cx) { + Poll::Ready(Some(transport_event)) => Poll::Ready(transport_event), + _ => Poll::Pending, } - Poll::Pending } } -/// Event produced by a [`ListenStream`]. -#[derive(Debug)] -enum TcpListenerEvent { - /// The listener is listening on a new additional [`Multiaddr`]. - NewAddress(Multiaddr), - /// An upgrade, consisting of the upgrade future, the listener address and the remote address. - Upgrade { - /// The upgrade. - upgrade: Ready>, - /// The local address which produced this upgrade. - local_addr: Multiaddr, - /// The remote address which produced this upgrade. - remote_addr: Multiaddr, - }, - /// A [`Multiaddr`] is no longer used for listening. - AddressExpired(Multiaddr), - /// A non-fatal error has happened on the listener. - /// - /// This event should be generated in order to notify the user that something wrong has - /// happened. The listener, however, continues to run. - Error(io::Error), -} - /// A stream of incoming connections on one or more interfaces. struct ListenStream where @@ -669,6 +583,12 @@ where sleep_on_error: Duration, /// The current pause, if any. pause: Option, + /// Pending event to reported. + pending_event: Option<::Item>, + /// The listener can be manually closed with [`Transport::remove_listener`](libp2p_core::Transport::remove_listener). + is_closed: bool, + /// The stream must be awaken after it has been closed to deliver the last event. + close_listener_waker: Option, } impl ListenStream @@ -694,6 +614,9 @@ where if_watcher, pause: None, sleep_on_error: Duration::from_millis(100), + pending_event: None, + is_closed: false, + close_listener_waker: None, }) } @@ -716,6 +639,74 @@ where .unregister(self.listen_addr.ip(), self.listen_addr.port()), } } + + /// Close the listener. + /// + /// This will create a [`TransportEvent::ListenerClosed`] and + /// terminate the stream once the event has been reported. + fn close(&mut self, reason: Result<(), io::Error>) { + if self.is_closed { + return; + } + self.pending_event = Some(TransportEvent::ListenerClosed { + listener_id: self.listener_id, + reason, + }); + self.is_closed = true; + + // Wake the stream to deliver the last event. + if let Some(waker) = self.close_listener_waker.take() { + waker.wake(); + } + } + + /// Poll for a next If Event. + fn poll_if_addr(&mut self, cx: &mut Context<'_>) -> Poll<::Item> { + let if_watcher = match self.if_watcher.as_mut() { + Some(if_watcher) => if_watcher, + None => return Poll::Pending, + }; + + let my_listen_addr_port = self.listen_addr.port(); + + while let Poll::Ready(Some(event)) = if_watcher.poll_next_unpin(cx) { + match event { + Ok(IfEvent::Up(inet)) => { + let ip = inet.addr(); + if self.listen_addr.is_ipv4() == ip.is_ipv4() { + let ma = ip_to_multiaddr(ip, my_listen_addr_port); + log::debug!("New listen address: {}", ma); + self.port_reuse.register(ip, my_listen_addr_port); + return Poll::Ready(TransportEvent::NewAddress { + listener_id: self.listener_id, + listen_addr: ma, + }); + } + } + Ok(IfEvent::Down(inet)) => { + let ip = inet.addr(); + if self.listen_addr.is_ipv4() == ip.is_ipv4() { + let ma = ip_to_multiaddr(ip, my_listen_addr_port); + log::debug!("Expired listen address: {}", ma); + self.port_reuse.unregister(ip, my_listen_addr_port); + return Poll::Ready(TransportEvent::AddressExpired { + listener_id: self.listener_id, + listen_addr: ma, + }); + } + } + Err(error) => { + self.pause = Some(Delay::new(self.sleep_on_error)); + return Poll::Ready(TransportEvent::ListenerError { + listener_id: self.listener_id, + error, + }); + } + } + } + + Poll::Pending + } } impl Drop for ListenStream @@ -733,52 +724,34 @@ where T::Listener: Unpin, T::Stream: Unpin, { - type Item = Result, io::Error>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let me = Pin::into_inner(self); + type Item = TransportEvent>, io::Error>; - if let Some(mut pause) = me.pause.take() { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + if let Some(mut pause) = self.pause.take() { match pause.poll_unpin(cx) { Poll::Ready(_) => {} Poll::Pending => { - me.pause = Some(pause); + self.pause = Some(pause); return Poll::Pending; } } } - if let Some(if_watcher) = me.if_watcher.as_mut() { - while let Poll::Ready(Some(event)) = if_watcher.poll_next_unpin(cx) { - match event { - Ok(IfEvent::Up(inet)) => { - let ip = inet.addr(); - if me.listen_addr.is_ipv4() == ip.is_ipv4() { - let ma = ip_to_multiaddr(ip, me.listen_addr.port()); - log::debug!("New listen address: {}", ma); - me.port_reuse.register(ip, me.listen_addr.port()); - return Poll::Ready(Some(Ok(TcpListenerEvent::NewAddress(ma)))); - } - } - Ok(IfEvent::Down(inet)) => { - let ip = inet.addr(); - if me.listen_addr.is_ipv4() == ip.is_ipv4() { - let ma = ip_to_multiaddr(ip, me.listen_addr.port()); - log::debug!("Expired listen address: {}", ma); - me.port_reuse.unregister(ip, me.listen_addr.port()); - return Poll::Ready(Some(Ok(TcpListenerEvent::AddressExpired(ma)))); - } - } - Err(err) => { - me.pause = Some(Delay::new(me.sleep_on_error)); - return Poll::Ready(Some(Ok(TcpListenerEvent::Error(err)))); - } - } - } + if let Some(event) = self.pending_event.take() { + return Poll::Ready(Some(event)); + } + + if self.is_closed { + // Terminate the stream if the listener closed and all remaining events have been reported. + return Poll::Ready(None); + } + + if let Poll::Ready(event) = self.poll_if_addr(cx) { + return Poll::Ready(Some(event)); } // Take the pending connection from the backlog. - match T::poll_accept(&mut me.listener, cx) { + match T::poll_accept(&mut self.listener, cx) { Poll::Ready(Ok(Incoming { local_addr, remote_addr, @@ -789,20 +762,25 @@ where log::debug!("Incoming connection from {} at {}", remote_addr, local_addr); - return Poll::Ready(Some(Ok(TcpListenerEvent::Upgrade { + return Poll::Ready(Some(TransportEvent::Incoming { + listener_id: self.listener_id, upgrade: future::ok(stream), local_addr, - remote_addr, - }))); + send_back_addr: remote_addr, + })); } - Poll::Ready(Err(e)) => { + Poll::Ready(Err(error)) => { // These errors are non-fatal for the listener stream. - me.pause = Some(Delay::new(me.sleep_on_error)); - return Poll::Ready(Some(Ok(TcpListenerEvent::Error(e)))); + self.pause = Some(Delay::new(self.sleep_on_error)); + return Poll::Ready(Some(TransportEvent::ListenerError { + listener_id: self.listener_id, + error, + })); } Poll::Pending => {} - }; + } + self.close_listener_waker = Some(cx.waker().clone()); Poll::Pending } } @@ -1119,7 +1097,7 @@ mod tests { match poll_fn(|cx| Pin::new(&mut tcp).poll(cx)).await { TransportEvent::NewAddress { .. } => { // Check that tcp and listener share the same port reuse SocketAddr - let listener = tcp.listeners.front().unwrap(); + let listener = tcp.listeners.iter().next().unwrap(); let port_reuse_tcp = tcp.port_reuse.local_dial_addr(&listener.listen_addr.ip()); let port_reuse_listener = listener .port_reuse @@ -1188,7 +1166,7 @@ mod tests { TransportEvent::NewAddress { listen_addr: addr1, .. } => { - let listener1 = tcp.listeners.front().unwrap(); + let listener1 = tcp.listeners.iter().next().unwrap(); let port_reuse_tcp = tcp.port_reuse.local_dial_addr(&listener1.listen_addr.ip()); let port_reuse_listener1 = listener1