diff --git a/Cargo.lock b/Cargo.lock index b2fb6e87074..de08f42e91b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3040,7 +3040,7 @@ dependencies = [ [[package]] name = "libp2p-quic" -version = "0.8.0-alpha" +version = "0.9.0-alpha" dependencies = [ "async-std", "bytes", @@ -3058,7 +3058,7 @@ dependencies = [ "log", "parking_lot", "quickcheck", - "quinn-proto", + "quinn", "rand 0.8.5", "rustls 0.21.2", "thiserror", @@ -4313,6 +4313,26 @@ dependencies = [ "pin-project-lite 0.1.12", ] +[[package]] +name = "quinn" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21252f1c0fc131f1b69182db8f34837e8a69737b8251dff75636a9be0518c324" +dependencies = [ + "async-io", + "async-std", + "bytes", + "futures-io", + "pin-project-lite 0.2.9", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls 0.21.2", + "thiserror", + "tokio", + "tracing", +] + [[package]] name = "quinn-proto" version = "0.10.1" @@ -4330,6 +4350,19 @@ dependencies = [ "tracing", ] +[[package]] +name = "quinn-udp" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6df19e284d93757a9fb91d63672f7741b129246a669db09d1c0063071debc0c0" +dependencies = [ + "bytes", + "libc", + "socket2 0.5.3", + "tracing", + "windows-sys 0.48.0", +] + [[package]] name = "quote" version = "1.0.32" diff --git a/Cargo.toml b/Cargo.toml index 66b2cfc58ee..fa8c70eb4ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -83,7 +83,7 @@ libp2p-perf = { version = "0.2.0", path = "protocols/perf" } libp2p-ping = { version = "0.43.0", path = "protocols/ping" } libp2p-plaintext = { version = "0.40.0", path = "transports/plaintext" } libp2p-pnet = { version = "0.23.0", path = "transports/pnet" } -libp2p-quic = { version = "0.8.0-alpha", path = "transports/quic" } +libp2p-quic = { version = "0.9.0-alpha", path = "transports/quic" } libp2p-relay = { version = "0.16.1", path = "protocols/relay" } libp2p-rendezvous = { version = "0.13.0", path = "protocols/rendezvous" } libp2p-request-response = { version = "0.25.1", path = "protocols/request-response" } diff --git a/transports/quic/CHANGELOG.md b/transports/quic/CHANGELOG.md index 6e3ce801a2c..2012a3caf94 100644 --- a/transports/quic/CHANGELOG.md +++ b/transports/quic/CHANGELOG.md @@ -1,3 +1,10 @@ +## 0.9.0-alpha - unreleased + +- Use `quinn` instead of `quinn-proto`. + See [PR 3454]. + +[PR 3454]: https://github.com/libp2p/rust-libp2p/pull/3454 + ## 0.8.0-alpha - Raise MSRV to 1.65. diff --git a/transports/quic/Cargo.toml b/transports/quic/Cargo.toml index a04cc8d48b5..5d810860912 100644 --- a/transports/quic/Cargo.toml +++ b/transports/quic/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "libp2p-quic" -version = "0.8.0-alpha" +version = "0.9.0-alpha" authors = ["Parity Technologies "] edition = "2021" rust-version = { workspace = true } @@ -19,15 +19,15 @@ libp2p-tls = { workspace = true } libp2p-identity = { workspace = true } log = "0.4" parking_lot = "0.12.0" -quinn-proto = { version = "0.10.1", default-features = false, features = ["tls-rustls"] } +quinn = { version = "0.10.1", default-features = false, features = ["tls-rustls", "futures-io"] } rand = "0.8.5" rustls = { version = "0.21.2", default-features = false } thiserror = "1.0.44" tokio = { version = "1.29.1", default-features = false, features = ["net", "rt", "time"], optional = true } [features] -tokio = ["dep:tokio", "if-watch/tokio"] -async-std = ["dep:async-std", "if-watch/smol"] +tokio = ["dep:tokio", "if-watch/tokio", "quinn/runtime-tokio"] +async-std = ["dep:async-std", "if-watch/smol", "quinn/runtime-async-std"] # Passing arguments to the docsrs builder in order to properly document cfg's. # More information: https://docs.rs/about/builds#cross-compiling diff --git a/transports/quic/src/config.rs b/transports/quic/src/config.rs new file mode 100644 index 00000000000..201594e247c --- /dev/null +++ b/transports/quic/src/config.rs @@ -0,0 +1,142 @@ +// Copyright 2017-2020 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use quinn::VarInt; +use std::{sync::Arc, time::Duration}; + +/// Config for the transport. +#[derive(Clone)] +pub struct Config { + /// Timeout for the initial handshake when establishing a connection. + /// The actual timeout is the minimum of this and the [`Config::max_idle_timeout`]. + pub handshake_timeout: Duration, + /// Maximum duration of inactivity in ms to accept before timing out the connection. + pub max_idle_timeout: u32, + /// Period of inactivity before sending a keep-alive packet. + /// Must be set lower than the idle_timeout of both + /// peers to be effective. + /// + /// See [`quinn::TransportConfig::keep_alive_interval`] for more + /// info. + pub keep_alive_interval: Duration, + /// Maximum number of incoming bidirectional streams that may be open + /// concurrently by the remote peer. + pub max_concurrent_stream_limit: u32, + + /// Max unacknowledged data in bytes that may be send on a single stream. + pub max_stream_data: u32, + + /// Max unacknowledged data in bytes that may be send in total on all streams + /// of a connection. + pub max_connection_data: u32, + + /// Support QUIC version draft-29 for dialing and listening. + /// + /// Per default only QUIC Version 1 / [`libp2p_core::multiaddr::Protocol::QuicV1`] + /// is supported. + /// + /// If support for draft-29 is enabled servers support draft-29 and version 1 on all + /// QUIC listening addresses. + /// As client the version is chosen based on the remote's address. + pub support_draft_29: bool, + + /// TLS client config for the inner [`quinn::ClientConfig`]. + client_tls_config: Arc, + /// TLS server config for the inner [`quinn::ServerConfig`]. + server_tls_config: Arc, +} + +impl Config { + /// Creates a new configuration object with default values. + pub fn new(keypair: &libp2p_identity::Keypair) -> Self { + let client_tls_config = Arc::new(libp2p_tls::make_client_config(keypair, None).unwrap()); + let server_tls_config = Arc::new(libp2p_tls::make_server_config(keypair).unwrap()); + Self { + client_tls_config, + server_tls_config, + support_draft_29: false, + handshake_timeout: Duration::from_secs(5), + max_idle_timeout: 30 * 1000, + max_concurrent_stream_limit: 256, + keep_alive_interval: Duration::from_secs(15), + max_connection_data: 15_000_000, + + // Ensure that one stream is not consuming the whole connection. + max_stream_data: 10_000_000, + } + } +} + +/// Represents the inner configuration for [`quinn`]. +#[derive(Debug, Clone)] +pub(crate) struct QuinnConfig { + pub(crate) client_config: quinn::ClientConfig, + pub(crate) server_config: quinn::ServerConfig, + pub(crate) endpoint_config: quinn::EndpointConfig, +} + +impl From for QuinnConfig { + fn from(config: Config) -> QuinnConfig { + let Config { + client_tls_config, + server_tls_config, + max_idle_timeout, + max_concurrent_stream_limit, + keep_alive_interval, + max_connection_data, + max_stream_data, + support_draft_29, + handshake_timeout: _, + } = config; + let mut transport = quinn::TransportConfig::default(); + // Disable uni-directional streams. + transport.max_concurrent_uni_streams(0u32.into()); + transport.max_concurrent_bidi_streams(max_concurrent_stream_limit.into()); + // Disable datagrams. + transport.datagram_receive_buffer_size(None); + transport.keep_alive_interval(Some(keep_alive_interval)); + transport.max_idle_timeout(Some(VarInt::from_u32(max_idle_timeout).into())); + transport.allow_spin(false); + transport.stream_receive_window(max_stream_data.into()); + transport.receive_window(max_connection_data.into()); + let transport = Arc::new(transport); + + let mut server_config = quinn::ServerConfig::with_crypto(server_tls_config); + server_config.transport = Arc::clone(&transport); + // Disables connection migration. + // Long-term this should be enabled, however we then need to handle address change + // on connections in the `Connection`. + server_config.migration(false); + + let mut client_config = quinn::ClientConfig::new(client_tls_config); + client_config.transport_config(transport); + + let mut endpoint_config = quinn::EndpointConfig::default(); + if !support_draft_29 { + endpoint_config.supported_versions(vec![1]); + } + + QuinnConfig { + client_config, + server_config, + endpoint_config, + } + } +} diff --git a/transports/quic/src/connection.rs b/transports/quic/src/connection.rs index 0e5727dcf21..783258a0130 100644 --- a/transports/quic/src/connection.rs +++ b/transports/quic/src/connection.rs @@ -19,409 +19,113 @@ // DEALINGS IN THE SOFTWARE. mod connecting; -mod substream; +mod stream; -use crate::{ - endpoint::{self, ToEndpoint}, - Error, -}; pub use connecting::Connecting; -pub use substream::Substream; -use substream::{SubstreamState, WriteState}; +pub use stream::Stream; + +use crate::{ConnectionError, Error}; -use futures::{channel::mpsc, ready, FutureExt, StreamExt}; -use futures_timer::Delay; +use futures::{future::BoxFuture, FutureExt}; use libp2p_core::muxing::{StreamMuxer, StreamMuxerEvent}; -use parking_lot::Mutex; use std::{ - any::Any, - collections::HashMap, - net::SocketAddr, pin::Pin, - sync::Arc, - task::{Context, Poll, Waker}, - time::Instant, + task::{Context, Poll}, }; /// State for a single opened QUIC connection. -#[derive(Debug)] pub struct Connection { - /// State shared with the substreams. - state: Arc>, - /// Channel to the [`endpoint::Driver`] that drives the [`quinn_proto::Endpoint`] that - /// this connection belongs to. - endpoint_channel: endpoint::Channel, - /// Pending message to be sent to the [`quinn_proto::Endpoint`] in the [`endpoint::Driver`]. - pending_to_endpoint: Option, - /// Events that the [`quinn_proto::Endpoint`] will send in destination to our local - /// [`quinn_proto::Connection`]. - from_endpoint: mpsc::Receiver, - /// Identifier for this connection according to the [`quinn_proto::Endpoint`]. - /// Used when sending messages to the endpoint. - connection_id: quinn_proto::ConnectionHandle, - /// `Future` that triggers at the [`Instant`] that [`quinn_proto::Connection::poll_timeout`] - /// indicates. - next_timeout: Option<(Delay, Instant)>, + /// Underlying connection. + connection: quinn::Connection, + /// Future for accepting a new incoming bidirectional stream. + incoming: Option< + BoxFuture<'static, Result<(quinn::SendStream, quinn::RecvStream), quinn::ConnectionError>>, + >, + /// Future for opening a new outgoing bidirectional stream. + outgoing: Option< + BoxFuture<'static, Result<(quinn::SendStream, quinn::RecvStream), quinn::ConnectionError>>, + >, + /// Future to wait for the connection to be closed. + closing: Option>, } impl Connection { /// Build a [`Connection`] from raw components. /// - /// This function assumes that there exists a [`Driver`](super::endpoint::Driver) - /// that will process the messages sent to `EndpointChannel::to_endpoint` and send us messages - /// on `from_endpoint`. - /// - /// `connection_id` is used to identify the local connection in the messages sent to - /// `to_endpoint`. - /// - /// This function assumes that the [`quinn_proto::Connection`] is completely fresh and none of + /// This function assumes that the [`quinn::Connection`] is completely fresh and none of /// its methods has ever been called. Failure to comply might lead to logic errors and panics. - pub(crate) fn from_quinn_connection( - endpoint_channel: endpoint::Channel, - connection: quinn_proto::Connection, - connection_id: quinn_proto::ConnectionHandle, - from_endpoint: mpsc::Receiver, - ) -> Self { - let state = State { - connection, - substreams: HashMap::new(), - poll_connection_waker: None, - poll_inbound_waker: None, - poll_outbound_waker: None, - }; + fn new(connection: quinn::Connection) -> Self { Self { - endpoint_channel, - pending_to_endpoint: None, - next_timeout: None, - from_endpoint, - connection_id, - state: Arc::new(Mutex::new(state)), - } - } - - /// The address that the local socket is bound to. - pub(crate) fn local_addr(&self) -> &SocketAddr { - self.endpoint_channel.socket_addr() - } - - /// Returns the address of the node we're connected to. - pub(crate) fn remote_addr(&self) -> SocketAddr { - self.state.lock().connection.remote_address() - } - - /// Identity of the remote peer inferred from the handshake. - /// - /// `None` if the handshake is not complete yet, i.e. [`Self::poll_event`] - /// has not yet reported a [`quinn_proto::Event::Connected`] - fn peer_identity(&self) -> Option> { - self.state - .lock() - .connection - .crypto_session() - .peer_identity() - } - - /// Polls the connection for an event that happened on it. - /// - /// `quinn::proto::Connection` is polled in the order instructed in their docs: - /// 1. [`quinn_proto::Connection::poll_transmit`] - /// 2. [`quinn_proto::Connection::poll_timeout`] - /// 3. [`quinn_proto::Connection::poll_endpoint_events`] - /// 4. [`quinn_proto::Connection::poll`] - fn poll_event(&mut self, cx: &mut Context<'_>) -> Poll> { - let mut inner = self.state.lock(); - loop { - // Sending the pending event to the endpoint. If the endpoint is too busy, we just - // stop the processing here. - // We don't deliver substream-related events to the user as long as - // `to_endpoint` is full. This should propagate the back-pressure of `to_endpoint` - // being full to the user. - if let Some(to_endpoint) = self.pending_to_endpoint.take() { - match self.endpoint_channel.try_send(to_endpoint, cx) { - Ok(Ok(())) => {} - Ok(Err(to_endpoint)) => { - self.pending_to_endpoint = Some(to_endpoint); - return Poll::Pending; - } - Err(endpoint::Disconnected {}) => { - return Poll::Ready(None); - } - } - } - - match self.from_endpoint.poll_next_unpin(cx) { - Poll::Ready(Some(event)) => { - inner.connection.handle_event(event); - continue; - } - Poll::Ready(None) => { - return Poll::Ready(None); - } - Poll::Pending => {} - } - - // The maximum amount of segments which can be transmitted in a single Transmit - // if a platform supports Generic Send Offload (GSO). - // Set to 1 for now since not all platforms support GSO. - // TODO: Fix for platforms that support GSO. - let max_datagrams = 1; - // Poll the connection for packets to send on the UDP socket and try to send them on - // `to_endpoint`. - if let Some(transmit) = inner - .connection - .poll_transmit(Instant::now(), max_datagrams) - { - // TODO: ECN bits not handled - self.pending_to_endpoint = Some(ToEndpoint::SendUdpPacket(transmit)); - continue; - } - - match inner.connection.poll_timeout() { - Some(timeout) => match self.next_timeout { - Some((_, when)) if when == timeout => {} - _ => { - let now = Instant::now(); - // 0ns if now > when - let duration = timeout.duration_since(now); - let next_timeout = Delay::new(duration); - self.next_timeout = Some((next_timeout, timeout)) - } - }, - None => self.next_timeout = None, - } - - if let Some((timeout, when)) = self.next_timeout.as_mut() { - if timeout.poll_unpin(cx).is_ready() { - inner.connection.handle_timeout(*when); - continue; - } - } - - // The connection also needs to be able to send control messages to the endpoint. This is - // handled here, and we try to send them on `to_endpoint` as well. - if let Some(event) = inner.connection.poll_endpoint_events() { - let connection_id = self.connection_id; - self.pending_to_endpoint = Some(ToEndpoint::ProcessConnectionEvent { - connection_id, - event, - }); - continue; - } - - // The final step consists in returning the events related to the various substreams. - if let Some(ev) = inner.connection.poll() { - return Poll::Ready(Some(ev)); - } - - return Poll::Pending; + connection, + incoming: None, + outgoing: None, + closing: None, } } } impl StreamMuxer for Connection { - type Substream = Substream; + type Substream = Stream; type Error = Error; - fn poll( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - while let Poll::Ready(event) = self.poll_event(cx) { - let mut inner = self.state.lock(); - let event = match event { - Some(event) => event, - None => return Poll::Ready(Err(Error::EndpointDriverCrashed)), - }; - match event { - quinn_proto::Event::Connected | quinn_proto::Event::HandshakeDataReady => { - debug_assert!( - false, - "Unexpected event {event:?} on established QUIC connection" - ); - } - quinn_proto::Event::ConnectionLost { reason } => { - inner - .connection - .close(Instant::now(), From::from(0u32), Default::default()); - inner.substreams.values_mut().for_each(|s| s.wake_all()); - return Poll::Ready(Err(Error::Connection(reason.into()))); - } - quinn_proto::Event::Stream(quinn_proto::StreamEvent::Opened { - dir: quinn_proto::Dir::Bi, - }) => { - if let Some(waker) = inner.poll_outbound_waker.take() { - waker.wake(); - } - } - quinn_proto::Event::Stream(quinn_proto::StreamEvent::Available { - dir: quinn_proto::Dir::Bi, - }) => { - if let Some(waker) = inner.poll_inbound_waker.take() { - waker.wake(); - } - } - quinn_proto::Event::Stream(quinn_proto::StreamEvent::Readable { id }) => { - if let Some(substream) = inner.substreams.get_mut(&id) { - if let Some(waker) = substream.read_waker.take() { - waker.wake(); - } - } - } - quinn_proto::Event::Stream(quinn_proto::StreamEvent::Writable { id }) => { - if let Some(substream) = inner.substreams.get_mut(&id) { - if let Some(waker) = substream.write_waker.take() { - waker.wake(); - } - } - } - quinn_proto::Event::Stream(quinn_proto::StreamEvent::Finished { id }) => { - if let Some(substream) = inner.substreams.get_mut(&id) { - if matches!( - substream.write_state, - WriteState::Open | WriteState::Closing - ) { - substream.write_state = WriteState::Closed; - } - if let Some(waker) = substream.write_waker.take() { - waker.wake(); - } - if let Some(waker) = substream.close_waker.take() { - waker.wake(); - } - } - } - quinn_proto::Event::Stream(quinn_proto::StreamEvent::Stopped { - id, - error_code: _, - }) => { - if let Some(substream) = inner.substreams.get_mut(&id) { - substream.write_state = WriteState::Stopped; - if let Some(waker) = substream.write_waker.take() { - waker.wake(); - } - if let Some(waker) = substream.close_waker.take() { - waker.wake(); - } - } - } - quinn_proto::Event::DatagramReceived - | quinn_proto::Event::Stream(quinn_proto::StreamEvent::Available { - dir: quinn_proto::Dir::Uni, - }) - | quinn_proto::Event::Stream(quinn_proto::StreamEvent::Opened { - dir: quinn_proto::Dir::Uni, - }) => { - unreachable!("We don't use datagrams or unidirectional streams.") - } - } - } - // TODO: If connection migration is enabled (currently disabled) address - // change on the connection needs to be handled. - - self.state.lock().poll_connection_waker = Some(cx.waker().clone()); - Poll::Pending - } - fn poll_inbound( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let mut inner = self.state.lock(); + let this = self.get_mut(); - let substream_id = match inner.connection.streams().accept(quinn_proto::Dir::Bi) { - Some(id) => { - inner.poll_inbound_waker = None; - id - } - None => { - inner.poll_inbound_waker = Some(cx.waker().clone()); - return Poll::Pending; - } - }; - inner.substreams.insert(substream_id, Default::default()); - let substream = Substream::new(substream_id, self.state.clone()); + let incoming = this.incoming.get_or_insert_with(|| { + let connection = this.connection.clone(); + async move { connection.accept_bi().await }.boxed() + }); - Poll::Ready(Ok(substream)) + let (send, recv) = futures::ready!(incoming.poll_unpin(cx)).map_err(ConnectionError)?; + this.incoming.take(); + let stream = Stream::new(send, recv); + Poll::Ready(Ok(stream)) } fn poll_outbound( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let mut inner = self.state.lock(); - let substream_id = match inner.connection.streams().open(quinn_proto::Dir::Bi) { - Some(id) => { - inner.poll_outbound_waker = None; - id - } - None => { - inner.poll_outbound_waker = Some(cx.waker().clone()); - return Poll::Pending; - } - }; - inner.substreams.insert(substream_id, Default::default()); - let substream = Substream::new(substream_id, self.state.clone()); - Poll::Ready(Ok(substream)) - } + let this = self.get_mut(); - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut inner = self.state.lock(); - if inner.connection.is_drained() { - return Poll::Ready(Ok(())); - } + let outgoing = this.outgoing.get_or_insert_with(|| { + let connection = this.connection.clone(); + async move { connection.open_bi().await }.boxed() + }); - for substream in inner.substreams.keys().cloned().collect::>() { - let _ = inner.connection.send_stream(substream).finish(); - } - - if inner.connection.streams().send_streams() == 0 && !inner.connection.is_closed() { - inner - .connection - .close(Instant::now(), From::from(0u32), Default::default()) - } - drop(inner); - - loop { - match ready!(self.poll_event(cx)) { - Some(quinn_proto::Event::ConnectionLost { .. }) => return Poll::Ready(Ok(())), - None => return Poll::Ready(Err(Error::EndpointDriverCrashed)), - _ => {} - } - } + let (send, recv) = futures::ready!(outgoing.poll_unpin(cx)).map_err(ConnectionError)?; + this.outgoing.take(); + let stream = Stream::new(send, recv); + Poll::Ready(Ok(stream)) } -} -impl Drop for Connection { - fn drop(&mut self) { - let to_endpoint = ToEndpoint::ProcessConnectionEvent { - connection_id: self.connection_id, - event: quinn_proto::EndpointEvent::drained(), - }; - self.endpoint_channel.send_on_drop(to_endpoint); + fn poll( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + // TODO: If connection migration is enabled (currently disabled) address + // change on the connection needs to be handled. + Poll::Pending } -} -/// Mutex-protected state of [`Connection`]. -#[derive(Debug)] -pub struct State { - /// The QUIC inner state machine for this specific connection. - connection: quinn_proto::Connection, + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); - /// State of all the substreams that the muxer reports as open. - pub substreams: HashMap, + let closing = this.closing.get_or_insert_with(|| { + this.connection.close(From::from(0u32), &[]); + let connection = this.connection.clone(); + async move { connection.closed().await }.boxed() + }); - /// Waker to wake if a new outbound substream is opened. - pub poll_outbound_waker: Option, - /// Waker to wake if a new inbound substream was happened. - pub poll_inbound_waker: Option, - /// Waker to wake if the connection should be polled again. - pub poll_connection_waker: Option, -} + match futures::ready!(closing.poll_unpin(cx)) { + // Expected error given that `connection.close` was called above. + quinn::ConnectionError::LocallyClosed => {} + error => return Poll::Ready(Err(Error::Connection(ConnectionError(error)))), + }; -impl State { - fn unchecked_substream_state(&mut self, id: quinn_proto::StreamId) -> &mut SubstreamState { - self.substreams - .get_mut(&id) - .expect("Substream should be known.") + Poll::Ready(Ok(())) } } diff --git a/transports/quic/src/connection/connecting.rs b/transports/quic/src/connection/connecting.rs index e9a7d3e5f2c..b911eaa7dfe 100644 --- a/transports/quic/src/connection/connecting.rs +++ b/transports/quic/src/connection/connecting.rs @@ -20,9 +20,12 @@ //! Future that drives a QUIC connection until is has performed its TLS handshake. -use crate::{Connection, Error}; +use crate::{Connection, ConnectionError, Error}; -use futures::prelude::*; +use futures::{ + future::{select, Either, FutureExt, Select}, + prelude::*, +}; use futures_timer::Delay; use libp2p_identity::PeerId; use std::{ @@ -34,64 +37,46 @@ use std::{ /// A QUIC connection currently being negotiated. #[derive(Debug)] pub struct Connecting { - connection: Option, - timeout: Delay, + connecting: Select, } impl Connecting { - pub(crate) fn new(connection: Connection, timeout: Duration) -> Self { + pub(crate) fn new(connection: quinn::Connecting, timeout: Duration) -> Self { Connecting { - connection: Some(connection), - timeout: Delay::new(timeout), + connecting: select(connection, Delay::new(timeout)), } } } +impl Connecting { + /// Returns the address of the node we're connected to. + /// Panics if the connection is still handshaking. + fn remote_peer_id(connection: &quinn::Connection) -> PeerId { + let identity = connection + .peer_identity() + .expect("connection got identity because it passed TLS handshake; qed"); + let certificates: Box> = + identity.downcast().expect("we rely on rustls feature; qed"); + let end_entity = certificates + .get(0) + .expect("there should be exactly one certificate; qed"); + let p2p_cert = libp2p_tls::certificate::parse(end_entity) + .expect("the certificate was validated during TLS handshake; qed"); + p2p_cert.peer_id() + } +} + impl Future for Connecting { type Output = Result<(PeerId, Connection), Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let connection = self - .connection - .as_mut() - .expect("Future polled after it has completed"); - - loop { - let event = match connection.poll_event(cx) { - Poll::Ready(Some(event)) => event, - Poll::Ready(None) => return Poll::Ready(Err(Error::EndpointDriverCrashed)), - Poll::Pending => { - return self - .timeout - .poll_unpin(cx) - .map(|()| Err(Error::HandshakeTimedOut)); - } - }; - match event { - quinn_proto::Event::Connected => { - // Parse the remote's Id identity from the certificate. - let identity = connection - .peer_identity() - .expect("connection got identity because it passed TLS handshake; qed"); - let certificates: Box> = - identity.downcast().expect("we rely on rustls feature; qed"); - let end_entity = certificates - .get(0) - .expect("there should be exactly one certificate; qed"); - let p2p_cert = libp2p_tls::certificate::parse(end_entity) - .expect("the certificate was validated during TLS handshake; qed"); - let peer_id = p2p_cert.peer_id(); + let connection = match futures::ready!(self.connecting.poll_unpin(cx)) { + Either::Right(_) => return Poll::Ready(Err(Error::HandshakeTimedOut)), + Either::Left((connection, _)) => connection.map_err(ConnectionError)?, + }; - return Poll::Ready(Ok((peer_id, self.connection.take().unwrap()))); - } - quinn_proto::Event::ConnectionLost { reason } => { - return Poll::Ready(Err(Error::Connection(reason.into()))) - } - quinn_proto::Event::HandshakeDataReady | quinn_proto::Event::Stream(_) => {} - quinn_proto::Event::DatagramReceived => { - debug_assert!(false, "Datagrams are not supported") - } - } - } + let peer_id = Self::remote_peer_id(&connection); + let muxer = Connection::new(connection); + Poll::Ready(Ok((peer_id, muxer))) } } diff --git a/transports/quic/src/connection/stream.rs b/transports/quic/src/connection/stream.rs new file mode 100644 index 00000000000..b0c505bf856 --- /dev/null +++ b/transports/quic/src/connection/stream.rs @@ -0,0 +1,86 @@ +// Copyright 2022 Protocol Labs. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use std::{ + io::{self}, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::{AsyncRead, AsyncWrite}; + +/// A single stream on a connection +pub struct Stream { + /// A send part of the stream + send: quinn::SendStream, + /// A receive part of the stream + recv: quinn::RecvStream, + /// Whether the stream is closed or not + close_result: Option>, +} + +impl Stream { + pub(super) fn new(send: quinn::SendStream, recv: quinn::RecvStream) -> Self { + Self { + send, + recv, + close_result: None, + } + } +} + +impl AsyncRead for Stream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut [u8], + ) -> Poll> { + if let Some(close_result) = self.close_result { + if close_result.is_err() { + return Poll::Ready(Ok(0)); + } + } + Pin::new(&mut self.recv).poll_read(cx, buf) + } +} + +impl AsyncWrite for Stream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.send).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.send).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + if let Some(close_result) = self.close_result { + // For some reason poll_close needs to be 'fuse'able + return Poll::Ready(close_result.map_err(Into::into)); + } + let close_result = futures::ready!(Pin::new(&mut self.send).poll_close(cx)); + self.close_result = Some(close_result.as_ref().map_err(|e| e.kind()).copied()); + Poll::Ready(close_result) + } +} diff --git a/transports/quic/src/connection/substream.rs b/transports/quic/src/connection/substream.rs deleted file mode 100644 index b3a82542e9c..00000000000 --- a/transports/quic/src/connection/substream.rs +++ /dev/null @@ -1,257 +0,0 @@ -// Copyright 2022 Protocol Labs. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -use std::{ - io::{self, Write}, - pin::Pin, - sync::Arc, - task::{Context, Poll, Waker}, -}; - -use futures::{AsyncRead, AsyncWrite}; -use parking_lot::Mutex; - -use super::State; - -/// Wakers for the [`AsyncRead`] and [`AsyncWrite`] on a substream. -#[derive(Debug, Default, Clone)] -pub struct SubstreamState { - /// Waker to wake if the substream becomes readable. - pub read_waker: Option, - /// Waker to wake if the substream becomes writable, closed or stopped. - pub write_waker: Option, - /// Waker to wake if the substream becomes closed or stopped. - pub close_waker: Option, - - pub write_state: WriteState, -} - -impl SubstreamState { - /// Wake all wakers for reading, writing and closed the stream. - pub fn wake_all(&mut self) { - if let Some(waker) = self.read_waker.take() { - waker.wake(); - } - if let Some(waker) = self.write_waker.take() { - waker.wake(); - } - if let Some(waker) = self.close_waker.take() { - waker.wake(); - } - } -} - -/// A single stream on a connection -#[derive(Debug)] -pub struct Substream { - /// The id of the stream. - id: quinn_proto::StreamId, - /// The state of the [`super::Connection`] this stream belongs to. - state: Arc>, -} - -impl Substream { - pub fn new(id: quinn_proto::StreamId, state: Arc>) -> Self { - Self { id, state } - } -} - -impl AsyncRead for Substream { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - mut buf: &mut [u8], - ) -> Poll> { - let mut state = self.state.lock(); - - let mut stream = state.connection.recv_stream(self.id); - let mut chunks = match stream.read(true) { - Ok(chunks) => chunks, - Err(quinn_proto::ReadableError::UnknownStream) => { - return Poll::Ready(Ok(0)); - } - Err(quinn_proto::ReadableError::IllegalOrderedRead) => { - unreachable!( - "Illegal ordered read can only happen if `stream.read(false)` is used." - ); - } - }; - - let mut bytes = 0; - let mut pending = false; - let mut error = None; - loop { - if buf.is_empty() { - // Chunks::next will continue returning `Ok(Some(_))` with an - // empty chunk if there is no space left in the buffer, so we - // break early here. - break; - } - let chunk = match chunks.next(buf.len()) { - Ok(Some(chunk)) => chunk, - Ok(None) => break, - Err(err @ quinn_proto::ReadError::Reset(_)) => { - error = Some(Err(io::Error::new(io::ErrorKind::ConnectionReset, err))); - break; - } - Err(quinn_proto::ReadError::Blocked) => { - pending = true; - break; - } - }; - - buf.write_all(&chunk.bytes).expect("enough buffer space"); - bytes += chunk.bytes.len(); - } - if chunks.finalize().should_transmit() { - if let Some(waker) = state.poll_connection_waker.take() { - waker.wake(); - } - } - if let Some(err) = error { - return Poll::Ready(err); - } - - if pending && bytes == 0 { - let substream_state = state.unchecked_substream_state(self.id); - substream_state.read_waker = Some(cx.waker().clone()); - return Poll::Pending; - } - - Poll::Ready(Ok(bytes)) - } -} - -impl AsyncWrite for Substream { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let mut state = self.state.lock(); - - match state.connection.send_stream(self.id).write(buf) { - Ok(bytes) => { - if let Some(waker) = state.poll_connection_waker.take() { - waker.wake(); - } - Poll::Ready(Ok(bytes)) - } - Err(quinn_proto::WriteError::Blocked) => { - let substream_state = state.unchecked_substream_state(self.id); - substream_state.write_waker = Some(cx.waker().clone()); - Poll::Pending - } - Err(err @ quinn_proto::WriteError::Stopped(_)) => { - Poll::Ready(Err(io::Error::new(io::ErrorKind::ConnectionReset, err))) - } - Err(quinn_proto::WriteError::UnknownStream) => { - Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) - } - } - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - // quinn doesn't support flushing, calling close will flush all substreams. - Poll::Ready(Ok(())) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut inner = self.state.lock(); - - let substream_state = inner.unchecked_substream_state(self.id); - match substream_state.write_state { - WriteState::Open => {} - WriteState::Closing => { - substream_state.close_waker = Some(cx.waker().clone()); - return Poll::Pending; - } - WriteState::Closed => return Poll::Ready(Ok(())), - WriteState::Stopped => { - let err = quinn_proto::FinishError::Stopped(0u32.into()); - return Poll::Ready(Err(io::Error::new(io::ErrorKind::ConnectionReset, err))); - } - } - - match inner.connection.send_stream(self.id).finish() { - Ok(()) => { - let substream_state = inner.unchecked_substream_state(self.id); - substream_state.close_waker = Some(cx.waker().clone()); - substream_state.write_state = WriteState::Closing; - Poll::Pending - } - Err(err @ quinn_proto::FinishError::Stopped(_)) => { - Poll::Ready(Err(io::Error::new(io::ErrorKind::ConnectionReset, err))) - } - Err(quinn_proto::FinishError::UnknownStream) => { - // We never make up IDs so the stream must have existed at some point if we get to here. - // `UnknownStream` is also emitted in case the stream is already finished, hence just - // return `Ok(())` here. - Poll::Ready(Ok(())) - } - } - } -} - -impl Drop for Substream { - fn drop(&mut self) { - let mut state = self.state.lock(); - state.substreams.remove(&self.id); - // Send `STOP_STREAM` if the remote did not finish the stream yet. - // We have to manually check the read stream since we might have - // received a `FIN` (without any other stream data) after the last - // time we tried to read. - let mut is_read_done = false; - if let Ok(mut chunks) = state.connection.recv_stream(self.id).read(true) { - if let Ok(chunk) = chunks.next(0) { - is_read_done = chunk.is_none(); - } - let _ = chunks.finalize(); - } - if !is_read_done { - let _ = state.connection.recv_stream(self.id).stop(0u32.into()); - } - // Close the writing side. - let mut send_stream = state.connection.send_stream(self.id); - match send_stream.finish() { - Ok(()) => {} - // Already finished or reset, which is fine. - Err(quinn_proto::FinishError::UnknownStream) => {} - Err(quinn_proto::FinishError::Stopped(reason)) => { - let _ = send_stream.reset(reason); - } - } - } -} - -#[derive(Debug, Default, Clone)] -pub enum WriteState { - /// The stream is open for writing. - #[default] - Open, - /// The writing side of the stream is closing. - Closing, - /// All data was successfully sent to the remote and the stream closed, - /// i.e. a [`quinn_proto::StreamEvent::Finished`] was reported for it. - Closed, - /// The stream was stopped by the remote before all data could be - /// sent. - Stopped, -} diff --git a/transports/quic/src/endpoint.rs b/transports/quic/src/endpoint.rs deleted file mode 100644 index bf69df50b62..00000000000 --- a/transports/quic/src/endpoint.rs +++ /dev/null @@ -1,674 +0,0 @@ -// Copyright 2017-2020 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -use crate::{ - provider::Provider, - transport::{ProtocolVersion, SocketFamily}, - ConnectError, Connection, Error, -}; - -use bytes::BytesMut; -use futures::{ - channel::{mpsc, oneshot}, - prelude::*, -}; -use quinn_proto::VarInt; -use std::{ - collections::HashMap, - net::{Ipv4Addr, Ipv6Addr, SocketAddr}, - ops::ControlFlow, - pin::Pin, - sync::Arc, - task::{Context, Poll}, - time::{Duration, Instant}, -}; - -// The `Driver` drops packets if the channel to the connection -// or transport is full. -// Set capacity 10 to avoid unnecessary packet drops if the receiver -// is only very briefly busy, but not buffer a large amount of packets -// if it is blocked longer. -const CHANNEL_CAPACITY: usize = 10; - -/// Config for the transport. -#[derive(Clone)] -pub struct Config { - /// Timeout for the initial handshake when establishing a connection. - /// The actual timeout is the minimum of this an the [`Config::max_idle_timeout`]. - pub handshake_timeout: Duration, - /// Maximum duration of inactivity in ms to accept before timing out the connection. - pub max_idle_timeout: u32, - /// Period of inactivity before sending a keep-alive packet. - /// Must be set lower than the idle_timeout of both - /// peers to be effective. - /// - /// See [`quinn_proto::TransportConfig::keep_alive_interval`] for more - /// info. - pub keep_alive_interval: Duration, - /// Maximum number of incoming bidirectional streams that may be open - /// concurrently by the remote peer. - pub max_concurrent_stream_limit: u32, - - /// Max unacknowledged data in bytes that may be send on a single stream. - pub max_stream_data: u32, - - /// Max unacknowledged data in bytes that may be send in total on all streams - /// of a connection. - pub max_connection_data: u32, - - /// Support QUIC version draft-29 for dialing and listening. - /// - /// Per default only QUIC Version 1 / [`libp2p_core::multiaddr::Protocol::QuicV1`] - /// is supported. - /// - /// If support for draft-29 is enabled servers support draft-29 and version 1 on all - /// QUIC listening addresses. - /// As client the version is chosen based on the remote's address. - pub support_draft_29: bool, - - /// TLS client config for the inner [`quinn_proto::ClientConfig`]. - client_tls_config: Arc, - /// TLS server config for the inner [`quinn_proto::ServerConfig`]. - server_tls_config: Arc, -} - -impl Config { - /// Creates a new configuration object with default values. - pub fn new(keypair: &libp2p_identity::Keypair) -> Self { - let client_tls_config = Arc::new(libp2p_tls::make_client_config(keypair, None).unwrap()); - let server_tls_config = Arc::new(libp2p_tls::make_server_config(keypair).unwrap()); - Self { - client_tls_config, - server_tls_config, - support_draft_29: false, - handshake_timeout: Duration::from_secs(5), - max_idle_timeout: 30 * 1000, - max_concurrent_stream_limit: 256, - keep_alive_interval: Duration::from_secs(15), - max_connection_data: 15_000_000, - - // Ensure that one stream is not consuming the whole connection. - max_stream_data: 10_000_000, - } - } -} - -/// Represents the inner configuration for [`quinn_proto`]. -#[derive(Debug, Clone)] -pub(crate) struct QuinnConfig { - client_config: quinn_proto::ClientConfig, - server_config: Arc, - endpoint_config: Arc, -} - -impl From for QuinnConfig { - fn from(config: Config) -> QuinnConfig { - let Config { - client_tls_config, - server_tls_config, - max_idle_timeout, - max_concurrent_stream_limit, - keep_alive_interval, - max_connection_data, - max_stream_data, - support_draft_29, - handshake_timeout: _, - } = config; - let mut transport = quinn_proto::TransportConfig::default(); - // Disable uni-directional streams. - transport.max_concurrent_uni_streams(0u32.into()); - transport.max_concurrent_bidi_streams(max_concurrent_stream_limit.into()); - // Disable datagrams. - transport.datagram_receive_buffer_size(None); - transport.keep_alive_interval(Some(keep_alive_interval)); - transport.max_idle_timeout(Some(VarInt::from_u32(max_idle_timeout).into())); - transport.allow_spin(false); - transport.stream_receive_window(max_stream_data.into()); - transport.receive_window(max_connection_data.into()); - let transport = Arc::new(transport); - - let mut server_config = quinn_proto::ServerConfig::with_crypto(server_tls_config); - server_config.transport = Arc::clone(&transport); - // Disables connection migration. - // Long-term this should be enabled, however we then need to handle address change - // on connections in the `Connection`. - server_config.migration(false); - - let mut client_config = quinn_proto::ClientConfig::new(client_tls_config); - client_config.transport_config(transport); - - let mut endpoint_config = quinn_proto::EndpointConfig::default(); - if !support_draft_29 { - endpoint_config.supported_versions(vec![1]); - } - - QuinnConfig { - client_config, - server_config: Arc::new(server_config), - endpoint_config: Arc::new(endpoint_config), - } - } -} - -/// Channel used to send commands to the [`Driver`]. -#[derive(Debug, Clone)] -pub(crate) struct Channel { - /// Channel to the background of the endpoint. - to_endpoint: mpsc::Sender, - /// Address that the socket is bound to. - /// Note: this may be a wildcard ip address. - socket_addr: SocketAddr, -} - -impl Channel { - /// Builds a new endpoint that is listening on the [`SocketAddr`]. - pub(crate) fn new_bidirectional( - quinn_config: QuinnConfig, - socket_addr: SocketAddr, - ) -> Result<(Self, mpsc::Receiver), Error> { - // Channel for forwarding new inbound connections to the listener. - let (new_connections_tx, new_connections_rx) = mpsc::channel(CHANNEL_CAPACITY); - let endpoint = Self::new::

(quinn_config, socket_addr, Some(new_connections_tx))?; - Ok((endpoint, new_connections_rx)) - } - - /// Builds a new endpoint that only supports outbound connections. - pub(crate) fn new_dialer( - quinn_config: QuinnConfig, - socket_family: SocketFamily, - ) -> Result { - let socket_addr = match socket_family { - SocketFamily::Ipv4 => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0), - SocketFamily::Ipv6 => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0), - }; - Self::new::

(quinn_config, socket_addr, None) - } - - /// Spawn a new [`Driver`] that runs in the background. - fn new( - quinn_config: QuinnConfig, - socket_addr: SocketAddr, - new_connections: Option>, - ) -> Result { - let socket = std::net::UdpSocket::bind(socket_addr)?; - // NOT blocking, as per man:bind(2), as we pass an IP address. - socket.set_nonblocking(true)?; - // Capacity 0 to back-pressure the rest of the application if - // the udp socket is busy. - let (to_endpoint_tx, to_endpoint_rx) = mpsc::channel(0); - - let channel = Self { - to_endpoint: to_endpoint_tx, - socket_addr: socket.local_addr()?, - }; - - let server_config = new_connections - .is_some() - .then_some(quinn_config.server_config); - - let provider_socket = P::from_socket(socket)?; - - let driver = Driver::

