diff --git a/core/src/connection/pool.rs b/core/src/connection/pool.rs index 59d41908f3a..2f5f881b4b3 100644 --- a/core/src/connection/pool.rs +++ b/core/src/connection/pool.rs @@ -137,7 +137,7 @@ struct PendingConnectionInfo { handler: THandler, endpoint: PendingPoint, /// When dropped, notifies the task which then knows to terminate. - _drop_notifier: oneshot::Sender, + abort_notifier: Option>, } impl fmt::Debug for Pool { @@ -341,10 +341,7 @@ where /// Returns `None` if the pool has no connection with the given ID. pub fn get(&mut self, id: ConnectionId) -> Option> { if let hash_map::Entry::Occupied(entry) = self.pending.entry(id) { - Some(PoolConnection::Pending(PendingConnection { - entry, - counters: &mut self.counters, - })) + Some(PoolConnection::Pending(PendingConnection { entry })) } else { self.established .iter_mut() @@ -371,10 +368,7 @@ where /// Gets a pending outgoing connection by ID. pub fn get_outgoing(&mut self, id: ConnectionId) -> Option> { match self.pending.entry(id) { - hash_map::Entry::Occupied(entry) => Some(PendingConnection { - entry, - counters: &mut self.counters, - }), + hash_map::Entry::Occupied(entry) => Some(PendingConnection { entry }), hash_map::Entry::Vacant(_) => None, } } @@ -418,11 +412,7 @@ where .entry(pending_connection) .expect_occupied("Iterating pending connections"); - PendingConnection { - entry, - counters: &mut self.counters, - } - .abort(); + PendingConnection { entry }.abort(); } } @@ -548,13 +538,13 @@ where let connection_id = self.next_connection_id(); - let (drop_notifier, drop_receiver) = oneshot::channel(); + let (abort_notifier, abort_receiver) = oneshot::channel(); self.spawn( task::new_for_pending_outgoing_connection( connection_id, dial, - drop_receiver, + abort_receiver, self.pending_connection_events_tx.clone(), ) .boxed(), @@ -567,7 +557,7 @@ where peer_id: peer, handler, endpoint: PendingPoint::Dialer, - _drop_notifier: drop_notifier, + abort_notifier: Some(abort_notifier), }, ); Ok(connection_id) @@ -595,13 +585,13 @@ where let connection_id = self.next_connection_id(); - let (drop_notifier, drop_receiver) = oneshot::channel(); + let (abort_notifier, abort_receiver) = oneshot::channel(); self.spawn( task::new_for_pending_incoming_connection( connection_id, future, - drop_receiver, + abort_receiver, self.pending_connection_events_tx.clone(), ) .boxed(), @@ -614,7 +604,7 @@ where peer_id: None, handler, endpoint: endpoint.into(), - _drop_notifier: drop_notifier, + abort_notifier: Some(abort_notifier), }, ); Ok(connection_id) @@ -730,7 +720,7 @@ where peer_id: expected_peer_id, handler, endpoint, - _drop_notifier, + abort_notifier: _, } = self .pending .remove(&id) @@ -898,7 +888,7 @@ where peer_id, handler, endpoint, - _drop_notifier, + abort_notifier: _, }) = self.pending.remove(&id) { self.counters.dec_pending(&endpoint); @@ -955,7 +945,6 @@ pub enum PoolConnection<'a, THandler: IntoConnectionHandler> { /// A pending connection in a pool. pub struct PendingConnection<'a, THandler: IntoConnectionHandler> { entry: hash_map::OccupiedEntry<'a, ConnectionId, PendingConnectionInfo>, - counters: &'a mut ConnectionCounters, } impl PendingConnection<'_, THandler> { @@ -975,9 +964,10 @@ impl PendingConnection<'_, THandler> { } /// Aborts the connection attempt, closing the connection. - pub fn abort(self) { - self.counters.dec_pending(&self.entry.get().endpoint); - self.entry.remove(); + pub fn abort(mut self) { + if let Some(notifier) = self.entry.get_mut().abort_notifier.take() { + drop(notifier); + } } } diff --git a/core/src/connection/pool/task.rs b/core/src/connection/pool/task.rs index 9062583fd79..889847afdef 100644 --- a/core/src/connection/pool/task.rs +++ b/core/src/connection/pool/task.rs @@ -41,7 +41,7 @@ use futures::{ use std::pin::Pin; use void::Void; -/// Commands that can be sent to a task. +/// Commands that can be sent to a task driving an established connection. #[derive(Debug)] pub enum Command { /// Notify the connection handler of an event. @@ -103,12 +103,12 @@ pub enum EstablishedConnectionEvent { pub async fn new_for_pending_outgoing_connection( connection_id: ConnectionId, dial: ConcurrentDial, - drop_receiver: oneshot::Receiver, + abort_receiver: oneshot::Receiver, mut events: mpsc::Sender>, ) where TTrans: Transport, { - match futures::future::select(drop_receiver, Box::pin(dial)).await { + match futures::future::select(abort_receiver, Box::pin(dial)).await { Either::Left((Err(oneshot::Canceled), _)) => { let _ = events .send(PendingConnectionEvent::PendingFailed { @@ -141,13 +141,13 @@ pub async fn new_for_pending_outgoing_connection( pub async fn new_for_pending_incoming_connection( connection_id: ConnectionId, future: TFut, - drop_receiver: oneshot::Receiver, + abort_receiver: oneshot::Receiver, mut events: mpsc::Sender>, ) where TTrans: Transport, TFut: Future> + Send + 'static, { - match futures::future::select(drop_receiver, Box::pin(future)).await { + match futures::future::select(abort_receiver, Box::pin(future)).await { Either::Left((Err(oneshot::Canceled), _)) => { let _ = events .send(PendingConnectionEvent::PendingFailed { diff --git a/core/tests/aborted_connection.rs b/core/tests/aborted_connection.rs new file mode 100644 index 00000000000..5aaf91c547a --- /dev/null +++ b/core/tests/aborted_connection.rs @@ -0,0 +1,50 @@ +mod util; + +use std::task::Poll; + +use libp2p_core::{ + connection::{self, PendingOutboundConnectionError}, + network::{NetworkConfig, NetworkEvent}, + transport::dummy::DummyTransport, + Multiaddr, Network, PeerId, +}; + +use futures::{executor::block_on, future::poll_fn}; +use multihash::Multihash; + +#[test] +fn aborting_pending_connection_surfaces_error() { + let mut network = Network::new( + DummyTransport::default(), + PeerId::random(), + NetworkConfig::default(), + ); + + let target_peer = PeerId::random(); + let mut target_multiaddr = "/ip4/127.0.0.1/tcp/1234".parse::().unwrap(); + target_multiaddr.push(multiaddr::Protocol::P2p(target_peer.into())); + + let handler = util::TestHandler(); + network + .dial(&target_multiaddr, handler) + .expect("dial failed"); + + let dialing_peer = network + .peer(target_peer) + .into_dialing() + .expect("peer should be dialing"); + + dialing_peer.disconnect(); + block_on(poll_fn(|cx| match network.poll(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(NetworkEvent::DialError { + error: PendingOutboundConnectionError::Aborted, + .. + }) => { + return Poll::Ready(()); + } + Poll::Ready(_) => { + panic!("We should see an aborted error, nothing else.") + } + })); +}