diff --git a/Cargo.lock b/Cargo.lock index 12c836ef05..9fd5eab7bc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6174,6 +6174,7 @@ dependencies = [ "tokio", "tokio-stream", "tokio-util 0.6.10", + "toml 0.5.11", "tower", "tracing", "yamux", diff --git a/base_layer/contacts/tests/contacts_service.rs b/base_layer/contacts/tests/contacts_service.rs index f960ca9bd7..73acc051ae 100644 --- a/base_layer/contacts/tests/contacts_service.rs +++ b/base_layer/contacts/tests/contacts_service.rs @@ -88,7 +88,6 @@ pub fn setup_contacts_service( auto_request: true, ..Default::default() }, - excluded_dial_addresses: vec![], ..Default::default() }, allow_test_addresses: true, diff --git a/base_layer/p2p/src/initialization.rs b/base_layer/p2p/src/initialization.rs index 7579d5954e..dd2cc4e3e4 100644 --- a/base_layer/p2p/src/initialization.rs +++ b/base_layer/p2p/src/initialization.rs @@ -332,7 +332,7 @@ async fn configure_comms_and_dht( .with_listener_liveness_allowlist_cidrs(listener_liveness_allowlist_cidrs) .with_dial_backoff(ConstantBackoff::new(Duration::from_millis(500))) .with_peer_storage(peer_database, Some(file_lock)) - .with_excluded_dial_addresses(config.dht.excluded_dial_addresses.clone()); + .with_excluded_dial_addresses(config.dht.excluded_dial_addresses.clone().into_vec().clone()); let mut comms = match config.auxiliary_tcp_listener_address { Some(ref addr) => builder.with_auxiliary_tcp_listener_address(addr.clone()).build()?, diff --git a/base_layer/wallet_ffi/src/error.rs b/base_layer/wallet_ffi/src/error.rs index 187ff1e5d1..e90770be83 100644 --- a/base_layer/wallet_ffi/src/error.rs +++ b/base_layer/wallet_ffi/src/error.rs @@ -57,6 +57,8 @@ pub enum InterfaceError { InvalidEmojiId, #[error("An error has occurred due to an invalid argument: `{0}`")] InvalidArgument(String), + #[error("An internal error has occurred: `{0}`")] + InternalError(String), #[error("Balance Unavailable")] BalanceError, } @@ -106,6 +108,10 @@ impl From for LibWalletError { code: 9, message: format!("Pointer error on {}:{:?}", p, v), }, + InterfaceError::InternalError(_) => Self { + code: 10, + message: format!("{:?}", v), + }, } } } diff --git a/base_layer/wallet_ffi/src/lib.rs b/base_layer/wallet_ffi/src/lib.rs index 6e9f3e9b43..0949e15b32 100644 --- a/base_layer/wallet_ffi/src/lib.rs +++ b/base_layer/wallet_ffi/src/lib.rs @@ -126,6 +126,7 @@ use tari_common_types::{ }; use tari_comms::{ multiaddr::Multiaddr, + net_address::{MultiaddrRange, MultiaddrRangeList, IP4_TCP_TEST_ADDR_RANGE}, peer_manager::{NodeIdentity, PeerQuery}, transports::MemoryTransport, types::CommsPublicKey, @@ -5199,6 +5200,7 @@ pub unsafe extern "C" fn transport_config_destroy(transport: *mut TariTransportC /// `database_path` - The database path char array pointer which. This is the folder path where the /// database files will be created and the application has write access to /// `discovery_timeout_in_secs`: specify how long the Discovery Timeout for the wallet is. +/// `exclude_dial_test_addresses`: exclude dialing of test addresses; this should be 'true' for production wallets /// `error_out` - Pointer to an int which will be modified to an error code should one occur, may not be null. Functions /// as an out parameter. /// @@ -5217,6 +5219,7 @@ pub unsafe extern "C" fn comms_config_create( datastore_path: *const c_char, discovery_timeout_in_secs: c_ulonglong, saf_message_duration_in_secs: c_ulonglong, + exclude_dial_test_addresses: bool, error_out: *mut c_int, ) -> *mut TariCommsConfig { let mut error = 0; @@ -5294,6 +5297,20 @@ pub unsafe extern "C" fn comms_config_create( MultiaddrList::from(vec![public_address]) }; + let excluded_dial_addresses = if exclude_dial_test_addresses { + let multi_addr_range = match MultiaddrRange::from_str(IP4_TCP_TEST_ADDR_RANGE) { + Ok(val) => val, + Err(e) => { + error = LibWalletError::from(InterfaceError::InternalError(e)).code; + ptr::swap(error_out, &mut error as *mut c_int); + return ptr::null_mut(); + }, + }; + MultiaddrRangeList::from(vec![multi_addr_range]) + } else { + MultiaddrRangeList::from(vec![]) + }; + let config = TariCommsConfig { override_from: None, public_addresses: addresses, @@ -5326,7 +5343,7 @@ pub unsafe extern "C" fn comms_config_create( minimum_desired_tcpv4_node_ratio: 0.0, ..Default::default() }, - excluded_dial_addresses: vec![], + excluded_dial_addresses, ..Default::default() }, allow_test_addresses: true, @@ -10237,6 +10254,7 @@ mod test { db_path_alice_str, 20, 10800, + false, error_ptr, ); @@ -10401,6 +10419,7 @@ mod test { db_path_alice_str, 20, 10800, + false, error_ptr, ); @@ -10628,6 +10647,7 @@ mod test { db_path_str, 20, 10800, + false, error_ptr, ); @@ -10691,6 +10711,7 @@ mod test { db_path_str, 20, 10800, + false, error_ptr, ); @@ -10774,6 +10795,7 @@ mod test { db_path_alice_str, 20, 10800, + false, error_ptr, ); @@ -10951,6 +10973,7 @@ mod test { db_path_alice_str, 20, 10800, + false, error_ptr, ); @@ -11089,6 +11112,7 @@ mod test { db_path_alice_str, 20, 10800, + false, error_ptr, ); @@ -11308,6 +11332,7 @@ mod test { db_path_alice_str, 20, 10800, + false, error_ptr, ); @@ -11534,6 +11559,7 @@ mod test { db_path_alice_str, 20, 10800, + false, error_ptr, ); @@ -11795,6 +11821,7 @@ mod test { db_path_str, 20, 10800, + false, error_ptr, ); let passphrase: *const c_char = CString::into_raw(CString::new("niao").unwrap()) as *const c_char; @@ -12175,6 +12202,7 @@ mod test { alice_db_path_str, 20, 10800, + false, error_ptr, ); let passphrase: *const c_char = CString::into_raw(CString::new("niao").unwrap()) as *const c_char; @@ -12239,6 +12267,7 @@ mod test { bob_db_path_str, 20, 10800, + false, error_ptr, ); let passphrase: *const c_char = CString::into_raw(CString::new("niao").unwrap()) as *const c_char; diff --git a/base_layer/wallet_ffi/wallet.h b/base_layer/wallet_ffi/wallet.h index 21a13dc6e4..c8419af904 100644 --- a/base_layer/wallet_ffi/wallet.h +++ b/base_layer/wallet_ffi/wallet.h @@ -2758,6 +2758,7 @@ void transport_config_destroy(TariTransportConfig *transport); * `database_path` - The database path char array pointer which. This is the folder path where the * database files will be created and the application has write access to * `discovery_timeout_in_secs`: specify how long the Discovery Timeout for the wallet is. + * `exclude_dial_test_addresses`: exclude dialing of test addresses; this should be 'true' for production wallets * `error_out` - Pointer to an int which will be modified to an error code should one occur, may not be null. Functions * as an out parameter. * @@ -2774,6 +2775,7 @@ TariCommsConfig *comms_config_create(const char *public_address, const char *datastore_path, unsigned long long discovery_timeout_in_secs, unsigned long long saf_message_duration_in_secs, + bool exclude_dial_test_addresses, int *error_out); /** diff --git a/clients/ffi_client/index.js b/clients/ffi_client/index.js index bfb30fdea1..6faea7baee 100644 --- a/clients/ffi_client/index.js +++ b/clients/ffi_client/index.js @@ -39,6 +39,7 @@ try { "./wallet", 30, 600, + false, err ); diff --git a/clients/ffi_client/recovery.js b/clients/ffi_client/recovery.js index 02875dea0a..286e4a5943 100644 --- a/clients/ffi_client/recovery.js +++ b/clients/ffi_client/recovery.js @@ -53,6 +53,7 @@ try { "./recovery", 30, 600, + false, err ); diff --git a/common/config/presets/c_base_node_c.toml b/common/config/presets/c_base_node_c.toml index 4c7b147950..3d72a8037c 100644 --- a/common/config/presets/c_base_node_c.toml +++ b/common/config/presets/c_base_node_c.toml @@ -303,8 +303,6 @@ database_url = "data/base_node/dht.db" #ban_duration = 21_600 # 6 * 60 * 60 # Length of time to ban a peer for a "short" duration. Default: 60 mins #ban_duration_short = 3_600 # 60 * 60 -# This allows the use of test addresses in the network like 127.0.0.1. Default: false -#allow_test_addresses = false # The maximum number of messages over `flood_ban_timespan` to allow before banning the peer (for `ban_duration_short`) # Default: 100_000 messages #flood_ban_max_msg_count = 100_000 @@ -316,5 +314,9 @@ database_url = "data/base_node/dht.db" # In a situation where a node is not well-connected and many nodes are locally marked as offline, we can retry # peers that were previously tried. Default: 2 hours #offline_peer_cooldown = 7_200 # 2 * 60 * 60 -# Addresses that should never be dialed (default value = []) -#excluded_dial_addresses = ["/ip4/x.x.x.x/tcp/xxxx", "/ip4/x.y.x.y/tcp/xyxy"] +# Addresses that should never be dialed (default value = []). This can be a specific address or an IPv4/TCP range. +# Example: When used in conjunction with `allow_test_addresses = true` (but it could be any other range) +# `excluded_dial_addresses = ["/ip4/127.*.0:49.*/tcp/*", "/ip4/127.*.101:255.*/tcp/*"]` +# or +# `excluded_dial_addresses = ["/ip4/127.0:0.1/tcp/122", "/ip4/127.0:0.1/tcp/1000:2000"]` +#excluded_dial_addresses = [] diff --git a/common/config/presets/d_console_wallet.toml b/common/config/presets/d_console_wallet.toml index 93b5a8f920..6a3c5d5f27 100644 --- a/common/config/presets/d_console_wallet.toml +++ b/common/config/presets/d_console_wallet.toml @@ -347,8 +347,6 @@ network_discovery.initial_peer_sync_delay = 25 #ban_duration = 21_600 # 6 * 60 * 60 # Length of time to ban a peer for a "short" duration. Default: 60 mins #ban_duration_short = 3_600 # 60 * 60 -# This allows the use of test addresses in the network like 127.0.0.1. Default: false -#allow_test_addresses = false # The maximum number of messages over `flood_ban_timespan` to allow before banning the peer (for `ban_duration_short`) # Default: 100_000 messages #flood_ban_max_msg_count = 100_000 @@ -360,5 +358,9 @@ network_discovery.initial_peer_sync_delay = 25 # In a situation where a node is not well-connected and many nodes are locally marked as offline, we can retry # peers that were previously tried. Default: 2 hours #offline_peer_cooldown = 7_200 # 2 * 60 * 60 -# Addresses that should never be dialed (default value = []) -#excluded_dial_addresses = ["/ip4/x.x.x.x/tcp/xxxx", "/ip4/x.y.x.y/tcp/xyxy"] +# Addresses that should never be dialed (default value = []). This can be a specific address or an IPv4/TCP range. +# Example: When used in conjunction with `allow_test_addresses = true` (but it could be any other range) +# `excluded_dial_addresses = ["/ip4/127.*.0:49.*/tcp/*", "/ip4/127.*.101:255.*/tcp/*"]` +# or +# `excluded_dial_addresses = ["/ip4/127.0:0.1/tcp/122", "/ip4/127.0:0.1/tcp/1000:2000"]` +#excluded_dial_addresses = [] diff --git a/comms/core/Cargo.toml b/comms/core/Cargo.toml index fef234ee85..8814759bf4 100644 --- a/comms/core/Cargo.toml +++ b/comms/core/Cargo.toml @@ -52,6 +52,7 @@ zeroize = "1" [dev-dependencies] tari_test_utils = { path = "../../infrastructure/test_utils" } tari_comms_rpc_macros = { path = "../rpc_macros" } +toml = { version = "0.5"} env_logger = "0.7.0" serde_json = "1.0.39" diff --git a/comms/core/src/builder/mod.rs b/comms/core/src/builder/mod.rs index 727d13cf6f..26455eabf2 100644 --- a/comms/core/src/builder/mod.rs +++ b/comms/core/src/builder/mod.rs @@ -45,6 +45,7 @@ use crate::{ connection_manager::{ConnectionManagerConfig, ConnectionManagerRequester}, connectivity::{ConnectivityConfig, ConnectivityRequester}, multiaddr::Multiaddr, + net_address::MultiaddrRange, peer_manager::{NodeIdentity, PeerManager}, peer_validator::PeerValidatorConfig, protocol::{NodeNetworkInfo, ProtocolExtensions}, @@ -242,7 +243,7 @@ impl CommsBuilder { self } - pub fn with_excluded_dial_addresses(mut self, excluded_addresses: Vec) -> Self { + pub fn with_excluded_dial_addresses(mut self, excluded_addresses: Vec) -> Self { self.connection_manager_config.excluded_dial_addresses = excluded_addresses; self } diff --git a/comms/core/src/connection_manager/dialer.rs b/comms/core/src/connection_manager/dialer.rs index 245d3e4308..357491ae22 100644 --- a/comms/core/src/connection_manager/dialer.rs +++ b/comms/core/src/connection_manager/dialer.rs @@ -55,7 +55,7 @@ use crate::{ }, multiaddr::Multiaddr, multiplexing::Yamux, - net_address::PeerAddressSource, + net_address::{MultiaddrRange, PeerAddressSource}, noise::{NoiseConfig, NoiseSocket}, peer_manager::{NodeId, NodeIdentity, Peer, PeerManager}, protocol::ProtocolId, @@ -557,7 +557,7 @@ where noise_config: &NoiseConfig, transport: &TTransport, network_byte: u8, - excluded_dial_addresses: Vec, + excluded_dial_addresses: Vec, ) -> ( DialState, Result<(NoiseSocket, Multiaddr), ConnectionManagerError>, @@ -568,7 +568,7 @@ where .clone() .into_vec() .iter() - .filter(|&a| !excluded_dial_addresses.iter().any(|excluded| a == excluded)) + .filter(|&a| !excluded_dial_addresses.iter().any(|excluded| excluded.contains(a))) .cloned() .collect::>(); if addresses.is_empty() { diff --git a/comms/core/src/connection_manager/manager.rs b/comms/core/src/connection_manager/manager.rs index 67c28679cd..a646a3dd41 100644 --- a/comms/core/src/connection_manager/manager.rs +++ b/comms/core/src/connection_manager/manager.rs @@ -49,6 +49,7 @@ use crate::{ backoff::Backoff, connection_manager::ConnectionId, multiplexing::Substream, + net_address::MultiaddrRange, noise::NoiseConfig, peer_manager::{NodeId, NodeIdentity, PeerManagerError}, peer_validator::PeerValidatorConfig, @@ -134,7 +135,7 @@ pub struct ConnectionManagerConfig { /// Peer validation configuration. See [PeerValidatorConfig] pub peer_validation_config: PeerValidatorConfig, /// Addresses that should never be dialed - pub excluded_dial_addresses: Vec, + pub excluded_dial_addresses: Vec, } impl Default for ConnectionManagerConfig { diff --git a/comms/core/src/connection_manager/tests/listener_dialer.rs b/comms/core/src/connection_manager/tests/listener_dialer.rs index a1c244b838..e73f052379 100644 --- a/comms/core/src/connection_manager/tests/listener_dialer.rs +++ b/comms/core/src/connection_manager/tests/listener_dialer.rs @@ -287,7 +287,7 @@ async fn excluded_yes() { let (request_tx, request_rx) = mpsc::channel(1); let peer_manager2 = build_peer_manager(); let connection_manager_config = ConnectionManagerConfig { - excluded_dial_addresses: vec![address.clone()], + excluded_dial_addresses: vec![address.to_string().parse().unwrap()], ..Default::default() }; let mut dialer = Dialer::new( diff --git a/comms/core/src/net_address/mod.rs b/comms/core/src/net_address/mod.rs index a437636096..1219f0d95a 100644 --- a/comms/core/src/net_address/mod.rs +++ b/comms/core/src/net_address/mod.rs @@ -27,3 +27,6 @@ pub use multiaddr_with_stats::{MultiaddrWithStats, PeerAddressSource}; mod mutliaddresses_with_stats; pub use mutliaddresses_with_stats::MultiaddressesWithStats; + +mod multiaddr_range; +pub use multiaddr_range::{MultiaddrRange, MultiaddrRangeList, IP4_TCP_TEST_ADDR_RANGE}; diff --git a/comms/core/src/net_address/multiaddr_range.rs b/comms/core/src/net_address/multiaddr_range.rs new file mode 100644 index 0000000000..1bbf2b913a --- /dev/null +++ b/comms/core/src/net_address/multiaddr_range.rs @@ -0,0 +1,506 @@ +// Copyright 2022 The Tari Project +// SPDX-License-Identifier: BSD-3-Clause + +use std::{fmt, net::Ipv4Addr, ops::Deref, slice, str::FromStr}; + +use multiaddr::{Multiaddr, Protocol}; +use serde::{ + de, + de::{Error, SeqAccess, Visitor}, + Deserialize, + Deserializer, + Serialize, +}; + +/// A MultiaddrRange for testing purposes that matches any IPv4 address and any port +pub const IP4_TCP_TEST_ADDR_RANGE: &str = "/ip4/127.*.*.*/tcp/*"; + +/// ----------------- MultiaddrRange ----------------- +/// A struct containing either an Ipv4AddrRange or a Multiaddr. If a range of IP addresses and/or ports needs to be +/// specified, the MultiaddrRange can be used, but it only supports IPv4 addresses with the TCP protocol. +#[derive(Debug, Clone, Serialize, PartialEq, Eq)] +pub enum MultiaddrRange { + Ipv4AddrRange(Ipv4AddrRange), + Multiaddr(Multiaddr), +} + +impl FromStr for MultiaddrRange { + type Err = String; + + fn from_str(s: &str) -> Result { + if let Ok(multiaddr) = Multiaddr::from_str(s) { + Ok(MultiaddrRange::Multiaddr(multiaddr)) + } else if let Ok(ipv4_addr_range) = Ipv4AddrRange::from_str(s) { + Ok(MultiaddrRange::Ipv4AddrRange(ipv4_addr_range)) + } else { + Err("Invalid format for both Multiaddr and Ipv4AddrRange".to_string()) + } + } +} + +impl fmt::Display for MultiaddrRange { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + MultiaddrRange::Ipv4AddrRange(ipv4_addr_range) => write!(f, "{}", ipv4_addr_range), + MultiaddrRange::Multiaddr(multiaddr) => write!(f, "{}", multiaddr), + } + } +} + +impl MultiaddrRange { + /// Check if the given Multiaddr is contained within the MultiaddrRange range + pub fn contains(&self, addr: &Multiaddr) -> bool { + match self { + MultiaddrRange::Ipv4AddrRange(ipv4_addr_range) => ipv4_addr_range.contains(addr), + MultiaddrRange::Multiaddr(multiaddr) => multiaddr == addr, + } + } +} + +/// ----------------- Ipv4AddrRange ----------------- +/// A struct containing an Ipv4Range and a PortRange +#[derive(Debug, Clone, Serialize, PartialEq, Eq)] +pub struct Ipv4AddrRange { + ip_range: Ipv4Range, + port_range: PortRange, +} + +impl FromStr for Ipv4AddrRange { + type Err = String; + + fn from_str(s: &str) -> Result { + let parts: Vec<&str> = s.split('/').collect(); + if parts.len() != 5 { + return Err("Invalid multiaddr format".to_string()); + } + + if parts[1] != "ip4" { + return Err("Only IPv4 addresses are supported".to_string()); + } + + let ip_range = Ipv4Range::new(parts[2])?; + if parts[3] != "tcp" { + return Err("Only TCP protocol is supported".to_string()); + } + + let port_range = PortRange::new(parts[4])?; + Ok(Ipv4AddrRange { ip_range, port_range }) + } +} + +impl Ipv4AddrRange { + fn contains(&self, addr: &Multiaddr) -> bool { + let mut ip = None; + let mut port = None; + + for protocol in addr { + match protocol { + Protocol::Ip4(ipv4) => ip = Some(ipv4), + Protocol::Tcp(tcp_port) => port = Some(tcp_port), + _ => {}, + } + } + + if let (Some(ip), Some(port)) = (ip, port) { + return self.ip_range.contains(ip) && self.port_range.contains(port); + } + + false + } +} + +impl fmt::Display for Ipv4AddrRange { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "/ip4/{}/tcp/{}", self.ip_range, self.port_range) + } +} + +// ----------------- Ipv4Range ----------------- +// A struct containing the start and end Ipv4Addr +#[derive(Debug, Clone, Serialize, PartialEq, Eq)] +struct Ipv4Range { + start: Ipv4Addr, + end: Ipv4Addr, +} + +impl Ipv4Range { + fn new(range_str: &str) -> Result { + let parts: Vec<&str> = range_str.split('.').collect(); + if parts.len() != 4 { + return Err("Invalid IP range format".to_string()); + } + + let mut start_octets = [0u8; 4]; + let mut end_octets = [0u8; 4]; + + for (i, part) in parts.iter().enumerate() { + if i == 0 { + start_octets[i] = part.parse().map_err(|_| "Invalid first IPv4 octet".to_string())?; + end_octets[i] = start_octets[i]; + } else if part == &"*" { + start_octets[i] = 0; + end_octets[i] = u8::MAX; + } else if part.contains(':') { + let range_parts: Vec<&str> = part.split(':').collect(); + if range_parts.len() != 2 { + return Err(format!("Invalid range format for IPv4 octet {}", i)); + } + start_octets[i] = range_parts[0] + .parse() + .map_err(|_| format!("Invalid range start for IPv4 octet {}", i))?; + end_octets[i] = range_parts[1] + .parse() + .map_err(|_| format!("Invalid range end for IPv4 octet {}", i))?; + } else { + start_octets[i] = part.parse().map_err(|_| format!("Invalid IPv4 octet {}", i))?; + end_octets[i] = start_octets[i]; + } + } + + Ok(Ipv4Range { + start: Ipv4Addr::new(start_octets[0], start_octets[1], start_octets[2], start_octets[3]), + end: Ipv4Addr::new(end_octets[0], end_octets[1], end_octets[2], end_octets[3]), + }) + } + + fn contains(&self, addr: Ipv4Addr) -> bool { + let octets = addr.octets(); + let start_octets = self.start.octets(); + let end_octets = self.end.octets(); + + for i in 0..4 { + if octets[i] < start_octets[i] || octets[i] > end_octets[i] { + return false; + } + } + true + } +} + +impl fmt::Display for Ipv4Range { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let start_octets = self.start.octets(); + let end_octets = self.end.octets(); + write!( + f, + "{}.{}.{}.{}", + start_octets[0], + if start_octets[1] == 0 && end_octets[1] == u8::MAX { + "*".to_string() + } else if start_octets[1] == end_octets[1] { + start_octets[1].to_string() + } else { + format!("{}:{}", start_octets[1], end_octets[1]) + }, + if start_octets[2] == 0 && end_octets[2] == u8::MAX { + "*".to_string() + } else if start_octets[2] == end_octets[2] { + start_octets[2].to_string() + } else { + format!("{}:{}", start_octets[2], end_octets[2]) + }, + if start_octets[3] == 0 && end_octets[3] == u8::MAX { + "*".to_string() + } else if start_octets[3] == end_octets[3] { + start_octets[3].to_string() + } else { + format!("{}:{}", start_octets[3], end_octets[3]) + } + ) + } +} + +// ----------------- PortRange ----------------- +// A struct containing the start and end port +#[derive(Debug, Clone, Serialize, PartialEq, Eq)] +struct PortRange { + start: u16, + end: u16, +} + +impl PortRange { + fn new(range_str: &str) -> Result { + if range_str == "*" { + return Ok(PortRange { + start: 0, + end: u16::MAX, + }); + } + + if range_str.contains(':') { + let parts: Vec<&str> = range_str.split(':').collect(); + if parts.len() != 2 { + return Err("Invalid port range format".to_string()); + } + let start = parts[0] + .parse() + .map_err(|_| format!("Invalid port range start '{}'", parts[0]))?; + let end = parts[1] + .parse() + .map_err(|_| format!("Invalid port range end '{}'", parts[1]))?; + if end < start { + return Err(format!( + "Invalid port range '{}', end `{}` is less than start `{}`", + range_str, end, start + )); + } + return Ok(PortRange { start, end }); + } + + let port = range_str.parse().map_err(|_| "Invalid port".to_string())?; + Ok(PortRange { start: port, end: port }) + } + + fn contains(&self, port: u16) -> bool { + port >= self.start && port <= self.end + } +} + +impl fmt::Display for PortRange { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.start <= 1 && self.end == u16::MAX { + write!(f, "*") + } else if self.start == self.end { + write!(f, "{}", self.start) + } else { + write!(f, "{}:{}", self.start, self.end) + } + } +} + +/// ----------------- MultiaddrRangeList ----------------- +/// Supports deserialization from a sequence of strings or comma-delimited strings +#[derive(Debug, Default, Clone, Serialize, PartialEq, Eq)] +pub struct MultiaddrRangeList(Vec); + +impl MultiaddrRangeList { + pub fn new() -> Self { + Self(vec![]) + } + + pub fn with_capacity(size: usize) -> Self { + Self(Vec::with_capacity(size)) + } + + pub fn into_vec(self) -> Vec { + self.0 + } + + pub fn as_slice(&self) -> &[MultiaddrRange] { + self.0.as_slice() + } +} + +impl Deref for MultiaddrRangeList { + type Target = [MultiaddrRange]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl AsRef<[MultiaddrRange]> for MultiaddrRangeList { + fn as_ref(&self) -> &[MultiaddrRange] { + self.0.as_ref() + } +} + +impl From> for MultiaddrRangeList { + fn from(v: Vec) -> Self { + Self(v) + } +} + +impl IntoIterator for MultiaddrRangeList { + type IntoIter = as IntoIterator>::IntoIter; + type Item = as IntoIterator>::Item; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl<'a> IntoIterator for &'a MultiaddrRangeList { + type IntoIter = slice::Iter<'a, MultiaddrRange>; + type Item = ::Item; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter() + } +} + +impl<'de> Deserialize<'de> for MultiaddrRangeList { + fn deserialize(deserializer: D) -> Result + where D: Deserializer<'de> { + struct MultiaddrRangeListVisitor; + + impl<'de> Visitor<'de> for MultiaddrRangeListVisitor { + type Value = MultiaddrRangeList; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "a comma delimited string or multiple string elements") + } + + fn visit_str(self, v: &str) -> Result + where E: de::Error { + let strings = v + .split(',') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .collect::>(); + let multiaddr_ranges: Result, _> = strings + .into_iter() + .map(|item| MultiaddrRange::from_str(item).map_err(E::custom)) + .collect(); + Ok(MultiaddrRangeList(multiaddr_ranges.map_err(E::custom)?)) + } + + fn visit_newtype_struct(self, deserializer: D) -> Result + where D: Deserializer<'de> { + deserializer.deserialize_seq(MultiaddrRangeListVisitor) + } + + fn visit_seq(self, mut seq: A) -> Result + where A: SeqAccess<'de> { + let mut buf = seq.size_hint().map(Vec::with_capacity).unwrap_or_default(); + while let Some(v) = seq.next_element::()? { + buf.push(v) + } + Ok(MultiaddrRangeList(buf)) + } + } + + if deserializer.is_human_readable() { + deserializer.deserialize_seq(MultiaddrRangeListVisitor) + } else { + deserializer.deserialize_newtype_struct("MultiaddrRangeList", MultiaddrRangeListVisitor) + } + } +} + +impl<'de> Deserialize<'de> for MultiaddrRange { + fn deserialize(deserializer: D) -> Result + where D: Deserializer<'de> { + let s = String::deserialize(deserializer)?; + MultiaddrRange::from_str(&s).map_err(D::Error::custom) + } +} + +#[cfg(test)] +mod test { + use std::{ + net::{IpAddr, Ipv6Addr}, + str::FromStr, + }; + + use serde::Deserialize; + + use crate::{ + multiaddr::Multiaddr, + net_address::{multiaddr_range::IP4_TCP_TEST_ADDR_RANGE, MultiaddrRange, MultiaddrRangeList}, + }; + + #[derive(Deserialize)] + struct Test { + something: MultiaddrRangeList, + } + + #[test] + fn it_parses_with_serde() { + // Random tests + let config_str = r#"something = [ + "/ip4/127.*.100:200.*/tcp/18000:19000", + "/ip4/127.0.150.1/tcp/18500", + "/ip4/127.0.0.1/udt/sctp/5678", + "/ip4/127.*.*.*/tcp/*" + ]"#; + let item_vec = toml::from_str::(config_str).unwrap().something.into_vec(); + assert_eq!(item_vec, vec![ + MultiaddrRange::from_str("/ip4/127.*.100:200.*/tcp/18000:19000").unwrap(), + MultiaddrRange::from_str("/ip4/127.0.150.1/tcp/18500").unwrap(), + MultiaddrRange::from_str("/ip4/127.0.0.1/udt/sctp/5678").unwrap(), + MultiaddrRange::from_str(IP4_TCP_TEST_ADDR_RANGE).unwrap() + ]); + + // Allowing only '/ip4/127.0.0.1/tcp/0:18189' + let config_str = r#"something = [ + "/ip4/127.*.*.*/tcp/0:18188", + "/ip4/127.*.*.*/tcp/18190:65535", + "/ip4/127.0.0.0/tcp/18189", + "/ip4/127.1:255.1:255.2:255/tcp/18189" + ]"#; + let item_vec = toml::from_str::(config_str).unwrap().something.into_vec(); + assert_eq!(item_vec, vec![ + MultiaddrRange::from_str("/ip4/127.*.*.*/tcp/0:18188").unwrap(), + MultiaddrRange::from_str("/ip4/127.*.*.*/tcp/18190:65535").unwrap(), + MultiaddrRange::from_str("/ip4/127.0.0.0/tcp/18189").unwrap(), + MultiaddrRange::from_str("/ip4/127.1:255.1:255.2:255/tcp/18189").unwrap(), + ]); + + for item in item_vec { + assert!(!item.contains(&Multiaddr::from_str("/ip4/127.0.0.1/tcp/18189").unwrap())); + } + } + + #[test] + fn it_parses_properly_and_verify_inclusion() { + // MultiaddrRange for ip4 with tcp + + let my_addr_range: MultiaddrRange = "/ip4/127.*.100:200.*/tcp/18000:19000".parse().unwrap(); + let addr: Multiaddr = "/ip4/127.0.150.1/tcp/18500".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.150.1/tcp/17500".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.50.1/tcp/18500".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + + let my_addr_range: MultiaddrRange = "/ip4/127.*.100:200.*/tcp/*".parse().unwrap(); + let addr: Multiaddr = "/ip4/127.0.150.1/tcp/18500".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.150.1/tcp/17500".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.50.1/tcp/17500".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + + let my_addr_range: MultiaddrRange = "/ip4/127.0.0.1/tcp/18000:19000".parse().unwrap(); + let addr: Multiaddr = "/ip4/127.0.0.1/tcp/18500".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.1.1/tcp/18500".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.0.1/tcp/17500".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + + let my_addr_range: MultiaddrRange = "/ip4/127.0.0.1/tcp/18188".parse().unwrap(); + let addr: Multiaddr = "/ip4/127.0.0.1/tcp/18188".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.1.1/tcp/18188".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.0.1/tcp/18189".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + + let my_addr_range: MultiaddrRange = IP4_TCP_TEST_ADDR_RANGE.parse().unwrap(); + let addr: Multiaddr = "/ip4/127.0.0.1/tcp/18188".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.0.1/tcp/18189".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.1.2.3/tcp/555".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + + // MultiaddrRange for other protocols + + let my_addr_range: MultiaddrRange = "/ip4/127.0.0.1/udt/sctp/5678".parse().unwrap(); + let addr: Multiaddr = "/ip4/127.0.0.1/udt/sctp/5678".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.0.1/udt/sctp/5679".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + + let my_addr_range: MultiaddrRange = Multiaddr::from(IpAddr::V6(Ipv6Addr::new(0x2001, 0x2, 0, 0, 0x1, 0, 0, 0))) + .to_string() + .parse() + .unwrap(); + let addr = Multiaddr::from(IpAddr::V6(Ipv6Addr::new(0x2001, 0x2, 0, 0, 0x1, 0, 0, 0))); + assert!(my_addr_range.contains(&addr)); + let addr = Multiaddr::from(IpAddr::V6(Ipv6Addr::new(0x2001, 0x2, 0, 0, 0, 0, 0, 0))); + assert!(!my_addr_range.contains(&addr)); + } +} diff --git a/comms/dht/src/actor.rs b/comms/dht/src/actor.rs index 57790c0d44..0291e81bc4 100644 --- a/comms/dht/src/actor.rs +++ b/comms/dht/src/actor.rs @@ -35,7 +35,7 @@ use log::*; use tari_comms::{ connection_manager::ConnectionManagerError, connectivity::{ConnectivityError, ConnectivityRequester, ConnectivitySelection}, - multiaddr::Multiaddr, + net_address::MultiaddrRange, peer_manager::{NodeId, NodeIdentity, PeerFeatures, PeerManager, PeerManagerError, PeerQuery, PeerQuerySortBy}, types::CommsPublicKey, PeerConnection, @@ -386,7 +386,7 @@ impl DhtActor { // Helper function to check if all peer addresses are excluded async fn check_if_addresses_excluded( - excluded_dial_addresses: Vec, + excluded_dial_addresses: Vec, peer_manager: &PeerManager, node_id: NodeId, ) -> Result<(), DhtActorError> { @@ -394,7 +394,7 @@ impl DhtActor { let addresses = peer_manager.get_peer_multi_addresses(&node_id).await?; if addresses .iter() - .all(|addr| excluded_dial_addresses.contains(addr.address())) + .all(|addr| excluded_dial_addresses.iter().any(|v| v.contains(addr.address()))) { warn!( target: LOG_TARGET, @@ -419,7 +419,7 @@ impl DhtActor { Box::pin(Self::broadcast_join( node_identity, peer_manager, - excluded_dial_addresses, + excluded_dial_addresses.into_vec(), outbound_requester, )) }, @@ -502,7 +502,7 @@ impl DhtActor { Box::pin(async move { DhtActor::check_if_addresses_excluded( - excluded_dial_addresses, + excluded_dial_addresses.into_vec(), &peer_manager, node_identity.node_id().clone(), ) @@ -533,7 +533,7 @@ impl DhtActor { async fn broadcast_join( node_identity: Arc, peer_manager: Arc, - excluded_dial_addresses: Vec, + excluded_dial_addresses: Vec, mut outbound_requester: OutboundMessageRequester, ) -> Result<(), DhtActorError> { DhtActor::check_if_addresses_excluded( @@ -748,10 +748,12 @@ impl DhtActor { let mut filtered_peers = Vec::with_capacity(peers.len()); for id in &peers { let addresses = peer_manager.get_peer_multi_addresses(id).await?; - if addresses - .iter() - .all(|addr| config.excluded_dial_addresses.contains(addr.address())) - { + if addresses.iter().all(|addr| { + config + .excluded_dial_addresses + .iter() + .any(|v| v.contains(addr.address())) + }) { trace!(target: LOG_TARGET, "Peer '{}' has only excluded addresses. Skipping.", id); } else { filtered_peers.push(id.clone()); diff --git a/comms/dht/src/config.rs b/comms/dht/src/config.rs index 6f3053539a..069c77b9f1 100644 --- a/comms/dht/src/config.rs +++ b/comms/dht/src/config.rs @@ -24,7 +24,7 @@ use std::{path::Path, time::Duration}; use serde::{Deserialize, Serialize}; use tari_common::configuration::serializers; -use tari_comms::{multiaddr::Multiaddr, peer_validator::PeerValidatorConfig}; +use tari_comms::{net_address::MultiaddrRangeList, peer_validator::PeerValidatorConfig}; use crate::{ actor::OffenceSeverity, @@ -94,7 +94,6 @@ pub struct DhtConfig { /// Default: 10 mins #[serde(with = "serializers::seconds")] pub ban_duration_short: Duration, - /// The maximum number of messages over `flood_ban_timespan` to allow before banning the peer (for /// `ban_duration_short`) Default: 100_000 messages pub flood_ban_max_msg_count: usize, @@ -115,8 +114,12 @@ pub struct DhtConfig { /// Configuration for peer validation /// See [PeerValidatorConfig] pub peer_validator_config: PeerValidatorConfig, - /// Addresses that should never be dialed - pub excluded_dial_addresses: Vec, + /// Addresses that should never be dialed (default value = []). This can be a specific address or an IPv4/TCP + /// range. Example: When used in conjunction with `allow_test_addresses = true` (but it could be any other + /// range) `excluded_dial_addresses = ["/ip4/127.*.0:49.*/tcp/*", "/ip4/127.*.101:255.*/tcp/*"]` + /// or + /// `excluded_dial_addresses = ["/ip4/127.0:0.1/tcp/122", "/ip4/127.0:0.1/tcp/1000:2000"]` + pub excluded_dial_addresses: MultiaddrRangeList, } impl DhtConfig { @@ -195,7 +198,7 @@ impl Default for DhtConfig { max_permitted_peer_claims: 5, offline_peer_cooldown: Duration::from_secs(24 * 60 * 60), peer_validator_config: Default::default(), - excluded_dial_addresses: vec![], + excluded_dial_addresses: vec![].into(), } } } diff --git a/comms/dht/src/connectivity/mod.rs b/comms/dht/src/connectivity/mod.rs index b0294e9184..47d85ce8b4 100644 --- a/comms/dht/src/connectivity/mod.rs +++ b/comms/dht/src/connectivity/mod.rs @@ -870,10 +870,12 @@ impl DhtConnectivity { let mut neighbours = Vec::with_capacity(self.neighbours.len()); for peer in &self.neighbours { let addresses = self.peer_manager.get_peer_multi_addresses(peer).await?; - if !addresses - .iter() - .all(|addr| self.config.excluded_dial_addresses.contains(addr.address())) - { + if !addresses.iter().all(|addr| { + self.config + .excluded_dial_addresses + .iter() + .any(|v| v.contains(addr.address())) + }) { neighbours.push(peer.clone()); } } @@ -882,10 +884,12 @@ impl DhtConnectivity { let mut random_pool = Vec::with_capacity(self.random_pool.len()); for peer in &self.random_pool { let addresses = self.peer_manager.get_peer_multi_addresses(peer).await?; - if !addresses - .iter() - .all(|addr| self.config.excluded_dial_addresses.contains(addr.address())) - { + if !addresses.iter().all(|addr| { + self.config + .excluded_dial_addresses + .iter() + .any(|v| v.contains(addr.address())) + }) { random_pool.push(peer.clone()); } } diff --git a/integration_tests/src/ffi/comms_config.rs b/integration_tests/src/ffi/comms_config.rs index 2d1eb6e8da..b8de76866b 100644 --- a/integration_tests/src/ffi/comms_config.rs +++ b/integration_tests/src/ffi/comms_config.rs @@ -49,6 +49,7 @@ impl CommsConfig { CString::new(base_dir).unwrap().into_raw(), 30, 600, + false, // This needs to be 'false' for the tests to pass &mut error, ); if error > 0 { diff --git a/integration_tests/src/ffi/ffi_import.rs b/integration_tests/src/ffi/ffi_import.rs index 70830118ed..539c73ff7b 100644 --- a/integration_tests/src/ffi/ffi_import.rs +++ b/integration_tests/src/ffi/ffi_import.rs @@ -367,6 +367,7 @@ extern "C" { datastore_path: *const c_char, discovery_timeout_in_secs: c_ulonglong, saf_message_duration_in_secs: c_ulonglong, + exclude_dial_test_addresses: bool, error_out: *mut c_int, ) -> *mut TariCommsConfig; pub fn comms_config_destroy(wc: *mut TariCommsConfig);