::new( - quinn_config.endpoint_config, - quinn_config.client_config, - new_connections, - server_config, - channel.clone(), - provider_socket, - to_endpoint_rx, - ); - - // Drive the endpoint future in the background. - P::spawn(driver); - - Ok(channel) - } - - pub(crate) fn socket_addr(&self) -> &SocketAddr { - &self.socket_addr - } - - /// Try to send a message to the background task without blocking. - /// - /// This first polls the channel for capacity. - /// If the channel is full, the message is returned in `Ok(Err(_))` - /// and the context's waker is registered for wake-up. - /// - /// If the background task crashed `Err` is returned. - pub(crate) fn try_send( - &mut self, - to_endpoint: ToEndpoint, - cx: &mut Context<'_>, - ) -> Result, Disconnected> { - match self.to_endpoint.poll_ready_unpin(cx) { - Poll::Ready(Ok(())) => {} - Poll::Ready(Err(e)) => { - debug_assert!( - e.is_disconnected(), - "mpsc::Sender can only be disconnected when calling `poll_ready_unpin" - ); - - return Err(Disconnected {}); - } - Poll::Pending => return Ok(Err(to_endpoint)), - }; - - if let Err(e) = self.to_endpoint.start_send(to_endpoint) { - debug_assert!(e.is_disconnected(), "We called `Sink::poll_ready` so we are guaranteed to have a slot. If this fails, it means we are disconnected."); - - return Err(Disconnected {}); - } - - Ok(Ok(())) - } - - pub(crate) async fn send(&mut self, to_endpoint: ToEndpoint) -> Result<(), Disconnected> { - self.to_endpoint - .send(to_endpoint) - .await - .map_err(|_| Disconnected {}) - } - - /// Send a message to inform the [`Driver`] about an - /// event caused by the owner of this [`Channel`] dropping. - /// This clones the sender to the endpoint to guarantee delivery. - /// This should *not* be called for regular messages. - pub(crate) fn send_on_drop(&mut self, to_endpoint: ToEndpoint) { - let _ = self.to_endpoint.clone().try_send(to_endpoint); - } -} - -#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)] -#[error("Background task disconnected")] -pub(crate) struct Disconnected {} -/// Message sent to the endpoint background task. -#[derive(Debug)] -pub(crate) enum ToEndpoint { - /// Instruct the [`quinn_proto::Endpoint`] to start connecting to the given address. - Dial { - /// UDP address to connect to. - addr: SocketAddr, - /// Version to dial the remote on. - version: ProtocolVersion, - /// Channel to return the result of the dialing to. - result: oneshot::Sender>, - }, - /// Send by a [`quinn_proto::Connection`] when the endpoint needs to process an event generated - /// by a connection. The event itself is opaque to us. Only `quinn_proto` knows what is in - /// there. - ProcessConnectionEvent { - connection_id: quinn_proto::ConnectionHandle, - event: quinn_proto::EndpointEvent, - }, - /// Instruct the endpoint to send a packet of data on its UDP socket. - SendUdpPacket(quinn_proto::Transmit), - /// The [`GenTransport`][crate::GenTransport] dialer or listener coupled to this endpoint - /// was dropped. - /// Once all pending connections are closed, the [`Driver`] should shut down. - Decoupled, -} - -/// Driver that runs in the background for as long as the endpoint is alive. Responsible for -/// processing messages and the UDP socket. -/// -/// # Behaviour -/// -/// This background task is responsible for the following: -/// -/// - Sending packets on the UDP socket. -/// - Receiving packets from the UDP socket and feed them to the [`quinn_proto::Endpoint`] state -/// machine. -/// - Transmitting events generated by the [`quinn_proto::Endpoint`] to the corresponding -/// [`crate::Connection`]. -/// - Receiving messages from the `rx` and processing the requested actions. This includes -/// UDP packets to send and events emitted by the [`crate::Connection`] objects. -/// - Sending new connections on `new_connection_tx`. -/// -/// When it comes to channels, there exists three main multi-producer-single-consumer channels -/// in play: -/// -/// - One channel, represented by `EndpointChannel::to_endpoint` and `Driver::rx`, -/// that communicates messages from [`Channel`] to the [`Driver`]. -/// - One channel for each existing connection that communicates messages from the -/// [`Driver` to that [`crate::Connection`]. -/// - One channel for the [`Driver`] to send newly-opened connections to. The receiving -/// side is processed by the [`GenTransport`][crate::GenTransport]. -/// -/// -/// ## Back-pressure -/// -/// ### If writing to the UDP socket is blocked -/// -/// In order to avoid an unbounded buffering of events, we prioritize sending data on the UDP -/// socket over everything else. Messages from the rest of the application sent through the -/// [`Channel`] are only processed if the UDP socket is ready so that we propagate back-pressure -/// in case of a busy socket. For connections, thus this eventually also back-pressures the -/// `AsyncWrite`on substreams. -/// -/// -/// ### Back-pressuring the remote if the application is busy -/// -/// If the channel to a connection is full because the connection is busy, inbound datagrams -/// for that connection are dropped so that the remote is backpressured. -/// The same applies for new connections if the transport is too busy to received it. -/// -/// -/// # Shutdown -/// -/// The background task shuts down if an [`ToEndpoint::Decoupled`] event was received and the -/// last active connection has drained. -#[derive(Debug)] -pub(crate) struct Driver { - // The actual QUIC state machine. - endpoint: quinn_proto::Endpoint, - // QuinnConfig for client connections. - client_config: quinn_proto::ClientConfig, - // Copy of the channel to the endpoint driver that is passed to each new connection. - channel: Channel, - // Channel to receive messages from the transport or connections. - rx: mpsc::Receiver, - - // Socket for sending and receiving datagrams. - provider_socket: P, - // Future for writing the next packet to the socket. - next_packet_out: Option, - - // List of all active connections, with a sender to notify them of events. - alive_connections: - HashMap>, - // Channel to forward new inbound connections to the transport. - // `None` if server capabilities are disabled, i.e. the endpoint is only used for dialing. - new_connection_tx: Option>, - // Whether the transport dropped its handle for this endpoint. - is_decoupled: bool, -} - -impl Driver

