diff --git a/src/monitoring_service/database.py b/src/monitoring_service/database.py index 0d2b37d26..fd99bf02f 100644 --- a/src/monitoring_service/database.py +++ b/src/monitoring_service/database.py @@ -39,22 +39,21 @@ class SharedDatabase(BaseDatabase): schema_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "schema.sql") def upsert_monitor_request(self, request: MonitorRequest) -> None: - values = [ - hex256(request.channel_identifier), - to_checksum_address(request.token_network_address), - request.balance_hash, - hex256(request.nonce), - request.additional_hash, - to_hex(request.closing_signature), - to_hex(request.non_closing_signature), - hex256(request.reward_amount), - to_hex(request.reward_proof_signature), - to_checksum_address(request.non_closing_signer), - ] - upsert_sql = "INSERT OR REPLACE INTO monitor_request VALUES ({})".format( - ", ".join("?" * len(values)) + self.upsert( + "monitor_request", + dict( + channel_identifier=hex256(request.channel_identifier), + token_network_address=to_checksum_address(request.token_network_address), + balance_hash=request.balance_hash, + nonce=hex256(request.nonce), + additional_hash=request.additional_hash, + closing_signature=to_hex(request.closing_signature), + non_closing_signature=to_hex(request.non_closing_signature), + reward_amount=hex256(request.reward_amount), + reward_proof_signature=to_hex(request.reward_proof_signature), + non_closing_signer=to_checksum_address(request.non_closing_signer), + ), ) - self.conn.execute(upsert_sql, values) def get_monitor_request( self, @@ -80,7 +79,11 @@ def get_monitor_request( if row is None: return None - kwargs = {key: val for key, val in zip(row.keys(), row) if key != "non_closing_signer"} + kwargs = { + key: val + for key, val in zip(row.keys(), row) + if key not in ("non_closing_signer", "saved_at", "waiting_for_channel") + } kwargs["token_network_address"] = to_canonical_address(kwargs["token_network_address"]) kwargs["closing_signature"] = decode_hex(kwargs["closing_signature"]) kwargs["non_closing_signature"] = decode_hex(kwargs["non_closing_signature"]) diff --git a/src/monitoring_service/schema.sql b/src/monitoring_service/schema.sql index 3e0f969fe..169b3343c 100644 --- a/src/monitoring_service/schema.sql +++ b/src/monitoring_service/schema.sql @@ -48,10 +48,20 @@ CREATE TABLE monitor_request ( reward_proof_signature CHAR(132) NOT NULL, non_closing_signer CHAR(42) NOT NULL, + + -- These two columns are just for handling MRs before we have confirmed + -- that a matching channel exists. + -- * If `waiting_for_channel` is false, we've already checked that such a + -- channel exists and everything is ok. + -- * If `saved_at` is sufficiently recent, a missing channel is acceptable. + -- * If `saved_at` is old, we will delete the MR if not matching channel is + -- found. + saved_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + waiting_for_channel BOOL NOT NULL DEFAULT TRUE, + PRIMARY KEY (channel_identifier, token_network_address, non_closing_signer) - --FOREIGN KEY (channel_identifier, token_network_address) - -- REFERENCES channels(channel_identifier, token_network_address) ON DELETE CASCADE ); +CREATE INDEX old_mr_idx ON monitor_request(saved_at) WHERE (waiting_for_channel); CREATE TABLE waiting_transactions ( transaction_hash CHAR(66) NOT NULL diff --git a/src/monitoring_service/service.py b/src/monitoring_service/service.py index 1e41549b2..21d5a0ac4 100644 --- a/src/monitoring_service/service.py +++ b/src/monitoring_service/service.py @@ -1,5 +1,6 @@ import sys import time +from datetime import datetime, timedelta from typing import Callable, Dict import structlog @@ -122,6 +123,7 @@ def start( last_block = min(last_confirmed_block, max_query_interval_end_block) self._process_new_blocks(last_block) + self._purge_old_monitor_requests() try: wait_function(self.poll_interval) @@ -186,3 +188,32 @@ def _process_new_blocks(self, last_block: BlockNumber) -> None: transaction_hash=tx_hash, receipt=receipt, ) + + def _purge_old_monitor_requests(self) -> None: + """ Delete all old MRs for which still no channel exists. + + Also marks all MRs which have a channel as not waiting_for_channel to + avoid checking them again, every time. + """ + with self.context.db.conn: + self.context.db.conn.execute( + """ + UPDATE monitor_request SET waiting_for_channel = false + WHERE waiting_for_channel + AND EXISTS ( + SELECT 1 + FROM channel + WHERE (channel.identifier, channel.token_network_address) + = (monitor_request.channel_identifier, monitor_request.token_network_address) + ) + """ + ) + before_this_is_old = datetime.utcnow() - timedelta(minutes=15) + self.context.db.conn.execute( + """ + DELETE FROM monitor_request + WHERE waiting_for_channel + AND saved_at < ? + """, + [before_this_is_old], + ) diff --git a/tests/monitoring/fixtures/__init__.py b/tests/monitoring/fixtures/__init__.py index 484b3658d..138ec55d7 100644 --- a/tests/monitoring/fixtures/__init__.py +++ b/tests/monitoring/fixtures/__init__.py @@ -1,2 +1,3 @@ from .contracts import * # noqa +from .factories import * # noqa from .server import * # noqa diff --git a/tests/monitoring/fixtures/factories.py b/tests/monitoring/fixtures/factories.py new file mode 100644 index 000000000..08593fd4e --- /dev/null +++ b/tests/monitoring/fixtures/factories.py @@ -0,0 +1,44 @@ +import pytest +from eth_utils import encode_hex, to_checksum_address + +from monitoring_service.states import HashedBalanceProof +from raiden.messages import RequestMonitoring +from raiden.utils.typing import Address, ChannelID, Nonce, TokenAmount, TokenNetworkAddress +from raiden_contracts.tests.utils.address import get_random_privkey +from raiden_contracts.utils.type_aliases import ChainID +from raiden_libs.utils import private_key_to_address + + +@pytest.fixture +def build_request_monitoring(): + non_closing_privkey = get_random_privkey() + non_closing_address = private_key_to_address(non_closing_privkey) + + def f( + chain_id: ChainID = ChainID(1), + amount: TokenAmount = TokenAmount(50), + nonce: Nonce = Nonce(1), + channel_id: ChannelID = ChannelID(1), + ) -> RequestMonitoring: + balance_proof = HashedBalanceProof( + channel_identifier=channel_id, + token_network_address=TokenNetworkAddress(b"1" * 20), + chain_id=chain_id, + nonce=nonce, + additional_hash="", + balance_hash=encode_hex(bytes([amount])), + priv_key=get_random_privkey(), + ) + request_monitoring = balance_proof.get_request_monitoring( + privkey=non_closing_privkey, + reward_amount=TokenAmount(55), + monitoring_service_contract_address=Address(bytes([11] * 20)), + ) + + # usually not a property of RequestMonitoring, but added for convenience in these tests + request_monitoring.non_closing_signer = to_checksum_address( # type: ignore + non_closing_address + ) + return request_monitoring + + return f diff --git a/tests/monitoring/monitoring_service/test_database.py b/tests/monitoring/monitoring_service/test_database.py index f5da38aed..eb94ad0e2 100644 --- a/tests/monitoring/monitoring_service/test_database.py +++ b/tests/monitoring/monitoring_service/test_database.py @@ -9,7 +9,7 @@ from monitoring_service.database import Database from monitoring_service.events import ActionMonitoringTriggeredEvent, ScheduledEvent -from monitoring_service.states import OnChainUpdateStatus +from monitoring_service.states import Channel, OnChainUpdateStatus from raiden.constants import UINT256_MAX from raiden.utils.typing import ( Address, @@ -18,6 +18,7 @@ TokenNetworkAddress, TransactionHash, ) +from raiden_libs.database import hex256 def test_scheduled_events(ms_database: Database): @@ -113,3 +114,52 @@ def test_save_and_load_channel(ms_database: Database): token_network_address=channel.token_network_address, channel_id=channel.identifier ) assert loaded_channel == channel + + +def test_purge_old_monitor_requests( + ms_database, build_request_monitoring, request_collector, monitoring_service +): + # We'll test the purge on MRs for three different channels + req_mons = [ + build_request_monitoring(channel_id=1), + build_request_monitoring(channel_id=2), + build_request_monitoring(channel_id=3), + ] + for req_mon in req_mons: + request_collector.on_monitor_request(req_mon) + + # Channel 1 exists in the db + token_network_address = req_mons[0].balance_proof.token_network_address + ms_database.conn.execute( + "INSERT INTO token_network VALUES (?)", [to_checksum_address(token_network_address)] + ) + ms_database.upsert_channel( + Channel( + identifier=ChannelID(1), + token_network_address=token_network_address, + participant1=Address(b"1" * 20), + participant2=Address(b"2" * 20), + settle_timeout=10, + ) + ) + + # The request for channel 2 is recent (default), but the one for channel 3 + # has been added 16 minutes ago. + ms_database.conn.execute( + """ + UPDATE monitor_request + SET saved_at = datetime('now', '-16 minutes') + WHERE channel_identifier = ? + """, + [hex256(3)], + ) + + monitoring_service._purge_old_monitor_requests() # pylint: disable=protected-access + remaining_mrs = ms_database.conn.execute( + """ + SELECT channel_identifier, waiting_for_channel + FROM monitor_request ORDER BY channel_identifier + """ + ).fetchall() + # sqlite returns booleans as 0/1 + assert [tuple(mr) for mr in remaining_mrs] == [(1, 0), (2, 1)] diff --git a/tests/monitoring/request_collector/test_server.py b/tests/monitoring/request_collector/test_server.py index cf7245359..4a409b354 100644 --- a/tests/monitoring/request_collector/test_server.py +++ b/tests/monitoring/request_collector/test_server.py @@ -1,54 +1,7 @@ # pylint: disable=redefined-outer-name -import pytest -from eth_utils import encode_hex, to_checksum_address +from eth_utils import to_checksum_address -from monitoring_service.states import HashedBalanceProof -from raiden.messages import RequestMonitoring from raiden.storage.serialization.serializer import DictSerializer -from raiden.utils.typing import ( - Address, - ChainID, - ChannelID, - Nonce, - TokenAmount, - TokenNetworkAddress, -) -from raiden_contracts.tests.utils import get_random_privkey -from raiden_libs.utils import private_key_to_address - - -@pytest.fixture -def build_request_monitoring(): - non_closing_privkey = get_random_privkey() - non_closing_address = private_key_to_address(non_closing_privkey) - - def f( - chain_id: ChainID = ChainID(1), - amount: TokenAmount = TokenAmount(50), - nonce: Nonce = Nonce(1), - ) -> RequestMonitoring: - balance_proof = HashedBalanceProof( - channel_identifier=ChannelID(1), - token_network_address=TokenNetworkAddress(b"1" * 20), - chain_id=chain_id, - nonce=nonce, - additional_hash="", - balance_hash=encode_hex(bytes([amount])), - priv_key=get_random_privkey(), - ) - request_monitoring = balance_proof.get_request_monitoring( - privkey=non_closing_privkey, - reward_amount=TokenAmount(55), - monitoring_service_contract_address=Address(bytes([11] * 20)), - ) - - # usually not a property of RequestMonitoring, but added for convenience in these tests - request_monitoring.non_closing_signer = to_checksum_address( # type: ignore - non_closing_address - ) - return request_monitoring - - return f def test_invalid_request(ms_database, build_request_monitoring, request_collector):