diff --git a/test/functional/feature_loan_payback_with_collateral.py b/test/functional/feature_loan_payback_with_collateral.py index 076cb5805f..16352a702a 100755 --- a/test/functional/feature_loan_payback_with_collateral.py +++ b/test/functional/feature_loan_payback_with_collateral.py @@ -36,17 +36,6 @@ def set_test_params(self): '-simulatemainnet=1' ]] - def rollback_to(self, block): - self.log.info("rollback to: %d", block) - node = self.nodes[0] - current_height = node.getblockcount() - if current_height == block: - return - blockhash = node.getblockhash(block + 1) - node.invalidateblock(blockhash) - node.clearmempool() - assert_equal(block, node.getblockcount()) - def createOracles(self): self.oracle_address1 = self.nodes[0].getnewaddress("", "legacy") price_feeds = [{"currency": "USD", "token": "DFI"}, diff --git a/test/functional/feature_negative_interest.py b/test/functional/feature_negative_interest.py index 4da02842df..401d2714b7 100755 --- a/test/functional/feature_negative_interest.py +++ b/test/functional/feature_negative_interest.py @@ -19,16 +19,6 @@ def getDecimalAmount(amount): class NegativeInterestTest (DefiTestFramework): - def rollback_to(self, block): - self.log.info("rollback to: %d", block) - current_height = self.nodes[0].getblockcount() - if current_height == block: - return - blockhash = self.nodes[0].getblockhash(block + 1) - self.nodes[0].invalidateblock(blockhash) - self.nodes[0].clearmempool() - assert_equal(block, self.nodes[0].getblockcount()) - def set_test_params(self): self.num_nodes = 1 self.setup_clean_chain = True diff --git a/test/functional/feature_on_chain_government_fee_distribution.py b/test/functional/feature_on_chain_government_fee_distribution.py index 2e6f262831..8cad8a41ea 100755 --- a/test/functional/feature_on_chain_government_fee_distribution.py +++ b/test/functional/feature_on_chain_government_fee_distribution.py @@ -7,9 +7,7 @@ from test_framework.test_framework import DefiTestFramework from test_framework.util import ( - assert_equal, - connect_nodes_bi, - disconnect_nodes, + assert_equal ) from decimal import ROUND_DOWN, Decimal @@ -120,19 +118,7 @@ def test_cfp_fee_distribution(self, amount, expectedFee, burnPct, vote, cycles=2 history = self.nodes[0].listaccounthistory(mn3['ownerAuthAddress'], {"txtype": "ProposalFeeRedistribution"}) assert_equal(history, []) - # Disconnect nodes and check connection count - for i in range(self.num_nodes - 1): - disconnect_nodes(self.nodes[i], i + 1) - assert_equal(self.nodes[i].getconnectioncount(), 0) - assert_equal(self.nodes[3].getconnectioncount(), 0) - - # Rollback nodes in isolation - for i in range(self.num_nodes): - self.rollback_to(height, nodes=[i]) - - # Connect nodes - for i in range(self.num_nodes - 1): - connect_nodes_bi(self.nodes, i, i + 1) + self.rollback_to(height) def setup(self): # Get MN addresses diff --git a/test/functional/feature_on_chain_government_govvar_update.py b/test/functional/feature_on_chain_government_govvar_update.py index e4c3ac9803..82bef5eb02 100755 --- a/test/functional/feature_on_chain_government_govvar_update.py +++ b/test/functional/feature_on_chain_government_govvar_update.py @@ -93,7 +93,7 @@ def test_cfp_update_automatic_payout(self): account = self.nodes[0].getaccount(address) assert_equal(account, ['100.00000000@DFI']) - self.rollback_to(height, nodes=[0, 1, 2, 3]) + self.rollback_to(height) def test_cfp_update_quorum(self): height = self.nodes[0].getblockcount() @@ -157,7 +157,7 @@ def test_cfp_update_quorum(self): proposal = self.nodes[0].getgovproposal(propId) assert_equal(proposal['status'], 'Rejected') - self.rollback_to(height, nodes=[0, 1, 2, 3]) + self.rollback_to(height) def test_cfp_update_approval_threshold(self): height = self.nodes[0].getblockcount() @@ -230,7 +230,7 @@ def test_cfp_update_approval_threshold(self): proposal = self.nodes[0].getgovproposal(propId) assert_equal(proposal['status'], 'Rejected') - self.rollback_to(height, nodes=[0, 1, 2, 3]) + self.rollback_to(height) def test_cfp_update_fee_redistribution(self): height = self.nodes[0].getblockcount() @@ -297,7 +297,7 @@ def test_cfp_update_fee_redistribution(self): account1 = self.nodes[0].getaccount(mn1['ownerAuthAddress']) assert_equal(account1[0], expectedAmount) - self.rollback_to(height, nodes=[0, 1, 2, 3]) + self.rollback_to(height) def test_cfp_update_cfp_fee(self): height = self.nodes[0].getblockcount() @@ -358,7 +358,7 @@ def test_cfp_update_cfp_fee(self): account1 = self.nodes[0].getaccount(mn1['ownerAuthAddress']) assert_equal(account1[0], expectedAmount) - self.rollback_to(height, nodes=[0, 1, 2, 3]) + self.rollback_to(height) def test_cfp_update_voting_period(self): height = self.nodes[0].getblockcount() @@ -408,7 +408,7 @@ def test_cfp_update_voting_period(self): proposal = self.nodes[0].getgovproposal(propId) assert_equal(proposal['status'], 'Completed') - self.rollback_to(height, nodes=[0, 1, 2, 3]) + self.rollback_to(height) def test_cfp_update_voc_emergency_period(self): height = self.nodes[0].getblockcount() @@ -433,7 +433,7 @@ def test_cfp_update_voc_emergency_period(self): proposal = self.nodes[0].getgovproposal(propId) assert_equal(proposal['status'], 'Rejected') - self.rollback_to(height, nodes=[0, 1, 2, 3]) + self.rollback_to(height) def test_cfp_update_voc_emergency_fee(self): height = self.nodes[0].getblockcount() @@ -478,7 +478,7 @@ def test_cfp_update_voc_emergency_fee(self): account1 = self.nodes[0].getaccount(mn1['ownerAuthAddress']) assert_equal(account1[0], expectedAmount) - self.rollback_to(height, nodes=[0, 1, 2, 3]) + self.rollback_to(height) def test_cfp_state_after_update(self): height = self.nodes[0].getblockcount() @@ -527,7 +527,7 @@ def test_cfp_state_after_update(self): proposal = self.nodes[0].getgovproposal(propId) assert_equal(proposal['status'], 'Completed') - self.rollback_to(height, nodes=[0, 1, 2, 3]) + self.rollback_to(height) def setup(self): # Get MN addresses diff --git a/test/functional/feature_poolswap.py b/test/functional/feature_poolswap.py index 19ef7080da..97e24d4d54 100755 --- a/test/functional/feature_poolswap.py +++ b/test/functional/feature_poolswap.py @@ -563,7 +563,7 @@ def test_testpoolswap_errors(self): "amountFrom": 0, "tokenFrom": self.symbolLTC, "tokenTo": self.symbolBTC, "from": self.accountGN0, "to": self.accountSN1, "maxPrice": 0.1}) def revert_to_initial_state(self): - self.rollback_to(block=0, nodes=[0, 1, 2]) + self.rollback_to(block=0) assert_equal(len(self.nodes[0].listpoolpairs()), 0) assert_equal(len(self.nodes[1].listpoolpairs()), 0) assert_equal(len(self.nodes[2].listpoolpairs()), 0) diff --git a/test/functional/feature_restore_utxo.py b/test/functional/feature_restore_utxo.py index a5f862ca2e..1e37a9a007 100755 --- a/test/functional/feature_restore_utxo.py +++ b/test/functional/feature_restore_utxo.py @@ -72,7 +72,7 @@ def run_test(self): self.nodes[1].generate(1) self.nodes[1].accounttoaccount(node1_source, {node1_source: "1@BTC"}) self.nodes[1].generate(1) - self.rollback_to(block, nodes=[1]) + self.rollback_to(block) assert_equal(len(self.nodes[1].listunspent()), node1_utxos) if __name__ == '__main__': diff --git a/test/functional/rpc_getstoredinterest.py b/test/functional/rpc_getstoredinterest.py index 8a91fedc49..f2e15cb958 100755 --- a/test/functional/rpc_getstoredinterest.py +++ b/test/functional/rpc_getstoredinterest.py @@ -43,16 +43,6 @@ def set_test_params(self): '-jellyfish_regtest=1', '-txindex=1', '-simulatemainnet=1'] ] - # Utils - def rollback_to(self, block): - node = self.nodes[0] - current_height = node.getblockcount() - if current_height == block: - return - blockhash = node.getblockhash(block + 1) - node.invalidateblock(blockhash) - node.clearmempool() - assert_equal(block, node.getblockcount()) def new_vault(self, loan_scheme, deposit=10): vaultId = self.nodes[0].createvault(self.account0, loan_scheme) diff --git a/test/functional/test_framework/test_framework.py b/test/functional/test_framework/test_framework.py index 8fc0321411..2598d6a11a 100755 --- a/test/functional/test_framework/test_framework.py +++ b/test/functional/test_framework/test_framework.py @@ -15,6 +15,7 @@ import sys import tempfile import time +import re from typing import List from .authproxy import JSONRPCException @@ -26,6 +27,7 @@ PortSeed, assert_equal, check_json_precision, + connect_nodes, connect_nodes_bi, disconnect_nodes, get_datadir_path, @@ -406,8 +408,7 @@ def import_deterministic_coinbase_privkeys(self): n.importprivkey(privkey=n.get_genesis_keys().operatorPrivKey, label='coinbase', rescan=True) # rollback one node (Default = node 0) - def _rollback_to(self, block, node=0): - node = self.nodes[node] + def _rollback_to(self, block, node): current_height = node.getblockcount() if current_height == block: return @@ -418,11 +419,26 @@ def _rollback_to(self, block, node=0): # rollback to block # nodes param is a list of node numbers to roll back ([0, 1, 2, 3...] (Default -> None -> node 0) def rollback_to(self, block, nodes=None): - if nodes is None: - self._rollback_to(block) - else: - for node in nodes: - self._rollback_to(block, node=node) + nodes = nodes or self.nodes + connections = {} + for node in nodes: + nodes_connections = [] + for x in node.getpeerinfo(): + if not x['inbound']: + node_number = re.findall(r'\d+', x['subver'])[-1] + nodes_connections.append(int(node_number)) + connections[node] = nodes_connections + + for node in nodes: + for x in connections[node]: + disconnect_nodes(node, x) + + for node in nodes: + self._rollback_to(block, node) + + for node in nodes: + for x in connections[node]: + connect_nodes(node, x) def run_test(self): """Tests must override this method to define test logic"""