Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Report aborted connection from Pool::poll #2369

Closed
wants to merge 8 commits into from
42 changes: 16 additions & 26 deletions core/src/connection/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ struct PendingConnectionInfo<THandler> {
handler: THandler,
endpoint: PendingPoint,
/// When dropped, notifies the task which then knows to terminate.
_drop_notifier: oneshot::Sender<Void>,
abort_notifier: Option<oneshot::Sender<Void>>,
}

impl<THandler: IntoConnectionHandler, TTrans: Transport> fmt::Debug for Pool<THandler, TTrans> {
Expand Down Expand Up @@ -341,10 +341,7 @@ where
/// Returns `None` if the pool has no connection with the given ID.
pub fn get(&mut self, id: ConnectionId) -> Option<PoolConnection<'_, THandler>> {
if let hash_map::Entry::Occupied(entry) = self.pending.entry(id) {
Some(PoolConnection::Pending(PendingConnection {
entry,
counters: &mut self.counters,
}))
Some(PoolConnection::Pending(PendingConnection { entry }))
} else {
self.established
.iter_mut()
Expand All @@ -371,10 +368,7 @@ where
/// Gets a pending outgoing connection by ID.
pub fn get_outgoing(&mut self, id: ConnectionId) -> Option<PendingConnection<'_, THandler>> {
match self.pending.entry(id) {
hash_map::Entry::Occupied(entry) => Some(PendingConnection {
entry,
counters: &mut self.counters,
}),
hash_map::Entry::Occupied(entry) => Some(PendingConnection { entry }),
hash_map::Entry::Vacant(_) => None,
}
}
Expand Down Expand Up @@ -418,11 +412,7 @@ where
.entry(pending_connection)
.expect_occupied("Iterating pending connections");

PendingConnection {
entry,
counters: &mut self.counters,
}
.abort();
PendingConnection { entry }.abort();
}
}

Expand Down Expand Up @@ -548,13 +538,13 @@ where

let connection_id = self.next_connection_id();

let (drop_notifier, drop_receiver) = oneshot::channel();
let (abort_notifier, abort_receiver) = oneshot::channel();

self.spawn(
task::new_for_pending_outgoing_connection(
connection_id,
dial,
drop_receiver,
abort_receiver,
self.pending_connection_events_tx.clone(),
)
.boxed(),
Expand All @@ -567,7 +557,7 @@ where
peer_id: peer,
handler,
endpoint: PendingPoint::Dialer,
_drop_notifier: drop_notifier,
abort_notifier: Some(abort_notifier),
},
);
Ok(connection_id)
Expand Down Expand Up @@ -595,13 +585,13 @@ where

let connection_id = self.next_connection_id();

let (drop_notifier, drop_receiver) = oneshot::channel();
let (abort_notifier, abort_receiver) = oneshot::channel();

self.spawn(
task::new_for_pending_incoming_connection(
connection_id,
future,
drop_receiver,
abort_receiver,
self.pending_connection_events_tx.clone(),
)
.boxed(),
Expand All @@ -614,7 +604,7 @@ where
peer_id: None,
handler,
endpoint: endpoint.into(),
_drop_notifier: drop_notifier,
abort_notifier: Some(abort_notifier),
},
);
Ok(connection_id)
Expand Down Expand Up @@ -730,7 +720,7 @@ where
peer_id: expected_peer_id,
handler,
endpoint,
_drop_notifier,
abort_notifier: _,
} = self
.pending
.remove(&id)
Expand Down Expand Up @@ -898,7 +888,7 @@ where
peer_id,
handler,
endpoint,
_drop_notifier,
abort_notifier: _,
}) = self.pending.remove(&id)
{
self.counters.dec_pending(&endpoint);
Expand Down Expand Up @@ -955,7 +945,6 @@ pub enum PoolConnection<'a, THandler: IntoConnectionHandler> {
/// A pending connection in a pool.
pub struct PendingConnection<'a, THandler: IntoConnectionHandler> {
entry: hash_map::OccupiedEntry<'a, ConnectionId, PendingConnectionInfo<THandler>>,
counters: &'a mut ConnectionCounters,
}

