Skip to content

Commit

Permalink
fix: edge-case fixes for wallet peer switching in console wallet
Browse files Browse the repository at this point in the history
- set peer using the watch to allow the connectivity service to
  immediately be aware of the new peer
- aborted the dial early if necessary, should the user set a different peer
- slightly reduce busy-ness of the wallet monitor by monitoring for less
  comms connectivity events
- monitor for wallet connectivity peer status changes to improve the
  responsiveness of the status ui update.
  • Loading branch information
sdbondi committed Aug 23, 2021
1 parent db4c0b9 commit 56b65b8
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 51 deletions.
5 changes: 5 additions & 0 deletions applications/tari_console_wallet/src/ui/state/app_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ use tari_crypto::{ristretto::RistrettoPublicKey, tari_utilities::hex::Hex};
use tari_shutdown::ShutdownSignal;
use tari_wallet::{
base_node_service::{handle::BaseNodeEventReceiver, service::BaseNodeState},
connectivity_service::WalletConnectivityHandle,
contacts_service::storage::database::Contact,
output_manager_service::{handle::OutputManagerEventReceiver, service::Balance, TxId, TxoValidationType},
transaction_service::{
Expand Down Expand Up @@ -652,6 +653,10 @@ impl AppStateInner {
self.wallet.comms.connectivity().get_event_subscription().fuse()
}

pub fn get_wallet_connectivity(&self) -> WalletConnectivityHandle {
self.wallet.wallet_connectivity.clone()
}

pub fn get_base_node_event_stream(&self) -> Fuse<BaseNodeEventReceiver> {
self.wallet.base_node_service.clone().get_event_stream_fused()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ impl WalletEventMonitor {
.get_output_manager_service_event_stream();

let mut connectivity_events = self.app_state_inner.read().await.get_connectivity_event_stream();
let wallet_connectivity = self.app_state_inner.read().await.get_wallet_connectivity();
let mut connectivity_status = wallet_connectivity.get_connectivity_status_watch().fuse();

let mut base_node_events = self.app_state_inner.read().await.get_base_node_event_stream();

Expand Down Expand Up @@ -105,17 +107,18 @@ impl WalletEventMonitor {
Err(_) => debug!(target: LOG_TARGET, "Lagging read on Transaction Service event broadcast channel"),
}
},
status = connectivity_status.select_next_some() => {
trace!(target: LOG_TARGET, "Wallet Event Monitor received wallet connectivity status {:?}", status);
self.trigger_peer_state_refresh().await;
},
result = connectivity_events.select_next_some() => {
match result {
Ok(msg) => {
trace!(target: LOG_TARGET, "Wallet Event Monitor received wallet connectivity event {:?}", msg);
match &*msg {
ConnectivityEvent::PeerDisconnected(_) |
ConnectivityEvent::ManagedPeerDisconnected(_) |
ConnectivityEvent::PeerConnected(_) |
ConnectivityEvent::PeerBanned(_) |
ConnectivityEvent::PeerOffline(_) |
ConnectivityEvent::PeerConnectionWillClose(_, _) => {
ConnectivityEvent::PeerDisconnected(node_id) |
ConnectivityEvent::ManagedPeerDisconnected(node_id) ||
ConnectivityEvent::PeerConnected(conn) => {
self.trigger_peer_state_refresh().await;
},
// Only the above variants trigger state refresh
Expand Down
2 changes: 1 addition & 1 deletion base_layer/wallet/src/base_node_service/monitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ impl<T: WalletBackend + 'static> BaseNodeMonitor<T> {
}

async fn update_connectivity_status(&self) -> NodeId {
let mut watcher = self.wallet_connectivity.get_connectivity_status_watcher();
let mut watcher = self.wallet_connectivity.get_connectivity_status_watch();
loop {
use OnlineStatus::*;
match watcher.recv().await.unwrap_or(Offline) {
Expand Down
19 changes: 8 additions & 11 deletions base_layer/wallet/src/connectivity_service/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use super::service::OnlineStatus;
use crate::connectivity_service::error::WalletConnectivityError;
use crate::connectivity_service::{error::WalletConnectivityError, watch::Watch};
use futures::{
channel::{mpsc, oneshot},
SinkExt,
Expand All @@ -36,33 +36,30 @@ use tokio::sync::watch;
pub enum WalletConnectivityRequest {
ObtainBaseNodeWalletRpcClient(oneshot::Sender<RpcClientLease<BaseNodeWalletRpcClient>>),
ObtainBaseNodeSyncRpcClient(oneshot::Sender<RpcClientLease<BaseNodeSyncRpcClient>>),
SetBaseNode(Box<Peer>),
}

#[derive(Clone)]
pub struct WalletConnectivityHandle {
sender: mpsc::Sender<WalletConnectivityRequest>,
base_node_watch_rx: watch::Receiver<Option<Peer>>,
base_node_watch: Watch<Option<Peer>>,
online_status_rx: watch::Receiver<OnlineStatus>,
}

impl WalletConnectivityHandle {
pub(super) fn new(
sender: mpsc::Sender<WalletConnectivityRequest>,
base_node_watch_rx: watch::Receiver<Option<Peer>>,
base_node_watch: Watch<Option<Peer>>,
online_status_rx: watch::Receiver<OnlineStatus>,
) -> Self {
Self {
sender,
base_node_watch_rx,
base_node_watch,
online_status_rx,
}
}

pub async fn set_base_node(&mut self, base_node_peer: Peer) -> Result<(), WalletConnectivityError> {
self.sender
.send(WalletConnectivityRequest::SetBaseNode(Box::new(base_node_peer)))
.await?;
self.base_node_watch.broadcast(Some(base_node_peer));
Ok(())
}

Expand Down Expand Up @@ -109,15 +106,15 @@ impl WalletConnectivityHandle {
self.online_status_rx.recv().await.unwrap_or(OnlineStatus::Offline)
}

pub fn get_connectivity_status_watcher(&self) -> watch::Receiver<OnlineStatus> {
pub fn get_connectivity_status_watch(&self) -> watch::Receiver<OnlineStatus> {
self.online_status_rx.clone()
}

pub fn get_current_base_node_peer(&self) -> Option<Peer> {
self.base_node_watch_rx.borrow().clone()
self.base_node_watch.borrow().clone()
}

pub fn get_current_base_node_id(&self) -> Option<NodeId> {
self.base_node_watch_rx.borrow().as_ref().map(|p| p.node_id.clone())
self.base_node_watch.borrow().as_ref().map(|p| p.node_id.clone())
}
}
2 changes: 1 addition & 1 deletion base_layer/wallet/src/connectivity_service/initializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl ServiceInitializer for WalletConnectivityInitializer {
let online_status_watch = Watch::new(OnlineStatus::Offline);
context.register_handle(WalletConnectivityHandle::new(
sender,
base_node_watch.get_receiver(),
base_node_watch.clone(),
online_status_watch.get_receiver(),
));

Expand Down
47 changes: 36 additions & 11 deletions base_layer/wallet/src/connectivity_service/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ use crate::{
use core::mem;
use futures::{
channel::{mpsc, oneshot},
future,
future::Either,
stream::Fuse,
StreamExt,
};
Expand All @@ -35,6 +37,7 @@ use tari_comms::{
connectivity::ConnectivityRequester,
peer_manager::{NodeId, Peer},
protocol::rpc::{RpcClientLease, RpcClientPool},
PeerConnection,
};
use tari_core::base_node::{rpc::BaseNodeWalletRpcClient, sync::rpc::BaseNodeSyncRpcClient};
use tokio::time;
Expand Down Expand Up @@ -110,10 +113,6 @@ impl WalletConnectivityService {
ObtainBaseNodeSyncRpcClient(reply) => {
self.handle_pool_request(reply.into()).await;
},

SetBaseNode(peer) => {
self.set_base_node_peer(*peer);
},
}
}

Expand Down Expand Up @@ -202,11 +201,15 @@ impl WalletConnectivityService {
self.base_node_watch.broadcast(Some(peer));
}

fn current_base_node(&self) -> Option<NodeId> {
self.base_node_watch.borrow().as_ref().map(|p| p.node_id.clone())
}

async fn setup_base_node_connection(&mut self) {
self.pools = None;
loop {
let node_id = match self.base_node_watch.borrow().as_ref() {
Some(p) => p.node_id.clone(),
let node_id = match self.current_base_node() {
Some(n) => n,
None => return,
};
debug!(
Expand All @@ -215,16 +218,24 @@ impl WalletConnectivityService {
);
self.set_online_status(OnlineStatus::Connecting);
match self.try_setup_rpc_pool(node_id.clone()).await {
Ok(_) => {
Ok(true) => {
self.set_online_status(OnlineStatus::Online);
debug!(
target: LOG_TARGET,
"Wallet is ONLINE and connected to base node {}", node_id
);
break;
},
Ok(false) => {
// Retry with updated peer
continue;
},
Err(e) => {
self.set_online_status(OnlineStatus::Offline);
if self.current_base_node() != Some(node_id) {
self.set_online_status(OnlineStatus::Connecting);
} else {
self.set_online_status(OnlineStatus::Offline);
}
error!(target: LOG_TARGET, "{}", e);
time::delay_for(self.config.base_node_monitor_refresh_interval).await;
continue;
Expand All @@ -237,9 +248,12 @@ impl WalletConnectivityService {
let _ = self.online_status_watch.broadcast(status);
}

async fn try_setup_rpc_pool(&mut self, peer: NodeId) -> Result<(), WalletConnectivityError> {
async fn try_setup_rpc_pool(&mut self, peer: NodeId) -> Result<bool, WalletConnectivityError> {
self.connectivity.add_managed_peers(vec![peer.clone()]).await?;
let conn = self.connectivity.dial_peer(peer).await?;
let conn = match self.try_dial_peer(peer).await? {
Some(peer) => peer,
None => return Ok(false),
};
debug!(
target: LOG_TARGET,
"Successfully established peer connection to base node {}",
Expand All @@ -257,7 +271,18 @@ impl WalletConnectivityService {
"Successfully established RPC connection {}",
conn.peer_node_id()
);
Ok(())
Ok(true)
}

async fn try_dial_peer(&mut self, peer: NodeId) -> Result<Option<PeerConnection>, WalletConnectivityError> {
let recv_fut = self.base_node_watch.recv();
futures::pin_mut!(recv_fut);
let dial_fut = self.connectivity.dial_peer(peer);
futures::pin_mut!(dial_fut);
match future::select(recv_fut, dial_fut).await {
Either::Left(_) => Ok(None),
Either::Right((conn, _)) => Ok(Some(conn?)),
}
}

async fn notify_pending_requests(&mut self) -> Result<(), WalletConnectivityError> {
Expand Down
6 changes: 3 additions & 3 deletions base_layer/wallet/src/connectivity_service/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async fn setup() -> (
let (tx, rx) = mpsc::channel(1);
let base_node_watch = Watch::new(None);
let online_status_watch = Watch::new(OnlineStatus::Offline);
let handle = WalletConnectivityHandle::new(tx, base_node_watch.get_receiver(), online_status_watch.get_receiver());
let handle = WalletConnectivityHandle::new(tx, base_node_watch.clone(), online_status_watch.get_receiver());
let (connectivity, mock) = create_connectivity_mock();
let mock_state = mock.spawn();
// let peer_manager = create_peer_manager(tempdir().unwrap());
Expand Down Expand Up @@ -138,7 +138,7 @@ async fn it_changes_to_a_new_base_node() {

mock_state.await_call_count(2).await;
mock_state.expect_dial_peer(base_node_peer1.node_id()).await;
assert_eq!(mock_state.count_calls_containing("AddManagedPeer").await, 1);
assert_eq!(mock_state.count_calls_containing("AddManagedPeer").await, 2);
let _ = mock_state.take_calls().await;

let rpc_client = handle.obtain_base_node_wallet_rpc_client().await.unwrap();
Expand All @@ -149,7 +149,7 @@ async fn it_changes_to_a_new_base_node() {

mock_state.await_call_count(2).await;
mock_state.expect_dial_peer(base_node_peer2.node_id()).await;
assert_eq!(mock_state.count_calls_containing("AddManagedPeer").await, 1);
assert_eq!(mock_state.count_calls_containing("AddManagedPeer").await, 2);

let rpc_client = handle.obtain_base_node_wallet_rpc_client().await.unwrap();
assert!(rpc_client.is_connected());
Expand Down
2 changes: 1 addition & 1 deletion base_layer/wallet/src/connectivity_service/watch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl<T: Clone> Watch<T> {
self.receiver_mut().recv().await
}

pub fn borrow(&mut self) -> watch::Ref<'_, T> {
pub fn borrow(&self) -> watch::Ref<'_, T> {
self.receiver().borrow()
}

Expand Down
18 changes: 8 additions & 10 deletions common/src/configuration/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -518,18 +518,16 @@ fn convert_node_config(
let console_wallet_notify_file = optional(cfg.get_str(key))?.map(PathBuf::from);

let key = "wallet.base_node_service_refresh_interval";
let wallet_base_node_service_refresh_interval = match cfg.get_int(key) {
Ok(seconds) => seconds as u64,
Err(ConfigError::NotFound(_)) => 10,
Err(e) => return Err(ConfigurationError::new(&key, &e.to_string())),
};
let wallet_base_node_service_refresh_interval = cfg
.get_int(key)
.map(|seconds| seconds as u64)
.map_err(|e| ConfigurationError::new(&key, &e.to_string()))?;

let key = "wallet.base_node_service_request_max_age";
let wallet_base_node_service_request_max_age = match cfg.get_int(key) {
Ok(seconds) => seconds as u64,
Err(ConfigError::NotFound(_)) => 60,
Err(e) => return Err(ConfigurationError::new(&key, &e.to_string())),
};
let wallet_base_node_service_request_max_age = cfg
.get_int(key)
.map(|seconds| seconds as u64)
.map_err(|e| ConfigurationError::new(&key, &e.to_string()))?;

let key = "common.liveness_max_sessions";
let liveness_max_sessions = cfg
Expand Down
3 changes: 2 additions & 1 deletion common/src/configuration/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ pub fn default_config(bootstrap: &ConfigBootstrap) -> Config {
)
.unwrap();
cfg.set_default("wallet.base_node_query_timeout", 60).unwrap();
// 60 sec * 60 minutes * 12 hours.
cfg.set_default("wallet.base_node_service_refresh_interval", 5).unwrap();
cfg.set_default("wallet.base_node_service_request_max_age", 60).unwrap();
cfg.set_default("wallet.scan_for_utxo_interval", 60 * 60 * 12).unwrap();
cfg.set_default("wallet.transaction_broadcast_monitoring_timeout", 60)
.unwrap();
Expand Down
17 changes: 11 additions & 6 deletions comms/src/protocol/rpc/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,14 +464,14 @@ where
return Ok(());
}

debug!(
target: LOG_TARGET,
"[Peer=`{}`] Got request {}", self.node_id, decoded_msg
);

let msg_flags = RpcMessageFlags::from_bits_truncate(decoded_msg.flags as u8);
if msg_flags.contains(RpcMessageFlags::ACK) {
debug!(target: LOG_TARGET, "[Peer=`{}`] ACK.", self.node_id);
debug!(
target: LOG_TARGET,
"[Peer=`{}` {}] sending ACK response.",
self.node_id,
self.protocol_name()
);
let ack = proto::rpc::RpcResponse {
request_id,
status: RpcStatus::ok().as_code(),
Expand All @@ -482,6 +482,11 @@ where
return Ok(());
}

debug!(
target: LOG_TARGET,
"[Peer=`{}`] Got request {}", self.node_id, decoded_msg
);

let req = Request::with_context(
self.create_request_context(request_id),
method,
Expand Down

0 comments on commit 56b65b8

Please sign in to comment.