Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(swarm): prevent overflow in keep-alive computation #4644

Merged
merged 11 commits into from
Oct 18, 2023
2 changes: 2 additions & 0 deletions swarm/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
See [PR 4120].
- Make the `Debug` implementation of `StreamProtocol` more concise.
See [PR 4631](https://github.com/libp2p/rust-libp2p/pull/4631).
- Fix overflow in `KeepAlive` computation that could occur panic at `Delay::new` if `SwarmBuilder::idle_connection_timeout` is configured too large.
See [PR 4644](https://github.com/libp2p/rust-libp2p/pull/4644).

[PR 4120]: https://github.com/libp2p/rust-libp2p/pull/4120

Expand Down
132 changes: 86 additions & 46 deletions swarm/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,16 @@ use std::collections::HashSet;
use std::fmt::{Display, Formatter};
use std::future::Future;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::Waker;
use std::time::Duration;
use std::{fmt, io, mem, pin::Pin, task::Context, task::Poll};

static NEXT_CONNECTION_ID: AtomicUsize = AtomicUsize::new(1);

/// Counter of the number of active streams on a connection
type ActiveStreamCounter = Arc<()>;

/// Connection identifier.
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct ConnectionId(usize);
Expand Down Expand Up @@ -157,6 +161,8 @@ where
local_supported_protocols: HashSet<StreamProtocol>,
remote_supported_protocols: HashSet<StreamProtocol>,
idle_timeout: Duration,
/// The counter of active streams
stream_counter: ActiveStreamCounter,
thomaseizinger marked this conversation as resolved.
Show resolved Hide resolved
}

impl<THandler> fmt::Debug for Connection<THandler>
Expand Down Expand Up @@ -205,6 +211,7 @@ where
local_supported_protocols: initial_protocols,
remote_supported_protocols: Default::default(),
idle_timeout,
stream_counter: Arc::new(()),
}
}

Expand Down Expand Up @@ -237,6 +244,7 @@ where
local_supported_protocols: supported_protocols,
remote_supported_protocols,
idle_timeout,
stream_counter,
} = self.get_mut();

loop {
Expand Down Expand Up @@ -344,55 +352,17 @@ where
}
}

// Ask the handler whether it wants the connection (and the handler itself)
// to be kept alive, which determines the planned shutdown, if any.
let keep_alive = handler.connection_keep_alive();
match (&mut *shutdown, keep_alive) {
(Shutdown::Later(timer, deadline), KeepAlive::Until(t)) => {
if *deadline != t {
*deadline = t;
if let Some(new_duration) = deadline.checked_duration_since(Instant::now())
{
let effective_keep_alive = max(new_duration, *idle_timeout);

timer.reset(effective_keep_alive)
}
}
}
(_, KeepAlive::Until(earliest_shutdown)) => {
let now = Instant::now();

if let Some(requested) = earliest_shutdown.checked_duration_since(now) {
let effective_keep_alive = max(requested, *idle_timeout);

let safe_keep_alive = checked_add_fraction(now, effective_keep_alive);

// Important: We store the _original_ `Instant` given by the `ConnectionHandler` in the `Later` instance to ensure we can compare it in the above branch.
// This is quite subtle but will hopefully become simpler soon once `KeepAlive::Until` is fully deprecated. See <https://github.com/libp2p/rust-libp2p/issues/3844>/
*shutdown = Shutdown::Later(Delay::new(safe_keep_alive), earliest_shutdown)
}
}
(_, KeepAlive::No) if idle_timeout == &Duration::ZERO => {
*shutdown = Shutdown::Asap;
}
(Shutdown::Later(_, _), KeepAlive::No) => {
// Do nothing, i.e. let the shutdown timer continue to tick.
}
(_, KeepAlive::No) => {
let now = Instant::now();
let safe_keep_alive = checked_add_fraction(now, *idle_timeout);

*shutdown = Shutdown::Later(Delay::new(safe_keep_alive), now + safe_keep_alive);
}
(_, KeepAlive::Yes) => *shutdown = Shutdown::None,
};

// Check if the connection (and handler) should be shut down.
// As long as we're still negotiating substreams, shutdown is always postponed.
if negotiating_in.is_empty()
&& negotiating_out.is_empty()
&& requested_substreams.is_empty()
&& Arc::strong_count(stream_counter) == 1
{
if let Some(new_timeout) = compute_new_shutdown(handler, shutdown, *idle_timeout) {
*shutdown = new_timeout;
}

match shutdown {
Shutdown::None => {}
Shutdown::Asap => return Poll::Ready(Err(ConnectionError::KeepAliveTimeout)),
Expand Down Expand Up @@ -427,6 +397,7 @@ where
timeout,
upgrade,
*substream_upgrade_protocol_override,
stream_counter.clone(),
));