impl<THandler: IntoConnectionHandler> PendingConnection<'_, THandler> {
Expand All @@ -975,9 +964,10 @@ impl<THandler: IntoConnectionHandler> PendingConnection<'_, THandler> {
}

/// Aborts the connection attempt, closing the connection.
pub fn abort(self) {
self.counters.dec_pending(&self.entry.get().endpoint);
self.entry.remove();
pub fn abort(mut self) {
if let Some(notifier) = self.entry.get_mut().abort_notifier.take() {
drop(notifier);
}
}
}

Expand Down
10 changes: 5 additions & 5 deletions core/src/connection/pool/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use futures::{
use std::pin::Pin;
use void::Void;

/// Commands that can be sent to a task.
/// Commands that can be sent to a task driving an established connection.
#[derive(Debug)]
pub enum Command<T> {
/// Notify the connection handler of an event.
Expand Down Expand Up @@ -103,12 +103,12 @@ pub enum EstablishedConnectionEvent<THandler: IntoConnectionHandler> {
pub async fn new_for_pending_outgoing_connection<TTrans>(
connection_id: ConnectionId,
dial: ConcurrentDial<TTrans>,
drop_receiver: oneshot::Receiver<Void>,
abort_receiver: oneshot::Receiver<Void>,
mut events: mpsc::Sender<PendingConnectionEvent<TTrans>>,
) where
TTrans: Transport,
{
match futures::future::select(drop_receiver, Box::pin(dial)).await {
match futures::future::select(abort_receiver, Box::pin(dial)).await {
Either::Left((Err(oneshot::Canceled), _)) => {
let _ = events
.send(PendingConnectionEvent::PendingFailed {
Expand Down Expand Up @@ -141,13 +141,13 @@ pub async fn new_for_pending_outgoing_connection<TTrans>(
pub async fn new_for_pending_incoming_connection<TFut, TTrans>(
connection_id: ConnectionId,
future: TFut,
drop_receiver: oneshot::Receiver<Void>,
abort_receiver: oneshot::Receiver<Void>,
mut events: mpsc::Sender<PendingConnectionEvent<TTrans>>,
) where
TTrans: Transport,
TFut: Future<Output = Result<TTrans::Output, TTrans::Error>> + Send + 'static,
{
match futures::future::select(drop_receiver, Box::pin(future)).await {
match futures::future::select(abort_receiver, Box::pin(future)).await {
Either::Left((Err(oneshot::Canceled), _)) => {
let _ = events
.send(PendingConnectionEvent::PendingFailed {
Expand Down
50 changes: 50 additions & 0 deletions core/tests/aborted_connection.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
mod util;

use std::task::Poll;

use libp2p_core::{
connection::{self, PendingOutboundConnectionError},
network::{NetworkConfig, NetworkEvent},
transport::dummy::DummyTransport,
Multiaddr, Network, PeerId,
};

use futures::{executor::block_on, future::poll_fn};
use multihash::Multihash;

#[test]
fn aborting_pending_connection_surfaces_error() {
let mut network = Network::new(
DummyTransport::default(),
PeerId::random(),
NetworkConfig::default(),
);

let target_peer = PeerId::random();
let mut target_multiaddr = "/ip4/127.0.0.1/tcp/1234".parse::<Multiaddr>().unwrap();
target_multiaddr.push(multiaddr::Protocol::P2p(target_peer.into()));

let handler = util::TestHandler();
network
.dial(&target_multiaddr, handler)
.expect("dial failed");

let dialing_peer = network
.peer(target_peer)
.into_dialing()
.expect("peer should be dialing");

dialing_peer.disconnect();
block_on(poll_fn(|cx| match network.poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(NetworkEvent::DialError {
error: PendingOutboundConnectionError::Aborted,
..
}) => {
return Poll::Ready(());
}
Poll::Ready(_) => {
panic!("We should see an aborted error, nothing else.")
}
}));
}