diff --git a/base_layer/contacts/tests/contacts_service.rs b/base_layer/contacts/tests/contacts_service.rs index f960ca9bd7..73acc051ae 100644 --- a/base_layer/contacts/tests/contacts_service.rs +++ b/base_layer/contacts/tests/contacts_service.rs @@ -88,7 +88,6 @@ pub fn setup_contacts_service( auto_request: true, ..Default::default() }, - excluded_dial_addresses: vec![], ..Default::default() }, allow_test_addresses: true, diff --git a/base_layer/wallet_ffi/src/lib.rs b/base_layer/wallet_ffi/src/lib.rs index 895163c93c..c8e1164c10 100644 --- a/base_layer/wallet_ffi/src/lib.rs +++ b/base_layer/wallet_ffi/src/lib.rs @@ -126,6 +126,7 @@ use tari_common_types::{ }; use tari_comms::{ multiaddr::Multiaddr, + net_address::IP4_TCP_TEST_ADDR_RANGE, peer_manager::{NodeIdentity, PeerQuery}, transports::MemoryTransport, types::CommsPublicKey, @@ -5327,7 +5328,7 @@ pub unsafe extern "C" fn comms_config_create( minimum_desired_tcpv4_node_ratio: 0.0, ..Default::default() }, - excluded_dial_addresses: vec![], + excluded_dial_addresses: vec![IP4_TCP_TEST_ADDR_RANGE.parse().expect("valid address range")], ..Default::default() }, allow_test_addresses: true, diff --git a/comms/core/src/builder/mod.rs b/comms/core/src/builder/mod.rs index 727d13cf6f..26455eabf2 100644 --- a/comms/core/src/builder/mod.rs +++ b/comms/core/src/builder/mod.rs @@ -45,6 +45,7 @@ use crate::{ connection_manager::{ConnectionManagerConfig, ConnectionManagerRequester}, connectivity::{ConnectivityConfig, ConnectivityRequester}, multiaddr::Multiaddr, + net_address::MultiaddrRange, peer_manager::{NodeIdentity, PeerManager}, peer_validator::PeerValidatorConfig, protocol::{NodeNetworkInfo, ProtocolExtensions}, @@ -242,7 +243,7 @@ impl CommsBuilder { self } - pub fn with_excluded_dial_addresses(mut self, excluded_addresses: Vec) -> Self { + pub fn with_excluded_dial_addresses(mut self, excluded_addresses: Vec) -> Self { self.connection_manager_config.excluded_dial_addresses = excluded_addresses; self } diff --git a/comms/core/src/connection_manager/dialer.rs b/comms/core/src/connection_manager/dialer.rs index 245d3e4308..357491ae22 100644 --- a/comms/core/src/connection_manager/dialer.rs +++ b/comms/core/src/connection_manager/dialer.rs @@ -55,7 +55,7 @@ use crate::{ }, multiaddr::Multiaddr, multiplexing::Yamux, - net_address::PeerAddressSource, + net_address::{MultiaddrRange, PeerAddressSource}, noise::{NoiseConfig, NoiseSocket}, peer_manager::{NodeId, NodeIdentity, Peer, PeerManager}, protocol::ProtocolId, @@ -557,7 +557,7 @@ where noise_config: &NoiseConfig, transport: &TTransport, network_byte: u8, - excluded_dial_addresses: Vec, + excluded_dial_addresses: Vec, ) -> ( DialState, Result<(NoiseSocket, Multiaddr), ConnectionManagerError>, @@ -568,7 +568,7 @@ where .clone() .into_vec() .iter() - .filter(|&a| !excluded_dial_addresses.iter().any(|excluded| a == excluded)) + .filter(|&a| !excluded_dial_addresses.iter().any(|excluded| excluded.contains(a))) .cloned() .collect::>(); if addresses.is_empty() { diff --git a/comms/core/src/connection_manager/manager.rs b/comms/core/src/connection_manager/manager.rs index 67c28679cd..a646a3dd41 100644 --- a/comms/core/src/connection_manager/manager.rs +++ b/comms/core/src/connection_manager/manager.rs @@ -49,6 +49,7 @@ use crate::{ backoff::Backoff, connection_manager::ConnectionId, multiplexing::Substream, + net_address::MultiaddrRange, noise::NoiseConfig, peer_manager::{NodeId, NodeIdentity, PeerManagerError}, peer_validator::PeerValidatorConfig, @@ -134,7 +135,7 @@ pub struct ConnectionManagerConfig { /// Peer validation configuration. See [PeerValidatorConfig] pub peer_validation_config: PeerValidatorConfig, /// Addresses that should never be dialed - pub excluded_dial_addresses: Vec, + pub excluded_dial_addresses: Vec, } impl Default for ConnectionManagerConfig { diff --git a/comms/core/src/connection_manager/tests/listener_dialer.rs b/comms/core/src/connection_manager/tests/listener_dialer.rs index a1c244b838..e73f052379 100644 --- a/comms/core/src/connection_manager/tests/listener_dialer.rs +++ b/comms/core/src/connection_manager/tests/listener_dialer.rs @@ -287,7 +287,7 @@ async fn excluded_yes() { let (request_tx, request_rx) = mpsc::channel(1); let peer_manager2 = build_peer_manager(); let connection_manager_config = ConnectionManagerConfig { - excluded_dial_addresses: vec![address.clone()], + excluded_dial_addresses: vec![address.to_string().parse().unwrap()], ..Default::default() }; let mut dialer = Dialer::new( diff --git a/comms/core/src/net_address/mod.rs b/comms/core/src/net_address/mod.rs index a437636096..5efd1f557d 100644 --- a/comms/core/src/net_address/mod.rs +++ b/comms/core/src/net_address/mod.rs @@ -27,3 +27,6 @@ pub use multiaddr_with_stats::{MultiaddrWithStats, PeerAddressSource}; mod mutliaddresses_with_stats; pub use mutliaddresses_with_stats::MultiaddressesWithStats; + +mod multiaddr_range; +pub use multiaddr_range::{MultiaddrRange, IP4_TCP_TEST_ADDR_RANGE}; diff --git a/comms/core/src/net_address/multiaddr_range.rs b/comms/core/src/net_address/multiaddr_range.rs new file mode 100644 index 0000000000..985d598607 --- /dev/null +++ b/comms/core/src/net_address/multiaddr_range.rs @@ -0,0 +1,333 @@ +// Copyright 2022 The Tari Project +// SPDX-License-Identifier: BSD-3-Clause + +use std::{fmt, net::Ipv4Addr, str::FromStr}; + +use multiaddr::{Multiaddr, Protocol}; +use serde_derive::{Deserialize, Serialize}; + +/// A MultiaddrRange for testing purposes that matches any IPv4 address and any port +pub const IP4_TCP_TEST_ADDR_RANGE: &str = "/ip4/127.*.*.*/tcp/*"; + +/// ----------------- MultiaddrRange ----------------- +/// A struct containing either an Ipv4AddrRange or a Multiaddr. If a range of IP addresses and/or ports needs to be +/// specified, the MultiaddrRange can be used, but it only supports IPv4 addresses with the TCP protocol. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MultiaddrRange { + ipv4_addr_range: Option, + multiaddr: Option, +} + +impl FromStr for MultiaddrRange { + type Err = String; + + fn from_str(s: &str) -> Result { + if let Ok(multiaddr) = Multiaddr::from_str(s) { + Ok(MultiaddrRange { + ipv4_addr_range: None, + multiaddr: Some(multiaddr), + }) + } else if let Ok(ipv4_addr_range) = Ipv4AddrRange::from_str(s) { + Ok(MultiaddrRange { + ipv4_addr_range: Some(ipv4_addr_range), + multiaddr: None, + }) + } else { + Err("Invalid format for both Multiaddr and Ipv4AddrRange".to_string()) + } + } +} + +impl fmt::Display for MultiaddrRange { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let Some(ipv4_addr_range) = &self.ipv4_addr_range { + write!(f, "{}", ipv4_addr_range) + } else if let Some(multiaddr) = &self.multiaddr { + write!(f, "{}", multiaddr) + } else { + write!(f, "None") + } + } +} + +impl MultiaddrRange { + /// Check if the given Multiaddr is contained within the MultiaddrRange range + pub fn contains(&self, addr: &Multiaddr) -> bool { + if let Some(ipv4_addr_range) = &self.ipv4_addr_range { + return ipv4_addr_range.contains(addr); + } + if let Some(multiaddr) = &self.multiaddr { + return multiaddr == addr; + } + false + } +} + +// ----------------- Ipv4AddrRange ----------------- +// A struct containing an Ipv4Range and a PortRange +#[derive(Debug, Clone, Serialize, Deserialize)] +struct Ipv4AddrRange { + ip_range: Ipv4Range, + port_range: PortRange, +} + +impl FromStr for Ipv4AddrRange { + type Err = String; + + fn from_str(s: &str) -> Result { + let parts: Vec<&str> = s.split('/').collect(); + if parts.len() != 5 { + return Err("Invalid multiaddr format".to_string()); + } + + if parts[1] != "ip4" { + return Err("Only IPv4 addresses are supported".to_string()); + } + + let ip_range = Ipv4Range::new(parts[2])?; + if parts[3] != "tcp" { + return Err("Only TCP protocol is supported".to_string()); + } + + let port_range = PortRange::new(parts[4])?; + Ok(Ipv4AddrRange { ip_range, port_range }) + } +} + +impl Ipv4AddrRange { + fn contains(&self, addr: &Multiaddr) -> bool { + let mut ip = None; + let mut port = None; + + for protocol in addr { + match protocol { + Protocol::Ip4(ipv4) => ip = Some(ipv4), + Protocol::Tcp(tcp_port) => port = Some(tcp_port), + _ => {}, + } + } + + if let (Some(ip), Some(port)) = (ip, port) { + return self.ip_range.contains(ip) && self.port_range.contains(port); + } + + false + } +} + +impl fmt::Display for Ipv4AddrRange { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "/ip4/{}/tcp/{}", self.ip_range, self.port_range) + } +} + +// ----------------- Ipv4Range ----------------- +// A struct containing the start and end Ipv4Addr +#[derive(Debug, Clone, Serialize, Deserialize)] +struct Ipv4Range { + start: Ipv4Addr, + end: Ipv4Addr, +} + +impl Ipv4Range { + fn new(range_str: &str) -> Result { + let parts: Vec<&str> = range_str.split('.').collect(); + if parts.len() != 4 { + return Err("Invalid IP range format".to_string()); + } + + let mut start_octets = [0u8; 4]; + let mut end_octets = [0u8; 4]; + + for (i, part) in parts.iter().enumerate() { + if i == 0 { + start_octets[i] = part.parse().map_err(|_| "Invalid first octet".to_string())?; + end_octets[i] = start_octets[i]; + } else if part == &"*" { + start_octets[i] = 0; + end_octets[i] = u8::MAX; + } else if part.contains(':') { + let range_parts: Vec<&str> = part.split(':').collect(); + if range_parts.len() != 2 { + return Err("Invalid range format".to_string()); + } + start_octets[i] = range_parts[0].parse().map_err(|_| "Invalid range start".to_string())?; + end_octets[i] = range_parts[1].parse().map_err(|_| "Invalid range end".to_string())?; + } else { + start_octets[i] = part.parse().map_err(|_| "Invalid octet".to_string())?; + end_octets[i] = start_octets[i]; + } + } + + Ok(Ipv4Range { + start: Ipv4Addr::new(start_octets[0], start_octets[1], start_octets[2], start_octets[3]), + end: Ipv4Addr::new(end_octets[0], end_octets[1], end_octets[2], end_octets[3]), + }) + } + + fn contains(&self, addr: Ipv4Addr) -> bool { + let octets = addr.octets(); + let start_octets = self.start.octets(); + let end_octets = self.end.octets(); + + for i in 0..4 { + if octets[i] < start_octets[i] || octets[i] > end_octets[i] { + return false; + } + } + true + } +} + +impl fmt::Display for Ipv4Range { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let start_octets = self.start.octets(); + let end_octets = self.end.octets(); + write!( + f, + "{}.{}.{}.{}", + start_octets[0], + if start_octets[1] == 0 && end_octets[1] == u8::MAX { + "*".to_string() + } else if start_octets[1] == end_octets[1] { + start_octets[1].to_string() + } else { + format!("{}:{}", start_octets[1], end_octets[1]) + }, + if start_octets[2] == 0 && end_octets[2] == u8::MAX { + "*".to_string() + } else if start_octets[2] == end_octets[2] { + start_octets[2].to_string() + } else { + format!("{}:{}", start_octets[2], end_octets[2]) + }, + if start_octets[3] == 0 && end_octets[3] == u8::MAX { + "*".to_string() + } else if start_octets[3] == end_octets[3] { + start_octets[3].to_string() + } else { + format!("{}:{}", start_octets[3], end_octets[3]) + } + ) + } +} + +// ----------------- PortRange ----------------- +// A struct containing the start and end port +#[derive(Debug, Clone, Serialize, Deserialize)] +struct PortRange { + start: u16, + end: u16, +} + +impl PortRange { + fn new(range_str: &str) -> Result { + if range_str == "*" { + return Ok(PortRange { + start: 0, + end: u16::MAX, + }); + } + + if range_str.contains(':') { + let parts: Vec<&str> = range_str.split(':').collect(); + if parts.len() != 2 { + return Err("Invalid port range format".to_string()); + } + let start = parts[0].parse().map_err(|_| "Invalid port range start".to_string())?; + let end = parts[1].parse().map_err(|_| "Invalid port range end".to_string())?; + return Ok(PortRange { start, end }); + } + + let port = range_str.parse().map_err(|_| "Invalid port".to_string())?; + Ok(PortRange { start: port, end: port }) + } + + fn contains(&self, port: u16) -> bool { + port >= self.start && port <= self.end + } +} + +impl fmt::Display for PortRange { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.start == 0 && self.end == u16::MAX { + write!(f, "*") + } else if self.start == self.end { + write!(f, "{}", self.start) + } else { + write!(f, "{}:{}", self.start, self.end) + } + } +} + +#[cfg(test)] +mod test { + use std::net::{IpAddr, Ipv6Addr}; + + use crate::{ + multiaddr::Multiaddr, + net_address::{multiaddr_range::IP4_TCP_TEST_ADDR_RANGE, MultiaddrRange}, + }; + + #[test] + fn it_parses_properly_and_verify_inclusion() { + // MultiaddrRange for ip4 with tcp + + let my_addr_range: MultiaddrRange = "/ip4/127.*.100:200.*/tcp/18000:19000".parse().unwrap(); + let addr: Multiaddr = "/ip4/127.0.150.1/tcp/18500".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.150.1/tcp/17500".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.50.1/tcp/18500".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + + let my_addr_range: MultiaddrRange = "/ip4/127.*.100:200.*/tcp/*".parse().unwrap(); + let addr: Multiaddr = "/ip4/127.0.150.1/tcp/18500".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.150.1/tcp/17500".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.50.1/tcp/17500".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + + let my_addr_range: MultiaddrRange = "/ip4/127.0.0.1/tcp/18000:19000".parse().unwrap(); + let addr: Multiaddr = "/ip4/127.0.0.1/tcp/18500".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.1.1/tcp/18500".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.0.1/tcp/17500".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + + let my_addr_range: MultiaddrRange = "/ip4/127.0.0.1/tcp/18188".parse().unwrap(); + let addr: Multiaddr = "/ip4/127.0.0.1/tcp/18188".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.1.1/tcp/18188".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.0.1/tcp/18189".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + + let my_addr_range: MultiaddrRange = IP4_TCP_TEST_ADDR_RANGE.parse().unwrap(); + let addr: Multiaddr = "/ip4/127.0.0.1/tcp/18188".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.0.1/tcp/18189".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.1.2.3/tcp/555".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + + // MultiaddrRange for other protocols + + let my_addr_range: MultiaddrRange = "/ip4/127.0.0.1/udt/sctp/5678".parse().unwrap(); + let addr: Multiaddr = "/ip4/127.0.0.1/udt/sctp/5678".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.0.1/udt/sctp/5679".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + + let my_addr_range: MultiaddrRange = Multiaddr::from(IpAddr::V6(Ipv6Addr::new(0x2001, 0x2, 0, 0, 0x1, 0, 0, 0))) + .to_string() + .parse() + .unwrap(); + let addr = Multiaddr::from(IpAddr::V6(Ipv6Addr::new(0x2001, 0x2, 0, 0, 0x1, 0, 0, 0))); + assert!(my_addr_range.contains(&addr)); + let addr = Multiaddr::from(IpAddr::V6(Ipv6Addr::new(0x2001, 0x2, 0, 0, 0, 0, 0, 0))); + assert!(!my_addr_range.contains(&addr)); + } +} diff --git a/comms/dht/src/actor.rs b/comms/dht/src/actor.rs index 57790c0d44..64c1bebf85 100644 --- a/comms/dht/src/actor.rs +++ b/comms/dht/src/actor.rs @@ -35,7 +35,7 @@ use log::*; use tari_comms::{ connection_manager::ConnectionManagerError, connectivity::{ConnectivityError, ConnectivityRequester, ConnectivitySelection}, - multiaddr::Multiaddr, + net_address::MultiaddrRange, peer_manager::{NodeId, NodeIdentity, PeerFeatures, PeerManager, PeerManagerError, PeerQuery, PeerQuerySortBy}, types::CommsPublicKey, PeerConnection, @@ -386,7 +386,7 @@ impl DhtActor { // Helper function to check if all peer addresses are excluded async fn check_if_addresses_excluded( - excluded_dial_addresses: Vec, + excluded_dial_addresses: Vec, peer_manager: &PeerManager, node_id: NodeId, ) -> Result<(), DhtActorError> { @@ -394,7 +394,7 @@ impl DhtActor { let addresses = peer_manager.get_peer_multi_addresses(&node_id).await?; if addresses .iter() - .all(|addr| excluded_dial_addresses.contains(addr.address())) + .all(|addr| excluded_dial_addresses.iter().any(|v| v.contains(addr.address()))) { warn!( target: LOG_TARGET, @@ -533,7 +533,7 @@ impl DhtActor { async fn broadcast_join( node_identity: Arc, peer_manager: Arc, - excluded_dial_addresses: Vec, + excluded_dial_addresses: Vec, mut outbound_requester: OutboundMessageRequester, ) -> Result<(), DhtActorError> { DhtActor::check_if_addresses_excluded( @@ -748,10 +748,12 @@ impl DhtActor { let mut filtered_peers = Vec::with_capacity(peers.len()); for id in &peers { let addresses = peer_manager.get_peer_multi_addresses(id).await?; - if addresses - .iter() - .all(|addr| config.excluded_dial_addresses.contains(addr.address())) - { + if addresses.iter().all(|addr| { + config + .excluded_dial_addresses + .iter() + .any(|v| v.contains(addr.address())) + }) { trace!(target: LOG_TARGET, "Peer '{}' has only excluded addresses. Skipping.", id); } else { filtered_peers.push(id.clone()); diff --git a/comms/dht/src/config.rs b/comms/dht/src/config.rs index 6f3053539a..796246fe3f 100644 --- a/comms/dht/src/config.rs +++ b/comms/dht/src/config.rs @@ -24,7 +24,7 @@ use std::{path::Path, time::Duration}; use serde::{Deserialize, Serialize}; use tari_common::configuration::serializers; -use tari_comms::{multiaddr::Multiaddr, peer_validator::PeerValidatorConfig}; +use tari_comms::{net_address::MultiaddrRange, peer_validator::PeerValidatorConfig}; use crate::{ actor::OffenceSeverity, @@ -116,7 +116,7 @@ pub struct DhtConfig { /// See [PeerValidatorConfig] pub peer_validator_config: PeerValidatorConfig, /// Addresses that should never be dialed - pub excluded_dial_addresses: Vec, + pub excluded_dial_addresses: Vec, } impl DhtConfig { diff --git a/comms/dht/src/connectivity/mod.rs b/comms/dht/src/connectivity/mod.rs index b0294e9184..47d85ce8b4 100644 --- a/comms/dht/src/connectivity/mod.rs +++ b/comms/dht/src/connectivity/mod.rs @@ -870,10 +870,12 @@ impl DhtConnectivity { let mut neighbours = Vec::with_capacity(self.neighbours.len()); for peer in &self.neighbours { let addresses = self.peer_manager.get_peer_multi_addresses(peer).await?; - if !addresses - .iter() - .all(|addr| self.config.excluded_dial_addresses.contains(addr.address())) - { + if !addresses.iter().all(|addr| { + self.config + .excluded_dial_addresses + .iter() + .any(|v| v.contains(addr.address())) + }) { neighbours.push(peer.clone()); } } @@ -882,10 +884,12 @@ impl DhtConnectivity { let mut random_pool = Vec::with_capacity(self.random_pool.len()); for peer in &self.random_pool { let addresses = self.peer_manager.get_peer_multi_addresses(peer).await?; - if !addresses - .iter() - .all(|addr| self.config.excluded_dial_addresses.contains(addr.address())) - { + if !addresses.iter().all(|addr| { + self.config + .excluded_dial_addresses + .iter() + .any(|v| v.contains(addr.address())) + }) { random_pool.push(peer.clone()); } }