From aca61f39cdb65cccfafa02565b9b13bdf8fc3d99 Mon Sep 17 00:00:00 2001 From: Hansie Odendaal <39146854+hansieodendaal@users.noreply.github.com> Date: Wed, 18 Sep 2024 10:22:59 +0200 Subject: [PATCH] feat: add multiaddr with range checks for use with universe (#6557) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Description --- Added `MultiaddrRange`, which implements range checks for `Multiaddr`, when using IP4 with TCP. This enables specifying a range of IP4 with TCP addresses. As an exmple, any communication node can enable test addresses to connect to them (`allow_test_addresses = true`), but refrain from dialling any test addresses in return (`excluded_dial_addresses = ["/ip4/127.*.*.*/tcp/*"]`). With application to universe: - TCP seed node settings: ``` allow_test_addresses = true excluded_dial_addresses = [ "/ip4/127.*.*.*/tcp/0:18188", "/ip4/127.*.*.*/tcp/18190:65534", "/ip4/127.0.0.0/tcp/18189", "/ip4/127.1:255.1:255.2:255/tcp/18189" ] # Only '/ip4/127.0.0.1/tcp/0:18189' allowed ``` - Universe base node settings: ``` type = "tcp" public_addresses = ["/ip4/127.0.0.1/tcp/18189"] tcp.listener_address = "/ip4/0.0.0.0/tcp/18189" allow_test_addresses = true public_addresses = ["/ip4/127.0.0.1/tcp/18189"] #excluded_dial_addresses = [] ``` - Universe wallet settings: ``` dns_seeds = [] custom_base_node = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx::/ip4/127.0.0.1/tcp/18189" type = "tcp" public_addresses = ["/ip4/127.0.0.1/tcp/18188"] tcp.listener_address = "/ip4/0.0.0.0/tcp/18188" allow_test_addresses = true public_addresses = ["/ip4/127.0.0.1/tcp/18188"] #excluded_dial_addresses = [] ``` Motivation and Context --- Currently, Universe base nodes and wallets use `/ip4/172.2.3.4/tcp/18189` and `/ip4/172.2.3.4/tcp/18188` as their public addresses respectively, but any node trying to contact them is not able to. This results in many wasted resources. The Universe wallets also maintain connections with the seed nodes, which is not ideal. How Has This Been Tested? --- - Added new unit tests - System-level testing using the suggested settings^ (simulated seed node, simulated universe base node, simulated universe wallet) **From the seed node to the universe wallet** ``` >> add-peer 4602fb85883fec887e6b5e5a93cbc9547f19817685e78ff0a9e585826e322b44 /ip4/127.0.0.1/tcp/18188 Peer with node id '9a7764f742bf4e4bae6c5bd4f6' was added to the base node. >> ☎️ Dialing peer... ☠️ ConnectionFailed: All peer addresses are excluded for peer 9a7764f742bf4e4bae6c5bd4f6 ``` **From the seed node to the universe base node** ``` >> add-peer ee9ad9dce31a2d4a9225f4965e50df98ae4f85b58f94b34b9db9cc44f2aa2921 /ip4/127.0.0.1/tcp/18189 Peer with node id '998eb49cf4f2dd3b3d5a394c8e' was added to the base node. ☎️ Dialing peer... ⚡️ Peer connected in 0ms! Connection: Id: 1, Node ID: 998eb49cf4f2dd3b, Direction: Inbound, Peer Address: /ip4/192.168.5.114/tcp/62398, Age: 1913s, #Substreams: 2, #Refs: 2 ``` **From the universe base node to the universe wallet** ``` >> add-peer 4602fb85883fec887e6b5e5a93cbc9547f19817685e78ff0a9e585826e322b44 /ip4/127.0.0.1/tcp/18188 Peer with node id '9a7764f742bf4e4bae6c5bd4f6' was added to the base node. ☎️ Dialing peer... ⚡️ Peer connected in 0ms! Connection: Id: 8, Node ID: 9a7764f742bf4e4b, Direction: Inbound, Peer Address: /ip4/127.0.0.1/tcp/62412, Age: 2054s, #Substreams: 6, #Refs: 2 ``` **From the universe base node to the seed node** ``` >> add-peer 6677c4d401b98f403de671712a98ad8bf2976db27a1d411b08bedfd86751e048 /ip4/192.168.5.114/tcp/9991 Peer with node id '85d605836f02951c65651f99d0' was added to the base node. ☎️ Dialing peer... >> ⚡️ Peer connected in 0ms! Connection: Id: 0, Node ID: 85d605836f02951c, Direction: Outbound, Peer Address: /ip4/192.168.5.114/tcp/9991, Age: 2050s, #Substreams: 2, #Refs: 3 ``` What process can a PR reviewer use to test or verify this change? --- - Code review - System-level testing Breaking Changes --- - [x] None - [ ] Requires data directory on base node to be deleted - [ ] Requires hard fork - [ ] Other - Please specify --- Cargo.lock | 1 + base_layer/contacts/tests/contacts_service.rs | 1 - base_layer/p2p/src/initialization.rs | 2 +- base_layer/wallet_ffi/src/error.rs | 6 + base_layer/wallet_ffi/src/lib.rs | 31 +- base_layer/wallet_ffi/wallet.h | 2 + clients/ffi_client/index.js | 1 + clients/ffi_client/recovery.js | 1 + common/config/presets/c_base_node_c.toml | 10 +- common/config/presets/d_console_wallet.toml | 10 +- comms/core/Cargo.toml | 1 + comms/core/src/builder/mod.rs | 3 +- comms/core/src/connection_manager/dialer.rs | 6 +- comms/core/src/connection_manager/manager.rs | 3 +- .../tests/listener_dialer.rs | 2 +- comms/core/src/net_address/mod.rs | 3 + comms/core/src/net_address/multiaddr_range.rs | 506 ++++++++++++++++++ comms/dht/src/actor.rs | 22 +- comms/dht/src/config.rs | 13 +- comms/dht/src/connectivity/mod.rs | 20 +- integration_tests/src/ffi/comms_config.rs | 1 + integration_tests/src/ffi/ffi_import.rs | 1 + 22 files changed, 606 insertions(+), 40 deletions(-) create mode 100644 comms/core/src/net_address/multiaddr_range.rs diff --git a/Cargo.lock b/Cargo.lock index 12c836ef05..9fd5eab7bc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6174,6 +6174,7 @@ dependencies = [ "tokio", "tokio-stream", "tokio-util 0.6.10", + "toml 0.5.11", "tower", "tracing", "yamux", diff --git a/base_layer/contacts/tests/contacts_service.rs b/base_layer/contacts/tests/contacts_service.rs index f960ca9bd7..73acc051ae 100644 --- a/base_layer/contacts/tests/contacts_service.rs +++ b/base_layer/contacts/tests/contacts_service.rs @@ -88,7 +88,6 @@ pub fn setup_contacts_service( auto_request: true, ..Default::default() }, - excluded_dial_addresses: vec![], ..Default::default() }, allow_test_addresses: true, diff --git a/base_layer/p2p/src/initialization.rs b/base_layer/p2p/src/initialization.rs index 7579d5954e..dd2cc4e3e4 100644 --- a/base_layer/p2p/src/initialization.rs +++ b/base_layer/p2p/src/initialization.rs @@ -332,7 +332,7 @@ async fn configure_comms_and_dht( .with_listener_liveness_allowlist_cidrs(listener_liveness_allowlist_cidrs) .with_dial_backoff(ConstantBackoff::new(Duration::from_millis(500))) .with_peer_storage(peer_database, Some(file_lock)) - .with_excluded_dial_addresses(config.dht.excluded_dial_addresses.clone()); + .with_excluded_dial_addresses(config.dht.excluded_dial_addresses.clone().into_vec().clone()); let mut comms = match config.auxiliary_tcp_listener_address { Some(ref addr) => builder.with_auxiliary_tcp_listener_address(addr.clone()).build()?, diff --git a/base_layer/wallet_ffi/src/error.rs b/base_layer/wallet_ffi/src/error.rs index 187ff1e5d1..e90770be83 100644 --- a/base_layer/wallet_ffi/src/error.rs +++ b/base_layer/wallet_ffi/src/error.rs @@ -57,6 +57,8 @@ pub enum InterfaceError { InvalidEmojiId, #[error("An error has occurred due to an invalid argument: `{0}`")] InvalidArgument(String), + #[error("An internal error has occurred: `{0}`")] + InternalError(String), #[error("Balance Unavailable")] BalanceError, } @@ -106,6 +108,10 @@ impl From for LibWalletError { code: 9, message: format!("Pointer error on {}:{:?}", p, v), }, + InterfaceError::InternalError(_) => Self { + code: 10, + message: format!("{:?}", v), + }, } } } diff --git a/base_layer/wallet_ffi/src/lib.rs b/base_layer/wallet_ffi/src/lib.rs index 6e9f3e9b43..0949e15b32 100644 --- a/base_layer/wallet_ffi/src/lib.rs +++ b/base_layer/wallet_ffi/src/lib.rs @@ -126,6 +126,7 @@ use tari_common_types::{ }; use tari_comms::{ multiaddr::Multiaddr, + net_address::{MultiaddrRange, MultiaddrRangeList, IP4_TCP_TEST_ADDR_RANGE}, peer_manager::{NodeIdentity, PeerQuery}, transports::MemoryTransport, types::CommsPublicKey, @@ -5199,6 +5200,7 @@ pub unsafe extern "C" fn transport_config_destroy(transport: *mut TariTransportC /// `database_path` - The database path char array pointer which. This is the folder path where the /// database files will be created and the application has write access to /// `discovery_timeout_in_secs`: specify how long the Discovery Timeout for the wallet is. +/// `exclude_dial_test_addresses`: exclude dialing of test addresses; this should be 'true' for production wallets /// `error_out` - Pointer to an int which will be modified to an error code should one occur, may not be null. Functions /// as an out parameter. /// @@ -5217,6 +5219,7 @@ pub unsafe extern "C" fn comms_config_create( datastore_path: *const c_char, discovery_timeout_in_secs: c_ulonglong, saf_message_duration_in_secs: c_ulonglong, + exclude_dial_test_addresses: bool, error_out: *mut c_int, ) -> *mut TariCommsConfig { let mut error = 0; @@ -5294,6 +5297,20 @@ pub unsafe extern "C" fn comms_config_create( MultiaddrList::from(vec![public_address]) }; + let excluded_dial_addresses = if exclude_dial_test_addresses { + let multi_addr_range = match MultiaddrRange::from_str(IP4_TCP_TEST_ADDR_RANGE) { + Ok(val) => val, + Err(e) => { + error = LibWalletError::from(InterfaceError::InternalError(e)).code; + ptr::swap(error_out, &mut error as *mut c_int); + return ptr::null_mut(); + }, + }; + MultiaddrRangeList::from(vec![multi_addr_range]) + } else { + MultiaddrRangeList::from(vec![]) + }; + let config = TariCommsConfig { override_from: None, public_addresses: addresses, @@ -5326,7 +5343,7 @@ pub unsafe extern "C" fn comms_config_create( minimum_desired_tcpv4_node_ratio: 0.0, ..Default::default() }, - excluded_dial_addresses: vec![], + excluded_dial_addresses, ..Default::default() }, allow_test_addresses: true, @@ -10237,6 +10254,7 @@ mod test { db_path_alice_str, 20, 10800, + false, error_ptr, ); @@ -10401,6 +10419,7 @@ mod test { db_path_alice_str, 20, 10800, + false, error_ptr, ); @@ -10628,6 +10647,7 @@ mod test { db_path_str, 20, 10800, + false, error_ptr, ); @@ -10691,6 +10711,7 @@ mod test { db_path_str, 20, 10800, + false, error_ptr, ); @@ -10774,6 +10795,7 @@ mod test { db_path_alice_str, 20, 10800, + false, error_ptr, ); @@ -10951,6 +10973,7 @@ mod test { db_path_alice_str, 20, 10800, + false, error_ptr, ); @@ -11089,6 +11112,7 @@ mod test { db_path_alice_str, 20, 10800, + false, error_ptr, ); @@ -11308,6 +11332,7 @@ mod test { db_path_alice_str, 20, 10800, + false, error_ptr, ); @@ -11534,6 +11559,7 @@ mod test { db_path_alice_str, 20, 10800, + false, error_ptr, ); @@ -11795,6 +11821,7 @@ mod test { db_path_str, 20, 10800, + false, error_ptr, ); let passphrase: *const c_char = CString::into_raw(CString::new("niao").unwrap()) as *const c_char; @@ -12175,6 +12202,7 @@ mod test { alice_db_path_str, 20, 10800, + false, error_ptr, ); let passphrase: *const c_char = CString::into_raw(CString::new("niao").unwrap()) as *const c_char; @@ -12239,6 +12267,7 @@ mod test { bob_db_path_str, 20, 10800, + false, error_ptr, ); let passphrase: *const c_char = CString::into_raw(CString::new("niao").unwrap()) as *const c_char; diff --git a/base_layer/wallet_ffi/wallet.h b/base_layer/wallet_ffi/wallet.h index 21a13dc6e4..c8419af904 100644 --- a/base_layer/wallet_ffi/wallet.h +++ b/base_layer/wallet_ffi/wallet.h @@ -2758,6 +2758,7 @@ void transport_config_destroy(TariTransportConfig *transport); * `database_path` - The database path char array pointer which. This is the folder path where the * database files will be created and the application has write access to * `discovery_timeout_in_secs`: specify how long the Discovery Timeout for the wallet is. + * `exclude_dial_test_addresses`: exclude dialing of test addresses; this should be 'true' for production wallets * `error_out` - Pointer to an int which will be modified to an error code should one occur, may not be null. Functions * as an out parameter. * @@ -2774,6 +2775,7 @@ TariCommsConfig *comms_config_create(const char *public_address, const char *datastore_path, unsigned long long discovery_timeout_in_secs, unsigned long long saf_message_duration_in_secs, + bool exclude_dial_test_addresses, int *error_out); /** diff --git a/clients/ffi_client/index.js b/clients/ffi_client/index.js index bfb30fdea1..6faea7baee 100644 --- a/clients/ffi_client/index.js +++ b/clients/ffi_client/index.js @@ -39,6 +39,7 @@ try { "./wallet", 30, 600, + false, err ); diff --git a/clients/ffi_client/recovery.js b/clients/ffi_client/recovery.js index 02875dea0a..286e4a5943 100644 --- a/clients/ffi_client/recovery.js +++ b/clients/ffi_client/recovery.js @@ -53,6 +53,7 @@ try { "./recovery", 30, 600, + false, err ); diff --git a/common/config/presets/c_base_node_c.toml b/common/config/presets/c_base_node_c.toml index 4c7b147950..3d72a8037c 100644 --- a/common/config/presets/c_base_node_c.toml +++ b/common/config/presets/c_base_node_c.toml @@ -303,8 +303,6 @@ database_url = "data/base_node/dht.db" #ban_duration = 21_600 # 6 * 60 * 60 # Length of time to ban a peer for a "short" duration. Default: 60 mins #ban_duration_short = 3_600 # 60 * 60 -# This allows the use of test addresses in the network like 127.0.0.1. Default: false -#allow_test_addresses = false # The maximum number of messages over `flood_ban_timespan` to allow before banning the peer (for `ban_duration_short`) # Default: 100_000 messages #flood_ban_max_msg_count = 100_000 @@ -316,5 +314,9 @@ database_url = "data/base_node/dht.db" # In a situation where a node is not well-connected and many nodes are locally marked as offline, we can retry # peers that were previously tried. Default: 2 hours #offline_peer_cooldown = 7_200 # 2 * 60 * 60 -# Addresses that should never be dialed (default value = []) -#excluded_dial_addresses = ["/ip4/x.x.x.x/tcp/xxxx", "/ip4/x.y.x.y/tcp/xyxy"] +# Addresses that should never be dialed (default value = []). This can be a specific address or an IPv4/TCP range. +# Example: When used in conjunction with `allow_test_addresses = true` (but it could be any other range) +# `excluded_dial_addresses = ["/ip4/127.*.0:49.*/tcp/*", "/ip4/127.*.101:255.*/tcp/*"]` +# or +# `excluded_dial_addresses = ["/ip4/127.0:0.1/tcp/122", "/ip4/127.0:0.1/tcp/1000:2000"]` +#excluded_dial_addresses = [] diff --git a/common/config/presets/d_console_wallet.toml b/common/config/presets/d_console_wallet.toml index 93b5a8f920..6a3c5d5f27 100644 --- a/common/config/presets/d_console_wallet.toml +++ b/common/config/presets/d_console_wallet.toml @@ -347,8 +347,6 @@ network_discovery.initial_peer_sync_delay = 25 #ban_duration = 21_600 # 6 * 60 * 60 # Length of time to ban a peer for a "short" duration. Default: 60 mins #ban_duration_short = 3_600 # 60 * 60 -# This allows the use of test addresses in the network like 127.0.0.1. Default: false -#allow_test_addresses = false # The maximum number of messages over `flood_ban_timespan` to allow before banning the peer (for `ban_duration_short`) # Default: 100_000 messages #flood_ban_max_msg_count = 100_000 @@ -360,5 +358,9 @@ network_discovery.initial_peer_sync_delay = 25 # In a situation where a node is not well-connected and many nodes are locally marked as offline, we can retry # peers that were previously tried. Default: 2 hours #offline_peer_cooldown = 7_200 # 2 * 60 * 60 -# Addresses that should never be dialed (default value = []) -#excluded_dial_addresses = ["/ip4/x.x.x.x/tcp/xxxx", "/ip4/x.y.x.y/tcp/xyxy"] +# Addresses that should never be dialed (default value = []). This can be a specific address or an IPv4/TCP range. +# Example: When used in conjunction with `allow_test_addresses = true` (but it could be any other range) +# `excluded_dial_addresses = ["/ip4/127.*.0:49.*/tcp/*", "/ip4/127.*.101:255.*/tcp/*"]` +# or +# `excluded_dial_addresses = ["/ip4/127.0:0.1/tcp/122", "/ip4/127.0:0.1/tcp/1000:2000"]` +#excluded_dial_addresses = [] diff --git a/comms/core/Cargo.toml b/comms/core/Cargo.toml index fef234ee85..8814759bf4 100644 --- a/comms/core/Cargo.toml +++ b/comms/core/Cargo.toml @@ -52,6 +52,7 @@ zeroize = "1" [dev-dependencies] tari_test_utils = { path = "../../infrastructure/test_utils" } tari_comms_rpc_macros = { path = "../rpc_macros" } +toml = { version = "0.5"} env_logger = "0.7.0" serde_json = "1.0.39" diff --git a/comms/core/src/builder/mod.rs b/comms/core/src/builder/mod.rs index 727d13cf6f..26455eabf2 100644 --- a/comms/core/src/builder/mod.rs +++ b/comms/core/src/builder/mod.rs @@ -45,6 +45,7 @@ use crate::{ connection_manager::{ConnectionManagerConfig, ConnectionManagerRequester}, connectivity::{ConnectivityConfig, ConnectivityRequester}, multiaddr::Multiaddr, + net_address::MultiaddrRange, peer_manager::{NodeIdentity, PeerManager}, peer_validator::PeerValidatorConfig, protocol::{NodeNetworkInfo, ProtocolExtensions}, @@ -242,7 +243,7 @@ impl CommsBuilder { self } - pub fn with_excluded_dial_addresses(mut self, excluded_addresses: Vec) -> Self { + pub fn with_excluded_dial_addresses(mut self, excluded_addresses: Vec) -> Self { self.connection_manager_config.excluded_dial_addresses = excluded_addresses; self } diff --git a/comms/core/src/connection_manager/dialer.rs b/comms/core/src/connection_manager/dialer.rs index 245d3e4308..357491ae22 100644 --- a/comms/core/src/connection_manager/dialer.rs +++ b/comms/core/src/connection_manager/dialer.rs @@ -55,7 +55,7 @@ use crate::{ }, multiaddr::Multiaddr, multiplexing::Yamux, - net_address::PeerAddressSource, + net_address::{MultiaddrRange, PeerAddressSource}, noise::{NoiseConfig, NoiseSocket}, peer_manager::{NodeId, NodeIdentity, Peer, PeerManager}, protocol::ProtocolId, @@ -557,7 +557,7 @@ where noise_config: &NoiseConfig, transport: &TTransport, network_byte: u8, - excluded_dial_addresses: Vec, + excluded_dial_addresses: Vec, ) -> ( DialState, Result<(NoiseSocket, Multiaddr), ConnectionManagerError>, @@ -568,7 +568,7 @@ where .clone() .into_vec() .iter() - .filter(|&a| !excluded_dial_addresses.iter().any(|excluded| a == excluded)) + .filter(|&a| !excluded_dial_addresses.iter().any(|excluded| excluded.contains(a))) .cloned() .collect::>(); if addresses.is_empty() { diff --git a/comms/core/src/connection_manager/manager.rs b/comms/core/src/connection_manager/manager.rs index 67c28679cd..a646a3dd41 100644 --- a/comms/core/src/connection_manager/manager.rs +++ b/comms/core/src/connection_manager/manager.rs @@ -49,6 +49,7 @@ use crate::{ backoff::Backoff, connection_manager::ConnectionId, multiplexing::Substream, + net_address::MultiaddrRange, noise::NoiseConfig, peer_manager::{NodeId, NodeIdentity, PeerManagerError}, peer_validator::PeerValidatorConfig, @@ -134,7 +135,7 @@ pub struct ConnectionManagerConfig { /// Peer validation configuration. See [PeerValidatorConfig] pub peer_validation_config: PeerValidatorConfig, /// Addresses that should never be dialed - pub excluded_dial_addresses: Vec, + pub excluded_dial_addresses: Vec, } impl Default for ConnectionManagerConfig { diff --git a/comms/core/src/connection_manager/tests/listener_dialer.rs b/comms/core/src/connection_manager/tests/listener_dialer.rs index a1c244b838..e73f052379 100644 --- a/comms/core/src/connection_manager/tests/listener_dialer.rs +++ b/comms/core/src/connection_manager/tests/listener_dialer.rs @@ -287,7 +287,7 @@ async fn excluded_yes() { let (request_tx, request_rx) = mpsc::channel(1); let peer_manager2 = build_peer_manager(); let connection_manager_config = ConnectionManagerConfig { - excluded_dial_addresses: vec![address.clone()], + excluded_dial_addresses: vec![address.to_string().parse().unwrap()], ..Default::default() }; let mut dialer = Dialer::new( diff --git a/comms/core/src/net_address/mod.rs b/comms/core/src/net_address/mod.rs index a437636096..1219f0d95a 100644 --- a/comms/core/src/net_address/mod.rs +++ b/comms/core/src/net_address/mod.rs @@ -27,3 +27,6 @@ pub use multiaddr_with_stats::{MultiaddrWithStats, PeerAddressSource}; mod mutliaddresses_with_stats; pub use mutliaddresses_with_stats::MultiaddressesWithStats; + +mod multiaddr_range; +pub use multiaddr_range::{MultiaddrRange, MultiaddrRangeList, IP4_TCP_TEST_ADDR_RANGE}; diff --git a/comms/core/src/net_address/multiaddr_range.rs b/comms/core/src/net_address/multiaddr_range.rs new file mode 100644 index 0000000000..1bbf2b913a --- /dev/null +++ b/comms/core/src/net_address/multiaddr_range.rs @@ -0,0 +1,506 @@ +// Copyright 2022 The Tari Project +// SPDX-License-Identifier: BSD-3-Clause + +use std::{fmt, net::Ipv4Addr, ops::Deref, slice, str::FromStr}; + +use multiaddr::{Multiaddr, Protocol}; +use serde::{ + de, + de::{Error, SeqAccess, Visitor}, + Deserialize, + Deserializer, + Serialize, +}; + +/// A MultiaddrRange for testing purposes that matches any IPv4 address and any port +pub const IP4_TCP_TEST_ADDR_RANGE: &str = "/ip4/127.*.*.*/tcp/*"; + +/// ----------------- MultiaddrRange ----------------- +/// A struct containing either an Ipv4AddrRange or a Multiaddr. If a range of IP addresses and/or ports needs to be +/// specified, the MultiaddrRange can be used, but it only supports IPv4 addresses with the TCP protocol. +#[derive(Debug, Clone, Serialize, PartialEq, Eq)] +pub enum MultiaddrRange { + Ipv4AddrRange(Ipv4AddrRange), + Multiaddr(Multiaddr), +} + +impl FromStr for MultiaddrRange { + type Err = String; + + fn from_str(s: &str) -> Result { + if let Ok(multiaddr) = Multiaddr::from_str(s) { + Ok(MultiaddrRange::Multiaddr(multiaddr)) + } else if let Ok(ipv4_addr_range) = Ipv4AddrRange::from_str(s) { + Ok(MultiaddrRange::Ipv4AddrRange(ipv4_addr_range)) + } else { + Err("Invalid format for both Multiaddr and Ipv4AddrRange".to_string()) + } + } +} + +impl fmt::Display for MultiaddrRange { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + MultiaddrRange::Ipv4AddrRange(ipv4_addr_range) => write!(f, "{}", ipv4_addr_range), + MultiaddrRange::Multiaddr(multiaddr) => write!(f, "{}", multiaddr), + } + } +} + +impl MultiaddrRange { + /// Check if the given Multiaddr is contained within the MultiaddrRange range + pub fn contains(&self, addr: &Multiaddr) -> bool { + match self { + MultiaddrRange::Ipv4AddrRange(ipv4_addr_range) => ipv4_addr_range.contains(addr), + MultiaddrRange::Multiaddr(multiaddr) => multiaddr == addr, + } + } +} + +/// ----------------- Ipv4AddrRange ----------------- +/// A struct containing an Ipv4Range and a PortRange +#[derive(Debug, Clone, Serialize, PartialEq, Eq)] +pub struct Ipv4AddrRange { + ip_range: Ipv4Range, + port_range: PortRange, +} + +impl FromStr for Ipv4AddrRange { + type Err = String; + + fn from_str(s: &str) -> Result { + let parts: Vec<&str> = s.split('/').collect(); + if parts.len() != 5 { + return Err("Invalid multiaddr format".to_string()); + } + + if parts[1] != "ip4" { + return Err("Only IPv4 addresses are supported".to_string()); + } + + let ip_range = Ipv4Range::new(parts[2])?; + if parts[3] != "tcp" { + return Err("Only TCP protocol is supported".to_string()); + } + + let port_range = PortRange::new(parts[4])?; + Ok(Ipv4AddrRange { ip_range, port_range }) + } +} + +impl Ipv4AddrRange { + fn contains(&self, addr: &Multiaddr) -> bool { + let mut ip = None; + let mut port = None; + + for protocol in addr { + match protocol { + Protocol::Ip4(ipv4) => ip = Some(ipv4), + Protocol::Tcp(tcp_port) => port = Some(tcp_port), + _ => {}, + } + } + + if let (Some(ip), Some(port)) = (ip, port) { + return self.ip_range.contains(ip) && self.port_range.contains(port); + } + + false + } +} + +impl fmt::Display for Ipv4AddrRange { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "/ip4/{}/tcp/{}", self.ip_range, self.port_range) + } +} + +// ----------------- Ipv4Range ----------------- +// A struct containing the start and end Ipv4Addr +#[derive(Debug, Clone, Serialize, PartialEq, Eq)] +struct Ipv4Range { + start: Ipv4Addr, + end: Ipv4Addr, +} + +impl Ipv4Range { + fn new(range_str: &str) -> Result { + let parts: Vec<&str> = range_str.split('.').collect(); + if parts.len() != 4 { + return Err("Invalid IP range format".to_string()); + } + + let mut start_octets = [0u8; 4]; + let mut end_octets = [0u8; 4]; + + for (i, part) in parts.iter().enumerate() { + if i == 0 { + start_octets[i] = part.parse().map_err(|_| "Invalid first IPv4 octet".to_string())?; + end_octets[i] = start_octets[i]; + } else if part == &"*" { + start_octets[i] = 0; + end_octets[i] = u8::MAX; + } else if part.contains(':') { + let range_parts: Vec<&str> = part.split(':').collect(); + if range_parts.len() != 2 { + return Err(format!("Invalid range format for IPv4 octet {}", i)); + } + start_octets[i] = range_parts[0] + .parse() + .map_err(|_| format!("Invalid range start for IPv4 octet {}", i))?; + end_octets[i] = range_parts[1] + .parse() + .map_err(|_| format!("Invalid range end for IPv4 octet {}", i))?; + } else { + start_octets[i] = part.parse().map_err(|_| format!("Invalid IPv4 octet {}", i))?; + end_octets[i] = start_octets[i]; + } + } + + Ok(Ipv4Range { + start: Ipv4Addr::new(start_octets[0], start_octets[1], start_octets[2], start_octets[3]), + end: Ipv4Addr::new(end_octets[0], end_octets[1], end_octets[2], end_octets[3]), + }) + } + + fn contains(&self, addr: Ipv4Addr) -> bool { + let octets = addr.octets(); + let start_octets = self.start.octets(); + let end_octets = self.end.octets(); + + for i in 0..4 { + if octets[i] < start_octets[i] || octets[i] > end_octets[i] { + return false; + } + } + true + } +} + +impl fmt::Display for Ipv4Range { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let start_octets = self.start.octets(); + let end_octets = self.end.octets(); + write!( + f, + "{}.{}.{}.{}", + start_octets[0], + if start_octets[1] == 0 && end_octets[1] == u8::MAX { + "*".to_string() + } else if start_octets[1] == end_octets[1] { + start_octets[1].to_string() + } else { + format!("{}:{}", start_octets[1], end_octets[1]) + }, + if start_octets[2] == 0 && end_octets[2] == u8::MAX { + "*".to_string() + } else if start_octets[2] == end_octets[2] { + start_octets[2].to_string() + } else { + format!("{}:{}", start_octets[2], end_octets[2]) + }, + if start_octets[3] == 0 && end_octets[3] == u8::MAX { + "*".to_string() + } else if start_octets[3] == end_octets[3] { + start_octets[3].to_string() + } else { + format!("{}:{}", start_octets[3], end_octets[3]) + } + ) + } +} + +// ----------------- PortRange ----------------- +// A struct containing the start and end port +#[derive(Debug, Clone, Serialize, PartialEq, Eq)] +struct PortRange { + start: u16, + end: u16, +} + +impl PortRange { + fn new(range_str: &str) -> Result { + if range_str == "*" { + return Ok(PortRange { + start: 0, + end: u16::MAX, + }); + } + + if range_str.contains(':') { + let parts: Vec<&str> = range_str.split(':').collect(); + if parts.len() != 2 { + return Err("Invalid port range format".to_string()); + } + let start = parts[0] + .parse() + .map_err(|_| format!("Invalid port range start '{}'", parts[0]))?; + let end = parts[1] + .parse() + .map_err(|_| format!("Invalid port range end '{}'", parts[1]))?; + if end < start { + return Err(format!( + "Invalid port range '{}', end `{}` is less than start `{}`", + range_str, end, start + )); + } + return Ok(PortRange { start, end }); + } + + let port = range_str.parse().map_err(|_| "Invalid port".to_string())?; + Ok(PortRange { start: port, end: port }) + } + + fn contains(&self, port: u16) -> bool { + port >= self.start && port <= self.end + } +} + +impl fmt::Display for PortRange { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.start <= 1 && self.end == u16::MAX { + write!(f, "*") + } else if self.start == self.end { + write!(f, "{}", self.start) + } else { + write!(f, "{}:{}", self.start, self.end) + } + } +} + +/// ----------------- MultiaddrRangeList ----------------- +/// Supports deserialization from a sequence of strings or comma-delimited strings +#[derive(Debug, Default, Clone, Serialize, PartialEq, Eq)] +pub struct MultiaddrRangeList(Vec); + +impl MultiaddrRangeList { + pub fn new() -> Self { + Self(vec![]) + } + + pub fn with_capacity(size: usize) -> Self { + Self(Vec::with_capacity(size)) + } + + pub fn into_vec(self) -> Vec { + self.0 + } + + pub fn as_slice(&self) -> &[MultiaddrRange] { + self.0.as_slice() + } +} + +impl Deref for MultiaddrRangeList { + type Target = [MultiaddrRange]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl AsRef<[MultiaddrRange]> for MultiaddrRangeList { + fn as_ref(&self) -> &[MultiaddrRange] { + self.0.as_ref() + } +} + +impl From> for MultiaddrRangeList { + fn from(v: Vec) -> Self { + Self(v) + } +} + +impl IntoIterator for MultiaddrRangeList { + type IntoIter = as IntoIterator>::IntoIter; + type Item = as IntoIterator>::Item; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl<'a> IntoIterator for &'a MultiaddrRangeList { + type IntoIter = slice::Iter<'a, MultiaddrRange>; + type Item = ::Item; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter() + } +} + +impl<'de> Deserialize<'de> for MultiaddrRangeList { + fn deserialize(deserializer: D) -> Result + where D: Deserializer<'de> { + struct MultiaddrRangeListVisitor; + + impl<'de> Visitor<'de> for MultiaddrRangeListVisitor { + type Value = MultiaddrRangeList; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "a comma delimited string or multiple string elements") + } + + fn visit_str(self, v: &str) -> Result + where E: de::Error { + let strings = v + .split(',') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .collect::>(); + let multiaddr_ranges: Result, _> = strings + .into_iter() + .map(|item| MultiaddrRange::from_str(item).map_err(E::custom)) + .collect(); + Ok(MultiaddrRangeList(multiaddr_ranges.map_err(E::custom)?)) + } + + fn visit_newtype_struct(self, deserializer: D) -> Result + where D: Deserializer<'de> { + deserializer.deserialize_seq(MultiaddrRangeListVisitor) + } + + fn visit_seq(self, mut seq: A) -> Result + where A: SeqAccess<'de> { + let mut buf = seq.size_hint().map(Vec::with_capacity).unwrap_or_default(); + while let Some(v) = seq.next_element::()? { + buf.push(v) + } + Ok(MultiaddrRangeList(buf)) + } + } + + if deserializer.is_human_readable() { + deserializer.deserialize_seq(MultiaddrRangeListVisitor) + } else { + deserializer.deserialize_newtype_struct("MultiaddrRangeList", MultiaddrRangeListVisitor) + } + } +} + +impl<'de> Deserialize<'de> for MultiaddrRange { + fn deserialize(deserializer: D) -> Result + where D: Deserializer<'de> { + let s = String::deserialize(deserializer)?; + MultiaddrRange::from_str(&s).map_err(D::Error::custom) + } +} + +#[cfg(test)] +mod test { + use std::{ + net::{IpAddr, Ipv6Addr}, + str::FromStr, + }; + + use serde::Deserialize; + + use crate::{ + multiaddr::Multiaddr, + net_address::{multiaddr_range::IP4_TCP_TEST_ADDR_RANGE, MultiaddrRange, MultiaddrRangeList}, + }; + + #[derive(Deserialize)] + struct Test { + something: MultiaddrRangeList, + } + + #[test] + fn it_parses_with_serde() { + // Random tests + let config_str = r#"something = [ + "/ip4/127.*.100:200.*/tcp/18000:19000", + "/ip4/127.0.150.1/tcp/18500", + "/ip4/127.0.0.1/udt/sctp/5678", + "/ip4/127.*.*.*/tcp/*" + ]"#; + let item_vec = toml::from_str::(config_str).unwrap().something.into_vec(); + assert_eq!(item_vec, vec![ + MultiaddrRange::from_str("/ip4/127.*.100:200.*/tcp/18000:19000").unwrap(), + MultiaddrRange::from_str("/ip4/127.0.150.1/tcp/18500").unwrap(), + MultiaddrRange::from_str("/ip4/127.0.0.1/udt/sctp/5678").unwrap(), + MultiaddrRange::from_str(IP4_TCP_TEST_ADDR_RANGE).unwrap() + ]); + + // Allowing only '/ip4/127.0.0.1/tcp/0:18189' + let config_str = r#"something = [ + "/ip4/127.*.*.*/tcp/0:18188", + "/ip4/127.*.*.*/tcp/18190:65535", + "/ip4/127.0.0.0/tcp/18189", + "/ip4/127.1:255.1:255.2:255/tcp/18189" + ]"#; + let item_vec = toml::from_str::(config_str).unwrap().something.into_vec(); + assert_eq!(item_vec, vec![ + MultiaddrRange::from_str("/ip4/127.*.*.*/tcp/0:18188").unwrap(), + MultiaddrRange::from_str("/ip4/127.*.*.*/tcp/18190:65535").unwrap(), + MultiaddrRange::from_str("/ip4/127.0.0.0/tcp/18189").unwrap(), + MultiaddrRange::from_str("/ip4/127.1:255.1:255.2:255/tcp/18189").unwrap(), + ]); + + for item in item_vec { + assert!(!item.contains(&Multiaddr::from_str("/ip4/127.0.0.1/tcp/18189").unwrap())); + } + } + + #[test] + fn it_parses_properly_and_verify_inclusion() { + // MultiaddrRange for ip4 with tcp + + let my_addr_range: MultiaddrRange = "/ip4/127.*.100:200.*/tcp/18000:19000".parse().unwrap(); + let addr: Multiaddr = "/ip4/127.0.150.1/tcp/18500".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.150.1/tcp/17500".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.50.1/tcp/18500".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + + let my_addr_range: MultiaddrRange = "/ip4/127.*.100:200.*/tcp/*".parse().unwrap(); + let addr: Multiaddr = "/ip4/127.0.150.1/tcp/18500".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.150.1/tcp/17500".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.50.1/tcp/17500".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + + let my_addr_range: MultiaddrRange = "/ip4/127.0.0.1/tcp/18000:19000".parse().unwrap(); + let addr: Multiaddr = "/ip4/127.0.0.1/tcp/18500".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.1.1/tcp/18500".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.0.1/tcp/17500".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + + let my_addr_range: MultiaddrRange = "/ip4/127.0.0.1/tcp/18188".parse().unwrap(); + let addr: Multiaddr = "/ip4/127.0.0.1/tcp/18188".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.1.1/tcp/18188".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.0.1/tcp/18189".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + + let my_addr_range: MultiaddrRange = IP4_TCP_TEST_ADDR_RANGE.parse().unwrap(); + let addr: Multiaddr = "/ip4/127.0.0.1/tcp/18188".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.0.1/tcp/18189".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.1.2.3/tcp/555".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + + // MultiaddrRange for other protocols + + let my_addr_range: MultiaddrRange = "/ip4/127.0.0.1/udt/sctp/5678".parse().unwrap(); + let addr: Multiaddr = "/ip4/127.0.0.1/udt/sctp/5678".parse().unwrap(); + assert!(my_addr_range.contains(&addr)); + let addr: Multiaddr = "/ip4/127.0.0.1/udt/sctp/5679".parse().unwrap(); + assert!(!my_addr_range.contains(&addr)); + + let my_addr_range: MultiaddrRange = Multiaddr::from(IpAddr::V6(Ipv6Addr::new(0x2001, 0x2, 0, 0, 0x1, 0, 0, 0))) + .to_string() + .parse() + .unwrap(); + let addr = Multiaddr::from(IpAddr::V6(Ipv6Addr::new(0x2001, 0x2, 0, 0, 0x1, 0, 0, 0))); + assert!(my_addr_range.contains(&addr)); + let addr = Multiaddr::from(IpAddr::V6(Ipv6Addr::new(0x2001, 0x2, 0, 0, 0, 0, 0, 0))); + assert!(!my_addr_range.contains(&addr)); + } +} diff --git a/comms/dht/src/actor.rs b/comms/dht/src/actor.rs index 57790c0d44..0291e81bc4 100644 --- a/comms/dht/src/actor.rs +++ b/comms/dht/src/actor.rs @@ -35,7 +35,7 @@ use log::*; use tari_comms::{ connection_manager::ConnectionManagerError, connectivity::{ConnectivityError, ConnectivityRequester, ConnectivitySelection}, - multiaddr::Multiaddr, + net_address::MultiaddrRange, peer_manager::{NodeId, NodeIdentity, PeerFeatures, PeerManager, PeerManagerError, PeerQuery, PeerQuerySortBy}, types::CommsPublicKey, PeerConnection, @@ -386,7 +386,7 @@ impl DhtActor { // Helper function to check if all peer addresses are excluded async fn check_if_addresses_excluded( - excluded_dial_addresses: Vec, + excluded_dial_addresses: Vec, peer_manager: &PeerManager, node_id: NodeId, ) -> Result<(), DhtActorError> { @@ -394,7 +394,7 @@ impl DhtActor { let addresses = peer_manager.get_peer_multi_addresses(&node_id).await?; if addresses .iter() - .all(|addr| excluded_dial_addresses.contains(addr.address())) + .all(|addr| excluded_dial_addresses.iter().any(|v| v.contains(addr.address()))) { warn!( target: LOG_TARGET, @@ -419,7 +419,7 @@ impl DhtActor { Box::pin(Self::broadcast_join( node_identity, peer_manager, - excluded_dial_addresses, + excluded_dial_addresses.into_vec(), outbound_requester, )) }, @@ -502,7 +502,7 @@ impl DhtActor { Box::pin(async move { DhtActor::check_if_addresses_excluded( - excluded_dial_addresses, + excluded_dial_addresses.into_vec(), &peer_manager, node_identity.node_id().clone(), ) @@ -533,7 +533,7 @@ impl DhtActor { async fn broadcast_join( node_identity: Arc, peer_manager: Arc, - excluded_dial_addresses: Vec, + excluded_dial_addresses: Vec, mut outbound_requester: OutboundMessageRequester, ) -> Result<(), DhtActorError> { DhtActor::check_if_addresses_excluded( @@ -748,10 +748,12 @@ impl DhtActor { let mut filtered_peers = Vec::with_capacity(peers.len()); for id in &peers { let addresses = peer_manager.get_peer_multi_addresses(id).await?; - if addresses - .iter() - .all(|addr| config.excluded_dial_addresses.contains(addr.address())) - { + if addresses.iter().all(|addr| { + config + .excluded_dial_addresses + .iter() + .any(|v| v.contains(addr.address())) + }) { trace!(target: LOG_TARGET, "Peer '{}' has only excluded addresses. Skipping.", id); } else { filtered_peers.push(id.clone()); diff --git a/comms/dht/src/config.rs b/comms/dht/src/config.rs index 6f3053539a..069c77b9f1 100644 --- a/comms/dht/src/config.rs +++ b/comms/dht/src/config.rs @@ -24,7 +24,7 @@ use std::{path::Path, time::Duration}; use serde::{Deserialize, Serialize}; use tari_common::configuration::serializers; -use tari_comms::{multiaddr::Multiaddr, peer_validator::PeerValidatorConfig}; +use tari_comms::{net_address::MultiaddrRangeList, peer_validator::PeerValidatorConfig}; use crate::{ actor::OffenceSeverity, @@ -94,7 +94,6 @@ pub struct DhtConfig { /// Default: 10 mins #[serde(with = "serializers::seconds")] pub ban_duration_short: Duration, - /// The maximum number of messages over `flood_ban_timespan` to allow before banning the peer (for /// `ban_duration_short`) Default: 100_000 messages pub flood_ban_max_msg_count: usize, @@ -115,8 +114,12 @@ pub struct DhtConfig { /// Configuration for peer validation /// See [PeerValidatorConfig] pub peer_validator_config: PeerValidatorConfig, - /// Addresses that should never be dialed - pub excluded_dial_addresses: Vec, + /// Addresses that should never be dialed (default value = []). This can be a specific address or an IPv4/TCP + /// range. Example: When used in conjunction with `allow_test_addresses = true` (but it could be any other + /// range) `excluded_dial_addresses = ["/ip4/127.*.0:49.*/tcp/*", "/ip4/127.*.101:255.*/tcp/*"]` + /// or + /// `excluded_dial_addresses = ["/ip4/127.0:0.1/tcp/122", "/ip4/127.0:0.1/tcp/1000:2000"]` + pub excluded_dial_addresses: MultiaddrRangeList, } impl DhtConfig { @@ -195,7 +198,7 @@ impl Default for DhtConfig { max_permitted_peer_claims: 5, offline_peer_cooldown: Duration::from_secs(24 * 60 * 60), peer_validator_config: Default::default(), - excluded_dial_addresses: vec![], + excluded_dial_addresses: vec![].into(), } } } diff --git a/comms/dht/src/connectivity/mod.rs b/comms/dht/src/connectivity/mod.rs index b0294e9184..47d85ce8b4 100644 --- a/comms/dht/src/connectivity/mod.rs +++ b/comms/dht/src/connectivity/mod.rs @@ -870,10 +870,12 @@ impl DhtConnectivity { let mut neighbours = Vec::with_capacity(self.neighbours.len()); for peer in &self.neighbours { let addresses = self.peer_manager.get_peer_multi_addresses(peer).await?; - if !addresses - .iter() - .all(|addr| self.config.excluded_dial_addresses.contains(addr.address())) - { + if !addresses.iter().all(|addr| { + self.config + .excluded_dial_addresses + .iter() + .any(|v| v.contains(addr.address())) + }) { neighbours.push(peer.clone()); } } @@ -882,10 +884,12 @@ impl DhtConnectivity { let mut random_pool = Vec::with_capacity(self.random_pool.len()); for peer in &self.random_pool { let addresses = self.peer_manager.get_peer_multi_addresses(peer).await?; - if !addresses - .iter() - .all(|addr| self.config.excluded_dial_addresses.contains(addr.address())) - { + if !addresses.iter().all(|addr| { + self.config + .excluded_dial_addresses + .iter() + .any(|v| v.contains(addr.address())) + }) { random_pool.push(peer.clone()); } } diff --git a/integration_tests/src/ffi/comms_config.rs b/integration_tests/src/ffi/comms_config.rs index 2d1eb6e8da..b8de76866b 100644 --- a/integration_tests/src/ffi/comms_config.rs +++ b/integration_tests/src/ffi/comms_config.rs @@ -49,6 +49,7 @@ impl CommsConfig { CString::new(base_dir).unwrap().into_raw(), 30, 600, + false, // This needs to be 'false' for the tests to pass &mut error, ); if error > 0 { diff --git a/integration_tests/src/ffi/ffi_import.rs b/integration_tests/src/ffi/ffi_import.rs index 70830118ed..539c73ff7b 100644 --- a/integration_tests/src/ffi/ffi_import.rs +++ b/integration_tests/src/ffi/ffi_import.rs @@ -367,6 +367,7 @@ extern "C" { datastore_path: *const c_char, discovery_timeout_in_secs: c_ulonglong, saf_message_duration_in_secs: c_ulonglong, + exclude_dial_test_addresses: bool, error_out: *mut c_int, ) -> *mut TariCommsConfig; pub fn comms_config_destroy(wc: *mut TariCommsConfig);