Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: edge-case fixes for wallet peer switching in console wallet #3226

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions applications/tari_console_wallet/src/ui/state/app_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ use tari_crypto::{ristretto::RistrettoPublicKey, tari_utilities::hex::Hex};
use tari_shutdown::ShutdownSignal;
use tari_wallet::{
base_node_service::{handle::BaseNodeEventReceiver, service::BaseNodeState},
connectivity_service::WalletConnectivityHandle,
contacts_service::storage::database::Contact,
output_manager_service::{handle::OutputManagerEventReceiver, service::Balance, TxId, TxoValidationType},
transaction_service::{
Expand Down Expand Up @@ -652,6 +653,10 @@ impl AppStateInner {
self.wallet.comms.connectivity().get_event_subscription().fuse()
}

pub fn get_wallet_connectivity(&self) -> WalletConnectivityHandle {
self.wallet.wallet_connectivity.clone()
}

pub fn get_base_node_event_stream(&self) -> Fuse<BaseNodeEventReceiver> {
self.wallet.base_node_service.clone().get_event_stream_fused()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ impl WalletEventMonitor {
.get_output_manager_service_event_stream();

let mut connectivity_events = self.app_state_inner.read().await.get_connectivity_event_stream();
let wallet_connectivity = self.app_state_inner.read().await.get_wallet_connectivity();
let mut connectivity_status = wallet_connectivity.get_connectivity_status_watch().fuse();

let mut base_node_events = self.app_state_inner.read().await.get_base_node_event_stream();

Expand Down Expand Up @@ -105,17 +107,18 @@ impl WalletEventMonitor {
Err(_) => debug!(target: LOG_TARGET, "Lagging read on Transaction Service event broadcast channel"),
}
},
status = connectivity_status.select_next_some() => {
trace!(target: LOG_TARGET, "Wallet Event Monitor received wallet connectivity status {:?}", status);
self.trigger_peer_state_refresh().await;
},
result = connectivity_events.select_next_some() => {
match result {
Ok(msg) => {
trace!(target: LOG_TARGET, "Wallet Event Monitor received wallet connectivity event {:?}", msg);
match &*msg {
ConnectivityEvent::PeerDisconnected(_) |
ConnectivityEvent::ManagedPeerDisconnected(_) |
ConnectivityEvent::PeerConnected(_) |
ConnectivityEvent::PeerBanned(_) |
ConnectivityEvent::PeerOffline(_) |
ConnectivityEvent::PeerConnectionWillClose(_, _) => {
ConnectivityEvent::PeerConnected(_) => {
self.trigger_peer_state_refresh().await;
},
// Only the above variants trigger state refresh
Expand Down
2 changes: 1 addition & 1 deletion base_layer/wallet/src/base_node_service/monitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ impl<T: WalletBackend + 'static> BaseNodeMonitor<T> {
}

async fn update_connectivity_status(&self) -> NodeId {
let mut watcher = self.wallet_connectivity.get_connectivity_status_watcher();
let mut watcher = self.wallet_connectivity.get_connectivity_status_watch();
loop {
use OnlineStatus::*;
match watcher.recv().await.unwrap_or(Offline) {
Expand Down
19 changes: 8 additions & 11 deletions base_layer/wallet/src/connectivity_service/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use super::service::OnlineStatus;
use crate::connectivity_service::error::WalletConnectivityError;
use crate::connectivity_service::{error::WalletConnectivityError, watch::Watch};
use futures::{
channel::{mpsc, oneshot},
SinkExt,
Expand All @@ -36,33 +36,30 @@ use tokio::sync::watch;
pub enum WalletConnectivityRequest {
ObtainBaseNodeWalletRpcClient(oneshot::Sender<RpcClientLease<BaseNodeWalletRpcClient>>),
ObtainBaseNodeSyncRpcClient(oneshot::Sender<RpcClientLease<BaseNodeSyncRpcClient>>),
SetBaseNode(Box<Peer>),
}

#[derive(Clone)]
pub struct WalletConnectivityHandle {
sender: mpsc::Sender<WalletConnectivityRequest>,
base_node_watch_rx: watch::Receiver<Option<Peer>>,
base_node_watch: Watch<Option<Peer>>,
online_status_rx: watch::Receiver<OnlineStatus>,
}

impl WalletConnectivityHandle {
pub(super) fn new(
sender: mpsc::Sender<WalletConnectivityRequest>,
base_node_watch_rx: watch::Receiver<Option<Peer>>,
base_node_watch: Watch<Option<Peer>>,
online_status_rx: watch::Receiver<OnlineStatus>,
) -> Self {
Self {
sender,
base_node_watch_rx,
base_node_watch,
online_status_rx,
}
}

pub async fn set_base_node(&mut self, base_node_peer: Peer) -> Result<(), WalletConnectivityError> {
self.sender
.send(WalletConnectivityRequest::SetBaseNode(Box::new(base_node_peer)))
.await?;
self.base_node_watch.broadcast(Some(base_node_peer));
Ok(())
}

Expand Down Expand Up @@ -109,15 +106,15 @@ impl WalletConnectivityHandle {
self.online_status_rx.recv().await.unwrap_or(OnlineStatus::Offline)
}

pub fn get_connectivity_status_watcher(&self) -> watch::Receiver<OnlineStatus> {
pub fn get_connectivity_status_watch(&self) -> watch::Receiver<OnlineStatus> {
self.online_status_rx.clone()
}

pub fn get_current_base_node_peer(&self) -> Option<Peer> {
self.base_node_watch_rx.borrow().clone()
self.base_node_watch.borrow().clone()
}

pub fn get_current_base_node_id(&self) -> Option<NodeId> {
self.base_node_watch_rx.borrow().as_ref().map(|p| p.node_id.clone())
self.base_node_watch.borrow().as_ref().map(|p| p.node_id.clone())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl ServiceInitializer for WalletConnectivityInitializer {
let online_status_watch = Watch::new(OnlineStatus::Offline);
context.register_handle(WalletConnectivityHandle::new(
sender,
base_node_watch.get_receiver(),
base_node_watch.clone(),
online_status_watch.get_receiver(),
));

Expand Down
47 changes: 36 additions & 11 deletions base_layer/wallet/src/connectivity_service/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ use crate::{
use core::mem;
use futures::{
channel::{mpsc, oneshot},
future,
future::Either,
stream::Fuse,
StreamExt,
};
Expand All @@ -35,6 +37,7 @@ use tari_comms::{
connectivity::ConnectivityRequester,
peer_manager::{NodeId, Peer},
protocol::rpc::{RpcClientLease, RpcClientPool},
PeerConnection,
};
use tari_core::base_node::{rpc::BaseNodeWalletRpcClient, sync::rpc::BaseNodeSyncRpcClient};
use tokio::time;
Expand Down Expand Up @@ -110,10 +113,6 @@ impl WalletConnectivityService {
ObtainBaseNodeSyncRpcClient(reply) => {
self.handle_pool_request(reply.into()).await;
},

SetBaseNode(peer) => {
self.set_base_node_peer(*peer);
},
}
}

Expand Down Expand Up @@ -202,11 +201,15 @@ impl WalletConnectivityService {
self.base_node_watch.broadcast(Some(peer));
}

fn current_base_node(&self) -> Option<NodeId> {
self.base_node_watch.borrow().as_ref().map(|p| p.node_id.clone())
}

async fn setup_base_node_connection(&mut self) {
self.pools = None;
loop {
let node_id = match self.base_node_watch.borrow().as_ref() {
Some(p) => p.node_id.clone(),
let node_id = match self.current_base_node() {
Some(n) => n,
None => return,
};
debug!(
Expand All @@ -215,16 +218,24 @@ impl WalletConnectivityService {
);
self.set_online_status(OnlineStatus::Connecting);
match self.try_setup_rpc_pool(node_id.clone()).await {
Ok(_) => {
Ok(true) => {
self.set_online_status(OnlineStatus::Online);
debug!(
target: LOG_TARGET,
"Wallet is ONLINE and connected to base node {}", node_id
);
break;
},
Ok(false) => {
// Retry with updated peer
continue;
},
Err(e) => {
self.set_online_status(OnlineStatus::Offline);
if self.current_base_node() != Some(node_id) {
self.set_online_status(OnlineStatus::Connecting);
} else {
self.set_online_status(OnlineStatus::Offline);
}
error!(target: LOG_TARGET, "{}", e);
time::delay_for(self.config.base_node_monitor_refresh_interval).await;
continue;
Expand All @@ -237,9 +248,12 @@ impl WalletConnectivityService {
let _ = self.online_status_watch.broadcast(status);
}

async fn try_setup_rpc_pool(&mut self, peer: NodeId) -> Result<(), WalletConnectivityError> {
async fn try_setup_rpc_pool(&mut self, peer: NodeId) -> Result<bool, WalletConnectivityError> {
self.connectivity.add_managed_peers(vec![peer.clone()]).await?;
let conn = self.connectivity.dial_peer(peer).await?;
let conn = match self.try_dial_peer(peer).await? {
Some(peer) => peer,
None => return Ok(false),
};
debug!(
target: LOG_TARGET,
"Successfully established peer connection to base node {}",
Expand All @@ -257,7 +271,18 @@ impl WalletConnectivityService {
"Successfully established RPC connection {}",
conn.peer_node_id()
);
Ok(())
Ok(true)
}

async fn try_dial_peer(&mut self, peer: NodeId) -> Result<Option<PeerConnection>, WalletConnectivityError> {
let recv_fut = self.base_node_watch.recv();
futures::pin_mut!(recv_fut);
let dial_fut = self.connectivity.dial_peer(peer);
futures::pin_mut!(dial_fut);
match future::select(recv_fut, dial_fut).await {
Either::Left(_) => Ok(None),
Either::Right((conn, _)) => Ok(Some(conn?)),
}
}

async fn notify_pending_requests(&mut self) -> Result<(), WalletConnectivityError> {
Expand Down
6 changes: 3 additions & 3 deletions base_layer/wallet/src/connectivity_service/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async fn setup() -> (
let (tx, rx) = mpsc::channel(1);
let base_node_watch = Watch::new(None);
let online_status_watch = Watch::new(OnlineStatus::Offline);
let handle = WalletConnectivityHandle::new(tx, base_node_watch.get_receiver(), online_status_watch.get_receiver());
let handle = WalletConnectivityHandle::new(tx, base_node_watch.clone(), online_status_watch.get_receiver());
let (connectivity, mock) = create_connectivity_mock();
let mock_state = mock.spawn();
// let peer_manager = create_peer_manager(tempdir().unwrap());
Expand Down Expand Up @@ -138,7 +138,7 @@ async fn it_changes_to_a_new_base_node() {

mock_state.await_call_count(2).await;
mock_state.expect_dial_peer(base_node_peer1.node_id()).await;
assert_eq!(mock_state.count_calls_containing("AddManagedPeer").await, 1);
assert_eq!(mock_state.count_calls_containing("AddManagedPeer").await, 2);
let _ = mock_state.take_calls().await;

let rpc_client = handle.obtain_base_node_wallet_rpc_client().await.unwrap();
Expand All @@ -149,7 +149,7 @@ async fn it_changes_to_a_new_base_node() {

mock_state.await_call_count(2).await;
mock_state.expect_dial_peer(base_node_peer2.node_id()).await;
assert_eq!(mock_state.count_calls_containing("AddManagedPeer").await, 1);
assert_eq!(mock_state.count_calls_containing("AddManagedPeer").await, 2);

let rpc_client = handle.obtain_base_node_wallet_rpc_client().await.unwrap();
assert!(rpc_client.is_connected());
Expand Down
2 changes: 1 addition & 1 deletion base_layer/wallet/src/connectivity_service/watch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl<T: Clone> Watch<T> {
self.receiver_mut().recv().await
}

pub fn borrow(&mut self) -> watch::Ref<'_, T> {
pub fn borrow(&self) -> watch::Ref<'_, T> {
self.receiver().borrow()
}

Expand Down
18 changes: 8 additions & 10 deletions common/src/configuration/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -518,18 +518,16 @@ fn convert_node_config(
let console_wallet_notify_file = optional(cfg.get_str(key))?.map(PathBuf::from);

let key = "wallet.base_node_service_refresh_interval";
let wallet_base_node_service_refresh_interval = match cfg.get_int(key) {
Ok(seconds) => seconds as u64,
Err(ConfigError::NotFound(_)) => 10,
Err(e) => return Err(ConfigurationError::new(&key, &e.to_string())),
};
let wallet_base_node_service_refresh_interval = cfg
.get_int(key)
.map(|seconds| seconds as u64)
.map_err(|e| ConfigurationError::new(&key, &e.to_string()))?;

let key = "wallet.base_node_service_request_max_age";
let wallet_base_node_service_request_max_age = match cfg.get_int(key) {
Ok(seconds) => seconds as u64,
Err(ConfigError::NotFound(_)) => 60,
Err(e) => return Err(ConfigurationError::new(&key, &e.to_string())),
};
let wallet_base_node_service_request_max_age = cfg
.get_int(key)
.map(|seconds| seconds as u64)
.map_err(|e| ConfigurationError::new(&key, &e.to_string()))?;

let key = "common.liveness_max_sessions";
let liveness_max_sessions = cfg
Expand Down
3 changes: 2 additions & 1 deletion common/src/configuration/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ pub fn default_config(bootstrap: &ConfigBootstrap) -> Config {
)
.unwrap();
cfg.set_default("wallet.base_node_query_timeout", 60).unwrap();
// 60 sec * 60 minutes * 12 hours.
cfg.set_default("wallet.base_node_service_refresh_interval", 5).unwrap();
cfg.set_default("wallet.base_node_service_request_max_age", 60).unwrap();
cfg.set_default("wallet.scan_for_utxo_interval", 60 * 60 * 12).unwrap();
cfg.set_default("wallet.transaction_broadcast_monitoring_timeout", 60)
.unwrap();
Expand Down
17 changes: 11 additions & 6 deletions comms/src/protocol/rpc/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,14 +464,14 @@ where
return Ok(());
}

debug!(
target: LOG_TARGET,
"[Peer=`{}`] Got request {}", self.node_id, decoded_msg
);

let msg_flags = RpcMessageFlags::from_bits_truncate(decoded_msg.flags as u8);
if msg_flags.contains(RpcMessageFlags::ACK) {
debug!(target: LOG_TARGET, "[Peer=`{}`] ACK.", self.node_id);
debug!(
target: LOG_TARGET,
"[Peer=`{}` {}] sending ACK response.",
self.node_id,
self.protocol_name()
);
let ack = proto::rpc::RpcResponse {
request_id,
status: RpcStatus::ok().as_code(),
Expand All @@ -482,6 +482,11 @@ where
return Ok(());
}

debug!(
target: LOG_TARGET,
"[Peer=`{}`] Got request {}", self.node_id, decoded_msg
);

let req = Request::with_context(
self.create_request_context(request_id),
method,
Expand Down