continue; // Go back to the top, handler can potentially make progress again.
Expand All @@ -440,7 +411,11 @@ where
Poll::Ready(substream) => {
let protocol = handler.listen_protocol();

negotiating_in.push(StreamUpgrade::new_inbound(substream, protocol));
negotiating_in.push(StreamUpgrade::new_inbound(
substream,
protocol,
stream_counter.clone(),
));

continue; // Go back to the top, handler can potentially make progress again.
}
Expand Down Expand Up @@ -481,6 +456,69 @@ fn gather_supported_protocols(handler: &impl ConnectionHandler) -> HashSet<Strea
.collect()
}

fn compute_new_shutdown(
handler: &impl ConnectionHandler,
current_shutdown: &Shutdown,
idle_timeout: Duration,
) -> Option<Shutdown> {
// Ask the handler whether it wants the connection (and the handler itself)
// to be kept alive, which determines the planned shutdown, if any.
let keep_alive = handler.connection_keep_alive();
match (current_shutdown, keep_alive) {
(Shutdown::Later(_, deadline), KeepAlive::Until(t)) => {
let now = Instant::now();

if *deadline != t {
let new_deadline = t;
if let Some(new_duration) = new_deadline.checked_duration_since(now) {
let effective_keep_alive = max(new_duration, idle_timeout);

let safe_keep_alive = checked_add_fraction(now, effective_keep_alive);
return Some(Shutdown::Later(
Delay::new(safe_keep_alive),
new_deadline,
));
}
}

None
}
(_, KeepAlive::Until(earliest_shutdown)) => {
let now = Instant::now();

if let Some(requested) = earliest_shutdown.checked_duration_since(now) {
let effective_keep_alive = max(requested, idle_timeout);

let safe_keep_alive = checked_add_fraction(now, effective_keep_alive);

// Important: We store the _original_ `Instant` given by the `ConnectionHandler` in the `Later` instance to ensure we can compare it in the above branch.
// This is quite subtle but will hopefully become simpler soon once `KeepAlive::Until` is fully deprecated. See <https://github.com/libp2p/rust-libp2p/issues/3844>/
return Some(Shutdown::Later(
Delay::new(safe_keep_alive),
earliest_shutdown,
));
}

None
}
(_, KeepAlive::No) if idle_timeout == Duration::ZERO => Some(Shutdown::Asap),
(Shutdown::Later(_, _), KeepAlive::No) => {
// Do nothing, i.e. let the shutdown timer continue to tick.
None
}
(_, KeepAlive::No) => {
let now = Instant::now();
let safe_keep_alive = checked_add_fraction(now, idle_timeout);

Some(Shutdown::Later(
Delay::new(safe_keep_alive),
now + safe_keep_alive,
))
}
(_, KeepAlive::Yes) => Some(Shutdown::None),
}
}

/// Repeatedly halves and adds the [`Duration`] to the [`Instant`] until [`Instant::checked_add`] succeeds.
///
/// [`Instant`] depends on the underlying platform and has a limit of which points in time it can represent.
Expand Down Expand Up @@ -527,6 +565,7 @@ impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
timeout: Delay,
upgrade: Upgrade,
version_override: Option<upgrade::Version>,
counter: Arc<()>,
) -> Self
where
Upgrade: OutboundUpgradeSend<Output = TOk, Error = TErr>,
Expand Down Expand Up @@ -558,7 +597,7 @@ impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
.map_err(to_stream_upgrade_error)?;

let output = upgrade
.upgrade_outbound(Stream::new(stream), info)
.upgrade_outbound(Stream::new(stream, counter), info)
.await
.map_err(StreamUpgradeError::Apply)?;

Expand All @@ -572,6 +611,7 @@ impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
fn new_inbound<Upgrade>(
substream: SubstreamBox,
protocol: SubstreamProtocol<Upgrade, UserData>,
counter: Arc<()>,
) -> Self
where
Upgrade: InboundUpgradeSend<Output = TOk, Error = TErr>,
Expand All @@ -590,7 +630,7 @@ impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
.map_err(to_stream_upgrade_error)?;

let output = upgrade
.upgrade_inbound(Stream::new(stream), info)
.upgrade_inbound(Stream::new(stream, counter), info)
.await
.map_err(StreamUpgradeError::Apply)?;

Expand Down
26 changes: 18 additions & 8 deletions swarm/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,11 @@ pub trait ConnectionHandler: Send + 'static {