{ - fn new( - endpoint_config: Arc, - client_config: quinn_proto::ClientConfig, - new_connection_tx: Option>, - server_config: Option>, - channel: Channel, - socket: P, - rx: mpsc::Receiver, - ) -> Self { - Driver { - endpoint: quinn_proto::Endpoint::new(endpoint_config, server_config, false), - client_config, - channel, - rx, - provider_socket: socket, - next_packet_out: None, - alive_connections: HashMap::new(), - new_connection_tx, - is_decoupled: false, - } - } - - /// Handle a message sent from either the [`GenTransport`](super::GenTransport) - /// or a [`crate::Connection`]. - fn handle_message( - &mut self, - to_endpoint: ToEndpoint, - ) -> ControlFlow<(), Option> { - match to_endpoint { - ToEndpoint::Dial { - addr, - result, - version, - } => { - let mut config = self.client_config.clone(); - if version == ProtocolVersion::Draft29 { - config.version(0xff00_001d); - } - // This `"l"` seems necessary because an empty string is an invalid domain - // name. While we don't use domain names, the underlying rustls library - // is based upon the assumption that we do. - let (connection_id, connection) = match self.endpoint.connect(config, addr, "l") { - Ok(c) => c, - Err(err) => { - let _ = result.send(Err(ConnectError::from(err).into())); - return ControlFlow::Continue(None); - } - }; - - debug_assert_eq!(connection.side(), quinn_proto::Side::Client); - let (tx, rx) = mpsc::channel(CHANNEL_CAPACITY); - let connection = Connection::from_quinn_connection( - self.channel.clone(), - connection, - connection_id, - rx, - ); - self.alive_connections.insert(connection_id, tx); - let _ = result.send(Ok(connection)); - } - - // A connection wants to notify the endpoint of something. - ToEndpoint::ProcessConnectionEvent { - connection_id, - event, - } => { - let has_key = self.alive_connections.contains_key(&connection_id); - if !has_key { - return ControlFlow::Continue(None); - } - // We "drained" event indicates that the connection no longer exists and - // its ID can be reclaimed. - let is_drained_event = event.is_drained(); - if is_drained_event { - self.alive_connections.remove(&connection_id); - if self.is_decoupled && self.alive_connections.is_empty() { - log::debug!( - "Driver is decoupled and no active connections remain. Shutting down." - ); - return ControlFlow::Break(()); - } - } - - let event_back = self.endpoint.handle_event(connection_id, event); - - if let Some(event_back) = event_back { - debug_assert!(!is_drained_event); - if let Some(sender) = self.alive_connections.get_mut(&connection_id) { - // We clone the sender to guarantee that there will be at least one - // free slot to send the event. - // The channel can not grow out of bound because an `event_back` is - // only sent if we previously received an event from the same connection. - // If the connection is busy, it won't sent us any more events to handle. - let _ = sender.clone().start_send(event_back); - } else { - log::error!("State mismatch: event for closed connection"); - } - } - } - - // Data needs to be sent on the UDP socket. - ToEndpoint::SendUdpPacket(transmit) => return ControlFlow::Continue(Some(transmit)), - ToEndpoint::Decoupled => self.handle_decoupling()?, - } - ControlFlow::Continue(None) - } - - /// Handle an UDP datagram received on the socket. - /// The datagram content was written into the `socket_recv_buffer`. - fn handle_datagram(&mut self, packet: BytesMut, packet_src: SocketAddr) -> ControlFlow<()> { - let local_ip = self.channel.socket_addr.ip(); - // TODO: ECN bits aren't handled - let (connec_id, event) = - match self - .endpoint - .handle(Instant::now(), packet_src, Some(local_ip), None, packet) - { - Some(event) => event, - None => return ControlFlow::Continue(()), - }; - match event { - quinn_proto::DatagramEvent::ConnectionEvent(event) => { - // `event` has type `quinn_proto::ConnectionEvent`, which has multiple - // variants. `quinn_proto::Endpoint::handle` however only ever returns - // `ConnectionEvent::Datagram`. - debug_assert!(format!("{event:?}").contains("Datagram")); - - // Redirect the datagram to its connection. - if let Some(sender) = self.alive_connections.get_mut(&connec_id) { - match sender.try_send(event) { - Ok(()) => {} - Err(err) if err.is_disconnected() => { - // Connection was dropped by the user. - // Inform the endpoint that this connection is drained. - self.endpoint - .handle_event(connec_id, quinn_proto::EndpointEvent::drained()); - self.alive_connections.remove(&connec_id); - } - Err(err) if err.is_full() => { - // Connection is too busy. Drop the datagram to back-pressure the remote. - log::debug!( - "Dropping packet for connection {:?} because the connection's channel is full.", - connec_id - ); - } - Err(_) => unreachable!("Error is either `Full` or `Disconnected`."), - } - } else { - log::error!("State mismatch: event for closed connection"); - } - } - quinn_proto::DatagramEvent::NewConnection(connec) => { - // A new connection has been received. `connec_id` is a newly-allocated - // identifier. - debug_assert_eq!(connec.side(), quinn_proto::Side::Server); - let connection_tx = match self.new_connection_tx.as_mut() { - Some(tx) => tx, - None => { - debug_assert!(false, "Endpoint reported a new connection even though server capabilities are disabled."); - return ControlFlow::Continue(()); - } - }; - - let (tx, rx) = mpsc::channel(CHANNEL_CAPACITY); - let connection = - Connection::from_quinn_connection(self.channel.clone(), connec, connec_id, rx); - match connection_tx.start_send(connection) { - Ok(()) => { - self.alive_connections.insert(connec_id, tx); - } - Err(e) if e.is_disconnected() => self.handle_decoupling()?, - Err(e) if e.is_full() => log::warn!( - "Dropping new incoming connection {:?} because the channel to the listener is full", - connec_id - ), - Err(_) => unreachable!("Error is either `Full` or `Disconnected`."), - } - } - } - ControlFlow::Continue(()) - } - - /// The transport dropped the channel to this [`Driver`]. - fn handle_decoupling(&mut self) -> ControlFlow<()> { - if self.alive_connections.is_empty() { - return ControlFlow::Break(()); - } - // Listener was closed. - self.endpoint.reject_new_connections(); - self.new_connection_tx = None; - self.is_decoupled = true; - ControlFlow::Continue(()) - } -} - -/// Future that runs until the [`Driver`] is decoupled and not active connections -/// remain -impl Future for Driver

{ - type Output = (); - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - // Flush any pending pocket so that the socket is reading to write an next - // packet. - match self.provider_socket.poll_send_flush(cx) { - // The pending packet was send or no packet was pending. - Poll::Ready(Ok(_)) => { - // Start sending a packet on the socket. - if let Some(transmit) = self.next_packet_out.take() { - self.provider_socket - .start_send(transmit.contents.into(), transmit.destination); - continue; - } - - // The endpoint might request packets to be sent out. This is handled in - // priority to avoid buffering up packets. - if let Some(transmit) = self.endpoint.poll_transmit() { - self.next_packet_out = Some(transmit); - continue; - } - - // Handle messages from transport and connections. - match self.rx.poll_next_unpin(cx) { - Poll::Ready(Some(to_endpoint)) => match self.handle_message(to_endpoint) { - ControlFlow::Continue(Some(transmit)) => { - self.next_packet_out = Some(transmit); - continue; - } - ControlFlow::Continue(None) => continue, - ControlFlow::Break(()) => break, - }, - Poll::Ready(None) => { - unreachable!("Sender side is never dropped or closed.") - } - Poll::Pending => {} - } - } - // Errors on the socket are expected to never happen, and we handle them by simply - // printing a log message. The packet gets discarded in case of error, but we are - // robust to packet losses and it is consequently not a logic error to proceed with - // normal operations. - Poll::Ready(Err(err)) => { - log::warn!("Error while sending on QUIC UDP socket: {:?}", err); - continue; - } - Poll::Pending => {} - } - - // Poll for new packets from the remote. - match self.provider_socket.poll_recv_from(cx) { - Poll::Ready(Ok((bytes, packet_src))) => { - let bytes_mut = bytes.as_slice().into(); - match self.handle_datagram(bytes_mut, packet_src) { - ControlFlow::Continue(()) => continue, - ControlFlow::Break(()) => break, - } - } - // Errors on the socket are expected to never happen, and we handle them by - // simply printing a log message. - Poll::Ready(Err(err)) => { - log::warn!("Error while receive on QUIC UDP socket: {:?}", err); - continue; - } - Poll::Pending => {} - } - - return Poll::Pending; - } - - Poll::Ready(()) - } -} diff --git a/transports/quic/src/hole_punching.rs b/transports/quic/src/hole_punching.rs index b9589dd17a0..874bc659b2e 100644 --- a/transports/quic/src/hole_punching.rs +++ b/transports/quic/src/hole_punching.rs @@ -1,19 +1,20 @@ -use std::{net::SocketAddr, time::Duration}; +use crate::{provider::Provider, Error}; use futures::future::Either; + use rand::{distributions, Rng}; -use crate::{ - endpoint::{self, ToEndpoint}, - Error, Provider, +use std::{ + net::{SocketAddr, UdpSocket}, + time::Duration, }; pub(crate) async fn hole_puncher( - endpoint_channel: endpoint::Channel, + socket: UdpSocket, remote_addr: SocketAddr, timeout_duration: Duration, ) -> Error { - let punch_holes_future = punch_holes::

(endpoint_channel, remote_addr); + let punch_holes_future = punch_holes::

(socket, remote_addr); futures::pin_mut!(punch_holes_future); match futures::future::select(P::sleep(timeout_duration), punch_holes_future).await { Either::Left(_) => Error::HandshakeTimedOut, @@ -21,27 +22,18 @@ pub(crate) async fn hole_puncher( } } -async fn punch_holes( - mut endpoint_channel: endpoint::Channel, - remote_addr: SocketAddr, -) -> Error { +async fn punch_holes(socket: UdpSocket, remote_addr: SocketAddr) -> Error { loop { let sleep_duration = Duration::from_millis(rand::thread_rng().gen_range(10..=200)); P::sleep(sleep_duration).await; - let random_udp_packet = ToEndpoint::SendUdpPacket(quinn_proto::Transmit { - destination: remote_addr, - ecn: None, - contents: rand::thread_rng() - .sample_iter(distributions::Standard) - .take(64) - .collect(), - segment_size: None, - src_ip: None, - }); + let contents: Vec = rand::thread_rng() + .sample_iter(distributions::Standard) + .take(64) + .collect(); - if endpoint_channel.send(random_udp_packet).await.is_err() { - return Error::EndpointDriverCrashed; + if let Err(e) = P::send_to(&socket, &contents, remote_addr).await { + return Error::Io(e); } } } diff --git a/transports/quic/src/lib.rs b/transports/quic/src/lib.rs index 945f5119c6e..494ecfdcddb 100644 --- a/transports/quic/src/lib.rs +++ b/transports/quic/src/lib.rs @@ -57,16 +57,17 @@ #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] +mod config; mod connection; -mod endpoint; mod hole_punching; mod provider; mod transport; use std::net::SocketAddr; -pub use connection::{Connecting, Connection, Substream}; -pub use endpoint::Config; +pub use config::Config; +pub use connection::{Connecting, Connection, Stream}; + #[cfg(feature = "async-std")] pub use provider::async_std; #[cfg(feature = "tokio")] @@ -89,8 +90,7 @@ pub enum Error { #[error(transparent)] Io(#[from] std::io::Error), - /// The task spawned in [`Provider::spawn`] to drive - /// the quic endpoint has crashed. + /// The task to drive a quic endpoint has crashed. #[error("Endpoint driver crashed")] EndpointDriverCrashed, @@ -110,9 +110,9 @@ pub enum Error { /// Dialing a remote peer failed. #[derive(Debug, thiserror::Error)] #[error(transparent)] -pub struct ConnectError(#[from] quinn_proto::ConnectError); +pub struct ConnectError(quinn::ConnectError); /// Error on an established [`Connection`]. #[derive(Debug, thiserror::Error)] #[error(transparent)] -pub struct ConnectionError(#[from] quinn_proto::ConnectionError); +pub struct ConnectionError(quinn::ConnectionError); diff --git a/transports/quic/src/provider.rs b/transports/quic/src/provider.rs index c9401e9b99f..6f1122ee55f 100644 --- a/transports/quic/src/provider.rs +++ b/transports/quic/src/provider.rs @@ -18,11 +18,11 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use futures::{future::BoxFuture, Future}; +use futures::future::BoxFuture; use if_watch::IfEvent; use std::{ io, - net::SocketAddr, + net::{SocketAddr, UdpSocket}, task::{Context, Poll}, time::Duration, }; @@ -32,40 +32,20 @@ pub mod async_std; #[cfg(feature = "tokio")] pub mod tokio; -/// Size of the buffer for reading data 0x10000. -#[cfg(any(feature = "async-std", feature = "tokio"))] -const RECEIVE_BUFFER_SIZE: usize = 65536; +pub enum Runtime { + #[cfg(feature = "tokio")] + Tokio, + #[cfg(feature = "async-std")] + AsyncStd, + Dummy, +} -/// Provider for non-blocking receiving and sending on a [`std::net::UdpSocket`] -/// and spawning tasks. +/// Provider for a corresponding quinn runtime and spawning tasks. pub trait Provider: Unpin + Send + Sized + 'static { type IfWatcher: Unpin + Send; - /// Create a new providing that is wrapping the socket. - /// - /// Note: The socket must be set to non-blocking. - fn from_socket(socket: std::net::UdpSocket) -> io::Result; - - /// Receive a single packet. - /// - /// Returns the message and the address the message came from. - fn poll_recv_from(&mut self, cx: &mut Context<'_>) -> Poll, SocketAddr)>>; - - /// Set sending a packet on the socket. - /// - /// Since only one packet can be sent at a time, this may only be called if a preceding - /// call to [`Provider::poll_send_flush`] returned [`Poll::Ready`]. - fn start_send(&mut self, data: Vec, addr: SocketAddr); - - /// Flush a packet send in [`Provider::start_send`]. - /// - /// If [`Poll::Ready`] is returned the socket is ready for sending a new packet. - fn poll_send_flush(&mut self, cx: &mut Context<'_>) -> Poll>; - - /// Run the given future in the background until it ends. - /// - /// This is used to spawn the task that is driving the endpoint. - fn spawn(future: impl Future + Send + 'static); + /// Run the corresponding runtime + fn runtime() -> Runtime; /// Create a new [`if_watch`] watcher that reports [`IfEvent`]s for network interface changes. fn new_if_watcher() -> io::Result; @@ -78,4 +58,11 @@ pub trait Provider: Unpin + Send + Sized + 'static { /// Sleep for specified amount of time. fn sleep(duration: Duration) -> BoxFuture<'static, ()>; + + /// Sends data on the socket to the given address. On success, returns the number of bytes written. + fn send_to<'a>( + udp_socket: &'a UdpSocket, + buf: &'a [u8], + target: SocketAddr, + ) -> BoxFuture<'a, io::Result>; } diff --git a/transports/quic/src/provider/async_std.rs b/transports/quic/src/provider/async_std.rs index e593b2ed4f4..a110058108c 100644 --- a/transports/quic/src/provider/async_std.rs +++ b/transports/quic/src/provider/async_std.rs @@ -18,13 +18,10 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use async_std::{net::UdpSocket, task::spawn}; -use futures::{future::BoxFuture, ready, Future, FutureExt, Stream, StreamExt}; +use futures::{future::BoxFuture, FutureExt}; use std::{ io, - net::SocketAddr, - pin::Pin, - sync::Arc, + net::UdpSocket, task::{Context, Poll}, time::Duration, }; @@ -34,65 +31,14 @@ use crate::GenTransport; /// Transport with [`async-std`] runtime. pub type Transport = GenTransport; -/// Provider for reading / writing to a sockets and spawning -/// tasks using [`async-std`]. -pub struct Provider { - socket: Arc, - // Future for sending a packet. - // This is needed since [`async_Std::net::UdpSocket`] does not - // provide a poll-style interface for sending a packet. - send_packet: Option>>, - recv_stream: ReceiveStream, -} +/// Provider for quinn runtime and spawning tasks using [`async-std`]. +pub struct Provider; impl super::Provider for Provider { type IfWatcher = if_watch::smol::IfWatcher; - fn from_socket(socket: std::net::UdpSocket) -> io::Result { - let socket = Arc::new(socket.into()); - let recv_stream = ReceiveStream::new(Arc::clone(&socket)); - Ok(Provider { - socket, - send_packet: None, - recv_stream, - }) - } - - fn poll_recv_from(&mut self, cx: &mut Context<'_>) -> Poll, SocketAddr)>> { - match self.recv_stream.poll_next_unpin(cx) { - Poll::Ready(ready) => { - Poll::Ready(ready.expect("ReceiveStream::poll_next never returns None.")) - } - Poll::Pending => Poll::Pending, - } - } - - fn start_send(&mut self, data: Vec, addr: SocketAddr) { - let socket = self.socket.clone(); - let send = async move { - socket.send_to(&data, addr).await?; - Ok(()) - } - .boxed(); - self.send_packet = Some(send) - } - - fn poll_send_flush(&mut self, cx: &mut Context<'_>) -> Poll> { - let pending = match self.send_packet.as_mut() { - Some(pending) => pending, - None => return Poll::Ready(Ok(())), - }; - match pending.poll_unpin(cx) { - Poll::Ready(result) => { - self.send_packet = None; - Poll::Ready(result) - } - Poll::Pending => Poll::Pending, - } - } - - fn spawn(future: impl Future + Send + 'static) { - spawn(future); + fn runtime() -> super::Runtime { + super::Runtime::AsyncStd } fn new_if_watcher() -> io::Result { @@ -109,48 +55,16 @@ impl super::Provider for Provider { fn sleep(duration: Duration) -> BoxFuture<'static, ()> { async_std::task::sleep(duration).boxed() } -} - -type ReceiveStreamItem = ( - Result<(usize, SocketAddr), io::Error>, - Arc, - Vec, -); -/// Wrapper around the socket to implement `Stream` on it. -struct ReceiveStream { - /// Future for receiving a packet on the socket. - // This is needed since [`async_Std::net::UdpSocket`] does not - // provide a poll-style interface for receiving packets. - fut: BoxFuture<'static, ReceiveStreamItem>, -} - -impl ReceiveStream { - fn new(socket: Arc) -> Self { - let fut = ReceiveStream::next(socket, vec![0; super::RECEIVE_BUFFER_SIZE]).boxed(); - Self { fut: fut.boxed() } - } - - async fn next(socket: Arc, mut socket_recv_buffer: Vec) -> ReceiveStreamItem { - let recv = socket.recv_from(&mut socket_recv_buffer).await; - (recv, socket, socket_recv_buffer) - } -} - -impl Stream for ReceiveStream { - type Item = Result<(Vec, SocketAddr), io::Error>; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let (result, socket, buffer) = ready!(self.fut.poll_unpin(cx)); - - let result = result.map(|(packet_len, packet_src)| { - debug_assert!(packet_len <= buffer.len()); - // Copies the bytes from the `socket_recv_buffer` they were written into. - (buffer[..packet_len].into(), packet_src) - }); - // Set the future for receiving the next packet on the stream. - self.fut = ReceiveStream::next(socket, buffer).boxed(); - - Poll::Ready(Some(result)) + fn send_to<'a>( + udp_socket: &'a UdpSocket, + buf: &'a [u8], + target: std::net::SocketAddr, + ) -> BoxFuture<'a, io::Result> { + Box::pin(async move { + async_std::net::UdpSocket::from(udp_socket.try_clone()?) + .send_to(buf, target) + .await + }) } } diff --git a/transports/quic/src/provider/tokio.rs b/transports/quic/src/provider/tokio.rs index 77c9060e3c1..9cb148d6ef2 100644 --- a/transports/quic/src/provider/tokio.rs +++ b/transports/quic/src/provider/tokio.rs @@ -18,72 +18,27 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use futures::{future::BoxFuture, ready, Future, FutureExt}; +use futures::{future::BoxFuture, FutureExt}; use std::{ io, - net::SocketAddr, + net::{SocketAddr, UdpSocket}, task::{Context, Poll}, time::Duration, }; -use tokio::{io::ReadBuf, net::UdpSocket}; use crate::GenTransport; /// Transport with [`tokio`] runtime. pub type Transport = GenTransport; -/// Provider for reading / writing to a sockets and spawning -/// tasks using [`tokio`]. -pub struct Provider { - socket: UdpSocket, - socket_recv_buffer: Vec, - next_packet_out: Option<(Vec, SocketAddr)>, -} +/// Provider for quinn runtime and spawning tasks using [`tokio`]. +pub struct Provider; impl super::Provider for Provider { type IfWatcher = if_watch::tokio::IfWatcher; - fn from_socket(socket: std::net::UdpSocket) -> std::io::Result { - let socket = UdpSocket::from_std(socket)?; - Ok(Provider { - socket, - socket_recv_buffer: vec![0; super::RECEIVE_BUFFER_SIZE], - next_packet_out: None, - }) - } - - fn poll_send_flush(&mut self, cx: &mut Context<'_>) -> Poll> { - let (data, addr) = match self.next_packet_out.as_ref() { - Some(pending) => pending, - None => return Poll::Ready(Ok(())), - }; - match self.socket.poll_send_to(cx, data.as_slice(), *addr) { - Poll::Ready(result) => { - self.next_packet_out = None; - Poll::Ready(result.map(|_| ())) - } - Poll::Pending => Poll::Pending, - } - } - - fn poll_recv_from(&mut self, cx: &mut Context<'_>) -> Poll, SocketAddr)>> { - let Self { - socket, - socket_recv_buffer, - .. - } = self; - let mut read_buf = ReadBuf::new(socket_recv_buffer.as_mut_slice()); - let packet_src = ready!(socket.poll_recv_from(cx, &mut read_buf)?); - let bytes = read_buf.filled().to_vec(); - Poll::Ready(Ok((bytes, packet_src))) - } - - fn start_send(&mut self, data: Vec, addr: SocketAddr) { - self.next_packet_out = Some((data, addr)); - } - - fn spawn(future: impl Future + Send + 'static) { - tokio::spawn(future); + fn runtime() -> super::Runtime { + super::Runtime::Tokio } fn new_if_watcher() -> io::Result { @@ -100,4 +55,16 @@ impl super::Provider for Provider { fn sleep(duration: Duration) -> BoxFuture<'static, ()> { tokio::time::sleep(duration).boxed() } + + fn send_to<'a>( + udp_socket: &'a UdpSocket, + buf: &'a [u8], + target: SocketAddr, + ) -> BoxFuture<'a, io::Result> { + Box::pin(async move { + tokio::net::UdpSocket::from_std(udp_socket.try_clone()?)? + .send_to(buf, target) + .await + }) + } } diff --git a/transports/quic/src/transport.rs b/transports/quic/src/transport.rs index 84f9e479ee8..d4a1db35604 100644 --- a/transports/quic/src/transport.rs +++ b/transports/quic/src/transport.rs @@ -18,12 +18,12 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::endpoint::{Config, QuinnConfig, ToEndpoint}; +use crate::config::{Config, QuinnConfig}; use crate::hole_punching::hole_puncher; use crate::provider::Provider; -use crate::{endpoint, Connecting, Connection, Error}; +use crate::{ConnectError, Connecting, Connection, Error}; -use futures::channel::{mpsc, oneshot}; +use futures::channel::oneshot; use futures::future::{BoxFuture, Either}; use futures::ready; use futures::stream::StreamExt; @@ -38,10 +38,10 @@ use libp2p_core::{ }; use libp2p_identity::PeerId; use std::collections::hash_map::{DefaultHasher, Entry}; -use std::collections::{HashMap, VecDeque}; +use std::collections::HashMap; use std::fmt; use std::hash::{Hash, Hasher}; -use std::net::IpAddr; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, UdpSocket}; use std::time::Duration; use std::{ net::SocketAddr, @@ -62,7 +62,7 @@ use std::{ /// See . #[derive(Debug)] pub struct GenTransport { - /// Config for the inner [`quinn_proto`] structs. + /// Config for the inner [`quinn`] structs. quinn_config: QuinnConfig, /// Timeout for the [`Connecting`] future. handshake_timeout: Duration, @@ -71,7 +71,7 @@ pub struct GenTransport { /// Streams of active [`Listener`]s. listeners: SelectAll>, /// Dialer for each socket family if no matching listener exists. - dialer: HashMap, + dialer: HashMap, /// Waker to poll the transport again when a new dialer or listener is added. waker: Option, /// Holepunching attempts @@ -95,21 +95,57 @@ impl GenTransport

{ } } + /// Create a new [`quinn::Endpoint`] with the given configs. + fn new_endpoint( + endpoint_config: quinn::EndpointConfig, + server_config: Option, + socket: UdpSocket, + ) -> Result { + use crate::provider::Runtime; + match P::runtime() { + #[cfg(feature = "tokio")] + Runtime::Tokio => { + let runtime = std::sync::Arc::new(quinn::TokioRuntime); + let endpoint = + quinn::Endpoint::new(endpoint_config, server_config, socket, runtime)?; + Ok(endpoint) + } + #[cfg(feature = "async-std")] + Runtime::AsyncStd => { + let runtime = std::sync::Arc::new(quinn::AsyncStdRuntime); + let endpoint = + quinn::Endpoint::new(endpoint_config, server_config, socket, runtime)?; + Ok(endpoint) + } + Runtime::Dummy => { + let _ = endpoint_config; + let _ = server_config; + let _ = socket; + let err = std::io::Error::new(std::io::ErrorKind::Other, "no async runtime found"); + Err(Error::Io(err)) + } + } + } + + /// Extract the addr, quic version and peer id from the given [`Multiaddr`]. fn remote_multiaddr_to_socketaddr( &self, addr: Multiaddr, + check_unspecified_addr: bool, ) -> Result< (SocketAddr, ProtocolVersion, Option), TransportError<::Error>, > { let (socket_addr, version, peer_id) = multiaddr_to_socketaddr(&addr, self.support_draft_29) .ok_or_else(|| TransportError::MultiaddrNotSupported(addr.clone()))?; - if socket_addr.port() == 0 || socket_addr.ip().is_unspecified() { + if check_unspecified_addr && (socket_addr.port() == 0 || socket_addr.ip().is_unspecified()) + { return Err(TransportError::MultiaddrNotSupported(addr)); } Ok((socket_addr, version, peer_id)) } + /// Pick any listener to use for dialing. fn eligible_listener(&mut self, socket_addr: &SocketAddr) -> Option<&mut Listener

> { let mut listeners: Vec<_> = self .listeners @@ -118,7 +154,7 @@ impl GenTransport

{ if l.is_closed { return false; } - let listen_addr = l.endpoint_channel.socket_addr(); + let listen_addr = l.socket_addr(); SocketFamily::is_same(&listen_addr.ip(), &socket_addr.ip()) && listen_addr.ip().is_loopback() == socket_addr.ip().is_loopback() }) @@ -149,13 +185,16 @@ impl Transport for GenTransport

{ listener_id: ListenerId, addr: Multiaddr, ) -> Result<(), TransportError> { - let (socket_addr, version, _peer_id) = - multiaddr_to_socketaddr(&addr, self.support_draft_29) - .ok_or(TransportError::MultiaddrNotSupported(addr))?; + let (socket_addr, version, _peer_id) = self.remote_multiaddr_to_socketaddr(addr, false)?; + let endpoint_config = self.quinn_config.endpoint_config.clone(); + let server_config = self.quinn_config.server_config.clone(); + let socket = UdpSocket::bind(socket_addr).map_err(Self::Error::from)?; + let socket_c = socket.try_clone().map_err(Self::Error::from)?; + let endpoint = Self::new_endpoint(endpoint_config, Some(server_config), socket)?; let listener = Listener::new( listener_id, - socket_addr, - self.quinn_config.clone(), + socket_c, + endpoint, self.handshake_timeout, version, )?; @@ -194,46 +233,68 @@ impl Transport for GenTransport

{ } fn dial(&mut self, addr: Multiaddr) -> Result> { - let (socket_addr, version, _peer_id) = self.remote_multiaddr_to_socketaddr(addr)?; - - let handshake_timeout = self.handshake_timeout; + let (socket_addr, version, _peer_id) = self.remote_multiaddr_to_socketaddr(addr, true)?; - let dialer_state = match self.eligible_listener(&socket_addr) { + let endpoint = match self.eligible_listener(&socket_addr) { None => { // No listener. Get or create an explicit dialer. let socket_family = socket_addr.ip().into(); let dialer = match self.dialer.entry(socket_family) { - Entry::Occupied(occupied) => occupied.into_mut(), + Entry::Occupied(occupied) => occupied.get().clone(), Entry::Vacant(vacant) => { if let Some(waker) = self.waker.take() { waker.wake(); } - vacant.insert(Dialer::new::

(self.quinn_config.clone(), socket_family)?) + let listen_socket_addr = match socket_family { + SocketFamily::Ipv4 => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0), + SocketFamily::Ipv6 => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0), + }; + let socket = + UdpSocket::bind(listen_socket_addr).map_err(Self::Error::from)?; + let endpoint_config = self.quinn_config.endpoint_config.clone(); + let endpoint = Self::new_endpoint(endpoint_config, None, socket)?; + + vacant.insert(endpoint.clone()); + endpoint } }; - &mut dialer.state + dialer } - Some(listener) => &mut listener.dialer_state, + Some(listener) => listener.endpoint.clone(), }; - Ok(dialer_state.new_dial(socket_addr, handshake_timeout, version)) + let handshake_timeout = self.handshake_timeout; + let mut client_config = self.quinn_config.client_config.clone(); + if version == ProtocolVersion::Draft29 { + client_config.version(0xff00_001d); + } + Ok(Box::pin(async move { + // This `"l"` seems necessary because an empty string is an invalid domain + // name. While we don't use domain names, the underlying rustls library + // is based upon the assumption that we do. + let connecting = endpoint + .connect_with(client_config, socket_addr, "l") + .map_err(ConnectError)?; + Connecting::new(connecting, handshake_timeout).await + })) } fn dial_as_listener( &mut self, addr: Multiaddr, ) -> Result> { - let (socket_addr, _version, peer_id) = self.remote_multiaddr_to_socketaddr(addr.clone())?; + let (socket_addr, _version, peer_id) = + self.remote_multiaddr_to_socketaddr(addr.clone(), true)?; let peer_id = peer_id.ok_or(TransportError::MultiaddrNotSupported(addr))?; - let endpoint_channel = self + let socket = self .eligible_listener(&socket_addr) .ok_or(TransportError::Other( Error::NoActiveListenerForDialAsListener, ))? - .endpoint_channel - .clone(); + .try_clone_socket() + .map_err(Self::Error::from)?; - let hole_puncher = hole_puncher::

(endpoint_channel, socket_addr, self.handshake_timeout); + let hole_puncher = hole_puncher::

(socket, socket_addr, self.handshake_timeout); let (sender, receiver) = oneshot::channel(); @@ -274,19 +335,6 @@ impl Transport for GenTransport

{ mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let mut errored = Vec::new(); - for (key, dialer) in &mut self.dialer { - if let Poll::Ready(_error) = dialer.poll(cx) { - errored.push(*key); - } - } - - for key in errored { - // Endpoint driver of dialer crashed. - // Drop dialer and all pending dials so that the connection receiver is notified. - self.dialer.remove(&key); - } - while let Poll::Ready(Some(ev)) = self.listeners.poll_next_unpin(cx) { match ev { TransportEvent::Incoming { @@ -331,112 +379,22 @@ impl From for TransportError { } } -/// Dialer for addresses if no matching listener exists. -#[derive(Debug)] -struct Dialer { - /// Channel to the [`crate::endpoint::Driver`] that - /// is driving the endpoint. - endpoint_channel: endpoint::Channel, - /// Queued dials for the endpoint. - state: DialerState, -} - -impl Dialer { - fn new( - config: QuinnConfig, - socket_family: SocketFamily, - ) -> Result> { - let endpoint_channel = endpoint::Channel::new_dialer::

(config, socket_family) - .map_err(TransportError::Other)?; - Ok(Dialer { - endpoint_channel, - state: DialerState::default(), - }) - } - - fn poll(&mut self, cx: &mut Context<'_>) -> Poll { - self.state.poll(&mut self.endpoint_channel, cx) - } -} - -impl Drop for Dialer { - fn drop(&mut self) { - self.endpoint_channel.send_on_drop(ToEndpoint::Decoupled); - } -} - -/// Pending dials to be sent to the endpoint was the [`endpoint::Channel`] -/// has capacity -#[derive(Default, Debug)] -struct DialerState { - pending_dials: VecDeque, - waker: Option, -} - -impl DialerState { - fn new_dial( - &mut self, - address: SocketAddr, - timeout: Duration, - version: ProtocolVersion, - ) -> BoxFuture<'static, Result<(PeerId, Connection), Error>> { - let (rx, tx) = oneshot::channel(); - - let message = ToEndpoint::Dial { - addr: address, - result: rx, - version, - }; - - self.pending_dials.push_back(message); - - if let Some(waker) = self.waker.take() { - waker.wake(); - } - - async move { - // Our oneshot getting dropped means the message didn't make it to the endpoint driver. - let connection = tx.await.map_err(|_| Error::EndpointDriverCrashed)??; - let (peer, connection) = Connecting::new(connection, timeout).await?; - - Ok((peer, connection)) - } - .boxed() - } - - /// Send all pending dials into the given [`endpoint::Channel`]. - /// - /// This only ever returns [`Poll::Pending`], or an error in case the channel is closed. - fn poll(&mut self, channel: &mut endpoint::Channel, cx: &mut Context<'_>) -> Poll { - while let Some(to_endpoint) = self.pending_dials.pop_front() { - match channel.try_send(to_endpoint, cx) { - Ok(Ok(())) => {} - Ok(Err(to_endpoint)) => { - self.pending_dials.push_front(to_endpoint); - break; - } - Err(endpoint::Disconnected {}) => return Poll::Ready(Error::EndpointDriverCrashed), - } - } - self.waker = Some(cx.waker().clone()); - Poll::Pending - } -} - /// Listener for incoming connections. struct Listener { /// Id of the listener. listener_id: ListenerId, + /// Version of the supported quic protocol. version: ProtocolVersion, - /// Channel to the endpoint to initiate dials. - endpoint_channel: endpoint::Channel, - /// Queued dials. - dialer_state: DialerState, + /// Endpoint + endpoint: quinn::Endpoint, - /// Channel where new connections are being sent. - new_connections_rx: mpsc::Receiver, + /// An underlying copy of the socket to be able to hole punch with + socket: UdpSocket, + + /// A future to poll new incoming connections. + accept: BoxFuture<'static, Option>, /// Timeout for connection establishment on inbound connections. handshake_timeout: Duration, @@ -458,38 +416,39 @@ struct Listener { impl Listener

{ fn new( listener_id: ListenerId, - socket_addr: SocketAddr, - config: QuinnConfig, + socket: UdpSocket, + endpoint: quinn::Endpoint, handshake_timeout: Duration, version: ProtocolVersion, ) -> Result { - let (endpoint_channel, new_connections_rx) = - endpoint::Channel::new_bidirectional::

(config, socket_addr)?; - let if_watcher; let pending_event; - if socket_addr.ip().is_unspecified() { + let local_addr = socket.local_addr()?; + if local_addr.ip().is_unspecified() { if_watcher = Some(P::new_if_watcher()?); pending_event = None; } else { if_watcher = None; - let ma = socketaddr_to_multiaddr(endpoint_channel.socket_addr(), version); + let ma = socketaddr_to_multiaddr(&local_addr, version); pending_event = Some(TransportEvent::NewAddress { listener_id, listen_addr: ma, }) } + let endpoint_c = endpoint.clone(); + let accept = async move { endpoint_c.accept().await }.boxed(); + Ok(Listener { - endpoint_channel, + endpoint, + socket, + accept, listener_id, version, - new_connections_rx, handshake_timeout, if_watcher, is_closed: false, pending_event, - dialer_state: DialerState::default(), close_listener_waker: None, }) } @@ -500,6 +459,7 @@ impl Listener

{ if self.is_closed { return; } + self.endpoint.close(From::from(0u32), &[]); self.pending_event = Some(TransportEvent::ListenerClosed { listener_id: self.listener_id, reason, @@ -512,8 +472,20 @@ impl Listener

{ } } + /// Clone underlying socket (for hole punching). + fn try_clone_socket(&self) -> std::io::Result { + self.socket.try_clone() + } + + fn socket_addr(&self) -> SocketAddr { + self.socket + .local_addr() + .expect("Cannot fail because the socket is bound") + } + /// Poll for a next If Event. fn poll_if_addr(&mut self, cx: &mut Context<'_>) -> Poll<::Item> { + let endpoint_addr = self.socket_addr(); let if_watcher = match self.if_watcher.as_mut() { Some(iw) => iw, None => return Poll::Pending, @@ -521,11 +493,9 @@ impl Listener

{ loop { match ready!(P::poll_if_event(if_watcher, cx)) { Ok(IfEvent::Up(inet)) => { - if let Some(listen_addr) = ip_to_listenaddr( - self.endpoint_channel.socket_addr(), - inet.addr(), - self.version, - ) { + if let Some(listen_addr) = + ip_to_listenaddr(&endpoint_addr, inet.addr(), self.version) + { log::debug!("New listen address: {}", listen_addr); return Poll::Ready(TransportEvent::NewAddress { listener_id: self.listener_id, @@ -534,11 +504,9 @@ impl Listener

{ } } Ok(IfEvent::Down(inet)) => { - if let Some(listen_addr) = ip_to_listenaddr( - self.endpoint_channel.socket_addr(), - inet.addr(), - self.version, - ) { + if let Some(listen_addr) = + ip_to_listenaddr(&endpoint_addr, inet.addr(), self.version) + { log::debug!("Expired listen address: {}", listen_addr); return Poll::Ready(TransportEvent::AddressExpired { listener_id: self.listener_id, @@ -555,21 +523,10 @@ impl Listener

{ } } } - - /// Poll [`DialerState`] to initiate requested dials. - fn poll_dialer(&mut self, cx: &mut Context<'_>) -> Poll { - let Self { - dialer_state, - endpoint_channel, - .. - } = self; - - dialer_state.poll(endpoint_channel, cx) - } } impl Stream for Listener

{ - type Item = TransportEvent; + type Item = TransportEvent< as Transport>::ListenerUpgrade, Error>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { if let Some(event) = self.pending_event.take() { @@ -581,17 +538,18 @@ impl Stream for Listener

{ if let Poll::Ready(event) = self.poll_if_addr(cx) { return Poll::Ready(Some(event)); } - if let Poll::Ready(error) = self.poll_dialer(cx) { - self.close(Err(error)); - continue; - } - match self.new_connections_rx.poll_next_unpin(cx) { - Poll::Ready(Some(connection)) => { - let local_addr = socketaddr_to_multiaddr(connection.local_addr(), self.version); - let send_back_addr = - socketaddr_to_multiaddr(&connection.remote_addr(), self.version); + + match self.accept.poll_unpin(cx) { + Poll::Ready(Some(connecting)) => { + let endpoint = self.endpoint.clone(); + self.accept = async move { endpoint.accept().await }.boxed(); + + let local_addr = socketaddr_to_multiaddr(&self.socket_addr(), self.version); + let remote_addr = connecting.remote_address(); + let send_back_addr = socketaddr_to_multiaddr(&remote_addr, self.version); + let event = TransportEvent::Incoming { - upgrade: Connecting::new(connection, self.handshake_timeout), + upgrade: Connecting::new(connecting, self.handshake_timeout), local_addr, send_back_addr, listener_id: self.listener_id, @@ -616,9 +574,6 @@ impl fmt::Debug for Listener

{ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Listener") .field("listener_id", &self.listener_id) - .field("endpoint_channel", &self.endpoint_channel) - .field("dialer_state", &self.dialer_state) - .field("new_connections_rx", &self.new_connections_rx) .field("handshake_timeout", &self.handshake_timeout) .field("is_closed", &self.is_closed) .field("pending_event", &self.pending_event) @@ -626,12 +581,6 @@ impl fmt::Debug for Listener

{ } } -impl Drop for Listener

{ - fn drop(&mut self) { - self.endpoint_channel.send_on_drop(ToEndpoint::Decoupled); - } -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum ProtocolVersion { V1, // i.e. RFC9000 @@ -766,7 +715,6 @@ fn socketaddr_to_multiaddr(socket_addr: &SocketAddr, version: ProtocolVersion) - #[cfg(any(feature = "async-std", feature = "tokio"))] mod test { use futures::future::poll_fn; - use futures_timer::Delay; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use super::*; @@ -882,15 +830,6 @@ mod test { .listen_on(id, "/ip4/0.0.0.0/udp/0/quic-v1".parse().unwrap()) .unwrap(); - // Copy channel to use it later. - let mut channel = transport - .listeners - .iter() - .next() - .unwrap() - .endpoint_channel - .clone(); - match poll_fn(|cx| Pin::new(&mut transport).as_mut().poll(cx)).await { TransportEvent::NewAddress { listener_id, @@ -923,14 +862,6 @@ mod test { .now_or_never() .is_none()); assert!(transport.listeners.is_empty()); - - // Check that the [`Driver`] has shut down. - Delay::new(Duration::from_millis(10)).await; - poll_fn(|cx| { - assert!(channel.try_send(ToEndpoint::Decoupled, cx).is_err()); - Poll::Ready(()) - }) - .await; } } @@ -945,32 +876,9 @@ mod test { .dial("/ip4/123.45.67.8/udp/1234/quic-v1".parse().unwrap()) .unwrap(); - // Expect a dialer and its background task to exist. - let mut channel = transport - .dialer - .get(&SocketFamily::Ipv4) - .unwrap() - .endpoint_channel - .clone(); + assert!(transport.dialer.contains_key(&SocketFamily::Ipv4)); assert!(!transport.dialer.contains_key(&SocketFamily::Ipv6)); - // Send dummy dial to check that the endpoint driver is running. - poll_fn(|cx| { - let (tx, _) = oneshot::channel(); - let _ = channel - .try_send( - ToEndpoint::Dial { - addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), - result: tx, - version: ProtocolVersion::V1, - }, - cx, - ) - .unwrap(); - Poll::Ready(()) - }) - .await; - // Start listening so that the dialer and driver are dropped. transport .listen_on( @@ -979,13 +887,5 @@ mod test { ) .unwrap(); assert!(!transport.dialer.contains_key(&SocketFamily::Ipv4)); - - // Check that the [`Driver`] has shut down. - Delay::new(Duration::from_millis(10)).await; - poll_fn(|cx| { - assert!(channel.try_send(ToEndpoint::Decoupled, cx).is_err()); - Poll::Ready(()) - }) - .await; } } diff --git a/transports/quic/tests/smoke.rs b/transports/quic/tests/smoke.rs index 8a6d689a7b0..93adfa68013 100644 --- a/transports/quic/tests/smoke.rs +++ b/transports/quic/tests/smoke.rs @@ -428,7 +428,7 @@ async fn smoke() { assert_eq!(b_connected, a_peer_id); } -async fn build_streams() -> (SubstreamBox, SubstreamBox) { +async fn build_streams() -> (SubstreamBox, SubstreamBox) { let (_, mut a_transport) = create_default_transport::

(); let (_, mut b_transport) = create_default_transport::

(); @@ -522,7 +522,7 @@ async fn start_listening(transport: &mut Boxed<(PeerId, StreamMuxerBox)>, addr: } } -fn prop( +fn prop( number_listeners: NonZeroU8, number_streams: NonZeroU8, ) -> quickcheck::TestResult { @@ -599,7 +599,7 @@ fn prop( quickcheck::TestResult::passed() } -async fn answer_inbound_streams( +async fn answer_inbound_streams( mut connection: StreamMuxerBox, ) { loop { @@ -634,7 +634,7 @@ async fn answer_inbound_streams( } } -async fn open_outbound_streams( +async fn open_outbound_streams( mut connection: StreamMuxerBox, number_streams: usize, completed_streams_tx: mpsc::Sender<()>, @@ -740,3 +740,22 @@ impl BlockOn for libp2p_quic::tokio::Provider { .unwrap() } } + +trait Spawn { + /// Run the given future in the background until it ends. + fn spawn(future: impl Future + Send + 'static); +} + +#[cfg(feature = "async-std")] +impl Spawn for libp2p_quic::async_std::Provider { + fn spawn(future: impl Future + Send + 'static) { + async_std::task::spawn(future); + } +} + +#[cfg(feature = "tokio")] +impl Spawn for libp2p_quic::tokio::Provider { + fn spawn(future: impl Future + Send + 'static) { + tokio::spawn(future); + } +}