/// Returns until when the connection should be kept alive.
///
/// This method is called by the `Swarm` after each invocation of
/// [`ConnectionHandler::poll`] to determine if the connection and the associated
/// [`ConnectionHandler`]s should be kept alive as far as this handler is concerned
/// and if so, for how long.
/// `Swarm` checks if there are still active streams on this connection after
/// each invocation of [`ConnectionHandler::poll`]. If no, this method will
/// be called by the `Swarm` to determine if the connection and the associated
/// [`ConnectionHandler`]s should be kept alive as far as this handler is
/// concerned and if so, for how long.
///
/// Returning [`KeepAlive::No`] indicates that the connection should be
/// closed and this handler destroyed immediately.
Expand All @@ -139,10 +140,19 @@ pub trait ConnectionHandler: Send + 'static {
/// Returning [`KeepAlive::Yes`] indicates that the connection should
/// be kept alive until the next call to this method.
///
/// > **Note**: The connection is always closed and the handler destroyed
/// > when [`ConnectionHandler::poll`] returns an error. Furthermore, the
/// > connection may be closed for reasons outside of the control
/// > of the handler.
/// By default, connections are considered active and thus kept-alive whilst:
///
/// - There are streams currently being upgraded, see [`InboundUpgrade`](libp2p_core::upgrade::InboundUpgrade) and [`OutboundUpgrade`](libp2p_core::upgrade::OutboundUpgrade).
/// - There are still active streams, i.e. instances of [`Stream`](crate::stream::Stream) where the user did not call [Stream::no_keep_alive](crate::stream::Stream::no_keep_alive).
/// - The ConnectionHandler returns Poll::Ready.
///
/// Only once none of these conditions are true do we invoke this function to determine,
/// whether the connection should be kept alive even further.
/// Note that for most protocols, this is not necessary as it represents a completely idle
/// connection with no active and no pending streams.
///
/// If you'd like to delay the shutdown of idle connections, consider configuring
/// [SwarmBuilder::idle_connection_timeout](crate::SwarmBuilder) in your applications.
fn connection_keep_alive(&self) -> KeepAlive;

/// Should behave like `Stream::poll()`.
Expand Down
51 changes: 39 additions & 12 deletions swarm/src/stream.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,43 @@
use futures::{AsyncRead, AsyncWrite};
use libp2p_core::muxing::SubstreamBox;
use libp2p_core::Negotiated;
use std::io::{IoSlice, IoSliceMut};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{
io::{IoSlice, IoSliceMut},
pin::Pin,
sync::{Arc, Weak},
task::{Context, Poll},
};

#[derive(Debug)]
pub struct Stream(Negotiated<SubstreamBox>);
pub struct Stream {
stream: Negotiated<SubstreamBox>,
counter: StreamCounter,
}

#[derive(Debug)]
enum StreamCounter {
Arc(Arc<()>),
Weak(Weak<()>),
}

impl Stream {
pub(crate) fn new(stream: Negotiated<SubstreamBox>) -> Self {
Self(stream)
pub(crate) fn new(stream: Negotiated<SubstreamBox>, counter: Arc<()>) -> Self {
let counter = StreamCounter::Arc(counter);
Self { stream, counter }
}

/// Opt-out this stream from the [Swarm](crate::Swarm)s connection keep alive algorithm.
///
/// By default, any active stream keeps a connection alive. For most protocols,
/// this is a good default as it ensures that the protocol is completed before
/// a connection is shut down.
/// Some protocols like libp2p's [ping](https://github.com/libp2p/specs/blob/master/ping/ping.md)
/// for example never complete and are of an auxiliary nature.
/// These protocols should opt-out of the keep alive algorithm using this method.
pub fn no_keep_alive(&mut self) {
if let StreamCounter::Arc(arc_counter) = &self.counter {
self.counter = StreamCounter::Weak(Arc::downgrade(arc_counter));
}
}
}

Expand All @@ -20,15 +47,15 @@ impl AsyncRead for Stream {
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.get_mut().0).poll_read(cx, buf)
Pin::new(&mut self.get_mut().stream).poll_read(cx, buf)
}

fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [IoSliceMut<'_>],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.get_mut().0).poll_read_vectored(cx, bufs)
Pin::new(&mut self.get_mut().stream).poll_read_vectored(cx, bufs)
}
}

Expand All @@ -38,22 +65,22 @@ impl AsyncWrite for Stream {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
Pin::new(&mut self.get_mut().stream).poll_write(cx, buf)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.get_mut().0).poll_write_vectored(cx, bufs)
Pin::new(&mut self.get_mut().stream).poll_write_vectored(cx, bufs)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.get_mut().0).poll_flush(cx)
Pin::new(&mut self.get_mut().stream).poll_flush(cx)
}

fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.get_mut().0).poll_close(cx)
Pin::new(&mut self.get_mut().stream).poll_close(cx)
}
}